Skip to content
Snippets Groups Projects
Commit 4b807b2e authored by Matias Mennecart's avatar Matias Mennecart
Browse files

add and edit some tests

parent 02d29a97
Branches
Tags
No related merge requests found
......@@ -58,7 +58,7 @@ public class ClassificationModel extends Observable {
this.dataToClass = new ConcurrentHashMap<>();
this.type = type;
this.kOptimal = 0;
this.k = 0;
this.k = 1;
this.distance = new DistanceEuclidienne();
}
/**
......
......@@ -12,7 +12,7 @@ import java.util.Map;
*/
public abstract class LoadableData {
private static Set<String> classificationTypes;
private static Set<String> classificationTypes = new HashSet<>();
private static Map<String, Color> classification = new HashMap<>() ;
......
package fr.univlille.sae.classification.knn;
import com.opencsv.exceptions.CsvRequiredFieldEmptyException;
import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne;
import fr.univlille.sae.classification.knn.distance.DistanceEuclidienneNormalisee;
import fr.univlille.sae.classification.model.ClassificationModel;
......@@ -41,7 +42,7 @@ public class MethodKNNTest {
@Test
public void testKVoisins_distance_euclidienne() throws IOException {
public void testKVoisins_distance_euclidienne() throws IOException, CsvRequiredFieldEmptyException {
model.loadData(csvTemp);
List<LoadableData> datas = model.getDatas();
......@@ -59,7 +60,7 @@ public class MethodKNNTest {
@Test
public void testKVoisins_distance_euclidienne_normalise() throws IOException {
public void testKVoisins_distance_euclidienne_normalise() throws IOException, CsvRequiredFieldEmptyException {
model.loadData(csvTemp);
List<LoadableData> datas = model.getDatas();
......
package fr.univlille.sae.classification.model;
import com.opencsv.exceptions.CsvBadConverterException;
import com.opencsv.exceptions.CsvRequiredFieldEmptyException;
import fr.univlille.sae.classification.knn.DataComparator;
import fr.univlille.sae.classification.knn.distance.DistanceEuclidienne;
import fr.univlille.sae.classification.knn.distance.DistanceManhattan;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
......@@ -7,6 +12,7 @@ import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.List;
import java.util.Map;
......@@ -18,16 +24,36 @@ class ClassificationModelTest {
private File csvTemp;
private String csvTest;
private File errorCsv;
private String errorCsvTest;
private String path = System.getProperty("user.dir") + File.separator + "res" + File.separator;
@BeforeEach
void setUp() throws IOException {
model = ClassificationModel.getClassificationModel();
csvTemp = File.createTempFile("test", ".csv");
csvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" +
"5.1,3.5,1.4,0.2,\"Setosa\"\n" +
"4.9,3.0,1.4,0.2,\"Setosa\"\n";
"4.1,3.2,1.4,0.2,\"Setosa\"\n" +
"4.1,3.1,1.5,0.2,\"Setosa\"\n" +
"3.9,3.6,1.6,0.7,\"Setosa\"\n" +
"3.7,2.7,1.1,0.4,\"Setosa\"\n" +
"4.7,3.9,1.9,0.3,\"Setosa\"\n" +
"4.9,3.2,2.1,0.4,\"Setosa\"\n";
Files.write(Paths.get(csvTemp.getAbsolutePath()), csvTest.getBytes());
errorCsv = File.createTempFile("error_test", ".csv");
errorCsvTest = "\"sepal.length\",\"sepal.width\",\"petal.length\",\"petal.width\",\"variety\"\n" +
"4.1,sqd,1.4,0.2,\"Setosa\"\n" +
"4.1,3.1,1.5,\"Setosa\"\n" +
"3.9,3.6,1.6,0.7,\"Setosa\"\n" +
"3.7,2.7,qsdsq,0.4,\"Setosa\"\n" +
"4.7,3.9,1.9,0.3,\"Setosa\"\n" +
"4.9,3.2,2.1,0.4,\"Setosa\"\n";
Files.write(Paths.get(errorCsv.getAbsolutePath()), errorCsvTest.getBytes());
}
@Test
......@@ -55,41 +81,42 @@ class ClassificationModelTest {
}
@Test
void testLoadData() throws IOException {
void testLoadData() throws IOException, CsvRequiredFieldEmptyException, IllegalAccessException {
model.loadData(csvTemp);
List<LoadableData> datas = model.getDatas();
assertEquals(2, datas.size());
Iris i1 = (Iris) datas.get(0);
assertEquals(5.1, i1.getSepalLength());
assertEquals(3.5, i1.getSepalWidth());
assertEquals(1.4, i1.getPetalLength());
assertEquals(0.2, i1.getPetalWidth());
assertEquals("Setosa", i1.getClassification());
Iris i2 = (Iris) datas.get(1);
assertEquals(4.9, i2.getSepalLength());
assertEquals(3.0, i2.getSepalWidth());
assertEquals(1.4, i2.getPetalLength());
assertEquals(0.2, i2.getPetalWidth());
assertEquals("Setosa", i1.getClassification());
assertEquals(6, datas.size());
assertEquals(0, model.getDataToClass().size());
csvTemp.delete();
}
@Test
void testClassifierDonnees() {
public void test_load_data_with_errors() {
RuntimeException e = assertThrows(RuntimeException.class, () -> {
model.loadData(errorCsv);
});
assertTrue(e.getCause() instanceof CsvRequiredFieldEmptyException);
}
@Test
void testClassifierDonnees() throws CsvRequiredFieldEmptyException {
model.loadData(csvTemp);
model.ajouterDonnee(5.1, 3.5, 1.4, 0.2);
model.setK(3);
model.classifierDonnees();
model.ajouterDonnee(4.9, 3.0, 1.4, 0.2);
assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0]));
assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1]));
assertEquals(true, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[0]));
assertEquals(false, model.getDataToClass().get(model.getDataToClass().keySet().toArray()[1]));
}
@Test
......@@ -97,4 +124,74 @@ class ClassificationModelTest {
model.setType(DataType.IRIS);
assertEquals(DataType.IRIS, model.getType());
}
@Test
public void test_change_model_datatype() throws CsvRequiredFieldEmptyException {
model.setType(DataType.POKEMON);
model.loadData(new File(path + "data/pokemon_train.csv"));
assertEquals(DataType.POKEMON, model.getType());
assertFalse(model.getDatas().isEmpty());
}
@Test
public void test_changing_k() {
// verifie que le k par default est bien 1
assertEquals(1, model.getK());
model.setK(3);
model.setKOptimal(6);
assertEquals(3, model.getK());
assertEquals(6, model.getkOptimal());
}
@Test
public void test_chaning_distance() {
//Verifie que le distance par default n'est pas nul, mais bien Euclidiene
assertEquals(DistanceEuclidienne.class, model.getDistance().getClass());
model.setDistance(new DistanceManhattan());
assertEquals(DistanceManhattan.class, model.getDistance().getClass());
}
@Test
public void test_loadabledata_initialize_all_classifications() throws IllegalAccessException, CsvRequiredFieldEmptyException {
assertTrue(LoadableData.getClassificationTypes().isEmpty());
model.loadData(csvTemp);
LoadableData.setClassificationTypes(model.getDatas());
assertFalse(LoadableData.getClassificationTypes().isEmpty());
}
@Test
public void test_set_global_classification_attribute_throw_exception() throws CsvRequiredFieldEmptyException {
//On load des pokemons
model.setType(DataType.POKEMON);
model.loadData(new File(path + "data/pokemon_train.csv"));
LoadableData data = model.getDatas().get(0);
// on met le type Name (0)
assertThrows(IllegalArgumentException.class, () -> {
data.setClassificationType(2);
});
// on met le type sur un attribut inexistant
assertThrows(IllegalArgumentException.class, () -> {
data.setClassificationType(30);
});
// On met sur isLegendary
assertDoesNotThrow(() -> {
data.setClassificationType(12);
});
}
}
......@@ -29,7 +29,7 @@ class IrisTest {
assertEquals(1.9, iris.getPetalLength());
}
/*
@Test
void getDataType() {
assertEquals(3.0 , iris.getDataType("sepalWidth"));
......@@ -47,4 +47,6 @@ class IrisTest {
void testToString() {
assertEquals("Iris{sepalLength=2.8, sepalWidth=3.0, petalLength=1.9, petalWidth=4.1}", iris.toString());
}
*/
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment