- class sinabs.synopcounter.SNNAnalyzer(model: Module, dt: float = 1.0)[source]#
Helper class to acquire statistics for spiking and parameter layers at the same time. To calculate the number of synapses between neurons accurately, a simple scaling factor based on the kernel size is not enough, as neurons on the edges of the input will have a different amount of connections as neurons in the center. This is why we make use of a transposed convolution layer to calculate this synaptic connection map once. The amount of synapses between two layers depends on all parameters of a conv layer such as kernel size, stride, groups etc. Transposed conv will take all those parameters into account and ‘reproject’ the output of a conv layer. As long as the spatial dimensions don’t change during training, we can reuse the same connection map, which is a tensor of the same dimensions as the layer output. We can therefore calculate the number of synaptic operations accurately for each layer by multiplying the respective connection map with the output.
model (Module) – Your PyTorch model.
dt (float) – the number of milliseconds corresponding to a time step in the simulation (default 1.0).
>>> analyzer = SNNAnalyzer(my_spiking_model) >>> output = my_spiking_model(input_) # forward pass >>> layer_stats = analyzer.get_layer_statistics() >>> model_stats = analyzer.get_model_statistics()
- get_layer_statistics(average: bool = False) dict [source]#
Outputs a dictionary with statistics for each individual layer.
average (bool) – The statistics such as firing rate per neuron, the number of neurons or synops are averaged across batches.
- Return type:
- class sinabs.synopcounter.SynOpCounter(modules, sum_activations=True)[source]#
Counter for the synaptic operations emitted by all Neuromorphic ReLUs in an analog CNN model.
modules – list of modules, e.g. MyTorchModel.modules()
sum_activations – If True (default), returns a single number of synops, otherwise a list of layer synops.
>>> counter = SynOpCounter(MyTorchModel.modules(), sum_activations=True) >>> output = MyTorchModule(input) # forward pass >>> synop_count = counter()