Skip to content
Snippets Groups Projects
Commit e64b1aea authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

2DView: added reset when moving backward

parent 8b0f1931
No related branches found
No related tags found
1 merge request!26Custom 2d view
...@@ -546,10 +546,10 @@ class callbacks(callbacksOp): ...@@ -546,10 +546,10 @@ class callbacks(callbacksOp):
# Callback to handle the 2D view spiking visualization # Callback to handle the 2D view spiking visualization
@app.callback( @app.callback(
Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),Output("StoredData", "data"), Output("cytoscape-compound", "elements"),Output("cytoscape-compound", "layout"), Output('spikes_info', 'children'),Output('2DView-heatmap','figure'),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'), Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("StoredData", "data")) State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, StoredData): def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter):
""" Function called each step to update the 2D view """ Function called each step to update the 2D view
Args: Args:
...@@ -558,16 +558,19 @@ class callbacks(callbacksOp): ...@@ -558,16 +558,19 @@ class callbacks(callbacksOp):
visUpdateInterval : interval instance that will cause this function to be called each step visUpdateInterval : interval instance that will cause this function to be called each step
mouseOverNodeData : contains data of the hovered node mouseOverNodeData : contains data of the hovered node
elements : nodes description elements : nodes description
StoredData : data stored for shared access
heatmapData : heatmap data heatmapData : heatmap data
Layer2DViewFilter : selected layers
Returns: Returns:
if information tab should be opened or closed if information tab should be opened or closed
""" """
try: try:
elements = StoredData[0] elements = super.Spikes2D
if dash.callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]: if dash.callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]:
for element in elements[1:]:
element["data"]["spiked"] = 0
element["data"]["spikes"] = 0
spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True) spikes = getSpike(int(sliderValue)*float(updateInterval), g.updateInterval,["Layer1"],True)
if spikes: if spikes:
maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes]) maxSpike = max([list(list(s.values())[0].values())[0] for s in spikes])
...@@ -579,14 +582,18 @@ class callbacks(callbacksOp): ...@@ -579,14 +582,18 @@ class callbacks(callbacksOp):
if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])): if (element["data"]["id"] == "Layer1_"+str(list(list(spike.values())[0].keys())[0])) and (element["data"]["label"] == str(list(list(spike.values())[0].keys())[0])):
element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2) element["data"]["spiked"] = round(list(list(spike.values())[0].values())[0] / maxSpike,2)
element["data"]["spikes"] = list(list(spike.values())[0].values())[0] element["data"]["spikes"] = list(list(spike.values())[0].values())[0]
StoredData[1][i] += list(list(spike.values())[0].values())[0] super.AccumulatedSpikes2D[i] += list(list(spike.values())[0].values())[0]
i+=1 i+=1
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
else: else:
try: try:
return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] matrix = super.toMatrix(super.AccumulatedSpikes2D,5)
indices = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D))],5)
return [elements,{'name': 'grid','animate': False},f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}", {"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
except Exception as e: except Exception as e:
return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = super.toMatrix(StoredData[1],5), zsmooth= 'best', colorscale= 'Reds',showscale= False, hovertemplate=('Neuron: %{x}+%{y}*5 <br>''Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}},StoredData] return [elements,{'name': 'grid','animate': False},[],{"data":[go.Heatmap(z = matrix, zsmooth= 'best', colorscale= 'Reds',showscale= False, customdata = indices, hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'))],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 10, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}}]
except Exception as e: except Exception as e:
print("animation2DViewController:" + str(e)) print("animation2DViewController:" + str(e))
......
...@@ -28,11 +28,12 @@ class layout(layoutOp): ...@@ -28,11 +28,12 @@ class layout(layoutOp):
MaxSpike = dict() MaxSpike = dict()
MaxPotential = dict() MaxPotential = dict()
MaxSynapse = dict() MaxSynapse = dict()
AccumulatedSpikes2D = dict()
Spikes2D = dict()
# LabelPie Data -------------------------------------------------- # LabelPie Data --------------------------------------------------
Label = [[], []] Label = [[], []]
Max = 0 Max = 0
# ---------------------------------------------------------------- # ----------------------------------------------------------------
Nodes = []
tabs = [] tabs = []
label = " " label = " "
visStopped = True visStopped = True
...@@ -61,8 +62,9 @@ class layout(layoutOp): ...@@ -61,8 +62,9 @@ class layout(layoutOp):
def toMatrix(self, l,n): def toMatrix(self, l,n):
""" 1D array to 2D """ 1D array to 2D
""" """
return [l[i:i+n] for i in range(0, len(l), n)] Matrix = [l[i:i+n] for i in range(0, len(l), n)]
return Matrix
def clearData(self): def clearData(self):
""" Clear the data when moved forward or backward for more than one step """ Clear the data when moved forward or backward for more than one step
""" """
...@@ -73,11 +75,12 @@ class layout(layoutOp): ...@@ -73,11 +75,12 @@ class layout(layoutOp):
self.PotentialGraphY.clear() self.PotentialGraphY.clear()
self.xAxisLabel.clear() self.xAxisLabel.clear()
self.Label.clear() self.Label.clear()
self.MaxPotential = dict() self.MaxPotential.clear()
self.MaxSpike = dict() self.MaxSpike.clear()
self.MaxSynapse = dict() self.MaxSynapse.clear()
self.Spikes2D = self.generate2DView(self.g)
self.AccumulatedSpikes2D = [0 for n in self.Spikes2D if n["data"]["spiked"] != -1]
self.Max = 0 self.Max = 0
self.Nodes = []
def Vis(self): def Vis(self):
""" Create layer components """ Create layer components
...@@ -101,9 +104,6 @@ class layout(layoutOp): ...@@ -101,9 +104,6 @@ class layout(layoutOp):
"Accuracy", style={"width": "25%", "fontWeight": "500"}), html.Td(str(self.g.Accuracy)+" %", style={"width": "25%"})]) "Accuracy", style={"width": "25%", "fontWeight": "500"}), html.Td(str(self.g.Accuracy)+" %", style={"width": "25%"})])
] ]
# Generate 2D View -------------------------------------------
self.Nodes = self.generate2DView(self.g)
# Tabs content ----------------------------------------------- # Tabs content -----------------------------------------------
info_vis = dbc.Card( info_vis = dbc.Card(
dbc.CardBody([dbc.Card([dbc.CardHeader( dbc.CardBody([dbc.Card([dbc.CardHeader(
...@@ -189,7 +189,7 @@ class layout(layoutOp): ...@@ -189,7 +189,7 @@ class layout(layoutOp):
, ,
html.Div( html.Div(
[ [
dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]), #dcc.Store(id="StoredData",data=[self.Nodes,[0 for n in self.Nodes if n["data"]["spiked"] != -1]]),
cyto.Cytoscape( cyto.Cytoscape(
id='cytoscape-compound', id='cytoscape-compound',
layout={'name': 'grid','animate': False}, layout={'name': 'grid','animate': False},
...@@ -255,7 +255,7 @@ class layout(layoutOp): ...@@ -255,7 +255,7 @@ class layout(layoutOp):
} }
} }
], ],
elements=self.Nodes elements=self.Spikes2D
), ),
html.P(id="spikes_info", style={"padding": "8px"}) html.P(id="spikes_info", style={"padding": "8px"})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment