From e64b1aea7a48b8c837ed2594dec91ff2e77b9548 Mon Sep 17 00:00:00 2001 From: Hammouda Elbez <hammouda.elbez@univ-lille.fr> Date: Mon, 13 Mar 2023 11:32:35 +0100 Subject: [PATCH] 2DView: added reset when moving backward --- src/Modules/General/callbacks.py | 25 ++++++++++++++++--------- src/Modules/General/layout.py | 24 ++++++++++++------------ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py index 2b26e3e..35206d7 100755 --- a/src/Modules/General/callbacks.py +++ b/src/Modules/General/callbacks.py @@ -546,10 +546,10 @@ class callbacks(callbacksOp): # Callback to handle the 2D view spiking visualization @app.callback( - Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"), + Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'), Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'), - State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data")) - def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, StoredData): + State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value")) + def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter): """ Function called each step to update the 2D view Args: @@ -558,16 +558,19 @@ class callbacks(callbacksOp): visUpdateInterval : interval instance that will cause this function to be called each step mouseOverNodeData : contains data of the hovered node elements : nodes description - StoredData : data stored for shared access heatmapData : heatmap data + Layer2DViewFilter : selected layers Returns: if information tab should be opened or closed """ try: - elements = StoredData[0] + elements = super.Spikes2D if dash.callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]: + for element in elements[1:]: + element["data"]["spiked"] = 0 + element["data"]["spikes"] = 0 spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True) if spikes: maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes]) @@ -579,14 +582,18 @@ class callbacks(callbacksOp): if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])): element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2) element["data"]["spikes"] = list(list(spike.values())[0].values())[0] - StoredData[1][i] += list(list(spike.values())[0].values())[0] + super.AccumulatedSpikes2D[i] += list(list(spike.values())[0].values())[0] i+=1 - return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] + matrix = super.toMatrix(super.AccumulatedSpikes2D,5) + indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5) + return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}] else: try: - return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] + matrix = super.toMatrix(super.AccumulatedSpikes2D,5) + indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5) + return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}] except Exception as e: - return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] + return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}] except Exception as e: print("animation2DViewController:" + str(e)) diff --git a/src/Modules/General/layout.py b/src/Modules/General/layout.py index 9f6d242..fc29291 100755 --- a/src/Modules/General/layout.py +++ b/src/Modules/General/layout.py @@ -28,11 +28,12 @@ class layout(layoutOp): MaxSpike = dict() MaxPotential = dict() MaxSynapse = dict() + AccumulatedSpikes2D = dict() + Spikes2D = dict() # LabelPie Data -------------------------------------------------- Label = [[], []] Max = 0 # ---------------------------------------------------------------- - Nodes = [] tabs = [] label = " " visStopped = True @@ -61,8 +62,9 @@ class layout(layoutOp): def toMatrix(self, l,n): """ 1D array to 2D """ - return [l[i:i+n] for i in range(0, len(l), n)] - + Matrix = [l[i:i+n] for i in range(0, len(l), n)] + return Matrix + def clearData(self): """ Clear the data when moved forward or backward for more than one step """ @@ -73,11 +75,12 @@ class layout(layoutOp): self.PotentialGraphY.clear() self.xAxisLabel.clear() self.Label.clear() - self.MaxPotential = dict() - self.MaxSpike = dict() - self.MaxSynapse = dict() + self.MaxPotential.clear() + self.MaxSpike.clear() + self.MaxSynapse.clear() + self.Spikes2D = self.generate2DView(self.g) + self.AccumulatedSpikes2D = [0 for n in self.Spikes2D if n["data"]["spiked"] != -1] self.Max = 0 - self.Nodes = [] def Vis(self): """ Create layer components @@ -101,9 +104,6 @@ class layout(layoutOp): "Accuracy", style={"width": "25%", "fontWeight": "500"}), html.Td(str(self.g.Accuracy)+" %", style={"width": "25%"})]) ] - # Generate 2D View ------------------------------------------- - self.Nodes = self.generate2DView(self.g) - # Tabs content ----------------------------------------------- info_vis = dbc.Card( dbc.CardBody([dbc.Card([dbc.CardHeader( @@ -189,7 +189,7 @@ class layout(layoutOp): , html.Div( [ - dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]), + #dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]), cyto.Cytoscape( id='cytoscape-compound', layout={'name': 'grid','animate': False}, @@ -255,7 +255,7 @@ class layout(layoutOp): } } ], - elements=self.Nodes + elements=self.Spikes2D ), html.P(id="spikes_info", style={"padding": "8px"}) -- GitLab