diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index c4a980045968e8badcfcd59622ffefba18444f5b..06e00f69aa72ba09d283c16dad334df80d901766 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -14,6 +14,21 @@ from modnef.quantizer import * from snntorch._neurons import SpikingNeuron from snntorch.surrogate import fast_sigmoid + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class QuantizeSTE(torch.autograd.Function): + """Quantization avec Straight-Through Estimator (STE)""" + @staticmethod + def forward(ctx, x, quantizer): + return quantizer(x, True) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None # STE: Passe le gradient inchangé + _quantizer = { "FixedPointQuantizer" : FixedPointQuantizer, "MinMaxQuantizer" : MinMaxQuantizer, @@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron): for p in self.parameters(): p.data = self.quantizer(p.data, unscale=unscale) + #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer) + #p.data = QuantizeSTE.apply(p.data, self.quantizer) def quantize_hp(self, unscale : bool = True): """ @@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron): for p in self.parameters(): p.data = self.quantizer.clamp(p.data) - print("clamp") def run_quantize(self, mode=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index f2c95cbe6d5d7ef0ad7271e134a9a1d600f0e2c5..b0d75862ace0f96e202e63d1e4e2f4334cffd4b4 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -19,6 +19,8 @@ from math import log, ceil from modnef.quantizer import DynamicScaleFactorQuantizer from snntorch import LIF + + class ShiftLIF(ModNEFNeuron): """ ModNEFTorch Shift LIF neuron model @@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron): self.mem.data = self.quantizer(self.mem.data, True) input_.data = self.quantizer(input_.data, True) + self.mem = self.mem+input_ if self.reset_mechanism == "subtract":