from __future__ import annotations
from typing import TYPE_CHECKING, Any
from torchsig.utils.abstractions import HierarchicalMetadataObject
if TYPE_CHECKING:
from torchsig.signals import Signal
[docs]
class BaseSignalGenerator(HierarchicalMetadataObject):
"""Defines a callable object which takes no arguments and returns a Signal.
Takes a metadata object in init to specify values for things like min and max bandwidth.
Attributes:
metadata: A metadata object to be used in signal generation.
transforms: Transforms to be applied to generated signals before returning them in the __call__() method
"""
[docs]
def __init__(self, transforms: list[Any] = [], **kwargs: dict[str, Any]) -> None:
"""Initializes Signal Builder.
Args:
transforms: List of transforms to be applied to generated signals before returning them in the __call__() method
**kwargs: Additional keyword arguments to pass to the parent class (HierarchicalMetadataObject)
"""
self.transforms = transforms
HierarchicalMetadataObject.__init__(self, **kwargs)
[docs]
def set_default_class_name(self, name: str) -> None:
"""Sets the class_name to name if there wasn't already a class name set.
Args:
name: The class name to set if no class name exists.
"""
if not hasattr(self, "class_name"):
self["class_name"] = name
[docs]
def copy(self) -> BaseSignalGenerator:
"""Creates a deep copy of the SignalGenerator with copied transforms.
Returns:
A new instance of the SignalGenerator with copied metadata and transforms.
"""
cpy = HierarchicalMetadataObject.copy(self)
cpy.transforms = [transform.copy() for transform in self.transforms]
return cpy
[docs]
def __call__(self) -> Signal:
"""Generates a new signal and applies all transforms.
Returns:
The generated signal after applying all transforms.
"""
new_signal = self.generate() # generate the signal
new_signal.add_parent(self, register=False) # transient parent link
if hasattr(self, "class_name"):
new_signal["class_name"] = (
self.class_name
) # if a class_name is given, it will override any class_name already in signal.metadata
for transform in self.transforms: # apply all transforms
new_signal = transform(new_signal)
return new_signal
[docs]
def __repr__(self) -> str:
"""Returns a string representation of the SignalGenerator.
Returns:
A string representation showing the class name, metadata, and transforms.
"""
repr_str = f"{self.__class__.__name__}("
if self._metadata is not None:
repr_str += "metadata="
repr_str += str(self._metadata)
repr_str += ", "
if self.transforms is not None:
repr_str += "transforms="
repr_str += str(self.transforms)
repr_str += ", "
repr_str += ")"
return repr_str
[docs]
def generate(self) -> Signal:
"""Generates a new signal.
This method must be implemented by subclasses.
Returns:
A new Signal object.
Raises:
NotImplementedError: If the method is not implemented by a subclass.
"""
raise NotImplementedError("Subclasses must implement 'build'")
[docs]
class ConcatSignalGenerator(BaseSignalGenerator):
"""A Signal Generator that wraps other signal generators and returns one of their outputs at random when called.
This generator randomly selects one of the provided signal generators and returns its output.
Each wrapped signal generator must be a valid BaseSignalGenerator instance.
Attributes:
signal_generators: List of BaseSignalGenerator instances to choose from.
random_generator: Random number generator used to select a signal generator.
"""
[docs]
def __init__(
self, signal_generators: list[BaseSignalGenerator], **kwargs: dict[str, Any]
) -> None:
"""Initializes the ConcatSignalGenerator.
Args:
signal_generators: List of BaseSignalGenerator instances to wrap.
**kwargs: Additional keyword arguments to pass to the parent class.
Raises:
TypeError: If any of the signal_generators are not BaseSignalGenerator instances.
"""
BaseSignalGenerator.__init__(self, **kwargs)
self.signal_generators = signal_generators
for signal_generator in self.signal_generators:
if True: # isinstance(signal_generator, Seedable):
signal_generator.add_parent(self)
try:
if self.validate_init:
signal_generator.validate_metadata_fields()
except AttributeError:
pass # there is no validate function; ignore and assume the best; a user who doesn't write a validate function does so at their own risk
[docs]
def copy(self) -> ConcatSignalGenerator:
"""Creates a deep copy of the ConcatSignalGenerator with copied signal generators.
Returns:
A new instance of ConcatSignalGenerator with copied metadata and signal generators.
"""
cpy = BaseSignalGenerator.copy(self)
cpy.signal_generators = [
signal_generator.copy() for signal_generator in self.signal_generators
]
return cpy
[docs]
def __repr__(self) -> str:
"""Returns a string representation of the ConcatSignalGenerator.
Returns:
A string representation showing the class name, metadata, and signal generators.
"""
repr_str = f"{self.__class__.__name__}("
if self._metadata is not None:
repr_str += "metadata="
repr_str += str(self._metadata)
repr_str += ", "
if self.signal_generators is not None:
repr_str += "signal_generators="
repr_str += str(self.signal_generators)
repr_str += ", "
repr_str += ")"
return repr_str
[docs]
def generate(self) -> Signal:
"""Generates a signal by randomly selecting one of the wrapped generators.
Returns:
Signal: The output of a randomly selected signal generator.
"""
return self.random_generator.choice(self.signal_generators)()