Source code for torchsig.datasets.dataset_utils
"""Dataset Utilities
"""
from torchsig.datasets.dataset_metadata import DatasetMetadata, NarrowbandMetadata, WidebandMetadata
from torchsig.signals.signal_types import Signal
from torchsig.utils.dsp import (
frequency_shift,
upconversion_anti_aliasing_filter,
)
import numpy as np
import yaml
# name of yaml file where dataset information will be written
dataset_yaml_name = "create_dataset_info.yaml"
# name of yaml file where dataset writing information will be written
writer_yaml_name = "writer_info.yaml"
[docs]
def dataset_full_path(dataset_type: str, impairment_level: int, train: bool = None) -> str:
"""Generates the full path for a dataset based on its type, impairment level, and whether it is for training.
Args:
dataset_type (str): Type of dataset (e.g., 'narrowband', 'wideband').
impairment_level (int): The impairment level for the dataset (0 = clean, 1 = level 1, 2 = impaired).
train (bool, optional): Whether the dataset is for training (True) or validation (False). Defaults to None.
Returns:
str: The full path to the dataset, e.g., 'torchsig_narrowband_clean/train'.
Example:
full_path = dataset_full_path('narrowband', 0, True)
print(full_path) # Output: 'torchsig_narrowband_clean/train'
"""
impaired_names = [
"clean",
"impaired_level_1",
"impaired"
]
impaired = impaired_names[impairment_level]
# e.g., torchsig_narrowband_clean
full_root = f"torchsig_{dataset_type}_{impaired}"
if train is not None:
# e.g., torchsig_narrowband_clean/train
subpath = "train" if train else "val"
full_root = f"{full_root}/{subpath}"
return full_root
[docs]
def collate_fn(batch):
"""Collates a batch by zipping its elements together.
Args:
batch (tuple): A batch from the dataloader.
Returns:
tuple: A tuple of zipped elements, where each element corresponds to a single batch item.
"""
return tuple(zip(*batch))
[docs]
def frequency_shift_signal(
signal: Signal,
center_freq_min: float,
center_freq_max: float,
sample_rate: float,
frequency_max: float,
frequency_min: float,
random_generator: np.random.Generator = np.random.default_rng(seed=None),
) -> Signal:
"""Randomly shifts the frequency of a signal to a new center frequency and applies aliasing filters if necessary.
Args:
signal (Signal): The signal object to be frequency shifted.
center_freq_min (float): Minimum center frequency for the random shift.
center_freq_max (float): Maximum center frequency for the random shift.
sample_rate (float): The sample rate of the signal.
frequency_max (float): Maximum frequency limit for aliasing.
frequency_min (float): Minimum frequency limit for aliasing.
random_generator (np.random.Generator, optional): Random number generator for generating the random shift. Defaults to `np.random.default_rng()`.
Returns:
Signal: The frequency-shifted signal with updated metadata.
"""
# randomize the center frequency
center_freq = random_generator.uniform(low=center_freq_min, high=center_freq_max)
# frequency shift to center_freq
signal.data = frequency_shift(signal.data, center_freq, sample_rate)
# update center_freq field in metadata
signal.metadata.center_freq = center_freq
# calculate upper and lower frequency edges of signal
upper_freq = signal.metadata.upper_freq
lower_freq = signal.metadata.lower_freq
# has aliasing occured due to the upconversion to the signal?
if (upper_freq > frequency_max or lower_freq < frequency_min):
# apply an anti-aliasing filter to the signal to attenuate energy that
# wrapped around -fs/2 or fs/2. additionally, due to the filtering the
# bandwidth changed bandwidth, and therefore changed the center frequency,
# so update the two metadata fields accordingly
signal.data, signal.metadata.center_freq, signal.metadata.bandwidth = upconversion_anti_aliasing_filter (
signal.data,
signal.metadata.center_freq,
signal.metadata.bandwidth,
sample_rate,
frequency_max,
frequency_min
)
#else: # do nothing
return signal
[docs]
def save_type(transforms: list, target_transforms: list):
"""Determines if the dataset will generate 'raw' IQ data, which means no transform and target transforms have been applied.
Args:
transforms (list): A list of transformations to be applied to the data.
target_transforms (list): A list of target transformations.
Returns:
bool: `True` if no transformations are applied, indicating raw IQ data; otherwise `False`.
"""
if len(transforms) > 0 or len(target_transforms) > 0:
return False
return True
[docs]
def to_dataset_metadata(dataset_metadata: DatasetMetadata | str | dict):
"""Converts the input dataset metadata to an appropriate DatasetMetadata object.
Args:
dataset_metadata (DatasetMetadata | str | dict): The dataset metadata, which can be:
- A `DatasetMetadata` object,
- A string representing the path to a YAML file containing the metadata,
- A dictionary representing the dataset metadata.
Returns:
DatasetMetadata: The corresponding `NarrowbandMetadata` or `WidebandMetadata` object initialized with the provided parameters.
Raises:
ValueError: If the input `dataset_metadata` is not valid or if required fields are missing from the metadata.
"""
if isinstance(dataset_metadata, DatasetMetadata):
return dataset_metadata
if isinstance(dataset_metadata, str):
with open(dataset_metadata, 'r') as f:
dataset_metadata = yaml.load(f, Loader=yaml.FullLoader)
if isinstance(dataset_metadata, dict):
# check that yaml file has minimum required params
if "required" not in dataset_metadata.keys():
raise ValueError("Invalid dataset_metadata. Does not have required field.")
# validate dataset_type exists
if "dataset_type" not in dataset_metadata['required'].keys():
raise ValueError("Invalid dataset_metadata. Does not have dataset_type field under required.")
# get dataset_type
dataset_type = dataset_metadata['required']['dataset_type'].lower()
# check if accidentally set dataset_type wrong
if "num_signals_max" in dataset_metadata['required'].keys() and dataset_type == "narrowband":
raise ValueError("num_signals_max defined in required params but dataset_type is narrowband. Should dataset_type be wideband?")
# use appropriate dataset metadata type
metadata = None
if dataset_type == "narrowband":
metadata = NarrowbandMetadata
elif dataset_type == "wideband":
metadata = WidebandMetadata
else:
raise ValueError("Invalid dataset_type in dataset_metadata")
# Validate minimum parameters given in yaml to instantiate
for min_param in metadata.minimum_params:
if min_param not in dataset_metadata['required'].keys():
raise ValueError(f"Missing required parameter {min_param} in dataset_metadata.")
# Put parameters into a flattened dictionary
init_params_dict = dataset_metadata['required']
# Remove dataset_type from the parameters
del dataset_metadata['required']['dataset_type']
# Remove transforms if they exist
if "transforms" in dataset_metadata['overrides'].keys():
del dataset_metadata['overrides']['transforms']
# Remove target transforms if they exist
if "target_transforms" in dataset_metadata['overrides'].keys():
del dataset_metadata['overrides']['target_transforms']
# Remove read_only if they exist
if "read_only" in dataset_metadata.keys():
del dataset_metadata['read_only']
# Handle if class_distribution is "uniform"
if "class_distribution" in dataset_metadata['overrides'].keys():
if dataset_metadata['overrides']['class_distribution'] == "uniform":
dataset_metadata['overrides']['class_distribution'] = None
# Handle if class_list is "all"
if "class_list" in dataset_metadata['overrides'].keys():
if dataset_metadata['overrides']['class_list'] == "all":
dataset_metadata['overrides']['class_list'] = None
# Handle if num_signals_distribution is "uniform"
if "num_signals_distribution" in dataset_metadata['overrides'].keys():
if dataset_metadata['overrides']['num_signals_distribution'] == "uniform":
dataset_metadata['overrides']['num_signals_distribution'] = None
# Merge overrides and write parameters if they exist
if "overrides" in dataset_metadata.keys():
init_params_dict = init_params_dict | dataset_metadata['overrides']
if "write" in dataset_metadata.keys():
init_params_dict = init_params_dict | dataset_metadata['write']
# Unpack dataset metadata and return the appropriate metadata object
return metadata(**init_params_dict)
# else:
# If the input is neither DatasetMetadata, str, nor dict
raise ValueError("Invalid dataset_metadata.")