Skip to content
Snippets Groups Projects
Commit aa9bb56d authored by Elbez Hammouda's avatar Elbez Hammouda
Browse files

General module multilayer support v1

parent d2b31f7c
No related branches found
No related tags found
1 merge request!22V0.351
......@@ -3,6 +3,7 @@
Dash callbacks are the responsible on updating graphs each step.
"""
from collections import deque
import dash
import pymongo
from bson.json_util import dumps
......@@ -32,13 +33,14 @@ class callbacks():
try:
def processGeneralGraph(data, sliderValue, generalGraphFilter):
def processGeneralGraph(data, sliderValue, generalGraphFilter, generalLayerFilter):
""" Create general graph components.
Args:
data (array): data to be presented in the graph
sliderValue (int): value of the slider
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
Returns:
general graph content
......@@ -46,7 +48,6 @@ class callbacks():
X = 0
Xlabel = ""
try:
if sliderValue != None:
Xlabel = "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(
g.updateInterval, sliderValue+1)+"]"
......@@ -60,78 +61,85 @@ class callbacks():
super.InfoGraphX.append(X)
i = 0
graphs = []
graphs = {l:[] for l in generalLayerFilter}
annotations = []
for layer in generalLayerFilter:
i = 0
for f in generalGraphFilter:
if(f == "Spikes"):
if data != None and data[i] != None:
super.SpikeGraphY.append(data[i])
if(layer not in super.SpikeGraphY):
super.SpikeGraphY[layer] = deque(maxlen=100)
super.MaxSpike[layer] = 0
if data != None and layer in data[i]:
super.SpikeGraphY[layer].append(data[i][layer]["spikes"])
else:
super.SpikeGraphY.append(0)
super.SpikeGraphY[layer].append(0)
super.MaxSpike[layer] = max(super.MaxSpike[layer], max(
super.SpikeGraphY[layer]) if super.SpikeGraphY[layer] else 0)
super.MaxSpike = max(super.MaxSpike, max(
super.SpikeGraphY) if super.SpikeGraphY else 0)
graphs.append(
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxSpike)
for i in super.SpikeGraphY]),
y=list([norm(i, super.MaxSpike[layer])
for i in super.SpikeGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(31, 119, 180)"),
name='Spikes',
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SpikeGraphY),
customdata=list(super.SpikeGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(
super.MaxSpike),
))
super.MaxSpike[layer])))
if(f == "Synapses"):
if data != None and data[i] != None:
super.SynapseGraphY.append(data[i])
if(layer not in super.SynapseGraphY):
super.SynapseGraphY[layer] = deque(maxlen=100)
super.MaxSynapse[layer] = 0
if data != None and layer in data[i]:
super.SynapseGraphY[layer].append(data[i][layer]["synapseUpdate"])
else:
super.SynapseGraphY.append(0)
super.MaxSynapse = max(
max(super.SynapseGraphY) if super.SynapseGraphY else 0, super.MaxSynapse)
graphs.append(
super.SynapseGraphY[layer].append(0)
super.MaxSynapse[layer] = max(
max(super.SynapseGraphY[layer]) if super.SynapseGraphY[layer] else 0, super.MaxSynapse[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, max(super.SynapseGraphY))
for i in super.SynapseGraphY]),
y=list([norm(i, max(super.SynapseGraphY[layer]))
for i in super.SynapseGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(255, 127, 14)"),
name='Synapses update',
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SynapseGraphY),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse)))
customdata=list(super.SynapseGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse[layer])))
if(f == "Potentials"):
if data != None and data[i] != None:
super.PotentialGraphY.append(data[i])
if(layer not in super.PotentialGraphY):
super.PotentialGraphY[layer] = deque(maxlen=100)
super.MaxPotential[layer] = 0
if data != None and layer in data[i]:
super.PotentialGraphY[layer].append(data[i][layer]["potential"])
else:
super.PotentialGraphY.append(0)
super.PotentialGraphY[layer].append(0)
super.MaxPotential = max(
max(super.PotentialGraphY) if super.PotentialGraphY else 0, super.MaxPotential)
graphs.append(
super.MaxPotential[layer] = max(
max(super.PotentialGraphY[layer]) if super.PotentialGraphY[layer] else 0, super.MaxPotential[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxPotential)
for i in super.PotentialGraphY]),
y=list([norm(i, super.MaxPotential[layer])
for i in super.PotentialGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(44, 160, 44)"),
name="Neuron's potential update",
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.PotentialGraphY),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential)))
customdata=list(super.PotentialGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential[layer])))
i += 1
if(g.Labels != None):
......@@ -141,8 +149,8 @@ class callbacks():
else:
super.LossGraphY.append(None)
fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=[
[{'rowspan': 2}], [None], [{'rowspan': 3}], [None], [None]])
fig = make_subplots(rows=1+(len(graphs)*1), cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=
[[{'rowspan': 1}]] +[[{'rowspan': 1}] for l in graphs])
fig.add_trace(
go.Scatter(x=list(super.InfoGraphX), y=list(super.LossGraphY), mode='lines',
......@@ -159,7 +167,7 @@ class callbacks():
row=1, col=1
)
for graph in graphs:
for key,graph in graphs.items:
fig.add_trace(
graph, row=3, col=1)
......@@ -182,11 +190,14 @@ class callbacks():
annotations=annotations)
else:
fig = make_subplots(
rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05)
rows=len(graphs), cols=1, shared_xaxes=True, vertical_spacing=0.05)
for graph in graphs:
l = 1
for key,graphL in graphs.items():
for graph in graphL:
fig.add_trace(
graph, row=1, col=1)
graph, row=l, col=1)
l +=1
fig['layout'].update(
# TODO: change axis when there is no Loss
......@@ -197,16 +208,12 @@ class callbacks():
super.InfoGraphX) if super.InfoGraphX else 0],
#rangeslider={'visible': True,'autorange': True},
# showticklabels=False,
tickvals=list(super.InfoGraphX),
),
yaxis=dict(
range=[0, 105]
),
tickvals=list(super.InfoGraphX)),
yaxis=dict(range=[0, 105]),
showlegend=True,
uirevision='no reset of zoom',
margin={'l': 0, 'r': 0, 't': 30, 'b': 25},
annotations=annotations,
)
annotations=annotations)
return fig
except Exception as e:
......@@ -269,8 +276,8 @@ class callbacks():
[Input("vis-update", "n_intervals"),
Input("btn-back", "n_clicks"), Input("btn-next", "n_clicks")],
[State("vis-slider", "value"), State("interval", "value"),
State("GeneralGraphFilter", "value"), State("clear", "children")])
def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, clearGraphs):
State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State("clear", "children")])
def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, clearGraphs):
""" This is the callback function. It is called each step.
Args:
......@@ -280,6 +287,7 @@ class callbacks():
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
clearGraphs (boolean): a dash state to pass information to all visualization about clearing content (if needed)
Returns:
......@@ -311,6 +319,12 @@ class callbacks():
super.generalGraphFilterOld = generalGraphFilter
super.clearData()
clearGraphs = not clearGraphs
if (super.generalLayerFilterOld != generalLayerFilter):
super.generalLayerFilterOld = generalLayerFilter
super.clearData()
clearGraphs = not clearGraphs
if abs(super.oldSliderValue - sliderValue) > 2:
super.clearData()
clearGraphs = not clearGraphs
......@@ -400,14 +414,15 @@ class callbacks():
@app.callback(
[Output("general-graph", "figure")
], [Input("v-step", "children")],
[State("interval", "value"), State("GeneralGraphFilter", "value"), State('general-graph-switch', 'on')])
def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalGraphSwitchIsOn):
[State("interval", "value"), State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State('general-graph-switch', 'on')])
def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, generalGraphSwitchIsOn):
""" Update the general graph.
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
generalGraphSwitchIsOn (bool): general graph switch value
Raises:
......@@ -422,14 +437,14 @@ class callbacks():
if(not super.visStopped):
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter)
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
if generalData == [None]:
generalGraph = processGeneralGraph(
None, sliderValue, generalGraphFilter)
None, sliderValue, generalGraphFilter, generalLayerFilter)
else:
generalGraph = processGeneralGraph(
generalData, sliderValue, generalGraphFilter)
generalData, sliderValue, generalGraphFilter, generalLayerFilter)
return [generalGraph]
......@@ -438,9 +453,9 @@ class callbacks():
raise PreventUpdate
else:
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter)
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
generalGraph = processGeneralGraph(
generalData, int(sliderValue), generalGraphFilter)
generalData, int(sliderValue), generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
raise PreventUpdate
......@@ -556,13 +571,13 @@ class callbacks():
"""
return (data * 100)/Max if Max != 0 else data
def GeneralModuleData(timestamp, filter):
def GeneralModuleData(timestamp, filter, layers):
""" Returns a list of graphs based on the specified filter.
Args:
timestamp (int): timestamp value
filter (array): array of selected filters
layers (array): array of selected layers
Returns:
list of graphs
"""
......@@ -570,20 +585,19 @@ class callbacks():
for f in filter:
if f == "Spikes":
if res == []:
res = [getSpike(timestamp, g.updateInterval)]
res = [getSpike(timestamp, g.updateInterval,layers)]
else:
res.append(getSpike(timestamp, g.updateInterval))
res.append(getSpike(timestamp, g.updateInterval,layers))
if f == "Synapses":
if res == []:
res = [getSynapse(timestamp, g.updateInterval)]
res = [getSynapse(timestamp, g.updateInterval,layers)]
else:
res.append(getSynapse(timestamp, g.updateInterval))
res.append(getSynapse(timestamp, g.updateInterval,layers))
if f == "Potentials":
if res == []:
res = [getPotential(timestamp, g.updateInterval)]
res = [getPotential(timestamp, g.updateInterval,layers)]
else:
res.append(getPotential(
timestamp, g.updateInterval))
res.append(getPotential(timestamp, g.updateInterval,layers))
# get loss value
if res == []:
......@@ -635,74 +649,87 @@ class callbacks():
return [L, Max]
def getSpike(timestamp, interval):
def getSpike(timestamp, interval, layer):
""" Get spikes activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains spikes
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes')
spikes = col.find(
{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count()
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ----------------------------
# ToJson----------------------
spikes = loads(dumps(spikes))
# ----------------------------
spikes = {s["_id"]:s for s in spikes}
if not spikes:
return None
return spikes
def getSynapse(timestamp, interval):
def getSynapse(timestamp, interval, layer):
""" Get syanpse activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains synapses activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'synapseWeight')
synapse = col.find(
{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count()
synapse = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","synapseUpdate": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ----------------------------
# ToJson----------------------
synapse = loads(dumps(synapse))
# ----------------------------
synapse = {s["_id"]:s for s in synapse}
if not synapse:
return None
return synapse
def getPotential(timestamp, interval):
def getPotential(timestamp, interval, layer):
""" Get potential activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains potential activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'potential')
potential = col.find(
{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count()
potential = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","potential": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ----------------------------
# ToJson----------------------
potential = loads(dumps(potential))
# ----------------------------
potential = {p["_id"]:p for p in potential}
if not potential:
return None
......
......@@ -17,16 +17,17 @@ class layout():
# InfoGraph Axis -------------------------------------------------
oldSliderValue = 0
generalGraphFilterOld = []
generalLayerFilterOld = []
xAxisLabel = deque(maxlen=100)
InfoGraphX = deque(maxlen=100)
SpikeGraphY = deque(maxlen=100)
SynapseGraphY = deque(maxlen=100)
PotentialGraphY = deque(maxlen=100)
SpikeGraphY = dict()
SynapseGraphY = dict()
PotentialGraphY = dict()
LossGraphY = deque(maxlen=100)
MaxSpike = 0
MaxPotential = 0
MaxSynapse = 0
MaxSpike = dict()
MaxPotential = dict()
MaxSynapse = dict()
# LabelPie Data --------------------------------------------------
Label = [[], []]
Max = 0
......@@ -69,9 +70,9 @@ class layout():
self.SynapseGraphY.clear()
self.xAxisLabel.clear()
self.Label.clear()
self.MaxPotential = 0
self.MaxSpike = 0
self.MaxSynapse = 0
self.MaxPotential = dict()
self.MaxSpike = dict()
self.MaxSynapse = dict()
self.Max = 0
self.Nodes = []
self.Edges = []
......@@ -128,13 +129,22 @@ class layout():
size=30,
color="#28a745",
style={"marginLeft": "10px"}
), dcc.Dropdown(
),
# Graphs filter
dcc.Dropdown(
id='GeneralGraphFilter',
options=[{'label': "Spikes", 'value': "Spikes"}, {'label': "Synapses update", 'value': "Synapses"}, {
'label': "Neurons potential update", 'value': "Potentials"}],
value=['Spikes'],
multi=True,
style={'width': '80%', "marginLeft": "10px", "textAlign": "start"})], className="row", style={"paddingLeft": "20px"})
style={'width': '50%', "marginLeft": "10px", "textAlign": "start"}),
# Layers filter
dcc.Dropdown(
id='GeneralLayerFilter',
options=[{'label': str(i), 'value': str(i)} for i in (i for i in g.Layer_Neuron if i != "Input")],
value=[str(i) for i in (i for i in g.Layer_Neuron if i != "Input")],
multi=True,
style={'width': '50%', "marginLeft": "5px", "textAlign": "start"})], className="row", style={"paddingLeft": "20px"})
], className="col-12")
], className="row"),
html.Div([dcc.Graph(id='general-graph', animate=False, config={"displaylogo": False})])], className="col-lg-9 col-sm-12 col-xs-12" if(g.Labels != None) else "col-lg-12 col-sm-12 col-xs-12"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment