Add custom hooks to monitor network properties#

As shown in this how-to, Sinabs provides functions to monitor network activities such as synaptic operations and firing rates. For this it uses the hook-mechanism of PyTorch modules. This makes it easy to monitor custom statistics by writing our own hooks. In this how-to we will see how we can use this to keep track of the number of neurons and shapes of the spiking layers in our networks.

Setup and network definition#

Let’s start by importing all necessary packages and by setting up a simple SNN in sinabs.

from typing import Any, List
import torch
from torch import nn
from sinabs import layers as sl
import sinabs.hooks

# - Define SNN
class SNN(nn.Sequential):
    def __init__(self, batch_size):
        super().__init__(
            sl.FlattenTime(),
            nn.Conv2d(1, 16, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(16, 32, 5, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.SumPool2d(2),
            nn.Conv2d(32, 120, 4, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            nn.Flatten(),
            nn.Linear(120, 10, bias=False),
            sl.IAFSqueeze(batch_size=batch_size),
            sl.UnflattenTime(batch_size=batch_size),
        )

batch_size = 5
snn = SNN(batch_size=batch_size)

Set up hook#

Now let’s define a hook that captures the shape of our spiking layers. It will do so by looking at the shape of the output of the layers.

In general, a hook is a function with the following signature:

It has three parameters. The first is the module to which the hook is registered. The other two are a list of all inputs to the layer, as well as the output of the layer. If registered with a PyTorch module, the hook will always be executed after the forward method of that module is called.

Hook definition#

Let’s define our custom hook:

# Define hook
def monitor_shape(module: nn.Module, input_: List[Any], output: Any):
    batch_size, *neuron_shape = output.shape
    hook_data = sinabs.hooks.get_hook_data_dict(module)
    hook_data["batch_size"] = batch_size
    hook_data["neuron_shape"] = neuron_shape
    hook_data["num_neurons"] = output[0].numel()

The hook calls the get_hook_data_dict function from the sinabs.hooks module. This is convenience function checks if the module already has an attribute hook_data. If so, it will return it. Otherwise it will create a dictionnary as that attribute and return it.

We then extract the information we need from the layer output and write it into the hook_data. The layer input will be ignored in this case.

Note that in principle the hook can do pretty much whatever it wants and we don’t have to use a hook_data attribute. However, it is a nice convenience to have some consistency between different hooks.

Hook registration#

To register the hook we just need to call the register_forward_hook method that any Pytorch Module object (including all Sinabs layers) has. We could therefore register our hook with any layer of the network. However, in this case we are only interested in the spiking layers.

# Register hooks
for layer in snn:
    if isinstance(layer, sl.IAFSqueeze):
        layer.register_forward_hook(monitor_shape)

Now our hook will be called automatically with each forward call of the spiking layers. There is no need to call this hook manually. The hook will then update the data in the hook_data dictionnary of these layers, where we can access it. Let’s run some data through the network to see if everything works as expected

# Forward pass
rand_input_spikes = (torch.ones((batch_size, 10, 1, 28, 28)) ).float()
snn(rand_input_spikes)

# Access and print hook data
for idx, layer in enumerate(snn):
    if hasattr(layer, "hook_data"):
        print(f"Layer {idx}:")
        print(f"\tBatch size: {layer.hook_data['batch_size']}")
        print(f"\tShape: {layer.hook_data['neuron_shape']} - {layer.hook_data['num_neurons']} neurons in total")
Layer 2:
	Batch size: 50
	Shape: [16, 24, 24] - 9216 neurons in total
Layer 5:
	Batch size: 50
	Shape: [32, 8, 8] - 2048 neurons in total
Layer 8:
	Batch size: 50
	Shape: [120, 1, 1] - 120 neurons in total
Layer 11:
	Batch size: 50
	Shape: [10] - 10 neurons in total

You might be surprised that the batch size is recorded as 50 and not 5 as defind above. This is because in Sinabs the batch and time dimensions are usually flattened out, so that the shape is compatible with the 2D operations of PyTorch, such as Conv2D. This is done by the Flatten layer in the beginning, and undone by the Unflatten layer in the end of the networks. Within the network, batch and time dimensions are only separated internally within the IAFSqueeze layers. Therefore to the hook it seems like the batch size is multiplied by the number of timesteps.

Are there backward hooks?#

You may have noticed that we are using forward hooks in this example and might be wondering whether there are also backward hooks. The answer is yes, Sinabs layers are essentially Pytorch Modules and therefore support both forward and backward hooks. To learn more about hooks, you can also check out the PyTorch docs.