diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 55298a14af1dca60bda5e2c4aab3e00a97d6131f..10ef31c19cd5fb77642faee04403d6251480048a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -18,6 +18,10 @@ import torch.nn.functional as F
 import brevitas.nn as qnn
 import brevitas.quant as bq
 
+from brevitas.core.quant import QuantType
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.scaling import ScalingImplType
+
 class QuantizeSTE(torch.autograd.Function):
     @staticmethod
     def forward(ctx, weights, quantizer):
@@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
-    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
+    # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False,
+    #                           weight_quant_type=QuantType.INT, 
+    #                                  weight_bit_width=8,
+    #                                  weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
+    #                                  weight_scaling_impl_type=ScalingImplType.CONST,
+    #                                  weight_scaling_const=1.0
+    # )
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index b0409a43311971e674f3a840f9ef58a921a607b1..429d24a8392bcb180d7ccc47b624ff2a06e4575a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function):
         
         q_weights = quantizer(weights, True)
 
-        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
+        ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
         return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
         # STE : on passe directement le gradient au poids float
-        #scale_factor, = ctx.saved_tensors
-        return grad_output, None
+        scale_factor, = ctx.saved_tensors
+        return grad_output*scale_factor, None
     
 
 class ShiftLIF(ModNEFNeuron):
@@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron):
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
-      #input_.data = self.quantizer(input_.data, True)
+      input_.data = QuantizeSTE.apply(input_.data, self.quantizer)
 
 
     self.mem = self.mem+input_
diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391