From 3fe54c44ce67f18b7de34498e1b563d4f30c936a Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Tue, 25 Mar 2025 22:59:14 +0100
Subject: [PATCH] add test quantizer

---
 .../modnef_neurons/modnef_torch_neuron.py      | 18 +++++++++++++++++-
 .../modnef_neurons/srlif_model/shiftlif.py     |  3 +++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index c4a9800..06e00f6 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -14,6 +14,21 @@ from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
 
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class QuantizeSTE(torch.autograd.Function):
+    """Quantization avec Straight-Through Estimator (STE)"""
+    @staticmethod
+    def forward(ctx, x, quantizer):
+        return quantizer(x, True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None  # STE: Passe le gradient inchangé
+
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
   "MinMaxQuantizer" : MinMaxQuantizer,
@@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron):
     
     for p in self.parameters():
       p.data = self.quantizer(p.data, unscale=unscale)
+      #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
+      #p.data = QuantizeSTE.apply(p.data, self.quantizer)
   
   def quantize_hp(self, unscale : bool = True):
     """
@@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron):
 
     for p in self.parameters():
       p.data = self.quantizer.clamp(p.data)
-      print("clamp")
   
   def run_quantize(self, mode=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index f2c95cb..b0d7586 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -19,6 +19,8 @@ from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 from snntorch import LIF
 
+    
+
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron):
       self.mem.data = self.quantizer(self.mem.data, True)
       input_.data = self.quantizer(input_.data, True)
 
+
     self.mem = self.mem+input_
 
     if self.reset_mechanism == "subtract":
-- 
GitLab