Skip to content
Snippets Groups Projects
Select Git revision
  • 95febf21c8c8962d1803109749afb4ace7a9542d
  • master default protected
2 results

callbacks.py

Blame
  • callbacks.py 46.69 KiB
    """ This class contains Dash callbacks
    
        Dash callbacks are the responsible on updating graphs each step.
    """
    
    from collections import deque
    import pymongo
    import traceback
    from bson.json_util import dumps
    from bson.json_util import loads
    import plotly.graph_objects as go
    from dash import (no_update, Input, Output, State, ALL, callback_context)
    from plotly.subplots import make_subplots
    from src.templates.callbacksOp import callbacksOp
    
    class callbacks(callbacksOp):
        """ Callbacks class
        """
    
        def __init__(self, super, app, g):
            """ Initialize the callback .
    
            Args:
                app : Flask app
                g (Global_Var): reference to access global variables
            """
            # ------------------------------------------------------------
            # Graph build functions
            # ------------------------------------------------------------
            # to prevent creating duplicate callbacks next time
            if not g.checkExistance(app, "vis-slider"):
    
                try:
    
                    def processGeneralGraph(data, sliderValue, generalGraphFilter, generalLayerFilter):
                        """ Create general graph components.
    
                        Args:
                            data (array): data to be presented in the graph
                            sliderValue (int): value of the slider
                            generalGraphFilter (list): actual filter of GeneralGraph visualization
                            generalLayerFilter (list): selected layers for GeneralGraph visualization
    
                        Returns:
                            general graph content
                        """
                        X = 0
                        Xlabel = ""
                        try:
                            if sliderValue != None:
                                Xlabel = "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(
                                    g.updateInterval, sliderValue+1)+"]"
                            
                            if len(super.xAxisLabel) == 1 and super.xAxisLabel[0] == []:
                                super.xAxisLabel.clear()
                            super.xAxisLabel.append(Xlabel)
    
                            if len(super.InfoGraphX) > 0:
                                X = super.InfoGraphX[-1]+1
    
                            super.InfoGraphX.append(X)
                            
                            graphs = {l:[] for l in generalLayerFilter}
                            annotations = []
    
                            for layer in generalLayerFilter:
                                i = 0
                                for f in generalGraphFilter:
                                    if(f == "Spikes"):
                                        if(layer not in super.SpikeGraphY):
                                            super.SpikeGraphY[layer] = deque(maxlen=100)
                                            super.MaxSpike[layer] = 0
                                        if data[i] != None and layer in data[i]:
                                            super.SpikeGraphY[layer].append(data[i][layer]["spikes"])
                                        else:
                                            super.SpikeGraphY[layer].append(0)
                                        
                                        super.MaxSpike[layer] = max(super.MaxSpike[layer], max(
                                            super.SpikeGraphY[layer]) if super.SpikeGraphY[layer] else 0)
                                        
                                        graphs[layer].append(
                                            go.Scatter(
                                                x=list(super.InfoGraphX),
                                                y=list([norm(i, super.MaxSpike[layer])
                                                    for i in super.SpikeGraphY[layer]]),
                                                fill='tozeroy' if len(
                                                    generalGraphFilter) == 1 else 'none',
                                                line=dict(color="rgb(31, 119, 180)"),
                                                name='Spikes'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
                                                mode='lines+markers',
                                                text=list(super.xAxisLabel),
                                                customdata=list(super.SpikeGraphY[layer]),
                                                hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(
                                                    super.MaxSpike[layer])))
                                        
                                    if(f == "Synapses"):
                                        if(layer not in super.SynapseGraphY):
                                            super.SynapseGraphY[layer] = deque(maxlen=100)
                                            super.MaxSynapse[layer] = 0
                                        if data[i] != None and layer in data[i]:
                                            super.SynapseGraphY[layer].append(data[i][layer]["synapseUpdate"])
                                        else:
                                            super.SynapseGraphY[layer].append(0)
                                        super.MaxSynapse[layer] = max(
                                            max(super.SynapseGraphY[layer]) if super.SynapseGraphY[layer] else 0, super.MaxSynapse[layer])
                                        graphs[layer].append(
                                            go.Scatter(
                                                x=list(super.InfoGraphX),
                                                y=list([norm(i, max(super.SynapseGraphY[layer]))
                                                    for i in super.SynapseGraphY[layer]]),
                                                fill='tozeroy' if len(
                                                    generalGraphFilter) == 1 else 'none',
                                                line=dict(color="rgb(255, 127, 14)"),
                                                name='Synapses activity'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
                                                mode='lines+markers',
                                                text=list(super.xAxisLabel),
                                                customdata=list(super.SynapseGraphY[layer]),
                                                hovertemplate="<span style='color:white;'>%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxSynapse[layer])))
                                    if(f == "Potentials"):
                                        if(layer not in super.PotentialGraphY):
                                            super.PotentialGraphY[layer] = deque(maxlen=100)
                                            super.MaxPotential[layer] = 0
                                        if data[i] != None and layer in data[i]:
                                            super.PotentialGraphY[layer].append(data[i][layer]["potential"])
                                        else:
                                            super.PotentialGraphY[layer].append(0)
    
                                        super.MaxPotential[layer] = max(
                                            max(super.PotentialGraphY[layer]) if super.PotentialGraphY[layer] else 0, super.MaxPotential[layer])
                                        graphs[layer].append(
                                            go.Scatter(
                                                x=list(super.InfoGraphX),
                                                y=list([norm(i, super.MaxPotential[layer])
                                                    for i in super.PotentialGraphY[layer]]),
                                                fill='tozeroy' if len(
                                                    generalGraphFilter) == 1 else 'none',
                                                line=dict(color="rgb(44, 160, 44)"),
                                                name='Neuron\'s potential'+ ('['+layer+']' if len(generalLayerFilter) > 1 else ''),
                                                mode='lines+markers',
                                                text=list(super.xAxisLabel),
                                                customdata=list(super.PotentialGraphY[layer]),
                                                hovertemplate="%{text} <br> <b>%{customdata}</b> <br> <b>Max</b> "+str(super.MaxPotential[layer])))
                                    i += 1
    
                            if(len(graphs) != 0):
    
                                if(g.finalLabels != None):
    
                                    if(data[-1] != None):
                                        super.LossGraphY.append(round(data[-1], 2))
                                    else:
                                        super.LossGraphY.append(None)
    
                                    fig = make_subplots(rows=2+(len(graphs)*3), cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=
                                    [[{'rowspan': 2}],[None]] + ([[{'rowspan': 3}],[None],[None]] * len(graphs)))
    
                                    fig.add_trace(
                                        go.Scatter(x=list(super.InfoGraphX), y=list(super.LossGraphY), mode='lines',
                                                text=[
                                            str(t)+' %' for t in list(super.LossGraphY)],
                                            hoverinfo='text',
                                            connectgaps=False,
                                            xaxis='x2',
                                            line={"color": "#dc3545",
                                                "dash": "dot",
                                                "width": 2},
                                            name='Loss'
                                        ),
                                        row=1, col=1
                                    )
                                    l = 3
                                    for key,graphL in graphs.items():
                                        for graph in graphL:
                                            fig.add_trace(
                                                graph, row=l, col=1)
                                        l *=2
    
                                else:
                                    fig = make_subplots(
                                        rows=len(graphs), cols=1, shared_xaxes=True,vertical_spacing=0.05)
    
                                    l = 1
                                    for key,graphL in graphs.items():
                                        for graph in graphL:
                                            fig.add_trace(
                                                graph, row=l, col=1)
                                        l +=1
    
                                fig.update_xaxes(title_text="Step", row=(len(graphs) * 3) if g.finalLabels != None else len(graphs), col=1)
    
                                fig['layout'].update(
                                    yaxis=dict(range=[0, 105]),
                                    showlegend=True,
                                    uirevision='no reset of zoom',
                                    margin={'l': 0, 'r': 0, 't': 30, 'b': 25},
                                    annotations=annotations)
    
                            else:
                                
                                fig = make_subplots(rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05, specs=[[{'rowspan': 1}]])
                                fig.add_trace(
                                        go.Scatter(x=list(), y=list(), mode='lines',
                                            line={"color": "#dc3545",
                                                "dash": "dot",
                                                "width": 2},
                                        ),row=1, col=1)
                            
                            return fig
    
                        except Exception:
                            print("processGeneralGraph " + traceback.format_exc())
    
                    def processLabelInfoTreemap(data):
                        """ Return a Treemap visualization of actual inputs
    
                        Args:
                            data (array): Array contains the total value of processed inputs
                            and the actual inputs being processed.
    
                        Returns:
                            treemap visualization
                        """
                        try:
    
                            if data != None:
                                super.Max = data[1]
                                super.Label = data[0]
                            else:
                                super.Label = dict()
                            return {
                                'data': [
                                    go.Treemap(
                                        labels=list(super.Label.keys()
                                                    ) if super.Label != [] else [],
                                        textfont_size=20,
                                        textposition="middle center",
                                        parents=["" for i in range(
                                            len(super.Label))],
                                        values=list(super.Label.values()
                                                    ) if super.Label != [] else [],
                                        hovertemplate='%{value}',
                                        name='Label',
                                        marker=dict(
                                            colors=list(super.Label.keys()
                                                        ) if super.Label != [] else [],
                                            colorscale='RdBu',))],
                                'layout': go.Layout(
                                    xaxis_type='category',
                                    title_text='Processed Inputs: <b>' +
                                    str(super.Max)+'</b>',
                                    uirevision='no reset of zoom',
                                    margin={'l': 0, 'r': 0, 't': 30, 'b': 0},
                                )}
                        except Exception:
                            print("processLabelInfoTreemap:"+ traceback.format_exc())
    
                    # ----------------------------------------------------
                    # Callbacks
                    # ----------------------------------------------------
                    # Main Callback:
                    # 1- updates progress bar.
                    # 2- Listener for control buttons.
                    # 3- trigger update in all existing visualizations.
                    @app.callback(
                        [Output("text", "children"), Output("vis-slider", "value"), Output("vis-slider", "step"), Output("vis-slider", "max"), Output("v-step", "children"),
                        Output("clear", "children"), Output("interval", "value")],
                        [Input("vis-update", "n_intervals"),
                         Input("btn-back", "n_clicks"), Input("btn-next", "n_clicks")],
                        [State("vis-slider", "value"), State("interval", "value"),
                         State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State("clear", "children")])
                    def progress(visUpdateInterval, backButton, nextButton, sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, clearGraphs):
                        """ This is the callback function. It is called each step.
    
                        Args:
                            visUpdateInterval : interval instance that will cause this function to be called each step
                            backButton (int): number of clicks on the back button
                            nextButton (int): number of clicks on the next button
                            sliderValue (int): value of the slider
                            updateInterval (int): update Interval (s)
                            generalGraphFilter (list): actual filter of GeneralGraph visualization
                            generalLayerFilter (list): selected layers for GeneralGraph visualization
                            clearGraphs (boolean): a dash state to pass information to all visualization about clearing content (if needed)
    
                        Returns:
                            array of outputs that are selected in the callback
                        """
                        try:
                            context = callback_context.triggered[0]['prop_id'].split('.')[
                                0]
                            # update interval value if changed
                            if(g.updateInterval != float(updateInterval)):
                                super.clearData()
                                clearGraphs = not clearGraphs
    
                                if(float(updateInterval) < 0.005):
                                    g.updateInterval = 0.005
                                elif (float(updateInterval) > 180.0):
                                    g.updateInterval = 180.0
                                else:
                                    g.updateInterval = float(updateInterval)
    
                                # update slider value when interval changed
                                sliderValue = sliderValue / g.stepMax
                                sliderValue = int(
                                    sliderValue * (int(g.Max/g.updateInterval)+1))
    
                            g.stepMax = int(g.Max/g.updateInterval)+1
    
                            if (super.generalGraphFilterOld != generalGraphFilter):
                                super.generalGraphFilterOld = generalGraphFilter
                                super.clearData()
                                clearGraphs = not clearGraphs
                            
                            if (super.generalLayerFilterOld != generalLayerFilter):
                                super.generalLayerFilterOld = generalLayerFilter
                                super.clearData()
                                clearGraphs = not clearGraphs
                                
                            if abs(super.oldSliderValue - sliderValue) > 2:
                                super.clearData()
                                clearGraphs = not clearGraphs
    
                            super.oldSliderValue = sliderValue
                            if "btn" in context:
                                if context == "btn-back":
                                    if sliderValue > 0:
                                        sliderValue = sliderValue - 1
                                        super.oldSliderValue = sliderValue
                                    super.clearData()
                                    clearGraphs = not clearGraphs
                                else:
                                    if context == "btn-next" and callback_context.triggered[0]['value'] != None:
                                        if(sliderValue < g.stepMax):
                                            sliderValue = sliderValue + 1
    
                                super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
                                    g.updateInterval, sliderValue+1)+" ]"
    
                                if(super.visStopped):
                                    return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
                                else:
                                    return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
    
                            else:
    
                                if(not super.visStopped):
                                    sliderValue = sliderValue + 1
                                    super.oldSliderValue = sliderValue
    
                                    super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
                                        g.updateInterval, sliderValue+1)+" ]"
    
                                    if sliderValue*g.updateInterval >= g.Max:
                                        super.visStopped = True
                                    else:
                                        super.visStopped = False
    
                                    return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
                                else:
    
                                    super.label = "[ "+g.getLabelTime(g.updateInterval, sliderValue)+" , "+g.getLabelTime(
                                        g.updateInterval, sliderValue+1)+" ]"
    
                                    return [super.label, sliderValue, 1, g.stepMax, sliderValue, clearGraphs, g.updateInterval]
    
                        except Exception:
                            print("progress:" + traceback.format_exc())
                    # Callback to controle play/stop button
                    @app.callback(
                        [Output("btnControle", "children"), Output("btnControle", "className"),Output("vis-update", "disabled")],
                        [Input("vis-update", "n_intervals"),Input("btnControle", "n_clicks")],[State("vis-slider", "value"),State("btnControle", "children")])
                    def progressButton(visUpdateInterval,playButton,sliderValue,playButtonText):
                        """ This is the callback function. It is called when play/stop button is clicked.
    
                        Args:
                            visUpdateInterval : interval instance that will cause this function to be called each step
                            playButton (int): number of clicks on the start/stop button
                            sliderValue (int): value of the slider
                            playButtonText (String): text on the start/stop button
                            
                        Returns:
                            array of outputs that are selected in the callback
                        """
                        if callback_context.triggered[0]['prop_id'].split('.')[0] == "btnControle":
                            if playButtonText == "Start":
                                if(int(g.stepMax) <= sliderValue):
                                    super.visStopped = True
                                    return ["Start", "btn btn-success", True]
                                else:
                                    super.visStopped = False
                                    return ["Stop", "btn btn-danger", False]
                            else:
                                if(int(g.stepMax) >= sliderValue):
                                    super.visStopped = True
                                    return ["Start", "btn btn-success", True]
                                else:
                                    return no_update
                        else:
                            if(int(g.stepMax) <= sliderValue):
                                    super.visStopped = True
                                    return ["Start", "btn btn-success", True]
                            else:
                                return no_update
    
                    # Callback to handle general graph content
                    @app.callback(
                        [Output("general-graph", "figure")
                         ], [Input("v-step", "children")],
                        [State("interval", "value"), State("GeneralGraphFilter", "value"), State("GeneralLayerFilter", "value"), State('general-graph-switch', 'on')])
                    def progressGeneralGraph(sliderValue, updateInterval, generalGraphFilter, generalLayerFilter, generalGraphSwitchIsOn):
                        """ Update the general graph.
    
                        Args:
                            sliderValue (int): value of the slider
                            updateInterval (int): update Interval (s)
                            generalGraphFilter (list): actual filter of GeneralGraph visualization
                            generalLayerFilter (list): selected layers for GeneralGraph visualization
                            generalGraphSwitchIsOn (bool): general graph switch value
    
                        Raises:
                            no_update: in case we don't want to update the content we rise this execption 
    
                        Returns:
                            content of the graph that contains general information on the network activity
                        """
                        if generalGraphSwitchIsOn:
                            if len(super.xAxisLabel) > 0 and "["+g.getLabelTime(g.updateInterval, sliderValue)+","+g.getLabelTime(g.updateInterval, sliderValue+1)+"]" == super.xAxisLabel[-1]:
                                return no_update
    
                            if(not super.visStopped):
                                generalData = GeneralModuleData(
                                    int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
    
                                if generalData == [None]:
                                    generalGraph = processGeneralGraph(
                                        None, sliderValue, generalGraphFilter, generalLayerFilter)
                                else:
                                    generalGraph = processGeneralGraph(
                                        generalData, sliderValue, generalGraphFilter, generalLayerFilter)
    
                                return [generalGraph]
    
                            else:
                                if(sliderValue > g.stepMax):
                                    return no_update
                                else:
                                    generalData = GeneralModuleData(
                                        int(sliderValue)*float(updateInterval), generalGraphFilter, generalLayerFilter)
                                    generalGraph = processGeneralGraph(
                                        generalData, int(sliderValue), generalGraphFilter, generalLayerFilter)
                                    return [generalGraph]
                        else:
                            return no_update
    
                    # Callback to handle label graph content
                    @app.callback(
                        [Output("label-graph", "figure")
                         ], [Input("v-step", "children")],
                        [State("interval", "value"), State('label-graph-switch', 'on')])
                    def progressLabelGraph(sliderValue, updateInterval, labelGraphSwitchIsOn):
                        """ Update the label treemap graph. 
    
                        Args:
                            sliderValue (int): value of the slider
                            updateInterval (int): update Interval (s)
                            labelGraphSwitchIsOn (bool): label graph switch value
    
                        Raises:
                            PreventUpdate: in case we don't want to update the content we rise this execption
    
                        Returns:
                            content of the Treemap that contains information on the network input
                        """
                        if labelGraphSwitchIsOn:
    
                            labelData = getNetworkInput(
                                int(sliderValue)*float(updateInterval), g.updateInterval)
                            if(not super.visStopped):
    
                                if labelData == [None]:
                                    labelInfoTreemap = processLabelInfoTreemap(
                                        None)
                                else:
                                    labelInfoTreemap = processLabelInfoTreemap(
                                        labelData)
    
                                if sliderValue*g.updateInterval >= g.Max:
                                    super.visStopped = True
                                    return [labelInfoTreemap]
                                else:
                                    super.visStopped = False
                                    return [labelInfoTreemap]
                            else:
                                labelInfoTreemap = processLabelInfoTreemap(
                                    labelData)
                                return [labelInfoTreemap]
                        else:
                            return no_update
    
                    # Callback to update the speed of visualization
                    @app.callback(
                        [Output("vis-update", "interval"),Output("speed", "value")],
                        [Input("speed", "value")])
                    def speedControle(speedValue):
                        """ Store speed value in the shared holder 'vis-update'
    
                        Args:
                            speedValue (int): actual speed value
    
                        Returns:
                            speed value to be stored
                        """
                        try:
                            if(speedValue < 0.25):
                               speedValue = 0.25
                            if (speedValue > 120):
                               speedValue = 120
    
                            return [speedValue * 1000,speedValue]
                        except Exception:
                            print("speedControle:" + traceback.format_exc())
    
                    # Callback to handle information tab (open or close)
                    @app.callback(
                        [Output("collapse-info", "is_open")],
                        [Input("group-info-toggle", "n_clicks")], [State("collapse-info", "is_open")])
                    def informationTab(nbrClicks, isTabOpen):
                        """ Function called when the user clicks on the information tab.
    
                        Args:
                            nbrClicks (int): number of clicks on the tab
                            isTabOpen (bool): whether tab is open or not
    
                        Returns:
                            if information tab should be opened or closed
                        """
                        try:
                            if callback_context.triggered[0]["value"] != None:
                                return [not isTabOpen]
                            else:
                                return [isTabOpen]
                        except Exception:
                            print("informationTabController:" + traceback.format_exc())
    
                    # Callback to handle the 2D view spiking visualization
                    @app.callback(
                        Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'),
                        Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
                        State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
                    def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
                        """ Function called each step to update the 2D view
    
                        Args:
                            sliderValue (int): value of the slider
                            updateInterval (int): update Interval (s)
                            visUpdateInterval : interval instance that will cause this function to be called each step
                            mouseOverNodeData : contains data of the hovered node
                            elements : nodes description
                            heatmapData : heatmap data
                            Layer2DViewFilter : selected layers
    
                        Returns:
                            if information tab should be opened or closed
                        """
                        try:
                            elements = super.Spikes2D
                            matrix = {}
                            indices = {}
                            labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval)
                            if callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
                                super.SpikesActivityPerInput = {i:[[0 for j in range(g.nbrClasses+1)] for i in range(g.Layer_Neuron[i])] for i in g.Layer_Neuron if i != "Input"}
                                for element in elements:
                                        if element["data"]['spiked'] != -1:
                                            element["data"]["spiked"] = 0
                                            element["data"]["spikes"] = 0
    
                                spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,Layer2DViewFilter,True)
                                
                                for layer in Layer2DViewFilter:
                                    #neurons = [[0 for j in range(g.nbrClasses)] for i in range(g.Layer_Neuron[layer])]
                                    try:
                                        layerSpikes = [list(list(list(s.values())[0].values())[0].values())[0] for s in spikes if list(s.keys())[0] == layer]
                                    except Exception:
                                        layerSpikes = []
    
                                    if spikes and layerSpikes:
                                        maxSpike = max(layerSpikes) 
                                        for spike in spikes:
                                            if list(spike.keys())[0] == layer:
                                                # update the spikes neurons
                                                i = 0
                                                for element in elements:
                                                    if element["data"]['spiked'] != -1:
                                                        if (element["data"]["id"] == layer+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
                                                            spk = list(list(list(spike.values())[0].values())[0].values())[0]
                                                            element["data"]["spiked"] = round(spk / maxSpike,2)
                                                            element["data"]["spikes"] = spk
                                                            super.AccumulatedSpikes2D[layer][int(element["data"]["label"])] += spk
                                                            super.SpikesActivityPerInput[layer][list(list(spike.values())[0].keys())[0]][list(list(list(spike.values())[0].values())[0].keys())[0]] = spk
                                                        i+=1
    
                                    matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
                                    indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
    
    
    
                                if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
                                    for layer in super.AccumulatedSpikes2D:
                                        if layer not in matrix:
                                            matrix[layer] = []
                                            indices[layer] = []
    
                                heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
    
                                SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
                                return [elements,[],heatmaps,SpikesActivityPerInput]
                            else:
                                
                                try:
                                    for layer in Layer2DViewFilter:
                                        matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
                                        indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
    
                                    if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
                                        for layer in super.AccumulatedSpikes2D:
                                            if layer not in matrix:
                                                matrix[layer] = []
                                                indices[layer] = []
    
                                    heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
    
                                    SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput]
                                    return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps,SpikesActivityPerInput]
                                except Exception:
                                    print("OnHover:"+traceback.format_exc())
                                    return no_update
                        except Exception:
                            print("animation2DViewController:" + traceback.format_exc())
    
                except Exception:
                    print("Done loading:"+traceback.format_exc())
    
            try:
    
                # ---------------------------------------------------------
                # Helper functions
                # ---------------------------------------------------------
    
                def norm(data, Max):
                    """ Normalize data to be between 0 and Max.
    
                    Args:
                        data (array): array of values
                        Max (int): max value
    
                    Returns:
                        normalized array
                    """
                    return (data * 100)/Max if Max != 0 else data
    
                def GeneralModuleData(timestamp, filter, layers):
                    """ Returns a list of graphs based on the specified filter.
    
                    Args:
                        timestamp (int): timestamp value
                        filter (array): array of selected filters
                        layers (array): array of selected layers
                    Returns:
                        list of graphs
                    """
                    res = []
                    for f in filter:
                        if f == "Spikes":
                            if res == []:
                                res = [getSpike(timestamp, g.updateInterval,layers,False)]
                            else:
                                res.append(getSpike(timestamp, g.updateInterval,layers,False))
                        if f == "Synapses":
                            if res == []:
                                res = [getSynapse(timestamp, g.updateInterval,layers)]
                            else:
                                res.append(getSynapse(timestamp, g.updateInterval,layers))
                        if f == "Potentials":
                            if res == []:
                                res = [getPotential(timestamp, g.updateInterval,layers)]
                            else:
                                res.append(getPotential(timestamp, g.updateInterval,layers))
    
                    # get loss value
                    if res == []:
                        res = [getLoss(timestamp, g.updateInterval)]
                    else:
                        res.append(getLoss(timestamp, g.updateInterval))
                    return res
                    
                def make_SpikeActivityPerInput(layer,dataLabel):
                    fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]])
                    fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer], colorscale= 'Reds',hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1)
                    fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'})
                    if dataLabel == None:
                        fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
                    else:
                        fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
                    fig.update_xaxes(tickvals=[i for i in range(g.nbrClasses+1)])
                    return fig
                # ---------------------------------------------------------
                # MongoDB operations
                # ---------------------------------------------------------
    
                def getNetworkInput(timestamp, interval):
                    """ Get network input for a given timestamp and interval.
    
                    Args:
                        timestamp (int): timestamp value
                        interval (int): interval value
    
                    Returns:
                        array contains totale processed inputs and current inputs
                    """
                    # MongoDB --------------------------------------------
                    col = pymongo.collection.Collection(g.db, 'labels')
    
                    labels = col.aggregate([
                        {"$match": {
                            "T": {'$gt': timestamp, '$lte': (timestamp+interval)}
                        }},
                        {"$group": {"_id": "$L", "C": {"$sum": 1}, "G": {"$max": "$G"}}},
                        {"$sort": {"_id": 1}}
                    ], allowDiskUse = True)
    
                    # ToJson ---------------------------------------------
                    labels = loads(dumps(labels))
                    # ----------------------------------------------------
    
                    Max = 0
                    for i in labels:
                        Max = max(Max, i["G"])
    
                    if not labels:
                        return None
    
                    L = dict({i["_id"]: i["C"] for i in labels})
    
                    return [L, Max]
    
                def getSpike(timestamp, interval, layer, perNeuron):
                    """ Get spikes activity in a given interval.
    
                    Args:
                        timestamp (int): timestamp value
                        interval (int): interval value
                        layers (array): array of selected layers
                        perNeuron (boolean): return global or perNeuron Spikes
    
                    Returns:
                        array contains spikes
                    """
                    # MongoDB---------------------
                    col = pymongo.collection.Collection(g.db, 'spikes')
    
                    if perNeuron:
                        spikes = col.aggregate([
                                            {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
                                            {"$group": {"_id": {"L":"$i.L","N":"$i.N","Input":"$Input"},"spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
                                           ], allowDiskUse = True)
                    else:
                        spikes = col.aggregate([
                                            {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"i.L": {'$in': layer}}]}},
                                            {"$group": {"_id": "$i.L","spikes": {"$sum":1}}},{"$sort": {"_id": 1}}
                                           ], allowDiskUse = True)
    
                    # ToJson----------------------
                    spikes = loads(dumps(spikes))
                    # ----------------------------
                    if perNeuron:
                        spikes = [{s["_id"]["L"]:{s["_id"]["N"]:{s["_id"]["Input"]:s["spikes"]}}} for s in spikes]
                    else:
                        spikes = {s["_id"]:s for s in spikes}
                    
                    if not spikes:
                        return None
    
                    return spikes
    
                def getSynapse(timestamp, interval, layer):
                    """ Get syanpse activity in a given interval.
    
                    Args:
                        timestamp (int): timestamp value
                        interval (int): interval value
                        layers (array): array of selected layers
    
                    Returns:
                        array contains synapses activity
                    """
                    # MongoDB---------------------
                    col = pymongo.collection.Collection(g.db, 'synapseWeight')
                    synapse = col.aggregate([
                                            {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
                                            {"$group": {"_id": "$L","synapseUpdate": {"$sum":1}}},{"$sort": {"_id": 1}}
                                           ], allowDiskUse = True)
                    
                    # ----------------------------
    
                    # ToJson----------------------
                    synapse = loads(dumps(synapse))
                    # ----------------------------
                    synapse = {s["_id"]:s for s in synapse}
    
                    if not synapse:
                        return None
                    return synapse
    
                def getPotential(timestamp, interval, layer):
                    """ Get potential activity in a given interval.
    
                    Args:
                        timestamp (int): timestamp value
                        interval (int): interval value
                        layers (array): array of selected layers
    
                    Returns:
                        array contains potential activity
                    """
                    # MongoDB---------------------
                    col = pymongo.collection.Collection(g.db, 'potential')
                    potential = col.aggregate([
                                            {"$match": {"$and": [{"T": {'$gt': timestamp, '$lte': (timestamp+interval)}},{"L": {'$in': layer}}]}},
                                            {"$group": {"_id": "$L","potential": {"$sum":1}}},{"$sort": {"_id": 1}}
                                           ], allowDiskUse = True)
                    # ----------------------------
    
                    # ToJson----------------------
                    potential = loads(dumps(potential))
                    # ----------------------------
                    potential = {p["_id"]:p for p in potential}
    
                    if not potential:
                        return None
    
                    return potential
    
                def getLoss(timestamp, interval):
                    """Get the loss of the spike i and step .
    
                    Args:
                        timestamp (int): timestamp value
                        interval (int): interval value
    
                    Returns:
                        loss value
                    """
                    if g.finalLabels == None:
                        return None
                    # MongoDB---------------------
                    col = pymongo.collection.Collection(g.db, 'spikes')
                    spikes = col.find(
                        {"T": {'$gt': timestamp, '$lte': (timestamp+interval)}})
                    # ----------------------------
    
                    # ToJson----------------------
                    spikes = loads(dumps(spikes))
                    # ----------------------------
    
                    if not spikes:
                        return None
    
                    loss = 0
    
                    for a in spikes:
                        for l in g.finalLabels:
                            if (str(a["i"]["N"]) == str(l["N"]) and str(a["i"]["L"]) == l["L"]):
                                if (str(l["Label"]) != str(a["Input"])):
                                    loss += 1
    
                    return min(100,round((loss*100) / len(spikes), 2))
                # ---------------------------------------------------------------------
    
            except Exception:
                print("Helper functions and MongoDB operations: "+ traceback.format_exc())