diff --git a/DevEfficace/rapport.md b/DevEfficace/rapport.md new file mode 100644 index 0000000000000000000000000000000000000000..57e2c8bb503e8153ea05fd21d650bed7563c12c6 --- /dev/null +++ b/DevEfficace/rapport.md @@ -0,0 +1,129 @@ +# Algorithme K-NN - Rapport R3.02 (EN COURS - NON RELU) + +### Groupe H-4 + +- [MENNECART Matias](mailto:matias.mennecart.etu@univ-lille.fr) +- [DEBUYSER Hugo](mailto:hugo.debuyser.etu@univ-lille.fr) +- [ANTOINE Maxence](mailto:maxence.antoine.etu@univ-lille.fr) +- [DEKEISER Matisse](mailto:matisse.dekeiser.etu@univ-lille.fr) +- [DESMONS Hugo](mailto:hugo.desmons.etu@univ-lille.fr) + + +--- + +## Implémentation de K-NN + +*Une description de votre implémentation de l’algorithme k-NN : classe implémentant l’algorithme, méthode(s) de cette classe implémentant le calcul de la distance, traitement de la normalisation, méthode(s) de cette classe implémentant la classification, méthode(s) évaluant la robustesse. N’hésitez +pas à mettre en avant l’efficacité de ces méthodes (approprié pour un grand volume de données, normalisation +efficace des distances).* + +Nous avons implémenté l'algorithme K-NN dans une classe MethodKNN. Nous avons choisis d'utiliser des methodes statiques pour faciliter l'utilisation de cette algorithme. +Cette classe contient 5 methodes permettant d'implementer l'algorithme K-NN et ses differentes fonctionnalités. + +> public static void updateModel(List<LoadableData> datas) + +Cette premiere methode permet de mettre a jour les données de l'algorithme. Ainsi il faut passer en parametre les données sur lesquelles ont souhaite travailler. +Cette methode va calculer les valeurs max et min ainsi que l'amplitude de chaque attribue des données passé en parametre. +Cette methode necessite 2 parcours, 1 sur les données, puis un sur les min et les max de chaque attribue pour calculer l'amplitude. Il est donc conseillé de n'utiliser cette methode qu'une fois par jeu de données (toute facon le resultat serait le meme) +C'est pour cela que cette methode n'est pas directement appelé dans les methodes ci-dessous mais qu'elle doit explicitement être appelé aupparavant. + +> public static List<LoadableData> kVoisins(List<LoadableData> datas, LoadableData data, int k, Distance distance) + +Cette methode a pour objectif de recupéré les k voisins les plus proches d'un donnée parmis un jeu de données et selon une distance. +Elle prends donc en parametre le jeu de données, la données pour laquelle on souhaite obtenir les voisins, le nombre de voisins souhaités ainsi que la distance avec laquelle les calculs doivent etres effectués. +Cette methode n'effectue qu'un seul parcours de boucle et calcul pour chacune des données sa distance avec la donnée passée en parametre. Les calculs de distance sont definis dans l'objet implémentant l'interface Distance passé en parametre. +Si il s'agit d'une distance normalisée la normalisation s'effectue au moment de la recherche des voisins. + +>public static double robustesse(List<LoadableData> datas, int k, Distance distance, double testPart) + + +Cette methode a pour objectif d'évaluer la robustesse de l'algorithme sur un jeu de donné avec un K donné et une distance donnée. Elle prends en parametre le jeu de données sur lequel effectué l'evalutation, +le k a tester, la distance a utiliser ainsi que un pourcentage correspondant a la partie des données qui sera utilisé afin de tester les données. Ex avec testPart=0.2, 80% des données serviront de données de reférence et 20% seront utilisé pour tester la +validité de l'algorithme. Cette effectue une validation croisée (Voir #Validation Croisée) + + +> public static int bestK(List<LoadableData> datas, Distance distance) + +Cette methode a pour objectif de rechercher le meilleur K possible pour un jeu de données et une distance données. +Elle va tester la totalité des K impair compris entre 1 et racine caré du nombre de données. Cette valeur max permet d'eviter que le K choisit soit trop grand et +fausse les résultats. Elle test donc la robustesse de chaque K et renvoie le K ayant la meilleure robustesse. Cette methode parcours le jeu de donnée Kmax*(1/testPart) fois, ou testPart correspond au pourcentage des données utilisé comme valeurs de test. + + +--- + +## Validation croisée + + +#### Rappel de la methode de validation croisée + +La validation croisée est une méthode d'évaluation utilisée pour mesurer la performance d'un modèle en le testant sur des données qu'il n'a pas utilisées pour l'entraînement. Dans ce cas précis : + +Les données sont divisées en plusieurs parties égales (appelées folds ou partitions). +À chaque itération, une des partitions est utilisée comme jeu de test, tandis que les autres servent pour l'entraînement. +Les résultats des tests sont cumulés pour calculer un score global. +Cette méthode permet de minimiser le biais d'évaluation en utilisant toutes les données tour à tour pour l'entraînement et le test. + +#### Implémentation de la validation croisée + + +La méthode effectue une validation croisée en divisant les données en plusieurs partitions. À chaque itération : + +Une partie des données sert de jeu de test. +Le reste des données sert de jeu d'entraînement. + +Voici les étapes principales : + +- Calcul du nombre d'itérations : +Le nombre d'itérations est déterminé par +1/testPart. Par exemple, si testPart = 0.1, la méthode effectue 10 itérations, si testPart = 0.2, la méthode effectue 5 iterations, etc. + +- Division des données : +Pour chaque itération i, une sous-liste correspondant au pourcentage testPart (par exemple, 10 % des données) est extraite et utilisée comme jeu de test (testData). +Le reste des données sert de jeu d'entraînement (trainingData). + +- Estimation des classes : +Chaque élément du jeu de test est classé à l’aide de la fonction MethodKNN.estimateClass(trainingData, l, k, distance) où trainingData correspond au reste des données. +Si la classe prédite correspond à la classe réelle (retournée par l.getClassification()), un compteur (totalFind) est incrémenté. + +- Calcul du taux de réussite : +Après chaque itération, le taux de réussite est calculé totalFind/totalTry et ajouté à la variable taux. +- Renvoie du taux total moyen de reussite: Enfin, on divise la variable taux par le nombre d'itération afin d'obtenir un taux de reussite moyen. +--- + +## Choix du meilleur K + + +Pour obtenir le meilleur K, on appel la methode bestK(List<LoadableData> datas, Distance distance) decrite plus haut. On obtient un K optimal, puis on appel la methode robustesse(...) avec le k trouvé plutot comme parametre. + +En appliquant cette methode voici les resultats que nous avons obtenue avec: + +##### Iris + + + +| Distance \ K | 1 | 3 | 5 | 7 | 9 | 11 | K choisit | +|---------------------------------|-------|-------|-------|-------|-------|-------|----| +| Distance Euclidienne | 0.96 | 0.966 | 0.96 | 0.98 | 0.98 | 0.98 | 5 | +| Distance Euclidienne Normalisée | 0.946 | 0.946 | 0.96 | 0.96 | 0.96 | 0.96 | 7 | +| Distance Manhattan | 0.953 | 0.946 | 0.946 | 0.96 | 0.96 | 0.953 | 7 | +| Distance Manhattan Normalisée | 0.946 | 0.96 | 0.946 | 0.953 | 0.953 | 0.953 | 3 | + +On obtient donc un taux de reussiste plutôt élevé. A chaque fois l'algorithme choisit le K avec le plus haut taux de reussite. En cas d'égalité, il choisit le plus petit K parmis les égalités. + +##### Pokemon + +| Distance \ K | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 | 17 | 19 | 21 | K choisit | +|---------------------------------|-------|-------|-------|-------|-------|-------|----|----|----|----|----|----| +| Distance Euclidienne | 0.243 | 0.229 | 0.239 | 0.235 | 0.247 | 0.251 | 0.237 | 0.225 | 0.215 | 0.205 | 0.2 | 11 | +| Distance Euclidienne Normalisée | 0.211 | 0.229 | 0.251 | 0.245 | 0.245 | 0.239 | 0.245 | 0.243 | 0.237 | 0.239 | 0.225 | 5 | +| Distance Manhattan | 0.231 | 0.235 | 0.239 | 0.239 | 0.241 | 0.239 | 0.237 | 0.235 | 0.233 | 0.207 | 0.201 | 9 | +| Distance Manhattan Normalisée | 0.178 | 0.188 | 0.2 | 0.215 | 0.205 | 0.203 | 0.194 | 0.190 | 0.184 | 0.180 | 0.190 | 7 | +--- + +ajoyter un commentaire sur les resultats + +## Efficacité + +Comme expliqué pour chaque methode dans la partie Implementation de l'algorithme, nous avons chercher a minimiser le nombre de parcours du fichier de données et plus generalement le nombre de boucle. +L'algorithme necessite une List qui sera donnée en parametre, il est donc libre a la personne qui l'utilise de fournir l'implementation de List qu'il souhaite. De plus, pour le calcul des parametres du jeu de données ( +amplitude,valeur minimale, valeur maximal) nous avons utiliser un tableau de double afin de limiter les performances \ No newline at end of file diff --git a/res/data/error_iris.csv b/res/data/error_iris.csv new file mode 100644 index 0000000000000000000000000000000000000000..350d1471cdab74a6393233a8413ab6bd45cc5cb7 --- /dev/null +++ b/res/data/error_iris.csv @@ -0,0 +1,6 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +"abc","xyz","pqr","uvw","Setosa" +"def","lmn","stu","ijk","Setosa" +"ghi","opq","vwx","rst","Setosa" +"jkl","mno","yza","bcd","Setosa" +"qrs","tuv","efg","hij","Setosa" \ No newline at end of file diff --git a/res/stages/data-view-stage.fxml b/res/stages/data-view-stage.fxml index 315cb877303a7d975b4ac5cf5e19bf23b8e212bf..9f17a8fd5c75ee133bb950bbb475f88f00c6548f 100644 --- a/res/stages/data-view-stage.fxml +++ b/res/stages/data-view-stage.fxml @@ -1,30 +1,36 @@ <?xml version="1.0" encoding="UTF-8"?> -<?import javafx.geometry.*?> -<?import javafx.scene.*?> -<?import javafx.scene.chart.*?> -<?import javafx.scene.control.*?> -<?import javafx.scene.image.*?> -<?import javafx.scene.layout.*?> -<?import javafx.scene.text.*?> -<?import javafx.stage.*?> +<?import javafx.geometry.Insets?> +<?import javafx.scene.Scene?> +<?import javafx.scene.chart.NumberAxis?> +<?import javafx.scene.chart.ScatterChart?> +<?import javafx.scene.control.Button?> +<?import javafx.scene.control.Label?> +<?import javafx.scene.control.ListView?> +<?import javafx.scene.image.Image?> +<?import javafx.scene.image.ImageView?> +<?import javafx.scene.layout.AnchorPane?> +<?import javafx.scene.layout.HBox?> +<?import javafx.scene.layout.VBox?> +<?import javafx.scene.text.Font?> +<?import javafx.stage.Stage?> -<Stage fx:id="stage" xmlns="http://javafx.com/javafx/17.0.12" xmlns:fx="http://javafx.com/fxml/1" fx:controller="fr.univlille.sae.classification.controller.DataStageController"> +<Stage fx:id="stage" xmlns="http://javafx.com/javafx/21" xmlns:fx="http://javafx.com/fxml/1" fx:controller="fr.univlille.sae.classification.controller.DataStageController"> <scene> <Scene> <AnchorPane prefHeight="502.0" prefWidth="999.0"> <children> - <VBox prefHeight="513.0" prefWidth="999.0"> + <VBox prefHeight="562.0" prefWidth="999.0"> <children> <Label alignment="CENTER" prefHeight="52.0" prefWidth="999.0" style="-fx-background-color: #105561;" text="Vue de classification de données" textFill="WHITE"> <font> <Font name="System Bold" size="20.0" /> </font></Label> - <HBox prefHeight="463.0" prefWidth="999.0"> + <HBox prefHeight="511.0" prefWidth="999.0"> <children> <VBox prefHeight="459.0" prefWidth="762.0"> <children> - <HBox alignment="TOP_CENTER" prefHeight="462.0" prefWidth="762.0" spacing="5.0"> + <HBox alignment="TOP_CENTER" prefHeight="462.0" prefWidth="760.0" spacing="5.0"> <children> <AnchorPane prefHeight="509.0" prefWidth="688.0"> <children> @@ -60,6 +66,18 @@ </Button> </children> </HBox> + <VBox prefHeight="65.0" prefWidth="801.0"> + <children> + <HBox fx:id="legend" alignment="CENTER" prefHeight="58.0" prefWidth="762.0" spacing="10.0"> + <opaqueInsets> + <Insets /> + </opaqueInsets> + <padding> + <Insets left="2.0" right="2.0" /> + </padding> + </HBox> + </children> + </VBox> </children> </VBox> <VBox prefHeight="470.0" prefWidth="238.0"> diff --git a/res/stages/main-stage.fxml b/res/stages/main-stage.fxml index e867ac83924ee95ffef3b2b01e4cf2d69183734f..1e6fdc9b6536a7d946ca517c86d9cf002de7367b 100644 --- a/res/stages/main-stage.fxml +++ b/res/stages/main-stage.fxml @@ -9,7 +9,7 @@ <?import javafx.scene.text.*?> <?import javafx.stage.*?> -<Stage fx:id="stage" xmlns="http://javafx.com/javafx/17.0.12" xmlns:fx="http://javafx.com/fxml/1" fx:controller="fr.univlille.sae.classification.controller.MainStageController"> +<Stage fx:id="stage" xmlns="http://javafx.com/javafx/17.0.2-ea" xmlns:fx="http://javafx.com/fxml/1" fx:controller="fr.univlille.sae.classification.controller.MainStageController"> <scene> <Scene> <AnchorPane prefHeight="535.0" prefWidth="922.0"> @@ -37,12 +37,19 @@ <NumberAxis fx:id="ordAxe" prefHeight="354.0" prefWidth="54.0" side="LEFT" stylesheets="@../css/style.css" /> </yAxis> </ScatterChart> - <Label fx:id="AxesSelected" alignment="CENTER" layoutX="73.0" layoutY="152.0" prefHeight="38.0" prefWidth="600.0"> + <Label fx:id="AxesSelected" alignment="CENTER" layoutX="-2.0" layoutY="3.0" prefHeight="380.0" prefWidth="682.0"> <font> <Font size="21.0" /> </font> </Label> - <VBox layoutY="345.0" prefHeight="78.0" prefWidth="678.0" /> + <VBox layoutY="345.0" prefHeight="78.0" prefWidth="678.0"> + <children> + <VBox fx:id="legend" alignment="CENTER" prefHeight="100.0" prefWidth="200.0" spacing="10.0"> + <padding> + <Insets left="2.0" right="2.0" /> + </padding> + </VBox> + </children></VBox> </children> <HBox.margin> <Insets left="10.0" /> diff --git a/src/main/java/fr/univlille/sae/classification/ClassificationApp.java b/src/main/java/fr/univlille/sae/classification/ClassificationApp.java index 145eb3a499cc9b13ec62058508c4933b03f4c0b2..c3c9cd7cab7f88c31ab2a550418d84daef27a770 100644 --- a/src/main/java/fr/univlille/sae/classification/ClassificationApp.java +++ b/src/main/java/fr/univlille/sae/classification/ClassificationApp.java @@ -7,9 +7,19 @@ import javafx.stage.Stage; import java.io.IOException; +/** + * Classe principale pour l'application de classification. + * Cette classe initialise et lance l'interface graphique de l'application. + */ public class ClassificationApp extends Application { + /** + * Point d'entrée principal pour l'initialisation de l'interface utilisateur. + * Cette méthode configure la vue principale en utilisant une instance du modèle + * de classification, puis affiche la fenêtre principale. + * @param stage la fenêtre principale de l'application. + */ public void start(Stage stage) throws IOException { ClassificationModel model = ClassificationModel.getClassificationModel(); MainStageView view = new MainStageView(model); @@ -17,7 +27,11 @@ public class ClassificationApp extends Application { view.show(); } - // Ouvre l'application + /** + * Point d'entrée principal de l'application. + * Cette méthode lance l'application JavaFX. + * @param args les arguments de ligne de commande. + */ public static void main(String[] args) { Application.launch(args); } diff --git a/src/main/java/fr/univlille/sae/classification/Main.java b/src/main/java/fr/univlille/sae/classification/Main.java index c1bafba0e15b9d93c8067cca55fecbe14473c900..84adac9c52050badc03a6fd071244f7d2e781d39 100644 --- a/src/main/java/fr/univlille/sae/classification/Main.java +++ b/src/main/java/fr/univlille/sae/classification/Main.java @@ -6,8 +6,14 @@ import javafx.application.Application; import javafx.stage.Stage; import java.io.IOException; - +/** + * Cette classe redirige l'exécution vers la classe principale de l'application, + */ public class Main { + /** + * Point d'entrée principal de l'application. + * @param args les arguments de ligne de commande. + */ public static void main(String[] args) { ClassificationApp.main(args); } diff --git a/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java b/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java index 21c7e41752c372ff4c8c1e47add3812f136d0e00..9a220bdf4e1c24f8db37dad2a63785c1c25873de 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java @@ -90,6 +90,7 @@ public class AddDataController { else if (attrValue instanceof Boolean) { ChoiceBox<String> choiceBox = new ChoiceBox<>(); choiceBox.getItems().addAll("VRAI", "FAUX"); + choiceBox.setValue("VRAI"); hbox.getChildren().add(choiceBox); components.add(choiceBox); } @@ -124,6 +125,7 @@ public class AddDataController { return null; }).toArray(); + System.out.println(Arrays.toString(values)); ClassificationModel.getClassificationModel().ajouterDonnee(values); }catch (IllegalArgumentException e){ Alert alert = new Alert(Alert.AlertType.ERROR); diff --git a/src/main/java/fr/univlille/sae/classification/controller/DataStageController.java b/src/main/java/fr/univlille/sae/classification/controller/DataStageController.java index 94c7e97cc0273ca8d7e2ea6367e0e7cc9554a1f3..edc1926630b2d6e48ec257ef103af2249d61d715 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/DataStageController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/DataStageController.java @@ -15,152 +15,43 @@ import java.io.IOException; /** * Controlleur pour le FXML data-view-stage, pour gérer la vue supplémentaire */ -public class DataStageController { +public class DataStageController extends DataVisualizationController{ @FXML Stage stage; - @FXML - ScatterChart scatterChart; - @FXML - Label AxesSelected; + @FXML ListView PointInfo; - /** - * DataStageView associé au controlleur - */ - private DataStageView dataStageView; - private double initialX; - private double initialY; - private double initialLowerBoundX; - private double initialUpperBoundX; - private double initialLowerBoundY; - private double initialUpperBoundY; + public void initialize() { setupZoom(); setupDrag(); } - /** - * Ouvrir les paramètres des axes de la vue - */ - public void openAxesSetting(){ - AxesSettingsView axesSettingsView = new AxesSettingsView(ClassificationModel.getClassificationModel(), stage, dataStageView); - axesSettingsView.show(); - } + /** * Associe la dataStageView associer à la classe * @param dataStageView */ public void setDataStageView (DataStageView dataStageView) { - this.dataStageView = dataStageView; + this.view = dataStageView; } - /** - * Renvoie la grille associé à la classe - * @return grille de la classe - */ - public ScatterChart getScatterChart() { - return this.scatterChart; - } - /** - * Attribut une valeur à l'axe de la grille - * @param texte Valeur de l'axe - */ - public void setAxesSelected(String texte) { - this.AxesSelected.setText(texte); - } - public void setAxesSelectedDisable(){ - this.AxesSelected.setDisable(true); - } + public ListView getPointInfo(){ return this.PointInfo; }; - private void setupZoom() { - NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); - NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); - - scatterChart.setOnScroll(event -> { - xAxis.setAutoRanging(false); - yAxis.setAutoRanging(false); - - double delta = event.getDeltaY(); - double mouseX = event.getSceneX(); - double mouseY = event.getSceneY(); - - double chartX = xAxis.sceneToLocal(mouseX, mouseY).getX(); - double chartY = yAxis.sceneToLocal(mouseX, mouseY).getY(); - - double zoomFactor; - if (delta > 0) { - zoomFactor = 0.90; - } else { - zoomFactor = 1.05; - } - - double xLower = xAxis.getLowerBound(); - double xUpper = xAxis.getUpperBound(); - double yLower = yAxis.getLowerBound(); - double yUpper = yAxis.getUpperBound(); - double rangeX = xUpper - xLower; - double rangeY = yUpper - yLower; - - double newRangeX = rangeX * zoomFactor; - double newRangeY = rangeY * zoomFactor; - - xAxis.setLowerBound(xLower + (chartX / xAxis.getWidth()) * (rangeX - newRangeX)); - xAxis.setUpperBound(xUpper - ((xAxis.getWidth() - chartX) / xAxis.getWidth()) * (rangeX - newRangeX)); - - yAxis.setLowerBound(yLower + ((yAxis.getHeight() - chartY) / yAxis.getHeight()) * (rangeY - newRangeY)); - yAxis.setUpperBound(yUpper - (chartY / yAxis.getHeight()) * (rangeY - newRangeY)); - }); - - xAxis.setAutoRanging(true); - yAxis.setAutoRanging(true); - } - - private void setupDrag() { - scatterChart.setOnMousePressed(event -> { - initialX = event.getSceneX(); - initialY = event.getSceneY(); - initialLowerBoundX = ((NumberAxis) scatterChart.getXAxis()).getLowerBound(); - initialUpperBoundX = ((NumberAxis) scatterChart.getXAxis()).getUpperBound(); - initialLowerBoundY = ((NumberAxis) scatterChart.getYAxis()).getLowerBound(); - initialUpperBoundY = ((NumberAxis) scatterChart.getYAxis()).getUpperBound(); - }); - - NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); - NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); - - scatterChart.setOnMouseDragged(event -> { - xAxis.setAutoRanging(false); - yAxis.setAutoRanging(false); - double deltaX = event.getSceneX() - initialX; - double deltaY = event.getSceneY() - initialY; - - double newLowerBoundX = initialLowerBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); - double newUpperBoundX = initialUpperBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); - double newLowerBoundY = initialLowerBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); - double newUpperBoundY = initialUpperBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); - - xAxis.setLowerBound(newLowerBoundX); - xAxis.setUpperBound(newUpperBoundX); - yAxis.setLowerBound(newLowerBoundY); - yAxis.setUpperBound(newUpperBoundY); - }); - xAxis.setAutoRanging(true); - yAxis.setAutoRanging(true); - } } diff --git a/src/main/java/fr/univlille/sae/classification/controller/DataVisualizationController.java b/src/main/java/fr/univlille/sae/classification/controller/DataVisualizationController.java new file mode 100644 index 0000000000000000000000000000000000000000..5ef74cc7e39ae4b8cb3dd04ad4ae9900740f78e6 --- /dev/null +++ b/src/main/java/fr/univlille/sae/classification/controller/DataVisualizationController.java @@ -0,0 +1,159 @@ +package fr.univlille.sae.classification.controller; + +import fr.univlille.sae.classification.model.ClassificationModel; +import fr.univlille.sae.classification.view.AxesSettingsView; +import fr.univlille.sae.classification.view.DataVisualizationView; +import javafx.fxml.FXML; +import javafx.scene.chart.NumberAxis; +import javafx.scene.chart.ScatterChart; +import javafx.scene.control.Label; +import javafx.scene.layout.HBox; +import javafx.scene.layout.VBox; +import javafx.stage.Stage; + +public abstract class DataVisualizationController { + + + @FXML + Stage stage; + + @FXML + Label AxesSelected; + + @FXML + VBox legend; + + + @FXML + ScatterChart scatterChart; + + protected double initialX; + protected double initialY; + protected double initialLowerBoundX; + protected double initialUpperBoundX; + protected double initialLowerBoundY; + protected double initialUpperBoundY; + + + protected DataVisualizationView view; + + + protected void setupZoom() { + NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); + NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); + + scatterChart.setOnScroll(event -> { + xAxis.setAutoRanging(false); + yAxis.setAutoRanging(false); + + double delta = event.getDeltaY(); + double mouseX = event.getSceneX(); + double mouseY = event.getSceneY(); + + double chartX = xAxis.sceneToLocal(mouseX, mouseY).getX(); + double chartY = yAxis.sceneToLocal(mouseX, mouseY).getY(); + + double zoomFactor; + if (delta > 0) { + zoomFactor = 0.90; + } else { + zoomFactor = 1.05; + } + + double xLower = xAxis.getLowerBound(); + double xUpper = xAxis.getUpperBound(); + double yLower = yAxis.getLowerBound(); + double yUpper = yAxis.getUpperBound(); + + double rangeX = xUpper - xLower; + double rangeY = yUpper - yLower; + + double newRangeX = rangeX * zoomFactor; + double newRangeY = rangeY * zoomFactor; + + xAxis.setLowerBound(xLower + (chartX / xAxis.getWidth()) * (rangeX - newRangeX)); + xAxis.setUpperBound(xUpper - ((xAxis.getWidth() - chartX) / xAxis.getWidth()) * (rangeX - newRangeX)); + + yAxis.setLowerBound(yLower + ((yAxis.getHeight() - chartY) / yAxis.getHeight()) * (rangeY - newRangeY)); + yAxis.setUpperBound(yUpper - (chartY / yAxis.getHeight()) * (rangeY - newRangeY)); + }); + + xAxis.setAutoRanging(true); + yAxis.setAutoRanging(true); + } + + + protected void setupDrag() { + scatterChart.setOnMousePressed(event -> { + initialX = event.getSceneX(); + initialY = event.getSceneY(); + initialLowerBoundX = ((NumberAxis) scatterChart.getXAxis()).getLowerBound(); + initialUpperBoundX = ((NumberAxis) scatterChart.getXAxis()).getUpperBound(); + initialLowerBoundY = ((NumberAxis) scatterChart.getYAxis()).getLowerBound(); + initialUpperBoundY = ((NumberAxis) scatterChart.getYAxis()).getUpperBound(); + }); + + NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); + NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); + + scatterChart.setOnMouseDragged(event -> { + xAxis.setAutoRanging(false); + yAxis.setAutoRanging(false); + double deltaX = event.getSceneX() - initialX; + double deltaY = event.getSceneY() - initialY; + + double newLowerBoundX = initialLowerBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); + double newUpperBoundX = initialUpperBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); + double newLowerBoundY = initialLowerBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); + double newUpperBoundY = initialUpperBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); + + xAxis.setLowerBound(newLowerBoundX); + xAxis.setUpperBound(newUpperBoundX); + yAxis.setLowerBound(newLowerBoundY); + yAxis.setUpperBound(newUpperBoundY); + }); + xAxis.setAutoRanging(true); + yAxis.setAutoRanging(true); + } + + /** + * Ouvrir les paramètres des axes de la vue + */ + public void openAxesSetting(){ + AxesSettingsView axesSettingsView = new AxesSettingsView(ClassificationModel.getClassificationModel(), stage, view); + axesSettingsView.show(); + } + + + /** + * Renvoie la grille associé à la classe + * @return grille de la classe + */ + public ScatterChart getScatterChart() { + return this.scatterChart; + } + + + /** + * Attribut une valeur à l'axe de la grille + * @param texte Valeur de l'axe + */ + public void setAxesSelected(String texte) { + this.AxesSelected.setText(texte); + } + + public void setAxesSelectedDisable(){ + this.AxesSelected.setDisable(true); + } + + + + public void loadLegend(VBox vBox) { + this.legend.getChildren().clear(); + this.legend.getChildren().addAll(vBox.getChildren()); + } + + + + +} diff --git a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java index def9b3273b96a1a2870403485fa2466c807f98a0..aed754ea35b2858b197d619a6436712911466db6 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/KNNController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/KNNController.java @@ -8,9 +8,11 @@ import fr.univlille.sae.classification.model.LoadableData; import javafx.collections.ObservableList; import javafx.concurrent.Task; import javafx.fxml.FXML; +import javafx.geometry.Pos; import javafx.scene.Scene; import javafx.scene.control.*; import javafx.scene.layout.HBox; +import javafx.scene.layout.VBox; import javafx.stage.Stage; import java.util.ArrayList; @@ -37,10 +39,11 @@ public class KNNController { @FXML public void initialize() { + int max = (int) Math.sqrt(ClassificationModel.getClassificationModel().getDatas().size()); kEntry.setValueFactory(new SpinnerValueFactory.IntegerSpinnerValueFactory(1, - (int) Math.sqrt(ClassificationModel.getClassificationModel().getDatas().size()), + (max%2 == 0) ? max-1 : max, 1, - 1)); + 2)); kEntry.getValueFactory().setValue(ClassificationModel.getClassificationModel().getK()); @@ -59,13 +62,10 @@ public class KNNController { }else { // Calcul du K Optimal: - HBox hBox = new HBox(); - Task<Scene> knnTask = new Task<>() { @Override protected Scene call() throws Exception { - System.out.println("Call call()"); updateProgress(0, 3); updateMessage("Préparation des données "); @@ -104,30 +104,32 @@ public class KNNController { return scene; } }; - + VBox vBox = new VBox(); ProgressBar pBar = new ProgressBar(); pBar.progressProperty().bind(knnTask.progressProperty()); Label statusLabel = new Label(); statusLabel.textProperty().bind(knnTask.messageProperty()); - - hBox.getChildren().addAll(statusLabel, pBar); + vBox.alignmentProperty().setValue(Pos.CENTER); + vBox.getChildren().addAll( pBar, statusLabel); Stage stageLoad = new Stage(); - Scene scene = new Scene(hBox); + Scene scene = new Scene(vBox); + stageLoad.setTitle("Alogirhme K-NN"); + stageLoad.setMinWidth(300); stageLoad.setScene(scene); stageLoad.show(); Stage stageFinished = new Stage(); - + stageFinished.setTitle("Algorithme K-NN - results"); knnTask.setOnSucceeded(e -> { stageLoad.close(); stageFinished.setScene(knnTask.getValue()); stageFinished.show(); }); - knnTask.run(); - //new Thread(knnTask).start(); + + new Thread(knnTask).start(); diff --git a/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java b/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java index 014af783059a9eb43473293a24c3d30af44c39c2..4baf15fcc240ad98712f30d4d42a49855513be51 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/LoadDataController.java @@ -1,5 +1,7 @@ package fr.univlille.sae.classification.controller; +import com.opencsv.exceptions.CsvException; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import fr.univlille.sae.classification.model.ClassificationModel; import fr.univlille.sae.classification.model.DataType; import javafx.fxml.FXML; @@ -34,6 +36,7 @@ public class LoadDataController { @FXML public void initialize() { fileType.getItems().addAll(DataType.values()); + fileType.setValue(DataType.values()[0]); } /** @@ -72,7 +75,15 @@ public class LoadDataController { } ClassificationModel.getClassificationModel().setType(typeChoisi); - ClassificationModel.getClassificationModel().loadData(file); + try { + ClassificationModel.getClassificationModel().loadData(file); + }catch (RuntimeException | CsvRequiredFieldEmptyException e) { + Alert alert = new Alert(Alert.AlertType.ERROR); + alert.initOwner(stage); + alert.setHeaderText(e.toString()); + alert.setContentText("Le chargement du fichier à echoué, veuillez reessayer !"); + alert.showAndWait(); + } stage.close(); } } diff --git a/src/main/java/fr/univlille/sae/classification/controller/MainStageController.java b/src/main/java/fr/univlille/sae/classification/controller/MainStageController.java index d18c02d73d9952ef1fb300ddde0731b560ace749..bed2fe9ea1cbc0a630a8026a45faea8404b04698 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/MainStageController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/MainStageController.java @@ -12,31 +12,22 @@ import javafx.stage.Stage; import java.io.IOException; -public class MainStageController { +public class MainStageController extends DataVisualizationController{ + - @FXML - Stage stage; @FXML Button classifyData; - @FXML - ScatterChart scatterChart; - @FXML - Label AxesSelected; + @FXML ListView PointInfo; + private MainStageView mainStageView; - private double initialX; - private double initialY; - private double initialLowerBoundX; - private double initialUpperBoundX; - private double initialLowerBoundY; - private double initialUpperBoundY; public void initialize() { setupZoom(); @@ -81,6 +72,7 @@ public class MainStageController { */ public void setMainStageView(MainStageView mainStageView) { this.mainStageView = mainStageView; + this.view = mainStageView; } /** @@ -109,17 +101,7 @@ public class MainStageController { return this.scatterChart; } - /** - * Attribue une valeur à l'axe de la grille. - * @param texte Valeur de l'axe à afficher sur l'interface. - */ - public void setAxesSelected(String texte) { - this.AxesSelected.setText(texte); - } - public void setAxesSelectedDisable(){ - this.AxesSelected.setDisable(true); - } /** * Renvoie le bouton de classification de données. @@ -133,82 +115,6 @@ public class MainStageController { return this.PointInfo; }; - private void setupZoom() { - NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); - NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); - - scatterChart.setOnScroll(event -> { - xAxis.setAutoRanging(false); - yAxis.setAutoRanging(false); - - double delta = event.getDeltaY(); - double mouseX = event.getSceneX(); - double mouseY = event.getSceneY(); - double chartX = xAxis.sceneToLocal(mouseX, mouseY).getX(); - double chartY = yAxis.sceneToLocal(mouseX, mouseY).getY(); - - double zoomFactor; - if (delta > 0) { - zoomFactor = 0.90; - } else { - zoomFactor = 1.05; - } - - double xLower = xAxis.getLowerBound(); - double xUpper = xAxis.getUpperBound(); - double yLower = yAxis.getLowerBound(); - double yUpper = yAxis.getUpperBound(); - - double rangeX = xUpper - xLower; - double rangeY = yUpper - yLower; - - double newRangeX = rangeX * zoomFactor; - double newRangeY = rangeY * zoomFactor; - - xAxis.setLowerBound(xLower + (chartX / xAxis.getWidth()) * (rangeX - newRangeX)); - xAxis.setUpperBound(xUpper - ((xAxis.getWidth() - chartX) / xAxis.getWidth()) * (rangeX - newRangeX)); - - yAxis.setLowerBound(yLower + ((yAxis.getHeight() - chartY) / yAxis.getHeight()) * (rangeY - newRangeY)); - yAxis.setUpperBound(yUpper - (chartY / yAxis.getHeight()) * (rangeY - newRangeY)); - }); - - xAxis.setAutoRanging(true); - yAxis.setAutoRanging(true); - } - - - private void setupDrag() { - scatterChart.setOnMousePressed(event -> { - initialX = event.getSceneX(); - initialY = event.getSceneY(); - initialLowerBoundX = ((NumberAxis) scatterChart.getXAxis()).getLowerBound(); - initialUpperBoundX = ((NumberAxis) scatterChart.getXAxis()).getUpperBound(); - initialLowerBoundY = ((NumberAxis) scatterChart.getYAxis()).getLowerBound(); - initialUpperBoundY = ((NumberAxis) scatterChart.getYAxis()).getUpperBound(); - }); - - NumberAxis xAxis = (NumberAxis) scatterChart.getXAxis(); - NumberAxis yAxis = (NumberAxis) scatterChart.getYAxis(); - - scatterChart.setOnMouseDragged(event -> { - xAxis.setAutoRanging(false); - yAxis.setAutoRanging(false); - double deltaX = event.getSceneX() - initialX; - double deltaY = event.getSceneY() - initialY; - - double newLowerBoundX = initialLowerBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); - double newUpperBoundX = initialUpperBoundX - deltaX * (xAxis.getUpperBound() - xAxis.getLowerBound()) / scatterChart.getWidth(); - double newLowerBoundY = initialLowerBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); - double newUpperBoundY = initialUpperBoundY + deltaY * (yAxis.getUpperBound() - yAxis.getLowerBound()) / scatterChart.getHeight(); - - xAxis.setLowerBound(newLowerBoundX); - xAxis.setUpperBound(newUpperBoundX); - yAxis.setLowerBound(newLowerBoundY); - yAxis.setUpperBound(newUpperBoundY); - }); - xAxis.setAutoRanging(true); - yAxis.setAutoRanging(true); - } } diff --git a/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java b/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java index ce462a0b5c12e72f1ac2d9de29bed32aee1006a3..686303ccf2d4e88c048c40d37c40ebe7719370b4 100644 --- a/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java +++ b/src/main/java/fr/univlille/sae/classification/knn/MethodKNN.java @@ -1,5 +1,6 @@ package fr.univlille.sae.classification.knn; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import fr.univlille.sae.classification.knn.distance.*; import fr.univlille.sae.classification.model.ClassificationModel; import fr.univlille.sae.classification.model.DataType; @@ -12,7 +13,6 @@ import java.util.*; public class MethodKNN { - private static final Random random = new Random(); public static String path = System.getProperty("user.dir") + File.separator + "res" + File.separator; @@ -20,22 +20,29 @@ public class MethodKNN { public static double[] minData; public static double[] maxData; - public MethodKNN(ClassificationModel model) { + private MethodKNN() { - updateModel(model.getDatas()); - - } + } + /** + * Permet de mettre a jour les données de l'algorithme. Recalcul les amplitudes et les min/max des données + * @param datas Les données sur lequel l'algorithme doit travailler + */ public static void updateModel(List<LoadableData> datas) { if(datas.isEmpty()) return; - minData = new double[datas.get(0).getAttributes().length]; - maxData = new double[datas.get(0).getAttributes().length]; - amplitude = new double[datas.get(0).getAttributes().length]; + + int numAttributes = datas.get(0).getAttributes().length; + minData = new double[numAttributes]; + maxData = new double[numAttributes]; + amplitude = new double[numAttributes]; + + for(LoadableData l :datas) { - for(int i = 0; i<l.getAttributes().length; i++) { - if(l.getAttributes()[i] < minData[i]) minData[i] = l.getAttributes()[i]; - if(l.getAttributes()[i] > maxData[i]) maxData[i] = l.getAttributes()[i]; + double[] attributes = l.getAttributes(); + for(int i = 0; i<numAttributes; i++) { + if(attributes[i] < minData[i]) minData[i] = attributes[i]; + if(attributes[i] > maxData[i]) maxData[i] = attributes[i]; } } @@ -44,6 +51,15 @@ public class MethodKNN { } } + /** + * Permet de recuperer les K-voisins les plus proches d'une données dans un jeu de données + * en fonction d'une Distance. + * @param datas Le jeu de données + * @param data La donnée avec laquelle calculer la distance + * @param k Le nombre de voisins a recupérer + * @param distance + * @return + */ public static List<LoadableData> kVoisins(List<LoadableData> datas, LoadableData data, int k, Distance distance) { // On recupere toutes les données @@ -65,45 +81,49 @@ public class MethodKNN { // On recupere les K voisions de data. List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, data, k, distance); - System.out.println("Neighbours: " + kVoisins); - - // System.out.println("Neighbours found : " + kVoisins); // On compte le nombre de représentation de chaque class parmis les voisins + // Et on récupere la plus présente + Map<String, Integer> classOfNeighbours = new HashMap<>(); + String currentClass = kVoisins.get(0).getClassification(); + + + for(LoadableData voisin : kVoisins) { int newValue = ((classOfNeighbours.get(voisin.getClassification()) == null) ? 0 : classOfNeighbours.get(voisin.getClassification()) )+ 1; classOfNeighbours.put(voisin.getClassification(), newValue); - } - // On recupere la classe la plus repésenté parmis les voisins (au hasard si egalité entre 2) - String currentClass = kVoisins.get(0).getClassification(); - for(String classification : classOfNeighbours.keySet()) { - if(classOfNeighbours.get(classification) > classOfNeighbours.get(currentClass)) { - currentClass = classification; - }else if (classOfNeighbours.get(classification).equals(classOfNeighbours.get(currentClass))) { - if(random.nextInt(2) == 1) currentClass = classification; + // si la classe est plus presente que la classe acutelemnt majoritaire, on change la classe majoritaire. + // Si il y'a egalité alors on garde la premiere trouvé + if(classOfNeighbours.get(voisin.getClassification()) > classOfNeighbours.get(currentClass)) { + currentClass = voisin.getClassification(); } + } - // System.out.println("Estimate class = " + currentClass); return currentClass; } public static int bestK(List<LoadableData> datas, Distance distance) { + // On borne le K pour eviter de trouver un K trop grand int maxK = (int) (Math.sqrt(datas.size())); System.out.println("Max k: " + maxK); - Map<Integer, Double> results = new HashMap<>(); + int betK = 1; + + Map<Integer, Double> results = new LinkedHashMap<>(); // Pour chaque valeur impaire possible de K, on calcul la robustesse (le taux de reussite) de l'algorithme. for(int i =1; i<maxK; i = i +2) { results.put(i, robustesse(datas, i, distance, 0.2)); + // On modifie le meilleur k si le taux est superieur au K precedent + // Si egalité, on garde le premier trouvé + if(results.get(i) > results.get(betK)) betK = i; } - System.out.println(results); + System.out.println("Results: " + results); - // On return le K ayant le meilleur taux de reussite ( ou l'un des K si egalités). - return Collections.max(results.entrySet(), Map.Entry.comparingByValue()).getKey(); + return betK; } @@ -114,7 +134,7 @@ public class MethodKNN { double taux = 0; - for(int i = 0; i<(int)1/testPart; i++) { + for(int i = 0; i<1/testPart; i++) { int totalFind = 0; int totalTry = 0; @@ -132,10 +152,7 @@ public class MethodKNN { for(LoadableData l : testData) { totalTry++; String baseClass = l.getClassification(); - // System.out.println("Base class : " + baseClass); - // System.out.println("Base data: " + l); if(baseClass.equals(MethodKNN.estimateClass(trainingData,l, k, distance))) totalFind++; - } @@ -150,14 +167,14 @@ public class MethodKNN { return taux/(1/testPart); } - public static void main(String[] args) { + public static void main(String[] args) throws CsvRequiredFieldEmptyException { //Test de la robustesse et du meillleur K ClassificationModel model = ClassificationModel.getClassificationModel(); - model.setType(DataType.IRIS); - model.loadData(new File(path+"data/iris.csv")); + model.setType(DataType.POKEMON); + model.loadData(new File(path+"data/pokemon_train.csv")); MethodKNN.updateModel(model.getDatas()); System.out.println(); @@ -179,7 +196,7 @@ public class MethodKNN { - } + } diff --git a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java index 0c5291ca16dadd29e930f2346bec1219feb9c5f5..403e01cbbfdc0b1b94d65de4ec9196c80d43878a 100644 --- a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java +++ b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java @@ -1,6 +1,8 @@ package fr.univlille.sae.classification.model; import com.opencsv.bean.CsvToBeanBuilder; +import com.opencsv.exceptions.CsvBadConverterException; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import fr.univlille.sae.classification.knn.MethodKNN; import fr.univlille.sae.classification.knn.distance.Distance; import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; @@ -79,8 +81,10 @@ public class ClassificationModel extends Observable { * Charge les données à partir d'un fichier CSV. * @param file fichier contenant les données à charger. */ - public void loadData(File file) { + public void loadData(File file) throws CsvRequiredFieldEmptyException, CsvBadConverterException { try { + this.dataToClass.clear(); + this.datas = new CsvToBeanBuilder<LoadableData>(Files.newBufferedReader(file.toPath())) .withSeparator(',') .withType(type.getClazz()) @@ -117,7 +121,7 @@ public class ClassificationModel extends Observable { public void classifierDonnee(LoadableData data) { if(dataToClass.get(data) != null && dataToClass.get(data)) return; this.dataToClass.remove(data); - data.setClassification(MethodKNN.estimateClass(datas, data, kOptimal, distance)); + data.setClassification(MethodKNN.estimateClass(datas, data, k, distance)); notifyObservers(data); dataToClass.put(data, true); } @@ -137,6 +141,7 @@ public class ClassificationModel extends Observable { public void setDistance(Distance distance) { this.distance = distance; + this.kOptimal = 0; } public Distance getDistance() { diff --git a/src/main/java/fr/univlille/sae/classification/model/Iris.java b/src/main/java/fr/univlille/sae/classification/model/Iris.java index e716c1062f71cbcec4606f807ba67e670367ed22..49ff5b8834288e96cd703e23baa887be1fcc92d7 100644 --- a/src/main/java/fr/univlille/sae/classification/model/Iris.java +++ b/src/main/java/fr/univlille/sae/classification/model/Iris.java @@ -13,13 +13,13 @@ import java.util.Map; */ public class Iris extends LoadableData { - @CsvBindByName(column = "sepal.length") + @CsvBindByName(column = "sepal.length", required = true) private double sepalLength; - @CsvBindByName(column = "sepal.width") + @CsvBindByName(column = "sepal.width", required = true) private double sepalWidth; - @CsvBindByName(column = "petal.length") + @CsvBindByName(column = "petal.length", required = true) private double petalLength; - @CsvBindByName(column = "petal.width") + @CsvBindByName(column = "petal.width", required = true) private double petalWidth; @CsvBindByName(column = "variety") private String variety; @@ -109,6 +109,7 @@ public class Iris extends LoadableData { return petalLength; } + @Override public double[] getAttributes() { return new double[]{sepalLength, sepalWidth, petalLength, petalWidth} ; @@ -120,23 +121,8 @@ public class Iris extends LoadableData { } - /** - * Renvoie la couleur associée à la variété de l'Iris. - * @return couleur correspondant à la variété. - */ - @Override - public Color getColor() { - switch (this.variety) { - case "Setosa": - return Color.RED; - case "Versicolor": - return Color.BLUE; - case "Virginica": - return Color.GREEN; - default: - return Color.BLACK; // Couleur par défaut si la variété est inconnue - } - } + + /** * Renvoie les noms des attributs de l'Iris. diff --git a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java index 2f08a23358a1a43266f5364065cf800b1f048524..d1b6696c9a3b53870870913fb953e819b3f1a413 100644 --- a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java +++ b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java @@ -2,6 +2,10 @@ package fr.univlille.sae.classification.model; import javafx.scene.paint.Color; +import java.util.Map; +import java.util.HashMap; +import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -12,10 +16,13 @@ public abstract class LoadableData { private static Set<String> classificationTypes; + private static Map<String, Color> classification = new HashMap<>() ; + /** * Constructeur par défaut. */ protected LoadableData() { + } /** @@ -32,14 +39,64 @@ public abstract class LoadableData { return classificationTypes; } + public static Map<String, Color> getClassifications() { + return classification; + } + /** * Définit les types de classification disponibles. * @param classificationTypes ensemble de types de classification à définir. */ public static void setClassificationTypes(Set<String> classificationTypes) { LoadableData.classificationTypes = classificationTypes; + LoadableData.classification.clear(); + int nbOfColors = classificationTypes.size() + 1; + int nb = 0; + for(String s : classificationTypes) { + // Génération de couleurs avec une plage évitant le blanc + + LoadableData.classification.put(s, getColor(nb++, nbOfColors)); + } + + LoadableData.classification.put("undefined", getColor(nb,nbOfColors)); + } + + + private static Color getColor(int nb, int totalColors) { + // Ratio pour répartir les couleurs uniformément + double ratio = (double) nb / (double) totalColors; + + // Utilisation de fonctions trigonométriques pour des transitions douces + double red = 0.5 + 0.4 * Math.sin(2 * Math.PI * ratio); // Oscille entre 0.1 et 0.9 + double green = 0.5 + 0.4 * Math.sin(2 * Math.PI * ratio + Math.PI / 3); // Décalage de phase + double blue = 0.5 + 0.4 * Math.sin(2 * Math.PI * ratio + 2 * Math.PI / 3); // Décalage de phase + + // Réduction de la luminosité pour éviter le blanc et gris clair + double maxComponent = Math.max(red, Math.max(green, blue)); + if (maxComponent > 0.8) { + red *= 0.8 / maxComponent; + green *= 0.8 / maxComponent; + blue *= 0.8 / maxComponent; + } + + // Conversion en objet Color + return Color.color(red, green, blue); + } + + /* private static Color getColor(int i) { + double ratio = (double) i / classificationTypes.size(); + + // Réduire les composantes pour éviter les tons clairs + double red = 0.2 + 0.6 * ratio; // Entre 0.2 et 0.8 + double green = 0.8 - 0.6 * ratio; // Entre 0.8 et 0.2 + double blue = 0.5 + 0.3 * Math.sin(ratio * Math.PI); // Entre 0.5 et 0.8 + + return Color.color(red, green, blue); } + + */ + /** * Définit la classification de l'objet. * @param classification classification à définir. @@ -48,12 +105,6 @@ public abstract class LoadableData { public abstract Map<String, Object> getAttributesNames(); - /** - * Renvoie la couleur associée à l'objet. - * @return couleur correspondant à la classification de l'objet. - */ - public abstract Color getColor(); - public abstract double[] getAttributes(); diff --git a/src/main/java/fr/univlille/sae/classification/model/PointFactory.java b/src/main/java/fr/univlille/sae/classification/model/PointFactory.java index bddf6995eb5e17c5e48b7de573ec1fa00dbac038..27eba5e9a8c138c2bf0ef639b73bfeada563ed68 100644 --- a/src/main/java/fr/univlille/sae/classification/model/PointFactory.java +++ b/src/main/java/fr/univlille/sae/classification/model/PointFactory.java @@ -32,7 +32,7 @@ public class PointFactory { data = new Pokemon(coords); } else if (coords.length == 11) { - data = new Pokemon((String) coords[0], (Integer) coords[1], (Integer) coords[2], (Double) coords[3], (Integer) coords[4], (Integer) coords[5], (Integer) coords[6], (Integer) coords[7], (Integer) coords[8], "", "", (Double) coords[9], (Boolean) coords[10]); + data = new Pokemon((String) coords[0], (Integer) coords[1], (Integer) coords[2], (Double) coords[3], (Integer) coords[4], (Integer) coords[5], (Integer) coords[6], (Integer) coords[7], (Integer) coords[8], "undefined", "", (Double) coords[9], (Boolean) coords[10]); } break; default: diff --git a/src/main/java/fr/univlille/sae/classification/model/Pokemon.java b/src/main/java/fr/univlille/sae/classification/model/Pokemon.java index fb4d7c44c419fcbaf6eb6e9dc9bb2710c614f8f7..7ac5fad96ac847d5199b969014b9c3f852fa4e79 100644 --- a/src/main/java/fr/univlille/sae/classification/model/Pokemon.java +++ b/src/main/java/fr/univlille/sae/classification/model/Pokemon.java @@ -1,13 +1,11 @@ package fr.univlille.sae.classification.model; import com.opencsv.bean.CsvBindByName; -import javafx.scene.paint.Color; - import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; -public class Pokemon extends LoadableData { +public class Pokemon extends LoadableData{ // name,attack,base_egg_steps,capture_rate,defense,experience_growth,hp,sp_attack,sp_defense,type1,type2,speed,is_legendary // Swablu,40,5120,255.0,60,600000,45,75,50,normal,flying,1.2,False @@ -51,7 +49,11 @@ public class Pokemon extends LoadableData { this.hp = hp; this.spAttack = spAttack; this.spDefense = spDefense; - this.type1 = type1; + if(type1 == null || type1.isEmpty()) { + this.type1 = "undefined"; + }else { + this.type1 = type1; + } this.type2 = type2; this.speed = speed; this.isLegendary = isLegendary; @@ -112,54 +114,10 @@ public class Pokemon extends LoadableData { return attrNames; } - /** - * Renvoie la couleur associée à l'objet. - * - * @return couleur correspondant à la classification de l'objet. - */ - @Override - public Color getColor() { - switch (this.type1) { - case "normal": - return Color.LIGHTGREY; - case "grass": - return Color.GREEN; - case "electric": - return Color.YELLOW; - case "bug": - return Color.GREENYELLOW; - case "psychic": - return Color.PLUM; - case "poison": - return Color.PURPLE; - case "steel": - return Color.SILVER; - case "dragon": - return Color.WHITE; - case "flying": - return Color.SKYBLUE; - case "water": - return Color.BLUE; - case "rock": - return Color.SIENNA; - case "fire": - return Color.RED; - case "fairy": - return Color.PINK; - case "fighting": - return Color.FIREBRICK; - case "ice": - return Color.DARKTURQUOISE; - case "ghost": - return Color.DARKMAGENTA; - case "dark": - return Color.GREY; - case "ground": - return Color.KHAKI; - default: - return Color.BLACK; // Couleur par défaut si la variété est inconnue - } - } + + + + @Override public double[] getAttributes() { @@ -169,7 +127,7 @@ public class Pokemon extends LoadableData { @Override public String[] getStringAttributes() { - return new String[]{name, type2, String.valueOf(isLegendary)}; + return new String[]{type2, String.valueOf(isLegendary)}; } @Override diff --git a/src/main/java/fr/univlille/sae/classification/utils/Observer.java b/src/main/java/fr/univlille/sae/classification/utils/Observer.java index e17a1a34c046e7a215b5afb60b6e712240813021..0f7f2081b45365e98f517f4ec7fc366582e3c6cf 100644 --- a/src/main/java/fr/univlille/sae/classification/utils/Observer.java +++ b/src/main/java/fr/univlille/sae/classification/utils/Observer.java @@ -1,9 +1,28 @@ package fr.univlille.sae.classification.utils; -public interface Observer { +/** + * Interface pour implémenter le modèle Observateur. + * Cette interface définit les méthodes que les classes doivent implémenter pour agir + * comme des observateurs dans le cadre du modèle Observateur/Observé. + * Les observateurs sont notifiés des changements d'état des objets observés + * via les méthodes `update'. + */ +public interface Observer { + /** + * Méthode appelée pour notifier l'observateur qu'un changement s'est produit + * dans l'objet observé. + * @param observable l'objet observé qui a subi un changement. + */ void update(Observable observable); + + /** + * Méthode appelée pour notifier l'observateur qu'un changement s'est produit + * dans l'objet observé, avec des données supplémentaires. + * @param observable l'objet observé qui a subi un changement. + * @param data des informations supplémentaires concernant le changement. + */ void update(Observable observable, Object data); diff --git a/src/main/java/fr/univlille/sae/classification/utils/ViewUtil.java b/src/main/java/fr/univlille/sae/classification/utils/ViewUtil.java index b01cc53fa32108c26b3aa4b5db9d95d785b5c069..8613467602476d7429b0085881ba871cfac2c9dd 100644 --- a/src/main/java/fr/univlille/sae/classification/utils/ViewUtil.java +++ b/src/main/java/fr/univlille/sae/classification/utils/ViewUtil.java @@ -2,14 +2,25 @@ package fr.univlille.sae.classification.utils; import fr.univlille.sae.classification.controller.DataStageController; import fr.univlille.sae.classification.controller.MainStageController; +import fr.univlille.sae.classification.model.ClassificationModel; import fr.univlille.sae.classification.model.LoadableData; +import javafx.geometry.Pos; +import javafx.scene.chart.ScatterChart; import javafx.scene.chart.XYChart; import javafx.scene.control.ContextMenu; +import javafx.scene.control.Label; import javafx.scene.control.MenuItem; +import javafx.scene.layout.HBox; +import javafx.scene.layout.VBox; import javafx.scene.paint.Color; +import javafx.scene.shape.Circle; +import javafx.scene.shape.Rectangle; import javafx.scene.shape.Shape; import javafx.stage.Stage; +import java.util.HashMap; +import java.util.Map; + /** * Classe utilitaire pour la gestion des vues. */ @@ -23,7 +34,9 @@ public class ViewUtil { */ public static Shape getForm(LoadableData dataLoaded, Shape form, Object controller) { try { - form.setFill(dataLoaded.getColor()); + Color color = LoadableData.getClassifications().get(dataLoaded.getClassification()); + + form.setFill(color); form.setOnMouseClicked(e -> { if (controller instanceof DataStageController) { @@ -44,4 +57,57 @@ public class ViewUtil { return form; } + + public static VBox loadLegend() { + //Color + + Map<String, Color> colors = new HashMap<>(Map.copyOf(LoadableData.getClassifications())); + Rectangle rectangle = new Rectangle(10, 10); + rectangle.setFill(colors.remove("undefined")); + Label label = new Label("undefined"); + VBox legend = new VBox(); + legend.setAlignment(Pos.CENTER); + HBox line = new HBox(); + line.setSpacing(10); + line.setAlignment(Pos.CENTER); + + HBox tempHBox = new HBox(); + tempHBox.getChildren().addAll(rectangle, label); + line.getChildren().add(tempHBox); + + String[] colorsString = colors.keySet().toArray(new String[0]); + for(int i = 0 ; i < colorsString.length ; i+= 7) { + for(int j = 0 ; i+j < colorsString.length && j < i+7 ; j++) { + if(j%7 == 0 && i != 0 ) { + legend.getChildren().add(line); + line = new HBox(); + line.setSpacing(10); + line.setAlignment(Pos.CENTER); + } + + tempHBox = new HBox(); + label = new Label(colorsString[i+j]); + rectangle = new Rectangle(10, 10); + rectangle.setFill(colors.get(colorsString[i+j])); + tempHBox.getChildren().addAll(rectangle, label); + line.getChildren().add(tempHBox); + + + if(colorsString.length < 7) legend.getChildren().add(line); + +/** + for(String s : colors.keySet()) { + Circle c = new Circle(5); + c.setFill(colors.get(s)); + label = new Label(s); + tempHBox = new HBox(); + tempHBox.getChildren().addAll(c, label); + + hbox.getChildren().add(tempHBox); + } + */ + + return legend; + } + } diff --git a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java index 973c2088541cc0b848e649fe3482c9e6703ef087..68d70c468c480b1ac178c38e901cfddf6f8ee524 100644 --- a/src/main/java/fr/univlille/sae/classification/view/DataStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/DataStageView.java @@ -1,6 +1,7 @@ package fr.univlille.sae.classification.view; import fr.univlille.sae.classification.controller.DataStageController; +import fr.univlille.sae.classification.controller.MainStageController; import fr.univlille.sae.classification.model.ClassificationModel; import fr.univlille.sae.classification.model.DataType; import fr.univlille.sae.classification.model.Iris; @@ -13,6 +14,7 @@ import javafx.fxml.FXMLLoader; import javafx.scene.Node; import javafx.scene.chart.ScatterChart; import javafx.scene.chart.XYChart; +import javafx.scene.layout.HBox; import javafx.scene.shape.Circle; import javafx.scene.shape.Rectangle; import javafx.stage.Stage; @@ -31,10 +33,10 @@ import java.util.Map; */ public class DataStageView extends DataVisualizationView implements Observer { - private ClassificationModel model; - private DataStageController controller; - private Map<String, ScatterChart.Series<Double, Double>> serieList; + + + private XYChart.Series series1; private XYChart.Series series2; @@ -48,9 +50,9 @@ public class DataStageView extends DataVisualizationView implements Observer { * @param model le modèle de classification utilisé pour gérer les données. */ public DataStageView(ClassificationModel model) { - super(); - this.serieList = new HashMap<String, ScatterChart.Series<Double, Double>>(); - this.model = model; + super(model); + + this.series1 = new XYChart.Series(); this.series2 = new XYChart.Series(); this.series3 = new XYChart.Series(); @@ -76,10 +78,11 @@ public class DataStageView extends DataVisualizationView implements Observer { root.setResizable(false); root.setTitle("SAE3.3 - Logiciel de classification"); root.show(); + controller = (MainStageController) controller; controller = loader.getController(); - controller.setDataStageView(this); + ((DataStageController) controller).setDataStageView(this); scatterChart = controller.getScatterChart(); - + scatterChart.setLegendVisible(false); scatterChart.getData().addAll(series4, series1, series2, series3); controller.setAxesSelected("Aucun fichier sélectionné"); @@ -92,120 +95,13 @@ public class DataStageView extends DataVisualizationView implements Observer { } } - /** - * Met à jour l'affichage des données en fonction des changements dans le modèle. - * @param observable modèle observé. - */ - @Override - public void update(Observable observable) { - try { - if (scatterChart == null || !(observable instanceof ClassificationModel)) { - System.err.println("Erreur de mise à jour."); - return; - } - - scatterChart.getData().clear(); - serieList.clear(); - - if (actualX == null && actualY == null) { - controller.setAxesSelected("Aucuns axes sélectionnés"); - } else { - controller.setAxesSelected(""); - controller.setAxesSelectedDisable(); - - List<LoadableData> points = new ArrayList<>(model.getDatas()); - points.addAll(model.getDataToClass().keySet()); - for (LoadableData data : points) { - Object xValue = data.getAttributesNames().get(actualX); - Object yValue = data.getAttributesNames().get(actualY); - - Double x = 0.0; - if (xValue instanceof Integer) { - x = ((Integer) xValue).doubleValue(); - } else if (xValue instanceof Double) { - x = (Double) xValue; - } - - Double y = 0.0; - if (yValue instanceof Integer) { - y = ((Integer) yValue).doubleValue(); - } else if (yValue instanceof Double) { - y = (Double) yValue; - } - - ScatterChart.Data<Double, Double> dataPoint = new ScatterChart.Data<>(x, y); - - Node nodePoint = ViewUtil.getForm(data, new Circle(5), controller); - - ScatterChart.Series<Double, Double> editSerie = serieList.get(data.getClassification()); - if(editSerie == null){ - editSerie = new ScatterChart.Series<Double, Double>(); - } - if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { - nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); - } - dataPoint.setNode(nodePoint); - editSerie.getData().add(dataPoint); - serieList.put(data.getClassification(), editSerie); - } - - for(String serie : serieList.keySet()) { - serieList.get(serie).setName(serie); - } - scatterChart.getData().addAll(serieList.values()); - } - } catch (Exception e) { - System.err.println("Erreur de mise à jour : " + e.getMessage()); - } - } - - /** - * Met à jour l'affichage en ajoutant un nouveau point de données. - * @param observable modèle observé. - * @param data point de données à ajouter. - */ - @Override - public void update(Observable observable, Object data) { - try { - if (scatterChart == null || !(observable instanceof ClassificationModel)) { - System.err.println("Erreur de mise à jour."); - return; - } - LoadableData newData = (LoadableData) data; - if (actualX == null || actualY == null) { - controller.setAxesSelected("Aucuns axes sélectionnés"); - return; - } - Object attrX = newData.getAttributesNames().get(actualX); - Object attrY = newData.getAttributesNames().get(actualY); - if (attrX instanceof Integer) { - attrX = ((Integer) attrX).doubleValue(); - } - if (attrY instanceof Integer) { - attrY = ((Integer) attrY).doubleValue(); - } - XYChart.Data<Double, Double> dataPoint = new XYChart.Data<>( - (Double) attrX, - (Double) attrY - ); - - dataPoint.setNode(ViewUtil.getForm(newData, new Rectangle(10, 10), controller)); - if (!scatterChart.getData().isEmpty()) { - series4.getData().add(dataPoint); - series4.setName("indéfini"); - scatterChart.getData().add(series4); - } - } catch (Exception e) { - System.err.println("Erreur de mise à jour : " + e.getMessage()); - } - } /** * Renvoie le contrôleur associé à cette vue. * @return contrôleur de la vue. */ public DataStageController getController() { - return controller; + return (DataStageController) controller; } /** diff --git a/src/main/java/fr/univlille/sae/classification/view/DataVisualizationView.java b/src/main/java/fr/univlille/sae/classification/view/DataVisualizationView.java index c3ad5af65602dcf8a469e619d393a5955429e562..0ee1a0140fd6e25ea2dba391e677ffcafbe6679a 100644 --- a/src/main/java/fr/univlille/sae/classification/view/DataVisualizationView.java +++ b/src/main/java/fr/univlille/sae/classification/view/DataVisualizationView.java @@ -1,6 +1,23 @@ package fr.univlille.sae.classification.view; +import fr.univlille.sae.classification.controller.DataStageController; +import fr.univlille.sae.classification.controller.DataVisualizationController; +import fr.univlille.sae.classification.controller.MainStageController; +import fr.univlille.sae.classification.model.ClassificationModel; +import fr.univlille.sae.classification.model.LoadableData; +import fr.univlille.sae.classification.utils.Observable; +import fr.univlille.sae.classification.utils.ViewUtil; +import javafx.scene.Node; import javafx.scene.chart.ScatterChart; +import javafx.scene.chart.XYChart; +import javafx.scene.layout.HBox; +import javafx.scene.layout.VBox; +import javafx.scene.shape.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * Classe abstraite représentant une vue de visualisation des données. @@ -8,14 +25,27 @@ import javafx.scene.chart.ScatterChart; */ public abstract class DataVisualizationView { + public DataVisualizationController controller; + private ScatterChart.Series series1; + private ScatterChart.Series series2; + private ScatterChart.Series series3; + private ScatterChart.Series series4; protected String actualX; protected String actualY; protected ScatterChart scatterChart; + + private Map<String, ScatterChart.Series<Double, Double>> serieList; + public ClassificationModel model; /** - * Constructeur par défaut. + * Constructeur pour initialiser la vue de données. + * @param model le modèle de classification utilisé pour gérer les données. */ - protected DataVisualizationView() {} + protected DataVisualizationView(ClassificationModel model) { + this.serieList = new HashMap<String, ScatterChart.Series<Double, Double>>(); + this.model = model; + this.series4 = new XYChart.Series(); + } /** * Renvoie le nom de l'axe X actuel. @@ -61,4 +91,137 @@ public abstract class DataVisualizationView { * Méthode abstraite à implémenter pour recharger les données de la vue. */ public abstract void reload(); + + + /** + * Met à jour l'affichage des données en fonction des changements dans le modèle. + * @param observable modèle observé. + */ + public void update(Observable observable) { + try { + if (scatterChart == null || !(observable instanceof ClassificationModel)) { + System.err.println("Erreur de mise à jour."); + return; + } + + scatterChart.getData().clear(); + serieList.clear(); + + if (actualX == null && actualY == null) { + controller.setAxesSelected("Aucuns axes sélectionnés"); + } else { + controller.setAxesSelected(""); + controller.setAxesSelectedDisable(); + + List<LoadableData> points = new ArrayList<>(model.getDatas()); + points.addAll(model.getDataToClass().keySet()); + for (LoadableData data : points) { + Object xValue = data.getAttributesNames().get(actualX); + Object yValue = data.getAttributesNames().get(actualY); + + + + double x = 0; + if(xValue instanceof Number) { + x = ((Number) xValue).doubleValue(); + } + double y = 0; + if(yValue instanceof Number) { + y = ((Number) yValue).doubleValue(); + } + /** + Double x = 0.0; + if (xValue instanceof Integer) { + x = ((Integer) xValue).doubleValue(); + } else if (xValue instanceof Double) { + x = (Double) xValue; + } + + Double y = 0.0; + if (yValue instanceof Integer) { + y = ((Integer) yValue).doubleValue(); + } else if (yValue instanceof Double) { + y = (Double) yValue; + } + **/ + ScatterChart.Data<Double, Double> dataPoint = new ScatterChart.Data<>(x, y); + + Node nodePoint = ViewUtil.getForm(data, new Circle(5), controller); + + ScatterChart.Series<Double, Double> editSerie = serieList.get(data.getClassification()); + if(editSerie == null){ + editSerie = new ScatterChart.Series<Double, Double>(); + } + if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { + nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); + } + + dataPoint.setNode(nodePoint); + editSerie.getData().add(dataPoint); + serieList.put(data.getClassification(), editSerie); + } + + for(String serie : serieList.keySet()) { + serieList.get(serie).setName(serie); + } + scatterChart.getData().addAll(serieList.values()); + + + VBox vBox = ViewUtil.loadLegend(); + controller.loadLegend(vBox); + } + + + } catch (Exception e) { + System.err.println("Erreur de mise à jour : " + e.getMessage()); + } + } + + /** + * Met à jour l'affichage en ajoutant un nouveau point de données. + * @param observable modèle observé. + * @param data point de données à ajouter. + */ + public void update(Observable observable, Object data) { + try { + if (scatterChart == null || !(observable instanceof ClassificationModel)) { + System.err.println("Erreur de mise à jour."); + return; + } + + + + LoadableData newData = (LoadableData) data; + if (actualX == null || actualY == null) { + controller.setAxesSelected("Aucuns axes sélectionnés"); + return; + } + Object attrX = newData.getAttributesNames().get(actualX); + Object attrY = newData.getAttributesNames().get(actualY); + if (attrX instanceof Integer) { + attrX = ((Integer) attrX).doubleValue(); + } + if (attrY instanceof Integer) { + attrY = ((Integer) attrY).doubleValue(); + } + XYChart.Data<Double, Double> dataPoint = new XYChart.Data<>( + (Double) attrX, + (Double) attrY + ); + + dataPoint.setNode(ViewUtil.getForm(newData, new Rectangle(10, 10), controller)); + if (!scatterChart.getData().isEmpty()) { + series4.getData().add(dataPoint); + series4.setName("indéfini"); + scatterChart.getData().add(series4); + } + + + controller.loadLegend(ViewUtil.loadLegend()); + } catch (Exception e) { + System.err.println("Erreur de mise à jour : " + e.getMessage()); + } + } + + } diff --git a/src/main/java/fr/univlille/sae/classification/view/KNNView.java b/src/main/java/fr/univlille/sae/classification/view/KNNView.java index 31cd2982be04ff01e20ba30e7d61a17a087b39d2..5df40577e89bfd2bda585d59c9f94ab00fe2e2cb 100644 --- a/src/main/java/fr/univlille/sae/classification/view/KNNView.java +++ b/src/main/java/fr/univlille/sae/classification/view/KNNView.java @@ -26,6 +26,9 @@ public class KNNView { this.owner = owner; } + /** + * + */ public void show() { FXMLLoader loader = new FXMLLoader(); URL fxmlFileUrl = getClass().getClassLoader().getResource("stages"+ File.separator+"k-NN-stage.fxml"); diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java index d87d6d9deccfdc8aa22aa0150bee66b6f69d0817..36661733bc711b9951119c79df26b58938cdbedb 100644 --- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java +++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java @@ -8,9 +8,12 @@ import fr.univlille.sae.classification.utils.Observer; import fr.univlille.sae.classification.utils.ViewUtil; import javafx.fxml.FXMLLoader; import javafx.scene.Node; +import javafx.scene.chart.NumberAxis; import javafx.scene.chart.ScatterChart; import javafx.scene.chart.XYChart; import javafx.scene.control.*; +import javafx.scene.layout.HBox; +import javafx.scene.layout.VBox; import javafx.scene.shape.*; import javafx.stage.Stage; @@ -24,12 +27,9 @@ import java.util.*; */ public class MainStageView extends DataVisualizationView implements Observer { - private ClassificationModel model; - private MainStageController controller; private Stage root; - private Map<String, ScatterChart.Series<Double, Double>> serieList; private ScatterChart.Series series1; private ScatterChart.Series series2; @@ -38,16 +38,17 @@ public class MainStageView extends DataVisualizationView implements Observer { /** * Constructeur de la vue principale. + * * @param model modèle de classification à utiliser. */ public MainStageView(ClassificationModel model) { - super(); - this.serieList = new HashMap<String, ScatterChart.Series<Double, Double>>(); + super(model); + this.series1 = new ScatterChart.Series(); this.series2 = new ScatterChart.Series(); this.series3 = new ScatterChart.Series(); this.series4 = new ScatterChart.Series(); - this.model = model; + model.attach(this); } @@ -85,8 +86,10 @@ public class MainStageView extends DataVisualizationView implements Observer { } }); + + controller = (MainStageController) controller; controller = loader.getController(); - controller.setMainStageView(this); + ((MainStageController) controller).setMainStageView(this); scatterChart = controller.getScatterChart(); //scatterChart.getData().addAll(series1, series2, series3, series4); controller.setAxesSelected("Aucun fichier sélectionné"); @@ -96,115 +99,23 @@ public class MainStageView extends DataVisualizationView implements Observer { } } - @Override - public void update(Observable observable) { - try { - if (scatterChart == null || !(observable instanceof ClassificationModel)) { - System.err.println("Erreur de mise à jour."); - return; - } - - scatterChart.getData().clear(); - serieList.clear(); - - if (actualX == null && actualY == null) { - controller.setAxesSelected("Aucuns axes sélectionnés"); - } else { - controller.setAxesSelected(""); - controller.setAxesSelectedDisable(); - - List<LoadableData> points = new ArrayList<>(model.getDatas()); - points.addAll(model.getDataToClass().keySet()); - for (LoadableData data : points) { - Object xValue = data.getAttributesNames().get(actualX); - Object yValue = data.getAttributesNames().get(actualY); - - Double x = 0.0; - if (xValue instanceof Integer) { - x = ((Integer) xValue).doubleValue(); - } else if (xValue instanceof Double) { - x = (Double) xValue; - } - - Double y = 0.0; - if (yValue instanceof Integer) { - y = ((Integer) yValue).doubleValue(); - } else if (yValue instanceof Double) { - y = (Double) yValue; - } - - ScatterChart.Data<Double, Double> dataPoint = new ScatterChart.Data<>(x, y); - - Node nodePoint = ViewUtil.getForm(data, new Circle(5), controller); - - ScatterChart.Series<Double, Double> editSerie = serieList.get(data.getClassification()); - if(editSerie == null){ - editSerie = new ScatterChart.Series<Double, Double>(); - } - if(data.getClassification().equals("undefined") || model.getDataToClass().containsKey(data)) { - nodePoint = ViewUtil.getForm(data, new Rectangle(10,10), controller); - } - - dataPoint.setNode(nodePoint); - editSerie.getData().add(dataPoint); - serieList.put(data.getClassification(), editSerie); - } - - for(String serie : serieList.keySet()) { - serieList.get(serie).setName(serie); - } - scatterChart.getData().addAll(serieList.values()); - } - } catch (Exception e) { - System.err.println("Erreur de mise à jour : " + e.getMessage()); - } - } - - @Override - public void update(Observable observable, Object data) { - try { - if (scatterChart == null || !(observable instanceof ClassificationModel)) { - System.err.println("Erreur de mise à jour."); - return; - } - - - - LoadableData newData = (LoadableData) data; - if (actualX == null || actualY == null) { - controller.setAxesSelected("Aucuns axes sélectionnés"); - return; - } - Object attrX = newData.getAttributesNames().get(actualX); - Object attrY = newData.getAttributesNames().get(actualY); - if (attrX instanceof Integer) { - attrX = ((Integer) attrX).doubleValue(); - } - if (attrY instanceof Integer) { - attrY = ((Integer) attrY).doubleValue(); - } - XYChart.Data<Double, Double> dataPoint = new XYChart.Data<>( - (Double) attrX, - (Double) attrY - ); - - dataPoint.setNode(ViewUtil.getForm(newData, new Rectangle(10, 10), controller)); - if (!scatterChart.getData().isEmpty()) { - series4.getData().add(dataPoint); - series4.setName("indéfini"); - scatterChart.getData().add(series4); - } - } catch (Exception e) { - System.err.println("Erreur de mise à jour : " + e.getMessage()); - } - } + /** + * Retourne le contrôleur principal de la scène. + * + * @return le contrôleur principal de la scène en tant qu'instance. + */ public MainStageController getController() { - return controller; + return (MainStageController) controller; } + /** + * Recharge les données nécessaires à partir du modèle de classification. + * Cette méthode met à jour l'état en fonction des données actuelles + */ @Override public void reload() { this.update(ClassificationModel.getClassificationModel()); } } + diff --git a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java index e0873a5f3a78475d65d874d2e7d40ab6788a9f0a..824be5f4bcf220de6165f03d342622ef2316bb37 100644 --- a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java +++ b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java @@ -40,7 +40,7 @@ class IrisTest { @Test void getColor() { - assertEquals(Color.RED, iris.getColor()); + assertEquals(Color.RED, iris.getClassifications().get(iris.getClassification())); } @Test