{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Converting an ANN to an SNN\n",
    "\n",
    "This tutorial walks you through how to convert your pre-trained model to a spiking version.\n",
    "Lets start by installing all the necessary packages. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining an ANN\n",
    "We define a simple convolutional architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "ann = nn.Sequential(\n",
    "    nn.Conv2d(1, 20, 5, 1, bias=False),\n",
    "    nn.ReLU(),\n",
    "    nn.AvgPool2d(2, 2),\n",
    "    nn.Conv2d(20, 32, 5, 1, bias=False),\n",
    "    nn.ReLU(),\n",
    "    nn.AvgPool2d(2, 2),\n",
    "    nn.Conv2d(32, 128, 3, 1, bias=False),\n",
    "    nn.ReLU(),\n",
    "    nn.AvgPool2d(2, 2),\n",
    "    nn.Flatten(),\n",
    "    nn.Linear(128, 500, bias=False),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(500, 10, bias=False),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a custom dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll need to fine-tune our model on our dataset of choice. Here we'll use MNIST. Since we intend to do a spiking neural network simulation, we override this `Dataset` to also *optionally* return a `spike raster` instead of an image. \n",
    "\n",
    "In this implementation of the `Dataset` we use *rate coding* to generate a series of spikes at each pixel of the image proportional to it's gray level."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "\n",
    "class MNIST(datasets.MNIST):\n",
    "    def __init__(self, root, train=True, is_spiking=False, time_window=100):\n",
    "        super().__init__(\n",
    "            root=root, train=train, download=True, transform=transforms.ToTensor()\n",
    "        )\n",
    "        self.is_spiking = is_spiking\n",
    "        self.time_window = time_window\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img, target = self.data[index].unsqueeze(0) / 255, self.targets[index]\n",
    "        # img is now a tensor of 1x28x28\n",
    "\n",
    "        if self.is_spiking:\n",
    "            img = (torch.rand(self.time_window, *img.shape) < img).float()\n",
    "\n",
    "        return img, target"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune the ANN\n",
    "\n",
    "We'll make sure that classification accuracy is high enough with this model on MNIST. Note here that we are not yet using spiking input (`spiking=False`). This is vanilla training for standard image classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "mnist_train = MNIST(\"./data\", train=True, is_spiking=False)\n",
    "train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)\n",
    "\n",
    "mnist_test = MNIST(\"./data\", train=False, is_spiking=False)\n",
    "test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We iterate over our data loader `train_loader` and train our parameters using the `Adam` optimizer with a learning rate of `1e-4`. Since the last layer in our network has no specific activation function defined, `cross_entropy` loss is a good candidate to train our network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f260aac6f8c40d88f094f5317f34a76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "ann = ann.to(device)\n",
    "ann.train()\n",
    "\n",
    "optim = torch.optim.Adam(ann.parameters(), lr=1e-3)\n",
    "\n",
    "n_epochs = 2\n",
    "\n",
    "for n in tqdm(range(n_epochs)):\n",
    "    for data, target in iter(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        output = ann(data)\n",
    "        optim.zero_grad()\n",
    "\n",
    "        loss = F.cross_entropy(output, target)\n",
    "        loss.backward()\n",
    "        optim.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification accuracy: 97.5%\n"
     ]
    }
   ],
   "source": [
    "correct_predictions = []\n",
    "\n",
    "for data, target in iter(test_loader):\n",
    "    data, target = data.to(device), target.to(device)\n",
    "    output = ann(data)\n",
    "\n",
    "    # get the index of the max log-probability\n",
    "    pred = output.argmax(dim=1, keepdim=True)\n",
    "\n",
    "    # Compute the total correct predictions\n",
    "    correct_predictions.append(pred.eq(target.view_as(pred)))\n",
    "\n",
    "correct_predictions = torch.cat(correct_predictions)\n",
    "print(\n",
    "    f\"Classification accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training this model on `MNIST` is fairly straight forward and you should reach accuracies of around `>98%` within a small number of epochs. In the script above we only train for 3 epochs!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Model conversion to SNN\n",
    "\n",
    "Up until this point we have only operated on images using standard CNN architectures. Now we look at how to build an equivalent spiking convolutional neural network (`SCNN`).\n",
    "\n",
    "`sinabs` has a handy method for this. Given a standard CNN model, the `from_model` method in `sinabs` that converts it into a spiking neural network. It is a *one liner*! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sinabs.from_torch import from_model\n",
    "\n",
    "input_shape = (1, 28, 28)\n",
    "num_timesteps = 100  # per sample\n",
    "\n",
    "sinabs_model = from_model(\n",
    "    ann, input_shape=input_shape, add_spiking_output=True, synops=False, num_timesteps=num_timesteps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that this method takes two more parameters in addition to the model to be converted.\n",
    "\n",
    "`input_shape` is needed in order to instantiate a SNN with the appropriate number of neurons because unlike traditional CNNs, SNNs are *stateful*.\n",
    "\n",
    "`add_spiking_output` is a boolean flag to specify whether or not to add a spiking layer as the last layer in the network. This ensure that both the input and output to our network are of the form of `spikes`.\n",
    "\n",
    "`synops=True` tells sinabs to include the machinery for calculating synaptic operations, which we'll use later.\n",
    "\n",
    "Let us now look at the generated SCNN. You should see that the only major difference is that the `ReLU` layers are replaced by `SpikingLayer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)\n",
       "  (1): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0)\n",
       "  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (3): Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)\n",
       "  (4): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0)\n",
       "  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (6): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)\n",
       "  (7): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0)\n",
       "  (8): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
       "  (9): Flatten(start_dim=1, end_dim=-1)\n",
       "  (10): Linear(in_features=128, out_features=500, bias=False)\n",
       "  (11): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0)\n",
       "  (12): Linear(in_features=500, out_features=10, bias=False)\n",
       "  (Spiking output): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sinabs_model.spiking_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model validation in sinabs simulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets test our SCNN model to verify whether the network is in fact \"equivalent\" to the CNN model in terms of its performance. As we did previously, we start by defining a data loader (this time it is going to produce spikes, `spiking=True`) and then pass it to our test method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch_size = 10\n",
    "\n",
    "spike_mnist_test = MNIST(\n",
    "    \"./data\", train=False, is_spiking=True, time_window=num_timesteps\n",
    ")\n",
    "spike_test_loader = DataLoader(\n",
    "    spike_mnist_test, batch_size=test_batch_size, shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the spiking simulations are significantly slower on a PC, we are going to limit our test to 300 samples here. You can of course test it on the entire 10k samples if you want to verify that it works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7e045462fda44ee98cd7db4bd184f854",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classification accuracy: 97.0%\n"
     ]
    }
   ],
   "source": [
    "import sinabs.layers as sl\n",
    "\n",
    "correct_predictions = []\n",
    "\n",
    "for data, target in tqdm(spike_test_loader):\n",
    "    data, target = data.to(device), target.to(device)\n",
    "    data = sl.FlattenTime()(data)\n",
    "    with torch.no_grad():\n",
    "        output = sinabs_model(data)\n",
    "        output = output.unflatten(\n",
    "            0, (test_batch_size, output.shape[0] // test_batch_size)\n",
    "        )\n",
    "\n",
    "    # get the index of the max log-probability\n",
    "    pred = output.sum(1).argmax(dim=1, keepdim=True)\n",
    "\n",
    "    # Compute the total correct predictions\n",
    "    correct_predictions.append(pred.eq(target.view_as(pred)))\n",
    "    if len(correct_predictions) * test_batch_size >= 300:\n",
    "        break\n",
    "\n",
    "correct_predictions = torch.cat(correct_predictions)\n",
    "print(\n",
    "    f\"Classification accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that this auto-generated spiking (`sinabs_model`) network's performance is close to that of the `ann`! Yay!\n",
    "\n",
    "You would have noticed a free parameter that was added `time_window`. This is a critical parameter that determines whether or not your SNN is going to work well. The longer `time_window` is, the more spikes we produce as input and the better the performance of the network is going to be. Feel free to experiment with this parameter and see how this changes your network performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualisation of specific example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get one sample from the dataloader\n",
    "img, label = spike_mnist_test[10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets visualize this data, just so we know what to expect. We can do this by collapsing the time dimension of the spike raster returned by the dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAO40lEQVR4nO3de4xc5XnH8d/Dst7FN1hjZ2sMKeA6BESKaTd2K1BFYxWMpcagSjSORJzI1KiK2xChKCiRGpqmkoNyIUVNUhO7OBGF0ALFrdwQs4pqURrjNXF9BRt8Abtrb6iJfAGv9/L0jz2OFtjzznjmzMV+vh9pNbPnmTPn8dg/nzPnnTmvubsAnPvOa3QDAOqDsANBEHYgCMIOBEHYgSDOr+fGxlmbt2tCPTcJhHJSJ3TK+22sWlVhN7P5kr4jqUXSD9x9eerx7ZqguTavmk0CSNjg3bm1ig/jzaxF0t9LulXSNZIWmdk1lT4fgNqq5j37HEmvuvsedz8l6XFJC4tpC0DRqgn7DElvjPr9QLbsXcxsqZn1mFnPgPqr2ByAatT8bLy7r3D3LnfvalVbrTcHIEc1YT8o6bJRv1+aLQPQhKoJ+0ZJs8zsCjMbJ+kTktYU0xaAolU89Obug2a2TNKzGhl6W+Xu2wvrDEChqhpnd/e1ktYW1AuAGuLjskAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EUdcpm1EbJ/5kbm5twpMbkuue/OM5yfqb16b/icxZuDVZf6H72tzaUJsn1+18MVnWxCd+nn4A3oU9OxAEYQeCIOxAEIQdCIKwA0EQdiAIwg4EYe7psc4iTbYpPtfm1W17Z4uWjo5kfdy/jkvWH7j8qdxau6X/ft8YHJ+s39DeuP3BroETyfrLp6Yl61/+h0/n1i75xguVtNT0Nni3jvoRG6tW1YdqzGyfpGOShiQNuntXNc8HoHaK+ATdH7r7mwU8D4Aa4j07EES1YXdJPzWzTWa2dKwHmNlSM+sxs54B9Ve5OQCVqvYw/kZ3P2hmH5C0zsxedvf1ox/g7iskrZBGTtBVuT0AFapqz+7uB7PbPklPS0p/hQpAw1QcdjObYGaTTt+XdLOkbUU1BqBY1RzGd0p62sxOP88/uftPCukqmFf+6qpk/bVZ3y/xDBMq3vYPf/XhZP0Hfemx7GMDbcn6nrcuzq1Nm5AeR3/26n9P1j/U+nay/tG/fCC3tmjH55Prtq3dmKyfjSoOu7vvkXRdgb0AqCGG3oAgCDsQBGEHgiDsQBCEHQiCS0nXgX30I8n6yo+vqOr5H3zr8tzaY8tvTa7bseNosm479ybrw2+nvwM1TYn6eS3Jda/8u7uT9Vdu/26yPv38ibm18V84mFzXXupM1gcPHU7WmxF7diAIwg4EQdiBIAg7EARhB4Ig7EAQhB0IgnH2OhhqS48n33TBcLI+4EPJ+sp/XJBbu+RH6Usml7p0UE0vLTSc/nPNWpaebvq3+/4iWd9y90O5tbVXrU2uO2f+nyfrHY8wzg6gSRF2IAjCDgRB2IEgCDsQBGEHgiDsQBCMs9fBwOTWqta/fsOnkvUZ5+j0w6V88KvpP/d/Lc5/3Ut9tqF90aH0xh9Jl5sRe3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJx9jqY+7fVTf/rGy8sqJNYPrP+M7m1vbesTK67+IP/naw/od+oqKdGKrlnN7NVZtZnZttGLZtiZuvMbHd221HbNgFUq5zD+EckzX/Psvskdbv7LEnd2e8AmljJsLv7eklH3rN4oaTV2f3Vkm4rti0ARav0PXunu/dm9w9Jyp0Yy8yWSloqSe0aX+HmAFSr6rPx7u5KXJfQ3Ve4e5e7d7WqrdrNAahQpWE/bGbTJSm77SuuJQC1UGnY10hanN1fLOmZYtoBUCsl37Ob2WOSbpI01cwOSPqKpOWSnjCzJZL2S7qjlk02u/OuuzpZv6JtfbK+a+BEsj512+AZ9wRp4vbE28Zb0uvOHFfqYPXsG2cvGXZ3X5RTmldwLwBqiI/LAkEQdiAIwg4EQdiBIAg7EARfcS3A7jvTX0FdcuHryfrNOz6ZrLf/24tn3BOk1uOVTzjdNzSpwE6aA3t2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCcfYCLLv1J8l679A7ybp9fWqJLew/w44gSe98wCpe938Hzr0LJrNnB4Ig7EAQhB0IgrADQRB2IAjCDgRB2IEgGGcvwLTzjyXrDxxOX4i39blNRbaDzKTX87/P/tbQ28l1L2l9q8SzX3TmDTUYe3YgCMIOBEHYgSAIOxAEYQeCIOxAEIQdCIJx9jK1TJuWWzs5fKiOnaBc5w3kj7N3tIxPrrvr5PSi22m4knt2M1tlZn1mtm3UsvvN7KCZbc5+FtS2TQDVKucw/hFJ88dY/m13n539rC22LQBFKxl2d18v6UgdegFQQ9WcoFtmZluyw/zcC3aZ2VIz6zGzngH1V7E5ANWoNOzfkzRT0mxJvZK+mfdAd1/h7l3u3tWqtgo3B6BaFYXd3Q+7+5C7D0t6WNKcYtsCULSKwm5mo8clbpe0Le+xAJpDyXF2M3tM0k2SpprZAUlfkXSTmc2W5JL2Sbq7di02h1/Nm5lb+9Tk9HXjB7wlWd+t/DF8VG7KXa9XvO6h/gtLPOLsO/9UMuzuvmiMxStr0AuAGuLjskAQhB0IgrADQRB2IAjCDgTBV1zLNNieP/1vq6WH1p77v6tLPPubFXSEwY/9brL+8MzvJKoTk+v2PHh9sn6hfp6sNyP27EAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBOPsZXpnav44+4APJdedOTE9jr65koYC8N+/Llnv/Js9yfr0lgtya/f0diXXvehffpGs51+kunmxZweCIOxAEIQdCIKwA0EQdiAIwg4EQdiBIBhnL9OkA8O5tRf788fgJal/uNTLPFhBR2e/lsmTk/UZD76arH/30vXJ+n+8PSm3tulr6e/Cj+/fkKyfjdizA0EQdiAIwg4EQdiBIAg7EARhB4Ig7EAQjLOXadLj+dcJ3/fX6SmXP3xBb7K++9LZyfrggYPJeiO1XD0rWd/71fbc2i1X7kyu+7XO50tsvTVZfeiuP82tjf/Pc28cvZSSe3Yzu8zMfmZmO8xsu5l9Lls+xczWmdnu7Laj9u0CqFQ5h/GDku5192sk/Z6kz5rZNZLuk9Tt7rMkdWe/A2hSJcPu7r3u/lJ2/5iknZJmSFooaXX2sNWSbqtRjwAKcEbv2c3scknXS9ogqdPdT78ZPSSpM2edpZKWSlK7xlfcKIDqlH023swmSnpS0j3ufnR0zd1dOdfgc/cV7t7l7l2taquqWQCVKyvsZtaqkaA/6u5PZYsPm9n0rD5dUl9tWgRQhJKH8WZmklZK2unu3xpVWiNpsaTl2e0zNenwLLC3Pz309sWLtyfrJ59NDyE9undOsj6r45e5tQPHL0quOzCc/v++c/zxZP0Ll/04Wb+hvfKPcmw/lb5g820vLEnWZ/3itdxa+uLf56Zy3rPfIOlOSVvNbHO27EsaCfkTZrZE0n5Jd9SkQwCFKBl2d39eUt7VGeYV2w6AWuHjskAQhB0IgrADQRB2IAjCDgTBV1wL8M8rP5as7//kxcl6qUsi3zX75WS9zfLH6VutJblu9SrfX+waOJGsf/zpe5P13/p8/teOpZhj6Sns2YEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCBu5yEx9TLYpPtfifVGu5UMzk/WP/Dj/e9eS9PXOzQV2827Hh08m66kxfEm6qvvPkvXJG/MvJd350AvJdXHmNni3jvqRMb+lyp4dCIKwA0EQdiAIwg4EQdiBIAg7EARhB4JgnB04hzDODoCwA1EQdiAIwg4EQdiBIAg7EARhB4IoGXYzu8zMfmZmO8xsu5l9Llt+v5kdNLPN2c+C2rcLoFLlTBIxKOled3/JzCZJ2mRm67Lat939G7VrD0BRypmfvVdSb3b/mJntlDSj1o0BKNYZvWc3s8slXS9pQ7ZomZltMbNVZtaRs85SM+sxs54B9VfXLYCKlR12M5so6UlJ97j7UUnfkzRT0myN7Pm/OdZ67r7C3bvcvatVbdV3DKAiZYXdzFo1EvRH3f0pSXL3w+4+5O7Dkh6WNKd2bQKoVjln403SSkk73f1bo5ZPH/Ww2yVtK749AEUp52z8DZLulLTVzDZny74kaZGZzZbkkvZJursG/QEoSDln45+XNNb3Y9cW3w6AWuETdEAQhB0IgrADQRB2IAjCDgRB2IEgCDsQBGEHgiDsQBCEHQiCsANBEHYgCMIOBEHYgSDqOmWzmf1S0v5Ri6ZKerNuDZyZZu2tWfuS6K1SRfb2m+4+baxCXcP+vo2b9bh7V8MaSGjW3pq1L4neKlWv3jiMB4Ig7EAQjQ77igZvP6VZe2vWviR6q1Rdemvoe3YA9dPoPTuAOiHsQBANCbuZzTezV8zsVTO7rxE95DGzfWa2NZuGuqfBvawysz4z2zZq2RQzW2dmu7PbMefYa1BvTTGNd2Ka8Ya+do2e/rzu79nNrEXSLkl/JOmApI2SFrn7jro2ksPM9knqcveGfwDDzP5A0nFJP3T3a7NlD0g64u7Ls/8oO9z9i03S2/2Sjjd6Gu9stqLpo6cZl3SbpE+rga9doq87VIfXrRF79jmSXnX3Pe5+StLjkhY2oI+m5+7rJR15z+KFklZn91dr5B9L3eX01hTcvdfdX8ruH5N0eprxhr52ib7qohFhnyHpjVG/H1Bzzffukn5qZpvMbGmjmxlDp7v3ZvcPSepsZDNjKDmNdz29Z5rxpnntKpn+vFqcoHu/G939dyTdKumz2eFqU/KR92DNNHZa1jTe9TLGNOO/1sjXrtLpz6vViLAflHTZqN8vzZY1BXc/mN32SXpazTcV9eHTM+hmt30N7ufXmmka77GmGVcTvHaNnP68EWHfKGmWmV1hZuMkfULSmgb08T5mNiE7cSIzmyDpZjXfVNRrJC3O7i+W9EwDe3mXZpnGO2+acTX4tWv49OfuXvcfSQs0ckb+NUlfbkQPOX1dKel/sp/tje5N0mMaOawb0Mi5jSWSLpbULWm3pOckTWmi3n4kaaukLRoJ1vQG9XajRg7Rt0janP0saPRrl+irLq8bH5cFguAEHRAEYQeCIOxAEIQdCIKwA0EQdiAIwg4E8f9QLFhTnxGeqgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "plt.imshow(img.sum(0)[0]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now take this data (including the time dimension), and pass it to the Sinabs SNN model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "snn_output = sinabs_model(img.to(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us now display the output in time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAATP0lEQVR4nO3df7DldX3f8efLZdllV5AfEqNACo0MKYEIeEVAY41oi0ggdWiKNW2acbLTiW0wcSZjmk6jzaQztmpMOtGZrZqQxGLqio1jDYqGaDPTLCyIsrCQ4A9+rJAFo/IjFhZ4949zdvdwe8/Zc+/u55y7fJ6PmTt7vt/z/X4/7/O9n3Nf+/2dqkKS1K/nzLsASdJ8GQSS1DmDQJI6ZxBIUucMAknq3GHzLmDU4VlX69k47zIk6ZDxCN95qKqOP5BlrKogWM9GXp4L512GJB0yPl9b7j7QZbhrSJI6ZxBIUucMAknqnEEgSZ0zCCSpcwaBJHWuaRAkuTLJ9iS3JXlby7YkSSvTLAiSnAH8PHAu8BLgkiQvbtWeJGllWm4R/ANga1X9XVU9CXwReGPD9iRJK9AyCLYDP57kuCQbgIuBkxZPlGRTkm1Jtu3m8YblSJKW0uwWE1W1I8m7gc8BjwG3AE8tMd1mYDPAUTnWx6VJ0ow1PVhcVR+uqpdW1auA7wB/1bI9SdLyNb3pXJIfqKpdSX6IwfGB81q2J0lavtZ3H/1EkuOA3cBbq+q7jduTJC1T0yCoqh9vuXxJ0oHzymJJ6pxBIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTONQ2CJL+U5LYk25NcnWR9y/YkScvXLAiSnAD8IrBQVWcAa4ArWrUnSVqZ1ruGDgOOSHIYsAH4VuP2JEnL1CwIqmon8B7gHuB+4HtV9bnF0yXZlGRbkm27ebxVOZKkMVruGjoGuAw4BXgRsDHJzyyerqo2V9VCVS2sZV2rciRJY7TcNfRa4BtV9WBV7QauAS5o2J4kaQVaBsE9wHlJNiQJcCGwo2F7kqQVaHmMYCuwBbgZuHXY1uZW7UmSVuawlguvql8Hfr1lG5KkA+OVxZLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUucMAknqnEEgSZ1r+YSy05LcMvLzcJK3tWpPkrQyzW5DXVV3AmcBJFkD7AQ+2ao9SdLKzGrX0IXA16rq7hm1J0maUtMH04y4Arh6qTeSbAI2Aaxnw4zKkSTt0XyLIMnhwKXAx5d6v6o2V9VCVS2sZV3rciRJi8xi19DrgZur6m9m0JYkaZlmEQRvYsxuIUnS/DUNgiQbgdcB17RsR5K0ck0PFlfVY8BxLduQJB0YryyWpM4ZBJLUOYNAkjpnEEhS5wwCSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI6ZxBIUuda33306CRbktyRZEeS81u2J0lavtaPqvxt4Nqqunz4pDKfRSlJq0yzIEjyPOBVwL8CqKongCdatSdJWpmWu4ZOAR4Efi/Jl5N8aPigmmdIsinJtiTbdvN4w3IkSUtpGQSHAecAH6yqs4HHgHcsnsiH10vSfLUMgvuA+6pq63B4C4NgkCStIs2CoKoeAO5Nctpw1IXA7a3akyStTOuzhv4t8NHhGUNfB36ucXuSpGVq/fD6W4CFlm1Ikg6MVxZLUucMAknqnEEgSZ0zCCSpc/s9WJzkTOBHhoM7qmp725IkSbM0NgiG9wr6E+Ak4KtAgDOT3ANcVlUPz6ZESVJLk3YN/QawDTi1qv5JVf0UcCpwI/CbM6hNkjQDk3YNvRb4sap6es+Iqno6yb8Dbm1emSRpJiZtETxRVU8uHjkc521CJelZYtIWwfokZzM4NjAq4G1CJenZYlIQPAC8b8J7kqRngbFBUFWvnmEdkqQ5mXT66BsnzVhV1xz8ciRJszZp19BPTnivgP0GQZJvAo8ATwFPVpV3IpWkVWbSrqGD9eyAn6iqhw7SsiRJB5n3GpKkzrUOggI+l+SmJJuWmiDJpiTbkmzb7eUJkjRzrR9V+cqq2pnkB4DrktxRVV8anaCqNgObAY7KsdW4HknSIlMFQZILgJNHp6+qP9jffFW1c/jvriSfBM4FvjR5LknSLE1zG+o/BH4YuIXB2T8w2OUzMQiSbASeU1WPDF//I+A/HlC1kqSDbpotggXg9Kpa7m6bFwCfTLKnnf9eVdcucxmSpMamCYLtwA8C9y9nwVX1deAlKylKkjQ70wTB84Hbk9zAyF1Hq+rSZlVJkmZmmiB4Z+siJEnzs98gqKovJnkB8LLhqBuqalfbsiRJs7LfC8qS/DRwA/BPgZ8Gtia5vHVhkqTZmGbX0K8BL9uzFZDkeODzwJaWhUmSZmOaW0w8Z9GuoG9POZ8k6RAwzRbBtUk+C1w9HP5nwGfalSRJmqWJQZDB1WC/w+BA8SuHozdX1SdbFyZJmo2JQVBVleQzVXUmUzyIRpJ06JlmX//NSV62/8kkSYeiaY4RvBx4c5K7gceAMNhY+LGmlUmSZmKaIPjHzauQJM3NNEHgw2Ik6VlsmiD4XwzCIMB64BTgTuBHp2kgyRpgG7Czqi5ZYZ2SpEamudfQmaPDSc4BfmEZbVwJ7ACOWl5pkqRZWPYVwlV1M4MDyPuV5ETgDcCHltuOJGk2pnlU5S+PDD4HOAf41pTLfz/wK8CRE5a/CdgEsJ4NUy5WknSwTLNFcOTIzzoGxwwu299MSS4BdlXVTZOmq6rNVbVQVQtrWTdFOZKkg2maYwTvAkiyoar+bhnLfgVwaZKLGRxkPirJH1XVz6ysVElSC9M8j+D8JLcDdwyHX5LkA/ubr6p+tapOrKqTgSuAPzMEJGn1mWbX0PsZXFT2bYCq+grwqoY1SZJmaJrrCKiqewc3It3rqeU0UlV/Dvz5cuaRJM3GNEFwb5ILgEqyln3XBUiSngWm2TX0r4G3AicAO4GzhsOSpGeBac4aegh48wxqkSTNwdggSPIfJsxXVfUbDeqRJM3YpC2Cx5YYtxF4C3AcYBBI0rPA2CCoqvfueZ3kSAYHiX8O+Bjw3nHzSZIOLft7eP2xwC8zOEZwFXBOVX1nFoVJkmZj0jGC/wK8EdgMnFlVj86sKknSzEw6ffTtwIuAfw98K8nDw59Hkjw8m/IkSa1NOkaw7GcVSJIOPf6xl6TOGQSS1DmDQJI61ywIkqxPckOSryS5Lcm7WrUlSVq5qW5DvUKPA6+pqkeHdy39iyR/WlV/2bBNSdIyNQuCqipgz7UHa4c/1ao9SdLKND1GkGRNkluAXcB1VbV1iWk2JdmWZNtuHm9ZjiRpCU2DoKqeqqqzgBOBc5OcscQ0m6tqoaoW1rKuZTmSpCXM5KyhqvoucD1w0SzakyRNr+VZQ8cnOXr4+gjgdcAdrdqTJK1My7OGXghclWQNg8D5H1X16YbtSZJWoOVZQ18Fzm61fEnSweGVxZLUOYNAkjpnEEhS5wwCSeqcQSBJnWt5+uhBk4UznzG85v5vLzndQ6/7e88YPmb7vscsj87z1AuPG7u8x84+ce/r9Q98f2xN//cHjxj73rj5Ftc9bVuj8y3+jM+/7u6pljda7+h7iz/H6HuT1tnofBu/fN+SNSxVxzTtjqt1cU2TLK53XNvjPu+kZSyufd2nb9j7+vFLzh3b7uh6GrV4nY0ub7TvT+pzk9oZ/RyLlzHud7d4GePem/RdGu2ro9/FxXVM+h0sXjdLzbPYpO/LEQ89NXa+ScscNe738P3nr3nG8Li/P5O+I5P6/lg3bpluugncIpCkzhkEktQ5g0CSOmcQSFLnDAJJ6pxBIEmda3kb6pOSXJ/k9uHD669s1ZYkaeVaXkfwJPD2qro5yZHATUmuq6rbG7YpSVqmZlsEVXV/Vd08fP0IsAM4oVV7kqSVmcmVxUlOZvBsgiUfXg9sAljPhlmUI0ka0fxgcZLnAp8A3lZVDy9+34fXS9J8NQ2CJGsZhMBHq+qalm1Jklam5VlDAT4M7Kiq97VqR5J0YFpuEbwC+BfAa5LcMvy5uGF7kqQVaPnw+r8A0mr5kqSDwyuLJalzBoEkdc4gkKTOGQSS1DmDQJI6l6qadw17HZVj6+W5cN5lSNIh4/O15aaqWjiQZbhFIEmdMwgkqXMGgSR1ziCQpM4ZBJLUOYNAkjpnEEhS51o+j+AjSXYl2d6qDUnSgWu5RfD7wEUNly9JOgiaBUFVfQn421bLlyQdHM0eTDOtJJuATQDr2TDnaiSpP3M/WFxVm6tqoaoW1rJu3uVIUnfmHgSSpPkyCCSpcy1PH70a+D/AaUnuS/KWVm1Jklau2cHiqnpTq2VLkg4edw1JUucMAknqnEEgSZ0zCCSpcwaBJHVu7reYGOeb/+mCva//+J//1jPeu/H7p+x9/fPPu3/v6//2vReOXd6XH/2hva83Hf/FscsbN8/i+c46fP3Ydkfn+8AJW6eq72VHfGOqmhYb/fy/sPPlS7a7uO1x88AzP+NoDYvrG2dx3aPr4uzn3jN2vnG/x0nrZfF7mx/8h3tfj36O0fGL35v0e1zJZ570eUeXN+lzjPvdr7SPjBpdzwAX3v6Te1+/98UfHzvfuHW72LjPtbjWce9N+s6Na2fxfOP68KTpJlncfyb143Gm7RejbS2ebtzflTXj/6xMzS0CSeqcQSBJnTMIJKlzBoEkdc4gkKTOGQSS1DmDQJI61zQIklyU5M4kdyV5R8u2JEkr0/J5BGuA3wVeD5wOvCnJ6a3akyStTMstgnOBu6rq61X1BPAx4LKG7UmSVqBlEJwA3DsyfN9w3DMk2ZRkW5Jtu3m8YTmSpKXM/WBxVW2uqoWqWljLunmXI0ndaRkEO4GTRoZPHI6TJK0iLYPgRuDUJKckORy4AvhUw/YkSSvQ8uH1Tyb5N8BngTXAR6rqtlbtSZJWpunzCKrqM8BnWrYhSTowcz9YLEmaL4NAkjpnEEhS5wwCSeqcQSBJnUtVzbuGvZI8Atw57zpWiecDD827iFXA9bCP62If18U+p1XVkQeygKanj67AnVW1MO8iVoMk21wXrodRrot9XBf7JNl2oMtw15Akdc4gkKTOrbYg2DzvAlYR18WA62Ef18U+rot9DnhdrKqDxZKk2VttWwSSpBkzCCSpc6siCJJclOTOJHclece865mlJCcluT7J7UluS3LlcPyxSa5L8tfDf4+Zd62zkmRNki8n+fRw+JQkW4f944+Hz7d41ktydJItSe5IsiPJ+b32iyS/NPx+bE9ydZL1vfSLJB9JsivJ9pFxS/aDDPzOcJ18Nck507Qx9yBIsgb4XeD1wOnAm5KcPt+qZupJ4O1VdTpwHvDW4ed/B/CFqjoV+MJwuBdXAjtGht8N/FZVvRj4DvCWuVQ1e78NXFtVPwK8hME66a5fJDkB+EVgoarOYPB8kyvop1/8PnDRonHj+sHrgVOHP5uAD07TwNyDADgXuKuqvl5VTwAfAy6bc00zU1X3V9XNw9ePMPiyn8BgHVw1nOwq4KfmUuCMJTkReAPwoeFwgNcAW4aTdLEukjwPeBXwYYCqeqKqvkun/YLBxa9HJDkM2ADcTyf9oqq+BPztotHj+sFlwB/UwF8CRyd54f7aWA1BcAJw78jwfcNx3UlyMnA2sBV4QVXdP3zrAeAF86prxt4P/Arw9HD4OOC7VfXkcLiX/nEK8CDwe8PdZB9KspEO+0VV7QTeA9zDIAC+B9xEn/1ij3H9YEV/T1dDEAhI8lzgE8Dbqurh0fdqcI7vs/483ySXALuq6qZ517IKHAacA3ywqs4GHmPRbqCO+sUxDP6newrwImAj//+ukm4djH6wGoJgJ3DSyPCJw3HdSLKWQQh8tKquGY7+mz2bdMN/d82rvhl6BXBpkm8y2EX4Ggb7yY8e7hKAfvrHfcB9VbV1OLyFQTD02C9eC3yjqh6sqt3ANQz6So/9Yo9x/WBFf09XQxDcCJw6PAPgcAYHgT4155pmZrgP/MPAjqp638hbnwJ+dvj6Z4E/mXVts1ZVv1pVJ1bVyQz6wZ9V1ZuB64HLh5P1si4eAO5Nctpw1IXA7XTYLxjsEjovyYbh92XPuuiuX4wY1w8+BfzL4dlD5wHfG9mFNF5Vzf0HuBj4K+BrwK/Nu54Zf/ZXMtis+ypwy/DnYgb7xr8A/DXweeDYedc64/XyauDTw9d/H7gBuAv4OLBu3vXNaB2cBWwb9o3/CRzTa78A3gXcAWwH/hBY10u/AK5mcGxkN4MtxbeM6wdAGJyF+TXgVgZnWu23DW8xIUmdWw27hiRJc2QQSFLnDAJJ6pxBIEmdMwgkqXMGgbqU5Lgktwx/Hkiyc/j60SQfmHd90ix5+qi6l+SdwKNV9Z551yLNg1sE0ogkrx55DsI7k1yV5H8nuTvJG5P85yS3Jrl2eGsQkrw0yReT3JTks9Pc7VFaTQwCabIfZnDPo0uBPwKur6ozge8DbxiGwX8FLq+qlwIfAX5zXsVKK3HY/ieRuvanVbU7ya0MHohy7XD8rcDJwGnAGcB1g9vgsIbB7QCkQ4ZBIE32OEBVPZ1kd+07qPY0g+9PgNuq6vx5FSgdKHcNSQfmTuD4JOfD4JbiSX50zjVJy2IQSAegBo9XvRx4d5KvMLh77AVzLUpaJk8flaTOuUUgSZ0zCCSpcwaBJHXOIJCkzhkEktQ5g0CSOmcQSFLn/h/kqldUEYI74wAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "plt.pcolormesh(snn_output.T.detach().cpu())\n",
    "\n",
    "plt.ylabel(\"Neuron ID\")\n",
    "plt.yticks(np.arange(10) + 0.5, np.arange(10))\n",
    "plt.xlabel(\"Time\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the majority of spikes are emitted by the output neuron corresponding to the digit plotted above, which is a correct inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "65b6f4b806bbaf5b54d6ccaa27abf7e5307b1f0e4411e9da36d5256169cebdd7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
