Utilities¶
Extra utilities such as writing and reading to disk or type checking are included in the torchsig/utils folder.
The following utilities are available:
Digital Signal Processing Utils¶
Digital Signal Processing (DSP) Utils
- torchsig.utils.dsp.slice_tail_to_length(input_signal: ndarray, num_samples: int) ndarray [source]¶
Slices the tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Maximum number of samples for a signal
- Raises:
ValueError – If signal is too short to be sliced
- Returns:
Signal with length sliced to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.slice_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray [source]¶
Slices the head and tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Maximum number of samples for a signal
- Raises:
ValueError – If signal is too short to be sliced
- Returns:
Signal with length sliced to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.pad_head_tail_to_length(input_signal: ndarray, num_samples: int) ndarray [source]¶
Zero pads the head and tail of a signal
- Parameters:
input_signal (np.ndarray) – Input signal
num_samples (int) – Desired length of signal
- Raises:
ValueError – If signal is too long to be padded
- Returns:
Signal with length padded to num_samples
- Return type:
np.ndarray
- torchsig.utils.dsp.upconversion_anti_aliasing_filter(input_signal: ndarray, center_freq: float, bandwidth: float, sample_rate: float, frequency_max: float, frequency_min: float)[source]¶
Applies a BPF to avoid aliasing
Upconversion of a signal away from baseband can force energy to alias around the -fs/2 or +fs/2 boundary, depending on the amount of sidelobes and the amount the signal has been frequency shifted by. The function checks the combination of the center frequency and bandwidth to see if any of the energy exceeds the limits specified by frequency_max and frequency_min, and if so, builds and applies the filter.
- Parameters:
input_signal (np.ndarray) – Input signal modulated to band-pass
center_freq (float) – Center frequency of the signal
bandwidth (float) – Bandwidth of the signal
sample_rate (float) – Sample rate of the signal
frequency_max (float) – The maximum frequency where energy content can reside
frequency_min (float) – The minimum frequency where energy content can reside
- Returns:
Anti-aliased signal float: Updated center frequency for the bounding box float: Updated bandwidth for the bounding box
- Return type:
np.ndarray
- torchsig.utils.dsp.is_even(number)[source]¶
Is the number even?
Returns true if the number is even, false if the number is odd
- Parameters:
number – Any number
- Returns:
Returns true if number is even, false if number is odd
- Return type:
- torchsig.utils.dsp.is_odd(number)[source]¶
Is the number odd?
Returns true if the number is odd, false if the number is even
- Parameters:
number – Any number
- Returns:
Returns true if number is odd, false if number is even
- Return type:
- torchsig.utils.dsp.is_multiple_of_4(number)[source]¶
Is the number a multiple of 4?
Returns true if the number is a multiple of 4, false otherwise. A number os a multiple of 4 if both the number is even and the number divided by 2 is even.
- Parameters:
number – Any number
- Returns:
Returns true if number is a multiple of 4, false otherwise
- Return type:
- torchsig.utils.dsp.interpolate_power_of_2_resampler(input_signal: ndarray, interpolation_rate: int) ndarray [source]¶
Applies power of 2 resampling
- Parameters:
input_signal (np.ndarray) – Input signal to be resampled
interpolation_rate (_type_) – Interpolate rate, must be greater than 0. For interpolation,
2. (interpolate_rate >=)
- Raises:
ValueError – Throws error if the interpolation rate is not an integer
ValueError – Throws error if the interpolation rate is not >= 2.
ValueError – Throws error if the interpolation rate is not a power of 2.
- Returns:
Interpolated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.design_half_band_filter(stage_number: int = 0, passband_percentage: float = 0.8, attenuation_db: float = 120) ndarray [source]¶
Designs half band filter weights for dyadic resampling
Implements the filter design for dyadic (power of 2) resampling, see fred harris, Multirate Signal Processing for Communication Systems, 2nd Edition, Chapter 8.7.
The dyadic filter uses a series of stages, a multi-stage structure, to efficiently implement large resampling rates. For interpolation, each additional stage increases the resampling rate by a factor of 2, and therefore the signal bandwidth becomes a smaller relative proportion of the sample rate. Therefore, both the passband edge will be decreased for each successive stage, which also allows for the transition bandwidth to be increased for each successive stage, thereby also reducing the amount of computation needed.
- Parameters:
stage_number (int, optional) – Stage number in the cascade, must be greater than or equal to zero. Defaults to 0.
passband_percentage (float, optional) – The proportion of the available bandwidth used for the passband edge. The default of 0.8 translates into the passband edge being 80% of maximum passband of fs/4, or 0.8*fs/4.
attenuation_db (float, optional) – The sidelobe attenuation level, must be greater than zero. Defaults to 120.
- Raises:
ValueError – Checks to ensure that the filter length has the appropriate length
- Returns:
Half band filter weights
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_resampler(input_signal: ndarray, resample_rate: float) ndarray [source]¶
Multi-stage polyphase filterbank-based resampling.
If the resampling rate is 1.0, then nothing is done and then same input signal is returned. If the resampling rate is greater than 1, then it performs interpolation using multistage_polyphase_interpolator. If the resampling rate is less than 1, then it performs interpolation using multistage_polyphase_decimator.
- Parameters:
input_signal (np.ndarray) – The input signal to be resampled
resample_rate (float) – The resampling rate. Must be greater than 0.
- Returns:
The resampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_decimator(input_signal: ndarray, decimation_rate: float) ndarray [source]¶
Multi-stage polyphase filterbank-based decimation
The decimation is applied with two possible stages. The first stage implements the an integer rate portion and the second stage implements the fractional rate portion.
For example, a resampling rate of 0.4 is a decimation by 2.5. The decimation of 2.5 is represented by an integer decimation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore a decimation by 2 is applied followed by a decimation of 1.25.
- Parameters:
input_signal (np.ndarray) – The input signal
decimation_rate (float) – The decimation rate. Must be greater or equal to 1.
- Returns:
The decimated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.multistage_polyphase_interpolator(input_signal: ndarray, resample_rate_ideal: float) ndarray [source]¶
Multi-stage polyphase filterbank-based interpolation
The interpolation is applied with two possible stages. The first stage implements the the fractional rate portion and the the second stage implements the integer rate portion.
For example, a resampling rate of 2.5 is an interpolation of 2.5. The interpolation of 2.5 is represented by an integer interpolation of 2, and the fractional rate is therefore 2.5/2 = 1.25. Therefore an interpolation of of 1.25 is applied followed by an interpolation of 2.
- Parameters:
input_signal (np.ndarray) – The input signal
decimation_rate (float) – The interpolation rate. Must be greater or equal to 1.
- Returns:
The interpolated signal
- Return type:
np.ndarray
- torchsig.utils.dsp.polyphase_fractional_resampler(input_signal: ndarray, fractional_rate: float) ndarray [source]¶
Fractional rate polyphase resampler
Implements a fractional rate resampler through the SciPy upfirdn() function with a large number of branches. A fixed “up” rate of 10,000 is used and the fractional rate then deterimes the “down” rate, such that up/down reasonably approximates the desired fractional resampling rate.
- Parameters:
input_signal (np.ndarray) – Input signal to be resampled
fractional_rate (float) – The fractional interpolation rate, must be greater than 0.
- Returns:
Resampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.prototype_polyphase_filter_interpolation(num_branches: int, attenuation_db=120) ndarray [source]¶
Designs polyphase filterbank weights for interpolation
- torchsig.utils.dsp.prototype_polyphase_filter_decimation(num_branches: int, attenuation_db=120) ndarray [source]¶
Designs polyphase filterbank weights for decimation
- torchsig.utils.dsp.prototype_polyphase_filter(num_branches: int, attenuation_db=120) ndarray [source]¶
Designs the prototype filter for a polyphase filter bank
- torchsig.utils.dsp.polyphase_integer_interpolator(input_signal: ndarray, interpolation_rate: int) ndarray [source]¶
Integer-rate polyphase filterbank-based interpolation
- Parameters:
input_signal (np.ndarray) – Input signal to be interpolated
interpolation_rate (int) – The interpolation rate
- Raises:
ValueError – Throws an error if the right number of samples are not produced
- Returns:
Interpolated output signal
- Return type:
np.ndarray
- torchsig.utils.dsp.polyphase_decimator(input_signal: ndarray, decimation_rate: int) ndarray [source]¶
Integer-rate polyphase filterbank-based decimation
- Parameters:
input_signal (np.ndarray) – Input signal to be decimated
decimation_rate (int) – The decimation rate
- Raises:
ValueError – Throws an error if the right number of samples are not produced
- Returns:
Decimated output signal
- Return type:
np.ndarray
- torchsig.utils.dsp.upsample(signal: ndarray, rate: int) ndarray [source]¶
Upsamples a signal
Upsamples a signal by insertion of zeros. Ex: upsample by 2 produces: sample, 0, sample, 0, sample 0, etc., and upsample by 3 produces sample, 0, 0, sample, 0, 0, etc.
- Parameters:
signal (np.ndarray) – The input signal
rate (int) – The upsampling rate, must be > 1
- Raises:
ValueError – Throws an error when the rate is less or equal to 1
ValueError – Throws an error when the rate is not an integer
- Returns:
The upsampled signal
- Return type:
np.ndarray
- torchsig.utils.dsp.center_freq_from_lower_upper_freq(lower_freq: float, upper_freq: float) float [source]¶
Calculates center frequency from lower frequency and upper frequency
- torchsig.utils.dsp.bandwidth_from_lower_upper_freq(lower_freq: float, upper_freq: float) float [source]¶
Calculates bandwidth from lower frequency and upper frequency
- torchsig.utils.dsp.lower_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float [source]¶
Calculates the lower frequency from center frequency and bandwidth
- torchsig.utils.dsp.upper_freq_from_center_freq_bandwidth(center_freq: float, bandwidth: float) float [source]¶
Calculates upper frequency from center frequency and bandwidth
- torchsig.utils.dsp.frequency_shift(signal: ndarray, frequency: float, sample_rate: float) ndarray [source]¶
Performs a frequency shift
- torchsig.utils.dsp.compute_spectrogram(iq_samples: ndarray, fft_size: int, fft_stride: int) ndarray [source]¶
Computes two-dimensional spectrogram values in dB.
- Parameters:
iq_samples (np.ndarray) – Input signal.
fft_size (int) – The size of the FFT in number of bins.
fft_stride (int) – The stride is the amount by which the input sample pointer increases for each FFT. When fft_stride=fft_size, then there is no overlap of input samples in successive FFTs. When fft_stride=fft_size/2, there is 50% overlap of input samples between successive FFTs.
- Raises:
ValueError – Throws an error if fft_stride is less than 0 or greater than fft_size.
- Returns:
Two-dimensional array of spectrogram values in dB.
- Return type:
np.ndarray
- torchsig.utils.dsp.estimate_tone_bandwidth(num_samples: int, sample_rate: float)[source]¶
Estimate the bandwidth of a tone
The bandwidth of a tone is completely defined by the number of samples in the time-series.
- torchsig.utils.dsp.convolve(signal: ndarray, taps: ndarray) ndarray [source]¶
Wrapper function to implement convolution()
A wrapped version of SciPy’s convolve(), which discards trasition regions resulting from the convolution process.
- Parameters:
signal (np.ndarray) – The input signal
taps (np.ndarray) – The filter weights
- Returns:
The convolution output
- Return type:
np.ndarray
- torchsig.utils.dsp.low_pass(cutoff: float, transition_bandwidth: float, sample_rate: float, attenuation_db: float = 120) ndarray [source]¶
Low-pass filter design
- Parameters:
cutoff (float) – The filter cutoff, 0 < cutoff < sample_rate/2. Must be in the same units as sample_rate.
transition_bandwidth (float) – The transition bandwidth of the filter, 0 < transition_bandwidth < sample_rate/2. Must be in the same units as sample_rate.
sample_rate (float) – The sampling rate associated with the filter design.
attenuation_db (float, optional) – Sidelobe attenuation level. Defaults to 120.
- Returns:
Filter weights
- Return type:
np.ndarray
- torchsig.utils.dsp.estimate_filter_length(transition_bandwidth: float, attenuation_db: float, sample_rate: float) int [source]¶
Estimates FIR filter length
Estimate the length of an FIR filter using fred harris’ approximation, Multirate Signal Processing for Communication Systems, Second Edition, p.59.
- Parameters:
- Returns:
The estimated filter length
- Return type:
- torchsig.utils.dsp.srrc_taps(iq_samples_per_symbol: int, filter_span_in_symbols: int, alpha: float = 0.35) ndarray [source]¶
Designs square-root raised cosine (SRRC) pulse shaping filter
- Parameters:
iq_samples_per_symbol (int) – The samples-per-symbol (SPS) of the underlying modulation, equivalent to the oversampling rate.
filter_span_in_symbols (int) – The filter span in number of symbols.
alpha (float, optional) – The alpha roll-off value of the pulse shaping filter, which is the amount of excess bandwidth. Defaults to 0.35.
- Returns:
SRRC filter weights
- Return type:
np.ndarray
Dataset Utils¶
TorchSig Dataset generation code for command line
- torchsig.utils.generate.generate(root: str, dataset_metadata: DatasetMetadata, batch_size: int, num_workers: int)[source]¶
Generates and saves a dataset to disk.
This function selects the dataset type (‘narrowband’ or ‘wideband’) based on the provided metadata and then calls the DatasetCreator class to generate the dataset and save it to disk. It writes the dataset in batches using the specified batch size and number of workers.
- Parameters:
root (str) – The root directory where the dataset will be saved.
dataset_metadata (DatasetMetadata) – Metadata that defines the dataset type and properties.
batch_size (int) – The number of samples per batch to process.
num_workers (int) – The number of worker threads to use for loading the data in parallel.
- Raises:
ValueError – If the dataset type is unknown or invalid.
Reading/Writing Utils¶
Writer¶
Dataset Writer Utils
- class torchsig.utils.writer.DatasetCreator(dataset: ~torchsig.datasets.datasets.NewTorchSigDataset, root: str, overwrite: bool = False, batch_size: int = 1, num_workers: int = 1, collate_fn: ~typing.Callable = <function collate_fn>, tqdm_desc: str | None = None, file_handler: ~torchsig.utils.file_handlers.base_handler.TorchSigFileHandler = <class 'torchsig.utils.file_handlers.zarr.ZarrFileHandler'>, train: bool | None = None)[source]¶
Bases:
object
Class for creating a dataset and saving it to disk in batches.
This class generates a dataset if it doesn’t already exist on disk. It processes the data in batches and saves it using a specified file handler. The class allows setting options like whether to overwrite existing datasets, batch size, and number of worker threads.
- root¶
The root directory where the dataset will be saved.
- Type:
Path
- writer¶
The file handler used for saving the dataset.
- Type:
- dataloader¶
The DataLoader used to load data in batches.
- Type:
DataLoader
- get_writing_info_dict() Dict[str, Any] [source]¶
Returns a dictionary with information about the dataset being written.
This method gathers information regarding the root, overwrite status, batch size, number of workers, file handler class, and the save type of the dataset.
- Returns:
Dictionary containing the dataset writing configuration.
- Return type:
Dict[str, Any]
- check_yamls() List[Tuple[str, Any, Any]] [source]¶
Checks for differences between the dataset metadata on disk and the dataset metadata in memory.
Compares the dataset metadata that would be written to disk against the existing metadata on disk. Returns a list of differences.
- Returns:
List of differences between metadata on disk and in memory.
- Return type:
List[Tuple[str, Any, Any]]
- create() None [source]¶
Creates the dataset on disk by writing batches to the file handler.
This method generates the dataset in batches and saves it to disk. If the dataset already exists and overwrite is set to False, it will skip regeneration.
The method also writes the dataset metadata and writing information to YAML files.
- Raises:
ValueError – If the dataset is already generated and overwrite is set to False.
YAML Utils¶
YAML utilities
- torchsig.utils.yaml.custom_representer(dumper, value)[source]¶
Custom representer for YAML to handle sequences (lists).
This function customizes how lists are represented in the YAML output, using flow style for sequences (inline lists).
- Parameters:
dumper (yaml.Dumper) – The YAML dumper responsible for serializing the data.
value (list) – The list to be represented in YAML.
- Returns:
The dumper with the custom representation for the list.
- Return type:
yaml.Dumper
- torchsig.utils.yaml.write_dict_to_yaml(filename: str, info_dict: dict) None [source]¶
Writes a dictionary to a YAML file with customized settings.
This function writes the provided info_dict to a YAML file. It customizes the representation of lists by using the custom_representer, and it uses specific formatting options (e.g., no sorting of keys, custom line width).
File Handlers¶
File Handlers for writing and reading datasets to/from disk
Only write one item from a TorchSigDataset’s __getitem__ method
- class torchsig.utils.file_handlers.base_handler.TorchSigFileHandler(root: str, dataset_metadata: DatasetMetadata, batch_size: int, train: bool | None = None)[source]¶
Bases:
BaseFileHandler
- class torchsig.utils.file_handlers.zarr.ZarrFileHandler(root: str, dataset_metadata: DatasetMetadata, batch_size: int, train: bool | None = None)[source]¶
Bases:
TorchSigFileHandler
Handler for reading and writing data to/from a Zarr file format.
This class extends the TorchSigFileHandler and provides functionality to handle reading, writing, and managing Zarr-based storage for dataset samples.
- datapath_filename = 'data.zarr'¶
- chunk_size = (100,)¶
- exists() bool [source]¶
Checks if the Zarr file exists at the specified path.
- Returns:
True if the Zarr file exists, otherwise False.
- Return type:
- write(batch_idx: int, batch: Any) None [source]¶
Writes a sample (data and targets) to the Zarr file at the specified index.
- Parameters:
idx (int) – The index at which to store the data in the Zarr file.
data (np.ndarray) – The data to write to the Zarr file.
targets (Any) – The corresponding targets to write as metadata for the sample.
Notes
If the index is greater than the current size of the array, the array is expanded to accommodate the new sample.
- static static_load(filename: str, idx: int) Tuple[ndarray, List[Dict[str, Any]]] [source]¶
Loads a sample from the Zarr file at the specified index.
- Parameters:
- Returns:
The data and the associated metadata for the sample.
- Return type:
Tuple[np.ndarray, List[Dict[str, Any]]]
- Raises:
IndexError – If the index is out of bounds.
- load(idx: int) Tuple[ndarray, Dict[str, Any]] | Tuple[Any, ...] [source]¶
Loads a sample from the Zarr file at the specified index into memory.
- Parameters:
idx (int) – The index of the sample to load.
- Returns:
The data and the corresponding targets for the sample.
- Return type:
Tuple[np.ndarray, List[Dict[str, Any]] | Tuple[Any, …]]]
- Raises:
IndexError – If the index is out of bounds.
Variable and Data Verification Utils¶
Data verification and error checking utils
- torchsig.utils.verify.verify_int(a: int, name: str, low: int = 0, high: int | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) int [source]¶
Verifies that the value a is an integer and within the specified bounds.
- Parameters:
a (int) – The value to be checked.
name (str) – The name of the value to be used in error messages.
low (int, optional) – The lower bound of the value. Defaults to 0.
high (int, optional) – The upper bound of the value. Defaults to None.
clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.
clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.
exclude_low (bool, optional) – If True, a must be strictly greater than low. Defaults to False.
exclude_high (bool, optional) – If True, a must be strictly less than high. Defaults to False.
- Raises:
ValueError – If a is not an integer or out of bounds.
- Returns:
The verified integer value a.
- Return type:
- torchsig.utils.verify.verify_float(f: float, name: str, low: float = 0.0, high: float | None = None, clip_low: bool = False, clip_high: bool = False, exclude_low: bool = False, exclude_high: bool = False) float [source]¶
Verifies that the value f is a float and within the specified bounds.
- Parameters:
f (float) – The value to be checked.
name (str) – The name of the value to be used in error messages.
low (float, optional) – The lower bound of the value. Defaults to 0.0.
high (float, optional) – The upper bound of the value. Defaults to None.
clip_low (bool, optional) – If True, the value will be clipped to low if it is below low. Defaults to False.
clip_high (bool, optional) – If True, the value will be clipped to high if it exceeds high. Defaults to False.
exclude_low (bool, optional) – If True, f must be strictly greater than low. Defaults to False.
exclude_high (bool, optional) – If True, f must be strictly less than high. Defaults to False.
- Raises:
ValueError – If f is not a float or out of bounds.
- Returns:
The verified float value f.
- Return type:
- torchsig.utils.verify.verify_str(s: str, name: str, valid: List[str] = [], str_format: str = 'lower') str [source]¶
Verifies that the value s is a string and optionally formats it according to the specified format.
- Parameters:
s (str) – The value to be checked.
name (str) – The name of the value to be used in error messages.
valid (List[str], optional) – A list of valid string values. Defaults to an empty list.
str_format (str, optional) – The format for the string. Can be “lower”, “upper”, or “title”. Defaults to “lower”.
- Raises:
ValueError – If s is not a string or if it is not in the list of valid values.
- Returns:
The verified string value s in the specified format.
- Return type:
- torchsig.utils.verify.verify_distribution_list(distro: List[float], required_length: int, distro_name: str, list_name: str) List[float] [source]¶
Verifies and normalizes a given distribution list.
If the distribution list is None, it assumes a uniform distribution and returns it as is. If the distribution list is not of the required length or does not sum to 1.0, it raises an error or normalizes the list to sum to 1.0.
- Parameters:
distro (List[float]) – The distribution list to verify. Can be None for a uniform distribution.
required_length (int) – The expected length of the distribution list.
distro_name (str) – The name of the distribution list (used for error messages).
list_name (str) – The name of the list this distribution corresponds to (used for error messages).
- Returns:
The verified and possibly normalized distribution list.
- Return type:
List[float]
- Raises:
ValueError – If the distribution list is not of the required length or does not sum to 1.0 and cannot be normalized.
- torchsig.utils.verify.verify_list(l: list, name: str, no_duplicates: bool = False, data_type=None) list [source]¶
Verifies that the value l is a list and optionally checks for duplicates or verifies item types.
- Parameters:
l (list) – The value to be checked.
name (str) – The name of the value to be used in error messages.
no_duplicates (bool, optional) – If True, raises an error if the list contains duplicates. Defaults to False.
data_type (type, optional) – The type each item in the list should have. Defaults to None.
- Raises:
ValueError – If l is not a list, if it contains duplicates (when no_duplicates=True), or if any item in the list is not of the required type.
- Returns:
The verified list l.
- Return type:
- torchsig.utils.verify.verify_numpy_array(n: ndarray, name: str, min_length: int | None = None, max_length: int | None = None, exact_length: int | None = None, data_type=None) ndarray [source]¶
Verifies that the value n is a NumPy array and optionally checks its length or item types.
- Parameters:
n (np.ndarray) – The value to be checked.
name (str) – The name of the value to be used in error messages.
min_length (int, optional) – The minimum length of the array. Defaults to None.
max_length (int, optional) – The maximum length of the array. Defaults to None.
exact_length (int, optional) – The exact length of the array. Defaults to None.
data_type (type, optional) – The type each item in the array should have. Defaults to None.
- Raises:
ValueError – If n is not a NumPy array or its length is not within the specified bounds, or if any item in the array is not of the required type.
- Returns:
The verified NumPy array n.
- Return type:
np.ndarray
- torchsig.utils.verify.verify_transforms(t: Transform) List[Transform | Callable] [source]¶
Verifies that the value t is a valid transform, which can be a single transform or a list of transforms.
- Parameters:
t (Transform) – The transform(s) to be checked.
- Raises:
ValueError – If t is not a valid transform.
- Returns:
The verified list of transforms.
- Return type:
List[Transform | Callable]
- torchsig.utils.verify.verify_target_transforms(tt: TargetTransform) List[TargetTransform | Callable] [source]¶
Verifies that the value tt is a valid target transform, which can be a single target transform or a list of transforms.
- Parameters:
tt (TargetTransform) – The target transform(s) to be checked.
- Raises:
ValueError – If tt is not a valid target transform.
- Returns:
The verified list of target transforms.
- Return type:
List[TargetTransform | Callable]
- torchsig.utils.verify.verify_dict(d: dict, name: str, required_keys: list = [], required_types: list = [])[source]¶
Verifies that the value d is a dictionary and optionally checks for required keys and their types.
- Parameters:
d (dict) – The value to be checked.
name (str) – The name of the value to be used in error messages.
required_keys (list, optional) – A list of required keys in the dictionary. Defaults to an empty list.
required_types (list, optional) – A list of types for each required key. Defaults to an empty list.
- Raises:
ValueError – If d is not a dictionary, or if any required key is missing or has an incorrect type.
- Returns:
The verified dictionary d.
- Return type:
Printing Utils¶
Contains Helpful methods for properly implementing __str__ and __repr__ methods of classes
- torchsig.utils.printing.generate_repr_str(class_object: Any, exclude_params: List[str] = []) str [source]¶
Generates a string representation of the class object, excluding specified parameters.
This function creates a human-readable string representation of the given class object, including its class name and parameters. It excludes any parameters specified in the exclude_params list. If the class object is an instance of Seedable, certain attributes related to seeding are handled specifically.
- Parameters:
class_object (Any) – The class object to generate the string representation for.
exclude_params (List[str], optional) – A list of parameter names to exclude from the string representation. Defaults to an empty list.
- Returns:
A formatted string representation of the class object with parameters.
- Return type:
- Raises:
AttributeError – If the class object does not have a __dict__ attribute or any other required attributes for the operation.
Example
>>> class Example: >>> def __init__(self, param1, param2): >>> self.param1 = param1 >>> self.param2 = param2 >>> e = Example(1, 2) >>> generate_repr_str(e) 'Example(param1=1,param2=2)'
Notes
If the class object is an instance of Seedable, the seed and parent attributes will be added back into the string representation.
- torchsig.utils.printing.dataset_metadata_str(dataset_metadata, max_width: int = 100, first_col_width: int = 29, array_width_indent_offset: int = 2) str [source]¶
Custom string representation for the class.
This method returns a formatted string that provides a detailed summary of the object’s key attributes, including signal parameters, dataset configuration, and transform details. It uses textwrap.fill to format long attributes such as lists or arrays into a neatly wrapped format for easier readability.
The string includes information on the dataset’s configuration, signal characteristics, transformations, and other attributes in a human-readable way. The result is intended to provide a concise yet comprehensive overview of the object’s state, useful for debugging, logging, or displaying object details.
- Parameters:
dataset_metadata (Any) – The dataset metadata object to generate a string for.
max_width (int, optional) – Maximum width of the output string. Defaults to 100.
first_col_width (int, optional) – Width of the first column in the output string. Defaults to 29.
array_width_indent_offset (int, optional) – Indentation offset for array-like attributes. Defaults to 2.
- Returns:
A formatted string that represents the object’s attributes in a readable format.
- Return type:
- Example Output:
` MyClass ---------------------------------------------------------------------------------------------------- num_iq_samples_dataset 1000 num_samples 5000 impairment_level 0.8 fft_size 512 sample_rate 1000.0 num_signals_min 1 num_signals_max 5 num_signals_distribution [0.2, 0.3, 0.5] snr_db_min 5.0 snr_db_max 30.0 signal_duration_percent_min 0.1 signal_duration_percent_max 0.9 transforms [TransformA, TransformB] target_transforms [TargetTransform1] class_list [Class1, Class2, Class3] class_distribution [0.3, 0.4, 0.3] seed 42 `
- torchsig.utils.printing.dataset_metadata_repr(dataset_metadata) str [source]¶
Return a string representation of the object for debugging and inspection.
This method generates a string that provides a concise yet detailed summary of the object’s state, useful for debugging or interacting with the object in an interactive environment (e.g., REPL, Jupyter notebooks).
The __repr__ method is intended to give an unambiguous, readable string that represents the object. The returned string includes key attributes and their values, formatted in a way that can be interpreted back as code, i.e., it aims to provide a string that could be used to recreate the object (though not necessarily identical, as it is for debugging purposes).
- Returns:
- A detailed, formatted string that represents the object’s state, showing
key attributes and their current values.
- Return type:
Randomization Utils¶
Utility to handle random number generators.
- class torchsig.utils.random.Seedable(seed: int | None = None, parent=None)[source]¶
Bases:
object
A class/interface representing objects capable of accessing random numbers and being seeded. Stores an inernal random number generator object. Can be seeded with the Seedable.seed(seed_value : long) function. Two Seedable objects with the same seed will always generate/access the same random values in the same order. Containing or composing Seedable objects are generally responsible for seeding contained or composed Seedable objects.
- seed(seed: int) None [source]¶
Seed number generators with given seed.
- Parameters:
seed (int) – Seed to use.