Utilities

Extra utilities such as writing and reading to disk or type checking are included in the torchsig/utils folder.

The following utilities are available:

Digital Signal Processing Utils

Digital Signal Processing (DSP) Utils

torchsig.utils.dsp.slice_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]

Slices the tail of a signal

Parameters:
  • input_signal (np.ndarray) – Input signal

  • num_samples (int) – Maximum number of samples for a signal

Raises:

ValueError – If signal is too short to be sliced

Returns:

Signal with length sliced to num_samples

Return type:

np.ndarray

torchsig.utils.dsp.slice_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]

Slices the head and tail of a signal

Parameters:
  • input_signal (np.ndarray) – Input signal

  • num_samples (int) – Maximum number of samples for a signal

Raises:

ValueError – If signal is too short to be sliced

Returns:

Signal with length sliced to num_samples

Return type:

np.ndarray

torchsig.utils.dsp.pad_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray[source]

Zero pads the head and tail of a signal

Parameters:
  • input_signal (np.ndarray) – Input signal

  • num_samples (int) – Desired length of signal

Raises:

ValueError – If signal is too long to be padded

Returns:

Signal with length padded to num_samples

Return type:

np.ndarray

torchsig.utils.dsp.upconversion_anti_aliasing_filter(input_signal: ndarray, center_freq: float, bandwidth: float, sample_rate: float, frequency_max: float, frequency_min: float)[source]

Applies a BPF to avoid aliasing

Upconversion of a signal away from baseband can force energy to alias around the -fs/2 or +fs/2 boundary, depending on the amount of sidelobes and the amount the signal has been frequency shifted by. The function checks the combination of the center frequency and bandwidth to see if any of the energy exceeds the limits specified by frequency_max and frequency_min, and if so, builds and applies the filter.

Parameters:
  • input_signal (np.ndarray) – Input signal modulated to band-pass

  • center_freq (float) – Center frequency of the signal

  • bandwidth (float) – Bandwidth of the signal

  • sample_rate (float) – Sample rate of the signal

  • frequency_max (float) – The maximum frequency where energy content can reside

  • frequency_min (float) – The minimum frequency where energy content can reside

Returns:

Anti-aliased signal float: Updated center frequency for the bounding box float: Updated bandwidth for the bounding box

Return type:

np.ndarray

torchsig.utils.dsp.is_even(number)[source]

Is the number even?

Returns true if the number is even, false if the number is odd

Parameters:

number – Any number

Returns:

Returns true if number is even, false if number is odd

Return type:

bool

torchsig.utils.dsp.is_odd(number)[source]

Is the number odd?

Returns true if the number is odd, false if the number is even

Parameters:

number – Any number

Returns:

Returns true if number is odd, false if number is even

Return type:

bool

torchsig.utils.dsp.is_multiple_of_4(number)[source]

Is the number a multiple of 4?

Returns true if the number is a multiple of 4, false otherwise. A number os a multiple of 4 if both the number is even and the number divided by 2 is even.

Parameters:

number – Any number

Returns:

Returns true if number is a multiple of 4, false otherwise

Return type:

bool

torchsig.utils.dsp.interpolate_power_of_2_resampler(input_signal: ndarray, interpolation_rate: int) ndarray[source]

Applies power of 2 resampling

Parameters:
  • input_signal (np.ndarray) – Input signal to be resampled

  • interpolation_rate (_type_) – Interpolate rate, must be greater than 0. For interpolation,

  • 2. (interpolate_rate >=)

Raises:
  • ValueError – Throws error if the interpolation rate is not an integer

  • ValueError – Throws error if the interpolation rate is not >= 2.

  • ValueError – Throws error if the interpolation rate is not a power of 2.

Returns:

Interpolated signal

Return type:

np.ndarray

torchsig.utils.dsp.design_half_band_filter(stage_number: int = 0, passband_percentage: float = 0.8, attenuation_db: float = 120) ndarray[source]

Designs half band filter weights for dyadic resampling

Implements the filter design for dyadic (power of 2) resampling, see fred harris, Multirate Signal Processing for Communication Systems, 2nd Edition, Chapter 8.7.

The dyadic filter uses a series of stages, a multi-stage structure, to efficiently implement large resampling rates. For interpolation, each additional stage increases the resampling rate by a factor of 2, and therefore the signal bandwidth becomes a smaller relative proportion of the sample rate. Therefore, both the passband edge will be decreased for each successive stage, which also allows for the transition bandwidth to be increased for each successive stage, thereby also reducing the amount of computation needed.

Parameters:
  • stage_number (int, optional) – Stage number in the cascade, must be greater than or equal to zero. Defaults to 0.

  • passband_percentage (float, optional) – The proportion of the available bandwidth used for the passband edge. The default of 0.8 translates into the passband edge being 80% of maximum passband of fs/4, or 0.8*fs/4.

  • attenuation_db (float, optional) – The sidelobe attenuation level, must be greater than zero. Defaults to 120.

Raises:

ValueError – Checks to ensure that the filter length has the appropriate length

Returns:

Half band filter weights

Return type:

np.ndarray

torchsig.utils.dsp.multistage_polyphase_resampler(input_signal: ndarray, resample_rate: float) ndarray[source]

Multi-stage polyphase filterbank-based resampling.

If the resampling rate is 1.0, then nothing is done and then same input signal is returned. If the resampling rate is greater than 1, then it performs interpolation using multistage_polyphase_interpolator. If the resampling rate is less than 1, then it performs interpolation using multistage_polyphase_decimator.

Parameters:
  • input_signal (np.ndarray) – The input signal to be resampled

  • resample_rate (float) – The resampling rate. Must be greater than 0.

Returns:

The resampled signal

Return type:

np.ndarray

torchsig.utils.dsp.multistage_polyphase_decimator(input_signal: ndarray, decimation_rate: float) ndarray[source]

Multi-stage polyphase filterbank-based decimation

The decimation is applied with two possible stages. The first stage implements the an integer rate portion and the second stage implements the fractional rate portion.

For example, a resampling rate of 0.4 is a decimation by 2.5. The decimation of 2.5 is represented by an integer decimation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore a decimation by 2 is applied followed by a decimation of 1.25.

Parameters:
  • input_signal (np.ndarray) – The input signal

  • decimation_rate (float) – The decimation rate. Must be greater or equal to 1.

Returns:

The decimated signal

Return type:

np.ndarray

torchsig.utils.dsp.multistage_polyphase_interpolator(input_signal: ndarray, resample_rate_ideal: float) ndarray[source]

Multi-stage polyphase filterbank-based interpolation

The interpolation is applied with two possible stages. The first stage implements the the fractional rate portion and the the second stage implements the integer rate portion.

For example, a resampling rate of 2.5 is an interpolation of 2.5. The interpolation of 2.5 is represented by an integer interpolation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore an interpolation of of 1.25 is applied followed by an interpolation of 2.

Parameters:
  • input_signal (np.ndarray) – The input signal

  • decimation_rate (float) – The interpolation rate. Must be greater or equal to 1.

Returns:

The interpolated signal

Return type:

np.ndarray

torchsig.utils.dsp.polyphase_fractional_resampler(input_signal: ndarray, fractional_rate: float) ndarray[source]

Fractional rate polyphase resampler

Implements a fractional rate resampler through the SciPy upfirdn() function with a large number of branches. A fixed “up” rate of 10,000 is used and the fractional rate then deterimes the “down” rate, such that up/down reasonably approximates the desired fractional resampling rate.

Parameters:
  • input_signal (np.ndarray) – Input signal to be resampled

  • fractional_rate (float) – The fractional interpolation rate, must be greater than 0.

Returns:

Resampled signal

Return type:

np.ndarray

torchsig.utils.dsp.prototype_polyphase_filter_interpolation(num_branches: int, attenuation_db=120) ndarray[source]

Designs polyphase filterbank weights for interpolation

Parameters:
  • num_branches (int) – Number of branches in the polyphase filterbank.

  • attenuation_db (int, optional) – Sidelobe attenuation level in dB. Defaults to 120.

Returns:

Filter weights

Return type:

np.ndarray

torchsig.utils.dsp.prototype_polyphase_filter_decimation(num_branches: int, attenuation_db=120) ndarray[source]

Designs polyphase filterbank weights for decimation

Parameters:
  • num_branches (int) – Number of branches in the polyphase filterbank.

  • attenuation_db (int, optional) – Sidelobe attenuation level in dB. Defaults to 120.

