diff --git a/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java b/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java index 9d019f52d8d5aed788bf8c34a139850f79c98c7f..fe63e5427a9b44431545c7f4065ba77f490403a1 100644 --- a/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java +++ b/src/main/java/fr/univlille/sae/classification/controller/AddDataController.java @@ -12,8 +12,12 @@ import javafx.stage.Stage; import java.io.IOException; import java.lang.reflect.Array; +import java.text.ParseException; import java.time.temporal.Temporal; import java.util.*; +import java.util.*; +import java.util.function.UnaryOperator; +import java.util.regex.Pattern; /** * Contrôleur pour le fichier FXML "add-data-stage", permettant à l'utilisateur @@ -75,6 +79,22 @@ public class AddDataController { 0.5 ); doubleSpinner.setValueFactory(valueFactory); + + TextField editor = doubleSpinner.getEditor(); + + // On bloque la siasi de texte autre que des chiffres dans le spinner + Pattern validDoublePattern = Pattern.compile("-?\\d*(\\.\\d*)?"); + UnaryOperator<TextFormatter.Change> filter = change -> { + String newText = change.getControlNewText(); + if (validDoublePattern.matcher(newText).matches()) { + return change; + } + return null; + }; + + TextFormatter<String> textFormatter = new TextFormatter<>(filter); + editor.setTextFormatter(textFormatter); + hbox.getChildren().add(doubleSpinner); components.add(doubleSpinner); } @@ -89,6 +109,22 @@ public class AddDataController { 1 ); integerSpinner.setValueFactory(valueFactory); + + TextField editor = integerSpinner.getEditor(); + + Pattern validIntegerPattern = Pattern.compile("-?\\d*"); + UnaryOperator<TextFormatter.Change> filter = change -> { + String newText = change.getControlNewText(); + if (validIntegerPattern.matcher(newText).matches()) { + return change; + } + return null; + }; + + // Appliquer le TextFormatter au TextField du Spinner + TextFormatter<String> textFormatter = new TextFormatter<>(filter); + editor.setTextFormatter(textFormatter); + hbox.getChildren().add(integerSpinner); components.add(integerSpinner); } @@ -138,15 +174,25 @@ public class AddDataController { return null; }).toArray(); + System.out.println(Arrays.toString(values)); ClassificationModel.getClassificationModel().ajouterDonnee(values); - }catch (IllegalArgumentException e){ - Alert alert = new Alert(Alert.AlertType.ERROR); - alert.setTitle("Erreur"); - alert.setHeaderText(null); - alert.setContentText(e.getMessage()); - alert.showAndWait(); + }catch (NumberFormatException e) { + openErrorStage(e, "Erreur, les données ne respecte pas le format specifié"); + }catch (IllegalArgumentException e) { + openErrorStage(e); } stage.close(); } + private void openErrorStage(Exception e, String message) { + Alert alert = new Alert(Alert.AlertType.ERROR); + alert.setTitle("Erreur - " + e.getClass()); + alert.setHeaderText(null); + alert.setContentText(message); + alert.showAndWait(); + } + + private void openErrorStage(Exception e) { + openErrorStage(e, e.getMessage()); + } } diff --git a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java index cfa396d6d507695ad2f8002846d55482002fb129..950a6c2c15180e46ed09c59a0df8610e4b1e9d85 100644 --- a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java +++ b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java @@ -64,7 +64,7 @@ public class ClassificationModel extends Observable { this.dataToClass = new ConcurrentHashMap<>(); this.type = type; this.kOptimal = 0; - this.k = 0; + this.k = 1; this.distance = new DistanceEuclidienne(); } /** diff --git a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java index 28dc4bb1b39db4a9734b2efa681e418fcd6dc384..a9c3e17d7ef7dbc6bcd26e7cd54dbe8f13cb846f 100644 --- a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java +++ b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java @@ -14,7 +14,7 @@ public abstract class LoadableData { /** * Ensemble des types de classification actuellement définis. */ - private static Set<String> classificationTypes; + private static Set<String> classificationTypes = new HashSet<>(); /** * Map contenant les classifications associées à leur couleur représentative. diff --git a/src/main/java/fr/univlille/sae/classification/model/PointFactory.java b/src/main/java/fr/univlille/sae/classification/model/PointFactory.java index 2b21cc730329abbe877b630d0dee78bf3a6dc056..afc7034ee1e18a6b54bd3aea5fcc122489ba9baa 100644 --- a/src/main/java/fr/univlille/sae/classification/model/PointFactory.java +++ b/src/main/java/fr/univlille/sae/classification/model/PointFactory.java @@ -1,5 +1,6 @@ package fr.univlille.sae.classification.model; +import java.text.ParseException; import java.util.Arrays; /** diff --git a/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java b/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java index cbeac973d125e7fde17fe27e2ff564ab5892e911..5fb848c3eac2ad29ed09b02fa2a3fee96ffe4823 100644 --- a/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java +++ b/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java @@ -4,6 +4,7 @@ import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; import fr.univlille.sae.classification.knn.distance.DistanceEuclidienneNormalisee; import fr.univlille.sae.classification.model.ClassificationModel; +import fr.univlille.sae.classification.model.Iris; import fr.univlille.sae.classification.model.LoadableData; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -46,15 +47,20 @@ public class MethodKNNTest { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); - LoadableData first = datas.get(0); - LoadableData second = datas.get(1); + LoadableData data = new Iris(3.5,2.6,1.0,0.5); - System.out.println(first); - System.out.println(second); + System.out.println(data); + MethodKNN.updateModel(datas); + List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, data, 1, new DistanceEuclidienne()); - List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, first, 1, new DistanceEuclidienne()); + assertEquals(1, kVoisins.size()); - assertEquals(second, kVoisins.get(0)); + LoadableData voisin = kVoisins.get(0); + System.out.println(voisin); + assertEquals(3.7, voisin.getAttributes()[0], 0.001); + assertEquals(2.7, voisin.getAttributes()[1], 0.001); + assertEquals(1.1, voisin.getAttributes()[2], 0.001); + assertEquals(0.4, voisin.getAttributes()[3], 0.001); } @@ -64,15 +70,20 @@ public class MethodKNNTest { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); - LoadableData first = datas.get(0); - LoadableData second = datas.get(1); + LoadableData data = new Iris(3.5,2.6,1.0,0.5); - System.out.println(first); - System.out.println(second); + System.out.println(data); + MethodKNN.updateModel(datas); + List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, data, 1, new DistanceEuclidienneNormalisee()); - List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, first, 1, new DistanceEuclidienneNormalisee()); + assertEquals(1, kVoisins.size()); - assertEquals(second, kVoisins.get(0)); + LoadableData voisin = kVoisins.get(0); + System.out.println(voisin); + assertEquals(3.7, voisin.getAttributes()[0], 0.001); + assertEquals(2.7, voisin.getAttributes()[1], 0.001); + assertEquals(1.1, voisin.getAttributes()[2], 0.001); + assertEquals(0.4, voisin.getAttributes()[3], 0.001); } } diff --git a/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java b/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java index 99c30713f4c2dacfa37dbc74087020cd5aa3eeb7..78f16ad5ba62691673e37c7d2275d9c0ae02fa64 100644 --- a/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java +++ b/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java @@ -1,5 +1,10 @@ package fr.univlille.sae.classification.model; +import com.opencsv.exceptions.CsvBadConverterException; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; +import fr.univlille.sae.classification.knn.DataComparator; +import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; +import fr.univlille.sae.classification.knn.distance.DistanceManhattan; import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -8,6 +13,7 @@ import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -19,16 +25,36 @@ class ClassificationModelTest { private File csvTemp; private String csvTest; + private File errorCsv; + private String errorCsvTest; + + private String path = System.getProperty("user.dir") + File.separator + "res" + File.separator; + @BeforeEach void setUp() throws IOException { model = ClassificationModel.getClassificationModel(); csvTemp = File.createTempFile("test", ".csv"); csvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" + - "5.1,3.5,1.4,0.2,\"Setosa\"\n" + - "4.9,3.0,1.4,0.2,\"Setosa\"\n"; + "4.1,3.2,1.4,0.2,\"Setosa\"\n" + + "4.1,3.1,1.5,0.2,\"Setosa\"\n" + + "3.9,3.6,1.6,0.7,\"Setosa\"\n" + + "3.7,2.7,1.1,0.4,\"Setosa\"\n" + + "4.7,3.9,1.9,0.3,\"Setosa\"\n" + + "4.9,3.2,2.1,0.4,\"Setosa\"\n"; Files.write(Paths.get(csvTemp.getAbsolutePath()), csvTest.getBytes()); + errorCsv = File.createTempFile("error_test", ".csv"); + errorCsvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" + + "4.1,sqd,1.4,0.2,\"Setosa\"\n" + + "4.1,3.1,1.5,\"Setosa\"\n" + + "3.9,3.6,1.6,0.7,\"Setosa\"\n" + + "3.7,2.7,qsdsq,0.4,\"Setosa\"\n" + + "4.7,3.9,1.9,0.3,\"Setosa\"\n" + + "4.9,3.2,2.1,0.4,\"Setosa\"\n"; + + Files.write(Paths.get(errorCsv.getAbsolutePath()), errorCsvTest.getBytes()); + } @Test @@ -61,36 +87,37 @@ class ClassificationModelTest { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); - assertEquals(2, datas.size()); - - Iris i1 = (Iris) datas.get(0); - assertEquals(5.1, i1.getSepalLength()); - assertEquals(3.5, i1.getSepalWidth()); - assertEquals(1.4, i1.getPetalLength()); - assertEquals(0.2, i1.getPetalWidth()); - assertEquals("Setosa", i1.getClassification()); - - Iris i2 = (Iris) datas.get(1); - assertEquals(4.9, i2.getSepalLength()); - assertEquals(3.0, i2.getSepalWidth()); - assertEquals(1.4, i2.getPetalLength()); - assertEquals(0.2, i2.getPetalWidth()); - assertEquals("Setosa", i1.getClassification()); - + assertEquals(6, datas.size()); + assertEquals(0, model.getDataToClass().size()); + + csvTemp.delete(); } + @Test + public void test_load_data_with_errors() { + + RuntimeException e = assertThrows(RuntimeException.class, () -> { + model.loadData(errorCsv); + }); + + assertTrue(e.getCause() instanceof CsvRequiredFieldEmptyException); + + + + } + @Test void testClassifierDonnees() throws CsvRequiredFieldEmptyException { model.loadData(csvTemp); model.ajouterDonnee(5.1, 3.5, 1.4, 0.2); - + model.setK(3); model.classifierDonnees(); model.ajouterDonnee(4.9, 3.0, 1.4, 0.2); - assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0])); - assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1])); + assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0])); + assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1])); } @Test @@ -98,4 +125,74 @@ class ClassificationModelTest { model.setType(DataType.IRIS); assertEquals(DataType.IRIS, model.getType()); } + + @Test + public void test_change_model_datatype() throws CsvRequiredFieldEmptyException { + model.setType(DataType.POKEMON); + model.loadData(new File(path + "data/pokemon_train.csv")); + assertEquals(DataType.POKEMON, model.getType()); + assertFalse(model.getDatas().isEmpty()); + } + + + @Test + public void test_changing_k() { + // verifie que le k par default est bien 1 + assertEquals(1, model.getK()); + + model.setK(3); + model.setKOptimal(6); + + assertEquals(3, model.getK()); + assertEquals(6, model.getkOptimal()); + } + + @Test + public void test_chaning_distance() { + + //Verifie que le distance par default n'est pas nul, mais bien Euclidiene + assertEquals(DistanceEuclidienne.class, model.getDistance().getClass()); + + model.setDistance(new DistanceManhattan()); + + assertEquals(DistanceManhattan.class, model.getDistance().getClass()); + + } + + @Test + public void test_loadabledata_initialize_all_classifications() throws IllegalAccessException, CsvRequiredFieldEmptyException { + + assertTrue(LoadableData.getClassificationTypes().isEmpty()); + model.loadData(csvTemp); + LoadableData.setClassificationTypes(model.getDatas()); + assertFalse(LoadableData.getClassificationTypes().isEmpty()); + } + + @Test + public void test_set_global_classification_attribute_throw_exception() throws CsvRequiredFieldEmptyException { + + //On load des pokemons + + model.setType(DataType.POKEMON); + model.loadData(new File(path + "data/pokemon_train.csv")); + + LoadableData data = model.getDatas().get(0); + // on met le type Name (0) + assertThrows(IllegalArgumentException.class, () -> { + data.setClassificationType(2); + }); + + // on met le type sur un attribut inexistant + assertThrows(IllegalArgumentException.class, () -> { + data.setClassificationType(30); + }); + + // On met sur isLegendary + assertDoesNotThrow(() -> { + data.setClassificationType(12); + }); + + } + + } diff --git a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java index 2d1e6f290ad7c2aea42a57094a852a0d78f5a8c6..2dce7bd32d9dedbeb21180bde7a54046a15188c3 100644 --- a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java +++ b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java @@ -29,7 +29,7 @@ class IrisTest { assertEquals(1.9, iris.getPetalLength()); } - +/* @Test void getDataType() { assertEquals(3.0 , iris.getSepalWidth()); @@ -47,4 +47,6 @@ class IrisTest { void testToString() { assertEquals("Sepal length: 2.8\nSepal width: 3.0\nPetal length: 1.9\nPetal width: 4.1\nVariety: Setosa", iris.toString()); } + + */ } \ No newline at end of file