Source code for torchsig.utils.data_loading
"""Collate function and DataLoader with worker seeding for TorchSig.
Provides:
- metadata_padding_collate_fn: pads variable-length metadata in each batch.
- WorkerSeedingDataLoader: seeds each worker process differently for reproducibility.
"""
import warnings
import numpy as np
import torch
from torch.utils.data import DataLoader, get_worker_info
from torchsig.utils.random import Seedable
[docs]
def metadata_padding_collate_fn(batch):
"""Collate a batch of (data, metadata_list) pairs, padding metadata to equal lengths.
Metadata for each sample is a list of dicts. This function:
1. Finds the maximum metadata-list length in the batch.
2. Pads shorter metadata lists with default values.
3. Stacks data tensors and metadata fields into batched tensors.
Args:
batch: A list where each element is a tuple of:
- x: any object convertible to a NumPy array (e.g., tensor, array).
- y: a list of metadata dicts, where each dict shares the same set of keys.
Returns:
A tuple containing:
- data_tensor: stacked torch.Tensor of all x values, shape (batch_size, ...).
- metadata_tensors: dict mapping each metadata key to a Tensor of shape
(batch_size, max_sequence_length).
Raises:
ValueError: if any element in `batch` is not a tuple of length 2.
"""
default_y_value = 0
batch_max_len = 0
iqs = []
y_tensor_obj = {}
for data_pair in batch:
if not isinstance(data_pair, tuple) or len(data_pair) != 2:
raise ValueError(
f"{data_pair} is not a valid (x, y) pair; this collate function "
"expects datasets to return tuples of (x, y)"
)
_, metadata_list = data_pair
batch_max_len = max(batch_max_len, len(metadata_list))
for metadata_obj in metadata_list:
for key in metadata_obj:
if key not in y_tensor_obj:
y_tensor_obj[key] = []
iqs.append(data_pair[0])
if batch_max_len < 1:
# No metadata to pad, return raw list for metadata
return torch.Tensor(np.array(iqs)), y_tensor_obj
# Initialize per-key lists for each time step
for key in y_tensor_obj:
y_tensor_obj[key] = [[] for _ in range(batch_max_len)]
# Fill in metadata values or default where missing
for _, metadata_list in batch:
for i in range(batch_max_len):
if i < len(metadata_list):
metadata_obj = metadata_list[i]
# Use .items() here to iterate key-value pairs in y_tensor_obj
for key, value_lists in y_tensor_obj.items():
# Use .get() with default_y_value
value_lists[i].append(metadata_obj.get(key, default_y_value))
else:
for value_lists in y_tensor_obj.values():
value_lists[i].append(default_y_value)
# Convert lists to tensors, dropping invalid keys
final_tensor_obj = {}
for key, sequences in y_tensor_obj.items():
try:
final_tensor_obj[key] = torch.Tensor(np.array(sequences))
except (ValueError, TypeError, MemoryError) as e:
warnings.warn(
f"Dropping key value: '{key}' because it contained invalid tensor values: {type(e).__name__}",
stacklevel=2
)
return torch.Tensor(np.array(iqs)), final_tensor_obj
[docs]
class WorkerSeedingDataLoader(DataLoader, Seedable):
"""DataLoader that seeds each worker process differently using a shared seed.
This loader prohibits external `worker_init_fn` definitions and sets its own
init function to ensure reproducible randomness in multi-worker pipelines.
"""
[docs]
def __init__(self, dataset, seed=None, **kwargs):
"""Initialize DataLoader and Seedable, then assign custom worker init.
Args:
dataset: The dataset to load.
seed: Optional seed value. If None, a random seed is generated.
**kwargs: Passed to both `DataLoader` and `Seedable` initializers.
Raises:
ValueError: if `worker_init_fn` is provided in kwargs.
"""
if seed is None:
seed = np.random.randint(
1000
) # just pick a random seed if none is given
DataLoader.__init__(self, dataset, **kwargs)
Seedable.__init__(self, seed=seed)
dataset.seed(seed)
if self.worker_init_fn:
raise ValueError(
"No worker_init_fn should be given to WorkerSeedingDataLoader; "
"it will set its own worker_init_fn."
)
self.worker_init_fn = self.init_worker_seed
[docs]
def seed(self, seed_val):
"""Set the seed value for both the loader and its dataset.
Args:
seed_val: The seed value to set.
"""
Seedable.seed(self, seed_val)
self.dataset.seed(seed_val)
[docs]
def init_worker_seed(self, worker_id):
"""Set a unique random seed for each worker process.
Uses the shared `random_generator` from the `Seedable` mixin to derive
a new seed per `worker_id`.
Args:
worker_id: The integer ID of the worker process.
"""
seed = int(self.random_generator.random() * 100 + 1) * (worker_id + 1)
get_worker_info().dataset.seed(seed)