From f5298795e20795b72df3d4dfdd8cfce022c8e7ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Sat, 29 Mar 2025 20:10:30 +0100
Subject: [PATCH] add qat test

---
 modneflib/modnef/modnef_torch/model.py        | 15 ++++++++++++++
 .../modnef_neurons/blif_model/blif.py         |  4 +---
 .../modnef_neurons/blif_model/rblif.py        | 20 +++++++++----------
 .../modnef_neurons/srlif_model/shiftlif.py    |  5 +++--
 .../quantizer/dynamic_scale_quantizer.py      |  2 +-
 5 files changed, 30 insertions(+), 16 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 0da59bc..152cd93 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -100,6 +100,21 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.init_quantizer()
+
+  def quantize_weight(self, force_init=False):
+    """
+    Quantize synaptic weight
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.init_quantizer()
+        m.quantize_weight()
   
   def quantize(self, force_init=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 50f18d8..2ba0908 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -232,6 +232,7 @@ class BLIF(ModNEFNeuron):
       input_.data = self.quantizer(input_.data, True)
       self.mem.data = self.quantizer(self.mem.data, True)
 
+
     self.reset = self.mem_reset(self.mem)
 
     if self.reset_mechanism == "subtract":
@@ -241,9 +242,6 @@ class BLIF(ModNEFNeuron):
     else:
       self.mem = self.mem*self.beta
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index c759fcf..586048c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -258,15 +258,15 @@ class RBLIF(ModNEFNeuron):
 
     rec = self.reccurent(self.spk)
 
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
-    #   input_.data = self.quantizer(input_.data, True)
-    #   rec.data = self.quantizer(rec.data, True)
-
     if self.quantization_flag:
-      self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
-      input_ = QuantizeMembrane.apply(input_, self.quantizer)
-      rec = QuantizeMembrane.apply(rec, self.quantizer)
+      self.mem.data = self.quantizer(self.mem.data, True)
+      input_.data = self.quantizer(input_.data, True)
+      rec.data = self.quantizer(rec.data, True)
+
+    # if self.quantization_flag:
+    #   self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
+    #   input_ = QuantizeMembrane.apply(input_, self.quantizer)
+    #   rec = QuantizeMembrane.apply(rec, self.quantizer)
 
     if self.reset_mechanism == "subtract":
       self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
@@ -279,8 +279,8 @@ class RBLIF(ModNEFNeuron):
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
 
     self.spk = self.fire(self.mem)
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 1906fb5..654027d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -264,8 +264,8 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
@@ -331,6 +331,7 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    print(self.threshold)
 
 
   @classmethod
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 2631170..0f25c44 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -114,7 +114,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       weight = torch.Tensor(weight)
 
     if not torch.is_tensor(rec_weight):
-      rec_weight = torch.Tensor(weight)
+      rec_weight = torch.Tensor(rec_weight)
 
     if self.signed==None:
       self.signed = torch.min(weight.min(), rec_weight.min())<0.0
-- 
GitLab