diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 73225949001afeb0d5700f6f4a5c15c78b2079aa..0a1acd91e0af892d9b43ca10623d5d6203af65dc 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -10,6 +10,8 @@ Descriptions: ModNEF torch neuron interface builder import torch from modnef.quantizer import * +from snntorch._neurons import SpikingNeuron +from snntorch.surrogate import fast_sigmoid _quantizer = { "FixedPointQuantizer" : FixedPointQuantizer, @@ -17,7 +19,7 @@ _quantizer = { "DynamicScaleFactorQuantizer" : DynamicScaleFactorQuantizer } -class ModNEFNeuron(): +class ModNEFNeuron(SpikingNeuron): """ ModNEF torch neuron interface @@ -42,7 +44,27 @@ class ModNEFNeuron(): create and return the corresponding modnef archbuilder module from internal neuron parameters """ - def __init__(self, quantizer : Quantizer): + def __init__(self, + threshold, + reset_mechanism, + quantizer : Quantizer, + spike_grad=fast_sigmoid(slope=25)): + + SpikingNeuron.__init__( + self=self, + threshold=threshold, + reset_mechanism=reset_mechanism, + spike_gard=spike_grad, + surrogate_disable=False, + init_hidden=False, + inhibition=False, + learn_threshold=False, + state_quant=False, + output=False, + graded_spikes_factor=1.0, + learn_graded_spikes_factor=False + ) + self.hardware_estimation_flag = False self.quantization_flag = False @@ -61,7 +83,20 @@ class ModNEFNeuron(): raise NotImplementedError() - def quantize_weight(self): + def init_quantizer(self): + + params = list(self.parameters()) + + w1 = params[0].data + + if len(params)==2: + w2 = params[0].data + else: + w2 = torch.zeros((1)) + + self.quantizer.init_quantizer(w1, w2) + + def quantize_weight(self, unscaled : bool = False): """ synaptic weight quantization @@ -70,9 +105,10 @@ class ModNEFNeuron(): NotImplementedError() """ - raise NotImplementedError() + for param in self.parameters(): + param.data = self.quantizer(param.data, unscale=unscaled) - def quantize_parameters(self): + def quantize_hp(self): """ neuron hyper-parameters quantization """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 9bb6d1df8eeff1ab9044378d73e3625e4a9e4438..195e965f8f99ef994bca289886b182ad7af402fa 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -11,14 +11,14 @@ Based on snntorch.Leaky and snntroch.LIF class import torch.nn as nn import torch -from snntorch import LIF +from snntorch.surrogate import fast_sigmoid import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import log, ceil from modnef.quantizer import DynamicScaleFactorQuantizer -class ShiftLIF(LIF, ModNEFNeuron): +class ShiftLIF(ModNEFNeuron): """ ModNEFTorch Shift LIF neuron model @@ -86,7 +86,7 @@ class ShiftLIF(LIF, ModNEFNeuron): out_features, beta, threshold=1.0, - spike_grad=None, + spike_grad=fast_sigmoid(slope=25), reset_mechanism="subtract", quantizer=DynamicScaleFactorQuantizer(8) ): @@ -117,24 +117,13 @@ class ShiftLIF(LIF, ModNEFNeuron): print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}") beta = 1-2**-self.shift - LIF.__init__( - self=self, - beta=beta, - threshold=threshold, - spike_grad=spike_grad, - surrogate_disable=False, - init_hidden=False, - inhibition=False, - learn_beta=False, - learn_threshold=False, - reset_mechanism=reset_mechanism, - state_quant=False, - output=False, - graded_spikes_factor=1.0, - learn_graded_spikes_factor=False, - ) - ModNEFNeuron.__init__(self=self, quantizer=quantizer) + ModNEFNeuron.__init__(self=self, + threshold=threshold, + reset_mechanism=reset_mechanism, + spike_grad=spike_grad, + quantizer=quantizer + ) self.fc = nn.Linear(in_features, out_features, bias=False) diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py index 2631170b678e8ece885ed61e4e16a2efe17340db..43a6aa1fca47cb9e36775b0596a95a0b9c61c6b9 100644 --- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py +++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py @@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer): is_initialize=config["is_initialize"] ) - def init_from_weight(self, weight, rec_weight=torch.zeros((1))): + def init_quantizer(self, weight, rec_weight=torch.zeros((1))): """ initialize quantizer parameters from synaptic weight diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py index 106a242180b8b83b5d47ebe1dbf0e4f7604f63cd..c48e690cdbc31e131a1351c02330c7a8fc3e148c 100644 --- a/modneflib/modnef/quantizer/fixed_point_quantizer.py +++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py @@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer): is_initialize=config["is_initialize"] ) - def init_from_weight(self, weight, rec_weight=torch.zeros((1))): + def init_quantizer(self, weight, rec_weight=torch.zeros((1))): """ initialize quantizer parameters from synaptic weight diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py index 6ca813180dfc9d0de66ff4428564c6240b9f4649..ebc4ae9fb6cde07e030b35dad1a89f0c82f6ae00 100644 --- a/modneflib/modnef/quantizer/min_max_quantizer.py +++ b/modneflib/modnef/quantizer/min_max_quantizer.py @@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer): is_initialize=config["is_initialize"] ) - def init_from_weight(self, weight, rec_weight=torch.zeros((1))): + def init_quantizer(self, weight, rec_weight=torch.zeros((1))): """ initialize quantizer parameters from synaptic weight diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index d75316275c5aef34558d855e9fd015c900748ece..6ab752eb28bb5c601293cfe46cd1ce7e05ff4e5a 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -77,7 +77,7 @@ class Quantizer(): raise NotImplementedError() - def init_from_weight(self, weight, rec_weight=torch.zeros((1))): + def init_quantizer(self, weight, rec_weight=torch.zeros((1))): """ initialize quantizer parameters from synaptic weight @@ -112,10 +112,10 @@ class Quantizer(): else: tdata = data + qdata = self._quant(tdata) + if unscale: - qdata = self._unquant(self._quant(tdata)) - else: - qdata = self._quant(tdata) + qdata = self._unquant(qdata) if isinstance(data, (int, float)): return qdata.item()