diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py index d427c7826a90ad7f5619db74dd378261cea5073d..d402fda94c19ba45b074b1adcaadb7ed7065a0cb 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py @@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py index 8ebfcedb3759c5d63f4b930125c0ee951c0232f3..b0a66a05f55d3038ef179d6ef54541924c53806e 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py @@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py index bc4cb586e861eb60c30d9a938dcc25f8e0b193a7..3d3427bc77b9619d1909c923ad9ee1ca6f6b5047 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py @@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 698307e7c85df5af31a236e4abc6a0e004f26461..4c9b3e9f6d25e787521b98b72bad682b72c394ff 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -200,7 +200,9 @@ class ShiftLif(ModNEFArchMod): mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") - self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) + self.v_threshold = two_comp(self.quantizer(self.v_threshold, unscale=False, clamp=False), self.variable_size) + + else: for i in range(self.input_neuron): diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index 5e26bc2aee28221c452d7c802a559ce42dda6fe8..84146c6c46b4d0f6bbdeed7f5d5e5703bba8e2f5 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -143,8 +143,7 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): - m.quantize(force_init=force_init) - m.clamp(False) + m.quantize(force_init=force_init, clamp=clamp) def clamp(self, force_init=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 0d4ae6ea746ff5d482e0d0c40dd1f5a35e0a15e5..99b6154532bfe2b6d09c4852c5b7db8c7f1a1963 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -238,8 +238,8 @@ class BLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem*self.beta @@ -265,12 +265,12 @@ class BLIF(ModNEFNeuron): ------- BLIF """ + if self.hardware_description["variable_size"]==-1: if self.hardware_estimation_flag: val_max = max(abs(self.val_max), abs(self.val_min)) val_max = self.quantizer(val_max) - print(val_max) self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8 else: self.hardware_description["variable_size"]=16 diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index f52eed0911e69860603856c7c754621052fd2793..f930deb08db239d9e47b64cbea9f2e527e4b6b51 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -253,8 +253,8 @@ class RBLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem*self.beta diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 208da8d63ea24c098bdb812b5a89db1e162a15de..681b88aa426edc61ceaba9377e5a7f10da48c19f 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -96,7 +96,7 @@ class ModNEFNeuron(SpikingNeuron): else: self.quantizer.init_from_weight(param[0], param[1]) - def quantize_weight(self): + def quantize_weight(self, clamp=False): """ synaptic weight quantization @@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - p.data = QuantizeSTE.apply(p.data, self.quantizer, True) + p.data = QuantizeSTE.apply(p.data, self.quantizer, clamp) def quantize_hp(self): @@ -117,7 +117,7 @@ class ModNEFNeuron(SpikingNeuron): raise NotImplementedError() - def quantize(self, force_init=False): + def quantize(self, force_init=False, clamp=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -130,7 +130,7 @@ class ModNEFNeuron(SpikingNeuron): if force_init: self.init_quantizer() - self.quantize_weight() + self.quantize_weight(clamp=clamp) self.quantize_hp() def clamp(self, force_init=False): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index 38364e28ebb82f5e4ca21eed5b3eb81479e537a1..a2ab9a4205c4e85ff231949bd90de46f6f1e88f6 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -258,8 +258,8 @@ class RSLIF(ModNEFNeuron): self.mem = self.mem-self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() # update neuron self.mem = self.mem - self.v_leak diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 6a6c7d39aa2ea655a355f4f919dc2b9aa3911aac..2ee8653a5c442eecfc3cce4a571515035ad59519 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -248,8 +248,8 @@ class SLIF(ModNEFNeuron): self.mem = self.mem - self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() # update neuron self.mem = self.mem - self.v_leak diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 65d8ba530ed1f1af64df5f2789e4f6e8de00de82..07a631bc275e75bdc04896a02f0a04cc9f39d41f 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -267,8 +267,8 @@ class RShiftLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem-self.__shift(self.mem) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 70f9d96c04f51d74647285f40955590ae2a3d3d5..e2cb2652ad0dd4fd6c1cc600618797c0baab4bff 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -120,6 +120,8 @@ class ShiftLIF(ModNEFNeuron): quantizer=quantizer ) + + self.shift = int(-log(1-beta)/log(2)) @@ -138,6 +140,8 @@ class ShiftLIF(ModNEFNeuron): "variable_size" : -1 } + print(threshold) + @classmethod def from_dict(cls, dict, spike_grad): """ @@ -258,8 +262,8 @@ class ShiftLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem-self.__shift(self.mem) @@ -294,6 +298,8 @@ class ShiftLIF(ModNEFNeuron): else: self.hardware_description["variable_size"]=16 + #self.clamp(force_init=True) + module = builder.ShiftLif( name=module_name, input_neuron=self.fc.in_features, diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py index 57f67b4478d0883efe4e20a08026b94b7a0487b4..f22f0a520133df3baa9723a1db5dc1311fbeaa61 100644 --- a/modneflib/modnef/modnef_torch/quantLinear.py +++ b/modneflib/modnef/modnef_torch/quantLinear.py @@ -51,7 +51,7 @@ class QuantLinear(nn.Linear): if quantizer!=None: w = QuantizeSTE.apply(self.weight, quantizer) - w.data = quantizer.clamp(w) + #w.data = quantizer.clamp(w) else: w = self.weight diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index e26495079afd1651990f8cdfadb6676e3e77f616..20914b619575f2200ecf21dcb152e902eef42c85 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -115,8 +115,8 @@ class Quantizer(): qdata = self._quant(tdata) if clamp: - born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)) - born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1) + born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)).to(qdata.device) + born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1).to(qdata.device) qdata = torch.clamp(qdata, min=born_min, max=born_max) if unscale: diff --git a/modneflib/modnef/quantizer/quantizer_scheduler.py b/modneflib/modnef/quantizer/quantizer_scheduler.py index 850dd9b00f66033098fa7869693fd9f38c2e8eb7..2d51f186aabc5136fe00984dfee9ff6938a084f6 100644 --- a/modneflib/modnef/quantizer/quantizer_scheduler.py +++ b/modneflib/modnef/quantizer/quantizer_scheduler.py @@ -16,28 +16,26 @@ class QuantizerScheduler(): self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)] - self.model = model + self.bit_counter = 0 + self.epoch_counter = 0 - self.period = T + self.bitwidth = self.num_bits[self.bit_counter] + self.period = T self.epoch_max = self.period*(len(self.num_bits)-1) - print(self.num_bits) - print(self.epoch_max) - - self.bit_counter = 0 + self.model = model self.quantizationMethod = quantizationMethod self.__update() - self.epoch_counter = 0 + def __update(self): - print(self.num_bits[self.bit_counter]) for m in self.model.modules(): if isinstance(m, ModNEFNeuron): - m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter]) + m.quantizer = self.quantizationMethod(self.bitwidth) m.init_quantizer() m.quantize_hp() @@ -50,6 +48,7 @@ class QuantizerScheduler(): else: if self.epoch_counter%self.period==0: self.bit_counter += 1 + self.bitwidth = self.num_bits[self.bit_counter] self.__update() def save_model(self): diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py index c5ffd3292637b7bf0cecb9f25f88c645d462d84e..b66832e320fbb464ae48abcb65b866b9ace0d662 100644 --- a/modneflib/modnef/quantizer/ste_quantizer.py +++ b/modneflib/modnef/quantizer/ste_quantizer.py @@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function): quantization method applied to data """ - q_data = quantizer(data, True, clamp) + q_data = quantizer(data, unscale=True, clamp=clamp) ctx.scale = quantizer.scale_factor diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py index 1c2898bc4458dc387d7c87db15c0294fe94978f6..8fa94db51748bc8254de854504312ab584edd0ad 100644 --- a/modneflib/modnef/templates/evaluation.py +++ b/modneflib/modnef/templates/evaluation.py @@ -14,6 +14,9 @@ if __name__ == "__main__": """Experience name""" exp_name = "Evaluation" + """Device definition""" + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + """Model definition""" best_model_name = "best_model" @@ -22,7 +25,7 @@ if __name__ == "__main__": model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25)) - model.load_state_dict(torch.load(best_model_name)) + model.load_state_dict(torch.load(best_model_name, map_location=device)) """Kind of run @@ -33,15 +36,12 @@ if __name__ == "__main__": kind = sys.argv[1] - """Device definition""" - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - """Evaluation variable definition""" verbose = True output_path = "." """FPGA file definition""" - driver_config_path = "driver_config" + driver_config_path = "driver_config.yml" board_path = "" @@ -69,6 +69,7 @@ if __name__ == "__main__": quant=False ) elif kind == "qeval": + model.quantize(force_init=True, clamp=True) acc, y_pred, y_true = evaluation( model=model, testLoader=testLoader, diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py index b8bbb3ceb6bb9a4790ab37cf5f6223f4cbd5aa07..6868daf3e509f7cf5215a18f678c2b364a809f62 100644 --- a/modneflib/modnef/templates/model.py +++ b/modneflib/modnef/templates/model.py @@ -114,11 +114,20 @@ class MyModel(mt.ModNEFModel): batch_result = [] + run_time = [] + + n_layer = 0 + + for m in self.modules(): + if isinstance(m, mt.ModNEFNeuron): + n_layer += 1 + for sample in input_spikes: - sample_res = self.driver.run_sample(sample, to_aer, True, len(self.layers)) + sample_res = self.driver.run_sample(sample, to_aer, True, n_layer) + run_time.append(self.driver.sample_time) batch_result.append([sample_res]) - return torch.tensor(batch_result).permute(1, 0, 2), None + return torch.tensor(batch_result).permute(1, 0, 2), None, run_time def to_vhdl(self, file_name=None, output_path = ".", driver_config_path = "./driver.yml"): """ @@ -149,8 +158,8 @@ class MyModel(mt.ModNEFModel): baud_rate=921_600, queue_read_depth=10240, queue_write_depth=1024, - tx_name="uart_rxd", - rx_name="uart_txd" + tx_name="uart_txd", + rx_name="uart_rxd" ) builder.add_module(uart) diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py index 73dd3521012b282bf9c2f44f86b7dc6b41d65ba4..7cee0717607bc16f362d25196ceb235622db2662 100644 --- a/modneflib/modnef/templates/run_lib.py +++ b/modneflib/modnef/templates/run_lib.py @@ -105,7 +105,7 @@ def train(model, if qat_scheduler==None: # no QAT scheduler so we save our model torch.save(model.state_dict(), f"{output_path}/{best_model_name}") best_acc = acc_test - elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth + elif qat_scheduler.save_model(): # if QAT scheduler, we need to check if we quantize at the target bitwidth torch.save(model.state_dict(), f"{output_path}/{best_model_name}") best_acc = acc_test @@ -114,9 +114,11 @@ def train(model, lr_val_history.append(lr_scheduler.get_last_lr()[0]) lr_scheduler.step() - if qat_scheduler!=None: - bitwidth_val_history.append(qat_scheduler.bitwidth) - qat_scheduler.step() + if qat: + model.clamp() + if qat_scheduler!=None: + bitwidth_val_history.append(qat_scheduler.bitwidth) + qat_scheduler.step() if save_history: np.save(f"{output_path}/loss.npy", np.array(avg_loss_history)) @@ -128,7 +130,7 @@ def train(model, if lr_scheduler!=None: np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler)) - if qat_scheduler!=None: + if qat and qat_scheduler!=None: np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history)) if save_plot: @@ -159,7 +161,7 @@ def train(model, plt.ylabel("learning rate") plt.savefig(f"{output_path}/lr_values.svg") - if qat_scheduler!=None: + if qat and qat_scheduler!=None: plt.figure() plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth") plt.legend() @@ -287,7 +289,7 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva y_true.append(target) - spk_rec, batch_speed = model(data) + spk_rec, mem_rec, batch_speed = model(data) run_time.extend(batch_speed) @@ -304,8 +306,8 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva del target del spk_rec - y_true = torch.stack(y_true).reshape(-1) - y_pred = torch.stack(y_pred).reshape(-1) + y_true = torch.stack(y_true).cou().reshape(-1).numpy() + y_pred = torch.stack(y_pred).coup().reshape(-1).numpy() run_time = np.array(run_time) diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py index 7765dae21c343bb89607ce680b0958e24c91de71..3ba0a9bc435ea769ed5b8f4fb48441eea649abf0 100644 --- a/modneflib/modnef/templates/vhdl_generation.py +++ b/modneflib/modnef/templates/vhdl_generation.py @@ -13,6 +13,9 @@ if __name__ == "__main__": """Experience name""" exp_name = "Evaluation" + """Device definition""" + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + """Model definition""" best_model_name = "best_model" @@ -21,11 +24,7 @@ if __name__ == "__main__": model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25)) - model.load_state_dict(torch.load(best_model_name)) - - - """Device definition""" - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.load_state_dict(torch.load(best_model_name, map_location=device)) """Hardware Estimation variable definition""" verbose = True