Select Git revision
test_gnt.cpp
Forked from
coa-2024 / coa-tp4-graphs
Source project has a limited visibility.
-
Maxime Lalisse authoredMaxime Lalisse authored
callbacks.py 41.86 KiB
""" This class contains Dash callbacks
Dash callbacks are the responsible on updating graphs each step.
"""
from collections import deque
import dash
import pymongo
import numpy as np
from bson.json_util import dumps
from bson.json_util import loads
import plotly.graph_objects as go
from dash.exceptions import PreventUpdate
from plotly.subplots import make_subplots
from dash.dependencies import Input, Output, State, MATCH
from src.templates.callbacksOp import callbacksOp
class callbacks(callbacksOp):
""" Callbacks class
"""
def __init__(self, super, app, g):
""" Initialize the callback .
Args:
app : Flask app
g (Global_Var): reference to access global variables
"""
# ------------------------------------------------------------
# Graph build functions
# ------------------------------------------------------------
# to prevent creating duplicate callbacks next time
if not g.checkExistance(app, "vis-slider"):
try:
def processGeneralGraph(data, sliderValue, generalGraphFilter, generalLayerFilter):
""" Create general graph components.
Args:
data (array): data to be presented in the graph
sliderValue (int): value of the slider
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
Returns:
general graph content
"""
X = 0
Xlabel = ""
try:
if sliderValue != None:
Xlabel = "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(
g.updateInterval, sliderValue+1)+"]"
if len(super.xAxisLabel) == 1 and super.xAxisLabel[0] == []:
super.xAxisLabel.clear()
super.xAxisLabel.append(Xlabel)
if len(super.InfoGraphX) > 0:
X = super.InfoGraphX[-1]+1
super.InfoGraphX.append(X)
graphs = {l:[] for l in generalLayerFilter}
annotations = []
for layer in generalLayerFilter:
i = 0
for f in generalGraphFilter:
if(f == "Spikes"):
if(layer not in super.SpikeGraphY):
super.SpikeGraphY[layer] = deque(maxlen=100)
super.MaxSpike[layer] = 0
if data[i] != None and layer in data[i]:
super.SpikeGraphY[layer].append(data[i][layer]["spikes"])
else:
super.SpikeGraphY[layer].append(0)
super.MaxSpike[layer] = max(super.MaxSpike[layer], max(
super.SpikeGraphY[layer]) if super.SpikeGraphY[layer] else 0)
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxSpike[layer])
for i in super.SpikeGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(31, 119, 180)"),
name='Spikes'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SpikeGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(
super.MaxSpike[layer])))
if(f == "Synapses"):
if(layer not in super.SynapseGraphY):
super.SynapseGraphY[layer] = deque(maxlen=100)
super.MaxSynapse[layer] = 0
if data[i] != None and layer in data[i]:
super.SynapseGraphY[layer].append(data[i][layer]["synapseUpdate"])
else:
super.SynapseGraphY[layer].append(0)
super.MaxSynapse[layer] = max(
max(super.SynapseGraphY[layer]) if super.SynapseGraphY[layer] else 0, super.MaxSynapse[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, max(super.SynapseGraphY[layer]))
for i in super.SynapseGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(255, 127, 14)"),
name='Synapses activity'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.SynapseGraphY[layer]),
hovertemplate="<span style='color:white;'>%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse[layer])))
if(f == "Potentials"):
if(layer not in super.PotentialGraphY):
super.PotentialGraphY[layer] = deque(maxlen=100)
super.MaxPotential[layer] = 0
if data[i] != None and layer in data[i]:
super.PotentialGraphY[layer].append(data[i][layer]["potential"])
else:
super.PotentialGraphY[layer].append(0)
super.MaxPotential[layer] = max(
max(super.PotentialGraphY[layer]) if super.PotentialGraphY[layer] else 0, super.MaxPotential[layer])
graphs[layer].append(
go.Scatter(
x=list(super.InfoGraphX),
y=list([norm(i, super.MaxPotential[layer])
for i in super.PotentialGraphY[layer]]),
fill='tozeroy' if len(
generalGraphFilter) == 1 else 'none',
line=dict(color="rgb(44, 160, 44)"),
name='Neuron\'s potential'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
mode='lines+markers',
text=list(super.xAxisLabel),
customdata=list(super.PotentialGraphY[layer]),
hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential[layer])))
i += 1
if(len(graphs) != 0):
if(g.finalLabels != None):
if(data[-1] != None):
super.LossGraphY.append(round(data[-1], 2))
else:
super.LossGraphY.append(None)
fig = make_subplots(rows=2+(len(graphs)*3), cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=
[[{'rowspan': 2}],[None]] + ([[{'rowspan': 3}],[None],[None]] * len(graphs)))
fig.add_trace(
go.Scatter(x=list(super.InfoGraphX), y=list(super.LossGraphY), mode='lines',
text=[
str(t)+' %' for t in list(super.LossGraphY)],
hoverinfo='text',
connectgaps=False,
xaxis='x2',
line={"color": "#dc3545",
"dash": "dot",
"width": 2},
name='Loss'
),
row=1, col=1
)
l = 3
for key,graphL in graphs.items():
for graph in graphL:
fig.add_trace(
graph, row=l, col=1)
l *=2
else:
fig = make_subplots(
rows=len(graphs), cols=1, shared_xaxes=True,vertical_spacing=0.05)
l = 1
for key,graphL in graphs.items():
for graph in graphL:
fig.add_trace(
graph, row=l, col=1)
l +=1
fig.update_xaxes(title_text="Step", row=(len(graphs) * 3) if g.finalLabels != None else len(graphs), col=1)
fig['layout'].update(
yaxis=dict(range=[0, 105]),
showlegend=True,
uirevision='no reset of zoom',
margin={'l': 0, 'r': 0, 't': 30, 'b': 25},
annotations=annotations)
else:
fig = make_subplots(rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=[[{'rowspan': 1}]])
fig.add_trace(
go.Scatter(x=list(), y=list(), mode='lines',
line={"color": "#dc3545",
"dash": "dot",
"width": 2},
),row=1, col=1)
return fig
except Exception as e:
print("processGeneralGraph "+str(e))
def processLabelInfoTreemap(data):
""" Return a Treemap visualization of actual inputs
Args:
data (array): Array contains the total value of processed inputs
and the actual inputs being processed.
Returns:
treemap visualization
"""
try:
if data != None:
super.Max = data[1]
super.Label = data[0]
else:
super.Label = dict()
return {
'data': [
go.Treemap(
labels=list(super.Label.keys()
) if super.Label != [] else [],
textfont_size=20,
textposition="middle center",
parents=["" for i in range(
len(super.Label))],
values=list(super.Label.values()
) if super.Label != [] else [],
hovertemplate='%{value}',
name='Label',
marker=dict(
colors=list(super.Label.keys()
) if super.Label != [] else [],
colorscale='RdBu',))],
'layout': go.Layout(
xaxis_type='category',
title_text='Processed Inputs: <b>' +
str(super.Max)+'</b>',
uirevision='no reset of zoom',
margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
)}
except Exception as e:
print("processLabelInfoTreemap:"+str(e))
# ----------------------------------------------------
# Callbacks
# ----------------------------------------------------
# Main Callback:
# 1- updates progress bar.
# 2- Listener for control buttons.
# 3- trigger update in all existing visualizations.
@app.callback(
[Output("text", "children"), Output("vis-slider", "value"), Output("vis-slider", "step"), Output("vis-slider", "max"), Output("v-step", "children"),
Output("clear", "children"), Output("interval", "value")],
[Input("vis-update", "n_intervals"),
Input("btn-back", "n_clicks"), Input("btn-next", "n_clicks")],
[State("vis-slider", "value"), State("interval", "value"),
State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State("clear", "children")])
def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, clearGraphs):
""" This is the callback function. It is called each step.
Args:
visUpdateInterval : interval instance that will cause this function to be called each step
backButton (int): number of clicks on the back button
nextButton (int): number of clicks on the next button
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
clearGraphs (boolean): a dash state to pass information to all visualization about clearing content (if needed)
Returns:
array of outputs that are selected in the callback
"""
try:
context = dash.callback_context.triggered[0]['prop_id'].split('.')[
0]
# update interval value if changed
if(g.updateInterval != float(updateInterval)):
super.clearData()
clearGraphs = not clearGraphs
if(float(updateInterval) < 0.005):
g.updateInterval = 0.005
elif (float(updateInterval) > 180.0):
g.updateInterval = 180.0
else:
g.updateInterval = float(updateInterval)
# update slider value when interval changed
sliderValue = sliderValue / g.stepMax
sliderValue = int(
sliderValue * (int(g.Max/g.updateInterval)+1))
g.stepMax = int(g.Max/g.updateInterval)+1
if (super.generalGraphFilterOld != generalGraphFilter):
super.generalGraphFilterOld = generalGraphFilter
super.clearData()
clearGraphs = not clearGraphs
if (super.generalLayerFilterOld != generalLayerFilter):
super.generalLayerFilterOld = generalLayerFilter
super.clearData()
clearGraphs = not clearGraphs
if abs(super.oldSliderValue - sliderValue) > 2:
super.clearData()
clearGraphs = not clearGraphs
super.oldSliderValue = sliderValue
if "btn" in context:
if context == "btn-back":
if sliderValue > 0:
sliderValue = sliderValue - 1
super.oldSliderValue = sliderValue
super.clearData()
clearGraphs = not clearGraphs
else:
if context == "btn-next" and dash.callback_context.triggered[0]['value'] != None:
if(sliderValue < g.stepMax):
sliderValue = sliderValue + 1
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
if(super.visStopped):
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
if(not super.visStopped):
sliderValue = sliderValue + 1
super.oldSliderValue = sliderValue
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
if sliderValue*g.updateInterval >= g.Max:
super.visStopped = True
else:
super.visStopped = False
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
else:
super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
g.updateInterval, sliderValue+1)+" ]"
return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
except Exception as e:
print("progress:" + str(e))
# Callback to controle play/stop button
@app.callback(
[Output("btnControle", "children"), Output("btnControle", "className"),Output("vis-update", "disabled")],
[Input("vis-update", "n_intervals"),Input("btnControle", "n_clicks")],[State("vis-slider", "value"),State("btnControle", "children")])
def progressButton(visUpdateInterval,playButton,sliderValue,playButtonText):
""" This is the callback function. It is called when play/stop button is clicked.
Args:
visUpdateInterval : interval instance that will cause this function to be called each step
playButton (int): number of clicks on the start/stop button
sliderValue (int): value of the slider
playButtonText (String): text on the start/stop button
Returns:
array of outputs that are selected in the callback
"""
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] == "btnControle":
if playButtonText == "Start":
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
super.visStopped = False
return ["Stop", "btn btn-danger", False]
else:
if(int(g.stepMax) >= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
raise PreventUpdate
else:
if(int(g.stepMax) <= sliderValue):
super.visStopped = True
return ["Start", "btn btn-success", True]
else:
raise PreventUpdate
# Callback to handle general graph content
@app.callback(
[Output("general-graph", "figure")
], [Input("v-step", "children")],
[State("interval", "value"), State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State('general-graph-switch', 'on')])
def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, generalGraphSwitchIsOn):
""" Update the general graph.
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
generalGraphFilter (list): actual filter of GeneralGraph visualization
generalLayerFilter (list): selected layers for GeneralGraph visualization
generalGraphSwitchIsOn (bool): general graph switch value
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
Returns:
content of the graph that contains general information on the network activity
"""
if generalGraphSwitchIsOn:
if len(super.xAxisLabel) > 0 and "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(g.updateInterval, sliderValue+1)+"]" == super.xAxisLabel[-1]:
raise PreventUpdate
if(not super.visStopped):
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
if generalData == [None]:
generalGraph = processGeneralGraph(
None, sliderValue, generalGraphFilter, generalLayerFilter)
else:
generalGraph = processGeneralGraph(
generalData, sliderValue, generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
if(sliderValue > g.stepMax):
raise PreventUpdate
else:
generalData = GeneralModuleData(
int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
generalGraph = processGeneralGraph(
generalData, int(sliderValue), generalGraphFilter, generalLayerFilter)
return [generalGraph]
else:
raise PreventUpdate
# Callback to handle label graph content
@app.callback(
[Output("label-graph", "figure")
], [Input("v-step", "children")],
[State("interval", "value"), State('label-graph-switch', 'on')])
def progressLabelGraph(sliderValue, updateInterval, labelGraphSwitchIsOn):
""" Update the label treemap graph.
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
labelGraphSwitchIsOn (bool): label graph switch value
Raises:
PreventUpdate: in case we don't want to update the content we rise this execption
Returns:
content of the Treemap that contains information on the network input
"""
if labelGraphSwitchIsOn:
labelData = getNetworkInput(
int(sliderValue)*float(updateInterval), g.updateInterval)
if(not super.visStopped):
if labelData == [None]:
labelInfoTreemap = processLabelInfoTreemap(
None)
else:
labelInfoTreemap = processLabelInfoTreemap(
labelData)
if sliderValue*g.updateInterval >= g.Max:
super.visStopped = True
return [labelInfoTreemap]
else:
super.visStopped = False
return [labelInfoTreemap]
else:
labelInfoTreemap = processLabelInfoTreemap(
labelData)
return [labelInfoTreemap]
else:
raise PreventUpdate
# Callback to update the speed of visualization
@app.callback(
[Output("vis-update", "interval"),Output("speed", "value")],
[Input("speed", "value")])
def speedControle(speedValue):
""" Store speed value in the shared holder 'vis-update'
Args:
speedValue (int): actual speed value
Returns:
speed value to be stored
"""
try:
if(speedValue < 0.25):
speedValue = 0.25
if (speedValue > 5):
speedValue = 5
return [speedValue * 1000,speedValue]
except Exception as e:
print("speedControle:" + str(e))
# Callback to handle information tab (open or close)
@app.callback(
[Output("collapse-info", "is_open")],
[Input("group-info-toggle", "n_clicks")], [State("collapse-info", "is_open")])
def informationTab(nbrClicks, isTabOpen):
""" Function called when the user clicks on the information tab.
Args:
nbrClicks (int): number of clicks on the tab
isTabOpen (bool): whether tab is open or not
Returns:
if information tab should be opened or closed
"""
try:
if dash.callback_context.triggered[0]["value"] != None:
return [not isTabOpen]
else:
return [isTabOpen]
except Exception as e:
print("informationTabController:" + str(e))
# Callback to handle the 2D view spiking visualization
@app.callback(
Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'tapNodeData'),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data"))
def animation2DView(visUpdateInterval,sliderValue, tapNodeData, updateInterval, elements, StoredData):
""" Function called each step to update the 2D view
Args:
sliderValue (int): value of the slider
updateInterval (int): update Interval (s)
visUpdateInterval : interval instance that will cause this function to be called each step
tapNodeData : contains data of the clickde node
elements : nodes description
StoredData : data stored for shared access
heatmapData : heatmap data
Returns:
if information tab should be opened or closed
"""
try:
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] == "v-step":
elements = StoredData[0]
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True)
if spikes:
maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes])
for spike in spikes:
if list(spike.keys())[0] == "Layer1":
# update the spikes neurons
i = 0
for element in elements[1:]:
if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2)
element["data"]["spikes"] = list(list(spike.values())[0].values())[0]
StoredData[1][i] = list(list(spike.values())[0].values())[0]
i+=1
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
else:
try:
return [elements,{'name': 'grid','animate': False},f"Neuron {tapNodeData['label']} : {tapNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
except Exception as e:
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData]
except Exception as e:
print("animation2DViewController:" + str(e))
except Exception as e:
print("Done loading:"+str(e))
try:
# ---------------------------------------------------------
# Helper functions
# ---------------------------------------------------------
def norm(data, Max):
""" Normalize data to be between 0 and Max.
Args:
data (array): array of values
Max (int): max value
Returns:
normalized array
"""
return (data * 100)/Max if Max != 0 else data
def GeneralModuleData(timestamp, filter, layers):
""" Returns a list of graphs based on the specified filter.
Args:
timestamp (int): timestamp value
filter (array): array of selected filters
layers (array): array of selected layers
Returns:
list of graphs
"""
res = []
for f in filter:
if f == "Spikes":
if res == []:
res = [getSpike(timestamp, g.updateInterval,layers,False)]
else:
res.append(getSpike(timestamp, g.updateInterval,layers,False))
if f == "Synapses":
if res == []:
res = [getSynapse(timestamp, g.updateInterval,layers)]
else:
res.append(getSynapse(timestamp, g.updateInterval,layers))
if f == "Potentials":
if res == []:
res = [getPotential(timestamp, g.updateInterval,layers)]
else:
res.append(getPotential(timestamp, g.updateInterval,layers))
# get loss value
if res == []:
res = [getLoss(timestamp, g.updateInterval)]
else:
res.append(getLoss(timestamp, g.updateInterval))
return res
# ---------------------------------------------------------
# MongoDB operations
# ---------------------------------------------------------
def getNetworkInput(timestamp, interval):
""" Get network input for a given timestamp and interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
Returns:
array contains totale processed inputs and current inputs
"""
# MongoDB --------------------------------------------
col = pymongo.collection.Collection(g.db, 'labels')
labels = col.aggregate([
{"$match": {
"T": {'$gt': timestamp, '$lte': (timestamp+interval)}
}},
{"$group": {"_id": "$L", "C": {"$sum": 1}, "G": {"$max": "$G"}}},
{"$sort": {"_id": 1}}
])
# ToJson ---------------------------------------------
labels = loads(dumps(labels))
# ----------------------------------------------------
Max = 0
for i in labels:
Max = max(Max, i["G"])
L = dict({i["_id"]: 0 for i in labels})
if not labels:
return None
for i in labels:
L[i["_id"]] = L[i["_id"]] + i["C"]
return [L, Max]
def getSpike(timestamp, interval, layer, perNeuron):
""" Get spikes activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
perNeuron (boolean): return global or perNeuron Spikes
Returns:
array contains spikes
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes')
if perNeuron:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": {"L":"$i.L","N":"$i.N"},"spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
])
else:
spikes = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
{"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ToJson----------------------
spikes = loads(dumps(spikes))
# ----------------------------
if perNeuron:
spikes = [{s["_id"]["L"]:{s["_id"]["N"]:s["spikes"]}} for s in spikes]
else:
spikes = {s["_id"]:s for s in spikes}
if not spikes:
return None
return spikes
def getSynapse(timestamp, interval, layer):
""" Get syanpse activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains synapses activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'synapseWeight')
synapse = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","synapseUpdate": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ----------------------------
# ToJson----------------------
synapse = loads(dumps(synapse))
# ----------------------------
synapse = {s["_id"]:s for s in synapse}
if not synapse:
return None
return synapse
def getPotential(timestamp, interval, layer):
""" Get potential activity in a given interval.
Args:
timestamp (int): timestamp value
interval (int): interval value
layers (array): array of selected layers
Returns:
array contains potential activity
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'potential')
potential = col.aggregate([
{"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
{"$group": {"_id": "$L","potential": {"$sum":1}}},{"$sort": {"_id": 1}}
])
# ----------------------------
# ToJson----------------------
potential = loads(dumps(potential))
# ----------------------------
potential = {p["_id"]:p for p in potential}
if not potential:
return None
return potential
def getLoss(timestamp, interval):
"""Get the loss of the spike i and step .
Args:
timestamp (int): timestamp value
interval (int): interval value
Returns:
loss value
"""
if g.finalLabels == None:
return None
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'spikes')
spikes = col.find(
{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}})
# ----------------------------
# ToJson----------------------
spikes = loads(dumps(spikes))
# ----------------------------
if not spikes:
return None
loss = 0
for a in spikes:
for l in g.finalLabels:
if (str(a["i"]["N"]) == str(l["N"]) and str(a["i"]["L"]) == l["L"]):
if (str(l["Label"]) != str(a["Input"])):
loss += 1
return min(100,round((loss*100) / len(spikes), 2))
# ---------------------------------------------------------------------
except Exception as e:
print("Helper functions and MongoDB operations: "+str(e))