Source code for torchsig.transforms.metadata_transforms

"""Metadata Transforms"""

__all__ = [
    "MetadataTransform",
    "YOLOLabel"
]

from torchsig.signals.signal_types import Signal
from torchsig.transforms.base_transforms import Transform
from torchsig.utils.printing import generate_repr_str


## Base/Helper Classes
[docs] class MetadataTransform(Transform): """Base class for metadata transforms. This class defines the basic structure of a metadata transform, which includes: - The ability to validate metadata before applying the transform. - A method for applying the transform on signal metadata. - A callable interface to apply the transform to a list of signal metadata. Attributes: required_metadata: List of metadata fields required for applying the target transform. Methods: __validate(metadata): Validates the signal metadata before applying the transform. __apply(metadata): Applies the target transform to the metadata. Should be overridden by subclasses. __call__(signal): Applies the transform to a list of signal metadata dictionaries. __str__(): Returns the string representation of the transform. __repr__(): Returns a detailed string representation of the transform object. """
[docs] def __init__(self, required_metadata: list[str] = [], **kwargs) -> None: """Initialize the MetadataTransform. Args: required_metadata: List of metadata fields required for applying the target transform. **kwargs: Additional keyword arguments passed to the parent class. """ super().__init__(required_metadata=required_metadata, **kwargs)
def __validate__(self, signal): """Validate signal metadata before applying target transforms. Makes sure a signal has all required metadata for a transform; returns the original signal if it is valid; raises an exception otherwise. Args: signal: The signal to validate. Raises: ValueError: If metadata is missing required metadata fields or if input is not a Signal object. """ if not isinstance(signal, Signal): raise TypeError(f"input ({type(signal)}) is not a Signal object.") for required_metadatum in self.required_metadata: if not hasattr(signal, required_metadatum): raise ValueError( f"key: {required_metadatum} is missing from signal metadata, but is required by {self.__class__.__name__}" ) return signal
[docs] def __call__(self, signal: Signal) -> Signal: """Applies the target transform to a list of signal metadata. Args: signal: The signal to transform. Returns: The transformed signal. """ for component_signal in signal.component_signals: self.__apply__(component_signal) return signal
def __apply__(self, signal): """Applies the target transform to a single signal metadata. Args: signal: The signal to transform. Raises: NotImplementedError: Subclasses must implement this method. """ raise NotImplementedError
[docs] def __repr__(self) -> str: """Returns a detailed string representation of the transform object. Returns: A string representation of the transform object. """ return generate_repr_str(self, exclude_params=["required_metadata"])
[docs] class YOLOLabel(MetadataTransform): """Adds a YOLO_label to a signal. This transform adds a YOLO_label to a signal in the form of a list of tuples (cid, cx, cy, width, height). Attributes: required_metadata: List of metadata fields required for applying the transform. targets_metadata: List of metadata fields that will be added by the transform. """
[docs] def __init__(self, **kwargs): """Initialize the YOLOLabel transform. Args: **kwargs: Additional keyword arguments passed to the parent class. """ super().__init__( required_metadata=[ "class_index", "start", "bandwidth", "center_freq", "dataset_metadata", ], **kwargs, ) self.targets_metadata = ["yolo_label"]
def __apply__(self, signal: Signal) -> Signal: """Applies the YOLOLabel transform to a single signal. Args: signal: The signal to transform. Returns: The transformed signal with YOLO_label added. """ class_index = signal.class_index # normalized to width of sample width = signal.duration # normalize bandwidth with sample rate height = signal.bandwidth / signal.sample_rate x_center = signal.start + (width / 2.0) # normalize center frequency with sample rate # subtract from 1 since (0,0) for YOLO is upper left, but we define (0,0) lower left y_center = ( 1 - ((signal.sample_rate / 2.0) + signal.center_freq) / signal.sample_rate ) yolo_label = (class_index, x_center, y_center, width, height) signal["yolo_label"] = yolo_label return signal