From 6fd614f6d3b10b05f7c469fa8d8b18f876b64ea2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lie=20saulquin?= <aurelie.saulq@gmail.com> Date: Tue, 1 Apr 2025 17:52:37 +0200 Subject: [PATCH] add quant linear test --- .../modnef_neurons/modnef_torch_neuron.py | 12 +++++++++++- .../modnef_neurons/srlif_model/shiftlif.py | 8 ++++---- modneflib/modnef/modnef_torch/quantLinear.py | 0 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 modneflib/modnef/modnef_torch/quantLinear.py diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 55298a1..10ef31c 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -18,6 +18,10 @@ import torch.nn.functional as F import brevitas.nn as qnn import brevitas.quant as bq +from brevitas.core.quant import QuantType +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType + class QuantizeSTE(torch.autograd.Function): @staticmethod def forward(ctx, weights, quantizer): @@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron): reset_mechanism=reset_mechanism ) - #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5) + # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, + # weight_quant_type=QuantType.INT, + # weight_bit_width=8, + # weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO, + # weight_scaling_impl_type=ScalingImplType.CONST, + # weight_scaling_const=1.0 + # ) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index b0409a4..429d24a 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function): q_weights = quantizer(weights, True) - #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward + ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward return q_weights @staticmethod def backward(ctx, grad_output): # STE : on passe directement le gradient au poids float - #scale_factor, = ctx.saved_tensors - return grad_output, None + scale_factor, = ctx.saved_tensors + return grad_output*scale_factor, None class ShiftLIF(ModNEFNeuron): @@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron): if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer) - #input_.data = self.quantizer(input_.data, True) + input_.data = QuantizeSTE.apply(input_.data, self.quantizer) self.mem = self.mem+input_ diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py new file mode 100644 index 0000000..e69de29 -- GitLab