Minimise the number of synaptic operations#

As described in Sorbaro et al. 2020, it can be beneficial for power consumption and latency in the network to reduce the number of synaptic operations (synops). This number is essentially the output activation multiplied by the number of outward connections (fan-out) to the next layer. We describe how using Sinabs’ synops counters it’s possible to easily add a term to your loss function which then can be minimised.

Training an ANN with fewer synops#

Let’s start by defining our ANN. Keep in mind that we use NeuromorphicRelus here as we need to discretize the output in the forward pass to simulate the number of spikes that layer would emit. In the backward pass the derivative of the ReLU function is used.

import torch
import torch.nn as nn
import sinabs
import sinabs.layers as sl


ann = nn.Sequential(
    nn.Conv2d(1, 16, 5, bias=False),
    sl.NeuromorphicReLU(),
    sl.SumPool2d(2),
    nn.Conv2d(16, 32, 5, bias=False),
    sl.NeuromorphicReLU(),
    sl.SumPool2d(2),
    nn.Conv2d(32, 120, 4, bias=False),
    sl.NeuromorphicReLU(),
    nn.Flatten(),
    nn.Linear(120, 10, bias=False),
)

Let’s apply a SynopsCounter to our ANN to track how many synaptic operations it would need in an SNN.

batch_size = 5

synops_counter_ann = sinabs.SynOpCounter(ann.modules(), sum_activations=True)

rand_input = torch.rand((batch_size, 1, 28, 28))
ann(rand_input)
print(f"Synops after feeding input: {synops_counter_ann()}")
Synops after feeding input: 0.6000000238418579

Training an SNN with fewer synops#

Let’s start by defining your SNN model.

class SNN(nn.Sequential):
    def __init__(self, batch_size):
        super().__init__(
            sl.FlattenTime(),
            nn.Conv2d(1, 16, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(16, 32, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(32, 120, 4, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.Flatten(),
            nn.Linear(120, 10, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.UnflattenTime(batch_size=batch_size),
        )

    @property
    def spiking_layers(self):
        return [layer for layer in self.net.children() if isinstance(layer, sl.StatefulLayer)]

    def reset_states(self):
        for layer in self.spiking_layers:
            layer.reset_states()


batch_size = 5
snn = SNN(batch_size=batch_size)
snn
SNN(
  (0): FlattenTime(start_dim=0, end_dim=1)
  (1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (2): IAFSqueeze(spike_threshold=1.0)
  (3): SumPool2d(norm_type=1, kernel_size=2, stride=None, ceil_mode=False)
  (4): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (5): IAFSqueeze(spike_threshold=1.0)
  (6): SumPool2d(norm_type=1, kernel_size=2, stride=None, ceil_mode=False)
  (7): Conv2d(32, 120, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (8): IAFSqueeze(spike_threshold=1.0)
  (9): Flatten(start_dim=1, end_dim=-1)
  (10): Linear(in_features=120, out_features=10, bias=False)
  (11): IAFSqueeze(spike_threshold=1.0)
  (12): UnflattenTime()
)

If we apply a SynopsLossCounter to the model, we’ll be able to track the number of synops as we feed new inputs to the model.

synops_counter = sinabs.SNNSynOpCounter(snn)
print(f"Synops before feeding input: {synops_counter.get_total_synops()}")

rand_input_spikes = (torch.rand((batch_size, 10, 1, 28, 28)) < 0.1).float()
y_hat = snn(rand_input_spikes)
print(f"Synops after feeding input: {synops_counter.get_total_synops()}")
Synops before feeding input: 0.0
Synops after feeding input: 7078570.0

You can also get a more detailed count for each layer.

synops_counter.get_synops()
{2: {'Layer': 2,
  'In': tensor(3866.),
  'Fanout_Prev': 400.0,
  'SynOps': tensor(1546400.),
  'N. timesteps': 50,
  'Time window (ms)': 50.0,
  'SynOps/s': tensor(30928000.)},
 5: {'Layer': 5,
  'In': tensor(5876., grad_fn=<MulBackward0>),
  'Fanout_Prev': 800.0,
  'SynOps': tensor(4700800., grad_fn=<MulBackward0>),
  'N. timesteps': 50,
  'Time window (ms)': 50.0,
  'SynOps/s': tensor(94016000., grad_fn=<MulBackward0>)},
 8: {'Layer': 8,
  'In': tensor(433., grad_fn=<MulBackward0>),
  'Fanout_Prev': 1920.0,
  'SynOps': tensor(831360., grad_fn=<MulBackward0>),
  'N. timesteps': 50,
  'Time window (ms)': 50.0,
  'SynOps/s': tensor(16627199., grad_fn=<MulBackward0>)},
 11: {'Layer': 11,
  'In': tensor(1., grad_fn=<MulBackward0>),
  'Fanout_Prev': 10,
  'SynOps': tensor(10., grad_fn=<MulBackward0>),
  'N. timesteps': 50,
  'Time window (ms)': 50.0,
  'SynOps/s': tensor(200., grad_fn=<MulBackward0>)}}

Once we have can calculate the total synops, we might want to choose a target synops number in order to decrease power consumption. As a rule of thumb we’re going to take half of the number of initial synops as constant target.

# done once
target_synops = synops_counter.get_total_synops() / 2

# in your training loop
synops = synops_counter.get_total_synops()
synops_loss = (target_synops - synops).square() / target_synops.square()

loss = y_hat.sum(1) + synops_loss