Skip to content
Snippets Groups Projects
Commit b349661d authored by Matias Mennecart's avatar Matias Mennecart
Browse files

connect KNN to the interface

parent 95149918
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@
<children>
<HBox alignment="CENTER" prefHeight="92.0" prefWidth="364.0">
<children>
<ChoiceBox fx:id="AlgoSelector" prefWidth="150.0" stylesheets="@../css/style.css" />
<ChoiceBox fx:id="algoSelector" prefWidth="150.0" stylesheets="@../css/style.css" />
</children>
<padding>
<Insets top="20.0" />
......@@ -36,8 +36,8 @@
</HBox>
<HBox alignment="CENTER" prefHeight="36.0" prefWidth="364.0" spacing="20.0">
<children>
<TextField fx:id="KEntry" />
<Button fx:id="AutoK" mnemonicParsing="false" stylesheets="@../css/style.css" text="Attribution auto" textFill="WHITE">
<Spinner fx:id="kEntry" />
<Button fx:id="autoK" mnemonicParsing="false" onAction="#bestK" stylesheets="@../css/style.css" text="Attribution auto" textFill="WHITE">
<font>
<Font name="System Bold" size="13.0" />
</font>
......@@ -46,7 +46,7 @@
</HBox>
<HBox alignment="CENTER" prefHeight="59.0" prefWidth="364.0">
<children>
<Button fx:id="confirmK" mnemonicParsing="false" stylesheets="@../css/style.css" text="Valider" textFill="WHITE">
<Button fx:id="confirmK" onAction="#validate" mnemonicParsing="false" stylesheets="@../css/style.css" text="Valider" textFill="WHITE">
<font>
<Font name="System Bold" size="14.0" />
</font></Button>
......
package fr.univlille.sae.classification.controller;
import fr.univlille.sae.classification.knn.MethodKNN;
import fr.univlille.sae.classification.knn.distance.Distance;
import fr.univlille.sae.classification.knn.distance.DistanceManhattanNormalisee;
import fr.univlille.sae.classification.model.ClassificationModel;
import fr.univlille.sae.classification.model.LoadableData;
import javafx.collections.ObservableList;
import javafx.concurrent.Task;
import javafx.fxml.FXML;
import javafx.scene.control.Button;
import javafx.scene.control.ChoiceBox;
import javafx.scene.control.TextField;
import javafx.scene.Scene;
import javafx.scene.control.*;
import javafx.scene.layout.HBox;
import javafx.stage.Stage;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class KNNController {
@FXML
private Stage stage;
@FXML
ChoiceBox<String> AlgoSelector;
ChoiceBox<String> algoSelector;
@FXML
TextField KEntry;
Spinner<Integer> kEntry;
@FXML
Button AutoK;
Button autoK;
@FXML
Button confirmK;
@FXML
public void initialize() {
kEntry.setValueFactory(new SpinnerValueFactory.IntegerSpinnerValueFactory(1,
(int) Math.sqrt(ClassificationModel.getClassificationModel().getDatas().size()),
1,
1));
algoSelector.getItems().addAll("Euclidienne", "Euclidienne Normalisée", "Manhattan", "Manhattan Normalisée");
}
public void bestK() {
ClassificationModel model = ClassificationModel.getClassificationModel();
if(model.getkOptimal() > 0) {
// Le K Optimal à déja été calculé, il n'est pas necessaire de le recaculer.
kEntry.getValueFactory().setValue(model.getkOptimal());
}else {
// Calcul du K Optimal:
HBox hBox = new HBox();
Task<Scene> knnTask = new Task<>() {
@Override
protected Scene call() throws Exception {
System.out.println("Call call()");
updateProgress(0, 3);
updateMessage("Préparation des données ");
//List<LoadableData> datasShuffle = new ArrayList<>(List.copyOf(model.getDatas()));
// Collections.shuffle(datasShuffle);
MethodKNN.updateModel(model.getDatas());
updateProgress(1, 3);
updateMessage("Recherche du meilleur K");
int bestK = MethodKNN.bestK(model.getDatas(), model.getDistance());
updateMessage("Test de robustesse");
updateProgress(2, 3);
double robustesse = MethodKNN.robustesse( model.getDatas(), bestK, model.getDistance(), 0.2);
model.setKOptimal(bestK);
updateMessage("Affichage du resultat");
updateProgress(2.5, 3);
HBox hBox = new HBox();
Label label = new Label("Best K: " + bestK + " robustesse : " + robustesse);
hBox.getChildren().add(label);
Scene scene = new Scene(hBox);
kEntry.getValueFactory().setValue(bestK);
updateMessage("Finished");
updateProgress(3, 3);
return scene;
}
};
ProgressBar pBar = new ProgressBar();
pBar.progressProperty().bind(knnTask.progressProperty());
Label statusLabel = new Label();
statusLabel.textProperty().bind(knnTask.messageProperty());
hBox.getChildren().addAll(statusLabel, pBar);
Stage stageLoad = new Stage();
Scene scene = new Scene(hBox);
stageLoad.setScene(scene);
stageLoad.show();
Stage stageFinished = new Stage();
knnTask.setOnSucceeded(e -> {
stageLoad.close();
stageFinished.setScene(knnTask.getValue());
stageFinished.show();
});
knnTask.run();
//new Thread(knnTask).start();
}
}
public void validate() {
ClassificationModel model = ClassificationModel.getClassificationModel();
int k = kEntry.getValue();
Distance dist = Distance.getByName(algoSelector.getValue());
model.setDistance(dist);
model.setK(k);
model.classifierDonnees();
stage.close();
}
}
package fr.univlille.sae.classification.controller;
import fr.univlille.sae.classification.model.ClassificationModel;
import fr.univlille.sae.classification.model.DataType;
import javafx.fxml.FXML;
import javafx.scene.control.Alert;
import javafx.scene.control.Button;
......@@ -14,9 +15,11 @@ import java.io.IOException;
public class LoadDataController {
@FXML
Stage stage;
@FXML
TextField filePath;
......@@ -44,6 +47,7 @@ public class LoadDataController {
if(file != null) {
filePath.setText(file.getPath());
}
}
/**
......@@ -58,6 +62,7 @@ public class LoadDataController {
alert.initOwner(stage);
alert.setContentText("Le chargement du fichier à echoué, veuillez reessayer !");
alert.showAndWait();
openFileChooser();
return;
}
......
......@@ -65,6 +65,7 @@ public class MethodKNN {
// On recupere les K voisions de data.
List<LoadableData> kVoisins = MethodKNN.kVoisins(datas, data, k, distance);
System.out.println("Neighbours: " + kVoisins);
// System.out.println("Neighbours found : " + kVoisins);
......@@ -155,8 +156,8 @@ public class MethodKNN {
ClassificationModel model = ClassificationModel.getClassificationModel();
model.setType(DataType.POKEMON);
model.loadData(new File(path+"data/pokemon_train.csv"));
model.setType(DataType.IRIS);
model.loadData(new File(path+"data/iris.csv"));
MethodKNN.updateModel(model.getDatas());
System.out.println();
......@@ -168,11 +169,11 @@ public class MethodKNN {
System.out.println("Search best k");
// On cherche le meilleure K
int bestK = MethodKNN.bestK(datas, new DistanceEuclidienneNormalisee());
int bestK = MethodKNN.bestK(datas, new DistanceManhattanNormalisee());
System.out.println(bestK);
// Puis on clacul la robustesse avec le K trouvé
System.out.println(MethodKNN.robustesse( datas, bestK, new DistanceEuclidienneNormalisee(), 0.2));
System.out.println(MethodKNN.robustesse( datas, bestK, new DistanceManhattanNormalisee(), 0.2));
}
......
......@@ -7,4 +7,18 @@ public interface Distance {
double distance(LoadableData l1, LoadableData l2);
static Distance getByName(String name){
switch (name) {
case "Euclidienne Normalisée":
return new DistanceEuclidienneNormalisee();
case "Manhattan":
return new DistanceManhattan();
case "Manhattan Normalisée":
return new DistanceManhattanNormalisee();
default:
return new DistanceEuclidienne();
}
}
}
......@@ -14,7 +14,8 @@ public class DistanceEuclidienneNormalisee implements Distance{
double[] normaliseL1 = normalise(l1);
double[] normaliseL2 = normalise(l2);
for(int i = 0;i<normaliseL1.length;i++) {
total += Math.pow(normaliseL2[i] - normaliseL1[i], 2);
//total += Math.pow(normaliseL2[i] - normaliseL1[i], 2);
total += Math.pow((l2.getAttributes()[i] - l1.getAttributes()[i])/MethodKNN.amplitude[i], 2);
}
//A Check
for(int i = 0;i<l2.getStringAttributes().length;i++) {
......
......@@ -28,6 +28,7 @@ public class ClassificationModel extends Observable {
private Distance distance;
private int kOptimal;
private int k;
/**
* Renvoie une instance unique du modèle. Par défaut, le type de ce modèle est Iris.
......@@ -43,7 +44,7 @@ public class ClassificationModel extends Observable {
* Initialise le modèle avec le type de données Iris.
*/
private ClassificationModel() {
this(DataType.POKEMON);
this(DataType.IRIS);
}
/**
......@@ -54,7 +55,8 @@ public class ClassificationModel extends Observable {
this.datas = new ArrayList<>();
this.dataToClass = new ConcurrentHashMap<>();
this.type = type;
this.kOptimal = 1;
this.kOptimal = 0;
this.k = 0;
this.distance = new DistanceEuclidienne();
}
/**
......@@ -89,6 +91,8 @@ public class ClassificationModel extends Observable {
types.add(d.getClassification());
}
Collections.shuffle(datas);
LoadableData.setClassificationTypes(types);
notifyObservers();
} catch (IOException e) {
......@@ -112,7 +116,7 @@ public class ClassificationModel extends Observable {
*/
public void classifierDonnee(LoadableData data) {
if(dataToClass.get(data) != null && dataToClass.get(data)) return;
this.dataToClass.remove(data);
data.setClassification(MethodKNN.estimateClass(datas, data, kOptimal, distance));
notifyObservers(data);
dataToClass.put(data, true);
......@@ -143,6 +147,18 @@ public class ClassificationModel extends Observable {
return kOptimal;
}
public void setKOptimal(int kOptimal) {
this.kOptimal = kOptimal;
}
public int getK() {
return k;
}
public void setK(int k) {
this.k = k;
}
/**
* Renvoie la liste des données chargées.
* @return liste des données chargées.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment