Quickstart Sinabs#

If you’re familiar with how SNNs work, you might find this quick overview about Sinabs useful.

Sinabs is based on PyTorch#

All of Sinabs’ layers inherit from torch.nn.Module. Thus you will be able to access your parameters, wrap layers in a nn.Sequential module and all the other things that you would do with a normal PyTorch layer.

How to define your network#

We want to re-use as much PyTorch functionality as possible. We use Linear, Conv2d and AvgPool layers to define weight matrices, whereas Sinabs layers add state as well as the non-linear activation to each of those weight layers. This is a definition of a simple SNN which takes as an input a tensor of (Batch, Time, Channels):

import torch
import torch.nn as nn

import sinabs.activation
import sinabs.layers as sl

model = nn.Sequential(
    nn.Linear(16, 64),
    nn.Linear(64, 4),

Inference with SNNs#

For simple inference using SNNs, you just use the model like any other torch model

# Define an input (Batch, Time, Channels)
input = (torch.rand(1, 100, 16) > 0.2).float()

# Compute output with the model
with torch.no_grad():
    output = model(input)

torch.Size([1, 100, 4])

You can see above that the output of the SNN model defined above has the shape (batch, time, neurons), where neurons is the number of neurons in the final layer of the model.

Note that the network state is retained after any forward pass/inference. If you require resetting of the states/gradient, you can do so using the corresponding methods layer.reset_states() or layer.zero_grad().

Training with BPTT#

BPTT (Back-Propagation-Through-Time) refers to training a model with data that spans several time steps. Crucially, to train models on such data, the model needs to learn the temporal dependence in the data and therefore, the computed gradients need to be propagated back in time in addition to the propagation along its layers.

Sinabs enables you to train SNNs using BPTT to take full advantage of the temporal computation and memory afforded by spiking neurons. You see below a small example of how you can train your Sinabs models using BPTT.

We first start with a couple of helper functions that loop over all the layers in our model and reset their states and gradients. You will see how they come handy in the next code block.

# Some helper functions to reset our model during the training loops
def reset_model_states(seq_model: nn.Sequential, randomize: bool = False):
    Method to reset the internal states of a model
    for lyr in seq_model:
        if isinstance(lyr, sl.LIF):

def zero_grad_states(seq_model: nn.Sequential):
    Method to reset the gradients of the internal states of a model
    for lyr in seq_model:
        if isinstance(lyr, sl.LIF):

For the purpose of this demonstration, we define a very simple toy task:

Train the model to produce 10 spikes in response to an input spike pattern from 16 spiking neurons.

For simplicity, we generate a random spike train and use that as our input spike pattern.

Like with any standard training loop in pytorch, we start by defining an optimizer and loop over several training epochs.

In each training loop, the following steps are carried out.

  1. Reset the parameter gradients.

  2. Reset the state/vmem gradients.

  3. Reset the model state/vmem to an initial condition.

  4. Perform a forward pass.

  5. Calculate the loss.

  6. Backpropagate gradients based on the computed loss.

  7. Update parameters.

Note the additional steps 2 and 3. These are additional required inorder to account for the stateful nature of spiking layers in our model.

# Define an input (Batch, Time, Channels)
input_data = (torch.rand(1, 100, 16) > 0.2).float()

# Training routine
optim = torch.optim.RMSprop(model.parameters(), lr=1e-3)
num_epochs = 100
target_num_spikes = 10

for epoch in range(num_epochs):
    # Reset the gradients of the parameters

    # We will also need to reset the gradients of neuron states.
    # Alternatively you could also reset the states themselves.
    reset_model_states(model, randomize=False)

    # Forward pass
    out = model(input_data)
    print(f"Epoch {epoch}: Output spikes: {out.sum().item()}")

    # Compute loss
    loss = (out.sum() - target_num_spikes) ** 2

    # Back-propagate the gradients.

    # Update parameters

    # Early stopage
    if not loss:
Epoch 0: Output spikes: 0.0
Epoch 1: Output spikes: 0.0
Epoch 2: Output spikes: 0.0
Epoch 3: Output spikes: 0.0
Epoch 4: Output spikes: 0.0
Epoch 5: Output spikes: 0.0
Epoch 6: Output spikes: 0.0
Epoch 7: Output spikes: 0.0
Epoch 8: Output spikes: 0.0
Epoch 9: Output spikes: 0.0
Epoch 10: Output spikes: 0.0
Epoch 11: Output spikes: 0.0
Epoch 12: Output spikes: 0.0
Epoch 13: Output spikes: 0.0
Epoch 14: Output spikes: 0.0
Epoch 15: Output spikes: 0.0
Epoch 16: Output spikes: 0.0
Epoch 17: Output spikes: 0.0
Epoch 18: Output spikes: 0.0
Epoch 19: Output spikes: 0.0
Epoch 20: Output spikes: 0.0
Epoch 21: Output spikes: 0.0
Epoch 22: Output spikes: 0.0
Epoch 23: Output spikes: 0.0
Epoch 24: Output spikes: 0.0
Epoch 25: Output spikes: 0.0
Epoch 26: Output spikes: 0.0
Epoch 27: Output spikes: 0.0
Epoch 28: Output spikes: 0.0
Epoch 29: Output spikes: 0.0
Epoch 30: Output spikes: 0.0
Epoch 31: Output spikes: 0.0
Epoch 32: Output spikes: 0.0
Epoch 33: Output spikes: 0.0
Epoch 34: Output spikes: 0.0
Epoch 35: Output spikes: 0.0
Epoch 36: Output spikes: 0.0
Epoch 37: Output spikes: 0.0
Epoch 38: Output spikes: 0.0
Epoch 39: Output spikes: 0.0
Epoch 40: Output spikes: 1.0
Epoch 41: Output spikes: 1.0
Epoch 42: Output spikes: 4.0
Epoch 43: Output spikes: 5.0
Epoch 44: Output spikes: 7.0
Epoch 45: Output spikes: 7.0
Epoch 46: Output spikes: 8.0
Epoch 47: Output spikes: 8.0
Epoch 48: Output spikes: 8.0
Epoch 49: Output spikes: 7.0
Epoch 50: Output spikes: 11.0
Epoch 51: Output spikes: 8.0
Epoch 52: Output spikes: 11.0
Epoch 53: Output spikes: 11.0
Epoch 54: Output spikes: 8.0
Epoch 55: Output spikes: 11.0
Epoch 56: Output spikes: 11.0
Epoch 57: Output spikes: 8.0
Epoch 58: Output spikes: 11.0
Epoch 59: Output spikes: 11.0
Epoch 60: Output spikes: 8.0
Epoch 61: Output spikes: 11.0
Epoch 62: Output spikes: 10.0
out.sum(), out.shape
(tensor(10., grad_fn=<SumBackward0>), torch.Size([1, 100, 4]))

We see above that the model trains to produce 10 spikes as intended.

That is it! Now you know everything you need to know about training models with Sinabs!

Working with Convolutional networks#

When working with convolutional connectivity, a nn.Conv2d layer only takes as input a tensor of (Batch, Channels, Height, Width). If we feed a tensor that has an additional time dimension (Batch, Time, Channels, Height, Width) to such a layer, we will receive an error. In order for us to apply 2D convolutions across time, we have to make use of a small trick where we flatten batch and time dimension before feeding it to the Conv layer. If the input is flattened, the Squeeze versions of spiking Sinabs layers understand and take care of expanding the time dimension appropriately, without any major changes to your model definition.

batch_size = 8
time_steps = 100

conv_model = nn.Sequential(
    nn.Conv2d(2, 16, kernel_size=3),
    sl.LIFSqueeze(tau_mem=20.0, batch_size=batch_size),
    nn.Conv2d(16, 32, kernel_size=3),
    sl.LIFSqueeze(tau_mem=20.0, batch_size=batch_size),
    nn.Linear(512, 4),

# (Batch, Time, Channels, Height, Width)
data = torch.rand(batch_size, time_steps, 2, 8, 8)

# Data reshaped to fit the flattened model definition
input = data.view(batch_size * time_steps, 2, 8, 8)

torch.Size([800, 2, 8, 8])

The rest of the forward pass or training loops remain the same as described in the above sections.

with torch.no_grad():
    output = conv_model(input)

This output has to then be reshaped to split and restore batch and time dimensions.

output_spike_raster = output.view(batch_size, time_steps, 4)
torch.Size([8, 100, 4])