diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index b6a0432c865c52a819d13cdce460fa1a8f1d96b8..56e3d103ce02d058f3291421f2ddb15ea280c194 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -14,20 +14,23 @@ from modnef.quantizer import * from snntorch._neurons import SpikingNeuron from snntorch.surrogate import fast_sigmoid - -import torch -import torch.nn as nn import torch.nn.functional as F +import brevitas.nn as qnn class QuantizeSTE(torch.autograd.Function): - """Quantization avec Straight-Through Estimator (STE)""" @staticmethod - def forward(ctx, x, quantizer): - return quantizer(x, True) + def forward(ctx, weights, quantizer): + + q_weights = quantizer(weights, True) + + #ctx.scale = quantizer.scale_factor # On sauvegarde le scale pour backward + return q_weights @staticmethod def backward(ctx, grad_output): - return grad_output, None # STE: Passe le gradient inchangé + # STE : on passe directement le gradient au poids float + return grad_output, None + _quantizer = { "FixedPointQuantizer" : FixedPointQuantizer, @@ -74,6 +77,7 @@ class ModNEFNeuron(SpikingNeuron): reset_mechanism=reset_mechanism ) + #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) @@ -87,6 +91,9 @@ class ModNEFNeuron(SpikingNeuron): self.quantizer = quantizer + + self.alpha = 0.9 + @classmethod def from_dict(cls, dict): """ @@ -107,7 +114,7 @@ class ModNEFNeuron(SpikingNeuron): else: self.quantizer.init_from_weight(param[0], param[1]) - def quantize_weight(self, unscale : bool = True): + def quantize_weight(self, unscale : bool = True, ema = False): """ synaptic weight quantization @@ -117,9 +124,15 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - p.data = self.quantizer(p.data, unscale=unscale) - #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer) - #p.data = QuantizeSTE.apply(p.data, self.quantizer) + + # if ema: + # print(self.alpha) + # p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer) + # self.alpha *= 0.1 + # #p.data = QuantizeSTE.apply(p.data, self.quantizer) + # else: + p.data = self.quantizer(p.data, True) + def quantize_hp(self, unscale : bool = True): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 5d938b43d64ea10d526d1d4ca74f5f24962e3bca..a7c1b1f403ed372cbce42b6bb57c044ce2c33add 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -20,7 +20,15 @@ from modnef.quantizer import DynamicScaleFactorQuantizer from snntorch import LIF - +class QuantizeSTE(torch.autograd.Function): + """Quantization avec Straight-Through Estimator (STE)""" + @staticmethod + def forward(ctx, x, quantizer): + return quantizer(x, True) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None # STE: Passe le gradient inchangé class ShiftLIF(ModNEFNeuron): """ ModNEFTorch Shift LIF neuron model @@ -252,7 +260,7 @@ class ShiftLIF(ModNEFNeuron): if self.quantization_flag: self.mem.data = self.quantizer(self.mem.data, True) - input_.data = self.quantizer(input_.data, True) + #input_.data = self.quantizer(input_.data, True) self.mem = self.mem+input_ diff --git a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py index 0f25c44624f8e2766923149e70ecf45060e712dd..f1d7497a8893891ea15c8d2446f4a2b512848b1c 100644 --- a/modneflib/modnef/quantizer/dynamic_scale_quantizer.py +++ b/modneflib/modnef/quantizer/dynamic_scale_quantizer.py @@ -161,5 +161,4 @@ class DynamicScaleFactorQuantizer(Quantizer): ------- Tensor """ - return data*self.scale_factor \ No newline at end of file