From e64b1aea7a48b8c837ed2594dec91ff2e77b9548 Mon Sep 17 00:00:00 2001
From: Hammouda Elbez <hammouda.elbez@univ-lille.fr>
Date: Mon, 13 Mar 2023 11:32:35 +0100
Subject: [PATCH] 2DView: added reset when moving backward

---
 src/Modules/General/callbacks.py | 25 ++++++++++++++++---------
 src/Modules/General/layout.py    | 24 ++++++++++++------------
 2 files changed, 28 insertions(+), 21 deletions(-)

diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py
index 2b26e3e..35206d7 100755
--- a/src/Modules/General/callbacks.py
+++ b/src/Modules/General/callbacks.py
@@ -546,10 +546,10 @@ class callbacks(callbacksOp):
 
                 # Callback to handle the 2D view spiking visualization
                 @app.callback(
-                    Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"),
+                    Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),
                     Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
-                    State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data"))
-                def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, StoredData):
+                    State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
+                def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
                     """ Function called each step to update the 2D view
 
                     Args:
@@ -558,16 +558,19 @@ class callbacks(callbacksOp):
                         visUpdateInterval : interval instance that will cause this function to be called each step
                         mouseOverNodeData : contains data of the hovered node
                         elements : nodes description
-                        StoredData : data stored for shared access
                         heatmapData : heatmap data
+                        Layer2DViewFilter : selected layers
 
 
                     Returns:
                         if information tab should be opened or closed
                     """
                     try:
-                        elements = StoredData[0]
+                        elements = super.Spikes2D
                         if dash.callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
+                            for element in elements[1:]:
+                                element["data"]["spiked"] = 0
+                                element["data"]["spikes"] = 0
                             spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True)
                             if spikes:
                                 maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes]) 
@@ -579,14 +582,18 @@ class callbacks(callbacksOp):
                                             if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
                                                 element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2)
                                                 element["data"]["spikes"] = list(list(spike.values())[0].values())[0]
-                                                StoredData[1][i] += list(list(spike.values())[0].values())[0]
+                                                super.AccumulatedSpikes2D[i] += list(list(spike.values())[0].values())[0]
                                             i+=1
-                            return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData]
+                            matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
+                            indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
+                            return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
                         else:
                             try:
-                                return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData]
+                                matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
+                                indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
+                                return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
                             except Exception as e:
-                                return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData]
+                                return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
                     except Exception as e:
                         print("animation2DViewController:" + str(e))
 
diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py
index 9f6d242..fc29291 100755
--- a/src/Modules/General/layout.py
+++ b/src/Modules/General/layout.py
@@ -28,11 +28,12 @@ class layout(layoutOp):
     MaxSpike = dict()
     MaxPotential = dict()
     MaxSynapse = dict()
+    AccumulatedSpikes2D = dict()
+    Spikes2D = dict()
     # LabelPie Data --------------------------------------------------
     Label = [[], []]
     Max = 0
     # ----------------------------------------------------------------
-    Nodes = []
     tabs = []
     label = " "
     visStopped = True
@@ -61,8 +62,9 @@ class layout(layoutOp):
     def toMatrix(self, l,n):
         """ 1D array to 2D
         """
-        return [l[i:i+n] for i in range(0, len(l), n)]
-
+        Matrix = [l[i:i+n] for i in range(0, len(l), n)]
+        return Matrix
+    
     def clearData(self):
         """ Clear the data when moved forward or backward for more than one step
         """
@@ -73,11 +75,12 @@ class layout(layoutOp):
         self.PotentialGraphY.clear()
         self.xAxisLabel.clear()
         self.Label.clear()
-        self.MaxPotential = dict()
-        self.MaxSpike = dict()
-        self.MaxSynapse = dict()
+        self.MaxPotential.clear()
+        self.MaxSpike.clear()
+        self.MaxSynapse.clear()
+        self.Spikes2D = self.generate2DView(self.g)
+        self.AccumulatedSpikes2D = [0 for n in self.Spikes2D if n["data"]["spiked"] != -1]
         self.Max = 0
-        self.Nodes = []
 
     def Vis(self):
         """ Create layer components
@@ -101,9 +104,6 @@ class layout(layoutOp):
                 "Accuracy", style={"width": "25%", "fontWeight": "500"}), html.Td(str(self.g.Accuracy)+" %", style={"width": "25%"})])
         ]
 
-        # Generate 2D View -------------------------------------------
-        self.Nodes = self.generate2DView(self.g)
-
         # Tabs content -----------------------------------------------
         info_vis = dbc.Card(
             dbc.CardBody([dbc.Card([dbc.CardHeader(
@@ -189,7 +189,7 @@ class layout(layoutOp):
                                 ,
                                 html.Div(
                                     [
-                                        dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]),
+                                        #dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]),
                                         cyto.Cytoscape(
                                             id='cytoscape-compound',
                                             layout={'name': 'grid','animate': False},
@@ -255,7 +255,7 @@ class layout(layoutOp):
                                                     }
                                                 }
                                             ],
-                                            elements=self.Nodes
+                                            elements=self.Spikes2D
                                         ),
                                         html.P(id="spikes_info", style={"padding": "8px"})
 
-- 
GitLab