Source code for sinabs.layers.alif

from typing import Callable, Optional, Union

import torch
import torch.nn as nn

from sinabs.activation import MembraneSubtract, SingleExponential, SingleSpike

from . import functional
from .reshape import SqueezeMixin
from .stateful_layer import StatefulLayer


[docs]class ALIF(StatefulLayer): """Adaptive Leaky Integrate and Fire neuron layer that inherits from :class:`~sinabs.layers.StatefulLayer`. Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html Neuron dynamics in discrete time: .. math :: V(t+1) = \\alpha V(t) + (1-\\alpha) \\sum w.s(t) B(t+1) = b0 + \\text{adapt_scale } b(t) b(t+1) = \\rho b(t) + (1-\\rho) s(t) \\text{if } V_{mem}(t) >= B(t) \\text{, then } V_{mem} \\rightarrow V_{mem} - b0, b \\rightarrow 0 where :math:`\\alpha = e^{-1/\\tau_{mem}}`, :math:`\\rho = e^{-1/\\tau_{adapt}}` and :math:`w.s(t)` is the input current for a spike s and weight w. By default there will not be any synaptic current dynamics used. You can specify tau_syn to apply an exponential decay kernel to the input: .. math :: i(t+1) = \\alpha_{syn} i(t) (1-\\alpha_{syn}) + input Parameters: tau_mem: Membrane potential time constant. tau_adapt: Spike threshold time constant. tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. adapt_scale: The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold. spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance. reset_fn: A function that defines how the membrane potential is reset after a spike. surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential. min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself. norm_input: When True, normalise input current by tau. This helps when training time constants. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. b: The deviation from the original spike threshold. spike_threshold: The current spike threshold that gets updated with every output spike. """ def __init__( self, tau_mem: Union[float, torch.Tensor], tau_adapt: Union[float, torch.Tensor], tau_syn: Optional[Union[float, torch.Tensor]] = None, adapt_scale: Union[float, torch.Tensor] = 1.8, spike_threshold: torch.Tensor = torch.tensor(1.0), spike_fn: Callable = SingleSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), min_v_mem: Optional[float] = None, shape: Optional[torch.Size] = None, train_alphas: bool = False, norm_input: bool = True, record_states: bool = False, ): super().__init__( state_names=["v_mem", "i_syn", "b", "spike_threshold"] if tau_syn is not None else ["v_mem", "b", "spike_threshold"] ) if train_alphas: self.alpha_mem = nn.Parameter(torch.exp(-1 / torch.as_tensor(tau_mem))) self.alpha_adapt = nn.Parameter(torch.exp(-1 / torch.as_tensor(tau_adapt))) self.alpha_syn = ( nn.Parameter(torch.exp(-1 / torch.as_tensor(tau_syn))) if tau_syn is not None else None ) else: self.tau_mem = nn.Parameter(torch.as_tensor(tau_mem)) self.tau_adapt = nn.Parameter(torch.as_tensor(tau_adapt)) self.tau_syn = nn.Parameter(torch.as_tensor(tau_syn)) if tau_syn else None self.adapt_scale = adapt_scale self.b0 = spike_threshold self.spike_fn = spike_fn self.reset_fn = reset_fn self.surrogate_grad_fn = surrogate_grad_fn self.min_v_mem = min_v_mem self.train_alphas = train_alphas self.norm_input = norm_input self.record_states = record_states if shape: self.init_state_with_shape(shape) @property def alpha_mem_calculated(self): """Calculates alpha_mem from tau_mem, if not already known.""" return self.alpha_mem if self.train_alphas else torch.exp(-1 / self.tau_mem) @property def alpha_adapt_calculated(self): """Calculates alpha_adapt from tau_adapt, if not already known.""" return self.alpha_adapt if self.train_alphas else torch.exp(-1 / self.tau_adapt) @property def alpha_syn_calculated(self): """Calculates alpha_syn from tau_syn, if not already known.""" if self.train_alphas: return self.alpha_syn elif not self.train_alphas and self.tau_syn is not None: return torch.exp(-1 / self.tau_syn) else: return None
[docs] def forward(self, input_data: torch.Tensor): """ Parameters: input_data: Data to be processed. Expected shape: (batch, time, ...) Returns: Output data with same shape as `input_data`. """ batch_size, time_steps, *trailing_dim = input_data.shape # Ensure the neuron state are initialized if not self.is_state_initialised() or not self.state_has_shape( (batch_size, *trailing_dim) ): self.init_state_with_shape((batch_size, *trailing_dim)) alpha_mem = self.alpha_mem_calculated alpha_syn = self.alpha_syn_calculated alpha_adapt = self.alpha_adapt_calculated spikes, state, recordings = functional.alif_forward( input_data=input_data, alpha_mem=alpha_mem, alpha_adapt=alpha_adapt, alpha_syn=alpha_syn, adapt_scale=self.adapt_scale, state=dict(self.named_buffers()), spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, min_v_mem=self.min_v_mem, b0=self.b0, norm_input=self.norm_input, record_states=self.record_states, ) self.b = state["b"] self.v_mem = state["v_mem"] self.spike_threshold = state["spike_threshold"] self.recordings = recordings self.firing_rate = spikes.sum() / spikes.numel() return spikes
@property def shape(self): if self.is_state_initialised(): return self.v_mem.shape else: return None @property def _param_dict(self) -> dict: param_dict = super()._param_dict param_dict.update( tau_mem=-1 / torch.log(self.alpha_mem.detach_()) if self.train_alphas else self.tau_mem, tau_adapt=-1 / torch.log(self.alpha_adapt.detach_()) if self.train_alphas else self.tau_adapt, tau_syn=-1 / torch.log(self.alpha_syn.detach_()) if self.train_alphas else self.tau_syn, adapt_scale=self.adapt_scale, spike_threshold=self.b0, spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, train_alphas=self.train_alphas, shape=self.shape, min_v_mem=self.min_v_mem, record_states=self.record_states, ) return param_dict
[docs]class ALIFRecurrent(ALIF): """Adaptive Leaky Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.ALIF`. Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html Neuron dynamics in discrete time: .. math :: V(t+1) = \\alpha V(t) + (1-\\alpha) \\sum w.s(t) B(t+1) = b0 + \\text{adapt_scale } b(t) b(t+1) = \\rho b(t) + (1-\\rho) s(t) \\text{if } V_{mem}(t) >= B(t) \\text{, then } V_{mem} \\rightarrow V_{mem} - b0, b \\rightarrow 0 where :math:`\\alpha = e^{-1/\\tau_{mem}}`, :math:`\\rho = e^{-1/\\tau_{adapt}}` and :math:`w.s(t)` is the input current for a spike s and weight w. By default there will not be any synaptic current dynamics used. You can specify tau_syn to apply an exponential decay kernel to the input: .. math :: i(t+1) = \\alpha_{syn} i(t) (1-\\alpha_{syn}) + input Parameters: tau_mem: Membrane potential time constant. tau_adapt: Spike threshold time constant. rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. adapt_scale: The amount that the spike threshold is bumped up for every spike, after which it decays back to the initial threshold. spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance. reset_fn: A function that defines how the membrane potential is reset after a spike. surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential. min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself. norm_input: When True, normalise input current by tau. This helps when training time constants. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. b: The deviation from the original spike threshold. spike_threshold: The current spike threshold that gets updated with every output spike. """ def __init__( self, tau_mem: Union[float, torch.Tensor], tau_adapt: Union[float, torch.Tensor], rec_connect: torch.nn.Module, tau_syn: Optional[Union[float, torch.Tensor]] = None, adapt_scale: Union[float, torch.Tensor] = 1.8, spike_threshold: torch.Tensor = torch.tensor(1.0), spike_fn: Callable = SingleSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), min_v_mem: Optional[float] = None, shape: Optional[torch.Size] = None, train_alphas: bool = False, norm_input: bool = True, record_states: bool = False, ): super().__init__( tau_mem=tau_mem, tau_adapt=tau_adapt, tau_syn=tau_syn, adapt_scale=adapt_scale, spike_threshold=spike_threshold, spike_fn=spike_fn, reset_fn=reset_fn, surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, shape=shape, train_alphas=train_alphas, norm_input=norm_input, record_states=record_states, ) self.rec_connect = rec_connect
[docs] def forward(self, input_data: torch.Tensor): """ Parameters: input_data: Data to be processed. Expected shape: (batch, time, ...) Returns: Output data with same shape as `input_data`. """ batch_size, time_steps, *trailing_dim = input_data.shape # Ensure the neuron state are initialized if not self.is_state_initialised() or not self.state_has_shape( (batch_size, *trailing_dim) ): self.init_state_with_shape((batch_size, *trailing_dim)) alpha_mem = self.alpha_mem_calculated alpha_syn = self.alpha_syn_calculated alpha_adapt = self.alpha_adapt_calculated spikes, state, recordings = functional.alif_recurrent( input_data=input_data, alpha_mem=alpha_mem, alpha_adapt=alpha_adapt, alpha_syn=alpha_syn, adapt_scale=self.adapt_scale, state=dict(self.named_buffers()), spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, min_v_mem=self.min_v_mem, rec_connect=self.rec_connect, b0=self.b0, norm_input=self.norm_input, record_states=self.record_states, ) self.b = state["b"] self.v_mem = state["v_mem"] self.spike_threshold = state["spike_threshold"] self.recordings = recordings self.firing_rate = spikes.sum() / spikes.numel() return spikes
class ALIFSqueeze(ALIF, SqueezeMixin): """ALIF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent ALIF class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with layers that can only take a 4D input, such as convolutional and pooling layers. Shape: - Input: :math:`(Batch \\times Time, Channel, Height, Width)` or :math:`(Batch \\times Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. b: The deviation from the original spike threshold. spike_threshold: The current spike threshold that gets updated with every output spike. """ def __init__( self, batch_size=None, num_timesteps=None, **kwargs, ): super().__init__(**kwargs) self.squeeze_init(batch_size, num_timesteps) def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Forward call wrapper that will flatten the input to and unflatten the output from the super class forward call.""" return self.squeeze_forward(input_data, super().forward) @property def _param_dict(self) -> dict: return self.squeeze_param_dict(super()._param_dict)