diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index b6a0432c865c52a819d13cdce460fa1a8f1d96b8..56e3d103ce02d058f3291421f2ddb15ea280c194 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -14,20 +14,23 @@ from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
 
-
-import torch
-import torch.nn as nn
 import torch.nn.functional as F
+import brevitas.nn as qnn
 
 class QuantizeSTE(torch.autograd.Function):
-    """Quantization avec Straight-Through Estimator (STE)"""
     @staticmethod
-    def forward(ctx, x, quantizer):
-        return quantizer(x, True)
+    def forward(ctx, weights, quantizer):
+        
+        q_weights = quantizer(weights, True)
+
+        #ctx.scale = quantizer.scale_factor  # On sauvegarde le scale pour backward
+        return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None  # STE: Passe le gradient inchangé
+        # STE : on passe directement le gradient au poids float
+        return grad_output, None
+
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
+    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
@@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron):
 
     self.quantizer = quantizer
 
+
+    self.alpha = 0.9
+
   @classmethod
   def from_dict(cls, dict):
     """
@@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self, unscale : bool = True):
+  def quantize_weight(self, unscale : bool = True, ema = False):
     """
     synaptic weight quantization
 
@@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = self.quantizer(p.data, unscale=unscale)
-      #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
-      #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      
+      # if ema:
+      #   print(self.alpha)
+      #   p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
+      #   self.alpha *= 0.1
+      #   #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      # else:
+      p.data = self.quantizer(p.data, True)
+     
   
   def quantize_hp(self, unscale : bool = True):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 5d938b43d64ea10d526d1d4ca74f5f24962e3bca..a7c1b1f403ed372cbce42b6bb57c044ce2c33add 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer
 from snntorch import LIF
 
     
-
+class QuantizeSTE(torch.autograd.Function):
+    """Quantization avec Straight-Through Estimator (STE)"""
+    @staticmethod
+    def forward(ctx, x, quantizer):
+        return quantizer(x, True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None  # STE: Passe le gradient inchangé
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron):
 
     if self.quantization_flag:
       self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
+      #input_.data = self.quantizer(input_.data, True)
 
 
     self.mem = self.mem+input_
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 0f25c44624f8e2766923149e70ecf45060e712dd..f1d7497a8893891ea15c8d2446f4a2b512848b1c 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer):
     -------
     Tensor
     """
-
     return data*self.scale_factor
\ No newline at end of file