"""Target Transforms
"""
__all__ = [
"TargetTransform",
"FamilyName",
"FamilyIndex",
"CustomLabel",
"YOLOLabel",
]
# TorchSig
from torchsig.transforms.base_transforms import Transform
from torchsig.signals.signal_lists import TorchSigSignalLists
from torchsig.utils.printing import generate_repr_str
# Built-In
from typing import List, Any, Optional, Dict
## Base/Helper Classes
[docs]
class CustomLabel(TargetTransform):
"""Adds a 'label' field to the metadata, which contains a tuple of fields
specified in the `label_fields` attribute.
Attributes:
label_fields (List[str]): The list of metadata fields to extract and place in the 'label' tuple.
"""
[docs]
def __init__(self, label_fields: List[str], label_name: str = 'label', **kwargs):
super().__init__(**kwargs)
self.required_metadata = label_fields
self.targets_metadata = [label_name]
def __apply__(self, metadata):
metadata[self.targets_metadata[0]] = tuple([metadata[field] for field in self.required_metadata])
return metadata
[docs]
class PassThrough(TargetTransform):
"""A helper class that does not alter the signal metadata but adds requested fields to the output.
This class is often used in combination with other transforms.
"""
[docs]
def __init__(self, field: List[str] = [], **kwargs):
super().__init__(**kwargs)
self.required_metadata = field
self.targets_metadata = field
def __apply__(self, metadata: dict):
return metadata
### Built-In Target Transforms
# These target transforms already have labels within the Signal class,
# which is turned into a dictionary inside the DatasetDict class. Thus,
# they do not any further processig than grabbing the label
###
[docs]
class CenterFreq(PassThrough):
"""Adds `center_freq` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['center_freq'])
[docs]
class Bandwidth(PassThrough):
"""Adds `bandwidth` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['bandwidth'])
[docs]
class StartInSamples(PassThrough):
"""Adds `start_in_samples` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['start_in_samples'])
[docs]
class DurationInSamples(PassThrough):
"""Adds `duration_in_samples` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['duration_in_samples'])
[docs]
class SNR(PassThrough):
"""Adds `snr_db` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['snr_db'])
[docs]
class ClassName(PassThrough):
"""Adds `class_name` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['class_name'])
[docs]
class ClassIndex(PassThrough):
"""Adds `class_index` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['class_index'])
[docs]
class SampleRate(PassThrough):
"""Adds `sample_rate` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['sample_rate'])
[docs]
class NumSamples(PassThrough):
"""Adds `num_samples` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['num_samples'])
[docs]
class Start(PassThrough):
"""Adds `start` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['start'])
[docs]
class Stop(PassThrough):
"""Adds `stop` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['stop'])
[docs]
class Duration(PassThrough):
"""Adds `duration` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['duration'])
[docs]
class StopInSamples(PassThrough):
"""Adds `stop_in_samples` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['stop_in_samples'])
[docs]
class UpperFreq(PassThrough):
"""Adds `upper_freq` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['upper_freq'])
[docs]
class LowerFreq(PassThrough):
"""Adds `lower_freq` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['lower_freq'])
[docs]
class OversamplingRate(PassThrough):
"""Adds `oversampling_rate` from signal metadata
"""
[docs]
def __init__(self, **kwargs):
super().__init__(field = ['oversampling_rate'])
# Special Target Transforms
# Target Transforms that require calculation to generate.
# They also need their metadata label field added to the metadata.
[docs]
class FamilyName(TargetTransform):
"""
Adds a family_name to a signal's metadata based on it's class_name
Attributes:
class_family_dict (Optional[Dict[str, str]], optional): Class name to Family name dict (keys=class name, values= family name). Defaults to TorchSigSignalLists.family_dict.
"""
[docs]
def __init__(self, class_family_dict: Optional[Dict[str, str]] = TorchSigSignalLists.family_dict, **kwargs):
super().__init__(**kwargs)
self.required_metadata = ["class_name"]
self.targets_metadata = ["family_name"]
self.class_family_dict = class_family_dict
def __apply__(self, metadata):
metadata["family_name"] = self.class_family_dict[metadata["class_name"]]
return metadata
[docs]
class FamilyIndex(TargetTransform):
"""
Adds a family_index to a signal's metadata based on it's class_name
Attributes:
class_family_dict (Optional[Dict[str, str]], optional): Class name to Family name dict (keys=class name, values= family name). Defaults to TorchSigSignalLists.family_dict.
family_list (Optional[List[str]], optional): Family list to index by. Defaults to alphabetical list of `class_family_dict` family names.
"""
[docs]
def __init__(self, class_family_dict: Optional[Dict[str, str]] = TorchSigSignalLists.family_dict, family_list: Optional[List[str]] = None, **kwargs):
super().__init__(**kwargs)
self.required_metadata = ["class_name"]
self.targets_metadata = ["family_id"]
self.class_family_dict = class_family_dict
self.family_list = sorted(list(set(self.class_family_dict.values()))) if family_list is None else family_list
def __apply__(self, metadata):
fam_name = self.class_family_dict[metadata["class_name"]]
metadata["family_id"] = self.family_list.index(fam_name)
return metadata
[docs]
class YOLOLabel(TargetTransform):
"""
Adds a YOLO_label to a signal, in the form of a list of tuples (cid, cx, cy, width, height)
Attributes:
output (str, optional): Structure to aggregate YOLO labels ("dict", "list"). Defaults to "list".
"""
output_list = ["list", "dict"]
[docs]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.required_metadata = ["class_index", "start", "bandwidth", "center_freq", "sample_rate"]
self.targets_metadata = ["yolo_label"]
def __apply__(self, metadata):
class_index = metadata["class_index"]
# normalized to width of sample
width = metadata["duration"]
# normalize bandwidth with sample rate
height = metadata["bandwidth"] / metadata["sample_rate"]
x_center = metadata["start"] + (width / 2.0)
# normalize center frequency with sample rate
# subtract from 1 since (0,0) for YOLO is upper left, but we define (0,0) lower left
y_center = 1 - ((metadata["sample_rate"] / 2.0) + metadata["center_freq"]) / metadata["sample_rate"]
yolo_label = (class_index, x_center, y_center, width, height)
metadata["yolo_label"] = yolo_label
return metadata