Source code for sinabs.from_torch

import copy
from dataclasses import dataclass, field
from typing import Callable, Optional, Tuple
from warnings import warn
import torch
from torch import nn
from sinabs.activation import MembraneSubtract, MultiSpike, SingleExponential

from sinabs import Network
import sinabs.layers as sl

_backends = {"sinabs": sl}

try:
    import sinabs.exodus.layers as el
except ModuleNotFoundError:
    pass
else:
    _backends["exodus"] = el


[docs]def from_model( model: nn.Module, input_shape: Optional[Tuple[int, int, int]] = None, spike_threshold: float = 1.0, spike_fn: Callable = MultiSpike, reset_fn: Callable = MembraneSubtract(), surrogate_grad_fn: Callable = SingleExponential(), min_v_mem: float = -1.0, bias_rescaling: float = 1.0, batch_size: Optional[int] = None, num_timesteps: Optional[int] = None, synops: bool = False, add_spiking_output: bool = False, backend: str = "sinabs", kwargs_backend: dict = dict(), ): """ Converts a Torch model and returns a Sinabs network object. The modules in the model are analyzed, and a copy is returned, with all ReLUs, LeakyReLUs and NeuromorphicReLUs turned into SpikingLayers. Parameters: model: Torch model input_shape: If provided, the layer dimensions are computed. Otherwise they will be computed at the first forward pass. spike_threshold: The membrane potential threshold for spiking (same for all layers). spike_fn: The spike dynamics to determine the number of spikes out reset_fn: The reset mechanism of the neuron (like reset to zero, or subtract) surrogate_grad_fn: The surrogate gradient method for the spiking dynamics min_v_mem: The lower bound of the potential in (same for all layers). bias_rescaling: Biases are divided by this value. batch_size: Must be provided if `num_timesteps` is None and is ignored otherwise. num_timesteps: Number of timesteps per sample. If None, `batch_size` must be provided to seperate batch and time dimensions. synops: If True (default: False), register hooks for counting synaptic operations during forward passes. add_spiking_output: If True (default: False), add a spiking layer to the end of a sequential model if not present. backend: String defining the simulation backend (currently sinabs or exodus) kwargs_backend: Dict with additional kwargs for the simulation backend """ return SpkConverter( input_shape=input_shape, spike_threshold=spike_threshold, spike_fn=spike_fn, reset_fn=reset_fn, surrogate_grad_fn=surrogate_grad_fn, min_v_mem=min_v_mem, bias_rescaling=bias_rescaling, batch_size=batch_size, num_timesteps=num_timesteps, synops=synops, add_spiking_output=add_spiking_output, backend=backend, kwargs_backend=kwargs_backend, ).convert(model)
[docs]@dataclass class SpkConverter: """ Converts a Torch model and returns a Sinabs network object. The modules in the model are analyzed, and a copy is returned, with all ReLUs and NeuromorphicReLUs turned into SpikingLayers. """ input_shape: Optional[Tuple[int, int, int]] = None spike_threshold: float = 1.0 spike_fn: Callable = MultiSpike reset_fn: Callable = MembraneSubtract() surrogate_grad_fn: Callable = SingleExponential() min_v_mem: float = -1.0 bias_rescaling: float = 1.0 batch_size: Optional[int] = None num_timesteps: Optional[int] = None synops: bool = False add_spiking_output: bool = False backend: str = "sinabs" kwargs_backend: dict = field(default_factory=dict) def relu2spiking(self): try: backend_module = _backends[self.backend] except KeyError: raise ValueError( f"Backend '{self.backend}' is not available. Available backends: " ", ".join(_backends.keys()) ) return backend_module.IAFSqueeze( spike_threshold=self.spike_threshold, spike_fn=self.spike_fn, reset_fn=self.reset_fn, surrogate_grad_fn=self.surrogate_grad_fn, min_v_mem=self.min_v_mem, batch_size=self.batch_size, num_timesteps=self.num_timesteps, **self.kwargs_backend, ).to(self.device)
[docs] def convert(self, model: nn.Module) -> Network: """ Converts the Torch model and returns a Sinabs network object. Parameters model: A torch module. Returns network: The Sinabs network object created by conversion. """ spk_model = copy.deepcopy(model) # device is taken as the device of the first element of the input state_dict try: self.device = next(model.parameters()).device except StopIteration: self.device = torch.device("cpu") if self.add_spiking_output: # Add spiking output to sequential model if isinstance(spk_model, nn.Sequential) and not isinstance( spk_model[-1], (nn.ReLU, sl.NeuromorphicReLU) ): spk_model.add_module("Spiking output", nn.ReLU()) else: warn( "Spiking output can only be added to sequential models that do not end in a ReLU. No layer has been added." ) self.convert_module(spk_model) network = Network( model, spk_model, input_shape=self.input_shape, synops=self.synops, batch_size=self.batch_size, num_timesteps=self.num_timesteps, ) return network
def convert_module(self, module): submodules = list(module.named_children()) for name, subm in submodules: # if it's one of the layers we're looking for, substitute it if isinstance(subm, (nn.ReLU, sl.NeuromorphicReLU)): setattr(module, name, self.relu2spiking()) elif self.bias_rescaling != 1.0 and isinstance( subm, (nn.Linear, nn.Conv2d) ): if subm.bias is not None: subm.bias.data = ( subm.bias.data.clone().detach() / self.bias_rescaling ) # if in turn it has children, go iteratively inside elif len(list(subm.named_children())): self.convert_module(subm) # otherwise we have a base layer of the non-interesting ones else: pass