Source code for sinabs.backend.dynapcnn.config_builder

import time
from abc import ABC, abstractmethod
from typing import List

import samna

from .dvs_layer import DVSLayer
from .mapping import LayerConstraints, get_valid_mapping


[docs] class ConfigBuilder(ABC):
[docs] @classmethod @abstractmethod def get_samna_module(self): """Get the saman parent module that hosts all the appropriate sub-modules and classes. Returns ------- samna module """
[docs] @classmethod @abstractmethod def get_default_config(cls): """ Returns ------- Returns the default configuration for the device type """
[docs] @classmethod @abstractmethod def build_config(cls, model: "DynapcnnNetwork", chip_layers: List[int]): """Build the configuration given a model. Parameters ---------- model: The target model chip_layers: Chip layers where the given model layers are to be mapped. Returns ------- Samna Configuration object """
[docs] @classmethod @abstractmethod def get_constraints(cls) -> List[LayerConstraints]: """Returns the layer constraints of a the given device. Returns ------- List[LayerConstraints] """
[docs] @classmethod @abstractmethod def monitor_layers(cls, config, layers: List[int]): """Enable the monitor for a given set of layers in the config object."""
[docs] @classmethod def get_valid_mapping(cls, model: "DynapcnnNetwork") -> List[int]: """Find a valid set of layers for a given model. Parameters ---------- model (DynapcnnNetwork): A model Returns ------- List of core indices corresponding to each layer of the model: The index of the core on chip to which the i-th layer in the model is mapped is the value of the i-th entry in the list. """ mapping = get_valid_mapping(model, cls.get_constraints()) # turn the mapping into a dict mapping = {m[0]: m[1] for m in mapping} # Check if there is a dvs layer in the model num_dynapcnn_cores = len(model.sequence) if isinstance(model.sequence[0], DVSLayer): num_dynapcnn_cores -= 1 # apply the mapping chip_layers_ordering = [mapping[i] for i in range(num_dynapcnn_cores)] return chip_layers_ordering
[docs] @classmethod def validate_configuration(cls, config) -> bool: """Check if a given configuration is valid. Parameters ---------- config: Configuration object Returns ------- True if the configuration is valid, else false """ is_valid, message = cls.get_samna_module().validate_configuration(config) if not is_valid: print(message) return is_valid
[docs] @classmethod @abstractmethod def get_input_buffer(cls): """Initialize and return the appropriate output buffer object Note that this just the buffer object. This does not actually connect the buffer object to the graph. (It is needed as of samna 0.21.0) """
[docs] @classmethod @abstractmethod def get_output_buffer(cls): """Initialize and return the appropriate output buffer object Note that this just the buffer object. This does not actually connect the buffer object to the graph. """
[docs] @classmethod @abstractmethod def reset_states(cls, config, randomize=False): """Randomize or reset the neuron states. Parameters ---------- randomize (bool): If true, the states will be set to random initial values. Else they will be set to zero """
[docs] @classmethod def set_all_v_mem_to_zeros(cls, samna_device, layer_id: int) -> None: """Reset all memory states to zeros. Parameters ---------- samna_device: samna device object to erase vmem memory. layer_id: layer index """ mod = cls.get_samna_module() layer_constraint: LayerConstraints = cls.get_constraints()[layer_id] events = [] for i in range(layer_constraint.neuron_memory): event = mod.event.WriteNeuronValue() event.address = i event.layer = layer_id event.neuron_state = 0 events.append(event) temporary_source_node = cls.get_input_buffer() temporary_graph = samna.graph.sequential( [temporary_source_node, samna_device.get_model().get_sink_node()] ) temporary_graph.start() temporary_source_node.write(events) temporary_graph.stop() return