Skip to content
Snippets Groups Projects
Commit 3152dcee authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

2D view working with two graphs

parent 19a9428b
No related branches found
No related tags found
1 merge request!26Custom 2d view
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from collections import deque from collections import deque
import dash import dash
import pymongo import pymongo
import numpy as np
from bson.json_util import dumps from bson.json_util import dumps
from bson.json_util import loads from bson.json_util import loads
import plotly.graph_objects as go import plotly.graph_objects as go
...@@ -374,6 +375,7 @@ class callbacks(callbacksOp): ...@@ -374,6 +375,7 @@ class callbacks(callbacksOp):
""" This is the callback function. It is called when play/stop button is clicked. """ This is the callback function. It is called when play/stop button is clicked.
Args: Args:
visUpdateInterval : interval instance that will cause this function to be called each step
playButton (int): number of clicks on the start/stop button playButton (int): number of clicks on the start/stop button
sliderValue (int): value of the slider sliderValue (int): value of the slider
playButtonText (String): text on the start/stop button playButtonText (String): text on the start/stop button
...@@ -542,6 +544,52 @@ class callbacks(callbacksOp): ...@@ -542,6 +544,52 @@ class callbacks(callbacksOp):
except Exception as e: except Exception as e:
print("informationTabController:" + str(e)) print("informationTabController:" + str(e))
# Callback to handle the 2D view spiking visualization
@app.callback(
Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'tapNodeData'),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data"))
def animation2DView(visUpdateInterval,sliderValue, tapNodeData, updateInterval, elements, StoredData):
""" Function called each step to update the 2D view
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
visUpdateInterval : interval instance that will cause this function to be called each step
tapNodeData : contains data of the clickde node
elements : nodes description
StoredData : data stored for shared access
heatmapData : heatmap data
Returns:
if information tab should be opened or closed
"""
try:
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] == "v-step":
elements = StoredData[0]
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True)
if spikes:
maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes])
for spike in spikes:
if list(spike.keys())[0] == "Layer1":
# update the spikes neurons
i = 0
for element in elements[1:]:
if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2)
element["data"]["spikes"] = list(list(spike.values())[0].values())[0]
StoredData[1][i] = list(list(spike.values())[0].values())[0]
i+=1
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
else:
try:
return [elements,{'name': 'grid','animate': False},f"Neuron {tapNodeData['label']} : {tapNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
except Exception as e:
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
except Exception as e:
print("animation2DViewController:" + str(e))
except Exception as e: except Exception as e:
print("Done loading:"+str(e)) print("Done loading:"+str(e))
...@@ -577,9 +625,9 @@ class callbacks(callbacksOp): ...@@ -577,9 +625,9 @@ class callbacks(callbacksOp):
for f in filter: for f in filter:
if f == "Spikes": if f == "Spikes":
if res == []: if res == []:
res = [getSpike(timestamp, g.updateInterval,layers)] res = [getSpike(timestamp, g.updateInterval,layers,False)]
else: else:
res.append(getSpike(timestamp, g.updateInterval,layers)) res.append(getSpike(timestamp, g.updateInterval,layers,False))
if f == "Synapses": if f == "Synapses":
if res == []: if res == []:
res = [getSynapse(timestamp, g.updateInterval,layers)] res = [getSynapse(timestamp, g.updateInterval,layers)]
...@@ -641,30 +689,41 @@ class callbacks(callbacksOp): ...@@ -641,30 +689,41 @@ class callbacks(callbacksOp):
return [L, Max] return [L, Max]
def getSpike(timestamp, interval, layer): def getSpike(timestamp, interval, layer, perNeuron):
""" Get spikes activity in a given interval. """ Get spikes activity in a given interval.
Args: Args:
timestamp (int): timestamp value timestamp (int): timestamp value
interval (int): interval value interval (int): interval value
layers (array): array of selected layers layers (array): array of selected layers
perNeuron (boolean): return global or perNeuron Spikes
Returns: Returns:
array contains spikes array contains spikes
""" """
# MongoDB--------------------- # MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes') col = pymongo.collection.Collection(g.db, 'spikes')
spikes = col.aggregate([
if perNeuron:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": {"L":"$i.L","N":"$i.N"},"spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
])
else:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}}, {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}} {"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
]) ])
# ----------------------------
# ToJson---------------------- # ToJson----------------------
spikes = loads(dumps(spikes)) spikes = loads(dumps(spikes))
# ---------------------------- # ----------------------------
spikes = {s["_id"]:s for s in spikes} if perNeuron:
spikes = [{s["_id"]["L"]:{s["_id"]["N"]:s["spikes"]}} for s in spikes]
else:
spikes = {s["_id"]:s for s in spikes}
if not spikes: if not spikes:
return None return None
......
...@@ -53,11 +53,14 @@ class layout(layoutOp): ...@@ -53,11 +53,14 @@ class layout(layoutOp):
for L in g.LayersNeuronsInfo: for L in g.LayersNeuronsInfo:
Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}}) Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}})
for i in range(L["neuronNbr"]): for i in range(L["neuronNbr"]):
Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 10}, Nodes.append({'classes': 'neuron', 'data': {'id': L["layer"]+"_"+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % 5) * 50, 'y': (i // 5) * 50}})
'position': {'x': 25*i, 'y': 0}})
# Add connections # Add connections
def toMatrix(self, l,n):
""" 1D array to 2D
"""
return [l[i:i+n] for i in range(0, len(l), n)]
def clearData(self): def clearData(self):
""" Clear the data when moved forward or backward for more than one step """ Clear the data when moved forward or backward for more than one step
...@@ -165,30 +168,51 @@ class layout(layoutOp): ...@@ -165,30 +168,51 @@ class layout(layoutOp):
dcc.Tab(dbc.Card( dcc.Tab(dbc.Card(
dbc.CardBody([ dbc.CardBody([
html.Div( html.Div(
[ [dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]),
cyto.Cytoscape( cyto.Cytoscape(
id='cytoscape-compound', id='cytoscape-compound',
responsive=True, layout={'name': 'grid','animate': False},
layout={'name': 'grid'},
style={'width': '100%', style={'width': '100%',
'height': '100%'}, 'height': '100%'},
stylesheet=[ stylesheet=[
{ {
'selector': 'node', 'selector': 'node',
'style': {'content': 'data(label)'} 'style': {'label': 'data(label)'}
}, },
{ {
'selector': '.layers', 'selector': '[spiked <= 1.0]',
'style': {'width': 5} 'style': {
'background-color': 'rgb(70,227,70)',
}
}, },
{ {
'selector': '.neurons', 'selector': '[spiked < 0.8]',
'style': {'line-style': 'dashed'} 'style': {
}, 'background-color': 'rgb(100,227,100)'
}
},
{ {
'selector': '[spiked = 10]', 'selector': '[spiked < 0.6]',
'style': { 'style': {
'background-color': 'rgb(180,180,180)' 'background-color': 'rgb(130,227,130)'
}
},
{
'selector': '[spiked < 0.4]',
'style': {
'background-color': 'rgb(160,227,160)'
}
},
{
'selector': '[spiked < 0.2]',
'style': {
'background-color': 'rgb(190,227,190)'
}
},
{
'selector': '[spiked = 0.0]',
'style': {
'background-color': 'rgb(199,197,197)'
} }
}, },
{ {
...@@ -199,11 +223,25 @@ class layout(layoutOp): ...@@ -199,11 +223,25 @@ class layout(layoutOp):
} }
], ],
elements=self.Nodes elements=self.Nodes
) ),
html.P(id="spikes_info", style={"padding": "8px"})
], style={"background": "rgb(227, 245, 251)", "height": "60vh", "textAlign": "start", "padding": "0px", "width":"70%"}), ], style={"background": "rgb(227, 245, 251)", "height": "50vh", "textAlign": "start", "padding": "0px","paddingBottom":"12px", "width":"70%"}),
html.Div([], style={"height": "60vh", "textAlign": "start", "padding": "0px", "width":"30%"}) html.Div([
# Layers filter
dcc.Dropdown(
id='2DViewLayerFilter',
options=[{'label': str(i), 'value': str(i)} for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
value=[str(i) for i in (
i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))],
multi=True,
style={"minWidth": "20%", "textAlign": "start"}),
# HeatMap
dcc.Graph(id='2DView-heatmap', config={"displaylogo": False}
),
], style={"height": "50vh", "textAlign": "start", "padding": "8px", "width":"30%"})
], className="row")), label="2D view", value="2Dview")], id="tabinfo", value="General information"), ], className="row")), label="2D view", value="2Dview")], id="tabinfo", value="General information"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment