Skip to content
Snippets Groups Projects
Select Git revision
  • a343f7196c6d42644694372890a35b84b76bc934
  • master default protected
2 results

layout.py

Blame
  • layout.py 9.21 KiB
    """ Create a dash layout for the module
    """
    
    import pymongo
    from itertools import product
    from collections import deque
    import traceback
    from .callbacks import callbacks
    from bson.json_util import dumps
    from bson.json_util import loads
    import plotly.graph_objects as go
    from dash import dcc, html
    import dash_bootstrap_components as dbc
    from src.templates.layoutOp import layoutOp
    
    class layout(layoutOp):
        """ Layout class
        """
        # Needed variables for the graphs --------------------------------
        xAxisSpikeNbrGraph = dict()
        xAxisSpikeNbrLabel = dict()
        xAxisPotentialLabel = dict()
        xAxisPotentialGraph = dict()
        yAxisSpikeNbrGraph = dict()
        yAxisPotentialGraph = dict()
    
        def clearData(self, indexes):
            """ Clear the data when moved forward or backward for more than one step
    
            indexes (List) : Existing neurons that are displayed
            """
            self.xAxisPotentialGraph.clear()
            self.xAxisPotentialLabel.clear()
            self.xAxisSpikeNbrGraph.clear()
            self.xAxisSpikeNbrLabel.clear()
            self.yAxisSpikeNbrGraph.clear()
            self.yAxisPotentialGraph.clear()
    
            for index in indexes:
                self.xAxisPotentialGraph[index] = deque(maxlen=100)
                self.xAxisPotentialLabel[index] = deque(maxlen=100)
                self.xAxisSpikeNbrGraph[index] = deque(maxlen=100)
                self.xAxisSpikeNbrLabel[index] = deque(maxlen=100)
                self.yAxisSpikeNbrGraph[index] = deque(maxlen=100)
                self.yAxisPotentialGraph[index] = deque(maxlen=100)
    
        def Vis(self):
            """ Create layer components
    
            Args:
                app : Flask app
                g (Global_Var): reference to access global variables
    
            Returns:
                Dash app layer
            """
            try:
                self.clearData([])
                if self.g.config.DEBUG:
                    print("neuron-vis")
                layer = dbc.Card(
                    dbc.CardBody(
                        [
                            html.Div(id="neuron-vis", children=[
                                # Global show based on selected layer
                                html.Div([
                                    dcc.Dropdown(
                                        id='LayerFilterNeuron',
                                        options=[{'label': str(i["layer"]), 'value': str(i["layer"])} for i in (
                                            i for i in self.g.LayersNeuronsInfo)],
                                        multi=False,
                                        style={'width': '150px', "marginLeft": "10px", "textAlign": "start"}),
                                    dcc.Dropdown(
                                        id='NeuronFilterNeuron',
                                        options=[],
                                        multi=False,
                                        style={'width': '150px', "marginLeft": "10px", "textAlign": "start"}),
                                    dbc.Button(html.I(className="fa-solid fa-plus"), id="AddComponentNeuron", n_clicks=0, style={
                                        "fontWeight": "500", "marginLeft": "20px", "height": "36px", "backgroundColor": "rgb(68, 71, 99)", "borderColor": "rgb(68, 71, 99)"}), html.Div(id='clear-Neuron', children="False", style={'display': 'none'}), html.Div(id='display-Neuron', children="False", style={'display': 'none'})
                                ], className="d-flex"),
                                html.Div(id={'type': "GraphsAreaNeuron"}, children=[html.Div(id={'type': "OutputNeurons"}, children=[dcc.Graph(id="SpikePerNeuronFreq", figure=self.SpikePerNeuron3D(self.g), config={"displaylogo": False}, className="col-6"),
                                dcc.Graph(id="SpikePerNeuronNbr", config={"displaylogo": False}, className="col-6")], className="d-flex")], style={"textAlign": "-webkit-center", "paddingTop": "10px"}) if(self.g.finalLabels != None) else html.Div(id={'type': "GraphsAreaNeuron"}, children=[], style={"textAlign": "-webkit-center", "paddingTop": "10px"})])
                        ], style={"textAlign": "center", "padding": "10px"}
                    ))
    
                # load callbacks
                callbacks(self,self.app, self.g)
                # Return the Layer
                return layer
            except Exception:
                print("NeuronLayer: " + traceback.format_exc())
    
        # ----------------------------------------------------------------
        # Helper functions
        # ----------------------------------------------------------------
    
        def SpikePerNeuron3D(self, g):
            """ Create the 3D spike per neuron view
    
            Args:
                g (Global_Var): reference to access global variables
    
            Returns:
                the 3D graph
            """
            if(g.finalLabels == None):
                return {'data': [],
                        'layout': {'margin': {'l': 0, 'r': 0, 't': 30, 'b': 0},
                                   'scene': {
                            'xaxis_title': 'Neuron Id',
                            'yaxis_title': 'Spike Frequency',
                            'zaxis_title': 'Class'},
                    'title': 'No Labels detected'}}
            else:
    
                data = self.getSpikePerNeuron(self.g)
    
                total = 0
    
                for c in data:
                    total = total + c["count"]
    
                xx = [N["i"]["N"] for N in data]
                yy = [(count["count"]/total) for count in data]
                zz = [int(c["Label"]) for i, c in product(data, g.finalLabels) if (int(c["N"]) == int(i["i"]["N"]) and c["L"] == i["i"]["L"])]
    
                labels = list(dict.fromkeys(zz))
                items = [[[item[0], item[1], item[2]]
                          for item in zip(xx, yy, zz) if item[2] == x] for x in labels]
    
                fig = go.Figure()
                for mesh in items:
                    fig.add_trace(go.Mesh3d(
                        x=[x[0] for x in mesh],
                        y=[round(y[1],4) for y in mesh],
                        z=[z[2] for z in mesh],
                        showlegend=True,
                        colorbar_title='z',
                        colorscale='rainbow',
                        opacity=0.7,
                        name=mesh[0][2],
                        hovertemplate="Neuron id: %{x} <br>Spike frequency: %{y}<br>Class: %{z}"))
    
                    fig.add_trace(go.Scatter3d(
                        x=[x[0] for x in mesh],
                        y=[y[1] for y in mesh],
                        z=[z[2] for z in mesh],
                        showlegend=False,
                        marker_size=2,
                        mode='markers',
                        opacity=0.8,
                        name=mesh[0][2],
                        hovertemplate="Neuron id: %{x} <br>Spike frequency: %{y}<br>Class: %{z}"))
    
                fig.update_layout(
                    margin= dict(r=0, b=0, l=0, t=30),
                    scene= dict(
                        xaxis_title= 'Neuron id',
                        yaxis_title= 'Spike frequency',
                        zaxis_title= 'Class'),
                    title= 'Spike frequency per neuron of the output layer',
                    title_x=0.5
                )
                return fig
    
        def SpikesSameClass(self, filteredClass, g):
            """ Returns a graph of neurons spikes activity from selected class.
    
            Args:
                filteredClass (array): array contains information about the selected class
                g (Global_Var): reference to access global variables
    
            Returns:
                graph of filtered class neurons activity
            """
            if(filteredClass == None or g.finalLabels == None):
                return {'data': [],
                        'layout': {'margin': {'l': 0, 'r': 0, 't': 30, 'b': 0}}}
            else:
                data = self.getSpikePerNeuron(self.g)
                data = [d for d, l in product(data, g.finalLabels) 
                        if (l["N"] == d['i']['N'] and l["Label"] == filteredClass["z"])]
                xx = [N["i"]["N"] for N in data]
                yy = [count["count"] for count in data]
                graph = {'data': [go.Bar(
                    x=[x for x in range(len(xx))],
                    y=yy,
                    text=yy,
                    hoverinfo='text',
                    textposition='outside',
                )],
                    'layout': {'margin': {'l': 60, 'r': 0, 't': 30, 'b': 30},
                               'xaxis': {'ticktext': xx, 'tickvals': [x for x in range(len(xx))], 'title': 'Neuron Id'},
                               'yaxis': {'title': 'Spike number'},
                               'uirevision': 'no reset of zoom',
                               'title': 'Spike number per neuron for class '+str(filteredClass["z"])}}
                return graph
    
        # ----------------------------------------------------------------
        # MongoDB operations
        # ----------------------------------------------------------------
    
        def getSpikePerNeuron(self, g):
            """ Get totale spikes per neuron.
    
            Args:
                g (Global_Var): reference to access global variables
    
            Returns:
                array contains totale spikes per neuron
            """
            # MongoDB---------------------
            col = pymongo.collection.Collection(g.db, 'SpikePerNeuron')
            SpikePerNeuron = col.find()
            # ----------------------------
    
            # ToJson----------------------
            SpikePerNeuron = loads(dumps(SpikePerNeuron))
            # ----------------------------
    
            if not SpikePerNeuron:
                return None
    
            return [info for info in SpikePerNeuron if(info["i"]["L"] == g.LayersNeuronsInfo[-1]["layer"])]