From b8a9609930f4913e078dfe53b0c19410704cf1d0 Mon Sep 17 00:00:00 2001 From: Hammouda Elbez <hammouda.elbez@univ-lille.fr> Date: Mon, 31 Jul 2023 15:48:51 +0200 Subject: [PATCH] General module updated --- src/Modules/General/callbacks.py | 22 +++++++++++++--------- src/Modules/General/spark.py | 2 +- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/Modules/General/callbacks.py b/src/Modules/General/callbacks.py index 7c1e5ab..a7fc92c 100755 --- a/src/Modules/General/callbacks.py +++ b/src/Modules/General/callbacks.py @@ -231,7 +231,6 @@ class callbacks(callbacksOp): treemap visualization """ try: - if data != None: super.Max = data[1] super.Label = data[0] @@ -435,7 +434,7 @@ class callbacks(callbacksOp): content of the graph that contains general information on the network activity """ if generalGraphSwitchIsOn: - labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval) + labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval, False) if labelData != None: labelData = (labelData[1] // InputPerEpoch)+1 if (labelData[1] % InputPerEpoch) != 0 else (labelData[1] // InputPerEpoch) else: @@ -490,7 +489,7 @@ class callbacks(callbacksOp): if labelGraphSwitchIsOn: labelData = getNetworkInput( - int(sliderValue)*float(updateInterval), g.updateInterval) + int(sliderValue)*float(updateInterval), g.updateInterval, False) if(not super.visStopped): if labelData == [None]: @@ -581,7 +580,7 @@ class callbacks(callbacksOp): elements = super.generate2DView(g,Layer2DViewFilter) matrix = {} indices = {} - labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval) + labelData = getNetworkInput(int(sliderValue)*float(updateInterval), g.updateInterval, True) if callback_context.triggered[0]['prop_id'].split('.')[0] in ["v-step","vis-update"]: super.SpikesActivityPerInput = {i["layer"]:[[0 for _ in range(g.ClassNbr)] for _ in range(g.Layer_Neuron[i["layer"]])] for i in g.LayersNeuronsInfo} for element in elements: @@ -707,7 +706,7 @@ class callbacks(callbacksOp): def make_SpikeActivityPerInput(layer,dataLabel,Layer2DViewFilter): fig = make_subplots(rows=5, cols=1, shared_xaxes=True, vertical_spacing=0.08, specs=[[{'rowspan': 4}],[None],[None],[None],[{'rowspan': 1}]]) - fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer] if layer in Layer2DViewFilter else [], colorscale= 'Reds' if sum([item for row in super.SpikesActivityPerInput[layer] for item in row]) > 0 else "gray", zmin=0, hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1) + fig.add_trace(go.Heatmap(z=super.SpikesActivityPerInput[layer] if layer in Layer2DViewFilter else [], colorscale= 'Reds', zmin=0, hovertemplate=('Class: %{x} <br>Neuron: %{y} <br>Spikes: %{z} <extra></extra>'),xgap=4,ygap=4),row=1, col=1) fig.update_layout({"xaxis":dict(title="Class",tickmode="array",zeroline = False,tickvals=[i for i in range(g.ClassNbr)]),"yaxis":dict(title="Neuron",zeroline = False),"margin":{'l': 0, 'r': 0, 't': 5, 'b': 0},"uirevision":'no reset of zoom'}) if (dataLabel == None) or (layer not in Layer2DViewFilter): fig.add_trace(go.Bar(x=[],y=[],hovertemplate=('Label: %{x} <br>Nbr: %{y} <extra></extra>')),row=5, col=1) @@ -719,12 +718,13 @@ class callbacks(callbacksOp): # MongoDB operations # --------------------------------------------------------- - def getNetworkInput(timestamp, interval): + def getNetworkInput(timestamp, interval, full): """ Get network input for a given timestamp and interval. Args: timestamp (int): timestamp value interval (int): interval value + full (bool): get info for all Inputs or just the existing ones in that period Returns: array contains totale processed inputs and current inputs @@ -751,10 +751,14 @@ class callbacks(callbacksOp): if not labels: return None - L = dict({i: 0 for i in range(g.ClassNbr)}) + if (full): + L = dict({i: 0 for i in range(g.ClassNbr)}) - for l in labels: - L[l["_id"]] = l["C"] + for l in labels: + L[l["_id"]] = l["C"] + else: + L = dict({l["_id"]: l["C"] for l in labels}) + return [L, Max] def getSpike(timestamp, interval, layer, perNeuron): diff --git a/src/Modules/General/spark.py b/src/Modules/General/spark.py index f07bc32..5a9ba8c 100755 --- a/src/Modules/General/spark.py +++ b/src/Modules/General/spark.py @@ -65,7 +65,7 @@ class spark(sparkOp): if ('labels' in self.g.db.list_collection_names()): self.g.labelsExistance = True M = max(M, pymongo.collection.Collection(self.g.db, 'labels').find_one(sort=[("T", -1)])["T"]) - self.g.ClassNbr = int(max(M, pymongo.collection.Collection(self.g.db, 'labels').find_one(sort=[("L", -1)])["L"]))+1 + self.g.ClassNbr = int(pymongo.collection.Collection(self.g.db, 'labels').find_one(sort=[("L", -1)])["L"])+1 else: print("No labels") -- GitLab