"""Dataset Base Classes for creation and static loading.
"""
from __future__ import annotations
# TorchSig
from torchsig.datasets.dataset_metadata import DatasetMetadata
from torchsig.signals.signal_types import DatasetSignal, DatasetDict
from torchsig.signals.builder import SignalBuilder
import torchsig.signals.builders as signal_builders
from torchsig.utils.random import Seedable
from torchsig.utils.dsp import compute_spectrogram
from torchsig.datasets.dataset_utils import (
to_dataset_metadata,
frequency_shift_signal,
dataset_full_path
)
from torchsig.utils.printing import generate_repr_str
from torchsig.utils.verify import verify_transforms, verify_target_transforms
from torchsig.utils.file_handlers.zarr import ZarrFileHandler
from torchsig.datasets.dataset_utils import dataset_yaml_name, writer_yaml_name
# Third Party
from torch.utils.data import Dataset, IterableDataset
import numpy as np
# Built-In
from typing import Tuple, Dict, TYPE_CHECKING
from pathlib import Path
import yaml
import warnings
if TYPE_CHECKING:
from torchsig.utils.file_handlers.zarr import TorchSigFileHandler
[docs]
class TorchsigIterableDataset(IterableDataset, Seedable):
"""Creates a new TorchSig dataset that generates data infinitely unless `num_samples` inside `dataset_metadata` is defined.
This base class provides the functionality to generate signals and write them to disk if necessary. The dataset will continue
to generate samples infinitely unless a `num_samples` value is defined in the `dataset_metadata`.
"""
[docs]
def __init__(
self,
dataset_metadata: DatasetMetadata | str | dict,
**kwargs
):
"""
Initializes the dataset, creates signal builders, and prepares file handlers based on metadata.
Args:
dataset_metadata (DatasetMetadata | str | dict): The dataset metadata.
**kwargs: Additional keyword arguments for initialization.
"""
Seedable.__init__(self, **kwargs)
self._dataset_metadata: DatasetMetadata = to_dataset_metadata(dataset_metadata)
self._dataset_metadata.add_parent(self)
self.num_samples_generated = 0
self.builders: Dict[str, SignalBuilder] = self._initialize_builders() # initialize builders
def __iter__(self):
return self
def __next__(self):
"""Returns a dataset sample and corresponding targets for a given index.
Returns:
Tuple[np.ndarray, Tuple]: The sample data and the target values.
Raises:
IndexError: If the index is out of bounds of the generated samples.
"""
# user requesting another sample at index +1 larger than current list of generates samples
# generate new sample
sample = self.__generate_new_signal__()
# apply dataset transforms
sample = self.dataset_metadata.impairments.dataset_transforms(sample)
# apply user transforms
for transform in self.dataset_metadata.transforms:
sample = transform(sample)
# convert to DatasetDict
sample = DatasetDict(signal=sample)
targets = []
# apply target transforms
for target_transform in self.dataset_metadata.target_transforms:
# apply transform to all metadatas
sample.metadata = target_transform(sample.metadata)
# get target outputs
target_transform_output = []
for signal_metadata in sample.metadata:
# extract output from metadata
# as required by TT target output field name
signal_output = []
for field in target_transform.targets_metadata:
signal_output.append(signal_metadata[field])
signal_output = tuple(signal_output)
target_transform_output.append(signal_output)
targets.append(target_transform_output)
# convert targets as a list of target transform output ordered by transform
# to ordered by signal
# e.g., [(transform 1 output for all signals), (transform 2 output for all signals), ... ] ->
# [signal 1 outputs, signal 2 outputs, ... ]
targets = list(zip(*targets))
if len(self.dataset_metadata.target_transforms) == 0:
# no target transform applied
targets = sample.metadata
elif self.dataset_metadata.dataset_type == 'narrowband':
# only one signal in list for narrowband
# unwrap targets
targets = [item[0] if len(item) == 1 else item for row in targets for item in row]
# unwrap any target transform output that produced a tuple
targets = targets[0] if len(targets) == 1 else tuple(targets)
else:
# wideband
targets = [tuple([item[0] if len(item) == 1 else item for item in row]) for row in targets]
# unwrap any target transform output that produced a tuple
targets = [row[0] if len(row) == 1 else row for row in targets]
self.num_samples_generated += 1
return sample.data, targets
[docs]
def reset(self):
"""Resets the dataset to its initial state."""
self._dataset_metadata.num_samples_generated = 0
def _initialize_builders(self) -> Dict[str, SignalBuilder]:
"""
Initializes signal builders from the class list based on the signal classes supported by the dataset.
Returns:
Dict[str, SignalBuilder]: A dictionary where the key is the signal class name, and the value is the corresponding
SignalBuilder object.
"""
builders = {}
# for each builder
for builder_name in signal_builders.__all__:
builder = getattr(signal_builders, builder_name) # get builder class
# check if class list has any of the builder's supported classes
matching_classes = set(self._dataset_metadata.class_list) & set(builder.supported_classes)
if len(matching_classes) > 0: # yes
for c in matching_classes:
# add builder
builders[c] = builder(self._dataset_metadata, c,)
builders[c].add_parent(self)
return builders
[docs]
def __str__(self) -> str:
"""Returns a string representation of the dataset, including its metadata and the signal builders.
Returns:
str: String representation of the dataset.
"""
max_width = 100
# first_col_width = 29
# second_col_width = max_width - first_col_width
# array_width_indent = first_col_width + 2
builders_str = "\n".join([f"{key:<15}: {value}" for key, value in self.builders.items()])
class_str = f"{self.__class__.__name__}"
center_width = (max_width - len(class_str)) // 2
return (
f"\n{'-' * center_width} {self.__class__.__name__} {'-' * center_width}\n"
f"{self.dataset_metadata}\n"
f"\nBuilders"
f"{'-' * max_width}\n"
f"{builders_str}\n"
)
[docs]
def __repr__(self):
"""Returns a string representation of the object with all its attributes.
Returns:
str: String representation of the object with its attributes.
"""
return generate_repr_str(self)
def _build_noise_floor(self) -> np.ndarray:
"""Generates the noise floor for the dataset by creating an IQ sample and applying a frequency-domain noise estimation.
Returns:
np.ndarray: The generated IQ samples representing the noise floor.
"""
real_samples = self.random_generator.normal(
0,
1,
self.dataset_metadata.num_iq_samples_dataset
)
imag_samples = self.random_generator.normal(
0,
1,
self.dataset_metadata.num_iq_samples_dataset
)
# combine real and imaginary portions of noise
iq_samples = real_samples + 1j* imag_samples
# compute an estimate of the noise floor in the frequency domain. use a large stride to process a subset
# of the data since not many FFTs are needed to be averaged for the noise
noise_spectrogram_db = compute_spectrogram(iq_samples,self.dataset_metadata.fft_size,self.dataset_metadata.fft_stride*16)
# average over time
noise_fft_db = np.mean(noise_spectrogram_db,axis=1)
# estimate the average noise value in dB in the frequency domain
noise_avg_db = np.mean(noise_fft_db)
# compute the correction factor as the distance from the desired level
correction_db = self.dataset_metadata.noise_power_db-noise_avg_db
# apply the correction
correction = 10**(correction_db/10)
iq_samples = np.sqrt(correction)*iq_samples
iq_samples = iq_samples.astype(np.complex64)
return iq_samples
def __generate_new_signal__(self) -> DatasetSignal:
"""Generates a new dataset signal/sample.
Args:
idx (int): The index for the new signal.
Returns:
DatasetSignal: A new generated dataset signal containing the data and metadata.
"""
# build noise floor
iq_samples = self._build_noise_floor()
# empty signal list initialization
signals = []
# determine number of signals in sample
num_signals_to_generate = self.random_generator.integers(low=self.dataset_metadata.num_signals_min, high = self.dataset_metadata.num_signals_max+1)
# generate individual bursts
for i in range(num_signals_to_generate):
# choose random signal
class_name = self._random_signal_class()
# get builder for signal class
builder = self.builders[class_name]
# generate signal at complex baseband
new_signal = builder.build()
# apply signal transforms
new_signal = self.dataset_metadata.impairments.signal_transforms(new_signal)
# frequency shift signal
# after signal transforms applied at complex baseband
new_signal = frequency_shift_signal(
new_signal,
center_freq_min=self.dataset_metadata.signal_center_freq_min,
center_freq_max=self.dataset_metadata.signal_center_freq_max,
sample_rate=self.dataset_metadata.sample_rate,
frequency_max=self.dataset_metadata.frequency_max,
frequency_min=self.dataset_metadata.frequency_min,
random_generator=self.random_generator,
)
# place signal on iq sample cut
iq_samples[new_signal.metadata.start_in_samples:new_signal.metadata.stop_in_samples] += new_signal.data
# append the signal on the list
signals.append(new_signal)
# form the sample (dataset object)
sample = DatasetSignal(data=iq_samples, signals=signals)
return sample
# Read-Only properties
@property
def dataset_metadata(self):
"""Returns the dataset metadata.
Returns:
DatasetMetadata: The dataset metadata.
"""
return self._dataset_metadata
# Functions
def _random_signal_class(self):
"""Randomly selects which signal to create next.
Returns:
str: A signal class name from the available signal classes.
"""
return self.random_generator.choice(self.dataset_metadata.class_list, p=self.dataset_metadata.class_distribution)
[docs]
class NewTorchSigDataset(Dataset, Seedable):
"""Creates a new TorchSig dataset that generates data infinitely unless `num_samples` inside `dataset_metadata` is defined.
This base class provides the functionality to generate signals and write them to disk if necessary. The dataset will continue
to generate samples infinitely unless a `num_samples` value is defined in the `dataset_metadata`.
"""
[docs]
def __init__(
self,
dataset_metadata: DatasetMetadata | str | dict,
**kwargs
):
"""
Initializes the dataset, creates signal builders, and prepares file handlers based on metadata.
Args:
dataset_metadata (DatasetMetadata | str | dict): The dataset metadata.
**kwargs: Additional keyword arguments for initialization.
"""
Seedable.__init__(self, **kwargs)
self._dataset_metadata: DatasetMetadata = to_dataset_metadata(dataset_metadata)
self._dataset_metadata.add_parent(self)
self.num_samples_generated = 0
self.builders: Dict[str, SignalBuilder] = self._initialize_builders() # initialize builders
self._current_idx: int = 0 # Internal counter for iterator usage
warnings.warn("NewTorchSigDataset will become a torch.IterableDataset in the future.",
FutureWarning
)
def __iter__(self):
return self
def __next__(self):
# Return the next sample
result = self[self._current_idx]
self._current_idx += 1
return result
[docs]
def reset(self):
"""Resets the dataset to its initial state."""
self._dataset_metadata.num_samples_generated = 0
self._current_idx = 0
def _initialize_builders(self) -> Dict[str, SignalBuilder]:
"""
Initializes signal builders from the class list based on the signal classes supported by the dataset.
Returns:
Dict[str, SignalBuilder]: A dictionary where the key is the signal class name, and the value is the corresponding
SignalBuilder object.
"""
builders = {}
# for each builder
for builder_name in signal_builders.__all__:
builder = getattr(signal_builders, builder_name) # get builder class
# check if class list has any of the builder's supported classes
matching_classes = set(self._dataset_metadata.class_list) & set(builder.supported_classes)
if len(matching_classes) > 0: # yes
for c in matching_classes:
# add builder
builders[c] = builder(self._dataset_metadata, c,)
builders[c].add_parent(self)
return builders
def __len__(self) -> int:
"""Returns the number of samples generated in the dataset.
Returns:
int: The number of samples in the dataset.
"""
# If infinite dataset, return how many samples have been generated
if self.dataset_metadata.num_samples is None:
return self.num_samples_generated
# else:
return self.dataset_metadata.num_samples
[docs]
def __str__(self) -> str:
"""Returns a string representation of the dataset, including its metadata and the signal builders.
Returns:
str: String representation of the dataset.
"""
max_width = 100
# first_col_width = 29
# second_col_width = max_width - first_col_width
# array_width_indent = first_col_width + 2
builders_str = "\n".join([f"{key:<15}: {value}" for key, value in self.builders.items()])
class_str = f"{self.__class__.__name__}"
center_width = (max_width - len(class_str)) // 2
return (
f"\n{'-' * center_width} {self.__class__.__name__} {'-' * center_width}\n"
f"{self.dataset_metadata}\n"
f"\nBuilders"
f"{'-' * max_width}\n"
f"{builders_str}\n"
)
[docs]
def __repr__(self):
"""Returns a string representation of the object with all its attributes.
Returns:
str: String representation of the object with its attributes.
"""
return generate_repr_str(self)
def _build_noise_floor(self) -> np.ndarray:
"""Generates the noise floor for the dataset by creating an IQ sample and applying a frequency-domain noise estimation.
Returns:
np.ndarray: The generated IQ samples representing the noise floor.
"""
real_samples = self.random_generator.normal(
0,
1,
self.dataset_metadata.num_iq_samples_dataset
)
imag_samples = self.random_generator.normal(
0,
1,
self.dataset_metadata.num_iq_samples_dataset
)
# combine real and imaginary portions of noise
iq_samples = real_samples + 1j* imag_samples
# compute an estimate of the noise floor in the frequency domain. use a large stride to process a subset
# of the data since not many FFTs are needed to be averaged for the noise
noise_spectrogram_db = compute_spectrogram(iq_samples,self.dataset_metadata.fft_size,self.dataset_metadata.fft_stride*16)
# average over time
noise_fft_db = np.mean(noise_spectrogram_db,axis=1)
# estimate the average noise value in dB in the frequency domain
noise_avg_db = np.mean(noise_fft_db)
# compute the correction factor as the distance from the desired level
correction_db = self.dataset_metadata.noise_power_db-noise_avg_db
# apply the correction
correction = 10**(correction_db/10)
iq_samples = np.sqrt(correction)*iq_samples
iq_samples = iq_samples.astype(np.complex64)
return iq_samples
def __generate_new_signal__(self) -> DatasetSignal:
"""Generates a new dataset signal/sample.
Args:
idx (int): The index for the new signal.
Returns:
DatasetSignal: A new generated dataset signal containing the data and metadata.
"""
# build noise floor
iq_samples = self._build_noise_floor()
# empty signal list initialization
signals = []
# determine number of signals in sample
num_signals_to_generate = self.random_generator.integers(low=self.dataset_metadata.num_signals_min, high = self.dataset_metadata.num_signals_max+1)
# generate individual bursts
for i in range(num_signals_to_generate):
# choose random signal
class_name = self._random_signal_class()
# get builder for signal class
builder = self.builders[class_name]
# generate signal at complex baseband
new_signal = builder.build()
# apply signal transforms
new_signal = self.dataset_metadata.impairments.signal_transforms(new_signal)
# frequency shift signal
# after signal transforms applied at complex baseband
new_signal = frequency_shift_signal(
new_signal,
center_freq_min=self.dataset_metadata.signal_center_freq_min,
center_freq_max=self.dataset_metadata.signal_center_freq_max,
sample_rate=self.dataset_metadata.sample_rate,
frequency_max=self.dataset_metadata.frequency_max,
frequency_min=self.dataset_metadata.frequency_min,
random_generator=self.random_generator,
)
# place signal on iq sample cut
iq_samples[new_signal.metadata.start_in_samples:new_signal.metadata.stop_in_samples] += new_signal.data
# append the signal on the list
signals.append(new_signal)
# form the sample (dataset object)
sample = DatasetSignal(data=iq_samples, signals=signals)
return sample
def _verify_idx(self, idx: int) -> None:
is_infinite_dataset = self.dataset_metadata.num_samples is None
idx_in_bounds = idx >= 0
if not is_infinite_dataset:
idx_in_bounds = idx_in_bounds and idx < self.dataset_metadata.num_samples
sample_already_generated = idx < self.num_samples_generated
# idx_skipping = idx > self.dataset_metadata.num_samples_generated
if idx < 0:
# idx less than zero
raise IndexError(f"index {idx} is less than zero and is out of bounds.")
if not is_infinite_dataset and not idx_in_bounds:
# is finite dataset
# idx is not between 0 and num_samples
raise IndexError(f"index {idx} is out of bounds for finite dataset with {self.dataset_metadata.num_samples} num_samples.")
if sample_already_generated:
# idx < number of generated samples
# requesting previously generated sample
raise IndexError(f"cannot access previously generated samples in {self.__class__.__name__} for index {idx}. Ensure you are accessing dataset in order (0, 1, 2,...) or save dataset with DatasetCreator")
# elif idx_skipping:
# # idx > number of generated samples
# # requesting to generate sample out of order
# # e.g., calling dataset[100] without calling dataset[0]...dataset[99] first in order
# raise IndexError(f"index {idx} requesting sample out of order. Must request next sample at index {self.dataset_metadata.num_samples_generated}. Ensure you are accessing dataset in order (0, 1, 2,...).")
def __getitem__(self, idx: int) -> Tuple[np.ndarray, Tuple]:
"""Returns a dataset sample and corresponding targets for a given index.
Args:
idx (int): The index of the sample to retrieve.
Returns:
Tuple[np.ndarray, Tuple]: The sample data and the target values.
Raises:
IndexError: If the index is out of bounds of the generated samples.
"""
# verifies idx
self._verify_idx(idx)
# user requesting another sample at index +1 larger than current list of generates samples
# generate new sample
sample = self.__generate_new_signal__()
# apply dataset transforms
sample = self.dataset_metadata.impairments.dataset_transforms(sample)
# apply user transforms
for transform in self.dataset_metadata.transforms:
sample = transform(sample)
# convert to DatasetDict
sample = DatasetDict(signal=sample)
targets = []
# apply target transforms
for target_transform in self.dataset_metadata.target_transforms:
# apply transform to all metadatas
sample.metadata = target_transform(sample.metadata)
# get target outputs
target_transform_output = []
for signal_metadata in sample.metadata:
# extract output from metadata
# as required by TT target output field name
signal_output = []
for field in target_transform.targets_metadata:
signal_output.append(signal_metadata[field])
signal_output = tuple(signal_output)
target_transform_output.append(signal_output)
targets.append(target_transform_output)
# convert targets as a list of target transform output ordered by transform
# to ordered by signal
# e.g., [(transform 1 output for all signals), (transform 2 output for all signals), ... ] ->
# [signal 1 outputs, signal 2 outputs, ... ]
targets = list(zip(*targets))
if len(self.dataset_metadata.target_transforms) == 0:
# no target transform applied
targets = sample.metadata
elif self.dataset_metadata.dataset_type == 'narrowband':
# only one signal in list for narrowband
# unwrap targets
targets = [item[0] if len(item) == 1 else item for row in targets for item in row]
# unwrap any target transform output that produced a tuple
targets = targets[0] if len(targets) == 1 else tuple(targets)
else:
# wideband
targets = [tuple([item[0] if len(item) == 1 else item for item in row]) for row in targets]
# unwrap any target transform output that produced a tuple
targets = [row[0] if len(row) == 1 else row for row in targets]
self.num_samples_generated += 1
return sample.data, targets
# Read-Only properties
@property
def dataset_metadata(self):
"""Returns the dataset metadata.
Returns:
DatasetMetadata: The dataset metadata.
"""
return self._dataset_metadata
# Functions
def _random_signal_class(self):
"""Randomly selects which signal to create next.
Returns:
str: A signal class name from the available signal classes.
"""
return self.random_generator.choice(self.dataset_metadata.class_list, p=self.dataset_metadata.class_distribution)
[docs]
class StaticTorchSigDataset(Dataset):
"""Static Dataset class, which loads pre-generated data from a directory.
This class assumes that the dataset has already been generated and saved to disk using a subclass of `NewTorchSigDataset`.
It allows loading raw or processed data from disk for inference or analysis.
Args:
root (str): The root directory where the dataset is stored.
impairment_level (int): Defines impairment level 0, 1, 2.
dataset_type (str): Type of the dataset, either "narrowband" or "wideband".
transforms (list, optional): Transforms to apply to the data (default: []).
target_transforms (list, optional): Target transforms to apply (default: []).
file_handler_class (TorchSigFileHandler, optional): Class used for reading the dataset (default: ZarrFileHandler).
"""
[docs]
def __init__(
self,
root: str,
impairment_level: int,
dataset_type: str,
transforms: list = [],
target_transforms: list = [],
file_handler_class: TorchSigFileHandler = ZarrFileHandler,
train: bool = None,
# **kwargs
):
self.root = Path(root)
self.impairment_level = impairment_level
self.dataset_type = dataset_type
self.transforms = transforms
self.target_transforms = target_transforms
self.file_handler = file_handler_class
self.train = train
# create filepath to saved dataset
# e.g., root/torchsig_narrowband_clean/
self.full_root = dataset_full_path(
dataset_type = self.dataset_type,
impairment_level = self.impairment_level,
train = self.train
)
self.full_root = f"{self.root}/{self.full_root}"
# check dataset data type from writer_info.dataset_yaml_name
with open(f"{self.full_root}/{writer_yaml_name}", 'r') as f:
writer_info = yaml.load(f, Loader=yaml.FullLoader)
self.raw = writer_info['save_type'] == "raw"
# need to create new dataset metadata from dataset_info.yaml
self.dataset_metadata = to_dataset_metadata(f"{self.full_root}/{dataset_yaml_name}")
# dataset size
self.num_samples = self.file_handler.size(self.full_root)
self._verify()
def _verify(self):
# Transforms
self.transforms = verify_transforms(self.transforms)
# Target Transforms
self.target_transforms = verify_target_transforms(self.target_transforms)
# print(self.target_transforms)
# print("verify")
def __len__(self) -> int:
"""Returns the number of samples in the dataset.
Returns:
int: The number of samples in the dataset.
"""
return self.num_samples
def __getitem__(self, idx: int) -> Tuple[np.ndarray, Tuple]:
"""Retrieves a sample from the dataset by index.
Args:
idx (int): The index of the sample to retrieve.
Returns:
Tuple[np.ndarray, Tuple]: The data and targets for the sample.
Raises:
IndexError: If the index is out of bounds.
"""
if idx >= 0 and idx < self.__len__():
# load data and metadata
# data: np.ndarray
# signal_metadatas: List[dict]
if self.raw:
# loading in raw IQ data and signal metadata
data, signal_metadatas = self.file_handler.static_load(self.full_root, idx)
# convert to DatasetSignal
sample = DatasetSignal(
data = data,
signals = signal_metadatas,
dataset_metadata = self.dataset_metadata,
)
# apply user transforms
for t in self.transforms:
sample = t(sample)
# convert to DatasetDict
sample = DatasetDict(signal=sample)
# apply target transforms
targets = []
for target_transform in self.target_transforms:
# apply transform to all metadatas
sample.metadata = target_transform(sample.metadata)
# get target outputs
target_transform_output = []
for signal_metadata in sample.metadata:
# extract output from metadata
# as required by TT target output field name
signal_output = []
for field in target_transform.targets_metadata:
signal_output.append(signal_metadata[field])
signal_output = tuple(signal_output)
target_transform_output.append(signal_output)
targets.append(target_transform_output)
# convert targets as a list of target transform output ordered by transform
# to ordered by signal
# e.g., [(transform 1 output for all signals), (transform 2 output for all signals), ... ] ->
# [signal 1 outputs, signal 2 outputs, ... ]
targets = list(zip(*targets))
if len(self.target_transforms) == 0:
# no target transform applied
targets = sample.metadata
elif self.dataset_type == 'narrowband':
# only one signal in list for narrowband
# unwrap targets
targets = [item[0] if len(item) == 1 else item for row in targets for item in row]
# unwrap any target transform output that produced a tuple
targets = targets[0] if len(targets) == 1 else tuple(targets)
else:
# wideband
targets = [tuple([item[0] if len(item) == 1 else item for item in row]) for row in targets]
# unwrap any target transform output that produced a tuple
targets = [row[0] if len(row) == 1 else row for row in targets]
return sample.data, targets
# else:
# loading in transformed data and targets from target transform
data, targets = self.file_handler.static_load(self.full_root, idx)
return data, targets
else:
raise IndexError(f"Index {idx} is out of bounds. Must be [0, {self.__len__()}]")
[docs]
def __str__(self) -> str:
return f"{self.__class__.__name__}: {self.full_root}"
[docs]
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}"
f"(root={self.root}, "
f"impairment_level={self.impairment_level}, "
f"transforms={self.transforms.__repr__()}, "
f"target_transforms={self.target_transforms.__repr__()}, "
f"file_handler_class={self.file_handler}, "
f"train={self.train})"
)