Source code for sinabs.backend.dynapcnn.config_builder
from abc import ABC, abstractmethod
from typing import Dict, List
import samna
from .dynapcnn_layer import DynapcnnLayer
from .mapping import LayerConstraints, get_valid_mapping
[docs]
class ConfigBuilder(ABC):
[docs]
@classmethod
@abstractmethod
def get_samna_module(self):
"""Get the samna parent module that hosts all the appropriate sub-modules and classes.
Returns:
samna module
"""
[docs]
@classmethod
@abstractmethod
def get_default_config(cls):
"""
Returns:
Default configuration for the device type
"""
[docs]
@classmethod
@abstractmethod
def build_config(
cls,
layers: Dict[int, DynapcnnLayer],
layer2core_map: Dict[int, int],
destination_map: Dict[int, List[int]],
):
"""Build the configuration given a model.
Args:
layers (Dict): Keys are layer indices, values are DynapcnnLayer instances.
layer2core_map (Dict): Keys are layer indices, values are corresponding
cores on hardware. Needed to map the destinations.
destination_map (Dict): Indices of destination layers for `layer`.
Returns:
Samna Configuration object
"""
[docs]
@classmethod
@abstractmethod
def get_constraints(cls) -> List[LayerConstraints]:
"""Returns the layer constraints of a the given device.
Returns:
List[LayerConstraints]
"""
[docs]
@classmethod
@abstractmethod
def monitor_layers(cls, config, layers: List[int]):
"""Enable the monitor for a given set of layers in the config object."""
[docs]
@classmethod
def map_layers_to_cores(cls, layers: Dict[int, DynapcnnLayer]) -> Dict[int, int]:
"""Find a mapping from DynapcnnLayers onto on-chip cores
Args:
layers: Dict with layer indices as keys and DynapcnnLayer instances as values.
Returns:
Dict mapping layer indices (keys) to assigned core IDs (values).
"""
return get_valid_mapping(layers, cls.get_constraints())
[docs]
@classmethod
def validate_configuration(cls, config) -> bool:
"""Check if a given configuration is valid.
Args:
config: Configuration object.
Returns:
True if the configuration is valid, else false
"""
is_valid, message = cls.get_samna_module().validate_configuration(config)
if not is_valid:
print(message)
return is_valid
[docs]
@classmethod
@abstractmethod
def get_output_buffer(cls):
"""Initialize and return the appropriate output buffer object Note that this just the
buffer object.
This does not actually connect the buffer object to the graph.
"""
[docs]
@classmethod
@abstractmethod
def reset_states(cls, config, randomize=False):
"""Randomize or reset the neuron states.
Args:
randomize (bool): If true, the states will be set to random initial values.
Else they will be set to zero
"""
[docs]
@classmethod
def set_all_v_mem_to_zeros(cls, samna_device, layer_id: int) -> None:
"""Reset all memory states to zeros.
Args:
samna_device: samna device object to erase vmem memory.
layer_id: layer index
"""
mod = cls.get_samna_module()
layer_constraint: LayerConstraints = cls.get_constraints()[layer_id]
events = []
for i in range(layer_constraint.neuron_memory):
event = mod.event.WriteNeuronValue()
event.address = i
event.layer = layer_id
event.neuron_state = 0
events.append(event)
temporary_source_node = cls.get_input_buffer()
temporary_graph = samna.graph.sequential(
[temporary_source_node, samna_device.get_model().get_sink_node()]
)
temporary_graph.start()
temporary_source_node.write(events)
temporary_graph.stop()
return