LIFRecurrent#

class sinabs.layers.LIFRecurrent(tau_mem: typing.Union[float, torch.Tensor], rec_connect: torch.nn.modules.module.Module, tau_syn: typing.Optional[typing.Union[float, torch.Tensor]] = None, spike_threshold: float = 1.0, spike_fn: typing.Callable = <class 'sinabs.activation.spike_generation.MultiSpike'>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem: typing.Optional[float] = None, train_alphas: bool = False, shape: typing.Optional[torch.Size] = None, norm_input: bool = True, record_states: bool = False)[source]#

Leaky Integrate and Fire neuron layer with recurrent connections which inherits from LIF.

Neuron dynamics in discrete time for norm_input=True:

\[V_{mem}(t+1) = max(\alpha V_{mem}(t) + (1-\alpha)\sum z_{in}(t) z_{rec}(t), V_{min})\]

Neuron dynamics for norm_input=False:

\[V_{mem}(t+1) = max(\alpha V_{mem}(t) + \sum z_{in}(t) z_{rec}(t), V_{min})\]

where \(\alpha = e^{-1/tau_{mem}}\), \(V_{min}\) is a minimum membrane potential and \(\sum z_{in}(t) z_{rec}(t)\) represents the sum of all input and recurrent currents at time \(t\). We also reset the membrane potential according to reset_fn:

\[\text{if } V_{mem}(t) >= V_{th} \text{, then } V_{mem} \rightarrow V_{reset}\]
Parameters
  • tau_mem (Union[float, torch.Tensor]) – Membrane potential time constant.

  • rec_connect (torch.nn.modules.module.Module) – An nn.Module which defines the recurrent connectivity, e.g. nn.Linear

  • tau_syn (Optional[Union[float, torch.Tensor]]) – Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.

  • spike_threshold (float) – Spikes are emitted if v_mem is above that threshold. By default set to 1.0.

  • spike_fn (Callable) – Specify how many spikes per time step per neuron can be emitted.

  • reset_fn (Callable) – Specify how a neuron’s membrane potential should be reset after a spike.

  • surrogate_grad_fn (Callable) – Choose a surrogate gradient function from sinabs.activation

  • min_v_mem (Optional[float]) – Lower bound for membrane potential v_mem, clipped at every time step.

  • train_alphas (bool) – When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.

  • shape (Optional[torch.Size]) – Optionally initialise the layer state with given shape. If None, will be inferred from input_size.

  • norm_input (bool) – When True, normalise input current by tau. This helps when training time constants.

  • record_states (bool) – When True, will record all internal states such as v_mem or i_syn in a dictionary attribute recordings. Default is False.

Shape:
  • Input: \((Batch, Time, Channel, Height, Width)\) or \((Batch, Time, Channel)\)

  • Output: Same as input.

v_mem#

The membrane potential resets according to reset_fn for every spike.

i_syn#

This attribute is only available if tau_syn is not None.

forward(input_data: torch.Tensor)[source]#
Parameters

input_data (torch.Tensor) – Data to be processed. Expected shape: (batch, time, …)

Returns

Output data with same shape as input_data.