From f5298795e20795b72df3d4dfdd8cfce022c8e7ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com> Date: Sat, 29 Mar 2025 20:10:30 +0100 Subject: [PATCH] add qat test --- modneflib/modnef/modnef_torch/model.py | 15 ++++++++++++++ .../modnef_neurons/blif_model/blif.py | 4 +--- .../modnef_neurons/blif_model/rblif.py | 20 +++++++++---------- .../modnef_neurons/srlif_model/shiftlif.py | 5 +++-- .../quantizer/dynamic_scale_quantizer.py | 2 +- 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index 0da59bc..152cd93 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -100,6 +100,21 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): m.init_quantizer() + + def quantize_weight(self, force_init=False): + """ + Quantize synaptic weight + + Parameters + ---------- + force_init = False : bool + force quantizer initialization + """ + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + m.init_quantizer() + m.quantize_weight() def quantize(self, force_init=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 50f18d8..2ba0908 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -232,6 +232,7 @@ class BLIF(ModNEFNeuron): input_.data = self.quantizer(input_.data, True) self.mem.data = self.quantizer(self.mem.data, True) + self.reset = self.mem_reset(self.mem) if self.reset_mechanism == "subtract": @@ -241,9 +242,6 @@ class BLIF(ModNEFNeuron): else: self.mem = self.mem*self.beta - if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) - if self.hardware_estimation_flag: self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index c759fcf..586048c 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -258,15 +258,15 @@ class RBLIF(ModNEFNeuron): rec = self.reccurent(self.spk) - # if self.quantization_flag: - # self.mem.data = self.quantizer(self.mem.data, True) - # input_.data = self.quantizer(input_.data, True) - # rec.data = self.quantizer(rec.data, True) - if self.quantization_flag: - self.mem = QuantizeMembrane.apply(self.mem, self.quantizer) - input_ = QuantizeMembrane.apply(input_, self.quantizer) - rec = QuantizeMembrane.apply(rec, self.quantizer) + self.mem.data = self.quantizer(self.mem.data, True) + input_.data = self.quantizer(input_.data, True) + rec.data = self.quantizer(rec.data, True) + + # if self.quantization_flag: + # self.mem = QuantizeMembrane.apply(self.mem, self.quantizer) + # input_ = QuantizeMembrane.apply(input_, self.quantizer) + # rec = QuantizeMembrane.apply(rec, self.quantizer) if self.reset_mechanism == "subtract": self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold @@ -279,8 +279,8 @@ class RBLIF(ModNEFNeuron): self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) - if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) + # if self.quantization_flag: + # self.mem.data = self.quantizer(self.mem.data, True) self.spk = self.fire(self.mem) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 1906fb5..654027d 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -264,8 +264,8 @@ class ShiftLIF(ModNEFNeuron): else: self.mem = self.mem-self.__shift(self.mem) - if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) + # if self.quantization_flag: + # self.mem.data = self.quantizer(self.mem.data, True) if self.hardware_estimation_flag: self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) @@ -331,6 +331,7 @@ class ShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) + print(self.threshold) @classmethod diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py index 2631170..0f25c44 100644 --- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py +++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py @@ -114,7 +114,7 @@ class DynamicScaleFactorQuantizer(Quantizer): weight = torch.Tensor(weight) if not torch.is_tensor(rec_weight): - rec_weight = torch.Tensor(weight) + rec_weight = torch.Tensor(rec_weight) if self.signed==None: self.signed = torch.min(weight.min(), rec_weight.min())<0.0 -- GitLab