LIF#
- class sinabs.layers.LIF(tau_mem: float | ~torch.Tensor, tau_syn: float | ~torch.Tensor | None = None, spike_threshold: ~torch.Tensor = tensor(1.), spike_fn: ~typing.Callable = <class 'sinabs.activation.spike_generation.MultiSpike'>, reset_fn: ~typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: ~typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem: float | None = None, train_alphas: bool = False, shape: ~torch.Size | None = None, norm_input: bool = True, record_states: bool = False)[source]#
Leaky Integrate and Fire neuron layer that inherits from
StatefulLayer
.Neuron dynamics in discrete time for norm_input=True:
\[V_{mem}(t+1) = max(\alpha V_{mem}(t) + (1-\alpha)\sum z(t), V_{min})\]Neuron dynamics for norm_input=False:
\[V_{mem}(t+1) = max(\alpha V_{mem}(t) + \sum z(t), V_{min})\]where \(\alpha = e^{-1/tau_{mem}}\), \(V_{min}\) is a minimum membrane potential and \(\sum z(t)\) represents the sum of all input currents at time \(t\). We also reset the membrane potential according to reset_fn:
\[\text{if } V_{mem}(t) >= V_{th} \text{, then } V_{mem} \rightarrow V_{reset}\]- Parameters:
tau_mem (float | Tensor) – Membrane potential time constant.
tau_syn (float | Tensor | None) – Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
spike_threshold (Tensor) – Spikes are emitted if v_mem is above that threshold. By default set to 1.0.
spike_fn (Callable) – Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance.
reset_fn (Callable) – Specify how a neuron’s membrane potential should be reset after a spike.
surrogate_grad_fn (Callable) – Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential.
min_v_mem (float | None) – Lower bound for membrane potential v_mem, clipped at every time step.
train_alphas (bool) – When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
shape (Size | None) – Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
norm_input (bool) – When True, normalise input current by tau. This helps when training time constants.
record_states (bool) – When True, will record all internal states such as v_mem or i_syn in a dictionary attribute recordings. Default is False.
- Shape:
Input: \((Batch, Time, Channel, Height, Width)\) or \((Batch, Time, Channel)\)
Output: Same as input.
- v_mem#
The membrane potential resets according to reset_fn for every spike.
- i_syn#
This attribute is only available if tau_syn is not None.
- property alpha_mem_calculated: Tensor#
Calculates alpha_mem from tau_mem, if not already known.
- property alpha_syn_calculated: Tensor#
Calculates alpha_syn from tau_syn, if not already known.
- forward(input_data: Tensor) Tensor [source]#
- Parameters:
input_data (Tensor) – Data to be processed. Expected shape: (batch, time, …)
- Returns:
Output data with same shape as input_data.
- Return type:
Tensor
- property tau_mem_calculated: Tensor#
Calculates tau_mem from alpha_mem, if not already known.
- property tau_syn_calculated: Tensor#
Calculates tau_syn from alpha_syn, if not already known.