sinabs Tutorial 使用入门#
Take LeNet as an example to train and test a spiking neural network (SNN)#
以LeNet为例训练使用脉冲神经网络(SNN)#
We demonstrate the general steps to construct a useful SNN model in sinabs and take LeNet-5 on MNIST for an example.
我们以LeNet-5在MNIST的工作为例,介绍创建并使用SNN的通用步骤:
Build/define a LeNet CNN model in Pytorch
用Pytorch创建一个LeNet CNNTrain and test this LeNet CNN model in Pytorch
用Pytorch训练并测试这个CNN模型Convert this LeNet CNN model into SNN using sinabs
用sinabs将CNN模型转化为SNN模型Test on SNN in sinabs
用sinabs测试SNN模型
import os
import torch
import torchvision
import sinabs
import sinabs.layers as sl
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from sinabs.from_torch import from_model
Build/define a LeNet CNN model in Pytorch#
用Pytorch创建一个LeNet CNN#
Recommand to use torch.nn.Sequential of torch.nn layers instead of manually added forwarding functions among layers.
推荐使用torch.nn.Sequential模型,其中每一层都使用torch.nn定义的layers,而不推荐在forward()函数中自定义层间函数Current supporting standard layers:
目前sinabs支持自动转化为SNN的标准层有:Conv2d
Linear
AvgPool2d
MaxPool2d
ReLU
Flatten
Dropout
BatchNorm
Users can also define their own layers deriving from torch.nn.Module
用户也可以自定义层,继承torch.nn.Module
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.seq = nn.Sequential(
# 1st Conv + ReLU + Pooling
nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
# 2nd Conv + ReLU + Pooling
nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
# Dense layers
nn.Flatten(),
nn.Linear(4 * 4 * 50, 500),
nn.ReLU(),
nn.Linear(500, 10),
)
def forward(self, x):
return self.seq(x)
Setting up environment#
设置环境#
Torch device: GPU or CPU
设置Torch运行的设备: GPU 或者 CPUTorch dataloader: training/testing/spiking_testing
设置Torch的dataloader: 分别用在训练/测试/脉冲神经网络测试Input image size: (n_channel, width, height)
输入图像大小: (通道数,图像宽,图像高)
def prepare():
# Setting up environment
# Declare global environment parameters
# Torch device: GPU or CPU
# Torch dataloader: training
# Torch dataloader: testing
# Torch dataloader: spiking testing
# Input image size: (n_channel, width, height)
global device, train_dataloader, test_dataloader, spiking_test_dataloader, input_image_size
# Torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model folder to save trained models
os.makedirs("models", exist_ok=True)
# Setting up random seed to reproduce experiments
torch.manual_seed(0)
if device != "cpu":
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Downloading/Loading MNIST dataset as tensors for training
train_dataset = MNIST(
"./data/",
train=True,
download=True,
transform=torchvision.transforms.ToTensor(),
)
# Downloading/Loading MNIST dataset as tensors for testing
test_dataset = MNIST(
"./data/",
train=False,
download=True,
transform=torchvision.transforms.ToTensor(),
)
# Define Torch dataloaders for training, testing and spiking testing
BATCH_SIZE = 512
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)
spiking_test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)
# Define the size of input images
input_image_size = (1, 28, 28)
# Return global prameters
return (
device,
train_dataloader,
test_dataloader,
spiking_test_dataloader,
input_image_size,
)
Train LeNet CNN model in Pytorch#
用Pytorch训练LeNet CNN模型#
Define loss
定义损失函数Define optimizer
定义优化器Backpropagation over batches and epochs
反向传播
def train(model, n_epochs=20):
# Training a CNN model
# Define loss
criterion = torch.nn.CrossEntropyLoss()
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Visualize and display training loss in a progress bar
pbar = tqdm(range(n_epochs))
# backprop over epochs
for epoch in pbar:
# over batches
for imgs, labels in train_dataloader:
# reset grad to zero for each batch
optimizer.zero_grad()
# port to device
imgs, labels = imgs.to(device), labels.to(device)
# forward pass
outputs = model(imgs)
# calculate loss
loss = criterion(outputs, labels)
# display loss in progress bar
pbar.set_postfix(loss=loss.item())
# backward pass
loss.backward()
# optimze parameters
optimizer.step()
return model
Test LeNet CNN model in Pytorch#
用Pytorch测试LeNet CNN模型#
# Define the function to count the correct prediction
def count_correct(output, target):
_, predicted = torch.max(output, 1)
acc = (predicted == target).sum().float()
return acc.cpu().numpy()
def test(model):
# Test the accuracy of a CNN model
# With no gradient means less memory and calculation on forward pass
with torch.no_grad():
# evaluation usese Dropout and BatchNorm in inference mode
model.eval()
# Count correct prediction and total test number
n_correct = 0
n_test = 0
# over batches
for imgs, labels in test_dataloader:
# port to device
imgs, labels = imgs.to(device), labels.to(device)
# inference
outputs = model(imgs)
n_correct += count_correct(outputs, labels)
n_test += len(labels)
# calculate accuracy
ann_accuracy = n_correct / n_test * 100.0
print("ANN test accuracy: %.2f" % (ann_accuracy))
return ann_accuracy
Test LeNet SNN model in sinabs#
用sinabs测试LeNet SNN模型#
Transfer pytorch trained CNN model to SNN model in sinabs
将CNN模型转化为SNN模型neural model is different
神经元模型是不同的a spiking neuron of an SNN holds a membrane potential state (V) of a certain time t over a time period (n_dt)
每一个脉冲神经元都有一个膜电压V作为其某一时刻t的状态,整个SNN模拟的时间长度为n_dtweighted input adds up to the V
输入乘以权重的值会加到V上a spiking neuron outputs a spike (binary output per time step dt) when V >= threshold at time t
一个脉冲神经元在V>=threshold的时刻t,会释放一个脉冲,即每个时间步长为一个二进制的输出once a spike is generated, the V is subtraced by membrane_subtract, and the lower bound of V is set to min_v_mem
每当产生一个脉冲,V会减去membrane_subtract,并且其下界设置为min_v_mem(V不能低于min_v_mem)
network architecture is the same (e.g. convolution, pooling and dense)
神经网络的架构是一致的(比如 convolution, pooling and dense)network parameters are the same (e.g. weights and biases)
神经网络的参数值是一致的(例如 权重和偏置)
Tile an image to a sequence of n_dt images as input to SNN simulations
将一个图片扩展为n_dt个图片序列作为SNN的输入This processing on tile-up images seems inefficient
这种把一副图片复制为图片序列的方法看起来很低效however, it is only a software simulation on continous current flow injecting to spiking neurons for n_dt length
但是这只是受限于软件模拟的限制,其实质是输入脉冲神经元中耗时n_dt的固定电流which is ultra power efficient on Neuromorphic hardware
在神经形态硬件中的实现是非常低功耗且高效的
sinabs can only infer one input as a time, so batch_size = 1
sinabs每次只能推理一个输入,因此batch_size = 1Classification is calculated on the count of spikes on the output layer over time period (n_dt)
分类的结果是在n_dt的时间段中读取输出神经元的脉冲总数来计算的
# Define tensor_tile function to generate sequence of input images
def tensor_tile(a, dim, n_tile):
# a: input tensor
# dim: tile on a specific dim or dims in a tuple
# n_tile: number of tile to repeat
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
)
return torch.index_select(a, dim, order_index)
def snn_test(model, n_dt=10, n_test=10000):
# Testing the accuracy of SNN on sinabs
# model: CNN model
# n_dt: the time window of each simulation
# n_test: number of test images in total
# Transfer Pytorch trained CNN model to sinabs SNN model
net = from_model(
model, # Pytorch trained model
input_image_size, # Input image size: (n_channel, width, height)
spike_threshold=1.0, # Threshold of the membrane potential of a Spiking neuron
bias_rescaling=1.0, # Subtract membrane potential when the neuron fires a spike
min_v_mem=-1.0, # The lower bound of the membrane potential
num_timesteps=n_dt, # The number of time steps
).to(device)
# With no gradient means less memory and calculation on forward pass
with torch.no_grad():
# evaluation usese Dropout and BatchNorm in inference mode
net.spiking_model.eval()
# Count correct prediction and total test number
n_correct = 0
# loop over the input files once a time
for i, (imgs, labels) in enumerate(tqdm(spiking_test_dataloader)):
if i > n_test:
break
# tile image to a sequence of n_dt length as input to SNN
input_frames = tensor_tile(imgs, 0, n_dt).to(device)
labels = labels.to(device)
# Reset neural states of all the neurons in the network for each inference
net.reset_states()
# inference
outputs = net.spiking_model(input_frames)
n_correct += count_correct(outputs.sum(0, keepdim=True), labels)
# calculate accuracy
snn_accuracy = n_correct / n_test * 100.0
print("SNN test accuracy: %.2f" % (snn_accuracy))
return snn_accuracy
# Setting up environment
prepare()
# Init LeNet5 CNN
classifier = LeNet5().to(device)
# Train CNN model
train(classifier, n_epochs=2)
# Test on CNN model
ann_accuracy = test(classifier)
ANN test accuracy: 98.04
# Test on SNN model
snn_accuracy = snn_test(classifier, n_dt=10, n_test=2000)
SNN test accuracy: 98.35