Returns:

Filter weights

Return type:

np.ndarray

torchsig.utils.dsp.prototype_polyphase_filter(num_branches: int, attenuation_db=120) ndarray[source]

Designs the prototype filter for a polyphase filter bank

Parameters:
  • num_branches (int) – Number of branches in the polyphase filterbank

  • attenuation_db (int, optional) – Sidelobe attenuation level. Defaults to 120.

Returns:

Filter weights

Return type:

np.ndarray

torchsig.utils.dsp.polyphase_integer_interpolator(input_signal: ndarray, interpolation_rate: int) ndarray[source]

Integer-rate polyphase filterbank-based interpolation

Parameters:
  • input_signal (np.ndarray) – Input signal to be interpolated

  • interpolation_rate (int) – The interpolation rate

Raises:

ValueError – Throws an error if the right number of samples are not produced

Returns:

Interpolated output signal

Return type:

np.ndarray

torchsig.utils.dsp.polyphase_decimator(input_signal: ndarray, decimation_rate: int) ndarray[source]

Integer-rate polyphase filterbank-based decimation

Parameters:
  • input_signal (np.ndarray) – Input signal to be decimated

  • decimation_rate (int) – The decimation rate

Raises:

ValueError – Throws an error if the right number of samples are not produced

Returns:

Decimated output signal

Return type:

np.ndarray

torchsig.utils.dsp.upsample(signal: ndarray, rate: int) ndarray[source]

Upsamples a signal

Upsamples a signal by insertion of zeros. Ex: upsample by 2 produces: sample, 0, sample, 0, sample 0, etc., and upsample by 3 produces sample, 0, 0, sample, 0, 0, etc.

Parameters:
  • signal (np.ndarray) – The input signal

  • rate (int) – The upsampling rate, must be > 1

Raises:
  • ValueError – Throws an error when the rate is less or equal to 1

  • ValueError – Throws an error when the rate is not an integer

Returns:

The upsampled signal

Return type:

np.ndarray

torchsig.utils.dsp.center_freq_from_lower_upper_freq(lower_freq: float, upper_freq: float) float[source]

Calculates center frequency from lower frequency and upper frequency

Parameters:
  • lower_freq (float) – The lower frequency corresponding to the 3 dB bandwidth of the signal

  • upper_freq (float) – The upper frequency corresponding to the 3 dB bandwidth of the signal

Returns:

The center frequency

Return type:

float

torchsig.utils.dsp.bandwidth_from_lower_upper_freq(lower_freq: float, upper_freq: float) float[source]

Calculates bandwidth from lower frequency and upper frequency

Parameters:
  • lower_freq (float) – The lower frequency corresponding to the 3 dB bandwidth of the signal

  • upper_freq (float) – The upper frequency corresponding to the 3 dB bandwidth of the signal

Returns:

The bandwidth

Return type:

float

torchsig.utils.dsp.lower_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float[source]

Calculates the lower frequency from center frequency and bandwidth

Parameters:
  • center_freq (float) – The center frequency of the signal

  • bandwidth (float) – The bandwidth of the signal

Returns:

The lower frequency

Return type:

float

torchsig.utils.dsp.upper_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float[source]

Calculates upper frequency from center frequency and bandwidth

Parameters:
  • center_freq (float) – The center frequency of the signal

  • bandwidth (float) – The bandwidth of the signal

Returns:

The upper frequency

Return type:

float

torchsig.utils.dsp.frequency_shift(signal: ndarray, frequency: float, sample_rate: float) ndarray[source]

Performs a frequency shift

Parameters:
  • signal (np.ndarray) – Input signal

  • frequency (float) – The frequency to shift by. Must have the same units as sample_rate.

  • sample_rate (float) – The sample rate of the signal. Must have the same units as frequency.

Returns:

The frequency shifted signal

Return type:

np.ndarray

torchsig.utils.dsp.compute_spectrogram(iq_samples: ndarray, fft_size: int, fft_stride: int) ndarray[source]

Computes two-dimensional spectrogram values in dB.

Parameters:
  • iq_samples (np.ndarray) – Input signal.

  • fft_size (int) – The size of the FFT in number of bins.

  • fft_stride (int) – The stride is the amount by which the input sample pointer increases for each FFT. When fft_stride=fft_size, then there is no overlap of input samples in successive FFTs. When fft_stride=fft_size/2, there is 50% overlap of input samples between successive FFTs.

Raises:

ValueError – Throws an error if fft_stride is less than 0 or greater than fft_size.

Returns:

Two-dimensional array of spectrogram values in dB.

Return type:

np.ndarray

torchsig.utils.dsp.estimate_tone_bandwidth(num_samples: int, sample_rate: float)[source]

Estimate the bandwidth of a tone

The bandwidth of a tone is completely defined by the number of samples in the time-series.

Parameters:
  • num_samples (int) – The length of the tone in samples.

  • sample_rate (float) – The sample rate associated with the tone.

Returns:

Bandwidth estimate of the tone

Return type:

np.ndarray

torchsig.utils.dsp.convolve(signal: ndarray, taps: ndarray) ndarray[source]

Wrapper function to implement convolution()

A wrapped version of SciPy’s convolve(), which discards trasition regions resulting from the convolution process.

Parameters:
  • signal (np.ndarray) – The input signal

  • taps (np.ndarray) – The filter weights

Returns:

The convolution output

Return type:

np.ndarray

torchsig.utils.dsp.low_pass(cutoff: float, transition_bandwidth: float, sample_rate: float, attenuation_db: float = 120) ndarray[source]

Low-pass filter design

Parameters:
  • cutoff (float) – The filter cutoff, 0 < cutoff < sample_rate/2. Must be in the same units as sample_rate.

  • transition_bandwidth (float) – The transition bandwidth of the filter, 0 < transition_bandwidth < sample_rate/2. Must be in the same units as sample_rate.

  • sample_rate (float) – The sampling rate associated with the filter design.

  • attenuation_db (float, optional) – Sidelobe attenuation level. Defaults to 120.

Returns:

Filter weights

Return type:

np.ndarray

torchsig.utils.dsp.estimate_filter_length(transition_bandwidth: float, attenuation_db: float, sample_rate: float) int[source]

Estimates FIR filter length

Estimate the length of an FIR filter using fred harris’ approximation, Multirate Signal Processing for Communication Systems, Second Edition, p.59.

Parameters:
  • transition_bandwidth (float) – The transition bandwidth of the filter, 0 < transition_bandwidth < sample_rate/2.

  • attenuation_db (float) – Sidelobe attenuation level in dB.

  • sample_rate (float) – The sampling rate associated with the filter design.

Returns:

The estimated filter length

Return type:

int

torchsig.utils.dsp.srrc_taps(iq_samples_per_symbol: int, filter_span_in_symbols: int, alpha: float = 0.35) ndarray[source]

Designs square-root raised cosine (SRRC) pulse shaping filter

Parameters:
  • iq_samples_per_symbol (int) – The samples-per-symbol (SPS) of the underlying modulation, equivalent to the oversampling rate.

  • filter_span_in_symbols (int) – The filter span in number of symbols.

  • alpha (float, optional) – The alpha roll-off value of the pulse shaping filter, which is the amount of excess bandwidth. Defaults to 0.35.

Returns:

SRRC filter weights

Return type:

np.ndarray

torchsig.utils.dsp.gaussian_taps(samples_per_symbol: int, bt: float = 0.35) ndarray[source]

Designs Gaussian filter weights

Parameters:
  • samples_per_symbol (int) – Samples-per-symbol (SPS) for the underlying modulation, equivalent to the oversampling rate.

  • bt (float, optional) – Time-bandwidth product. Defaults to 0.35.

Returns:

Gaussian filter weights

Return type:

np.ndarray

torchsig.utils.dsp.low_pass_iterative_design(cutoff: float, transition_bandwidth: float, sample_rate: float, desired_attenuation_db: float = 120) ndarray[source]

Dataset Utils

TorchSig Dataset generation code for command line

torchsig.utils.generate.generate(root: str, dataset_metadata: DatasetMetadata, batch_size: int, num_workers: int)[source]

Generates and saves a dataset to disk.

This function selects the dataset type (‘narrowband’ or ‘wideband’) based on the provided metadata and then calls the DatasetCreator class to generate the dataset and save it to disk. It writes the dataset in batches using the specified batch size and number of workers.

Parameters:
  • root (str) – The root directory where the dataset will be saved.

  • dataset_metadata (DatasetMetadata) – Metadata that defines the dataset type and properties.

  • batch_size (int) – The number of samples per batch to process.

  • num_workers (int) – The number of worker threads to use for loading the data in parallel.

Raises:

ValueError – If the dataset type is unknown or invalid.

Reading/Writing Utils

Writer

Dataset Writer Utils

class torchsig.utils.writer.DatasetCreator(dataset: ~torchsig.datasets.datasets.NewTorchSigDataset, root: str, overwrite: bool = False, batch_size: int = 1, num_workers: int = 1, collate_fn: ~typing.Callable = <function collate_fn>, tqdm_desc: str | None = None, file_handler: ~torchsig.utils.file_handlers.base_handler.TorchSigFileHandler = <class 'torchsig.utils.file_handlers.zarr.ZarrFileHandler'>, train: bool | None = None)[source]

Bases: object

Class for creating a dataset and saving it to disk in batches.

This class generates a dataset if it doesn’t already exist on disk. It processes the data in batches and saves it using a specified file handler. The class allows setting options like whether to overwrite existing datasets, batch size, and number of worker threads.

root

The root directory where the dataset will be saved.

Type:

Path

overwrite

Flag indicating whether to overwrite an existing dataset.

Type:

bool

batch_size

The number of samples in each batch.

Type:

int

num_workers

The number of worker threads to use for data loading.

Type:

int

save_type

The type of dataset being saved (“raw” or “processed”).

Type:

str

tqdm_desc

A description for the progress bar.

Type:

str

writer

The file handler used for saving the dataset.

Type:

TorchSigFileHandler

dataloader

The DataLoader used to load data in batches.

Type:

DataLoader

get_writing_info_dict() Dict[str, Any][source]

Returns a dictionary with information about the dataset being written.

This method gathers information regarding the root, overwrite status, batch size, number of workers, file handler class, and the save type of the dataset.

Returns:

Dictionary containing the dataset writing configuration.

Return type:

Dict[str, Any]

check_yamls() List[Tuple[str, Any, Any]][source]

Checks for differences between the dataset metadata on disk and the dataset metadata in memory.

Compares the dataset metadata that would be written to disk against the existing metadata on disk. Returns a list of differences.

Returns:

List of differences between metadata on disk and in memory.

Return type:

List[Tuple[str, Any, Any]]

create() None[source]

Creates the dataset on disk by writing batches to the file handler.

This method generates the dataset in batches and saves it to disk. If the dataset already exists and overwrite is set to False, it will skip regeneration.

The method also writes the dataset metadata and writing information to YAML files.

Raises:

ValueError – If the dataset is already generated and overwrite is set to False.

YAML Utils

YAML utilities

torchsig.utils.yaml.custom_representer(dumper, value)[source]

Custom representer for YAML to handle sequences (lists).

This function customizes how lists are represented in the YAML output, using flow style for sequences (inline lists).

Parameters:
  • dumper (yaml.Dumper) – The YAML dumper responsible for serializing the data.

  • value (list) – The list to be represented in YAML.

Returns:

The dumper with the custom representation for the list.

Return type:

yaml.Dumper

torchsig.utils.yaml.write_dict_to_yaml(filename: str, info_dict: dict) None[source]

Writes a dictionary to a YAML file with customized settings.

This function writes the provided info_dict to a YAML file. It customizes the representation of lists by using the custom_representer, and it uses specific formatting options (e.g., no sorting of keys, custom line width).

Parameters:
  • filename (str) – The name of the YAML file to which the dictionary will be written.

  • info_dict (dict) – The dictionary to be written to the YAML file.

Returns:

This function does not return any value.

Return type:

None

File Handlers

File Handlers for writing and reading datasets to/from disk

Only write one item from a TorchSigDataset’s __getitem__ method

class torchsig.utils.file_handlers.base_handler.BaseFileHandler(root: str)[source]

Bases: object

setup() None[source]
teardown() None[source]
exists() bool[source]
write(batch_idx: int, batch: Any) None[source]
load(idx: int) Any[source]
static static_load(filename: str, idx: int) Any[source]
class torchsig.utils.file_handlers.base_handler.TorchSigFileHandler(root: str, dataset_metadata: DatasetMetadata, batch_size: int, train: bool | None = None)[source]

Bases: BaseFileHandler

write(batch_idx: int, batch: Any) None[source]
static size(dataset_path: str) int[source]
static static_load(filename: str, idx: int) Tuple[ndarray, List[Dict[str, Any]]][source]
load(idx: int) Tuple[ndarray, List[Dict[str, Any]]][source]
class torchsig.utils.file_handlers.zarr.ZarrFileHandler(root: str, dataset_metadata: DatasetMetadata, batch_size: int, train: bool | None = None)[source]

Bases: TorchSigFileHandler

Handler for reading and writing data to/from a Zarr file format.

This class extends the TorchSigFileHandler and provides functionality to handle reading, writing, and managing Zarr-based storage for dataset samples.

datapath_filename

The name of the file used to store the data in Zarr format.

Type:

str

datapath_filename = 'data.zarr'
chunk_size = (100,)
exists() bool[source]

Checks if the Zarr file exists at the specified path.

Returns:

True if the Zarr file exists, otherwise False.

Return type:

bool

write(batch_idx: int, batch: Any) None[source]

Writes a sample (data and targets) to the Zarr file at the specified index.

Parameters:
  • idx (int) – The index at which to store the data in the Zarr file.

  • data (np.ndarray) – The data to write to the Zarr file.

  • targets (Any) – The corresponding targets to write as metadata for the sample.

Notes

If the index is greater than the current size of the array, the array is expanded to accommodate the new sample.

static size(dataset_path: str) int[source]
static static_load(filename: str, idx: int) Tuple[ndarray, List[Dict[str, Any]]][source]

Loads a sample from the Zarr file at the specified index.

Parameters:
  • filename (str) – Path to the directory containing the Zarr file.

  • idx (int) – The index of the sample to load.

Returns:

The data and the associated metadata for the sample.

Return type:

Tuple[np.ndarray, List[Dict[str, Any]]]

Raises:

IndexError – If the index is out of bounds.

load(idx: int) Tuple[ndarray, Dict[str, Any]] | Tuple[Any, ...][source]

Loads a sample from the Zarr file at the specified index into memory.

Parameters:

idx (int) – The index of the sample to load.

Returns:

The data and the corresponding targets for the sample.

Return type:

Tuple[np.ndarray, List[Dict[str, Any]] | Tuple[Any, …]]]

Raises:

IndexError – If the index is out of bounds.

Variable and Data Verification Utils

Data verification and error checking utils

torchsig.utils.verify.verify_int(a: int, name: str, low: int = 0, high: int | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) int[source]

Verifies that the value a is an integer and within the specified bounds.

Parameters:
  • a (int) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • low (int, optional) – The lower bound of the value. Defaults to 0.

  • high (int, optional) – The upper bound of the value. Defaults to None.

  • clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.

  • clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.

  • exclude_low (bool, optional) – If True, a must be strictly greater than low. Defaults to False.

  • exclude_high (bool, optional) – If True, a must be strictly less than high. Defaults to False.

Raises:

ValueError – If a is not an integer or out of bounds.

Returns:

The verified integer value a.

Return type:

int

torchsig.utils.verify.verify_float(f: float, name: str, low: float = 0.0, high: float | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) float[source]

Verifies that the value f is a float and within the specified bounds.

Parameters:
  • f (float) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • low (float, optional) – The lower bound of the value. Defaults to 0.0.

  • high (float, optional) – The upper bound of the value. Defaults to None.

  • clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.

  • clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.

  • exclude_low (bool, optional) – If True, f must be strictly greater than low. Defaults to False.

  • exclude_high (bool, optional) – If True, f must be strictly less than high. Defaults to False.

Raises:

ValueError – If f is not a float or out of bounds.

Returns:

The verified float value f.

Return type:

float

torchsig.utils.verify.verify_str(s: str, name: str, valid: List[str] = [], str_format: str = 'lower') str[source]

Verifies that the value s is a string and optionally formats it according to the specified format.

Parameters:
  • s (str) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • valid (List[str], optional) – A list of valid string values. Defaults to an empty list.

  • str_format (str, optional) – The format for the string. Can be “lower”, “upper”, or “title”. Defaults to “lower”.

Raises:

ValueError – If s is not a string or if it is not in the list of valid values.

Returns:

The verified string value s in the specified format.

Return type:

str

torchsig.utils.verify.verify_distribution_list(distro: List[float], required_length: int, distro_name: str, list_name: str) List[float][source]

Verifies and normalizes a given distribution list.

If the distribution list is None, it assumes a uniform distribution and returns it as is. If the distribution list is not of the required length or does not sum to 1.0, it raises an error or normalizes the list to sum to 1.0.

Parameters:
  • distro (List[float]) – The distribution list to verify. Can be None for a uniform distribution.

  • required_length (int) – The expected length of the distribution list.

  • distro_name (str) – The name of the distribution list (used for error messages).

  • list_name (str) – The name of the list this distribution corresponds to (used for error messages).

Returns:

The verified and possibly normalized distribution list.

Return type:

List[float]

Raises:

ValueError – If the distribution list is not of the required length or does not sum to 1.0 and cannot be normalized.

torchsig.utils.verify.verify_list(l: list, name: str, no_duplicates: bool = False, data_type=None) list[source]

Verifies that the value l is a list and optionally checks for duplicates or verifies item types.

Parameters:
  • l (list) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • no_duplicates (bool, optional) – If True, raises an error if the list contains duplicates. Defaults to False.

  • data_type (type, optional) – The type each item in the list should have. Defaults to None.

Raises:

ValueError – If l is not a list, if it contains duplicates (when no_duplicates=True), or if any item in the list is not of the required type.

Returns:

The verified list l.

Return type:

list

torchsig.utils.verify.verify_numpy_array(n: ndarray, name: str, min_length: int | None = None, max_length: int | None = None, exact_length: int | None = None, data_type=None) ndarray[source]

Verifies that the value n is a NumPy array and optionally checks its length or item types.

Parameters:
  • n (np.ndarray) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • min_length (int, optional) – The minimum length of the array. Defaults to None.

  • max_length (int, optional) – The maximum length of the array. Defaults to None.

  • exact_length (int, optional) – The exact length of the array. Defaults to None.

  • data_type (type, optional) – The type each item in the array should have. Defaults to None.

Raises:

ValueError – If n is not a NumPy array or its length is not within the specified bounds, or if any item in the array is not of the required type.

Returns:

The verified NumPy array n.

Return type:

np.ndarray

torchsig.utils.verify.verify_transforms(t: Transform) List[Transform | Callable][source]

Verifies that the value t is a valid transform, which can be a single transform or a list of transforms.

Parameters:

t (Transform) – The transform(s) to be checked.

Raises:

ValueError – If t is not a valid transform.

Returns:

The verified list of transforms.

Return type:

List[Transform | Callable]

torchsig.utils.verify.verify_target_transforms(tt: TargetTransform) List[TargetTransform | Callable][source]

Verifies that the value tt is a valid target transform, which can be a single target transform or a list of transforms.

Parameters:

tt (TargetTransform) – The target transform(s) to be checked.

Raises:

ValueError – If tt is not a valid target transform.

Returns:

The verified list of target transforms.

Return type:

List[TargetTransform | Callable]

torchsig.utils.verify.verify_dict(d: dict, name: str, required_keys: list = [], required_types: list = [])[source]

Verifies that the value d is a dictionary and optionally checks for required keys and their types.

Parameters:
  • d (dict) – The value to be checked.

  • name (str) – The name of the value to be used in error messages.

  • required_keys (list, optional) – A list of required keys in the dictionary. Defaults to an empty list.

  • required_types (list, optional) – A list of types for each required key. Defaults to an empty list.

Raises:

ValueError – If d is not a dictionary, or if any required key is missing or has an incorrect type.

Returns:

The verified dictionary d.

Return type:

dict

Printing Utils

Contains Helpful methods for properly implementing __str__ and __repr__ methods of classes

torchsig.utils.printing.generate_repr_str(class_object: Any, exclude_params: List[str] = []) str[source]

Generates a string representation of the class object, excluding specified parameters.

This function creates a human-readable string representation of the given class object, including its class name and parameters. It excludes any parameters specified in the exclude_params list. If the class object is an instance of Seedable, certain attributes related to seeding are handled specifically.

Parameters:
  • class_object (Any) – The class object to generate the string representation for.

  • exclude_params (List[str], optional) – A list of parameter names to exclude from the string representation. Defaults to an empty list.

Returns:

A formatted string representation of the class object with parameters.

Return type:

str

Raises:

AttributeError – If the class object does not have a __dict__ attribute or any other required attributes for the operation.

Example

>>> class Example:
>>>     def __init__(self, param1, param2):
>>>         self.param1 = param1
>>>         self.param2 = param2
>>> e = Example(1, 2)
>>> generate_repr_str(e)
'Example(param1=1,param2=2)'

Notes

  • If the class object is an instance of Seedable, the seed and parent attributes will be added back into the string representation.

torchsig.utils.printing.dataset_metadata_str(dataset_metadata, max_width: int = 100, first_col_width: int = 29, array_width_indent_offset: int = 2) str[source]

Custom string representation for the class.

This method returns a formatted string that provides a detailed summary of the object’s key attributes, including signal parameters, dataset configuration, and transform details. It uses textwrap.fill to format long attributes such as lists or arrays into a neatly wrapped format for easier readability.

The string includes information on the dataset’s configuration, signal characteristics, transformations, and other attributes in a human-readable way. The result is intended to provide a concise yet comprehensive overview of the object’s state, useful for debugging, logging, or displaying object details.

Parameters:
  • dataset_metadata (Any) – The dataset metadata object to generate a string for.

  • max_width (int, optional) – Maximum width of the output string. Defaults to 100.

  • first_col_width (int, optional) – Width of the first column in the output string. Defaults to 29.

  • array_width_indent_offset (int, optional) – Indentation offset for array-like attributes. Defaults to 2.

Returns:

A formatted string that represents the object’s attributes in a readable format.

Return type:

str

Example Output:

` MyClass ---------------------------------------------------------------------------------------------------- num_iq_samples_dataset            1000 num_samples                       5000 impairment_level                  0.8 fft_size                          512 sample_rate                       1000.0 num_signals_min                   1 num_signals_max                   5 num_signals_distribution          [0.2, 0.3, 0.5] snr_db_min                        5.0 snr_db_max                        30.0 signal_duration_percent_min       0.1 signal_duration_percent_max       0.9 transforms                        [TransformA, TransformB] target_transforms                 [TargetTransform1] class_list                        [Class1, Class2, Class3] class_distribution                [0.3, 0.4, 0.3] seed                               42 `

torchsig.utils.printing.dataset_metadata_repr(dataset_metadata) str[source]

Return a string representation of the object for debugging and inspection.

This method generates a string that provides a concise yet detailed summary of the object’s state, useful for debugging or interacting with the object in an interactive environment (e.g., REPL, Jupyter notebooks).

The __repr__ method is intended to give an unambiguous, readable string that represents the object. The returned string includes key attributes and their values, formatted in a way that can be interpreted back as code, i.e., it aims to provide a string that could be used to recreate the object (though not necessarily identical, as it is for debugging purposes).

Returns:

A detailed, formatted string that represents the object’s state, showing

key attributes and their current values.

Return type:

str

Randomization Utils

Utility to handle random number generators.

class torchsig.utils.random.Seedable(seed: int | None = None, parent=None)[source]

Bases: object

A class/interface representing objects capable of accessing random numbers and being seeded. Stores an inernal random number generator object. Can be seeded with the Seedable.seed(seed_value : long) function. Two Seedable objects with the same seed will always generate/access the same random values in the same order. Containing or composing Seedable objects are generally responsible for seeding contained or composed Seedable objects.

add_parent(parent) None[source]

Add parent Seedable object and set up RNGs accordingly

update_from_parent() None[source]

Update numpy and torch number generators with parent seed

seed(seed: int) None[source]

Seed number generators with given seed.

Parameters:

seed (int) – Seed to use.

get_second_seed(seed: int) int[source]

Gets second seed, usually used to seed both torch and numpy generators with slightly different seeds

Parameters:

seed (int) – Seed to use.

Returns:

New seed.

Return type:

int

setup_rngs() None[source]

Initialize torch and numpy number generators, and update its children.