Skip to content
Snippets Groups Projects
Select Git revision
  • 1e8a2fe8a7677815e5db1f27dd55f2ac2e7ee21c
  • master default protected
2 results

layout.py

Blame
  • layout.py 21.20 KiB
    """ Create a dash layout for the module
    """
    
    import importlib
    import traceback
    import math
    import dash_daq as daq
    from collections import deque
    import dash_cytoscape as cyto
    from .callbacks import callbacks
    from dash import dcc
    from dash import html
    import dash_bootstrap_components as dbc
    from src.templates.layoutOp import layoutOp
    
    class layout(layoutOp):
        """ Layout class
        """
        # InfoGraph Axis -------------------------------------------------
        oldSliderValue = 0
        generalGraphFilterOld = []
        generalLayerFilterOld = []
        xAxisLabel = deque(maxlen=100)
        InfoGraphX = deque(maxlen=100)
        SpikeGraphY = dict()
        SynapseGraphY = dict()
        PotentialGraphY = dict()
        LossGraphY = deque(maxlen=100)
    
        MaxSpike = dict()
        MaxPotential = dict()
        MaxSynapse = dict()
        AccumulatedSpikes2D = []
        Spikes2D = dict()
        # LabelPie Data --------------------------------------------------
        Label = [[], []]
        Max = 0
        # ----------------------------------------------------------------
        tabs = []
        label = " "
        visStopped = True
        tabLoaded = False
        resetGraphs = False
        # ----------------------------------------------------------------
    
        # 2D view --------------------------------------------------------
    
        def generate2DView(self, g):
            """ Generates a 2D View of the neural network
    
            Args:
                g (Global_Var): reference to access global variables
            """
            Nodes = []
            # Create the neurones and layers
            for L in g.LayersNeuronsInfo:
                Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}})
                for i in range(L["neuronNbr"]):
                    Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20})
            
            # Add connections
            return Nodes
        
        def toMatrix(self, l):
            """ 1D array to 2D
            """
            n = int(math.sqrt(len(l)))
            Matrix = [l[i:i+n] for i in range(0, len(l), n)]
            return Matrix
        
        def clearData(self):
            """ Clear the data when moved forward or backward for more than one step
            """
            self.InfoGraphX.clear()
            self.SpikeGraphY.clear()
            self.LossGraphY.clear()
            self.SynapseGraphY.clear()
            self.PotentialGraphY.clear()
            self.xAxisLabel.clear()
            self.Label.clear()
            self.MaxPotential.clear()
            self.MaxSpike.clear()
            self.MaxSynapse.clear()
            self.Spikes2D = self.generate2DView(self.g)
            self.AccumulatedSpikes2D = {i:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i == n["data"]["parent"]] for i in self.g.Layer_Neuron if i != "Input"}
            self.Max = 0
    
        def Vis(self):
            """ Create layer components
    
            Args:
                app : Flask app
                g (Global_Var): reference to access global variables
    
            Returns:
                Dash app layer
            """
            self.clearData()
    
            # Table of network information -------------------------------
            tablecontent = [
                html.Tr([html.Td("Neurons", style={"width": "25%", "fontWeight": "500"}), html.Td(self.g.NeuronsNbr, style={"width": "25%"}), html.Td(
                    "Layers", style={"width": "25%", "fontWeight": "500"}), html.Td(self.g.LayersNbr, style={"width": "25%"})]),
                html.Tr([html.Td("Input", style={"width": "25%", "fontWeight": "500"}), html.Td(self.g.Input, style={"width": "25%"}), html.Td(
                    "Dataset", style={"width": "25%", "fontWeight": "500"}), html.Td(self.g.Dataset, style={"width": "25%"})]),
                html.Tr([html.Td("Simulation Date", style={"width": "25%", "fontWeight": "500"}), html.Td(self.g.Date, style={"width": "25%"}), html.Td(
                    "Accuracy", style={"width": "25%", "fontWeight": "500"}), html.Td(str(self.g.Accuracy)+" %", style={"width": "25%"})])
            ]
    
            # Tabs content -----------------------------------------------
            info_vis = dbc.Card(
                dbc.CardBody([dbc.Card([dbc.CardHeader(
                                        dbc.Button(
                                            "Information",
                                            color="none",
                                            id="group-info-toggle",
                                            style={
                                                "width": "100%", "height": "100%", "padding": "10px"}
                                        ), style={"padding": "0px", }),
                                        dbc.Collapse(
                                        dbc.CardBody([
                                            dbc.Table(html.Tbody(tablecontent, id="TableBody"), striped=True, bordered=True, responsive=True)]),
                                        id="collapse-info",
                                    )]),
                    dcc.Tabs([
                        dcc.Tab(dbc.Card(
                                dbc.CardBody([
                                    html.Div([html.Div([
                                        html.Div([
                                            html.Div([
                                                    html.Div([daq.PowerButton(
                                                        id="general-graph-switch",
                                                        on='True',
                                                        size=30,
                                                        color="#28a745",
                                                        style={"marginLeft": "10px"}
                                                    ),
                                                    html.P("Graphs: ", style={
                                                       "textAlign": "start", "marginLeft": "10px", "marginTop": "4px"}),
                                                    # Graphs filter
                                                    dcc.Dropdown(
                                                    id='GeneralGraphFilter',
                                                    options=[{'label': "Spikes", 'value': "Spikes"}, {'label': "Synapses activity", 'value': "Synapses"}, {
                                                        'label': "Neurons potential", 'value': "Potentials"}],
                                                    value=["Spikes","Synapses","Potentials"],
                                                    multi=True,
                                                    style={"minWidth": "20%", "marginLeft": "5px", "textAlign": "start"}),
                                                    # Layers filter
                                                    html.P("Layers: ", style={
                                                       "textAlign": "start", "marginLeft": "20px", "marginTop": "4px"}),
                                                    dcc.Dropdown(
                                                    id='GeneralLayerFilter',
                                                    options=[{'label': str(i), 'value': str(i)} for i in (
                                                        i for i in self.g.Layer_Neuron if i != "Input")],
                                                    value=[str(i) for i in (
                                                        i for i in self.g.Layer_Neuron if i != "Input")],
                                                    multi=True,
                                                    style={"minWidth": "20%","marginLeft": "5px", "textAlign": "start"})], className="d-flex", style={"paddingLeft": "20px", 'width': '100%'})
                                            ], className="col-12")
                                        ], className="d-flex"),
                                        html.Div([dcc.Graph(id='general-graph', config={"displaylogo": False})])], className="col-lg-9 col-sm-12 col-xs-12" if(self.g.labelsExistance) else "col-lg-12 col-sm-12 col-xs-12"),
                                        html.Div([
                                            html.Div([
                                                daq.PowerButton(
                                                    id="label-graph-switch",
                                                    on='True',
                                                    size=30,
                                                    color="#28a745",
                                                    style={"marginLeft": "20px"}
                                                ),
                                                html.P("Inputs", style={"textAlign": "start", "marginLeft": "10px", "marginTop": "4px"})], className="d-flex"),
                                            dcc.Graph(id='label-graph', config={"displaylogo": False})], className="col-lg-3 col-sm-12 col-xs-12") if(self.g.labelsExistance) else []], className="row")
                                ], style={"padding": "5px"})), label="General information", value="General information"),
                            # 2D view
                            dcc.Tab(dbc.Card(
                                dbc.CardBody([
                                    html.Div([
                                            # Layers filter
                                            html.P("Layers: ", style={
                                                       "textAlign": "start", "marginRight": "10px", "marginTop": "4px"}),
                                            dcc.Dropdown(
                                            id='2DViewLayerFilter',
                                            options=[{'label': str(i), 'value': str(i)} for i in (
                                                i for i in self.g.Layer_Neuron if i != "Input")],
                                            value=[str(i) for i in (
                                                i for i in self.g.Layer_Neuron if i != "Input")],
                                            multi=True,
                                            style={"minWidth": "80%", "textAlign": "start"}),
                                        ], style={"textAlign": "start", },className="d-flex col-lg-12 col-sm-12 col-xs-12"),
                                    html.Div([
                                                html.Div([html.P("Accumulated Spikes", style={"margin":"0px"})]),
                                                # Accumulated Spikes HeatMap
                                                dcc.Tabs([dcc.Tab(dbc.Card(dbc.CardBody([
                                                        dcc.Graph(id={"type":"2DView-heatmap","index":i}, config={"displaylogo": False})
                                                    ])),label=i, value='2Dview-'+str(x)) for x, i in enumerate(self.g.Layer_Neuron) if i != "Input"],value="2Dview-1"),
                                                ], style={"textAlign": "start", }, className="col-lg-3 col-sm-12 col-xs-12")
                                    ,
                                    html.Div(
                                        [
                                            html.Div([html.P("2D Space", style={"margin":"0px"})]),
                                            html.Div([
                                            cyto.Cytoscape(
                                                id='cytoscape-compound',
                                                layout={'name': 'preset'},
                                                boxSelectionEnabled=False,
                                                style={'width': '100%',
                                                       'height': '100%'},
                                                stylesheet=[
                                                    {
                                                        'selector': 'node',
                                                        'style': {'label': 'data(label)'}
                                                    },                                          
                                                    {
                                                        'selector': '[spiked <= 1.0]',
                                                        'style': {
                                                            'background-color': 'rgb(227,70,70)',
                                                            'height': 45,
                                                            'width': 45
                                                        }
                                                    },
                                                    {
                                                        'selector': '[spiked < 0.8]',
                                                        'style': {
                                                            'background-color': 'rgb(227,100,100)',
                                                            'height': 40,
                                                            'width': 40
                                                        }
                                                    }, 
                                                    {
                                                        'selector': '[spiked < 0.6]',
                                                        'style': {
                                                            'background-color': 'rgb(227,130,130)',
                                                            'height': 35,
                                                            'width': 35
                                                        }
                                                    }, 
                                                    {
                                                        'selector': '[spiked < 0.4]',
                                                        'style': {
                                                            'background-color': 'rgb(227,160,160)',
                                                            'height': 30,
                                                            'width': 30
                                                        }
                                                    }, 
                                                    {
                                                        'selector': '[spiked < 0.2]',
                                                        'style': {
                                                            'background-color': 'rgb(227,190,190)',
                                                            'height': 25,
                                                            'width': 25
                                                        }
                                                    },  
                                                    {
                                                        'selector': '[spiked = 0.0]',
                                                        'style': {
                                                            'background-color': 'rgb(199,197,197)',
                                                            'height': 20,
                                                            'width': 20
                                                        }
                                                    },
                                                    {
                                                        'selector': '[spiked = -1]',
                                                        'style': {
                                                            'background-color': 'rgb(227,227,227)'
                                                        }
                                                    }
                                                ],
                                                elements=self.Spikes2D
                                            )],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}),
                                            html.P(id="spikes_info", style={"padding": "8px","margin":"0px"})
    
                                        ], style={"height": "50vh", "textAlign": "start", "padding": "0px","marginBottom":"12px"},className="col-lg-6 col-sm-12 col-xs-12"),
                                        
                                        # 3D destribution
                                        html.Div([
                                            html.Div([html.P("Spikes Activity Per Input", style={"margin":"0px"})]),
                                            dcc.Tabs([dcc.Tab(dbc.Card(dbc.CardBody([
                                                    dcc.Graph(id={"type":"SpikesActivityPerInput","index":i}, config={"displaylogo": False})
                                                ])),label=i, value='SpikesActivityPerInput-'+str(x)) for x, i in enumerate(self.g.Layer_Neuron) if i != "Input"],value="SpikesActivityPerInput-1")
                                        ], style={"textAlign": "start", },className="col-lg-3 col-sm-12 col-xs-12")
                                        
                                        ], className="row")), label="2D view", value="2Dview")], id="tabinfo", value="General information"),
    
                ]
                ),
                className="mt-3")
    
            # VS2N Tabs
            tabs = [dcc.Tab(info_vis, label="General")]
    
            # Add active modules
            for m in [i for i in self.g.modules if i != "General"]:
                try:
                    tabs.append(dcc.Tab(importlib.import_module(
                        ".Modules."+m+".layout", package="src").layout().load(self.app, self.g), label=m))
                except Exception:
                    print("Tabs appending:"+traceback.format_exc())
    
            # App layout
            self.app.layout = html.Div([
                html.Nav(className="navbar sticky-top navbar-dark", children=[
                    dbc.Row([
                            html.Div([html.A("VS2N", className="navbar-brand", href="#")], className="col-12", style={
                                     "marginRight": "0px", "marginLeft": "0px", "paddingRight": "0px", "paddingLeft": "0px"}),
                            html.Div([html.A("Information", target="_blank", rel="noopener noreferrer", href="https://gitlab.univ-lille.fr/bioinsp/VS2N", style={
                                     "color": "rgb(217, 220, 255)"})], className="col-sm-2 col-lg-2 align-self-center", style={"position": "fixed", "textAlign": "start"}),
                            html.Div([html.A("Log Out", href="/logout", style={"color": "rgb(217, 220, 255)"})], className="col-sm-2 col-lg-2 align-self-center", style={
                                     "position": "fixed", "right": "0px", "textAlign": "end"})
                            ], className="col-12", style={"marginRight": "0px", "marginLeft": "0px", "paddingRight": "0px", "paddingLeft": "0px"})
                ], style={"background": "rgb(68, 71, 99)"}),
                dbc.Container([
                    dcc.Interval(id="vis-update", interval=1000, disabled=True),
                    dcc.Tabs(tabs, id="tabs"),
                    # Control Layout
                    dbc.Row([
                        html.Div([dbc.Button(html.I(className="fa-solid fa-angle-left"), id="btn-back", className="btn btn-default", style={"marginTop": "12px", "fontWeight": "500", "backgroundColor": "rgb(68, 71, 99)"}),
                                  dbc.Button("Start", id="btnControle", className="btn btn-success", style={
                                             "marginTop": "12px", "marginLeft": "5px", "width": "100px", "fontWeight": "500"}),
                                  dbc.Button(html.I(className="fa-solid fa-angle-right"), id="btn-next", className="btn btn-default", style={"marginTop": "12px", "marginLeft": "5px", "fontWeight": "500", "backgroundColor": "rgb(68, 71, 99)"})], className="col-md-4 col-sm-12 col-lg-4"),
                        html.Div([
                            html.Div([
                                html.Span("Update Speed (s)",
                                          className="input-group-text")
                            ], className="input-group-prepend"),
                            dbc.Input(type="number", id="speed", value=1, min=0.25, max=5, step=0.25, style={
                                      "width": "30%", "textAlign": "center"})
                        ], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"}),
    
                        html.Div([
                            html.Div([
                                html.Span("Update Interval (s)",
                                          className="input-group-text")
                            ], className="input-group-prepend"),
                            dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.005,
                                      max=180, step=0.005, style={"width": "30%", "textAlign": "center"})
                        ], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"})
                    ], className="d-flex justify-content-center"), dbc.Col(
                        [
                            html.P(id="text", style={"marginLeft": "25px"}),
                            dcc.Slider(
                                id='vis-slider',
                                min=0,
                                max=self.g.stepMax,
                                marks=None,
                                step=1,
                                value=0,
                            )], style={"textAlign": "start", "padding": "5px", "paddingTop": "12px"})],
                    id="main_vis",
                    className="container-fluid col-12 p-2",
                    style={"textAlign": "center",
                           "height": "100vh", "alignItems": "center"},
                ), html.Div(id='v-step', children="0", style={'display': 'none'}), html.Div(id='clear', children="False", style={'display': 'none'})])
    
            # load callbacks
            callbacks(self,self.app, self.g)