Source code for torchsig.image_datasets.datasets.composites

import numpy as np
import torch
from torch.utils.data import Dataset

from torchsig.image_datasets.transforms.impairments import normalize_image

"""
Defines a concatinated Dataset of all the given datasets joined as one dataset object.
Inputs:
    component_datasets: a list of Dataset objects which contain instances of each component, represented as (image_component: ndarray(c,height,width), class_id: int)
    balance: whether or not to balance the dataset by selecting samples from component datasets at random
"""
[docs] class ConcatDataset(Dataset):
[docs] def __init__(self, component_datasets, balance = True, transforms = []): self.component_datasets = component_datasets self.balance = balance self.transforms = transforms
def __len__(self): if self.balance: return max([len(ds) for ds in self.component_datasets])*len(self.component_datasets) #assume the length wed get if we upsampled other datasets to max the largest component dataset else: return np.sum([len(ds) for ds in self.component_datasets]) def __getitem__(self, idx): x = None if self.balance: # ignores given index; will always be random with replacement ds = self.component_datasets[np.random.randint(0, len(self.component_datasets))] x = ds[np.random.randint(0, len(ds))] else: ds_ind = 0 ds = self.component_datasets[ds_ind] ds_idx = idx while(ds_idx >= len(ds)): ds_idx -= len(ds) ds_ind += 1 ds = self.component_datasets[ds_ind] x = ds[ds_idx] if self.transforms: if type(self.transforms) == list: for transform in self.transforms: x = transform(x) else: x = self.transforms(x) return x
[docs] def next(self): return self[np.random.randint(0,len(self))]