diff --git a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd index 468271ea6ba3bf46f53401973a106edf3564a651..83722e01ae4a66b269a62f1762ba4d73b2e90685 100644 --- a/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd +++ b/ModNEF_Sources/modules/neurons/BLIF/beta_lif.vhd @@ -121,9 +121,7 @@ begin case state is when multiplication => V_mult := std_logic_vector(signed(V) * signed(beta)); - V_mult := std_logic_vector(shift_right(signed(V_mult), fixed_point)); - V_buff := V_mult(variable_size-1 downto 0); - --V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point); + V_buff := V_mult(fixed_point + variable_size-1 downto fixed_point); if signed(V_buff) >= signed(v_threshold) then spike <= '1'; diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py index 9dc9af6c60627682068086fb893a789f7e899930..0caecf1e51726951047ea6f88f1a08202a8275d6 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py @@ -199,7 +199,7 @@ class BLif(ModNEFArchMod): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw) + w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), bw) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -210,7 +210,7 @@ class BLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.weight_size) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py index f6f7ca9881e1d6252aff9e5cdd10b15037aa1ad3..294001987bc026786cf2ce953f1cf52183b1dd93 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/blif_debugger.py @@ -213,7 +213,7 @@ class BLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -224,7 +224,7 @@ class BLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py index 81f87c65b8ccbe95398c478cd6942ecdd21b4a40..d427c7826a90ad7f5619db74dd378261cea5073d 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py @@ -205,7 +205,7 @@ class RBLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -213,7 +213,7 @@ class RBLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py index 1c1561348cb39c35213b86b0ab2824b63767a7c1..8ebfcedb3759c5d63f4b930125c0ee951c0232f3 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py @@ -214,14 +214,14 @@ class RSLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() @@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif.py b/modneflib/modnef/arch_builder/modules/SLIF/slif.py index ddf5c96f6e242224437fca9c8282bcdedcaf7e6e..991c16e5ed45a422b51199bf228697b87057e7ad 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/slif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/slif.py @@ -201,7 +201,7 @@ class SLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) @@ -213,7 +213,7 @@ class SLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py index 7fb587c792b02f91ceb9c375469e0d68b437062e..dfca4698ce362f8a6b979bf48ee49b1bc3b56640 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/slif_debugger.py @@ -215,7 +215,7 @@ class SLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) @@ -227,7 +227,7 @@ class SLif_Debugger(ModNEFDebuggerMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py index f3077ae9316df4d056ea3f2611867196c5adb71a..bc4cb586e861eb60c30d9a938dcc25f8e0b193a7 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py @@ -206,7 +206,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -214,7 +214,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 35bb7344303f59276e43ce80ea060b7090519d4f..698307e7c85df5af31a236e4abc6a0e004f26461 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -196,7 +196,7 @@ class ShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j]), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -206,7 +206,7 @@ class ShiftLif(ModNEFArchMod): for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j]) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index f81295db9b5acd3b38ec0695410742bc9e602df8..5e26bc2aee28221c452d7c802a559ce42dda6fe8 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -131,7 +131,7 @@ class ModNEFModel(nn.Module): m.init_quantizer() m.quantize_weight() - def quantize(self, force_init=False): + def quantize(self, force_init=False, clamp=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -144,6 +144,7 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): m.quantize(force_init=force_init) + m.clamp(False) def clamp(self, force_init=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 08f17e763ec7194a98a93e2647ba287c5bd37f2e..0d4ae6ea746ff5d482e0d0c40dd1f5a35e0a15e5 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -230,19 +230,18 @@ class BLIF(ModNEFNeuron): self.reset = self.mem_reset(self.mem) - self.mem = self.mem + forward_current - self.reset*self.threshold + self.mem = self.mem + forward_current + + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - - if self.reset_mechanism == "subtract": - self.mem = self.mem*self.beta - elif self.reset_mechanism == "zero": - self.mem = self.mem*self.beta-self.reset*self.mem - else: - self.mem = self.mem*self.beta + self.mem = self.mem*self.beta if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index 8b49564ea2549316450ffeb6fa464a36304d9737..f52eed0911e69860603856c7c754621052fd2793 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -247,16 +247,16 @@ class RBLIF(ModNEFNeuron): self.mem = self.mem + forward_current + rec_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem*self.beta-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem*self.beta-self.reset*self.mem - else: - self.mem = self.mem*self.beta + self.mem = self.mem*self.beta if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 9a48b0c61d4cd08aad983bab0aca0942a4f5d0dc..208da8d63ea24c098bdb812b5a89db1e162a15de 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - p.data = QuantizeSTE.apply(p.data, self.quantizer) + p.data = QuantizeSTE.apply(p.data, self.quantizer, True) def quantize_hp(self): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index c1915a78a6b6251553b61137537388b3735aae21..38364e28ebb82f5e4ca21eed5b3eb81479e537a1 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -255,6 +255,8 @@ class RSLIF(ModNEFNeuron): self.mem = self.mem + forward_current + rec_current + self.mem = self.mem-self.reset*(self.mem-self.v_rest) + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) @@ -262,7 +264,7 @@ class RSLIF(ModNEFNeuron): # update neuron self.mem = self.mem - self.v_leak min_reset = (self.mem<self.v_min).to(torch.float32) - self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) + self.mem = self.mem-min_reset*(self.mem-self.v_rest) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 43e2e1c7f12c33cacc10b07dffbfc2fbf92d76dc..6a6c7d39aa2ea655a355f4f919dc2b9aa3911aac 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -245,6 +245,7 @@ class SLIF(ModNEFNeuron): self.mem = self.mem + forward_current + self.mem = self.mem - self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) @@ -253,7 +254,7 @@ class SLIF(ModNEFNeuron): # update neuron self.mem = self.mem - self.v_leak min_reset = (self.mem<self.v_min).to(torch.float32) - self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) + self.mem = self.mem-min_reset*(self.mem-self.v_rest) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 0ea9983f38fbe56c3d8caece2751d049ac34d9cb..65d8ba530ed1f1af64df5f2789e4f6e8de00de82 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -261,16 +261,16 @@ class RShiftLIF(ModNEFNeuron): self.mem = self.mem+forward_current+rec_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem - else: - self.mem = self.mem-self.__shift(self.mem) + self.mem = self.mem-self.__shift(self.mem) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 5291733f9424f4024443e564efb014a339a73fe9..70f9d96c04f51d74647285f40955590ae2a3d3d5 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -252,16 +252,16 @@ class ShiftLIF(ModNEFNeuron): self.mem = self.mem+forward_current + if self.reset_mechanism == "subtract": + self.mem = self.mem-self.reset*self.threshold + elif self.reset_mechanism == "zero": + self.mem = self.mem-self.reset*self.mem + if self.hardware_estimation_flag: self.val_min = torch.min(self.mem.min(), self.val_min) self.val_max = torch.max(self.mem.max(), self.val_max) - if self.reset_mechanism == "subtract": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold - elif self.reset_mechanism == "zero": - self.mem = self.mem-self.__shift(self.mem)-self.reset*self.mem - else: - self.mem = self.mem-self.__shift(self.mem) + self.mem = self.mem-self.__shift(self.mem) if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index c56f50616be06d383a33840424ae199967f35403..e26495079afd1651990f8cdfadb6676e3e77f616 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -91,7 +91,7 @@ class Quantizer(): raise NotImplementedError() - def __call__(self, data, unscale=False): + def __call__(self, data, unscale=False, clamp=False): """ Call quantization function @@ -114,6 +114,11 @@ class Quantizer(): qdata = self._quant(tdata) + if clamp: + born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)) + born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1) + qdata = torch.clamp(qdata, min=born_min, max=born_max) + if unscale: qdata = self._unquant(qdata).to(torch.float32) diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py index 565cbd8a14354bf97f5056dd2045f61659c4499e..c5ffd3292637b7bf0cecb9f25f88c645d462d84e 100644 --- a/modneflib/modnef/quantizer/ste_quantizer.py +++ b/modneflib/modnef/quantizer/ste_quantizer.py @@ -25,7 +25,7 @@ class QuantizeSTE(torch.autograd.Function): """ @staticmethod - def forward(ctx, data, quantizer): + def forward(ctx, data, quantizer, clamp=False): """ Apply quantization method to data @@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function): quantization method applied to data """ - q_data = quantizer(data, True) + q_data = quantizer(data, True, clamp) ctx.scale = quantizer.scale_factor