Skip to content
Snippets Groups Projects
Commit 3fe54c44 authored by ahoni's avatar ahoni
Browse files

add test quantizer

parent 256235ce
Branches
No related tags found
1 merge request!3Dev
...@@ -14,6 +14,21 @@ from modnef.quantizer import * ...@@ -14,6 +14,21 @@ from modnef.quantizer import *
from snntorch._neurons import SpikingNeuron from snntorch._neurons import SpikingNeuron
from snntorch.surrogate import fast_sigmoid from snntorch.surrogate import fast_sigmoid
import torch
import torch.nn as nn
import torch.nn.functional as F
class QuantizeSTE(torch.autograd.Function):
"""Quantization avec Straight-Through Estimator (STE)"""
@staticmethod
def forward(ctx, x, quantizer):
return quantizer(x, True)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None # STE: Passe le gradient inchangé
_quantizer = { _quantizer = {
"FixedPointQuantizer" : FixedPointQuantizer, "FixedPointQuantizer" : FixedPointQuantizer,
"MinMaxQuantizer" : MinMaxQuantizer, "MinMaxQuantizer" : MinMaxQuantizer,
...@@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -103,6 +118,8 @@ class ModNEFNeuron(SpikingNeuron):
for p in self.parameters(): for p in self.parameters():
p.data = self.quantizer(p.data, unscale=unscale) p.data = self.quantizer(p.data, unscale=unscale)
#p.data = 0.9 * p.data + (1-0.9) * QuantizeSTE.apply(p.data, self.quantizer)
#p.data = QuantizeSTE.apply(p.data, self.quantizer)
def quantize_hp(self, unscale : bool = True): def quantize_hp(self, unscale : bool = True):
""" """
...@@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron): ...@@ -140,7 +157,6 @@ class ModNEFNeuron(SpikingNeuron):
for p in self.parameters(): for p in self.parameters():
p.data = self.quantizer.clamp(p.data) p.data = self.quantizer.clamp(p.data)
print("clamp")
def run_quantize(self, mode=False): def run_quantize(self, mode=False):
""" """
......
...@@ -19,6 +19,8 @@ from math import log, ceil ...@@ -19,6 +19,8 @@ from math import log, ceil
from modnef.quantizer import DynamicScaleFactorQuantizer from modnef.quantizer import DynamicScaleFactorQuantizer
from snntorch import LIF from snntorch import LIF
class ShiftLIF(ModNEFNeuron): class ShiftLIF(ModNEFNeuron):
""" """
ModNEFTorch Shift LIF neuron model ModNEFTorch Shift LIF neuron model
...@@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron): ...@@ -252,6 +254,7 @@ class ShiftLIF(ModNEFNeuron):
self.mem.data = self.quantizer(self.mem.data, True) self.mem.data = self.quantizer(self.mem.data, True)
input_.data = self.quantizer(input_.data, True) input_.data = self.quantizer(input_.data, True)
self.mem = self.mem+input_ self.mem = self.mem+input_
if self.reset_mechanism == "subtract": if self.reset_mechanism == "subtract":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment