Source code for sinabs.layers.iaf

import torch
from copy import deepcopy
from typing import Optional, Callable
from sinabs.activation import MultiSpike, MembraneSubtract, SingleExponential
from .reshape import SqueezeMixin
from .lif import LIF, LIFRecurrent
import numpy as np


[docs]class IAF(LIF): """ Integrate and Fire neuron layer that is designed as a special case of :class:`~sinabs.layers.LIF` with tau_mem=inf. Neuron dynamics in discrete time: .. math :: V_{mem}(t+1) = V_{mem}(t) + \\sum z(t) \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset} where :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`. Parameters: spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance. reset_fn: A function that defines how the membrane potential is reset after a spike. surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential. tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, spike_threshold: float = 1.0, spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), tau_syn: Optional[float] = None, min_v_mem: Optional[float] = None, shape: Optional[torch.Size] = None, record_states: bool = False, ): super().__init__( tau_mem=np.inf, tau_syn=tau_syn, spike_threshold=spike_threshold, spike_fn=spike_fn, reset_fn=reset_fn, surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, shape=shape, norm_input=False, record_states=record_states, ) # IAF does not have time constants self.tau_mem = None @property def alpha_mem_calculated(self): """Always returns a tensor of 1.""" return torch.tensor(1.0) @property def _param_dict(self) -> dict: param_dict = super()._param_dict param_dict.pop("tau_mem") param_dict.pop("train_alphas") param_dict.pop("norm_input") return param_dict
[docs]class IAFRecurrent(LIFRecurrent): """ Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.LIFRecurrent`. Neuron dynamics in discrete time: .. math :: V_{mem}(t+1) = V_{mem}(t) + \\sum z(t) \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset} where :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`. Parameters: rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance. reset_fn: A function that defines how the membrane potential is reset after a spike. surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential. tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, rec_connect: torch.nn.Module, spike_threshold: float = 1.0, spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), tau_syn: Optional[float] = None, min_v_mem: Optional[float] = None, shape: Optional[torch.Size] = None, record_states: bool = False, ): super().__init__( rec_connect=rec_connect, tau_mem=np.inf, tau_syn=tau_syn, spike_threshold=spike_threshold, spike_fn=spike_fn, reset_fn=reset_fn, surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, shape=shape, norm_input=False, record_states=record_states, ) # IAF does not have time constants self.tau_mem = None @property def alpha_mem_calculated(self): """Always returns a tensor of 1.""" return torch.tensor(1.0) @property def _param_dict(self) -> dict: param_dict = super()._param_dict param_dict.pop("tau_mem") param_dict.pop("train_alphas") param_dict.pop("norm_input") return param_dict
[docs]class IAFSqueeze(IAF, SqueezeMixin): """ IAF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent IAF class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with layers that can only take a 4D input, such as convolutional and pooling layers. Shape: - Input: :math:`(Batch \\times Time, Channel, Height, Width)` or :math:`(Batch \\times Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, batch_size=None, num_timesteps=None, **kwargs, ): super().__init__(**kwargs) self.squeeze_init(batch_size, num_timesteps)
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """ Forward call wrapper that will flatten the input to and unflatten the output from the super class forward call. """ return self.squeeze_forward(input_data, super().forward)
@property def _param_dict(self) -> dict: return self.squeeze_param_dict(super()._param_dict)