diff --git a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java index 8e05aa8222aa209f0f12a00fc21c23b087c122fb..def9b3273b96a1a2870403485fa2466c807f98a0 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java @@ -42,7 +42,10 @@ public class KNNController { 1, 1)); + kEntry.getValueFactory().setValue(ClassificationModel.getClassificationModel().getK()); + algoSelector.getItems().addAll("Euclidienne", "Euclidienne Normalisée", "Manhattan", "Manhattan Normalisée"); + algoSelector.setValue(Distance.getDistanceName(ClassificationModel.getClassificationModel().getDistance())); } @@ -69,24 +72,26 @@ public class KNNController { //List<LoadableData> datasShuffle = new ArrayList<>(List.copyOf(model.getDatas())); // Collections.shuffle(datasShuffle); MethodKNN.updateModel(model.getDatas()); - + Distance dist = Distance.getByName(algoSelector.getValue()); updateProgress(1, 3); updateMessage("Recherche du meilleur K"); - int bestK = MethodKNN.bestK(model.getDatas(), model.getDistance()); + int bestK = MethodKNN.bestK(model.getDatas(), dist); updateMessage("Test de robustesse"); updateProgress(2, 3); - double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, model.getDistance(), 0.2); + double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, dist, 0.2); model.setKOptimal(bestK); updateMessage("Affichage du resultat"); updateProgress(2.5, 3); + model.setDistance(dist); + HBox hBox = new HBox(); Label label = new Label("Best K: " + bestK + " robustesse : " + robustesse); hBox.getChildren().add(label); diff --git a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java index 261f27e09c5c4cf0c85d417eb1ddd0056322bd3a..24fd3796dd99cc4eb343b15cc679a407b90eac05 100644 --- a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java +++ b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java @@ -21,4 +21,16 @@ public interface Distance { } + static String getDistanceName(Distance distance){ + if (distance instanceof DistanceEuclidienneNormalisee) { + return "Euclidienne Normalisee"; + }else if (distance instanceof DistanceManhattan){ + return "Manhattan"; + }else if (distance instanceof DistanceManhattanNormalisee){ + return "ManhattanNormalisee"; + }else { + return "Euclidienne"; + } + } + } diff --git a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java index 25242aad61036660c1edb2005ff78a2218aab280..3900cb19a9230759b77b6784b434b75c4b345dea 100644 --- a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java @@ -123,7 +123,7 @@ public class DataStageView extends DataVisualizationView implements Observer { if(editSerie == null){ editSerie = new ScatterChart.Series<Double, Double>(); } - if(data.getClassification().equals("undefined")) { + if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); } dataPoint.setNode(nodePoint); diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java index fb9de2f3dac5f2d18ab5686b0e972e2160c4aae5..2907867b5b85b139bb6d3cfbcf6c95313fa9f699 100644 --- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java @@ -123,7 +123,7 @@ public class MainStageView extends DataVisualizationView implements Observer { if(editSerie == null){ editSerie = new ScatterChart.Series<Double, Double>(); } - if(data.getClassification().equals("undefined")) { + if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); }