Select Git revision
-
Hammouda Elbez authoredHammouda Elbez authored
layout.py 9.21 KiB
""" Create a dash layout for the module
"""
import pymongo
from itertools import product
from collections import deque
import traceback
from .callbacks import callbacks
from bson.json_util import dumps
from bson.json_util import loads
import plotly.graph_objects as go
from dash import dcc, html
import dash_bootstrap_components as dbc
from src.templates.layoutOp import layoutOp
class layout(layoutOp):
""" Layout class
"""
# Needed variables for the graphs --------------------------------
xAxisSpikeNbrGraph = dict()
xAxisSpikeNbrLabel = dict()
xAxisPotentialLabel = dict()
xAxisPotentialGraph = dict()
yAxisSpikeNbrGraph = dict()
yAxisPotentialGraph = dict()
def clearData(self, indexes):
""" Clear the data when moved forward or backward for more than one step
indexes (List) : Existing neurons that are displayed
"""
self.xAxisPotentialGraph.clear()
self.xAxisPotentialLabel.clear()
self.xAxisSpikeNbrGraph.clear()
self.xAxisSpikeNbrLabel.clear()
self.yAxisSpikeNbrGraph.clear()
self.yAxisPotentialGraph.clear()
for index in indexes:
self.xAxisPotentialGraph[index] = deque(maxlen=100)
self.xAxisPotentialLabel[index] = deque(maxlen=100)
self.xAxisSpikeNbrGraph[index] = deque(maxlen=100)
self.xAxisSpikeNbrLabel[index] = deque(maxlen=100)
self.yAxisSpikeNbrGraph[index] = deque(maxlen=100)
self.yAxisPotentialGraph[index] = deque(maxlen=100)
def Vis(self):
""" Create layer components
Args:
app : Flask app
g (Global_Var): reference to access global variables
Returns:
Dash app layer
"""
try:
self.clearData([])
if self.g.config.DEBUG:
print("neuron-vis")
layer = dbc.Card(
dbc.CardBody(
[
html.Div(id="neuron-vis", children=[
# Global show based on selected layer
html.Div([
dcc.Dropdown(
id='LayerFilterNeuron',
options=[{'label': str(i["layer"]), 'value': str(i["layer"])} for i in (
i for i in self.g.LayersNeuronsInfo)],
multi=False,
style={'width': '150px', "marginLeft": "10px", "textAlign": "start"}),
dcc.Dropdown(
id='NeuronFilterNeuron',
options=[],
multi=False,
style={'width': '150px', "marginLeft": "10px", "textAlign": "start"}),
dbc.Button(html.I(className="fa-solid fa-plus"), id="AddComponentNeuron", n_clicks=0, style={
"fontWeight": "500", "marginLeft": "20px", "height": "36px", "backgroundColor": "rgb(68, 71, 99)", "borderColor": "rgb(68, 71, 99)"}), html.Div(id='clear-Neuron', children="False", style={'display': 'none'}), html.Div(id='display-Neuron', children="False", style={'display': 'none'})
], className="d-flex"),
html.Div(id={'type': "GraphsAreaNeuron"}, children=[html.Div(id={'type': "OutputNeurons"}, children=[dcc.Graph(id="SpikePerNeuronFreq", figure=self.SpikePerNeuron3D(self.g), config={"displaylogo": False}, className="col-6"),
dcc.Graph(id="SpikePerNeuronNbr", config={"displaylogo": False}, className="col-6")], className="d-flex")], style={"textAlign": "-webkit-center", "paddingTop": "10px"}) if(self.g.finalLabels != None) else html.Div(id={'type': "GraphsAreaNeuron"}, children=[], style={"textAlign": "-webkit-center", "paddingTop": "10px"})])
], style={"textAlign": "center", "padding": "10px"}
))
# load callbacks
callbacks(self,self.app, self.g)
# Return the Layer
return layer
except Exception:
print("NeuronLayer: " + traceback.format_exc())
# ----------------------------------------------------------------
# Helper functions
# ----------------------------------------------------------------
def SpikePerNeuron3D(self, g):
""" Create the 3D spike per neuron view
Args:
g (Global_Var): reference to access global variables
Returns:
the 3D graph
"""
if(g.finalLabels == None):
return {'data': [],
'layout': {'margin': {'l': 0, 'r': 0, 't': 30, 'b': 0},
'scene': {
'xaxis_title': 'Neuron Id',
'yaxis_title': 'Spike Frequency',
'zaxis_title': 'Class'},
'title': 'No Labels detected'}}
else:
data = self.getSpikePerNeuron(self.g)
total = 0
for c in data:
total = total + c["count"]
xx = [N["i"]["N"] for N in data]
yy = [(count["count"]/total) for count in data]
zz = [int(c["Label"]) for i, c in product(data, g.finalLabels) if (int(c["N"]) == int(i["i"]["N"]) and c["L"] == i["i"]["L"])]
labels = list(dict.fromkeys(zz))
items = [[[item[0], item[1], item[2]]
for item in zip(xx, yy, zz) if item[2] == x] for x in labels]
fig = go.Figure()
for mesh in items:
fig.add_trace(go.Mesh3d(
x=[x[0] for x in mesh],
y=[round(y[1],4) for y in mesh],
z=[z[2] for z in mesh],
showlegend=True,
colorbar_title='z',
colorscale='rainbow',
opacity=0.7,
name=mesh[0][2],
hovertemplate="Neuron id: %{x} <br>Spike frequency: %{y}<br>Class: %{z}"))
fig.add_trace(go.Scatter3d(
x=[x[0] for x in mesh],
y=[y[1] for y in mesh],
z=[z[2] for z in mesh],
showlegend=False,
marker_size=2,
mode='markers',
opacity=0.8,
name=mesh[0][2],
hovertemplate="Neuron id: %{x} <br>Spike frequency: %{y}<br>Class: %{z}"))
fig.update_layout(
margin= dict(r=0, b=0, l=0, t=30),
scene= dict(
xaxis_title= 'Neuron id',
yaxis_title= 'Spike frequency',
zaxis_title= 'Class'),
title= 'Spike frequency per neuron of the output layer',
title_x=0.5
)
return fig
def SpikesSameClass(self, filteredClass, g):
""" Returns a graph of neurons spikes activity from selected class.
Args:
filteredClass (array): array contains information about the selected class
g (Global_Var): reference to access global variables
Returns:
graph of filtered class neurons activity
"""
if(filteredClass == None or g.finalLabels == None):
return {'data': [],
'layout': {'margin': {'l': 0, 'r': 0, 't': 30, 'b': 0}}}
else:
data = self.getSpikePerNeuron(self.g)
data = [d for d, l in product(data, g.finalLabels)
if (l["N"] == d['i']['N'] and l["Label"] == filteredClass["z"])]
xx = [N["i"]["N"] for N in data]
yy = [count["count"] for count in data]
graph = {'data': [go.Bar(
x=[x for x in range(len(xx))],
y=yy,
text=yy,
hoverinfo='text',
textposition='outside',
)],
'layout': {'margin': {'l': 60, 'r': 0, 't': 30, 'b': 30},
'xaxis': {'ticktext': xx, 'tickvals': [x for x in range(len(xx))], 'title': 'Neuron Id'},
'yaxis': {'title': 'Spike number'},
'uirevision': 'no reset of zoom',
'title': 'Spike number per neuron for class '+str(filteredClass["z"])}}
return graph
# ----------------------------------------------------------------
# MongoDB operations
# ----------------------------------------------------------------
def getSpikePerNeuron(self, g):
""" Get totale spikes per neuron.
Args:
g (Global_Var): reference to access global variables
Returns:
array contains totale spikes per neuron
"""
# MongoDB---------------------
col = pymongo.collection.Collection(g.db, 'SpikePerNeuron')
SpikePerNeuron = col.find()
# ----------------------------
# ToJson----------------------
SpikePerNeuron = loads(dumps(SpikePerNeuron))
# ----------------------------
if not SpikePerNeuron:
return None
return [info for info in SpikePerNeuron if(info["i"]["L"] == g.LayersNeuronsInfo[-1]["layer"])]