Number of synaptic operations and how to minimise them#

As described in Sorbaro et al. 2020, it can be beneficial for power consumption and latency in the network to reduce the number of synaptic operations (synops). This number is essentially the output activation multiplied by the number of outward connections (fan-out) to the next layer. We describe how using Sinabs’ synops counters it’s possible to easily add a term to your loss function which then can be minimised.

Training an ANN with fewer synops#

Let’s start by defining our ANN. Keep in mind that we use NeuromorphicRelus here as we need to discretize the output in the forward pass to simulate the number of spikes that layer would emit. In the backward pass the derivative of the ReLU function is used.

import torch
import torch.nn as nn
import sinabs
import sinabs.layers as sl


ann = nn.Sequential(
    nn.Conv2d(1, 16, 5, bias=False),
    sl.NeuromorphicReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(16, 32, 5, bias=False),
    sl.NeuromorphicReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 120, 4, bias=False),
    sl.NeuromorphicReLU(),
    nn.Flatten(),
    nn.Linear(120, 10, bias=False),
)
/home/docs/checkouts/readthedocs.org/user_builds/sinabs/envs/v0.3.4/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Let’s apply a SynopsCounter to our ANN to track how many synaptic operations it would need in an SNN.

batch_size = 5

synops_counter_ann = sinabs.SynOpCounter(ann.modules(), sum_activations=True)

rand_input = torch.rand((batch_size, 1, 28, 28))
ann(rand_input)
print(f"Synops after feeding input: {synops_counter_ann()}")
Synops after feeding input: 57.20000076293945

Training an SNN with fewer synops#

Let’s start by defining your SNN model.

class SNN(nn.Sequential):
    def __init__(self, batch_size):
        super().__init__(
            sl.FlattenTime(),
            nn.Conv2d(1, 16, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.AvgPool2d(2),
            nn.Conv2d(16, 32, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.AvgPool2d(2),
            nn.Conv2d(32, 120, 4, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.Flatten(),
            nn.Linear(120, 10, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.UnflattenTime(batch_size=batch_size),
        )

    @property
    def spiking_layers(self):
        return [layer for layer in self.net.children() if isinstance(layer, sl.StatefulLayer)]

    def reset_states(self):
        for layer in self.spiking_layers:
            layer.reset_states()


batch_size = 5
snn = SNN(batch_size=batch_size)

If we apply a SynopsLossCounter to the model, we’ll be able to track the number of synops as we feed new inputs to the model.

synops_counter = sinabs.SNNSynOpCounter(snn)
print(f"Synops before feeding input: {synops_counter.get_total_synops()}")

rand_input_spikes = (torch.rand((batch_size, 10, 1, 28, 28)) < 0.3).float()
y_hat = snn(rand_input_spikes)
print(f"Synops after feeding input: {synops_counter.get_total_synops()}")
Synops before feeding input: 0.0
Synops after feeding input: 12060000.0

You can also get a more detailed count for each layer.

synops_counter.get_synops()
In Fanout_Prev SynOps N. timesteps Time window (ms) SynOps/s
Layer
2 tensor(11695.) 400.0 tensor(4678000.) 50 50.0 tensor(93560000.)
5 tensor(35926., grad_fn=<MulBackward0>) 800.0 tensor(28740800., grad_fn=<MulBackward0>) 50 50.0 tensor(5.7482e+08, grad_fn=<MulBackward0>)
8 tensor(410., grad_fn=<MulBackward0>) 1920.0 tensor(787200., grad_fn=<MulBackward0>) 50 50.0 tensor(15744000., grad_fn=<MulBackward0>)
11 tensor(0., grad_fn=<MulBackward0>) 10.0 tensor(0., grad_fn=<MulBackward0>) 50 50.0 tensor(0., grad_fn=<MulBackward0>)

Once we have can calculate the total synops, we might want to choose a target synops number in order to decrease power consumption. As a rule of thumb we’re going to take half of the number of initial synops as constant target.

# done once
target_synops = synops_counter.get_total_synops() / 2

# in your training loop
synops = synops_counter.get_total_synops()
synops_loss = (target_synops - synops).square() / target_synops.square()

loss = y_hat.sum(1) + synops_loss