from typing import Dict, List, Optional, Tuple

import torch
import random

[docs]class StatefulLayer(torch.nn.Module): """A base class that instantiates buffers/states which update at every time step and provides helper methods that manage those states. Parameters: state_names: the PyTorch buffers to initialise. These are not parameters. """ def __init__(self, state_names: List[str]): super().__init__() for state_name in state_names: self.register_buffer(state_name, torch.zeros((0)))
[docs] def zero_grad(self, set_to_none: bool = False) -> None: r""" Zero's the gradients for buffers/state along with the parameters. See :meth:`torch.nn.Module.zero_grad` for details """ # Zero grad parameters super().zero_grad(set_to_none) if self.is_state_initialised(): # Zero grad buffers for b in self.buffers(): if b.grad_fn is not None: b.detach_() else: b.requires_grad_(False)
[docs] def forward(self, *args, **kwargs): """ Not implemented - You need to implement a forward method in child class """ raise NotImplementedError( "No forward method has been implemented for this class" )
[docs] def is_state_initialised(self) -> bool: """Checks if buffers are of shape 0 and returns True only if none of them are.""" for buffer in self.buffers(): if buffer.shape == torch.Size([0]): return False return True
[docs] def state_has_shape(self, shape) -> bool: """Checks if all state have a given shape.""" for buff in self.buffers(): if buff.shape != shape: return False return True
[docs] def handle_state_batch_size_mismatch(self, new_batch_size: int): """Handles the state mismatch based on the new batch size by randomly selecting `new_batch_size` number of states from the previous batch_size in a repated way. Args: new_batch_size: int New batch size. """ for name, buffer in self.named_buffers(): indices = torch.randint(low=0, high=buffer.shape[0], size=(new_batch_size,)).to(buffer.device) new_buffer = torch.index_select(buffer, 0, indices) self.register_buffer(name, new_buffer)
[docs] def has_trailing_dimension(self, trailing_dim: Tuple[int, int, int]) -> bool: """Checks if the trailing dimension (ch, y, x) matches the given. Args: trailing_dim: Tuple[int, int, int] Three tuple in (channel, y, x) dimensions. Returns: bool: Whether all the states dimensions match. """ for buff in self.buffers(): if buff.shape[1:] != torch.Size(trailing_dim): return False return True
[docs] def init_state_with_shape(self, shape, randomize: bool = False) -> None: """Initialise state/buffers with either zeros or random tensor of specific shape.""" for name, buffer in self.named_buffers(): self.register_buffer(name, torch.zeros(shape, device=buffer.device)) self.reset_states(randomize=randomize)
[docs] def reset_states( self, randomize: bool = False, value_ranges: Optional[Dict[str, Tuple[float, float]]] = None, ): """Reset the state/buffers in a layer. Parameters: randomize: If true, reset the states between a range provided. Else, the states are reset to zero. value_ranges: A dictionary of key value pairs: buffer_name -> (min, max) for each state that needs to be reset. The states are reset with a uniform distribution between the min and max values specified. Any state with an undefined key in this dictionary will be reset between 0 and 1 This parameter is only used if randomize is set to true. .. note:: If you would like to reset the state with a custom distribution, you can do this individually for each parameter as follows:: layer.<state_name>.data = <your desired data> layer.<state_name>.detach_() """ if self.is_state_initialised(): for name, buffer in self.named_buffers(): if randomize: if value_ranges and name in value_ranges: min_value, max_value = value_ranges[name] else: min_value, max_value = (0.0, 1.0) # Initialize with uniform distribution torch.nn.init.uniform_(buffer) # Rescale the value = buffer * (max_value - min_value) + min_value else: buffer.zero_() buffer.detach_()
def __repr__(self): param_strings = [ f"{key}={value}" for key, value in self._param_dict.items() if key in [ "tau_mem", "tau_syn", "tau_adapt", "adapt_scale", "spike_threshold", "min_v_mem", "norm_input", "batch_size", "num_timesteps", ] and value is not None ] param_strings = ", ".join(param_strings) return f"{self.__class__.__name__}({param_strings})" def __deepcopy__(self, memo=None): copy = self.__class__(**self._param_dict) # Copy parameters for name, param in self.named_parameters(): new_inst_param = getattr(copy, name) = # Copy buffers (using state dict will fail if buffers have non-default shapes) for name, buffer in self.named_buffers(): new_inst_buffer = getattr(copy, name) = # Copy parameters return copy @property def _param_dict(self) -> dict: """Dict of all parameters relevant for creating a new instance with same parameters as `self`.""" return dict() @property def arg_dict(self) -> dict: """A public getter function for the constructor arguments.""" return self._param_dict @property def does_spike(self) -> bool: """Return True if the layer has an activation function.""" return hasattr(self, "spike_fn") and self.spike_fn is not None