"""HDF5 File Handler for TorchSig datasets.
High-performance HDF5 storage with optimized compression and chunking.
"""
from __future__ import annotations
# Built-In
import threading
from typing import Any
import h5py
# Third Party
import numpy as np
from torchsig import __version__ as torchsig_version
# TorchSig
from torchsig.signals.signal_types import Signal
from torchsig.utils.abstractions import HierarchicalMetadataObject
from torchsig.utils.file_handlers import BaseFileHandler, FileReader, FileWriter
def _load_metadata_lazy(*args, **kwargs):
"""Deferred to avoid touching torchsig.datasets at import time."""
from torchsig.datasets.dataset_metadata import load_dataset_metadata
return load_dataset_metadata(*args, **kwargs)
[docs]
def populate_hdf5_group_with_signal_data(group, signal):
"""Makes sure this and all parent metadata objects are represented in the hdf5 group (returns true iff a new group was added)"""
id_string = str(id(signal))
try:
# if there is already data awith this id, do nothing
temp = group[id_string]
return False
except KeyError:
# there is not already data with this id
try:
group.create_dataset(id_string, data=signal.data)
except ValueError:
pass
return True
[docs]
def populate_hdf5_group_with_component_signals(group, signal):
if len(signal.component_signals) > 0:
try:
group.create_dataset(
str(id(signal)),
data=[
str(id(component_signal))
for component_signal in signal.component_signals
],
)
except ValueError:
pass
return True
return False
def _populate_hdf5_group_with_signal(group, signal):
populate_hdf5_group_with_metadata(group["metadata"], signal)
populate_hdf5_group_with_signal_data(group["data"], signal)
populate_hdf5_group_with_component_signals(group["component_signals"], signal)
for component_signal in signal.component_signals:
_populate_hdf5_group_with_signal(group, component_signal)
[docs]
def populate_hdf5_group_with_signal(group, signal, index=True):
_populate_hdf5_group_with_signal(group, signal)
if index:
group["index"].create_dataset(
str(len(group["index"])), data=str(id(signal))
) # keep track of this index in a dataset
[docs]
def populate_hdf5_group_with_signals(group, signals, index=True):
for signal in signals:
populate_hdf5_group_with_signal(group, signal, index=index)
[docs]
class HDF5Writer(FileWriter):
"""Handles writing Signal data to HDF5 files with specified compression and buffering."""
[docs]
def __init__(
self,
root,
compression: str = "gzip",
compression_opts: int = 6,
shuffle: bool = True,
fletcher32: bool = True,
chunk_cache_size: int = 1024 * 1024 * 10, # 10MB cache
max_batches_in_memory: int = 4,
):
"""Initializes the HDF5FileHandler.
Args:
root (str): Where to write dataset on disk.
compression (str, optional): Compression algorithm ('gzip', 'szip', 'lzf'). Defaults to 'gzip'.
compression_opts (int, optional): Compression level (0-9 for gzip). Defaults to 6.
shuffle (bool, optional): Enable shuffle filter for better compression. Defaults to True.
fletcher32 (bool, optional): Enable Fletcher32 checksum filter. Defaults to True.
chunk_cache_size (int, optional): HDF5 chunk cache size in bytes. Defaults to 10MB.
max_batches_in_memory (int, optional): Maximum batches to keep in memory before flushing. Defaults to 4.
"""
# compression
self.compression = compression
self.compression_opts = compression_opts
self.shuffle = shuffle
self.fletcher32 = fletcher32
self.chunk_cache_size = chunk_cache_size
self.max_batches_in_memory = max_batches_in_memory
# Internal state
self._file = None
self._data_group = None
self._batch_buffer: list[tuple[int, Any]] = []
self._current_sample_index = 0
super().__init__(root=root)
self.datapath = self.root.joinpath("data.h5")
# Thread safety
self._lock = threading.Lock()
def _setup(self) -> None:
"""Set up HDF5 file and initial structure."""
# Create HDF5 file with optimized settings
self._file = h5py.File(
self.datapath,
"w",
libver="latest", # Use latest HDF5 format for better performance
swmr=False, # Single writer mode for dataset creation
rdcc_nbytes=self.chunk_cache_size, # Chunk cache size
rdcc_w0=0.75, # Chunk cache policy
)
# Set global attributes
self._file.attrs["torchsig_version"] = torchsig_version
self._file.attrs["compression"] = self.compression
self._file.attrs["created_by"] = "TorchSig HDF5FileHandler"
self._file.create_group("data")
self._file.create_group("metadata")
self._file.create_group("index")
self._file.create_group("component_signals")
[docs]
def teardown(self) -> None:
"""Clean up resources and close HDF5 file."""
# Flush any remaining data if buffer exists
if hasattr(self, "_batch_buffer") and self._batch_buffer:
self._flush_buffer()
# Close file
if hasattr(self, "_file") and self._file is not None:
try:
self._file.flush()
self._file.close()
except Exception:
pass # File might already be closed
del self._file
def _write_batch_to_hdf5(self, data) -> None:
"""Writes a batch of signals (as List[Signal]) to the file.
Args:
data (List[Signal]): The list of signals to write to the HDF5 file.
"""
populate_hdf5_group_with_signals(self._file, data)
def _flush_buffer(self) -> None:
"""Flush buffered batches to HDF5 file."""
if not self._batch_buffer:
return
# Ensure file is open for writing
if not self._file:
self._setup()
if not hasattr(self, "_lock"):
self._lock = threading.Lock()
with self._lock:
# Sort buffer by batch index to maintain order
self._batch_buffer.sort(key=lambda x: x[0])
# Process all batches in buffer
for batch_idx, data in self._batch_buffer:
# breakpoint()
self._write_batch_to_hdf5(data)
# Clear buffer
self._batch_buffer.clear()
# Force flush to disk
if self._file:
self._file.flush()
[docs]
def write(self, batch_idx: int, data) -> None:
"""Write a batch of data to HDF5 file.
Args:
batch_idx (int): Index of the batch being written.
data (Any): Signal data to write.
"""
# Add to buffer
self._batch_buffer.append((batch_idx, data))
# Flush buffer if it's getting too large
if len(self._batch_buffer) >= self.max_batches_in_memory:
self._flush_buffer()
def __len__(self) -> int:
"""Returns the total number of samples in the dataset."""
return len(self._file["index"])
[docs]
def handle_bytes_as_string(bts):
if isinstance(bts, bytes):
return str(bts.decode())
if isinstance(bts, np.ndarray):
if bts.dtype == np.dtype("O"):
return bts.astype(np.str_)
return bts
[docs]
def load_value_from_group(group, key):
return handle_bytes_as_string(group[key][()])
[docs]
def load_signal_from_group_by_id(group, id_str):
component_signals = []
try:
component_signals = [
load_signal_from_group_by_id(group, temp_id)
for temp_id in load_value_from_group(group["component_signals"], id_str)
]
except:
pass
signal = Signal(
data=load_value_from_group(group["data"], id_str),
component_signals=component_signals,
)
signal = fill_object_metadata_from_group_and_id(signal, group, id_str)
return signal
[docs]
def load_signal_from_group_by_index(group, ind):
id_str = load_value_from_group(group["index"], str(ind))
return load_signal_from_group_by_id(group, id_str)
[docs]
class HDF5Reader(FileReader):
"""Handles reading Signal data from HDF5 files."""
[docs]
def __init__(self, root) -> None:
"""Initializes the HDF5Reader.
Args:
root (str): The root directory containing the HDF5 file.
"""
super().__init__(root=root)
self.datapath = self.root.joinpath("data.h5")
self._file = h5py.File(self.datapath, "r")
[docs]
def read(self, idx: int) -> Signal:
"""Reads a single sample and its corresponding targets from the HDF5 file.
Args:
idx (int): The index of the sample to read.
Returns:
Signal: The sample as a Signal object.
"""
return load_signal_from_group_by_index(self._file, idx)
def __len__(self) -> int:
"""Returns the total number of samples in the dataset.
Returns:
int: The number of samples in the dataset.
"""
return len(self._file["index"])
[docs]
def teardown(self) -> None:
"""Closes the HDF5 file handle."""
if self._file:
self._file.close()
self._file = None
[docs]
class HDF5FileHandler(BaseFileHandler):
"""HDF5FileHandler creates a reader or writer for HDF5 files."""
reader_class: FileReader = HDF5Reader
writer_class: FileWriter = HDF5Writer
[docs]
@staticmethod
def create_handler(mode: str, root: str, **kwargs) -> HDF5Writer | HDF5Reader:
"""Creates an instance of HDF5Reader or HDF5Writer based on the mode.
Args:
mode (str): The mode, either "r" for read or "w" for write.
root (str): The root directory for the file handler.
**kwargs: Additional arguments for the file handler.
Returns:
HDF5Writer | HDF5Reader: The created file handler.
Raises:
ValueError: If the mode is invalid.
"""
if mode == "r":
return HDF5FileHandler.reader_class(root, **kwargs)
if mode == "w":
return HDF5FileHandler.writer_class(root, **kwargs)
raise ValueError(f"Invalid File Handler mode: {mode}")