From d08d6933b098c70afc903223f4f36e79c682073b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 4 Apr 2025 12:42:29 +0200
Subject: [PATCH] add clamp to arch builder

---
 ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd  |  4 +---
 .../modnef/arch_builder/modules/BLIF/blif.py      |  4 ++--
 .../arch_builder/modules/BLIF/blif_debugger.py    |  4 ++--
 .../modnef/arch_builder/modules/BLIF/rblif.py     |  8 ++++----
 .../modnef/arch_builder/modules/SLIF/rslif.py     |  8 ++++----
 .../modnef/arch_builder/modules/SLIF/slif.py      |  4 ++--
 .../arch_builder/modules/SLIF/slif_debugger.py    |  4 ++--
 .../arch_builder/modules/ShiftLIF/rshiftlif.py    |  8 ++++----
 .../arch_builder/modules/ShiftLIF/shiftlif.py     |  4 ++--
 modneflib/modnef/modnef_torch/model.py            |  3 ++-
 .../modnef_neurons/blif_model/blif.py             | 15 +++++++--------
 .../modnef_neurons/blif_model/rblif.py            | 12 ++++++------
 .../modnef_neurons/modnef_torch_neuron.py         |  2 +-
 .../modnef_neurons/slif_model/rslif.py            |  4 +++-
 .../modnef_neurons/slif_model/slif.py             |  3 ++-
 .../modnef_neurons/srlif_model/rshiftlif.py       | 12 ++++++------
 .../modnef_neurons/srlif_model/shiftlif.py        | 12 ++++++------
 modneflib/modnef/quantizer/quantizer.py           |  7 ++++++-
 modneflib/modnef/quantizer/ste_quantizer.py       |  4 ++--
 19 files changed, 64 insertions(+), 58 deletions(-)

diff --git a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
index 468271e..83722e0 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
@@ -121,9 +121,7 @@ begin
             case state is 
               when multiplication =>
                 V_mult := std_logic_vector(signed(V) * signed(beta));
-                V_mult := std_logic_vector(shift_right(signed(V_mult), fixed_point));
-                V_buff := V_mult(variable_size-1 downto 0);
-                --V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
+                V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
 
                 if signed(V_buff) >= signed(v_threshold) then
                   spike <= '1';
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
index 9dc9af6..0caecf1 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
@@ -199,7 +199,7 @@ class BLif(ModNEFArchMod):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
 
-          w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw)
+          w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), bw)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -210,7 +210,7 @@ class BLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
index f6f7ca9..2940019 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
@@ -213,7 +213,7 @@ class BLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -224,7 +224,7 @@ class BLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
index 81f87c6..d427c78 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
@@ -205,7 +205,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -213,7 +213,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
index 1c15613..8ebfced 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
@@ -214,14 +214,14 @@ class RSLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     else:
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     mem_file.close()
@@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     else:
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     mem_file.close()
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif.py b/modneflib/modnef/arch_builder/modules/SLIF/slif.py
index ddf5c96..991c16e 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/slif.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/slif.py
@@ -201,7 +201,7 @@ class SLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
       self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size)
@@ -213,7 +213,7 @@ class SLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
 
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
index 7fb587c..dfca469 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
@@ -215,7 +215,7 @@ class SLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
       self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size)
@@ -227,7 +227,7 @@ class SLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
 
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
index f3077ae..bc4cb58 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
@@ -206,7 +206,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -214,7 +214,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
index 35bb734..698307e 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
@@ -196,7 +196,7 @@ class ShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -206,7 +206,7 @@ class ShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index f81295d..5e26bc2 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -131,7 +131,7 @@ class ModNEFModel(nn.Module):
           m.init_quantizer()
         m.quantize_weight()
   
-  def quantize(self, force_init=False):
+  def quantize(self, force_init=False, clamp=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
@@ -144,6 +144,7 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.quantize(force_init=force_init)
+        m.clamp(False)
 
   def clamp(self, force_init=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 08f17e7..0d4ae6e 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -230,19 +230,18 @@ class BLIF(ModNEFNeuron):
     
     self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem + forward_current - self.reset*self.threshold
+    self.mem = self.mem + forward_current
+    
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem*self.beta
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem*self.beta-self.reset*self.mem
-    else:
-      self.mem = self.mem*self.beta
+    self.mem = self.mem*self.beta
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 8b49564..f52eed0 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -247,16 +247,16 @@ class RBLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current + rec_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem*self.beta-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem*self.beta-self.reset*self.mem
-    else:
-      self.mem = self.mem*self.beta
+    self.mem = self.mem*self.beta
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 9a48b0c..208da8d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      p.data = QuantizeSTE.apply(p.data, self.quantizer, True)
      
   
   def quantize_hp(self):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index c1915a7..38364e2 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -255,6 +255,8 @@ class RSLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current + rec_current 
 
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
@@ -262,7 +264,7 @@ class RSLIF(ModNEFNeuron):
     # update neuron
     self.mem = self.mem - self.v_leak
     min_reset = (self.mem<self.v_min).to(torch.float32)
-    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
+    self.mem = self.mem-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 43e2e1c..6a6c7d3 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -245,6 +245,7 @@ class SLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current
 
+    self.mem = self.mem - self.reset*(self.mem-self.v_rest)
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
@@ -253,7 +254,7 @@ class SLIF(ModNEFNeuron):
     # update neuron
     self.mem = self.mem - self.v_leak
     min_reset = (self.mem<self.v_min).to(torch.float32)
-    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
+    self.mem = self.mem-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 0ea9983..65d8ba5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -261,16 +261,16 @@ class RShiftLIF(ModNEFNeuron):
 
     self.mem = self.mem+forward_current+rec_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem
-    else:
-      self.mem = self.mem-self.__shift(self.mem)
+    self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 5291733..70f9d96 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -252,16 +252,16 @@ class ShiftLIF(ModNEFNeuron):
 
     self.mem = self.mem+forward_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem
-    else:
-      self.mem = self.mem-self.__shift(self.mem)
+    self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index c56f506..e264950 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -91,7 +91,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def __call__(self, data, unscale=False):
+  def __call__(self, data, unscale=False, clamp=False):
     """
     Call quantization function
 
@@ -114,6 +114,11 @@ class Quantizer():
 
     qdata = self._quant(tdata)
 
+    if clamp:
+      born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1))
+      born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1)
+      qdata = torch.clamp(qdata, min=born_min, max=born_max)
+
     if unscale:
       qdata = self._unquant(qdata).to(torch.float32)
     
diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py
index 565cbd8..c5ffd32 100644
--- a/modneflib/modnef/quantizer/ste_quantizer.py
+++ b/modneflib/modnef/quantizer/ste_quantizer.py
@@ -25,7 +25,7 @@ class QuantizeSTE(torch.autograd.Function):
   """
   
   @staticmethod
-  def forward(ctx, data, quantizer):
+  def forward(ctx, data, quantizer, clamp=False):
     """
     Apply quantization method to data
 
@@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function):
       quantization method applied to data
     """
     
-    q_data = quantizer(data, True)
+    q_data = quantizer(data, True, clamp)
 
 
     ctx.scale = quantizer.scale_factor
-- 
GitLab