Source code for torchsig.utils.defaults

from torchsig.datasets.datasets import TorchSigIterableDataset
from torchsig.transforms.impairments import Impairments
from torchsig.utils.data_loading import WorkerSeedingDataLoader
from torchsig.utils.writer import default_collate_fn


[docs] class TorchSigDefaults: """A class for managing default values used in TorchSig. This class provides a centralized location for default configuration values used throughout the TorchSig library, particularly for dataset generation. """
[docs] def __init__(self): """Initialize default dataset metadata values.""" self._default_dataset_metadata = { "num_iq_samples_dataset": 262144, # 512**2 "num_signals_min": 1, "num_signals_max": 1, "fft_size": 512, "fft_stride": 512, "sample_rate": 10000000, "noise_power_db": 0.0, "snr_db_min": 0.0, "snr_db_max": 50.0, "cochannel_overlap_probability": 0.2, "signal_duration_in_samples_min": 262144 * 0.8, "signal_duration_in_samples_max": 262144 * 1.0, "bandwidth_min": 2500000, "bandwidth_max": 3333333, "signal_center_freq_min": -2500000, "signal_center_freq_max": 2499999, "frequency_min": -2500000, "frequency_max": 2499999, }
@property def default_dataset_metadata(self): """Return a copy of the default dataset metadata. Returns: dict: A copy of the default dataset metadata dictionary. """ return self._default_dataset_metadata.copy()
[docs] def default_dataset( impairment_level=None, transforms=[], component_transforms=[], **kwargs ): """Create a default TorchSigIterableDataset with optional impairments. This function creates a dataset with default metadata and applies any specified impairments and transforms. If impairment_level is provided, it adds the corresponding signal and dataset transforms. Args: impairment_level: Optional impairment level to apply to the dataset. If None, no impairments are applied. transforms: List of dataset-level transforms to apply. component_transforms: List of signal-level transforms to apply. **kwargs: Additional keyword arguments to pass to the dataset constructor. Returns: TorchSigIterableDataset: A configured dataset instance. """ defaults_to_use = TorchSigDefaults() dataset_metadata = defaults_to_use.default_dataset_metadata if impairment_level is not None: impairments = Impairments(impairment_level) burst_impairments = impairments.signal_transforms signal_impairments = impairments.dataset_transforms new_transforms = [signal_impairments, *transforms] new_component_transforms = [burst_impairments, *component_transforms] else: new_transforms = transforms new_component_transforms = component_transforms new_dataset = TorchSigIterableDataset( metadata=dataset_metadata, transforms=new_transforms, **kwargs ) for signal_gen in new_dataset.signal_generators: try: signal_gen["transforms"] = new_component_transforms except: pass # this object has no transforms to set return new_dataset
[docs] def default_dataloader( seed=False, collate_fn=default_collate_fn, batch_size=1, num_workers=1, **kwargs ) -> WorkerSeedingDataLoader: """Create a default WorkerSeedingDataLoader with optional seeding. This function creates a data loader with default settings and applies any specified configuration. If seed is provided, it initializes the loader with that seed value. Args: seed: Optional seed value for reproducible data loading. If False, no seeding is applied. collate_fn: Function to use for collating samples into batches. Defaults to default_collate_fn. batch_size: Number of samples per batch. Defaults to 1. num_workers: Number of subprocesses to use for data loading. Defaults to 1. **kwargs: Additional keyword arguments to pass to the dataset constructor. Returns: WorkerSeedingDataLoader: Configured data loader instance. """ dataset = default_dataset(**kwargs) dataloader = WorkerSeedingDataLoader( dataset, collate_fn=collate_fn, batch_size=batch_size, num_workers=num_workers ) if seed: dataloader.seed(seed) return dataloader