diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index 01f59f90e268fd3d374aa169783b9021733365f8..c759fcf8212af5aea1aec88b703bf488f7021f84 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -18,6 +18,21 @@ from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import log, ceil from modnef.quantizer import * +import torch.autograd as autograd + +class QuantizeMembrane(autograd.Function): + @staticmethod + def forward(ctx, U, quantizer): + max_val = U.abs().max().detach() # Détachement pour éviter de bloquer le gradient + U_quant = quantizer(U, True) + ctx.save_for_backward(U, quantizer.scale_factor) + return U_quant + + @staticmethod + def backward(ctx, grad_output): + grad_input, factor = ctx.saved_tensors + return grad_output, None + class RBLIF(ModNEFNeuron): """ ModNEFTorch reccurent BLIF neuron model @@ -243,10 +258,15 @@ class RBLIF(ModNEFNeuron): rec = self.reccurent(self.spk) + # if self.quantization_flag: + # self.mem.data = self.quantizer(self.mem.data, True) + # input_.data = self.quantizer(input_.data, True) + # rec.data = self.quantizer(rec.data, True) + if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) - input_.data = self.quantizer(input_.data, True) - rec.data = self.quantizer(rec.data, True) + self.mem = QuantizeMembrane.apply(self.mem, self.quantizer) + input_ = QuantizeMembrane.apply(input_, self.quantizer) + rec = QuantizeMembrane.apply(rec, self.quantizer) if self.reset_mechanism == "subtract": self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold