Training by backpropagation through time (BPTT)#

BPTT is normally a procedure used while training recurrent neural networks. In the case of spiking networks, even if the network is not recurrent, it has a memory of its previous processing steps through the persistence of membrane potentials. Unlike normal neural networks, spiking networks have an internal state that lasts in time.

This is why BPTT can be used for more precise (but also much more computationally expensive) training in SNNs for sequential tasks. In sinabs, backpropagation in the spiking network is accomplished through a surrogate gradient method, since the spiking nonlinearity is not differentiable.

In this notebook, we will train a spiking network directly (without training an analog network first), on the Sequential MNIST task. In Sequential MNIST, a network is shown the 28x28-pixel MNIST digits one row after the other. The input to the network is a single row of 28 pixels, followed by the second one, etc, until all 28 rows are shown. At this point, the network makes a prediction on the digit label.

First, we define the MNIST dataset. Note that pixel values are between 0 and 1. We turn those values into probabilities of spiking.

from torchvision import datasets
import torch

torch.manual_seed(0)


class MNIST(datasets.MNIST):
    def __init__(self, root, train=True, single_channel=False):
        datasets.MNIST.__init__(self, root, train=train, download=True)
        self.single_channel = single_channel

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = img.float() / 255.0

        # default is  by row, output is [time, channels] = [28, 28]
        # OR if we want by single item, output is [784, 1]
        if self.single_channel:
            img = img.reshape(-1).unsqueeze(1)

        spikes = torch.rand(size=img.shape) < img
        spikes = spikes.float()

        return spikes, target
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 64

dataset_test = MNIST(root="./data/", train=False)
dataloader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=BATCH_SIZE, drop_last=True
)

dataset = MNIST(root="./data/", train=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, drop_last=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training a baseline#

We must demonstrate that this task is not solvable with similar accuracy by a memory-less analog network, despite being sequential. Let us then try to train such a baseline.

from torch import nn

ann = nn.Sequential(
    nn.Linear(28, 128),
    nn.ReLU(),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Linear(128, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.ReLU(),
)

Training#

from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters())

for epoch in range(2):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()

        target = target.unsqueeze(1).repeat([1, 28])
        img = img.reshape([-1, 28])
        target = target.reshape([-1])

        out = ann(img)
        #         out = out.sum(1)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

Testing#

accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:

    img = img.reshape([-1, 28])
    out = ann(img)
    out = out.reshape([64, 28, 10])
    out = out.sum(1)

    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target).sum().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs) / len(accs))
0.4560296474358974

Defining a spiking network#

We then define a 4-layer fully connected spiking neural network.

from sinabs.from_torch import from_model

model = from_model(ann, batch_size=BATCH_SIZE).to(device)
model = model.train()

Training#

Here, we begin training. Note that the state of the network must be reset at every iteration.

PyTorch convolutional layers don’t support inputs that aren’t 4-dimensional (batch, channels, height, width).

As a workaround, when using sinabs, you’ll have to squeeze the time and batch dimensions. Starting from data in the form (batch, time, channels, …), the data should be squeezed to (batch*time, channels, …).

The spiking layer will automatically unpack that dimension distinguishing between batch and time, provided this convention is followed exactly, and the batch_size=... parameter is correctly defined for the spiking model.

Here, we can see that the input dimensionality is (64, 28, 28) = (batch, time, channels):

for img in dataloader:
    print(img[0].shape)
    break
torch.Size([64, 28, 28])

For this reason, in the training we included a reshape(-1, 28) (to squeeze the dimensions in input) and a reshape((BATCH_SIZE, 28, 10)) on the output to restore the original form.

from tqdm.notebook import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(1):
    pbar = tqdm(dataloader)
    for img, target in pbar:
        optimizer.zero_grad()
        model.reset_states()

        img = img.reshape((-1, 28))  # merging time and batch dimensions
        out = model.spiking_model(img.to(device))
        out = out.reshape((BATCH_SIZE, 28, 10))  # restoring original dimensions

        # the output of the network is summed over the 28 time steps (rows)
        out = out.sum(1)
        loss = criterion(out, target.to(device))
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

Testing#

accs = []

pbar = tqdm(dataloader_test)
for img, target in pbar:
    model.reset_states()

    img = img.reshape((-1, 28))  # merging time and batch dimensions
    out = model.spiking_model(img.to(device))
    out = out.reshape((BATCH_SIZE, 28, 10))  # restoring original dimensions

    out = out.sum(1)
    predicted = torch.max(out, axis=1)[1]
    acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE
    accs.append(acc)

print(sum(accs) / len(accs))
0.7990785256410257

This value, although not very high, shows as a proof of concept that the persistent state of the spiking network (the membrane potentials) can be exploited as a short-term memory for solving sequential tasks, provided the training procedure takes it into account.