From 58d75437ce9bed03112b37371e8edc095f7a58bc Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com>
Date: Fri, 14 Mar 2025 11:37:04 +0100
Subject: [PATCH] add clamp to neuron model

---
 modneflib/modnef/modnef_torch/model.py           |  6 ++++++
 .../modnef_neurons/blif_model/blif.py            |  7 +++++++
 .../modnef_neurons/blif_model/rblif.py           |  9 +++++++++
 .../modnef_neurons/modnef_torch_neuron.py        |  7 +++++++
 .../modnef_neurons/slif_model/rslif.py           |  9 +++++++++
 .../modnef_neurons/slif_model/slif.py            |  6 ++++++
 .../modnef_neurons/srlif_model/rshiftlif.py      |  9 +++++++++
 .../modnef_neurons/srlif_model/shiftlif.py       |  8 ++++++++
 .../modnef/quantizer/dynamic_scale_quantizer.py  |  4 ----
 modneflib/modnef/quantizer/quantizer.py          | 16 ++++++----------
 10 files changed, 67 insertions(+), 14 deletions(-)

diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py
index bf63ae0..44443dc 100644
--- a/modneflib/modnef/modnef_torch/model.py
+++ b/modneflib/modnef/modnef_torch/model.py
@@ -106,6 +106,12 @@ class ModNEFModel(nn.Module):
       if isinstance(m, ModNEFNeuron):
         m.quantize(force_init=force_init)
 
+  def clamp(self):
+
+    for m in self.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.clamp()
+
   def train(self, mode : bool = True, quant : bool = False):
     """
     Set neuron model for trainning
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
index 3c83a40..d223c86 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py
@@ -342,6 +342,13 @@ class BLIF(Leaky, ModNEFNeuron):
     self.quantize_parameters()
     self.quantization_flag = True
 
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
index f0ddcdd..d48960d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py
@@ -357,6 +357,15 @@ class RBLIF(Leaky, ModNEFNeuron):
     self.quantize_parameters()
     self.quantization_flag = True
 
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+
+
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
index dc49183..7322594 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py
@@ -91,6 +91,13 @@ class ModNEFNeuron():
     
     raise NotImplementedError()
   
+  def clamp(self):
+    """
+    Clamp synaptic weight
+    """
+
+    raise NotImplementedError()
+  
   def set_quant(self, mode=False):
     self.quantization_flag = mode
   
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
index c950a76..694971e 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py
@@ -375,6 +375,15 @@ class RSLIF(LIF, ModNEFNeuron):
 
     self.quantize_weight()
     self.quantize_parameters()
+
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+
   
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
index 315a07d..fa370b9 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py
@@ -366,6 +366,12 @@ class SLIF(LIF, ModNEFNeuron):
     self.quantize_weight()
     self.quantize_parameters()
 
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
   
   @classmethod
   def detach_hidden(cls):
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
index 427a1c0..0193876 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py
@@ -375,6 +375,15 @@ class RShiftLIF(LIF, ModNEFNeuron):
     self.quantize_weight()
     self.quantize_parameters()
 
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+    self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data)
+
+
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
index 0ffa99b..9bb6d1d 100644
--- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
+++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py
@@ -357,6 +357,14 @@ class ShiftLIF(LIF, ModNEFNeuron):
     self.quantize_weight()
     self.quantize_parameters()
 
+  def clamp(self):
+    """
+    Clamp synaptic weight and neuron hyper-parameters
+    """
+
+    self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data)
+
+
   @classmethod
   def detach_hidden(cls):
     """Returns the hidden states, detached from the current graph.
diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
index 4f4b096..724a8f4 100644
--- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
+++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py
@@ -145,11 +145,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
     -------
     Tensor
     """
-    
-    born_min = -int(self.signed)*2**(self.bitwidth-1)
-    born_max = 2**(self.bitwidth-int(self.signed))-1
 
-    #scaled = torch.clamp(data/self.scale_factor, min=born_min, max=born_max).to(dtype)
     scaled = torch.round(data/self.scale_factor).to(self.dtype)
 
     if unscale:
diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py
index 2a8ef79..235981d 100644
--- a/modneflib/modnef/quantizer/quantizer.py
+++ b/modneflib/modnef/quantizer/quantizer.py
@@ -152,21 +152,17 @@ class Quantizer():
     int | float | list | numpy.array | Tensor (depending on type of data)
     """
 
-    b_min = -1 #(-2**(self.bitwidth-int(self.signed))*int(self.signed))
-    b_max = 1 #2**(self.bitwidth-int(self.signed))-1
+    born_min = -int(self.signed)*2**(self.bitwidth-1)
+    born_max = 2**(self.bitwidth-int(self.signed))-1
     
     if isinstance(data, (int, float)):
-      return self._clamp(torch.tensor(data)).item()
+      return torch.clamp(torch.tensor(data), min=born_min, max=born_max).item()
     elif isinstance(data, list):
-      return self._clamp(torch.tensor(data)).tolist()
+      return torch.clamp(torch.tensor(data), min=born_min, max=born_max).tolist()
     elif isinstance(data, np.ndarray):
-      return self._clamp(torch.tensor(data)).numpy()
+      return torch.clamp(torch.tensor(data), min=born_min, max=born_max).numpy()
     elif torch.is_tensor(data):
-      return self._clamp(data).detach()
+      return torch.clamp(data, min=born_min, max=born_max).detach()
     else:
       raise TypeError("Unsupported data type")
-    
-  def _clamp(self, data):
-
-    raise NotImplementedError()
     
\ No newline at end of file
-- 
GitLab