"""Dataset Writer Utils"""
from __future__ import annotations
import concurrent.futures
import os
import threading
from pathlib import Path
from shutil import disk_usage
from time import time
import warnings
# Built-In
from typing import Any, TYPE_CHECKING
import numpy as np
from torch.utils.data._utils.collate import default_collate as torch_default_collate
# Third Party
from tqdm.auto import tqdm
# TorchSig
from torchsig.utils.file_handlers.hdf5 import HDF5Writer
from torchsig.utils.yaml import write_dict_to_yaml
if TYPE_CHECKING:
from torch.utils.data import DataLoader
from torchsig.utils.file_handlers.base_handler import FileWriter
[docs]
def default_collate_fn(batch):
"""Collates a batch by zipping its elements together.
Args:
batch (tuple): A batch from the dataloader.
Returns:
tuple: A tuple of zipped elements, where each element corresponds to a single batch item.
"""
return tuple(zip(*batch))
[docs]
class DatasetCreator:
"""Class for creating a dataset and saving it to disk in batches.
This class generates a dataset if it doesn't already exist on disk.
It processes the data in batches and saves it using a specified file handler.
The class allows setting options like whether to overwrite existing datasets,
batch size, and number of worker threads.
Attributes:
dataloader (DataLoader): The DataLoader used to load data in batches.
root (Path): The root directory where the dataset will be saved.
overwrite (bool): Flag indicating whether to overwrite an existing dataset.
tqdm_desc (str): A description for the progress bar.
file_handler (FileWriter): The file handler used for saving the dataset.
"""
[docs]
def __init__(
self,
dataloader: DataLoader = None,
dataset_length: int | None = None,
root: str = ".",
overwrite: bool = True, # will overwrite any existing dataset on disk
tqdm_desc: str | None = None,
file_handler: FileWriter = HDF5Writer,
multithreading: bool = True,
**kwargs, # any additional file handler args
):
"""Initializes the DatasetCreator.
Args:
dataloader (DataLoader): The DataLoader used to load data in batches.
dataset_length (int): The number of samples to draw from a dataset.
root (Path): The root directory where the dataset will be saved.
overwrite (bool): Flag indicating whether to overwrite an existing dataset.
tqdm_desc (str): A description for the progress bar.
file_handler (FileWriter): The file handler used for saving the dataset.
"""
self.root = Path(root)
self.dataset_info_filepath = self.root.joinpath("dataset_info.yaml")
self.writer_info_filepath = self.root.joinpath("writer_info.yaml")
self.dataset_length = dataset_length
self.overwrite = overwrite
self.batch_size = dataloader.batch_size
self.num_workers = dataloader.num_workers
self.multithreading = multithreading
self.num_batches = self.dataset_length // self.batch_size
if not np.equal(self.dataset_length % self.batch_size, 0):
self.num_batches += (
1 # include the partial batch at the end if it can't be evenly batched
)
self.dataloader = dataloader
if (
self.dataloader.dataset.target_labels is None
and self.dataloader.collate_fn == torch_default_collate
):
# DataLoader should just return Signal objects
# do not use torch's default collate function
self.dataloader.collate_fn = lambda x: x
self.file_handler = file_handler
# get reference to tqdm progress bar object
self.pbar = tqdm()
self.tqdm_desc = "Generating Dataset:" if tqdm_desc is None else tqdm_desc
# limit in gigabytes for remaining space on disk for which writer stops writing
self.minimum_remaining_disk_gigabytes = 1
# Thread lock for updating tqdm message to avoid race conditions
self._tqdm_lock = threading.Lock()
self._msg_timer = None
[docs]
def get_writing_info_dict(self) -> dict[str, Any]:
"""Returns a dictionary with information about the dataset being written.
This method gathers information regarding the root, overwrite status,
batch size, number of workers, file handler class, and the save type
of the dataset.
Returns:
Dict[str, Any]: Dictionary containing the dataset writing configuration.
"""
return {
"root": str(self.root),
"overwrite": self.overwrite,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"complete": False,
}
def _write_batch(self, writer, batch_idx: int, batch: Any):
"""Multi-threaded writer batch
Args:
batch_idx (int): batch index
batch (Any): batch
"""
try:
# write to disk
writer.write(batch_idx, batch)
finally:
# Clear batch reference to help garbage collection
del batch
[docs]
def create(self) -> None:
"""Creates the dataset on disk by writing batches to the file handler.
This method generates the dataset in batches and saves it to disk. If the
dataset already exists and `overwrite` is set to False, it will skip regeneration.
The method also writes the dataset metadata and writing information to YAML files.
Raises:
ValueError: If the dataset is already generated and `overwrite` is set to False.
"""
temp_labels = self.dataloader.dataset.target_labels
self.dataloader.dataset.target_labels = None
with self.file_handler(root=self.root) as writer:
if writer.exists() and not self.overwrite:
complete, different_params = self.check_yamls()
if np.equal(len(different_params), 0) and complete:
print(f"Dataset already exists in {self.root}. Not regenerating.")
return
if not complete:
# dataset on disk is corrupted
# dataset was not fully written to disk
raise RuntimeError(
f"Dataset only partially exists in {self.root} (writing dataset to disk was cancelled early). Regenerate the dataset by setting overwrite = True for DatasetCreator"
)
# dataset exists on disk with different params
# use dataset on disk instead
# warn users that params are different
print(
f"Dataset exists at {self.root} but is different than current dataset."
)
print("Differences:")
for row in different_params:
key, disk_value, current_value = row
print(f"\t{key} = {current_value} ({disk_value} found)")
print(
"If you want to overwrite dataset on disk, set overwrite = True for the DatasetCreator."
)
print("Not regenerating. Using dataset on disk.")
return
# generate info yamls
write_dict_to_yaml(self.writer_info_filepath, self.get_writing_info_dict())
# store start time
self._msg_timer = time()
# write dataset
if self.multithreading:
# write each batch as its own thread
# num_threads defaults to: min(32, os.cpu_count() + 4)
with concurrent.futures.ThreadPoolExecutor() as executor:
# Process batches in chunks to avoid memory buildup
batch_chunk_size = max(
1, min(100, self.num_batches) // 10
) # Process in smaller chunks
batch_iter = enumerate(self.dataloader)
processed_batches = 0
total_batches = self.num_batches
# Process in chunks to manage memory
while processed_batches < total_batches:
# Get next chunk of batches
chunk_futures = []
chunk_size = 0
for _ in range(
min(batch_chunk_size, total_batches - processed_batches)
):
try:
batch_idx, batch = next(batch_iter)
if batch_idx == self.num_batches - 1 and not np.equal(
self.dataset_length % self.batch_size, 0
):
batch = batch[
: self.dataset_length % self.batch_size
]
future = executor.submit(
self._write_batch, writer, batch_idx, batch
)
chunk_futures.append(future)
chunk_size += 1
except StopIteration:
break
# Only process if we have futures to process
if chunk_futures:
# Wait for chunk to complete before processing next chunk
concurrent.futures.wait(chunk_futures)
# Clear references to help garbage collection
for future in chunk_futures:
future.result() # Ensure completion
del chunk_futures
processed_batches += chunk_size
# Force garbage collection between chunks
import gc
gc.collect()
else:
# No more batches to process
break
else:
# single threaded writing
itr = iter(self.dataloader)
for batch_idx in tqdm(range(self.num_batches), total=self.num_batches):
batch = next(itr)
if batch_idx == self.num_batches - 1 and not np.equal(
self.dataset_length % self.batch_size, 0
):
batch = batch[: self.dataset_length % self.batch_size]
try:
# write to disk
self._write_batch(writer, batch_idx, batch)
# update progress bar message
self._update_tqdm_message(batch_idx)
finally:
# Clear batch reference to help garbage collection
del batch
# Force garbage collection every 10 batches
if np.equal(batch_idx % 10, 0):
import gc
gc.collect()
# update writer yaml
# indicate writing dataset to disk was successful
updated_writer_yaml = self.get_writing_info_dict()
updated_writer_yaml["complete"] = True
write_dict_to_yaml(self.writer_info_filepath, updated_writer_yaml)
self.dataloader.dataset.target_labels = temp_labels
def _update_tqdm_message(self, batch_idx: int):
"""Updates the tqdm progress bar with remaining disk space
Informs the user how much remaining space left (in gigabytes) is
on their disk. Includes a check to stop writing to disk in case
the disk is at risk of being completely filled.
Raises:
ValueError: If the disk space remaining is below a threshold
"""
with self._tqdm_lock:
# compute elapsed time since last run
elapsed_time = time() - self._msg_timer
# run every second, but wait until 20 iterations have
# passed in order to create a more realiable estimate
if not batch_idx or elapsed_time > 1:
# get the amount of disk space remaining
disk_size_available_bytes = disk_usage(self.root)[2]
# convert to GB and round to two decimal places
disk_size_available_gigabytes = np.round(
disk_size_available_bytes / (1000**3), 2
)
# get size of dataset written so far
dataset_size_current_gigabytes = self._get_directory_size_gigabytes(self.root)
# num samples processed and remaining
num_samples_written = (batch_idx + 1) * self.batch_size
num_samples_remaining = self.dataset_length - num_samples_written
# estimate size per sample
dataset_size_per_sample_gigabytes = (
dataset_size_current_gigabytes / num_samples_written
)
# predict estimated size
dataset_size_remaining_gigabytes = np.round(
dataset_size_per_sample_gigabytes * num_samples_remaining, 2
)
# estimate total dataset size
dataset_size_total_gigabytes = np.round(
dataset_size_per_sample_gigabytes * self.dataset_length, 2
)
# concatenate disk size for progress bar message
updated_tqdm_desc = f"{self.tqdm_desc} estimated dataset size = {dataset_size_total_gigabytes} GB, dataset remaining = {dataset_size_remaining_gigabytes} GB, remaining disk = {disk_size_available_gigabytes} GB"
# avoid crashing by stopping write process
if (
disk_size_available_gigabytes
< self.minimum_remaining_disk_gigabytes
):
# remaining disk size is below a hard cutoff value to avoid crashing operating system
raise ValueError(
f"Disk nearly full! Remaining space is {disk_size_available_gigabytes} GB. Please make space before continuing."
)
if dataset_size_remaining_gigabytes > disk_size_available_gigabytes:
# projected size of dataset too large for available disk space
raise ValueError(
f"Not enough disk space. Projected dataset size is {dataset_size_remaining_gigabytes} GB. Remaining space is {disk_size_available_gigabytes} GB. Please reduce dataset size or make space before continuing."
)
# set the progress bar message
self.pbar.set_description(updated_tqdm_desc)
def _get_directory_size_gigabytes(self, start_path: str | Path) -> float:
"""Calculate the total size of a directory (including subdirectories) in gigabytes.
This function recursively walks through all files in the specified directory
and its subdirectories, summing their sizes. Files that cannot be accessed
(due to permissions, deletion, etc.) are skipped with a warning.
Args:
start_path: Path to the directory to calculate size for. Can be either
a string or Path object.
Returns:
Total size of the directory in gigabytes as a float.
Raises:
NotADirectoryError: If the provided path is not a directory.
FileNotFoundError: If the provided path doesn't exist.
"""
total_size = 0
start_path = Path(start_path)
# Validate the path
if not start_path.exists():
raise FileNotFoundError(f"Path does not exist: {start_path}")
if not start_path.is_dir():
raise NotADirectoryError(f"Path is not a directory: {start_path}")
for path, _, files in os.walk(start_path):
for f in files:
fp = Path(path) / f
try:
total_size += fp.stat().st_size
except (OSError, FileNotFoundError) as e:
# Skip files that can't be accessed
warnings.warn(
f"Skipping file {fp} due to error: {e}",
RuntimeWarning,
stacklevel=2
)
continue
return total_size / (1000**3)