class sinabs.layers.StatefulLayer(state_names: List[str])[source]#

A base class that instantiates buffers/states which update at every time step and provides helper methods that manage those states.


state_names (List[str]) – the PyTorch buffers to initialise. These are not parameters.

property arg_dict: dict#

A public getter function for the constructor arguments.

property does_spike: bool#

Return True if the layer has an activation function.

forward(*args, **kwargs)[source]#

Not implemented - You need to implement a forward method in child class

handle_state_batch_size_mismatch(new_batch_size: int)[source]#

Handles the state mismatch based on the new batch size by randomly selecting new_batch_size number of states from the previous batch_size in a repated way. :param new_batch_size: int

New batch size.


new_batch_size (int) –

has_trailing_dimension(trailing_dim: Tuple[int, int, int]) bool[source]#

Checks if the trailing dimension (ch, y, x) matches the given. :param trailing_dim: Tuple[int, int, int]

Three tuple in (channel, y, x) dimensions.


Whether all the states dimensions match.

Return type:



trailing_dim (Tuple[int, int, int]) –

init_state_with_shape(shape, randomize: bool = False) None[source]#

Initialise state/buffers with either zeros or random tensor of specific shape.


randomize (bool) –

Return type:


is_state_initialised() bool[source]#

Checks if buffers are of shape 0 and returns True only if none of them are.

Return type:


reset_states(randomize: bool = False, value_ranges: Dict[str, Tuple[float, float]] | None = None)[source]#

Reset the state/buffers in a layer.

  • randomize (bool) – If true, reset the states between a range provided. Else, the states are reset to zero.

  • value_ranges (Dict[str, Tuple[float, float]] | None) – A dictionary of key value pairs: buffer_name -> (min, max) for each state that needs to be reset. The states are reset with a uniform distribution between the min and max values specified. Any state with an undefined key in this dictionary will be reset between 0 and 1 This parameter is only used if randomize is set to true.


If you would like to reset the state with a custom distribution, you can do this individually for each parameter as follows:

layer.<state_name>.data = <your desired data>


state_has_shape(shape) bool[source]#

Checks if all state have a given shape.

Return type:


zero_grad(set_to_none: bool = False) None[source]#

Zero’s the gradients for buffers/state along with the parameters. See torch.nn.Module.zero_grad() for details


set_to_none (bool) –

Return type: