diff --git a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java index b2ecdbe8af0eb18f87da05049189dc42205bd4c0..3aa6c21d71195ee2c25ca3b280bc106db25c86ab 100644 --- a/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java +++ b/src/main/java/fr/univlille/sae/classification/model/ClassificationModel.java @@ -58,7 +58,7 @@ public class ClassificationModel extends Observable { this.dataToClass = new ConcurrentHashMap<>(); this.type = type; this.kOptimal = 0; - this.k = 0; + this.k = 1; this.distance = new DistanceEuclidienne(); } /** diff --git a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java index 69088689a75b1f598a2523bb0c072253653d73c3..5cb828ff287e473800e532b56216a19e25eb22b0 100644 --- a/src/main/java/fr/univlille/sae/classification/model/LoadableData.java +++ b/src/main/java/fr/univlille/sae/classification/model/LoadableData.java @@ -12,7 +12,7 @@ import java.util.Map; */ public abstract class LoadableData { - private static Set<String> classificationTypes; + private static Set<String> classificationTypes = new HashSet<>(); private static Map<String, Color> classification = new HashMap<>() ; diff --git a/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java b/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java index 946657490b368fe5667ceed11896e57a1f7bee91..cbeac973d125e7fde17fe27e2ff564ab5892e911 100644 --- a/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java +++ b/src/test/java/fr/univlille/sae/classification/knn/MethodKNNTest.java @@ -1,5 +1,6 @@ package fr.univlille.sae.classification.knn; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; import fr.univlille.sae.classification.knn.distance.DistanceEuclidienneNormalisee; import fr.univlille.sae.classification.model.ClassificationModel; @@ -41,7 +42,7 @@ public class MethodKNNTest { @Test - public void testKVoisins_distance_euclidienne() throws IOException { + public void testKVoisins_distance_euclidienne() throws IOException, CsvRequiredFieldEmptyException { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); @@ -59,7 +60,7 @@ public class MethodKNNTest { @Test - public void testKVoisins_distance_euclidienne_normalise() throws IOException { + public void testKVoisins_distance_euclidienne_normalise() throws IOException, CsvRequiredFieldEmptyException { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); diff --git a/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java b/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java index d3b5ab5a5667d95348db2a4c152a06126f7e44ca..359bcd19287adfecdc4c1a6912cfe720417f9e09 100644 --- a/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java +++ b/src/test/java/fr/univlille/sae/classification/model/ClassificationModelTest.java @@ -1,5 +1,10 @@ package fr.univlille.sae.classification.model; +import com.opencsv.exceptions.CsvBadConverterException; +import com.opencsv.exceptions.CsvRequiredFieldEmptyException; +import fr.univlille.sae.classification.knn.DataComparator; +import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne; +import fr.univlille.sae.classification.knn.distance.DistanceManhattan; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -7,6 +12,7 @@ import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -18,16 +24,36 @@ class ClassificationModelTest { private File csvTemp; private String csvTest; + private File errorCsv; + private String errorCsvTest; + + private String path = System.getProperty("user.dir") + File.separator + "res" + File.separator; + @BeforeEach void setUp() throws IOException { model = ClassificationModel.getClassificationModel(); csvTemp = File.createTempFile("test", ".csv"); csvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" + - "5.1,3.5,1.4,0.2,\"Setosa\"\n" + - "4.9,3.0,1.4,0.2,\"Setosa\"\n"; + "4.1,3.2,1.4,0.2,\"Setosa\"\n" + + "4.1,3.1,1.5,0.2,\"Setosa\"\n" + + "3.9,3.6,1.6,0.7,\"Setosa\"\n" + + "3.7,2.7,1.1,0.4,\"Setosa\"\n" + + "4.7,3.9,1.9,0.3,\"Setosa\"\n" + + "4.9,3.2,2.1,0.4,\"Setosa\"\n"; Files.write(Paths.get(csvTemp.getAbsolutePath()), csvTest.getBytes()); + errorCsv = File.createTempFile("error_test", ".csv"); + errorCsvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" + + "4.1,sqd,1.4,0.2,\"Setosa\"\n" + + "4.1,3.1,1.5,\"Setosa\"\n" + + "3.9,3.6,1.6,0.7,\"Setosa\"\n" + + "3.7,2.7,qsdsq,0.4,\"Setosa\"\n" + + "4.7,3.9,1.9,0.3,\"Setosa\"\n" + + "4.9,3.2,2.1,0.4,\"Setosa\"\n"; + + Files.write(Paths.get(errorCsv.getAbsolutePath()), errorCsvTest.getBytes()); + } @Test @@ -55,41 +81,42 @@ class ClassificationModelTest { } @Test - void testLoadData() throws IOException { + void testLoadData() throws IOException, CsvRequiredFieldEmptyException, IllegalAccessException { model.loadData(csvTemp); List<LoadableData> datas = model.getDatas(); - assertEquals(2, datas.size()); - - Iris i1 = (Iris) datas.get(0); - assertEquals(5.1, i1.getSepalLength()); - assertEquals(3.5, i1.getSepalWidth()); - assertEquals(1.4, i1.getPetalLength()); - assertEquals(0.2, i1.getPetalWidth()); - assertEquals("Setosa", i1.getClassification()); - - Iris i2 = (Iris) datas.get(1); - assertEquals(4.9, i2.getSepalLength()); - assertEquals(3.0, i2.getSepalWidth()); - assertEquals(1.4, i2.getPetalLength()); - assertEquals(0.2, i2.getPetalWidth()); - assertEquals("Setosa", i1.getClassification()); - + assertEquals(6, datas.size()); + assertEquals(0, model.getDataToClass().size()); + + csvTemp.delete(); } @Test - void testClassifierDonnees() { + public void test_load_data_with_errors() { + + RuntimeException e = assertThrows(RuntimeException.class, () -> { + model.loadData(errorCsv); + }); + + assertTrue(e.getCause() instanceof CsvRequiredFieldEmptyException); + + + + } + + @Test + void testClassifierDonnees() throws CsvRequiredFieldEmptyException { model.loadData(csvTemp); model.ajouterDonnee(5.1, 3.5, 1.4, 0.2); - + model.setK(3); model.classifierDonnees(); model.ajouterDonnee(4.9, 3.0, 1.4, 0.2); - assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0])); - assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1])); + assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0])); + assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1])); } @Test @@ -97,4 +124,74 @@ class ClassificationModelTest { model.setType(DataType.IRIS); assertEquals(DataType.IRIS, model.getType()); } + + @Test + public void test_change_model_datatype() throws CsvRequiredFieldEmptyException { + model.setType(DataType.POKEMON); + model.loadData(new File(path + "data/pokemon_train.csv")); + assertEquals(DataType.POKEMON, model.getType()); + assertFalse(model.getDatas().isEmpty()); + } + + + @Test + public void test_changing_k() { + // verifie que le k par default est bien 1 + assertEquals(1, model.getK()); + + model.setK(3); + model.setKOptimal(6); + + assertEquals(3, model.getK()); + assertEquals(6, model.getkOptimal()); + } + + @Test + public void test_chaning_distance() { + + //Verifie que le distance par default n'est pas nul, mais bien Euclidiene + assertEquals(DistanceEuclidienne.class, model.getDistance().getClass()); + + model.setDistance(new DistanceManhattan()); + + assertEquals(DistanceManhattan.class, model.getDistance().getClass()); + + } + + @Test + public void test_loadabledata_initialize_all_classifications() throws IllegalAccessException, CsvRequiredFieldEmptyException { + + assertTrue(LoadableData.getClassificationTypes().isEmpty()); + model.loadData(csvTemp); + LoadableData.setClassificationTypes(model.getDatas()); + assertFalse(LoadableData.getClassificationTypes().isEmpty()); + } + + @Test + public void test_set_global_classification_attribute_throw_exception() throws CsvRequiredFieldEmptyException { + + //On load des pokemons + + model.setType(DataType.POKEMON); + model.loadData(new File(path + "data/pokemon_train.csv")); + + LoadableData data = model.getDatas().get(0); + // on met le type Name (0) + assertThrows(IllegalArgumentException.class, () -> { + data.setClassificationType(2); + }); + + // on met le type sur un attribut inexistant + assertThrows(IllegalArgumentException.class, () -> { + data.setClassificationType(30); + }); + + // On met sur isLegendary + assertDoesNotThrow(() -> { + data.setClassificationType(12); + }); + + } + + } diff --git a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java index 824be5f4bcf220de6165f03d342622ef2316bb37..f6e950e4bd3c8bbceb4dd7ab5e05cc1a55f83f7d 100644 --- a/src/test/java/fr/univlille/sae/classification/model/IrisTest.java +++ b/src/test/java/fr/univlille/sae/classification/model/IrisTest.java @@ -29,7 +29,7 @@ class IrisTest { assertEquals(1.9, iris.getPetalLength()); } - +/* @Test void getDataType() { assertEquals(3.0 , iris.getDataType("sepalWidth")); @@ -47,4 +47,6 @@ class IrisTest { void testToString() { assertEquals("Iris{sepalLength=2.8, sepalWidth=3.0, petalLength=1.9, petalWidth=4.1}", iris.toString()); } + + */ } \ No newline at end of file