From 6f0352f59a38f5505324d41ac948320dee706bb6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 14 Mar 2025 11:39:28 +0100
Subject: [PATCH 01/23] add dev branch

---
 modneflib/modnef/quantizer/quantizer.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 235981d..2aca58a 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -138,14 +138,12 @@ class Quantizer():
 
   def clamp(self, data):
     """
-    Call quantization function
+    Call clamp function
 
     Parameters
     ----------
     data : int | float | list | numpy.array | Tensor
       data to quantize
-    unscale = False : bool
-      If true, apply quantization and then, unquantize data to simulate quantization
 
     Returns
     -------
-- 
GitLab


From 8f8b87a5eb4ba802c69be56fab68f9049dfaf209 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 17 Mar 2025 11:48:18 +0100
Subject: [PATCH 02/23] remove unsed state in vhdl code

---
 ModNEF_Sources/modules/bias.vhd               |  0
 .../modules/neurons/BLIF/blif_parallel.vhd    | 14 +++---
 .../modules/neurons/BLIF/blif_sequential.vhd  | 14 +-----
 .../modules/neurons/BLIF/rblif_sequential.vhd | 14 +-----
 .../modules/neurons/SLIF/rslif_sequential.vhd | 14 +-----
 .../modules/neurons/SLIF/slif_parallel.vhd    | 14 +++---
 .../modules/neurons/SLIF/slif_sequential.vhd  | 14 +-----
 .../neurons/ShiftLif/rshiftlif_sequential.vhd | 14 +-----
 .../neurons/ShiftLif/shiftlif_parallel.vhd    | 14 +++---
 .../neurons/ShiftLif/shiftlif_sequential.vhd  | 14 +-----
 .../modnef_neurons/blif_model/blif.py         |  3 ++
 .../modnef_neurons/blif_model/rblif.py        |  3 ++
 .../modnef_neurons/slif_model/rslif.py        |  3 ++
 .../modnef_neurons/slif_model/slif.py         |  3 ++
 .../modnef_neurons/srlif_model/rshiftlif.py   |  5 ++-
 .../modnef_neurons/srlif_model/shiftlif.py    |  3 ++
 .../quantizer/dynamic_scale_quantizer.py      | 29 ++++++++----
 .../modnef/quantizer/fixed_point_quantizer.py | 27 +++++++-----
 .../modnef/quantizer/min_max_quantizer.py     | 25 ++++++++---
 modneflib/modnef/quantizer/quantizer.py       | 44 ++++++++++++++-----
 20 files changed, 131 insertions(+), 140 deletions(-)
 delete mode 100644 ModNEF_Sources/modules/bias.vhd

diff --git a/ModNEF_Sources/modules/bias.vhd b/ModNEF_Sources/modules/bias.vhd
deleted file mode 100644
index e69de29..0000000
diff --git a/ModNEF_Sources/modules/neurons/BLIF/blif_parallel.vhd b/ModNEF_Sources/modules/neurons/BLIF/blif_parallel.vhd
index 261ab77..6564a36 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/blif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/blif_parallel.vhd
@@ -130,7 +130,7 @@ architecture Behavioral of BLif_Parallel is
 
   -- type definition
   type reception_state_t    is (idle, request, get_data);
-  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration, arbitration_finish);
+  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration);
 
   -- ram signals
   signal data_read  : std_logic_vector((output_neuron*weight_size)-1 downto 0);
@@ -249,8 +249,9 @@ begin
           when check_arbitration =>
             
             if spikes = no_spike then
-              transmission_state <= arbitration_finish;
               o_emu_busy <= '0';
+              transmission_state <= idle;
+              tr_fsm_en := '0';
             else
               transmission_state <= request;
               arb_spikes <= spikes;
@@ -278,15 +279,12 @@ begin
           when wait_arbitration =>
             start_arb <= '0';
             if arb_busy = '0' then
-              transmission_state <= arbitration_finish;
+              transmission_state <= idle;
+              o_emu_busy <= '0';
+              tr_fsm_en := '0';
             else
               transmission_state <= wait_arbitration;
             end if;  
-              
-          when arbitration_finish =>
-            transmission_state <= idle;
-            o_emu_busy <= '0';
-            tr_fsm_en := '0';
         end case;
       end if;
     end if;
diff --git a/ModNEF_Sources/modules/neurons/BLIF/blif_sequential.vhd b/ModNEF_Sources/modules/neurons/BLIF/blif_sequential.vhd
index 00dd851..6cc760a 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/blif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/blif_sequential.vhd
@@ -88,7 +88,7 @@ architecture Behavioral of BLif_Sequential is
   -- type definition
   type array_t              is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
-  type reception_state_t    is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t    is (idle, request, get_data);
   
   -- ram signals
   signal data_read  : std_logic_vector((output_neuron*weight_size)-1 downto 0);
@@ -151,13 +151,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -169,10 +162,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -256,7 +245,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
             else
               transmission_state <= accept;
               o_req <= '0';
diff --git a/ModNEF_Sources/modules/neurons/BLIF/rblif_sequential.vhd b/ModNEF_Sources/modules/neurons/BLIF/rblif_sequential.vhd
index 11b43b5..745760a 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/rblif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/rblif_sequential.vhd
@@ -88,7 +88,7 @@ architecture Behavioral of RBLif_Sequential is
 
   -- type definition
   type array_t is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
-  type reception_state_t is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t is (idle, request, get_data);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
 
   -- output signals
@@ -161,13 +161,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -179,10 +172,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -305,7 +294,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
               rec_ram_en <= '1';
               rec_current_en <= '1';
             else
diff --git a/ModNEF_Sources/modules/neurons/SLIF/rslif_sequential.vhd b/ModNEF_Sources/modules/neurons/SLIF/rslif_sequential.vhd
index d975f60..cc01a68 100644
--- a/ModNEF_Sources/modules/neurons/SLIF/rslif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/SLIF/rslif_sequential.vhd
@@ -88,7 +88,7 @@ architecture Behavioral of RSLif_Sequential is
   
   -- type definition
   type array_t              is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
-  type reception_state_t    is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t    is (idle, request, get_data);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
 
   -- output signals
@@ -161,13 +161,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -179,10 +172,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -305,7 +294,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
               rec_ram_en <= '1';
               rec_current_en <= '1';
             else
diff --git a/ModNEF_Sources/modules/neurons/SLIF/slif_parallel.vhd b/ModNEF_Sources/modules/neurons/SLIF/slif_parallel.vhd
index 5a50090..8c4b605 100644
--- a/ModNEF_Sources/modules/neurons/SLIF/slif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/SLIF/slif_parallel.vhd
@@ -131,7 +131,7 @@ architecture Behavioral of SLif_Parallel is
 
   -- type definition
   type reception_state_t    is (idle, request, get_data);
-  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration, arbitration_finish);
+  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration);
 
   -- ram signals
   signal data_read  : std_logic_vector((output_neuron*weight_size)-1 downto 0);
@@ -253,8 +253,9 @@ begin
 
           when check_arbitration =>
             if spikes = no_spike then
-              transmission_state <= arbitration_finish;
+              transmission_state <= idle;
               o_emu_busy <= '0';
+              tr_fsm_en := '0';
             else
               transmission_state <= request;
               arb_spikes <= spikes;
@@ -281,15 +282,12 @@ begin
           when wait_arbitration =>
             start_arb <= '0';
             if arb_busy = '0' then
-              transmission_state <= arbitration_finish;
+              transmission_state <= idle;
+              o_emu_busy <= '0';
+              tr_fsm_en := '0';
             else
               transmission_state <= wait_arbitration;
             end if;  
-              
-          when arbitration_finish =>
-            transmission_state <= idle;
-            o_emu_busy <= '0';
-            tr_fsm_en := '0';
         end case;
       end if;
     end if;
diff --git a/ModNEF_Sources/modules/neurons/SLIF/slif_sequential.vhd b/ModNEF_Sources/modules/neurons/SLIF/slif_sequential.vhd
index 391eaa6..362d982 100644
--- a/ModNEF_Sources/modules/neurons/SLIF/slif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/SLIF/slif_sequential.vhd
@@ -89,7 +89,7 @@ architecture Behavioral of SLif_Sequential is
   -- type definition
   type array_t              is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
-  type reception_state_t    is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t    is (idle, request, get_data);
 
   -- ram signals
   signal data_read  : std_logic_vector((output_neuron*weight_size)-1 downto 0);
@@ -149,13 +149,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -167,10 +160,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -259,7 +248,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
             else
               transmission_state <= accept;
               o_req <= '0';
diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_sequential.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_sequential.vhd
index fa3c0cb..9854bca 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_sequential.vhd
@@ -85,7 +85,7 @@ architecture Behavioral of RShiftLif_Sequential is
 
   -- type definition
   type array_t              is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
-  type reception_state_t    is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t    is (idle, request, get_data);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
 
   -- output signals
@@ -158,13 +158,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -176,10 +169,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -302,7 +291,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
               rec_ram_en <= '1';
               rec_current_en <= '1';
             else
diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
index f03d9e8..6c6989f 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
@@ -127,7 +127,7 @@ architecture Behavioral of ShiftLif_Parallel is
   end component;
 
   -- type definition
-  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration, arbitration_finish);
+  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration);
   type reception_state_t    is (idle, request, get_data);
 
   -- ram signals
@@ -249,8 +249,9 @@ begin
 
           when check_arbitration =>
             if spikes = no_spike then
-              transmission_state <= arbitration_finish;
+              transmission_state <= idle;
               o_emu_busy <= '0';
+              tr_fsm_en := '0';
             else
               transmission_state <= request;
               arb_spikes <= spikes;
@@ -277,15 +278,12 @@ begin
           when wait_arbitration =>
             start_arb <= '0';
             if arb_busy = '0' then
-              transmission_state <= arbitration_finish;
+              transmission_state <= idle;
+              o_emu_busy <= '0';
+              tr_fsm_en := '0';
             else
               transmission_state <= wait_arbitration;
             end if;  
-              
-          when arbitration_finish =>
-            transmission_state <= idle;
-            o_emu_busy <= '0';
-            tr_fsm_en := '0';
         end case;
       end if;
     end if;
diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_sequential.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_sequential.vhd
index b86bb30..0533fdc 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_sequential.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_sequential.vhd
@@ -85,7 +85,7 @@ architecture Behavioral of ShiftLif_Sequential is
 
   -- type definition
   type array_t is array(output_neuron-1 downto 0) of std_logic_vector(variable_size-1 downto 0);
-  type reception_state_t is (idle, request, wait_data, get_data, update_current);
+  type reception_state_t is (idle, request, get_data);
   type transmission_state_t is (idle, request, accept, get_voltage, emulate, set_voltage, emulate_finish);
 
   -- ram signals
@@ -146,13 +146,6 @@ begin
               reception_state <= request;
             end if;
 
-          when wait_data =>
-            if i_emu_busy = '1' then
-              reception_state <= get_data;
-            else
-              reception_state <= wait_data;
-            end if;
-
           when get_data =>
             spike_flag <= i_spike_flag;
             if i_emu_busy='0' and spike_flag = '0' then
@@ -164,10 +157,6 @@ begin
               current_en <= '1';
               reception_state <= get_data;
             end if;
-
-          when update_current =>
-            reception_state <= idle;
-            rec_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -252,7 +241,6 @@ begin
           when accept =>
             if i_ack <= '0' then
               transmission_state <= get_voltage;
-              o_emu_busy <= '1';
             else
               transmission_state <= accept;
               o_req <= '0';
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index d223c86..83cc04f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -252,6 +252,9 @@ class BLIF(Leaky, ModNEFNeuron):
     else:
       self.mem = self.mem*self.beta
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index d48960d..bf702cb 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -269,6 +269,9 @@ class RBLIF(Leaky, ModNEFNeuron):
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     self.spk = self.fire(self.mem)
 
     return self.spk, self.mem
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 694971e..2d78aa9 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -280,6 +280,9 @@ class RSLIF(LIF, ModNEFNeuron):
 
     self.mem = self.mem - self.v_leak
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     self.spk = self.fire(self.mem)
 
     do_spike_reset = (self.spk/self.graded_spikes_factor - self.reset)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index fa370b9..9d01858 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -270,6 +270,9 @@ class SLIF(LIF, ModNEFNeuron):
       
     self.mem = self.mem-self.v_leak
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     spk = self.fire(self.mem)
 
     do_spike_reset = (spk/self.graded_spikes_factor - self.reset)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 0193876..cdf4d26 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -281,7 +281,10 @@ class RShiftLIF(LIF, ModNEFNeuron):
     elif self.reset_mechanism == "zero":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem
     else:
-      self.mem = self.mem*self.beta
+      self.mem = self.mem-self.__shift(self.mem)
+
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 9bb6d1d..d4f58d3 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -269,6 +269,9 @@ class ShiftLIF(LIF, ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 724a8f4..2631170 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -110,10 +110,10 @@ class DynamicScaleFactorQuantizer(Quantizer):
 
     self.is_initialize = True
 
-    if torch.is_tensor(weight):
+    if not torch.is_tensor(weight):
       weight = torch.Tensor(weight)
 
-    if torch.is_tensor(rec_weight):
+    if not torch.is_tensor(rec_weight):
       rec_weight = torch.Tensor(weight)
 
     if self.signed==None:
@@ -130,7 +130,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
     #self.scale_factor = torch.max(torch.abs(weight).max(), torch.abs(weight).max())/2**(self.bitwidth-1)
 
 
-  def _quant(self, data, unscale) -> torch.Tensor:
+  def _quant(self, data) -> torch.Tensor:
     """
     Apply quantization
 
@@ -138,8 +138,6 @@ class DynamicScaleFactorQuantizer(Quantizer):
     ----------
     data : Tensor
       data to quantize
-    unscale = False : bool
-      If true, apply quantization and then, unquantize data to simulate quantization
 
     Returns
     -------
@@ -147,8 +145,21 @@ class DynamicScaleFactorQuantizer(Quantizer):
     """
 
     scaled = torch.round(data/self.scale_factor).to(self.dtype)
+    
+    return scaled
+    
+  def _unquant(self, data) -> torch.Tensor:
+    """
+    Unquantize data
 
-    if unscale:
-      return scaled*self.scale_factor
-    else:
-      return scaled
\ No newline at end of file
+    Parameters
+    ----------
+    data : Tensor
+      data to unquatize
+
+    Returns
+    -------
+    Tensor
+    """
+
+    return data*self.scale_factor
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index de6c789..106a242 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -147,7 +147,7 @@ class FixedPointQuantizer(Quantizer):
         self.scale_factor = 2**self.fixed_point
 
 
-  def _quant(self, data, unscale) -> torch.Tensor:
+  def _quant(self, data) -> torch.Tensor:
     """
     Apply quantization
 
@@ -155,8 +155,6 @@ class FixedPointQuantizer(Quantizer):
     ----------
     data : Tensor
       data to quantize
-    unscale = False : bool
-      If true, apply quantization and then, unquantize data to simulate quantization
 
     Returns
     -------
@@ -167,13 +165,20 @@ class FixedPointQuantizer(Quantizer):
 
     #scaled = torch.clamp(scaled, -2**(self.bitwidth-1), 2**(self.bitwidth-1)-1)
     
-    if unscale:
-      return (scaled.to(torch.float32))/self.scale_factor
-    else:
-      return scaled
+    return scaled
     
-  def _clamp(self, data):
-    b_min = (-2**(self.bitwidth-int(self.signed))*int(self.signed))/self.scale_factor
-    b_max = (2**(self.bitwidth-int(self.signed))-1)/self.scale_factor
+  def _unquant(self, data) -> torch.Tensor:
+    """
+    Unquantize data
+
+    Parameters
+    ----------
+    data : Tensor
+      data to unquatize
+
+    Returns
+    -------
+    Tensor
+    """
 
-    return torch.clamp(data, min=b_min, max=b_max)
\ No newline at end of file
+    return (data.to(torch.float32))/self.scale_factor
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index a700f6b..6ca8131 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -134,7 +134,7 @@ class MinMaxQuantizer(Quantizer):
     self.b_max = 2**(self.bitwidth-int(self.signed))-1
     self.b_min = -int(self.signed)*self.b_max
 
-  def _quant(self, data, unscale) -> torch.Tensor:
+  def _quant(self, data) -> torch.Tensor:
     """
     Apply quantization
 
@@ -142,8 +142,6 @@ class MinMaxQuantizer(Quantizer):
     ----------
     data : Tensor
       data to quantize
-    unscale = False : bool
-      If true, apply quantization and then, unquantize data to simulate quantization
 
     Returns
     -------
@@ -152,7 +150,20 @@ class MinMaxQuantizer(Quantizer):
 
     scaled = ((data-self.x_min)/(self.x_max-self.x_min)*(self.b_max-self.b_min)+self.b_min).to(self.dtype)
     
-    if unscale:
-      return (scaled-self.b_min)/(self.b_max-self.b_min)*(self.x_max-self.x_min)+self.x_min
-    else:
-      return scaled
\ No newline at end of file
+    return scaled
+    
+  def _unquant(self, data) -> torch.Tensor:
+    """
+    Unquantize data
+
+    Parameters
+    ----------
+    data : Tensor
+      data to unquatize
+
+    Returns
+    -------
+    Tensor
+    """
+
+    return (data-self.b_min)/(self.b_max-self.b_min)*(self.x_max-self.x_min)+self.x_min
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 2aca58a..d753162 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -106,19 +106,29 @@ class Quantizer():
     -------
     int | float | list | numpy.array | Tensor (depending on type of data)
     """
+
+    if not torch.is_tensor(data):
+      tdata = torch.tensor(data)
+    else:
+      tdata = data
+
+    if unscale:
+      qdata = self._unquant(self._quant(tdata))
+    else:
+      qdata = self._quant(tdata)
     
     if isinstance(data, (int, float)):
-      return self._quant(data=torch.tensor(data), unscale=unscale).item()
+      return qdata.item()
     elif isinstance(data, list):
-      return self._quant(data=torch.tensor(data), unscale=unscale).tolist()
+      return qdata.tolist()
     elif isinstance(data, np.ndarray):
-      return self._quant(data=torch.tensor(data), unscale=unscale).numpy()
+      return qdata.numpy()
     elif torch.is_tensor(data):
-      return self._quant(data=data, unscale=unscale).detach()
+      return qdata.detach()
     else:
       raise TypeError("Unsupported data type")
 
-  def _quant(self, data, unscale) -> torch.Tensor:
+  def _quant(self, data) -> torch.Tensor:
     """
     Apply quantization
 
@@ -126,15 +136,29 @@ class Quantizer():
     ----------
     data : Tensor
       data to quantize
-    unscale = False : bool
-      If true, apply quantization and then, unquantize data to simulate quantization
 
     Returns
     -------
     Tensor
     """
 
-    pass
+    raise NotImplementedError()
+
+  def _unquant(self, data) -> torch.Tensor:
+    """
+    Unquantize data
+
+    Parameters
+    ----------
+    data : Tensor
+      data to unquatize
+
+    Returns
+    -------
+    Tensor
+    """
+
+    raise NotImplementedError()
 
   def clamp(self, data):
     """
@@ -150,8 +174,8 @@ class Quantizer():
     int | float | list | numpy.array | Tensor (depending on type of data)
     """
 
-    born_min = -int(self.signed)*2**(self.bitwidth-1)
-    born_max = 2**(self.bitwidth-int(self.signed))-1
+    born_min = self._unquant(torch.tensor(-int(self.signed)*2**(self.bitwidth-1))).item()
+    born_max = self._unquant(torch.tensor(2**(self.bitwidth-int(self.signed))-1)).item()
     
     if isinstance(data, (int, float)):
       return torch.clamp(torch.tensor(data), min=born_min, max=born_max).item()
-- 
GitLab


From 33502b1d19f7a773144ae9dc62ddb4ab120a6f36 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Mon, 17 Mar 2025 16:04:19 +0100
Subject: [PATCH 03/23] remove single_step uart module

---
 ModNEF_Sources/modules/uart/uart_1step.vhd    | 390 ------------------
 .../arch_builder/modules/UART/uart_1step.py   | 242 -----------
 .../drivers/single_step_driver.py             | 136 ------
 3 files changed, 768 deletions(-)
 delete mode 100644 ModNEF_Sources/modules/uart/uart_1step.vhd
 delete mode 100644 modneflib/modnef/arch_builder/modules/UART/uart_1step.py
 delete mode 100644 modneflib/modnef/modnef_driver/drivers/single_step_driver.py

diff --git a/ModNEF_Sources/modules/uart/uart_1step.vhd b/ModNEF_Sources/modules/uart/uart_1step.vhd
deleted file mode 100644
index 50ecbe8..0000000
--- a/ModNEF_Sources/modules/uart/uart_1step.vhd
+++ /dev/null
@@ -1,390 +0,0 @@
-----------------------------------------------------------------------------------
---
--- Project : ModNEF
--- Component name : uart_1step
--- Depencies : uart_controller
---
--- Authors : Aurelie Saulquin
--- Email : aurelie.saulquin@univ-lille.fr
---
--- Version : 1.0
--- Version comment : stable version 
---
--- Licenses : cern-ohl-s-2.0
---
--- Description : 
--- UART component where one data transmission is use for one emulation step
--- Component will receive data, send all data to network and receive data 
--- from network and transmit it to computer
--- 
-----------------------------------------------------------------------------------
-
-library IEEE;
-use IEEE.std_logic_1164.all;
-use work.math.all;
-
-entity uart_1Step is
-  generic(
-    clk_freq          : integer := 100_000_000;
-    baud_rate         : integer := 115_200;
-
-    queue_read_depth  : integer := 32;
-    queue_read_type   : string := "fifo";
-
-    queue_write_type  : string := "fifo";
-
-    input_layer_size  : integer := 8;
-    output_layer_size : integer := 8
-  );
-  port (
-    i_clk             : in  std_logic;
-    i_en              : in  std_logic;
-
-    i_tx              : in  std_logic;
-    o_rx              : out std_logic;
-
-    i_emu_ready       : in  std_logic;
-    o_start_emu       : out std_logic;
-    o_reset_membrane  : out std_logic;
-
-    i_req             : in  std_logic;
-    o_ack             : out std_logic;
-    i_emu_busy        : in  std_logic;
-    i_spike_flag      : in  std_logic;
-    i_aer             : in  std_logic_vector(log_b(output_layer_size, 2)-1 downto 0);
-
-    o_req             : out std_logic;
-    i_ack             : in  std_logic;
-    o_emu_busy        : out std_logic;
-    o_spike_flag      : out std_logic;
-    o_aer             : out std_logic_vector(log_b(input_layer_size, 2)-1 downto 0)
-  );
-end uart_1Step;
-
-architecture Behavioral of uart_1Step is
-
-  component uart_controller is
-    generic(
-      clk_freq          : integer := 100_000_000;
-      baud_rate         : integer := 115_200;
-      oversamp_rate     : integer := 16;
-      
-      queue_read_depth  : integer := 64;
-      queue_read_width  : integer := 1;
-      queue_read_type   : string  := "fifo";
-
-      queue_write_depth : integer := 16;
-      queue_write_width : integer := 1;
-      queue_write_type  : string := "fifo"   -- fifo or lifo
-    );
-    port(
-      i_clk                 : in  std_logic;
-      i_en                  : in  std_logic;
-      o_busy                : out std_logic;
-      o_reset_detected      : out std_logic;
-
-      -- UART pins
-      i_tx                  : in  std_logic;
-      o_rx                  : out std_logic;
-      o_uart_busy           : out std_logic;       
-      
-      -- read I/O
-      o_read_data           : out std_logic_vector(queue_read_width*8-1 downto 0);
-      i_read_pop            : in  std_logic;
-      o_read_busy           : out std_logic;
-      o_read_queue_empty    : out std_logic;
-      o_read_queue_full     : out std_logic;
-
-      -- write I/O
-      i_start_transmission  : in  std_logic;
-      i_write_data          : in  std_logic_vector(queue_write_width*8-1 downto 0);
-      i_write_push          : in  std_logic;
-      o_write_busy          : out std_logic;
-      o_write_queue_empty   : out std_logic;
-      o_write_queue_full    : out std_logic
-        
-    );
-  end component;
-  
-  -- type definition
-  type emu_state_t              is (idle, wait_data, check_emu, emulate, wait_out_aer, send_aer, wait_transmission);
-  type uart_to_network_state_t  is (idle, check_emu, request, accept, transfert);
-  type network_to_uart_state_t  is (idle, wait_request, accept, wait_aer);
-
-  -- queue constant definition
-  --constant queue_read_depth : integer := 255;
-  constant queue_read_width : integer := log_b(queue_read_depth, 256);
-
-  constant queue_write_depth : integer := output_layer_size;
-  constant queue_write_width : integer := log_b(output_layer_size, 256);
-
-  -- read queue signals
-  signal read_data  : std_logic_vector(queue_read_width*8-1 downto 0) := (others=>'0');
-  signal read_pop   : std_logic := '0';
-  signal read_busy  : std_logic;
-  signal read_empty : std_logic;
-
-  -- write queue signals
-  signal write_data : std_logic_vector(queue_write_width*8-1 downto 0) := (others=>'0');
-  signal write_push : std_logic := '0';
-  signal write_busy : std_logic;
-  
-  -- uart signals
-  signal start_uart_transmission : std_logic := '0';
-
-  -- emulation signals
-  signal emu_state : emu_state_t := idle;
-  signal start_emu : std_logic;
-
-  -- membrane reset signals
-  signal reset_detected       : std_logic := '0';
-  signal reset_membrane       : std_logic := '0';
-
-  -- uart to network signals
-  signal uart_to_network_state  : uart_to_network_state_t := idle; 
-  signal uart_to_network_busy   : std_logic := '0';
-
-  -- network to uart signals
-  signal network_to_uart_state  : network_to_uart_state_t := idle;
-  signal network_to_uart_busy   : std_logic := '0';
-
-begin
-
-  o_start_emu <= start_emu;
-
-  o_reset_membrane <= reset_membrane;
-
-  -- controller FSM
-  process(i_clk, i_en)
-  begin
-    if rising_edge(i_clk) then
-      if i_en = '0' then
-        emu_state <= idle;
-        start_emu <= '0';
-        start_uart_transmission <= '0';
-      else
-        case emu_state is
-          when idle => 
-            start_emu <= '0';
-            start_uart_transmission <= '0';
-            reset_membrane <= '0';
-            if read_busy = '1' then
-              emu_state <= wait_data;
-            else
-              emu_state <= idle;
-            end if;
-
-          when wait_data =>
-            if read_busy = '0' then
-              emu_state <= check_emu;
-            else
-              emu_state <= wait_data;
-            end if;
-
-          when check_emu =>
-            if i_emu_ready = '1' then
-              emu_state <= emulate;
-              start_emu <= '1';
-            else
-              emu_state <= check_emu;
-            end if;
-
-          when emulate =>
-            start_emu <= '0';
-            if network_to_uart_busy = '1' then
-              emu_state <= wait_out_aer;
-            else
-              emu_state <= emulate;
-            end if;
-
-          when wait_out_aer =>
-            if i_emu_ready = '1' then
-              emu_state <= send_aer;
-              start_uart_transmission <= '1';
-              reset_membrane <= reset_detected;
-            else
-              emu_state <= wait_out_aer;
-            end if;
-
-          when send_aer =>
-            start_uart_transmission <= '0';
-            reset_membrane <= '0';
-            if  write_busy = '1' then
-              emu_state <= wait_transmission;
-            else
-              emu_state <= send_aer;
-            end if;
-
-          when wait_transmission =>
-            if write_busy = '0' then
-              emu_state <= idle;
-            else
-              emu_state <= wait_transmission;
-            end if;
-        end case;
-      end if;
-    end if;
-  end process;
-
-  -- Controller to network FSM
-  o_aer <= read_data(log_b(input_layer_size, 2)-1 downto 0) when uart_to_network_state = transfert else (others=>'0');
-  process(i_clk, i_en)
-  begin
-    if rising_edge(i_clk) then
-      if i_en = '0' then
-        o_req <= '0';
-        o_spike_flag <= '0';
-        o_emu_busy <= '0';
-        read_pop <= '0';
-        uart_to_network_busy <= '0';
-        uart_to_network_state <= idle;
-      else
-        case uart_to_network_state is
-          when idle =>
-            o_req <= '0';
-            o_spike_flag <= '0';
-            read_pop <= '0';
-            if start_emu = '1' then
-              uart_to_network_state <= check_emu;
-              o_emu_busy <= '1';
-              uart_to_network_busy <= '1';
-            else
-              uart_to_network_state <= idle;
-              o_emu_busy <= '0';
-              uart_to_network_busy <= '0';
-            end if;
-
-          when check_emu =>
-            if read_empty = '1' then
-              uart_to_network_state <= idle;
-              o_emu_busy <= '0';
-              uart_to_network_busy <= '0';
-            else
-              uart_to_network_state <= request;
-              o_req <= '1';
-            end if;
-
-          when request =>
-            if i_ack = '1' then
-              uart_to_network_state <= accept;
-              o_req <= '0';
-            else
-              uart_to_network_state <= request;
-              o_req <= '1';
-            end if;
-
-          when accept =>
-            if i_ack = '0' then
-              uart_to_network_state <= transfert;
-              read_pop <= '1';
-            else
-              uart_to_network_state <= accept;
-            end if;
-
-          when transfert =>
-            if read_empty = '1' then
-              uart_to_network_state <= idle;
-              o_emu_busy <= '0';
-              read_pop <= '0';
-              o_spike_flag <= '0';
-              uart_to_network_busy <= '0';
-            else
-              uart_to_network_state <= transfert;
-              read_pop <= '1';
-              o_spike_flag <= '1';
-            end if;
-        end case;
-      end if;
-    end if;
-  end process;
-  
-
-  write_data(log_b(output_layer_size, 2)-1 downto 0) <= i_aer when network_to_uart_state = wait_aer else (others=>'0');
-  write_push <= i_spike_flag when network_to_uart_state = wait_aer else '0';
-
-  -- Network to Controller FSM
-  process(i_clk, i_en)
-  begin
-    if i_en = '0' then
-      network_to_uart_state <= idle;
-      network_to_uart_busy <= '0';
-      o_ack <= '0';
-    else
-      if rising_edge(i_clk) then
-        case network_to_uart_state is
-          when idle =>
-            o_ack <= '0';
-            if i_emu_busy = '1' then
-              network_to_uart_state <= wait_request;
-              network_to_uart_busy <= '1';
-            else
-              network_to_uart_state <= idle;
-              network_to_uart_busy <= '0';
-            end if;
-
-          when wait_request =>
-            if i_emu_busy = '0' then
-              network_to_uart_state <= idle;
-              network_to_uart_busy <= '0';
-            elsif i_req = '1' then
-              o_ack <= '1';
-              network_to_uart_state <= accept;
-            else
-              network_to_uart_state <= wait_request;
-            end if;
-
-          when accept =>
-            if i_req = '0' then
-              network_to_uart_state <= wait_aer;
-              o_ack <= '0';
-            else
-              network_to_uart_state <= accept;
-              o_ack <= '1';
-            end if;
-
-          when wait_aer => 
-            if i_emu_busy = '0' then
-              network_to_uart_state <= idle;
-              network_to_uart_busy <= '0';
-            else 
-              network_to_uart_state <= wait_aer;
-            end if;
-        end case;
-      end if;
-    end if;
-  end process;
-
-  c_uart_controller : uart_controller generic map(
-    clk_freq => clk_freq,
-    baud_rate => baud_rate,
-    oversamp_rate => 16,
-
-    queue_read_depth => queue_read_depth,
-    queue_read_width => queue_read_width,
-    queue_read_type => queue_read_type,
-
-    queue_write_depth => queue_write_depth,
-    queue_write_width => queue_write_width,
-    queue_write_type => queue_write_type
-  ) port map(
-    i_clk => i_clk,
-    i_en => i_en,
-    o_busy => open,
-    o_reset_detected => reset_detected,
-    i_tx => i_tx,
-    o_rx => o_rx,
-    o_uart_busy => open,
-    o_read_data => read_data,
-    i_read_pop => read_pop,
-    o_read_busy => read_busy,
-    o_read_queue_empty => read_empty,
-    o_read_queue_full => open,
-    i_start_transmission => start_uart_transmission,
-    i_write_data => write_data,
-    i_write_push => write_push,
-    o_write_busy => write_busy,
-    o_write_queue_empty => open,
-    o_write_queue_full => open
-  );
-
-end Behavioral;
diff --git a/modneflib/modnef/arch_builder/modules/UART/uart_1step.py b/modneflib/modnef/arch_builder/modules/UART/uart_1step.py
deleted file mode 100644
index 8fca341..0000000
--- a/modneflib/modnef/arch_builder/modules/UART/uart_1step.py
+++ /dev/null
@@ -1,242 +0,0 @@
-"""
-File name: uart_1step
-Author: Aurélie Saulquin  
-Version: 2.0.0
-License: GPL-3.0-or-later
-Contact: aurelie.saulquin@univ-lille.fr
-Dependencies: io_arch, yaml
-Descriptions: UART_1Step ModNEF archbuilder module
-"""
-
-from ..io_arch import IOArch
-import yaml
-
-_UART_1STEP_DEFINITION = """
-  component uart_1Step is
-    generic(
-      clk_freq          : integer := 100_000_000;
-      baud_rate         : integer := 115_200;
-
-      queue_read_depth  : integer := 32;
-      queue_read_type   : string  := "fifo";
-
-      queue_write_type  : string  := "fifo";
-
-      input_layer_size  : integer := 8;
-      output_layer_size : integer := 8
-    );
-    port(
-      i_clk             : in  std_logic;
-      i_en              : in  std_logic;
-
-      i_tx              : in  std_logic;
-      o_rx              : out std_logic;
-
-      i_emu_ready       : in  std_logic;
-      o_start_emu       : out std_logic;
-      o_reset_membrane  : out std_logic;
-
-      i_req             : in  std_logic;
-      o_ack             : out std_logic;
-      i_emu_busy        : in  std_logic;
-      i_spike_flag      : in  std_logic;
-      i_aer             : in  std_logic_vector(log_b(output_layer_size, 2)-1 downto 0);
-
-      o_req             : out std_logic;
-      i_ack             : in  std_logic;
-      o_emu_busy        : out std_logic;
-      o_spike_flag      : out std_logic;
-      o_aer             : out std_logic_vector(log_b(input_layer_size, 2)-1 downto 0)
-    );
-  end component;
-"""
-
-class Uart_1Step(IOArch):
-  """
-  Uart_1Step module class
-  Each UART transmission correspond to an emulation step
-
-  Attributes
-  ----------
-  name : str
-    name of module
-  input_layer_size : int
-    size in neurons of input layer
-  output_layer_size : int
-    size in neurons of output layer
-  clk_freq : int
-    board clock frequency
-  baud_rate : int
-    data baud rate
-  queue_read_type : int
-    type of reaad queue
-  queue_write_type : int
-    type of write queue
-  tx_name : str
-    name of tx signal
-  rx_name : str
-    name of rx signal
-
-  Methods
-  -------
-  vhdl_component_name()
-    return component name
-  vhdl_component_definition()
-    return vhdl component definition
-  to_vhdl(vhdl_file, pred, suc, clock_name):
-    write vhdl component instanciation 
-  to_yaml(file):
-    generate yaml configuration file for driver
-  write_io(vhdl_file):
-    write signals into entity definition section
-  """
-
-  def __init__(self, 
-               name: str,
-               input_layer_size: int, 
-               output_layer_size: int,
-               clk_freq: int, 
-               baud_rate: int, 
-               tx_name: str, 
-               rx_name: str, 
-               queue_read_depth : int,
-               queue_read_type: str = "fifo", 
-               queue_write_type: str = "fifo"
-              ):
-    """
-    Initialize attributes
-
-    Parameters
-    ----------
-    name : str
-      name of module
-    clk_freq : int
-      board clock frequency
-    baud_rate : int
-      data baud rate
-    tx_name : str
-      name of tx signal
-    rx_name : str
-      name of rx signal
-    input_layer_size : int = -1
-      size in neurons of input layer
-    output_layer_size : int = -1
-      size in neurons of output layer
-    queue_read_type : str = "fifo"
-      read queue type : "fifo" or "lifo"
-    queue_write_type : str = "fifo"
-      write queue type : "fifo" or "lifo"
-    """
-
-    self.name = name
-    
-    self.input_neuron = output_layer_size
-    self.output_neuron = input_layer_size
-
-    self.input_layer_size = input_layer_size
-    self.output_layer_size = output_layer_size
-
-    self.clk_freq = clk_freq
-    self.baud_rate = baud_rate
-
-    self.queue_read_type = queue_read_type
-    self.queue_read_depth = queue_read_depth
-    self.queue_write_type = queue_write_type
-
-    self.tx_name = tx_name
-    self.rx_name = rx_name
-  
-  def vhdl_component_name(self):
-    """
-    Module identifier use during component definition
-
-    Returns
-    -------
-    str
-    """
-
-    return "Uart_1_Step"
-
-  
-  def vhdl_component_definition(self):
-    """
-    VHDL component definition
-
-    Returns
-    -------
-    str
-    """
-
-    return _UART_1STEP_DEFINITION
-
-  def to_vhdl(self, vhdl_file, pred, suc, clock_name):
-    """
-    Write vhdl componenent 
-
-    Parameters
-    ----------
-    vhdl_file : TextIOWrapper
-      vhdl file 
-    pred : List of ModNEFArchMod 
-      list of predecessor module (1 pred for this module)
-    suc : List of ModNEFArchMod
-      list of successor module (1 suc for this module)
-    clock_name : str
-      clock signal name
-    """
-
-    vhdl_file.write(f"\t{self.name} : uart_1step generic map(\n")
-    vhdl_file.write(f"\t\tclk_freq => {self.clk_freq},\n")
-    vhdl_file.write(f"\t\tbaud_rate => {self.baud_rate},\n")
-    vhdl_file.write(f"\t\tqueue_read_depth => {self.queue_read_depth},\n")
-    vhdl_file.write(f"\t\tqueue_read_type => \"{self.queue_read_type}\",\n")
-    vhdl_file.write(f"\t\tqueue_write_type => \"{self.queue_write_type}\",\n")
-    vhdl_file.write(f"\t\tinput_layer_size => {self.input_layer_size},\n")
-    vhdl_file.write(f"\t\toutput_layer_size => {self.output_layer_size}\n")
-    vhdl_file.write(f"\t) port map(\n")
-    vhdl_file.write(f"\t\ti_clk => {clock_name},\n")
-    vhdl_file.write("\t\ti_en => '1',\n")
-    vhdl_file.write(f"\t\ti_tx => {self.tx_name},\n")
-    vhdl_file.write(f"\t\to_rx => {self.rx_name},\n")
-    vhdl_file.write("\t\ti_emu_ready => emu_ready,\n")
-    vhdl_file.write("\t\to_start_emu => start_emu,\n")
-    vhdl_file.write("\t\to_reset_membrane => reset_membrane,\n")
-    self._write_port_map(vhdl_file, pred[0].name, self.name, "in", "", False)
-    self._write_port_map(vhdl_file, self.name, suc[0].name, "out", "", True)
-    vhdl_file.write("\t);\n")
-  
-  def to_yaml(self, file):
-    """
-    Generate yaml driver description file
-
-    Parameters
-    ----------
-    file : str
-      configuration file name
-    """
-    d = {}
-
-    
-
-    d["module"] = "1Step"
-    d["input_layer_size"] = self.input_layer_size
-    d["output_layer_size"] = self.output_layer_size
-    d["baud_rate"] = self.baud_rate
-    d["queue_read_depth"] = self.queue_read_depth
-    d["queue_write_depth"] = self.output_layer_size
-
-    with open(file, 'w') as f:
-      yaml.dump(d, f)
-
-  def write_io(self, vhdl_file):
-    """
-    Write port IO in entity definition section
-
-    Parameters
-    ----------
-    vhdl_file : TextIOWrapper
-      vhdl file
-    """
-
-    vhdl_file.write(f"\t\t{self.tx_name} : in std_logic;\n")
-    vhdl_file.write(f"\t\t{self.rx_name} : out std_logic\n")
diff --git a/modneflib/modnef/modnef_driver/drivers/single_step_driver.py b/modneflib/modnef/modnef_driver/drivers/single_step_driver.py
deleted file mode 100644
index 0feba5e..0000000
--- a/modneflib/modnef/modnef_driver/drivers/single_step_driver.py
+++ /dev/null
@@ -1,136 +0,0 @@
-"""
-File name: single_step_driver
-Author: Aurélie Saulquin  
-Version: 2.0.0
-License: GPL-3.0-or-later
-Contact: aurelie.saulquin@univ-lille.fr
-Dependencies: default_driver, yaml
-Descriptions: Driver class for UART_1Step uart module
-"""
-
-from .default_driver import default_transformation
-from .default_driver import ModNEF_Driver, ClosedDriverError
-import yaml
-
-class SingleStep_Driver(ModNEF_Driver):
-  """
-  Driver of Uart_SingleStep module
-
-  Attributes
-  ----------
-  board_path : str
-    fpga driver board path
-  baud_rate : int
-    data baud rate
-  input_layer_size : int
-    number of neuron in input layer
-  output_layer_size : int
-    number of neuron in output layer
-  queue_read_depth : int
-    number of word of read queue
-  queue_write_depth : int
-    number of word of write queue
-
-  Methods
-  -------
-  from_yaml(yaml_file, board_path) : classmethod
-    create driver from yaml configuration file
-  run_sample(input_sample, transformation, reset_membrane):
-    run data communication to run a data sample
-  """
-
-  def __init__(self, board_path, baud_rate, input_layer_size, output_layer_size, queue_read_depth, queue_write_depth):
-    """
-    Constructor
-
-    Parameters
-    ----------
-    board_path : str
-      fpga driver board path
-    baud_rate : int
-      data baud rate
-    input_layer_size : int
-      number of neuron in input layer
-    output_layer_size : int
-      number of neuron in output layer
-    queue_read_depth : int
-      number of word of read queue
-    queue_write_depth : int
-      number of word of write queue
-    """
-
-    
-
-    super().__init__(board_path, baud_rate, input_layer_size, output_layer_size, queue_read_depth, queue_write_depth)
-
-  @classmethod
-  def from_yaml(cls, yaml_file, board_path):
-    """
-    classmethod
-
-    create driver from driver configuration file
-
-    Parameters
-    ----------
-    yaml_file : str
-      configuration file
-    board_path : str
-      path to board driver
-    """
-
-    
-    
-    with open(yaml_file, 'r') as f:
-      config = yaml.safe_load(f)
-
-    print("coucou")
-
-    d = cls(board_path = board_path,
-      baud_rate = config["baud_rate"],
-      input_layer_size = config["input_layer_size"],
-      output_layer_size = config["output_layer_size"],
-      queue_read_depth = config["queue_read_depth"],
-      queue_write_depth = config["queue_write_depth"]
-    )
-    return d
-  
-
-  def run_sample(self, input_sample, transformation = default_transformation, reset_membrane=False, extra_step = 0):
-    """
-    Run an entire data sample by using run_step function (for more details see run_step)
-
-    Parameters
-    ----------
-    input_sample : list
-      list of spikes of sample
-    transformation : function
-      function call to tranform input spikes to AER representation
-    reset_membrane : bool
-      set to true if reset voltage membrane after sample transmission
-
-    Returns
-    -------
-    list of list of int:
-      list of list of output AER data for all emulation step
-    """
-
-    if self._is_close:
-      raise ClosedDriverError()
-
-    sample_res = [0 for _ in range(self.output_layer_size)]
-
-    sample_aer = [transformation(s) for s in input_sample]
-
-    for es in range(extra_step):
-      sample_aer.append([])
-
-    for step in range(len(input_sample)):
-      step_spikes = sample_aer[step]      
-      if step == len(input_sample)-1:
-        step_res = self.rust_driver.data_transmission(step_spikes, reset_membrane)
-      else:
-        step_res = self.rust_driver.data_transmission(step_spikes, False)
-      for s in step_res:
-        sample_res[s] += 1
-
-    return sample_res
-- 
GitLab


From 7c7480bf567989071aa5847af1b194c6fccc21d7 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Tue, 18 Mar 2025 05:50:14 +0100
Subject: [PATCH 04/23] remove post update quantization

---
 modneflib/modnef/arch_builder/modules/UART/__init__.py         | 1 -
 modneflib/modnef/modnef_driver/drivers/__init__.py             | 1 -
 modneflib/modnef/modnef_driver/modnef_drivers.py               | 1 -
 .../modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py | 3 ---
 modneflib/modnef/templates/run_lib.py                          | 2 +-
 5 files changed, 1 insertion(+), 7 deletions(-)

diff --git a/modneflib/modnef/arch_builder/modules/UART/__init__.py b/modneflib/modnef/arch_builder/modules/UART/__init__.py
index 1bf9604..3cbbe30 100644
--- a/modneflib/modnef/arch_builder/modules/UART/__init__.py
+++ b/modneflib/modnef/arch_builder/modules/UART/__init__.py
@@ -7,7 +7,6 @@ Dependencies: uart_1step, uart_classifier, uart_classifier_timer, uart_xstep, ua
 Descriptions: UART module builder init
 """
 
-from .uart_1step import Uart_1Step
 from .uart_classifier import Uart_Classifier
 from .uart_classifier_timer import Uart_Classifier_Timer
 from .uart_xstep import Uart_XStep
diff --git a/modneflib/modnef/modnef_driver/drivers/__init__.py b/modneflib/modnef/modnef_driver/drivers/__init__.py
index 0615fd9..1bd6c71 100644
--- a/modneflib/modnef/modnef_driver/drivers/__init__.py
+++ b/modneflib/modnef/modnef_driver/drivers/__init__.py
@@ -11,6 +11,5 @@ from .classifier_driver import Classifier_Driver
 from .classifier_timer_driver import Classifier_Timer_Driver
 from .debugger_driver import Debugger_Driver
 from .default_driver import ModNEF_Driver
-from .single_step_driver import SingleStep_Driver
 from .xstep_driver import XStep_Driver
 from .xstep_timer_driver import XStep_Timer_Driver
\ No newline at end of file
diff --git a/modneflib/modnef/modnef_driver/modnef_drivers.py b/modneflib/modnef/modnef_driver/modnef_drivers.py
index be6f4db..11e2cc3 100644
--- a/modneflib/modnef/modnef_driver/modnef_drivers.py
+++ b/modneflib/modnef/modnef_driver/modnef_drivers.py
@@ -13,7 +13,6 @@ from .drivers import *
 import yaml
 
 drivers_dict = {
-  "1Step" : SingleStep_Driver,
   "XStep" : XStep_Driver,
   "Classifier" : Classifier_Driver,
   "Debugger" : Debugger_Driver,
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index d4f58d3..9bb6d1d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -269,9 +269,6 @@ class ShiftLIF(LIF, ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py
index 07ccb6d..21f536b 100644
--- a/modneflib/modnef/templates/run_lib.py
+++ b/modneflib/modnef/templates/run_lib.py
@@ -69,7 +69,7 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
     acc_test_history.append(acc_test)
 
     if best_model_name!="" and acc_test>best_acc:
-      torch.save(model.state_dict(), best_model_name)
+      torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
       best_acc = acc_test
 
   if save_history:
-- 
GitLab


From 5f5d13434f53ebbfc58737ef865cd660e3f37b62 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Wed, 19 Mar 2025 22:21:33 +0100
Subject: [PATCH 05/23] add quantizer modif

---
 .../modnef_neurons/modnef_torch_neuron.py     | 46 +++++++++++++++++--
 .../modnef_neurons/srlif_model/shiftlif.py    | 29 ++++--------
 .../quantizer/dynamic_scale_quantizer.py      |  2 +-
 .../modnef/quantizer/fixed_point_quantizer.py |  2 +-
 .../modnef/quantizer/min_max_quantizer.py     |  2 +-
 modneflib/modnef/quantizer/quantizer.py       |  8 ++--
 6 files changed, 57 insertions(+), 32 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 7322594..0a1acd9 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -10,6 +10,8 @@ Descriptions: ModNEF torch neuron interface builder
 
 import torch
 from modnef.quantizer import *
+from snntorch._neurons import SpikingNeuron
+from snntorch.surrogate import fast_sigmoid
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -17,7 +19,7 @@ _quantizer = {
   "DynamicScaleFactorQuantizer" : DynamicScaleFactorQuantizer
 }
 
-class ModNEFNeuron():
+class ModNEFNeuron(SpikingNeuron):
   """
   ModNEF torch neuron interface
 
@@ -42,7 +44,27 @@ class ModNEFNeuron():
     create and return the corresponding modnef archbuilder module from internal neuron parameters
   """
 
-  def __init__(self, quantizer : Quantizer):
+  def __init__(self, 
+               threshold, 
+               reset_mechanism, 
+               quantizer : Quantizer, 
+               spike_grad=fast_sigmoid(slope=25)):
+
+    SpikingNeuron.__init__(
+      self=self,
+      threshold=threshold,
+      reset_mechanism=reset_mechanism,
+      spike_gard=spike_grad,
+      surrogate_disable=False,
+      init_hidden=False,
+      inhibition=False,
+      learn_threshold=False,
+      state_quant=False,
+      output=False,
+      graded_spikes_factor=1.0,
+      learn_graded_spikes_factor=False
+    )
+
     self.hardware_estimation_flag = False
     self.quantization_flag = False
 
@@ -61,7 +83,20 @@ class ModNEFNeuron():
 
     raise NotImplementedError()
   
-  def quantize_weight(self):
+  def init_quantizer(self):
+
+    params = list(self.parameters())
+
+    w1 = params[0].data
+
+    if len(params)==2:
+      w2 = params[0].data
+    else:
+      w2 = torch.zeros((1))
+
+    self.quantizer.init_quantizer(w1, w2)
+  
+  def quantize_weight(self, unscaled : bool = False):
     """
     synaptic weight quantization
 
@@ -70,9 +105,10 @@ class ModNEFNeuron():
     NotImplementedError()
     """
     
-    raise NotImplementedError()
+    for param in self.parameters():
+      param.data = self.quantizer(param.data, unscale=unscaled)
   
-  def quantize_parameters(self):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 9bb6d1d..195e965 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -11,14 +11,14 @@ Based on snntorch.Leaky and snntroch.LIF class
 
 import torch.nn as nn
 import torch
-from snntorch import LIF
+from snntorch.surrogate import fast_sigmoid
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 
-class ShiftLIF(LIF, ModNEFNeuron):
+class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
 
@@ -86,7 +86,7 @@ class ShiftLIF(LIF, ModNEFNeuron):
                out_features,
                beta,
                threshold=1.0,
-               spike_grad=None,
+               spike_grad=fast_sigmoid(slope=25),
                reset_mechanism="subtract",
                quantizer=DynamicScaleFactorQuantizer(8)
             ):
@@ -117,24 +117,13 @@ class ShiftLIF(LIF, ModNEFNeuron):
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
 
-    LIF.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
 
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
+    ModNEFNeuron.__init__(self=self, 
+                          threshold=threshold,
+                          reset_mechanism=reset_mechanism,
+                          spike_grad=spike_grad,
+                          quantizer=quantizer
+                          )
 
     self.fc = nn.Linear(in_features, out_features, bias=False)
 
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 2631170..43a6aa1 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index 106a242..c48e690 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index 6ca8131..ebc4ae9 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index d753162..6ab752e 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -77,7 +77,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
@@ -112,10 +112,10 @@ class Quantizer():
     else:
       tdata = data
 
+    qdata = self._quant(tdata)
+
     if unscale:
-      qdata = self._unquant(self._quant(tdata))
-    else:
-      qdata = self._quant(tdata)
+      qdata = self._unquant(qdata)
     
     if isinstance(data, (int, float)):
       return qdata.item()
-- 
GitLab


From e872088a13243712b9022390d09c0069f8e9d4ca Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Sun, 23 Mar 2025 18:42:59 +0100
Subject: [PATCH 06/23] change modnefneuron

---
 modneflib/modnef/modnef_torch/model.py        |  2 +-
 .../modnef_neurons/blif_model/blif.py         | 75 ++++------------
 .../modnef_neurons/blif_model/rblif.py        | 79 ++++-------------
 .../modnef_neurons/modnef_torch_neuron.py     | 87 ++++++++++++-------
 .../modnef_neurons/slif_model/rslif.py        | 82 ++++-------------
 .../modnef_neurons/slif_model/slif.py         | 80 ++++-------------
 .../modnef_neurons/srlif_model/rshiftlif.py   | 78 ++++-------------
 .../modnef_neurons/srlif_model/shiftlif.py    | 62 ++++---------
 .../quantizer/dynamic_scale_quantizer.py      |  2 +-
 .../modnef/quantizer/fixed_point_quantizer.py |  2 +-
 .../modnef/quantizer/min_max_quantizer.py     |  2 +-
 modneflib/modnef/quantizer/quantizer.py       |  6 +-
 12 files changed, 166 insertions(+), 391 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 44443dc..ae40072 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -88,7 +88,7 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.hardware_estimation(hardware)
-        m.set_quant(quant)
+        m.run_quantize(quant)
 
     return super().train(mode=mode)
   
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 83cc04f..50f18d8 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -18,7 +18,7 @@ from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from modnef.quantizer import *
 
-class BLIF(Leaky, ModNEFNeuron):
+class BLIF(ModNEFNeuron):
   """
   ModNEFTorch BLIF neuron model
 
@@ -111,26 +111,15 @@ class BLIF(Leaky, ModNEFNeuron):
       quantization method
     """
     
-    Leaky.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
-
-    self.fc = nn.Linear(in_features, out_features, bias=False)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+    
+    self.register_buffer("beta", torch.tensor(beta))
 
     self._init_mem()
 
@@ -307,50 +296,20 @@ class BLIF(Leaky, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
-    """
 
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-    self.quantization_flag = True
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index bf702cb..29bd284 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
 
-class RBLIF(Leaky, ModNEFNeuron):
+class RBLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent BLIF neuron model
 
@@ -116,27 +116,17 @@ class RBLIF(Leaky, ModNEFNeuron):
       quantization method
     """
     
-    Leaky.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+    
+    self.register_buffer("beta", torch.tensor(beta))
 
-    self.fc = nn.Linear(in_features, out_features, bias=False)
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = nn.Linear(out_features, out_features, bias=True)
 
     self._init_mem()
 
@@ -321,52 +311,19 @@ class RBLIF(Leaky, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-    self.quantization_flag = True
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 0a1acd9..08e37ab 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -9,6 +9,7 @@ Descriptions: ModNEF torch neuron interface builder
 """
 
 import torch
+import torch.nn as nn
 from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
@@ -44,26 +45,22 @@ class ModNEFNeuron(SpikingNeuron):
     create and return the corresponding modnef archbuilder module from internal neuron parameters
   """
 
-  def __init__(self, 
-               threshold, 
-               reset_mechanism, 
-               quantizer : Quantizer, 
-               spike_grad=fast_sigmoid(slope=25)):
-
-    SpikingNeuron.__init__(
-      self=self,
+  def __init__(self,
+               in_features,
+               out_features,
+               threshold,
+               reset_mechanism,
+               spike_grad, 
+               quantizer):
+    
+    super().__init__(
       threshold=threshold,
-      reset_mechanism=reset_mechanism,
-      spike_gard=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_threshold=False,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
+      spike_grad=spike_grad,
+      reset_mechanism=reset_mechanism
     )
+    
+    self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
+
 
     self.hardware_estimation_flag = False
     self.quantization_flag = False
@@ -84,19 +81,20 @@ class ModNEFNeuron(SpikingNeuron):
     raise NotImplementedError()
   
   def init_quantizer(self):
+    """
+    Initialize internal or re-initialize internal quantizer
+    """
 
-    params = list(self.parameters())
-
-    w1 = params[0].data
+    param = list(self.parameters())
 
-    if len(params)==2:
-      w2 = params[0].data
+    if len(param)==1:
+      self.quantizer.init_from_weight(param[0])
+      print("init no rec")
     else:
-      w2 = torch.zeros((1))
-
-    self.quantizer.init_quantizer(w1, w2)
+      self.quantizer.init_from_weight(param[0], param[1])
+      print("init rec")
   
-  def quantize_weight(self, unscaled : bool = False):
+  def quantize_weight(self, unscale : bool = True):
     """
     synaptic weight quantization
 
@@ -105,12 +103,20 @@ class ModNEFNeuron(SpikingNeuron):
     NotImplementedError()
     """
     
-    for param in self.parameters():
-      param.data = self.quantizer(param.data, unscale=unscaled)
+    for p in self.parameters():
+      p.data = self.quantizer(p.data, unscale=unscale)
+      print(p)
+      print("quantize weight")
   
-  def quantize_hp(self):
+  def quantize_hp(self, unscale : bool = True):
     """
     neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
+
+    Parameters
+    ----------
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
     
     raise NotImplementedError()
@@ -125,16 +131,31 @@ class ModNEFNeuron(SpikingNeuron):
       force quantizer initialization
     """
     
-    raise NotImplementedError()
+    if force_init:
+      self.init_quantizer()
+
+    self.quantize_weight()
+    self.quantize_hp()
   
   def clamp(self):
     """
     Clamp synaptic weight
     """
 
-    raise NotImplementedError()
+    for p in self.parameters():
+      p.data = self.quantizer.clamp(p.data)
+      print("clamp")
   
-  def set_quant(self, mode=False):
+  def run_quantize(self, mode=False):
+    """
+    Srtup quantization flag
+
+    Parameters
+    ----------
+    mode : bool = False
+      quantize run or not
+    """
+
     self.quantization_flag = mode
   
   def hardware_estimation(self, mode = False):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 2d78aa9..190ddc8 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
 from modnef.quantizer import MinMaxQuantizer
 
-class RSLIF(LIF, ModNEFNeuron):
+class RSLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Simplifed LIF neuron model
 
@@ -119,33 +119,18 @@ class RSLIF(LIF, ModNEFNeuron):
       quantization function
     """
 
-    LIF.__init__(
-      self=self,
-      beta = v_leak,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism="zero",
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
-    )
-
-    ModNEFNeuron.__init__(self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism="zero",
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
 
     self.register_buffer("v_leak", torch.as_tensor(v_leak))
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.in_features = in_features
-    self.out_features = out_features
-
-    self.fc = nn.Linear(self.in_features, self.out_features, bias=False)
     self.reccurent = nn.Linear(self.out_features, self.out_features, bias=False)
 
     self._init_mem()
@@ -338,56 +323,23 @@ class RSLIF(LIF, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
+  def quantize_hp(self, unscale : bool = True):
     """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.v_leak.data = self.quantizer(self.v_leak.data, True)
-    self.v_min.data = self.quantizer(self.v_min.data, True)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True)
-    self.threshold.data = self.quantizer(self.threshold, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.quantize_weight()
-    self.quantize_parameters()
+    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
+    self.v_min.data = self.quantizer(self.v_min.data, unscale)
+    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
+    self.threshold.data = self.quantizer(self.threshold, unscale)
 
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
-    """
-
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
-
-  
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 9d01858..78cc36b 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
 from modnef.quantizer import MinMaxQuantizer
 
-class SLIF(LIF, ModNEFNeuron):
+class SLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Simplifed LIF neuron model
 
@@ -119,34 +119,18 @@ class SLIF(LIF, ModNEFNeuron):
       quantization method
     """
 
-    LIF.__init__(
-      self=self,
-      beta = v_leak,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism="zero",
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
-    )
-
-    ModNEFNeuron.__init__(self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism="zero",
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
 
     self.register_buffer("v_leak", torch.as_tensor(v_leak))
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.in_features = in_features
-    self.out_features = out_features
-
-    self.fc = nn.Linear(self.in_features, self.out_features, bias=False)
-
     self._init_mem()
     
     self.hardware_description = {
@@ -329,53 +313,23 @@ class SLIF(LIF, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
 
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.v_leak.data = self.quantizer(self.v_leak.data, True)
-    self.v_min.data = self.quantizer(self.v_min.data, True)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True)
-    self.threshold.data = self.quantizer(self.threshold, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
 
-    self.quantize_weight()
-    self.quantize_parameters()
+    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
+    self.v_min.data = self.quantizer(self.v_min.data, unscale)
+    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
+    self.threshold.data = self.quantizer(self.threshold, unscale)
 
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
-    """
-
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-  
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index cdf4d26..e6354a2 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -19,7 +19,7 @@ from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 
 
-class RShiftLIF(LIF, ModNEFNeuron):
+class RShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Shift LIF neuron model
 
@@ -116,35 +116,25 @@ class RShiftLIF(LIF, ModNEFNeuron):
     quantizer = DynamicScaleFactoirQuantizer(8) : Quantizer
       quantization method
     """
+
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
     
     self.shift = int(-log(1-beta)/log(2))
 
     if (1-2**-self.shift) != beta:
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
-    
-    LIF.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
 
-    self.fc = nn.Linear(in_features, out_features, bias=False)
     self.reccurent = nn.Linear(out_features, out_features, bias=False)
 
+    self.register_buffer("beta", torch.tensor(beta))
+
     self._init_mem()
 
     self.hardware_description = {
@@ -340,51 +330,19 @@ class RShiftLIF(LIF, ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 195e965..f2c95cb 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -17,6 +17,7 @@ from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
+from snntorch import LIF
 
 class ShiftLIF(ModNEFNeuron):
   """
@@ -110,6 +111,15 @@ class ShiftLIF(ModNEFNeuron):
     quantizer = DynamicScaleFactorQuantizer(8) : Quantizer
       quantization method
     """
+
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+
     
     self.shift = int(-log(1-beta)/log(2))
 
@@ -118,14 +128,7 @@ class ShiftLIF(ModNEFNeuron):
       beta = 1-2**-self.shift
 
 
-    ModNEFNeuron.__init__(self=self, 
-                          threshold=threshold,
-                          reset_mechanism=reset_mechanism,
-                          spike_grad=spike_grad,
-                          quantizer=quantizer
-                          )
-
-    self.fc = nn.Linear(in_features, out_features, bias=False)
+    self.register_buffer("beta", torch.tensor(beta))
 
     self._init_mem()
 
@@ -310,48 +313,19 @@ class ShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
+  def quantize_hp(self, unscale : bool = True):
     """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 43a6aa1..2631170 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index c48e690..106a242 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index ebc4ae9..6ca8131 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 6ab752e..c56f506 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -77,7 +77,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
@@ -108,14 +108,14 @@ class Quantizer():
     """
 
     if not torch.is_tensor(data):
-      tdata = torch.tensor(data)
+      tdata = torch.tensor(data, dtype=torch.float32)
     else:
       tdata = data
 
     qdata = self._quant(tdata)
 
     if unscale:
-      qdata = self._unquant(qdata)
+      qdata = self._unquant(qdata).to(torch.float32)
     
     if isinstance(data, (int, float)):
       return qdata.item()
-- 
GitLab


From 77bbfb1a0692d5bb84c74c8a6994532881ac9327 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 24 Mar 2025 10:22:09 +0100
Subject: [PATCH 07/23] add model to template

---
 .../modnef_neurons/blif_model/rblif.py        |   2 +-
 .../modnef_neurons/modnef_torch_neuron.py     |   4 -
 .../modnef_neurons/slif_model/rslif.py        |   2 +-
 .../modnef_neurons/slif_model/slif.py         |   2 +-
 .../modnef_neurons/srlif_model/rshiftlif.py   |   2 +-
 modneflib/modnef/templates/evaluation.py      |   7 +-
 modneflib/modnef/templates/model.py           | 175 ++++++++++++++++++
 modneflib/modnef/templates/run_lib.py         |   5 +
 modneflib/modnef/templates/train.py           |  11 +-
 modneflib/modnef/templates/vhdl_generation.py |   7 +-
 10 files changed, 201 insertions(+), 16 deletions(-)
 create mode 100644 modneflib/modnef/templates/model.py

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 29bd284..01f59f9 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -126,7 +126,7 @@ class RBLIF(ModNEFNeuron):
     
     self.register_buffer("beta", torch.tensor(beta))
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=True)
+    self.reccurent = nn.Linear(out_features, out_features, bias=False)
 
     self._init_mem()
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 08e37ab..c4a9800 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -89,10 +89,8 @@ class ModNEFNeuron(SpikingNeuron):
 
     if len(param)==1:
       self.quantizer.init_from_weight(param[0])
-      print("init no rec")
     else:
       self.quantizer.init_from_weight(param[0], param[1])
-      print("init rec")
   
   def quantize_weight(self, unscale : bool = True):
     """
@@ -105,8 +103,6 @@ class ModNEFNeuron(SpikingNeuron):
     
     for p in self.parameters():
       p.data = self.quantizer(p.data, unscale=unscale)
-      print(p)
-      print("quantize weight")
   
   def quantize_hp(self, unscale : bool = True):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 190ddc8..7d9c21f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -131,7 +131,7 @@ class RSLIF(ModNEFNeuron):
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.reccurent = nn.Linear(self.out_features, self.out_features, bias=False)
+    self.reccurent = nn.Linear(out_features, out_features, bias=False)
 
     self._init_mem()
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 78cc36b..089519b 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -287,7 +287,7 @@ class SLIF(ModNEFNeuron):
       if self.hardware_estimation_flag:
         val_max = max(abs(self.val_max), abs(self.val_min))
         print(val_max)
-        val_max = self.quantizer(val_max, dtype=torch.int32)
+        val_max = self.quantizer(val_max)
         print(val_max)
         self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8
       else:
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index e6354a2..3cd77b3 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -303,7 +303,7 @@ class RShiftLIF(ModNEFNeuron):
     if self.hardware_description["variable_size"]==-1:
       if self.hardware_estimation_flag:
         val_max = max(abs(self.val_max), abs(self.val_min))
-        val_max = val_max*2**(self.hardware_description["compute_fp"])
+        val_max = self.quantizer(self.val_max)
         self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8
       else:
         self.hardware_description["variable_size"]=16
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index aa74e22..ebcc8f8 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -6,6 +6,7 @@ from snntorch.surrogate import fast_sigmoid
 import torch
 from run_lib import *
 import sys
+from model import MyModel
 
 if __name__ == "__main__":
 
@@ -13,10 +14,12 @@ if __name__ == "__main__":
   exp_name = "Evaluation"
 
   """Model definition"""
-  model_path = "model_template.json"
   best_model_name = "best_model"
 
-  model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+  # model_path = "model_template.json"
+  # model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+
+  model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25))
 
   model.load_state_dict(torch.load(best_model_name))
 
diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py
new file mode 100644
index 0000000..ac2fb76
--- /dev/null
+++ b/modneflib/modnef/templates/model.py
@@ -0,0 +1,175 @@
+import modnef.modnef_torch as mt
+from modnef.arch_builder import *
+from snntorch.surrogate import fast_sigmoid
+from modnef.quantizer import *
+import torch
+
+class MyModel(mt.ModNEFModel):
+
+  def __init__(self, name, spike_grad=fast_sigmoid(slope=25)):
+
+    super().__init__()
+
+    self.name = name
+    
+    self.layer1 = mt.SLIF(in_features=2312,
+                          out_features=128,
+                          threshold=0.8,
+                          v_leak=0.015,
+                          v_min=0.0,
+                          v_rest=0.0,
+                          spike_grad=spike_grad,
+                          quantizer=MinMaxQuantizer(
+                            bitwidth=8,
+                            signed=None
+                          ))
+    
+    self.layer2 = mt.ShiftLIF(in_features=128,
+                              out_features=64,
+                              threshold=0.8,
+                              beta=0.875,
+                              reset_mechanism="subtract",
+                              spike_grad=spike_grad,
+                              quantizer=DynamicScaleFactorQuantizer(
+                                bitwidth=8,
+                                signed=None
+                              ))
+    
+    self.layer3 = mt.BLIF(in_features=64,
+                          out_features=10,
+                          threshold=0.8,
+                          beta=0.9,
+                          reset_mechanism="subtract",
+                          spike_grad=spike_grad,
+                          quantizer=FixedPointQuantizer(
+                            bitwidth=8,
+                            fixed_point=7,
+                            signed=None
+                          ))
+    
+  def software_forward(self, input_spikes):
+    """
+    Run layers upate
+
+    Parameters
+    ----------
+    input_spikes : Tensor
+      input spikes
+
+    Returns
+    -------
+    tuple of tensor
+      output_spike, output_mem
+    """
+    
+    spk1, mem1 = self.layer1.reset_mem()
+    spk2, mem2 = self.layer2.reset_mem()
+    spk3, mem3 = self.layer3.reset_mem()
+
+    spk_rec = []
+    mem_rec = []
+
+    batch_size = input_spikes.shape[0]
+    n_steps = input_spikes.shape[1]
+
+    for step in range(n_steps):
+      x = input_spikes[:, step].reshape(batch_size, -1)
+
+      spk1, mem1 = self.layer1(x, mem1, spk1)
+      spk2, mem2 = self.layer2(spk1, mem2, spk2)
+      spk3, mem3 = self.layer3(spk2, mem3, spk3)
+
+      spk_rec.append(spk3)
+      mem_rec.append(mem3)
+
+    return torch.stack(spk_rec, dim=0), torch.stack(mem_rec, dim=0)
+  
+  def fpga_forward(self, input_spikes):
+    """
+    Transmit input spike to FPGA
+
+    Parameters
+    ----------
+    input_spikes : Tensor
+      input spikes
+
+    Returns
+    -------
+    tuple of tensor
+      output_spike, None
+    """
+
+    def to_aer(input):
+      input = input.reshape(-1).to(torch.int32)
+
+      aer = []
+      for i in range(input.shape[0]):
+        for _ in range(input[i]):
+          aer.append(i)
+
+      return aer
+    
+    if self.driver == None:
+      raise Exception("please open fpga driver before")
+    
+    batch_result = []
+
+    for sample in input_spikes:
+      sample_res = self.driver.run_sample(sample, to_aer, True, len(self.layers))
+      batch_result.append([sample_res])
+
+    return torch.tensor(batch_result).permute(1, 0, 2), None
+  
+  def to_vhdl(self, file_name=None, output_path = ".", driver_config_path = "./driver.yml"):
+    """
+    Generate VHDL file of model
+
+    Parameters
+    ----------
+    file_name = None : str
+      VHDL file name
+      if default, file name is model name
+    output_path = "." : str
+      output file path
+    driver_config_path = "./driver.yml" : str
+      driver configuration file
+    """
+    
+    if file_name==None:
+      file_name = f"{output_path}/{self.name}.vhd"
+
+    builder = ModNEFBuilder(self.name, 2312, 10)
+
+
+    uart = Uart_XStep(
+      name="uart",
+      input_layer_size=2312,
+      output_layer_size=10,
+      clk_freq=125_000_000,
+      baud_rate=921_600,
+      queue_read_depth=10240,
+      queue_write_depth=1024,
+      tx_name="uart_rxd",
+      rx_name="uart_txd"
+    )
+
+    builder.add_module(uart)
+    builder.set_io(uart)
+
+    layer1_module = self.layer1.get_builder_module(f"{self.name}_layer1", output_path)
+    builder.add_module(layer1_module)
+
+    layer2_module = self.layer2.get_builder_module(f"{self.name}_layer2", output_path)
+    builder.add_module(layer2_module)
+
+    layer3_module = self.layer3.get_builder_module(f"{self.name}_layer3", output_path)
+    builder.add_module(layer3_module)
+
+    builder.add_link(uart, layer1_module)
+    builder.add_link(layer1_module, layer2_module)
+    builder.add_link(layer2_module, layer3_module)
+    builder.add_link(layer3_module, uart)
+    
+
+    builder.get_driver_yaml(f"{output_path}/{driver_config_path}")
+    builder.to_vhdl(file_name, "clock")
\ No newline at end of file
diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py
index 21f536b..85a1352 100644
--- a/modneflib/modnef/templates/run_lib.py
+++ b/modneflib/modnef/templates/run_lib.py
@@ -169,6 +169,9 @@ def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"),
 
   model.eval(quant)
 
+  if quant:
+    model.quantize(force_init=True)
+
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true   
@@ -182,6 +185,8 @@ def hardware_estimation(model, testLoader, name="Hardware Estimation", device=to
 
   model.hardware_estimation()
 
+  model.quantize(force_init=True)
+
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true
diff --git a/modneflib/modnef/templates/train.py b/modneflib/modnef/templates/train.py
index 03603ec..b67de45 100644
--- a/modneflib/modnef/templates/train.py
+++ b/modneflib/modnef/templates/train.py
@@ -5,12 +5,15 @@ import os
 from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
+from model import MyModel
 
 if __name__ == "__main__":
 
   """Model definition"""
-  model_path = "model_template.json"
-  model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+  # model_path = "model_template.json"
+  # model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+
+  model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25))
 
   """Optimizer"""
   optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
@@ -22,7 +25,7 @@ if __name__ == "__main__":
   device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
 
   """Train variable definition"""
-  n_epoch = 1
+  n_epoch = 2
   best_model_name = "best_model"
   verbose = True
   save_plot = False
@@ -34,7 +37,7 @@ if __name__ == "__main__":
 
   # data set definition, change to your dataset
   sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=10)
+  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=5)
 
   train_set = tonic.datasets.NMNIST(save_to=dataset_path, train=True, transform=frame_transform)
   test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index 82dbb71..fcca214 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -5,6 +5,7 @@ import os
 from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
+from model import MyModel
 
 if __name__ == "__main__":
 
@@ -12,10 +13,12 @@ if __name__ == "__main__":
   exp_name = "Evaluation"
 
   """Model definition"""
-  model_path = "model_template.json"
   best_model_name = "best_model"
 
-  model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+  # model_path = "model_template.json"
+  # model = ModNEFModelBuilder(model_path, spike_grad=fast_sigmoid(slope=25))
+
+  model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25))
 
   model.load_state_dict(torch.load(best_model_name))
 
-- 
GitLab


From cae5478e03365289f3222f4c0698e73c1c2569b6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Mon, 24 Mar 2025 15:33:18 +0100
Subject: [PATCH 08/23] change template

---
 .../neurons/ShiftLif/shiftlif_parallel.vhd    | 15 ++++++++------
 .../modnef_driver/drivers/xstep_driver.py     |  3 ++-
 modneflib/modnef/templates/dataset.py         | 20 +++++++++++++++++++
 modneflib/modnef/templates/evaluation.py      | 16 ++-------------
 modneflib/modnef/templates/train.py           | 18 +----------------
 modneflib/modnef/templates/vhdl_generation.py | 17 ++--------------
 6 files changed, 36 insertions(+), 53 deletions(-)
 create mode 100644 modneflib/modnef/templates/dataset.py

diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
index 6c6989f..32a3aca 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/shiftlif_parallel.vhd
@@ -127,7 +127,7 @@ architecture Behavioral of ShiftLif_Parallel is
   end component;
 
   -- type definition
-  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration);
+  type transmission_state_t is (idle, voltage_update, check_arbitration, request, accept, wait_arbitration, arbitration_finish);
   type reception_state_t    is (idle, request, get_data);
 
   -- ram signals
@@ -249,9 +249,8 @@ begin
 
           when check_arbitration =>
             if spikes = no_spike then
-              transmission_state <= idle;
+              transmission_state <= arbitration_finish;
               o_emu_busy <= '0';
-              tr_fsm_en := '0';
             else
               transmission_state <= request;
               arb_spikes <= spikes;
@@ -278,12 +277,15 @@ begin
           when wait_arbitration =>
             start_arb <= '0';
             if arb_busy = '0' then
-              transmission_state <= idle;
-              o_emu_busy <= '0';
-              tr_fsm_en := '0';
+              transmission_state <= arbitration_finish;
             else
               transmission_state <= wait_arbitration;
             end if;  
+              
+          when arbitration_finish =>
+            transmission_state <= idle;
+            o_emu_busy <= '0';
+            tr_fsm_en := '0';
         end case;
       end if;
     end if;
@@ -337,3 +339,4 @@ begin
   end generate neuron_generation;
 
 end Behavioral;
+
diff --git a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
index 778b95b..6b5b216 100644
--- a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
+++ b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
@@ -160,8 +160,9 @@ class XStep_Driver(ModNEF_Driver):
         step_send += 1
 
       if step == len(sample_aer)-1:
+        print(len(data))
         emulation_result = self.rust_driver.data_transmission(data, reset_membrane)
-        
+        print("hi")
         res_step = self._unpack_data(emulation_result)
         for rs in res_step:
           for aer in rs:
diff --git a/modneflib/modnef/templates/dataset.py b/modneflib/modnef/templates/dataset.py
new file mode 100644
index 0000000..e0b86fc
--- /dev/null
+++ b/modneflib/modnef/templates/dataset.py
@@ -0,0 +1,20 @@
+import os
+import tonic
+from torch.utils.data import DataLoader
+
+"""DataSet Definition"""
+dataset_path = f"{os.environ['HOME']}/datasets"
+
+# data set definition, change to your dataset
+sensor_size = tonic.datasets.NMNIST.sensor_size
+frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=5)
+
+train_set = tonic.datasets.NMNIST(save_to=dataset_path, train=True, transform=frame_transform)
+test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
+
+# batch loader
+batch_size = 64
+
+trainLoader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+validationLoader = None
\ No newline at end of file
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index ebcc8f8..91cba66 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -7,6 +7,7 @@ import torch
 from run_lib import *
 import sys
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -53,20 +54,7 @@ if __name__ == "__main__":
   conf_matrix_file = "confusion_matrix.png"
   conf_matrix_classes = [str(i) for i in range(10)]
   
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=10)
-
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
   
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-
   if kind == "eval":
     acc, y_pred, y_true = evaluation(
       model=model, 
@@ -99,7 +87,7 @@ if __name__ == "__main__":
     exit(-1)
 
   if save_conf_matrix:
-    confusion_matrix(
+    conf_matrix(
       y_true=y_true,
       y_pred=y_pred,
       file_name=conf_matrix_file,
diff --git a/modneflib/modnef/templates/train.py b/modneflib/modnef/templates/train.py
index b67de45..0f48125 100644
--- a/modneflib/modnef/templates/train.py
+++ b/modneflib/modnef/templates/train.py
@@ -6,6 +6,7 @@ from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -31,23 +32,6 @@ if __name__ == "__main__":
   save_plot = False
   save_history = False
   output_path = "."
-  
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=5)
-
-  train_set = tonic.datasets.NMNIST(save_to=dataset_path, train=True, transform=frame_transform)
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
-  
-  trainLoader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
-  validationLoader = None
 
   train(
     model=model, 
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index fcca214..563d1a1 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -6,6 +6,7 @@ from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
+from dataset import *
 
 if __name__ == "__main__":
 
@@ -33,21 +34,7 @@ if __name__ == "__main__":
   """VHDL file definition"""
   output_path = "."
   file_name = "template_vhdl_model.vhd"
-  driver_config_path = "driver_config"
-  
-  """DataSet Definition"""
-  dataset_path = f"{os.environ['HOME']}/datasets"
-
-  # data set definition, change to your dataset
-  sensor_size = tonic.datasets.NMNIST.sensor_size
-  frame_transform = tonic.transforms.ToFrame(sensor_size=sensor_size, n_time_bins=10)
-
-  test_set = tonic.datasets.NMNIST(save_to=dataset_path, train=False, transform=frame_transform)
-
-  # batch loader
-  batch_size = 64
-  
-  testLoader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last = True, collate_fn = tonic.collation.PadTensors(batch_first=True))
+  driver_config_path = "driver_config.yml"
 
   acc, y_pred, y_true = hardware_estimation(
     model=model, 
-- 
GitLab


From 256235ce6ca935468b54e51c5a884536f08ac4bf Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 24 Mar 2025 17:24:18 +0100
Subject: [PATCH 09/23] remove print

---
 modneflib/modnef/modnef_driver/drivers/xstep_driver.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
index 6b5b216..1f6d31a 100644
--- a/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
+++ b/modneflib/modnef/modnef_driver/drivers/xstep_driver.py
@@ -160,9 +160,7 @@ class XStep_Driver(ModNEF_Driver):
         step_send += 1
 
       if step == len(sample_aer)-1:
-        print(len(data))
         emulation_result = self.rust_driver.data_transmission(data, reset_membrane)
-        print("hi")
         res_step = self._unpack_data(emulation_result)
         for rs in res_step:
           for aer in rs:
-- 
GitLab


From 3fe54c44ce67f18b7de34498e1b563d4f30c936a Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Tue, 25 Mar 2025 22:59:14 +0100
Subject: [PATCH 10/23] add test quantizer

---
 .../modnef_neurons/modnef_torch_neuron.py      | 18 +++++++++++++++++-
 .../modnef_neurons/srlif_model/shiftlif.py     |  3 +++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index c4a9800..06e00f6 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -14,6 +14,21 @@ from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
 
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class QuantizeSTE(torch.autograd.Function):
+    """Quantization avec Straight-Through Estimator (STE)"""
+    @staticmethod
+    def forward(ctx, x, quantizer):
+        return quantizer(x, True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None  # STE: Passe le gradient inchangé
+
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
   "MinMaxQuantizer" : MinMaxQuantizer,
@@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron):
     
     for p in self.parameters():
       p.data = self.quantizer(p.data, unscale=unscale)
+      #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
+      #p.data = QuantizeSTE.apply(p.data, self.quantizer)
   
   def quantize_hp(self, unscale : bool = True):
     """
@@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron):
 
     for p in self.parameters():
       p.data = self.quantizer.clamp(p.data)
-      print("clamp")
   
   def run_quantize(self, mode=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index f2c95cb..b0d7586 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -19,6 +19,8 @@ from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 from snntorch import LIF
 
+    
+
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron):
       self.mem.data = self.quantizer(self.mem.data, True)
       input_.data = self.quantizer(input_.data, True)
 
+
     self.mem = self.mem+input_
 
     if self.reset_mechanism == "subtract":
-- 
GitLab


From adc7744bbcea3982f5c136f14efa937cb412689a Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Fri, 28 Mar 2025 09:30:25 +0100
Subject: [PATCH 11/23] add qat test

---
 .../arch_builder/modules/ShiftLIF/shiftlif.py       |  2 +-
 modneflib/modnef/modnef_torch/model.py              | 13 +++++++++++--
 .../modnef_neurons/modnef_torch_neuron.py           |  5 ++++-
 .../modnef_neurons/srlif_model/rshiftlif.py         |  1 -
 .../modnef_neurons/srlif_model/shiftlif.py          |  4 +++-
 5 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
index 929ed26..35bb734 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
@@ -212,7 +212,7 @@ class ShiftLif(ModNEFArchMod):
 
       self.v_threshold = self.quantizer(self.v_threshold)
     
-    mem_file.close()        
+    mem_file.close() 
 
   def to_debugger(self, output_file : str = ""):
     """
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index ae40072..0da59bc 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -92,6 +92,15 @@ class ModNEFModel(nn.Module):
 
     return super().train(mode=mode)
   
+  def init_quantizer(self):
+    """
+    initialize quantizer of laters
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.init_quantizer()
+  
   def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
@@ -106,11 +115,11 @@ class ModNEFModel(nn.Module):
       if isinstance(m, ModNEFNeuron):
         m.quantize(force_init=force_init)
 
-  def clamp(self):
+  def clamp(self, force_init=False):
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
-        m.clamp()
+        m.clamp(force_init=force_init)
 
   def train(self, mode : bool = True, quant : bool = False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 06e00f6..b6a0432 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron):
     self.quantize_weight()
     self.quantize_hp()
   
-  def clamp(self):
+  def clamp(self, force_init=False):
     """
     Clamp synaptic weight
     """
 
+    if force_init:
+      self.init_quantizer()
+
     for p in self.parameters():
       p.data = self.quantizer.clamp(p.data)
   
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 3cd77b3..f3dda9a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index b0d7586..1906fb5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
+    if self.quantization_flag:
+      self.mem.data = self.quantizer(self.mem.data, True)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
@@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
-- 
GitLab


From dbfdd4f21bdfb492f68ddca32bd78ee63345eb43 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 28 Mar 2025 14:36:52 +0100
Subject: [PATCH 12/23] add quantizer test

---
 .../modnef_neurons/blif_model/rblif.py        | 26 ++++++++++++++++---
 1 file changed, 23 insertions(+), 3 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 01f59f9..c759fcf 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -18,6 +18,21 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
 
+import torch.autograd as autograd
+
+class QuantizeMembrane(autograd.Function):
+    @staticmethod
+    def forward(ctx, U, quantizer):
+        max_val = U.abs().max().detach()  # Détachement pour éviter de bloquer le gradient
+        U_quant = quantizer(U, True)
+        ctx.save_for_backward(U, quantizer.scale_factor)
+        return U_quant
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input, factor = ctx.saved_tensors
+        return grad_output, None
+
 class RBLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent BLIF neuron model
@@ -243,10 +258,15 @@ class RBLIF(ModNEFNeuron):
 
     rec = self.reccurent(self.spk)
 
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
+    #   input_.data = self.quantizer(input_.data, True)
+    #   rec.data = self.quantizer(rec.data, True)
+
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec.data = self.quantizer(rec.data, True)
+      self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
+      input_ = QuantizeMembrane.apply(input_, self.quantizer)
+      rec = QuantizeMembrane.apply(rec, self.quantizer)
 
     if self.reset_mechanism == "subtract":
       self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
-- 
GitLab


From f5298795e20795b72df3d4dfdd8cfce022c8e7ee Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Sat, 29 Mar 2025 20:10:30 +0100
Subject: [PATCH 13/23] add qat test

---
 modneflib/modnef/modnef_torch/model.py        | 15 ++++++++++++++
 .../modnef_neurons/blif_model/blif.py         |  4 +---
 .../modnef_neurons/blif_model/rblif.py        | 20 +++++++++----------
 .../modnef_neurons/srlif_model/shiftlif.py    |  5 +++--
 .../quantizer/dynamic_scale_quantizer.py      |  2 +-
 5 files changed, 30 insertions(+), 16 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 0da59bc..152cd93 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -100,6 +100,21 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.init_quantizer()
+
+  def quantize_weight(self, force_init=False):
+    """
+    Quantize synaptic weight
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.init_quantizer()
+        m.quantize_weight()
   
   def quantize(self, force_init=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 50f18d8..2ba0908 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -232,6 +232,7 @@ class BLIF(ModNEFNeuron):
       input_.data = self.quantizer(input_.data, True)
       self.mem.data = self.quantizer(self.mem.data, True)
 
+
     self.reset = self.mem_reset(self.mem)
 
     if self.reset_mechanism == "subtract":
@@ -241,9 +242,6 @@ class BLIF(ModNEFNeuron):
     else:
       self.mem = self.mem*self.beta
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index c759fcf..586048c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -258,15 +258,15 @@ class RBLIF(ModNEFNeuron):
 
     rec = self.reccurent(self.spk)
 
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
-    #   input_.data = self.quantizer(input_.data, True)
-    #   rec.data = self.quantizer(rec.data, True)
-
     if self.quantization_flag:
-      self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
-      input_ = QuantizeMembrane.apply(input_, self.quantizer)
-      rec = QuantizeMembrane.apply(rec, self.quantizer)
+      self.mem.data = self.quantizer(self.mem.data, True)
+      input_.data = self.quantizer(input_.data, True)
+      rec.data = self.quantizer(rec.data, True)
+
+    # if self.quantization_flag:
+    #   self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
+    #   input_ = QuantizeMembrane.apply(input_, self.quantizer)
+    #   rec = QuantizeMembrane.apply(rec, self.quantizer)
 
     if self.reset_mechanism == "subtract":
       self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
@@ -279,8 +279,8 @@ class RBLIF(ModNEFNeuron):
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
       self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
 
     self.spk = self.fire(self.mem)
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 1906fb5..654027d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -264,8 +264,8 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+    # if self.quantization_flag:
+    #   self.mem.data = self.quantizer(self.mem.data, True)
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
@@ -331,6 +331,7 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    print(self.threshold)
 
 
   @classmethod
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 2631170..0f25c44 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -114,7 +114,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       weight = torch.Tensor(weight)
 
     if not torch.is_tensor(rec_weight):
-      rec_weight = torch.Tensor(weight)
+      rec_weight = torch.Tensor(rec_weight)
 
     if self.signed==None:
       self.signed = torch.min(weight.min(), rec_weight.min())<0.0
-- 
GitLab


From 0c9b2c213a9b81381648ab37eca335779545dcf2 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Sat, 29 Mar 2025 21:06:25 +0100
Subject: [PATCH 14/23] modif

---
 ModNEF_Sources/modules/uart/uart_xstep_timer.vhd       |  2 +-
 .../modnef/modnef_driver/drivers/xstep_timer_driver.py | 10 ++++++----
 .../modnef_neurons/srlif_model/shiftlif.py             |  4 ++++
 3 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/ModNEF_Sources/modules/uart/uart_xstep_timer.vhd b/ModNEF_Sources/modules/uart/uart_xstep_timer.vhd
index e38d9b2..53d5a40 100644
--- a/ModNEF_Sources/modules/uart/uart_xstep_timer.vhd
+++ b/ModNEF_Sources/modules/uart/uart_xstep_timer.vhd
@@ -220,7 +220,7 @@ begin
             end if;
 
           when wait_out_aer =>
-            if i_emu_ready = '1' then
+            if i_emu_ready = '1' and network_to_uart_busy='0' then
               count_time <= '0';
               if read_empty = '1' then -- no more data to process
                 emu_state <= push_timer;
diff --git a/modneflib/modnef/modnef_driver/drivers/xstep_timer_driver.py b/modneflib/modnef/modnef_driver/drivers/xstep_timer_driver.py
index 58de027..ebb2357 100644
--- a/modneflib/modnef/modnef_driver/drivers/xstep_timer_driver.py
+++ b/modneflib/modnef/modnef_driver/drivers/xstep_timer_driver.py
@@ -144,7 +144,7 @@ class XStep_Timer_Driver(ModNEF_Driver):
     sample_aer = [transformation(input_sample[step]) for step in range(len(input_sample))]
 
     for es in range(extra_step):
-      sample_aer.append([])
+      sample_aer.append([0, 1, 2, 3])
     
     step_send = 0
     res = [0 for _ in range(self.output_layer_size)]
@@ -155,11 +155,11 @@ class XStep_Timer_Driver(ModNEF_Driver):
 
     for step in range(len(sample_aer)):
       next_data = []
+      if(len(sample_aer[step])+len(next_data) > 256**self.queue_read_width):
+        print(f"warning, the read queue cannot encode the len of emulation step : acutal len {len(sample_aer[step])}, maximum len {256**self.qeue_read_width}")
       next_data.append(len(sample_aer[step]))
       next_data.extend(sample_aer[step])
-      if(len(sample_aer[step]) > 256**self.queue_read_width):
-        print(f"warning, the read queue cannot encode the len of emulation step : acutal len {len(sample_aer[step])}, maximum len {256**self.qeue_read_width}")
-      if len(data) + len(next_data) > self.queue_read_depth or (step_send+1)*self.output_layer_size > self.queue_write_depth-2:
+      if len(data) + len(next_data) > self.queue_read_depth or step_send*self.output_layer_size > self.queue_write_depth-2:
         emulation_result = self.rust_driver.data_transmission(data, False)
 
         res_step = self._unpack_data(emulation_result)
@@ -201,8 +201,10 @@ class XStep_Timer_Driver(ModNEF_Driver):
 
     #print((data[0]*256 + data[1])*self.clock_period)
 
+
     self.sample_time += (data[0]*256 + data[1])*self.clock_period
 
+
     while index < len(data):
       n_data = data[index]
       index += 1
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 1906fb5..395b4bc 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -257,6 +257,10 @@ class ShiftLIF(ModNEFNeuron):
 
     self.mem = self.mem+input_
 
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
+      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+
     if self.reset_mechanism == "subtract":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
     elif self.reset_mechanism == "zero":
-- 
GitLab


From 2daec63a28a7655f37eef6f3d289939e4c407b78 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 31 Mar 2025 12:49:41 +0200
Subject: [PATCH 15/23] add test qat

---
 .../modnef_neurons/modnef_torch_neuron.py     | 35 +++++++++++++------
 .../modnef_neurons/srlif_model/shiftlif.py    | 12 +++++--
 .../quantizer/dynamic_scale_quantizer.py      |  1 -
 3 files changed, 34 insertions(+), 14 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index b6a0432..56e3d10 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -14,20 +14,23 @@ from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
 
-
-import torch
-import torch.nn as nn
 import torch.nn.functional as F
+import brevitas.nn as qnn
 
 class QuantizeSTE(torch.autograd.Function):
-    """Quantization avec Straight-Through Estimator (STE)"""
     @staticmethod
-    def forward(ctx, x, quantizer):
-        return quantizer(x, True)
+    def forward(ctx, weights, quantizer):
+        
+        q_weights = quantizer(weights, True)
+
+        #ctx.scale = quantizer.scale_factor  # On sauvegarde le scale pour backward
+        return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None  # STE: Passe le gradient inchangé
+        # STE : on passe directement le gradient au poids float
+        return grad_output, None
+
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
+    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
@@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron):
 
     self.quantizer = quantizer
 
+
+    self.alpha = 0.9
+
   @classmethod
   def from_dict(cls, dict):
     """
@@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self, unscale : bool = True):
+  def quantize_weight(self, unscale : bool = True, ema = False):
     """
     synaptic weight quantization
 
@@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = self.quantizer(p.data, unscale=unscale)
-      #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
-      #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      
+      # if ema:
+      #   print(self.alpha)
+      #   p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
+      #   self.alpha *= 0.1
+      #   #p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      # else:
+      p.data = self.quantizer(p.data, True)
+     
   
   def quantize_hp(self, unscale : bool = True):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 5d938b4..a7c1b1f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer
 from snntorch import LIF
 
     
-
+class QuantizeSTE(torch.autograd.Function):
+    """Quantization avec Straight-Through Estimator (STE)"""
+    @staticmethod
+    def forward(ctx, x, quantizer):
+        return quantizer(x, True)
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        return grad_output, None  # STE: Passe le gradient inchangé
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron):
 
     if self.quantization_flag:
       self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
+      #input_.data = self.quantizer(input_.data, True)
 
 
     self.mem = self.mem+input_
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 0f25c44..f1d7497 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer):
     -------
     Tensor
     """
-
     return data*self.scale_factor
\ No newline at end of file
-- 
GitLab


From ff5d695f35dca9fd19eb98c8469d7620fd69e3f0 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Tue, 1 Apr 2025 10:02:51 +0200
Subject: [PATCH 16/23] brevitas test

---
 modneflib/modnef/modnef_torch/model.py         |  3 ++-
 .../modnef_neurons/modnef_torch_neuron.py      | 18 ++++++++----------
 .../modnef_neurons/srlif_model/shiftlif.py     | 18 ++++++++++++------
 3 files changed, 22 insertions(+), 17 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 152cd93..2c0d687 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -113,7 +113,8 @@ class ModNEFModel(nn.Module):
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
-        m.init_quantizer()
+        if force_init:
+          m.init_quantizer()
         m.quantize_weight()
   
   def quantize(self, force_init=False):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 56e3d10..55298a1 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid
 
 import torch.nn.functional as F
 import brevitas.nn as qnn
+import brevitas.quant as bq
 
 class QuantizeSTE(torch.autograd.Function):
     @staticmethod
@@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function):
         
         q_weights = quantizer(weights, True)
 
-        #ctx.scale = quantizer.scale_factor  # On sauvegarde le scale pour backward
+        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
         return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
         # STE : on passe directement le gradient au poids float
+        #scale_factor, = ctx.saved_tensors
         return grad_output, None
 
 
@@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
-    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
+    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
@@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      
-      # if ema:
-      #   print(self.alpha)
-      #   p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
-      #   self.alpha *= 0.1
-      #   #p.data = QuantizeSTE.apply(p.data, self.quantizer)
-      # else:
-      p.data = self.quantizer(p.data, True)
+      print(p)
+      p.data = QuantizeSTE.apply(p.data, self.quantizer)
+
+    #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
      
   
   def quantize_hp(self, unscale : bool = True):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index a7c1b1f..b0409a4 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -21,14 +21,21 @@ from snntorch import LIF
 
     
 class QuantizeSTE(torch.autograd.Function):
-    """Quantization avec Straight-Through Estimator (STE)"""
     @staticmethod
-    def forward(ctx, x, quantizer):
-        return quantizer(x, True)
+    def forward(ctx, weights, quantizer):
+        
+        q_weights = quantizer(weights, True)
+
+        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
+        return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output, None  # STE: Passe le gradient inchangé
+        # STE : on passe directement le gradient au poids float
+        #scale_factor, = ctx.saved_tensors
+        return grad_output, None
+    
+
 class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
@@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron):
     self.reset = self.mem_reset(self.mem)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
       #input_.data = self.quantizer(input_.data, True)
 
 
@@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron):
     """
 
     self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    print(self.threshold)
 
 
   @classmethod
-- 
GitLab


From 6fd614f6d3b10b05f7c469fa8d8b18f876b64ea2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Tue, 1 Apr 2025 17:52:37 +0200
Subject: [PATCH 17/23] add quant linear test

---
 .../modnef_neurons/modnef_torch_neuron.py            | 12 +++++++++++-
 .../modnef_neurons/srlif_model/shiftlif.py           |  8 ++++----
 modneflib/modnef/modnef_torch/quantLinear.py         |  0
 3 files changed, 15 insertions(+), 5 deletions(-)
 create mode 100644 modneflib/modnef/modnef_torch/quantLinear.py

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 55298a1..10ef31c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -18,6 +18,10 @@ import torch.nn.functional as F
 import brevitas.nn as qnn
 import brevitas.quant as bq
 
+from brevitas.core.quant import QuantType
+from brevitas.core.restrict_val import RestrictValueType
+from brevitas.core.scaling import ScalingImplType
+
 class QuantizeSTE(torch.autograd.Function):
     @staticmethod
     def forward(ctx, weights, quantizer):
@@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron):
       reset_mechanism=reset_mechanism
     )
     
