diff --git a/modneflib/modnef/quantizer/__init__.py b/modneflib/modnef/quantizer/__init__.py
index 2ef9cb9816e32de17524204fe25486ea81790efa..7ca95c0dde63c9d819a40ee3cd54da799384ebe1 100644
--- a/modneflib/modnef/quantizer/__init__.py
+++ b/modneflib/modnef/quantizer/__init__.py
@@ -11,4 +11,5 @@ from .quantizer import Quantizer
 from .fixed_point_quantizer import FixedPointQuantizer
 from .min_max_quantizer import MinMaxQuantizer
 from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer
-from .ste_quantizer import QuantizeSTE
\ No newline at end of file
+from .ste_quantizer import QuantizeSTE
+from .quantizer_scheduler import QuantizerScheduler
\ No newline at end of file
diff --git a/modneflib/modnef/quantizer/fixed_point_quantizer.py b/modneflib/modnef/quantizer/fixed_point_quantizer.py
index 106a242180b8b83b5d47ebe1dbf0e4f7604f63cd..16b08e53490ba071295f0c960eac0fc53c44335d 100644
--- a/modneflib/modnef/quantizer/fixed_point_quantizer.py
+++ b/modneflib/modnef/quantizer/fixed_point_quantizer.py
@@ -67,9 +67,6 @@ class FixedPointQuantizer(Quantizer):
     dtype = torch.int32 : torch.dtype
       type use during conversion
     """
-
-    if bitwidth==-1 and fixed_point==-1:
-      raise Exception("You must fix at least one value to compute the other one")
     
     super().__init__(
       bitwidth=bitwidth,
@@ -145,7 +142,7 @@ class FixedPointQuantizer(Quantizer):
       elif self.fixed_point==-1:
         self.fixed_point = self.bitwidth-int_part_size
         self.scale_factor = 2**self.fixed_point
-
+  
 
   def _quant(self, data) -> torch.Tensor:
     """
diff --git a/modneflib/modnef/quantizer/quantizer_scheduler.py b/modneflib/modnef/quantizer/quantizer_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..850dd9b00f66033098fa7869693fd9f38c2e8eb7
--- /dev/null
+++ b/modneflib/modnef/quantizer/quantizer_scheduler.py
@@ -0,0 +1,56 @@
+"""
+File name: quantizer_scheduler
+Author: Aurélie Saulquin  
+Version: 0.1.0
+License: GPL-3.0-or-later
+Contact: aurelie.saulquin@univ-lille.fr
+Dependencies: modnef.modnef_torch
+Descriptions: ModNEF quantizer scheduler
+"""
+
+from modnef.modnef_torch import ModNEFNeuron
+
+class QuantizerScheduler():
+
+  def __init__(self, model, bit_range, T, quantizationMethod):
+
+    self.num_bits = [i for i in range(bit_range[0], bit_range[1]-1, -1)]
+
+    self.model = model
+
+    self.period = T
+
+    self.epoch_max = self.period*(len(self.num_bits)-1)
+
+    print(self.num_bits)
+    print(self.epoch_max)
+
+    self.bit_counter = 0
+
+    self.quantizationMethod = quantizationMethod
+
+    self.__update()
+
+    self.epoch_counter = 0
+
+  def __update(self):
+    print(self.num_bits[self.bit_counter])
+    for m in self.model.modules():
+      if isinstance(m, ModNEFNeuron):
+        m.quantizer = self.quantizationMethod(self.num_bits[self.bit_counter])
+        m.init_quantizer()
+        m.quantize_hp()
+
+  def step(self):
+
+    self.epoch_counter += 1
+
+    if self.epoch_counter > self.epoch_max:
+      return
+    else:
+      if self.epoch_counter%self.period==0:
+        self.bit_counter += 1
+        self.__update()
+
+  def save_model(self):
+    return self.epoch_counter >= self.epoch_max
diff --git a/modneflib/modnef/templates/evaluation.py b/modneflib/modnef/templates/evaluation.py
index 91cba66dd98812a08f9cdbe433d2c6db85887f6e..1c2898bc4458dc387d7c87db15c0294fe94978f6 100644
--- a/modneflib/modnef/templates/evaluation.py
+++ b/modneflib/modnef/templates/evaluation.py
@@ -38,7 +38,6 @@ if __name__ == "__main__":
 
   """Evaluation variable definition"""
   verbose = True
-  save_conf_matrix = False
   output_path = "."
 
   """FPGA file definition"""
@@ -50,9 +49,14 @@ if __name__ == "__main__":
   acc = 0.0
   y_true = None
   y_pred = None
+  run_time = None
+  
   save_conf_matrix = False
-  conf_matrix_file = "confusion_matrix.png"
-  conf_matrix_classes = [str(i) for i in range(10)]
+  conf_matrix_file = "confusion_matrix.svg"
+  num_class = 10
+  conf_matrix_classes = [str(i) for i in range(num_class)]
+
+  save_array = False
   
   
   if kind == "eval":
@@ -74,7 +78,7 @@ if __name__ == "__main__":
       quant=True
       )
   elif kind == "feval":
-    acc, y_pred, y_true = fpga_evaluation(
+    acc, y_pred, y_true, run_time = fpga_evaluation(
       model=model, 
       driver_config=driver_config_path,
       board_path=board_path,
@@ -92,4 +96,10 @@ if __name__ == "__main__":
       y_pred=y_pred,
       file_name=conf_matrix_file,
       classes=conf_matrix_classes
-      )
\ No newline at end of file
+      )
+
+  if save_array:
+    np.save(f"{output_path}/y_true.npy", y_true)
+    np.save(f"{output_path}/y_pred.npy", y_pred)
+    if kind=="feval":
+      np.save(f"{output_path}/run_time.npy", run_time)
\ No newline at end of file
diff --git a/modneflib/modnef/templates/model.py b/modneflib/modnef/templates/model.py
index ac2fb76d6257efaf126d5552ecfb0cd259bc83fc..b8bbb3ceb6bb9a4790ab37cf5f6223f4cbd5aa07 100644
--- a/modneflib/modnef/templates/model.py
+++ b/modneflib/modnef/templates/model.py
@@ -141,7 +141,7 @@ class MyModel(mt.ModNEFModel):
     builder = ModNEFBuilder(self.name, 2312, 10)
 
 
-    uart = Uart_XStep(
+    uart = Uart_XStep_Timer(
       name="uart",
       input_layer_size=2312,
       output_layer_size=10,
diff --git a/modneflib/modnef/templates/run_lib.py b/modneflib/modnef/templates/run_lib.py
index 85a13529268e8ad7a81c86f0c78467adda158dbd..73dd3521012b282bf9c2f44f86b7dc6b41d65ba4 100644
--- a/modneflib/modnef/templates/run_lib.py
+++ b/modneflib/modnef/templates/run_lib.py
@@ -1,4 +1,3 @@
-import torch.nn as nn
 from snntorch import surrogate
 import torch
 spike_grad = surrogate.fast_sigmoid(slope=25)
@@ -9,8 +8,7 @@ import matplotlib.pyplot as plt
 from sklearn.metrics import confusion_matrix
 import seaborn as sns
 
-
-def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
+def train_1_epoch(model, trainLoader, optimizer, loss, qat, device, verbose):
   epoch_loss = []
 
   if verbose:
@@ -20,13 +18,17 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
     loader = trainLoader
 
   for _, (data, target) in enumerate(loader):
-    model.train()
+    model.train(quant=qat)
+
+    """Prepare data"""
     data = data.to(device)
     data = data.squeeze(0)
     target = target.to(device)
     
+    """Forward Pass"""
     spk_rec, mem_rec = model(data)
 
+    """Prepare backward"""
     loss_val = torch.zeros((1), dtype=torch.float, device=device)
 
     for step in range(data.shape[1]):
@@ -34,6 +36,7 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
 
     epoch_loss.append(loss_val.item())
 
+    """Backward"""
     model.zero_grad()
     loss_val.backward()
     optimizer.step()
@@ -44,33 +47,76 @@ def train_1_epoch(model, trainLoader, optimizer, loss, device, verbose):
 
   return np.mean(epoch_loss)
 
-def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("cpu"), validationLoader=None, n_epoch=10, best_model_name="best_model", verbose=True, save_plot=False, save_history=False, output_path="."):
+def train(model, 
+          trainLoader, 
+          testLoader, 
+          optimizer, 
+          loss, 
+          lr_scheduler = None,
+          qat = False, 
+          qat_scheduler = None,
+          device=torch.device("cpu"), 
+          validationLoader=None, 
+          n_epoch=10, 
+          best_model_name="best_model", 
+          verbose=True, 
+          save_plot=False, 
+          save_history=False, 
+          output_path="."
+          ):
+  
   avg_loss_history = []
   acc_test_history = []
   acc_val_history = []
-
+  lr_val_history = []
+  bitwidth_val_history = []
 
   best_acc = 0
 
   model = model.to(device)
 
+  if qat: # we prepare model for QAT
+    model.init_quantizer()
+    model.quantize_hp()
+
   for epoch in range(n_epoch):
     if verbose:
       print(f"---------- Epoch : {epoch} ----------")
     
-    epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose)
+    """Model training"""
+    epoch_loss = train_1_epoch(model=model, trainLoader=trainLoader, optimizer=optimizer, loss=loss, device=device, verbose=verbose, qat=qat)
     avg_loss_history.append(epoch_loss)
 
+    """Model Validation"""
     if validationLoader!=None:
-      acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device)
+      acc_val, _, _ = evaluation(model=model, testLoader=validationLoader, name="Validation", verbose=verbose, device=device, quant=qat)
       acc_val_history.append(acc_val)
 
-    acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device)
+    """Model evaluation in test"""
+    acc_test, _, _ = evaluation(model=model, testLoader=testLoader, name="Test", verbose=verbose, device=device, quant=qat)
     acc_test_history.append(acc_test)
 
-    if best_model_name!="" and acc_test>best_acc:
-      torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
-      best_acc = acc_test
+    """Save best model"""
+    if best_model_name!="" and acc_test>best_acc: 
+      if not qat:
+        torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+        best_acc = acc_test
+      else: #if QAT
+        if qat_scheduler==None: # no QAT scheduler so we save our model
+          torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+          best_acc = acc_test
+        elif qat_scheduler.save(): # if QAT scheduler, we need to check if we quantize at the target bitwidth
+          torch.save(model.state_dict(), f"{output_path}/{best_model_name}")
+          best_acc = acc_test
+
+    """Update schedulers"""
+    if lr_scheduler!=None:
+      lr_val_history.append(lr_scheduler.get_last_lr()[0])
+      lr_scheduler.step()
+
+    if qat_scheduler!=None:
+      bitwidth_val_history.append(qat_scheduler.bitwidth)
+      qat_scheduler.step()
 
   if save_history:
     np.save(f"{output_path}/loss.npy", np.array(avg_loss_history))
@@ -79,13 +125,19 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
     if len(acc_val_history)!=0:
       np.save(f"{output_path}/acc_validation.npy", np.array(acc_val_history))
 
+    if lr_scheduler!=None:
+      np.save(f"{output_path}/lr_scheduler.npy", np.array(lr_scheduler))
+
+    if qat_scheduler!=None:
+      np.save(f"{output_path}/qat_scheudler_bitwidth.npy", np.array(bitwidth_val_history))
+
   if save_plot:
     plt.figure()  # Create a new figure
     plt.plot([i for i in range(n_epoch)], avg_loss_history)
     plt.title('Average Loss')
     plt.xlabel("Epoch")
     plt.ylabel("Loss")
-    plt.savefig(f"{output_path}/loss.png")
+    plt.savefig(f"{output_path}/loss.svg")
     
     plt.figure()
     if len(acc_val_history)!=0:
@@ -96,7 +148,25 @@ def train(model, trainLoader, testLoader, optimizer, loss, device=torch.device("
     plt.title("Accuracy")
     plt.xlabel("Epoch")
     plt.ylabel("Accuracy")
-    plt.savefig(f"{output_path}/accuracy.png")
+    plt.savefig(f"{output_path}/accuracy.svg")
+
+    if lr_scheduler!=None:
+      plt.figure()
+      plt.plot([i for i in range(n_epoch)], lr_val_history, label="LR Value")
+      plt.legend()
+      plt.title("LR Values")
+      plt.xlabel("Epoch")
+      plt.ylabel("learning rate")
+      plt.savefig(f"{output_path}/lr_values.svg")
+
+    if qat_scheduler!=None:
+      plt.figure()
+      plt.plot([i for i in range(n_epoch)], lr_val_history, label="bitwidth")
+      plt.legend()
+      plt.title("Quantizer bitwidth")
+      plt.xlabel("Epoch")
+      plt.ylabel("bitwidth")
+      plt.savefig(f"{output_path}/quant_bitwidth.svg")
 
   return avg_loss_history, acc_val_history, acc_test_history, best_acc
 
@@ -154,8 +224,8 @@ def __run_accuracy(model, testLoader, name, verbose, device):
       del spk_rec
       del mem_rec
 
-    y_true = torch.stack(y_true).reshape(-1)
-    y_pred = torch.stack(y_pred).reshape(-1)
+    y_true = torch.stack(y_true).reshape(-1).cpu().numpy()
+    y_pred = torch.stack(y_pred).reshape(-1).cpu().numpy()
     
       
     return (correct/total), y_pred, y_true
@@ -169,9 +239,6 @@ def evaluation(model, testLoader, name="Evaluation", device=torch.device("cpu"),
 
   model.eval(quant)
 
-  if quant:
-    model.quantize(force_init=True)
-
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true   
@@ -185,13 +252,10 @@ def hardware_estimation(model, testLoader, name="Hardware Estimation", device=to
 
   model.hardware_estimation()
 
-  model.quantize(force_init=True)
-
   accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
 
   return accuracy, y_pred, y_true
 
-
 def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Evaluation", verbose=False):
   accuracy = 0
   y_pred = []
@@ -203,9 +267,51 @@ def fpga_evaluation(model, testLoader, board_path, driver_config, name="FPGA Eva
 
   model.fpga_eval(board_path, driver_config)
 
-  accuracy, y_pred, y_true = __run_accuracy(model=model, testLoader=testLoader, name=name, verbose=verbose, device=device)
+  y_true = []
+  y_pred = []
+  run_time = []
+  correct = 0
+  total = 0
 
-  return accuracy, y_pred, y_true   
+  if verbose:
+    bar_format = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]"
+    loader = tqdm(testLoader, desc=name, bar_format=bar_format)
+  else:
+    loader = testLoader
+
+
+  for _, (data, target) in enumerate(loader):
+    
+    data = data.to(device)
+    target = target.to(device)
+
+    y_true.append(target)
+
+    spk_rec, batch_speed = model(data)
+
+    run_time.extend(batch_speed)
+
+    output = (spk_rec.sum(dim=0))/data.shape[1]
+    predicted = output.argmax(dim=1).to(device)
+    correct += predicted.eq(target.view_as(predicted)).sum().item()
+    y_pred.append(predicted)
+    total += target.size(0)
+
+    if verbose:
+      loader.set_postfix_str(f"Accuracy : {np.mean(correct/total*100):0<3.2f} Run Time : {np.mean(batch_speed)*1e6:.3f} µs")
+
+    del data
+    del target
+    del spk_rec
+
+  y_true = torch.stack(y_true).reshape(-1)
+  y_pred = torch.stack(y_pred).reshape(-1)
+  run_time = np.array(run_time)
+  
+    
+  accuracy = (correct/total)
+
+  return accuracy, y_pred, y_true, run_time
 
 def conf_matrix(y_true, y_pred, file_name, classes):
   cm = confusion_matrix(y_true, y_pred)
diff --git a/modneflib/modnef/templates/train.py b/modneflib/modnef/templates/train.py
index 0f481251196a869236f019f777c3ce83f9e01f5e..5d4866c4ec4e7faf5724bd760bdf042b3d2ebd2a 100644
--- a/modneflib/modnef/templates/train.py
+++ b/modneflib/modnef/templates/train.py
@@ -1,12 +1,9 @@
-import tonic
-from torch.utils.data import DataLoader
-from modnef.modnef_torch import ModNEFModelBuilder
-import os
 from snntorch.surrogate import fast_sigmoid
 from run_lib import *
 import torch
 from model import MyModel
 from dataset import *
+from modnef.quantizer import QuantizerScheduler
 
 if __name__ == "__main__":
 
@@ -18,6 +15,7 @@ if __name__ == "__main__":
 
   """Optimizer"""
   optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999))
+  lr_scheduler = None #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, eta_min=5e-4)
 
   """Loss"""
   loss = torch.nn.CrossEntropyLoss()
@@ -27,12 +25,17 @@ if __name__ == "__main__":
 
   """Train variable definition"""
   n_epoch = 2
+  qat = False
   best_model_name = "best_model"
   verbose = True
   save_plot = False
   save_history = False
   output_path = "."
 
+  """Quantization Aware Training"""
+  qat=False
+  qat_scheduler = None #QuantizerScheduler(model, (8,3), 3, lambda x : FixedPointQuantizer(x, x-1, True, True))
+
   train(
     model=model, 
     trainLoader=trainLoader, 
@@ -40,6 +43,9 @@ if __name__ == "__main__":
     validationLoader=validationLoader,
     optimizer=optimizer, 
     loss=loss, 
+    lr_scheduler = lr_scheduler,
+    qat = qat, 
+    qat_scheduler = qat_scheduler,
     device=device,  
     n_epoch=n_epoch, 
     best_model_name=best_model_name, 
diff --git a/modneflib/modnef/templates/vhdl_generation.py b/modneflib/modnef/templates/vhdl_generation.py
index 563d1a1fb7de009c76d06f4efdd45217cd0c65f7..7765dae21c343bb89607ce680b0958e24c91de71 100644
--- a/modneflib/modnef/templates/vhdl_generation.py
+++ b/modneflib/modnef/templates/vhdl_generation.py
@@ -36,6 +36,9 @@ if __name__ == "__main__":
   file_name = "template_vhdl_model.vhd"
   driver_config_path = "driver_config.yml"
 
+  """Prepare model for hardware estimation"""
+  model.quantize(force_init=True, clamp=True)
+
   acc, y_pred, y_true = hardware_estimation(
     model=model, 
     testLoader=testLoader,