import warnings
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from .layers import StatefulLayer
from .synopcounter import SNNAnalyzer
from .utils import get_activations, get_network_activations
ArrayLike = Union[np.ndarray, List, Tuple]
[docs]class Network(torch.nn.Module):
"""Class of a spiking neural network.
Attributes:
spiking_model: torch.nn.Module, a spiking neural network model
analog_model: torch.nn.Module, an artifical neural network model
input_shape: Tuple, size of input
synops: If True (default: False), register hooks for counting synaptic \
operations during forward passes, instantiating `sinabs.SNNSynOpCounter`.
"""
def __init__(
self,
analog_model=None,
spiking_model=None,
input_shape: Optional[ArrayLike] = None,
synops: bool = False,
batch_size: int = 1,
num_timesteps: int = 1,
):
super().__init__()
self.spiking_model: nn.Module = spiking_model
self.analog_model: nn.Module = analog_model
self.input_shape = input_shape
self.synops = synops
if synops:
self.synops_counter = SNNAnalyzer(self.spiking_model)
if input_shape is not None and spiking_model is not None:
self._compute_shapes(
input_shape, batch_size=batch_size, num_timesteps=num_timesteps
)
@property
def layers(self):
return list(self.spiking_model.named_children())
def _compute_shapes(self, input_shape, batch_size=1, num_timesteps=1):
def hook(module, inp, out):
module.out_shape = out.shape[1:]
hook_list = []
for layer in self.spiking_model.modules():
this_hook = layer.register_forward_hook(hook)
hook_list.append(this_hook)
device = next(self.parameters()).device
# Infer shape
if batch_size is None:
batch_size = 1
if num_timesteps is None:
num_timesteps = 1
shape = [batch_size * num_timesteps] + list(input_shape)
dummy_input = torch.zeros(shape, requires_grad=False).to(device)
# do a forward pass
self(dummy_input)
[this_hook.remove() for this_hook in hook_list]
[docs] def forward(self, tsrInput) -> torch.Tensor:
"""Forward pass for this model."""
return self.spiking_model(tsrInput)
[docs] def compare_activations(
self,
data,
name_list: Optional[ArrayLike] = None,
compute_rate: bool = False,
verbose: bool = False,
) -> Tuple[np.ndarray, np.ndarray, str]:
"""Compare activations of the analog model and the SNN for a given data sample.
Args:
data (np.ndarray): Data to process
name_list (List[str]): list of all layer names (str) whose activations need to be compared
compute_rate (bool): True if you want to compute firing rate. By default spike count is returned
verbose (bool): print debugging logs to the terminal
Returns:
tuple: A tuple of lists (ann_activity, snn_activity, name_list)
- ann_activity: output activity of the ann layers
- snn_activity: output activity of the snn layers
- name_list: spiking layers' name list for plotting comparison
"""
if name_list is None:
name_list = ["Input"]
for layer_name, lyr in self.spiking_model.named_modules():
if isinstance(lyr, StatefulLayer):
name_list.append(layer_name)
if verbose:
print("Comparing activations for {0}".format(name_list))
# Calculate activations for the torch analog model
if compute_rate:
tsrAnalogData = data.mean(0).unsqueeze(0)
else:
tsrAnalogData = data.sum(0).unsqueeze(0)
with torch.no_grad():
analog_activations = get_activations(
self.analog_model, tsrAnalogData, name_list=name_list
)
# Calculate activations for spiking model
spike_rates = get_network_activations(
self.spiking_model, data, name_list=name_list, bRate=compute_rate
)
return analog_activations, spike_rates, name_list
[docs] def plot_comparison(
self, data, name_list: Optional[ArrayLike] = None, compute_rate=False
):
"""Plots a scatter plot of all the activations.
Args:
data: Data to be processed
name_list: ArrayLike with names of all the layers of interest to be compared
compute_rate: Compare firing rates instead of spike count
Returns:
tuple: A tuple of lists (ann_activity, snn_activity)
- ann_activity: output activity of the ann layers
- snn_activity: output activity of the snn layers
"""
import pylab
analog_activations, spike_rates, name_list = self.compare_activations(
data, name_list=name_list, compute_rate=compute_rate
)
for nLyrIdx in range(len(name_list)):
pylab.scatter(
spike_rates[nLyrIdx],
analog_activations[nLyrIdx],
label=name_list[nLyrIdx],
)
if compute_rate:
pylab.xlabel("Spike rates (Hz)")
else:
pylab.xlabel("# Spike count")
pylab.ylabel("Analog activations")
pylab.legend()
return analog_activations, spike_rates
[docs] def reset_states(
self,
randomize: bool = False,
value_ranges: Optional[List[Dict[str, Tuple[float, float]]]] = None,
):
"""Reset all neuron states in the submodules.
Parameters
----------
randomize: Bool
If true, reset the states between a range provided. Else, the states are reset to zero.
value_ranges: Optional[List[Dict[str, Tuple[float, float]]]]
A list of value_range dictionaries with the same length as the total stateful layers in the module.
Each dictionary is a key value pair: buffer_name -> (min, max) for each state that needs to be reset.
The states are reset with a uniform distribution between the min and max values specified.
Any state with an undefined key in this dictionary will be reset between 0 and 1
This parameter is only used if randomize is set to true.
"""
if value_ranges:
num_stateful_layers = len(
[None for mod in self.modules() if isinstance(mod, StatefulLayer)]
)
if len(value_ranges) != num_stateful_layers:
raise TypeError(
"The number of entries in value_ranges does not match the number of stateful sub modules"
)
i = 0
for lyr in self.modules():
if isinstance(lyr, StatefulLayer):
if value_ranges is None:
vr = None
else:
vr = value_ranges[i]
i += 1
lyr.reset_states(randomize=randomize, value_ranges=vr)
[docs] def zero_grad(self, set_to_none: bool = False) -> None:
for lyr in self.spiking_model:
lyr.zero_grad(set_to_none)
[docs] def get_synops(self, num_evs_in=None) -> dict:
"""Please see docs for `sinabs.SNNSynOpCounter.get_synops()`."""
if num_evs_in is not None:
warnings.warn("num_evs_in is deprecated and has no effect")
return self.synops_counter.get_synops()
def get_parent_module_by_name(
root: torch.nn.Module, name: str
) -> Tuple[torch.nn.Module, str]:
"""Find a nested Module of a given name inside a Module, and return its parent Module.
Args:
root: The Module inside which to look for the nested Module
name: Name of the Module that is being searched for within root. Must
contain all parent modules, separated by a `.` , e.g.
"root.nested_module1.nested_module2.desired_module"
Returns:
torch.nn.Module: The Module that contains the Module with the given name. In the example
above this would be `nested_module2`.
str: The name of the child, without parent modules, e.g. "desired_module"
"""
if "." not in name:
if not hasattr(root, name):
raise KeyError(f"The requested module `{name}` could not be found.")
return root, name
else:
child_name, *rest = name.split(".")
child = getattr(root, child_name)
return get_parent_module_by_name(child, ".".join(rest))
def infer_module_device(module: torch.nn.Module) -> Union[torch.device, None]:
"""Infere on which device a module is operating by first looking at its parameters and then, if
no parameters are found, at its buffers.
Args:
module: The module whose device is to be inferred.
Returns:
torch.device: The device of 'module', or `None` if no device has been found.
"""
try:
return next(module.parameters()).device
except StopIteration:
# No parameters, try buffers
try:
return next(module.buffers()).device
except StopIteration:
# No buffers, don't infer device
return None