"""Dataset Writer Utils"""
from __future__ import annotations
import concurrent.futures
import math
import threading
from dataclasses import dataclass
from pathlib import Path
from time import time
from typing import TYPE_CHECKING, Any
import yaml
from torch.utils.data._utils.collate import default_collate as torch_default_collate
from tqdm.auto import tqdm
from torchsig.utils.file_handlers.hdf5 import HDF5Writer
# TorchSig
from torchsig.utils.yaml import write_dict_to_yaml
if TYPE_CHECKING:
from torch.utils.data import DataLoader
from torchsig.utils.file_handlers.base_handler import FileWriter
[docs]
def default_collate_fn(batch):
"""Collates a batch by zipping its elements together. Note: not pickle-safe for complex
nested structures, but works for typical (data, label) batches.
Args:
batch (tuple): A batch from the dataloader.
Returns:
tuple: A tuple of zipped elements, where each element corresponds to a single batch item.
"""
return tuple(zip(*batch))
[docs]
def identity_collate_fn(batch):
"""Pickle-safe identity collate for Signal objects (returns list unchanged)."""
return batch
@dataclass(frozen=True)
class _DatasetExistenceProbe:
"""Configurable notion of 'dataset exists' without entering FileWriter.__enter__()."""
root: Path
maybe_data_file: Path | None
def exists(self) -> bool:
if not self.root.exists() or not self.root.is_dir():
return False
if self.maybe_data_file is not None:
return self.maybe_data_file.exists()
# fallback: any content
return any(self.root.iterdir())
def _deep_equal(a: Any, b: Any, *, float_rtol: float = 1e-9, float_atol: float = 0.0) -> bool:
"""Recursive equality for YAML-loaded structures (dict/list/scalars)."""
if a is b:
return True
if a is None or b is None:
return a is b
# Floats: tolerate tiny rounding changes from serialization/IO
if isinstance(a, (float, int)) and isinstance(b, (float, int)):
if isinstance(a, float) or isinstance(b, float):
return math.isclose(float(a), float(b), rel_tol=float_rtol, abs_tol=float_atol)
return int(a) == int(b)
# Dicts
if isinstance(a, dict) and isinstance(b, dict):
if set(a.keys()) != set(b.keys()):
return False
return all(_deep_equal(a[k], b[k], float_rtol=float_rtol, float_atol=float_atol) for k in a)
# Sequences
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
if len(a) != len(b):
return False
return all(_deep_equal(x, y, float_rtol=float_rtol, float_atol=float_atol) for x, y in zip(a, b, strict=True))
return a == b
[docs]
class DatasetCreator:
"""Class for creating a dataset and saving it to disk in batches.
This class generates a dataset if it does not already exist on disk.
It processes the data in batches and saves it using a specified file handler.
The class allows setting options like whether to overwrite existing datasets,
batch size, and number of worker threads.
Attributes:
dataloader (DataLoader): The DataLoader used to load data in batches.
root (Path): The root directory where the dataset will be saved.
overwrite (bool): Flag indicating whether to overwrite an existing dataset.
tqdm_desc (str): A description for the progress bar.
file_handler (FileWriter): The file handler used for saving the dataset.
"""
[docs]
def __init__(
self,
dataloader: DataLoader,
dataset_length: int | None = None,
root: str = ".",
overwrite: bool = True,
tqdm_desc: str | None = None,
file_handler: FileWriter = HDF5Writer,
multithreading: bool = True,
max_inflight_futures: int = 32,
**kwargs,
):
"""Initializes the DatasetCreator.
Args:
dataloader (DataLoader): DataLoader used to load data in batches. Required.
dataset_length (int): Number of dataset items to be created. Length inferrence attempted if not provided.
root (Path): Root directory where the dataset files will be saved. Defaults to current directory.
overwrite (bool): Flag indicating whether to overwrite an existing dataset. Defaults to True.
tqdm_desc (str): Description for the progress bar.
file_handler (FileWriter): File handler used to write dataset. Defaults to HDF5Writer.
multithreading (bool): Whether to use multithreading for writing batches. Defaults to True.
max_inflight_futures (int): Maximum number of concurrent futures when using multithreading. Defaults to 32.
**kwargs: Additional arguments for the file handler.
"""
# File attributes
self.root = Path(root)
self.dataset_info_filepath = self.root.joinpath("dataset_info.yaml")
self.writer_info_filepath = self.root.joinpath("writer_info.yaml")
self.file_handler = file_handler
self.kwargs = dict(kwargs)
self.overwrite = bool(overwrite)
self.multithreading = bool(multithreading)
self.max_inflight_futures = int(max_inflight_futures)
self.dataloader = dataloader
self.dataset_length_requested = self._infer_dataset_length(dataset_length)
# optional
self.batch_size = getattr(dataloader, "batch_size", None)
self.num_workers = getattr(dataloader, "num_workers", None)
self.tqdm_desc = "Generating Dataset:" if tqdm_desc is None else tqdm_desc
# counters
self.items_written = 0
self._counter_lock = threading.Lock()
self._msg_timer = None
def _infer_dataset_length(self, dataset_length: int | None) -> int:
"""Infer dataset length or require it explicitly for iterable datasets.
For map-style datasets, the length can be inferred from the dataset object.
For iterable datasets, the length cannot be inferred and must be provided.
Args:
dataset_length (int | None): The length of the dataset to be created. If None
the method will attempt to infer the length from the dataloader's dataset.
Returns:
int: dataset length.
"""
if dataset_length is not None:
return int(dataset_length)
# map-style datasets: try len(dataloader.dataset) to infer dataset length
try:
inferred_length = len(self.dataloader.dataset)
except Exception as e: # pylint: disable=broad-except
raise ValueError(
"dataset_length must be provided when writing from an IterableDataset "
"(e.g., TorchSigIterableDataset), because length cannot be inferred."
) from e
return int(inferred_length)
def _existence_probe(self) -> _DatasetExistenceProbe:
"""Instantiate writer without entering context to not enter setup, resetting folder."""
maybe_data_file = None
try:
writer = self.file_handler(root=self.root)
maybe_data_file = getattr(writer, "datapath", None)
if isinstance(maybe_data_file, (str, Path)):
maybe_data_file = Path(maybe_data_file)
except Exception:
# best-effort only
maybe_data_file = None
return _DatasetExistenceProbe(root=self.root, maybe_data_file=maybe_data_file)
def _get_dataset_metadata_dict(self) -> dict[str, Any]:
"""Best-effort extraction of dataset metadata for YAML.
Returns:
dict: dictionary containing dataset metadata information.
"""
ds = self.dataloader.dataset
if hasattr(ds, "get_full_metadata"):
return ds.get_full_metadata()
return {}
[docs]
def get_dataset_info_dict(self, *, dataset_length: int, original_target_labels: Any) -> dict[str, Any]:
"""Get metadata content for the dataset_info.yaml file.
Returns:
Dict[str, Any]: Dictionary containing the dataset metadata information.
"""
ds = self.dataloader.dataset
seed = getattr(ds, "rng_seed", None)
return {
"dataset_length": int(dataset_length),
"seed": None if seed is None else int(seed),
"target_labels": original_target_labels,
"dataset_metadata": self._get_dataset_metadata_dict(),
}
[docs]
def get_writer_info_dict(self, *, complete: bool) -> dict[str, Any]:
"""Returns a dictionary with information about the dataset writing configuration.
Used primarily for creating content for the writer_info.yaml summary file.
Returns:
Dict[str, Any]: Dictionary containing the dataset writing configuration.
"""
return {
"root": str(self.root),
"overwrite": bool(self.overwrite),
"batch_size": None if self.batch_size is None else int(self.batch_size),
"num_workers": None if self.num_workers is None else int(self.num_workers),
"file_handler": getattr(self.file_handler, "__name__", str(self.file_handler)),
"multithreading": bool(self.multithreading),
"dataset_length_requested": int(self.dataset_length_requested),
"items_written": int(self.items_written),
"complete": bool(complete),
"timestamp_unix": int(time()),
}
[docs]
def check_yamls(self, *, expected_dataset_info: dict[str, Any]) -> tuple[bool, list[tuple[str, Any, Any]]]:
"""Returns (complete, differences) without mutating dataset or entering writer context."""
differences: list[tuple[str, Any, Any]] = []
if not self.writer_info_filepath.exists():
return False, [("writer_info.yaml", "missing", "expected present")]
with open(self.writer_info_filepath) as f:
writer_disk = yaml.safe_load(f) or {}
complete = bool(writer_disk.get("complete", False))
if not self.dataset_info_filepath.exists():
differences.append(("dataset_info.yaml", "missing", "expected present"))
return complete, differences
with open(self.dataset_info_filepath) as f:
dataset_disk = yaml.safe_load(f) or {}
stable_keys = ["seed", "target_labels", "dataset_metadata"]
for k in stable_keys:
if k not in dataset_disk:
differences.append((k, "missing_on_disk", expected_dataset_info.get(k)))
continue
if not _deep_equal(dataset_disk.get(k), expected_dataset_info.get(k)):
differences.append((k, dataset_disk.get(k), expected_dataset_info.get(k)))
# Length must match requested for "no regeneration needed"
disk_len = dataset_disk.get("dataset_length", None)
if disk_len is None or int(disk_len) != int(self.dataset_length_requested):
differences.append(("dataset_length", disk_len, int(self.dataset_length_requested)))
return complete, differences
def _ensure_signal_batch_mode(self) -> tuple[Any, Any]:
"""Mutate dataset/dataloader so DataLoader yields Signal objects; return (orig_target_labels, orig_collate_fn)."""
ds = self.dataloader.dataset
orig_target_labels = getattr(ds, "target_labels", None)
orig_collate_fn = getattr(self.dataloader, "collate_fn", None)
# Force dataset to return Signal objects (TorchSigIterableDataset behavior)
if hasattr(ds, "target_labels"):
ds.target_labels = None
# Ensure DataLoader does not try to default-collate Signal objects into tensors
if getattr(ds, "target_labels", None) is None and self.dataloader.collate_fn in (torch_default_collate, default_collate_fn):
self.dataloader.collate_fn = identity_collate_fn
return orig_target_labels, orig_collate_fn
[docs]
def create(self) -> None:
"""Creates the dataset on disk by writing batches to the file handler.
This method generates the dataset in batches and saves it to disk. If the
dataset already exists and `overwrite` is set to False, it will skip regeneration.
The method also writes the dataset metadata and writing information to YAML files.
Raises:
ValueError: If the dataset is already generated and `overwrite` is set to False.
"""
ds = self.dataloader.dataset
orig_target_labels = getattr(ds, "target_labels", None)
orig_collate_fn = getattr(self.dataloader, "collate_fn", None)
expected_dataset_info = self.get_dataset_info_dict(
dataset_length=self.dataset_length_requested,
original_target_labels=orig_target_labels,
)
# Existence/overwrite decision before entering writer context.
probe = self._existence_probe()
if probe.exists() and not self.overwrite:
complete, diffs = self.check_yamls(expected_dataset_info=expected_dataset_info)
if complete and len(diffs) == 0:
print(f"Dataset already exists in {self.root}. Not regenerating.")
return
if not complete:
raise RuntimeError(
f"Dataset only partially exists in {self.root}. "
"Regenerate by setting overwrite=True."
)
print(f"Dataset exists at {self.root} but differs from current dataset config. Using dataset on disk.")
for k, disk_v, cur_v in diffs:
print(f"\t{k}: disk={disk_v} current={cur_v}")
return
# create dataset
try:
orig_target_labels, orig_collate_fn = self._ensure_signal_batch_mode()
self.items_written = 0
self._msg_timer = time()
with self.file_handler(root=self.root) as writer:
# Write initial YAMLs
write_dict_to_yaml(self.dataset_info_filepath, self.get_dataset_info_dict(
dataset_length=0, original_target_labels=orig_target_labels,
))
write_dict_to_yaml(self.writer_info_filepath, self.get_writer_info_dict(complete=False))
remaining = self.dataset_length_requested
# Best-effort progress bar total
total_batches = None
if isinstance(self.batch_size, int) and self.batch_size > 0:
total_batches = math.ceil(self.dataset_length_requested / self.batch_size)
pbar = tqdm(desc=self.tqdm_desc, total=total_batches)
if self.multithreading:
writer_lock = threading.Lock()
futures: list[concurrent.futures.Future[int]] = []
def submit_write(batch_idx: int, batch: Any) -> int:
# Lock writer.write for safety with HDF5Writer buffering
with writer_lock:
writer.write(batch_idx, batch)
return len(batch) if hasattr(batch, "__len__") else 1
batch_idx = 0
# Single executor; max_workers=1 is enough since writer calls are serialized
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for batch in self.dataloader:
if remaining <= 0:
break
if hasattr(batch, "__len__") and len(batch) > remaining:
batch = batch[:remaining]
batch_len = len(batch) if hasattr(batch, "__len__") else 1
futures.append(executor.submit(submit_write, batch_idx, batch))
batch_idx += 1
remaining -= batch_len
if len(futures) >= self.max_inflight_futures:
for fut in futures:
self.items_written += fut.result()
pbar.update(1)
futures.clear()
for fut in futures:
self.items_written += fut.result()
pbar.update(1)
else: # single-threaded writing
batch_idx = 0
for batch in self.dataloader:
if remaining <= 0:
break
if hasattr(batch, "__len__") and len(batch) > remaining:
batch = batch[:remaining]
batch_len = len(batch) if hasattr(batch, "__len__") else 1
writer.write(batch_idx, batch)
batch_idx += 1
remaining -= batch_len
self.items_written += batch_len
pbar.update(1)
pbar.close()
# Validate after successful context close
if self.items_written != self.dataset_length_requested:
raise RuntimeError(
f"DatasetCreator wrote {self.items_written} samples, "
f"expected {self.dataset_length_requested}."
)
# Final YAML update
write_dict_to_yaml(self.dataset_info_filepath, self.get_dataset_info_dict(
dataset_length=self.items_written,
original_target_labels=orig_target_labels,
))
write_dict_to_yaml(self.writer_info_filepath, self.get_writer_info_dict(complete=True))
finally:
# Always restore caller-visible state
if hasattr(ds, "target_labels"):
ds.target_labels = orig_target_labels
if orig_collate_fn is not None:
self.dataloader.collate_fn = orig_collate_fn