{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# sinabs Tutorial 使用入门\n",
    "## Take LeNet as an example to train and test a spiking neural network (SNN)\n",
    "## 以LeNet为例训练使用脉冲神经网络(SNN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We demonstrate the general steps to construct a useful SNN model in *sinabs* and take LeNet-5 on MNIST for an example.   \n",
    "我们以LeNet-5在MNIST的工作为例，介绍创建并使用SNN的通用步骤：\n",
    "1. Build/define a LeNet **CNN** model in Pytorch   \n",
    "   用Pytorch创建一个LeNet **CNN**\n",
    "2. Train and test this LeNet **CNN** model in Pytorch   \n",
    "   用Pytorch训练并测试这个**CNN**模型\n",
    "3. Convert this LeNet **CNN** model into **SNN** using *sinabs*   \n",
    "   用sinabs将**CNN**模型转化为**SNN**模型\n",
    "4. Test on **SNN** in *sinabs*   \n",
    "   用sinabs测试**SNN**模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torchvision\n",
    "import sinabs\n",
    "import sinabs.layers as sl\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST\n",
    "from sinabs.from_torch import from_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build/define a LeNet CNN model in Pytorch\n",
    "### 用Pytorch创建一个LeNet CNN\n",
    "1. Recommand to use *torch.nn.Sequential* of *torch.nn* layers instead of manually added forwarding functions among layers.   \n",
    "   推荐使用*torch.nn.Sequential*模型，其中每一层都使用*torch.nn*定义的layers,而不推荐在forward()函数中自定义层间函数\n",
    "2. Current supporting standard layers:   \n",
    "   目前sinabs支持自动转化为SNN的标准层有:\n",
    "    - *Conv2d*\n",
    "    - *Linear*\n",
    "    - *AvgPool2d*\n",
    "    - *MaxPool2d*\n",
    "    - *ReLU*\n",
    "    - *Flatten*\n",
    "    - *Dropout*\n",
    "    - *BatchNorm*\n",
    "3. Users can also define their own layers deriving from *torch.nn.Module*   \n",
    "   用户也可以自定义层，继承*torch.nn.Module*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LeNet5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.seq = nn.Sequential(\n",
    "            # 1st Conv + ReLU + Pooling\n",
    "            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            # 2nd Conv + ReLU + Pooling\n",
    "            nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            # Dense layers\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(4 * 4 * 50, 500),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(500, 10),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.seq(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up environment\n",
    "### 设置环境\n",
    "1. Torch device: GPU or CPU   \n",
    "   设置Torch运行的设备: GPU 或者 CPU\n",
    "2. Torch dataloader: training/testing/spiking_testing   \n",
    "   设置Torch的dataloader: 分别用在训练/测试/脉冲神经网络测试\n",
    "3. Input image size: (n_channel, width, height)   \n",
    "   输入图像大小: (通道数，图像宽，图像高)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare():\n",
    "    # Setting up environment\n",
    "\n",
    "    # Declare global environment parameters\n",
    "    # Torch device: GPU or CPU\n",
    "    # Torch dataloader: training\n",
    "    # Torch dataloader: testing\n",
    "    # Torch dataloader: spiking testing\n",
    "    # Input image size: (n_channel, width, height)\n",
    "    global device, train_dataloader, test_dataloader, spiking_test_dataloader, input_image_size\n",
    "\n",
    "    # Torch device\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "    # Model folder to save trained models\n",
    "    os.makedirs(\"models\", exist_ok=True)\n",
    "\n",
    "    # Setting up random seed to reproduce experiments\n",
    "    torch.manual_seed(0)\n",
    "    if device != \"cpu\":\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    # Downloading/Loading MNIST dataset as tensors for training\n",
    "    train_dataset = MNIST(\n",
    "        \"./data/\",\n",
    "        train=True,\n",
    "        download=True,\n",
    "        transform=torchvision.transforms.ToTensor(),\n",
    "    )\n",
    "\n",
    "    # Downloading/Loading MNIST dataset as tensors for testing\n",
    "    test_dataset = MNIST(\n",
    "        \"./data/\",\n",
    "        train=False,\n",
    "        download=True,\n",
    "        transform=torchvision.transforms.ToTensor(),\n",
    "    )\n",
    "\n",
    "    # Define Torch dataloaders for training, testing and spiking testing\n",
    "    BATCH_SIZE = 512\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "    spiking_test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)\n",
    "\n",
    "    # Define the size of input images\n",
    "    input_image_size = (1, 28, 28)\n",
    "\n",
    "    # Return global prameters\n",
    "    return (\n",
    "        device,\n",
    "        train_dataloader,\n",
    "        test_dataloader,\n",
    "        spiking_test_dataloader,\n",
    "        input_image_size,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train LeNet CNN model in Pytorch\n",
    "### 用Pytorch训练LeNet CNN模型\n",
    "1. Define loss   \n",
    "   定义损失函数\n",
    "2. Define optimizer   \n",
    "   定义优化器\n",
    "3. Backpropagation over batches and epochs   \n",
    "   反向传播"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, n_epochs=20):\n",
    "    # Training a CNN model\n",
    "\n",
    "    # Define loss\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    # Define optimizer\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    # Visualize and display training loss in a progress bar\n",
    "    pbar = tqdm(range(n_epochs))\n",
    "\n",
    "    # backprop over epochs\n",
    "    for epoch in pbar:\n",
    "        # over batches\n",
    "        for imgs, labels in train_dataloader:\n",
    "            # reset grad to zero for each batch\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # port to device\n",
    "            imgs, labels = imgs.to(device), labels.to(device)\n",
    "            # forward pass\n",
    "            outputs = model(imgs)\n",
    "            # calculate loss\n",
    "            loss = criterion(outputs, labels)\n",
    "            # display loss in progress bar\n",
    "            pbar.set_postfix(loss=loss.item())\n",
    "\n",
    "            # backward pass\n",
    "            loss.backward()\n",
    "            # optimze parameters\n",
    "            optimizer.step()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test LeNet CNN model in Pytorch\n",
    "### 用Pytorch测试LeNet CNN模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the function to count the correct prediction\n",
    "def count_correct(output, target):\n",
    "    _, predicted = torch.max(output, 1)\n",
    "    acc = (predicted == target).sum().float()\n",
    "    return acc.cpu().numpy()\n",
    "\n",
    "\n",
    "def test(model):\n",
    "    # Test the accuracy of a CNN model\n",
    "\n",
    "    # With no gradient means less memory and calculation on forward pass\n",
    "    with torch.no_grad():\n",
    "        # evaluation usese Dropout and BatchNorm in inference mode\n",
    "        model.eval()\n",
    "        # Count correct prediction and total test number\n",
    "        n_correct = 0\n",
    "        n_test = 0\n",
    "\n",
    "        # over batches\n",
    "        for imgs, labels in test_dataloader:\n",
    "            # port to device\n",
    "            imgs, labels = imgs.to(device), labels.to(device)\n",
    "            # inference\n",
    "            outputs = model(imgs)\n",
    "            n_correct += count_correct(outputs, labels)\n",
    "            n_test += len(labels)\n",
    "    # calculate accuracy\n",
    "    ann_accuracy = n_correct / n_test * 100.0\n",
    "    print(\"ANN test accuracy: %.2f\" % (ann_accuracy))\n",
    "    return ann_accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test LeNet SNN model in *sinabs*\n",
    "### 用*sinabs*测试LeNet SNN模型\n",
    "\n",
    "1. Transfer pytorch trained CNN model to SNN model in *sinabs*   \n",
    "   将CNN模型转化为SNN模型\n",
    "    - neural model is different   \n",
    "      神经元模型是不同的\n",
    "        - a spiking neuron of an SNN holds a membrane potential state (V) of a certain time t over a time period (n_dt)   \n",
    "          每一个脉冲神经元都有一个膜电压V作为其某一时刻t的状态,整个SNN模拟的时间长度为n_dt\n",
    "        - weighted input adds up to the V   \n",
    "          输入乘以权重的值会加到V上\n",
    "        - a spiking neuron outputs a spike (binary output per time step dt) when V >= threshold at time t   \n",
    "          一个脉冲神经元在V>=threshold的时刻t，会释放一个脉冲，即每个时间步长为一个二进制的输出\n",
    "        - once a spike is generated, the V is subtraced by membrane_subtract, and the lower bound of V is set to min_v_mem   \n",
    "          每当产生一个脉冲，V会减去membrane_subtract，并且其下界设置为min_v_mem(V不能低于min_v_mem)\n",
    "\n",
    "    - network architecture is the same (e.g. convolution, pooling and dense)   \n",
    "      神经网络的架构是一致的(比如 convolution, pooling and dense)\n",
    "    - network parameters are the same (e.g. weights and biases)   \n",
    "      神经网络的参数值是一致的(例如 权重和偏置)\n",
    "2. Tile an image to a sequence of n_dt images as input to SNN simulations   \n",
    "   将一个图片扩展为n_dt个图片序列作为SNN的输入\n",
    "    - This processing on tile-up images seems inefficient   \n",
    "      这种把一副图片复制为图片序列的方法看起来很低效\n",
    "    - however, it is only a software simulation on continous current flow injecting to spiking neurons for n_dt length    \n",
    "      但是这只是受限于软件模拟的限制，其实质是输入脉冲神经元中耗时n_dt的固定电流\n",
    "    - which is ultra power efficient on Neuromorphic hardware   \n",
    "      在神经形态硬件中的实现是非常低功耗且高效的\n",
    "3. *sinabs* can only infer one input as a time, so batch_size = 1   \n",
    "   *sinabs*每次只能推理一个输入，因此batch_size = 1\n",
    "4. Classification is calculated on the count of spikes on the output layer over time period (n_dt)   \n",
    "   分类的结果是在n_dt的时间段中读取输出神经元的脉冲总数来计算的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define tensor_tile function to generate sequence of input images\n",
    "def tensor_tile(a, dim, n_tile):\n",
    "    # a: input tensor\n",
    "    # dim: tile on a specific dim or dims in a tuple\n",
    "    # n_tile: number of tile to repeat\n",
    "    init_dim = a.size(dim)\n",
    "    repeat_idx = [1] * a.dim()\n",
    "    repeat_idx[dim] = n_tile\n",
    "    a = a.repeat(*(repeat_idx))\n",
    "    order_index = torch.LongTensor(\n",
    "        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])\n",
    "    )\n",
    "    return torch.index_select(a, dim, order_index)\n",
    "\n",
    "\n",
    "def snn_test(model, n_dt=10, n_test=10000):\n",
    "    # Testing the accuracy of SNN on sinabs\n",
    "    # model: CNN model\n",
    "    # n_dt: the time window of each simulation\n",
    "    # n_test: number of test images in total\n",
    "\n",
    "    # Transfer Pytorch trained CNN model to sinabs SNN model\n",
    "    net = from_model(\n",
    "        model,  # Pytorch trained model\n",
    "        input_image_size,  # Input image size: (n_channel, width, height)\n",
    "        spike_threshold=1.0,  # Threshold of the membrane potential of a Spiking neuron\n",
    "        bias_rescaling=1.0,  # Subtract membrane potential when the neuron fires a spike\n",
    "        min_v_mem=-1.0,  # The lower bound of the membrane potential\n",
    "        num_timesteps=n_dt, # The number of time steps\n",
    "    ).to(device)\n",
    "\n",
    "    # With no gradient means less memory and calculation on forward pass\n",
    "    with torch.no_grad():\n",
    "        # evaluation usese Dropout and BatchNorm in inference mode\n",
    "        net.spiking_model.eval()\n",
    "        # Count correct prediction and total test number\n",
    "        n_correct = 0\n",
    "        # loop over the input files once a time\n",
    "        for i, (imgs, labels) in enumerate(tqdm(spiking_test_dataloader)):\n",
    "            if i > n_test:\n",
    "                break\n",
    "            # tile image to a sequence of n_dt length as input to SNN\n",
    "            input_frames = tensor_tile(imgs, 0, n_dt).to(device)\n",
    "            labels = labels.to(device)\n",
    "            # Reset neural states of all the neurons in the network for each inference\n",
    "            net.reset_states()\n",
    "            # inference\n",
    "            outputs = net.spiking_model(input_frames)\n",
    "            n_correct += count_correct(outputs.sum(0, keepdim=True), labels)\n",
    "    # calculate accuracy\n",
    "    snn_accuracy = n_correct / n_test * 100.0\n",
    "    print(\"SNN test accuracy: %.2f\" % (snn_accuracy))\n",
    "    return snn_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c93ae950d5547439c8072baf0365336",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ANN test accuracy: 98.04\n"
     ]
    }
   ],
   "source": [
    "# Setting up environment\n",
    "prepare()\n",
    "# Init LeNet5 CNN\n",
    "classifier = LeNet5().to(device)\n",
    "# Train CNN model\n",
    "train(classifier, n_epochs=2)\n",
    "# Test on CNN model\n",
    "ann_accuracy = test(classifier)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3dd85ad61fd74d908c7b2986aeae901b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SNN test accuracy: 98.35\n"
     ]
    }
   ],
   "source": [
    "# Test on SNN model\n",
    "snn_accuracy = snn_test(classifier, n_dt=10, n_test=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "65b6f4b806bbaf5b54d6ccaa27abf7e5307b1f0e4411e9da36d5256169cebdd7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
