From d08d6933b098c70afc903223f4f36e79c682073b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com> Date: Fri, 4 Apr 2025 12:42:29 +0200 Subject: [PATCH] add clamp to arch builder --- ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd | 4 +--- .../modnef/arch_builder/modules/BLIF/blif.py | 4 ++-- .../arch_builder/modules/BLIF/blif_debugger.py | 4 ++-- .../modnef/arch_builder/modules/BLIF/rblif.py | 8 ++++---- .../modnef/arch_builder/modules/SLIF/rslif.py | 8 ++++---- .../modnef/arch_builder/modules/SLIF/slif.py | 4 ++-- .../arch_builder/modules/SLIF/slif_debugger.py | 4 ++-- .../arch_builder/modules/ShiftLIF/rshiftlif.py | 8 ++++---- .../arch_builder/modules/ShiftLIF/shiftlif.py | 4 ++-- modneflib/modnef/modnef_torch/model.py | 3 ++- .../modnef_neurons/blif_model/blif.py | 15 +++++++-------- .../modnef_neurons/blif_model/rblif.py | 12 ++++++------ .../modnef_neurons/modnef_torch_neuron.py | 2 +- .../modnef_neurons/slif_model/rslif.py | 4 +++- .../modnef_neurons/slif_model/slif.py | 3 ++- .../modnef_neurons/srlif_model/rshiftlif.py | 12 ++++++------ .../modnef_neurons/srlif_model/shiftlif.py | 12 ++++++------ modneflib/modnef/quantizer/quantizer.py | 7 ++++++- modneflib/modnef/quantizer/ste_quantizer.py | 4 ++-- 19 files changed, 64 insertions(+), 58 deletions(-) diff --git a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd index 468271e..83722e0 100644 --- a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd +++ b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd @@ -121,9 +121,7 @@ begin case state is when multiplication => V_mult := std_logic_vector(signed(V) * signed(beta)); - V_mult := std_logic_vector(shift_right(signed(V_mult), fixed_point)); - V_buff := V_mult(variable_size-1 downto 0); - --V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point); + V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point); if signed(V_buff) >= signed(v_threshold) then spike <= '1'; diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py index 9dc9af6..0caecf1 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py @@ -199,7 +199,7 @@ class BLif(ModNEFArchMod): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw) + w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), bw) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -210,7 +210,7 @@ class BLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py index f6f7ca9..2940019 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py @@ -213,7 +213,7 @@ class BLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -224,7 +224,7 @@ class BLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py index 81f87c6..d427c78 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py @@ -205,7 +205,7 @@ class RBLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -213,7 +213,7 @@ class RBLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py index 1c15613..8ebfced 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py @@ -214,14 +214,14 @@ class RSLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() @@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif.py b/modneflib/modnef/arch_builder/modules/SLIF/slif.py index ddf5c96..991c16e 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/slif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/slif.py @@ -201,7 +201,7 @@ class SLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) @@ -213,7 +213,7 @@ class SLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py index 7fb587c..dfca469 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py @@ -215,7 +215,7 @@ class SLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) @@ -227,7 +227,7 @@ class SLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py index f3077ae..bc4cb58 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py @@ -206,7 +206,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -214,7 +214,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 35bb734..698307e 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -196,7 +196,7 @@ class ShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -206,7 +206,7 @@ class ShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index f81295d..5e26bc2 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -131,7 +131,7 @@ class ModNEFModel(nn.Module): m.init_quantizer() m.quantize_weight() - def quantize(self, force_init=False): + def quantize(self, force_init=False, clamp=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -144,6 +144,7 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) + m.clamp(False) def clamp(self, force_init=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 08f17e7..0d4ae6e 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -230,19 +230,18 @@ class BLIF(ModNEFNeuron): self.reset = self.mem_reset(self.mem) - self.mem = self.mem + forward_current - self.reset*self.threshold + self.mem = self.mem + forward_current + + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - - if self.reset_mechanism == "subtract": - self.mem = self.mem*self.beta - elif self.reset_mechanism == "zero": - self.mem = self.mem*self.beta-self.reset*self.mem - else: - self.mem = self.mem*self.beta + self.mem = self.mem*self.beta if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index 8b49564..f52eed0 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -247,16 +247,16 @@ class RBLIF(ModNEFNeuron): self.mem = self.mem + forward_current + rec_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem*self.beta-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem*self.beta-self.reset*self.mem - else: - self.mem = self.mem*self.beta + self.mem = self.mem*self.beta if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 9a48b0c..208da8d 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - p.data = QuantizeSTE.apply(p.data, self.quantizer) + p.data = QuantizeSTE.apply(p.data, self.quantizer, True) def quantize_hp(self): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index c1915a7..38364e2 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -255,6 +255,8 @@ class RSLIF(ModNEFNeuron): self.mem = self.mem + forward_current + rec_current + self.mem = self.mem-self.reset*(self.mem-self.v_rest) + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) @@ -262,7 +264,7 @@ class RSLIF(ModNEFNeuron): # update neuron self.mem = self.mem - self.v_leak min_reset = (self.mem<self.v_min).to(torch.float32) - self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) + self.mem = self.mem-min_reset*(self.mem-self.v_rest) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 43e2e1c..6a6c7d3 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -245,6 +245,7 @@ class SLIF(ModNEFNeuron): self.mem = self.mem + forward_current + self.mem = self.mem - self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) @@ -253,7 +254,7 @@ class SLIF(ModNEFNeuron): # update neuron self.mem = self.mem - self.v_leak min_reset = (self.mem<self.v_min).to(torch.float32) - self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) + self.mem = self.mem-min_reset*(self.mem-self.v_rest) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 0ea9983..65d8ba5 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -261,16 +261,16 @@ class RShiftLIF(ModNEFNeuron): self.mem = self.mem+forward_current+rec_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem - else: - self.mem = self.mem-self.__shift(self.mem) + self.mem = self.mem-self.__shift(self.mem) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 5291733..70f9d96 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -252,16 +252,16 @@ class ShiftLIF(ModNEFNeuron): self.mem = self.mem+forward_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem - else: - self.mem = self.mem-self.__shift(self.mem) + self.mem = self.mem-self.__shift(self.mem) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index c56f506..e264950 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -91,7 +91,7 @@ class Quantizer(): raise NotImplementedError() - def __call__(self, data, unscale=False): + def __call__(self, data, unscale=False, clamp=False): """ Call quantization function @@ -114,6 +114,11 @@ class Quantizer(): qdata = self._quant(tdata) + if clamp: + born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)) + born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1) + qdata = torch.clamp(qdata, min=born_min, max=born_max) + if unscale: qdata = self._unquant(qdata).to(torch.float32) diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py index 565cbd8..c5ffd32 100644 --- a/modneflib/modnef/quantizer/ste_quantizer.py +++ b/modneflib/modnef/quantizer/ste_quantizer.py @@ -25,7 +25,7 @@ class QuantizeSTE(torch.autograd.Function): """ @staticmethod - def forward(ctx, data, quantizer): + def forward(ctx, data, quantizer, clamp=False): """ Apply quantization method to data @@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function): quantization method applied to data """ - q_data = quantizer(data, True) + q_data = quantizer(data, True, clamp) ctx.scale = quantizer.scale_factor -- GitLab