From f0599bc9e6b26ef36008e4af4d81e58998017510 Mon Sep 17 00:00:00 2001
From: Matias Mennecart <matias.mennecart@icloud.com>
Date: Mon, 25 Nov 2024 23:28:13 +0100
Subject: [PATCH] Fix issues while reloading data and fix issue with knn

---
 .../sae/classification/controller/KNNController.java | 11 ++++++++---
 .../sae/classification/knn/distance/Distance.java    | 12 ++++++++++++
 .../sae/classification/view/DataStageView.java       |  2 +-
 .../sae/classification/view/MainStageView.java       |  2 +-
 4 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
index 8e05aa8..def9b32 100644
--- a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
+++ b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
@@ -42,7 +42,10 @@ public class KNNController {
                         1,
                         1));
 
+        kEntry.getValueFactory().setValue(ClassificationModel.getClassificationModel().getK());
+
         algoSelector.getItems().addAll("Euclidienne", "Euclidienne Normalisée", "Manhattan", "Manhattan Normalisée");
+        algoSelector.setValue(Distance.getDistanceName(ClassificationModel.getClassificationModel().getDistance()));
     }
 
 
@@ -69,24 +72,26 @@ public class KNNController {
                     //List<LoadableData> datasShuffle = new ArrayList<>(List.copyOf(model.getDatas()));
                    // Collections.shuffle(datasShuffle);
                     MethodKNN.updateModel(model.getDatas());
-
+                    Distance dist = Distance.getByName(algoSelector.getValue());
 
                     updateProgress(1, 3);
                     updateMessage("Recherche du meilleur K");
 
-                    int bestK = MethodKNN.bestK(model.getDatas(), model.getDistance());
+                    int bestK = MethodKNN.bestK(model.getDatas(), dist);
 
 
                     updateMessage("Test de robustesse");
                     updateProgress(2, 3);
 
 
-                    double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, model.getDistance(), 0.2);
+                    double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, dist, 0.2);
                     model.setKOptimal(bestK);
 
                     updateMessage("Affichage du resultat");
                     updateProgress(2.5, 3);
 
+                    model.setDistance(dist);
+
                     HBox hBox = new HBox();
                     Label label = new Label("Best K: " + bestK + " robustesse : " + robustesse);
                     hBox.getChildren().add(label);
diff --git a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
index 261f27e..24fd379 100644
--- a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
+++ b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
@@ -21,4 +21,16 @@ public interface Distance {
 
     }
 
+    static String getDistanceName(Distance distance){
+        if (distance instanceof DistanceEuclidienneNormalisee) {
+            return "Euclidienne Normalisee";
+        }else if (distance instanceof DistanceManhattan){
+            return "Manhattan";
+        }else if (distance instanceof DistanceManhattanNormalisee){
+            return "ManhattanNormalisee";
+        }else {
+            return "Euclidienne";
+        }
+    }
+
 }
diff --git a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
index 25242aa..3900cb1 100644
--- a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
+++ b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
@@ -123,7 +123,7 @@ public class DataStageView extends DataVisualizationView implements Observer {
                     if(editSerie == null){
                         editSerie = new ScatterChart.Series<Double, Double>();
                     }
-                    if(data.getClassification().equals("undefined")) {
+                    if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) {
                         nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller);
                     }
                     dataPoint.setNode(nodePoint);
diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
index fb9de2f..2907867 100644
--- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
+++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
@@ -123,7 +123,7 @@ public class MainStageView extends DataVisualizationView implements Observer {
                     if(editSerie == null){
                         editSerie = new ScatterChart.Series<Double, Double>();
                     }
-                    if(data.getClassification().equals("undefined")) {
+                    if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) {
                         nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller);
                     }
 
-- 
GitLab