Select Git revision
callbacks.py
-
Hammouda Elbez authoredHammouda Elbez authored
callbacks.py 46.69 KiB
""" This class contains Dash callbacks
Dash callbacks are the responsible on updating graphs each step.
"""
from collections import deque
import pymongo
import traceback
from bson.json_util import dumps
from bson.json_util import loads
import plotly.graph_objects as go
from dash import (no_update, Input, Output, State, ALL, callback_context)
from plotly.subplots import make_subplots
from src.templates.callbacksOp import callbacksOp
class callbacks(callbacksOp):
""" Callbacks class
"""
def __init__(self, super, app, g):
""" Initialize the callback .
Args:
app : Flask app
g (Global_Var): reference to access global variables
"""
# ------------------------------------------------------------
# Graph build functions
# ------------------------------------------------------------
# to prevent creating duplicate callbacks next time
if not g.checkExistance(app, "vis-slider"):
try:
def processGeneralGraph(data, sliderValue, generalGraphFilter, generalLayerFilter):
""" Create general graph components.
Args:
data (array): data to be presented in the graph
sliderValue (int): value of the slider
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
Returns:
general graph content
"""
X = 0
Xlabel = ""
try:
if sliderValue != None:
Xlabel = "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(
g.updateInterval, sliderValue+1)+"]"
if len(super.xAxisLabel) == 1 and super.xAxisLabel[0] == []:
super.xAxisLabel.clear()
super.xAxisLabel.append(Xlabel)
if len(super.InfoGraphX) > 0:
X = super.InfoGraphX[-1]+1
super.InfoGraphX.append(X)
graphs = {l:[] for l in generalLayerFilter}
annotations = []
for layer in generalLayerFilter:
i = 0
for f in generalGraphFilter:
if(f == "Spikes"):
if(layer not in super.SpikeGraphY):
super.SpikeGraphY[layer] = deque(maxlen=100)
super.MaxSpike[layer] = 0
if data[i] != None and layer in data[i]:
super.SpikeGraphY[layer].append(data[i][layer]["spikes"])
else:
super.SpikeGraphY[layer].append(0)
super.MaxSpike[layer] = max(super.MaxSpike[layer], max(
super.SpikeGraphY[layer]) if super.SpikeGraphY[layer] else 0)
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxSpike[layer])
for i in super.SpikeGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(31, 119, 180)"),
name='Spikes'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SpikeGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(
super.MaxSpike[layer])))
if(f == "Synapses"):
if(layer not in super.SynapseGraphY):
super.SynapseGraphY[layer] = deque(maxlen=100)
super.MaxSynapse[layer] = 0
if data[i] != None and layer in data[i]:
super.SynapseGraphY[layer].append(data[i][layer]["synapseUpdate"])
else:
super.SynapseGraphY[layer].append(0)
super.MaxSynapse[layer] = max(
max(super.SynapseGraphY[layer]) if super.SynapseGraphY[layer] else 0, super.MaxSynapse[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, max(super.SynapseGraphY[layer]))
for i in super.SynapseGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(255, 127, 14)"),
name='Synapses activity'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SynapseGraphY[layer]),
hovertemplate="<span style='color:white;'>%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse[layer])))
if(f == "Potentials"):
if(layer not in super.PotentialGraphY):
super.PotentialGraphY[layer] = deque(maxlen=100)
super.MaxPotential[layer] = 0
if data[i] != None and layer in data[i]:
super.PotentialGraphY[layer].append(data[i][layer]["potential"])
else:
super.PotentialGraphY[layer].append(0)
super.MaxPotential[layer] = max(
max(super.PotentialGraphY[layer]) if super.PotentialGraphY[layer] else 0, super.MaxPotential[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxPotential[layer])
for i in super.PotentialGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(44, 160, 44)"),
name='Neuron\'s potential'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.PotentialGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential[layer])))
i += 1
if(len(graphs) != 0):
if(g.finalLabels != None):
if(data[-1] != None):
super.LossGraphY.append(round(data[-1], 2))
else:
super.LossGraphY.append(None)
fig = make_subplots(rows=2+(len(graphs)*3), cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=
[[{'rowspan': 2}],[None]] + ([[{'rowspan': 3}],[None],[None]] * len(graphs)))
fig.add_trace(
go.Scatter(x=list(super.InfoGraphX), y=list(super.LossGraphY), mode='lines',
text=[
str(t)+' %' for t in list(super.LossGraphY)],
hoverinfo='text',
connectgaps=False,
xaxis='x2',
line={"color": "#dc3545",
"dash": "dot",
"width": 2},
name='Loss'
),
row=1, col=1
)
l = 3
for key,graphL in graphs.items():
for graph in graphL:
fig.add_trace(
graph, row=l, col=1)
l *=2
else:
fig = make_subplots(
rows=len(graphs), cols=1, shared_xaxes=True,vertical_spacing=0.05)
l = 1
for key,graphL in graphs.items():
for graph in graphL:
fig.add_trace(
graph, row=l, col=1)
l +=1
fig.update_xaxes(title_text="Step", row=(len(graphs) * 3) if g.finalLabels != None else len(graphs), col=1)
fig['layout'].update(
yaxis=dict(range=[0, 105]),
showlegend=True,
uirevision='no reset of zoom',
margin={'l': 0, 'r': 0, 't': 30, 'b': 25},
annotations=annotations)
else:
fig = make_subplots(rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=[[{'rowspan': 1}]])
fig.add_trace(
go.Scatter(x=list(), y=list(), mode='lines',
line={"color": "#dc3545",
"dash": "dot",
"width": 2},
),row=1, col=1)
return fig
except Exception:
print("processGeneralGraph " + traceback.format_exc())
def processLabelInfoTreemap(data):
""" Return a Treemap visualization of actual inputs
Args:
data (array): Array contains the total value of processed inputs
and the actual inputs being processed.
Returns:
treemap visualization
"""
try:
if data != None:
super.Max = data[1]
super.Label = data[0]
else:
super.Label = dict()
return {
'data': [
go.Treemap(
labels=list(super.Label.keys()
) if super.Label != [] else [],
textfont_size=20,
textposition="middle center",
parents=["" for i in range(
len(super.Label))],
values=list(super.Label.values()
) if super.Label != [] else [],
hovertemplate='%{value}',
name='Label',
marker=dict(
colors=list(super.Label.keys()
) if super.Label != [] else [],
colorscale='RdBu',))],
'layout': go.Layout(
xaxis_type='category',
title_text='Processed Inputs: <b>' +
str(super.Max)+'</b>',
uirevision='no reset of zoom',
margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
)}
except Exception:
print("processLabelInfoTreemap:"+ traceback.format_exc())
# ----------------------------------------------------
# Callbacks
# ----------------------------------------------------
# Main Callback:
# 1- updates progress bar.
# 2- Listener for control buttons.
# 3- trigger update in all existing visualizations.
@app.callback(
[Output("text", "children"), Output("vis-slider", "value"), Output("vis-slider", "step"), Output("vis-slider", "max"), Output("v-step", "children"),
Output("clear", "children"), Output("interval", "value")],
[Input("vis-update", "n_intervals"),
Input("btn-back", "n_clicks"), Input("btn-next", "n_clicks")],
[State("vis-slider", "value"), State("interval", "value"),
State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State("clear", "children")])
def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, clearGraphs):
""" This is the callback function. It is called each step.
Args:
visUpdateInterval : interval instance that will cause this function to be called each step
backButton (int): number of clicks on the back button
nextButton (int): number of clicks on the next button
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
clearGraphs (boolean): a dash state to pass information to all visualization about clearing content (if needed)
Returns:
array of outputs that are selected in the callback
"""
try:
context = callback_context.triggered[0]['prop_id'].split('.')[
0]
# update interval value if changed
if(g.updateInterval != float(updateInterval)):
super.clearData()
clearGraphs = not clearGraphs
if(float(updateInterval) < 0.005):
g.updateInterval = 0.005
elif (float(updateInterval) > 180.0):
g.updateInterval = 180.0
else:
g.updateInterval = float(updateInterval)
# update slider value when interval changed
sliderValue = sliderValue / g.stepMax
sliderValue = int(
sliderValue * (int(g.Max/g.updateInterval)+1))
g.stepMax = int(g.Max/g.updateInterval)+1
if (super.generalGraphFilterOld != generalGraphFilter):
super.generalGraphFilterOld = generalGraphFilter
super.clearData()
clearGraphs = not clearGraphs
if (super.generalLayerFilterOld != generalLayerFilter):
super.generalLayerFilterOld = generalLayerFilter
super.clearData()
clearGraphs = not clearGraphs
if abs(super.oldSliderValue - sliderValue) > 2:
super.clearData()
clearGraphs = not clearGraphs
super.oldSliderValue = sliderValue
if "btn" in context:
if context == "btn-back":
if sliderValue > 0:
sliderValue = sliderValue - 1
super.oldSliderValue = sliderValue
super.clearData()
clearGraphs = not clearGraphs
else:
if context == "btn-next" and callback_context.triggered[0]['value'] != None:
if(sliderValue < g.stepMax):
sliderValue = sliderValue + 1
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
if(super.visStopped):
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
if(not super.visStopped):
sliderValue = sliderValue + 1
super.oldSliderValue = sliderValue
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
if sliderValue*g.updateInterval >= g.Max:
super.visStopped = True
else:
super.visStopped = False
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
except Exception:
print("progress:" + traceback.format_exc())
# Callback to controle play/stop button
@app.callback(
[Output("btnControle", "children"), Output("btnControle", "className"),Output("vis-update", "disabled")],
[Input("vis-update", "n_intervals"),Input("btnControle", "n_clicks")],[State("vis-slider", "value"),State("btnControle", "children")])
def progressButton(visUpdateInterval,playButton,sliderValue,playButtonText):
""" This is the callback function. It is called when play/stop button is clicked.
Args:
visUpdateInterval : interval instance that will cause this function to be called each step
playButton (int): number of clicks on the start/stop button
sliderValue (int): value of the slider
playButtonText (String): text on the start/stop button
Returns:
array of outputs that are selected in the callback
"""
if callback_context.triggered[0]['prop_id'].split('.')[0] == "btnControle":
if playButtonText == "Start":
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
super.visStopped = False
return ["Stop", "btn btn-danger", False]
else:
if(int(g.stepMax) >= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
return no_update
else:
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
return no_update
# Callback to handle general graph content
@app.callback(
[Output("general-graph", "figure")
], [Input("v-step", "children")],
[State("interval", "value"), State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State('general-graph-switch', 'on')])
def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, generalGraphSwitchIsOn):
""" Update the general graph.
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
generalGraphSwitchIsOn (bool): general graph switch value
Raises:
no_update: in case we don't want to update the content we rise this execption
Returns:
content of the graph that contains general information on the network activity
"""
if generalGraphSwitchIsOn:
if len(super.xAxisLabel) > 0 and "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(g.updateInterval, sliderValue+1)+"]" == super.xAxisLabel[-1]:
return no_update
if(not super.visStopped):
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
if generalData == [None]:
generalGraph = processGeneralGraph(
None, sliderValue, generalGraphFilter, generalLayerFilter)
else:
generalGraph = processGeneralGraph(
generalData, sliderValue, generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
if(sliderValue > g.stepMax):
return no_update
else:
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
generalGraph = processGeneralGraph(
generalData, int(sliderValue), generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
return no_update
# Callback to handle label graph content
@app.callback(
[Output("label-graph", "figure")
], [Input("v-step", "children")],
[State("interval", "value"), State('label-graph-switch', 'on')])
def progressLabelGraph(sliderValue, updateInterval, labelGraphSwitchIsOn):
""" Update the label treemap graph.
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
labelGraphSwitchIsOn (bool): label graph switch value
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
Returns:
content of the Treemap that contains information on the network input
"""
if labelGraphSwitchIsOn:
labelData = getNetworkInput(
int(sliderValue)*float(updateInterval), g.updateInterval)
if(not super.visStopped):
if labelData == [None]:
labelInfoTreemap = processLabelInfoTreemap(
None)
else:
labelInfoTreemap = processLabelInfoTreemap(
labelData)
if sliderValue*g.updateInterval >= g.Max:
super.visStopped = True
return [labelInfoTreemap]
else:
super.visStopped = False
return [labelInfoTreemap]
else:
labelInfoTreemap = processLabelInfoTreemap(
labelData)
return [labelInfoTreemap]
else:
return no_update
# Callback to update the speed of visualization
@app.callback(
[Output("vis-update", "interval"),Output("speed", "value")],
[Input("speed", "value")])
def speedControle(speedValue):
""" Store speed value in the shared holder 'vis-update'
Args:
speedValue (int): actual speed value
Returns:
speed value to be stored
"""
try:
if(speedValue < 0.25):
speedValue = 0.25
if (speedValue > 120):
speedValue = 120
return [speedValue * 1000,speedValue]
except Exception:
print("speedControle:" + traceback.format_exc())
# Callback to handle information tab (open or close)
@app.callback(
[Output("collapse-info", "is_open")],
[Input("group-info-toggle", "n_clicks")], [State("collapse-info", "is_open")])
def informationTab(nbrClicks, isTabOpen):
""" Function called when the user clicks on the information tab.
Args:
nbrClicks (int): number of clicks on the tab
isTabOpen (bool): whether tab is open or not
Returns:
if information tab should be opened or closed
"""
try:
if callback_context.triggered[0]["value"] != None:
return [not isTabOpen]
else:
return [isTabOpen]
except Exception:
print("informationTabController:" + traceback.format_exc())
# Callback to handle the 2D view spiking visualization
@app.callback(
Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
""" Function called each step to update the 2D view
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
visUpdateInterval : interval instance that will cause this function to be called each step
mouseOverNodeData : contains data of the hovered node
elements : nodes description
heatmapData : heatmap data
Layer2DViewFilter : selected layers
Returns:
if information tab should be opened or closed
"""
try:
elements = super.Spikes2D
matrix = {}
indices = {}
labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval)
if callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
super.SpikesActivityPerInput = {i:[[0 for j in range(g.nbrClasses+1)] for i in range(g.Layer_Neuron[i])] for i in g.Layer_Neuron if i != "Input"}
for element in elements:
if element["data"]['spiked'] != -1:
element["data"]["spiked"] = 0
element["data"]["spikes"] = 0
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,Layer2DViewFilter,True)
for layer in Layer2DViewFilter:
#neurons = [[0 for j in range(g.nbrClasses)] for i in range(g.Layer_Neuron[layer])]
try:
layerSpikes = [list(list(list(s.values())[0].values())[0].values())[0] for s in spikes if list(s.keys())[0] == layer]
except Exception:
layerSpikes = []
if spikes and layerSpikes:
maxSpike = max(layerSpikes)
for spike in spikes:
if list(spike.keys())[0] == layer:
# update the spikes neurons
i = 0
for element in elements:
if element["data"]['spiked'] != -1:
if (element["data"]["id"] == layer+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
spk = list(list(list(spike.values())[0].values())[0].values())[0]
element["data"]["spiked"] = round(spk / maxSpike,2)
element["data"]["spikes"] = spk
super.AccumulatedSpikes2D[layer][int(element["data"]["label"])] += spk
super.SpikesActivityPerInput[layer][list(list(spike.values())[0].keys())[0]][list(list(list(spike.values())[0].values())[0].keys())[0]] = spk
i+=1
matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
for layer in super.AccumulatedSpikes2D:
if layer not in matrix:
matrix[layer] = []
indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
return [elements,[],heatmaps,SpikesActivityPerInput]
else:
try:
for layer in Layer2DViewFilter:
matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
for layer in super.AccumulatedSpikes2D:
if layer not in matrix:
matrix[layer] = []
indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps,SpikesActivityPerInput]
except Exception:
print("OnHover:"+traceback.format_exc())
return no_update
except Exception:
print("animation2DViewController:" + traceback.format_exc())
except Exception:
print("Done loading:"+traceback.format_exc())
try:
# ---------------------------------------------------------
# Helper functions
# ---------------------------------------------------------
def norm(data, Max):
""" Normalize data to be between 0 and Max.
Args:
data (array): array of values
Max (int): max value
Returns:
normalized array
"""
return (data * 100)/Max if Max != 0 else data
def GeneralModuleData(timestamp, filter, layers):
""" Returns a list of graphs based on the specified filter.
Args:
timestamp (int): timestamp value
filter (array): array of selected filters
layers (array): array of selected layers
Returns:
list of graphs
"""
res = []
for f in filter:
if f == "Spikes":
if res == []:
res = [getSpike(timestamp, g.updateInterval,layers,False)]
else:
res.append(getSpike(timestamp, g.updateInterval,layers,False))
if f == "Synapses":
if res == []:
res = [getSynapse(timestamp, g.updateInterval,layers)]
else:
res.append(getSynapse(timestamp, g.updateInterval,layers))
if f == "Potentials":
if res == []:
res = [getPotential(timestamp, g.updateInterval,layers)]
else:
res.append(getPotential(timestamp, g.updateInterval,layers))
# get loss value
if res == []:
res = [getLoss(timestamp, g.updateInterval)]
else:
res.append(getLoss(timestamp, g.updateInterval))
return res
def make_SpikeActivityPerInput(layer,dataLabel):
fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]])
fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer], colorscale= 'Reds',hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1)
fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'})
if dataLabel == None:
fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
else:
fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
fig.update_xaxes(tickvals=[i for i in range(g.nbrClasses+1)])
return fig
# ---------------------------------------------------------
# MongoDB operations
# ---------------------------------------------------------
def getNetworkInput(timestamp, interval):
""" Get network input for a given timestamp and interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
Returns:
array contains totale processed inputs and current inputs
"""
# MongoDB --------------------------------------------
col = pymongo.collection.Collection(g.db, 'labels')
labels = col.aggregate([
{"$match": {
"T": {'$gt': timestamp, '$lte': (timestamp+interval)}
}},
{"$group": {"_id": "$L", "C": {"$sum": 1}, "G": {"$max": "$G"}}},
{"$sort": {"_id": 1}}
], allowDiskUse = True)
# ToJson ---------------------------------------------
labels = loads(dumps(labels))
# ----------------------------------------------------
Max = 0
for i in labels:
Max = max(Max, i["G"])
if not labels:
return None
L = dict({i["_id"]: i["C"] for i in labels})
return [L, Max]
def getSpike(timestamp, interval, layer, perNeuron):
""" Get spikes activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
perNeuron (boolean): return global or perNeuron Spikes
Returns:
array contains spikes
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes')
if perNeuron:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": {"L":"$i.L","N":"$i.N","Input":"$Input"},"spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
], allowDiskUse = True)
else:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
], allowDiskUse = True)
# ToJson----------------------
spikes = loads(dumps(spikes))
# ----------------------------
if perNeuron:
spikes = [{s["_id"]["L"]:{s["_id"]["N"]:{s["_id"]["Input"]:s["spikes"]}}} for s in spikes]
else:
spikes = {s["_id"]:s for s in spikes}
if not spikes:
return None
return spikes
def getSynapse(timestamp, interval, layer):
""" Get syanpse activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains synapses activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'synapseWeight')
synapse = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","synapseUpdate": {"$sum":1}}},{"$sort": {"_id": 1}}
], allowDiskUse = True)
# ----------------------------
# ToJson----------------------
synapse = loads(dumps(synapse))
# ----------------------------
synapse = {s["_id"]:s for s in synapse}
if not synapse:
return None
return synapse
def getPotential(timestamp, interval, layer):
""" Get potential activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains potential activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'potential')
potential = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","potential": {"$sum":1}}},{"$sort": {"_id": 1}}
], allowDiskUse = True)
# ----------------------------
# ToJson----------------------
potential = loads(dumps(potential))
# ----------------------------
potential = {p["_id"]:p for p in potential}
if not potential:
return None
return potential
def getLoss(timestamp, interval):
"""Get the loss of the spike i and step .
Args:
timestamp (int): timestamp value
interval (int): interval value
Returns:
loss value
"""
if g.finalLabels == None:
return None
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes')
spikes = col.find(
{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}})
# ----------------------------
# ToJson----------------------
spikes = loads(dumps(spikes))
# ----------------------------
if not spikes:
return None
loss = 0
for a in spikes:
for l in g.finalLabels:
if (str(a["i"]["N"]) == str(l["N"]) and str(a["i"]["L"]) == l["L"]):
if (str(l["Label"]) != str(a["Input"])):
loss += 1
return min(100,round((loss*100) / len(spikes), 2))
# ---------------------------------------------------------------------
except Exception:
print("Helper functions and MongoDB operations: "+ traceback.format_exc())