-    #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
+    # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False,
+    #                           weight_quant_type=QuantType.INT, 
+    #                                  weight_bit_width=8,
+    #                                  weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
+    #                                  weight_scaling_impl_type=ScalingImplType.CONST,
+    #                                  weight_scaling_const=1.0
+    # )
     self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index b0409a4..429d24a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function):
         
         q_weights = quantizer(weights, True)
 
-        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
+        ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
         return q_weights
 
     @staticmethod
     def backward(ctx, grad_output):
         # STE : on passe directement le gradient au poids float
-        #scale_factor, = ctx.saved_tensors
-        return grad_output, None
+        scale_factor, = ctx.saved_tensors
+        return grad_output*scale_factor, None
     
 
 class ShiftLIF(ModNEFNeuron):
@@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron):
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
-      #input_.data = self.quantizer(input_.data, True)
+      input_.data = QuantizeSTE.apply(input_.data, self.quantizer)
 
 
     self.mem = self.mem+input_
diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py
new file mode 100644
index 0000000..e69de29
-- 
GitLab


From 025865dcd6173f85022ec30d4ffe8818c88ded91 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Thu, 3 Apr 2025 11:01:49 +0200
Subject: [PATCH 18/23] add qat

---
 .../modnef/arch_builder/modules/BLIF/blif.py  |  4 ++
 modneflib/modnef/modnef_torch/__init__.py     |  3 +-
 modneflib/modnef/modnef_torch/model.py        | 30 ++++++--
 .../modnef_neurons/blif_model/blif.py         | 41 +++++------
 .../modnef_neurons/blif_model/rblif.py        | 72 ++++++-------------
 .../modnef_neurons/modnef_torch_neuron.py     | 56 +++------------
 .../modnef_neurons/slif_model/rslif.py        | 68 ++++++++----------
 .../modnef_neurons/slif_model/slif.py         | 56 ++++++---------
 .../modnef_neurons/srlif_model/rshiftlif.py   | 50 ++++++-------
 .../modnef_neurons/srlif_model/shiftlif.py    | 57 ++++-----------
 modneflib/modnef/modnef_torch/quantLinear.py  | 59 +++++++++++++++
 modneflib/modnef/quantizer/__init__.py        |  3 +-
 modneflib/modnef/quantizer/ste_quantizer.py   | 62 ++++++++++++++++
 13 files changed, 291 insertions(+), 270 deletions(-)
 create mode 100644 modneflib/modnef/quantizer/ste_quantizer.py

diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
index 4d5f09a..b6a0663 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
@@ -193,12 +193,16 @@ class BLif(ModNEFArchMod):
     bw = self.quantizer.bitwidth
 
     mem_file = open(f"{output_path}/{self.mem_init_file}", 'w')
+
+    truc = open(f"temp_{self.mem_init_file}", 'w')
     
     if self.quantizer.signed:
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
+
           w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw)
+          truc.write(f"{i} {j} {two_comp(self.quantizer(weights[i][j]), bw)}\n")
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/modnef_torch/__init__.py b/modneflib/modnef/modnef_torch/__init__.py
index e57895f..2053dab 100644
--- a/modneflib/modnef/modnef_torch/__init__.py
+++ b/modneflib/modnef/modnef_torch/__init__.py
@@ -9,4 +9,5 @@ Descriptions: ModNEF torch lib definition
 
 from .modnef_neurons import *
 from .model_builder import ModNEFModelBuilder
-from .model import ModNEFModel
\ No newline at end of file
+from .model import ModNEFModel
+from .quantLinear import QuantLinear
\ No newline at end of file
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 2c0d687..f81295d 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -1,17 +1,15 @@
 """
 File name: model
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
-Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron
+Dependencies: torch, snntorch, modnef_torch_neuron, modnef_driver
 Descriptions: ModNEF SNN Model
 """
 
-import modnef.modnef_torch.modnef_neurons as mn
 import torch.nn as nn
 import torch
-from modnef.arch_builder import *
 from modnef.modnef_driver import load_driver_from_yaml
 from modnef.modnef_torch.modnef_neurons import ModNEFNeuron
 
@@ -101,6 +99,22 @@ class ModNEFModel(nn.Module):
       if isinstance(m, ModNEFNeuron):
         m.init_quantizer()
 
+  def quantize_hp(self, force_init=False):
+    """
+    Quantize neuron hyper parameters
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        if force_init:
+          m.init_quantizer()
+        m.quantize_hp()
+
   def quantize_weight(self, force_init=False):
     """
     Quantize synaptic weight
@@ -132,6 +146,14 @@ class ModNEFModel(nn.Module):
         m.quantize(force_init=force_init)
 
   def clamp(self, force_init=False):
+    """
+    Clamp synaptic weight with quantizer born
+
+    Parameters
+    ----------
+    force_init = False : bool
+      force quantizer initialization
+    """
 
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 2ba0908..c435ed1 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -1,7 +1,7 @@
 """
 File name: blif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, math, snntorch, modnef.archbuilder, modnef_torch_neuron, modnef.quantizer
@@ -9,10 +9,8 @@ Descriptions: ModNEF torch BLIF neuron model
 Based on snntorch.Leaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
 from math import log, ceil
-from snntorch import Leaky
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
@@ -223,28 +221,28 @@ class BLIF(ModNEFNeuron):
     if not spk==None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
-
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      self.mem.data = self.quantizer(self.mem.data, True)
+    forward_current = self.fc(input_, quant)
 
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
+    
+    self.mem = self.mem + forward_current
 
-    self.reset = self.mem_reset(self.mem)
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
-      self.mem = (self.mem+input_)*self.beta-self.reset*self.threshold
+      self.mem = self.mem*self.beta#-self.reset*self.threshold
     elif self.reset_mechanism == "zero":
-      self.mem = (self.mem+input_)*self.beta-self.reset*self.mem
+      self.mem = self.mem*self.beta-self.reset*self.mem
     else:
       self.mem = self.mem*self.beta
 
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -295,19 +293,14 @@ class BLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
+    self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 586048c..8b49564 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -1,7 +1,7 @@
 """
 File name: rblif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,29 +9,14 @@ Descriptions: ModNEF torch reccurrent BLIF neuron model
 Based on snntorch.RLeaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import Leaky
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
+from modnef.modnef_torch.quantLinear import QuantLinear
 
-import torch.autograd as autograd
-
-class QuantizeMembrane(autograd.Function):
-    @staticmethod
-    def forward(ctx, U, quantizer):
-        max_val = U.abs().max().detach()  # Détachement pour éviter de bloquer le gradient
-        U_quant = quantizer(U, True)
-        ctx.save_for_backward(U, quantizer.scale_factor)
-        return U_quant
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        grad_input, factor = ctx.saved_tensors
-        return grad_output, None
 
 class RBLIF(ModNEFNeuron):
   """
@@ -141,7 +126,7 @@ class RBLIF(ModNEFNeuron):
     
     self.register_buffer("beta", torch.tensor(beta))
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self._init_mem()
 
@@ -246,41 +231,35 @@ class RBLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    forward_current = self.fc(input_, quant)
 
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_, device=self.spk.device)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
 
-    self.reset = self.mem_reset(self.mem)
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    rec = self.reccurent(self.spk)
+    self.reset = self.mem_reset(self.mem)
+    
+    rec_current = self.reccurent(self.spk, quant)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec.data = self.quantizer(rec.data, True)
+    self.mem = self.mem + forward_current + rec_current
 
-    # if self.quantization_flag:
-    #   self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
-    #   input_ = QuantizeMembrane.apply(input_, self.quantizer)
-    #   rec = QuantizeMembrane.apply(rec, self.quantizer)
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
-      self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
+      self.mem = self.mem*self.beta-self.reset*self.threshold
     elif self.reset_mechanism == "zero":
-      self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.mem
+      self.mem = self.mem*self.beta-self.reset*self.mem
     else:
       self.mem = self.mem*self.beta
 
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
-
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -331,19 +310,14 @@ class RBLIF(ModNEFNeuron):
     )
     return module
   
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
-    self.beta.data = self.quantizer(self.beta.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
+    self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 10ef31c..49b31e5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -1,7 +1,7 @@
 """
 File name: modnef_torch_neuron
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.0.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch
@@ -12,31 +12,7 @@ import torch
 import torch.nn as nn
 from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
-from snntorch.surrogate import fast_sigmoid
-
-import torch.nn.functional as F
-import brevitas.nn as qnn
-import brevitas.quant as bq
-
-from brevitas.core.quant import QuantType
-from brevitas.core.restrict_val import RestrictValueType
-from brevitas.core.scaling import ScalingImplType
-
-class QuantizeSTE(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, weights, quantizer):
-        
-        q_weights = quantizer(weights, True)
-
-        #ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
-        return q_weights
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        # STE : on passe directement le gradient au poids float
-        #scale_factor, = ctx.saved_tensors
-        return grad_output, None
-
+from ..quantLinear import QuantLinear
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -82,15 +58,8 @@ class ModNEFNeuron(SpikingNeuron):
       spike_grad=spike_grad,
       reset_mechanism=reset_mechanism
     )
-    
-    # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False,
-    #                           weight_quant_type=QuantType.INT, 
-    #                                  weight_bit_width=8,
-    #                                  weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
-    #                                  weight_scaling_impl_type=ScalingImplType.CONST,
-    #                                  weight_scaling_const=1.0
-    # )
-    self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
+
+    self.fc = QuantLinear(in_features=in_features, out_features=out_features)
 
 
     self.hardware_estimation_flag = False
@@ -126,7 +95,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self, unscale : bool = True, ema = False):
+  def quantize_weight(self):
     """
     synaptic weight quantization
 
@@ -136,21 +105,13 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      print(p)
       p.data = QuantizeSTE.apply(p.data, self.quantizer)
-
-    #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
      
   
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
     
     raise NotImplementedError()
@@ -174,6 +135,11 @@ class ModNEFNeuron(SpikingNeuron):
   def clamp(self, force_init=False):
     """
     Clamp synaptic weight
+
+    Parameters
+    ----------
+    force_init = Fasle : bool
+      force quantizer initialization
     """
 
     if force_init:
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 7d9c21f..c1915a7 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -1,7 +1,7 @@
 """
 File name: rslif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent SLIF neuron model
 Based on snntorch.RLeaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
-from modnef.quantizer import MinMaxQuantizer
+from modnef.quantizer import MinMaxQuantizer, QuantizeSTE
+from modnef.modnef_torch.quantLinear import QuantLinear
 
 class RSLIF(ModNEFNeuron):
   """
@@ -131,7 +130,7 @@ class RSLIF(ModNEFNeuron):
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self._init_mem()
 
@@ -238,45 +237,39 @@ class RSLIF(ModNEFNeuron):
 
     if not spk == None:
       self.spk = spk
-    
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.ones_like(input_)*self.v_rest
-    
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_)
+    forward_current = self.fc(input_, quant)
 
-    self.reset = self.mem_reset(self.mem)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest
 
-    rec_input = self.reccurent(self.spk)
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      rec_input.data = self.quantizer(rec_input.data, True)
-      self.mem = self.quantizer(self.mem.data, True)
 
-    self.mem = self.mem + input_ + rec_input
+    self.reset = self.mem_reset(self.mem)
+
+    rec_current = self.reccurent(self.spk, quant)
+
+    self.mem = self.mem + forward_current + rec_current 
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
+    # update neuron
     self.mem = self.mem - self.v_leak
+    min_reset = (self.mem<self.v_min).to(torch.float32)
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
-    self.spk = self.fire(self.mem)
+    spk = self.fire(self.mem)
 
-    do_spike_reset = (self.spk/self.graded_spikes_factor - self.reset)
-    do_min_reset = (self.mem<self.v_min).to(torch.float32)
-
-    self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest)
-    self.mem = self.mem - do_min_reset*(self.mem-self.v_rest)
-
-    return self.spk, self.mem
+    return spk, self.mem
   
   def get_builder_module(self, module_name : str, output_path : str = "."):
     """
@@ -324,21 +317,16 @@ class RSLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
-    self.v_min.data = self.quantizer(self.v_min.data, unscale)
-    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
-    self.threshold.data = self.quantizer(self.threshold, unscale)
+    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 089519b..43e2e1c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -1,7 +1,7 @@
 """
 File name: slif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,12 @@ Descriptions: ModNEF torch SLIF neuron model
 Based on snntorch.Leaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
-from modnef.quantizer import MinMaxQuantizer
+from modnef.quantizer import MinMaxQuantizer, QuantizeSTE
 
 class SLIF(ModNEFNeuron):
   """
@@ -235,36 +233,33 @@ class SLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.ones_like(input_)*self.v_rest
-
-    if self.quantization_flag:
-      input_.data = self.quantizer(input_.data, True)
-      self.mem.data = self.quantizer(self.mem.data, True)
+    forward_current = self.fc(input_, quant)
 
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest
+    
+    
     self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem + input_
+    self.mem = self.mem + forward_current
+
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
-      
-    self.mem = self.mem-self.v_leak
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
+
+    # update neuron
+    self.mem = self.mem - self.v_leak
+    min_reset = (self.mem<self.v_min).to(torch.float32)
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     spk = self.fire(self.mem)
 
-    do_spike_reset = (spk/self.graded_spikes_factor - self.reset)
-    do_min_reset = (self.mem<self.v_min).to(torch.float32)
-
-    self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest)
-    self.mem = self.mem - do_min_reset*(self.mem-self.v_rest)
-
     return spk, self.mem
   
   def get_builder_module(self, module_name : str, output_path : str = "."):
@@ -314,21 +309,16 @@ class SLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     We assume you've already intialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
-    self.v_min.data = self.quantizer(self.v_min.data, unscale)
-    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
-    self.threshold.data = self.quantizer(self.threshold, unscale)
+    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index f3dda9a..0ea9983 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: rsrlif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent Shift LIF neuron model
 Based on snntorch.RLeaky and snntorch.LIF class
 """
 
-import torch.nn as nn
 import torch
-from snntorch import LIF
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
+from modnef.modnef_torch.quantLinear import QuantLinear
 from math import log, ceil
-from modnef.quantizer import DynamicScaleFactorQuantizer
+from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE
 
 
 class RShiftLIF(ModNEFNeuron):
@@ -131,7 +130,7 @@ class RShiftLIF(ModNEFNeuron):
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
 
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = QuantLinear(out_features, out_features)
 
     self.register_buffer("beta", torch.tensor(beta))
 
@@ -246,25 +245,25 @@ class RShiftLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
+    quant = self.quantizer if self.quantization_flag else None
 
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    forward_current = self.fc(input_, quant)
 
-    if not self.spk.shape == input_.shape:
-      self.spk = torch.zeros_like(input_, device=self.spk.device)
-
-    self.reset = self.mem_reset(self.mem)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
+    
+    if not self.spk.shape == forward_current.shape:
+      self.spk = torch.zeros_like(forward_current, device=self.spk.device)
 
-    rec_input = self.reccurent(self.spk)
+    rec_current = self.reccurent(self.spk, quant)
 
-    if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-      input_.data = self.quantizer(input_.data, True)
-      rec_input.data = self.quantizer(rec_input.data, True)
+    self.reset = self.mem_reset(self.mem)
 
+    self.mem = self.mem+forward_current+rec_current
 
-    self.mem = self.mem+input_+rec_input
+    if self.hardware_estimation_flag:
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
@@ -274,11 +273,7 @@ class RShiftLIF(ModNEFNeuron):
       self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
-      self.mem.data = self.quantizer(self.mem.data, True)
-
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -330,18 +325,13 @@ class RShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 429d24a..5291733 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: srlif
 Author: Aurélie Saulquin  
-Version: 1.1.0
+Version: 1.2.1
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -9,32 +9,15 @@ Descriptions: ModNEF torch Shift LIF neuron model
 Based on snntorch.Leaky and snntroch.LIF class
 """
 
-import torch.nn as nn
 import torch
 from snntorch.surrogate import fast_sigmoid
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
-from modnef.quantizer import DynamicScaleFactorQuantizer
-from snntorch import LIF
+from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE
 
     
-class QuantizeSTE(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, weights, quantizer):
-        
-        q_weights = quantizer(weights, True)
-
-        ctx.save_for_backward(quantizer.scale_factor)  # On sauvegarde le scale pour backward
-        return q_weights
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        # STE : on passe directement le gradient au poids float
-        scale_factor, = ctx.saved_tensors
-        return grad_output*scale_factor, None
-    
 
 class ShiftLIF(ModNEFNeuron):
   """
@@ -258,23 +241,20 @@ class ShiftLIF(ModNEFNeuron):
     if not spk == None:
       self.spk = spk
 
-    input_ = self.fc(input_)
-
-    if not self.mem.shape == input_.shape:
-      self.mem = torch.zeros_like(input_, device=self.mem.device)
+    quant = self.quantizer if self.quantization_flag else None
 
-    self.reset = self.mem_reset(self.mem)
+    forward_current = self.fc(input_, quant)
 
-    if self.quantization_flag:
-      self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
-      input_.data = QuantizeSTE.apply(input_.data, self.quantizer)
+    if not self.mem.shape == forward_current.shape:
+      self.mem = torch.zeros_like(forward_current, device=self.mem.device)
 
+    self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem+input_
+    self.mem = self.mem+forward_current
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min)
+      self.val_max = torch.max(self.mem.max(), self.val_max)
 
     if self.reset_mechanism == "subtract":
       self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
@@ -283,12 +263,8 @@ class ShiftLIF(ModNEFNeuron):
     else:
       self.mem = self.mem-self.__shift(self.mem)
 
-    # if self.quantization_flag:
-    #   self.mem.data = self.quantizer(self.mem.data, True)
-
-    if self.hardware_estimation_flag:
-      self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
-      self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
+    if self.quantization_flag:
+      self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
 
     self.spk = self.fire(self.mem)
 
@@ -338,18 +314,13 @@ class ShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_hp(self, unscale : bool = True):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization.
     We assume you already initialize quantizer
-
-    Parameters
-    ----------
-    unscale : bool = True
-      set to true if quantization must be simulate
     """
 
-    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py
index e69de29..57f67b4 100644
--- a/modneflib/modnef/modnef_torch/quantLinear.py
+++ b/modneflib/modnef/modnef_torch/quantLinear.py
@@ -0,0 +1,59 @@
+"""
+File name: quantLinear
+Author: Aurélie Saulquin  
+Version: 0.1.1
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: torch, modnef.quantizer
+Descriptions: Quantized Linear torch layer
+"""
+
+import torch.nn as nn
+from modnef.quantizer import QuantizeSTE
+
+class QuantLinear(nn.Linear):
+  """
+  Quantized Linear torch layer
+  Extended from torch.nn.Linear
+
+  Methods
+  -------
+  forward(x, quantizer=None)
+    Apply linear forward, if quantizer!=None, quantized weight are used for linear
+  """
+
+  def __init__(self, in_features : int, out_features : int):
+    """
+    Initialize class
+
+    Parameters
+    ----------
+    in_features : int
+      input features of layer
+    out_features : int
+      output features of layer
+    """
+
+    super().__init__(in_features=in_features, out_features=out_features, bias=False)
+
+  def forward(self, x, quantizer=None):
+    """
+    Apply linear forward, if quantizer!=None, quantized weight are used for linear
+
+    Parameters
+    ----------
+    x : Torch
+      input spikes
+    quantizer = None : Quantizer
+      quantization method.
+      If None, full precision weight are used for linear
+    """
+
+    if quantizer!=None:
+      w = QuantizeSTE.apply(self.weight, quantizer)
+      w.data = quantizer.clamp(w)
+    else:
+      w = self.weight
+
+
+    return nn.functional.linear(x, w)
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/__init__.py b/modneflib/modnef/quantizer/__init__.py
index 2455aa9..2ef9cb9 100644
--- a/modneflib/modnef/quantizer/__init__.py
+++ b/modneflib/modnef/quantizer/__init__.py
@@ -10,4 +10,5 @@ Descriptions: ModNEF quantizer method
 from .quantizer import Quantizer
 from .fixed_point_quantizer import FixedPointQuantizer
 from .min_max_quantizer import MinMaxQuantizer
-from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
\ No newline at end of file
+from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
+from .ste_quantizer import QuantizeSTE
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py
new file mode 100644
index 0000000..565cbd8
--- /dev/null
+++ b/modneflib/modnef/quantizer/ste_quantizer.py
@@ -0,0 +1,62 @@
+"""
+File name: ste_quantizer
+Author: Aurélie Saulquin  
+Version: 0.1.0
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: torch
+Descriptions: Straight-Throught Estimator quantization method
+"""
+
+import torch
+
+class QuantizeSTE(torch.autograd.Function):
+  """
+  Straight-Throught Estimator quantization method
+
+  Methods
+  -------
+  @staticmethod
+  forward(ctx, data, quantizer)
+    Apply quantization method to data
+  @staticmethod
+  backward(ctx, grad_output)
+    Returns backward gradient
+  """
+  
+  @staticmethod
+  def forward(ctx, data, quantizer):
+    """
+    Apply quantization method to data
+
+    Parameters
+    ----------
+    ctx : torch.autograd.function.BackwardCFunction
+      Autograd context used to store variables for the backward pass
+    data : Tensor
+      data to quantize
+    quantizer : Quantizer
+      quantization method applied to data
+    """
+    
+    q_data = quantizer(data, True)
+
+
+    ctx.scale = quantizer.scale_factor
+
+    return q_data
+
+  @staticmethod
+  def backward(ctx, grad_output):
+    """
+    Return backward gradient without modificiation
+
+    Parameters
+    ----------
+    ctx : torch.autograd.function.BackwardCFunction
+      Autograd context used to store variables for the backward pass
+    grad_output : Tensor
+      gradient
+    """
+    
+    return grad_output, None
\ No newline at end of file
-- 
GitLab


From dcd96623965c1a4ffe3152769c4873a8355f6fdc Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Fri, 4 Apr 2025 08:34:39 +0200
Subject: [PATCH 19/23] add qat test

---
 ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd          | 4 +++-
 ModNEF_Sources/modules/uart/uart_xstep.vhd                | 2 +-
 modneflib/modnef/arch_builder/modules/BLIF/blif.py        | 3 ---
 .../modnef/modnef_torch/modnef_neurons/blif_model/blif.py | 8 ++++++--
 .../modnef_torch/modnef_neurons/modnef_torch_neuron.py    | 1 +
 5 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
index 83722e0..468271e 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
@@ -121,7 +121,9 @@ begin
             case state is 
               when multiplication =>
                 V_mult := std_logic_vector(signed(V) * signed(beta));
-                V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
+                V_mult := std_logic_vector(shift_right(signed(V_mult), fixed_point));
+                V_buff := V_mult(variable_size-1 downto 0);
+                --V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
 
                 if signed(V_buff) >= signed(v_threshold) then
                   spike <= '1';
diff --git a/ModNEF_Sources/modules/uart/uart_xstep.vhd b/ModNEF_Sources/modules/uart/uart_xstep.vhd
index efa6275..43ade23 100644
--- a/ModNEF_Sources/modules/uart/uart_xstep.vhd
+++ b/ModNEF_Sources/modules/uart/uart_xstep.vhd
@@ -196,7 +196,7 @@ begin
             end if;
 
           when wait_out_aer =>
-            if i_emu_ready = '1' then
+            if i_emu_ready = '1' and network_to_uart_busy='0' then
               if read_empty = '1' then -- no more data to process
                 start_uart_transmission <= '1';
                 emu_state <= send_aer;
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
index b6a0663..9dc9af6 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
@@ -193,8 +193,6 @@ class BLif(ModNEFArchMod):
     bw = self.quantizer.bitwidth
 
     mem_file = open(f"{output_path}/{self.mem_init_file}", 'w')
-
-    truc = open(f"temp_{self.mem_init_file}", 'w')
     
     if self.quantizer.signed:
       for i in range(self.input_neuron):
@@ -202,7 +200,6 @@ class BLif(ModNEFArchMod):
         for j in range(self.output_neuron-1, -1, -1):
 
           w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw)
-          truc.write(f"{i} {j} {two_comp(self.quantizer(weights[i][j]), bw)}\n")
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index c435ed1..08f17e7 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -228,14 +228,17 @@ class BLIF(ModNEFNeuron):
     if not self.mem.shape == forward_current.shape:
       self.mem = torch.zeros_like(forward_current, device=self.mem.device)
     
-    self.mem = self.mem + forward_current
+    self.reset = self.mem_reset(self.mem)
+
+    self.mem = self.mem + forward_current - self.reset*self.threshold
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
+
     if self.reset_mechanism == "subtract":
-      self.mem = self.mem*self.beta#-self.reset*self.threshold
+      self.mem = self.mem*self.beta
     elif self.reset_mechanism == "zero":
       self.mem = self.mem*self.beta-self.reset*self.mem
     else:
@@ -268,6 +271,7 @@ class BLIF(ModNEFNeuron):
       if self.hardware_estimation_flag:
         val_max = max(abs(self.val_max), abs(self.val_min))
         val_max = self.quantizer(val_max)
+        print(val_max)
         self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8
       else:
         self.hardware_description["variable_size"]=16
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 49b31e5..9a48b0c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -60,6 +60,7 @@ class ModNEFNeuron(SpikingNeuron):
     )
 
     self.fc = QuantLinear(in_features=in_features, out_features=out_features)
+    #self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
 
 
     self.hardware_estimation_flag = False
-- 
GitLab


From d08d6933b098c70afc903223f4f36e79c682073b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 4 Apr 2025 12:42:29 +0200
Subject: [PATCH 20/23] add clamp to arch builder

---
 ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd  |  4 +---
 .../modnef/arch_builder/modules/BLIF/blif.py      |  4 ++--
 .../arch_builder/modules/BLIF/blif_debugger.py    |  4 ++--
 .../modnef/arch_builder/modules/BLIF/rblif.py     |  8 ++++----
 .../modnef/arch_builder/modules/SLIF/rslif.py     |  8 ++++----
 .../modnef/arch_builder/modules/SLIF/slif.py      |  4 ++--
 .../arch_builder/modules/SLIF/slif_debugger.py    |  4 ++--
 .../arch_builder/modules/ShiftLIF/rshiftlif.py    |  8 ++++----
 .../arch_builder/modules/ShiftLIF/shiftlif.py     |  4 ++--
 modneflib/modnef/modnef_torch/model.py            |  3 ++-
 .../modnef_neurons/blif_model/blif.py             | 15 +++++++--------
 .../modnef_neurons/blif_model/rblif.py            | 12 ++++++------
 .../modnef_neurons/modnef_torch_neuron.py         |  2 +-
 .../modnef_neurons/slif_model/rslif.py            |  4 +++-
 .../modnef_neurons/slif_model/slif.py             |  3 ++-
 .../modnef_neurons/srlif_model/rshiftlif.py       | 12 ++++++------
 .../modnef_neurons/srlif_model/shiftlif.py        | 12 ++++++------
 modneflib/modnef/quantizer/quantizer.py           |  7 ++++++-
 modneflib/modnef/quantizer/ste_quantizer.py       |  4 ++--
 19 files changed, 64 insertions(+), 58 deletions(-)

diff --git a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
index 468271e..83722e0 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd
@@ -121,9 +121,7 @@ begin
             case state is 
               when multiplication =>
                 V_mult := std_logic_vector(signed(V) * signed(beta));
-                V_mult := std_logic_vector(shift_right(signed(V_mult), fixed_point));
-                V_buff := V_mult(variable_size-1 downto 0);
-                --V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
+                V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point);
 
                 if signed(V_buff) >= signed(v_threshold) then
                   spike <= '1';
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
index 9dc9af6..0caecf1 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py
@@ -199,7 +199,7 @@ class BLif(ModNEFArchMod):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
 
-          w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw)
+          w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), bw)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -210,7 +210,7 @@ class BLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
index f6f7ca9..2940019 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py
@@ -213,7 +213,7 @@ class BLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -224,7 +224,7 @@ class BLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
index 81f87c6..d427c78 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
@@ -205,7 +205,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -213,7 +213,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
index 1c15613..8ebfced 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
@@ -214,14 +214,14 @@ class RSLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     else:
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     mem_file.close()
@@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     else:
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     mem_file.close()
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif.py b/modneflib/modnef/arch_builder/modules/SLIF/slif.py
index ddf5c96..991c16e 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/slif.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/slif.py
@@ -201,7 +201,7 @@ class SLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
       self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size)
@@ -213,7 +213,7 @@ class SLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
 
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
index 7fb587c..dfca469 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py
@@ -215,7 +215,7 @@ class SLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
       self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size)
@@ -227,7 +227,7 @@ class SLif_Debugger(ModNEFDebuggerMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
 
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
index f3077ae..bc4cb58 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
@@ -206,7 +206,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -214,7 +214,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
index 35bb734..698307e 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
@@ -196,7 +196,7 @@ class ShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -206,7 +206,7 @@ class ShiftLif(ModNEFArchMod):
       for i in range(self.input_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j])
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
         
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index f81295d..5e26bc2 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -131,7 +131,7 @@ class ModNEFModel(nn.Module):
           m.init_quantizer()
         m.quantize_weight()
   
-  def quantize(self, force_init=False):
+  def quantize(self, force_init=False, clamp=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
@@ -144,6 +144,7 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.quantize(force_init=force_init)
+        m.clamp(False)
 
   def clamp(self, force_init=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 08f17e7..0d4ae6e 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -230,19 +230,18 @@ class BLIF(ModNEFNeuron):
     
     self.reset = self.mem_reset(self.mem)
 
-    self.mem = self.mem + forward_current - self.reset*self.threshold
+    self.mem = self.mem + forward_current
+    
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem*self.beta
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem*self.beta-self.reset*self.mem
-    else:
-      self.mem = self.mem*self.beta
+    self.mem = self.mem*self.beta
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 8b49564..f52eed0 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -247,16 +247,16 @@ class RBLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current + rec_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem*self.beta-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem*self.beta-self.reset*self.mem
-    else:
-      self.mem = self.mem*self.beta
+    self.mem = self.mem*self.beta
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 9a48b0c..208da8d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = QuantizeSTE.apply(p.data, self.quantizer)
+      p.data = QuantizeSTE.apply(p.data, self.quantizer, True)
      
   
   def quantize_hp(self):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index c1915a7..38364e2 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -255,6 +255,8 @@ class RSLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current + rec_current 
 
+    self.mem = self.mem-self.reset*(self.mem-self.v_rest)
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
@@ -262,7 +264,7 @@ class RSLIF(ModNEFNeuron):
     # update neuron
     self.mem = self.mem - self.v_leak
     min_reset = (self.mem<self.v_min).to(torch.float32)
-    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
+    self.mem = self.mem-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 43e2e1c..6a6c7d3 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -245,6 +245,7 @@ class SLIF(ModNEFNeuron):
 
     self.mem = self.mem + forward_current
 
+    self.mem = self.mem - self.reset*(self.mem-self.v_rest)
 
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
@@ -253,7 +254,7 @@ class SLIF(ModNEFNeuron):
     # update neuron
     self.mem = self.mem - self.v_leak
     min_reset = (self.mem<self.v_min).to(torch.float32)
-    self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest)
+    self.mem = self.mem-min_reset*(self.mem-self.v_rest)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 0ea9983..65d8ba5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -261,16 +261,16 @@ class RShiftLIF(ModNEFNeuron):
 
     self.mem = self.mem+forward_current+rec_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem
-    else:
-      self.mem = self.mem-self.__shift(self.mem)
+    self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 5291733..70f9d96 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -252,16 +252,16 @@ class ShiftLIF(ModNEFNeuron):
 
     self.mem = self.mem+forward_current
 
+    if self.reset_mechanism == "subtract":
+      self.mem = self.mem-self.reset*self.threshold
+    elif self.reset_mechanism == "zero":
+      self.mem = self.mem-self.reset*self.mem
+
     if self.hardware_estimation_flag:
       self.val_min = torch.min(self.mem.min(), self.val_min)
       self.val_max = torch.max(self.mem.max(), self.val_max)
 
-    if self.reset_mechanism == "subtract":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold
-    elif self.reset_mechanism == "zero":
-      self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem
-    else:
-      self.mem = self.mem-self.__shift(self.mem)
+    self.mem = self.mem-self.__shift(self.mem)
 
     if self.quantization_flag:
       self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer)
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index c56f506..e264950 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -91,7 +91,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def __call__(self, data, unscale=False):
+  def __call__(self, data, unscale=False, clamp=False):
     """
     Call quantization function
 
@@ -114,6 +114,11 @@ class Quantizer():
 
     qdata = self._quant(tdata)
 
+    if clamp:
+      born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1))
+      born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1)
+      qdata = torch.clamp(qdata, min=born_min, max=born_max)
+
     if unscale:
       qdata = self._unquant(qdata).to(torch.float32)
     
diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py
index 565cbd8..c5ffd32 100644
--- a/modneflib/modnef/quantizer/ste_quantizer.py
+++ b/modneflib/modnef/quantizer/ste_quantizer.py
@@ -25,7 +25,7 @@ class QuantizeSTE(torch.autograd.Function):
   """
   
   @staticmethod
-  def forward(ctx, data, quantizer):
+  def forward(ctx, data, quantizer, clamp=False):
     """
     Apply quantization method to data
 
@@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function):
       quantization method applied to data
     """
     
-    q_data = quantizer(data, True)
+    q_data = quantizer(data, True, clamp)
 
 
     ctx.scale = quantizer.scale_factor
-- 
GitLab


From 906a34d86f852f4850917f230bb10a6ade329dfd Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Fri, 4 Apr 2025 21:59:34 +0200
Subject: [PATCH 21/23] add quantizationa ware training algorithm

---
 modneflib/modnef/quantizer/__init__.py        |   3 +-
 .../modnef/quantizer/fixed_point_quantizer.py |   5 +-
 .../modnef/quantizer/quantizer_scheduler.py   |  56 +++++++
 modneflib/modnef/templates/evaluation.py      |  20 ++-
 modneflib/modnef/templates/model.py           |   2 +-
 modneflib/modnef/templates/run_lib.py         | 154 +++++++++++++++---
 modneflib/modnef/templates/train.py           |  14 +-
 modneflib/modnef/templates/vhdl_generation.py |   3 +
 8 files changed, 218 insertions(+), 39 deletions(-)
 create mode 100644 modneflib/modnef/quantizer/quantizer_scheduler.py

diff --git a/modneflib/modnef/quantizer/__init__.py b/modneflib/modnef/quantizer/__init__.py
index 2ef9cb9..7ca95c0 100644
--- a/modneflib/modnef/quantizer/__init__.py
+++ b/modneflib/modnef/quantizer/__init__.py
@@ -11,4 +11,5 @@ from .quantizer import Quantizer
 from .fixed_point_quantizer import FixedPointQuantizer
 from .min_max_quantizer import MinMaxQuantizer
 from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
-from .ste_quantizer import QuantizeSTE
\ No newline at end of file
+from .ste_quantizer import QuantizeSTE
+from .quantizer_scheduler import QuantizerScheduler
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index 106a242..16b08e5 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -67,9 +67,6 @@ class FixedPointQuantizer(Quantizer):
     dtype = torch.int32 : torch.dtype
       type use during conversion
     """
-
-    if bitwidth==-1 and fixed_point==-1:
-      raise Exception("You must fix at least one value to compute the other one")
     
     super().__init__(
       bitwidth=bitwidth,
@@ -145,7 +142,7 @@ class FixedPointQuantizer(Quantizer):
       elif self.fixed_point==-1:
         self.fixed_point = self.bitwidth-int_part_size
         self.scale_factor = 2**self.fixed_point
-
+  
 
   def _quant(self, data) -> torch.Tensor:
     """
diff --git a/modneflib/modnef/quantizer/quantizer_scheduler.py b/modneflib/modnef/quantizer/quantizer_scheduler.py
new file mode 100644
index 0000000..850dd9b
--- /dev/null
+++ b/modneflib/modnef/quantizer/quantizer_scheduler.py
@@ -0,0 +1,56 @@
+"""
+File name: quantizer_scheduler
+Author: Aurélie Saulquin  
+Version: 0.1.0
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: modnef.modnef_torch
+Descriptions: ModNEF quantizer scheduler
+"""
+
+from modnef.modnef_torch import ModNEFNeuron
+
+class QuantizerScheduler():
+
+  def __init__(self, model, bit_range, T, quantizationMethod):
+
+    self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)]
+
+    self.model = model
+
+    self.period = T
+
+    self.epoch_max = self.period*(len(self.num_bits)-1)
+
+    print(self.num_bits)
+    print(self.epoch_max)
+
+    self.bit_counter = 0
+
+    self.quantizationMethod = quantizationMethod
+
+    self.__update()
+
+    self.epoch_counter = 0
+
+  def __update(self):
+    print(self.num_bits[self.bit_counter])
+    for m in self.model.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter])
+        m.init_quantizer()
+        m.quantize_hp()
+
+  def step(self):
+
+    self.epoch_counter += 1
+
+    if self.epoch_counter > self.epoch_max:
+      return
+    else:
+      if self.epoch_counter%self.period==0:
+        self.bit_counter += 1
+        self.__update()
+
+  def save_model(self):
+    return self.epoch_counter >= self.epoch_max
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index 91cba66..1c2898b 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -38,7 +38,6 @@ if __name__ == "__main__":
 
   """Evaluation variable definition"""
   verbose = True
-  save_conf_matrix = False
   output_path = "."
 
   """FPGA file definition"""
@@ -50,9 +49,14 @@ if __name__ == "__main__":
   acc = 0.0
   y_true = None
   y_pred = None
+  run_time = None
+  
   save_conf_matrix = False
-  conf_matrix_file = "confusion_matrix.png"
-  conf_matrix_classes = [str(i) for i in range(10)]
+  conf_matrix_file = "confusion_matrix.svg"
+  num_class = 10
+  conf_matrix_classes = [str(i) for i in range(num_class)]
+
+  save_array = False
   
   
   if kind == "eval":
@@ -74,7 +78,7 @@ if __name__ == "__main__":
       quant=True
       )
   elif kind == "feval":
-    acc, y_pred, y_true = fpga_evaluation(
+    acc, y_pred, y_true, run_time = fpga_evaluation(
       model=model, 
       driver_config=driver_config_path,
       board_path=board_path,
@@ -92,4 +96,10 @@ if __name__ == "__main__":
       y_pred=y_pred,
       file_name=conf_matrix_file,
       classes=conf_matrix_classes
-      )
\ No newline at end of file
+      )
+
+  if save_array:
+    np.save(f"{output_path}/y_true.npy", y_true)
+    np.save(f"{output_path}/y_pred.npy", y_pred)
+    if kind=="feval":
+      np.save(f"{output_path}/run_time.npy", run_time)
\ No newline at end of file
diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py
index ac2fb76..b8bbb3c 100644
--- a/modneflib/modnef/templates/model.py
+++ b/modneflib/modnef/templates/model.py
@@ -141,7 +141,7 @@ class MyModel(mt.ModNEFModel):
     builder = ModNEFBuilder(self.name, 2312, 10)
 
 
-    uart = Uart_XStep(
+    uart = Uart_XStep_Timer(
       name="uart",
       input_layer_size=2312,
       output_layer_size=10,
diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py
index 85a1352..73dd352 100644
--- a/modneflib/modnef/templates/run_lib.py
+++ b/modneflib/modnef/templates/run_lib.py
@@ -1,4 +1,3 @@
-import torch.nn as nn
 from snntorch import surrogate
 import torch
 spike_grad = surrogate.fast_sigmoid(slope=25)
@@ -9,8 +8,7 @@ import matplotlib.pyplot as plt
 from sklearn.metrics import confusion_matrix
 import seaborn as sns
 
-
-def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
+def train_1_epoch(model, trainLoader, optimizer, loss, qat, device, verbose):
   epoch_loss = []
 
   if verbose:
@@ -20,13 +18,17 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
     loader = trainLoader
 
   for _, (data, target) in enumerate(loader):
-    model.train()
+    model.train(quant=qat)
+
+    """Prepare data"""
     data = data.to(device)
     data = data.squeeze(0)
     target = target.to(device)
     
+    """Forward Pass"""
     spk_rec, mem_rec = model(data)
 
+    """Prepare backward"""
     loss_val = torch.zeros((1), dtype=torch.float, device=device)
 
     for step in range(data.shape[1]):
@@ -34,6 +36,7 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
 
     epoch_loss.append(loss_val.item())
 
+    """Backward"""
     model.zero_grad()
     loss_val.backward()
     optimizer.step()
@@ -44,33 +47,76 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
 
   return np.mean(epoch_loss)
 
-def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("cpu"), validationLoader=None, n_epoch=10, best_model_name="best_model", verbose=True, save_plot=False, save_history=False, output_path="."):
+def train(model, 
+          trainLoader, 
+          testLoader, 
+          optimizer, 
+          loss, 
+          lr_scheduler = None,
+          qat = False, 
+          qat_scheduler = None,
+          device=torch.device("cpu"), 
+          validationLoader=None, 
+          n_epoch=10, 
+          best_model_name="best_model", 
+          verbose=True, 
+          save_plot=False, 
+          save_history=False, 
+          output_path="."
+          ):
+  
   avg_loss_history = []
   acc_test_history = []
   acc_val_history = []
-
+  lr_val_history = []
+  bitwidth_val_history = []
 
   best_acc = 0
 
   model = model.to(device)
 
+  if qat: # we prepare model for QAT
+    model.init_quantizer()
+    model.quantize_hp()
+
   for epoch in range(n_epoch):
     if verbose:
       print(f"---------- Epoch : {epoch} ----------")
     
-    epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose)
+    """Model training"""
+    epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose, qat=qat)
     avg_loss_history.append(epoch_loss)
 
+    """Model Validation"""
     if validationLoader!=None:
-      acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device)
+      acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device, quant=qat)
       acc_val_history.append(acc_val)
 
-    acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device)
+    """Model evaluation in test"""
+    acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device, quant=qat)
     acc_test_history.append(acc_test)
 
-    if best_model_name!="" and acc_test>best_acc:
-      torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
-      best_acc = acc_test
+    """Save best model"""
+    if best_model_name!="" and acc_test>best_acc: 
+      if not qat:
+        torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+        best_acc = acc_test
+      else: #if QAT
+        if qat_scheduler==None: # no QAT scheduler so we save our model
+          torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+          best_acc = acc_test
+        elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
+          torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+          best_acc = acc_test
+
+    """Update schedulers"""
+    if lr_scheduler!=None:
+      lr_val_history.append(lr_scheduler.get_last_lr()[0])
+      lr_scheduler.step()
+
+    if qat_scheduler!=None:
+      bitwidth_val_history.append(qat_scheduler.bitwidth)
+      qat_scheduler.step()
 
   if save_history:
     np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
@@ -79,13 +125,19 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
     if len(acc_val_history)!=0:
       np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history))
 
+    if lr_scheduler!=None:
+      np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
+
+    if qat_scheduler!=None:
+      np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
+
   if save_plot:
     plt.figure()  # Create a new figure
     plt.plot([i for i in range(n_epoch)], avg_loss_history)
     plt.title('Average Loss')
     plt.xlabel("Epoch")
     plt.ylabel("Loss")
-    plt.savefig(f"{output_path}/loss.png")
+    plt.savefig(f"{output_path}/loss.svg")
     
     plt.figure()
     if len(acc_val_history)!=0:
@@ -96,7 +148,25 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
     plt.title("Accuracy")
     plt.xlabel("Epoch")
     plt.ylabel("Accuracy")
-    plt.savefig(f"{output_path}/accuracy.png")
+    plt.savefig(f"{output_path}/accuracy.svg")
+
+    if lr_scheduler!=None:
+      plt.figure()
+      plt.plot([i for i in range(n_epoch)], lr_val_history, label="LR Value")
+      plt.legend()
+      plt.title("LR Values")
+      plt.xlabel("Epoch")
+      plt.ylabel("learning rate")
+      plt.savefig(f"{output_path}/lr_values.svg")
+
+    if qat_scheduler!=None:
+      plt.figure()
+      plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
+      plt.legend()
+      plt.title("Quantizer bitwidth")
+      plt.xlabel("Epoch")
+      plt.ylabel("bitwidth")
+      plt.savefig(f"{output_path}/quant_bitwidth.svg")
 
   return avg_loss_history, acc_val_history, acc_test_history, best_acc
 
@@ -154,8 +224,8 @@ def __run_accuracy(model, testLoader, name, verbose, device):
       del spk_rec
       del mem_rec
 
-    y_true = torch.stack(y_true).reshape(-1)
-    y_pred = torch.stack(y_pred).reshape(-1)
+    y_true = torch.stack(y_true).reshape(-1).cpu().numpy()
+    y_pred = torch.stack(y_pred).reshape(-1).cpu().numpy()
     
       
     return (correct/total), y_pred, y_true
@@ -169,9 +239,6 @@ def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"),
 
   model.eval(quant)
 
-  if quant:
-    model.quantize(force_init=True)
-
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true   
@@ -185,13 +252,10 @@ def hardware_estimation(model, testLoader, name="Hardware Estimation", device=to
 
   model.hardware_estimation()
 
-  model.quantize(force_init=True)
-
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true
 
-
 def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False):
   accuracy = 0
   y_pred = []
@@ -203,9 +267,51 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva
 
   model.fpga_eval(board_path, driver_config)
 
-  accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
+  y_true = []
+  y_pred = []
+  run_time = []
+  correct = 0
+  total = 0
 
-  return accuracy, y_pred, y_true   
+  if verbose:
+    bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
+    loader = tqdm(testLoader, desc=name, bar_format=bar_format)
+  else:
+    loader = testLoader
+
+
+  for _, (data, target) in enumerate(loader):
+    
+    data = data.to(device)
+    target = target.to(device)
+
+    y_true.append(target)
+
+    spk_rec, batch_speed = model(data)
+
+    run_time.extend(batch_speed)
+
+    output = (spk_rec.sum(dim=0))/data.shape[1]
+    predicted = output.argmax(dim=1).to(device)
+    correct += predicted.eq(target.view_as(predicted)).sum().item()
+    y_pred.append(predicted)
+    total += target.size(0)
+
+    if verbose:
+      loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f} Run Time : {np.mean(batch_speed)*1e6:.3f} µs")
+
+    del data
+    del target
+    del spk_rec
+
+  y_true = torch.stack(y_true).reshape(-1)
+  y_pred = torch.stack(y_pred).reshape(-1)
+  run_time = np.array(run_time)
+  
+    
+  accuracy = (correct/total)
+
+  return accuracy, y_pred, y_true, run_time
 
 def conf_matrix(y_true, y_pred, file_name, classes):
   cm = confusion_matrix(y_true, y_pred)
diff --git a/modneflib/modnef/templates/train.py b/modneflib/modnef/templates/train.py
index 0f48125..5d4866c 100644
--- a/modneflib/modnef/templates/train.py
+++ b/modneflib/modnef/templates/train.py
@@ -1,12 +1,9 @@
-import tonic
-from torch.utils.data import DataLoader
-from modnef.modnef_torch import ModNEFModelBuilder
-import os
 from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
 from dataset import *
+from modnef.quantizer import QuantizerScheduler
 
 if __name__ == "__main__":
 
@@ -18,6 +15,7 @@ if __name__ == "__main__":
 
   """Optimizer"""
   optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
+  lr_scheduler = None #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, eta_min=5e-4)
 
   """Loss"""
   loss = torch.nn.CrossEntropyLoss()
@@ -27,12 +25,17 @@ if __name__ == "__main__":
 
   """Train variable definition"""
   n_epoch = 2
+  qat = False
   best_model_name = "best_model"
   verbose = True
   save_plot = False
   save_history = False
   output_path = "."
 
+  """Quantization Aware Training"""
+  qat=False
+  qat_scheduler = None #QuantizerScheduler(model, (8,3), 3, lambda x : FixedPointQuantizer(x, x-1, True, True))
+
   train(
     model=model, 
     trainLoader=trainLoader, 
@@ -40,6 +43,9 @@ if __name__ == "__main__":
     validationLoader=validationLoader,
     optimizer=optimizer, 
     loss=loss, 
+    lr_scheduler = lr_scheduler,
+    qat = qat, 
+    qat_scheduler = qat_scheduler,
     device=device,  
     n_epoch=n_epoch, 
     best_model_name=best_model_name, 
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index 563d1a1..7765dae 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -36,6 +36,9 @@ if __name__ == "__main__":
   file_name = "template_vhdl_model.vhd"
   driver_config_path = "driver_config.yml"
 
+  """Prepare model for hardware estimation"""
+  model.quantize(force_init=True, clamp=True)
+
   acc, y_pred, y_true = hardware_estimation(
     model=model, 
     testLoader=testLoader, 
-- 
GitLab


From 94a2736415a26c5f18ff0a46a10ea13c13c5cd4a Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 7 Apr 2025 10:19:11 +0200
Subject: [PATCH 22/23] fix bug on arch builder

---
 .../modnef/arch_builder/modules/BLIF/rblif.py |  4 ++--
 .../modnef/arch_builder/modules/SLIF/rslif.py |  4 ++--
 .../modules/ShiftLIF/rshiftlif.py             |  4 ++--
 .../arch_builder/modules/ShiftLIF/shiftlif.py |  4 +++-
 modneflib/modnef/modnef_torch/model.py        |  3 +--
 .../modnef_neurons/blif_model/blif.py         |  6 +++---
 .../modnef_neurons/blif_model/rblif.py        |  4 ++--
 .../modnef_neurons/modnef_torch_neuron.py     |  8 ++++----
 .../modnef_neurons/slif_model/rslif.py        |  4 ++--
 .../modnef_neurons/slif_model/slif.py         |  4 ++--
 .../modnef_neurons/srlif_model/rshiftlif.py   |  4 ++--
 .../modnef_neurons/srlif_model/shiftlif.py    | 10 ++++++++--
 modneflib/modnef/modnef_torch/quantLinear.py  |  2 +-
 modneflib/modnef/quantizer/quantizer.py       |  4 ++--
 .../modnef/quantizer/quantizer_scheduler.py   | 17 ++++++++--------
 modneflib/modnef/quantizer/ste_quantizer.py   |  2 +-
 modneflib/modnef/templates/evaluation.py      | 11 +++++-----
 modneflib/modnef/templates/model.py           | 17 ++++++++++++----
 modneflib/modnef/templates/run_lib.py         | 20 ++++++++++---------
 modneflib/modnef/templates/vhdl_generation.py |  9 ++++-----
 20 files changed, 79 insertions(+), 62 deletions(-)

diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
index d427c78..d402fda 100644
--- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
+++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py
@@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
index 8ebfced..b0a66a0 100644
--- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
+++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py
@@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     else:
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True)
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
     mem_file.close()
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
index bc4cb58..3d3427b 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py
@@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
+          w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth)
 
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
@@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod):
       for i in range(self.output_neuron):
         w_line = 0
         for j in range(self.output_neuron-1, -1, -1):
-          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True)
+          w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True)
         
         rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
     
diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
index 698307e..4c9b3e9 100644
--- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
+++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py
@@ -200,7 +200,9 @@ class ShiftLif(ModNEFArchMod):
 
         mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n")
 
-      self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size)
+      self.v_threshold = two_comp(self.quantizer(self.v_threshold, unscale=False, clamp=False), self.variable_size)
+      
+      
 
     else:
       for i in range(self.input_neuron):
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 5e26bc2..84146c6 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -143,8 +143,7 @@ class ModNEFModel(nn.Module):
     
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
-        m.quantize(force_init=force_init)
-        m.clamp(False)
+        m.quantize(force_init=force_init, clamp=clamp)
 
   def clamp(self, force_init=False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 0d4ae6e..99b6154 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -238,8 +238,8 @@ class BLIF(ModNEFNeuron):
       self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     self.mem = self.mem*self.beta
 
@@ -265,12 +265,12 @@ class BLIF(ModNEFNeuron):
     -------
     BLIF
     """
+    
 
     if self.hardware_description["variable_size"]==-1:
       if self.hardware_estimation_flag:
         val_max = max(abs(self.val_max), abs(self.val_min))
         val_max = self.quantizer(val_max)
-        print(val_max)
         self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8
       else:
         self.hardware_description["variable_size"]=16
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index f52eed0..f930deb 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -253,8 +253,8 @@ class RBLIF(ModNEFNeuron):
       self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     self.mem = self.mem*self.beta
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 208da8d..681b88a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -96,7 +96,7 @@ class ModNEFNeuron(SpikingNeuron):
     else:
       self.quantizer.init_from_weight(param[0], param[1])
   
-  def quantize_weight(self):
+  def quantize_weight(self, clamp=False):
     """
     synaptic weight quantization
 
@@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron):
     """
     
     for p in self.parameters():
-      p.data = QuantizeSTE.apply(p.data, self.quantizer, True)
+      p.data = QuantizeSTE.apply(p.data, self.quantizer, clamp)
      
   
   def quantize_hp(self):
@@ -117,7 +117,7 @@ class ModNEFNeuron(SpikingNeuron):
     
     raise NotImplementedError()
   
-  def quantize(self, force_init=False):
+  def quantize(self, force_init=False, clamp=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
@@ -130,7 +130,7 @@ class ModNEFNeuron(SpikingNeuron):
     if force_init:
       self.init_quantizer()
 
-    self.quantize_weight()
+    self.quantize_weight(clamp=clamp)
     self.quantize_hp()
   
   def clamp(self, force_init=False):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 38364e2..a2ab9a4 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -258,8 +258,8 @@ class RSLIF(ModNEFNeuron):
     self.mem = self.mem-self.reset*(self.mem-self.v_rest)
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     # update neuron
     self.mem = self.mem - self.v_leak
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 6a6c7d3..2ee8653 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -248,8 +248,8 @@ class SLIF(ModNEFNeuron):
     self.mem = self.mem - self.reset*(self.mem-self.v_rest)
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     # update neuron
     self.mem = self.mem - self.v_leak
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 65d8ba5..07a631b 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -267,8 +267,8 @@ class RShiftLIF(ModNEFNeuron):
       self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     self.mem = self.mem-self.__shift(self.mem)
 
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 70f9d96..e2cb265 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -120,6 +120,8 @@ class ShiftLIF(ModNEFNeuron):
                      quantizer=quantizer
                      )
 
+
+    
     
     self.shift = int(-log(1-beta)/log(2))
 
@@ -138,6 +140,8 @@ class ShiftLIF(ModNEFNeuron):
       "variable_size" : -1
     }
 
+    print(threshold)
+
   @classmethod
   def from_dict(cls, dict, spike_grad):
     """
@@ -258,8 +262,8 @@ class ShiftLIF(ModNEFNeuron):
       self.mem = self.mem-self.reset*self.mem
 
     if self.hardware_estimation_flag:
-      self.val_min = torch.min(self.mem.min(), self.val_min)
-      self.val_max = torch.max(self.mem.max(), self.val_max)
+      self.val_min = torch.min(self.mem.min(), self.val_min).detach()
+      self.val_max = torch.max(self.mem.max(), self.val_max).detach()
 
     self.mem = self.mem-self.__shift(self.mem)
 
@@ -294,6 +298,8 @@ class ShiftLIF(ModNEFNeuron):
       else:
         self.hardware_description["variable_size"]=16
 
+    #self.clamp(force_init=True)
+
     module = builder.ShiftLif(
         name=module_name,
         input_neuron=self.fc.in_features,
diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py
index 57f67b4..f22f0a5 100644
--- a/modneflib/modnef/modnef_torch/quantLinear.py
+++ b/modneflib/modnef/modnef_torch/quantLinear.py
@@ -51,7 +51,7 @@ class QuantLinear(nn.Linear):
 
     if quantizer!=None:
       w = QuantizeSTE.apply(self.weight, quantizer)
-      w.data = quantizer.clamp(w)
+      #w.data = quantizer.clamp(w)
     else:
       w = self.weight
 
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index e264950..20914b6 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -115,8 +115,8 @@ class Quantizer():
     qdata = self._quant(tdata)
 
     if clamp:
-      born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1))
-      born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1)
+      born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)).to(qdata.device)
+      born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1).to(qdata.device)
       qdata = torch.clamp(qdata, min=born_min, max=born_max)
 
     if unscale:
diff --git a/modneflib/modnef/quantizer/quantizer_scheduler.py b/modneflib/modnef/quantizer/quantizer_scheduler.py
index 850dd9b..2d51f18 100644
--- a/modneflib/modnef/quantizer/quantizer_scheduler.py
+++ b/modneflib/modnef/quantizer/quantizer_scheduler.py
@@ -16,28 +16,26 @@ class QuantizerScheduler():
 
     self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)]
 
-    self.model = model
+    self.bit_counter = 0
+    self.epoch_counter = 0
 
-    self.period = T
+    self.bitwidth = self.num_bits[self.bit_counter]
 
+    self.period = T
     self.epoch_max = self.period*(len(self.num_bits)-1)
 
-    print(self.num_bits)
-    print(self.epoch_max)
-
-    self.bit_counter = 0
+    self.model = model
 
     self.quantizationMethod = quantizationMethod
 
     self.__update()
 
-    self.epoch_counter = 0
+    
 
   def __update(self):
-    print(self.num_bits[self.bit_counter])
     for m in self.model.modules():
       if isinstance(m, ModNEFNeuron):
-        m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter])
+        m.quantizer = self.quantizationMethod(self.bitwidth)
         m.init_quantizer()
         m.quantize_hp()
 
@@ -50,6 +48,7 @@ class QuantizerScheduler():
     else:
       if self.epoch_counter%self.period==0:
         self.bit_counter += 1
+        self.bitwidth = self.num_bits[self.bit_counter]
         self.__update()
 
   def save_model(self):
diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py
index c5ffd32..b66832e 100644
--- a/modneflib/modnef/quantizer/ste_quantizer.py
+++ b/modneflib/modnef/quantizer/ste_quantizer.py
@@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function):
       quantization method applied to data
     """
     
-    q_data = quantizer(data, True, clamp)
+    q_data = quantizer(data, unscale=True, clamp=clamp)
 
 
     ctx.scale = quantizer.scale_factor
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index 1c2898b..8fa94db 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -14,6 +14,9 @@ if __name__ == "__main__":
   """Experience name"""
   exp_name = "Evaluation"
 
+  """Device definition"""
+  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
   """Model definition"""
   best_model_name = "best_model"
 
@@ -22,7 +25,7 @@ if __name__ == "__main__":
 
   model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25))
 
-  model.load_state_dict(torch.load(best_model_name))
+  model.load_state_dict(torch.load(best_model_name, map_location=device))
 
 
   """Kind of run
@@ -33,15 +36,12 @@ if __name__ == "__main__":
   
   kind = sys.argv[1] 
 
-  """Device definition"""
-  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-
   """Evaluation variable definition"""
   verbose = True
   output_path = "."
 
   """FPGA file definition"""
-  driver_config_path = "driver_config"
+  driver_config_path = "driver_config.yml"
   board_path = ""
 
 
@@ -69,6 +69,7 @@ if __name__ == "__main__":
       quant=False
       )
   elif kind == "qeval":
+    model.quantize(force_init=True, clamp=True)
     acc, y_pred, y_true = evaluation(
       model=model, 
       testLoader=testLoader, 
diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py
index b8bbb3c..6868daf 100644
--- a/modneflib/modnef/templates/model.py
+++ b/modneflib/modnef/templates/model.py
@@ -114,11 +114,20 @@ class MyModel(mt.ModNEFModel):
     
     batch_result = []
 
+    run_time = []
+
+    n_layer = 0
+
+    for m in self.modules():
+      if isinstance(m, mt.ModNEFNeuron):
+        n_layer += 1
+
     for sample in input_spikes:
-      sample_res = self.driver.run_sample(sample, to_aer, True, len(self.layers))
+      sample_res = self.driver.run_sample(sample, to_aer, True, n_layer)
+      run_time.append(self.driver.sample_time)
       batch_result.append([sample_res])
 
-    return torch.tensor(batch_result).permute(1, 0, 2), None
+    return torch.tensor(batch_result).permute(1, 0, 2), None, run_time
   
   def to_vhdl(self, file_name=None, output_path = ".", driver_config_path = "./driver.yml"):
     """
@@ -149,8 +158,8 @@ class MyModel(mt.ModNEFModel):
       baud_rate=921_600,
       queue_read_depth=10240,
       queue_write_depth=1024,
-      tx_name="uart_rxd",
-      rx_name="uart_txd"
+      tx_name="uart_txd",
+      rx_name="uart_rxd"
     )
 
     builder.add_module(uart)
diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py
index 73dd352..7cee071 100644
--- a/modneflib/modnef/templates/run_lib.py
+++ b/modneflib/modnef/templates/run_lib.py
@@ -105,7 +105,7 @@ def train(model,
         if qat_scheduler==None: # no QAT scheduler so we save our model
           torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
           best_acc = acc_test
-        elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
+        elif qat_scheduler.save_model(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
           torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
           best_acc = acc_test
 
@@ -114,9 +114,11 @@ def train(model,
       lr_val_history.append(lr_scheduler.get_last_lr()[0])
       lr_scheduler.step()
 
-    if qat_scheduler!=None:
-      bitwidth_val_history.append(qat_scheduler.bitwidth)
-      qat_scheduler.step()
+    if qat:
+      model.clamp() 
+      if qat_scheduler!=None:
+        bitwidth_val_history.append(qat_scheduler.bitwidth)
+        qat_scheduler.step()
 
   if save_history:
     np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
@@ -128,7 +130,7 @@ def train(model,
     if lr_scheduler!=None:
       np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
 
-    if qat_scheduler!=None:
+    if qat and qat_scheduler!=None:
       np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
 
   if save_plot:
@@ -159,7 +161,7 @@ def train(model,
       plt.ylabel("learning rate")
       plt.savefig(f"{output_path}/lr_values.svg")
 
-    if qat_scheduler!=None:
+    if qat and qat_scheduler!=None:
       plt.figure()
       plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
       plt.legend()
@@ -287,7 +289,7 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva
 
     y_true.append(target)
 
-    spk_rec, batch_speed = model(data)
+    spk_rec, mem_rec, batch_speed = model(data)
 
     run_time.extend(batch_speed)
 
@@ -304,8 +306,8 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva
     del target
     del spk_rec
 
-  y_true = torch.stack(y_true).reshape(-1)
-  y_pred = torch.stack(y_pred).reshape(-1)
+  y_true = torch.stack(y_true).cou().reshape(-1).numpy()
+  y_pred = torch.stack(y_pred).coup().reshape(-1).numpy()
   run_time = np.array(run_time)
   
     
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index 7765dae..3ba0a9b 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -13,6 +13,9 @@ if __name__ == "__main__":
   """Experience name"""
   exp_name = "Evaluation"
 
+  """Device definition"""
+  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
   """Model definition"""
   best_model_name = "best_model"
 
@@ -21,11 +24,7 @@ if __name__ == "__main__":
 
   model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25))
 
-  model.load_state_dict(torch.load(best_model_name))
-
-
-  """Device definition"""
-  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+  model.load_state_dict(torch.load(best_model_name, map_location=device))
 
   """Hardware Estimation variable definition"""
   verbose = True
-- 
GitLab


From 292106dd6a22282e8df26d02adf0778386ce05c2 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Mon, 14 Apr 2025 11:23:26 +0200
Subject: [PATCH 23/23] remove print

---
 .../modules/neurons/BLIF/rblif_parallel.vhd   | 11 ++--
 .../modules/neurons/SLIF/rslif_parallel.vhd   | 12 ++--
 .../modules/neurons/SLIF/simplified_lif.vhd   | 60 ++++++++++++-------
 .../neurons/ShiftLif/rshiftlif_parallel.vhd   | 12 ++--
 .../modules/neurons/ShiftLif/shift_lif.vhd    | 57 ++++++++++++------
 modneflib/modnef/modnef_torch/model.py        |  1 +
 .../modnef_neurons/slif_model/rslif.py        |  8 +--
 .../modnef_neurons/slif_model/slif.py         | 10 ++--
 .../modnef_neurons/srlif_model/shiftlif.py    |  2 -
 9 files changed, 106 insertions(+), 67 deletions(-)

diff --git a/ModNEF_Sources/modules/neurons/BLIF/rblif_parallel.vhd b/ModNEF_Sources/modules/neurons/BLIF/rblif_parallel.vhd
index 16f9528..4fdb9c1 100644
--- a/ModNEF_Sources/modules/neurons/BLIF/rblif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/BLIF/rblif_parallel.vhd
@@ -7,7 +7,7 @@
 -- Authors : Aurelie Saulquin
 -- Email : aurelie.saulquin@univ-lille.fr
 --
--- Version : 1.1.0
+-- Version : 1.1.1
 -- Version comment : stable version 
 --
 -- Licenses : cern-ohl-s-2.0
@@ -251,7 +251,6 @@ begin
 
     if i_start_emu = '1' then
       tr_fsm_en := '1';
-      transmission_neuron_en <= '1';
     end if;
     
     if rising_edge(i_clk) then
@@ -260,8 +259,8 @@ begin
         start_calc <= '0';
         o_emu_busy <= '0';
         o_req <= '0';
-        rec_ram_en <= '1';
-        rec_neuron_en <= '1';
+        rec_ram_en <= '0';
+        rec_neuron_en <= '0';
         rec_spike_flag <= '0';
       else
         case transmission_state is
@@ -283,7 +282,7 @@ begin
             end if;
 
           when voltage_update =>
-            transmission_neuron_en <= '0';
+            transmission_neuron_en <= '1';
             start_calc <= '0';
             transmission_state <= check_arbitration;
 
@@ -311,6 +310,7 @@ begin
               transmission_state <= wait_arbitration;
               start_arb <= '1';
               rec_ram_en <= '1';
+              rec_neuron_en <= '1';
               
               rec_spike_flag <= arb_spike_flag;
             else
@@ -331,6 +331,7 @@ begin
             transmission_state <= idle;
             o_emu_busy <= '0';
             rec_neuron_en <= '0';
+            rec_ram_en <= '0';
             tr_fsm_en := '0';
         end case;
       end if;
diff --git a/ModNEF_Sources/modules/neurons/SLIF/rslif_parallel.vhd b/ModNEF_Sources/modules/neurons/SLIF/rslif_parallel.vhd
index 4e18183..894bed2 100644
--- a/ModNEF_Sources/modules/neurons/SLIF/rslif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/SLIF/rslif_parallel.vhd
@@ -7,7 +7,7 @@
 -- Authors : Aurelie Saulquin
 -- Email : aurelie.saulquin@univ-lille.fr
 --
--- Version : 1.1.0
+-- Version : 1.1.1
 -- Version comment : stable version 
 --
 -- Licenses : cern-ohl-s-2.0
@@ -240,8 +240,8 @@ begin
         start_calc <= '0';
         o_emu_busy <= '0';
         o_req <= '0';
-        rec_ram_en <= '1';
-        rec_neuron_en <= '1';
+        rec_ram_en <= '0';
+        rec_neuron_en <= '0';
         rec_spike_flag <= '0';
       else
         case transmission_state is
@@ -263,7 +263,7 @@ begin
             end if;
 
           when voltage_update =>
-            transmission_neuron_en <= '0';
+            transmission_neuron_en <= '1';
             start_calc <= '0';
             transmission_state <= check_arbitration;
 
@@ -291,6 +291,7 @@ begin
               transmission_state <= wait_arbitration;
               start_arb <= '1';
               rec_ram_en <= '1';
+              rec_neuron_en <= '1';
               
               rec_spike_flag <= arb_spike_flag;
             else
@@ -311,6 +312,7 @@ begin
             transmission_state <= idle;
             o_emu_busy <= '0';
             rec_neuron_en <= '0';
+            rec_ram_en <= '0';
             tr_fsm_en := '0';
         end case;
       end if;
@@ -338,7 +340,7 @@ begin
     mem_init_file => mem_init_file_rec
   ) port map (
     i_clk => i_clk,
-    i_en => '1',
+    i_en => rec_ram_en,
     i_addr => output_aer,
     o_data => rec_data_read
   );
diff --git a/ModNEF_Sources/modules/neurons/SLIF/simplified_lif.vhd b/ModNEF_Sources/modules/neurons/SLIF/simplified_lif.vhd
index 892bcdb..793a13b 100644
--- a/ModNEF_Sources/modules/neurons/SLIF/simplified_lif.vhd
+++ b/ModNEF_Sources/modules/neurons/SLIF/simplified_lif.vhd
@@ -7,7 +7,7 @@
 -- Authors : Aurelie Saulquin
 -- Email : aurelie.saulquin@univ-lille.fr
 --
--- Version : 1.2.0
+-- Version : 1.3.0
 -- Version comment : stable version 
 --
 -- Licenses : cern-ohl-s-2.0
@@ -70,6 +70,8 @@ begin
   o_spike <= spike;
 
   process(i_clk, i_inc_I, i_calc, i_en)
+    variable I : std_logic_vector(weight_size-1 downto 0);
+    variable I_rec : std_logic_vector(weight_size-1 downto 0);
   begin
     if rising_edge(i_clk) then
       if i_reset = '1' then
@@ -78,14 +80,21 @@ begin
       
       if i_en = '1' then
         if weight_signed then
-          if spike_flag = '1' or spike_flag_rec = '1' then
-            if spike_flag = '1' and spike_flag_rec = '0' then
-              V <= std_logic_vector(signed(V)+signed(weight));
-            elsif spike_flag = '0' and spike_flag_rec = '1' then
-              V <= std_logic_vector(signed(V)+signed(weight_rec));
+          if i_inc_I = '1' or i_inc_I_rec = '1' then
+            
+            if i_inc_I = '1' then
+              I := std_logic_vector(signed(i_w));
             else
-              V <= std_logic_vector(signed(V)+signed(weight)+signed(weight_rec));
-            end if; 
+              I := (others=>'0');
+            end if;
+
+            if i_inc_I_rec = '1' then
+              I_rec := std_logic_vector(signed(i_w_rec));
+            else
+              I_rec := (others=>'0');
+            end if;
+
+            V <= std_logic_vector(signed(V) + signed(I) + signed(I_rec));
           elsif i_calc = '1' then
             if signed(V) >= signed(v_threshold+v_leak) then
               spike <= '1';
@@ -99,15 +108,24 @@ begin
             end if;
           end if;
         else
-          if spike_flag = '1' or spike_flag_rec = '1' then
-            if spike_flag = '1' and spike_flag_rec = '0' then
-              V <= std_logic_vector(unsigned(V)+unsigned(weight));
-            elsif spike_flag = '0' and spike_flag_rec = '1' then
-              V <= std_logic_vector(unsigned(V)+unsigned(weight_rec));
-            else
-              V <= std_logic_vector(unsigned(V)+unsigned(weight)+unsigned(weight_rec));
-            end if; 
-          elsif i_calc = '1' then
+        if i_inc_I = '1' or i_inc_I_rec = '1' then
+            
+          if i_inc_I = '1' then
+            I := std_logic_vector(unsigned(i_w));
+          else
+            I := (others=>'0');
+          end if;
+
+          if i_inc_I_rec = '1' then
+            I_rec := std_logic_vector(unsigned(i_w_rec));
+          else
+            I_rec := (others=>'0');
+          end if;
+
+          V <= std_logic_vector(unsigned(V) + unsigned(I) + unsigned(I_rec));
+
+        elsif i_calc = '1' then
+
             if unsigned(V) >= unsigned(v_threshold+v_leak) then
               spike <= '1';
               V <= V_rest;
@@ -121,10 +139,10 @@ begin
           end if;
         end if;
 
-        spike_flag <= i_inc_I;
-        weight <= i_w;
-        spike_flag_rec <= i_inc_I_rec;
-        weight_rec <= i_w_rec;
+        -- spike_flag <= i_inc_I;
+        -- weight <= i_w;
+        -- spike_flag_rec <= i_inc_I_rec;
+        -- weight_rec <= i_w_rec;
 
       end if;
     end if;
diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_parallel.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_parallel.vhd
index 1af3b8a..8d3d6ad 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_parallel.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/rshiftlif_parallel.vhd
@@ -7,7 +7,7 @@
 -- Authors : Aurelie Saulquin
 -- Email : aurelie.saulquin@univ-lille.fr
 --
--- Version : 1.1.0
+-- Version : 1.1.1
 -- Version comment : stable version 
 --
 -- Licenses : cern-ohl-s-2.0
@@ -237,8 +237,8 @@ begin
         start_calc <= '0';
         o_emu_busy <= '0';
         o_req <= '0';
-        rec_ram_en <= '1';
-        rec_neuron_en <= '1';
+        rec_ram_en <= '0';
+        rec_neuron_en <= '0';
         rec_spike_flag <= '0';
       else
         case transmission_state is
@@ -260,7 +260,7 @@ begin
             end if;
 
           when voltage_update =>
-            transmission_neuron_en <= '0';
+            transmission_neuron_en <= '1';
             start_calc <= '0';
             transmission_state <= check_arbitration;
 
@@ -288,6 +288,7 @@ begin
               transmission_state <= wait_arbitration;
               start_arb <= '1';
               rec_ram_en <= '1';
+              rec_neuron_en <= '1';
               
               rec_spike_flag <= arb_spike_flag;
             else
@@ -308,6 +309,7 @@ begin
             transmission_state <= idle;
             o_emu_busy <= '0';
             rec_neuron_en <= '0';
+            rec_ram_en <= '0';
             tr_fsm_en := '0';
         end case;
       end if;
@@ -335,7 +337,7 @@ begin
     mem_init_file => mem_init_file_rec
   ) port map (
     i_clk => i_clk,
-    i_en => '1',
+    i_en => rec_ram_en,
     i_addr => output_aer,
     o_data => rec_data_read
   );
diff --git a/ModNEF_Sources/modules/neurons/ShiftLif/shift_lif.vhd b/ModNEF_Sources/modules/neurons/ShiftLif/shift_lif.vhd
index 84e97fa..9f6b460 100644
--- a/ModNEF_Sources/modules/neurons/ShiftLif/shift_lif.vhd
+++ b/ModNEF_Sources/modules/neurons/ShiftLif/shift_lif.vhd
@@ -7,7 +7,7 @@
 -- Authors : Aurelie Saulquin
 -- Email : aurelie.saulquin@univ-lille.fr
 --
--- Version : 1.1.0
+-- Version : 1.2.0
 -- Version comment : stable version 
 --
 -- Licenses : cern-ohl-s-2.0
@@ -69,6 +69,9 @@ begin
   o_spike <= spike;
 
   process(i_clk, i_inc_I, i_calc, i_en)
+  variable I : std_logic_vector(weight_size-1 downto 0);
+    variable I_rec : std_logic_vector(weight_size-1 downto 0);
+    variable v_buff : std_logic_vector(variable_size-1 downto 0);
   begin
     if rising_edge(i_clk) then
       if i_reset = '1' then
@@ -77,46 +80,62 @@ begin
 
       if i_en = '1' then
         if weight_signed then
-          if spike_flag = '1' or spike_flag_rec = '1' then
-            if spike_flag = '1' and spike_flag_rec = '0' then
-              V <= std_logic_vector(signed(V) + signed(weight));
-            elsif spike_flag = '0' and spike_flag_rec = '1' then
-              V <= std_logic_vector(signed(V) + signed(weight_rec));
+          if i_inc_I = '1' or i_inc_I_rec = '1' then
+            
+            if i_inc_I = '1' then
+              I := std_logic_vector(signed(i_w));
             else
-              V <= std_logic_vector(signed(V) + signed(weight) + signed(weight_rec));
+              I := (others=>'0');
             end if;
+
+            if i_inc_I_rec = '1' then
+              I_rec := std_logic_vector(signed(i_w_rec));
+            else
+              I_rec := (others=>'0');
+            end if;
+
+            V <= std_logic_vector(signed(V) + signed(I) + signed(I_rec));
           elsif i_calc='1' then
-            if signed(V) >= signed(v_threshold) then
+            V_buff := std_logic_vector(signed(V)-signed(shift_right(signed(V), shift)));
+            if signed(V_buff) >= signed(v_threshold) then
               spike <= '1';
               if reset = "zero" then
                 V <= (others=>'0');
               else 
-                V <= std_logic_vector(signed(V) - signed(v_threshold));
+                V <= std_logic_vector(signed(V_buff) - signed(v_threshold));
               end if;
             else
-              V <= std_logic_vector(signed(V)-signed(shift_right(signed(V), shift)));
+              V <= V_buff;
               spike <= '0';
             end if;
           end if;
         else
-          if spike_flag = '1' or spike_flag_rec = '1' then
-            if spike_flag = '1' and spike_flag_rec = '0' then
-              V <= std_logic_vector(unsigned(V) + unsigned(weight));
-            elsif spike_flag = '0' and spike_flag_rec = '1' then
-              V <= std_logic_vector(unsigned(V) + unsigned(weight_rec));
+          if i_inc_I = '1' or i_inc_I_rec = '1' then
+              
+            if i_inc_I = '1' then
+              I := std_logic_vector(unsigned(i_w));
             else
-              V <= std_logic_vector(unsigned(V) + unsigned(weight) + unsigned(weight_rec));
+              I := (others=>'0');
             end if;
+
+            if i_inc_I_rec = '1' then
+              I_rec := std_logic_vector(unsigned(i_w_rec));
+            else
+              I_rec := (others=>'0');
+            end if;
+
+            V <= std_logic_vector(unsigned(V) + unsigned(I) + unsigned(I_rec));
           elsif i_calc='1' then
-            if unsigned(V) >= unsigned(v_threshold) then
+            V_buff := std_logic_vector(unsigned(V)-unsigned(shift_right(unsigned(V), shift)));
+            if unsigned(V_buff) >= unsigned(v_threshold) then
               spike <= '1';
               if reset = "zero" then
                 V <= (others=>'0');
               else 
-                V <= std_logic_vector(unsigned(V) - unsigned(v_threshold));
+                V <= std_logic_vector(unsigned(V_buff) - unsigned(v_threshold));
               end if;
             else
-              V <= std_logic_vector(unsigned(V)-unsigned(shift_right(unsigned(V), shift)));
+              V <= V_buff;
               spike <= '0';
             end if;
           end if;
diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 84146c6..f63da28 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -198,6 +198,7 @@ class ModNEFModel(nn.Module):
 
     if self.driver != None:
       self.driver.close()
+      self.driver = None
 
   def forward(self, input_spikes):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index a2ab9a4..895adce 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -325,10 +325,10 @@ class RSLIF(ModNEFNeuron):
     We assume you've already intialize quantizer
     """
 
-    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
-    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
-    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
-    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
+    self.v_leak.data = QuantizeSTE.apply(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE.apply(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE.apply(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 2ee8653..5e93c39 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -282,9 +282,7 @@ class SLIF(ModNEFNeuron):
     if self.hardware_description["variable_size"]==-1:
       if self.hardware_estimation_flag:
         val_max = max(abs(self.val_max), abs(self.val_min))
-        print(val_max)
         val_max = self.quantizer(val_max)
-        print(val_max)
         self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8
       else:
         self.hardware_description["variable_size"]=16
@@ -316,10 +314,10 @@ class SLIF(ModNEFNeuron):
     We assume you've already intialize quantizer
     """
 
-    self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer)
-    self.v_min.data = QuantizeSTE(self.v_min, self.quantizer)
-    self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer)
-    self.threshold.data = QuantizeSTE(self.threshold, self.quantizer)
+    self.v_leak.data = QuantizeSTE.apply(self.v_leak, self.quantizer)
+    self.v_min.data = QuantizeSTE.apply(self.v_min, self.quantizer)
+    self.v_rest.data = QuantizeSTE.apply(self.v_rest, self.quantizer)
+    self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index e2cb265..92698f8 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -140,8 +140,6 @@ class ShiftLIF(ModNEFNeuron):
       "variable_size" : -1
     }
 
-    print(threshold)
-
   @classmethod
   def from_dict(cls, dict, spike_grad):
     """
-- 
GitLab