From adc7744bbcea3982f5c136f14efa937cb412689a Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Fri, 28 Mar 2025 09:30:25 +0100
Subject: [PATCH] add qat test

---
 .../arch_builder/modules/ShiftLIF/shiftlif.py       |  2 +-
 modneflib/modnef/modnef_torch/model.py              | 13 +++++++++++--
 .../modnef_neurons/modnef_torch_neuron.py           |  5 ++++-
 .../modnef_neurons/srlif_model/rshiftlif.py         |  1 -
 .../modnef_neurons/srlif_model/shiftlif.py          |  4 +++-
 5 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
index 929ed26..35bb734 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
@@ -212,7 +212,7 @@ class ShiftLif(ModNEFArchMod):
 
       self.v_threshold = self.quantizer(self.v_threshold)
     
-    mem_file.close()        
+    mem_file.close() 
 
   def to_debugger(self, output_file : str = ""):
     """
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index ae40072..0da59bc 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -92,6 +92,15 @@ class ModNEFModel(nn.Module):
 
     return super().train(mode=mode)
   
+  def init_quantizer(self):
+    """
+    initialize quantizer of laters
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.init_quantizer()
+  
   def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
@@ -106,11 +115,11 @@ class ModNEFModel(nn.Module):
       if isinstance(m, ModNEFNeuron):
         m.quantize(force_init=force_init)
 
-  def clamp(self):
+  def clamp(self, force_init=False):
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
-        m.clamp()
+        m.clamp(force_init=force_init)
 
   def train(self, mode : bool = True, quant : bool = False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 06e00f6..b6a0432 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron):
     self.quantize_weight()
     self.quantize_hp()
   
-  def clamp(self):
+  def clamp(self, force_init=False):
     """
     Clamp synaptic weight
     """
 
+    if force_init:
+      self.init_quantizer()
+
     for p in self.parameters():
       p.data = self.quantizer.clamp(p.data)
   
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 3cd77b3..f3dda9a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index b0d7586..1906fb5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
@@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
-- 
GitLab