diff --git a/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala b/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala
index 5cca257b38e11c6bcdaf237643547f097349372d..94937a07d09d3395dbb81ac97c2994b79062f537 100644
--- a/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala
+++ b/n2s3/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/core/ConnectionIndex.scala
@@ -18,6 +18,7 @@ class ConnectionIndex(inputLayer : NeuronGroupRef, outputLayer : NeuronGroupRef)
}
}.toMap
+
def getConnectionsBetween(input : NetworkEntityPath, output : NetworkEntityPath) : Seq[ConnectionPath] = {
index.getOrElse(output, Map()).getOrElse(input, Seq())
}
diff --git a/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala b/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala
index 03cc19dc39954e87a7a5c76f916bb02805721584..9e3ee29a127acd5cd3ccf31e3ed47cbd9df43f51 100644
--- a/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala
+++ b/n2s3_examples/src/main/scala/fr/univ_lille/cristal/emeraude/n2s3/apps/ExampleMnist2.scala
@@ -101,18 +101,20 @@ object ExampleMnist2 extends App {
val unsupervisedLayer1 = n2s3.createNeuronGroup()
.setIdentifier("Layer1")
- .setNumberOfNeurons(5)
+ .setNumberOfNeurons(20)
.setNeuronModel(LIF, Seq(
- (MembranePotentialThreshold, 35 millivolts)))
+ (MembranePotentialThreshold, 15 millivolts)))
val unsupervisedLayer2 = n2s3.createNeuronGroup()
.setIdentifier("Layer2")
- .setNumberOfNeurons(5)
+ .setNumberOfNeurons(10)
.setNeuronModel(LIF, Seq(
(MembranePotentialThreshold, 5 millivolts)))
inputLayer.connectTo(unsupervisedLayer1, new FullConnection(() => new SimplifiedSTDP()))
var Layer1WTAconnection = unsupervisedLayer1.connectTo(unsupervisedLayer1, new FullConnection(() => new InhibitorySynapse()))
+ unsupervisedLayer1.connectTo(unsupervisedLayer2, new FullConnection(() => new SimplifiedSTDP()))
+ var Layer2WTAconnection = unsupervisedLayer2.connectTo(unsupervisedLayer2, new FullConnection(() => new InhibitorySynapse()))
n2s3.create()
@@ -161,7 +163,6 @@ object ExampleMnist2 extends App {
simTime = "Train L1"
n2s3.runAndWait()
- println(System.currentTimeMillis() - globalTime)
simTime = simTime + " | " + (System.currentTimeMillis() - globalTime) + "\n"
if (log) {
@@ -172,8 +173,6 @@ object ExampleMnist2 extends App {
unsupervisedLayer1.fixNeurons()
Layer1WTAconnection.disconnect()
- unsupervisedLayer1.connectTo(unsupervisedLayer2, new FullConnection(() => new SimplifiedSTDP()))
- var Layer2WTAconnection = unsupervisedLayer2.connectTo(unsupervisedLayer2, new FullConnection(() => new InhibitorySynapse()))
n2s3.first = false
stream = InputMnist.DataParts(dataFile, labelFile, sizeChunk, sizeChunk)
inputStream.append(stream)
@@ -181,7 +180,7 @@ object ExampleMnist2 extends App {
println("Start Training L2")
var layerTolayerIndex = new ConnectionIndex(unsupervisedLayer1, unsupervisedLayer2)
- print(layerTolayerIndex)
+
var listOfConnexions2 = for (outputIndex <- 0 until unsupervisedLayer2.shape.getNumberOfPoints) yield {
for (_ <- 0 until 1) yield {
for {