class sinabs.layers.IAFRecurrent(rec_connect: torch.nn.modules.module.Module, spike_threshold: torch.Tensor = tensor(1.), spike_fn: typing.Callable = <class 'sinabs.activation.spike_generation.MultiSpike'>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), tau_syn: typing.Optional[float] = None, min_v_mem: typing.Optional[float] = None, shape: typing.Optional[torch.Size] = None, record_states: bool = False)[source]#

Integrate and Fire neuron layer with recurrent connections which inherits from LIFRecurrent.

Neuron dynamics in discrete time:

\[ \begin{align}\begin{aligned}V_{mem}(t+1) = V_{mem}(t) + \sum z(t)\\\text{if } V_{mem}(t) >= V_{th} \text{, then } V_{mem} \rightarrow V_{reset}\end{aligned}\end{align} \]

where \(\sum z(t)\) represents the sum of all input currents at time \(t\).

  • rec_connect (torch.nn.modules.module.Module) – An nn.Module which defines the recurrent connectivity, e.g. nn.Linear

  • spike_threshold (torch.Tensor) – Spikes are emitted if v_mem is above that threshold. By default set to 1.0.

  • spike_fn (Callable) – Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance.

  • reset_fn (Callable) – A function that defines how the membrane potential is reset after a spike.

  • surrogate_grad_fn (Callable) – Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential.

  • tau_syn (Optional[float]) – Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.

  • min_v_mem (Optional[float]) – Lower bound for membrane potential v_mem, clipped at every time step.

  • shape (Optional[torch.Size]) – Optionally initialise the layer state with given shape. If None, will be inferred from input_size.

  • record_states (bool) – When True, will record all internal states such as v_mem or i_syn in a dictionary attribute recordings. Default is False.

  • Input: \((Batch, Time, Channel, Height, Width)\) or \((Batch, Time, Channel)\)

  • Output: Same as input.


The membrane potential resets according to reset_fn for every spike.


This attribute is only available if tau_syn is not None.

property alpha_mem_calculated#

Always returns a tensor of 1.