diff --git a/src/main/java/fr/univlille/sae/classification/model/Iris.java b/src/main/java/fr/univlille/sae/classification/model/Iris.java
index 94e39725ac51f4eb727c22827f5da24c2b79d7a0..f71fe890c7154f3d7fae2ecffbc335d13db71fe6 100644
--- a/src/main/java/fr/univlille/sae/classification/model/Iris.java
+++ b/src/main/java/fr/univlille/sae/classification/model/Iris.java
@@ -62,7 +62,6 @@ public class Iris extends LoadableData{
         return petalLength;
     }
 
-
     public double getDataType(String axes){
         switch (axes){
             case "sepalWidth":
diff --git a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
index 78545eedf055ab4c96c3051096169b79efc20afc..236f9307faae9a97fceebaac53c1df836d1a76c7 100644
--- a/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
+++ b/src/main/java/fr/univlille/sae/classification/view/MainStageView.java
@@ -17,6 +17,8 @@ import javafx.scene.control.ContextMenu;
 import javafx.scene.control.MenuItem;
 import javafx.scene.paint.Color;
 import javafx.scene.shape.Circle;
+import javafx.scene.shape.Rectangle;
+import javafx.scene.shape.Shape;
 import javafx.stage.Stage;
 
 import java.io.File;
@@ -64,11 +66,13 @@ public class MainStageView extends DataVisualizationView implements Observer {
     public void update(Observable observable) {
         if(scatterChart == null) throw new IllegalStateException();
         if(!(observable instanceof ClassificationModel)) throw new IllegalStateException();
-        //on vide le nuage pour s'assurer que celui-ci est bien vide
         scatterChart.getData().clear();
 
         XYChart.Series series1 = new XYChart.Series();
-        series1.setName("Iris");
+        XYChart.Series series2 = new XYChart.Series();
+        XYChart.Series series3 = new XYChart.Series();
+
+
 
         //Jalon 1: on verifie que le type de donnée est bien IRIS
         if(model.getType() == DataType.IRIS) {
@@ -79,33 +83,39 @@ public class MainStageView extends DataVisualizationView implements Observer {
             }
             else{
                 controller.setAxesSelected("");
-                // On ajoute la serie au nuage
-                scatterChart.getData().add(series1);
 
-                //On recupere les données du model
                 List<LoadableData> points = new ArrayList<>(model.getDatas());
                 points.addAll(model.getDataToClass());
-                // on ajoute chaque point a la serie
                 for(LoadableData i : points) {
 
                     Iris iris = (Iris)i;
                     XYChart.Data<Double, Double> dataPoint = new XYChart.Data<>(iris.getDataType(actualX),
                             iris.getDataType(actualY));
 
-                    dataPoint.setNode(getCircle(iris));
+                    dataPoint.setNode(getForm(iris, new Circle(5)));
 
-                    series1.getData().add(dataPoint);
+                    if(iris.getClassification().equals("Setosa")){
+                        series1.getData().add(dataPoint);
+                    }else if(iris.getClassification().equals("Versicolor")){
+                        series2.getData().add(dataPoint);
+                    }else if(iris.getClassification().equals("Virginica")){
+                        series3.getData().add(dataPoint);
+                    }
 
                 }
+
+                series1.setName("Setosa");
+                series2.setName("Versicolor");
+                series3.setName("Virginica");
+
+                scatterChart.getData().addAll(series1, series2, series3);
             }
         }
     }
 
-
-    private Circle getCircle(Iris iris) {
-        Circle circle = new Circle(5);
-        circle.setFill(iris.getColor());
-        circle.setOnMouseClicked(e -> {
+    private Shape getForm(Iris iris, Shape form) {
+        form.setFill(iris.getColor());
+        form.setOnMouseClicked(e -> {
             ContextMenu contextMenu = new ContextMenu();
             for(String attributes : iris.getAttributesName()) {
                 contextMenu.getItems().add(new MenuItem(attributes + " : " + iris.getDataType(attributes)));
@@ -113,7 +123,7 @@ public class MainStageView extends DataVisualizationView implements Observer {
             contextMenu.show(root, e.getScreenX(), e.getScreenY());
         });
 
-        return  circle;
+        return form;
     }
 
 
@@ -133,7 +143,7 @@ public class MainStageView extends DataVisualizationView implements Observer {
                     iris.getDataType(actualY)
             );
 
-            dataPoint.setNode(getCircle(iris));
+            dataPoint.setNode(getForm(iris, new Rectangle(10, 10)));
             if (!scatterChart.getData().isEmpty()) {
                 XYChart.Series series = (XYChart.Series) scatterChart.getData().get(0);
                 series.getData().add(dataPoint);