StatefulLayer#
- class sinabs.layers.StatefulLayer(state_names: List[str])[source]#
A base class that instantiates buffers/states which update at every time step and provides helper methods that manage those states.
- Parameters:
state_names (List[str]) – the PyTorch buffers to initialise. These are not parameters.
- property arg_dict: dict#
A public getter function for the constructor arguments.
- property does_spike: bool#
Return True if the layer has an activation function.
- forward(*args, **kwargs)[source]#
Not implemented - You need to implement a forward method in child class
- handle_state_batch_size_mismatch(new_batch_size: int)[source]#
Handles the state mismatch based on the new batch size by randomly selecting new_batch_size number of states from the previous batch_size in a repated way.
- Parameters:
new_batch_size (int) – int New batch size.
- has_trailing_dimension(trailing_dim: Tuple[int, int, int]) bool [source]#
Checks if the trailing dimension (ch, y, x) matches the given.
- Parameters:
trailing_dim (Tuple[int, int, int]) – Tuple[int, int, int] Three tuple in (channel, y, x) dimensions.
- Returns:
Whether all the states dimensions match.
- Return type:
bool
- init_state_with_shape(shape, randomize: bool = False) None [source]#
Initialise state/buffers with either zeros or random tensor of specific shape.
- Parameters:
randomize (bool)
- Return type:
None
- is_state_initialised() bool [source]#
Checks if buffers are of shape 0 and returns True only if none of them are.
- Return type:
bool
- reset_states(randomize: bool = False, value_ranges: Dict[str, Tuple[float, float]] | None = None)[source]#
Reset the state/buffers in a layer.
- Parameters:
randomize (bool) – If true, reset the states between a range provided. Else, the states are reset to zero.
value_ranges (Dict[str, Tuple[float, float]] | None) – A dictionary of key value pairs: buffer_name -> (min, max) for each state that needs to be reset. The states are reset with a uniform distribution between the min and max values specified. Any state with an undefined key in this dictionary will be reset between 0 and 1 This parameter is only used if randomize is set to true.
Note
If you would like to reset the state with a custom distribution, you can do this individually for each parameter as follows:
layer.<state_name>.data = <your desired data>
layer.<state_name>.detach_()