Skip to content
Snippets Groups Projects
Commit 3fe54c44 authored by ahoni's avatar ahoni
Browse files

add test quantizer

parent 256235ce
No related branches found
No related tags found
1 merge request!3Dev
......@@ -14,6 +14,21 @@ from modnef.quantizer import *
from snntorch._neurons import SpikingNeuron
from snntorch.surrogate import fast_sigmoid
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod
def forward(ctx, x, quantizer):
return quantizer(x, True)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé
_quantizer = {
"FixedPointQuantizer" : FixedPointQuantizer,
"MinMaxQuantizer" : MinMaxQuantizer,
......@@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron):
for p in self.parameters():
p.data = self.quantizer(p.data, unscale=unscale)
#p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
#p.data = QuantizeSTE.apply(p.data, self.quantizer)
def quantize_hp(self, unscale : bool = True):
"""
......@@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron):
for p in self.parameters():
p.data = self.quantizer.clamp(p.data)
print("clamp")
def run_quantize(self, mode=False):
"""
......
......@@ -19,6 +19,8 @@ from math import log, ceil
from modnef.quantizer import DynamicScaleFactorQuantizer
from snntorch import LIF
class ShiftLIF(ModNEFNeuron):
"""
ModNEFTorch Shift LIF neuron model
......@@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron):
self.mem.data = self.quantizer(self.mem.data, True)
input_.data = self.quantizer(input_.data, True)
self.mem = self.mem+input_
if self.reset_mechanism == "subtract":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment