From 94a2736415a26c5f18ff0a46a10ea13c13c5cd4a Mon Sep 17 00:00:00 2001 From: ahoni <aurelie.saulq@proton.me> Date: Mon, 7 Apr 2025 10:19:11 +0200 Subject: [PATCH] fix bug on arch builder --- .../modnef/arch_builder/modules/BLIF/rblif.py | 4 ++-- .../modnef/arch_builder/modules/SLIF/rslif.py | 4 ++-- .../modules/ShiftLIF/rshiftlif.py | 4 ++-- .../arch_builder/modules/ShiftLIF/shiftlif.py | 4 +++- modneflib/modnef/modnef_torch/model.py | 3 +-- .../modnef_neurons/blif_model/blif.py | 6 +++--- .../modnef_neurons/blif_model/rblif.py | 4 ++-- .../modnef_neurons/modnef_torch_neuron.py | 8 ++++---- .../modnef_neurons/slif_model/rslif.py | 4 ++-- .../modnef_neurons/slif_model/slif.py | 4 ++-- .../modnef_neurons/srlif_model/rshiftlif.py | 4 ++-- .../modnef_neurons/srlif_model/shiftlif.py | 10 ++++++++-- modneflib/modnef/modnef_torch/quantLinear.py | 2 +- modneflib/modnef/quantizer/quantizer.py | 4 ++-- .../modnef/quantizer/quantizer_scheduler.py | 17 ++++++++-------- modneflib/modnef/quantizer/ste_quantizer.py | 2 +- modneflib/modnef/templates/evaluation.py | 11 +++++----- modneflib/modnef/templates/model.py | 17 ++++++++++++---- modneflib/modnef/templates/run_lib.py | 20 ++++++++++--------- modneflib/modnef/templates/vhdl_generation.py | 9 ++++----- 20 files changed, 79 insertions(+), 62 deletions(-) diff --git a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py index d427c78..d402fda 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/rblif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/rblif.py @@ -224,7 +224,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -232,7 +232,7 @@ class RBLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py index 8ebfced..b0a66a0 100644 --- a/modneflib/modnef/arch_builder/modules/SLIF/rslif.py +++ b/modneflib/modnef/arch_builder/modules/SLIF/rslif.py @@ -232,14 +232,14 @@ class RSLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") else: for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") mem_file.close() diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py index bc4cb58..3d3427b 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/rshiftlif.py @@ -226,7 +226,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) + w_line = (w_line<<self.quantizer.bitwidth) + two_comp(self.quantizer(rec_weights[i][j], unscale=False, clamp=True), self.quantizer.bitwidth) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") @@ -234,7 +234,7 @@ class RShiftLif(ModNEFArchMod): for i in range(self.output_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): - w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(weights[i][j], unscale=False, clamp=True) + w_line = (w_line<<self.quantizer.bitwidth) + self.quantizer(rec_weights[i][j], unscale=False, clamp=True) rec_mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py index 698307e..4c9b3e9 100644 --- a/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py +++ b/modneflib/modnef/arch_builder/modules/ShiftLIF/shiftlif.py @@ -200,7 +200,9 @@ class ShiftLif(ModNEFArchMod): mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") - self.v_threshold = two_comp(self.quantizer(self.v_threshold), self.variable_size) + self.v_threshold = two_comp(self.quantizer(self.v_threshold, unscale=False, clamp=False), self.variable_size) + + else: for i in range(self.input_neuron): diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index 5e26bc2..84146c6 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -143,8 +143,7 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): - m.quantize(force_init=force_init) - m.clamp(False) + m.quantize(force_init=force_init, clamp=clamp) def clamp(self, force_init=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 0d4ae6e..99b6154 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -238,8 +238,8 @@ class BLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem*self.beta @@ -265,12 +265,12 @@ class BLIF(ModNEFNeuron): ------- BLIF """ + if self.hardware_description["variable_size"]==-1: if self.hardware_estimation_flag: val_max = max(abs(self.val_max), abs(self.val_min)) val_max = self.quantizer(val_max) - print(val_max) self.hardware_description["variable_size"] = ceil(log(val_max)/log(256))*8 else: self.hardware_description["variable_size"]=16 diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index f52eed0..f930deb 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -253,8 +253,8 @@ class RBLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem*self.beta diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 208da8d..681b88a 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -96,7 +96,7 @@ class ModNEFNeuron(SpikingNeuron): else: self.quantizer.init_from_weight(param[0], param[1]) - def quantize_weight(self): + def quantize_weight(self, clamp=False): """ synaptic weight quantization @@ -106,7 +106,7 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - p.data = QuantizeSTE.apply(p.data, self.quantizer, True) + p.data = QuantizeSTE.apply(p.data, self.quantizer, clamp) def quantize_hp(self): @@ -117,7 +117,7 @@ class ModNEFNeuron(SpikingNeuron): raise NotImplementedError() - def quantize(self, force_init=False): + def quantize(self, force_init=False, clamp=False): """ Quantize synaptic weight and neuron hyper-parameters @@ -130,7 +130,7 @@ class ModNEFNeuron(SpikingNeuron): if force_init: self.init_quantizer() - self.quantize_weight() + self.quantize_weight(clamp=clamp) self.quantize_hp() def clamp(self, force_init=False): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index 38364e2..a2ab9a4 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -258,8 +258,8 @@ class RSLIF(ModNEFNeuron): self.mem = self.mem-self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() # update neuron self.mem = self.mem - self.v_leak diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 6a6c7d3..2ee8653 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -248,8 +248,8 @@ class SLIF(ModNEFNeuron): self.mem = self.mem - self.reset*(self.mem-self.v_rest) if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() # update neuron self.mem = self.mem - self.v_leak diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index 65d8ba5..07a631b 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -267,8 +267,8 @@ class RShiftLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem-self.__shift(self.mem) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 70f9d96..e2cb265 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -120,6 +120,8 @@ class ShiftLIF(ModNEFNeuron): quantizer=quantizer ) + + self.shift = int(-log(1-beta)/log(2)) @@ -138,6 +140,8 @@ class ShiftLIF(ModNEFNeuron): "variable_size" : -1 } + print(threshold) + @classmethod def from_dict(cls, dict, spike_grad): """ @@ -258,8 +262,8 @@ class ShiftLIF(ModNEFNeuron): self.mem = self.mem-self.reset*self.mem if self.hardware_estimation_flag: - self.val_min = torch.min(self.mem.min(), self.val_min) - self.val_max = torch.max(self.mem.max(), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min).detach() + self.val_max = torch.max(self.mem.max(), self.val_max).detach() self.mem = self.mem-self.__shift(self.mem) @@ -294,6 +298,8 @@ class ShiftLIF(ModNEFNeuron): else: self.hardware_description["variable_size"]=16 + #self.clamp(force_init=True) + module = builder.ShiftLif( name=module_name, input_neuron=self.fc.in_features, diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py index 57f67b4..f22f0a5 100644 --- a/modneflib/modnef/modnef_torch/quantLinear.py +++ b/modneflib/modnef/modnef_torch/quantLinear.py @@ -51,7 +51,7 @@ class QuantLinear(nn.Linear): if quantizer!=None: w = QuantizeSTE.apply(self.weight, quantizer) - w.data = quantizer.clamp(w) + #w.data = quantizer.clamp(w) else: w = self.weight diff --git a/modneflib/modnef/quantizer/quantizer.py b/modneflib/modnef/quantizer/quantizer.py index e264950..20914b6 100644 --- a/modneflib/modnef/quantizer/quantizer.py +++ b/modneflib/modnef/quantizer/quantizer.py @@ -115,8 +115,8 @@ class Quantizer(): qdata = self._quant(tdata) if clamp: - born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)) - born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1) + born_min = torch.tensor(-int(self.signed)*2**(self.bitwidth-1)).to(qdata.device) + born_max = torch.tensor(2**(self.bitwidth-int(self.signed))-1).to(qdata.device) qdata = torch.clamp(qdata, min=born_min, max=born_max) if unscale: diff --git a/modneflib/modnef/quantizer/quantizer_scheduler.py b/modneflib/modnef/quantizer/quantizer_scheduler.py index 850dd9b..2d51f18 100644 --- a/modneflib/modnef/quantizer/quantizer_scheduler.py +++ b/modneflib/modnef/quantizer/quantizer_scheduler.py @@ -16,28 +16,26 @@ class QuantizerScheduler(): self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)] - self.model = model + self.bit_counter = 0 + self.epoch_counter = 0 - self.period = T + self.bitwidth = self.num_bits[self.bit_counter] + self.period = T self.epoch_max = self.period*(len(self.num_bits)-1) - print(self.num_bits) - print(self.epoch_max) - - self.bit_counter = 0 + self.model = model self.quantizationMethod = quantizationMethod self.__update() - self.epoch_counter = 0 + def __update(self): - print(self.num_bits[self.bit_counter]) for m in self.model.modules(): if isinstance(m, ModNEFNeuron): - m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter]) + m.quantizer = self.quantizationMethod(self.bitwidth) m.init_quantizer() m.quantize_hp() @@ -50,6 +48,7 @@ class QuantizerScheduler(): else: if self.epoch_counter%self.period==0: self.bit_counter += 1 + self.bitwidth = self.num_bits[self.bit_counter] self.__update() def save_model(self): diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py index c5ffd32..b66832e 100644 --- a/modneflib/modnef/quantizer/ste_quantizer.py +++ b/modneflib/modnef/quantizer/ste_quantizer.py @@ -39,7 +39,7 @@ class QuantizeSTE(torch.autograd.Function): quantization method applied to data """ - q_data = quantizer(data, True, clamp) + q_data = quantizer(data, unscale=True, clamp=clamp) ctx.scale = quantizer.scale_factor diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py index 1c2898b..8fa94db 100644 --- a/modneflib/modnef/templates/evaluation.py +++ b/modneflib/modnef/templates/evaluation.py @@ -14,6 +14,9 @@ if __name__ == "__main__": """Experience name""" exp_name = "Evaluation" + """Device definition""" + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + """Model definition""" best_model_name = "best_model" @@ -22,7 +25,7 @@ if __name__ == "__main__": model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25)) - model.load_state_dict(torch.load(best_model_name)) + model.load_state_dict(torch.load(best_model_name, map_location=device)) """Kind of run @@ -33,15 +36,12 @@ if __name__ == "__main__": kind = sys.argv[1] - """Device definition""" - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - """Evaluation variable definition""" verbose = True output_path = "." """FPGA file definition""" - driver_config_path = "driver_config" + driver_config_path = "driver_config.yml" board_path = "" @@ -69,6 +69,7 @@ if __name__ == "__main__": quant=False ) elif kind == "qeval": + model.quantize(force_init=True, clamp=True) acc, y_pred, y_true = evaluation( model=model, testLoader=testLoader, diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py index b8bbb3c..6868daf 100644 --- a/modneflib/modnef/templates/model.py +++ b/modneflib/modnef/templates/model.py @@ -114,11 +114,20 @@ class MyModel(mt.ModNEFModel): batch_result = [] + run_time = [] + + n_layer = 0 + + for m in self.modules(): + if isinstance(m, mt.ModNEFNeuron): + n_layer += 1 + for sample in input_spikes: - sample_res = self.driver.run_sample(sample, to_aer, True, len(self.layers)) + sample_res = self.driver.run_sample(sample, to_aer, True, n_layer) + run_time.append(self.driver.sample_time) batch_result.append([sample_res]) - return torch.tensor(batch_result).permute(1, 0, 2), None + return torch.tensor(batch_result).permute(1, 0, 2), None, run_time def to_vhdl(self, file_name=None, output_path = ".", driver_config_path = "./driver.yml"): """ @@ -149,8 +158,8 @@ class MyModel(mt.ModNEFModel): baud_rate=921_600, queue_read_depth=10240, queue_write_depth=1024, - tx_name="uart_rxd", - rx_name="uart_txd" + tx_name="uart_txd", + rx_name="uart_rxd" ) builder.add_module(uart) diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py index 73dd352..7cee071 100644 --- a/modneflib/modnef/templates/run_lib.py +++ b/modneflib/modnef/templates/run_lib.py @@ -105,7 +105,7 @@ def train(model, if qat_scheduler==None: # no QAT scheduler so we save our model torch.save(model.state_dict(), f"{output_path}/{best_model_name}") best_acc = acc_test - elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth + elif qat_scheduler.save_model(): # if QAT scheduler, we need to check if we quantize at the target bitwidth torch.save(model.state_dict(), f"{output_path}/{best_model_name}") best_acc = acc_test @@ -114,9 +114,11 @@ def train(model, lr_val_history.append(lr_scheduler.get_last_lr()[0]) lr_scheduler.step() - if qat_scheduler!=None: - bitwidth_val_history.append(qat_scheduler.bitwidth) - qat_scheduler.step() + if qat: + model.clamp() + if qat_scheduler!=None: + bitwidth_val_history.append(qat_scheduler.bitwidth) + qat_scheduler.step() if save_history: np.save(f"{output_path}/loss.npy", np.array(avg_loss_history)) @@ -128,7 +130,7 @@ def train(model, if lr_scheduler!=None: np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler)) - if qat_scheduler!=None: + if qat and qat_scheduler!=None: np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history)) if save_plot: @@ -159,7 +161,7 @@ def train(model, plt.ylabel("learning rate") plt.savefig(f"{output_path}/lr_values.svg") - if qat_scheduler!=None: + if qat and qat_scheduler!=None: plt.figure() plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth") plt.legend() @@ -287,7 +289,7 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva y_true.append(target) - spk_rec, batch_speed = model(data) + spk_rec, mem_rec, batch_speed = model(data) run_time.extend(batch_speed) @@ -304,8 +306,8 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva del target del spk_rec - y_true = torch.stack(y_true).reshape(-1) - y_pred = torch.stack(y_pred).reshape(-1) + y_true = torch.stack(y_true).cou().reshape(-1).numpy() + y_pred = torch.stack(y_pred).coup().reshape(-1).numpy() run_time = np.array(run_time) diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py index 7765dae..3ba0a9b 100644 --- a/modneflib/modnef/templates/vhdl_generation.py +++ b/modneflib/modnef/templates/vhdl_generation.py @@ -13,6 +13,9 @@ if __name__ == "__main__": """Experience name""" exp_name = "Evaluation" + """Device definition""" + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + """Model definition""" best_model_name = "best_model" @@ -21,11 +24,7 @@ if __name__ == "__main__": model = MyModel("template_model", spike_grad=fast_sigmoid(slope=25)) - model.load_state_dict(torch.load(best_model_name)) - - - """Device definition""" - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model.load_state_dict(torch.load(best_model_name, map_location=device)) """Hardware Estimation variable definition""" verbose = True -- GitLab