import warnings
import torch
from sinabs.layers import NeuromorphicReLU
from numpy import product
def synops_hook(layer, inp, out):
assert len(inp) == 1, "Multiple inputs not supported for synops hook"
inp = inp[0]
layer.tot_in = inp.sum()
layer.tot_out = out.sum()
layer.synops = layer.tot_in * layer.fanout
layer.tw = inp.shape[0]
[docs]class SNNSynOpCounter:
"""
Counter for the synaptic operations emitted by all SpikingLayers in a
spiking model.
Note that this is automatically instantiated by `from_torch` and by
`Network` if they are passed `synops=True`.
Arguments:
model: Spiking model.
dt: the number of milliseconds corresponding to a time step in the \
simulation (default 1.0).
Example:
>>> counter = SNNSynOpCounter(my_spiking_model)
>>> output = my_spiking_model(input) # forward pass
>>> synops_table = counter.get_synops()
"""
def __init__(self, model, dt=1.0):
self.model = model
self.handles = []
self.dt = dt
for layer in model.modules():
self._register_synops_hook(layer)
def _register_synops_hook(self, layer):
if isinstance(layer, torch.nn.Conv2d):
layer.fanout = (
layer.out_channels * product(layer.kernel_size) / product(layer.stride)
)
elif isinstance(layer, torch.nn.Linear):
layer.fanout = layer.out_features
else:
return None
handle = layer.register_forward_hook(synops_hook)
self.handles.append(handle)
[docs] def get_synops(self) -> dict:
"""Method to compute a table of synaptic operations for the latest forward pass.
.. note:: this may not be accurate in presence of average pooling.
Returns:
A dictionary containing layer IDs and respectively, for
the latest forward pass performed, their number of input spikes,
fanout, synaptic operations, number of timesteps, total duration
of simulation, number of synaptic operations per second.
Example:
>>> synops_map = counter.get_synops()
>>> SynOps_dataframe = pandas.DataFrame.from_dict(synops_map, "index")
>>> SynOps_dataframe.set_index("Layer", inplace=True)
"""
SynOps_map = {}
scale_facts = []
for i, lyr in enumerate(self.model.modules()):
if isinstance(lyr, torch.nn.AvgPool2d):
if lyr.kernel_size != lyr.stride:
warnings.warn(
f"In order for the Synops counter to work accurately the pooling "
f"layers kernel size should match their strides. At the moment at layer {i}, "
f"the kernel_size = {lyr.kernel_size}, the stride = {lyr.stride}."
)
ks = lyr.kernel_size
scale_factor = ks**2 if isinstance(ks, int) else ks[0] * ks[1]
scale_facts.append(scale_factor)
if hasattr(lyr, "synops"):
scale_factor = 1
while len(scale_facts) != 0:
scale_factor *= scale_facts.pop()
SynOps_map[i] = {
"Layer": i,
"In": lyr.tot_in * scale_factor,
"Fanout_Prev": lyr.fanout,
"SynOps": lyr.synops * scale_factor,
"N. timesteps": lyr.tw,
"Time window (ms)": lyr.tw * self.dt,
"SynOps/s": (lyr.synops * scale_factor) / lyr.tw / self.dt * 1000,
}
return SynOps_map
[docs] def get_total_synops(self, per_second=False) -> float:
"""
Sums up total number of synaptic operations across the network.
.. note:: this may not be accurate in presence of average pooling.
Arguments:
per_second (bool, default False): if True, gives synops per second
instead of total synops in the last forward pass.
Returns:
synops: the total synops in the network, based on the last forward pass.
"""
synops = 0.0
for i, lyr in enumerate(self.model.modules()):
if hasattr(lyr, "synops"):
if per_second:
layer_synops = lyr.synops / lyr.tw / self.dt * 1000
else:
layer_synops = lyr.synops
synops = synops + layer_synops
return synops
[docs] def get_total_power_use(self, j_per_synop=1e-11):
"""
Method to quickly get the total power use of the network, estimated
over the latest forward pass.
Arguments:
j_per_synop: Energy use per synaptic operation, in joules.\
Default 1e-11 J.
Returns:
estimated power in mW.
"""
tot_synops_per_s = self.get_total_synops(per_second=True)
power_in_mW = tot_synops_per_s * j_per_synop * 1000
return power_in_mW
def __del__(self):
for handle in self.handles:
handle.remove()
[docs]class SynOpCounter:
"""
Counter for the synaptic operations emitted by all Neuromorphic ReLUs in an
analog CNN model.
Parameters:
modules: list of modules, e.g. MyTorchModel.modules()
sum_activations: If True (default), returns a single number of synops, otherwise a list of layer synops.
Example:
>>> counter = SynOpCounter(MyTorchModel.modules(), sum_activations=True)
>>> output = MyTorchModule(input) # forward pass
>>> synop_count = counter()
"""
def __init__(self, modules, sum_activations=True):
self.modules = []
for module in modules:
if isinstance(module, NeuromorphicReLU) and module.fanout > 0:
self.modules.append(module)
if len(self.modules) == 0:
raise ValueError("No NeuromorphicReLU found in module list.")
self.sum_activations = sum_activations
# self.modules[1:] = []
def __call__(self):
synops = []
for module in self.modules:
synops.append(module.activity)
if self.sum_activations:
synops = torch.stack(synops).sum()
return synops