Utility Function

Utility Function#

Utility function for plots in the documentation.

import matplotlib.pyplot as plt
import torch

import sinabs


def plot_evolution(neuron_model: sinabs.layers, input: torch.Tensor):
    neuron_model.reset_states()
    output = neuron_model(input)

    plt.figure(figsize=(10, 3))
    for key, recording in neuron_model.recordings.items():
        plt.plot(recording.detach().flatten(), drawstyle="steps", label=key)
    plt.plot(
        output.detach().flatten(), drawstyle="steps", color="black", label="output"
    )
    if not "spike_threshold" in neuron_model.recordings:
        plt.plot(
            [neuron_model.spike_threshold] * input.shape[1], label="spike_threshold"
        )
    plt.xlabel("time")
    plt.title(f"{neuron_model.__class__.__name__} neuron dynamics")
    plt.legend()