diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 929ed2612460c4901afae0cb5ac1c7b210cb9c6a..35bb7344303f59276e43ce80ea060b7090519d4f 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -212,7 +212,7 @@ class ShiftLif(ModNEFArchMod): self.v_threshold = self.quantizer(self.v_threshold) - mem_file.close() + mem_file.close() def to_debugger(self, output_file : str = ""): """ diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index ae40072c5329af8bd93e3af75f5d0bc358364a00..0da59bc486daf528582ead68609895f0df203b5f 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -92,6 +92,15 @@ class ModNEFModel(nn.Module): return super().train(mode=mode) + def init_quantizer(self): + """ + initialize quantizer of laters + """ + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + m.init_quantizer() + def quantize(self, force_init=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -106,11 +115,11 @@ class ModNEFModel(nn.Module): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) - def clamp(self): + def clamp(self, force_init=False): for m in self.modules(): if isinstance(m, ModNEFNeuron): - m.clamp() + m.clamp(force_init=force_init) def train(self, mode : bool = True, quant : bool = False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 06e00f69aa72ba09d283c16dad334df80d901766..b6a0432c865c52a819d13cdce460fa1a8f1d96b8 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron): self.quantize_weight() self.quantize_hp() - def clamp(self): + def clamp(self, force_init=False): """ Clamp synaptic weight """ + if force_init: + self.init_quantizer() + for p in self.parameters(): p.data = self.quantizer.clamp(p.data) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 3cd77b3f33527d1cbeb15ed0cada7384f00b5e30..f3dda9a66825ea3e42ede3c0eeb4d44c48cff5d0 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) @classmethod diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index b0d75862ace0f96e202e63d1e4e2f4334cffd4b4..1906fb5ceac64fc22f8b4adae6a25d14656ae814 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron): else: self.mem = self.mem-self.__shift(self.mem) + if self.quantization_flag: + self.mem.data = self.quantizer(self.mem.data, True) + if self.hardware_estimation_flag: self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) @@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) @classmethod