From c1b4b37f9f9c22ce624bec95d610ef780712e0c7 Mon Sep 17 00:00:00 2001
From: Hammouda Elbez <hammouda.elbez@univ-lille.fr>
Date: Fri, 26 May 2023 16:51:32 +0200
Subject: [PATCH] 2D view: adding ability to filter based on layer

---
 src/Modules/General/callbacks.py | 29 +++++++++++++----------------
 src/Modules/General/layout.py    | 13 +++++++------
 2 files changed, 20 insertions(+), 22 deletions(-)

diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py
index b9e4258..75bd78d 100755
--- a/src/Modules/General/callbacks.py
+++ b/src/Modules/General/callbacks.py
@@ -544,17 +544,16 @@ class callbacks(callbacksOp):
 
                 # Callback to handle the 2D view spiking visualization
                 @app.callback(
-                    Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'),
-                    Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
+                    Output("cytoscape-compound", "elements"),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'),
+                    Input("vis-update", "n_intervals"),Input("v-step", "children"),
                     State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
-                def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
+                def animation2DView(visUpdateInterval,sliderValue, updateInterval, elements, Layer2DViewFilter):
                     """ Function called each step to update the 2D view
 
                     Args:
                         sliderValue (int): value of the slider
                         updateInterval (int): update Interval (s)
                         visUpdateInterval : interval instance that will cause this function to be called each step
-                        mouseOverNodeData : contains data of the hovered node
                         elements : nodes description
                         heatmapData : heatmap data
                         Layer2DViewFilter : selected layers
@@ -563,7 +562,7 @@ class callbacks(callbacksOp):
                         if information tab should be opened or closed
                     """
                     try:
-                        elements = super.Spikes2D
+                        elements = super.generate2DView(g,Layer2DViewFilter)
                         matrix = {}
                         indices = {}
                         labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval)
@@ -601,18 +600,16 @@ class callbacks(callbacksOp):
                                 matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
                                 indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
 
-
-
                             if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
                                 for layer in super.AccumulatedSpikes2D:
                                     if layer not in matrix:
                                         matrix[layer] = []
                                         indices[layer] = []
 
-                            heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
+                            heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
 
-                            SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
-                            return [elements,[],heatmaps,SpikesActivityPerInput]
+                            SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput]
+                            return [elements,heatmaps,SpikesActivityPerInput]
                         else:
                             
                             try:
@@ -626,10 +623,10 @@ class callbacks(callbacksOp):
                                             matrix[layer] = []
                                             indices[layer] = []
 
-                                heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
+                                heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
 
-                                SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
-                                return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps,SpikesActivityPerInput]
+                                SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput]
+                                return [elements,heatmaps,SpikesActivityPerInput]
                             except Exception:
                                 print("OnHover:"+traceback.format_exc())
                                 return no_update
@@ -692,11 +689,11 @@ class callbacks(callbacksOp):
                     res.append(getLoss(timestamp, g.updateInterval))
                 return res
                 
-            def make_SpikeActivityPerInput(layer,dataLabel):
+            def make_SpikeActivityPerInput(layer,dataLabel,Layer2DViewFilter):
                 fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]])
-                fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer], colorscale= 'Reds',hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1)
+                fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer] if layer in Layer2DViewFilter else [], colorscale= 'Reds', zmin=0, hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1)
                 fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'})
-                if dataLabel == None:
+                if (dataLabel == None) or (layer not in Layer2DViewFilter):
                     fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
                 else:
                     fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py
index 7cc7508..446c73d 100755
--- a/src/Modules/General/layout.py
+++ b/src/Modules/General/layout.py
@@ -46,15 +46,16 @@ class layout(layoutOp):
 
     # 2D view --------------------------------------------------------
 
-    def generate2DView(self, g):
+    def generate2DView(self, g, layers):
         """ Generates a 2D View of the neural network
 
         Args:
             g (Global_Var): reference to access global variables
+            layers: (array): list of layer names
         """
         Nodes = []
         # Create the neurones and layers
-        for L in g.LayersNeuronsInfo:
+        for L in [l for l in g.LayersNeuronsInfo if l["layer"] in layers]:
             Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}})
             for i in range(L["neuronNbr"]):
                 Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20})
@@ -82,7 +83,7 @@ class layout(layoutOp):
         self.MaxPotential.clear()
         self.MaxSpike.clear()
         self.MaxSynapse.clear()
-        self.Spikes2D = self.generate2DView(self.g)
+        self.Spikes2D = self.generate2DView(self.g,[str(i) for i in (i["layer"] for i in self.g.LayersNeuronsInfo)])
         self.AccumulatedSpikes2D = {i["layer"]:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i["layer"] == n["data"]["parent"]] for i in self.g.LayersNeuronsInfo}
         self.SpikesActivityPerInput = {i["layer"]:[[0 for j in range(self.g.nbrClasses+1)] for _ in range(self.g.Layer_Neuron[i["layer"]])] for i in self.g.LayersNeuronsInfo}
         self.Max = 0
@@ -269,7 +270,7 @@ class layout(layoutOp):
                                             ],
                                             elements=self.Spikes2D
                                         )],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}),
-                                        html.P(id="spikes_info", style={"padding": "8px","margin":"0px"})
+                                        #html.P(id="spikes_info", style={"padding": "8px","margin":"0px"})
 
                                     ], style={ "textAlign": "start", "padding": "0px"},className="col-lg-5 col-sm-12 col-xs-12"),
                                     
@@ -333,8 +334,8 @@ class layout(layoutOp):
                             html.Span("Update Interval (s)",
                                       className="input-group-text")
                         ], className="input-group-prepend"),
-                        dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.005,
-                                  max=180, step=0.005, style={"width": "30%", "textAlign": "center"})
+                        dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.001,
+                                  max=180, step=0.001, style={"width": "30%", "textAlign": "center"})
                     ], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"})
                 ], className="d-flex justify-content-center"), dbc.Col(
                     [
-- 
GitLab