diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py index 4aa190e2689ff57baf385cd0a58a90e1b3293ca2..7a967ec0bd0e9b7344f787b1e085fbae8af34976 100755 --- a/src/Modules/General/callbacks.py +++ b/src/Modules/General/callbacks.py @@ -6,6 +6,7 @@ from collections import deque import dash import pymongo +import numpy as np from bson.json_util import dumps from bson.json_util import loads import plotly.graph_objects as go @@ -374,6 +375,7 @@ class callbacks(callbacksOp): """ This is the callback function. It is called when play/stop button is clicked. Args: + visUpdateInterval : interval instance that will cause this function to be called each step playButton (int): number of clicks on the start/stop button sliderValue (int): value of the slider playButtonText (String): text on the start/stop button @@ -542,6 +544,52 @@ class callbacks(callbacksOp): except Exception as e: print("informationTabController:" + str(e)) + # Callback to handle the 2D view spiking visualization + @app.callback( + Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"), + Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'tapNodeData'), + State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data")) + def animation2DView(visUpdateInterval,sliderValue, tapNodeData, updateInterval, elements, StoredData): + """ Function called each step to update the 2D view + + Args: + sliderValue (int): value of the slider + updateInterval (int): update Interval (s) + visUpdateInterval : interval instance that will cause this function to be called each step + tapNodeData : contains data of the clickde node + elements : nodes description + StoredData : data stored for shared access + heatmapData : heatmap data + + + Returns: + if information tab should be opened or closed + """ + try: + if dash.callback_context.triggered[0]['prop_id'].split('.')[0] == "v-step": + elements = StoredData[0] + spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True) + if spikes: + maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes]) + for spike in spikes: + if list(spike.keys())[0] == "Layer1": + # update the spikes neurons + i = 0 + for element in elements[1:]: + if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])): + element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2) + element["data"]["spikes"] = list(list(spike.values())[0].values())[0] + StoredData[1][i] = list(list(spike.values())[0].values())[0] + i+=1 + return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData] + else: + try: + return [elements,{'name': 'grid','animate': False},f"Neuron {tapNodeData['label']} : {tapNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData] + except Exception as e: + return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Portland')],"layout":{"yaxis":dict(autorange='reversed')}},StoredData] + except Exception as e: + print("animation2DViewController:" + str(e)) + except Exception as e: print("Done loading:"+str(e)) @@ -577,9 +625,9 @@ class callbacks(callbacksOp): for f in filter: if f == "Spikes": if res == []: - res = [getSpike(timestamp, g.updateInterval,layers)] + res = [getSpike(timestamp, g.updateInterval,layers,False)] else: - res.append(getSpike(timestamp, g.updateInterval,layers)) + res.append(getSpike(timestamp, g.updateInterval,layers,False)) if f == "Synapses": if res == []: res = [getSynapse(timestamp, g.updateInterval,layers)] @@ -641,30 +689,41 @@ class callbacks(callbacksOp): return [L, Max] - def getSpike(timestamp, interval, layer): + def getSpike(timestamp, interval, layer, perNeuron): """ Get spikes activity in a given interval. Args: timestamp (int): timestamp value interval (int): interval value layers (array): array of selected layers + perNeuron (boolean): return global or perNeuron Spikes Returns: array contains spikes """ # MongoDB--------------------- col = pymongo.collection.Collection(g.db, 'spikes') - spikes = col.aggregate([ + + if perNeuron: + spikes = col.aggregate([ + {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}}, + {"$group": {"_id": {"L":"$i.L","N":"$i.N"},"spikes": {"$sum":1}}},{"$sort": {"_id": 1}} + ]) + else: + spikes = col.aggregate([ {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}}, {"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}} ]) - - # ---------------------------- # ToJson---------------------- spikes = loads(dumps(spikes)) + # ---------------------------- - spikes = {s["_id"]:s for s in spikes} + if perNeuron: + spikes = [{s["_id"]["L"]:{s["_id"]["N"]:s["spikes"]}} for s in spikes] + else: + spikes = {s["_id"]:s for s in spikes} + if not spikes: return None diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py index cf113d4501cc9b138ef1cda3913d4cd4e647e49e..8c861ff6fb0bedf79540b1cfcc0246cab57974f7 100755 --- a/src/Modules/General/layout.py +++ b/src/Modules/General/layout.py @@ -53,11 +53,14 @@ class layout(layoutOp): for L in g.LayersNeuronsInfo: Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}}) for i in range(L["neuronNbr"]): - Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 10}, - 'position': {'x': 25*i, 'y': 0}}) + Nodes.append({'classes': 'neuron', 'data': {'id': L["layer"]+"_"+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % 5) * 50, 'y': (i // 5) * 50}}) # Add connections - + + def toMatrix(self, l,n): + """ 1D array to 2D + """ + return [l[i:i+n] for i in range(0, len(l), n)] def clearData(self): """ Clear the data when moved forward or backward for more than one step @@ -165,30 +168,51 @@ class layout(layoutOp): dcc.Tab(dbc.Card( dbc.CardBody([ html.Div( - [ + [dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]), cyto.Cytoscape( id='cytoscape-compound', - responsive=True, - layout={'name': 'grid'}, + layout={'name': 'grid','animate': False}, style={'width': '100%', 'height': '100%'}, stylesheet=[ { 'selector': 'node', - 'style': {'content': 'data(label)'} - }, + 'style': {'label': 'data(label)'} + }, { - 'selector': '.layers', - 'style': {'width': 5} + 'selector': '[spiked <= 1.0]', + 'style': { + 'background-color': 'rgb(70,227,70)', + } }, { - 'selector': '.neurons', - 'style': {'line-style': 'dashed'} - }, + 'selector': '[spiked < 0.8]', + 'style': { + 'background-color': 'rgb(100,227,100)' + } + }, { - 'selector': '[spiked = 10]', + 'selector': '[spiked < 0.6]', 'style': { - 'background-color': 'rgb(180,180,180)' + 'background-color': 'rgb(130,227,130)' + } + }, + { + 'selector': '[spiked < 0.4]', + 'style': { + 'background-color': 'rgb(160,227,160)' + } + }, + { + 'selector': '[spiked < 0.2]', + 'style': { + 'background-color': 'rgb(190,227,190)' + } + }, + { + 'selector': '[spiked = 0.0]', + 'style': { + 'background-color': 'rgb(199,197,197)' } }, { @@ -199,11 +223,25 @@ class layout(layoutOp): } ], elements=self.Nodes - ) + ), + html.P(id="spikes_info", style={"padding": "8px"}) - ], style={"background": "rgb(227, 245, 251)", "height": "60vh", "textAlign": "start", "padding": "0px", "width":"70%"}), + ], style={"background": "rgb(227, 245, 251)", "height": "50vh", "textAlign": "start", "padding": "0px","paddingBottom":"12px", "width":"70%"}), - html.Div([], style={"height": "60vh", "textAlign": "start", "padding": "0px", "width":"30%"}) + html.Div([ + # Layers filter + dcc.Dropdown( + id='2DViewLayerFilter', + options=[{'label': str(i), 'value': str(i)} for i in ( + i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))], + value=[str(i) for i in ( + i for i in self.g.Layer_Neuron if ("Input" not in i and "pool" not in i))], + multi=True, + style={"minWidth": "20%", "textAlign": "start"}), + # HeatMap + dcc.Graph(id='2DView-heatmap', config={"displaylogo": False} + ), + ], style={"height": "50vh", "textAlign": "start", "padding": "8px", "width":"30%"}) ], className="row")), label="2D view", value="2Dview")], id="tabinfo", value="General information"),