From 025865dcd6173f85022ec30d4ffe8818c88ded91 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Thu, 3 Apr 2025 11:01:49 +0200
Subject: [PATCH] add qat

---
 .../modnef/arch_builder/modules/BLIF/blif.py  |  4 ++
 modneflib/modnef/modnef_torch/__init__.py     |  3 +-
 modneflib/modnef/modnef_torch/model.py        | 30 ++++++--
 .../modnef_neurons/blif_model/blif.py         | 41 +++++------
 .../modnef_neurons/blif_model/rblif.py        | 72 ++++++-------------
 .../modnef_neurons/modnef_torch_neuron.py     | 56 +++------------
 .../modnef_neurons/slif_model/rslif.py        | 68 ++++++++----------
 .../modnef_neurons/slif_model/slif.py         | 56 ++++++---------
 .../modnef_neurons/srlif_model/rshiftlif.py   | 50 ++++++-------
 .../modnef_neurons/srlif_model/shiftlif.py    | 57 ++++-----------
 modneflib/modnef/modnef_torch/quantLinear.py  | 59 +++++++++++++++
 modneflib/modnef/quantizer/__init__.py        |  3 +-
 modneflib/modnef/quantizer/ste_quantizer.py   | 62 ++++++++++++++++
 13 files changed, 291 insertions(+), 270 deletions(-)
 create mode 100644 modneflib/modnef/quantizer/ste_quantizer.py

diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
index 4d5f09a..b6a0663 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
@@ -193,12 +193,16 @@ class BLif(ModNEFArchMod):
     bw = self.quantizer.bitwidth
 
     mem_file = open(f"{output_path}/{self.mem_init_file}", 'w')
+
+    truc = open(f"temp_{self.mem_init_file}", 'w')
     
     if self.quantizer.signed:
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
+
           w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw)
+          truc.write(f"{i} {j} {two_comp(self.quantizer(weights[i][j]), bw)}\n")
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/modnef_torch/__init__.py b/modneflib/modnef/modnef_torch/__init__.py
index e57895f..2053dab 100644
--- a/modneflib/modnef/modnef_torch/__init__.py
+++ b/modneflib/modnef/modnef_torch/__init__.py
@@ -9,4 +9,5 @@ Descriptions: ModNEF torch lib definition
 
 from .modnef_neurons import *
 from .model_builder import ModNEFModelBuilder
-from .model import ModNEFModel
\ No newline at end of file
+from .model import ModNEFModel
+from .quantLinear import QuantLinear
\ No newline at end of file
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 2c0d687..f81295d 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -1,17 +1,15 @@
 """
 File name: model
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
-Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron
+Dependencies: torch, snntorch, modnef_torch_neuron, modnef_driver
 Descriptions: ModNEF SNN Model
 """
 
-import modnef.modnef_torch.modnef_neurons as mn
 import torch.nn as nn
 import torch
-from modnef.arch_builder import *
 from modnef.modnef_driver import load_driver_from_yaml
 from modnef.modnef_torch.modnef_neurons import ModNEFNeuron
 
@@ -101,6 +99,22 @@ class ModNEFModel(nn.Module):
       if isinstance(m, ModNEFNeuron):
         m.init_quantizer()
 
+  def quantize_hp(self, force_init=False):
+    """
+    Quantize neuron hyper parameters
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        if force_init:
+          m.init_quantizer()
+        m.quantize_hp()
+
   def quantize_weight(self, force_init=False):
     """
     Quantize synaptic weight
@@ -132,6 +146,14 @@ class ModNEFModel(nn.Module):
         m.quantize(force_init=force_init)
 
   def clamp(self, force_init=False):
+    """
+    Clamp synaptic weight with quantizer born
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 2ba0908..c435ed1 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -1,7 +1,7 @@
 """
 File name: blif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, math, snntorch, modnef.archbuilder, modnef_torch_neuron, modnef.quantizer
@@ -9,10 +9,8 @@ Descriptions: ModNEF torch BLIF neuron model
 Based on snntorch.Leaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
 from math import log, ceil
-from snntorch import Leaky
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
@@ -223,28 +221,28 @@ class BLIF(ModNEFNeuron):
     if not spk==None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
-
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      self.mem.data = self.quantizer(self.mem.data, True)
+    forward_current = self.fc(input_, quant)
 
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
+    
+    self.mem = self.mem + forward_current
 
-    self.reset = self.mem_reset(self.mem)
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
-      self.mem = (self.mem+input_)*self.beta-self.reset*self.threshold
+      self.mem = self.mem*self.beta#-self.reset*self.threshold
     elif self.reset_mechanism == "zero":
-      self.mem = (self.mem+input_)*self.beta-self.reset*self.mem
+      self.mem = self.mem*self.beta-self.reset*self.mem
     else:
       self.mem = self.mem*self.beta
 
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -295,19 +293,14 @@ class BLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
+    self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 586048c..8b49564 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -1,7 +1,7 @@
 """
 File name: rblif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,29 +9,14 @@ Descriptions: ModNEF torch reccurrent BLIF neuron model
 Based on snntorch.RLeaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import Leaky
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
+from modnef.modnef_torch.quantLinear import QuantLinear
 
-import torch.autograd as autograd
-
-class QuantizeMembrane(autograd.Function):
-    @staticmethod
-    def forward(ctx, U, quantizer):
-        max_val = U.abs().max().detach()  # Détachement pour éviter de bloquer le gradient
-        U_quant = quantizer(U, True)
-        ctx.save_for_backward(U, quantizer.scale_factor)
-        return U_quant
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        grad_input, factor = ctx.saved_tensors
-        return grad_output, None
 
 class RBLIF(ModNEFNeuron):
   """
@@ -141,7 +126,7 @@ class RBLIF(ModNEFNeuron):
     
     self.register_buffer("beta", torch.tensor(beta))
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self._init_mem()
 
@@ -246,41 +231,35 @@ class RBLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    forward_current = self.fc(input_, quant)
 
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_, device=self.spk.device)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
 
-    self.reset = self.mem_reset(self.mem)
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    rec = self.reccurent(self.spk)
+    self.reset = self.mem_reset(self.mem)
+    
+    rec_current = self.reccurent(self.spk, quant)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec.data = self.quantizer(rec.data, True)
+    self.mem = self.mem + forward_current + rec_current
 
-    # if self.quantization_flag:
-    #   self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
-    #   input_ = QuantizeMembrane.apply(input_, self.quantizer)
-    #   rec = QuantizeMembrane.apply(rec, self.quantizer)
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
-      self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
+      self.mem = self.mem*self.beta-self.reset*self.threshold
     elif self.reset_mechanism == "zero":
-      self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.mem
+      self.mem = self.mem*self.beta-self.reset*self.mem
     else:
       self.mem = self.mem*self.beta
 
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
-
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -331,19 +310,14 @@ class RBLIF(ModNEFNeuron):
     )
     return module
   
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
+    self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 10ef31c..49b31e5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -1,7 +1,7 @@
 """
 File name: modnef_torch_neuron
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.0.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch
@@ -12,31 +12,7 @@ import torch
 import torch.nn as nn
 from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
-from snntorch.surrogate import fast_sigmoid
-
-import torch.nn.functional as F
-import brevitas.nn as qnn
-import brevitas.quant as bq
-
-from brevitas.core.quant import QuantType
-from brevitas.core.restrict_val import RestrictValueType
-from brevitas.core.scaling import ScalingImplType
-
-class QuantizeSTE(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, weights, quantizer):
-        
-        q_weights = quantizer(weights, True)
-
-        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
-        return q_weights
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        # STE : on passe directement le gradient au poids float
-        #scale_factor, = ctx.saved_tensors
-        return grad_output, None
-
+from ..quantLinear import QuantLinear
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -82,15 +58,8 @@ class ModNEFNeuron(SpikingNeuron):
       spike_grad=spike_grad,
       reset_mechanism=reset_mechanism
     )
-    
-    # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False,
-    #                           weight_quant_type=QuantType.INT, 
-    #                                  weight_bit_width=8,
-    #                                  weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
-    #                                  weight_scaling_impl_type=ScalingImplType.CONST,
-    #                                  weight_scaling_const=1.0
-    # )
-    self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
+
+    self.fc = QuantLinear(in_features=in_features, out_features=out_features)
 
 
     self.hardware_estimation_flag = False
@@ -126,7 +95,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self, unscale : bool = True, ema = False):
+  def quantize_weight(self):
     """
     synaptic weight quantization
 
@@ -136,21 +105,13 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      print(p)
       p.data = QuantizeSTE.apply(p.data, self.quantizer)
-
-    #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
      
   
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
     
     raise NotImplementedError()
@@ -174,6 +135,11 @@ class ModNEFNeuron(SpikingNeuron):
   def clamp(self, force_init=False):
     """
     Clamp synaptic weight
+
+    Parameters
+    ----------
+    force_init = Fasle : bool
+      force quantizer initialization
     """
 
     if force_init:
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 7d9c21f..c1915a7 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -1,7 +1,7 @@
 """
 File name: rslif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent SLIF neuron model
 Based on snntorch.RLeaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
-from modnef.quantizer import MinMaxQuantizer
+from modnef.quantizer import MinMaxQuantizer, QuantizeSTE
+from modnef.modnef_torch.quantLinear import QuantLinear
 
 class RSLIF(ModNEFNeuron):
   """
@@ -131,7 +130,7 @@ class RSLIF(ModNEFNeuron):
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self._init_mem()
 
@@ -238,45 +237,39 @@ class RSLIF(ModNEFNeuron):
 
     if not spk == None:
       self.spk = spk
-    
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.ones_like(input_)*self.v_rest
-    
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_)
+    forward_current = self.fc(input_, quant)
 
-    self.reset = self.mem_reset(self.mem)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest
 
-    rec_input = self.reccurent(self.spk)
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      rec_input.data = self.quantizer(rec_input.data, True)
-      self.mem = self.quantizer(self.mem.data, True)
 
-    self.mem = self.mem + input_ + rec_input
+    self.reset = self.mem_reset(self.mem)
+
+    rec_current = self.reccurent(self.spk, quant)
+
+    self.mem = self.mem + forward_current + rec_current 
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
+    # update neuron
     self.mem = self.mem - self.v_leak
+    min_reset = (self.mem<self.v_min).to(torch.float32)
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
-    self.spk = self.fire(self.mem)
+    spk = self.fire(self.mem)
 
-    do_spike_reset = (self.spk/self.graded_spikes_factor - self.reset)
-    do_min_reset = (self.mem<self.v_min).to(torch.float32)
-
-    self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest)
-    self.mem = self.mem - do_min_reset*(self.mem-self.v_rest)
-
-    return self.spk, self.mem
+    return spk, self.mem
   
   def get_builder_module(self, module_name : str, output_path : str = "."):
     """
@@ -324,21 +317,16 @@ class RSLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
-    self.v_min.data = self.quantizer(self.v_min.data, unscale)
-    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
-    self.threshold.data = self.quantizer(self.threshold, unscale)
+    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 089519b..43e2e1c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -1,7 +1,7 @@
 """
 File name: slif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,12 @@ Descriptions: ModNEF torch SLIF neuron model
 Based on snntorch.Leaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
-from modnef.quantizer import MinMaxQuantizer
+from modnef.quantizer import MinMaxQuantizer, QuantizeSTE
 
 class SLIF(ModNEFNeuron):
   """
@@ -235,36 +233,33 @@ class SLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.ones_like(input_)*self.v_rest
-
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      self.mem.data = self.quantizer(self.mem.data, True)
+    forward_current = self.fc(input_, quant)
 
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest
+    
+    
     self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem + input_
+    self.mem = self.mem + forward_current
+
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
-      
-    self.mem = self.mem-self.v_leak
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
+
+    # update neuron
+    self.mem = self.mem - self.v_leak
+    min_reset = (self.mem<self.v_min).to(torch.float32)
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     spk = self.fire(self.mem)
 
-    do_spike_reset = (spk/self.graded_spikes_factor - self.reset)
-    do_min_reset = (self.mem<self.v_min).to(torch.float32)
-
-    self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest)
-    self.mem = self.mem - do_min_reset*(self.mem-self.v_rest)
-
     return spk, self.mem
   
   def get_builder_module(self, module_name : str, output_path : str = "."):
@@ -314,21 +309,16 @@ class SLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
-    self.v_min.data = self.quantizer(self.v_min.data, unscale)
-    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
-    self.threshold.data = self.quantizer(self.threshold, unscale)
+    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index f3dda9a..0ea9983 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: rsrlif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent Shift LIF neuron model
 Based on snntorch.RLeaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
+from modnef.modnef_torch.quantLinear import QuantLinear
 from math import log, ceil
-from modnef.quantizer import DynamicScaleFactorQuantizer
+from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE
 
 
 class RShiftLIF(ModNEFNeuron):
@@ -131,7 +130,7 @@ class RShiftLIF(ModNEFNeuron):
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self.register_buffer("beta", torch.tensor(beta))
 
@@ -246,25 +245,25 @@ class RShiftLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    forward_current = self.fc(input_, quant)
 
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_, device=self.spk.device)
-
-    self.reset = self.mem_reset(self.mem)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
+    
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    rec_input = self.reccurent(self.spk)
+    rec_current = self.reccurent(self.spk, quant)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec_input.data = self.quantizer(rec_input.data, True)
+    self.reset = self.mem_reset(self.mem)
 
+    self.mem = self.mem+forward_current+rec_current
 
-    self.mem = self.mem+input_+rec_input
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
@@ -274,11 +273,7 @@ class RShiftLIF(ModNEFNeuron):
       self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -330,18 +325,13 @@ class RShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 429d24a..5291733 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: srlif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,32 +9,15 @@ Descriptions: ModNEF torch Shift LIF neuron model
 Based on snntorch.Leaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
 from snntorch.surrogate import fast_sigmoid
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
-from modnef.quantizer import DynamicScaleFactorQuantizer
-from snntorch import LIF
+from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE
 
     
-class QuantizeSTE(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, weights, quantizer):
-        
-        q_weights = quantizer(weights, True)
-
-        ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
-        return q_weights
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        # STE : on passe directement le gradient au poids float
-        scale_factor, = ctx.saved_tensors
-        return grad_output*scale_factor, None
-    
 
 class ShiftLIF(ModNEFNeuron):
   """
