From dbfdd4f21bdfb492f68ddca32bd78ee63345eb43 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 28 Mar 2025 14:36:52 +0100
Subject: [PATCH] add quantizer test

---
 .../modnef_neurons/blif_model/rblif.py        | 26 ++++++++++++++++---
 1 file changed, 23 insertions(+), 3 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 01f59f9..c759fcf 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -18,6 +18,21 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
 
+import torch.autograd as autograd
+
+class QuantizeMembrane(autograd.Function):
+    @staticmethod
+    def forward(ctx, U, quantizer):
+        max_val = U.abs().max().detach()  # Détachement pour éviter de bloquer le gradient
+        U_quant = quantizer(U, True)
+        ctx.save_for_backward(U, quantizer.scale_factor)
+        return U_quant
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input, factor = ctx.saved_tensors
+        return grad_output, None
+
 class RBLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent BLIF neuron model
@@ -243,10 +258,15 @@ class RBLIF(ModNEFNeuron):
 
     rec = self.reccurent(self.spk)
 
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
+    #   input_.data = self.quantizer(input_.data, True)
+    #   rec.data = self.quantizer(rec.data, True)
+
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec.data = self.quantizer(rec.data, True)
+      self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
+      input_ = QuantizeMembrane.apply(input_, self.quantizer)
+      rec = QuantizeMembrane.apply(rec, self.quantizer)
 
     if self.reset_mechanism == "subtract":
       self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
-- 
GitLab