Source code for sinabs.network

import warnings

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union, List, Tuple, Dict
from .utils import get_network_activations, get_activations
from .layers import StatefulLayer
from .synopcounter import SNNSynOpCounter

ArrayLike = Union[np.ndarray, List, Tuple]


[docs]class Network(torch.nn.Module): """ Class of a spiking neural network Attributes: spiking_model: torch.nn.Module, a spiking neural network model analog_model: torch.nn.Module, an artifical neural network model input_shape: Tuple, size of input synops: If True (default: False), register hooks for counting synaptic \ operations during forward passes, instantiating `sinabs.SNNSynOpCounter`. """ def __init__( self, analog_model=None, spiking_model=None, input_shape: Optional[ArrayLike] = None, synops: bool = False, batch_size: int = 1, num_timesteps: int = 1, ): super().__init__() self.spiking_model: nn.Module = spiking_model self.analog_model: nn.Module = analog_model self.input_shape = input_shape self.synops = synops if synops: self.synops_counter = SNNSynOpCounter(self.spiking_model) if input_shape is not None and spiking_model is not None: self._compute_shapes( input_shape, batch_size=batch_size, num_timesteps=num_timesteps ) @property def layers(self): return list(self.spiking_model.named_children()) def _compute_shapes(self, input_shape, batch_size=1, num_timesteps=1): def hook(module, inp, out): module.out_shape = out.shape[1:] hook_list = [] for layer in self.spiking_model.modules(): this_hook = layer.register_forward_hook(hook) hook_list.append(this_hook) device = next(self.parameters()).device # Infer shape if batch_size is None: batch_size = 1 if num_timesteps is None: num_timesteps = 1 shape = [batch_size * num_timesteps] + list(input_shape) dummy_input = torch.zeros(shape, requires_grad=False).to(device) # do a forward pass self(dummy_input) [this_hook.remove() for this_hook in hook_list]
[docs] def forward(self, tsrInput) -> torch.Tensor: """ Forward pass for this model """ return self.spiking_model(tsrInput)
[docs] def compare_activations( self, data, name_list: Optional[ArrayLike] = None, compute_rate: bool = False, verbose: bool = False, ) -> Tuple[np.ndarray, np.ndarray, str]: """ Compare activations of the analog model and the SNN for a given data sample Args: data (np.ndarray): Data to process name_list (List[str]): list of all layer names (str) whose activations need to be compared compute_rate (bool): True if you want to compute firing rate. By default spike count is returned verbose (bool): print debugging logs to the terminal Returns: tuple: A tuple of lists (ann_activity, snn_activity, name_list) - ann_activity: output activity of the ann layers - snn_activity: output activity of the snn layers - name_list: spiking layers' name list for plotting comparison """ if name_list is None: name_list = ["Input"] for layer_name, lyr in self.spiking_model.named_modules(): if isinstance(lyr, StatefulLayer): name_list.append(layer_name) if verbose: print("Comparing activations for {0}".format(name_list)) # Calculate activations for the torch analog model if compute_rate: tsrAnalogData = data.mean(0).unsqueeze(0) else: tsrAnalogData = data.sum(0).unsqueeze(0) with torch.no_grad(): analog_activations = get_activations( self.analog_model, tsrAnalogData, name_list=name_list ) # Calculate activations for spiking model spike_rates = get_network_activations( self.spiking_model, data, name_list=name_list, bRate=compute_rate ) return analog_activations, spike_rates, name_list
[docs] def plot_comparison( self, data, name_list: Optional[ArrayLike] = None, compute_rate=False ): """ Plots a scatter plot of all the activations Args: data: Data to be processed name_list: ArrayLike with names of all the layers of interest to be compared compute_rate: Compare firing rates instead of spike count Returns: tuple: A tuple of lists (ann_activity, snn_activity) - ann_activity: output activity of the ann layers - snn_activity: output activity of the snn layers """ import pylab analog_activations, spike_rates, name_list = self.compare_activations( data, name_list=name_list, compute_rate=compute_rate ) for nLyrIdx in range(len(name_list)): pylab.scatter( spike_rates[nLyrIdx], analog_activations[nLyrIdx], label=name_list[nLyrIdx], ) if compute_rate: pylab.xlabel("Spike rates (Hz)") else: pylab.xlabel("# Spike count") pylab.ylabel("Analog activations") pylab.legend() return analog_activations, spike_rates
[docs] def reset_states( self, randomize: bool = False, value_ranges: Optional[List[Dict[str, Tuple[float, float]]]] = None, ): """ Reset all neuron states in the submodules. Parameters ---------- randomize: Bool If true, reset the states between a range provided. Else, the states are reset to zero. value_ranges: Optional[List[Dict[str, Tuple[float, float]]]] A list of value_range dictionaries with the same length as the total stateful layers in the module. Each dictionary is a key value pair: buffer_name -> (min, max) for each state that needs to be reset. The states are reset with a uniform distribution between the min and max values specified. Any state with an undefined key in this dictionary will be reset between 0 and 1 This parameter is only used if randomize is set to true. """ if value_ranges: num_stateful_layers = len( [None for mod in self.modules() if isinstance(mod, StatefulLayer)] ) if len(value_ranges) != num_stateful_layers: raise TypeError( "The number of entries in value_ranges does not match the number of stateful sub modules" ) i = 0 for lyr in self.modules(): if isinstance(lyr, StatefulLayer): if value_ranges is None: vr = None else: vr = value_ranges[i] i += 1 lyr.reset_states(randomize=randomize, value_ranges=vr)
[docs] def get_synops(self, num_evs_in=None) -> dict: """ Please see docs for `sinabs.SNNSynOpCounter.get_synops()`. """ if num_evs_in is not None: warnings.warn("num_evs_in is deprecated and has no effect") return self.synops_counter.get_synops()
def get_parent_module_by_name( root: torch.nn.Module, name: str ) -> Tuple[torch.nn.Module, str]: """ Find a nested Module of a given name inside a Module, and return its parent Module. Args: root: The Module inside which to look for the nested Module name: Name of the Module that is being searched for within root. Must contain all parent modules, separated by a `.` , e.g. "root.nested_module1.nested_module2.desired_module" Returns: torch.nn.Module: The Module that contains the Module with the given name. In the example above this would be `nested_module2`. str: The name of the child, without parent modules, e.g. "desired_module" """ if "." not in name: if not hasattr(root, name): raise KeyError(f"The requested module `{name}` could not be found.") return root, name else: child_name, *rest = name.split(".") child = getattr(root, child_name) return get_parent_module_by_name(child, ".".join(rest)) def infer_module_device(module: torch.nn.Module) -> Union[torch.device, None]: """ Infere on which device a module is operating by first looking at its parameters and then, if no parameters are found, at its buffers. Args: module: The module whose device is to be inferred. Returns: torch.device: The device of 'module', or `None` if no device has been found. """ try: return next(module.parameters()).device except StopIteration: # No parameters, try buffers try: return next(module.buffers()).device except StopIteration: # No buffers, don't infer device return None