From 2daec63a28a7655f37eef6f3d289939e4c407b78 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 31 Mar 2025 12:49:41 +0200
Subject: [PATCH] add test qat

---
 .../modnef_neurons/modnef_torch_neuron.py     | 35 +++++++++++++------
 .../modnef_neurons/srlif_model/shiftlif.py    | 12 +++++--
 .../quantizer/dynamic_scale_quantizer.py      |  1 -
 3 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index b6a0432..56e3d10 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -14,20 +14,23 @@ from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
 
-
-import torch
-import torch.nn as nn
 import torch.nn.functional as F
+import brevitas.nn as qnn
 
 class QuantizeSTE(torch.autograd.Function):
-    """Quantization avec Straight-Through Estimator (STE)"""
     @staticmethod
-    def forward(ctx, x, quantizer):
-        return quantizer(x, True)
+    def forward(ctx, weights, quantizer):
+        
+        q_weights = quantizer(weights, True)
+
+        #ctx.scale = quantizer.scale_factor  # On sauvegarde le scale pour backward
+        return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None  # STE: Passe le gradient inchangé
+        # STE : on passe directement le gradient au poids float
+        return grad_output, None
+
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
+    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
@@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron):
 
     self.quantizer = quantizer
 
+
+    self.alpha = 0.9
+
   @classmethod
   def from_dict(cls, dict):
     """
@@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self, unscale : bool = True):
+  def quantize_weight(self, unscale : bool = True, ema = False):
     """
     synaptic weight quantization
 
@@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = self.quantizer(p.data, unscale=unscale)
-      #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
-      #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      
+      # if ema:
+      #   print(self.alpha)
+      #   p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
+      #   self.alpha *= 0.1
+      #   #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      # else:
+      p.data = self.quantizer(p.data, True)
+     
   
   def quantize_hp(self, unscale : bool = True):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 5d938b4..a7c1b1f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer
 from snntorch import LIF
 
     
-
+class QuantizeSTE(torch.autograd.Function):
+    """Quantization avec Straight-Through Estimator (STE)"""
+    @staticmethod
+    def forward(ctx, x, quantizer):
+        return quantizer(x, True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None  # STE: Passe le gradient inchangé
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron):
 
     if self.quantization_flag:
       self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
+      #input_.data = self.quantizer(input_.data, True)
 
 
     self.mem = self.mem+input_
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 0f25c44..f1d7497 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer):
     -------
     Tensor
     """
-
     return data*self.scale_factor
\ No newline at end of file
-- 
GitLab