diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index 152cd9337d81bd3a92af86598dc16dcf89877fb7..2c0d687f052fee19c932f904d2c2baf49e675762 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -113,7 +113,8 @@ class ModNEFModel(nn.Module): for m in self.modules(): if isinstance(m, ModNEFNeuron): - m.init_quantizer() + if force_init: + m.init_quantizer() m.quantize_weight() def quantize(self, force_init=False): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 56e3d103ce02d058f3291421f2ddb15ea280c194..55298a14af1dca60bda5e2c4aab3e00a97d6131f 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -16,6 +16,7 @@ from snntorch.surrogate import fast_sigmoid import torch.nn.functional as F import brevitas.nn as qnn +import brevitas.quant as bq class QuantizeSTE(torch.autograd.Function): @staticmethod @@ -23,12 +24,13 @@ class QuantizeSTE(torch.autograd.Function): q_weights = quantizer(weights, True) - #ctx.scale = quantizer.scale_factor # On sauvegarde le scale pour backward + #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward return q_weights @staticmethod def backward(ctx, grad_output): # STE : on passe directement le gradient au poids float + #scale_factor, = ctx.saved_tensors return grad_output, None @@ -77,7 +79,7 @@ class ModNEFNeuron(SpikingNeuron): reset_mechanism=reset_mechanism ) - #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_bit_witdh=5) + #self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, weight_quant=bq.Int8WeightPerTensorFixedPoint, bit_width=5) self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) @@ -124,14 +126,10 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - - # if ema: - # print(self.alpha) - # p.data = self.alpha * p.data + (1-self.alpha) * QuantizeSTE.apply(p.data, self.quantizer) - # self.alpha *= 0.1 - # #p.data = QuantizeSTE.apply(p.data, self.quantizer) - # else: - p.data = self.quantizer(p.data, True) + print(p) + p.data = QuantizeSTE.apply(p.data, self.quantizer) + + #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer) def quantize_hp(self, unscale : bool = True): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index a7c1b1f403ed372cbce42b6bb57c044ce2c33add..b0409a43311971e674f3a840f9ef58a921a607b1 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -21,14 +21,21 @@ from snntorch import LIF class QuantizeSTE(torch.autograd.Function): - """Quantization avec Straight-Through Estimator (STE)""" @staticmethod - def forward(ctx, x, quantizer): - return quantizer(x, True) + def forward(ctx, weights, quantizer): + + q_weights = quantizer(weights, True) + + #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward + return q_weights @staticmethod def backward(ctx, grad_output): - return grad_output, None # STE: Passe le gradient inchangé + # STE : on passe directement le gradient au poids float + #scale_factor, = ctx.saved_tensors + return grad_output, None + + class ShiftLIF(ModNEFNeuron): """ ModNEFTorch Shift LIF neuron model @@ -259,7 +266,7 @@ class ShiftLIF(ModNEFNeuron): self.reset = self.mem_reset(self.mem) if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) + self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer) #input_.data = self.quantizer(input_.data, True) @@ -343,7 +350,6 @@ class ShiftLIF(ModNEFNeuron): """ self.threshold.data = self.quantizer(self.threshold.data, unscale) - print(self.threshold) @classmethod