diff --git a/app_glob.py b/app_glob.py index 62f8c722f5500b48c6758cc2d9a1fadfdf75a40d..d5783a3a883acd5bf8d48a7eaf7f8ca494981cec 100644 --- a/app_glob.py +++ b/app_glob.py @@ -1,93 +1,59 @@ -from flask import Flask, request, jsonify, render_template -import pickle +from flask import Flask, render_template, request, jsonify import pandas as pd -from sklearn.preprocessing import StandardScaler -import numpy as np +import pickle +import os app = Flask(__name__) -# Charger le modèle -with open('random_forest_model_binaire.pkl', 'rb') as model_file: - rf = pickle.load(model_file) - -# Charger le scaler entraîné -with open('scaler_binaire.pkl', 'rb') as scaler_file: - scaler = pickle.load(scaler_file) +# Charger les modèles et scalers +model_paths = { + "binaire": "random_forest_model_binaire.pkl", + "sup0": "random_forest_model_sup0.pkl" +} +scaler_paths = { + "binaire": "scaler_binaire.pkl", + "sup0": "scaler_sup0.pkl" +} + +models = {} +scalers = {} + +for key in model_paths: + with open(model_paths[key], "rb") as model_file: + models[key] = pickle.load(model_file) + with open(scaler_paths[key], "rb") as scaler_file: + scalers[key] = pickle.load(scaler_file) + +# Charger les features +features_path = "features.txt" +if os.path.exists(features_path): + with open(features_path, "r") as f: + features = f.read().splitlines() +else: + features = [] @app.route('/') -def home(): - return render_template('index-glob.html') +def index(): + return render_template('index-glob.html', features=features) @app.route('/predict', methods=['POST']) def predict(): - return process_prediction(request.form, '/predict') + return predict_with_model("binaire") @app.route('/predict_sup0', methods=['POST']) def predict_sup0(): - return process_prediction(request.form, '/predict_sup0') - -def process_prediction(form_data, endpoint): - data = form_data.to_dict() - - if 'name' in data: - data['nb_caracteres_sans_espaces'] = len(data['name'].replace(" ", "")) - if 'artists' in data: - data['nb_artistes'] = data['artists'].count(',') + 1 - data['featuring'] = int(data['nb_artistes'] > 1) - if 'duration_ms' in data: - duration_ms = float(data['duration_ms']) - data['duree_minute'] = float(f"{int(duration_ms // 60000)}.{int((duration_ms % 60000) // 1000):02d}") - if 'year' in data: - year = int(data['year']) - data['categorie_annee'] = 3 if year < 1954 else 2 if year < 2002 else 1 - if 'tempo' in data: - tempo = float(data['tempo']) - if 40 <= tempo < 60: - data['categorie_tempo'] = 1 - elif 60 <= tempo < 66: - data['categorie_tempo'] = 2 - elif 66 <= tempo < 76: - data['categorie_tempo'] = 3 - elif 76 <= tempo < 108: - data['categorie_tempo'] = 4 - elif 108 <= tempo < 120: - data['categorie_tempo'] = 5 - elif 120 <= tempo < 163: - data['categorie_tempo'] = 6 - elif 163 <= tempo < 200: - data['categorie_tempo'] = 7 - elif 200 <= tempo <= 208: - data['categorie_tempo'] = 8 - else: - data['categorie_tempo'] = 9 - - # Supprimer les clés non utilisées directement - data.pop('name', None) - data.pop('artists', None) - data.pop('duration_ms', None) - - # Convertir les valeurs en float si possible - for key in data: - try: - data[key] = float(data[key]) - except ValueError: - pass - - expected_features = ['year', 'acousticness', 'danceability', 'energy', 'explicit', - 'instrumentalness', 'key', 'liveness', 'loudness', 'mode', - 'speechiness', 'tempo', 'valence', 'nb_caracteres_sans_espaces', - 'nb_artistes', 'featuring', 'duree_minute', 'categorie_annee', 'categorie_tempo'] - - input_data = pd.DataFrame([[data.get(key, 0) for key in expected_features]], columns=expected_features) - - missing_cols = [col for col in expected_features if col not in input_data.columns] - if missing_cols: - return jsonify({'error': f'Missing features: {missing_cols}'}), 400 - - input_data_scaled = scaler.transform(input_data) - predictions = rf.predict(input_data_scaled) - - return jsonify({'predictions': int(predictions[0])}) + return predict_with_model("sup0") + +def predict_with_model(model_key): + try: + data = request.get_json() + input_data = [float(data[feature]) for feature in features] + df_input = pd.DataFrame([input_data], columns=features) + df_scaled = scalers[model_key].transform(df_input) + prediction = models[model_key].predict(df_scaled) + return jsonify({"prediction": int(prediction[0])}) + except Exception as e: + return jsonify({"error": str(e)}) if __name__ == '__main__': app.run(debug=True) diff --git a/templates/index-glob.html b/templates/index-glob.html index cf39f737f5d6da8cd2793fe50881ee20b2b9a9cf..8efd8799f62aeed03f44e3e880d56ea6aabd2d2a 100644 --- a/templates/index-glob.html +++ b/templates/index-glob.html @@ -9,37 +9,42 @@ font-family: Arial, sans-serif; margin: 0; display: flex; + background-color: #121212; + color: white; } /* Sidebar */ .sidebar { - width: 200px; - background-color: #333; - color: white; + width: 220px; + background-color: #181818; height: 100vh; padding-top: 20px; display: flex; flex-direction: column; align-items: center; } + .sidebar img { + width: 120px; + margin-bottom: 20px; + } .sidebar button { width: 80%; padding: 10px; margin: 10px 0; border: none; - background: #444; - color: white; + background: #1DB954; + color: black; cursor: pointer; font-size: 16px; text-align: center; + border-radius: 20px; } .sidebar button:hover { - background: #555; + background: #1ed760; } .content { flex-grow: 1; padding: 20px; } - /* Cacher les sections par défaut */ .tab-content { display: none; } @@ -49,6 +54,9 @@ form { max-width: 600px; margin: auto; + background: #282828; + padding: 20px; + border-radius: 10px; } label { display: block; @@ -61,23 +69,25 @@ margin-top: 5px; border: 1px solid #ccc; border-radius: 4px; + background-color: #333; + color: white; } button { margin-top: 20px; padding: 10px 20px; - background-color: #4CAF50; - color: white; + background-color: #1DB954; + color: black; border: none; - border-radius: 4px; + border-radius: 20px; cursor: pointer; } button:hover { - background-color: #45a049; + background-color: #1ed760; } #result { margin-top: 20px; font-size: 1.2em; - color: #555; + color: #1DB954; } </style> </head> @@ -85,6 +95,7 @@ <!-- Sidebar pour naviguer entre les onglets --> <div class="sidebar"> + <img src="https://upload.wikimedia.org/wikipedia/commons/2/26/Spotify_logo_with_text.svg" alt="Spotify Logo"> <button onclick="showTab('tab1')">Prédiction Standard</button> <button onclick="showTab('tab2')">Prédiction (>0)</button> </div>