# Number of synaptic operations and how to minimise them#

As described in Sorbaro et al. 2020, it can be beneficial for power consumption and latency in the network to reduce the number of synaptic operations (synops). This number is essentially the output activation multiplied by the number of outward connections (fan-out) to the next layer. We describe how using Sinabs’ synops counters it’s possible to easily add a term to your loss function which then can be minimised.

## Training an ANN with fewer synops#

Let’s start by defining our ANN. Keep in mind that we use NeuromorphicRelus here as we need to discretize the output in the forward pass to simulate the number of spikes that layer would emit. In the backward pass the derivative of the ReLU function is used.

import torch
import torch.nn as nn
import sinabs
import sinabs.layers as sl

ann = nn.Sequential(
nn.Conv2d(1, 16, 5, bias=False),
sl.NeuromorphicReLU(),
nn.AvgPool2d(2),
nn.Conv2d(16, 32, 5, bias=False),
sl.NeuromorphicReLU(),
nn.AvgPool2d(2),
nn.Conv2d(32, 120, 4, bias=False),
sl.NeuromorphicReLU(),
nn.Flatten(),
nn.Linear(120, 10, bias=False),
)


Let’s apply a SynopsCounter to our ANN to track how many synaptic operations it would need in an SNN.

batch_size = 5

synops_counter_ann = sinabs.SynOpCounter(ann.modules(), sum_activations=True)

rand_input = torch.rand((batch_size, 1, 28, 28))
ann(rand_input)
print(f"Synops after feeding input: {synops_counter_ann()}")

Synops after feeding input: 0.0


## Training an SNN with fewer synops#

Let’s start by defining your SNN model.

class SNN(nn.Sequential):
def __init__(self, batch_size):
super().__init__(
sl.FlattenTime(),
nn.Conv2d(1, 16, 5, bias=False),
sl.IAFSqueeze(batch_size=batch_size),
nn.AvgPool2d(2),
nn.Conv2d(16, 32, 5, bias=False),
sl.IAFSqueeze(batch_size=batch_size),
nn.AvgPool2d(2),
nn.Conv2d(32, 120, 4, bias=False),
sl.IAFSqueeze(batch_size=batch_size),
nn.Flatten(),
nn.Linear(120, 10, bias=False),
sl.IAFSqueeze(batch_size=batch_size),
sl.UnflattenTime(batch_size=batch_size),
)

@property
def spiking_layers(self):
return [layer for layer in self.net.children() if isinstance(layer, sl.StatefulLayer)]

def reset_states(self):
for layer in self.spiking_layers:
layer.reset_states()

batch_size = 5
snn = SNN(batch_size=batch_size)


If we apply a SynopsLossCounter to the model, we’ll be able to track the number of synops as we feed new inputs to the model.

synops_counter = sinabs.SNNSynOpCounter(snn)
print(f"Synops before feeding input: {synops_counter.get_total_synops()}")

rand_input_spikes = (torch.rand((batch_size, 10, 1, 28, 28)) < 0.3).float()
y_hat = snn(rand_input_spikes)
print(f"Synops after feeding input: {synops_counter.get_total_synops()}")

Synops before feeding input: 0.0
Synops after feeding input: 9196760.0


You can also get a more detailed count for each layer.

synops_counter.get_synops()

In Fanout_Prev SynOps N. timesteps Time window (ms) SynOps/s
Layer
2 tensor(11737.) 400.0 tensor(4694800.) 50 50.0 tensor(93896000.)
# done once