diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index bf63ae0a948894ad047788c1173bf4de8d5d171e..44443dc4da4759f9c0c785f4df1e2b3c9493cca2 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -106,6 +106,12 @@ class ModNEFModel(nn.Module): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) + def clamp(self): + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + m.clamp() + def train(self, mode : bool = True, quant : bool = False): """ Set neuron model for trainning diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 3c83a40741ef86dc3818e88d41df8ecd1de3cd48..d223c86d20afea7be6e97ceec852ed2fb7346009 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -342,6 +342,13 @@ class BLIF(Leaky, ModNEFNeuron): self.quantize_parameters() self.quantization_flag = True + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index f0ddcdd17ba6fc38b7207a5107a7469a3a550f61..d48960ddde25bfcb3a8c02ae049cc1fafe93a86e 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -357,6 +357,15 @@ class RBLIF(Leaky, ModNEFNeuron): self.quantize_parameters() self.quantization_flag = True + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index dc491839d2d9ecb2f63fdecd6786eb0871f82261..73225949001afeb0d5700f6f4a5c15c78b2079aa 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -91,6 +91,13 @@ class ModNEFNeuron(): raise NotImplementedError() + def clamp(self): + """ + Clamp synaptic weight + """ + + raise NotImplementedError() + def set_quant(self, mode=False): self.quantization_flag = mode diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index c950a7664ab40907a166c7f8e75f513ff0261cc7..694971edc6b2b40a4b7d5a2f83760d45095fd043 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -375,6 +375,15 @@ class RSLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 315a07d3bb5aa351e55de8d05d1e126ac6868d7b..fa370b99a19bbe8ea30bf26807b506bcddfa1e44 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -366,6 +366,12 @@ class SLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 427a1c0db9505115dc71c9ded3a29fc226626da1..0193876ac1a13774dcc1062ba9642cacffdfbec3 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -375,6 +375,15 @@ class RShiftLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + self.reccurent.weight.data = self.quantizer.clamp(self.reccurent.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 0ffa99b0e89b1df72b11ad2a116dcb861988848c..9bb6d1df8eeff1ab9044378d73e3625e4a9e4438 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -357,6 +357,14 @@ class ShiftLIF(LIF, ModNEFNeuron): self.quantize_weight() self.quantize_parameters() + def clamp(self): + """ + Clamp synaptic weight and neuron hyper-parameters + """ + + self.fc.weight.data = self.quantizer.clamp(self.fc.weight.data) + + @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py index 4f4b096ed47a376e2c0a323f1e5f94be62a78e87..724a8f4b739bc34c14ea691dd19cfc8df795bcba 100644 --- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py +++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py @@ -145,11 +145,7 @@ class DynamicScaleFactorQuantizer(Quantizer): ------- Tensor """ - - born_min = -int(self.signed)*2**(self.bitwidth-1) - born_max = 2**(self.bitwidth-int(self.signed))-1 - #scaled = torch.clamp(data/self.scale_factor, min=born_min, max=born_max).to(dtype) scaled = torch.round(data/self.scale_factor).to(self.dtype) if unscale: diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index 2a8ef79a5162fb0d2ab331554b5bc41c6a781b96..235981d76cf33be6881d9f6d16181aa68fc76c5e 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -152,21 +152,17 @@ class Quantizer(): int | float | list | numpy.array | Tensor (depending on type of data) """ - b_min = -1 #(-2**(self.bitwidth-int(self.signed))*int(self.signed)) - b_max = 1 #2**(self.bitwidth-int(self.signed))-1 + born_min = -int(self.signed)*2**(self.bitwidth-1) + born_max = 2**(self.bitwidth-int(self.signed))-1 if isinstance(data, (int, float)): - return self._clamp(torch.tensor(data)).item() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).item() elif isinstance(data, list): - return self._clamp(torch.tensor(data)).tolist() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).tolist() elif isinstance(data, np.ndarray): - return self._clamp(torch.tensor(data)).numpy() + return torch.clamp(torch.tensor(data), min=born_min, max=born_max).numpy() elif torch.is_tensor(data): - return self._clamp(data).detach() + return torch.clamp(data, min=born_min, max=born_max).detach() else: raise TypeError("Unsupported data type") - - def _clamp(self, data): - - raise NotImplementedError() \ No newline at end of file