Skip to content
Snippets Groups Projects
Commit 2358ebbf authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

2D View multilayer support added

parent 5448099b
No related branches found
No related tags found
1 merge request!26Custom 2d view
......@@ -4,15 +4,13 @@
"""
from collections import deque
import dash
import pymongo
import traceback
from bson.json_util import dumps
from bson.json_util import loads
import plotly.graph_objects as go
from dash.exceptions import PreventUpdate
from dash import (no_update, Input, Output, State, ALL, callback_context)
from plotly.subplots import make_subplots
from dash.dependencies import Input, Output, State, MATCH
from src.templates.callbacksOp import callbacksOp
class callbacks(callbacksOp):
......@@ -287,7 +285,7 @@ class callbacks(callbacksOp):
array of outputs that are selected in the callback
"""
try:
context = dash.callback_context.triggered[0]['prop_id'].split('.')[
context = callback_context.triggered[0]['prop_id'].split('.')[
0]
# update interval value if changed
if(g.updateInterval != float(updateInterval)):
......@@ -331,7 +329,7 @@ class callbacks(callbacksOp):
super.clearData()
clearGraphs = not clearGraphs
else:
if context == "btn-next" and dash.callback_context.triggered[0]['value'] != None:
if context == "btn-next" and callback_context.triggered[0]['value'] != None:
if(sliderValue < g.stepMax):
sliderValue = sliderValue + 1
......@@ -383,7 +381,7 @@ class callbacks(callbacksOp):
Returns:
array of outputs that are selected in the callback
"""
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] == "btnControle":
if callback_context.triggered[0]['prop_id'].split('.')[0] == "btnControle":
if playButtonText == "Start":
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
......@@ -396,13 +394,13 @@ class callbacks(callbacksOp):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
raise PreventUpdate
return no_update
else:
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
raise PreventUpdate
return no_update
# Callback to handle general graph content
@app.callback(
......@@ -420,14 +418,14 @@ class callbacks(callbacksOp):
generalGraphSwitchIsOn (bool): general graph switch value
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
no_update: in case we don't want to update the content we rise this execption
Returns:
content of the graph that contains general information on the network activity
"""
if generalGraphSwitchIsOn:
if len(super.xAxisLabel) > 0 and "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(g.updateInterval, sliderValue+1)+"]" == super.xAxisLabel[-1]:
raise PreventUpdate
return no_update
if(not super.visStopped):
generalData = GeneralModuleData(
......@@ -444,7 +442,7 @@ class callbacks(callbacksOp):
else:
if(sliderValue > g.stepMax):
raise PreventUpdate
return no_update
else:
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
......@@ -452,7 +450,7 @@ class callbacks(callbacksOp):
generalData, int(sliderValue), generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
raise PreventUpdate
return no_update
# Callback to handle label graph content
@app.callback(
......@@ -497,7 +495,7 @@ class callbacks(callbacksOp):
labelData)
return [labelInfoTreemap]
else:
raise PreventUpdate
return no_update
# Callback to update the speed of visualization
@app.callback(
......@@ -537,7 +535,7 @@ class callbacks(callbacksOp):
if information tab should be opened or closed
"""
try:
if dash.callback_context.triggered[0]["value"] != None:
if callback_context.triggered[0]["value"] != None:
return [not isTabOpen]
else:
return [isTabOpen]
......@@ -546,10 +544,10 @@ class callbacks(callbacksOp):
# Callback to handle the 2D view spiking visualization
@app.callback(
Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),
Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
State("interval", "value"),State('cytoscape-compound', 'elements'),State({"index": ALL, "type": '2DView-heatmap'}, "id"),State("2DViewLayerFilter", "value"))
def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, selectedItem, Layer2DViewFilter):
""" Function called each step to update the 2D view
Args:
......@@ -559,41 +557,70 @@ class callbacks(callbacksOp):
mouseOverNodeData : contains data of the hovered node
elements : nodes description
heatmapData : heatmap data
selectedItem (list): selected layer
Layer2DViewFilter : selected layers
Returns:
if information tab should be opened or closed
"""
try:
elements = super.Spikes2D
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
for element in elements[1:]:
matrix = {}
indices = {}
if callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
for element in elements:
if element["data"]['spiked'] != -1:
element["data"]["spiked"] = 0
element["data"]["spikes"] = 0
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True)
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,Layer2DViewFilter,True)
for layer in Layer2DViewFilter:
if spikes:
maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes])
for spike in spikes:
if list(spike.keys())[0] == "Layer1":
if list(spike.keys())[0] == layer:
# update the spikes neurons
i = 0
for element in elements[1:]:
if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
for element in elements:
if element["data"]['spiked'] != -1:
if (element["data"]["id"] == layer+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2)
element["data"]["spikes"] = list(list(spike.values())[0].values())[0]
super.AccumulatedSpikes2D[i] += list(list(spike.values())[0].values())[0]
super.AccumulatedSpikes2D[layer][int(element["data"]["label"])] += list(list(spike.values())[0].values())[0]
i+=1
matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
for layer in super.AccumulatedSpikes2D:
if layer not in matrix:
matrix[layer] = []
indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
return [elements,[],heatmaps]
else:
try:
matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
for layer in Layer2DViewFilter:
matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
for layer in super.AccumulatedSpikes2D:
if layer not in matrix:
matrix[layer] = []
indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps]
except Exception:
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
print("OnHover:"+traceback.format_exc())
return no_update
except Exception:
print("animation2DViewController:" + traceback.format_exc())
......
......@@ -3,6 +3,7 @@
import importlib
import traceback
import math
import dash_daq as daq
from collections import deque
import dash_cytoscape as cyto
......@@ -29,7 +30,7 @@ class layout(layoutOp):
MaxSpike = dict()
MaxPotential = dict()
MaxSynapse = dict()
AccumulatedSpikes2D = dict()
AccumulatedSpikes2D = []
Spikes2D = dict()
# LabelPie Data --------------------------------------------------
Label = [[], []]
......@@ -55,14 +56,15 @@ class layout(layoutOp):
for L in g.LayersNeuronsInfo:
Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}})
for i in range(L["neuronNbr"]):
Nodes.append({'classes': 'neuron', 'data': {'id': L["layer"]+"_"+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % 5) * 70, 'y': (i // 5) * 70},'height': 20,'width': 20})
Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20})
# Add connections
return Nodes
def toMatrix(self, l,n):
def toMatrix(self, l):
""" 1D array to 2D
"""
n = int(math.sqrt(len(l)))
Matrix = [l[i:i+n] for i in range(0, len(l), n)]
return Matrix
......@@ -80,7 +82,7 @@ class layout(layoutOp):
self.MaxSpike.clear()
self.MaxSynapse.clear()
self.Spikes2D = self.generate2DView(self.g)
self.AccumulatedSpikes2D = [0 for n in self.Spikes2D if n["data"]["spiked"] != -1]
self.AccumulatedSpikes2D = {i:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i == n["data"]["parent"]] for i in self.g.Layer_Neuron if i != "Input"}
self.Max = 0
def Vis(self):
......@@ -149,9 +151,9 @@ class layout(layoutOp):
dcc.Dropdown(
id='GeneralLayerFilter',
options=[{'label': str(i), 'value': str(i)} for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
i for i in self.g.Layer_Neuron if i != "Input")],
value=[str(i) for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
i for i in self.g.Layer_Neuron if i != "Input")],
multi=True,
style={"minWidth": "20%","marginLeft": "5px", "textAlign": "start"})], className="d-flex", style={"paddingLeft": "20px", 'width': '100%'})
], className="col-12")
......@@ -173,27 +175,34 @@ class layout(layoutOp):
dcc.Tab(dbc.Card(
dbc.CardBody([
html.Div([
# Layers filter
html.P("Layers: ", style={
"textAlign": "start", "marginRight": "10px", "marginTop": "4px"}),
dcc.Dropdown(
id='2DViewLayerFilter',
options=[{'label': str(i), 'value': str(i)} for i in (
i for i in self.g.Layer_Neuron if i != "Input")],
value=[str(i) for i in (
i for i in self.g.Layer_Neuron if i != "Input")],
multi=True,
style={"minWidth": "80%", "textAlign": "start"}),
], style={"textAlign": "start", },className="d-flex col-lg-12 col-sm-12 col-xs-12"),
html.Div([
daq.PowerButton(
id="2DView-global-switch",
on='True',
size=30,
color="#28a745",
style={"marginLeft": "20px"}
),
html.P("Accumulated Spikes", style={
"textAlign": "start", "marginLeft": "10px", "marginTop": "4px"})], className="d-flex"),
html.Div([html.P("Accumulated Spikes", style={"margin":"0px"})]),
# Accumulated Spikes HeatMap
dcc.Graph(id='2DView-heatmap', config={"displaylogo": False}
),
], style={"textAlign": "start", "padding": "8px"}, className="col-lg-3 col-sm-12 col-xs-12")
dcc.Tabs([dcc.Tab(dbc.Card(dbc.CardBody([
dcc.Graph(id={"type":"2DView-heatmap","index":i}, config={"displaylogo": False})
])),label=i, value='2Dview-'+str(x)) for x, i in enumerate(self.g.Layer_Neuron) if i != "Input"],value="2Dview-1"),
], style={"textAlign": "start", }, className="col-lg-3 col-sm-12 col-xs-12")
,
html.Div(
[
#dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]),
html.Div([html.P("2D Space", style={"margin":"0px"})]),
html.Div([
cyto.Cytoscape(
id='cytoscape-compound',
layout={'name': 'grid','animate': False},
layout={'name': 'preset'},
boxSelectionEnabled=False,
style={'width': '100%',
'height': '100%'},
stylesheet=[
......@@ -257,22 +266,18 @@ class layout(layoutOp):
}
],
elements=self.Spikes2D
),
html.P(id="spikes_info", style={"padding": "8px"})
)],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}),
html.P(id="spikes_info", style={"padding": "8px","margin":"0px"})
], style={"background": "rgba(68, 71, 99, 0.05)", "height": "50vh", "textAlign": "start", "padding": "0px","marginBottom":"12px"},className="col-lg-6 col-sm-12 col-xs-12"),
], style={"height": "50vh", "textAlign": "start", "padding": "0px","marginBottom":"12px"},className="col-lg-6 col-sm-12 col-xs-12"),
# 3D destribution
html.Div([
# Layers filter
dcc.Dropdown(
id='2DViewLayerFilter',
options=[{'label': str(i), 'value': str(i)} for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
value=[str(i) for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
multi=True,
style={"minWidth": "20%", "textAlign": "start"}),
], style={"textAlign": "start", "padding": "8px"},className="col-lg-3 col-sm-12 col-xs-12")
html.Div([html.P("3D Destribution", style={"margin":"0px"})]),
dcc.Tabs([dcc.Tab(dbc.Card(dbc.CardBody([
dcc.Graph(id={"type":"30NeuronDestribution","index":i}, config={"displaylogo": False})
])),label=i, value='30Neuron-'+str(x)) for x, i in enumerate(self.g.Layer_Neuron) if i != "Input"],value="30Neuron-1")
], style={"textAlign": "start", },className="col-lg-3 col-sm-12 col-xs-12")
], className="row")), label="2D view", value="2Dview")], id="tabinfo", value="General information"),
......
......@@ -12,9 +12,9 @@ from collections import deque
import plotly.graph_objects as go
from dash import dcc, html
from bson.json_util import dumps, loads
from dash.exceptions import PreventUpdate
from dash.dependencies import ALL, MATCH, Input, Output, State
from src.templates.callbacksOp import callbacksOp
from dash import no_update
class callbacks(callbacksOp):
""" Callbacks class
"""
......@@ -185,7 +185,7 @@ class callbacks(callbacksOp):
isOnAcc (bool): whether spikes class graph is active or not
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
no_update: in case we don't want to update the content we rise this execption
Returns:
three graphs of the selected neuron
......@@ -213,7 +213,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
else:
if(sliderValue == 0):
data = getSpikesOfNeuron(
......@@ -229,7 +229,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
else:
# after adding to the screen
if(selectedItem["index"] not in super.xAxisSpikeNbrLabel):
......@@ -248,7 +248,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
except Exception:
print("processSpikeRelatedGraphs: " + traceback.format_exc())
......@@ -267,7 +267,7 @@ class callbacks(callbacksOp):
isOn (bool): whether this graph is active or not
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
return no_update: in case we don't want to update the content we rise this execption
Returns:
neuron potential graph of the selected neuron
......@@ -293,7 +293,7 @@ class callbacks(callbacksOp):
output = [neuronPotentialDrawGraph(selectedItem["index"], data, super.xAxisPotentialGraph[selectedItem["index"]], super.xAxisPotentialLabel[selectedItem["index"]], super.yAxisPotentialGraph, isOn)]
return output
else:
raise PreventUpdate
return no_update
else:
if(sliderValue == 0):
data = getPotentialOfNeuron(
......@@ -308,7 +308,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
else:
# after adding to the screen
if(selectedItem["index"] not in super.xAxisPotentialLabel):
......@@ -326,7 +326,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
except Exception:
print("processPotential: " + traceback.format_exc())
......
......@@ -10,7 +10,7 @@ import traceback
from dash import dcc, html
import dash_daq as daq
from bson.json_util import loads
from dash.exceptions import PreventUpdate
from dash import no_update
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import deque
......@@ -181,7 +181,7 @@ class callbacks(callbacksOp):
heatmapEditInfo (Array): heatmap edit stats
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
no_update: in case we don't want to update the content we rise this execption
Returns:
synapse frequency graph and heatmap content (data and layout)
......@@ -206,7 +206,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
else:
if(sliderValue == 0):
data = self.getSynapseWeights(int(sliderValue)*float(updateInterval), g, selectedItem)
......@@ -222,7 +222,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
else:
# after adding to the screen
if(selectedItem["index"] not in super.xAxisLabel):
......@@ -242,7 +242,7 @@ class callbacks(callbacksOp):
return output
else:
raise PreventUpdate
return no_update
except Exception:
print("processSynapseData: "+ traceback.format_exc())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment