Training an SNN with fewer synops#

Similar as in the previous tutorial, we start by defining a spiking model.

import torch
import torch.nn as nn
import sinabs
import sinabs.layers as sl

class SNN(nn.Sequential):
    def __init__(self, batch_size):
        super().__init__(
            sl.FlattenTime(),
            nn.Conv2d(1, 16, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(16, 32, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(32, 120, 4, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.Flatten(),
            nn.Linear(120, 10, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.UnflattenTime(batch_size=batch_size),
        )

batch_size = 5
snn = SNN(batch_size=batch_size)
snn
SNN(
  (0): FlattenTime(start_dim=0, end_dim=1)
  (1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (2): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), batch_size=5, num_timesteps=-1)
  (3): SumPool2d(norm_type=1, kernel_size=2, stride=None, ceil_mode=False)
  (4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (5): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), batch_size=5, num_timesteps=-1)
  (6): SumPool2d(norm_type=1, kernel_size=2, stride=None, ceil_mode=False)
  (7): Conv2d(32, 120, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (8): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), batch_size=5, num_timesteps=-1)
  (9): Flatten(start_dim=1, end_dim=-1)
  (10): Linear(in_features=120, out_features=10, bias=False)
  (11): IAFSqueeze(spike_threshold=Parameter containing:
  tensor(1.), batch_size=5, num_timesteps=-1)
  (12): UnflattenTime()
)

The SNNAnalyzer class tracks different statistics for spiking (such as IAF/LIF) and parameter (such as Conv2d/Linear) layers. The number of synaptic operations is part of the parameter layers. If we attach such an analyzer to the model, we’ll be able to use layer- or model-wide statistics during training, for optimization or logging purposes.

analyzer = sinabs.SNNAnalyzer(snn)
print(f"Synops before feeding input: {analyzer.get_model_statistics()['synops']}")

rand_input_spikes = (torch.ones((batch_size, 10, 1, 28, 28)) ).float()
y_hat = snn(rand_input_spikes)
print(f"Synops after feeding input: {analyzer.get_model_statistics()['synops']}")
Synops before feeding input: 0.0
Synops after feeding input: 14127320.0

You can break down the statistics for each layer:

layer_stats = analyzer.get_layer_statistics()

for layer_name in layer_stats.keys():
    print(f"Layer {layer_name} tracks statistics {layer_stats[layer_name].keys()}")
Layer spiking tracks statistics dict_keys(['2', '5', '8', '11'])
Layer parameter tracks statistics dict_keys(['1', '4', '7', '10'])

Once we can calculate the total number synops, we might want to choose a target synops number as part of our objective function. If we set the number to low, the network will fail to learn anything as there won’t be any activity at all. As a rule of thumb, do a training run without adding synops to your loss at first and observe how the numbers evolve. You can then set a target synops number accordingly.

In this tutorial we’re going to only optimise for number of synaptic operations given a constant input. We set the target to twice the number of operations of the untrained network. We log the number of synops over time and also the average firing rate in the network, which is closely related.

# Find out the target number of operations
target_synops = 2 * analyzer.get_model_statistics()['synops'].detach_()

optim = torch.optim.Adam(snn.parameters(), lr=1e-3)

n_synops = []
firing_rates = []
for epoch in range(100):
    sinabs.reset_states(snn)
    sinabs.zero_grad(snn)
    optim.zero_grad()
    
    snn(rand_input_spikes)

    model_stats = analyzer.get_model_statistics()
    synops = model_stats['synops']
    firing_rate = model_stats['firing_rate']
    
    n_synops.append(synops.detach().cpu().numpy())
    firing_rates.append(firing_rate.detach().cpu().numpy())
    
    synops_loss = (target_synops - synops).square() / target_synops.square()
    synops_loss.backward()
    optim.step()
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)

ax1.plot(n_synops, label="Synops during training")
ax1.axhline(y=target_synops.item(), color='black', label="Target synops")
ax1.set_ylabel("Synaptic ops\nper mini-batch")
ax1.legend()

ax2.plot(firing_rates)
ax2.set_ylabel("Average firing rate\nacross all neurons")
ax2.set_xlabel("Epoch")
Text(0.5, 0, 'Epoch')
../_images/26a4acf721374afc5470b1727ff24510387ccd43d8ae224427c6d0be280537b8.png

Using the Adam optimizer, which uses a form of momentum, we can see that the network quickly optimizes for the target number of synaptic operations. Closely related (although always between 0 and 1) is the average firing rate for all neurons. Additionally, we can also plot some statistics for each layer:

layer_stats = analyzer.get_layer_statistics()

for layer_name in ['2', '5', '8', '11']:
    print(f"Layer {layer_name} has {layer_stats['spiking'][layer_name]['n_neurons']} neurons.")
Layer 2 has 9216 neurons.
Layer 5 has 2048 neurons.
Layer 8 has 120 neurons.
Layer 11 has 10 neurons.
fig, axes = plt.subplots(4, 1, sharex=True, figsize=(6,6))

for axis, layer_name in zip(axes, ['2', '5', '8', '11']):
    axis.hist(layer_stats['spiking'][layer_name]['firing_rate_per_neuron'].detach().numpy().ravel(), bins=10)
    axis.set_ylabel(f"Layer {layer_name}")
axes[0].set_title("Histogram of firing rates")
axes[-1].set_xlabel("Spikes / neuron / time step");
../_images/f812474c8f6d39ff766fff22dad94a77090d3d4396a106ba8205562b7e7d8484.png