Select Git revision
sco_formsemestre_status.py
Forked from
Jean-Marie Place / SCODOC_R6A06
Source project has a limited visibility.
run_lib.py 9.04 KiB
from snntorch import surrogate
import torch
spike_grad = surrogate.fast_sigmoid(slope=25)
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
def train_1_epoch(model, trainLoader, optimizer, loss, qat, device, verbose):
epoch_loss = []
if verbose:
bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
loader = tqdm(trainLoader, desc="Train", bar_format=bar_format)
else:
loader = trainLoader
for _, (data, target) in enumerate(loader):
model.train(quant=qat)
"""Prepare data"""
data = data.to(device)
data = data.squeeze(0)
target = target.to(device)
"""Forward Pass"""
spk_rec, mem_rec = model(data)
"""Prepare backward"""
loss_val = torch.zeros((1), dtype=torch.float, device=device)
for step in range(data.shape[1]):
loss_val += loss(mem_rec[step], target)
epoch_loss.append(loss_val.item())
"""Backward"""
model.zero_grad()
loss_val.backward()
optimizer.step()
if verbose:
loader.set_postfix_str(f"Loss : {np.mean(epoch_loss):0<3.2f}")
return np.mean(epoch_loss)
def train(model,
trainLoader,
testLoader,
optimizer,
loss,
lr_scheduler = None,
qat = False,
qat_scheduler = None,
device=torch.device("cpu"),
validationLoader=None,
n_epoch=10,
best_model_name="best_model",
verbose=True,
save_plot=False,
save_history=False,
output_path="."
):
avg_loss_history = []
acc_test_history = []
acc_val_history = []
lr_val_history = []
bitwidth_val_history = []
best_acc = 0
model = model.to(device)
if qat: # we prepare model for QAT
model.init_quantizer()
model.quantize_hp()
for epoch in range(n_epoch):
if verbose:
print(f"---------- Epoch : {epoch} ----------")
"""Model training"""
epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose, qat=qat)
avg_loss_history.append(epoch_loss)
"""Model Validation"""
if validationLoader!=None:
acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device, quant=qat)
acc_val_history.append(acc_val)
"""Model evaluation in test"""
acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device, quant=qat)
acc_test_history.append(acc_test)
"""Save best model"""
if best_model_name!="" and acc_test>best_acc:
if not qat:
torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test
else: #if QAT
if qat_scheduler==None: # no QAT scheduler so we save our model
torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test
elif qat_scheduler.save_model(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test
"""Update schedulers"""
if lr_scheduler!=None:
lr_val_history.append(lr_scheduler.get_last_lr()[0])
lr_scheduler.step()
if qat:
model.clamp()
if qat_scheduler!=None:
bitwidth_val_history.append(qat_scheduler.bitwidth)
qat_scheduler.step()
if save_history:
np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
np.save(f"{output_path}/acc_test.npy", np.array(acc_test_history))
if len(acc_val_history)!=0:
np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history))
if lr_scheduler!=None:
np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
if qat and qat_scheduler!=None:
np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
if save_plot:
plt.figure() # Create a new figure
plt.plot([i for i in range(n_epoch)], avg_loss_history)
plt.title('Average Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig(f"{output_path}/loss.svg")
plt.figure()
if len(acc_val_history)!=0:
plt.plot([i for i in range(n_epoch)], acc_val_history, label="Validation")
plt.plot([i for i in range(n_epoch)], acc_test_history, label="Test")
plt.scatter([acc_test_history.index(best_acc)], [best_acc], label="Best Accuracy")
plt.legend()
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.savefig(f"{output_path}/accuracy.svg")
if lr_scheduler!=None:
plt.figure()
plt.plot([i for i in range(n_epoch)], lr_val_history, label="LR Value")
plt.legend()
plt.title("LR Values")
plt.xlabel("Epoch")
plt.ylabel("learning rate")
plt.savefig(f"{output_path}/lr_values.svg")
if qat and qat_scheduler!=None:
plt.figure()
plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
plt.legend()
plt.title("Quantizer bitwidth")
plt.xlabel("Epoch")
plt.ylabel("bitwidth")
plt.savefig(f"{output_path}/quant_bitwidth.svg")
return avg_loss_history, acc_val_history, acc_test_history, best_acc
def __run_accuracy(model, testLoader, name, verbose, device):
"""
Run inference
Parameters
----------
testLoader
test dataset loader
name : str
name of inference
Returns
-------
(float, list, list)
accuracy
predicted class
true class
"""
y_true = []
y_pred = []
correct = 0
total = 0
if verbose:
bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
loader = tqdm(testLoader, desc=name, bar_format=bar_format)
else:
loader = testLoader
for _, (data, target) in enumerate(loader):
data = data.to(device)
target = target.to(device)
y_true.append(target)
spk_rec, mem_rec = model(data)
output = (spk_rec.sum(dim=0))/data.shape[1]
predicted = output.argmax(dim=1).to(device)
correct += predicted.eq(target.view_as(predicted)).sum().item()
y_pred.append(predicted)
total += target.size(0)
if verbose:
loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f}")
del data
del target
del spk_rec
del mem_rec
y_true = torch.stack(y_true).reshape(-1).cpu().numpy()
y_pred = torch.stack(y_pred).reshape(-1).cpu().numpy()
return (correct/total), y_pred, y_true
def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"), verbose=False, quant=False):
accuracy = 0
y_pred = []
y_true = []
model = model.to(device)
model.eval(quant)
accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
return accuracy, y_pred, y_true
def hardware_estimation(model, testLoader, name="Hardware Estimation", device=torch.device("cpu"), verbose=False):
accuracy = 0
y_pred = []
y_true = []
model = model.to(device)
model.hardware_estimation()
accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
return accuracy, y_pred, y_true
def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False):
accuracy = 0
y_pred = []
y_true = []
device = torch.device("cpu")
model = model.to(device)
model.fpga_eval(board_path, driver_config)
y_true = []
y_pred = []
run_time = []
correct = 0
total = 0
if verbose:
bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
loader = tqdm(testLoader, desc=name, bar_format=bar_format)
else:
loader = testLoader
for _, (data, target) in enumerate(loader):
data = data.to(device)
target = target.to(device)
y_true.append(target)
spk_rec, mem_rec, batch_speed = model(data)
run_time.extend(batch_speed)
output = (spk_rec.sum(dim=0))/data.shape[1]
predicted = output.argmax(dim=1).to(device)
correct += predicted.eq(target.view_as(predicted)).sum().item()
y_pred.append(predicted)
total += target.size(0)
if verbose:
loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f} Run Time : {np.mean(batch_speed)*1e6:.3f} µs")
del data
del target
del spk_rec
y_true = torch.stack(y_true).cou().reshape(-1).numpy()
y_pred = torch.stack(y_pred).coup().reshape(-1).numpy()
run_time = np.array(run_time)
accuracy = (correct/total)
return accuracy, y_pred, y_true, run_time
def conf_matrix(y_true, y_pred, file_name, classes):
cm = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cm / np.sum(cm, axis=1)[:, None], index = [i for i in classes], columns = [i for i in classes])
plt.figure()
sns.heatmap(df_cm, annot=True)
plt.title(f"Evaluation Confusion Matrix")
plt.savefig(file_name)