Skip to content
Snippets Groups Projects
Commit 5f5d1343 authored by Aurélie saulquin's avatar Aurélie saulquin
Browse files

add quantizer modif

parent 7c7480bf
No related branches found
No related tags found
1 merge request!3Dev
...@@ -10,6 +10,8 @@ Descriptions: ModNEF torch neuron interface builder ...@@ -10,6 +10,8 @@ Descriptions: ModNEF torch neuron interface builder
import torch import torch
from modnef.quantizer import * from modnef.quantizer import *
from snntorch._neurons import SpikingNeuron
from snntorch.surrogate import fast_sigmoid
_quantizer = { _quantizer = {
"FixedPointQuantizer" : FixedPointQuantizer, "FixedPointQuantizer" : FixedPointQuantizer,
...@@ -17,7 +19,7 @@ _quantizer = { ...@@ -17,7 +19,7 @@ _quantizer = {
"DynamicScaleFactorQuantizer" : DynamicScaleFactorQuantizer "DynamicScaleFactorQuantizer" : DynamicScaleFactorQuantizer
} }
class ModNEFNeuron(): class ModNEFNeuron(SpikingNeuron):
""" """
ModNEF torch neuron interface ModNEF torch neuron interface
...@@ -42,7 +44,27 @@ class ModNEFNeuron(): ...@@ -42,7 +44,27 @@ class ModNEFNeuron():
create and return the corresponding modnef archbuilder module from internal neuron parameters create and return the corresponding modnef archbuilder module from internal neuron parameters
""" """
def __init__(self, quantizer : Quantizer): def __init__(self,
threshold,
reset_mechanism,
quantizer : Quantizer,
spike_grad=fast_sigmoid(slope=25)):
SpikingNeuron.__init__(
self=self,
threshold=threshold,
reset_mechanism=reset_mechanism,
spike_gard=spike_grad,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_threshold=False,
state_quant=False,
output=False,
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False
)
self.hardware_estimation_flag = False self.hardware_estimation_flag = False
self.quantization_flag = False self.quantization_flag = False
...@@ -61,7 +83,20 @@ class ModNEFNeuron(): ...@@ -61,7 +83,20 @@ class ModNEFNeuron():
raise NotImplementedError() raise NotImplementedError()
def quantize_weight(self): def init_quantizer(self):
params = list(self.parameters())
w1 = params[0].data
if len(params)==2:
w2 = params[0].data
else:
w2 = torch.zeros((1))
self.quantizer.init_quantizer(w1, w2)
def quantize_weight(self, unscaled : bool = False):
""" """
synaptic weight quantization synaptic weight quantization
...@@ -70,9 +105,10 @@ class ModNEFNeuron(): ...@@ -70,9 +105,10 @@ class ModNEFNeuron():
NotImplementedError() NotImplementedError()
""" """
raise NotImplementedError() for param in self.parameters():
param.data = self.quantizer(param.data, unscale=unscaled)
def quantize_parameters(self): def quantize_hp(self):
""" """
neuron hyper-parameters quantization neuron hyper-parameters quantization
""" """
......
...@@ -11,14 +11,14 @@ Based on snntorch.Leaky and snntroch.LIF class ...@@ -11,14 +11,14 @@ Based on snntorch.Leaky and snntroch.LIF class
import torch.nn as nn import torch.nn as nn
import torch import torch
from snntorch import LIF from snntorch.surrogate import fast_sigmoid
import modnef.arch_builder as builder import modnef.arch_builder as builder
from modnef.arch_builder.modules.utilities import * from modnef.arch_builder.modules.utilities import *
from ..modnef_torch_neuron import ModNEFNeuron, _quantizer from ..modnef_torch_neuron import ModNEFNeuron, _quantizer
from math import log, ceil from math import log, ceil
from modnef.quantizer import DynamicScaleFactorQuantizer from modnef.quantizer import DynamicScaleFactorQuantizer
class ShiftLIF(LIF, ModNEFNeuron): class ShiftLIF(ModNEFNeuron):
""" """
ModNEFTorch Shift LIF neuron model ModNEFTorch Shift LIF neuron model
...@@ -86,7 +86,7 @@ class ShiftLIF(LIF, ModNEFNeuron): ...@@ -86,7 +86,7 @@ class ShiftLIF(LIF, ModNEFNeuron):
out_features, out_features,
beta, beta,
threshold=1.0, threshold=1.0,
spike_grad=None, spike_grad=fast_sigmoid(slope=25),
reset_mechanism="subtract", reset_mechanism="subtract",
quantizer=DynamicScaleFactorQuantizer(8) quantizer=DynamicScaleFactorQuantizer(8)
): ):
...@@ -117,25 +117,14 @@ class ShiftLIF(LIF, ModNEFNeuron): ...@@ -117,25 +117,14 @@ class ShiftLIF(LIF, ModNEFNeuron):
print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}") print(f"initial value of beta ({beta}) has been change for {1-2**-self.shift} = 1-2**-{self.shift}")
beta = 1-2**-self.shift beta = 1-2**-self.shift
LIF.__init__(
self=self, ModNEFNeuron.__init__(self=self,
beta=beta,
threshold=threshold, threshold=threshold,
spike_grad=spike_grad,
surrogate_disable=False,
init_hidden=False,
inhibition=False,
learn_beta=False,
learn_threshold=False,
reset_mechanism=reset_mechanism, reset_mechanism=reset_mechanism,
state_quant=False, spike_grad=spike_grad,
output=False, quantizer=quantizer
graded_spikes_factor=1.0,
learn_graded_spikes_factor=False,
) )
ModNEFNeuron.__init__(self=self, quantizer=quantizer)
self.fc = nn.Linear(in_features, out_features, bias=False) self.fc = nn.Linear(in_features, out_features, bias=False)
self._init_mem() self._init_mem()
......
...@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer): ...@@ -96,7 +96,7 @@ class DynamicScaleFactorQuantizer(Quantizer):
is_initialize=config["is_initialize"] is_initialize=config["is_initialize"]
) )
def init_from_weight(self, weight, rec_weight=torch.zeros((1))): def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
""" """
initialize quantizer parameters from synaptic weight initialize quantizer parameters from synaptic weight
......
...@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer): ...@@ -109,7 +109,7 @@ class FixedPointQuantizer(Quantizer):
is_initialize=config["is_initialize"] is_initialize=config["is_initialize"]
) )
def init_from_weight(self, weight, rec_weight=torch.zeros((1))): def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
""" """
initialize quantizer parameters from synaptic weight initialize quantizer parameters from synaptic weight
......
...@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer): ...@@ -105,7 +105,7 @@ class MinMaxQuantizer(Quantizer):
is_initialize=config["is_initialize"] is_initialize=config["is_initialize"]
) )
def init_from_weight(self, weight, rec_weight=torch.zeros((1))): def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
""" """
initialize quantizer parameters from synaptic weight initialize quantizer parameters from synaptic weight
......
...@@ -77,7 +77,7 @@ class Quantizer(): ...@@ -77,7 +77,7 @@ class Quantizer():
raise NotImplementedError() raise NotImplementedError()
def init_from_weight(self, weight, rec_weight=torch.zeros((1))): def init_quantizer(self, weight, rec_weight=torch.zeros((1))):
""" """
initialize quantizer parameters from synaptic weight initialize quantizer parameters from synaptic weight
...@@ -112,11 +112,11 @@ class Quantizer(): ...@@ -112,11 +112,11 @@ class Quantizer():
else: else:
tdata = data tdata = data
if unscale:
qdata = self._unquant(self._quant(tdata))
else:
qdata = self._quant(tdata) qdata = self._quant(tdata)
if unscale:
qdata = self._unquant(qdata)
if isinstance(data, (int, float)): if isinstance(data, (int, float)):
return qdata.item() return qdata.item()
elif isinstance(data, list): elif isinstance(data, list):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment