Skip to content
Snippets Groups Projects
Commit 906a34d8 authored by ahoni's avatar ahoni
Browse files

add quantizationa ware training algorithm

parent d08d6933
No related branches found
No related tags found
1 merge request!3Dev
...@@ -12,3 +12,4 @@ from .fixed_point_quantizer import FixedPointQuantizer ...@@ -12,3 +12,4 @@ from .fixed_point_quantizer import FixedPointQuantizer
from .min_max_quantizer import MinMaxQuantizer from .min_max_quantizer import MinMaxQuantizer
from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
from .ste_quantizer import QuantizeSTE from .ste_quantizer import QuantizeSTE
from .quantizer_scheduler import QuantizerScheduler
\ No newline at end of file
...@@ -68,9 +68,6 @@ class FixedPointQuantizer(Quantizer): ...@@ -68,9 +68,6 @@ class FixedPointQuantizer(Quantizer):
type use during conversion type use during conversion
""" """
if bitwidth==-1 and fixed_point==-1:
raise Exception("You must fix at least one value to compute the other one")
super().__init__( super().__init__(
bitwidth=bitwidth, bitwidth=bitwidth,
signed=signed, signed=signed,
......
"""
File name: quantizer_scheduler
Author: Aurélie Saulquin
Version: 0.1.0
License: GPL-3.0-or-later
Contact: aurelie.saulquin@univ-lille.fr
Dependencies: modnef.modnef_torch
Descriptions: ModNEF quantizer scheduler
"""
from modnef.modnef_torch import ModNEFNeuron
class QuantizerScheduler():
def __init__(self, model, bit_range, T, quantizationMethod):
self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)]
self.model = model
self.period = T
self.epoch_max = self.period*(len(self.num_bits)-1)
print(self.num_bits)
print(self.epoch_max)
self.bit_counter = 0
self.quantizationMethod = quantizationMethod
self.__update()
self.epoch_counter = 0
def __update(self):
print(self.num_bits[self.bit_counter])
for m in self.model.modules():
if isinstance(m, ModNEFNeuron):
m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter])
m.init_quantizer()
m.quantize_hp()
def step(self):
self.epoch_counter += 1
if self.epoch_counter > self.epoch_max:
return
else:
if self.epoch_counter%self.period==0:
self.bit_counter += 1
self.__update()
def save_model(self):
return self.epoch_counter >= self.epoch_max
...@@ -38,7 +38,6 @@ if __name__ == "__main__": ...@@ -38,7 +38,6 @@ if __name__ == "__main__":
"""Evaluation variable definition""" """Evaluation variable definition"""
verbose = True verbose = True
save_conf_matrix = False
output_path = "." output_path = "."
"""FPGA file definition""" """FPGA file definition"""
...@@ -50,9 +49,14 @@ if __name__ == "__main__": ...@@ -50,9 +49,14 @@ if __name__ == "__main__":
acc = 0.0 acc = 0.0
y_true = None y_true = None
y_pred = None y_pred = None
run_time = None
save_conf_matrix = False save_conf_matrix = False
conf_matrix_file = "confusion_matrix.png" conf_matrix_file = "confusion_matrix.svg"
conf_matrix_classes = [str(i) for i in range(10)] num_class = 10
conf_matrix_classes = [str(i) for i in range(num_class)]
save_array = False
if kind == "eval": if kind == "eval":
...@@ -74,7 +78,7 @@ if __name__ == "__main__": ...@@ -74,7 +78,7 @@ if __name__ == "__main__":
quant=True quant=True
) )
elif kind == "feval": elif kind == "feval":
acc, y_pred, y_true = fpga_evaluation( acc, y_pred, y_true, run_time = fpga_evaluation(
model=model, model=model,
driver_config=driver_config_path, driver_config=driver_config_path,
board_path=board_path, board_path=board_path,
...@@ -93,3 +97,9 @@ if __name__ == "__main__": ...@@ -93,3 +97,9 @@ if __name__ == "__main__":
file_name=conf_matrix_file, file_name=conf_matrix_file,
classes=conf_matrix_classes classes=conf_matrix_classes
) )
if save_array:
np.save(f"{output_path}/y_true.npy", y_true)
np.save(f"{output_path}/y_pred.npy", y_pred)
if kind=="feval":
np.save(f"{output_path}/run_time.npy", run_time)
\ No newline at end of file
...@@ -141,7 +141,7 @@ class MyModel(mt.ModNEFModel): ...@@ -141,7 +141,7 @@ class MyModel(mt.ModNEFModel):
builder = ModNEFBuilder(self.name, 2312, 10) builder = ModNEFBuilder(self.name, 2312, 10)
uart = Uart_XStep( uart = Uart_XStep_Timer(
name="uart", name="uart",
input_layer_size=2312, input_layer_size=2312,
output_layer_size=10, output_layer_size=10,
......
import torch.nn as nn
from snntorch import surrogate from snntorch import surrogate
import torch import torch
spike_grad = surrogate.fast_sigmoid(slope=25) spike_grad = surrogate.fast_sigmoid(slope=25)
...@@ -9,8 +8,7 @@ import matplotlib.pyplot as plt ...@@ -9,8 +8,7 @@ import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
import seaborn as sns import seaborn as sns
def train_1_epoch(model, trainLoader, optimizer, loss, qat, device, verbose):
def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
epoch_loss = [] epoch_loss = []
if verbose: if verbose:
...@@ -20,13 +18,17 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose): ...@@ -20,13 +18,17 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
loader = trainLoader loader = trainLoader
for _, (data, target) in enumerate(loader): for _, (data, target) in enumerate(loader):
model.train() model.train(quant=qat)
"""Prepare data"""
data = data.to(device) data = data.to(device)
data = data.squeeze(0) data = data.squeeze(0)
target = target.to(device) target = target.to(device)
"""Forward Pass"""
spk_rec, mem_rec = model(data) spk_rec, mem_rec = model(data)
"""Prepare backward"""
loss_val = torch.zeros((1), dtype=torch.float, device=device) loss_val = torch.zeros((1), dtype=torch.float, device=device)
for step in range(data.shape[1]): for step in range(data.shape[1]):
...@@ -34,6 +36,7 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose): ...@@ -34,6 +36,7 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
epoch_loss.append(loss_val.item()) epoch_loss.append(loss_val.item())
"""Backward"""
model.zero_grad() model.zero_grad()
loss_val.backward() loss_val.backward()
optimizer.step() optimizer.step()
...@@ -44,34 +47,77 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose): ...@@ -44,34 +47,77 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
return np.mean(epoch_loss) return np.mean(epoch_loss)
def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("cpu"), validationLoader=None, n_epoch=10, best_model_name="best_model", verbose=True, save_plot=False, save_history=False, output_path="."): def train(model,
trainLoader,
testLoader,
optimizer,
loss,
lr_scheduler = None,
qat = False,
qat_scheduler = None,
device=torch.device("cpu"),
validationLoader=None,
n_epoch=10,
best_model_name="best_model",
verbose=True,
save_plot=False,
save_history=False,
output_path="."
):
avg_loss_history = [] avg_loss_history = []
acc_test_history = [] acc_test_history = []
acc_val_history = [] acc_val_history = []
lr_val_history = []
bitwidth_val_history = []
best_acc = 0 best_acc = 0
model = model.to(device) model = model.to(device)
if qat: # we prepare model for QAT
model.init_quantizer()
model.quantize_hp()
for epoch in range(n_epoch): for epoch in range(n_epoch):
if verbose: if verbose:
print(f"---------- Epoch : {epoch} ----------") print(f"---------- Epoch : {epoch} ----------")
epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose) """Model training"""
epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose, qat=qat)
avg_loss_history.append(epoch_loss) avg_loss_history.append(epoch_loss)
"""Model Validation"""
if validationLoader!=None: if validationLoader!=None:
acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device) acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device, quant=qat)
acc_val_history.append(acc_val) acc_val_history.append(acc_val)
acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device) """Model evaluation in test"""
acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device, quant=qat)
acc_test_history.append(acc_test) acc_test_history.append(acc_test)
"""Save best model"""
if best_model_name!="" and acc_test>best_acc: if best_model_name!="" and acc_test>best_acc:
if not qat:
torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test
else: #if QAT
if qat_scheduler==None: # no QAT scheduler so we save our model
torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test
elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
torch.save(model.state_dict(), f"{output_path}/{best_model_name}") torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
best_acc = acc_test best_acc = acc_test
"""Update schedulers"""
if lr_scheduler!=None:
lr_val_history.append(lr_scheduler.get_last_lr()[0])
lr_scheduler.step()
if qat_scheduler!=None:
bitwidth_val_history.append(qat_scheduler.bitwidth)
qat_scheduler.step()
if save_history: if save_history:
np.save(f"{output_path}/loss.npy", np.array(avg_loss_history)) np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
np.save(f"{output_path}/acc_test.npy", np.array(acc_test_history)) np.save(f"{output_path}/acc_test.npy", np.array(acc_test_history))
...@@ -79,13 +125,19 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device(" ...@@ -79,13 +125,19 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
if len(acc_val_history)!=0: if len(acc_val_history)!=0:
np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history)) np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history))
if lr_scheduler!=None:
np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
if qat_scheduler!=None:
np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
if save_plot: if save_plot:
plt.figure() # Create a new figure plt.figure() # Create a new figure
plt.plot([i for i in range(n_epoch)], avg_loss_history) plt.plot([i for i in range(n_epoch)], avg_loss_history)
plt.title('Average Loss') plt.title('Average Loss')
plt.xlabel("Epoch") plt.xlabel("Epoch")
plt.ylabel("Loss") plt.ylabel("Loss")
plt.savefig(f"{output_path}/loss.png") plt.savefig(f"{output_path}/loss.svg")
plt.figure() plt.figure()
if len(acc_val_history)!=0: if len(acc_val_history)!=0:
...@@ -96,7 +148,25 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device(" ...@@ -96,7 +148,25 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
plt.title("Accuracy") plt.title("Accuracy")
plt.xlabel("Epoch") plt.xlabel("Epoch")
plt.ylabel("Accuracy") plt.ylabel("Accuracy")
plt.savefig(f"{output_path}/accuracy.png") plt.savefig(f"{output_path}/accuracy.svg")
if lr_scheduler!=None:
plt.figure()
plt.plot([i for i in range(n_epoch)], lr_val_history, label="LR Value")
plt.legend()
plt.title("LR Values")
plt.xlabel("Epoch")
plt.ylabel("learning rate")
plt.savefig(f"{output_path}/lr_values.svg")
if qat_scheduler!=None:
plt.figure()
plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
plt.legend()
plt.title("Quantizer bitwidth")
plt.xlabel("Epoch")
plt.ylabel("bitwidth")
plt.savefig(f"{output_path}/quant_bitwidth.svg")
return avg_loss_history, acc_val_history, acc_test_history, best_acc return avg_loss_history, acc_val_history, acc_test_history, best_acc
...@@ -154,8 +224,8 @@ def __run_accuracy(model, testLoader, name, verbose, device): ...@@ -154,8 +224,8 @@ def __run_accuracy(model, testLoader, name, verbose, device):
del spk_rec del spk_rec
del mem_rec del mem_rec
y_true = torch.stack(y_true).reshape(-1) y_true = torch.stack(y_true).reshape(-1).cpu().numpy()
y_pred = torch.stack(y_pred).reshape(-1) y_pred = torch.stack(y_pred).reshape(-1).cpu().numpy()
return (correct/total), y_pred, y_true return (correct/total), y_pred, y_true
...@@ -169,9 +239,6 @@ def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"), ...@@ -169,9 +239,6 @@ def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"),
model.eval(quant) model.eval(quant)
if quant:
model.quantize(force_init=True)
accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device) accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
return accuracy, y_pred, y_true return accuracy, y_pred, y_true
...@@ -185,13 +252,10 @@ def hardware_estimation(model, testLoader, name="Hardware Estimation", device=to ...@@ -185,13 +252,10 @@ def hardware_estimation(model, testLoader, name="Hardware Estimation", device=to
model.hardware_estimation() model.hardware_estimation()
model.quantize(force_init=True)
accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device) accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
return accuracy, y_pred, y_true return accuracy, y_pred, y_true
def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False): def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False):
accuracy = 0 accuracy = 0
y_pred = [] y_pred = []
...@@ -203,9 +267,51 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva ...@@ -203,9 +267,51 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva
model.fpga_eval(board_path, driver_config) model.fpga_eval(board_path, driver_config)
accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device) y_true = []
y_pred = []
run_time = []
correct = 0
total = 0
return accuracy, y_pred, y_true if verbose:
bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
loader = tqdm(testLoader, desc=name, bar_format=bar_format)
else:
loader = testLoader
for _, (data, target) in enumerate(loader):
data = data.to(device)
target = target.to(device)
y_true.append(target)
spk_rec, batch_speed = model(data)
run_time.extend(batch_speed)
output = (spk_rec.sum(dim=0))/data.shape[1]
predicted = output.argmax(dim=1).to(device)
correct += predicted.eq(target.view_as(predicted)).sum().item()
y_pred.append(predicted)
total += target.size(0)
if verbose:
loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f} Run Time : {np.mean(batch_speed)*1e6:.3f} µs")
del data
del target
del spk_rec
y_true = torch.stack(y_true).reshape(-1)
y_pred = torch.stack(y_pred).reshape(-1)
run_time = np.array(run_time)
accuracy = (correct/total)
return accuracy, y_pred, y_true, run_time
def conf_matrix(y_true, y_pred, file_name, classes): def conf_matrix(y_true, y_pred, file_name, classes):
cm = confusion_matrix(y_true, y_pred) cm = confusion_matrix(y_true, y_pred)
......
import tonic
from torch.utils.data import DataLoader
from modnef.modnef_torch import ModNEFModelBuilder
import os
from snntorch.surrogate import fast_sigmoid from snntorch.surrogate import fast_sigmoid
from run_lib import * from run_lib import *
import torch import torch
from model import MyModel from model import MyModel
from dataset import * from dataset import *
from modnef.quantizer import QuantizerScheduler
if __name__ == "__main__": if __name__ == "__main__":
...@@ -18,6 +15,7 @@ if __name__ == "__main__": ...@@ -18,6 +15,7 @@ if __name__ == "__main__":
"""Optimizer""" """Optimizer"""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999)) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
lr_scheduler = None #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, eta_min=5e-4)
"""Loss""" """Loss"""
loss = torch.nn.CrossEntropyLoss() loss = torch.nn.CrossEntropyLoss()
...@@ -27,12 +25,17 @@ if __name__ == "__main__": ...@@ -27,12 +25,17 @@ if __name__ == "__main__":
"""Train variable definition""" """Train variable definition"""
n_epoch = 2 n_epoch = 2
qat = False
best_model_name = "best_model" best_model_name = "best_model"
verbose = True verbose = True
save_plot = False save_plot = False
save_history = False save_history = False
output_path = "." output_path = "."
"""Quantization Aware Training"""
qat=False
qat_scheduler = None #QuantizerScheduler(model, (8,3), 3, lambda x : FixedPointQuantizer(x, x-1, True, True))
train( train(
model=model, model=model,
trainLoader=trainLoader, trainLoader=trainLoader,
...@@ -40,6 +43,9 @@ if __name__ == "__main__": ...@@ -40,6 +43,9 @@ if __name__ == "__main__":
validationLoader=validationLoader, validationLoader=validationLoader,
optimizer=optimizer, optimizer=optimizer,
loss=loss, loss=loss,
lr_scheduler = lr_scheduler,
qat = qat,
qat_scheduler = qat_scheduler,
device=device, device=device,
n_epoch=n_epoch, n_epoch=n_epoch,
best_model_name=best_model_name, best_model_name=best_model_name,
......
...@@ -36,6 +36,9 @@ if __name__ == "__main__": ...@@ -36,6 +36,9 @@ if __name__ == "__main__":
file_name = "template_vhdl_model.vhd" file_name = "template_vhdl_model.vhd"
driver_config_path = "driver_config.yml" driver_config_path = "driver_config.yml"
"""Prepare model for hardware estimation"""
model.quantize(force_init=True, clamp=True)
acc, y_pred, y_true = hardware_estimation( acc, y_pred, y_true = hardware_estimation(
model=model, model=model,
testLoader=testLoader, testLoader=testLoader,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment