Skip to content
Snippets Groups Projects
Commit c1b4b37f authored by Hammouda Elbez's avatar Hammouda Elbez :computer:
Browse files

2D view: adding ability to filter based on layer

parent a343f719
No related branches found
No related tags found
No related merge requests found
...@@ -544,17 +544,16 @@ class callbacks(callbacksOp): ...@@ -544,17 +544,16 @@ class callbacks(callbacksOp):
# Callback to handle the 2D view spiking visualization # Callback to handle the 2D view spiking visualization
@app.callback( @app.callback(
Output("cytoscape-compound", "elements"),Output('spikes_info', 'children'),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'), Output("cytoscape-compound", "elements"),Output({"index": ALL, "type": '2DView-heatmap'},'figure'),Output({"index": ALL, "type": 'SpikesActivityPerInput'},'figure'),
Input("vis-update", "n_intervals"),Input("v-step", "children"),Input('cytoscape-compound', 'mouseoverNodeData'), Input("vis-update", "n_intervals"),Input("v-step", "children"),
State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value")) State("interval", "value"),State('cytoscape-compound', 'elements'),State("2DViewLayerFilter", "value"))
def animation2DView(visUpdateInterval,sliderValue, mouseOverNodeData, updateInterval, elements, Layer2DViewFilter): def animation2DView(visUpdateInterval,sliderValue, updateInterval, elements, Layer2DViewFilter):
""" Function called each step to update the 2D view """ Function called each step to update the 2D view
Args: Args:
sliderValue (int): value of the slider sliderValue (int): value of the slider
updateInterval (int): update Interval (s) updateInterval (int): update Interval (s)
visUpdateInterval : interval instance that will cause this function to be called each step visUpdateInterval : interval instance that will cause this function to be called each step
mouseOverNodeData : contains data of the hovered node
elements : nodes description elements : nodes description
heatmapData : heatmap data heatmapData : heatmap data
Layer2DViewFilter : selected layers Layer2DViewFilter : selected layers
...@@ -563,7 +562,7 @@ class callbacks(callbacksOp): ...@@ -563,7 +562,7 @@ class callbacks(callbacksOp):
if information tab should be opened or closed if information tab should be opened or closed
""" """
try: try:
elements = super.Spikes2D elements = super.generate2DView(g,Layer2DViewFilter)
matrix = {} matrix = {}
indices = {} indices = {}
labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval) labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval)
...@@ -601,18 +600,16 @@ class callbacks(callbacksOp): ...@@ -601,18 +600,16 @@ class callbacks(callbacksOp):
matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer]) matrix[layer] = super.toMatrix(super.AccumulatedSpikes2D[layer])
indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))]) indices[layer] = super.toMatrix([i for i in range(0,len(super.AccumulatedSpikes2D[layer]))])
if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D): if len(Layer2DViewFilter) != len(super.AccumulatedSpikes2D):
for layer in super.AccumulatedSpikes2D: for layer in super.AccumulatedSpikes2D:
if layer not in matrix: if layer not in matrix:
matrix[layer] = [] matrix[layer] = []
indices[layer] = [] indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput] SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput]
return [elements,[],heatmaps,SpikesActivityPerInput] return [elements,heatmaps,SpikesActivityPerInput]
else: else:
try: try:
...@@ -626,10 +623,10 @@ class callbacks(callbacksOp): ...@@ -626,10 +623,10 @@ class callbacks(callbacksOp):
matrix[layer] = [] matrix[layer] = []
indices[layer] = [] indices[layer] = []
heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D] heatmaps = [{"data":[go.Heatmap(z = matrix[layer], colorscale= 'Reds', zmin=0, customdata = indices[layer], hovertemplate=('Neuron: %{customdata} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4)],"layout":{"xaxis":dict(showgrid = False, zeroline = False),"yaxis":dict(autorange = 'reversed',scaleanchor = 'x',showgrid = False, zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom', "hoverlabel_align": 'right'}} for layer in super.AccumulatedSpikes2D]
SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData) for layer in super.SpikesActivityPerInput] SpikesActivityPerInput = [make_SpikeActivityPerInput(layer,labelData, Layer2DViewFilter) for layer in super.SpikesActivityPerInput]
return [elements,f"Neuron {mouseOverNodeData['label']} : {mouseOverNodeData['spikes']}" if 'spikes' in mouseOverNodeData else "", heatmaps,SpikesActivityPerInput] return [elements,heatmaps,SpikesActivityPerInput]
except Exception: except Exception:
print("OnHover:"+traceback.format_exc()) print("OnHover:"+traceback.format_exc())
return no_update return no_update
...@@ -692,11 +689,11 @@ class callbacks(callbacksOp): ...@@ -692,11 +689,11 @@ class callbacks(callbacksOp):
res.append(getLoss(timestamp, g.updateInterval)) res.append(getLoss(timestamp, g.updateInterval))
return res return res
def make_SpikeActivityPerInput(layer,dataLabel): def make_SpikeActivityPerInput(layer,dataLabel,Layer2DViewFilter):
fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]]) fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]])
fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer], colorscale= 'Reds',hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1) fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer] if layer in Layer2DViewFilter else [], colorscale= 'Reds', zmin=0, hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1)
fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'}) fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.nbrClasses+1)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'})
if dataLabel == None: if (dataLabel == None) or (layer not in Layer2DViewFilter):
fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1) fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
else: else:
fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1) fig.add_trace(go.Bar(x=list(dataLabel[0].keys()),y=list(dataLabel[0].values()),hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1)
......
...@@ -46,15 +46,16 @@ class layout(layoutOp): ...@@ -46,15 +46,16 @@ class layout(layoutOp):
# 2D view -------------------------------------------------------- # 2D view --------------------------------------------------------
def generate2DView(self, g): def generate2DView(self, g, layers):
""" Generates a 2D View of the neural network """ Generates a 2D View of the neural network
Args: Args:
g (Global_Var): reference to access global variables g (Global_Var): reference to access global variables
layers: (array): list of layer names
""" """
Nodes = [] Nodes = []
# Create the neurones and layers # Create the neurones and layers
for L in g.LayersNeuronsInfo: for L in [l for l in g.LayersNeuronsInfo if l["layer"] in layers]:
Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}}) Nodes.append({'data': {'id': L["layer"], 'label': L["layer"], 'spiked': -1}})
for i in range(L["neuronNbr"]): for i in range(L["neuronNbr"]):
Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20}) Nodes.append({'data': {'id': L["layer"]+str(i), 'label': str(i), 'parent': L["layer"], 'spiked': 0.0, 'spikes': 0},'position': {'x': (i % int(math.sqrt(L["neuronNbr"]))) * 70, 'y': (i // int(math.sqrt(L["neuronNbr"]))) * 70},'height': 20,'width': 20})
...@@ -82,7 +83,7 @@ class layout(layoutOp): ...@@ -82,7 +83,7 @@ class layout(layoutOp):
self.MaxPotential.clear() self.MaxPotential.clear()
self.MaxSpike.clear() self.MaxSpike.clear()
self.MaxSynapse.clear() self.MaxSynapse.clear()
self.Spikes2D = self.generate2DView(self.g) self.Spikes2D = self.generate2DView(self.g,[str(i) for i in (i["layer"] for i in self.g.LayersNeuronsInfo)])
self.AccumulatedSpikes2D = {i["layer"]:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i["layer"] == n["data"]["parent"]] for i in self.g.LayersNeuronsInfo} self.AccumulatedSpikes2D = {i["layer"]:[0 for n in self.Spikes2D if n["data"]["spiked"] != -1 and i["layer"] == n["data"]["parent"]] for i in self.g.LayersNeuronsInfo}
self.SpikesActivityPerInput = {i["layer"]:[[0 for j in range(self.g.nbrClasses+1)] for _ in range(self.g.Layer_Neuron[i["layer"]])] for i in self.g.LayersNeuronsInfo} self.SpikesActivityPerInput = {i["layer"]:[[0 for j in range(self.g.nbrClasses+1)] for _ in range(self.g.Layer_Neuron[i["layer"]])] for i in self.g.LayersNeuronsInfo}
self.Max = 0 self.Max = 0
...@@ -269,7 +270,7 @@ class layout(layoutOp): ...@@ -269,7 +270,7 @@ class layout(layoutOp):
], ],
elements=self.Spikes2D elements=self.Spikes2D
)],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}), )],style={'position': 'absolute','width': '100%','height': '100%','z-index': 999,"background": "rgba(68, 71, 99, 0.05)"}),
html.P(id="spikes_info", style={"padding": "8px","margin":"0px"}) #html.P(id="spikes_info", style={"padding": "8px","margin":"0px"})
], style={ "textAlign": "start", "padding": "0px"},className="col-lg-5 col-sm-12 col-xs-12"), ], style={ "textAlign": "start", "padding": "0px"},className="col-lg-5 col-sm-12 col-xs-12"),
...@@ -333,8 +334,8 @@ class layout(layoutOp): ...@@ -333,8 +334,8 @@ class layout(layoutOp):
html.Span("Update Interval (s)", html.Span("Update Interval (s)",
className="input-group-text") className="input-group-text")
], className="input-group-prepend"), ], className="input-group-prepend"),
dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.005, dbc.Input(type="number", id="interval", value=self.g.updateInterval, min=0.001,
max=180, step=0.005, style={"width": "30%", "textAlign": "center"}) max=180, step=0.001, style={"width": "30%", "textAlign": "center"})
], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"}) ], className="input-group col-md-12 col-sm-12 col-lg-4", style={"height": "38px", "paddingTop": "12px"})
], className="d-flex justify-content-center"), dbc.Col( ], className="d-flex justify-content-center"), dbc.Col(
[ [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment