Source code for sinabs.activation.surrogate_gradient_fn

import math
from dataclasses import dataclass

import torch


[docs]@dataclass class Heaviside: """Heaviside surrogate gradient with optional shift. Parameters: window: Distance between step of Heaviside surrogate gradient and threshold, relative to threshold. """ window: float = 1.0 def __call__(self, v_mem, spike_threshold): return ((v_mem >= (spike_threshold - self.window)).float()) / spike_threshold
def gaussian(x: torch.Tensor, mu: float, sigma: float): return torch.exp(-((x - mu) ** 2) / (2 * sigma**2)) / ( sigma * torch.sqrt(2 * torch.tensor(math.pi)) )
[docs]@dataclass class Gaussian: """Gaussian surrogate gradient function. Parameters mu: The mean of the Gaussian. sigma: The standard deviation of the Gaussian. grad_scale: Scale the gradients arbitrarily. """ mu: float = 0.0 sigma: float = 0.5 grad_scale: float = 1.0 def __call__(self, v_mem, spike_threshold): return ( gaussian(x=v_mem - spike_threshold, mu=self.mu, sigma=self.sigma) * self.grad_scale )
[docs]@dataclass class MultiGaussian: """Surrogate gradient as defined in Yin et al., 2021. https://www.biorxiv.org/content/10.1101/2021.03.22.436372v2 Parameters mu: The mean of the Gaussian. sigma: The standard deviation of the Gaussian. h: Controls the magnitude of the negative parts of the kernel. s: Controls the width of the negative parts of the kernel. grad_scale: Scale the gradients arbitrarily. """ mu: float = 0.0 sigma: float = 0.5 h: float = 0.15 s: float = 6 grad_scale: float = 1.0 def __call__(self, v_mem, spike_threshold): return ( (1 + self.h) * gaussian(x=v_mem - spike_threshold, mu=self.mu, sigma=self.sigma) - self.h * gaussian( x=v_mem - spike_threshold, mu=self.sigma, sigma=self.s * self.sigma ) - self.h * gaussian( x=v_mem - spike_threshold, mu=-self.sigma, sigma=self.s * self.sigma ) ) * self.grad_scale
[docs]@dataclass class SingleExponential: """Surrogate gradient as defined in Shrestha and Orchard, 2018. https://papers.nips.cc/paper/2018/hash/82f2b308c3b01637c607ce05f52a2fed-Abstract.html """ grad_width: float = 0.5 grad_scale: float = 1.0 def __call__(self, v_mem, spike_threshold): abs_width = spike_threshold * self.grad_width return ( self.grad_scale / abs_width * torch.exp(-torch.abs(v_mem - spike_threshold) / abs_width) )
[docs]@dataclass class PeriodicExponential: """Surrogate gradient as defined in Weidel and Sheik, 2021. https://arxiv.org/abs/2111.01456 """ grad_width: float = 0.5 grad_scale: float = 1.0 def __call__(self, v_mem, spike_threshold): # Normalize v_mem between -0.5 and 0.5 vmem_normalized = v_mem / spike_threshold - 0.5 # This is a periodic, stepwise linear function with discontinuities for # vmem == (N + 0.5) * spike_threshold (limit from the left: -spike_threshold, # limit from the right: +spike_threshold), 0 when vmem == N * spike_threshold vmem_periodic = vmem_normalized - torch.floor(vmem_normalized) vmem_periodic = spike_threshold * (2 * vmem_periodic - 1) # Combine different curves for vmem below and above spike_threshold vmem_below = (v_mem - spike_threshold) * (v_mem < spike_threshold) vmem_above = vmem_periodic * (v_mem >= spike_threshold) vmem_new = vmem_above + vmem_below surrogate = torch.exp(-torch.abs(vmem_new) / self.grad_width) return self.grad_scale * surrogate