From e851459f9a4b1ed406f1af0ac6d4a075983bf641 Mon Sep 17 00:00:00 2001
From: ahoni <aurelie.saulq@proton.me>
Date: Sun, 9 Mar 2025 14:36:19 +0100
Subject: [PATCH] add type to quantizer)

---
 modneflib/modnef/modnef_torch/model.py        | 14 ++++++
 .../modnef_neurons/blif_model/blif.py         | 38 ++++++----------
 .../modnef_neurons/blif_model/rblif.py        | 40 +++++++----------
 .../modnef_neurons/modnef_torch_neuron.py     |  9 ++--
 .../modnef_neurons/slif_model/rslif.py        | 45 +++++++------------
 .../modnef_neurons/slif_model/slif.py         | 43 +++++++-----------
 .../modnef_neurons/srlif_model/rshiftlif.py   | 41 +++++++----------
 .../modnef_neurons/srlif_model/shiftlif.py    | 39 ++++++----------
 .../quantizer/dynamic_scale_quantizer.py      | 20 +++++----
 .../modnef/quantizer/fixed_point_quantizer.py | 20 +++++----
 .../modnef/quantizer/min_max_quantizer.py     | 20 +++++----
 modneflib/modnef/quantizer/quantizer.py       | 31 ++++++-------
 modneflib/setup.py                            |  2 +-
 13 files changed, 161 insertions(+), 201 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index f35d486..0451bfe 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -91,6 +91,20 @@ class ModNEFModel(nn.Module):
         m.set_quant(quant)
 
     return super().train(mode=mode)
+  
+  def quantize(self, force_init=False):
+    """
+    Quantize synaptic weight and neuron hyper-parameters
+
+    Parameters
+    ----------
+    force_init = Fasle : bool
+      force quantizer initialization
+    """
+    
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.quantize(force_init=force_init)
 
   def train(self, mode : bool = True, quant : bool = False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 78cac10..3c83a40 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -1,7 +1,7 @@
 """
 File name: blif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, math, snntorch, modnef.archbuilder, modnef_torch_neuron, modnef.quantizer
@@ -56,11 +56,11 @@ class BLIF(Leaky, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters()
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -305,51 +305,41 @@ class BLIF(Leaky, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight)
 
-    self.threshold.data = self.quantizer(self.threshold.data, True, dtype)
-    self.beta.data = self.quantizer(self.beta.data, True, dtype)
+    self.threshold.data = self.quantizer(self.threshold.data, True)
+    self.beta.data = self.quantizer(self.beta.data, True)
 
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
+    self.quantize_weight()
+    self.quantize_parameters()
     self.quantization_flag = True
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index 06b766e..f0ddcdd 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -1,7 +1,7 @@
 """
 File name: rblif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -58,11 +58,11 @@ class RBLIF(Leaky, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters()
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -318,53 +318,43 @@ class RBLIF(Leaky, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
+    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.threshold.data = self.quantizer(self.threshold.data, True, dtype)
-    self.beta.data = self.quantizer(self.beta.data, True, dtype)
+    self.threshold.data = self.quantizer(self.threshold.data, True)
+    self.beta.data = self.quantizer(self.beta.data, True)
 
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
+    self.quantize_weight()
+    self.quantize_parameters()
     self.quantization_flag = True
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 544c765..dc49183 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -79,23 +79,20 @@ class ModNEFNeuron():
     
     raise NotImplementedError()
   
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     raise NotImplementedError()
   
   def set_quant(self, mode=False):
     self.quantization_flag = mode
-
-    if mode:
-      self.quantize(False)
   
   def hardware_estimation(self, mode = False):
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 0f6415a..c950a76 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -1,7 +1,7 @@
 """
 File name: rslif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -60,11 +60,11 @@ class RSLIF(LIF, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters(d)
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -336,56 +336,45 @@ class RSLIF(LIF, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
+    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, True, dtype)
-    self.v_min.data = self.quantizer(self.v_min.data, True, dtype)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True, dtype)
-    self.threshold.data = self.quantizer(self.threshold, True, dtype)
+    self.v_leak.data = self.quantizer(self.v_leak.data, True)
+    self.v_min.data = self.quantizer(self.v_min.data, True)
+    self.v_rest.data = self.quantizer(self.v_rest.data, True)
+    self.threshold.data = self.quantizer(self.threshold, True)
 
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
-    self.quantization_flag = True
+    self.quantize_weight()
+    self.quantize_parameters()
   
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 0bd22d6..315a07d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -1,7 +1,7 @@
 """
 File name: slif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -60,11 +60,11 @@ class SLIF(LIF, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters()
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -327,55 +327,44 @@ class SLIF(LIF, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight)
 
-    self.v_leak.data = self.quantizer(self.v_leak.data, True, dtype)
-    self.v_min.data = self.quantizer(self.v_min.data, True, dtype)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True, dtype)
-    self.threshold.data = self.quantizer(self.threshold, True, dtype)
+    self.v_leak.data = self.quantizer(self.v_leak.data, True)
+    self.v_min.data = self.quantizer(self.v_min.data, True)
+    self.v_rest.data = self.quantizer(self.v_rest.data, True)
+    self.threshold.data = self.quantizer(self.threshold, True)
 
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
-    self.quantization_flag = True
+    self.quantize_weight()
+    self.quantize_parameters()
 
   
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 692a5f1..71194ac 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: rsrlif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -59,11 +59,11 @@ class RShiftLIF(LIF, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters()
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -336,54 +336,43 @@ class RShiftLIF(LIF, ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
+    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.threshold.data = self.quantizer(self.threshold.data, True, dtype)
-    self.beta.data = self.quantizer(self.beta.data, True, dtype)
+    self.threshold.data = self.quantizer(self.threshold.data, True)
+    self.beta.data = self.quantizer(self.beta.data, True)
 
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
-    self.quantization_flag = True
+    self.quantize_weight()
+    self.quantize_parameters()
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index af97a70..adcd3d4 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -1,7 +1,7 @@
 """
 File name: srlif
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer
@@ -56,11 +56,11 @@ class ShiftLIF(LIF, ModNEFNeuron):
     Update membrane voltage
   get_builder_module(module_name, output_path=".")
     Create ModNEFArchBuilder from internal neuron description
-  quantize_weight(dtype=torch.int32)
+  quantize_weight()
     Quantize synaptic weight
-  quantize_parameters(dtype=torch.int32)
+  quantize_parameters()
     Quantize neuron hyper-parameters
-  quantize(dtype=torch.int32)
+  quantize(force_init=False)
     Quantize synaptic weight and neuron hyper-parameters
   hardware_estimation()
     Toggle hardware estimation calculation
@@ -321,53 +321,42 @@ class ShiftLIF(LIF, ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self, dtype=torch.int32):
+  def quantize_weight(self):
     """
     Quantize synaptic weight
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True, dtype)
+    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
 
-  def quantize_parameters(self, dtype=torch.int32):
+  def quantize_parameters(self):
     """
     Quantize neuron hyper-parameters
-
-    Parameters
-    ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
     """
 
     if not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(weight=self.fc.weight)
 
-    self.threshold.data = self.quantizer(self.threshold.data, True, dtype)
-    self.beta.data = self.quantizer(self.beta.data, True, dtype)
+    self.threshold.data = self.quantizer(self.threshold.data, True)
+    self.beta.data = self.quantizer(self.beta.data, True)
     
-  def quantize(self, force_init=False, dtype=torch.int32):
+  def quantize(self, force_init=False):
     """
     Quantize synaptic weight and neuron hyper-parameters
 
     Parameters
     ----------
-    dtype = torch.int32 : dtype
-      type use during quantization
+    force_init = Fasle : bool
+      force quantizer initialization
     """
     
     if force_init or not self.quantizer.is_initialize:
       self.quantizer.init_from_weight(self.fc.weight)
 
-    self.quantize_weight(dtype)
-    self.quantize_parameters(dtype)
-    self.quantization_flag = True
+    self.quantize_weight()
+    self.quantize_parameters()
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index c6e95d4..4f4b096 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -1,7 +1,7 @@
 """
 File name: dynamic_scale_quantizer
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, modnef.quantizer.Quantizer
@@ -23,6 +23,8 @@ class DynamicScaleFactorQuantizer(Quantizer):
     set to true if quantizer is signed, false if not
   is_initializer : bool
     set to true when quantizer parameters are initialize
+  dtype = torch.int32 : torch.dtype
+    type use during quantization
   scale_factor : float
     scale factor
 
@@ -33,14 +35,15 @@ class DynamicScaleFactorQuantizer(Quantizer):
     generate quantizer from dictionnary description
   init_from_weight(weight, rec_weight)
     initialize quantizer parameters from synaptic weight
-  __call__(data, unscale, dtype)
+  __call__(data, unscale)
     Call quantization function, if unscale=true, return dequantization data
   """
 
   def __init__(self, 
                bitwidth, 
                signed=None, 
-               is_initialize=False
+               is_initialize=False,
+               dtype=torch.int32
                ):
     """
     Construct class
@@ -54,12 +57,15 @@ class DynamicScaleFactorQuantizer(Quantizer):
       If None, determine during initialization
     is_initialize = False : bool
       set to true when quantizer parameters are initialize
+    dtype = troch.int32 : torch.dtype
+      type use during quatization
     """
 
     super().__init__(
       bitwidth=bitwidth,
       signed=signed,
-      is_initialize=is_initialize
+      is_initialize=is_initialize,
+      dtype=dtype
     )
     
     self.scale_factor = 0
@@ -124,7 +130,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
     #self.scale_factor = torch.max(torch.abs(weight).max(), torch.abs(weight).max())/2**(self.bitwidth-1)
 
 
-  def _quant(self, data, unscale, dtype=torch.int32) -> torch.Tensor:
+  def _quant(self, data, unscale) -> torch.Tensor:
     """
     Apply quantization
 
@@ -134,8 +140,6 @@ class DynamicScaleFactorQuantizer(Quantizer):
       data to quantize
     unscale = False : bool
       If true, apply quantization and then, unquantize data to simulate quantization
-    dtype=torch.int32 : dtype
-      data type
 
     Returns
     -------
@@ -146,7 +150,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
     born_max = 2**(self.bitwidth-int(self.signed))-1
 
     #scaled = torch.clamp(data/self.scale_factor, min=born_min, max=born_max).to(dtype)
-    scaled = torch.round(data/self.scale_factor).to(dtype)
+    scaled = torch.round(data/self.scale_factor).to(self.dtype)
 
     if unscale:
       return scaled*self.scale_factor
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index afb36ac..564fceb 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -1,7 +1,7 @@
 """
 File name: fixed_point_quantizer
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, math, modnef.quantizer.Quantizer
@@ -24,6 +24,8 @@ class FixedPointQuantizer(Quantizer):
     set to true if quantizer is signed, false if not
   is_initializer : bool
     set to true when quantizer parameters are initialize
+  dtype = torch.int32 : torch.dtype
+    type use during quantization
   fixed_point : int
     fixed point position
   scale_factor : int
@@ -36,7 +38,7 @@ class FixedPointQuantizer(Quantizer):
     generate quantizer from dictionnary description
   init_from_weight(weight, rec_weight)
     initialize quantizer parameters from synaptic weight
-  __call__(data, unscale, dtype)
+  __call__(data, unscale)
     Call quantization function, if unscale=true, return dequantization data
   """
 
@@ -44,7 +46,8 @@ class FixedPointQuantizer(Quantizer):
                bitwidth, 
                fixed_point=-1, 
                signed=None, 
-               is_initialize=False
+               is_initialize=False,
+               dtype=torch.int32
                ):
     """
     Construct class
@@ -61,6 +64,8 @@ class FixedPointQuantizer(Quantizer):
       If None, determine during initialization
     is_initialize = False : bool
       set to true when quantizer parameters are initialize
+    dtype = torch.int32 : torch.dtype
+      type use during conversion
     """
 
     if bitwidth==-1 and fixed_point==-1:
@@ -69,7 +74,8 @@ class FixedPointQuantizer(Quantizer):
     super().__init__(
       bitwidth=bitwidth,
       signed=signed,
-      is_initialize=is_initialize
+      is_initialize=is_initialize,
+      dtype=dtype
     )
 
     self.fixed_point = fixed_point
@@ -141,7 +147,7 @@ class FixedPointQuantizer(Quantizer):
         self.scale_factor = 2**self.fixed_point
 
 
-  def _quant(self, data, unscale, dtype) -> torch.Tensor:
+  def _quant(self, data, unscale) -> torch.Tensor:
     """
     Apply quantization
 
@@ -151,15 +157,13 @@ class FixedPointQuantizer(Quantizer):
       data to quantize
     unscale = False : bool
       If true, apply quantization and then, unquantize data to simulate quantization
-    dtype=torch.int32 : dtype
-      data type
 
     Returns
     -------
     Tensor
     """
 
-    scaled = torch.round(data*self.scale_factor).to(dtype)
+    scaled = torch.round(data*self.scale_factor).to(self.dtype)
     
     if unscale:
       return (scaled.to(torch.float32))/self.scale_factor
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index 28d1cf3..a700f6b 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -1,7 +1,7 @@
 """
 File name: quantizer
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, modnef.quantizer.Quantizer
@@ -23,6 +23,8 @@ class MinMaxQuantizer(Quantizer):
     set to true if quantizer is signed, false if not
   is_initializer : bool
     set to true when quantizer parameters are initialize
+  dtype : torch.dtype
+    type use during quantization
   x_min : float
     minimal value of weight
   x_max : float
@@ -39,14 +41,15 @@ class MinMaxQuantizer(Quantizer):
     generate quantizer from dictionnary description
   init_from_weight(weight, rec_weight)
     initialize quantizer parameters from synaptic weight
-  __call__(data, unscale, dtype)
+  __call__(data, unscale)
     Call quantization function, if unscale=true, return dequantization data
   """
 
   def __init__(self, 
                bitwidth, 
                signed=None, 
-               is_initialize=False
+               is_initialize=False,
+               dtype=torch.int32
                ):
     """
     Construct class
@@ -60,12 +63,15 @@ class MinMaxQuantizer(Quantizer):
       If None, determine during initialization
     is_initialize = False : bool
       set to true when quantizer parameters are initialize
+    dtype = torch.int32 : toch.dtype
+      type use during quantization
     """
 
     super().__init__(
       bitwidth=bitwidth,
       signed=signed,
-      is_initialize=is_initialize
+      is_initialize=is_initialize,
+      dtype=dtype
     )
     
     self.x_min = 0
@@ -128,7 +134,7 @@ class MinMaxQuantizer(Quantizer):
     self.b_max = 2**(self.bitwidth-int(self.signed))-1
     self.b_min = -int(self.signed)*self.b_max
 
-  def _quant(self, data, unscale, dtype) -> torch.Tensor:
+  def _quant(self, data, unscale) -> torch.Tensor:
     """
     Apply quantization
 
@@ -138,15 +144,13 @@ class MinMaxQuantizer(Quantizer):
       data to quantize
     unscale = False : bool
       If true, apply quantization and then, unquantize data to simulate quantization
-    dtype=torch.int32 : dtype
-      data type
 
     Returns
     -------
     Tensor
     """
 
-    scaled = ((data-self.x_min)/(self.x_max-self.x_min)*(self.b_max-self.b_min)+self.b_min).to(dtype)
+    scaled = ((data-self.x_min)/(self.x_max-self.x_min)*(self.b_max-self.b_min)+self.b_min).to(self.dtype)
     
     if unscale:
       return (scaled-self.b_min)/(self.b_max-self.b_min)*(self.x_max-self.x_min)+self.x_min
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 94fc8f5..5a37433 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -1,7 +1,7 @@
 """
 File name: quantizer
 Author: Aurélie Saulquin  
-Version: 1.0.0
+Version: 1.1.0
 License: GPL-3.0-or-later
 Contact: aurelie.saulquin@univ-lille.fr
 Dependencies: torch, numpy
@@ -22,7 +22,9 @@ class Quantizer():
   signed : bool
     set to true if quantizer is signed, false if not
   is_initializer : bool
-    set to true when quantizer parameters are initialize
+    set to true when quantizer parameters are initialize*
+  dtype : torch.dtype
+    type use during quantization
 
   Methods
   -------
@@ -31,14 +33,15 @@ class Quantizer():
     generate quantizer from dictionnary description
   init_from_weight(weight, rec_weight)
     initialize quantizer parameters from synaptic weight
-  __call__(data, unscale, dtype)
+  __call__(data, unscale)
     Call quantization function, if unscale=true, return dequantization data
   """
   
   def __init__(self, 
                bitwidth, 
                signed=None, 
-               is_initialize=False
+               is_initialize=False,
+               dtype=torch.int32
                ):
     """
     Construct class
@@ -52,12 +55,14 @@ class Quantizer():
       If None, determine during initialization
     is_initialize = False : bool
       set to true when quantizer parameters are initialize
+    dtype = torch.int32 : torch.dtype
+      type use during quantization
     """
 
     self.bitwidth = bitwidth
     self.is_initialize = is_initialize
     self.signed = signed
-    pass
+    self.dtype=dtype
 
   @classmethod
   def from_dict(cls, dict):
@@ -86,7 +91,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def __call__(self, data, unscale=False, dtype=torch.int16):
+  def __call__(self, data, unscale=False):
     """
     Call quantization function
 
@@ -96,8 +101,6 @@ class Quantizer():
       data to quantize
     unscale = False : bool
       If true, apply quantization and then, unquantize data to simulate quantization
-    dtype=torch.int32 : dtype
-      data type
 
     Returns
     -------
@@ -105,17 +108,17 @@ class Quantizer():
     """
     
     if isinstance(data, (int, float)):
-      return self._quant(data=torch.tensor(data), unscale=unscale, dtype=dtype).item()
+      return self._quant(data=torch.tensor(data), unscale=unscale).item()
     elif isinstance(data, list):
-      return self._quant(data=torch.tensor(data), unscale=unscale, dtype=dtype).tolist()
+      return self._quant(data=torch.tensor(data), unscale=unscale).tolist()
     elif isinstance(data, np.ndarray):
-      return self._quant(data=torch.tensor(data), unscale=unscale, dtype=dtype).numpy()
+      return self._quant(data=torch.tensor(data), unscale=unscale).numpy()
     elif torch.is_tensor(data):
-      return self._quant(data=data, unscale=unscale, dtype=dtype).detach()
+      return self._quant(data=data, unscale=unscale).detach()
     else:
       raise TypeError("Unsupported data type")
 
-  def _quant(self, data, unscale, dtype) -> torch.Tensor:
+  def _quant(self, data, unscale) -> torch.Tensor:
     """
     Apply quantization
 
@@ -125,8 +128,6 @@ class Quantizer():
       data to quantize
     unscale = False : bool
       If true, apply quantization and then, unquantize data to simulate quantization
-    dtype=torch.int32 : dtype
-      data type
 
     Returns
     -------
diff --git a/modneflib/setup.py b/modneflib/setup.py
index 191c203..6f56208 100644
--- a/modneflib/setup.py
+++ b/modneflib/setup.py
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
 setup(
         name = "modnef",
         packages=find_packages(),
-        version = "=2.0.0",
+        version = "2.0.0",
         description="ModNEF python librairy",
         author="Aurelie Saulquin",
         install_requires=["networkx", "matplotlib", "pyyaml", "torch", "snntorch"],
-- 
GitLab