Source code for sinabs.layers.lif

from typing import Callable, Optional, Union

import torch
import torch.nn as nn

from sinabs.activation import MembraneSubtract, MultiSpike, SingleExponential

from . import functional
from .reshape import SqueezeMixin
from .stateful_layer import StatefulLayer


[docs]class LIF(StatefulLayer): """Leaky Integrate and Fire neuron layer that inherits from :class:`~sinabs.layers.StatefulLayer`. Neuron dynamics in discrete time for norm_input=True: .. math :: V_{mem}(t+1) = max(\\alpha V_{mem}(t) + (1-\\alpha)\\sum z(t), V_{min}) Neuron dynamics for norm_input=False: .. math :: V_{mem}(t+1) = max(\\alpha V_{mem}(t) + \\sum z(t), V_{min}) where :math:`\\alpha = e^{-1/tau_{mem}}`, :math:`V_{min}` is a minimum membrane potential and :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`. We also reset the membrane potential according to reset_fn: .. math :: \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset} Parameters: tau_mem: Membrane potential time constant. tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Choose a Sinabs or custom torch.autograd.Function that takes a dict of states, a spike threshold and a surrogate gradient function and returns spikes. Be aware that the class itself is passed here (because torch.autograd methods are static) rather than an object instance. reset_fn: Specify how a neuron's membrane potential should be reset after a spike. surrogate_grad_fn: Choose how to define gradients for the spiking non-linearity during the backward pass. This is a function of membrane potential. min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. norm_input: When True, normalise input current by tau. This helps when training time constants. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, tau_mem: Union[float, torch.Tensor], tau_syn: Optional[Union[float, torch.Tensor]] = None, spike_threshold: torch.Tensor = torch.tensor(1.0), spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), min_v_mem: Optional[float] = None, train_alphas: bool = False, shape: Optional[torch.Size] = None, norm_input: bool = True, record_states: bool = False, ): super().__init__( state_names=["v_mem", "i_syn"] if tau_syn is not None else ["v_mem"] ) if train_alphas: self.alpha_mem = nn.Parameter( torch.exp(-1.0 / torch.as_tensor(tau_mem, dtype=torch.float32)) ) self.alpha_syn = ( nn.Parameter( torch.exp(-1.0 / torch.as_tensor(tau_syn, dtype=torch.float32)) ) if tau_syn is not None else None ) else: self.tau_mem = nn.Parameter(torch.as_tensor(tau_mem, dtype=torch.float32)) self.tau_syn = ( nn.Parameter(torch.as_tensor(tau_syn, dtype=torch.float32)) if tau_syn is not None else None ) self.spike_fn = spike_fn self.reset_fn = reset_fn self.surrogate_grad_fn = surrogate_grad_fn self.train_alphas = train_alphas self.norm_input = norm_input self.record_states = record_states self.min_v_mem = ( nn.Parameter(torch.as_tensor(min_v_mem), requires_grad=False) if min_v_mem else None ) self.spike_threshold = ( nn.Parameter(torch.as_tensor(spike_threshold), requires_grad=False) if spike_threshold else None ) if shape: self.init_state_with_shape(shape) @property def alpha_mem_calculated(self) -> torch.Tensor: """Calculates alpha_mem from tau_mem, if not already known.""" if self.train_alphas: return self.alpha_mem else: return torch.exp(-1.0 / self.tau_mem) @property def alpha_syn_calculated(self) -> torch.Tensor: """Calculates alpha_syn from tau_syn, if not already known.""" if self.train_alphas: return self.alpha_syn elif self.tau_syn is not None: # Calculate alpha with 64 bit floating point precision return torch.exp(-1.0 / self.tau_syn) else: return None @property def tau_mem_calculated(self) -> torch.Tensor: """Calculates tau_mem from alpha_mem, if not already known.""" if self.train_alphas: if self.alpha_mem is None: return None else: return -1.0 / torch.log(self.alpha_mem) else: return self.tau_mem @property def tau_syn_calculated(self) -> torch.Tensor: """Calculates tau_syn from alpha_syn, if not already known.""" if self.train_alphas: if self.alpha_syn is None: return None else: return -1.0 / torch.log(self.alpha_syn) else: return self.tau_syn
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """ Parameters: input_data: Data to be processed. Expected shape: (batch, time, ...) Returns: Output data with same shape as `input_data`. """ batch_size, time_steps, *trailing_dim = input_data.shape # Ensure the state is initialized. if not self.is_state_initialised(): self.init_state_with_shape((batch_size, *trailing_dim)) if not self.state_has_shape((batch_size, *trailing_dim)): # If the trailing dim has changed, we reinitialize the states. if not self.has_trailing_dimension(trailing_dim): self.init_state_with_shape((batch_size, *trailing_dim)) # Otherwise only the batch size has changed. else: self.handle_state_batch_size_mismatch(batch_size) # with the input batch size alpha_mem = self.alpha_mem_calculated alpha_syn = self.alpha_syn_calculated spikes, state, recordings = functional.lif_forward( input_data=input_data, alpha_mem=alpha_mem, alpha_syn=alpha_syn, state=dict(self.named_buffers()), spike_threshold=self.spike_threshold, spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, min_v_mem=self.min_v_mem, norm_input=self.norm_input, record_states=self.record_states, ) self.v_mem = state["v_mem"] self.i_syn = state["i_syn"] if alpha_syn is not None else None self.recordings = recordings self.firing_rate = spikes.mean() return spikes
@property def shape(self): if self.is_state_initialised(): return self.v_mem.shape else: return None @property def _param_dict(self) -> dict: param_dict = super()._param_dict tau_syn = self.tau_syn_calculated if tau_syn is not None: tau_syn = tau_syn.detach() tau_mem = self.tau_mem_calculated if tau_mem is not None: tau_mem = tau_mem.detach() param_dict.update( tau_mem=tau_mem, tau_syn=tau_syn, spike_threshold=self.spike_threshold, spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, train_alphas=self.train_alphas, shape=self.shape, min_v_mem=self.min_v_mem, norm_input=self.norm_input, record_states=self.record_states, ) return param_dict
[docs]class LIFRecurrent(LIF): """Leaky Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.LIF`. Neuron dynamics in discrete time for norm_input=True: .. math :: V_{mem}(t+1) = max(\\alpha V_{mem}(t) + (1-\\alpha)\\sum z_{in}(t) z_{rec}(t), V_{min}) Neuron dynamics for norm_input=False: .. math :: V_{mem}(t+1) = max(\\alpha V_{mem}(t) + \\sum z_{in}(t) z_{rec}(t), V_{min}) where :math:`\\alpha = e^{-1/tau_{mem}}`, :math:`V_{min}` is a minimum membrane potential and :math:`\\sum z_{in}(t) z_{rec}(t)` represents the sum of all input and recurrent currents at time :math:`t`. We also reset the membrane potential according to reset_fn: .. math :: \\text{if } V_{mem}(t) >= V_{th} \\text{, then } V_{mem} \\rightarrow V_{reset} Parameters: tau_mem: Membrane potential time constant. rec_connect: An nn.Module which defines the recurrent connectivity, e.g. nn.Linear tau_syn: Synaptic decay time constants. If None, no synaptic dynamics are used, which is the default. spike_threshold: Spikes are emitted if v_mem is above that threshold. By default set to 1.0. spike_fn: Specify how many spikes per time step per neuron can be emitted. reset_fn: Specify how a neuron's membrane potential should be reset after a spike. surrogate_grad_fn: Choose a surrogate gradient function from sinabs.activation min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step. train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself. shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size. norm_input: When True, normalise input current by tau. This helps when training time constants. record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False. Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, tau_mem: Union[float, torch.Tensor], rec_connect: torch.nn.Module, tau_syn: Optional[Union[float, torch.Tensor]] = None, spike_threshold: torch.Tensor = torch.tensor(1.0), spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), min_v_mem: Optional[float] = None, train_alphas: bool = False, shape: Optional[torch.Size] = None, norm_input: bool = True, record_states: bool = False, ): super().__init__( tau_mem=tau_mem, tau_syn=tau_syn, spike_threshold=spike_threshold, spike_fn=spike_fn, reset_fn=reset_fn, surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, shape=shape, train_alphas=train_alphas, norm_input=norm_input, record_states=record_states, ) self.rec_connect = rec_connect
[docs] def forward(self, input_data: torch.Tensor): """ Parameters: input_data: Data to be processed. Expected shape: (batch, time, ...) Returns: Output data with same shape as `input_data`. """ batch_size, time_steps, *trailing_dim = input_data.shape # Ensure the neuron state are initialized if not self.is_state_initialised() or not self.state_has_shape( (batch_size, *trailing_dim) ): self.init_state_with_shape((batch_size, *trailing_dim)) alpha_mem = self.alpha_mem_calculated alpha_syn = self.alpha_syn_calculated spikes, state, recordings = functional.lif_recurrent( input_data=input_data, alpha_mem=alpha_mem, alpha_syn=alpha_syn, state=dict(self.named_buffers()), spike_threshold=self.spike_threshold, spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, min_v_mem=self.min_v_mem, norm_input=self.norm_input, rec_connect=self.rec_connect, record_states=self.record_states, ) self.v_mem = state["v_mem"] self.i_syn = state["i_syn"] if alpha_syn is not None else None self.recordings = recordings self.firing_rate = spikes.sum() / spikes.numel() return spikes
@property def _param_dict(self) -> dict: param_dict = super()._param_dict param_dict.update(rec_connect=self.rec_connect) return param_dict
[docs]class LIFSqueeze(LIF, SqueezeMixin): """ :class:`~sinabs.layers.LIF` layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent :class:`~sinabs.layers.LIF` class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with layers that can only take a 4D input, such as convolutional and pooling layers. Shape: - Input: :math:`(Batch \\times Time, Channel, Height, Width)` or :math:`(Batch \\times Time, Channel)` - Output: Same as input. Attributes: v_mem: The membrane potential resets according to reset_fn for every spike. i_syn: This attribute is only available if tau_syn is not None. """ def __init__( self, batch_size=None, num_timesteps=None, **kwargs, ): super().__init__(**kwargs) self.squeeze_init(batch_size, num_timesteps)
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Forward call wrapper that will flatten the input to and unflatten the output from the super class forward call.""" return self.squeeze_forward(input_data, super().forward)
@property def _param_dict(self) -> dict: return self.squeeze_param_dict(super()._param_dict)