Source code for torchsig.datasets.datamodules

"""PyTorch Lightning DataModules
Learn More: https://lightning.ai/docs/pytorch/stable/data/datamodule.html
If dataset does not exist at root, creates new dataset and writes to disk
If dataset does exist, simply loaded it back in
"""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

# Built-In
# Third Party
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split

# TorchSig
from torchsig.datasets.datasets import StaticTorchSigDataset, TorchSigIterableDataset
from torchsig.transforms.impairments import Impairments
from torchsig.utils.file_handlers.hdf5 import HDF5Reader, HDF5Writer
from torchsig.utils.writer import DatasetCreator

if TYPE_CHECKING:
    from torchsig.utils.file_handlers import BaseFileHandler

[docs] class TorchSigDataModule(pl.LightningDataModule): """PyTorch Lightning DataModule for creating and loading TorchSig datasets. This DataModule handles: - Dataset creation or loading from disk via a file handler. - Splitting into train/val/test subsets. - Batching, collation, and worker seeding for training. Attributes: root: Directory where datasets are stored or created. dataset_size: Total number of samples in the dataset. dataset_splits: Fractions or counts for train/val/test splits. dataset_metadata: Metadata describing the dataset. impairment_level: Optional interference level for synthetic impairments. transforms: Transforms applied to the input data. target_labels: Names of target metadata fields to include. batch_size: Batch size for the training/validation/testing DataLoaders. num_workers: Number of worker processes for data loading. collate_fn: Custom collate function for batching. create_batch_size: Batch size used during on-disk dataset creation. create_num_workers: Number of workers used during dataset creation. file_writer: FileHandler class for disk I/O. file_reader: FileReader class for disk I/O. overwrite: If True, existing on-disk data will be overwritten. seed: Optional random seed for reproducibility. train: Initialized training dataset (set in `setup()`). val: Initialized validation dataset (set in `setup()`). test: Initialized test dataset (set in `setup()`). """
[docs] def __init__( self, root: str, metadata, dataset_size: int, dataset_splits: list[float] | list[int] = [0.70, 0.20, 0.10], # dataloader params batch_size: int = 1, num_workers: int = 1, collate_fn: callable | None = None, # dataset creator params create_batch_size: int = 8, create_num_workers: int = 4, file_writer: BaseFileHandler = HDF5Writer, file_reader: BaseFileHandler = HDF5Reader, overwrite: bool = False, # transforms impairment_level: int = 0, transforms=[], target_labels: list[str] | None = None, seed: int | None = None, ): """Initialize the TorchSigDataModule. Args: root: Path to store or load the dataset. metadata: Metadata object, YAML file path, or dict describing classes and settings. dataset_size: Total number of samples to generate or load. dataset_splits: Fractions or counts for train/val/test splits. Defaults to [0.70, 0.20, 0.10]. batch_size: Batch size for data loaders. Defaults to 1. num_workers: Number of worker processes for data loading. Defaults to 1. collate_fn: Custom function to collate batch samples. Defaults to None. create_batch_size: Batch size when writing data to disk. Defaults to 8. create_num_workers: Workers used when creating the on-disk dataset. Defaults to 4. file_writer: FileWriter class for disk I/O. file_reader: FileReader class for disk I/O. overwrite: If True, existing data at `root` will be overwritten. Defaults to False. impairment_level: Level of synthetic impairment to apply. Defaults to 0 (no impairment). transforms: List of transforms applied to each sample's input. Defaults to []. target_labels: Names of metadata fields to include. Defaults to None. seed: Seed for randomness and reproducibility. Defaults to None. Raises: ValueError: If dataset_splits don't sum to 1.0 (when using fractions). FileNotFoundError: If metadata file path is invalid. """ # read from yaml or dataset metadata or code inputs super().__init__() # filepaths self.root = Path(root) self.dataset_size = dataset_size self.dataset_splits = dataset_splits # metadatas self.metadata = metadata # transforms, based on Impairment level self.impairment_level = impairment_level impairments = Impairments(level=impairment_level) self.burst_impairments = impairments.signal_transforms self.whole_signal_impairments = impairments.dataset_transforms self.transforms = [self.whole_signal_impairments, *transforms] self.target_labels = target_labels # initialize dataloader params self.batch_size = batch_size self.num_workers = 0 if num_workers is None else num_workers self.collate_fn = collate_fn # dataset creator params self.create_batch_size = create_batch_size self.create_num_workers = create_num_workers self.file_writer = file_writer self.file_reader = file_reader self.overwrite = overwrite # to be initialized in setup() self.train: StaticTorchSigDataset = None self.val: StaticTorchSigDataset = None self.test: StaticTorchSigDataset = None self.seed = seed
[docs] def prepare_data(self) -> None: """Prepares the dataset by creating new datasets if they do not exist on disk. The datasets are created using the `DatasetCreator` class. If the dataset already exists on disk, it is loaded back into memory. Raises: FileNotFoundError: If the root directory cannot be created. RuntimeError: If dataset creation fails. """ dataset = TorchSigIterableDataset( metadata=self.metadata, transforms=self.transforms, component_transforms=[self.burst_impairments], target_labels=self.target_labels, seed=self.seed ) loader = DataLoader( dataset=dataset, batch_size=self.create_batch_size, collate_fn=self.collate_fn, ) creator = DatasetCreator( dataloader=loader, dataset_length=self.dataset_size, root=self.root, overwrite=self.overwrite, file_writer=self.file_writer, ) print(f"Full Dataset: Impairment Level {self.impairment_level}, {self.dataset_size} dataset size") creator.create()
[docs] def setup(self, stage: str = "train") -> None: """Sets up the train and validation datasets for the given stage. Args: stage: The stage of the DataModule, typically 'train' or 'test'. Defaults to 'train'. Raises: FileNotFoundError: If the dataset files are not found at the specified root. ValueError: If dataset splits are invalid. """ full_dataset = StaticTorchSigDataset( root=self.root, file_handler_class=self.file_reader, target_labels=self.target_labels, ) self.train, self.val, self.test = random_split( full_dataset, self.dataset_splits )
[docs] def train_dataloader(self) -> DataLoader: """Returns the DataLoader for the training dataset. Returns: A PyTorch DataLoader for the training dataset. Raises: RuntimeError: If the training dataset is not initialized. """ return DataLoader( dataset=self.train, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn, num_workers=self.num_workers, pin_memory=True, )
[docs] def val_dataloader(self) -> DataLoader: """Returns the DataLoader for the validation dataset. Returns: A PyTorch DataLoader for the validation dataset. Raises: RuntimeError: If the validation dataset is not initialized. """ return DataLoader( dataset=self.val, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn, num_workers=self.num_workers, pin_memory=True, )
[docs] def test_dataloader(self) -> DataLoader: """Returns the DataLoader for the test dataset. Returns: A PyTorch DataLoader for the test dataset. Raises: RuntimeError: If the test dataset is not initialized. """ return DataLoader( dataset=self.test, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn, num_workers=self.num_workers, pin_memory=True, )