diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 55298a14af1dca60bda5e2c4aab3e00a97d6131f..10ef31c19cd5fb77642faee04403d6251480048a 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -18,6 +18,10 @@ import torch.nn.functional as F import brevitas.nn as qnn import brevitas.quant as bq +from brevitas.core.quant import QuantType +from brevitas.core.restrict_val import RestrictValueType +from brevitas.core.scaling import ScalingImplType + class QuantizeSTE(torch.autograd.Function): @staticmethod def forward(ctx, weights, quantizer): @@ -79,7 +83,13 @@ class ModNEFNeuron(SpikingNeuron): reset_mechanism=reset_mechanism ) - #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5) + # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, + # weight_quant_type=QuantType.INT, + # weight_bit_width=8, + # weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO, + # weight_scaling_impl_type=ScalingImplType.CONST, + # weight_scaling_const=1.0 + # ) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index b0409a43311971e674f3a840f9ef58a921a607b1..429d24a8392bcb180d7ccc47b624ff2a06e4575a 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -26,14 +26,14 @@ class QuantizeSTE(torch.autograd.Function): q_weights = quantizer(weights, True) - #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward + ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward return q_weights @staticmethod def backward(ctx, grad_output): # STE : on passe directement le gradient au poids float - #scale_factor, = ctx.saved_tensors - return grad_output, None + scale_factor, = ctx.saved_tensors + return grad_output*scale_factor, None class ShiftLIF(ModNEFNeuron): @@ -267,7 +267,7 @@ class ShiftLIF(ModNEFNeuron): if self.quantization_flag: self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer) - #input_.data = self.quantizer(input_.data, True) + input_.data = QuantizeSTE.apply(input_.data, self.quantizer) self.mem = self.mem+input_ diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391