From 5f5d13434f53ebbfc58737ef865cd660e3f37b62 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Wed, 19 Mar 2025 22:21:33 +0100
Subject: [PATCH] add quantizer modif

---
 .../modnef_neurons/modnef_torch_neuron.py     | 46 +++++++++++++++++--
 .../modnef_neurons/srlif_model/shiftlif.py    | 29 ++++--------
 .../quantizer/dynamic_scale_quantizer.py      |  2 +-
 .../modnef/quantizer/fixed_point_quantizer.py |  2 +-
 .../modnef/quantizer/min_max_quantizer.py     |  2 +-
 modneflib/modnef/quantizer/quantizer.py       |  8 ++--
 6 files changed, 57 insertions(+), 32 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index 7322594..0a1acd9 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -10,6 +10,8 @@ Descriptions: ModNEF torch neuron interface builder
 
 import torch
 from modnef.quantizer import *
+from snntorch._neurons import SpikingNeuron
+from snntorch.surrogate import fast_sigmoid
 
 _quantizer = {
   "FixedPointQuantizer" : FixedPointQuantizer,
@@ -17,7 +19,7 @@ _quantizer = {
   "DynamicScaleFactorQuantizer" : DynamicScaleFactorQuantizer
 }
 
-class ModNEFNeuron():
+class ModNEFNeuron(SpikingNeuron):
   """
   ModNEF torch neuron interface
 
@@ -42,7 +44,27 @@ class ModNEFNeuron():
     create and return the corresponding modnef archbuilder module from internal neuron parameters
   """
 
-  def __init__(self, quantizer : Quantizer):
+  def __init__(self, 
+               threshold, 
+               reset_mechanism, 
+               quantizer : Quantizer, 
+               spike_grad=fast_sigmoid(slope=25)):
+
+    SpikingNeuron.__init__(
+      self=self,
+      threshold=threshold,
+      reset_mechanism=reset_mechanism,
+      spike_gard=spike_grad,
+      surrogate_disable=False,
+      init_hidden=False,
+      inhibition=False,
+      learn_threshold=False,
+      state_quant=False,
+      output=False,
+      graded_spikes_factor=1.0,
+      learn_graded_spikes_factor=False
+    )
+
     self.hardware_estimation_flag = False
     self.quantization_flag = False
 
@@ -61,7 +83,20 @@ class ModNEFNeuron():
 
     raise NotImplementedError()
   
-  def quantize_weight(self):
+  def init_quantizer(self):
+
+    params = list(self.parameters())
+
+    w1 = params[0].data
+
+    if len(params)==2:
+      w2 = params[0].data
+    else:
+      w2 = torch.zeros((1))
+
+    self.quantizer.init_quantizer(w1, w2)
+  
+  def quantize_weight(self, unscaled : bool = False):
     """
     synaptic weight quantization
 
@@ -70,9 +105,10 @@ class ModNEFNeuron():
     NotImplementedError()
     """
     
-    raise NotImplementedError()
+    for param in self.parameters():
+      param.data = self.quantizer(param.data, unscale=unscaled)
   
-  def quantize_parameters(self):
+  def quantize_hp(self):
     """
     neuron hyper-parameters quantization
     """
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 9bb6d1d..195e965 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -11,14 +11,14 @@ Based on snntorch.Leaky and snntroch.LIF class
 
 import torch.nn as nn
 import torch
-from snntorch import LIF
+from snntorch.surrogate import fast_sigmoid
 import modnef.arch_builder as builder
 from modnef.arch_builder.modules.utilities import *
 from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
 from math import log, ceil
 from modnef.quantizer import DynamicScaleFactorQuantizer
 
-class ShiftLIF(LIF, ModNEFNeuron):
+class ShiftLIF(ModNEFNeuron):
   """
   ModNEFTorch Shift LIF neuron model
 
@@ -86,7 +86,7 @@ class ShiftLIF(LIF, ModNEFNeuron):
                out_features,
                beta,
                threshold=1.0,
-               spike_grad=None,
+               spike_grad=fast_sigmoid(slope=25),
                reset_mechanism="subtract",
                quantizer=DynamicScaleFactorQuantizer(8)
             ):
@@ -117,24 +117,13 @@ class ShiftLIF(LIF, ModNEFNeuron):
       print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
       beta = 1-2**-self.shift
 
-    LIF.__init__(
-      self=self,
-      beta=beta,
-      threshold=threshold,
-      spike_grad=spike_grad,
-      surrogate_disable=False,
-      init_hidden=False,
-      inhibition=False,
-      learn_beta=False,
-      learn_threshold=False,
-      reset_mechanism=reset_mechanism,
-      state_quant=False,
-      output=False,
-      graded_spikes_factor=1.0,
-      learn_graded_spikes_factor=False,
-    )
 
-    ModNEFNeuron.__init__(self=self, quantizer=quantizer)
+    ModNEFNeuron.__init__(self=self, 
+                          threshold=threshold,
+                          reset_mechanism=reset_mechanism,
+                          spike_grad=spike_grad,
+                          quantizer=quantizer
+                          )
 
     self.fc = nn.Linear(in_features, out_features, bias=False)
 
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 2631170..43a6aa1 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index 106a242..c48e690 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/min_max_quantizer.py b/modneflib/modnef/quantizer/min_max_quantizer.py
index 6ca8131..ebc4ae9 100644
--- a/modneflib/modnef/quantizer/min_max_quantizer.py
+++ b/modneflib/modnef/quantizer/min_max_quantizer.py
@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer):
       is_initialize=config["is_initialize"]
     )
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index d753162..6ab752e 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -77,7 +77,7 @@ class Quantizer():
     
     raise NotImplementedError()
 
-  def init_from_weight(self, weight, rec_weight=torch.zeros((1))):
+  def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
     """
     initialize quantizer parameters from synaptic weight
 
@@ -112,10 +112,10 @@ class Quantizer():
     else:
       tdata = data
 
+    qdata = self._quant(tdata)
+
     if unscale:
-      qdata = self._unquant(self._quant(tdata))
-    else:
-      qdata = self._quant(tdata)
+      qdata = self._unquant(qdata)
     
     if isinstance(data, (int, float)):
       return qdata.item()
-- 
GitLab