Skip to content
Snippets Groups Projects
Commit bf3c78c8 authored by Simon Majorczyk's avatar Simon Majorczyk
Browse files

Merge branch 'Simon' into dev

parents 63230eb4 76aea6e6
No related branches found
No related tags found
No related merge requests found
from flask import Flask, request, jsonify, render_template from flask import Flask, render_template, request, jsonify
import pickle
import pandas as pd import pandas as pd
from sklearn.preprocessing import StandardScaler import pickle
import numpy as np import os
app = Flask(__name__) app = Flask(__name__)
# Charger le modèle # Charger les modèles et scalers
with open('random_forest_model_binaire.pkl', 'rb') as model_file: model_paths = {
rf = pickle.load(model_file) "binaire": "random_forest_model_binaire.pkl",
"sup0": "random_forest_model_sup0.pkl"
# Charger le scaler entraîné }
with open('scaler_binaire.pkl', 'rb') as scaler_file: scaler_paths = {
scaler = pickle.load(scaler_file) "binaire": "scaler_binaire.pkl",
"sup0": "scaler_sup0.pkl"
}
models = {}
scalers = {}
for key in model_paths:
with open(model_paths[key], "rb") as model_file:
models[key] = pickle.load(model_file)
with open(scaler_paths[key], "rb") as scaler_file:
scalers[key] = pickle.load(scaler_file)
# Charger les features
features_path = "features.txt"
if os.path.exists(features_path):
with open(features_path, "r") as f:
features = f.read().splitlines()
else:
features = []
@app.route('/') @app.route('/')
def home(): def index():
return render_template('index-glob.html') return render_template('index-glob.html', features=features)
@app.route('/predict', methods=['POST']) @app.route('/predict', methods=['POST'])
def predict(): def predict():
return process_prediction(request.form, '/predict') return predict_with_model("binaire")
@app.route('/predict_sup0', methods=['POST']) @app.route('/predict_sup0', methods=['POST'])
def predict_sup0(): def predict_sup0():
return process_prediction(request.form, '/predict_sup0') return predict_with_model("sup0")
def process_prediction(form_data, endpoint):
data = form_data.to_dict()
if 'name' in data:
data['nb_caracteres_sans_espaces'] = len(data['name'].replace(" ", ""))
if 'artists' in data:
data['nb_artistes'] = data['artists'].count(',') + 1
data['featuring'] = int(data['nb_artistes'] > 1)
if 'duration_ms' in data:
duration_ms = float(data['duration_ms'])
data['duree_minute'] = float(f"{int(duration_ms // 60000)}.{int((duration_ms % 60000) // 1000):02d}")
if 'year' in data:
year = int(data['year'])
data['categorie_annee'] = 3 if year < 1954 else 2 if year < 2002 else 1
if 'tempo' in data:
tempo = float(data['tempo'])
if 40 <= tempo < 60:
data['categorie_tempo'] = 1
elif 60 <= tempo < 66:
data['categorie_tempo'] = 2
elif 66 <= tempo < 76:
data['categorie_tempo'] = 3
elif 76 <= tempo < 108:
data['categorie_tempo'] = 4
elif 108 <= tempo < 120:
data['categorie_tempo'] = 5
elif 120 <= tempo < 163:
data['categorie_tempo'] = 6
elif 163 <= tempo < 200:
data['categorie_tempo'] = 7
elif 200 <= tempo <= 208:
data['categorie_tempo'] = 8
else:
data['categorie_tempo'] = 9
# Supprimer les clés non utilisées directement def predict_with_model(model_key):
data.pop('name', None)
data.pop('artists', None)
data.pop('duration_ms', None)
# Convertir les valeurs en float si possible
for key in data:
try: try:
data[key] = float(data[key]) data = request.get_json()
except ValueError: input_data = [float(data[feature]) for feature in features]
pass df_input = pd.DataFrame([input_data], columns=features)
df_scaled = scalers[model_key].transform(df_input)
expected_features = ['year', 'acousticness', 'danceability', 'energy', 'explicit', prediction = models[model_key].predict(df_scaled)
'instrumentalness', 'key', 'liveness', 'loudness', 'mode', return jsonify({"prediction": int(prediction[0])})
'speechiness', 'tempo', 'valence', 'nb_caracteres_sans_espaces', except Exception as e:
'nb_artistes', 'featuring', 'duree_minute', 'categorie_annee', 'categorie_tempo'] return jsonify({"error": str(e)})
input_data = pd.DataFrame([[data.get(key, 0) for key in expected_features]], columns=expected_features)
missing_cols = [col for col in expected_features if col not in input_data.columns]
if missing_cols:
return jsonify({'error': f'Missing features: {missing_cols}'}), 400
input_data_scaled = scaler.transform(input_data)
predictions = rf.predict(input_data_scaled)
return jsonify({'predictions': int(predictions[0])})
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=True) app.run(debug=True)
...@@ -9,37 +9,42 @@ ...@@ -9,37 +9,42 @@
font-family: Arial, sans-serif; font-family: Arial, sans-serif;
margin: 0; margin: 0;
display: flex; display: flex;
background-color: #121212;
color: white;
} }
/* Sidebar */ /* Sidebar */
.sidebar { .sidebar {
width: 200px; width: 220px;
background-color: #333; background-color: #181818;
color: white;
height: 100vh; height: 100vh;
padding-top: 20px; padding-top: 20px;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
align-items: center; align-items: center;
} }
.sidebar img {
width: 120px;
margin-bottom: 20px;
}
.sidebar button { .sidebar button {
width: 80%; width: 80%;
padding: 10px; padding: 10px;
margin: 10px 0; margin: 10px 0;
border: none; border: none;
background: #444; background: #1DB954;
color: white; color: black;
cursor: pointer; cursor: pointer;
font-size: 16px; font-size: 16px;
text-align: center; text-align: center;
border-radius: 20px;
} }
.sidebar button:hover { .sidebar button:hover {
background: #555; background: #1ed760;
} }
.content { .content {
flex-grow: 1; flex-grow: 1;
padding: 20px; padding: 20px;
} }
/* Cacher les sections par défaut */
.tab-content { .tab-content {
display: none; display: none;
} }
...@@ -49,7 +54,11 @@ ...@@ -49,7 +54,11 @@
form { form {
max-width: 600px; max-width: 600px;
margin: auto; margin: auto;
background: #282828;
padding: 20px;
border-radius: 10px;
} }
label { label {
display: block; display: block;
margin-top: 10px; margin-top: 10px;
...@@ -61,23 +70,25 @@ ...@@ -61,23 +70,25 @@
margin-top: 5px; margin-top: 5px;
border: 1px solid #ccc; border: 1px solid #ccc;
border-radius: 4px; border-radius: 4px;
background-color: #333;
color: white;
} }
button { button {
margin-top: 20px; margin-top: 20px;
padding: 10px 20px; padding: 10px 20px;
background-color: #4CAF50; background-color: #1DB954;
color: white; color: black;
border: none; border: none;
border-radius: 4px; border-radius: 20px;
cursor: pointer; cursor: pointer;
} }
button:hover { button:hover {
background-color: #45a049; background-color: #1ed760;
} }
#result { #result {
margin-top: 20px; margin-top: 20px;
font-size: 1.2em; font-size: 1.2em;
color: #555; color: #1DB954;
} }
</style> </style>
</head> </head>
...@@ -85,6 +96,7 @@ ...@@ -85,6 +96,7 @@
<!-- Sidebar pour naviguer entre les onglets --> <!-- Sidebar pour naviguer entre les onglets -->
<div class="sidebar"> <div class="sidebar">
<img src="https://upload.wikimedia.org/wikipedia/commons/2/26/Spotify_logo_with_text.svg" alt="Spotify Logo">
<button onclick="showTab('tab1')">Prédiction Standard</button> <button onclick="showTab('tab1')">Prédiction Standard</button>
<button onclick="showTab('tab2')">Prédiction (>0)</button> <button onclick="showTab('tab2')">Prédiction (>0)</button>
</div> </div>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment