Skip to content
Snippets Groups Projects
Commit 6fd614f6 authored by Aurélie saulquin's avatar Aurélie saulquin
Browse files

add quant linear test

parent ff5d695f
No related branches found
No related tags found
1 merge request!3Dev
...@@ -18,6 +18,10 @@ import torch.nn.functional as F ...@@ -18,6 +18,10 @@ import torch.nn.functional as F
import brevitas.nn as qnn import brevitas.nn as qnn
import brevitas.quant as bq import brevitas.quant as bq
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
class QuantizeSTE(torch.autograd.Function): class QuantizeSTE(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, weights, quantizer): def forward(ctx, weights, quantizer):
...@@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron):
reset_mechanism=reset_mechanism reset_mechanism=reset_mechanism
) )
#self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5) # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False,
# weight_quant_type=QuantType.INT,
# weight_bit_width=8,
# weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO,
# weight_scaling_impl_type=ScalingImplType.CONST,
# weight_scaling_const=1.0
# )
self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False)
......
...@@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function): ...@@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function):
q_weights = quantizer(weights, True) q_weights = quantizer(weights, True)
#ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward
return q_weights return q_weights
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
# STE : on passe directement le gradient au poids float # STE : on passe directement le gradient au poids float
#scale_factor, = ctx.saved_tensors scale_factor, = ctx.saved_tensors
return grad_output, None return grad_output*scale_factor, None
class ShiftLIF(ModNEFNeuron): class ShiftLIF(ModNEFNeuron):
...@@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron): ...@@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron):
if self.quantization_flag: if self.quantization_flag:
self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer) self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer)
#input_.data = self.quantizer(input_.data, True) input_.data = QuantizeSTE.apply(input_.data, self.quantizer)
self.mem = self.mem+input_ self.mem = self.mem+input_
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment