diff --git a/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala b/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala index 5cca257b38e11c6bcdaf237643547f097349372d..94937a07d09d3395dbb81ac97c2994b79062f537 100644 --- a/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala +++ b/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala @@ -18,6 +18,7 @@ class ConnectionIndex(inputLayer : NeuronGroupRef, outputLayer : NeuronGroupRef) } }.toMap + def getConnectionsBetween(input : NetworkEntityPath, output : NetworkEntityPath) : Seq[ConnectionPath] = { index.getOrElse(output, Map()).getOrElse(input, Seq()) } diff --git a/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala b/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala index 03cc19dc39954e87a7a5c76f916bb02805721584..9e3ee29a127acd5cd3ccf31e3ed47cbd9df43f51 100644 --- a/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala +++ b/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala @@ -101,18 +101,20 @@ object ExampleMnist2 extends App { val unsupervisedLayer1 = n2s3.createNeuronGroup() .setIdentifier("Layer1") - .setNumberOfNeurons(5) + .setNumberOfNeurons(20) .setNeuronModel(LIF, Seq( - (MembranePotentialThreshold, 35 millivolts))) + (MembranePotentialThreshold, 15 millivolts))) val unsupervisedLayer2 = n2s3.createNeuronGroup() .setIdentifier("Layer2") - .setNumberOfNeurons(5) + .setNumberOfNeurons(10) .setNeuronModel(LIF, Seq( (MembranePotentialThreshold, 5 millivolts))) inputLayer.connectTo(unsupervisedLayer1, new FullConnection(() => new SimplifiedSTDP())) var Layer1WTAconnection = unsupervisedLayer1.connectTo(unsupervisedLayer1, new FullConnection(() => new InhibitorySynapse())) + unsupervisedLayer1.connectTo(unsupervisedLayer2, new FullConnection(() => new SimplifiedSTDP())) + var Layer2WTAconnection = unsupervisedLayer2.connectTo(unsupervisedLayer2, new FullConnection(() => new InhibitorySynapse())) n2s3.create() @@ -161,7 +163,6 @@ object ExampleMnist2 extends App { simTime = "Train L1" n2s3.runAndWait() - println(System.currentTimeMillis() - globalTime) simTime = simTime + " | " + (System.currentTimeMillis() - globalTime) + "\n" if (log) { @@ -172,8 +173,6 @@ object ExampleMnist2 extends App { unsupervisedLayer1.fixNeurons() Layer1WTAconnection.disconnect() - unsupervisedLayer1.connectTo(unsupervisedLayer2, new FullConnection(() => new SimplifiedSTDP())) - var Layer2WTAconnection = unsupervisedLayer2.connectTo(unsupervisedLayer2, new FullConnection(() => new InhibitorySynapse())) n2s3.first = false stream = InputMnist.DataParts(dataFile, labelFile, sizeChunk, sizeChunk) inputStream.append(stream) @@ -181,7 +180,7 @@ object ExampleMnist2 extends App { println("Start Training L2") var layerTolayerIndex = new ConnectionIndex(unsupervisedLayer1, unsupervisedLayer2) - print(layerTolayerIndex) + var listOfConnexions2 = for (outputIndex <- 0 until unsupervisedLayer2.shape.getNumberOfPoints) yield { for (_ <- 0 until 1) yield { for {