from typing import Optional, Union
import torch
from .lif import LIF
from .reshape import SqueezeMixin
[docs]class ExpLeak(LIF):
"""Leaky Integrator layer which is a special case of :class:`~sinabs.layers.LIF` without
activation function.
Neuron dynamics in discrete time:
.. math ::
V_{mem}(t+1) = \\alpha V_{mem}(t) + (1-\\alpha)\\sum z(t)
where :math:`\\alpha = e^{-1/tau_{mem}}` and :math:`\\sum z(t)` represents the sum of all input currents at time :math:`t`.
Parameters:
tau_mem: Membrane potential time constant.
min_v_mem: Lower bound for membrane potential v_mem, clipped at every time step.
train_alphas: When True, the discrete decay factor exp(-1/tau) is used for training rather than tau itself.
shape: Optionally initialise the layer state with given shape. If None, will be inferred from input_size.
norm_input: When True, normalise input current by tau. This helps when training time constants.
record_states: When True, will record all internal states such as v_mem or i_syn in a dictionary attribute `recordings`. Default is False.
"""
def __init__(
self,
tau_mem: Union[float, torch.Tensor],
shape: Optional[torch.Size] = None,
train_alphas: bool = False,
min_v_mem: Optional[float] = None,
norm_input: bool = False,
record_states: bool = False,
):
super().__init__(
tau_mem=tau_mem,
tau_syn=None,
spike_threshold=None,
train_alphas=train_alphas,
min_v_mem=min_v_mem,
shape=shape,
spike_fn=None,
reset_fn=None,
surrogate_grad_fn=None,
norm_input=norm_input,
record_states=record_states,
)
@property
def _param_dict(self) -> dict:
param_dict = super()._param_dict
param_dict.pop("tau_syn")
param_dict.pop("spike_fn")
param_dict.pop("reset_fn")
param_dict.pop("surrogate_grad_fn")
param_dict.pop("spike_threshold")
return param_dict
[docs]class ExpLeakSqueeze(ExpLeak, SqueezeMixin):
"""ExpLeak layer with 4-dimensional input (Batch*Time, Channel, Height, Width).
Same as parent ExpLeak class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width)
instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with
layers that can only take a 4D input, such as convolutional and pooling layers.
Shape:
- Input: :math:`(Batch \\times Time, Channel, Height, Width)` or :math:`(Batch \\times Time, Channel)`
- Output: Same as input.
Attributes:
v_mem: The membrane potential resets according to reset_fn for every spike.
i_syn: This attribute is only available if tau_syn is not None.
"""
def __init__(self, batch_size=None, num_timesteps=None, **kwargs):
super().__init__(**kwargs)
self.squeeze_init(batch_size, num_timesteps)
[docs] def forward(self, input_data: torch.Tensor) -> torch.Tensor:
"""Forward call wrapper that will flatten the input to and unflatten the output from the
super class forward call."""
return self.squeeze_forward(input_data, super().forward)
@property
def _param_dict(self) -> dict:
return self.squeeze_param_dict(super()._param_dict)