From adc7744bbcea3982f5c136f14efa937cb412689a Mon Sep 17 00:00:00 2001 From: ahoni <aurelie.saulq@proton.me> Date: Fri, 28 Mar 2025 09:30:25 +0100 Subject: [PATCH] add qat test --- .../arch_builder/modules/ShiftLIF/shiftlif.py | 2 +- modneflib/modnef/modnef_torch/model.py | 13 +++++++++++-- .../modnef_neurons/modnef_torch_neuron.py | 5 ++++- .../modnef_neurons/srlif_model/rshiftlif.py | 1 - .../modnef_neurons/srlif_model/shiftlif.py | 4 +++- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 929ed26..35bb734 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -212,7 +212,7 @@ class ShiftLif(ModNEFArchMod): self.v_threshold = self.quantizer(self.v_threshold) - mem_file.close() + mem_file.close() def to_debugger(self, output_file : str = ""): """ diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index ae40072..0da59bc 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -92,6 +92,15 @@ class ModNEFModel(nn.Module): return super().train(mode=mode) + def init_quantizer(self): + """ + initialize quantizer of laters + """ + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + m.init_quantizer() + def quantize(self, force_init=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -106,11 +115,11 @@ class ModNEFModel(nn.Module): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) - def clamp(self): + def clamp(self, force_init=False): for m in self.modules(): if isinstance(m, ModNEFNeuron): - m.clamp() + m.clamp(force_init=force_init) def train(self, mode : bool = True, quant : bool = False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 06e00f6..b6a0432 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron): self.quantize_weight() self.quantize_hp() - def clamp(self): + def clamp(self, force_init=False): """ Clamp synaptic weight """ + if force_init: + self.init_quantizer() + for p in self.parameters(): p.data = self.quantizer.clamp(p.data) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 3cd77b3..f3dda9a 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) @classmethod diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index b0d7586..1906fb5 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron): else: self.mem = self.mem-self.__shift(self.mem) + if self.quantization_flag: + self.mem.data = self.quantizer(self.mem.data, True) + if self.hardware_estimation_flag: self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) @@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) @classmethod -- GitLab