Skip to content
Snippets Groups Projects
Commit adc7744b authored by ahoni's avatar ahoni
Browse files

add qat test

parent 3fe54c44
No related branches found
No related tags found
1 merge request!3Dev
......@@ -92,6 +92,15 @@ class ModNEFModel(nn.Module):
return super().train(mode=mode)
def init_quantizer(self):
"""
initialize quantizer of laters
"""
for m in self.modules():
if isinstance(m, ModNEFNeuron):
m.init_quantizer()
def quantize(self, force_init=False):
"""
Quantize synaptic weight and neuron hyper-parameters
......@@ -106,11 +115,11 @@ class ModNEFModel(nn.Module):
if isinstance(m, ModNEFNeuron):
m.quantize(force_init=force_init)
def clamp(self):
def clamp(self, force_init=False):
for m in self.modules():
if isinstance(m, ModNEFNeuron):
m.clamp()
m.clamp(force_init=force_init)
def train(self, mode : bool = True, quant : bool = False):
"""
......
......@@ -150,11 +150,14 @@ class ModNEFNeuron(SpikingNeuron):
self.quantize_weight()
self.quantize_hp()
def clamp(self):
def clamp(self, force_init=False):
"""
Clamp synaptic weight
"""
if force_init:
self.init_quantizer()
for p in self.parameters():
p.data = self.quantizer.clamp(p.data)
......
......@@ -342,7 +342,6 @@ class RShiftLIF(ModNEFNeuron):
"""
self.threshold.data = self.quantizer(self.threshold.data, unscale)
self.beta.data = self.quantizer(self.beta.data, unscale)
@classmethod
......
......@@ -264,6 +264,9 @@ class ShiftLIF(ModNEFNeuron):
else:
self.mem = self.mem-self.__shift(self.mem)
if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True)
if self.hardware_estimation_flag:
self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min)
self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max)
......@@ -328,7 +331,6 @@ class ShiftLIF(ModNEFNeuron):
"""
self.threshold.data = self.quantizer(self.threshold.data, unscale)
self.beta.data = self.quantizer(self.beta.data, unscale)
@classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment