diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 01f59f90e268fd3d374aa169783b9021733365f8..c759fcf8212af5aea1aec88b703bf488f7021f84 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -18,6 +18,21 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
 
+import torch.autograd as autograd
+
+class QuantizeMembrane(autograd.Function):
+    @staticmethod
+    def forward(ctx, U, quantizer):
+        max_val = U.abs().max().detach()  # Détachement pour éviter de bloquer le gradient
+        U_quant = quantizer(U, True)
+        ctx.save_for_backward(U, quantizer.scale_factor)
+        return U_quant
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input, factor = ctx.saved_tensors
+        return grad_output, None
+
 class RBLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent BLIF neuron model
@@ -243,10 +258,15 @@ class RBLIF(ModNEFNeuron):
 
     rec = self.reccurent(self.spk)
 
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
+    #   input_.data = self.quantizer(input_.data, True)
+    #   rec.data = self.quantizer(rec.data, True)
+
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec.data = self.quantizer(rec.data, True)
+      self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
+      input_ = QuantizeMembrane.apply(input_, self.quantizer)
+      rec = QuantizeMembrane.apply(rec, self.quantizer)
 
     if self.reset_mechanism == "subtract":
       self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold