Skip to content
Snippets Groups Projects
Commit adc7744b authored by ahoni's avatar ahoni
Browse files

add qat test

parent 3fe54c44
Branches
No related tags found
1 merge request!3Dev
...@@ -92,6 +92,15 @@ class ModNEFModel(nn.Module): ...@@ -92,6 +92,15 @@ class ModNEFModel(nn.Module):
return super().train(mode=mode) return super().train(mode=mode)
def init_quantizer(self):
"""
initialize quantizer of laters
"""
for m in self.modules():
if isinstance(m, ModNEFNeuron):
m.init_quantizer()
def quantize(self, force_init=False): def quantize(self, force_init=False):
""" """
Quantize synaptic weight and neuron hyper-parameters Quantize synaptic weight and neuron hyper-parameters
...@@ -106,11 +115,11 @@ class ModNEFModel(nn.Module): ...@@ -106,11 +115,11 @@ class ModNEFModel(nn.Module):
if isinstance(m, ModNEFNeuron): if isinstance(m, ModNEFNeuron):
m.quantize(force_init=force_init) m.quantize(force_init=force_init)
def clamp(self): def clamp(self, force_init=False):
for m in self.modules(): for m in self.modules():
if isinstance(m, ModNEFNeuron): if isinstance(m, ModNEFNeuron):
m.clamp() m.clamp(force_init=force_init)
def train(self, mode : bool = True, quant : bool = False): def train(self, mode : bool = True, quant : bool = False):
""" """
......
...@@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron):
self.quantize_weight() self.quantize_weight()
self.quantize_hp() self.quantize_hp()
def clamp(self): def clamp(self, force_init=False):
""" """
Clamp synaptic weight Clamp synaptic weight
""" """
if force_init:
self.init_quantizer()
for p in self.parameters(): for p in self.parameters():
p.data = self.quantizer.clamp(p.data) p.data = self.quantizer.clamp(p.data)
......
...@@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron): ...@@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron):
""" """
self.threshold.data = self.quantizer(self.threshold.data, unscale) self.threshold.data = self.quantizer(self.threshold.data, unscale)
self.beta.data = self.quantizer(self.beta.data, unscale)
@classmethod @classmethod
......
...@@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron): ...@@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron):
else: else:
self.mem = self.mem-self.__shift(self.mem) self.mem = self.mem-self.__shift(self.mem)
if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True)
if self.hardware_estimation_flag: if self.hardware_estimation_flag:
self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
...@@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron): ...@@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron):
""" """
self.threshold.data = self.quantizer(self.threshold.data, unscale) self.threshold.data = self.quantizer(self.threshold.data, unscale)
self.beta.data = self.quantizer(self.beta.data, unscale)
@classmethod @classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment