Training NMNIST for deployment on Speck using EXODUS#

import sinabs
import sinabs.layers as sl
import torch
import torch.nn as nn
import numpy as np

Let’s visualize the neuron model that’s supported on chip

iaf = sl.IAF(record_states=True, spike_threshold=5.)

n_steps = 400
input_ = (torch.rand((1, n_steps, 1)) < 0.05).float()
output = iaf(input_)
import matplotlib.pyplot as plt

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(15,5))

ax1.eventplot(torch.where(input_)[1])
ax1.set_ylabel("Input events")
ax2.plot(iaf.recordings['v_mem'].squeeze().numpy())
ax2.set_ylabel("IF Vmem")
ax3.eventplot(torch.where(output)[1])
ax3.set_ylabel("Output Events")
ax3.set_xlabel("Time")
Text(0.5, 0, 'Time')
../_images/190f6622e3d739dd655716bdb6dc438de74989bea8362286c9bed244d5049a69.png
from tonic import datasets, transforms

trainset = datasets.NMNIST('data', train=True)
testset = datasets.NMNIST('data', train=False)
transform = transforms.Compose([
    transforms.ToFrame(sensor_size=trainset.sensor_size, n_time_bins=30, include_incomplete=True),
    lambda x: x.astype(np.float32),
])

events, label = trainset[0]
frames = transform(events)
frames.shape
(30, 2, 34, 34)
plt.imshow(frames[:10, 0].sum(0))
<matplotlib.image.AxesImage at 0x7f6d9c8911f0>
../_images/38efbe7c242f561dc35142ba13b693d0b4dbfa57523e47624143a056de11f270.png
trainset = datasets.NMNIST('data', train=True, transform=transform)
testset = datasets.NMNIST('data', train=False, transform=transform)
batch_size = 16

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=4, drop_last=True)
frames = next(iter(trainloader))[0]
frames.shape
torch.Size([16, 30, 2, 34, 34])
import sinabs.exodus.layers as sel

backend = sl # Sinabs
backend = sel # Sinabs EXODUS

model = nn.Sequential(
    sl.FlattenTime(),
    nn.Conv2d(2, 8, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(8, 16, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    sl.SumPool2d(2),
    nn.Conv2d(64, 10, kernel_size=2, padding=0, bias=False),
    backend.IAFSqueeze(batch_size=batch_size, min_v_mem=-1),
    nn.Flatten(),
    sl.UnflattenTime(batch_size=batch_size),
).cuda()
model(frames.cuda()).shape
torch.Size([16, 30, 10])
from tqdm.notebook import tqdm


n_epochs = 1
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
crit = nn.functional.cross_entropy

for epoch in range(n_epochs):
    losses = []
    for data, targets in tqdm(trainloader):
        data, targets = data.cuda(), targets.cuda()
        sinabs.reset_states(model)
        optimizer.zero_grad()
        y_hat = model(data)
        pred = y_hat.sum(1)
        loss = crit(pred, targets,)
        loss.backward()
        losses.append(loss)
        optimizer.step()
    print(f"Loss: {torch.stack(losses).mean()}")
Loss: 0.27892711758613586
import torchmetrics

acc = torchmetrics.Accuracy('multiclass', num_classes=10).cuda()
model.eval()

for data, targets in tqdm(testloader):
    data, targets = data.cuda(), targets.cuda()
    sinabs.reset_states(model)
    with torch.no_grad():
        y_hat = model(data)
    pred = y_hat.sum(1)
    acc(pred, targets)
Test accuracy: 0.9707000255584717
print(f"Test accuracy: {100*acc.compute():.2f}%")
Test accuracy: 97.07%
from sinabs.exodus.conversion import exodus_to_sinabs

sinabs_model = exodus_to_sinabs(model)
torch.save(sinabs_model, "nmnist_model.pth")