Quick Start With N-MNIST#

This tutorial explains all steps necessary to deploy a pretrained SNN to the devkit.

  1. The pretrained SNN will be obtained by:

    • CNN-to-SNN conversion.

    • Train SNN with BPTT(Back Propagation Through Time)

  2. As for the dataset, we use the “N-MNIST” dataset.The Neuromorphic-MNIST (N-MNIST) dataset is a spiking version of the original frame-based MNIST dataset. It consists of the same 60 000 training and 10 000 testing samples with a resolution of 34*34 pixels. “Tonic” provides publicly available event-based vision and audio datasets and event transformations. The package is fully compatible with PyTorch. Thus we will use Tonic as the tool for data preparation.

  3. In deployment stage, we use an auxiliary class DynapcnnNetwork to convert the pretrained SNN to a configuration object of the devkit.

  4. Further more, we take a in-depth investigation of the deployment process and reveal more details about the chip.

Data Preparation#

try:
    from tonic.datasets.nmnist import NMNIST
except ImportError:
    ! pip install tonic
    from tonic.datasets.nmnist import NMNIST
    
# download dataset
root_dir = "./NMNIST"
_ = NMNIST(save_to=root_dir, train=True)
_ = NMNIST(save_to=root_dir, train=False)
sample_data, label = NMNIST(save_to=root_dir, train=False)[0]

print(f"type of data is: {type(sample_data)}")
print(f"time length of sample data is: {sample_data['t'][-1] - sample_data['t'][0]} micro seconds")
print(f"there are {len(sample_data)} events in the sample data")
print(f"the label of the sample data is: {label}")
type of data is: <class 'numpy.ndarray'>
time length of sample data is: 300760 micro seconds
there are 4686 events in the sample data
the label of the sample data is: 5

CNN-To-SNN#

Tips for model “devkit-friendly architecture” can be found at here.

Layer types supported by the devkit can be found at here.

Define CNN#

from torch import nn


# define a CNN model
cnn = nn.Sequential(
    # [2, 34, 34] -> [8, 17, 17]
    nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(3, 3), padding=(1, 1), bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2, 2),
    # [8, 17, 17] -> [16, 8, 8]
    nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=(1, 1), bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2, 2),
    # [16 * 8 * 8] -> [16, 4, 4]
    nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2),  bias=False),
    nn.ReLU(),
    # [16 * 4 * 4] -> [10]
    nn.Flatten(),
    nn.Linear(16 * 4 * 4, 10, bias=False),
    nn.ReLU(),
)

# init the model weights
for layer in cnn.modules():
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(layer.weight.data)

You might notice that the output layer is a ReLU instead of a Linear directly. This is because, when converting to SNN, we directly replace the ReLU non-linear activation layer with an IAF layer. The conversion function need a IAF layer as the output layer. Also, at the devkit inference stage, the output is spikes. So using IAF as the output layer is also a choice that is more in line with our hardware operation.

Define CNN Training & Testing Datasets#

from tonic.transforms import ToFrame
from tonic.datasets import nmnist

# define a transform that accumulate the events into a single frame image
to_frame = ToFrame(sensor_size=NMNIST.sensor_size, n_time_bins=1)

cnn_train_dataset = NMNIST(save_to=root_dir, train=True, transform=to_frame)
cnn_test_dataset = NMNIST(save_to=root_dir, train=False, transform=to_frame)

# check the transformed data
sample_data, label = cnn_train_dataset[0]
print(f"The transformed array is in shape [Time-Step, Channel, Height, Width] --> {sample_data.shape}")
The transformed array is in shape [Time-Step, Channel, Height, Width] --> (1, 2, 34, 34)

Train & Test CNN#

!pip install ipywidgets
import torch
from torch.utils.data import DataLoader
from torch.optim import SGD
from tqdm.notebook import tqdm
from torch.nn import CrossEntropyLoss

epochs = 3
lr = 1e-3
batch_size = 4
num_workers = 4
device = "cuda:0"
shuffle = True

cnn = cnn.to(device=device)

cnn_train_dataloader = DataLoader(cnn_train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=shuffle)
cnn_test_dataloader = DataLoader(cnn_test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=shuffle)

optimizer = SGD(params=cnn.parameters(), lr=lr)
criterion = CrossEntropyLoss()

for e in range(epochs):

    # train
    train_p_bar = tqdm(cnn_train_dataloader)
    for data, label in train_p_bar:
        # remove the time-step axis since we are training CNN
        # move the data to accelerator
        data = data.squeeze(dim=1).to(dtype=torch.float, device=device)
        label = label.to(dtype=torch.long, device=device)
        # forward
        optimizer.zero_grad()
        output = cnn(data)
        loss = criterion(output, label)
        # backward
        loss.backward()
        optimizer.step()
        # set progressing bar
        train_p_bar.set_description(f"Epoch {e} - Training Loss: {round(loss.item(), 4)}")

    # validate
    correct_predictions = []
    with torch.no_grad():
        test_p_bar = tqdm(cnn_test_dataloader)
        for data, label in test_p_bar:
            # remove the time-step axis since we are training CNN
            # move the data to accelerator
            data = data.squeeze(dim=1).to(dtype=torch.float, device=device)
            label = label.to(dtype=torch.long, device=device)
            # forward
            output = cnn(data)
            # calculate accuracy
            pred = output.argmax(dim=1, keepdim=True)
            # compute the total correct predictions
            correct_predictions.append(pred.eq(label.view_as(pred)))
            # set progressing bar
            test_p_bar.set_description(f"Epoch {e} - Testing Model...")
    
        correct_predictions = torch.cat(correct_predictions)
        print(f"Epoch {e} - accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%")
Epoch 0 - accuracy: 55.55%
Epoch 1 - accuracy: 63.949999999999996%
Epoch 2 - accuracy: 74.72999999999999%

Covert CNN To SNN#

from sinabs.from_torch import from_model

snn_convert = from_model(model=cnn, input_shape=(2, 34, 34), batch_size=batch_size).spiking_model
snn_convert
Sequential(
  (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (7): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=256, out_features=10, bias=False)
  (10): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
)

Test Converted SNN#

# define a transform that accumulate the events into a raster-like tensor
n_time_steps = 100
to_raster = ToFrame(sensor_size=NMNIST.sensor_size, n_time_bins=n_time_steps)
snn_test_dataset = NMNIST(save_to=root_dir, train=False, transform=to_raster)
snn_test_dataloader = DataLoader(snn_test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False)

snn_convert = snn_convert.to(device)

correct_predictions = []
with torch.no_grad():
    test_p_bar = tqdm(snn_test_dataloader)
    for data, label in test_p_bar:
        # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
        data = data.reshape(-1, 2, 34, 34).to(dtype=torch.float, device=device)
        label = label.to(dtype=torch.long, device=device)
        # forward
        output = snn_convert(data)
        # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
        output = output.reshape(batch_size, n_time_steps, -1)
        # accumulate all time-steps output for final prediction
        output = output.sum(dim=1)
        # calculate accuracy
        pred = output.argmax(dim=1, keepdim=True)
        # compute the total correct predictions
        correct_predictions.append(pred.eq(label.view_as(pred)))
        # set progressing bar
        test_p_bar.set_description(f"Testing SNN Model...")

    correct_predictions = torch.cat(correct_predictions)
    print(f"accuracy of converted SNN: {correct_predictions.sum().item()/(len(correct_predictions))*100}%")
accuracy of converted SNN: 70.47%

Degraded Performance After Conversion#

You might observe a degraded performance after the CNN-to-SNN conversion which might caused by:

  • IAF neuron stay silent after converting to SNN.

  • the distribution mis-match between the output of ReLU activation layer and the output of spiking IAF layer.

  • Difference between synchronous-convolution and asynchronous-convolution.

To mitigate this issue, we usually apply the following tricks on the CNN/SNN’s:

  1. Re-scaling the first parameter layer’s weights of the SNN which prevent the SNN from non-spking,

  2. If trick No.1 not work well, in sinabs we provide an auxiliary function sinabs.utils.normalize_weights which can help to normalize the activation of the CNN. More details can be found here.

If the tricks above do not help, try to train an SNN directly with BPTT. We here only focus on how to use the devkit instead of optimizing the performance of the model.

Train SNN with BPTT#

Instead of using a CNN-to-SNN conversion, we are able to train an SNN directly with BPTT. sinabs-exodus provides a CUDA enhanced IAF neuron layer, which helps to speed-up the training process of the SNN. Here we recommend to use layers from the sinabs-exodus for faster training speed. The installation of the sinabs-exodus can be found here.

Define SNN#

import sinabs.layers as sl
from torch import nn
from sinabs.activation.surrogate_gradient_fn import PeriodicExponential

# just replace the ReLU layer with the sl.IAFSqueeze
snn_bptt = nn.Sequential(
    # [2, 34, 34] -> [8, 17, 17]
    nn.Conv2d(in_channels=2, out_channels=8, kernel_size=(3, 3), padding=(1, 1), bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    nn.AvgPool2d(2, 2),
    # [8, 17, 17] -> [16, 8, 8]
    nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), padding=(1, 1), bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    nn.AvgPool2d(2, 2),
    # [16 * 8 * 8] -> [16, 4, 4]
    nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2),  bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
    # [16 * 4 * 4] -> [10]
    nn.Flatten(),
    nn.Linear(16 * 4 * 4, 10, bias=False),
    sl.IAFSqueeze(batch_size=batch_size, min_v_mem=-1.0, surrogate_grad_fn=PeriodicExponential()),
)

# init the model weights
for layer in snn_bptt.modules():
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(layer.weight.data)

Why Disable All “Bias” Of The Convolutional Layers?#

The bias term in fact is related to the neuron’s leak mechanism on the hardware. The speed of the neuron v_mem leakage is effected by an external slow-clock.

More details of the bias and neuron leak can be found in the following 2 docs:

  1. Bias is a no no.

  2. How to leak the neuron.

Covert To Exodus Model If Exodus Available#

try:
    from sinabs.exodus import conversion
    snn_bptt = conversion.sinabs_to_exodus(snn_bptt)
except ImportError:
    print("Sinabs-exodus is not intalled.")

snn_bptt
Sequential(
  (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): EXODUS IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): EXODUS IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (7): EXODUS IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=256, out_features=10, bias=False)
  (10): EXODUS IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
)

Define SNN Training & Testing Datasets#

n_time_steps = 100
to_raster = ToFrame(sensor_size=NMNIST.sensor_size, n_time_bins=n_time_steps)

snn_train_dataset = NMNIST(save_to=root_dir, train=True, transform=to_raster)
snn_test_dataset = NMNIST(save_to=root_dir, train=False, transform=to_raster)

Train & Test SNN With BPTT#

epochs = 1
lr = 1e-3
batch_size = 4
num_workers = 4
device = "cuda:0"
shuffle = True

snn_train_dataloader = DataLoader(snn_train_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True)
snn_test_dataloader = DataLoader(snn_test_dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=False)

snn_bptt = snn_bptt.to(device=device)

optimizer = SGD(params=snn_bptt.parameters(), lr=lr)
criterion = CrossEntropyLoss()

for e in range(epochs):

    # train
    train_p_bar = tqdm(snn_train_dataloader)
    for data, label in train_p_bar:
        # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
        data = data.reshape(-1, 2, 34, 34).to(dtype=torch.float, device=device)
        label = label.to(dtype=torch.long, device=device)
        # forward
        optimizer.zero_grad()
        output = snn_bptt(data)
        # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
        output = output.reshape(batch_size, n_time_steps, -1)
        # accumulate all time-steps output for final prediction
        output = output.sum(dim=1)
        loss = criterion(output, label)
        # backward
        loss.backward()
        optimizer.step()
        
        # detach the neuron states and activations from current computation graph(necessary)
        for layer in snn_bptt.modules():
            if isinstance(layer, sl.StatefulLayer):
                for name, buffer in layer.named_buffers():
                    buffer.detach_()
        
        # set progressing bar
        train_p_bar.set_description(f"Epoch {e} - BPTT Training Loss: {round(loss.item(), 4)}")

    # validate
    correct_predictions = []
    with torch.no_grad():
        test_p_bar = tqdm(snn_test_dataloader)
        for data, label in test_p_bar:
            # reshape the input from [Batch, Time, Channel, Height, Width] into [Batch*Time, Channel, Height, Width]
            data = data.reshape(-1, 2, 34, 34).to(dtype=torch.float, device=device)
            label = label.to(dtype=torch.long, device=device)
            # forward
            output = snn_bptt(data)
            # reshape the output from [Batch*Time,num_classes] into [Batch, Time, num_classes]
            output = output.reshape(batch_size, n_time_steps, -1)
            # accumulate all time-steps output for final prediction
            output = output.sum(dim=1)
            # calculate accuracy
            pred = output.argmax(dim=1, keepdim=True)
            # compute the total correct predictions
            correct_predictions.append(pred.eq(label.view_as(pred)))
            # set progressing bar
            test_p_bar.set_description(f"Epoch {e} - BPTT Testing Model...")
    
        correct_predictions = torch.cat(correct_predictions)
        print(f"Epoch {e} - BPTT accuracy: {correct_predictions.sum().item()/(len(correct_predictions))*100}%")
Epoch 0 - BPTT accuracy: 92.34%

To obtain a SNN with better on-chip performance, please refer to the training tips.

Convert Back To Sinabs Model If Using Exodus Model For Training#

try:
    from sinabs.exodus import conversion
    snn_bptt = conversion.exodus_to_sinabs(snn_bptt)
except ImportError:
    print("Sinabs-exodus is not intalled.")

snn_bptt
Sequential(
  (0): Conv2d(2, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (7): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=256, out_features=10, bias=False)
  (10): IAFSqueeze(spike_threshold=1.0, min_v_mem=-1.0, batch_size=4, num_timesteps=-1)
)

Deploy SNN To The Devkit#

To deploy the SNN to the devkit, we use an auxiliary class DynapcnnNetwork to convert the pretrained SNN to a configuration object of the devkit.

In the example beblow, we use the “Speck2fModuleDevKit” as the inference device. More devkit names can be found here.

from sinabs.backend.dynapcnn import DynapcnnNetwork

# cpu_snn = snn_convert.to(device="cpu")
cpu_snn = snn_bptt.to(device="cpu")
dynapcnn = DynapcnnNetwork(snn=cpu_snn, input_shape=(2, 34, 34), discretize=True, dvs_input=False)
devkit_name = "speck2fmodule"

# use the `to` method of DynapcnnNetwork to deploy the SNN to the devkit
dynapcnn.to(device=devkit_name, chip_layers_ordering="auto")
print(f"The SNN is deployed on the core: {dynapcnn.chip_layers_ordering}")
Network is valid
The SNN is deployed on the core: [0, 1, 2, 3]

press the “reset button” on the hardware and re-run the code block above if you meet a “time-out error”!

Inference On The Devkit#

Our devkits takes samna events stream as the input. So before sending the raw N-MNIST events into devkit, we need first convert the data to samna’s samna.speck2f.event.Spike events stream.

Notice: Different types of devkit need different types of event as its input. For example, if you’re using a DynapcnnDevKit, you need to use the samna.dynapcnn.event.Spike as input.

import samna
from collections import Counter
from torch.utils.data import Subset

snn_test_dataset = NMNIST(save_to=root_dir, train=False)
# for time-saving, we only select a subset for on-chip infernce, here we select 1/100 for an example run
subset_indices = list(range(0, len(snn_test_dataset), 100))
snn_test_dataset = Subset(snn_test_dataset, subset_indices)

inferece_p_bar = tqdm(snn_test_dataset)

test_samples = 0
correct_samples = 0

for events, label in inferece_p_bar:

    # create samna Spike events stream
    samna_event_stream = []
    for ev in events:
        spk = samna.speck2f.event.Spike()
        spk.x = ev['x']
        spk.y = ev['y']
        spk.timestamp = ev['t'] - events['t'][0]
        spk.feature = ev['p']
        # Spikes will be sent to layer/core #0, since the SNN is deployed on core: [0, 1, 2, 3]
        spk.layer = 0
        samna_event_stream.append(spk)

    # inference on chip
    # output_events is also a list of Spike, but each Spike.layer is 3, since layer#3 is the output layer
    output_events = dynapcnn(samna_event_stream)
    
    # use the most frequent output neruon index as the final prediction
    neuron_index = [each.feature for each in output_events]
    if len(neuron_index) != 0:
        frequent_counter = Counter(neuron_index)
        prediction = frequent_counter.most_common(1)[0][0]
    else:
        prediction = -1
    inferece_p_bar.set_description(f"label: {label}, prediction: {prediction}, output spikes num: {len(output_events)}") 

    if prediction == label:
        correct_samples += 1

    test_samples += 1
    
print(f"On chip inference accuracy: {correct_samples / test_samples}")        
On chip inference accuracy: 0.94

In-depth Investigation Of SNN Deployment Stage#

Basicly, the key factors for deploying the SNN to the devkit and interacting with it are the following 2:

  • hardware configuration. It contains the SNN’s quantized weights, spking thresholds and connectivity etc. See more details about the devkit configuration in samna’s doc.

  • samna graph. It defines how input and output event streams flows beween the devkit and the host machine. See more details about the samna graph in samna’s doc.

In the example above, DynapcnnNetwork’s .to method implicitly calls a .make_config method to build a “hardware configuration” object for the devkit and then applys the configuration to the devkit before the calling of .to is finished.

Apart from that, when .to is called, a simple samna graph which supports the basic input-writing and output-reading for the devkit is built.

How SNN Is Deployed To The Processor?#

The “hardware configuration” has an attribute called .cnn_layers which contains the SNN quantized weights and connectivity. The weights of the SNN is quantized into a int8 precision and the “membrane potential(v_mem, neuron states)” of the IAF neuron is quantized into a int16 precision.

The nn.Linear layers in the model will be converted to a equivalent nn.Conv2d and the nn.AvgPool2d layers will also be converted to a sinabs.layers.SumPool2d before generating the hardware configuration. More details can be found in the FAQs.

Current version of sinabs-dynpacnn only supports parsing the nn.Sequential like architechture. In the future version, a feature of network graph extraction will be integrated. If you would like to try more complex architectures like a “residual-connection”, please refer to the FAQs.

The SNN is deployed to the devkit once the “hardware configuration” is applied to the devkit. Before the calling of .to method is finished, it implicitly applys the configutation by executing self.samna_device.get_model().apply_configuration(config).

How Input Data Is Sent To The Processor?#

input_data_flow

As stated above, samna graph controls how data flows between the devkit and host machine. The samna graph usually has a “input buffer node” for receiving input from the host machine. After preparing a list of “samna events”(Spike, DvsEvent and ReadNeuronValue etc.), we call the .write method of the “input buffer node”, like input_buffer.write(events_list) and the entire list of “samna events” will be sent to the devkit.

Apart from that, the devkit has a stopwatch(in the FPGA) which is timed in micro-seconds unit. Only When the time of the stopwatch is greater than or equal to the timestamp of the input event, the event will be sent into the DynapCNN layer or DVS layer for processing. In DynapcnnNetwork’s .forward method, everytime before the “samna events” is written into the devkit, it restarts the “stopwatch” automatically. So we usually shift the input events’ timestamps to let the timestamps start from 0.

If the current time of the stopwatch is already greater than all timestamps of the input events, all events will be sent for processing immediately!

How Output Data Is Read From The Devkit?#

Similarly, the samna graph also has an “output buffer node” which supports users read output events from the devkit. By calling the .get_events method of the “output buffer node”, users can obtain the output events as a list: output_list = output_buffer.get_events(). In DynapcnnNetwork’s .forward method it implicitly calls the output_buffer.get_events() and get the output.

It is possible to set multiple output buffer nodes for one samna graph. By setting different types of “Filter Node” before each output buffer node, users can read different types of output events from different output buffer more efficiently.

An Adanved Example With Visualizer#

In the example code below, we show that how to make a further modification on the “hardware configurration” to exploit more features of the devkit like:

  • monitor the input events

  • monitor the hidden layers’ output spikes

Besides, we use a Visualizer to show the input event streams on a GUI window.

import samnawe
# first define a callback function to modify the devkit configuration
# the callback function should only has 1 devkit config instance as its input argument
def config_modify_callback(devkit_cfg):

    # enable visualizing the output from dvs(pre-processing) layer
    devkit_cfg.dvs_layer.monitor_enable = True
    # disable visualizing the events generated by the embedded dvs on Speck
    devkit_cfg.dvs_layer.raw_monitor_enable = False
    # prevent the events generated by the embedded dvs been feed to the DynapCNN Core.
    devkit_cfg.dvs_layer.pass_sensor_events = False
    # point the dvs layer output destination to the core#0 
    devkit_cfg.dvs_layer.destinations[0].enable = True
    devkit_cfg.dvs_layer.destinations[0].layer = 0

    # the callback must return the modified devkit config
    return devkit_cfg

# close the devkit before reopen
samna.device.close_device(dynapcnn.samna_device)

# init DynapcnnNetwork instance
dynapcnn = DynapcnnNetwork(snn=cpu_snn, input_shape=(2, 34, 34), discretize=True, dvs_input=True)

devkit_name = "speck2fmodule"
# define which layers output you want to monitor
layers_to_monitor = [0, 1, 2, 3]
# pass the callback function into the `.to` method
dynapcnn.to(device=devkit_name, chip_layers_ordering=[0, 1, 2, 3], monitor_layers=layers_to_monitor, config_modifier=config_modify_callback)
print(f"The SNN is deployed on the core: {dynapcnn.chip_layers_ordering}")
Network is valid
/home/allan/newssd/synsense_codes/sinabs-dynapcnn/dynapcnn_venv/lib/python3.8/site-packages/sinabs/backend/dynapcnn/chips/dynapcnn.py:289: UserWarning: Layer 0 has pooling and is being monitored. Note that pooling will not be reflected in the monitored events.
  warn(
/home/allan/newssd/synsense_codes/sinabs-dynapcnn/dynapcnn_venv/lib/python3.8/site-packages/sinabs/backend/dynapcnn/chips/dynapcnn.py:289: UserWarning: Layer 1 has pooling and is being monitored. Note that pooling will not be reflected in the monitored events.
  warn(
The SNN is deployed on the core: [0, 1, 2, 3]

Instead of pass a list of layer indices to the monitor_layers to enable reading output from the specific layer, you can also achieve that goal by adding some codes in the config_modify_callback:

def config_modify_callback(devkit_cfg):

    # enable visualizing the output from dvs(pre-processing) layer
    devkit_cfg.dvs_layer.monitor_enable = True
    # disable visualizing the events generated by the embedded dvs
    devkit_cfg.dvs_layer.raw_monitor_enable = False
    # prevent the events generated by the embedded dvs been feed to the DynapCNN Core.
    devkit_cfg.dvs_layer.pass_sensor_events = False
    # point the dvs layer output destination to the core#0 
    devkit_cfg.dvs_layer.destinations[0].enable = True
    devkit_cfg.dvs_layer.destinations[0].layer = 0

    # **enable monitoring all layers output**
    for layer in [0, 1, 2, 3]:
        devkit_cfg.cnn_layers[layer].monitor_enable = True

    # the callback must return the modified devkit config
    return devkit_cfg

What I want to say here is that in most cases, exploit different features of the devkit is essentially modifying the devkit configuration of the devkit.

Use DynapcnnVisualizer#

After the SNN is deployed to the devkit. We can use the DynapcnnVisualizer to visualize the input events.

from sinabs.backend.dynapcnn.dynapcnn_visualizer import DynapcnnVisualizer


visualizer = DynapcnnVisualizer(
    window_scale=(4, 8),
    dvs_shape=(34, 34),
    spike_collection_interval=50,
)

visualizer.connect(dynapcnn)
Connecting: Please wait until the JIT compilation is done, this might take a while. You will get notified on completion.
Set up completed!

After the visualizer is built, you should see a GUI window pop out.

visualizer

By running the code block below, the input events will be displayed on the GUI window.

We now start to write inputs to the devkit.In the example above, we use the samna.speck2f.event.Spike as the input event type. Since now we need to visualize the input events which we write into the devkit, we need to use the samna.speck2f.event.DvsEvent. It is because only by this the input events can be captured by the visualizer. More details about the “DvsEvent” can be found here.

dvs_event_input_flow

from collections import Counter
from torch.utils.data import Subset

snn_test_dataset = NMNIST(save_to=root_dir, train=False)
# for time-saving, we only select a subset for on-chip infernce, here we select 1/100 for an example run
subset_indices = list(range(0, len(snn_test_dataset), 100))
snn_test_dataset = Subset(snn_test_dataset, subset_indices)

inferece_p_bar = tqdm(snn_test_dataset)

for events, label in inferece_p_bar:

    # instead of creating Spike and send it to core#0 directly, we now create DvsEvent(for visualization) and send it to the DVS layer
    # since in the "config_modify_callback" we point the output destination layer of the DVS layer to layer/core #0
    # so the DynacnnCore can still receive the same input as before.
    samna_event_stream = []
    for ev in events:
        dvs_ev = samna.speck2f.event.DvsEvent()
        dvs_ev.x = ev['x']
        dvs_ev.y = ev['y']
        dvs_ev.timestamp = ev['t'] - events['t'][0]
        dvs_ev.p = ev['p']
        samna_event_stream.append(dvs_ev)

    # inference on chip
    # output_events is also a list of Spike, but .layer will have 0, 1, 2, 3 since we choose to monitor all layers' output
    output_events = dynapcnn(samna_event_stream)
    
    # get each layers output spikes
    layer0_spks = [each.feature for each in output_events if each.layer == 0]
    layer1_spks = [each.feature for each in output_events if each.layer == 1]
    layer2_spks = [each.feature for each in output_events if each.layer == 2]
    layer3_spks = [each.feature for each in output_events if each.layer == 3]
    # use the most frequent output neruon index as the final prediction
    if len(layer3_spks) != 0:
        frequent_counter = Counter(layer3_spks)
        prediction = frequent_counter.most_common(1)[0][0]
    else:
        prediction = -1
    inferece_p_bar.set_description(f"label: {label} prediction: {prediction},layer 0 output spks: {len(layer0_spks)},layer 1 output spikes num: {len(layer1_spks)}, layer 2 output spikes num: {len(layer2_spks)},layer 3 output spikes num: {len(layer3_spks)}") 

    if prediction == label:
        correct_samples += 1

    test_samples += 1
    
print(f"On chip inference accuracy: {correct_samples / test_samples}")    
On chip inference accuracy: 0.94

Yay! Success. You have completed all steps!