diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index 44443dc4da4759f9c0c785f4df1e2b3c9493cca2..ae40072c5329af8bd93e3af75f5d0bc358364a00 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -88,7 +88,7 @@ class ModNEFModel(nn.Module):
     for m in self.modules():
       if isinstance(m, ModNEFNeuron):
         m.hardware_estimation(hardware)
-        m.set_quant(quant)
+        m.run_quantize(quant)
 
     return super().train(mode=mode)
   
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 83cc04fcb4300966ea80fba91575c66944fa52df..50f18d89dd311cfcba5fffe6ead72cdabfe1fa3f 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -18,7 +18,7 @@ from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from modnef.quantizer import *
 
-class BLIF(Leaky, ModNEFNeuron):
+class BLIF(ModNEFNeuron):
   """
   ModNEFTorch BLIF neuron model
 
@@ -111,26 +111,15 @@ class BLIF(Leaky, ModNEFNeuron):
       quantization method
     """
     
-    Leaky.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
-
-    self.fc = nn.Linear(in_features, out_features, bias=False)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+    
+    self.register_buffer("beta", torch.tensor(beta))
 
     self._init_mem()
 
@@ -307,50 +296,20 @@ class BLIF(Leaky, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
-    """
 
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-    self.quantization_flag = True
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index bf702cb48e12f6e598d3d0c12c0af77ac01fa183..29bd28444b1fd195c469d1b70132cbbc1a8a9b63 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import *
 
-class RBLIF(Leaky, ModNEFNeuron):
+class RBLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent BLIF neuron model
 
@@ -116,27 +116,17 @@ class RBLIF(Leaky, ModNEFNeuron):
       quantization method
     """
     
-    Leaky.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+    
+    self.register_buffer("beta", torch.tensor(beta))
 
-    self.fc = nn.Linear(in_features, out_features, bias=False)
-    self.reccurent = nn.Linear(out_features, out_features, bias=False)
+    self.reccurent = nn.Linear(out_features, out_features, bias=True)
 
     self._init_mem()
 
@@ -321,52 +311,19 @@ class RBLIF(Leaky, ModNEFNeuron):
     )
     return module
   
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-    self.quantization_flag = True
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 0a1acd91e0af892d9b43ca10623d5d6203af65dc..08e37abea078bacf98ae9e6ef2dd2c53e3bd3c6a 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -9,6 +9,7 @@ Descriptions: ModNEF torch neuron interface builder
 """
 
 import torch
+import torch.nn as nn
 from modnef.quantizer import *
 from snntorch._neurons import SpikingNeuron
 from snntorch.surrogate import fast_sigmoid
@@ -44,26 +45,22 @@ class ModNEFNeuron(SpikingNeuron):
     create and return the corresponding modnef archbuilder module from internal neuron parameters
   """
 
-  def __init__(self, 
-               threshold, 
-               reset_mechanism, 
-               quantizer : Quantizer, 
-               spike_grad=fast_sigmoid(slope=25)):
-
-    SpikingNeuron.__init__(
-      self=self,
+  def __init__(self,
+               in_features,
+               out_features,
+               threshold,
+               reset_mechanism,
+               spike_grad, 
+               quantizer):
+    
+    super().__init__(
       threshold=threshold,
-      reset_mechanism=reset_mechanism,
-      spike_gard=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_threshold=False,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
+      spike_grad=spike_grad,
+      reset_mechanism=reset_mechanism
     )
+    
+    self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
+
 
     self.hardware_estimation_flag = False
     self.quantization_flag = False
@@ -84,19 +81,20 @@ class ModNEFNeuron(SpikingNeuron):
     raise NotImplementedError()
   
   def init_quantizer(self):
+    """
+    Initialize internal or re-initialize internal quantizer
+    """
 
-    params = list(self.parameters())
-
-    w1 = params[0].data
+    param = list(self.parameters())
 
-    if len(params)==2:
-      w2 = params[0].data
+    if len(param)==1:
+      self.quantizer.init_from_weight(param[0])
+      print("init no rec")
     else:
-      w2 = torch.zeros((1))
-
-    self.quantizer.init_quantizer(w1, w2)
+      self.quantizer.init_from_weight(param[0], param[1])
+      print("init rec")
   
-  def quantize_weight(self, unscaled : bool = False):
+  def quantize_weight(self, unscale : bool = True):
     """
     synaptic weight quantization
 
@@ -105,12 +103,20 @@ class ModNEFNeuron(SpikingNeuron):
     NotImplementedError()
     """
     
-    for param in self.parameters():
-      param.data = self.quantizer(param.data, unscale=unscaled)
+    for p in self.parameters():
+      p.data = self.quantizer(p.data, unscale=unscale)
+      print(p)
+      print("quantize weight")
   
-  def quantize_hp(self):
+  def quantize_hp(self, unscale : bool = True):
     """
     neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
+
+    Parameters
+    ----------
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
     
     raise NotImplementedError()
@@ -125,16 +131,31 @@ class ModNEFNeuron(SpikingNeuron):
       force quantizer initialization
     """
     
-    raise NotImplementedError()
+    if force_init:
+      self.init_quantizer()
+
+    self.quantize_weight()
+    self.quantize_hp()
   
   def clamp(self):
     """
     Clamp synaptic weight
     """
 
-    raise NotImplementedError()
+    for p in self.parameters():
+      p.data = self.quantizer.clamp(p.data)
+      print("clamp")
   
-  def set_quant(self, mode=False):
+  def run_quantize(self, mode=False):
+    """
+    Srtup quantization flag
+
+    Parameters
+    ----------
+    mode : bool = False
+      quantize run or not
+    """
+
     self.quantization_flag = mode
   
   def hardware_estimation(self, mode = False):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index 2d78aa9440df32af8fa63f9e11218a89d5588483..190ddc84029476ef7c36e144c4fc9c57d3c21403 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
 from modnef.quantizer import MinMaxQuantizer
 
-class RSLIF(LIF, ModNEFNeuron):
+class RSLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Simplifed LIF neuron model
 
@@ -119,33 +119,18 @@ class RSLIF(LIF, ModNEFNeuron):
       quantization function
     """
 
-    LIF.__init__(
-      self=self,
-      beta = v_leak,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism="zero",
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
-    )
-
-    ModNEFNeuron.__init__(self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism="zero",
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
 
     self.register_buffer("v_leak", torch.as_tensor(v_leak))
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.in_features = in_features
-    self.out_features = out_features
-
-    self.fc = nn.Linear(self.in_features, self.out_features, bias=False)
     self.reccurent = nn.Linear(self.out_features, self.out_features, bias=False)
 
     self._init_mem()
@@ -338,56 +323,23 @@ class RSLIF(LIF, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
 
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
+  def quantize_hp(self, unscale : bool = True):
     """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.v_leak.data = self.quantizer(self.v_leak.data, True)
-    self.v_min.data = self.quantizer(self.v_min.data, True)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True)
-    self.threshold.data = self.quantizer(self.threshold, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
 
-    self.quantize_weight()
-    self.quantize_parameters()
+    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
+    self.v_min.data = self.quantizer(self.v_min.data, unscale)
+    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
+    self.threshold.data = self.quantizer(self.threshold, unscale)
 
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
-    """
-
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
-
-  
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 9d01858ccd432208fe1924dd629f53b881b0ba31..78cc36b27622e701758b8f10b18ecb17634dd84c 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -18,7 +18,7 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import ceil, log
 from modnef.quantizer import MinMaxQuantizer
 
-class SLIF(LIF, ModNEFNeuron):
+class SLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Simplifed LIF neuron model
 
@@ -119,34 +119,18 @@ class SLIF(LIF, ModNEFNeuron):
       quantization method
     """
 
-    LIF.__init__(
-      self=self,
-      beta = v_leak,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism="zero",
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False
-    )
-
-    ModNEFNeuron.__init__(self, quantizer=quantizer)
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism="zero",
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
 
     self.register_buffer("v_leak", torch.as_tensor(v_leak))
     self.register_buffer("v_min", torch.as_tensor(v_min))
     self.register_buffer("v_rest", torch.as_tensor(v_rest))
 
-    self.in_features = in_features
-    self.out_features = out_features
-
-    self.fc = nn.Linear(self.in_features, self.out_features, bias=False)
-
     self._init_mem()
     
     self.hardware_description = {
@@ -329,53 +313,23 @@ class SLIF(LIF, ModNEFNeuron):
       output_path=output_path
     )
     return module
-  
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
 
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.v_leak.data = self.quantizer(self.v_leak.data, True)
-    self.v_min.data = self.quantizer(self.v_min.data, True)
-    self.v_rest.data = self.quantizer(self.v_rest.data, True)
-    self.threshold.data = self.quantizer(self.threshold, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization
+    We assume you've already intialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
 
-    self.quantize_weight()
-    self.quantize_parameters()
+    self.v_leak.data = self.quantizer(self.v_leak.data, unscale)
+    self.v_min.data = self.quantizer(self.v_min.data, unscale)
+    self.v_rest.data = self.quantizer(self.v_rest.data, unscale)
+    self.threshold.data = self.quantizer(self.threshold, unscale)
 
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
-    """
-
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-  
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index cdf4d2678e00a525c30a4ffdc106699ffd7ac344..e6354a2086b3b742ffbf6270df64e31e8f46c8a7 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -19,7 +19,7 @@ from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 
 
-class RShiftLIF(LIF, ModNEFNeuron):
+class RShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch reccurent Shift LIF neuron model
 
@@ -116,35 +116,25 @@ class RShiftLIF(LIF, ModNEFNeuron):
     quantizer = DynamicScaleFactoirQuantizer(8) : Quantizer
       quantization method
     """
+
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
     
     self.shift = int(-log(1-beta)/log(2))
 
     if (1-2**-self.shift) != beta:
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
-    
-    LIF.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
-
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
 
-    self.fc = nn.Linear(in_features, out_features, bias=False)
     self.reccurent = nn.Linear(out_features, out_features, bias=False)
 
+    self.register_buffer("beta", torch.tensor(beta))
+
     self._init_mem()
 
     self.hardware_description = {
@@ -340,51 +330,19 @@ class RShiftLIF(LIF, ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-    self.reccurent.weight.data = self.quantizer(self.reccurent.weight.data, True)
-
-  def quantize_parameters(self):
+  def quantize_hp(self, unscale : bool = True):
     """
-    Quantize neuron hyper-parameters
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight, rec_weight=self.reccurent.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    self.beta.data = self.quantizer(self.beta.data, True)
-
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight, self.reccurent.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
-    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 195e965f8f99ef994bca289886b182ad7af402fa..f2c95cbe6d5d7ef0ad7271e134a9a1d600f0e2c5 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -17,6 +17,7 @@ from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
+from snntorch import LIF
 
 class ShiftLIF(ModNEFNeuron):
   """
@@ -110,6 +111,15 @@ class ShiftLIF(ModNEFNeuron):
     quantizer = DynamicScaleFactorQuantizer(8) : Quantizer
       quantization method
     """
+
+    super().__init__(threshold=threshold,
+                     in_features=in_features,
+                     out_features=out_features,
+                     reset_mechanism=reset_mechanism,
+                     spike_grad=spike_grad,
+                     quantizer=quantizer
+                     )
+
     
     self.shift = int(-log(1-beta)/log(2))
 
@@ -118,14 +128,7 @@ class ShiftLIF(ModNEFNeuron):
       beta = 1-2**-self.shift
 
 
-    ModNEFNeuron.__init__(self=self, 
-                          threshold=threshold,
-                          reset_mechanism=reset_mechanism,
-                          spike_grad=spike_grad,
-                          quantizer=quantizer
-                          )
-
-    self.fc = nn.Linear(in_features, out_features, bias=False)
+    self.register_buffer("beta", torch.tensor(beta))
 
     self._init_mem()
 
@@ -310,48 +313,19 @@ class ShiftLIF(ModNEFNeuron):
     )
     return module
 
-  def quantize_weight(self):
-    """
-    Quantize synaptic weight
-    """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.fc.weight.data = self.quantizer(self.fc.weight.data, True)
-
-  def quantize_parameters(self):
-    """
-    Quantize neuron hyper-parameters
+  def quantize_hp(self, unscale : bool = True):
     """
-
-    if not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(weight=self.fc.weight)
-
-    self.threshold.data = self.quantizer(self.threshold.data, True)
-    
-  def quantize(self, force_init=False):
-    """
-    Quantize synaptic weight and neuron hyper-parameters
+    neuron hyper-parameters quantization.
+    We assume you already initialize quantizer
 
     Parameters
     ----------
-    force_init = Fasle : bool
-      force quantizer initialization
-    """
-    
-    if force_init or not self.quantizer.is_initialize:
-      self.quantizer.init_from_weight(self.fc.weight)
-
-    self.quantize_weight()
-    self.quantize_parameters()
-
-  def clamp(self):
-    """
-    Clamp synaptic weight and neuron hyper-parameters
+    unscale : bool = True
+      set to true if quantization must be simulate
     """
 
-    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.threshold.data = self.quantizer(self.threshold.data, unscale)
+    self.beta.data = self.quantizer(self.beta.data, unscale)
 
 
   @classmethod
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 43a6aa1fca47cb9e36775b0596a95a0b9c61c6b9..2631170b678e8ece885ed61e4e16a2efe17340db 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index c48e690cdbc31e131a1351c02330c7a8fc3e148c..106a242180b8b83b5d47ebe1dbf0e4f7604f63cd 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index ebc4ae9fb6cde07e030b35dad1a89f0c82f6ae00..6ca813180dfc9d0de66ff4428564c6240b9f4649 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 6ab752eb28bb5c601293cfe46cd1ce7e05ff4e5a..c56f50616be06d383a33840424ae199967f35403 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -77,7 +77,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
+  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
@@ -108,14 +108,14 @@ class Quantizer():
     """
 
     if not torch.is_tensor(data):
-      tdata = torch.tensor(data)
+      tdata = torch.tensor(data, dtype=torch.float32)
     else:
       tdata = data
 
     qdata = self._quant(tdata)
 
     if unscale:
-      qdata = self._unquant(qdata)
+      qdata = self._unquant(qdata).to(torch.float32)
     
     if isinstance(data, (int, float)):
       return qdata.item()