Generate rate-based spike trains from normalised float number
We start by importing all relevant libraries.
[4]:
import torch
In practice, the inputs for Spiking Neural Networks are binary values where a 1 at a given time step corresponds to a spike at the corresponding time step.
Below is a simple method that converts an image to a stream of spikes. You will see that the longer the time window we use to convert a static image to spikes, the more accurate the data representation is, as evident from the L2 distance measure.
[5]:
def get_spike_train(time_win, input_image):
input_image = torch.rand(1, 64, 64)
# randomize a tensor accordingly with #time_win per pixel
random_tensor = torch.rand(time_win, 1, 64, 64)
# generating 1 if random number is lower than the pixel value of the input_image
converted_spike_train = (random_tensor < input_image).float()
# imag_original, is of 64*64 from input_image
img_original = input_image[0]
# img_converted, is the counted spikes over the time_win divided by the time_win
img_converted = converted_spike_train.sum(0)[0]/time_win
# the L2 distance between these two images
dist = torch.dist(img_original, img_converted, 2).item()
print("L2 distance between original image and converted spike trains: ", dist)
return converted_spike_train
# Define a random image
input_image = torch.rand(1, 64, 64)
# Longer time_win results in more precise conversion
time_win_list = [10, 100, 1000]
for time_win in time_win_list:
get_spike_train(time_win, input_image)
L2 distance between original image and converted spike trains: 8.164995193481445
L2 distance between original image and converted spike trains: 2.593313694000244
L2 distance between original image and converted spike trains: 0.8180820941925049
[ ]: