Skip to content
Snippets Groups Projects
Commit 2891a561 authored by Gaspard Goupy's avatar Gaspard Goupy
Browse files

comments and typo

parent 81cad827
Branches
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ import sys
import numpy as np
from tqdm import tqdm
from spikenn.snn import Fc
from spikenn.train import S2STDPOptimizer, RSTDPOptimizer, S4NNPOptimizer, AdditiveSTDP, MultiplicativeSTDP, BaseRegularizer, CompetitionRegularizerTwo, CompetitionRegularizerOne
from spikenn.train import S2STDPOptimizer, RSTDPOptimizer, S4NNOptimizer, AdditiveSTDP, MultiplicativeSTDP, BaseRegularizer, CompetitionRegularizerTwo, CompetitionRegularizerOne
from spikenn.utils import DecisionMap, Logger, EarlyStopper
from spikenn._impl import spike_sort
......@@ -249,7 +249,7 @@ class Readout:
config_optim = config["optimizer"]
# BP-based rule
if config_optim["method"] == "s4nn":
optim = S4NNPOptimizer(
optim = S4NNOptimizer(
network=network,
t_gap=config_optim["t_gap"],
class_inhib=config_optim.get('class_inhib', False), # intra-class WTA, use when multiple neurons per class (e.g. with NCGs),
......
......@@ -126,7 +126,6 @@ def class_inhibition(spks, pots, decision_map, max_time):
# SSTDP and S2-STDP weight update
# NOTE: Do not handle well dropout on the output neurons (contrary to R-STDP)
@njit
def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, ap, am, anti_ap, anti_am, stdp_func, stdp_args):
n_layers = len(outputs)
......@@ -205,8 +204,7 @@ def s2stdp(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_ti
# S4NN backward pass
# Almost similar to sstdp function, but in another function for the sake of clarity
# NOTE: Do not handle well dropout on the output neurons (contrary to R-STDP)
# Almost similar to sstdp code, but in another function for the sake of clarity
@njit
def s4nn(outputs, network_weights, y, decision_map, t_gap, class_inhib, use_time_ranges, max_time, lr):
n_layers = len(outputs)
......
......@@ -116,9 +116,9 @@ class Fc(SpikingLayer):
self.input_size = input_size
self.n_neurons = n_neurons
# NOTE: Implementation is in a function augmented with Numba to speed up computations on CPU
# NOTE: The implementation is in a function optimized with Numba to accelerate CPU computations.
def __call__(self, sample):
# Convert sample to SpikingDataset format (needed for multi-layer networks only)
# Convert dense sample to SpikingDataset format (needed for multi-layer networks only)
sample = self.convert_input(sample)
# Select the employed thresholds
thresholds = self.thresholds_train if self.train_mode else self.thresholds
......
......@@ -40,9 +40,8 @@ class S2STDPOptimizer(STDPOptimizer):
# NOTE:
# Implementation is in a function augmented with Numba to speed up computations on CPU
#
# code is not very clean, should be reworked.
# The implementation is in a function optimized with Numba to accelerate CPU computations.
# Hence, code is not very clean, should be reworked...
# However, Numba is not easily implementable in class methods.
def __call__(self, outputs, target_ind, decision_map):
n_layers = len(self.network)
......@@ -138,7 +137,7 @@ class RSTDPOptimizer(STDPOptimizer):
# S4NN optimizer (BP-based)
# NOTE: Can only train the output layer of the network
class S4NNPOptimizer:
class S4NNOptimizer:
__slots__ = ('network', 'lr', 't_gap', 'class_inhib', 'use_time_ranges', 'annealing', 'max_time')
......@@ -157,9 +156,8 @@ class S4NNPOptimizer:
# NOTE:
# Implementation is in a function augmented with Numba to speed up computations on CPU
#
# code is not very clean, should be reworked.
# The implementation is in a function optimized with Numba to accelerate CPU computations.
# Hence, code is not very clean, should be reworked...
# However, Numba is not easily implementable in class methods.
def __call__(self, outputs, target_ind, decision_map):
n_layers = len(self.network)
......@@ -296,7 +294,6 @@ class CompetitionRegularizerOne(BaseRegularizer):
# STDP inferface to store parameters and callable to core function
# TODO: Implementation not very clean, to rework
class MultiplicativeSTDP:
__slots__ = ('beta', 'w_min', 'w_max')
......@@ -312,7 +309,6 @@ class MultiplicativeSTDP:
# STDP inferface to store parameters and callable to core function
# TODO: Implementation not very clean, to rework
class AdditiveSTDP:
__slots__ = ('max_time')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment