Source code for torchsig.models.model_utils.general_layers

from torch import mean
from torch.nn import Module, LSTM

[docs] class DebugPrintLayer(Module): """ A layer for debugging pytorch models; prints out the shape and data type of the input tensor at runtime returns he input tensor unchanged """
[docs] def __init__(self): super().__init__()
[docs] def forward(self, x): print(x.shape, x.dtype) return x
[docs] class ScalingLayer(Module): """ A layer that given input tensor x outputs scale_val * x used to linearly scale inputs by a fixed value """
[docs] def __init__(self, scale_val): super().__init__() self.scale_val = scale_val
[docs] def forward(self, x): return self.scale_val * x
[docs] class DropChannel(Module): """ A layer that drops the last color channel of an image [must be in channel-first form] """
[docs] def __init__(self): super().__init__()
[docs] def forward(self, x): return x[:,:-1,:,:]
[docs] class LSTMImageReader(Module):
[docs] def __init__(self, input_width, lstm_width, img_shape, num_layers=2): super().__init__() self.img_shape = img_shape self.img_height = img_shape[0] self.img_width = img_shape[1] self.input_width = input_width self.lstm_width = lstm_width self.lstm_model = LSTM(self.input_width,self.lstm_width,num_layers,True,True,0,False,self.img_height)
[docs] def forward(self, x): output, (h,c) = self.lstm_model(x.transpose(1,2)) img_tensor = output.transpose(1,2)[:,:self.img_height,:self.img_width] #take only the last img_height entries in the outut sequence return img_tensor.reshape([x.size(0),1,self.img_height,self.img_width])
[docs] class Reshape(Module): """ A layer that reshapes the input tensor to a tensor of the given shape if keep_batch_dim is True (defaults to True), the batch dimension is excluded from the reshape operation; otherwise it is included """
[docs] def __init__(self, shape, keep_batch_dim=True): super(Reshape, self).__init__() self.shape = shape self.keep_batch_dim = keep_batch_dim
[docs] def forward(self, x): if self.keep_batch_dim: batch_dim = x.size(0) shape = [batch_dim] + list(self.shape) return x.view(shape) return x.view(self.shape)
[docs] class Mean(Module): """ A layer which returns the mean(s) along the dimension specified by dim of the input tensor """
[docs] def __init__(self, dim): super(Mean, self).__init__() self.dim = dim
[docs] def forward(self, x): return mean(x,self.dim)