Skip to content
Snippets Groups Projects
Commit ff5d695f authored by ahoni's avatar ahoni
Browse files

brevitas test

parent 2daec63a
No related branches found
No related tags found
1 merge request!3Dev
...@@ -113,6 +113,7 @@ class ModNEFModel(nn.Module): ...@@ -113,6 +113,7 @@ class ModNEFModel(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, ModNEFNeuron): if isinstance(m, ModNEFNeuron):
if force_init:
m.init_quantizer() m.init_quantizer()
m.quantize_weight() m.quantize_weight()
......
...@@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid ...@@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid
import torch.nn.functional as F import torch.nn.functional as F
import brevitas.nn as qnn import brevitas.nn as qnn
import brevitas.quant as bq
class QuantizeSTE(torch.autograd.Function): class QuantizeSTE(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function): ...@@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function):
q_weights = quantizer(weights, True) q_weights = quantizer(weights, True)
#ctx.scale = quantizer.scale_factor # On sauvegarde le scale pour backward #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward
return q_weights return q_weights
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# STE : on passe directement le gradient au poids float # STE : on passe directement le gradient au poids float
#scale_factor, = ctx.saved_tensors
return grad_output, None return grad_output, None
...@@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron):
reset_mechanism=reset_mechanism reset_mechanism=reset_mechanism
) )
#self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5) #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
...@@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron):
""" """
for p in self.parameters(): for p in self.parameters():
print(p)
p.data = QuantizeSTE.apply(p.data, self.quantizer)
# if ema: #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
# print(self.alpha)
# p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
# self.alpha *= 0.1
# #p.data = QuantizeSTE.apply(p.data, self.quantizer)
# else:
p.data = self.quantizer(p.data, True)
def quantize_hp(self, unscale : bool = True): def quantize_hp(self, unscale : bool = True):
......
...@@ -21,14 +21,21 @@ from snntorch import LIF ...@@ -21,14 +21,21 @@ from snntorch import LIF
class QuantizeSTE(torch.autograd.Function): class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod @staticmethod
def forward(ctx, x, quantizer): def forward(ctx, weights, quantizer):
return quantizer(x, True)
q_weights = quantizer(weights, True)
#ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward
return q_weights
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé # STE : on passe directement le gradient au poids float
#scale_factor, = ctx.saved_tensors
return grad_output, None
class ShiftLIF(ModNEFNeuron): class ShiftLIF(ModNEFNeuron):
""" """
ModNEFTorch Shift LIF neuron model ModNEFTorch Shift LIF neuron model
...@@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron): ...@@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron):
self.reset = self.mem_reset(self.mem) self.reset = self.mem_reset(self.mem)
if self.quantization_flag: if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True) self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
#input_.data = self.quantizer(input_.data, True) #input_.data = self.quantizer(input_.data, True)
...@@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron): ...@@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron):
""" """
self.threshold.data = self.quantizer(self.threshold.data, unscale) self.threshold.data = self.quantizer(self.threshold.data, unscale)
print(self.threshold)
@classmethod @classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment