{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A quick (and over-simplified) introduction to spiking neurons\n",
    "\n",
    "Author: Sadique Sheik\n",
    "\n",
    "Spiking neural networks are considered to be the `third generation` of neural networks preceeded by McCulloch-Pitts threshold neurons (`first generation`) which produced digital outputs and Artificial Neural Networks with continuous activations, like sigmoids and hyperbolic tangets, (`second generation`) that are commonly used these days.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Artificial Neuron (AN) Model\n",
    "\n",
    "The transfer function of a neuron in ANNs can be expressed as follows.\n",
    "\n",
    "$$\\vec{y} = f(W.\\vec{x} + b)$$\n",
    "\n",
    "where $f$ is typically a non linear activation function such as a `sigmoid` or `hyperbolic tangent` function.\n",
    "\n",
    "\n",
    "> Note that the output directly depends on the inputs. The neuron does not have an internal state that would affect its output.\n",
    "\n",
    "\n",
    "This is where spiking neurons and spiking neural networks (SNNs) come into play. They add an additional dependence on the current `state` of the neuron."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Leaky Integrate and Fire (LIF)  model\n",
    "\n",
    "One of the simplest models of spiking neurons is the Leaky Integrate and Fire (LIF) neuron model, where the state of the neuron, typically the `membrane potential` $v$,  is dependent on its previous state in addition to its inputs. The dynamics of this state can be described as follows.\n",
    "\n",
    "$$\\tau \\dot{v} = - v(t) + R \\cdot \\big(I_{syn} + I_{bias}\\big)$$\n",
    "\n",
    "where $\\tau$ is the membrane time constant, which determines how much the neuron depends on its previous states and inputs and therefore defines the neuron's memory. The linear equation above describes the dynamics of a system with `leaky` dynamics ie over time, the membrane potential $v$ slowly `leaks` to $0$. This is why the neuron model is referred to as a `Leaky Integrate and Fire Neuron model`.\n",
    "\n",
    "\n",
    "> Note that there is an additional notion of time, which does not occor explicitly in an artificial neuron model.\n",
    "\n",
    "\n",
    "$I_{syn} := W.\\vec{x}$ is the weighted sum of all the input snaptic contributions and $I_{bias}$ is a constant bias. The constant resistance $R$ has unity value and only serves to match units between membrane potential and currents.\n",
    "\n",
    "The output of this neuron is binary and instantaneous. It is $1$ only when $v$ reaches a threshold value $v_{th}$ upon which the membrane potential is immediately reset to a reset value $v_{reset} < v_{th}$. The dynamics of $v$ can be expressed as:\n",
    "\n",
    "$$\n",
    "v(t+\\delta) = \n",
    "\\left\\{\n",
    "\\begin{array}{}\n",
    "v_{reset},& \\text{if } v(t) \\geq v_{th} \\\\ \n",
    "v(t),& \\text{otherwise} \n",
    "\\end{array}\n",
    "\\right.\n",
    "$$\n",
    "\n",
    "\n",
    "This instantanious output of $1$ is typically referred to as `spike` at time $t$. A series of such spikes will hence forth be referred to as $s(t)$.\n",
    "\n",
    "$$\n",
    "s(t) = \n",
    "\\left\\{\n",
    "\\begin{array}{}\n",
    "1,& \\text{if } v(t)\\geq v_{th}\\\\\n",
    "0,& \\text{otherwise}\n",
    "\\end{array}\n",
    "\\right.\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A simple LIF neuron simulation for `t_sim=100` time steps\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "t_sim = 100  # ms\n",
    "tau = 20  # ms\n",
    "v_th = 10.0  # Membrane spiking threshold\n",
    "v_reset = 0  # Resting potential\n",
    "vmem = [0]  # Initial membrane potential\n",
    "i_syn = 11  # unit free, assuming constant synaptic input\n",
    "spike_train = []\n",
    "for t in range(t_sim):\n",
    "    # Check threshold\n",
    "    if vmem[t] >= v_th:\n",
    "        vmem[t] = v_reset\n",
    "        spike_train.append(t)\n",
    "    # Membrane dynamics\n",
    "    dv = (1 / tau) * (-vmem[t] + i_syn)\n",
    "    # Save data\n",
    "    vmem.append(vmem[t] + dv)\n",
    "\n",
    "print(\n",
    "    f\"Neuron spikes {len(spike_train)} times at the following simulation time steps :{spike_train}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can plot the membrane potential trace to see how it evolves over time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# Plot membrane potential\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(range(t_sim + 1), vmem)\n",
    "plt.title(\"LIF membrame dynamics\")\n",
    "plt.xlabel(\"$t$\")\n",
    "_ = plt.ylabel(\"$V_{mem}$\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constant-leak Integrate and Fire (IAF) model\n",
    "\n",
    "The LIF model describes a dynamical system where $v$ is continually being updated. While this is a realistic description of most systems in the real world, for the sake of simplicity (and ease of computation on a digital system like a PC), if we assume that there is only a constant `leak` (or no `leak` if $v_{leak} = 0$) in its membrane potential, such a model will henceforth be referred to as an `IAF` mode and is described by the following equation:\n",
    "\n",
    "$\\tau \\dot{v} = - v_{leak} + R \\cdot \\big(I_{syn} + I_{bias}\\big)$\n",
    "\n",
    "with $v_{leak}$ being the constant `leak`. Because it is a constant, it can be joined with the bias, such that the equation for $v$ can be simplified to\n",
    "\n",
    "$\\tau \\dot{v} = R \\cdot \\big(I_{syn} + \\tilde{I}_{bias}\\big)$\n",
    "\n",
    "with $\\tilde{I}_{bias} := I_{bias} + \\frac{v_{leak}}{R}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IAF model in SINABS\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sinabs.layers as sl\n",
    "\n",
    "\n",
    "# Define a neuron in 'SINABS'\n",
    "neuron = sl.IAF()\n",
    "\n",
    "# Set a constant bias current to excite the neuron\n",
    "input_current = torch.zeros(1, t_sim, 1) + 0.025\n",
    "\n",
    "# Membrane potential trace\n",
    "vmem = [0]\n",
    "\n",
    "# Output spike raster\n",
    "spike_train = []\n",
    "for t in range(t_sim):\n",
    "    with torch.no_grad():\n",
    "        out = neuron(input_current[:, t : t + 1])\n",
    "        # Check if there is a spike\n",
    "        if (out != 0).any():\n",
    "            spike_train.append(t)\n",
    "        # Record membrane potential\n",
    "        vmem.append(neuron.v_mem[0, 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot membrane potential\n",
    "plt.figure()\n",
    "plt.plot(range(t_sim + 1), vmem)\n",
    "plt.title(\"IAF membrame dynamics\")\n",
    "plt.xlabel(\"$t$\")\n",
    "plt.ylabel(\"$V_{mem}$\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Activation function\n",
    "\n",
    "The activation functions for artificial neuron models are fairly straight forward to understand. Given an input $x$ and an activation function $\\Phi$ the output $y$ of this neuron would be:\n",
    " \n",
    "$$y = \\Phi (W.x + b)$$\n",
    "\n",
    "For a spiking neuron model such as `IAF`, what would this activation function look like? Defining such a function is fairly complex because the output $y$ doesn't just depend on input $x$ but also its state. Moreover the output of these neurons is a series of spikes $s(t)$ ie series of Dirac delta functions.\n",
    "\n",
    "One way to interpret the resonse of the spiking neuron to its inputs is to look at the number of spikes the neuron produces within a time window. For instance, in the example above we can say that the neuron produces an output of $2$ for a `bias` of $0.025$ (which was our choice of `bias` parameter) and input `zero`. \n",
    "\n",
    "```\n",
    "We will refer to this `interpretation` of spikes as `rate coding` because we are looking at spike rate of neurons as our information medium.\n",
    "```\n",
    "\n",
    "With this `rate coding` interpretation of spiking neurons let us look at what the transfer function $\\Phi$ looks like for `IAF` neuron model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# IAF neuron activation function\n",
    "import numpy as np\n",
    "\n",
    "# Define a neuron in 'SINABS'\n",
    "neuron = sl.IAF()\n",
    "\n",
    "# Simulation time\n",
    "t_sim = 100\n",
    "\n",
    "\n",
    "spike_rates_out = []\n",
    "\n",
    "# Input\n",
    "input_spikes = torch.zeros((1, t_sim, 1))  # (batch, time, neuron)\n",
    "\n",
    "# Range of biases to be added.\n",
    "biases = torch.arange(-5, 5, 0.2)  # Range of inputs to iterate over\n",
    "\n",
    "for bias in biases:\n",
    "    # Reset the neuron\n",
    "    neuron.reset_states()\n",
    "    # Set a constant bias current to excite the neuron\n",
    "    inp = input_spikes + bias\n",
    "    # Output spike raster\n",
    "    with torch.no_grad():\n",
    "        # Simulate neuron for t_sim timesteps\n",
    "        out = neuron(inp)\n",
    "        \"\"\"\n",
    "        Notice that we can do the entire simulation in one line,\n",
    "        unlike the code block earlier where we had a for loop to iterate over all time steps.\n",
    "        \"\"\"\n",
    "        spike_rates_out.append(out.sum())\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(biases, spike_rates_out)\n",
    "plt.title(\"IAF neuron activation funciton $\\Phi$\")\n",
    "plt.ylabel(\"Output spike count\")\n",
    "plt.xlabel(\"Input current\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IAF neuron's transfer function is equivalent to ReLU\n",
    "\n",
    "The above figure must look familiar if you have worked with ANNs, especially deep learning models; the activation function of a `IAF neuron` can be interpreted as that of a `Rectified Linear Unit (ReLU)`'s activation function.\n",
    "\n",
    "__This observation is the entry point to porting Deep Learning (DL) models consisting of layers with ReLU (which is a majority) activations to SNNs.__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "- Farabet C, Paz R, Pérez-Carrasco J, Zamarreño C, Linares-Barranco A, LeCun Y, Culurciello E, Serrano-Gotarredona T, Linares-Barranco B. Comparison between frame-constrained fix-pixel-value and frame-free spiking-dynamic-pixel convNets for visual processing. Frontiers in neuroscience. 2012 Apr 12;6:32.\n",
    "- Neil D, Pfeiffer M, Liu SC. Learning to be efficient: Algorithms for training low-latency, low-compute deep spiking neural networks. InProceedings of the 31st annual ACM symposium on applied computing 2016 Apr 4 (pp. 293-298). ACM.\n",
    "- Rueckauer B, Lungu IA, Hu Y, Pfeiffer M, Liu SC. Conversion of continuous-valued deep networks to efficient event-driven networks for image classification. Frontiers in neuroscience. 2017 Dec 7;11:682.\n",
    "- Sengupta A, Ye Y, Wang R, Liu C, Roy K. Going deeper in spiking neural networks: VGG and residual architectures. Frontiers in neuroscience. 2019;13.\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
