Source code for sinabs.backend.dynapcnn.crop2d

from typing import List, Tuple, Union

import numpy as np
from torch import nn

ArrayLike = Union[np.ndarray, List, Tuple]

[docs] class Crop2d(nn.Module): """Crop input image by.""" def __init__( self, cropping: ArrayLike = ((0, 0), (0, 0)), ): """Crop input to the the rectangle dimensions. :param cropping: ((top, bottom), (left, right)) """ super().__init__() self.top_crop, self.bottom_crop = cropping[0] self.left_crop, self.right_crop = cropping[1]
[docs] def forward(self, binary_input): # Crop the data array crop_out = binary_input[ :, :, self.top_crop : self.bottom_crop, self.left_crop : self.right_crop, ] self.out_shape = crop_out.shape[1:] self.spikes_number = crop_out.abs().sum() = len(crop_out) return crop_out
[docs] def get_output_shape(self, input_shape: Tuple) -> Tuple: """Retuns the output dimensions. :param input_shape: (channels, height, width) :return: (channels, height, width) """ channels, height, width = input_shape return ( channels, self.bottom_crop - self.top_crop, self.right_crop - self.left_crop, )
def __repr__(self): return f"Crop2d(({self.top_crop}, {self.bottom_crop}), ({self.left_crop}, {self.right_crop}))"