diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py index b9284571abd2bfa61a22c3aacf6d75846e33f0c5..dbc14658484d51f8d5137779ce964522d2ed6a04 100644 --- a/src/Modules/General/callbacks.py +++ b/src/Modules/General/callbacks.py @@ -3,6 +3,7 @@ Dash callbacks are the responsible on updating graphs each step. """ +from collections import deque import dash import pymongo from bson.json_util import dumps @@ -32,13 +33,14 @@ class callbacks(): try: - def processGeneralGraph(data, sliderValue, generalGraphFilter): + def processGeneralGraph(data, sliderValue, generalGraphFilter, generalLayerFilter): """ Create general graph components. Args: data (array): data to be presented in the graph sliderValue (int): value of the slider generalGraphFilter (list): actual filter of GeneralGraph visualization + generalLayerFilter (list): selected layers for GeneralGraph visualization Returns: general graph content @@ -46,7 +48,6 @@ class callbacks(): X = 0 Xlabel = "" try: - if sliderValue != None: Xlabel = "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime( g.updateInterval, sliderValue+1)+"]" @@ -59,80 +60,87 @@ class callbacks(): X = super.InfoGraphX[-1]+1 super.InfoGraphX.append(X) - - i = 0 - graphs = [] + + graphs = {l:[] for l in generalLayerFilter} annotations = [] - - for f in generalGraphFilter: - if(f == "Spikes"): - - if data != None and data[i] != None: - super.SpikeGraphY.append(data[i]) - else: - super.SpikeGraphY.append(0) - - super.MaxSpike = max(super.MaxSpike, max( - super.SpikeGraphY) if super.SpikeGraphY else 0) - graphs.append( - go.Scatter( - x=list(super.InfoGraphX), - y=list([norm(i, super.MaxSpike) - for i in super.SpikeGraphY]), - fill='tozeroy' if len( - generalGraphFilter) == 1 else 'none', - line=dict(color="rgb(31, 119, 180)"), - name='Spikes', - mode='lines+markers', - text=list(super.xAxisLabel), - customdata=list(super.SpikeGraphY), - hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str( - super.MaxSpike), - )) - - if(f == "Synapses"): - if data != None and data[i] != None: - super.SynapseGraphY.append(data[i]) - else: - super.SynapseGraphY.append(0) - - super.MaxSynapse = max( - max(super.SynapseGraphY) if super.SynapseGraphY else 0, super.MaxSynapse) - graphs.append( - go.Scatter( - x=list(super.InfoGraphX), - y=list([norm(i, max(super.SynapseGraphY)) - for i in super.SynapseGraphY]), - fill='tozeroy' if len( - generalGraphFilter) == 1 else 'none', - line=dict(color="rgb(255, 127, 14)"), - name='Synapses update', - mode='lines+markers', - text=list(super.xAxisLabel), - customdata=list(super.SynapseGraphY), - hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse))) - if(f == "Potentials"): - if data != None and data[i] != None: - super.PotentialGraphY.append(data[i]) - else: - super.PotentialGraphY.append(0) - - super.MaxPotential = max( - max(super.PotentialGraphY) if super.PotentialGraphY else 0, super.MaxPotential) - graphs.append( - go.Scatter( - x=list(super.InfoGraphX), - y=list([norm(i, super.MaxPotential) - for i in super.PotentialGraphY]), - fill='tozeroy' if len( - generalGraphFilter) == 1 else 'none', - line=dict(color="rgb(44, 160, 44)"), - name="Neuron's potential update", - mode='lines+markers', - text=list(super.xAxisLabel), - customdata=list(super.PotentialGraphY), - hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential))) - i += 1 + for layer in generalLayerFilter: + i = 0 + for f in generalGraphFilter: + if(f == "Spikes"): + if(layer not in super.SpikeGraphY): + super.SpikeGraphY[layer] = deque(maxlen=100) + super.MaxSpike[layer] = 0 + if data != None and layer in data[i]: + super.SpikeGraphY[layer].append(data[i][layer]["spikes"]) + else: + super.SpikeGraphY[layer].append(0) + + super.MaxSpike[layer] = max(super.MaxSpike[layer], max( + super.SpikeGraphY[layer]) if super.SpikeGraphY[layer] else 0) + + graphs[layer].append( + go.Scatter( + x=list(super.InfoGraphX), + y=list([norm(i, super.MaxSpike[layer]) + for i in super.SpikeGraphY[layer]]), + fill='tozeroy' if len( + generalGraphFilter) == 1 else 'none', + line=dict(color="rgb(31, 119, 180)"), + name='Spikes', + mode='lines+markers', + text=list(super.xAxisLabel), + customdata=list(super.SpikeGraphY[layer]), + hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str( + super.MaxSpike[layer]))) + + if(f == "Synapses"): + if(layer not in super.SynapseGraphY): + super.SynapseGraphY[layer] = deque(maxlen=100) + super.MaxSynapse[layer] = 0 + if data != None and layer in data[i]: + super.SynapseGraphY[layer].append(data[i][layer]["synapseUpdate"]) + else: + super.SynapseGraphY[layer].append(0) + super.MaxSynapse[layer] = max( + max(super.SynapseGraphY[layer]) if super.SynapseGraphY[layer] else 0, super.MaxSynapse[layer]) + graphs[layer].append( + go.Scatter( + x=list(super.InfoGraphX), + y=list([norm(i, max(super.SynapseGraphY[layer])) + for i in super.SynapseGraphY[layer]]), + fill='tozeroy' if len( + generalGraphFilter) == 1 else 'none', + line=dict(color="rgb(255, 127, 14)"), + name='Synapses update', + mode='lines+markers', + text=list(super.xAxisLabel), + customdata=list(super.SynapseGraphY[layer]), + hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse[layer]))) + if(f == "Potentials"): + if(layer not in super.PotentialGraphY): + super.PotentialGraphY[layer] = deque(maxlen=100) + super.MaxPotential[layer] = 0 + if data != None and layer in data[i]: + super.PotentialGraphY[layer].append(data[i][layer]["potential"]) + else: + super.PotentialGraphY[layer].append(0) + + super.MaxPotential[layer] = max( + max(super.PotentialGraphY[layer]) if super.PotentialGraphY[layer] else 0, super.MaxPotential[layer]) + graphs[layer].append( + go.Scatter( + x=list(super.InfoGraphX), + y=list([norm(i, super.MaxPotential[layer]) + for i in super.PotentialGraphY[layer]]), + fill='tozeroy' if len( + generalGraphFilter) == 1 else 'none', + line=dict(color="rgb(44, 160, 44)"), + name="Neuron's potential update", + mode='lines+markers', + text=list(super.xAxisLabel), + customdata=list(super.PotentialGraphY[layer]), + hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential[layer]))) + i += 1 if(g.Labels != None): @@ -141,8 +149,8 @@ class callbacks(): else: super.LossGraphY.append(None) - fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=[ - [{'rowspan': 2}], [None], [{'rowspan': 3}], [None], [None]]) + fig = make_subplots(rows=1+(len(graphs)*1), cols=1, shared_xaxes=True, vertical_spacing=0.05, specs= + [[{'rowspan': 1}]] +[[{'rowspan': 1}] for l in graphs]) fig.add_trace( go.Scatter(x=list(super.InfoGraphX), y=list(super.LossGraphY), mode='lines', @@ -159,7 +167,7 @@ class callbacks(): row=1, col=1 ) - for graph in graphs: + for key,graph in graphs.items: fig.add_trace( graph, row=3, col=1) @@ -182,11 +190,14 @@ class callbacks(): annotations=annotations) else: fig = make_subplots( - rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05) + rows=len(graphs), cols=1, shared_xaxes=True, vertical_spacing=0.05) - for graph in graphs: - fig.add_trace( - graph, row=1, col=1) + l = 1 + for key,graphL in graphs.items(): + for graph in graphL: + fig.add_trace( + graph, row=l, col=1) + l +=1 fig['layout'].update( # TODO: change axis when there is no Loss @@ -197,16 +208,12 @@ class callbacks(): super.InfoGraphX) if super.InfoGraphX else 0], #rangeslider={'visible': True,'autorange': True}, # showticklabels=False, - tickvals=list(super.InfoGraphX), - ), - yaxis=dict( - range=[0, 105] - ), + tickvals=list(super.InfoGraphX)), + yaxis=dict(range=[0, 105]), showlegend=True, uirevision='no reset of zoom', margin={'l': 0, 'r': 0, 't': 30, 'b': 25}, - annotations=annotations, - ) + annotations=annotations) return fig except Exception as e: @@ -269,8 +276,8 @@ class callbacks(): [Input("vis-update", "n_intervals"), Input("btn-back", "n_clicks"), Input("btn-next", "n_clicks")], [State("vis-slider", "value"), State("interval", "value"), - State("GeneralGraphFilter", "value"), State("clear", "children")]) - def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, clearGraphs): + State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State("clear", "children")]) + def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, clearGraphs): """ This is the callback function. It is called each step. Args: @@ -280,6 +287,7 @@ class callbacks(): sliderValue (int): value of the slider updateInterval (int): update Interval (s) generalGraphFilter (list): actual filter of GeneralGraph visualization + generalLayerFilter (list): selected layers for GeneralGraph visualization clearGraphs (boolean): a dash state to pass information to all visualization about clearing content (if needed) Returns: @@ -311,6 +319,12 @@ class callbacks(): super.generalGraphFilterOld = generalGraphFilter super.clearData() clearGraphs = not clearGraphs + + if (super.generalLayerFilterOld != generalLayerFilter): + super.generalLayerFilterOld = generalLayerFilter + super.clearData() + clearGraphs = not clearGraphs + if abs(super.oldSliderValue - sliderValue) > 2: super.clearData() clearGraphs = not clearGraphs @@ -400,14 +414,15 @@ class callbacks(): @app.callback( [Output("general-graph", "figure") ], [Input("v-step", "children")], - [State("interval", "value"), State("GeneralGraphFilter", "value"), State('general-graph-switch', 'on')]) - def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalGraphSwitchIsOn): + [State("interval", "value"), State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State('general-graph-switch', 'on')]) + def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, generalGraphSwitchIsOn): """ Update the general graph. Args: sliderValue (int): value of the slider updateInterval (int): update Interval (s) generalGraphFilter (list): actual filter of GeneralGraph visualization + generalLayerFilter (list): selected layers for GeneralGraph visualization generalGraphSwitchIsOn (bool): general graph switch value Raises: @@ -422,14 +437,14 @@ class callbacks(): if(not super.visStopped): generalData = GeneralModuleData( - int(sliderValue)*float(updateInterval), generalGraphFilter) + int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter) if generalData == [None]: generalGraph = processGeneralGraph( - None, sliderValue, generalGraphFilter) + None, sliderValue, generalGraphFilter, generalLayerFilter) else: generalGraph = processGeneralGraph( - generalData, sliderValue, generalGraphFilter) + generalData, sliderValue, generalGraphFilter, generalLayerFilter) return [generalGraph] @@ -438,9 +453,9 @@ class callbacks(): raise PreventUpdate else: generalData = GeneralModuleData( - int(sliderValue)*float(updateInterval), generalGraphFilter) + int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter) generalGraph = processGeneralGraph( - generalData, int(sliderValue), generalGraphFilter) + generalData, int(sliderValue), generalGraphFilter, generalLayerFilter) return [generalGraph] else: raise PreventUpdate @@ -556,13 +571,13 @@ class callbacks(): """ return (data * 100)/Max if Max != 0 else data - def GeneralModuleData(timestamp, filter): + def GeneralModuleData(timestamp, filter, layers): """ Returns a list of graphs based on the specified filter. Args: timestamp (int): timestamp value filter (array): array of selected filters - + layers (array): array of selected layers Returns: list of graphs """ @@ -570,20 +585,19 @@ class callbacks(): for f in filter: if f == "Spikes": if res == []: - res = [getSpike(timestamp, g.updateInterval)] + res = [getSpike(timestamp, g.updateInterval,layers)] else: - res.append(getSpike(timestamp, g.updateInterval)) + res.append(getSpike(timestamp, g.updateInterval,layers)) if f == "Synapses": if res == []: - res = [getSynapse(timestamp, g.updateInterval)] + res = [getSynapse(timestamp, g.updateInterval,layers)] else: - res.append(getSynapse(timestamp, g.updateInterval)) + res.append(getSynapse(timestamp, g.updateInterval,layers)) if f == "Potentials": if res == []: - res = [getPotential(timestamp, g.updateInterval)] + res = [getPotential(timestamp, g.updateInterval,layers)] else: - res.append(getPotential( - timestamp, g.updateInterval)) + res.append(getPotential(timestamp, g.updateInterval,layers)) # get loss value if res == []: @@ -635,74 +649,87 @@ class callbacks(): return [L, Max] - def getSpike(timestamp, interval): + def getSpike(timestamp, interval, layer): """ Get spikes activity in a given interval. Args: timestamp (int): timestamp value interval (int): interval value + layers (array): array of selected layers Returns: array contains spikes """ # MongoDB--------------------- col = pymongo.collection.Collection(g.db, 'spikes') - spikes = col.find( - {"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count() + spikes = col.aggregate([ + {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}}, + {"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}} + ]) + # ---------------------------- # ToJson---------------------- spikes = loads(dumps(spikes)) # ---------------------------- - + spikes = {s["_id"]:s for s in spikes} if not spikes: return None return spikes - def getSynapse(timestamp, interval): + def getSynapse(timestamp, interval, layer): """ Get syanpse activity in a given interval. Args: timestamp (int): timestamp value interval (int): interval value + layers (array): array of selected layers Returns: array contains synapses activity """ # MongoDB--------------------- col = pymongo.collection.Collection(g.db, 'synapseWeight') - synapse = col.find( - {"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count() + synapse = col.aggregate([ + {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}}, + {"$group": {"_id": "$L","synapseUpdate": {"$sum":1}}},{"$sort": {"_id": 1}} + ]) + # ---------------------------- # ToJson---------------------- synapse = loads(dumps(synapse)) # ---------------------------- + synapse = {s["_id"]:s for s in synapse} if not synapse: return None return synapse - def getPotential(timestamp, interval): + def getPotential(timestamp, interval, layer): """ Get potential activity in a given interval. Args: timestamp (int): timestamp value interval (int): interval value + layers (array): array of selected layers Returns: array contains potential activity """ # MongoDB--------------------- col = pymongo.collection.Collection(g.db, 'potential') - potential = col.find( - {"T": {'$gt': timestamp, '$lte': (timestamp+interval)}}).count() + potential = col.aggregate([ + {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}}, + {"$group": {"_id": "$L","potential": {"$sum":1}}},{"$sort": {"_id": 1}} + ]) # ---------------------------- # ToJson---------------------- potential = loads(dumps(potential)) # ---------------------------- + potential = {p["_id"]:p for p in potential} if not potential: return None diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py index aa3a208057486e06839aea93691282bc8c91ade7..c18e30197ea55761092b11e0df9d41d788709277 100644 --- a/src/Modules/General/layout.py +++ b/src/Modules/General/layout.py @@ -17,16 +17,17 @@ class layout(): # InfoGraph Axis ------------------------------------------------- oldSliderValue = 0 generalGraphFilterOld = [] + generalLayerFilterOld = [] xAxisLabel = deque(maxlen=100) InfoGraphX = deque(maxlen=100) - SpikeGraphY = deque(maxlen=100) - SynapseGraphY = deque(maxlen=100) - PotentialGraphY = deque(maxlen=100) + SpikeGraphY = dict() + SynapseGraphY = dict() + PotentialGraphY = dict() LossGraphY = deque(maxlen=100) - MaxSpike = 0 - MaxPotential = 0 - MaxSynapse = 0 + MaxSpike = dict() + MaxPotential = dict() + MaxSynapse = dict() # LabelPie Data -------------------------------------------------- Label = [[], []] Max = 0 @@ -69,9 +70,9 @@ class layout(): self.SynapseGraphY.clear() self.xAxisLabel.clear() self.Label.clear() - self.MaxPotential = 0 - self.MaxSpike = 0 - self.MaxSynapse = 0 + self.MaxPotential = dict() + self.MaxSpike = dict() + self.MaxSynapse = dict() self.Max = 0 self.Nodes = [] self.Edges = [] @@ -128,13 +129,22 @@ class layout(): size=30, color="#28a745", style={"marginLeft": "10px"} - ), dcc.Dropdown( + ), + # Graphs filter + dcc.Dropdown( id='GeneralGraphFilter', options=[{'label': "Spikes", 'value': "Spikes"}, {'label': "Synapses update", 'value': "Synapses"}, { 'label': "Neurons potential update", 'value': "Potentials"}], value=['Spikes'], multi=True, - style={'width': '80%', "marginLeft": "10px", "textAlign": "start"})], className="row", style={"paddingLeft": "20px"}) + style={'width': '50%', "marginLeft": "10px", "textAlign": "start"}), + # Layers filter + dcc.Dropdown( + id='GeneralLayerFilter', + options=[{'label': str(i), 'value': str(i)} for i in (i for i in g.Layer_Neuron if i != "Input")], + value=[str(i) for i in (i for i in g.Layer_Neuron if i != "Input")], + multi=True, + style={'width': '50%', "marginLeft": "5px", "textAlign": "start"})], className="row", style={"paddingLeft": "20px"}) ], className="col-12") ], className="row"), html.Div([dcc.Graph(id='general-graph', animate=False, config={"displaylogo": False})])], className="col-lg-9 col-sm-12 col-xs-12" if(g.Labels != None) else "col-lg-12 col-sm-12 col-xs-12"),