From c1b4b37f9f9c22ce624bec95d610ef780712e0c7 Mon Sep 17 00:00:00 2001 From: Hammouda Elbez <hammouda.elbez@univ-lille.fr> Date: Fri, 26 May 2023 16:51:32 +0200 Subject: [PATCH] 2D view: adding ability to filter based on layer --- src/Modules/General/callbacks.py | 29 +++++++++++++---------------- src/Modules/General/layout.py | 13 +++++++------ 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py index b9e4258..75bd78d 100755 --- a/src/Modules/General/callbacks.py +++ b/src/Modules/General/callbacks.py @@ -544,17 +544,16 @@ class callbacks(callbacksOp): # Callback to handle the 2D view spiking visualization @app.callback( - Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'), - Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'), + Output("cytoscape-compound", "elements"),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'), + Input("vis-update", "n_intervals"),Input("v-step", "children"), State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value")) - def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter): + def animation2DView(visUpdateInterval,sliderValue, updateInterval, elements, Layer2DViewFilter): """ Function called each step to update the 2D view Args: sliderValue (int): value of the slider updateInterval (int): update Interval (s) visUpdateInterval : interval instance that will cause this function to be called each step - mouseOverNodeData : contains data of the hovered node elements : nodes description heatmapData : heatmap data Layer2DViewFilter : selected layers @@ -563,7 +562,7 @@ class callbacks(callbacksOp): if information tab should be opened or closed """ try: - elements = super.Spikes2D + elements = super.generate2DView(g,Layer2DViewFilter) matrix = {} indices = {} labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval) @@ -601,18 +600,16 @@ class callbacks(callbacksOp): matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer]) indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))]) - - if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D): for layer in super.AccumulatedSpikes2D: if layer not in matrix: matrix[layer] = [] indices[layer] = [] - heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] + heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] - SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput] - return [elements,[],heatmaps,SpikesActivityPerInput] + SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput] + return [elements,heatmaps,SpikesActivityPerInput] else: try: @@ -626,10 +623,10 @@ class callbacks(callbacksOp): matrix[layer] = [] indices[layer] = [] - heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] + heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] - SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput] - return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps,SpikesActivityPerInput] + SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput] + return [elements,heatmaps,SpikesActivityPerInput] except Exception: print("OnHover:"+traceback.format_exc()) return no_update @@ -692,11 +689,11 @@ class callbacks(callbacksOp): res.append(getLoss(timestamp, g.updateInterval)) return res - def make_SpikeActivityPerInput(layer,dataLabel): + def make_SpikeActivityPerInput(layer,dataLabel,Layer2DViewFilter): fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]]) - fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer], colorscale= 'Reds',hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1) + fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer] if layer in Layer2DViewFilter else [], colorscale= 'Reds', zmin=0, hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1) fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'}) - if dataLabel == None: + if (dataLabel == None) or (layer not in Layer2DViewFilter): fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1) else: fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1) diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py index 7cc7508..446c73d 100755 --- a/src/Modules/General/layout.py +++ b/src/Modules/General/layout.py @@ -46,15 +46,16 @@ class layout(layoutOp): # 2D view -------------------------------------------------------- - def generate2DView(self, g): + def generate2DView(self, g, layers): """ Generates a 2D View of the neural network Args: g (Global_Var): reference to access global variables + layers: (array): list of layer names """ Nodes = [] # Create the neurones and layers - for L in g.LayersNeuronsInfo: + for L in [l for l in g.LayersNeuronsInfo if l["layer"] in layers]: Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}}) for i in range(L["neuronNbr"]): Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20}) @@ -82,7 +83,7 @@ class layout(layoutOp): self.MaxPotential.clear() self.MaxSpike.clear() self.MaxSynapse.clear() - self.Spikes2D = self.generate2DView(self.g) + self.Spikes2D = self.generate2DView(self.g,[str(i) for i in (i["layer"] for i in self.g.LayersNeuronsInfo)]) self.AccumulatedSpikes2D = {i["layer"]:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i["layer"] == n["data"]["parent"]] for i in self.g.LayersNeuronsInfo} self.SpikesActivityPerInput = {i["layer"]:[[0 for j in range(self.g.nbrClasses+1)] for _ in range(self.g.Layer_Neuron[i["layer"]])] for i in self.g.LayersNeuronsInfo} self.Max = 0 @@ -269,7 +270,7 @@ class layout(layoutOp): ], elements=self.Spikes2D )],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}), - html.P(id="spikes_info", style={"padding": "8px","margin":"0px"}) + #html.P(id="spikes_info", style={"padding": "8px","margin":"0px"}) ], style={ "textAlign": "start", "padding": "0px"},className="col-lg-5 col-sm-12 col-xs-12"), @@ -333,8 +334,8 @@ class layout(layoutOp): html.Span("Update Interval (s)", className="input-group-text") ], className="input-group-prepend"), - dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.005, - max=180, step=0.005, style={"width": "30%", "textAlign": "center"}) + dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.001, + max=180, step=0.001, style={"width": "30%", "textAlign": "center"}) ], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"}) ], className="d-flex justify-content-center"), dbc.Col( [ -- GitLab