sinabs Tutorial 使用入门#

Take LeNet as an example to train and test a spiking neural network (SNN)#

以LeNet为例训练使用脉冲神经网络(SNN)#

We demonstrate the general steps to construct a useful SNN model in sinabs and take LeNet-5 on MNIST for an example.
我们以LeNet-5在MNIST的工作为例,介绍创建并使用SNN的通用步骤:

  1. Build/define a LeNet CNN model in Pytorch
    用Pytorch创建一个LeNet CNN

  2. Train and test this LeNet CNN model in Pytorch
    用Pytorch训练并测试这个CNN模型

  3. Convert this LeNet CNN model into SNN using sinabs
    用sinabs将CNN模型转化为SNN模型

  4. Test on SNN in sinabs
    用sinabs测试SNN模型

import os
import torch
import torchvision
import sinabs
import sinabs.layers as sl
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from sinabs.from_torch import from_model
/home/docs/checkouts/readthedocs.org/user_builds/sinabs/envs/v0.3.5/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Build/define a LeNet CNN model in Pytorch#

用Pytorch创建一个LeNet CNN#

  1. Recommand to use torch.nn.Sequential of torch.nn layers instead of manually added forwarding functions among layers.
    推荐使用torch.nn.Sequential模型,其中每一层都使用torch.nn定义的layers,而不推荐在forward()函数中自定义层间函数

  2. Current supporting standard layers:
    目前sinabs支持自动转化为SNN的标准层有:

    • Conv2d

    • Linear

    • AvgPool2d

    • MaxPool2d

    • ReLU

    • Flatten

    • Dropout

    • BatchNorm

  3. Users can also define their own layers deriving from torch.nn.Module
    用户也可以自定义层,继承torch.nn.Module

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.seq = nn.Sequential(
            # 1st Conv + ReLU + Pooling
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # 2nd Conv + ReLU + Pooling
            nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            # Dense layers
            nn.Flatten(),
            nn.Linear(4 * 4 * 50, 500),
            nn.ReLU(),
            nn.Linear(500, 10),
        )

    def forward(self, x):
        return self.seq(x)

Setting up environment#

设置环境#

  1. Torch device: GPU or CPU
    设置Torch运行的设备: GPU 或者 CPU

  2. Torch dataloader: training/testing/spiking_testing
    设置Torch的dataloader: 分别用在训练/测试/脉冲神经网络测试

  3. Input image size: (n_channel, width, height)
    输入图像大小: (通道数,图像宽,图像高)

def prepare():
    # Setting up environment

    # Declare global environment parameters
    # Torch device: GPU or CPU
    # Torch dataloader: training
    # Torch dataloader: testing
    # Torch dataloader: spiking testing
    # Input image size: (n_channel, width, height)
    global device, train_dataloader, test_dataloader, spiking_test_dataloader, input_image_size

    # Torch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model folder to save trained models
    os.makedirs("models", exist_ok=True)

    # Setting up random seed to reproduce experiments
    torch.manual_seed(0)
    if device is not "cpu":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Downloading/Loading MNIST dataset as tensors for training
    train_dataset = MNIST(
        "./data/",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    # Downloading/Loading MNIST dataset as tensors for testing
    test_dataset = MNIST(
        "./data/",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    # Define Torch dataloaders for training, testing and spiking testing
    BATCH_SIZE = 512
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
    spiking_test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    # Define the size of input images
    input_image_size = (1, 28, 28)

    # Return global prameters
    return (
        device,
        train_dataloader,
        test_dataloader,
        spiking_test_dataloader,
        input_image_size,
    )

Train LeNet CNN model in Pytorch#

用Pytorch训练LeNet CNN模型#

  1. Define loss
    定义损失函数

  2. Define optimizer
    定义优化器

  3. Backpropagation over batches and epochs
    反向传播

def train(model, n_epochs=20):
    # Training a CNN model

    # Define loss
    criterion = torch.nn.CrossEntropyLoss()
    # Define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # Visualize and display training loss in a progress bar
    pbar = tqdm(range(n_epochs))

    # backprop over epochs
    for epoch in pbar:
        # over batches
        for imgs, labels in train_dataloader:
            # reset grad to zero for each batch
            optimizer.zero_grad()

            # port to device
            imgs, labels = imgs.to(device), labels.to(device)
            # forward pass
            outputs = model(imgs)
            # calculate loss
            loss = criterion(outputs, labels)
            # display loss in progress bar
            pbar.set_postfix(loss=loss.item())

            # backward pass
            loss.backward()
            # optimze parameters
            optimizer.step()
    return model

Test LeNet CNN model in Pytorch#

用Pytorch测试LeNet CNN模型#

# Define the function to count the correct prediction
def count_correct(output, target):
    _, predicted = torch.max(output, 1)
    acc = (predicted == target).sum().float()
    return acc.cpu().numpy()


def test(model):
    # Test the accuracy of a CNN model

    # With no gradient means less memory and calculation on forward pass
    with torch.no_grad():
        # evaluation usese Dropout and BatchNorm in inference mode
        model.eval()
        # Count correct prediction and total test number
        n_correct = 0
        n_test = 0

        # over batches
        for imgs, labels in test_dataloader:
            # port to device
            imgs, labels = imgs.to(device), labels.to(device)
            # inference
            outputs = model(imgs)
            n_correct += count_correct(outputs, labels)
            n_test += len(labels)
    # calculate accuracy
    ann_accuracy = n_correct / n_test * 100.0
    print("ANN test accuracy: %.2f" % (ann_accuracy))
    return ann_accuracy

Test LeNet SNN model in sinabs#

sinabs测试LeNet SNN模型#

  1. Transfer pytorch trained CNN model to SNN model in sinabs
    将CNN模型转化为SNN模型

    • neural model is different
      神经元模型是不同的

      • a spiking neuron of an SNN holds a membrane potential state (V) of a certain time t over a time period (n_dt)
        每一个脉冲神经元都有一个膜电压V作为其某一时刻t的状态,整个SNN模拟的时间长度为n_dt

      • weighted input adds up to the V
        输入乘以权重的值会加到V上

      • a spiking neuron outputs a spike (binary output per time step dt) when V >= threshold at time t
        一个脉冲神经元在V>=threshold的时刻t,会释放一个脉冲,即每个时间步长为一个二进制的输出

      • once a spike is generated, the V is subtraced by membrane_subtract, and the lower bound of V is set to min_v_mem
        每当产生一个脉冲,V会减去membrane_subtract,并且其下界设置为min_v_mem(V不能低于min_v_mem)

    • network architecture is the same (e.g. convolution, pooling and dense)
      神经网络的架构是一致的(比如 convolution, pooling and dense)

    • network parameters are the same (e.g. weights and biases)
      神经网络的参数值是一致的(例如 权重和偏置)

  2. Tile an image to a sequence of n_dt images as input to SNN simulations
    将一个图片扩展为n_dt个图片序列作为SNN的输入

    • This processing on tile-up images seems inefficient
      这种把一副图片复制为图片序列的方法看起来很低效

    • however, it is only a software simulation on continous current flow injecting to spiking neurons for n_dt length
      但是这只是受限于软件模拟的限制,其实质是输入脉冲神经元中耗时n_dt的固定电流

    • which is ultra power efficient on Neuromorphic hardware
      在神经形态硬件中的实现是非常低功耗且高效的

  3. sinabs can only infer one input as a time, so batch_size = 1
    sinabs每次只能推理一个输入,因此batch_size = 1

  4. Classification is calculated on the count of spikes on the output layer over time period (n_dt)
    分类的结果是在n_dt的时间段中读取输出神经元的脉冲总数来计算的

# Define tensor_tile function to generate sequence of input images
def tensor_tile(a, dim, n_tile):
    # a: input tensor
    # dim: tile on a specific dim or dims in a tuple
    # n_tile: number of tile to repeat
    init_dim = a.size(dim)
    repeat_idx = [1] * a.dim()
    repeat_idx[dim] = n_tile
    a = a.repeat(*(repeat_idx))
    order_index = torch.LongTensor(
        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
    )
    return torch.index_select(a, dim, order_index)


def snn_test(model, n_dt=10, n_test=10000):
    # Testing the accuracy of SNN on sinabs
    # model: CNN model
    # n_dt: the time window of each simulation
    # n_test: number of test images in total

    # Transfer Pytorch trained CNN model to sinabs SNN model
    net = from_model(
        model,  # Pytorch trained model
        input_image_size,  # Input image size: (n_channel, width, height)
        spike_threshold=1.0,  # Threshold of the membrane potential of a Spiking neuron
        bias_rescaling=1.0,  # Subtract membrane potential when the neuron fires a spike
        min_v_mem=-1.0,  # The lower bound of the membrane potential
    ).to(device)

    # With no gradient means less memory and calculation on forward pass
    with torch.no_grad():
        # evaluation usese Dropout and BatchNorm in inference mode
        net.spiking_model.eval()
        # Count correct prediction and total test number
        n_correct = 0
        # loop over the input files once a time
        for i, (imgs, labels) in enumerate(tqdm(spiking_test_dataloader)):
            if i > n_test:
                break
            # tile image to a sequence of n_dt length as input to SNN
            input_frames = tensor_tile(imgs, 0, n_dt).to(device)
            labels = labels.to(device)
            # Reset neural states of all the neurons in the network for each inference
            net.reset_states()
            # inference
            outputs = net.spiking_model(input_frames)
            n_correct += count_correct(outputs.sum(0, keepdim=True), labels)
    # calculate accuracy
    snn_accuracy = n_correct / n_test * 100.0
    print("SNN test accuracy: %.2f" % (snn_accuracy))
    return snn_accuracy
# Setting up environment
prepare()
# Init LeNet5 CNN
classifier = LeNet5().to(device)
# Train CNN model
train(classifier, n_epochs=2)
# Test on CNN model
ann_accuracy = test(classifier)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
 69%|██████▉   | 6848512/9912422 [00:00<00:00, 64976488.99it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 77794348.09it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
100%|██████████| 28881/28881 [00:00<00:00, 101623904.21it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|██████████| 1648877/1648877 [00:00<00:00, 36656020.55it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
100%|██████████| 4542/4542 [00:00<00:00, 19559064.44it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/2 [00:00<?, ?it/s, loss=2.31]
  0%|          | 0/2 [00:00<?, ?it/s, loss=2.27]
  0%|          | 0/2 [00:00<?, ?it/s, loss=2.22]
  0%|          | 0/2 [00:01<?, ?it/s, loss=2.14]
  0%|          | 0/2 [00:01<?, ?it/s, loss=2.03]
  0%|          | 0/2 [00:01<?, ?it/s, loss=1.95]
  0%|          | 0/2 [00:02<?, ?it/s, loss=1.82]
  0%|          | 0/2 [00:02<?, ?it/s, loss=1.68]
  0%|          | 0/2 [00:02<?, ?it/s, loss=1.5] 
  0%|          | 0/2 [00:02<?, ?it/s, loss=1.31]
  0%|          | 0/2 [00:03<?, ?it/s, loss=1.21]
  0%|          | 0/2 [00:03<?, ?it/s, loss=1.05]
  0%|          | 0/2 [00:03<?, ?it/s, loss=0.982]
  0%|          | 0/2 [00:04<?, ?it/s, loss=0.876]
  0%|          | 0/2 [00:04<?, ?it/s, loss=0.729]
  0%|          | 0/2 [00:04<?, ?it/s, loss=0.77] 
  0%|          | 0/2 [00:04<?, ?it/s, loss=0.626]
  0%|          | 0/2 [00:05<?, ?it/s, loss=0.634]
  0%|          | 0/2 [00:05<?, ?it/s, loss=0.594]
  0%|          | 0/2 [00:05<?, ?it/s, loss=0.505]
  0%|          | 0/2 [00:06<?, ?it/s, loss=0.543]
  0%|          | 0/2 [00:06<?, ?it/s, loss=0.559]
  0%|          | 0/2 [00:06<?, ?it/s, loss=0.504]
  0%|          | 0/2 [00:07<?, ?it/s, loss=0.52] 
  0%|          | 0/2 [00:07<?, ?it/s, loss=0.493]
  0%|          | 0/2 [00:07<?, ?it/s, loss=0.456]
  0%|          | 0/2 [00:07<?, ?it/s, loss=0.385]
  0%|          | 0/2 [00:08<?, ?it/s, loss=0.398]
  0%|          | 0/2 [00:08<?, ?it/s, loss=0.361]
  0%|          | 0/2 [00:08<?, ?it/s, loss=0.385]
  0%|          | 0/2 [00:09<?, ?it/s, loss=0.44] 
  0%|          | 0/2 [00:09<?, ?it/s, loss=0.451]
  0%|          | 0/2 [00:09<?, ?it/s, loss=0.425]
  0%|          | 0/2 [00:10<?, ?it/s, loss=0.31] 
  0%|          | 0/2 [00:10<?, ?it/s, loss=0.409]
  0%|          | 0/2 [00:10<?, ?it/s, loss=0.328]
  0%|          | 0/2 [00:10<?, ?it/s, loss=0.344]
  0%|          | 0/2 [00:11<?, ?it/s, loss=0.371]
  0%|          | 0/2 [00:11<?, ?it/s, loss=0.384]
  0%|          | 0/2 [00:11<?, ?it/s, loss=0.356]
  0%|          | 0/2 [00:12<?, ?it/s, loss=0.361]
  0%|          | 0/2 [00:12<?, ?it/s, loss=0.28] 
  0%|          | 0/2 [00:12<?, ?it/s, loss=0.279]
  0%|          | 0/2 [00:13<?, ?it/s, loss=0.343]
  0%|          | 0/2 [00:13<?, ?it/s, loss=0.281]
  0%|          | 0/2 [00:13<?, ?it/s, loss=0.327]
  0%|          | 0/2 [00:14<?, ?it/s, loss=0.283]
  0%|          | 0/2 [00:14<?, ?it/s, loss=0.256]
  0%|          | 0/2 [00:14<?, ?it/s, loss=0.259]
  0%|          | 0/2 [00:14<?, ?it/s, loss=0.23] 
  0%|          | 0/2 [00:15<?, ?it/s, loss=0.286]
  0%|          | 0/2 [00:15<?, ?it/s, loss=0.316]
  0%|          | 0/2 [00:15<?, ?it/s, loss=0.292]
  0%|          | 0/2 [00:16<?, ?it/s, loss=0.257]
  0%|          | 0/2 [00:16<?, ?it/s, loss=0.229]
  0%|          | 0/2 [00:16<?, ?it/s, loss=0.266]
  0%|          | 0/2 [00:17<?, ?it/s, loss=0.301]
  0%|          | 0/2 [00:17<?, ?it/s, loss=0.267]
  0%|          | 0/2 [00:17<?, ?it/s, loss=0.239]
  0%|          | 0/2 [00:17<?, ?it/s, loss=0.174]
  0%|          | 0/2 [00:18<?, ?it/s, loss=0.216]
  0%|          | 0/2 [00:18<?, ?it/s, loss=0.233]
  0%|          | 0/2 [00:18<?, ?it/s, loss=0.233]
  0%|          | 0/2 [00:19<?, ?it/s, loss=0.198]
  0%|          | 0/2 [00:19<?, ?it/s, loss=0.18] 
  0%|          | 0/2 [00:19<?, ?it/s, loss=0.219]
  0%|          | 0/2 [00:20<?, ?it/s, loss=0.24] 
  0%|          | 0/2 [00:20<?, ?it/s, loss=0.229]
  0%|          | 0/2 [00:20<?, ?it/s, loss=0.223]
  0%|          | 0/2 [00:20<?, ?it/s, loss=0.219]
  0%|          | 0/2 [00:21<?, ?it/s, loss=0.179]
  0%|          | 0/2 [00:21<?, ?it/s, loss=0.21] 
  0%|          | 0/2 [00:21<?, ?it/s, loss=0.16]
  0%|          | 0/2 [00:22<?, ?it/s, loss=0.195]
  0%|          | 0/2 [00:22<?, ?it/s, loss=0.183]
  0%|          | 0/2 [00:22<?, ?it/s, loss=0.151]
  0%|          | 0/2 [00:22<?, ?it/s, loss=0.188]
  0%|          | 0/2 [00:23<?, ?it/s, loss=0.196]
  0%|          | 0/2 [00:23<?, ?it/s, loss=0.223]
  0%|          | 0/2 [00:23<?, ?it/s, loss=0.146]
  0%|          | 0/2 [00:24<?, ?it/s, loss=0.188]
  0%|          | 0/2 [00:24<?, ?it/s, loss=0.179]
  0%|          | 0/2 [00:24<?, ?it/s, loss=0.171]
  0%|          | 0/2 [00:25<?, ?it/s, loss=0.184]
  0%|          | 0/2 [00:25<?, ?it/s, loss=0.163]
  0%|          | 0/2 [00:25<?, ?it/s, loss=0.147]
  0%|          | 0/2 [00:25<?, ?it/s, loss=0.186]
  0%|          | 0/2 [00:26<?, ?it/s, loss=0.173]
  0%|          | 0/2 [00:26<?, ?it/s, loss=0.157]
  0%|          | 0/2 [00:26<?, ?it/s, loss=0.173]
  0%|          | 0/2 [00:27<?, ?it/s, loss=0.158]
  0%|          | 0/2 [00:27<?, ?it/s, loss=0.21] 
  0%|          | 0/2 [00:27<?, ?it/s, loss=0.14]
  0%|          | 0/2 [00:28<?, ?it/s, loss=0.147]
  0%|          | 0/2 [00:28<?, ?it/s, loss=0.111]
  0%|          | 0/2 [00:28<?, ?it/s, loss=0.134]
  0%|          | 0/2 [00:28<?, ?it/s, loss=0.185]
  0%|          | 0/2 [00:29<?, ?it/s, loss=0.13] 
  0%|          | 0/2 [00:29<?, ?it/s, loss=0.116]
  0%|          | 0/2 [00:29<?, ?it/s, loss=0.152]
  0%|          | 0/2 [00:30<?, ?it/s, loss=0.12] 
  0%|          | 0/2 [00:30<?, ?it/s, loss=0.127]
  0%|          | 0/2 [00:30<?, ?it/s, loss=0.107]
  0%|          | 0/2 [00:31<?, ?it/s, loss=0.124]
  0%|          | 0/2 [00:31<?, ?it/s, loss=0.121]
  0%|          | 0/2 [00:31<?, ?it/s, loss=0.152]
  0%|          | 0/2 [00:31<?, ?it/s, loss=0.117]
  0%|          | 0/2 [00:32<?, ?it/s, loss=0.0963]
  0%|          | 0/2 [00:32<?, ?it/s, loss=0.188] 
  0%|          | 0/2 [00:32<?, ?it/s, loss=0.113]
  0%|          | 0/2 [00:33<?, ?it/s, loss=0.128]
  0%|          | 0/2 [00:33<?, ?it/s, loss=0.117]
  0%|          | 0/2 [00:33<?, ?it/s, loss=0.134]
  0%|          | 0/2 [00:33<?, ?it/s, loss=0.0728]
  0%|          | 0/2 [00:34<?, ?it/s, loss=0.122] 
  0%|          | 0/2 [00:34<?, ?it/s, loss=0.114]
  0%|          | 0/2 [00:34<?, ?it/s, loss=0.104]
  0%|          | 0/2 [00:35<?, ?it/s, loss=0.0687]
 50%|█████     | 1/2 [00:35<00:35, 35.06s/it, loss=0.0687]
 50%|█████     | 1/2 [00:35<00:35, 35.06s/it, loss=0.158] 
 50%|█████     | 1/2 [00:35<00:35, 35.06s/it, loss=0.119]
 50%|█████     | 1/2 [00:35<00:35, 35.06s/it, loss=0.128]
 50%|█████     | 1/2 [00:36<00:35, 35.06s/it, loss=0.103]
 50%|█████     | 1/2 [00:36<00:35, 35.06s/it, loss=0.154]
 50%|█████     | 1/2 [00:36<00:35, 35.06s/it, loss=0.0738]
 50%|█████     | 1/2 [00:37<00:35, 35.06s/it, loss=0.108] 
 50%|█████     | 1/2 [00:37<00:35, 35.06s/it, loss=0.073]
 50%|█████     | 1/2 [00:37<00:35, 35.06s/it, loss=0.0942]
 50%|█████     | 1/2 [00:37<00:35, 35.06s/it, loss=0.118] 
 50%|█████     | 1/2 [00:38<00:35, 35.06s/it, loss=0.11] 
 50%|█████     | 1/2 [00:38<00:35, 35.06s/it, loss=0.0946]
 50%|█████     | 1/2 [00:38<00:35, 35.06s/it, loss=0.103] 
 50%|█████     | 1/2 [00:39<00:35, 35.06s/it, loss=0.113]
 50%|█████     | 1/2 [00:39<00:35, 35.06s/it, loss=0.154]
 50%|█████     | 1/2 [00:39<00:35, 35.06s/it, loss=0.0963]
 50%|█████     | 1/2 [00:40<00:35, 35.06s/it, loss=0.0883]
 50%|█████     | 1/2 [00:40<00:35, 35.06s/it, loss=0.108] 
 50%|█████     | 1/2 [00:40<00:35, 35.06s/it, loss=0.111]
 50%|█████     | 1/2 [00:40<00:35, 35.06s/it, loss=0.104]
 50%|█████     | 1/2 [00:41<00:35, 35.06s/it, loss=0.119]
 50%|█████     | 1/2 [00:41<00:35, 35.06s/it, loss=0.098]
 50%|█████     | 1/2 [00:41<00:35, 35.06s/it, loss=0.0693]
 50%|█████     | 1/2 [00:42<00:35, 35.06s/it, loss=0.118] 
 50%|█████     | 1/2 [00:42<00:35, 35.06s/it, loss=0.122]
 50%|█████     | 1/2 [00:42<00:35, 35.06s/it, loss=0.0946]
 50%|█████     | 1/2 [00:43<00:35, 35.06s/it, loss=0.0906]
 50%|█████     | 1/2 [00:43<00:35, 35.06s/it, loss=0.113] 
 50%|█████     | 1/2 [00:43<00:35, 35.06s/it, loss=0.152]
 50%|█████     | 1/2 [00:43<00:35, 35.06s/it, loss=0.0715]
 50%|█████     | 1/2 [00:44<00:35, 35.06s/it, loss=0.0721]
 50%|█████     | 1/2 [00:44<00:35, 35.06s/it, loss=0.0994]
 50%|█████     | 1/2 [00:44<00:35, 35.06s/it, loss=0.0785]
 50%|█████     | 1/2 [00:45<00:35, 35.06s/it, loss=0.0898]
 50%|█████     | 1/2 [00:45<00:35, 35.06s/it, loss=0.0809]
 50%|█████     | 1/2 [00:45<00:35, 35.06s/it, loss=0.0782]
 50%|█████     | 1/2 [00:46<00:35, 35.06s/it, loss=0.0798]
 50%|█████     | 1/2 [00:46<00:35, 35.06s/it, loss=0.097] 
 50%|█████     | 1/2 [00:46<00:35, 35.06s/it, loss=0.148]
 50%|█████     | 1/2 [00:46<00:35, 35.06s/it, loss=0.0844]
 50%|█████     | 1/2 [00:47<00:35, 35.06s/it, loss=0.0839]
 50%|█████     | 1/2 [00:47<00:35, 35.06s/it, loss=0.0973]
 50%|█████     | 1/2 [00:47<00:35, 35.06s/it, loss=0.0905]
 50%|█████     | 1/2 [00:48<00:35, 35.06s/it, loss=0.0468]
 50%|█████     | 1/2 [00:48<00:35, 35.06s/it, loss=0.0874]
 50%|█████     | 1/2 [00:48<00:35, 35.06s/it, loss=0.0717]
 50%|█████     | 1/2 [00:49<00:35, 35.06s/it, loss=0.075] 
 50%|█████     | 1/2 [00:49<00:35, 35.06s/it, loss=0.0993]
 50%|█████     | 1/2 [00:49<00:35, 35.06s/it, loss=0.0739]
 50%|█████     | 1/2 [00:49<00:35, 35.06s/it, loss=0.101] 
 50%|█████     | 1/2 [00:50<00:35, 35.06s/it, loss=0.0866]
 50%|█████     | 1/2 [00:50<00:35, 35.06s/it, loss=0.11]  
 50%|█████     | 1/2 [00:50<00:35, 35.06s/it, loss=0.119]
 50%|█████     | 1/2 [00:51<00:35, 35.06s/it, loss=0.0924]
 50%|█████     | 1/2 [00:51<00:35, 35.06s/it, loss=0.153] 
 50%|█████     | 1/2 [00:51<00:35, 35.06s/it, loss=0.107]
 50%|█████     | 1/2 [00:52<00:35, 35.06s/it, loss=0.0724]
 50%|█████     | 1/2 [00:52<00:35, 35.06s/it, loss=0.102] 
 50%|█████     | 1/2 [00:52<00:35, 35.06s/it, loss=0.0745]
 50%|█████     | 1/2 [00:52<00:35, 35.06s/it, loss=0.114] 
 50%|█████     | 1/2 [00:53<00:35, 35.06s/it, loss=0.0681]
 50%|█████     | 1/2 [00:53<00:35, 35.06s/it, loss=0.1]   
 50%|█████     | 1/2 [00:53<00:35, 35.06s/it, loss=0.0833]
 50%|█████     | 1/2 [00:54<00:35, 35.06s/it, loss=0.093] 
 50%|█████     | 1/2 [00:54<00:35, 35.06s/it, loss=0.0787]
 50%|█████     | 1/2 [00:54<00:35, 35.06s/it, loss=0.0939]
 50%|█████     | 1/2 [00:55<00:35, 35.06s/it, loss=0.0568]
 50%|█████     | 1/2 [00:55<00:35, 35.06s/it, loss=0.0844]
 50%|█████     | 1/2 [00:55<00:35, 35.06s/it, loss=0.0695]
 50%|█████     | 1/2 [00:55<00:35, 35.06s/it, loss=0.103] 
 50%|█████     | 1/2 [00:56<00:35, 35.06s/it, loss=0.0999]
 50%|█████     | 1/2 [00:56<00:35, 35.06s/it, loss=0.0596]
 50%|█████     | 1/2 [00:56<00:35, 35.06s/it, loss=0.0767]
 50%|█████     | 1/2 [00:57<00:35, 35.06s/it, loss=0.0761]
 50%|█████     | 1/2 [00:57<00:35, 35.06s/it, loss=0.0755]
 50%|█████     | 1/2 [00:57<00:35, 35.06s/it, loss=0.0787]
 50%|█████     | 1/2 [00:58<00:35, 35.06s/it, loss=0.067] 
 50%|█████     | 1/2 [00:58<00:35, 35.06s/it, loss=0.105]
 50%|█████     | 1/2 [00:58<00:35, 35.06s/it, loss=0.0884]
 50%|█████     | 1/2 [00:58<00:35, 35.06s/it, loss=0.136] 
 50%|█████     | 1/2 [00:59<00:35, 35.06s/it, loss=0.0877]
 50%|█████     | 1/2 [00:59<00:35, 35.06s/it, loss=0.0775]
 50%|█████     | 1/2 [00:59<00:35, 35.06s/it, loss=0.0892]
 50%|█████     | 1/2 [01:00<00:35, 35.06s/it, loss=0.0663]
 50%|█████     | 1/2 [01:00<00:35, 35.06s/it, loss=0.0815]
 50%|█████     | 1/2 [01:00<00:35, 35.06s/it, loss=0.108] 
 50%|█████     | 1/2 [01:00<00:35, 35.06s/it, loss=0.0763]
 50%|█████     | 1/2 [01:01<00:35, 35.06s/it, loss=0.0589]
 50%|█████     | 1/2 [01:01<00:35, 35.06s/it, loss=0.0625]
 50%|█████     | 1/2 [01:01<00:35, 35.06s/it, loss=0.0493]
 50%|█████     | 1/2 [01:02<00:35, 35.06s/it, loss=0.0902]
 50%|█████     | 1/2 [01:02<00:35, 35.06s/it, loss=0.0679]
 50%|█████     | 1/2 [01:02<00:35, 35.06s/it, loss=0.0563]
 50%|█████     | 1/2 [01:03<00:35, 35.06s/it, loss=0.0586]
 50%|█████     | 1/2 [01:03<00:35, 35.06s/it, loss=0.0485]
 50%|█████     | 1/2 [01:03<00:35, 35.06s/it, loss=0.0817]
 50%|█████     | 1/2 [01:03<00:35, 35.06s/it, loss=0.0705]
 50%|█████     | 1/2 [01:04<00:35, 35.06s/it, loss=0.0797]
 50%|█████     | 1/2 [01:04<00:35, 35.06s/it, loss=0.0646]
 50%|█████     | 1/2 [01:04<00:35, 35.06s/it, loss=0.0624]
 50%|█████     | 1/2 [01:05<00:35, 35.06s/it, loss=0.0554]
 50%|█████     | 1/2 [01:05<00:35, 35.06s/it, loss=0.113] 
 50%|█████     | 1/2 [01:05<00:35, 35.06s/it, loss=0.0865]
 50%|█████     | 1/2 [01:06<00:35, 35.06s/it, loss=0.0784]
 50%|█████     | 1/2 [01:06<00:35, 35.06s/it, loss=0.0633]
 50%|█████     | 1/2 [01:06<00:35, 35.06s/it, loss=0.0695]
 50%|█████     | 1/2 [01:06<00:35, 35.06s/it, loss=0.0547]
 50%|█████     | 1/2 [01:07<00:35, 35.06s/it, loss=0.0895]
 50%|█████     | 1/2 [01:07<00:35, 35.06s/it, loss=0.0724]
 50%|█████     | 1/2 [01:07<00:35, 35.06s/it, loss=0.0616]
 50%|█████     | 1/2 [01:08<00:35, 35.06s/it, loss=0.0731]
 50%|█████     | 1/2 [01:08<00:35, 35.06s/it, loss=0.0868]
 50%|█████     | 1/2 [01:08<00:35, 35.06s/it, loss=0.0872]
 50%|█████     | 1/2 [01:08<00:35, 35.06s/it, loss=0.0663]
 50%|█████     | 1/2 [01:09<00:35, 35.06s/it, loss=0.0531]
 50%|█████     | 1/2 [01:09<00:35, 35.06s/it, loss=0.0377]
 50%|█████     | 1/2 [01:09<00:35, 35.06s/it, loss=0.0791]
 50%|█████     | 1/2 [01:10<00:35, 35.06s/it, loss=0.0234]
100%|██████████| 2/2 [01:10<00:00, 35.01s/it, loss=0.0234]
100%|██████████| 2/2 [01:10<00:00, 35.02s/it, loss=0.0234]

ANN test accuracy: 98.09
# Test on SNN model
snn_accuracy = snn_test(classifier, n_dt=10, n_test=2000)
  0%|          | 0/10000 [00:00<?, ?it/s]
  0%|          | 15/10000 [00:00<01:10, 141.60it/s]
  0%|          | 30/10000 [00:00<01:10, 141.11it/s]
  0%|          | 45/10000 [00:00<01:10, 140.98it/s]
  1%|          | 60/10000 [00:00<01:10, 140.92it/s]
  1%|          | 75/10000 [00:00<01:10, 140.35it/s]
  1%|          | 90/10000 [00:00<01:10, 140.38it/s]
  1%|          | 105/10000 [00:00<01:10, 140.81it/s]
  1%|          | 120/10000 [00:00<01:10, 140.64it/s]
  1%|▏         | 135/10000 [00:00<01:10, 140.89it/s]
  2%|▏         | 150/10000 [00:01<01:10, 140.40it/s]
  2%|▏         | 165/10000 [00:01<01:10, 140.47it/s]
  2%|▏         | 180/10000 [00:01<01:09, 141.16it/s]
  2%|▏         | 195/10000 [00:01<01:09, 141.19it/s]
  2%|▏         | 210/10000 [00:01<01:09, 141.13it/s]
  2%|▏         | 225/10000 [00:01<01:09, 140.96it/s]
  2%|▏         | 240/10000 [00:01<01:09, 140.92it/s]
  3%|▎         | 255/10000 [00:01<01:09, 140.60it/s]
  3%|▎         | 270/10000 [00:01<01:09, 140.69it/s]
  3%|▎         | 285/10000 [00:02<01:09, 140.14it/s]
  3%|▎         | 300/10000 [00:02<01:09, 139.17it/s]
  3%|▎         | 314/10000 [00:02<01:09, 139.09it/s]
  3%|▎         | 328/10000 [00:02<01:09, 138.93it/s]
  3%|▎         | 342/10000 [00:02<01:09, 138.13it/s]
  4%|▎         | 356/10000 [00:02<01:09, 138.06it/s]
  4%|▎         | 370/10000 [00:02<01:09, 138.09it/s]
  4%|▍         | 385/10000 [00:02<01:09, 138.89it/s]
  4%|▍         | 399/10000 [00:02<01:09, 139.03it/s]
  4%|▍         | 413/10000 [00:02<01:08, 138.96it/s]
  4%|▍         | 427/10000 [00:03<01:08, 139.00it/s]
  4%|▍         | 441/10000 [00:03<01:08, 138.98it/s]
  5%|▍         | 456/10000 [00:03<01:08, 139.34it/s]
  5%|▍         | 470/10000 [00:03<01:08, 139.26it/s]
  5%|▍         | 484/10000 [00:03<01:08, 139.13it/s]
  5%|▍         | 498/10000 [00:03<01:08, 138.87it/s]
  5%|▌         | 512/10000 [00:03<01:08, 139.03it/s]
  5%|▌         | 527/10000 [00:03<01:07, 139.47it/s]
  5%|▌         | 541/10000 [00:03<01:07, 139.48it/s]
  6%|▌         | 556/10000 [00:03<01:07, 139.66it/s]
  6%|▌         | 571/10000 [00:04<01:07, 139.88it/s]
  6%|▌         | 585/10000 [00:04<01:07, 139.59it/s]
  6%|▌         | 599/10000 [00:04<01:07, 139.70it/s]
  6%|▌         | 613/10000 [00:04<01:07, 139.45it/s]
  6%|▋         | 627/10000 [00:04<01:07, 139.08it/s]
  6%|▋         | 642/10000 [00:04<01:07, 139.58it/s]
  7%|▋         | 656/10000 [00:04<01:07, 139.40it/s]
  7%|▋         | 671/10000 [00:04<01:06, 140.24it/s]
  7%|▋         | 686/10000 [00:04<01:06, 140.04it/s]
  7%|▋         | 701/10000 [00:05<01:06, 139.17it/s]
  7%|▋         | 715/10000 [00:05<01:07, 138.34it/s]
  7%|▋         | 729/10000 [00:05<01:07, 138.35it/s]
  7%|▋         | 743/10000 [00:05<01:06, 138.66it/s]
  8%|▊         | 757/10000 [00:05<01:06, 138.78it/s]
  8%|▊         | 771/10000 [00:05<01:06, 138.71it/s]
  8%|▊         | 785/10000 [00:05<01:06, 138.71it/s]
  8%|▊         | 799/10000 [00:05<01:06, 138.63it/s]
  8%|▊         | 813/10000 [00:05<01:06, 138.23it/s]
  8%|▊         | 827/10000 [00:05<01:06, 138.19it/s]
  8%|▊         | 842/10000 [00:06<01:05, 138.83it/s]
  9%|▊         | 857/10000 [00:06<01:05, 139.60it/s]
  9%|▊         | 871/10000 [00:06<01:05, 139.47it/s]
  9%|▉         | 885/10000 [00:06<01:05, 139.38it/s]
  9%|▉         | 899/10000 [00:06<01:05, 138.81it/s]
  9%|▉         | 913/10000 [00:06<01:05, 138.15it/s]
  9%|▉         | 927/10000 [00:06<01:05, 138.44it/s]
  9%|▉         | 942/10000 [00:06<01:05, 139.18it/s]
 10%|▉         | 956/10000 [00:06<01:04, 139.16it/s]
 10%|▉         | 971/10000 [00:06<01:04, 139.97it/s]
 10%|▉         | 986/10000 [00:07<01:04, 140.12it/s]
 10%|█         | 1001/10000 [00:07<01:04, 139.94it/s]
 10%|█         | 1016/10000 [00:07<01:04, 140.04it/s]
 10%|█         | 1031/10000 [00:07<01:04, 140.02it/s]
 10%|█         | 1046/10000 [00:07<01:04, 139.81it/s]
 11%|█         | 1061/10000 [00:07<01:03, 140.03it/s]
 11%|█         | 1076/10000 [00:07<01:03, 140.35it/s]
 11%|█         | 1091/10000 [00:07<01:03, 140.09it/s]
 11%|█         | 1106/10000 [00:07<01:03, 140.36it/s]
 11%|█         | 1121/10000 [00:08<01:03, 140.66it/s]
 11%|█▏        | 1136/10000 [00:08<01:03, 140.36it/s]
 12%|█▏        | 1151/10000 [00:08<01:03, 140.10it/s]
 12%|█▏        | 1166/10000 [00:08<01:02, 140.38it/s]
 12%|█▏        | 1181/10000 [00:08<01:02, 140.54it/s]
 12%|█▏        | 1196/10000 [00:08<01:02, 140.20it/s]
 12%|█▏        | 1211/10000 [00:08<01:02, 140.45it/s]
 12%|█▏        | 1226/10000 [00:08<01:02, 140.88it/s]
 12%|█▏        | 1241/10000 [00:08<01:02, 140.92it/s]
 13%|█▎        | 1256/10000 [00:08<01:01, 141.52it/s]
 13%|█▎        | 1271/10000 [00:09<01:01, 141.27it/s]
 13%|█▎        | 1286/10000 [00:09<01:01, 141.16it/s]
 13%|█▎        | 1301/10000 [00:09<01:01, 141.37it/s]
 13%|█▎        | 1316/10000 [00:09<01:01, 141.20it/s]
 13%|█▎        | 1331/10000 [00:09<01:01, 140.96it/s]
 13%|█▎        | 1346/10000 [00:09<01:01, 140.59it/s]
 14%|█▎        | 1361/10000 [00:09<01:01, 141.26it/s]
 14%|█▍        | 1376/10000 [00:09<01:00, 141.83it/s]
 14%|█▍        | 1391/10000 [00:09<01:00, 142.00it/s]
 14%|█▍        | 1406/10000 [00:10<01:00, 142.07it/s]
 14%|█▍        | 1421/10000 [00:10<01:00, 142.11it/s]
 14%|█▍        | 1436/10000 [00:10<01:00, 142.06it/s]
 15%|█▍        | 1451/10000 [00:10<01:00, 141.91it/s]
 15%|█▍        | 1466/10000 [00:10<01:00, 141.31it/s]
 15%|█▍        | 1481/10000 [00:10<01:00, 141.46it/s]
 15%|█▍        | 1496/10000 [00:10<00:59, 141.94it/s]
 15%|█▌        | 1511/10000 [00:10<00:59, 142.30it/s]
 15%|█▌        | 1526/10000 [00:10<00:59, 141.50it/s]
 15%|█▌        | 1541/10000 [00:10<00:59, 142.09it/s]
 16%|█▌        | 1556/10000 [00:11<00:59, 141.58it/s]
 16%|█▌        | 1571/10000 [00:11<00:59, 142.13it/s]
 16%|█▌        | 1586/10000 [00:11<00:58, 142.69it/s]
 16%|█▌        | 1601/10000 [00:11<00:58, 142.64it/s]
 16%|█▌        | 1616/10000 [00:11<00:58, 142.53it/s]
 16%|█▋        | 1631/10000 [00:11<00:58, 142.35it/s]
 16%|█▋        | 1646/10000 [00:11<00:58, 142.85it/s]
 17%|█▋        | 1661/10000 [00:11<00:58, 142.56it/s]
 17%|█▋        | 1676/10000 [00:11<00:58, 142.41it/s]
 17%|█▋        | 1691/10000 [00:12<00:58, 142.71it/s]
 17%|█▋        | 1706/10000 [00:12<00:58, 142.39it/s]
 17%|█▋        | 1721/10000 [00:12<00:58, 141.97it/s]
 17%|█▋        | 1736/10000 [00:12<00:58, 141.09it/s]
 18%|█▊        | 1751/10000 [00:12<00:58, 140.78it/s]
 18%|█▊        | 1766/10000 [00:12<00:58, 140.70it/s]
 18%|█▊        | 1781/10000 [00:12<00:58, 140.35it/s]
 18%|█▊        | 1796/10000 [00:12<00:58, 139.28it/s]
 18%|█▊        | 1811/10000 [00:12<00:58, 140.09it/s]
 18%|█▊        | 1826/10000 [00:13<00:58, 140.65it/s]
 18%|█▊        | 1841/10000 [00:13<00:57, 140.83it/s]
 19%|█▊        | 1856/10000 [00:13<00:57, 140.95it/s]
 19%|█▊        | 1871/10000 [00:13<00:57, 141.07it/s]
 19%|█▉        | 1886/10000 [00:13<00:57, 140.96it/s]
 19%|█▉        | 1901/10000 [00:13<00:57, 140.72it/s]
 19%|█▉        | 1916/10000 [00:13<00:57, 141.12it/s]
 19%|█▉        | 1931/10000 [00:13<00:57, 139.94it/s]
 19%|█▉        | 1946/10000 [00:13<00:57, 140.05it/s]
 20%|█▉        | 1961/10000 [00:13<00:57, 140.70it/s]
 20%|█▉        | 1976/10000 [00:14<00:56, 141.41it/s]
 20%|█▉        | 1991/10000 [00:14<00:56, 141.02it/s]
 20%|██        | 2001/10000 [00:14<00:56, 140.35it/s]
SNN test accuracy: 98.45