Source code for torchsig.image_datasets.datasets.synthetic_signals

import numpy as np
import torch
from torch.utils.data import Dataset
import cv2

from torchsig.image_datasets.transforms.impairments import BlurTransform, normalize_image, pad_border

"""
A Dataset class for generating 2d greyscale image data
Inputs:
    generator_function: a function taking no arguments which returns a generated image
    class_id: the integer class id to associate with this image type; must be specified; generally should not be the same as other instances
    transforms: either a single function or list of functions from images to images to be applied to each generated image; used for adding noise and impairments to data; defaults to None
"""
[docs] class GeneratorFunctionDataset(Dataset):
[docs] def __init__(self, generator_function, transforms = None): self.generator_function = generator_function self.transforms = transforms
def __len__(self): return 1 # this is somewhat arbitrary; it will generate as many instances as are asked for def __getitem__(self, idx): #image = normalize_image(self.generator_function()) image = self.generator_function() if self.transforms: if type(self.transforms) == list: for transform in self.transforms: image = transform(image) else: image = self.transforms(image) return image
[docs] def next(self): return self[0]
""" Takes no arguments and returns a 2d ndarray representing the spectrogram of a randomly generated tone Inputs: tone_width: the width of the tone to be generated; corresponds to the length in time of a simulated signal Outputs: 2d ndarray representing the spectrogram of a randomly generated tone """
[docs] def generate_tone(tone_width: int, max_height: int = 10, min_height: int = 3): height = np.random.randint(min_height, high=max_height+1) width = tone_width first_axis = torch.arange(height)*(3.141592653589792*2/height) second_axis = torch.ones(width) image = -torch.cos(torch.matmul(first_axis[:,None], second_axis[None,:])).unsqueeze(0) return image
""" curried implementation of 'generate_tone' """
[docs] def tone_generator_function(tone_width: int, max_height: int = 10, min_height: int = 3): return lambda: generate_tone(tone_width, max_height=max_height, min_height=min_height)
""" Takes no arguments and returns a 2d ndarray representing the spectrogram of a randomly generated tone Inputs: chirp_width: the width of the chirp to be generated; corresponds to the thickness of the chirp in the spectrogram height: the height of the chirp to be generat Outputs: 2d ndarray representing the spectrogram of a randomly generated chirp """
[docs] def generate_chirp(chirp_width: int, height: int, width: int, random_height_scale: float = [1,1], random_width_scale: float = [1,1]): thickness = chirp_width x_size = int(width * (np.random.rand()*(random_width_scale[1] - random_width_scale[0]) + random_width_scale[0])) y_size = int(height * (np.random.rand()*(random_height_scale[1] - random_height_scale[0]) + random_height_scale[0])) img = np.zeros([y_size,x_size,3]) cv2.line(img, (x_size,0), (0,y_size), (255,255,255), thickness) img = (img[:,:,0]/255)[None,:,:] return torch.Tensor(img)
""" curried implementation of 'generate_chirp' """
[docs] def chirp_generator_function(chirp_width: int, height: int, width: int, random_height_scale: float = [1,1], random_width_scale: float = [1,1]): return lambda: generate_chirp(chirp_width, height, width, random_height_scale=random_height_scale, random_width_scale=random_width_scale)
""" Takes no arguments and returns a 2d ndarray representing the spectrogram of a randomly generated 'signal' rectangle Inputs: tone_width: the width of the tone to be generated; corresponds to the length in time of a simulated signal Outputs: 2d ndarray representing the spectrogram of a randomly generated signal """
[docs] def generate_rectangle_signal(min_width: int = 10, max_width: int = 100, max_height: int = 50, min_height: int = 5, use_blur=True): height = np.random.randint(min_height, high=max_height+1) width = np.random.randint(min_width, high=max_width+1) image = torch.ones([1,height,width])#.unsqueeze(0) #if use_blur: # blur_transform = BlurTransform(blur_shape=max([x//4 for x in image.shape]), strength=1) # image = blur_transform(image) return pad_border(image,1)
""" curried implementation of 'generate_rectangle_signal' """
[docs] def rectangle_signal_generator_function(min_width: int = 10, max_width: int = 100, max_height: int = 50, min_height: int = 5, use_blur=True): return lambda: generate_rectangle_signal(min_width= min_width, max_width=max_width, max_height=max_height, min_height=min_height, use_blur=use_blur)
""" Takes in a function which returns an image representing a signal, and returns that image repeated with a fixed offset Inputs: generator_fn: the function called to produce the signal to repeat min_gap: the smallest allowable interval (in pixels) between signal repetitions max_gap: the largest allowable interval (in pixels) between signal repetitions repeat_axis: the axis over which the signal repeats min_repeats: the fewest repeats allowed max_repeats: the most repeasts allowed Outputs: 2d ndarray representing the spectrogram of a randomly generated signal """
[docs] def generate_repeated_signal(generator_fn, min_gap: int = 2, max_gap: int = 10, repeat_axis=-1, min_repeats=8, max_repeats=16): signal = generator_fn() gap = np.random.randint(min_gap, high=max_gap+1) n_repeats = np.random.randint(min_repeats, high=max_repeats+1) signal_length = list(signal.shape)[repeat_axis] total_length = (signal_length + gap)*n_repeats - gap pad_shape = [0,0]*len(signal.shape) pad_shape[repeat_axis*2 + 1] = gap padded_signal = torch.nn.functional.pad(signal,pad_shape[::-1]) tile_shape = [1]*len(signal.shape) tile_shape[repeat_axis] = n_repeats - 1 repeated_signal = torch.concat([signal, padded_signal.tile(tile_shape)], dim=repeat_axis) return torch.Tensor(repeated_signal)
""" curried implementation of 'generate_repeated_signal' """
[docs] def repeated_signal_generator_function(generator_fn, min_gap: int = 2, max_gap: int = 10, repeat_axis=-1, min_repeats=30, max_repeats=50): return lambda: generate_repeated_signal(generator_fn, min_gap = min_gap, max_gap = max_gap, repeat_axis=repeat_axis, min_repeats=min_repeats, max_repeats=max_repeats)