ALIF
Contents
ALIF#
An Adaptive Leaky Integrate Fire (ALIF) layer and variations thereof.
- class sinabs.layers.ALIF(tau_mem: typing.Union[float, torch.Tensor], tau_adapt: typing.Union[float, torch.Tensor], tau_syn: typing.Optional[typing.Union[float, torch.Tensor]] = None, adapt_scale: typing.Union[float, torch.Tensor] = 1.8, spike_threshold: float = 1.0, spike_fn: typing.Callable = <class 'sinabs.activation.spike_generation.SingleSpike'>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem: typing.Optional[float] = None, shape: typing.Optional[torch.Size] = None, train_alphas: bool = False, norm_input: bool = True, record_states: bool = False)#
Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html
Neuron dynamics in discrete time:
\[ \begin{align}\begin{aligned}V(t+1) = \alpha V(t) + (1-\alpha) \sum w.s(t)\\B(t+1) = b0 + \text{adapt_scale } b(t)\\b(t+1) = \rho b(t) + (1-\rho) s(t)\\\text{if } V_{mem}(t) >= B(t) \text{, then } V_{mem} \rightarrow V_{mem} - b0, b \rightarrow 0\end{aligned}\end{align} \]where \(\alpha = e^{-1/\tau_{mem}}\), \(\rho = e^{-1/\tau_{adapt}}\) and \(w.s(t)\) is the input current for a spike s and weight w.
By default there will not be any synaptic current dynamics used. You can specify tau_syn to apply an exponential decay kernel to the input:
\[i(t+1) = \alpha_{syn} i(t) (1-\alpha_{syn}) + input\]- Parameters
tau_mem (float) – Membrane potential time constant.
tau_adapt (float) – Spike threshold time constant.
tau_syn (float) – Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
adapt_scale (float) – The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold.
spike_threshold (float) – Set initial spike threshold. By default set to 1.0.
spike_fn (torch.autograd.Function) – Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance.
reset_fn (Callable) – A function that defines how the membrane potential is reset after a spike.
surrogate_grad_fn (Callable) – Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential.
min_v_mem (float or None) – Lower bound for membrane potential v_mem, clipped at every time step.
shape (torch.Size) – Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
train_alphas (bool) – When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
norm_input (bool) – When True, normalise input current by tau. This helps when training time constants.
record_states (bool) – When True, will record all internal states such as v_mem or i_syn in a dictionary attribute recordings. Default is False.
- forward(input_data: torch.Tensor)#
Forward pass with given data.
- Parameters
input_current – torch.Tensor Data to be processed. Expected shape: (batch, time, …)
- Returns
- torch.Tensor
Output data. Same shape as input_data.
ALIFRecurrent#
- class sinabs.layers.ALIFRecurrent(tau_mem: typing.Union[float, torch.Tensor], tau_adapt: typing.Union[float, torch.Tensor], rec_connect: torch.nn.modules.module.Module, tau_syn: typing.Optional[typing.Union[float, torch.Tensor]] = None, adapt_scale: typing.Union[float, torch.Tensor] = 1.8, spike_threshold: float = 1.0, spike_fn: typing.Callable = <class 'sinabs.activation.spike_generation.SingleSpike'>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem: typing.Optional[float] = None, shape: typing.Optional[torch.Size] = None, train_alphas: bool = False, norm_input: bool = True, record_states: bool = False)#
Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html
Neuron dynamics in discrete time:
\[ \begin{align}\begin{aligned}V(t+1) = \alpha V(t) + (1-\alpha) \sum w.s(t)\\B(t+1) = b0 + \text{adapt_scale } b(t)\\b(t+1) = \rho b(t) + (1-\rho) s(t)\\\text{if } V_{mem}(t) >= B(t) \text{, then } V_{mem} \rightarrow V_{mem} - b0, b \rightarrow 0\end{aligned}\end{align} \]where \(\alpha = e^{-1/\tau_{mem}}\), \(\rho = e^{-1/\tau_{adapt}}\) and \(w.s(t)\) is the input current for a spike s and weight w.
By default there will not be any synaptic current dynamics used. You can specify tau_syn to apply an exponential decay kernel to the input:
\[i(t+1) = \alpha_{syn} i(t) (1-\alpha_{syn}) + input\]- Parameters
tau_mem (float) – Membrane potential time constant.
tau_adapt (float) – Spike threshold time constant.
tau_syn (float) – Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default.
adapt_scale (float) – The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold.
spike_fn (torch.autograd.Function) – Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance.
reset_fn (Callable) – A function that defines how the membrane potential is reset after a spike.
surrogate_grad_fn (Callable) – Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential.
min_v_mem (float or None) – Lower bound for membrane potential v_mem, clipped at every time step.
shape (torch.Size) – Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
train_alphas (bool) – When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
norm_input (bool) – When True, normalise input current by tau. This helps when training time constants.
record_states (bool) – When True, will record all internal states such as v_mem or i_syn in a dictionary attribute recordings. Default is False.
- forward(input_data: torch.Tensor)#
Forward pass with given data.
- Parameters
input_current – torch.Tensor Data to be processed. Expected shape: (batch, time, …)
- Returns
- torch.Tensor
Output data. Same shape as input_data.