From 3fe54c44ce67f18b7de34498e1b563d4f30c936a Mon Sep 17 00:00:00 2001 From: ahoni <aurelie.saulq@proton.me> Date: Tue, 25 Mar 2025 22:59:14 +0100 Subject: [PATCH] add test quantizer --- .../modnef_neurons/modnef_torch_neuron.py | 18 +++++++++++++++++- .../modnef_neurons/srlif_model/shiftlif.py | 3 +++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index c4a9800..06e00f6 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -14,6 +14,21 @@ from modnef.quantizer import * from snntorch._neurons import SpikingNeuron from snntorch.surrogate import fast_sigmoid + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class QuantizeSTE(torch.autograd.Function): + """Quantization avec Straight-Through Estimator (STE)""" + @staticmethod + def forward(ctx, x, quantizer): + return quantizer(x, True) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None # STE: Passe le gradient inchangé + _quantizer = { "FixedPointQuantizer" : FixedPointQuantizer, "MinMaxQuantizer" : MinMaxQuantizer, @@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron): for p in self.parameters(): p.data = self.quantizer(p.data, unscale=unscale) + #p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer) + #p.data = QuantizeSTE.apply(p.data, self.quantizer) def quantize_hp(self, unscale : bool = True): """ @@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron): for p in self.parameters(): p.data = self.quantizer.clamp(p.data) - print("clamp") def run_quantize(self, mode=False): """ diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index f2c95cb..b0d7586 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -19,6 +19,8 @@ from math import log, ceil from modnef.quantizer import DynamicScaleFactorQuantizer from snntorch import LIF + + class ShiftLIF(ModNEFNeuron): """ ModNEFTorch Shift LIF neuron model @@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron): self.mem.data = self.quantizer(self.mem.data, True) input_.data = self.quantizer(input_.data, True) + self.mem = self.mem+input_ if self.reset_mechanism == "subtract": -- GitLab