Source code for sinabs.layers.crop2d

from typing import List, Tuple, Union

import numpy as np
from torch import nn

ArrayLike = Union[np.ndarray, List, Tuple]


[docs]class Cropping2dLayer(nn.Module): """Crop input image by. Parameters: cropping: ((top, bottom), (left, right)) """ def __init__( self, cropping: ArrayLike = ((0, 0), (0, 0)), ): super().__init__() self.top_crop, self.bottom_crop = cropping[0] self.left_crop, self.right_crop = cropping[1]
[docs] def forward(self, binary_input): _, self.channels_in, h, w = list(binary_input.shape) # Crop the data array crop_out = binary_input[ :, :, self.top_crop : h - self.bottom_crop, self.left_crop : w - self.right_crop, ] self.out_shape = crop_out.shape[1:] self.spikes_number = crop_out.abs().sum() self.tw = len(crop_out) return crop_out
[docs] def get_output_shape(self, input_shape: Tuple) -> Tuple: """Retuns the output dimensions. Parameters: input_shape: (channels, height, width) Returns: (channels, height, width) """ channels, height, width = input_shape return ( channels, height - self.top_crop - self.bottom_crop, width - self.left_crop - self.right_crop, )