diff --git a/res/stages/k-NN-stage.fxml b/res/stages/k-NN-stage.fxml index f28f3265275f5d6bb763e210df6b2a82d376c85a..d85e37816d2a07190893f8454bbc57f4dc5c0871 100644 --- a/res/stages/k-NN-stage.fxml +++ b/res/stages/k-NN-stage.fxml @@ -16,7 +16,7 @@ <children> <HBox alignment="CENTER" prefHeight="92.0" prefWidth="364.0"> <children> - <ChoiceBox fx:id="AlgoSelector" prefWidth="150.0" stylesheets="@../css/style.css" /> + <ChoiceBox fx:id="algoSelector" prefWidth="150.0" stylesheets="@../css/style.css" /> </children> <padding> <Insets top="20.0" /> @@ -36,8 +36,8 @@ </HBox> <HBox alignment="CENTER" prefHeight="36.0" prefWidth="364.0" spacing="20.0"> <children> - <TextField fx:id="KEntry" /> - <Button fx:id="AutoK" mnemonicParsing="false" stylesheets="@../css/style.css" text="Attribution auto" textFill="WHITE"> + <Spinner fx:id="kEntry" /> + <Button fx:id="autoK" mnemonicParsing="false" onAction="#bestK" stylesheets="@../css/style.css" text="Attribution auto" textFill="WHITE"> <font> <Font name="System Bold" size="13.0" /> </font> @@ -46,7 +46,7 @@ </HBox> <HBox alignment="CENTER" prefHeight="59.0" prefWidth="364.0"> <children> - <Button fx:id="confirmK" mnemonicParsing="false" stylesheets="@../css/style.css" text="Valider" textFill="WHITE"> + <Button fx:id="confirmK" onAction="#validate" mnemonicParsing="false" stylesheets="@../css/style.css" text="Valider" textFill="WHITE"> <font> <Font name="System Bold" size="14.0" /> </font></Button> diff --git a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java index 90b6e21796f0d8f9bad210a23df1ea3e4a2c3447..def9b3273b96a1a2870403485fa2466c807f98a0 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java @@ -1,22 +1,156 @@ package fr.univlille.sae.classification.controller; +import fr.univlille.sae.classification.knn.MethodKNN; +import fr.univlille.sae.classification.knn.distance.Distance; +import fr.univlille.sae.classification.knn.distance.DistanceManhattanNormalisee; +import fr.univlille.sae.classification.model.ClassificationModel; +import fr.univlille.sae.classification.model.LoadableData; +import javafx.collections.ObservableList; +import javafx.concurrent.Task; import javafx.fxml.FXML; -import javafx.scene.control.Button; -import javafx.scene.control.ChoiceBox; -import javafx.scene.control.TextField; +import javafx.scene.Scene; +import javafx.scene.control.*; +import javafx.scene.layout.HBox; +import javafx.stage.Stage; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; public class KNNController { + + @FXML + private Stage stage; + @FXML - ChoiceBox<String> AlgoSelector; + ChoiceBox<String> algoSelector; @FXML - TextField KEntry; + Spinner<Integer> kEntry; @FXML - Button AutoK; + Button autoK; @FXML Button confirmK; + @FXML + public void initialize() { + kEntry.setValueFactory(new SpinnerValueFactory.IntegerSpinnerValueFactory(1, + (int) Math.sqrt(ClassificationModel.getClassificationModel().getDatas().size()), + 1, + 1)); + + kEntry.getValueFactory().setValue(ClassificationModel.getClassificationModel().getK()); + + algoSelector.getItems().addAll("Euclidienne", "Euclidienne Normalisée", "Manhattan", "Manhattan Normalisée"); + algoSelector.setValue(Distance.getDistanceName(ClassificationModel.getClassificationModel().getDistance())); + } + + + + public void bestK() { + ClassificationModel model = ClassificationModel.getClassificationModel(); + + if(model.getkOptimal() > 0) { + // Le K Optimal à déja été calculé, il n'est pas necessaire de le recaculer. + kEntry.getValueFactory().setValue(model.getkOptimal()); + }else { + // Calcul du K Optimal: + + HBox hBox = new HBox(); + + + Task<Scene> knnTask = new Task<>() { + @Override + protected Scene call() throws Exception { + System.out.println("Call call()"); + updateProgress(0, 3); + updateMessage("Préparation des données "); + + //List<LoadableData> datasShuffle = new ArrayList<>(List.copyOf(model.getDatas())); + // Collections.shuffle(datasShuffle); + MethodKNN.updateModel(model.getDatas()); + Distance dist = Distance.getByName(algoSelector.getValue()); + + updateProgress(1, 3); + updateMessage("Recherche du meilleur K"); + + int bestK = MethodKNN.bestK(model.getDatas(), dist); + + + updateMessage("Test de robustesse"); + updateProgress(2, 3); + + + double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, dist, 0.2); + model.setKOptimal(bestK); + + updateMessage("Affichage du resultat"); + updateProgress(2.5, 3); + + model.setDistance(dist); + + HBox hBox = new HBox(); + Label label = new Label("Best K: " + bestK + " robustesse : " + robustesse); + hBox.getChildren().add(label); + Scene scene = new Scene(hBox); + kEntry.getValueFactory().setValue(bestK); + + updateMessage("Finished"); + updateProgress(3, 3); + + return scene; + } + }; + + ProgressBar pBar = new ProgressBar(); + pBar.progressProperty().bind(knnTask.progressProperty()); + Label statusLabel = new Label(); + statusLabel.textProperty().bind(knnTask.messageProperty()); + + + hBox.getChildren().addAll(statusLabel, pBar); + Stage stageLoad = new Stage(); + Scene scene = new Scene(hBox); + + stageLoad.setScene(scene); + stageLoad.show(); + + Stage stageFinished = new Stage(); + + knnTask.setOnSucceeded(e -> { + stageLoad.close(); + stageFinished.setScene(knnTask.getValue()); + stageFinished.show(); + + }); + knnTask.run(); + //new Thread(knnTask).start(); + + + + + } + + } + + + public void validate() { + + ClassificationModel model = ClassificationModel.getClassificationModel(); + + + int k = kEntry.getValue(); + Distance dist = Distance.getByName(algoSelector.getValue()); + + model.setDistance(dist); + model.setK(k); + model.classifierDonnees(); + + stage.close(); + + } + } diff --git a/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java b/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java index 92ab843e9e31994f487136157144695e97c07c98..3019e0c6372bbcd2f3c7424f55cdbad5617d07e7 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java @@ -1,6 +1,7 @@ package fr.univlille.sae.classification.controller; import fr.univlille.sae.classification.model.ClassificationModel; +import fr.univlille.sae.classification.model.DataType; import javafx.fxml.FXML; import javafx.scene.control.Alert; import javafx.scene.control.Button; @@ -14,9 +15,11 @@ import java.io.IOException; public class LoadDataController { + @FXML Stage stage; + @FXML TextField filePath; @@ -44,6 +47,7 @@ public class LoadDataController { if(file != null) { filePath.setText(file.getPath()); } + } /** @@ -58,6 +62,7 @@ public class LoadDataController { alert.initOwner(stage); alert.setContentText("Le chargement du fichier à echoué, veuillez reessayer !"); alert.showAndWait(); + openFileChooser(); return; } diff --git a/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java b/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java index b4deb89f301bf6126622078a84bc839b8f5a8fb1..ce462a0b5c12e72f1ac2d9de29bed32aee1006a3 100644 --- a/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java +++ b/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java @@ -1,7 +1,6 @@ package fr.univlille.sae.classification.knn; -import fr.univlille.sae.classification.knn.distance.Distance; -import fr.univlille.sae.classification.knn.distance.DistanceEuclidienneNormalisee; +import fr.univlille.sae.classification.knn.distance.*; import fr.univlille.sae.classification.model.ClassificationModel; import fr.univlille.sae.classification.model.DataType; import fr.univlille.sae.classification.model.LoadableData; @@ -66,6 +65,7 @@ public class MethodKNN { // On recupere les K voisions de data. List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, data, k, distance); + System.out.println("Neighbours: " + kVoisins); // System.out.println("Neighbours found : " + kVoisins); @@ -91,8 +91,7 @@ public class MethodKNN { public static int bestK(List<LoadableData> datas, Distance distance) { - //ToDO Juste pour eviter d'avoir k = 35 je limite la taille max de K. Je vais chercher si y'a une methode particuliere pour limiter sa taille - int maxK = (int) (Math.sqrt(datas.size())/2 *2); + int maxK = (int) (Math.sqrt(datas.size())); System.out.println("Max k: " + maxK); Map<Integer, Double> results = new HashMap<>(); @@ -111,31 +110,44 @@ public class MethodKNN { public static double robustesse(List<LoadableData> datas, int k, Distance distance, double testPart) { - int totalFind = 0; - int totalTry = 0; - // On calcul la robusstesse en utilisant testPart% du fichier de base comme donnée a tester. - int partSize = (int) (datas.size() * testPart); - List<LoadableData> trainingData = new ArrayList<>(List.copyOf(datas.subList(0, datas.size()-partSize))); - List<LoadableData> testData = List.copyOf(datas.subList(datas.size()-partSize, datas.size())); + double taux = 0; - // On met a jour l'algo avec les nouvelles données (permet de re-calculer l'amplitude ainsi que les val max et min - updateModel(trainingData); + for(int i = 0; i<(int)1/testPart; i++) { - // On estime la classe chaque donnée de test, et on verifie si l'algo a bon - for(LoadableData l : testData) { - totalTry++; - String baseClass = l.getClassification(); - // System.out.println("Base class : " + baseClass); - // System.out.println("Base data: " + l); - if(baseClass.equals(MethodKNN.estimateClass(trainingData,l, k, distance))) totalFind++; + int totalFind = 0; + int totalTry = 0; + + // On calcul la robusstesse en utilisant testPart% du fichier de base comme donnée a tester. + int partSize = (int) (datas.size() * testPart); + List<LoadableData> testData = List.copyOf(datas.subList(i*partSize, (i*partSize)+partSize)); + List<LoadableData> trainingData = new ArrayList<>(List.copyOf(datas)); + trainingData.removeAll(testData); + + // On met a jour l'algo avec les nouvelles données (permet de re-calculer l'amplitude ainsi que les val max et min + updateModel(trainingData); + + // On estime la classe chaque donnée de test, et on verifie si l'algo a bon + for(LoadableData l : testData) { + totalTry++; + String baseClass = l.getClassification(); + // System.out.println("Base class : " + baseClass); + // System.out.println("Base data: " + l); + if(baseClass.equals(MethodKNN.estimateClass(trainingData,l, k, distance))) totalFind++; + + } + + + // On affiche le taux de reussite a chaque tour + System.out.println("total find: " +totalFind + " total try: " + totalTry); + taux += (totalFind/(double) totalTry); } - // On return le taux de reussite - System.out.println("total find: " +totalFind + " total try: " + totalTry); - return (totalFind/(double) totalTry); + + + return taux/(1/testPart); } public static void main(String[] args) { @@ -144,8 +156,8 @@ public class MethodKNN { ClassificationModel model = ClassificationModel.getClassificationModel(); - model.setType(DataType.POKEMON); - model.loadData(new File(path+"data/pokemon_train.csv")); + model.setType(DataType.IRIS); + model.loadData(new File(path+"data/iris.csv")); MethodKNN.updateModel(model.getDatas()); System.out.println(); @@ -153,14 +165,18 @@ public class MethodKNN { // On mélange les données pour tester sur differentes variétes car le fichier de base est trié. Collections.shuffle(datas); - System.out.println("Search best k"); + for(int i = 0; i<1; i++) { + System.out.println("Search best k"); - // On cherche le meilleure K - int bestK = MethodKNN.bestK(datas, new DistanceEuclidienneNormalisee()); - System.out.println(bestK); + // On cherche le meilleure K + int bestK = MethodKNN.bestK(datas, new DistanceManhattanNormalisee()); + System.out.println(bestK); + + // Puis on clacul la robustesse avec le K trouvé + System.out.println(MethodKNN.robustesse( datas, bestK, new DistanceManhattanNormalisee(), 0.2)); + + } - // Puis on clacul la robustesse avec le K trouvé - System.out.println(MethodKNN.robustesse( datas, bestK, new DistanceEuclidienneNormalisee(), 0.2)); } diff --git a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java index 9f4be22c476edafedb6efc223e33c5577b3d939d..24fd3796dd99cc4eb343b15cc679a407b90eac05 100644 --- a/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java +++ b/src/main/java/fr/univlille/sae/classification/knn/distance/Distance.java @@ -7,4 +7,30 @@ public interface Distance { double distance(LoadableData l1, LoadableData l2); + static Distance getByName(String name){ + switch (name) { + case "Euclidienne Normalisée": + return new DistanceEuclidienneNormalisee(); + case "Manhattan": + return new DistanceManhattan(); + case "Manhattan Normalisée": + return new DistanceManhattanNormalisee(); + default: + return new DistanceEuclidienne(); + } + + } + + static String getDistanceName(Distance distance){ + if (distance instanceof DistanceEuclidienneNormalisee) { + return "Euclidienne Normalisee"; + }else if (distance instanceof DistanceManhattan){ + return "Manhattan"; + }else if (distance instanceof DistanceManhattanNormalisee){ + return "ManhattanNormalisee"; + }else { + return "Euclidienne"; + } + } + } diff --git a/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceEuclidienneNormalisee.java b/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceEuclidienneNormalisee.java index 5c0b54966dcbbd86e72304ca18b88a7b6dbf10fc..9be1180ffd7a53d7b2b2d785596fdbfcea7dbefc 100644 --- a/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceEuclidienneNormalisee.java +++ b/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceEuclidienneNormalisee.java @@ -14,7 +14,8 @@ public class DistanceEuclidienneNormalisee implements Distance{ double[] normaliseL1 = normalise(l1); double[] normaliseL2 = normalise(l2); for(int i = 0;i<normaliseL1.length;i++) { - total += Math.pow(normaliseL2[i] - normaliseL1[i], 2); + //total += Math.pow(normaliseL2[i] - normaliseL1[i], 2); + total += Math.pow((l2.getAttributes()[i] - l1.getAttributes()[i])/MethodKNN.amplitude[i], 2); } //A Check for(int i = 0;i<l2.getStringAttributes().length;i++) { diff --git "a/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalis\303\251e.java" b/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalisee.java similarity index 89% rename from "src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalis\303\251e.java" rename to src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalisee.java index 96a3e15936d5e1c5ba940d5d7c6ca2cae2060439..50a678bf52e356b3388b9145b072340f61b72d17 100644 --- "a/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalis\303\251e.java" +++ b/src/main/java/fr/univlille/sae/classification/knn/distance/DistanceManhattanNormalisee.java @@ -4,7 +4,7 @@ import fr.univlille.sae.classification.knn.MethodKNN; import fr.univlille.sae.classification.model.LoadableData; -public class DistanceManhattanNormalisée implements Distance{ +public class DistanceManhattanNormalisee implements Distance{ @Override public double distance(LoadableData l1, LoadableData l2) { diff --git a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java index 9e975c9d65d331d6a39451cc0b4ebbd4ad51e76a..0c5291ca16dadd29e930f2346bec1219feb9c5f5 100644 --- a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java +++ b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java @@ -3,6 +3,8 @@ package fr.univlille.sae.classification.model; import com.opencsv.bean.CsvToBeanBuilder; import fr.univlille.sae.classification.knn.MethodKNN; import fr.univlille.sae.classification.knn.distance.Distance; +import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; +import fr.univlille.sae.classification.knn.distance.DistanceManhattan; import fr.univlille.sae.classification.utils.Observable; import java.io.File; @@ -26,6 +28,7 @@ public class ClassificationModel extends Observable { private Distance distance; private int kOptimal; + private int k; /** * Renvoie une instance unique du modèle. Par défaut, le type de ce modèle est Iris. @@ -41,7 +44,7 @@ public class ClassificationModel extends Observable { * Initialise le modèle avec le type de données Iris. */ private ClassificationModel() { - this(DataType.POKEMON); + this(DataType.IRIS); } /** @@ -52,6 +55,9 @@ public class ClassificationModel extends Observable { this.datas = new ArrayList<>(); this.dataToClass = new ConcurrentHashMap<>(); this.type = type; + this.kOptimal = 0; + this.k = 0; + this.distance = new DistanceEuclidienne(); } /** * Ajoute un point au nuage de points avec toutes les données de ce point. @@ -85,6 +91,8 @@ public class ClassificationModel extends Observable { types.add(d.getClassification()); } + Collections.shuffle(datas); + LoadableData.setClassificationTypes(types); notifyObservers(); } catch (IOException e) { @@ -108,11 +116,8 @@ public class ClassificationModel extends Observable { */ public void classifierDonnee(LoadableData data) { if(dataToClass.get(data) != null && dataToClass.get(data)) return; - List<String> classes = new ArrayList<>(LoadableData.getClassificationTypes()); - - - - data.setClassification(MethodKNN.estimateClass(datas, data, 1, distance)); + this.dataToClass.remove(data); + data.setClassification(MethodKNN.estimateClass(datas, data, kOptimal, distance)); notifyObservers(data); dataToClass.put(data, true); } @@ -142,6 +147,18 @@ public class ClassificationModel extends Observable { return kOptimal; } + public void setKOptimal(int kOptimal) { + this.kOptimal = kOptimal; + } + + public int getK() { + return k; + } + + public void setK(int k) { + this.k = k; + } + /** * Renvoie la liste des données chargées. * @return liste des données chargées. diff --git a/src/main/java/fr/univlille/sae/classification/model/Pokemon.java b/src/main/java/fr/univlille/sae/classification/model/Pokemon.java index 2c5549a851e22288c3f3e28cfb4c8fd3421ce2ae..1311287cdc4219dda769e5a2c94a2311dbddf904 100644 --- a/src/main/java/fr/univlille/sae/classification/model/Pokemon.java +++ b/src/main/java/fr/univlille/sae/classification/model/Pokemon.java @@ -200,7 +200,7 @@ public class Pokemon extends LoadableData{ @Override public String[] getStringAttributes() { - return new String[0]; + return new String[]{name, type2, String.valueOf(isLegendary)}; } @Override diff --git a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java index 08b8d8bdc84c1938ab0d21d91346d6af50ee0f1c..ca99bcfaa63020966c667cc8b2492830ba9aa75e 100644 --- a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java @@ -124,6 +124,9 @@ public class DataStageView extends DataVisualizationView implements Observer { if(editSerie == null){ editSerie = new ScatterChart.Series<Double, Double>(); } + if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { + nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); + } dataPoint.setNode(nodePoint); editSerie.getData().add(dataPoint); serieList.put(data.getClassification(), editSerie); @@ -133,7 +136,6 @@ public class DataStageView extends DataVisualizationView implements Observer { serieList.get(serie).setName(serie); } scatterChart.getData().addAll(serieList.values()); - scatterChart.setLegendVisible(true); } } catch (Exception e) { System.err.println("Erreur de mise à jour : " + e.getMessage()); diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java index 6b80a19d8507e2fe55ad38f6452471ff8e2c7bed..1bb2c5347c165156a014e5f517f65566bb54daad 100644 --- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java @@ -124,6 +124,10 @@ public class MainStageView extends DataVisualizationView implements Observer { if(editSerie == null){ editSerie = new ScatterChart.Series<Double, Double>(); } + if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { + nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); + } + dataPoint.setNode(nodePoint); editSerie.getData().add(dataPoint); serieList.put(data.getClassification(), editSerie); @@ -147,8 +151,7 @@ public class MainStageView extends DataVisualizationView implements Observer { return; } - scatterChart.getData().clear(); - serieList.clear(); + LoadableData newData = (LoadableData) data; if (actualX == null || actualY == null) {