diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 152cd9337d81bd3a92af86598dc16dcf89877fb7..2c0d687f052fee19c932f904d2c2baf49e675762 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -113,7 +113,8 @@ class ModNEFModel(nn.Module):
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
-        m.init_quantizer()
+        if force_init:
+          m.init_quantizer()
         m.quantize_weight()
   
   def quantize(self, force_init=False):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 56e3d103ce02d058f3291421f2ddb15ea280c194..55298a14af1dca60bda5e2c4aab3e00a97d6131f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid
 
 import torch.nn.functional as F
 import brevitas.nn as qnn
+import brevitas.quant as bq
 
 class QuantizeSTE(torch.autograd.Function):
     @staticmethod
@@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function):
         
         q_weights = quantizer(weights, True)
 
-        #ctx.scale = quantizer.scale_factor  # On sauvegarde le scale pour backward
+        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
         return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
         # STE : on passe directement le gradient au poids float
+        #scale_factor, = ctx.saved_tensors
         return grad_output, None
 
 
@@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
-    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
+    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
@@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      
-      # if ema:
-      #   print(self.alpha)
-      #   p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
-      #   self.alpha *= 0.1
-      #   #p.data = QuantizeSTE.apply(p.data, self.quantizer)
-      # else:
-      p.data = self.quantizer(p.data, True)
+      print(p)
+      p.data = QuantizeSTE.apply(p.data, self.quantizer)
+
+    #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
      
   
   def quantize_hp(self, unscale : bool = True):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index a7c1b1f403ed372cbce42b6bb57c044ce2c33add..b0409a43311971e674f3a840f9ef58a921a607b1 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -21,14 +21,21 @@ from snntorch import LIF
 
     
 class QuantizeSTE(torch.autograd.Function):
-    """Quantization avec Straight-Through Estimator (STE)"""
     @staticmethod
-    def forward(ctx, x, quantizer):
-        return quantizer(x, True)
+    def forward(ctx, weights, quantizer):
+        
+        q_weights = quantizer(weights, True)
+
+        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
+        return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None  # STE: Passe le gradient inchangé
+        # STE : on passe directement le gradient au poids float
+        #scale_factor, = ctx.saved_tensors
+        return grad_output, None
+    
+
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron):
     self.reset = self.mem_reset(self.mem)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
       #input_.data = self.quantizer(input_.data, True)
 
 
@@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    print(self.threshold)
 
 
   @classmethod