@@ -258,23 +241,20 @@ class ShiftLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
-
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    quant = self.quantizer if self.quantization_flag else None
 
-    self.reset = self.mem_reset(self.mem)
+    forward_current = self.fc(input_, quant)
 
-    if self.quantization_flag:
-      self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
-      input_.data = QuantizeSTE.apply(input_.data, self.quantizer)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
 
+    self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem+input_
+    self.mem = self.mem+forward_current
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
@@ -283,12 +263,8 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
-
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -338,18 +314,13 @@ class ShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py
index e69de29..57f67b4 100644
--- a/modneflib/modnef/modnef_torch/quantLinear.py
+++ b/modneflib/modnef/modnef_torch/quantLinear.py
@@ -0,0 +1,59 @@
+"""
+File name: quantLinear
+Author: Aurélie Saulquin  
+Version: 0.1.1
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: torch, modnef.quantizer
+Descriptions: Quantized Linear torch layer
+"""
+
+import torch.nn as nn
+from modnef.quantizer import QuantizeSTE
+
+class QuantLinear(nn.Linear):
+  """
+  Quantized Linear torch layer
+  Extended from torch.nn.Linear
+
+  Methods
+  -------
+  forward(x, quantizer=None)
+    Apply linear forward, if quantizer!=None, quantized weight are used for linear
+  """
+
+  def __init__(self, in_features : int, out_features : int):
+    """
+    Initialize class
+
+    Parameters
+    ----------
+    in_features : int
+      input features of layer
+    out_features : int
+      output features of layer
+    """
+
+    super().__init__(in_features=in_features, out_features=out_features, bias=False)
+
+  def forward(self, x, quantizer=None):
+    """
+    Apply linear forward, if quantizer!=None, quantized weight are used for linear
+
+    Parameters
+    ----------
+    x : Torch
+      input spikes
+    quantizer = None : Quantizer
+      quantization method.
+      If None, full precision weight are used for linear
+    """
+
+    if quantizer!=None:
+      w = QuantizeSTE.apply(self.weight, quantizer)
+      w.data = quantizer.clamp(w)
+    else:
+      w = self.weight
+
+
+    return nn.functional.linear(x, w)
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/__init__.py b/modneflib/modnef/quantizer/__init__.py
index 2455aa9..2ef9cb9 100644
--- a/modneflib/modnef/quantizer/__init__.py
+++ b/modneflib/modnef/quantizer/__init__.py
@@ -10,4 +10,5 @@ Descriptions: ModNEF quantizer method
 from .quantizer import Quantizer
 from .fixed_point_quantizer import FixedPointQuantizer
 from .min_max_quantizer import MinMaxQuantizer
-from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
\ No newline at end of file
+from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
+from .ste_quantizer import QuantizeSTE
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py
new file mode 100644
index 0000000..565cbd8
--- /dev/null
+++ b/modneflib/modnef/quantizer/ste_quantizer.py
@@ -0,0 +1,62 @@
+"""
+File name: ste_quantizer
+Author: Aurélie Saulquin  
+Version: 0.1.0
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: torch
+Descriptions: Straight-Throught Estimator quantization method
+"""
+
+import torch
+
+class QuantizeSTE(torch.autograd.Function):
+  """
+  Straight-Throught Estimator quantization method
+
+  Methods
+  -------
+  @staticmethod
+  forward(ctx, data, quantizer)
+    Apply quantization method to data
+  @staticmethod
+  backward(ctx, grad_output)
+    Returns backward gradient
+  """
+  
+  @staticmethod
+  def forward(ctx, data, quantizer):
+    """
+    Apply quantization method to data
+
+    Parameters
+    ----------
+    ctx : torch.autograd.function.BackwardCFunction
+      Autograd context used to store variables for the backward pass
+    data : Tensor
+      data to quantize
+    quantizer : Quantizer
+      quantization method applied to data
+    """
+    
+    q_data = quantizer(data, True)
+
+
+    ctx.scale = quantizer.scale_factor
+
+    return q_data
+
+  @staticmethod
+  def backward(ctx, grad_output):
+    """
+    Return backward gradient without modificiation
+
+    Parameters
+    ----------
+    ctx : torch.autograd.function.BackwardCFunction
+      Autograd context used to store variables for the backward pass
+    grad_output : Tensor
+      gradient
+    """
+    
+    return grad_output, None
\ No newline at end of file
-- 
GitLab