"""
zarr==2.18.3
"""
from __future__ import annotations
# TorchSig
from torchsig.utils.file_handlers.base_handler import TorchSigFileHandler
from torchsig.datasets.dataset_metadata import DatasetMetadata
from torchsig.datasets.dataset_utils import writer_yaml_name
# Third Party
import zarr
import numpy as np
# Built-In
from typing import TYPE_CHECKING, Tuple, List, Dict, Any
import os
import pickle
import yaml
[docs]
class ZarrFileHandler(TorchSigFileHandler):
"""Handler for reading and writing data to/from a Zarr file format.
This class extends the `TorchSigFileHandler` and provides functionality to handle
reading, writing, and managing Zarr-based storage for dataset samples.
Attributes:
datapath_filename (str): The name of the folder used to store the data in Zarr format.
"""
datapath_filename_base = "data"
[docs]
def __init__(
self,
root: str,
batch_size: int = 1,
):
"""Initializes the ZarrFileHandler
Args:
root (str): Where to write dataset on disk.
batch_size (int, optional): Size fo each batch write. Defaults to 1.
"""
super().__init__(
root = root,
batch_size = batch_size
)
self.datapath = f"{self.root}/{ZarrFileHandler.datapath_filename_base}"
# compressor
self.compressor = zarr.Blosc(
cname = 'zstd', # type
clevel = 4, # compression level
shuffle = 2 # use bit shuffle
)
[docs]
def exists(self) -> bool:
"""Checks if the Zarr file exists at the specified path.
Returns:
bool: True if the Zarr file exists, otherwise False.
"""
if os.path.exists(self.datapath):
return True
else:
return False
[docs]
def write(self, batch_idx: int, batch: Any) -> None:
"""Writes a sample (data and targets) to the Zarr file at the specified index.
Args:
idx (int): The index at which to store the data in the Zarr file.
data (np.ndarray): The data to write to the Zarr file.
targets (Any): The corresponding targets to write as metadata for the sample.
Notes:
If the index is greater than the current size of the array, the array is
expanded to accommodate the new sample.
"""
start_idx = batch_idx * self.batch_size
stop_idx = start_idx + len(batch[0])
data, targets = batch
# write batch of data into file
zarr_array = zarr.open(
# filenames will have 10 digits
# might need to change if you have more than 1 billion batches
f"{self.datapath}/{batch_idx:010}.zarr",
mode = 'w', # create or overwrite if exists
# array will be shape (num samples, num iq samples)
shape = (len(data),) + data[0].shape,
# Data type
dtype = data[0].dtype,
# compression
compressor = self.compressor
)
zarr_array[:] = np.array(data)
# add targets to zarr array attributes
attrs_dict = {str(start_idx + tidx): target for tidx, target in enumerate(targets)}
zarr_array.attrs.update(attrs_dict)
[docs]
@staticmethod
def size(dataset_path: str) -> int:
"""Return size of dataset
Args:
dataset_path (str): path to dataset on disk
Returns:
int: size of dataset
"""
# find batch size
batch_size = TorchSigFileHandler._calculate_batch_size(dataset_path)
# count number of files
all_zarr_arrays = sorted(os.listdir(f"{dataset_path}/{ZarrFileHandler.datapath_filename_base}"))
num_zarr_files = len(all_zarr_arrays)
# num files * batch size
size = batch_size * (num_zarr_files - 1)
# check last file, since it might have less than batch_size data points
last_array = zarr.open(f"{dataset_path}/{ZarrFileHandler.datapath_filename_base}/{all_zarr_arrays[-1]}", mode = 'r')
last_batch_size = last_array.shape[0]
# add size of last batch file
size += last_batch_size
return size
[docs]
@staticmethod
def static_load(filename:str, idx: int) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
"""Loads a sample from the Zarr file at the specified index (without instantiating a ZarrFileHandler)
Args:
filename (str): Path to the directory containing the Zarr file.
idx (int): The index of the sample to load.
Returns:
Tuple[np.ndarray, List[Dict[str, Any]]]: The data and the associated metadata for the sample.
Raises:
IndexError: If the index is out of bounds.
"""
# calculate batch size
batch_size = TorchSigFileHandler._calculate_batch_size(filename)
batch_idx = idx // batch_size
batch_file_idx = idx % batch_size
# find correct file
batch_filename = f"{batch_idx:010}.zarr"
# load in
# root/data/batch filename.zarr
zarr_arr = zarr.open(f"{filename}/{ZarrFileHandler.datapath_filename_base}/{batch_filename}", mode = 'r')
data = zarr_arr[batch_file_idx]
targets = zarr_arr.attrs[str(idx)]
# print(f"load: {targets}")
# print(data)
# breakpoint()
if isinstance(targets, tuple) or isinstance(targets, list):
# target has multiple outputs
if isinstance(targets[0], list):
# convert `wideband targets (2D list) to a list of tuples
# also convert any nested lists into tuples
targets = list(
tuple(item if not isinstance(item, list) else tuple(item) for item in target)
for target in targets
)
else:
# convert narrowband targets (1D list) to a tuple
targets = tuple(targets)
# else:
# narrowband target (single item), return itself
# print(f"post load: {targets}")
return data, targets