Skip to content
Snippets Groups Projects
Commit ff5d695f authored by ahoni's avatar ahoni
Browse files

brevitas test

parent 2daec63a
No related branches found
No related tags found
1 merge request!3Dev
......@@ -113,6 +113,7 @@ class ModNEFModel(nn.Module):
for m in self.modules():
if isinstance(m, ModNEFNeuron):
if force_init:
m.init_quantizer()
m.quantize_weight()
......
......@@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid
import torch.nn.functional as F
import brevitas.nn as qnn
import brevitas.quant as bq
class QuantizeSTE(torch.autograd.Function):
@staticmethod
......@@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function):
q_weights = quantizer(weights, True)
#ctx.scale = quantizer.scale_factor # On sauvegarde le scale pour backward
#ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward
return q_weights
@staticmethod
def backward(ctx, grad_output):
# STE : on passe directement le gradient au poids float
#scale_factor, = ctx.saved_tensors
return grad_output, None
......@@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron):
reset_mechanism=reset_mechanism
)
#self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5)
#self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5)
self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
......@@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron):
"""
for p in self.parameters():
print(p)
p.data = QuantizeSTE.apply(p.data, self.quantizer)
# if ema:
# print(self.alpha)
# p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer)
# self.alpha *= 0.1
# #p.data = QuantizeSTE.apply(p.data, self.quantizer)
# else:
p.data = self.quantizer(p.data, True)
#self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer)
def quantize_hp(self, unscale : bool = True):
......
......@@ -21,14 +21,21 @@ from snntorch import LIF
class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod
def forward(ctx, x, quantizer):
return quantizer(x, True)
def forward(ctx, weights, quantizer):
q_weights = quantizer(weights, True)
#ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward
return q_weights
@staticmethod
def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé
# STE : on passe directement le gradient au poids float
#scale_factor, = ctx.saved_tensors
return grad_output, None
class ShiftLIF(ModNEFNeuron):
"""
ModNEFTorch Shift LIF neuron model
......@@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron):
self.reset = self.mem_reset(self.mem)
if self.quantization_flag:
self.mem.data = self.quantizer(self.mem.data, True)
self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
#input_.data = self.quantizer(input_.data, True)
......@@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron):
"""
self.threshold.data = self.quantizer(self.threshold.data, unscale)
print(self.threshold)
@classmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment