Training by backpropagation through time (BPTT)#
BPTT is normally a procedure used while training recurrent neural networks. In the case of spiking networks, even if the network is not recurrent, it has a memory of its previous processing steps through the persistence of membrane potentials. Unlike normal neural networks, spiking networks have an internal state that lasts in time.
This is why BPTT can be used for more precise (but also much more computationally expensive) training in SNNs for sequential tasks. In sinabs, backpropagation in the spiking network is accomplished through a surrogate gradient method, since the spiking nonlinearity is not differentiable.
In this notebook, we will train a spiking network directly (without training an analog network first), on the Sequential MNIST task. In Sequential MNIST, a network is shown the 28x28-pixel MNIST digits one row after the other. The input to the network is a single row of 28 pixels, followed by the second one, etc, until all 28 rows are shown. At this point, the network makes a prediction on the digit label.
First, we define the MNIST dataset. Note that pixel values are between 0 and 1. We turn those values into probabilities of spiking.
from torchvision import datasets import torch torch.manual_seed(0) class MNIST(datasets.MNIST): def __init__(self, root, train=True, single_channel=False): datasets.MNIST.__init__(self, root, train=train, download=True) self.single_channel = single_channel def __getitem__(self, index): img, target = self.data[index], self.targets[index] img = img.float() / 255.0 # default is by row, output is [time, channels] = [28, 28] # OR if we want by single item, output is [784, 1] if self.single_channel: img = img.reshape(-1).unsqueeze(1) spikes = torch.rand(size=img.shape) < img spikes = spikes.float() return spikes, target
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") BATCH_SIZE = 64 dataset_test = MNIST(root="./data/", train=False) dataloader_test = torch.utils.data.DataLoader( dataset_test, batch_size=BATCH_SIZE, drop_last=True ) dataset = MNIST(root="./data/", train=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, drop_last=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Training a baseline#
We must demonstrate that this task is not solvable with similar accuracy by a memory-less analog network, despite being sequential. Let us then try to train such a baseline.
from torch import nn ann = nn.Sequential( nn.Linear(28, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 10), nn.ReLU(), )
from tqdm.notebook import tqdm criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(ann.parameters()) for epoch in range(2): pbar = tqdm(dataloader) for img, target in pbar: optimizer.zero_grad() target = target.unsqueeze(1).repeat([1, 28]) img = img.reshape([-1, 28]) target = target.reshape([-1]) out = ann(img) # out = out.sum(1) loss = criterion(out, target) loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item())
accs =  pbar = tqdm(dataloader_test) for img, target in pbar: img = img.reshape([-1, 28]) out = ann(img) out = out.reshape([64, 28, 10]) out = out.sum(1) predicted = torch.max(out, axis=1) acc = (predicted == target).sum().numpy() / BATCH_SIZE accs.append(acc) print(sum(accs) / len(accs))
Defining a spiking network#
We then define a 4-layer fully connected spiking neural network.
from sinabs.from_torch import from_model model = from_model(ann, batch_size=BATCH_SIZE).to(device) model = model.train()
Here, we begin training. Note that the state of the network must be reset at every iteration.
PyTorch convolutional layers don’t support inputs that aren’t 4-dimensional (batch, channels, height, width).
As a workaround, when using sinabs, you’ll have to squeeze the time and batch dimensions. Starting from data in the form (batch, time, channels, …), the data should be squeezed to (batch*time, channels, …).
The spiking layer will automatically unpack that dimension distinguishing between batch and time, provided this convention is followed exactly, and the
batch_size=... parameter is correctly defined for the spiking model.
Here, we can see that the input dimensionality is (64, 28, 28) = (batch, time, channels):
for img in dataloader: print(img.shape) break
torch.Size([64, 28, 28])
For this reason, in the training we included a
reshape(-1, 28) (to squeeze the dimensions in input)
reshape((BATCH_SIZE, 28, 10)) on the output to restore the original form.
from tqdm.notebook import tqdm criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) for epoch in range(1): pbar = tqdm(dataloader) for img, target in pbar: optimizer.zero_grad() model.reset_states() img = img.reshape((-1, 28)) # merging time and batch dimensions out = model.spiking_model(img.to(device)) out = out.reshape((BATCH_SIZE, 28, 10)) # restoring original dimensions # the output of the network is summed over the 28 time steps (rows) out = out.sum(1) loss = criterion(out, target.to(device)) loss.backward() optimizer.step() pbar.set_postfix(loss=loss.item())
accs =  pbar = tqdm(dataloader_test) for img, target in pbar: model.reset_states() img = img.reshape((-1, 28)) # merging time and batch dimensions out = model.spiking_model(img.to(device)) out = out.reshape((BATCH_SIZE, 28, 10)) # restoring original dimensions out = out.sum(1) predicted = torch.max(out, axis=1) acc = (predicted == target.to(device)).sum().cpu().numpy() / BATCH_SIZE accs.append(acc) print(sum(accs) / len(accs))
This value, although not very high, shows as a proof of concept that the persistent state of the spiking network (the membrane potentials) can be exploited as a short-term memory for solving sequential tasks, provided the training procedure takes it into account.