From 58d75437ce9bed03112b37371e8edc095f7a58bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com> Date: Fri, 14 Mar 2025 11:37:04 +0100 Subject: [PATCH] add clamp to neuron model --- modneflib/modnef/modnef_torch/model.py | 6 ++++++ .../modnef_neurons/blif_model/blif.py | 7 +++++++ .../modnef_neurons/blif_model/rblif.py | 9 +++++++++ .../modnef_neurons/modnef_torch_neuron.py | 7 +++++++ .../modnef_neurons/slif_model/rslif.py | 9 +++++++++ .../modnef_neurons/slif_model/slif.py | 6 ++++++ .../modnef_neurons/srlif_model/rshiftlif.py | 9 +++++++++ .../modnef_neurons/srlif_model/shiftlif.py | 8 ++++++++ .../modnef/quantizer/dynamic_scale_quantizer.py | 4 ---- modneflib/modnef/quantizer/quantizer.py | 16 ++++++---------- 10 files changed, 67 insertions(+), 14 deletions(-) diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index bf63ae0..44443dc 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -106,6 +106,12 @@ class ModNEFModel(nn.Module): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) + def clamp(self): + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + m.clamp() + def train(self, mode : bool = True, quant : bool = False): """ Set neuron model for trainning diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 3c83a40..d223c86 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -342,6 +342,13 @@ class BLIF(Leaky, ModNEFNeuron): self.quantize_parameters() self.quantization_flag = True + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index f0ddcdd..d48960d 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -357,6 +357,15 @@ class RBLIF(Leaky, ModNEFNeuron): self.quantize_parameters() self.quantization_flag = True + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index dc49183..7322594 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -91,6 +91,13 @@ class ModNEFNeuron(): raise NotImplementedError() + def clamp(self): + """ + Clamp synaptic weight + """ + + raise NotImplementedError() + def set_quant(self, mode=False): self.quantization_flag = mode diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index c950a76..694971e 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -375,6 +375,15 @@ class RSLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 315a07d..fa370b9 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -366,6 +366,12 @@ class SLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 427a1c0..0193876 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -375,6 +375,15 @@ class RShiftLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 0ffa99b..9bb6d1d 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -357,6 +357,14 @@ class ShiftLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py index 4f4b096..724a8f4 100644 --- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py +++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py @@ -145,11 +145,7 @@ class DynamicScaleFactorQuantizer(Quantizer): ------- Tensor """ - - born_min = -int(self.signed)*2**(self.bitwidth-1) - born_max = 2**(self.bitwidth-int(self.signed))-1 - #scaled = torch.clamp(data/self.scale_factor, min=born_min, max=born_max).to(dtype) scaled = torch.round(data/self.scale_factor).to(self.dtype) if unscale: diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index 2a8ef79..235981d 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -152,21 +152,17 @@ class Quantizer(): int | float | list | numpy.array | Tensor (depending on type of data) """ - b_min = -1 #(-2**(self.bitwidth-int(self.signed))*int(self.signed)) - b_max = 1 #2**(self.bitwidth-int(self.signed))-1 + born_min = -int(self.signed)*2**(self.bitwidth-1) + born_max = 2**(self.bitwidth-int(self.signed))-1 if isinstance(data, (int, float)): - return self._clamp(torch.tensor(data)).item() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).item() elif isinstance(data, list): - return self._clamp(torch.tensor(data)).tolist() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).tolist() elif isinstance(data, np.ndarray): - return self._clamp(torch.tensor(data)).numpy() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).numpy() elif torch.is_tensor(data): - return self._clamp(data).detach() + return torch.clamp(data, min=born_min, max=born_max).detach() else: raise TypeError("Unsupported data type") - - def _clamp(self, data): - - raise NotImplementedError() \ No newline at end of file -- GitLab