Utilities¶
Extra utilities such as writing and reading to disk or type checking are included in the torchsig/utils folder.
The following utilities are available:
Digital Signal Processing Utils¶
Digital Signal Processing (DSP) Utils
- torchsig.utils.dsp.slice_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]¶
Slices the tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Maximum number of samples for a signal
- Raises:
ValueError – If signal is too short to be sliced
- Returns:
Signal with length sliced to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.slice_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]¶
Slices the head and tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Maximum number of samples for a signal
- Raises:
ValueError – If signal is too short to be sliced
- Returns:
Signal with length sliced to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.pad_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]¶
Zero pads the head and tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Desired length of signal
- Raises:
ValueError – If signal is too long to be padded
- Returns:
Signal with length padded to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.upconversion_anti_aliasing_filter(input_signal: ndarray, center_freq: float, bandwidth: float, sample_rate: float, frequency_max: float, frequency_min: float)[source]¶
Applies a BPF to avoid aliasing
Upconversion of a signal away from baseband can force energy to alias around the -fs/2 or +fs/2 boundary, depending on the amount of sidelobes and the amount the signal has been frequency shifted by. The function checks the combination of the center frequency and bandwidth to see if any of the energy exceeds the limits specified by frequency_max and frequency_min, and if so, builds and applies the filter.
- Parameters:
input_signal (np.ndarray) – Input signal modulated to band-pass
center_freq (float) – Center frequency of the signal
bandwidth (float) – Bandwidth of the signal
sample_rate (float) – Sample rate of the signal
frequency_max (float) – The maximum frequency where energy content can reside
frequency_min (float) – The minimum frequency where energy content can reside
- Returns:
Anti-aliased signal float: Updated center frequency for the bounding box float: Updated bandwidth for the bounding box
- Return type:
np.ndarray
- torchsig.utils.dsp.is_even(number)[source]¶
Is the number even?
Returns true if the number is even, false if the number is odd
- Parameters:
number – Any number
- Returns:
Returns true if number is even, false if number is odd
- Return type:
- torchsig.utils.dsp.is_odd(number)[source]¶
Is the number odd?
Returns true if the number is odd, false if the number is even
- Parameters:
number – Any number
- Returns:
Returns true if number is odd, false if number is even
- Return type:
- torchsig.utils.dsp.is_multiple_of_4(number)[source]¶
Is the number a multiple of 4?
Returns true if the number is a multiple of 4, false otherwise. A number os a multiple of 4 if both the number is even and the number divided by 2 is even.
- Parameters:
number – Any number
- Returns:
Returns true if number is a multiple of 4, false otherwise
- Return type:
- torchsig.utils.dsp.interpolate_power_of_2_resampler(input_signal: ndarray, interpolation_rate: int) ndarray[source]¶
Applies power of 2 resampling
- Parameters:
input_signal (np.ndarray) – Input signal to be resampled
interpolation_rate (_type_) – Interpolate rate, must be greater than 0. For interpolation,
2. (interpolate_rate >=)
- Raises:
ValueError – Throws error if the interpolation rate is not an integer
ValueError – Throws error if the interpolation rate is not >= 2.
ValueError – Throws error if the interpolation rate is not a power of 2.
- Returns:
Interpolated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.design_half_band_filter(stage_number: int = 0, passband_percentage: float = 0.8, attenuation_db: float = 120) ndarray[source]¶
Designs half band filter weights for dyadic resampling
Implements the filter design for dyadic (power of 2) resampling, see fred harris, Multirate Signal Processing for Communication Systems, 2nd Edition, Chapter 8.7.
The dyadic filter uses a series of stages, a multi-stage structure, to efficiently implement large resampling rates. For interpolation, each additional stage increases the resampling rate by a factor of 2, and therefore the signal bandwidth becomes a smaller relative proportion of the sample rate. Therefore, both the passband edge will be decreased for each successive stage, which also allows for the transition bandwidth to be increased for each successive stage, thereby also reducing the amount of computation needed.
- Parameters:
stage_number (int, optional) – Stage number in the cascade, must be greater than or equal to zero. Defaults to 0.
passband_percentage (float, optional) – The proportion of the available bandwidth used for the passband edge. The default of 0.8 translates into the passband edge being 80% of maximum passband of fs/4, or 0.8*fs/4.
attenuation_db (float, optional) – The sidelobe attenuation level, must be greater than zero. Defaults to 120.
- Raises:
ValueError – Checks to ensure that the filter length has the appropriate length
- Returns:
Half band filter weights
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_resampler(input_signal: ndarray, resample_rate: float) ndarray[source]¶
Multi-stage polyphase filterbank-based resampling.
If the resampling rate is 1.0, then nothing is done and then same input signal is returned. If the resampling rate is greater than 1, then it performs interpolation using multistage_polyphase_interpolator. If the resampling rate is less than 1, then it performs interpolation using multistage_polyphase_decimator.
- Parameters:
input_signal (np.ndarray) – The input signal to be resampled
resample_rate (float) – The resampling rate. Must be greater than 0.
- Returns:
The resampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_decimator(input_signal: ndarray, decimation_rate: float) ndarray[source]¶
Multi-stage polyphase filterbank-based decimation
The decimation is applied with two possible stages. The first stage implements the an integer rate portion and the second stage implements the fractional rate portion.
For example, a resampling rate of 0.4 is a decimation by 2.5. The decimation of 2.5 is represented by an integer decimation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore a decimation by 2 is applied followed by a decimation of 1.25.
- Parameters:
input_signal (np.ndarray) – The input signal
decimation_rate (float) – The decimation rate. Must be greater or equal to 1.
- Returns:
The decimated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_interpolator(input_signal: ndarray, resample_rate_ideal: float) ndarray[source]¶
Multi-stage polyphase filterbank-based interpolation
The interpolation is applied with two possible stages. The first stage implements the the fractional rate portion and the the second stage implements the integer rate portion.
For example, a resampling rate of 2.5 is an interpolation of 2.5. The interpolation of 2.5 is represented by an integer interpolation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore an interpolation of of 1.25 is applied followed by an interpolation of 2.
- Parameters:
input_signal (np.ndarray) – The input signal
decimation_rate (float) – The interpolation rate. Must be greater or equal to 1.
- Returns:
The interpolated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.polyphase_fractional_resampler(input_signal: ndarray, fractional_rate: float) ndarray[source]¶
Fractional rate polyphase resampler
Implements a fractional rate resampler through the SciPy upfirdn() function with a large number of branches. A fixed “up” rate of 10,000 is used and the fractional rate then deterimes the “down” rate, such that up/down reasonably approximates the desired fractional resampling rate.
- Parameters:
input_signal (np.ndarray) – Input signal to be resampled
fractional_rate (float) – The fractional interpolation rate, must be greater than 0.
- Returns:
Resampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.prototype_polyphase_filter_interpolation(num_branches: int, attenuation_db=120) ndarray[source]¶
Designs polyphase filterbank weights for interpolation
- torchsig.utils.dsp.prototype_polyphase_filter_decimation(num_branches: int, attenuation_db=120) ndarray[source]¶
Designs polyphase filterbank weights for decimation
- torchsig.utils.dsp.prototype_polyphase_filter(num_branches: int, attenuation_db: float = 120) ndarray[source]¶
Designs the prototype filter for a polyphase filter bank
- torchsig.utils.dsp.polyphase_integer_interpolator(input_signal: ndarray, interpolation_rate: int) ndarray[source]¶
Integer-rate polyphase filterbank-based interpolation
- Parameters:
input_signal (np.ndarray) – Input signal to be interpolated
interpolation_rate (int) – The interpolation rate
- Raises:
ValueError – Throws an error if the right number of samples are not produced
- Returns:
Interpolated output signal
- Return type:
np.ndarray
- torchsig.utils.dsp.polyphase_decimator(input_signal: ndarray, decimation_rate: int) ndarray[source]¶
Integer-rate polyphase filterbank-based decimation
- Parameters:
input_signal (np.ndarray) – Input signal to be decimated
decimation_rate (int) – The decimation rate
- Raises:
ValueError – Throws an error if the right number of samples are not produced
- Returns:
Decimated output signal
- Return type:
np.ndarray
- torchsig.utils.dsp.upsample(signal: ndarray, rate: int) ndarray[source]¶
Upsamples a signal
Upsamples a signal by insertion of zeros. Ex: upsample by 2 produces: sample, 0, sample, 0, sample 0, etc., and upsample by 3 produces sample, 0, 0, sample, 0, 0, etc.
- Parameters:
signal (np.ndarray) – The input signal
rate (int) – The upsampling rate, must be > 1
- Raises:
ValueError – Throws an error when the rate is less or equal to 1
ValueError – Throws an error when the rate is not an integer
- Returns:
The upsampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.center_freq_from_lower_upper_freq(lower_freq: float, upper_freq: float) float[source]¶
Calculates center frequency from lower frequency and upper frequency
- torchsig.utils.dsp.bandwidth_from_lower_upper_freq(lower_freq: float, upper_freq: float) float[source]¶
Calculates bandwidth from lower frequency and upper frequency
- torchsig.utils.dsp.lower_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float[source]¶
Calculates the lower frequency from center frequency and bandwidth
- torchsig.utils.dsp.upper_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float[source]¶
Calculates upper frequency from center frequency and bandwidth
- torchsig.utils.dsp.frequency_shift(signal: ndarray, frequency: float, sample_rate: float) ndarray[source]¶
Performs a frequency shift
- torchsig.utils.dsp.compute_spectrogram(iq_samples: ndarray, fft_size: int, fft_stride: int) ndarray[source]¶
Computes two-dimensional spectrogram values in dB.
- Parameters:
iq_samples (np.ndarray) – Input signal.
fft_size (int) – The size of the FFT in number of bins.
fft_stride (int) – The stride is the amount by which the input sample pointer increases for each FFT. When fft_stride=fft_size, then there is no overlap of input samples in successive FFTs. When fft_stride=fft_size/2, there is 50% overlap of input samples between successive FFTs.
- Raises:
ValueError – Throws an error if fft_stride is less than 0 or greater than fft_size.
- Returns:
Two-dimensional array of spectrogram values in dB.
- Return type:
np.ndarray
- torchsig.utils.dsp.estimate_tone_bandwidth(num_samples: int, sample_rate: float)[source]¶
Estimate the bandwidth of a tone
The bandwidth of a tone is completely defined by the number of samples in the time-series.
- torchsig.utils.dsp.convolve(signal: ndarray, taps: ndarray) ndarray[source]¶
Wrapper function to implement convolution()
A wrapped version of SciPy’s convolve(), which discards trasition regions resulting from the convolution process.
- Parameters:
signal (np.ndarray) – The input signal
taps (np.ndarray) – The filter weights
- Returns:
The convolution output
- Return type:
np.ndarray
- torchsig.utils.dsp.low_pass(cutoff: float, transition_bandwidth: float, sample_rate: float, attenuation_db: float = 120) ndarray[source]¶
Low-pass filter design
- Parameters:
cutoff (float) – The filter cutoff, 0 < cutoff < sample_rate/2. Must be in the same units as sample_rate.
transition_bandwidth (float) – The transition bandwidth of the filter, 0 < transition_bandwidth < sample_rate/2. Must be in the same units as sample_rate.
sample_rate (float) – The sampling rate associated with the filter design.
attenuation_db (float, optional) – Sidelobe attenuation level. Defaults to 120.
- Returns:
Filter weights
- Return type:
np.ndarray
- torchsig.utils.dsp.estimate_filter_length(transition_bandwidth: float, attenuation_db: float, sample_rate: float) int[source]¶
Estimates FIR filter length
Estimate the length of an FIR filter using fred harris’ approximation, Multirate Signal Processing for Communication Systems, Second Edition, p.59.
- Parameters:
- Returns:
The estimated filter length
- Return type:
- torchsig.utils.dsp.srrc_taps(iq_samples_per_symbol: int, filter_span_in_symbols: int, alpha: float = 0.35) ndarray[source]¶
Designs square-root raised cosine (SRRC) pulse shaping filter
- Parameters:
iq_samples_per_symbol (int) – The samples-per-symbol (SPS) of the underlying modulation, equivalent to the oversampling rate.
filter_span_in_symbols (int) – The filter span in number of symbols.
alpha (float, optional) – The alpha roll-off value of the pulse shaping filter, which is the amount of excess bandwidth. Defaults to 0.35.
- Returns:
SRRC filter weights
- Return type:
np.ndarray
- torchsig.utils.dsp.gaussian_taps(samples_per_symbol: int, bt: float = 0.35) ndarray[source]¶
Designs Gaussian filter weights
- torchsig.utils.dsp.low_pass_iterative_design(cutoff: float, transition_bandwidth: float, sample_rate: float, desired_attenuation_db: float = 120) ndarray[source]¶
Iteratively designs a low-pass filter using the window method, adjusting the filter length to meet the desired stopband attenuation.
The filter design process starts with an initial filter design, and then iteratively increases the filter length based on the measured stopband attenuation. This process continues until the desired stopband attenuation is achieved or the maximum number of iterations is reached.
- Parameters:
cutoff (float) – The cutoff frequency of the low-pass filter (in Hz).
transition_bandwidth (float) – The transition bandwidth of the filter (in Hz).
sample_rate (float) – The sample rate of the system (in Hz).
desired_attenuation_db (float, optional) – The desired stopband attenuation in decibels (dB). Defaults to 120 dB.
- Returns:
The designed low-pass filter coefficients.
- Return type:
np.ndarray
- Raises:
Warning – If the filter design process exceeds the maximum number of iterations, a warning is raised and the initial filter design is returned.
Notes
The iterative design process adjusts the filter length based on the ratio of desired and measured stopband attenuation. If the process doesn’t converge within a reasonable number of iterations, the initial design is returned.
- torchsig.utils.dsp.noise_generator(num_samples: int = 1024, power: float = 1.0, color: str = 'white', continuous: bool = True, rng: Generator | None = None) ndarray[source]¶
Generates additive complex noise of specified power and type.
- Parameters:
num_samples (int) – number of noise samples to generate. Default to 1024
power (float) – Desired noise power (linear, positive). Defaults to 1.0 W (0 dBW).
color (str) – Noise color, supports ‘white’, ‘pink’, or ‘red’ noise frequency spectrum types. Defaults to ‘white’.
continuous (bool) – Sets noise to continuous (True) or impulsive (False). Defaults to True.
rng (np.random.Generator, optional) – Random number generator. Defaults to np.random.default_rng(seed=None).
- Raises:
ValueError – If invalid noise power specified.
ValueError – If unsupported noise type specified.
- Returns:
Complex noise samples with specified power.
- Return type:
np.ndarray
- torchsig.utils.dsp.update_signal_snr_bandwidth(dataset: TorchSigIterableDataset, new_signal: Signal) None[source]¶
Updates the SNR and bandwidth of a signal based on dataset parameters.
This function performs two main operations: 1. Corrects the SNR of the signal by comparing the estimated SNR from the signal’s
spectrogram with the target SNR range defined in the signal metadata.
Updates the signal’s bandwidth metadata to better fit the bounding box by estimating the 99% bandwidth from the signal’s spectral content.
- Parameters:
dataset (TorchSigIterableDataset) – The dataset object containing FFT parameters, noise floor information, and other metadata needed for processing.
new_signal (Signal) – The signal object to be processed, containing: - data: The time-domain signal data - snr_db_min: Minimum target SNR in dB - snr_db_max: Maximum target SNR in dB - bandwidth: Current bandwidth value (will be updated)
- Returns:
The function modifies the new_signal object in place.
- Return type:
None
Notes
The SNR correction is performed by: 1. Computing a spectrogram of the signal 2. Estimating the current SNR from the spectrogram 3. Calculating a correction factor to match the target SNR 4. Applying this correction to the signal data
The bandwidth update is performed by: 1. Finding frequency bins where the signal exceeds the noise floor by 3dB 2. Determining the frequency range of these bins 3. Widening this range by half the FFT frequency resolution 4. Updating the signal’s bandwidth metadata with this new range
The signal data itself is not resampled - only the metadata is updated.
Data Coordinate System¶
Library for overlap detection in spectrograms to control co-channel interference.
This module provides classes and functions to define 2D coordinates and axis-aligned rectangles, and to detect overlaps between rectangles using line-segment intersection and containment tests.
- class torchsig.utils.coordinate_system.Coordinate(x: float, y: float)[source]¶
Bases:
objectRepresents a point in 2D space with x and y coordinates.
- class torchsig.utils.coordinate_system.Rectangle(lower_coord: Coordinate, upper_coord: Coordinate)[source]¶
Bases:
objectRepresents an axis-aligned rectangle defined by two opposite corners.
The rectangle is built from a lower-left and an upper-right corner, from which the other two corners are inferred.
- coord_lower_left¶
Lower-left corner.
- Type:
- coord_upper_right¶
Upper-right corner.
- Type:
- coord_upper_left¶
Upper-left corner.
- Type:
- coord_lower_right¶
Lower-right corner.
- Type:
- torchsig.utils.coordinate_system.counter_clock_wise(a: Coordinate, b: Coordinate, c: Coordinate) bool[source]¶
Determine if three points a, b, c are in counter-clockwise order.
- Parameters:
a (Coordinate) – First point.
b (Coordinate) – Second point.
c (Coordinate) – Third point.
- Returns:
True if the sequence (a → b → c) is counter-clockwise.
- Return type:
- torchsig.utils.coordinate_system.line_intersection(a: Coordinate, b: Coordinate, c: Coordinate, d: Coordinate) bool[source]¶
Check if the line segments AB and CD intersect.
Uses the counter-clockwise orientation test.
- Parameters:
a (Coordinate) – First endpoint of segment AB.
b (Coordinate) – Second endpoint of segment AB.
c (Coordinate) – First endpoint of segment CD.
d (Coordinate) – Second endpoint of segment CD.
- Returns:
True if segments AB and CD intersect.
- Return type:
- torchsig.utils.coordinate_system.is_within_range(test_coord_x: float, rectangle_left_x: float, rectangle_right_x: float) bool[source]¶
Check if a coordinate lies within a closed interval on the x-axis.
- torchsig.utils.coordinate_system.is_corner_in_rectangle(corner_coord: Coordinate, reference_box: Rectangle) bool[source]¶
Check if a corner point is within the bounds of a reference rectangle.
- Parameters:
corner_coord (Coordinate) – The corner to test.
reference_box (Rectangle) – The rectangle in which to test containment.
- Returns:
True if the corner is inside reference_box (including edges).
- Return type:
- torchsig.utils.coordinate_system.is_rectangle_inside_rectangle(rectangle_1: Rectangle, rectangle_2: Rectangle) bool[source]¶
Check if rectangle_1 is completely inside rectangle_2.
Tests whether all four corners of rectangle_1 lie within rectangle_2.
- torchsig.utils.coordinate_system.is_rectangle_overlap(rectangle_a: Rectangle, rectangle_b: Rectangle) bool[source]¶
Check if two rectangles overlap by intersection or containment.
- Overlap occurs if:
Any side of rectangle_a intersects any side of rectangle_b.
One rectangle is fully contained within the other.
Reading/Writing Utils¶
Writer¶
Dataset Writer Utils
- torchsig.utils.writer.default_collate_fn(batch)[source]¶
Collates a batch by zipping its elements together.
- class torchsig.utils.writer.DatasetCreator(dataloader: DataLoader = None, dataset_length: int | None = None, root: str = '.', overwrite: bool = True, tqdm_desc: str | None = None, file_handler: FileWriter = <class 'torchsig.utils.file_handlers.hdf5.HDF5Writer'>, multithreading: bool = True, **kwargs)[source]¶
Bases:
objectClass for creating a dataset and saving it to disk in batches.
This class generates a dataset if it doesn’t already exist on disk. It processes the data in batches and saves it using a specified file handler. The class allows setting options like whether to overwrite existing datasets, batch size, and number of worker threads.
- dataloader¶
The DataLoader used to load data in batches.
- Type:
DataLoader
- root¶
The root directory where the dataset will be saved.
- Type:
Path
- file_handler¶
The file handler used for saving the dataset.
- Type:
- get_writing_info_dict() dict[str, Any][source]¶
Returns a dictionary with information about the dataset being written.
This method gathers information regarding the root, overwrite status, batch size, number of workers, file handler class, and the save type of the dataset.
- Returns:
Dictionary containing the dataset writing configuration.
- Return type:
Dict[str, Any]
- create() None[source]¶
Creates the dataset on disk by writing batches to the file handler.
This method generates the dataset in batches and saves it to disk. If the dataset already exists and overwrite is set to False, it will skip regeneration.
The method also writes the dataset metadata and writing information to YAML files.
- Raises:
ValueError – If the dataset is already generated and overwrite is set to False.
Data Loading¶
Collate function and DataLoader with worker seeding for TorchSig. Provides:
metadata_padding_collate_fn: pads variable-length metadata in each batch.
WorkerSeedingDataLoader: seeds each worker process differently for reproducibility.
- torchsig.utils.data_loading.metadata_padding_collate_fn(batch)[source]¶
Collate a batch of (data, metadata_list) pairs, padding metadata to equal lengths.
- Metadata for each sample is a list of dicts. This function:
Finds the maximum metadata-list length in the batch.
Pads shorter metadata lists with default values.
Stacks data tensors and metadata fields into batched tensors.
- Parameters:
batch – A list where each element is a tuple of: - x: any object convertible to a NumPy array (e.g., tensor, array). - y: a list of metadata dicts, where each dict shares the same set of keys.
- Returns:
data_tensor: stacked torch.Tensor of all x values, shape (batch_size, …).
metadata_tensors: dict mapping each metadata key to a Tensor of shape (batch_size, max_sequence_length).
- Return type:
A tuple containing
- Raises:
ValueError – if any element in batch is not a tuple of length 2.
- class torchsig.utils.data_loading.WorkerSeedingDataLoader(dataset, seed=None, **kwargs)[source]¶
Bases:
DataLoader,SeedableDataLoader that seeds each worker process differently using a shared seed.
This loader prohibits external worker_init_fn definitions and sets its own init function to ensure reproducible randomness in multi-worker pipelines.
- seed(seed_val)[source]¶
Set the seed value for both the loader and its dataset.
- Parameters:
seed_val – The seed value to set.
- init_worker_seed(worker_id)[source]¶
Set a unique random seed for each worker process.
Uses the shared random_generator from the Seedable mixin to derive a new seed per worker_id.
- Parameters:
worker_id – The integer ID of the worker process.
- dataset: Dataset[_T_co]¶
- sampler: Sampler | Iterable¶
YAML Utils¶
YAML utilities
- torchsig.utils.yaml.custom_representer(dumper, value: list) Dumper[source]¶
Custom representer for YAML to handle sequences (lists).
This function customizes how lists are represented in the YAML output, using flow style for sequences (inline lists).
- Parameters:
dumper – The YAML dumper responsible for serializing the data.
value – The list to be represented in YAML.
- Returns:
The dumper with the custom representation for the list.
- torchsig.utils.yaml.write_dict_to_yaml(filename: str, info_dict: dict[str, Any]) None[source]¶
Writes a dictionary to a YAML file with customized settings.
This function writes the provided info_dict to a YAML file. It customizes the representation of lists by using the custom_representer, and it uses specific formatting options (e.g., no sorting of keys, custom line width).
- Parameters:
filename – The name of the YAML file to which the dictionary will be written.
info_dict – The dictionary to be written to the YAML file.
- Returns:
This function does not return any value.
- Return type:
None
- torchsig.utils.yaml.dataset_from_yaml_dict(yaml_dict: dict[str, Any]) TorchSigIterableDataset[source]¶
Creates a TorchSigIterableDataset from a YAML dictionary.
Passes data from the yaml_dict as needed into the TorchSigIterableDataset constructor and returns a new TorchSigIterableDataset.
- Parameters:
yaml_dict – dictionary containing dataset configuration with keys: - “dataset_metadata”: Dataset metadata - “target_labels”: List of target labels - “seed”: Random seed value
- Returns:
Configured TorchSigIterableDataset instance.
- torchsig.utils.yaml.load_dataset_yaml(filepath: str) TorchSigIterableDataset[source]¶
Loads YAML data from specified filepath and constructs a dataset.
Loads YAML data from the specified filepath and uses it to construct and return a new TorchSigIterableDataset.
- Parameters:
filepath – Path to the YAML file containing dataset configuration.
- Returns:
Configured TorchSigIterableDataset instance.
- torchsig.utils.yaml.save_dataset_yaml(filepath: str, dataset: TorchSigIterableDataset) None[source]¶
Saves dataset configuration to a YAML file.
Saves YAML data to the specified filepath to represent the input TorchSigIterableDataset.
- Parameters:
filepath – Path where the YAML file will be saved.
dataset – TorchSigIterableDataset instance to save.
- torchsig.utils.yaml.dataset_metadata_to_yaml_dict(dataset_metadata: Any) dict[str, Any][source]¶
Converts DatasetMetadata to a dictionary for YAML storage.
Returns a dictionary representation of a DatasetMetadata object for storing as YAML.
- Parameters:
dataset_metadata – DatasetMetadata object to convert.
- Returns:
dictionary containing the metadata for YAML storage.
File Handlers¶
File Handler Base and Utility Classes for reading and writing datasets to/from disk.
- class torchsig.utils.file_handlers.base_handler.FileWriter(root: str, **kwargs)[source]¶
Bases:
object- setup() None[source]¶
Prepare resources before writing begins.
This resets the root folder and then calls the subclass _setup.
- class torchsig.utils.file_handlers.base_handler.FileReader(root: str, **kwargs)[source]¶
Bases:
object
- class torchsig.utils.file_handlers.base_handler.BaseFileHandler[source]¶
Bases:
objectFile handler base class. Not be instantiated.
- Usage:
>>> BaseFileHandler.create_handler(mode = "r", root = "./) # create a reader >>> BaseFileHandler.create_handler(mode = "w", root = "./) # create a writer
- reader_class¶
alias of
FileReader
- writer_class¶
alias of
FileWriter
- static create_handler(mode: str, root: str, **kwargs) FileWriter | FileReader[source]¶
Creates FileWriter or FileReader
- Parameters:
- Raises:
ValueError – invalid model
- Returns:
FileHandler’s reader or writer.
- Return type:
HDF5 File Handler for TorchSig datasets.
High-performance HDF5 storage with optimized compression and chunking.
- torchsig.utils.file_handlers.hdf5.populate_hdf5_group_with_metadata(group, metadata_obj)[source]¶
Makes sure this and all parent metadata objects are represented in the hdf5 group (returns true iff a new group was added)
- torchsig.utils.file_handlers.hdf5.populate_hdf5_group_with_signal_data(group, signal)[source]¶
Makes sure this and all parent metadata objects are represented in the hdf5 group (returns true iff a new group was added)
- torchsig.utils.file_handlers.hdf5.populate_hdf5_group_with_component_signals(group, signal)[source]¶
- torchsig.utils.file_handlers.hdf5.populate_hdf5_group_with_signal(group, signal, index=True)[source]¶
- torchsig.utils.file_handlers.hdf5.populate_hdf5_group_with_signals(group, signals, index=True)[source]¶
- class torchsig.utils.file_handlers.hdf5.HDF5Writer(root, compression: str = 'gzip', compression_opts: int = 6, shuffle: bool = True, fletcher32: bool = True, chunk_cache_size: int = 10485760, max_batches_in_memory: int = 4)[source]¶
Bases:
FileWriterHandles writing Signal data to HDF5 files with specified compression and buffering.
- torchsig.utils.file_handlers.hdf5.fill_object_metadata_from_group_and_id(obj, group, id_str)[source]¶
- class torchsig.utils.file_handlers.hdf5.HDF5Reader(root)[source]¶
Bases:
FileReaderHandles reading Signal data from HDF5 files.
- class torchsig.utils.file_handlers.hdf5.HDF5FileHandler[source]¶
Bases:
BaseFileHandlerHDF5FileHandler creates a reader or writer for HDF5 files.
- reader_class¶
alias of
HDF5Reader
- writer_class¶
alias of
HDF5Writer
- static create_handler(mode: str, root: str, **kwargs) HDF5Writer | HDF5Reader[source]¶
Creates an instance of HDF5Reader or HDF5Writer based on the mode.
- Parameters:
- Returns:
The created file handler.
- Return type:
- Raises:
ValueError – If the mode is invalid.
Variable and Data Verification Utils¶
Data verification and error checking utils
- torchsig.utils.verify.verify_dict(d: dict, name: str, required_keys: list = [], required_types: list = [])[source]¶
Verifies that the value d is a dictionary and optionally checks for required keys and their types.
- Parameters:
d (dict) – The value to be checked.
name (str) – The name of the value to be used in error messages.
required_keys (list, optional) – A list of required keys in the dictionary. Defaults to an empty list.
required_types (list, optional) – A list of types for each required key. Defaults to an empty list.
- Raises:
ValueError – If d is not a dictionary, or if any required key is missing or has an incorrect type.
- Returns:
The verified dictionary d.
- Return type:
- torchsig.utils.verify.verify_distribution_list(distro: list[float], required_length: int, distro_name: str, list_name: str) list[float][source]¶
Verifies and normalizes a given distribution list.
If the distribution list is None, it assumes a uniform distribution and returns it as is. If the distribution list is not of the required length or does not sum to 1.0, it raises an error or normalizes the list to sum to 1.0.
- Parameters:
distro (List[float]) – The distribution list to verify. Can be None for a uniform distribution.
required_length (int) – The expected length of the distribution list.
distro_name (str) – The name of the distribution list (used for error messages).
list_name (str) – The name of the list this distribution corresponds to (used for error messages).
- Returns:
The verified and possibly normalized distribution list.
- Return type:
List[float]
- Raises:
ValueError – If the distribution list is not of the required length or does not sum to 1.0 and cannot be normalized.
- torchsig.utils.verify.verify_float(f: float, name: str, low: float = 0.0, high: float | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) float[source]¶
Verifies that the value f is a float and within the specified bounds.
- Parameters:
f (float) – The value to be checked.
name (str) – The name of the value to be used in error messages.
low (float, optional) – The lower bound of the value. Defaults to 0.0.
high (float, optional) – The upper bound of the value. Defaults to None.
clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.
clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.
exclude_low (bool, optional) – If True, f must be strictly greater than low. Defaults to False.
exclude_high (bool, optional) – If True, f must be strictly less than high. Defaults to False.
- Raises:
ValueError – If f is not a float or out of bounds.
- Returns:
The verified float value f.
- Return type:
- torchsig.utils.verify.verify_int(a: int, name: str, low: int = 0, high: int | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) int[source]¶
Verifies that the value a is an integer and within the specified bounds.
- Parameters:
a (int) – The value to be checked.
name (str) – The name of the value to be used in error messages.
low (int, optional) – The lower bound of the value. Defaults to 0.
high (int, optional) – The upper bound of the value. Defaults to None.
clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.
clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.
exclude_low (bool, optional) – If True, a must be strictly greater than low. Defaults to False.
exclude_high (bool, optional) – If True, a must be strictly less than high. Defaults to False.
- Raises:
ValueError – If a is not an integer or out of bounds.
- Returns:
The verified integer value a.
- Return type:
- torchsig.utils.verify.verify_list(l: list, name: str, no_duplicates: bool = False, data_type=None) list[source]¶
Verifies that the value l is a list and optionally checks for duplicates or verifies item types.
- Parameters:
l (list) – The value to be checked.
name (str) – The name of the value to be used in error messages.
no_duplicates (bool, optional) – If True, raises an error if the list contains duplicates. Defaults to False.
data_type (type, optional) – The type each item in the list should have. Defaults to None.
- Raises:
ValueError – If l is not a list, if it contains duplicates (when no_duplicates=True), or if any item in the list is not of the required type.
- Returns:
The verified list l.
- Return type:
- torchsig.utils.verify.verify_metadata_transforms(tt: MetadataTransform) list[MetadataTransform | callable][source]¶
Verifies that the value tt is a valid target transform, which can be a single target transform or a list of transforms.
- Parameters:
tt (MetadataTransform) – The target transform(s) to be checked.
- Raises:
ValueError – If tt is not a valid target transform.
- Returns:
The verified list of target transforms.
- Return type:
List[MetadataTransform | callable]
- torchsig.utils.verify.verify_numpy_array(n: ndarray, name: str, min_length: int | None = None, max_length: int | None = None, exact_length: int | None = None, data_type=None) ndarray[source]¶
Verifies that the value n is a NumPy array and optionally checks its length or item types.
- Parameters:
n (np.ndarray) – The value to be checked.
name (str) – The name of the value to be used in error messages.
min_length (int, optional) – The minimum length of the array. Defaults to None.
max_length (int, optional) – The maximum length of the array. Defaults to None.
exact_length (int, optional) – The exact length of the array. Defaults to None.
data_type (type, optional) – The type each item in the array should have. Defaults to None.
- Raises:
ValueError – If n is not a NumPy array or its length is not within the specified bounds, or if any item in the array is not of the required type.
- Returns:
The verified NumPy array n.
- Return type:
np.ndarray
- torchsig.utils.verify.verify_str(s: str, name: str, valid: list[str] = [], str_format: str = 'lower') str[source]¶
Verifies that the value s is a string and optionally formats it according to the specified format.
- Parameters:
s (str) – The value to be checked.
name (str) – The name of the value to be used in error messages.
valid (List[str], optional) – A list of valid string values. Defaults to an empty list.
str_format (str, optional) – The format for the string. Can be “lower”, “upper”, or “title”. Defaults to “lower”.
- Raises:
ValueError – If s is not a string or if it is not in the list of valid values.
- Returns:
The verified string value s in the specified format.
- Return type:
- torchsig.utils.verify.verify_transforms(t: Transform) list[Transform | callable][source]¶
Verifies that the value t is a valid transform, which can be a single transform or a list of transforms.
- Parameters:
t (Transform) – The transform(s) to be checked.
- Raises:
ValueError – If t is not a valid transform.
- Returns:
The verified list of transforms.
- Return type:
List[Transform | callable]
Printing Utils¶
Contains Helpful methods for properly implementing __str__ and __repr__ methods of classes
- torchsig.utils.printing.generate_repr_str(class_object: Any, exclude_params: list[str] = []) str[source]¶
Generates a string representation of the class object, excluding specified parameters.
This function creates a human-readable string representation of the given class object, including its class name and parameters. It excludes any parameters specified in the exclude_params list. If the class object is an instance of Seedable, certain attributes related to seeding are handled specifically.
- Parameters:
class_object (Any) – The class object to generate the string representation for.
exclude_params (List[str], optional) – A list of parameter names to exclude from the string representation. Defaults to an empty list.
- Returns:
A formatted string representation of the class object with parameters.
- Return type:
- Raises:
AttributeError – If the class object does not have a __dict__ attribute or any other required attributes for the operation.
Example
>>> class Example: >>> def __init__(self, param1, param2): >>> self.param1 = param1 >>> self.param2 = param2 >>> e = Example(1, 2) >>> generate_repr_str(e) 'Example(param1=1,param2=2)'
Notes
If the class object is an instance of Seedable, the seed and parent attributes will be added back into the string representation.
- torchsig.utils.printing.dataset_metadata_str(dataset_metadata, max_width: int = 100, first_col_width: int = 29, array_width_indent_offset: int = 2) str[source]¶
Custom string representation for the class.
This method returns a formatted string that provides a detailed summary of the object’s key attributes, including signal parameters, dataset configuration, and transform details. It uses textwrap.fill to format long attributes such as lists or arrays into a neatly wrapped format for easier readability.
The string includes information on the dataset’s configuration, signal characteristics, transformations, and other attributes in a human-readable way. The result is intended to provide a concise yet comprehensive overview of the object’s state, useful for debugging, logging, or displaying object details.
- Parameters:
dataset_metadata (Any) – The dataset metadata object to generate a string for.
max_width (int, optional) – Maximum width of the output string. Defaults to 100.
first_col_width (int, optional) – Width of the first column in the output string. Defaults to 29.
array_width_indent_offset (int, optional) – Indentation offset for array-like attributes. Defaults to 2.
- Returns:
A formatted string that represents the object’s attributes in a readable format.
- Return type:
- Example Output:
` MyClass ---------------------------------------------------------------------------------------------------- num_iq_samples_dataset 1000 fft_size 512 sample_rate 1000.0 num_signals_min 1 num_signals_max 5 num_signals_distribution [0.2, 0.3, 0.5] snr_db_min 5.0 snr_db_max 30.0 signal_duration_min 0.001 signal_duration_max 0.01 signal_bandwidth_min 10 signal_bandwidth_max 100 signal_center_freq_min -10 signal_center_freq_max 10 class_list [Class1, Class2, Class3] class_distribution [0.3, 0.4, 0.3] seed 42 `
- torchsig.utils.printing.dataset_metadata_repr(dataset_metadata) str[source]¶
Return a string representation of the object for debugging and inspection.
This method generates a string that provides a concise yet detailed summary of the object’s state, useful for debugging or interacting with the object in an interactive environment (e.g., REPL, Jupyter notebooks).
The __repr__ method is intended to give an unambiguous, readable string that represents the object. The returned string includes key attributes and their values, formatted in a way that can be interpreted back as code, i.e., it aims to provide a string that could be used to recreate the object (though not necessarily identical, as it is for debugging purposes).
- Returns:
- A detailed, formatted string that represents the object’s state, showing
key attributes and their current values.
- Return type:
Randomization Utils¶
Utility to handle random number generators.
- class torchsig.utils.random.Seedable(seed: int | None = None, parent: Seedable | None = None, **kwargs)[source]¶
Bases:
objectA class/interface representing objects capable of accessing random numbers and being seeded.
Stores an internal random number generator object. Can be seeded with the Seedable.seed(seed_value: int) function. Two Seedable objects with the same seed will always generate/access the same random values in the same order. Containing or composing Seedable objects are generally responsible for seeding contained or composed Seedable objects.
- add_parent(parent: Seedable) None[source]¶
Add parent Seedable object and set up RNGs accordingly.
- Parameters:
parent – Parent Seedable object to add.
- seed(seed: int) None[source]¶
Seed number generators with given seed.
- Parameters:
seed – Seed to use.
- get_second_seed(seed: int) int[source]¶
Gets second seed, usually used to seed both torch and numpy generators with slightly different seeds.
- Parameters:
seed – Seed to use.
- Returns:
New seed.
- get_distribution(params: list | tuple | float, scaling: str = 'linear') Distribution[source]¶
Create distribution function with proper seeding.
- Parameters:
params – Parameters for distribution.
scaling – Scaling param for distribution. Defaults to ‘linear’.
- Returns:
Distribution function, seeded.
- Return type:
- torchsig.utils.random.make_distribution(params: list | tuple | float, scaling: str = 'linear') Distribution[source]¶
Creates distribution given params.
- Parameters:
params – Params for distribution.
scaling – Scaling param for distribution. Defaults to ‘linear’.
- Raises:
NotImplementedError – params is unimplamented type.
ValueError – undefined distribution.
- Returns:
Distribution function from params.
- Return type:
- class torchsig.utils.random.Distribution(params: Any, **kwargs)[source]¶
Bases:
SeedableA class for representing random distributions.
Created by calling get_distribution(params) on a Seedable object. Distributions are callable, such that some_seedable.get_distribution(params)() should return a random number from the distribution.
- get_value() Any[source]¶
Samples from distribution function, returns a value.
- Raises:
NotImplementedError – Subclasses must implement this method.
- Returns:
Value(s) from distribution.
- class torchsig.utils.random.ChoiceDistribution(params: list | ndarray | int, **kwargs)[source]¶
Bases:
DistributionA class for handling random choices from lists.
- class torchsig.utils.random.UniformRangeDistribution(params: tuple[float, float], **kwargs)[source]¶
Bases:
DistributionA class for handling random uniform ranges.
- class torchsig.utils.random.Log10UniformRangeDistribution(params: tuple[float, float], **kwargs)[source]¶
Bases:
DistributionA class for handling log10-weighted random uniform ranges.
- get_value() Any[source]¶
Samples a random value from the log10-weighted uniform distribution.
- Returns:
Random value from the log10-weighted uniform distribution.
- Raises:
ValueError – If params contain 0 or negative numbers.