Skip to content
Snippets Groups Projects
Commit 2891a561 authored by Gaspard Goupy's avatar Gaspard Goupy
Browse files

comments and typo

parent 81cad827
Branches
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ import sys ...@@ -2,7 +2,7 @@ import sys
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from spikenn.snn import Fc from spikenn.snn import Fc
from spikenn.train import S2STDPOptimizer, RSTDPOptimizer, S4NNPOptimizer, AdditiveSTDP, MultiplicativeSTDP, BaseRegularizer, CompetitionRegularizerTwo, CompetitionRegularizerOne from spikenn.train import S2STDPOptimizer, RSTDPOptimizer, S4NNOptimizer, AdditiveSTDP, MultiplicativeSTDP, BaseRegularizer, CompetitionRegularizerTwo, CompetitionRegularizerOne
from spikenn.utils import DecisionMap, Logger, EarlyStopper from spikenn.utils import DecisionMap, Logger, EarlyStopper
from spikenn._impl import spike_sort from spikenn._impl import spike_sort
...@@ -249,7 +249,7 @@ class Readout: ...@@ -249,7 +249,7 @@ class Readout:
config_optim = config["optimizer"] config_optim = config["optimizer"]
# BP-based rule # BP-based rule
if config_optim["method"] == "s4nn": if config_optim["method"] == "s4nn":
optim = S4NNPOptimizer( optim = S4NNOptimizer(
network=network, network=network,
t_gap=config_optim["t_gap"], t_gap=config_optim["t_gap"],
class_inhib=config_optim.get('class_inhib', False), # intra-class WTA, use when multiple neurons per class (e.g. with NCGs), class_inhib=config_optim.get('class_inhib', False), # intra-class WTA, use when multiple neurons per class (e.g. with NCGs),
......
...@@ -126,7 +126,6 @@ def class_inhibition(spks, pots, decision_map, max_time): ...@@ -126,7 +126,6 @@ def class_inhibition(spks, pots, decision_map, max_time):
# SSTDP and S2-STDP weight update # SSTDP and S2-STDP weight update
# NOTE: Do not handle well dropout on the output neurons (contrary to R-STDP)
@njit @njit
def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, ap, am, anti_ap, anti_am, stdp_func, stdp_args): def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, ap, am, anti_ap, anti_am, stdp_func, stdp_args):
n_layers = len(outputs) n_layers = len(outputs)
...@@ -205,8 +204,7 @@ def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_ti ...@@ -205,8 +204,7 @@ def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_ti
# S4NN backward pass # S4NN backward pass
# Almost similar to sstdp function, but in another function for the sake of clarity # Almost similar to sstdp code, but in another function for the sake of clarity
# NOTE: Do not handle well dropout on the output neurons (contrary to R-STDP)
@njit @njit
def s4nn(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, lr): def s4nn(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, lr):
n_layers = len(outputs) n_layers = len(outputs)
......
...@@ -116,9 +116,9 @@ class Fc(SpikingLayer): ...@@ -116,9 +116,9 @@ class Fc(SpikingLayer):
self.input_size = input_size self.input_size = input_size
self.n_neurons = n_neurons self.n_neurons = n_neurons
# NOTE: Implementation is in a function augmented with Numba to speed up computations on CPU # NOTE: The implementation is in a function optimized with Numba to accelerate CPU computations.
def __call__(self, sample): def __call__(self, sample):
# Convert sample to SpikingDataset format (needed for multi-layer networks only) # Convert dense sample to SpikingDataset format (needed for multi-layer networks only)
sample = self.convert_input(sample) sample = self.convert_input(sample)
# Select the employed thresholds # Select the employed thresholds
thresholds = self.thresholds_train if self.train_mode else self.thresholds thresholds = self.thresholds_train if self.train_mode else self.thresholds
......
...@@ -40,9 +40,8 @@ class S2STDPOptimizer(STDPOptimizer): ...@@ -40,9 +40,8 @@ class S2STDPOptimizer(STDPOptimizer):
# NOTE: # NOTE:
# Implementation is in a function augmented with Numba to speed up computations on CPU # The implementation is in a function optimized with Numba to accelerate CPU computations.
# # Hence, code is not very clean, should be reworked...
# code is not very clean, should be reworked.
# However, Numba is not easily implementable in class methods. # However, Numba is not easily implementable in class methods.
def __call__(self, outputs, target_ind, decision_map): def __call__(self, outputs, target_ind, decision_map):
n_layers = len(self.network) n_layers = len(self.network)
...@@ -138,7 +137,7 @@ class RSTDPOptimizer(STDPOptimizer): ...@@ -138,7 +137,7 @@ class RSTDPOptimizer(STDPOptimizer):
# S4NN optimizer (BP-based) # S4NN optimizer (BP-based)
# NOTE: Can only train the output layer of the network # NOTE: Can only train the output layer of the network
class S4NNPOptimizer: class S4NNOptimizer:
__slots__ = ('network', 'lr', 't_gap', 'class_inhib', 'use_time_ranges', 'annealing', 'max_time') __slots__ = ('network', 'lr', 't_gap', 'class_inhib', 'use_time_ranges', 'annealing', 'max_time')
...@@ -157,9 +156,8 @@ class S4NNPOptimizer: ...@@ -157,9 +156,8 @@ class S4NNPOptimizer:
# NOTE: # NOTE:
# Implementation is in a function augmented with Numba to speed up computations on CPU # The implementation is in a function optimized with Numba to accelerate CPU computations.
# # Hence, code is not very clean, should be reworked...
# code is not very clean, should be reworked.
# However, Numba is not easily implementable in class methods. # However, Numba is not easily implementable in class methods.
def __call__(self, outputs, target_ind, decision_map): def __call__(self, outputs, target_ind, decision_map):
n_layers = len(self.network) n_layers = len(self.network)
...@@ -296,7 +294,6 @@ class CompetitionRegularizerOne(BaseRegularizer): ...@@ -296,7 +294,6 @@ class CompetitionRegularizerOne(BaseRegularizer):
# STDP inferface to store parameters and callable to core function # STDP inferface to store parameters and callable to core function
# TODO: Implementation not very clean, to rework
class MultiplicativeSTDP: class MultiplicativeSTDP:
__slots__ = ('beta', 'w_min', 'w_max') __slots__ = ('beta', 'w_min', 'w_max')
...@@ -312,7 +309,6 @@ class MultiplicativeSTDP: ...@@ -312,7 +309,6 @@ class MultiplicativeSTDP:
# STDP inferface to store parameters and callable to core function # STDP inferface to store parameters and callable to core function
# TODO: Implementation not very clean, to rework
class AdditiveSTDP: class AdditiveSTDP:
__slots__ = ('max_time') __slots__ = ('max_time')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment