Source code for sinabs.synopcounter

import warnings

import torch
from sinabs.layers import NeuromorphicReLU
from numpy import product

def synops_hook(layer, inp, out):
    assert len(inp) == 1, "Multiple inputs not supported for synops hook"
    inp = inp[0]
    layer.tot_in = inp.sum()
    layer.tot_out = out.sum()
    layer.synops = layer.tot_in * layer.fanout = inp.shape[0]

[docs]class SNNSynOpCounter: """ Counter for the synaptic operations emitted by all SpikingLayers in a spiking model. Note that this is automatically instantiated by `from_torch` and by `Network` if they are passed `synops=True`. Usage: counter = SNNSynOpCounter(my_spiking_model) output = my_spiking_model(input) # forward pass synops_table = counter.get_synops() Arguments: model: Spiking model. dt: the number of milliseconds corresponding to a time step in the \ simulation (default 1.0). """ def __init__(self, model, dt=1.0): self.model = model self.handles = [] self.dt = dt for layer in model.modules(): self._register_synops_hook(layer) def _register_synops_hook(self, layer): if isinstance(layer, torch.nn.Conv2d): layer.fanout = ( layer.out_channels * product(layer.kernel_size) / product(layer.stride) ) elif isinstance(layer, torch.nn.Linear): layer.fanout = layer.out_features else: return None handle = layer.register_forward_hook(synops_hook) self.handles.append(handle)
[docs] def get_synops(self) -> dict: """ Method to compute a table of synaptic operations for the latest forward pass. The returned dictionary can be parsed into a table using pandas like so, ``` synops_map = counter.get_synops() SynOps_dataframe = pandas.DataFrame.from_dict(synops_map, "index") SynOps_dataframe.set_index("Layer", inplace=True) ``` NOTE: this may not be accurate in presence of average pooling. Returns: SynOps_map: A dictionary containing layer IDs and \ respectively, for the latest forward pass performed, their: number of input spikes, fanout, synaptic operations, number of timesteps, total duration of simulation, number of synaptic operations per second. """ SynOps_map = {} scale_facts = [] for i, lyr in enumerate(self.model.modules()): if isinstance(lyr, torch.nn.AvgPool2d): if lyr.kernel_size != lyr.stride: warnings.warn( f"In order for the Synops counter to work accurately the pooling " f"layers kernel size should match their strides. At the moment at layer {i}, " f"the kernel_size = {lyr.kernel_size}, the stride = {lyr.stride}." ) ks = lyr.kernel_size scale_factor = ks**2 if isinstance(ks, int) else ks[0] * ks[1] scale_facts.append(scale_factor) if hasattr(lyr, "synops"): scale_factor = 1 while len(scale_facts) != 0: scale_factor *= scale_facts.pop() SynOps_map[i] = { "Layer": i, "In": lyr.tot_in * scale_factor, "Fanout_Prev": lyr.fanout, "SynOps": lyr.synops * scale_factor, "N. timesteps":, "Time window (ms)": * self.dt, "SynOps/s": (lyr.synops * scale_factor) / / self.dt * 1000, } return SynOps_map
[docs] def get_total_synops(self, per_second=False) -> float: """ Faster method for computing total synaptic operations without using Pandas. NOTE: this may not be accurate in presence of average pooling. Arguments: per_second (bool, default False): if True, gives synops per second \ instead of total synops in the last forward pass. Returns: synops: the total synops in the network, based on the last forward pass. """ synops = 0.0 for i, lyr in enumerate(self.model.modules()): if hasattr(lyr, "synops"): if per_second: layer_synops = lyr.synops / / self.dt * 1000 else: layer_synops = lyr.synops synops = synops + layer_synops return synops
[docs] def get_total_power_use(self, j_per_synop=1e-11): """ Method to quickly get the total power use of the network, estimated over the latest forward pass. Arguments: j_per_synop: Energy use per synaptic operation, in joules.\ Default 1e-11 J. Returns: estimated power in mW. """ tot_synops_per_s = self.get_total_synops(per_second=True) power_in_mW = tot_synops_per_s * j_per_synop * 1000 return power_in_mW
def __del__(self): for handle in self.handles: handle.remove()
[docs]class SynOpCounter: """ Counter for the synaptic operations emitted by all Neuromorphic ReLUs in an analog CNN model. Usage: counter = SynOpCounter(MyTorchModel.modules(), sum_activations=True) output = MyTorchModule(input) # forward pass synop_count = counter() :param modules: list of modules, e.g. MyTorchModel.modules() :param sum_activations: If True (default), returns a single number of synops, otherwise a list of layer synops. """ def __init__(self, modules, sum_activations=True): self.modules = [] for module in modules: if isinstance(module, NeuromorphicReLU) and module.fanout > 0: self.modules.append(module) if len(self.modules) == 0: raise ValueError("No NeuromorphicReLU found in module list.") self.sum_activations = sum_activations # self.modules[1:] = [] def __call__(self): synops = [] for module in self.modules: synops.append(module.activity) if self.sum_activations: synops = torch.stack(synops).sum() return synops