Source code for torchsig.utils.file_handlers.npy

r"""torchsig.utils.file_handlers.npy
================================
File-handler that exposes a directory of standard NumPy ``*.npy`` files as a
TorchSig :class:`~torchsig.signals.signal_types.Signal` dataset.
A TorchSig dataset is described by three co-located artefacts:
* **One or more ``*.npy`` files** - each file stores a 1-D NumPy array of
  complex samples.
* **A ``metadata.csv`` file** - one row per *global* waveform index,
  containing ``index,label,modcod,sample_rate``.
* **An ``info.json`` file** - a tiny JSON document that must contain at least
  ``{\"size\": <int>}`` and defines the advertised length of the dataset.
The heavy binary payload lives in the ``*.npy`` files; the human-readable
description lives in the CSV.  This separation keeps loading fast (memory-mapped
NumPy) while allowing easy inspection and editing of labels, modulation codes,
etc.
"""

# ----------------------------------------------------------------------
# Standard / third-party imports
# ----------------------------------------------------------------------
import bisect
import csv
import itertools
import json
from pathlib import Path

import numpy as np

# ----------------------------------------------------------------------
# TorchSig imports
# ----------------------------------------------------------------------
from torchsig.signals.signal_types import Signal
from torchsig.utils.file_handlers import FileReader


[docs] class NPYReader(FileReader): """Read a directory that contains ``*.npy`` files, a ``metadata.csv`` and an ``info.json``. The class presents the whole collection as a flat, indexable dataset: ``reader[idx]`` returns a :class:`~torchsig.signals.signal_types.Signal` whose ``data`` attribute holds the waveform (as a 1-D ``np.ndarray``) and whose ``metadata`` attribute holds the parsed CSV row for that index. Args: root: Path to the directory that holds the ``*.npy`` files, ``metadata.csv`` and ``info.json``. ``root`` may be a string or a :class:`pathlib.Path`. Attributes: npy_files: List[Path] - sorted list of discovered ``*.npy`` files. file_start_indices: List[int] - cumulative start index of each file in the global index space. total_elements: int - actual number of samples stored across all ``*.npy`` files. class_list: List[str] - ordered list of class names used to compute ``class_index``. dataset_size: int - size advertised by ``info.json`` (returned by ``len(reader)``). """ # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------
[docs] def __init__(self, root: str): super().__init__(root=root) # ------------------------------------------------------------------ # 0️⃣ Store the root directory (string → Path conversion is cheap) # ------------------------------------------------------------------ self.root_dir: str = root # ------------------------------------------------------------------ # 1️⃣ Discover all ``*.npy`` files, sorted alphabetically for # deterministic behaviour. # ------------------------------------------------------------------ self.npy_files = sorted(Path(root).glob("*.npy")) if not self.npy_files: raise FileNotFoundError("No .npy files found in directory.") # ------------------------------------------------------------------ # 2️⃣ Build a lookup table that maps a *global* index to the file that # contains it. ``self.file_start_indices[i]`` is the global index of # the first sample in ``self.npy_files[i]``. # ------------------------------------------------------------------ self.file_start_indices: list[int] = [] total = 0 for file_path in self.npy_files: # Memory-map the file only to read its shape; the data stays on disk. arr = np.load(file_path, mmap_mode="r") length = arr.shape[0] # number of waveforms in this file self.file_start_indices.append(total) total += length self.total_elements: int = total # ------------------------------------------------------------------ # 3️⃣ Known class names - used to translate a CSV label string into an # integer ``class_index``. # ------------------------------------------------------------------ self.class_list: list[str] = ["BPSK", "QPSK", "Noise"] # ------------------------------------------------------------------ # 4️⃣ Load the JSON info file - it tells us the **advertised** size. # ------------------------------------------------------------------ self.dataset_metadata: dict = self._load_json_metadata() self.dataset_size: int = 0 try: with open(f"{self.root}/info.json") as f: dataset_info = json.load(f) self.dataset_size = dataset_info["size"] except Exception as exc: raise ValueError(f"Error loading {self.root}/info.json") from exc
# ---------------------------------------------------------------------- # Helper - read the JSON once (used only for a clearer error message) # ---------------------------------------------------------------------- def _load_json_metadata(self) -> dict: """Load ``info.json`` and return the parsed dictionary. The method exists solely to raise a consistent :class:`ValueError` whenever the JSON file cannot be opened or parsed. The returned dict is stored in ``self.dataset_metadata`` for possible future introspection, but the current implementation only cares about the ``size`` field. """ try: with open(f"{self.root}/info.json") as f: return json.load(f) except Exception as exc: # pragma: no cover raise ValueError(f"Error loading {self.root}/info.json") from exc # ---------------------------------------------------------------------- # Public API - retrieve a single sample # ----------------------------------------------------------------------
[docs] def read(self, idx: int) -> Signal: """Return the waveform and its metadata for the *global* index ``idx``. Args: idx: Zero-based global index of the waveform to retrieve. Returns: Signal: A ``Signal`` whose ``data`` attribute is a ``np.ndarray`` of shape ``(1,)`` containing the complex sample, and whose ``metadata`` attribute holds the parsed CSV row for that index. Raises: IndexError: If ``idx`` is negative or greater than or equal to ``self.total_elements``. """ # -------------------------------------------------------------- # 0️⃣ Guard against out-of-range accesses # -------------------------------------------------------------- if idx < 0 or idx >= self.total_elements: raise IndexError( f"Index {idx} out of range (0 <= idx < {self.total_elements})." ) # -------------------------------------------------------------- # 1️⃣ Identify the file that contains this global index (binary search) # -------------------------------------------------------------- file_idx = bisect.bisect_right(self.file_start_indices, idx) - 1 in_file_idx = idx - self.file_start_indices[file_idx] # -------------------------------------------------------------- # 2️⃣ Load (memory-mapped) the required .npy file and fetch the scalar # -------------------------------------------------------------- file_path = self.npy_files[file_idx] arr = np.load(file_path, mmap_mode="r") # memmap view raw_sample = arr[in_file_idx] # scalar (real-only in fixtures) data = np.atleast_1d(raw_sample) # ensure ``len(data)`` works # -------------------------------------------------------------- # 3️⃣ Pull the associated CSV row (metadata) and apply the deterministic # class mapping. # -------------------------------------------------------------- with open(f"{self.root}/metadata.csv") as f: reader = csv.DictReader( f, fieldnames=["index", "label", "modcod", "sample_rate"] ) # get to idx row row = next(itertools.islice(reader, idx, idx + 1), None) if row is None: raise IndexError(f"Metadata idx {idx} is out of bounds") row["index"] = int(row["index"]) row["sample_rate"] = float(row["sample_rate"]) # add class_name row["class_name"] = row["label"].lower() # add class index row["class_index"] = self.class_list.index(row["label"]) row["num_signals_max"] = 1 metadata = row # -------------------------------------------------------------- # 4️⃣ Build the Signal object and hand it back to the caller # -------------------------------------------------------------- return Signal(data=data, component_signals=[], metadata=metadata)
# ---------------------------------------------------------------------- # Length protocol - reports the *advertised* dataset size from info.json # ---------------------------------------------------------------------- def __len__(self) -> int: """Return the size declared in ``info.json``. ``len(reader)`` may differ from the actual number of samples (``self.total_elements``) because the JSON file can deliberately lie about the size - this is useful for testing out-of-bounds handling. """ return self.dataset_size