Skip to content
Snippets Groups Projects
Commit f5298795 authored by Aurélie saulquin's avatar Aurélie saulquin
Browse files

add qat test

parent dbfdd4f2
Branches
No related tags found
1 merge request!3Dev
...@@ -101,6 +101,21 @@ class ModNEFModel(nn.Module): ...@@ -101,6 +101,21 @@ class ModNEFModel(nn.Module):
if isinstance(m, ModNEFNeuron): if isinstance(m, ModNEFNeuron):
m.init_quantizer() m.init_quantizer()
def quantize_weight(self, force_init=False):
"""
Quantize synaptic weight
Parameters
----------
force_init = False : bool
force quantizer initialization
"""
for m in self.modules():
if isinstance(m, ModNEFNeuron):
m.init_quantizer()
m.quantize_weight()
def quantize(self, force_init=False): def quantize(self, force_init=False):
""" """
Quantize synaptic weight and neuron hyper-parameters Quantize synaptic weight and neuron hyper-parameters
......
...@@ -232,6 +232,7 @@ class BLIF(ModNEFNeuron): ...@@ -232,6 +232,7 @@ class BLIF(ModNEFNeuron):
input_.data = self.quantizer(input_.data, True) input_.data = self.quantizer(input_.data, True)
self.mem.data = self.quantizer(self.mem.data, True) self.mem.data = self.quantizer(self.mem.data, True)
self.reset = self.mem_reset(self.mem) self.reset = self.mem_reset(self.mem)
if self.reset_mechanism == "subtract": if self.reset_mechanism == "subtract":
...@@ -241,9 +242,6 @@ class BLIF(ModNEFNeuron): ...@@ -241,9 +242,6 @@ class BLIF(ModNEFNeuron):
else: else:
self.mem = self.mem*self.beta self.mem = self.mem*self.beta
if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True)
if self.hardware_estimation_flag: if self.hardware_estimation_flag:
self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
......
...@@ -258,15 +258,15 @@ class RBLIF(ModNEFNeuron): ...@@ -258,15 +258,15 @@ class RBLIF(ModNEFNeuron):
rec = self.reccurent(self.spk) rec = self.reccurent(self.spk)
# if self.quantization_flag:
# self.mem.data = self.quantizer(self.mem.data, True)
# input_.data = self.quantizer(input_.data, True)
# rec.data = self.quantizer(rec.data, True)
if self.quantization_flag: if self.quantization_flag:
self.mem = QuantizeMembrane.apply(self.mem, self.quantizer) self.mem.data = self.quantizer(self.mem.data, True)
input_ = QuantizeMembrane.apply(input_, self.quantizer) input_.data = self.quantizer(input_.data, True)
rec = QuantizeMembrane.apply(rec, self.quantizer) rec.data = self.quantizer(rec.data, True)
# if self.quantization_flag:
# self.mem = QuantizeMembrane.apply(self.mem, self.quantizer)
# input_ = QuantizeMembrane.apply(input_, self.quantizer)
# rec = QuantizeMembrane.apply(rec, self.quantizer)
if self.reset_mechanism == "subtract": if self.reset_mechanism == "subtract":
self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold
...@@ -279,8 +279,8 @@ class RBLIF(ModNEFNeuron): ...@@ -279,8 +279,8 @@ class RBLIF(ModNEFNeuron):
self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
if self.quantization_flag: # if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True) # self.mem.data = self.quantizer(self.mem.data, True)
self.spk = self.fire(self.mem) self.spk = self.fire(self.mem)
......
...@@ -264,8 +264,8 @@ class ShiftLIF(ModNEFNeuron): ...@@ -264,8 +264,8 @@ class ShiftLIF(ModNEFNeuron):
else: else:
self.mem = self.mem-self.__shift(self.mem) self.mem = self.mem-self.__shift(self.mem)
if self.quantization_flag: # if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True) # self.mem.data = self.quantizer(self.mem.data, True)
if self.hardware_estimation_flag: if self.hardware_estimation_flag:
self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
...@@ -331,6 +331,7 @@ class ShiftLIF(ModNEFNeuron): ...@@ -331,6 +331,7 @@ class ShiftLIF(ModNEFNeuron):
""" """
self.threshold.data = self.quantizer(self.threshold.data, unscale) self.threshold.data = self.quantizer(self.threshold.data, unscale)
print(self.threshold)
@classmethod @classmethod
......
...@@ -114,7 +114,7 @@ class DynamicScaleFactorQuantizer(Quantizer): ...@@ -114,7 +114,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
weight = torch.Tensor(weight) weight = torch.Tensor(weight)
if not torch.is_tensor(rec_weight): if not torch.is_tensor(rec_weight):
rec_weight = torch.Tensor(weight) rec_weight = torch.Tensor(rec_weight)
if self.signed==None: if self.signed==None:
self.signed = torch.min(weight.min(), rec_weight.min())<0.0 self.signed = torch.min(weight.min(), rec_weight.min())<0.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment