From f0599bc9e6b26ef36008e4af4d81e58998017510 Mon Sep 17 00:00:00 2001
From: Matias Mennecart <matias.mennecart@icloud.com>
Date: Mon, 25 Nov 2024 23:28:13 +0100
Subject: [PATCH] Fix issues while reloading data and fix issue with knn
---
.../sae/classification/controller/KNNController.java | 11 ++++++++---
.../sae/classification/knn/distance/Distance.java | 12 ++++++++++++
.../sae/classification/view/DataStageView.java | 2 +-
.../sae/classification/view/MainStageView.java | 2 +-
4 files changed, 22 insertions(+), 5 deletions(-)
diff --git a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
index 8e05aa8..def9b32 100644
--- a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
+++ b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java
@@ -42,7 +42,10 @@ public class KNNController {
1,
1));
+ kEntry.getValueFactory().setValue(ClassificationModel.getClassificationModel().getK());
+
algoSelector.getItems().addAll("Euclidienne", "Euclidienne Normalisée", "Manhattan", "Manhattan Normalisée");
+ algoSelector.setValue(Distance.getDistanceName(ClassificationModel.getClassificationModel().getDistance()));
}
@@ -69,24 +72,26 @@ public class KNNController {
//List<LoadableData> datasShuffle = new ArrayList<>(List.copyOf(model.getDatas()));
// Collections.shuffle(datasShuffle);
MethodKNN.updateModel(model.getDatas());
-
+ Distance dist = Distance.getByName(algoSelector.getValue());
updateProgress(1, 3);
updateMessage("Recherche du meilleur K");
- int bestK = MethodKNN.bestK(model.getDatas(), model.getDistance());
+ int bestK = MethodKNN.bestK(model.getDatas(), dist);
updateMessage("Test de robustesse");
updateProgress(2, 3);
- double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, model.getDistance(), 0.2);
+ double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, dist, 0.2);
model.setKOptimal(bestK);
updateMessage("Affichage du resultat");
updateProgress(2.5, 3);
+ model.setDistance(dist);
+
HBox hBox = new HBox();
Label label = new Label("Best K: " + bestK + " robustesse : " + robustesse);
hBox.getChildren().add(label);
diff --git a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
index 261f27e..24fd379 100644
--- a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
+++ b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java
@@ -21,4 +21,16 @@ public interface Distance {
}
+ static String getDistanceName(Distance distance){
+ if (distance instanceof DistanceEuclidienneNormalisee) {
+ return "Euclidienne Normalisee";
+ }else if (distance instanceof DistanceManhattan){
+ return "Manhattan";
+ }else if (distance instanceof DistanceManhattanNormalisee){
+ return "ManhattanNormalisee";
+ }else {
+ return "Euclidienne";
+ }
+ }
+
}
diff --git a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
index 25242aa..3900cb1 100644
--- a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
+++ b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java
@@ -123,7 +123,7 @@ public class DataStageView extends DataVisualizationView implements Observer {
if(editSerie == null){
editSerie = new ScatterChart.Series<Double, Double>();
}
- if(data.getClassification().equals("undefined")) {
+ if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) {
nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller);
}
dataPoint.setNode(nodePoint);
diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
index fb9de2f..2907867 100644
--- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
+++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
@@ -123,7 +123,7 @@ public class MainStageView extends DataVisualizationView implements Observer {
if(editSerie == null){
editSerie = new ScatterChart.Series<Double, Double>();
}
- if(data.getClassification().equals("undefined")) {
+ if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) {
nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller);
}
--
GitLab