Skip to content
Snippets Groups Projects
Commit 9c1608dc authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

Multiple classes updated

parent acb7e5fc
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ int main(int argc, char** argv) {
experiment.template add_test<dataset::Mnist>(input_path+"t10k-images.idx3-ubyte", input_path+"t10k-labels.idx1-ubyte");
float th_lr = 1.0f;
float t_obj = 0.50f;
float t_obj = 0.75f;
float w_lr = 0.1f;
float alpha = 0.05f;
......
......@@ -34,7 +34,7 @@ int main(int argc, char** argv) {
experiment.template add_test<dataset::Mnist>(input_path+"t10k-images.idx3-ubyte", input_path+"t10k-labels.idx1-ubyte");
float th_lr = 1.0f;
float t_obj = 0.50f;
float t_obj = 0.75f;
float alpha = 0.05f;
float alpha_p= 0.01f;
float alpha_n= 0.005f;
......
......@@ -36,7 +36,7 @@ AbstractExperiment::AbstractExperiment(const std::string& name) :
std::cout << "Experiment renamed in " << _name << std::endl;
}
_isRandom = "";
_isRandom = "MY_5IM";
std::seed_seq seed(std::begin(_name), std::end(_name));
_random_generator.seed(seed);
......
......@@ -257,6 +257,7 @@ void _priv::DenseImpl::train(const std::vector<Spike>& input_spike, const Tensor
_a.at(0, 0, z) += w.at(spike.x, spike.y, spike.z, z);
// check if the voltage crossed the threshold
if(_a.at(0, 0, z) >= th.at(z)) {
_model.layer_Spikes+=1;
for(size_t z1=0; z1<depth; z1++) {
th.at(z1) -= _model._lr_th*(spike.time - _model._t_obj);
......@@ -321,4 +322,4 @@ void _priv::DenseImpl::test(const std::vector<Spike>& input_spike, const Tensor<
}
}
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment