Skip to content
Snippets Groups Projects
Select Git revision
  • e963ffc02ec360e9efb5d2e5da507860c903bad2
  • master default protected
2 results

sco_formsemestre_status.py

Blame
  • Forked from Jean-Marie Place / SCODOC_R6A06
    Source project has a limited visibility.
    run_lib.py 9.04 KiB
    from snntorch import surrogate
    import torch
    spike_grad = surrogate.fast_sigmoid(slope=25)
    from tqdm import tqdm
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    def train_1_epoch(model, trainLoader, optimizer, loss, qat, device, verbose):
      epoch_loss = []
    
      if verbose:
        bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
        loader = tqdm(trainLoader, desc="Train", bar_format=bar_format)
      else:
        loader = trainLoader
    
      for _, (data, target) in enumerate(loader):
        model.train(quant=qat)
    
        """Prepare data"""
        data = data.to(device)
        data = data.squeeze(0)
        target = target.to(device)
        
        """Forward Pass"""
        spk_rec, mem_rec = model(data)
    
        """Prepare backward"""
        loss_val = torch.zeros((1), dtype=torch.float, device=device)
    
        for step in range(data.shape[1]):
          loss_val += loss(mem_rec[step], target)
    
        epoch_loss.append(loss_val.item())
    
        """Backward"""
        model.zero_grad()
        loss_val.backward()
        optimizer.step()
    
        if verbose:
          loader.set_postfix_str(f"Loss : {np.mean(epoch_loss):0<3.2f}")
    
    
      return np.mean(epoch_loss)
    
    def train(model, 
              trainLoader, 
              testLoader, 
              optimizer, 
              loss, 
              lr_scheduler = None,
              qat = False, 
              qat_scheduler = None,
              device=torch.device("cpu"), 
              validationLoader=None, 
              n_epoch=10, 
              best_model_name="best_model", 
              verbose=True, 
              save_plot=False, 
              save_history=False, 
              output_path="."
              ):
      
      avg_loss_history = []
      acc_test_history = []
      acc_val_history = []
      lr_val_history = []
      bitwidth_val_history = []
    
      best_acc = 0
    
      model = model.to(device)
    
      if qat: # we prepare model for QAT
        model.init_quantizer()
        model.quantize_hp()
    
      for epoch in range(n_epoch):
        if verbose:
          print(f"---------- Epoch : {epoch} ----------")
        
        """Model training"""
        epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose, qat=qat)
        avg_loss_history.append(epoch_loss)
    
        """Model Validation"""
        if validationLoader!=None:
          acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device, quant=qat)
          acc_val_history.append(acc_val)
    
        """Model evaluation in test"""
        acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device, quant=qat)
        acc_test_history.append(acc_test)
    
        """Save best model"""
        if best_model_name!="" and acc_test>best_acc: 
          if not qat:
            torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
            best_acc = acc_test
          else: #if QAT
            if qat_scheduler==None: # no QAT scheduler so we save our model
              torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
              best_acc = acc_test
            elif qat_scheduler.save_model(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
              torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
              best_acc = acc_test
    
        """Update schedulers"""
        if lr_scheduler!=None:
          lr_val_history.append(lr_scheduler.get_last_lr()[0])
          lr_scheduler.step()
    
        if qat:
          model.clamp() 
          if qat_scheduler!=None:
            bitwidth_val_history.append(qat_scheduler.bitwidth)
            qat_scheduler.step()
    
      if save_history:
        np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
        np.save(f"{output_path}/acc_test.npy", np.array(acc_test_history))
        
        if len(acc_val_history)!=0:
          np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history))
    
        if lr_scheduler!=None:
          np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
    
        if qat and qat_scheduler!=None:
          np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
    
      if save_plot:
        plt.figure()  # Create a new figure
        plt.plot([i for i in range(n_epoch)], avg_loss_history)
        plt.title('Average Loss')
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(f"{output_path}/loss.svg")
        
        plt.figure()
        if len(acc_val_history)!=0:
          plt.plot([i for i in range(n_epoch)], acc_val_history, label="Validation")
        plt.plot([i for i in range(n_epoch)], acc_test_history, label="Test")
        plt.scatter([acc_test_history.index(best_acc)], [best_acc], label="Best Accuracy")
        plt.legend()
        plt.title("Accuracy")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.savefig(f"{output_path}/accuracy.svg")
    
        if lr_scheduler!=None:
          plt.figure()
          plt.plot([i for i in range(n_epoch)], lr_val_history, label="LR Value")
          plt.legend()
          plt.title("LR Values")
          plt.xlabel("Epoch")
          plt.ylabel("learning rate")
          plt.savefig(f"{output_path}/lr_values.svg")
    
        if qat and qat_scheduler!=None:
          plt.figure()
          plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
          plt.legend()
          plt.title("Quantizer bitwidth")
          plt.xlabel("Epoch")
          plt.ylabel("bitwidth")
          plt.savefig(f"{output_path}/quant_bitwidth.svg")
    
      return avg_loss_history, acc_val_history, acc_test_history, best_acc
    
    def __run_accuracy(model, testLoader, name, verbose, device):
        """
        Run inference
    
        Parameters
        ----------
        testLoader
          test dataset loader
        name : str
          name of inference
    
        Returns
        -------
        (float, list, list)
          accuracy
          predicted class
          true class
        """
    
        y_true = []
        y_pred = []
        correct = 0
        total = 0
    
        if verbose:
          bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
          loader = tqdm(testLoader, desc=name, bar_format=bar_format)
        else:
          loader = testLoader
    
    
        for _, (data, target) in enumerate(loader):
          
          data = data.to(device)
          target = target.to(device)
    
          y_true.append(target)
    
          spk_rec, mem_rec = model(data)
    
          output = (spk_rec.sum(dim=0))/data.shape[1]
          predicted = output.argmax(dim=1).to(device)
          correct += predicted.eq(target.view_as(predicted)).sum().item()
          y_pred.append(predicted)
          total += target.size(0)
    
          if verbose:
            loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f}")
    
          del data
          del target
          del spk_rec
          del mem_rec
    
        y_true = torch.stack(y_true).reshape(-1).cpu().numpy()
        y_pred = torch.stack(y_pred).reshape(-1).cpu().numpy()
        
          
        return (correct/total), y_pred, y_true
    
    def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"), verbose=False, quant=False):
      accuracy = 0
      y_pred = []
      y_true = []
    
      model = model.to(device)
    
      model.eval(quant)
    
      accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
    
      return accuracy, y_pred, y_true   
    
    def hardware_estimation(model, testLoader, name="Hardware Estimation", device=torch.device("cpu"), verbose=False):
      accuracy = 0
      y_pred = []
      y_true = []
    
      model = model.to(device)
    
      model.hardware_estimation()
    
      accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
    
      return accuracy, y_pred, y_true
    
    def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False):
      accuracy = 0
      y_pred = []
      y_true = []
    
      device = torch.device("cpu")
    
      model = model.to(device)
    
      model.fpga_eval(board_path, driver_config)
    
      y_true = []
      y_pred = []
      run_time = []
      correct = 0
      total = 0
    
      if verbose:
        bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
        loader = tqdm(testLoader, desc=name, bar_format=bar_format)
      else:
        loader = testLoader
    
    
      for _, (data, target) in enumerate(loader):
        
        data = data.to(device)
        target = target.to(device)
    
        y_true.append(target)
    
        spk_rec, mem_rec, batch_speed = model(data)
    
        run_time.extend(batch_speed)
    
        output = (spk_rec.sum(dim=0))/data.shape[1]
        predicted = output.argmax(dim=1).to(device)
        correct += predicted.eq(target.view_as(predicted)).sum().item()
        y_pred.append(predicted)
        total += target.size(0)
    
        if verbose:
          loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f} Run Time : {np.mean(batch_speed)*1e6:.3f} µs")
    
        del data
        del target
        del spk_rec
    
      y_true = torch.stack(y_true).cou().reshape(-1).numpy()
      y_pred = torch.stack(y_pred).coup().reshape(-1).numpy()
      run_time = np.array(run_time)
      
        
      accuracy = (correct/total)
    
      return accuracy, y_pred, y_true, run_time
    
    def conf_matrix(y_true, y_pred, file_name, classes):
      cm = confusion_matrix(y_true, y_pred)
      df_cm = pd.DataFrame(cm / np.sum(cm, axis=1)[:, None], index = [i for i in classes], columns = [i for i in classes])
      plt.figure()
      sns.heatmap(df_cm, annot=True)
      plt.title(f"Evaluation Confusion Matrix")
      plt.savefig(file_name)