from_torch
from_torch#
This module provides support for importing models into the sinabs from pytorch.
- class sinabs.from_torch.SpkConverter(input_shape: typing.Optional[typing.Tuple] = None, spike_threshold=1.0, spike_fn: typing.Callable = <sinabs.activation.spike_generation.MultiSpike object>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem: float = -1.0, bias_rescaling: float = 1.0, num_timesteps: typing.Optional[int] = None, batch_size: int = 1, synops: bool = False, add_spiking_output: bool = False, backend: str = 'bptt', kwargs_backend: typing.Optional[dict] = None)[source]#
Converts a Torch model and returns a Sinabs network object. The modules in the model are analyzed, and a copy is returned, with all ReLUs, LeakyReLUs and NeuromorphicReLUs turned into SpikingLayers.
- Parameters
input_shape (Optional[Tuple]) – If provided, the layer dimensions are computed. Otherwise they will computed at the first forward pass.
spike_threshold – The membrane potential threshold for spiking layers (same for all layers).
spike_fn (Callable) – The spike dynamics to determine the number of spikes out
reset_fn (Callable) – The reset mechanism of the neuron (like reset to zero, or subtract)
surrogate_grad_fn (Callable) – The surrogate gradient method for the spiking dynamics
min_v_mem (float) – The lower bound of the potential in convolutional and linear layers (same for all layers).
bias_rescaling (float) – Biases are divided by this value.
num_timesteps (Optional[int]) – Number of timesteps per sample. If None, batch_size must be provided to seperate batch and time dimensions.
batch_size (int) – Must be provided if num_timesteps is None and is ignored otherwise.
synops (bool) – If True (default: False), register hooks for counting synaptic operations during foward passes.
add_spiking_output (bool) – If True (default: False), add a spiking layer to the end of a sequential model if not present.
backend (str) – String defining the simulation backend (currently sinabs or exodus)
kwargs_backend (dict) – Dict with additional kwargs for the simulation backend
- convert(model: torch.nn.modules.module.Module) sinabs.network.Network [source]#
Converts the Torch model and returns a Sinabs network object. :param model: A torch module.
- Returns
The Sinabs network object created by conversion.
- Return type
network
- Parameters
model (torch.nn.modules.module.Module) –
- sinabs.from_torch.from_model(model, input_shape=None, spike_threshold=1.0, spike_fn: typing.Callable = <class 'sinabs.activation.spike_generation.MultiSpike'>, reset_fn: typing.Callable = MembraneSubtract(subtract_value=None), surrogate_grad_fn: typing.Callable = SingleExponential(grad_width=0.5, grad_scale=1.0), min_v_mem=-1.0, bias_rescaling=1.0, num_timesteps=None, batch_size=1, synops=False, add_spiking_output=False, backend='sinabs', kwargs_backend=None)[source]#
Converts a Torch model and returns a Sinabs network object. The modules in the model are analyzed, and a copy is returned, with all ReLUs, LeakyReLUs and NeuromorphicReLUs turned into SpikingLayers.
- Parameters
model – Torch model
input_shape – If provided, the layer dimensions are computed. Otherwise they will be computed at the first forward pass.
spike_threshold – The membrane potential threshold for spiking (same for all layers).
spike_fn (Callable) – The spike dynamics to determine the number of spikes out
reset_fn (Callable) – The reset mechanism of the neuron (like reset to zero, or subtract)
surrogate_grad_fn (Callable) – The surrogate gradient method for the spiking dynamics
min_v_mem – The lower bound of the potential in (same for all layers).
bias_rescaling – Biases are divided by this value.
num_timesteps – Number of timesteps per sample. If None, batch_size must be provided to seperate batch and time dimensions.
batch_size – Must be provided if num_timesteps is None and is ignored otherwise.
synops – If True (default: False), register hooks for counting synaptic operations during forward passes.
add_spiking_output – If True (default: False), add a spiking layer to the end of a sequential model if not present.
backend – String defining the simulation backend (currently sinabs or exodus)
kwargs_backend – Dict with additional kwargs for the simulation backend