Skip to content
Snippets Groups Projects
Commit 2daec63a authored by ahoni's avatar ahoni
Browse files

add test qat

parent 8b0940ba
Branches
No related tags found
1 merge request!3Dev
...@@ -14,20 +14,23 @@ from modnef.quantizer import * ...@@ -14,20 +14,23 @@ from modnef.quantizer import *
from snntorch._neurons import SpikingNeuron from snntorch._neurons import SpikingNeuron
from snntorch.surrogate import fast_sigmoid from snntorch.surrogate import fast_sigmoid
import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import brevitas.nn as qnn
class QuantizeSTE(torch.autograd.Function): class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod @staticmethod
def forward(ctx, x, quantizer): def forward(ctx, weights, quantizer):
return quantizer(x, True)
q_weights = quantizer(weights, True)
#ctx.scale = quantizer.scale_factor # On sauvegarde le scale pour backward
return q_weights
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé # STE : on passe directement le gradient au poids float
return grad_output, None
_quantizer = { _quantizer = {
"FixedPointQuantizer" : FixedPointQuantizer, "FixedPointQuantizer" : FixedPointQuantizer,
...@@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron):
reset_mechanism=reset_mechanism reset_mechanism=reset_mechanism
) )
#self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
...@@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron):
self.quantizer = quantizer self.quantizer = quantizer
self.alpha = 0.9
@classmethod @classmethod
def from_dict(cls, dict): def from_dict(cls, dict):
""" """
...@@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron):
else: else:
self.quantizer.init_from_weight(param[0], param[1]) self.quantizer.init_from_weight(param[0], param[1])
def quantize_weight(self, unscale : bool = True): def quantize_weight(self, unscale : bool = True, ema = False):
""" """
synaptic weight quantization synaptic weight quantization
...@@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron):
""" """
for p in self.parameters(): for p in self.parameters():
p.data = self.quantizer(p.data, unscale=unscale)
#p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer) # if ema:
#p.data = QuantizeSTE.apply(p.data, self.quantizer) # print(self.alpha)
# p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
# self.alpha *= 0.1
# #p.data = QuantizeSTE.apply(p.data, self.quantizer)
# else:
p.data = self.quantizer(p.data, True)
def quantize_hp(self, unscale : bool = True): def quantize_hp(self, unscale : bool = True):
""" """
......
...@@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer ...@@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer
from snntorch import LIF from snntorch import LIF
class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod
def forward(ctx, x, quantizer):
return quantizer(x, True)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé
class ShiftLIF(ModNEFNeuron): class ShiftLIF(ModNEFNeuron):
""" """
ModNEFTorch Shift LIF neuron model ModNEFTorch Shift LIF neuron model
...@@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron): ...@@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron):
if self.quantization_flag: if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True) self.mem.data = self.quantizer(self.mem.data, True)
input_.data = self.quantizer(input_.data, True) #input_.data = self.quantizer(input_.data, True)
self.mem = self.mem+input_ self.mem = self.mem+input_
......
...@@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer): ...@@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer):
------- -------
Tensor Tensor
""" """
return data*self.scale_factor return data*self.scale_factor
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment