From 025865dcd6173f85022ec30d4ffe8818c88ded91 Mon Sep 17 00:00:00 2001 From: ahoni <aurelie.saulq@proton.me> Date: Thu, 3 Apr 2025 11:01:49 +0200 Subject: [PATCH] add qat --- .../modnef/arch_builder/modules/BLIF/blif.py | 4 ++ modneflib/modnef/modnef_torch/__init__.py | 3 +- modneflib/modnef/modnef_torch/model.py | 30 ++++++-- .../modnef_neurons/blif_model/blif.py | 41 +++++------ .../modnef_neurons/blif_model/rblif.py | 72 ++++++------------- .../modnef_neurons/modnef_torch_neuron.py | 56 +++------------ .../modnef_neurons/slif_model/rslif.py | 68 ++++++++---------- .../modnef_neurons/slif_model/slif.py | 56 ++++++--------- .../modnef_neurons/srlif_model/rshiftlif.py | 50 ++++++------- .../modnef_neurons/srlif_model/shiftlif.py | 57 ++++----------- modneflib/modnef/modnef_torch/quantLinear.py | 59 +++++++++++++++ modneflib/modnef/quantizer/__init__.py | 3 +- modneflib/modnef/quantizer/ste_quantizer.py | 62 ++++++++++++++++ 13 files changed, 291 insertions(+), 270 deletions(-) create mode 100644 modneflib/modnef/quantizer/ste_quantizer.py diff --git a/modneflib/modnef/arch_builder/modules/BLIF/blif.py b/modneflib/modnef/arch_builder/modules/BLIF/blif.py index 4d5f09a..b6a0663 100644 --- a/modneflib/modnef/arch_builder/modules/BLIF/blif.py +++ b/modneflib/modnef/arch_builder/modules/BLIF/blif.py @@ -193,12 +193,16 @@ class BLif(ModNEFArchMod): bw = self.quantizer.bitwidth mem_file = open(f"{output_path}/{self.mem_init_file}", 'w') + + truc = open(f"temp_{self.mem_init_file}", 'w') if self.quantizer.signed: for i in range(self.input_neuron): w_line = 0 for j in range(self.output_neuron-1, -1, -1): + w_line = (w_line<<bw) + two_comp(self.quantizer(weights[i][j]), bw) + truc.write(f"{i} {j} {two_comp(self.quantizer(weights[i][j]), bw)}\n") mem_file.write(f"@{to_hex(i)} {to_hex(w_line)}\n") diff --git a/modneflib/modnef/modnef_torch/__init__.py b/modneflib/modnef/modnef_torch/__init__.py index e57895f..2053dab 100644 --- a/modneflib/modnef/modnef_torch/__init__.py +++ b/modneflib/modnef/modnef_torch/__init__.py @@ -9,4 +9,5 @@ Descriptions: ModNEF torch lib definition from .modnef_neurons import * from .model_builder import ModNEFModelBuilder -from .model import ModNEFModel \ No newline at end of file +from .model import ModNEFModel +from .quantLinear import QuantLinear \ No newline at end of file diff --git a/modneflib/modnef/modnef_torch/model.py b/modneflib/modnef/modnef_torch/model.py index 2c0d687..f81295d 100644 --- a/modneflib/modnef/modnef_torch/model.py +++ b/modneflib/modnef/modnef_torch/model.py @@ -1,17 +1,15 @@ """ File name: model Author: Aurélie Saulquin -Version: 1.0.0 +Version: 1.1.0 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr -Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron +Dependencies: torch, snntorch, modnef_torch_neuron, modnef_driver Descriptions: ModNEF SNN Model """ -import modnef.modnef_torch.modnef_neurons as mn import torch.nn as nn import torch -from modnef.arch_builder import * from modnef.modnef_driver import load_driver_from_yaml from modnef.modnef_torch.modnef_neurons import ModNEFNeuron @@ -101,6 +99,22 @@ class ModNEFModel(nn.Module): if isinstance(m, ModNEFNeuron): m.init_quantizer() + def quantize_hp(self, force_init=False): + """ + Quantize neuron hyper parameters + + Parameters + ---------- + force_init = False : bool + force quantizer initialization + """ + + for m in self.modules(): + if isinstance(m, ModNEFNeuron): + if force_init: + m.init_quantizer() + m.quantize_hp() + def quantize_weight(self, force_init=False): """ Quantize synaptic weight @@ -132,6 +146,14 @@ class ModNEFModel(nn.Module): m.quantize(force_init=force_init) def clamp(self, force_init=False): + """ + Clamp synaptic weight with quantizer born + + Parameters + ---------- + force_init = False : bool + force quantizer initialization + """ for m in self.modules(): if isinstance(m, ModNEFNeuron): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py index 2ba0908..c435ed1 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/blif.py @@ -1,7 +1,7 @@ """ File name: blif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, math, snntorch, modnef.archbuilder, modnef_torch_neuron, modnef.quantizer @@ -9,10 +9,8 @@ Descriptions: ModNEF torch BLIF neuron model Based on snntorch.Leaky and snntorch.LIF class """ -import torch.nn as nn import torch from math import log, ceil -from snntorch import Leaky import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer @@ -223,28 +221,28 @@ class BLIF(ModNEFNeuron): if not spk==None: self.spk = spk - input_ = self.fc(input_) + quant = self.quantizer if self.quantization_flag else None - if not self.mem.shape == input_.shape: - self.mem = torch.zeros_like(input_, device=self.mem.device) - - if self.quantization_flag: - input_.data = self.quantizer(input_.data, True) - self.mem.data = self.quantizer(self.mem.data, True) + forward_current = self.fc(input_, quant) + if not self.mem.shape == forward_current.shape: + self.mem = torch.zeros_like(forward_current, device=self.mem.device) + + self.mem = self.mem + forward_current - self.reset = self.mem_reset(self.mem) + if self.hardware_estimation_flag: + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) if self.reset_mechanism == "subtract": - self.mem = (self.mem+input_)*self.beta-self.reset*self.threshold + self.mem = self.mem*self.beta#-self.reset*self.threshold elif self.reset_mechanism == "zero": - self.mem = (self.mem+input_)*self.beta-self.reset*self.mem + self.mem = self.mem*self.beta-self.reset*self.mem else: self.mem = self.mem*self.beta - if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) + if self.quantization_flag: + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) self.spk = self.fire(self.mem) @@ -295,19 +293,14 @@ class BLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization. We assume you already initialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) + self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer) + self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer) @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py index 586048c..8b49564 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/blif_model/rblif.py @@ -1,7 +1,7 @@ """ File name: rblif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer @@ -9,29 +9,14 @@ Descriptions: ModNEF torch reccurrent BLIF neuron model Based on snntorch.RLeaky and snntorch.LIF class """ -import torch.nn as nn import torch -from snntorch import Leaky import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import log, ceil from modnef.quantizer import * +from modnef.modnef_torch.quantLinear import QuantLinear -import torch.autograd as autograd - -class QuantizeMembrane(autograd.Function): - @staticmethod - def forward(ctx, U, quantizer): - max_val = U.abs().max().detach() # Détachement pour éviter de bloquer le gradient - U_quant = quantizer(U, True) - ctx.save_for_backward(U, quantizer.scale_factor) - return U_quant - - @staticmethod - def backward(ctx, grad_output): - grad_input, factor = ctx.saved_tensors - return grad_output, None class RBLIF(ModNEFNeuron): """ @@ -141,7 +126,7 @@ class RBLIF(ModNEFNeuron): self.register_buffer("beta", torch.tensor(beta)) - self.reccurent = nn.Linear(out_features, out_features, bias=False) + self.reccurent = QuantLinear(out_features, out_features) self._init_mem() @@ -246,41 +231,35 @@ class RBLIF(ModNEFNeuron): if not spk == None: self.spk = spk - input_ = self.fc(input_) + quant = self.quantizer if self.quantization_flag else None - if not self.mem.shape == input_.shape: - self.mem = torch.zeros_like(input_, device=self.mem.device) + forward_current = self.fc(input_, quant) - if not self.spk.shape == input_.shape: - self.spk = torch.zeros_like(input_, device=self.spk.device) + if not self.mem.shape == forward_current.shape: + self.mem = torch.zeros_like(forward_current, device=self.mem.device) - self.reset = self.mem_reset(self.mem) + if not self.spk.shape == forward_current.shape: + self.spk = torch.zeros_like(forward_current, device=self.spk.device) - rec = self.reccurent(self.spk) + self.reset = self.mem_reset(self.mem) + + rec_current = self.reccurent(self.spk, quant) - if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) - input_.data = self.quantizer(input_.data, True) - rec.data = self.quantizer(rec.data, True) + self.mem = self.mem + forward_current + rec_current - # if self.quantization_flag: - # self.mem = QuantizeMembrane.apply(self.mem, self.quantizer) - # input_ = QuantizeMembrane.apply(input_, self.quantizer) - # rec = QuantizeMembrane.apply(rec, self.quantizer) + if self.hardware_estimation_flag: + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) if self.reset_mechanism == "subtract": - self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.threshold + self.mem = self.mem*self.beta-self.reset*self.threshold elif self.reset_mechanism == "zero": - self.mem = (self.mem+input_+rec)*self.beta-self.reset*self.mem + self.mem = self.mem*self.beta-self.reset*self.mem else: self.mem = self.mem*self.beta - if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) - - # if self.quantization_flag: - # self.mem.data = self.quantizer(self.mem.data, True) + if self.quantization_flag: + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) self.spk = self.fire(self.mem) @@ -331,19 +310,14 @@ class RBLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization. We assume you already initialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.threshold.data = self.quantizer(self.threshold.data, unscale) - self.beta.data = self.quantizer(self.beta.data, unscale) + self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer) + self.beta.data = QuantizeSTE.apply(self.beta, self.quantizer) @classmethod diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py index 10ef31c..49b31e5 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/modnef_torch_neuron.py @@ -1,7 +1,7 @@ """ File name: modnef_torch_neuron Author: Aurélie Saulquin -Version: 1.0.0 +Version: 1.0.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch @@ -12,31 +12,7 @@ import torch import torch.nn as nn from modnef.quantizer import * from snntorch._neurons import SpikingNeuron -from snntorch.surrogate import fast_sigmoid - -import torch.nn.functional as F -import brevitas.nn as qnn -import brevitas.quant as bq - -from brevitas.core.quant import QuantType -from brevitas.core.restrict_val import RestrictValueType -from brevitas.core.scaling import ScalingImplType - -class QuantizeSTE(torch.autograd.Function): - @staticmethod - def forward(ctx, weights, quantizer): - - q_weights = quantizer(weights, True) - - #ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward - return q_weights - - @staticmethod - def backward(ctx, grad_output): - # STE : on passe directement le gradient au poids float - #scale_factor, = ctx.saved_tensors - return grad_output, None - +from ..quantLinear import QuantLinear _quantizer = { "FixedPointQuantizer" : FixedPointQuantizer, @@ -82,15 +58,8 @@ class ModNEFNeuron(SpikingNeuron): spike_grad=spike_grad, reset_mechanism=reset_mechanism ) - - # self.fc = qnn.QuantLinear(in_features=in_features, out_features=out_features, bias=False, - # weight_quant_type=QuantType.INT, - # weight_bit_width=8, - # weight_restrict_scaling_type=RestrictValueType.POWER_OF_TWO, - # weight_scaling_impl_type=ScalingImplType.CONST, - # weight_scaling_const=1.0 - # ) - self.fc = nn.Linear(in_features=in_features, out_features=out_features, bias=False) + + self.fc = QuantLinear(in_features=in_features, out_features=out_features) self.hardware_estimation_flag = False @@ -126,7 +95,7 @@ class ModNEFNeuron(SpikingNeuron): else: self.quantizer.init_from_weight(param[0], param[1]) - def quantize_weight(self, unscale : bool = True, ema = False): + def quantize_weight(self): """ synaptic weight quantization @@ -136,21 +105,13 @@ class ModNEFNeuron(SpikingNeuron): """ for p in self.parameters(): - print(p) p.data = QuantizeSTE.apply(p.data, self.quantizer) - - #self.fc.weight = QuantizeSTE.apply(self.fc.weight, self.quantizer) - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization We assume you've already intialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ raise NotImplementedError() @@ -174,6 +135,11 @@ class ModNEFNeuron(SpikingNeuron): def clamp(self, force_init=False): """ Clamp synaptic weight + + Parameters + ---------- + force_init = Fasle : bool + force quantizer initialization """ if force_init: diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py index 7d9c21f..c1915a7 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/rslif.py @@ -1,7 +1,7 @@ """ File name: rslif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer @@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent SLIF neuron model Based on snntorch.RLeaky and snntroch.LIF class """ -import torch.nn as nn import torch -from snntorch import LIF import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import ceil, log -from modnef.quantizer import MinMaxQuantizer +from modnef.quantizer import MinMaxQuantizer, QuantizeSTE +from modnef.modnef_torch.quantLinear import QuantLinear class RSLIF(ModNEFNeuron): """ @@ -131,7 +130,7 @@ class RSLIF(ModNEFNeuron): self.register_buffer("v_min", torch.as_tensor(v_min)) self.register_buffer("v_rest", torch.as_tensor(v_rest)) - self.reccurent = nn.Linear(out_features, out_features, bias=False) + self.reccurent = QuantLinear(out_features, out_features) self._init_mem() @@ -238,45 +237,39 @@ class RSLIF(ModNEFNeuron): if not spk == None: self.spk = spk - - input_ = self.fc(input_) + quant = self.quantizer if self.quantization_flag else None - if not self.mem.shape == input_.shape: - self.mem = torch.ones_like(input_)*self.v_rest - - if not self.spk.shape == input_.shape: - self.spk = torch.zeros_like(input_) + forward_current = self.fc(input_, quant) - self.reset = self.mem_reset(self.mem) + if not self.mem.shape == forward_current.shape: + self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest - rec_input = self.reccurent(self.spk) + if not self.spk.shape == forward_current.shape: + self.spk = torch.zeros_like(forward_current, device=self.spk.device) - if self.quantization_flag: - input_.data = self.quantizer(input_.data, True) - rec_input.data = self.quantizer(rec_input.data, True) - self.mem = self.quantizer(self.mem.data, True) - self.mem = self.mem + input_ + rec_input + self.reset = self.mem_reset(self.mem) + + rec_current = self.reccurent(self.spk, quant) + + self.mem = self.mem + forward_current + rec_current if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) + # update neuron self.mem = self.mem - self.v_leak + min_reset = (self.mem<self.v_min).to(torch.float32) + self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) - self.spk = self.fire(self.mem) + spk = self.fire(self.mem) - do_spike_reset = (self.spk/self.graded_spikes_factor - self.reset) - do_min_reset = (self.mem<self.v_min).to(torch.float32) - - self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest) - self.mem = self.mem - do_min_reset*(self.mem-self.v_rest) - - return self.spk, self.mem + return spk, self.mem def get_builder_module(self, module_name : str, output_path : str = "."): """ @@ -324,21 +317,16 @@ class RSLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization We assume you've already intialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.v_leak.data = self.quantizer(self.v_leak.data, unscale) - self.v_min.data = self.quantizer(self.v_min.data, unscale) - self.v_rest.data = self.quantizer(self.v_rest.data, unscale) - self.threshold.data = self.quantizer(self.threshold, unscale) + self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer) + self.v_min.data = QuantizeSTE(self.v_min, self.quantizer) + self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer) + self.threshold.data = QuantizeSTE(self.threshold, self.quantizer) @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py index 089519b..43e2e1c 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/slif_model/slif.py @@ -1,7 +1,7 @@ """ File name: slif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer @@ -9,14 +9,12 @@ Descriptions: ModNEF torch SLIF neuron model Based on snntorch.Leaky and snntroch.LIF class """ -import torch.nn as nn import torch -from snntorch import LIF import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import ceil, log -from modnef.quantizer import MinMaxQuantizer +from modnef.quantizer import MinMaxQuantizer, QuantizeSTE class SLIF(ModNEFNeuron): """ @@ -235,36 +233,33 @@ class SLIF(ModNEFNeuron): if not spk == None: self.spk = spk - input_ = self.fc(input_) + quant = self.quantizer if self.quantization_flag else None - if not self.mem.shape == input_.shape: - self.mem = torch.ones_like(input_)*self.v_rest - - if self.quantization_flag: - input_.data = self.quantizer(input_.data, True) - self.mem.data = self.quantizer(self.mem.data, True) + forward_current = self.fc(input_, quant) + if not self.mem.shape == forward_current.shape: + self.mem = torch.ones_like(forward_current, device=self.mem.device)*self.v_rest + + self.reset = self.mem_reset(self.mem) - self.mem = self.mem + input_ + self.mem = self.mem + forward_current + if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) - - self.mem = self.mem-self.v_leak + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) + + # update neuron + self.mem = self.mem - self.v_leak + min_reset = (self.mem<self.v_min).to(torch.float32) + self.mem = self.mem-self.reset*(self.mem-self.v_rest)-min_reset*(self.mem-self.v_rest) if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) spk = self.fire(self.mem) - do_spike_reset = (spk/self.graded_spikes_factor - self.reset) - do_min_reset = (self.mem<self.v_min).to(torch.float32) - - self.mem = self.mem - do_spike_reset*(self.mem-self.v_rest) - self.mem = self.mem - do_min_reset*(self.mem-self.v_rest) - return spk, self.mem def get_builder_module(self, module_name : str, output_path : str = "."): @@ -314,21 +309,16 @@ class SLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization We assume you've already intialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.v_leak.data = self.quantizer(self.v_leak.data, unscale) - self.v_min.data = self.quantizer(self.v_min.data, unscale) - self.v_rest.data = self.quantizer(self.v_rest.data, unscale) - self.threshold.data = self.quantizer(self.threshold, unscale) + self.v_leak.data = QuantizeSTE(self.v_leak, self.quantizer) + self.v_min.data = QuantizeSTE(self.v_min, self.quantizer) + self.v_rest.data = QuantizeSTE(self.v_rest, self.quantizer) + self.threshold.data = QuantizeSTE(self.threshold, self.quantizer) @classmethod def detach_hidden(cls): diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py index f3dda9a..0ea9983 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/rshiftlif.py @@ -1,7 +1,7 @@ """ File name: rsrlif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer @@ -9,14 +9,13 @@ Descriptions: ModNEF torch reccurent Shift LIF neuron model Based on snntorch.RLeaky and snntorch.LIF class """ -import torch.nn as nn import torch -from snntorch import LIF import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer +from modnef.modnef_torch.quantLinear import QuantLinear from math import log, ceil -from modnef.quantizer import DynamicScaleFactorQuantizer +from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE class RShiftLIF(ModNEFNeuron): @@ -131,7 +130,7 @@ class RShiftLIF(ModNEFNeuron): print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}") beta = 1-2**-self.shift - self.reccurent = nn.Linear(out_features, out_features, bias=False) + self.reccurent = QuantLinear(out_features, out_features) self.register_buffer("beta", torch.tensor(beta)) @@ -246,25 +245,25 @@ class RShiftLIF(ModNEFNeuron): if not spk == None: self.spk = spk - input_ = self.fc(input_) + quant = self.quantizer if self.quantization_flag else None - if not self.mem.shape == input_.shape: - self.mem = torch.zeros_like(input_, device=self.mem.device) + forward_current = self.fc(input_, quant) - if not self.spk.shape == input_.shape: - self.spk = torch.zeros_like(input_, device=self.spk.device) - - self.reset = self.mem_reset(self.mem) + if not self.mem.shape == forward_current.shape: + self.mem = torch.zeros_like(forward_current, device=self.mem.device) + + if not self.spk.shape == forward_current.shape: + self.spk = torch.zeros_like(forward_current, device=self.spk.device) - rec_input = self.reccurent(self.spk) + rec_current = self.reccurent(self.spk, quant) - if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) - input_.data = self.quantizer(input_.data, True) - rec_input.data = self.quantizer(rec_input.data, True) + self.reset = self.mem_reset(self.mem) + self.mem = self.mem+forward_current+rec_current - self.mem = self.mem+input_+rec_input + if self.hardware_estimation_flag: + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) if self.reset_mechanism == "subtract": self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold @@ -274,11 +273,7 @@ class RShiftLIF(ModNEFNeuron): self.mem = self.mem-self.__shift(self.mem) if self.quantization_flag: - self.mem.data = self.quantizer(self.mem.data, True) - - if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) self.spk = self.fire(self.mem) @@ -330,18 +325,13 @@ class RShiftLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization. We assume you already initialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.threshold.data = self.quantizer(self.threshold.data, unscale) + self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer) @classmethod diff --git a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py index 429d24a..5291733 100644 --- a/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py +++ b/modneflib/modnef/modnef_torch/modnef_neurons/srlif_model/shiftlif.py @@ -1,7 +1,7 @@ """ File name: srlif Author: Aurélie Saulquin -Version: 1.1.0 +Version: 1.2.1 License: GPL-3.0-or-later Contact: aurelie.saulquin@univ-lille.fr Dependencies: torch, snntorch, modnef.archbuilder, modnef_torch_neuron, math, modnef.quantizer @@ -9,32 +9,15 @@ Descriptions: ModNEF torch Shift LIF neuron model Based on snntorch.Leaky and snntroch.LIF class """ -import torch.nn as nn import torch from snntorch.surrogate import fast_sigmoid import modnef.arch_builder as builder from modnef.arch_builder.modules.utilities import * from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from math import log, ceil -from modnef.quantizer import DynamicScaleFactorQuantizer -from snntorch import LIF +from modnef.quantizer import DynamicScaleFactorQuantizer, QuantizeSTE -class QuantizeSTE(torch.autograd.Function): - @staticmethod - def forward(ctx, weights, quantizer): - - q_weights = quantizer(weights, True) - - ctx.save_for_backward(quantizer.scale_factor) # On sauvegarde le scale pour backward - return q_weights - - @staticmethod - def backward(ctx, grad_output): - # STE : on passe directement le gradient au poids float - scale_factor, = ctx.saved_tensors - return grad_output*scale_factor, None - class ShiftLIF(ModNEFNeuron): """ @@ -258,23 +241,20 @@ class ShiftLIF(ModNEFNeuron): if not spk == None: self.spk = spk - input_ = self.fc(input_) - - if not self.mem.shape == input_.shape: - self.mem = torch.zeros_like(input_, device=self.mem.device) + quant = self.quantizer if self.quantization_flag else None - self.reset = self.mem_reset(self.mem) + forward_current = self.fc(input_, quant) - if self.quantization_flag: - self.mem.data = QuantizeSTE.apply(self.mem.data, self.quantizer) - input_.data = QuantizeSTE.apply(input_.data, self.quantizer) + if not self.mem.shape == forward_current.shape: + self.mem = torch.zeros_like(forward_current, device=self.mem.device) + self.reset = self.mem_reset(self.mem) - self.mem = self.mem+input_ + self.mem = self.mem+forward_current if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) + self.val_min = torch.min(self.mem.min(), self.val_min) + self.val_max = torch.max(self.mem.max(), self.val_max) if self.reset_mechanism == "subtract": self.mem = self.mem-self.__shift(self.mem)-self.reset*self.threshold @@ -283,12 +263,8 @@ class ShiftLIF(ModNEFNeuron): else: self.mem = self.mem-self.__shift(self.mem) - # if self.quantization_flag: - # self.mem.data = self.quantizer(self.mem.data, True) - - if self.hardware_estimation_flag: - self.val_min = torch.min(torch.min(input_.min(), self.mem.min()), self.val_min) - self.val_max = torch.max(torch.max(input_.max(), self.mem.max()), self.val_max) + if self.quantization_flag: + self.mem.data = QuantizeSTE.apply(self.mem, self.quantizer) self.spk = self.fire(self.mem) @@ -338,18 +314,13 @@ class ShiftLIF(ModNEFNeuron): ) return module - def quantize_hp(self, unscale : bool = True): + def quantize_hp(self): """ neuron hyper-parameters quantization. We assume you already initialize quantizer - - Parameters - ---------- - unscale : bool = True - set to true if quantization must be simulate """ - self.threshold.data = self.quantizer(self.threshold.data, unscale) + self.threshold.data = QuantizeSTE.apply(self.threshold, self.quantizer) @classmethod diff --git a/modneflib/modnef/modnef_torch/quantLinear.py b/modneflib/modnef/modnef_torch/quantLinear.py index e69de29..57f67b4 100644 --- a/modneflib/modnef/modnef_torch/quantLinear.py +++ b/modneflib/modnef/modnef_torch/quantLinear.py @@ -0,0 +1,59 @@ +""" +File name: quantLinear +Author: Aurélie Saulquin +Version: 0.1.1 +License: GPL-3.0-or-later +Contact: aurelie.saulquin@univ-lille.fr +Dependencies: torch, modnef.quantizer +Descriptions: Quantized Linear torch layer +""" + +import torch.nn as nn +from modnef.quantizer import QuantizeSTE + +class QuantLinear(nn.Linear): + """ + Quantized Linear torch layer + Extended from torch.nn.Linear + + Methods + ------- + forward(x, quantizer=None) + Apply linear forward, if quantizer!=None, quantized weight are used for linear + """ + + def __init__(self, in_features : int, out_features : int): + """ + Initialize class + + Parameters + ---------- + in_features : int + input features of layer + out_features : int + output features of layer + """ + + super().__init__(in_features=in_features, out_features=out_features, bias=False) + + def forward(self, x, quantizer=None): + """ + Apply linear forward, if quantizer!=None, quantized weight are used for linear + + Parameters + ---------- + x : Torch + input spikes + quantizer = None : Quantizer + quantization method. + If None, full precision weight are used for linear + """ + + if quantizer!=None: + w = QuantizeSTE.apply(self.weight, quantizer) + w.data = quantizer.clamp(w) + else: + w = self.weight + + + return nn.functional.linear(x, w) \ No newline at end of file diff --git a/modneflib/modnef/quantizer/__init__.py b/modneflib/modnef/quantizer/__init__.py index 2455aa9..2ef9cb9 100644 --- a/modneflib/modnef/quantizer/__init__.py +++ b/modneflib/modnef/quantizer/__init__.py @@ -10,4 +10,5 @@ Descriptions: ModNEF quantizer method from .quantizer import Quantizer from .fixed_point_quantizer import FixedPointQuantizer from .min_max_quantizer import MinMaxQuantizer -from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer \ No newline at end of file +from .dynamic_scale_quantizer import DynamicScaleFactorQuantizer +from .ste_quantizer import QuantizeSTE \ No newline at end of file diff --git a/modneflib/modnef/quantizer/ste_quantizer.py b/modneflib/modnef/quantizer/ste_quantizer.py new file mode 100644 index 0000000..565cbd8 --- /dev/null +++ b/modneflib/modnef/quantizer/ste_quantizer.py @@ -0,0 +1,62 @@ +""" +File name: ste_quantizer +Author: Aurélie Saulquin +Version: 0.1.0 +License: GPL-3.0-or-later +Contact: aurelie.saulquin@univ-lille.fr +Dependencies: torch +Descriptions: Straight-Throught Estimator quantization method +""" + +import torch + +class QuantizeSTE(torch.autograd.Function): + """ + Straight-Throught Estimator quantization method + + Methods + ------- + @staticmethod + forward(ctx, data, quantizer) + Apply quantization method to data + @staticmethod + backward(ctx, grad_output) + Returns backward gradient + """ + + @staticmethod + def forward(ctx, data, quantizer): + """ + Apply quantization method to data + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFunction + Autograd context used to store variables for the backward pass + data : Tensor + data to quantize + quantizer : Quantizer + quantization method applied to data + """ + + q_data = quantizer(data, True) + + + ctx.scale = quantizer.scale_factor + + return q_data + + @staticmethod + def backward(ctx, grad_output): + """ + Return backward gradient without modificiation + + Parameters + ---------- + ctx : torch.autograd.function.BackwardCFunction + Autograd context used to store variables for the backward pass + grad_output : Tensor + gradient + """ + + return grad_output, None \ No newline at end of file -- GitLab