analist commited on
Commit
96fbc4d
·
verified ·
1 Parent(s): 8c5fb8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -180
app.py CHANGED
@@ -96,216 +96,148 @@ def plot_feature_importance(model, feature_names, model_type):
96
  plt.title(f"Feature Importance - {model_type}")
97
  return plt.gcf()
98
 
99
- import streamlit as st
100
- import pandas as pd
101
- import numpy as np
102
- import matplotlib.pyplot as plt
103
- from sklearn.tree import plot_tree, export_text
104
- import seaborn as sns
105
- from sklearn.preprocessing import LabelEncoder
106
- from sklearn.ensemble import RandomForestClassifier
107
- from sklearn.tree import DecisionTreeClassifier
108
- from sklearn.ensemble import GradientBoostingClassifier
109
- from sklearn.linear_model import LogisticRegression
110
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
111
- import shap
112
-
113
- # Configuration de la page
114
- st.set_page_config(
115
- page_title="ML Model Interpreter",
116
- layout="wide",
117
- initial_sidebar_state="expanded"
118
- )
119
-
120
- # CSS personnalisé
121
- st.markdown("""
122
- <style>
123
- .main-header {
124
- color: #0D47A1;
125
- text-align: center;
126
- padding: 1rem;
127
- background: linear-gradient(90deg, #FFFFFF 0%, #90CAF9 50%, #FFFFFF 100%);
128
- border-radius: 10px;
129
- margin-bottom: 2rem;
130
- }
131
-
132
- .metric-card {
133
- background-color: white;
134
- padding: 1.5rem;
135
- border-radius: 10px;
136
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
137
- margin-bottom: 1rem;
138
- }
139
-
140
- .sub-header {
141
- color: #1E88E5;
142
- border-bottom: 2px solid #90CAF9;
143
- padding-bottom: 0.5rem;
144
- margin-bottom: 1rem;
145
- }
146
-
147
- .metric-value {
148
- font-size: 1.5rem;
149
- font-weight: bold;
150
- color: #1E88E5;
151
- }
152
-
153
- div[data-testid="stMetricValue"] {
154
- color: #1E88E5;
155
- }
156
- </style>
157
- """, unsafe_allow_html=True)
158
-
159
- def custom_metric_card(title, value, prefix=""):
160
- return f"""
161
- <div class="metric-card">
162
- <h3 style="color: #1E88E5; margin-bottom: 0.5rem;">{title}</h3>
163
- <p class="metric-value">{prefix}{value:.4f}</p>
164
- </div>
165
- """
166
-
167
- def set_plot_style(fig):
168
- """Configure le style des graphiques"""
169
- colors = ['#1E88E5', '#90CAF9', '#0D47A1', '#42A5F5']
170
- for ax in fig.axes:
171
- ax.set_facecolor('#F8F9FA')
172
- ax.grid(True, linestyle='--', alpha=0.3, color='#666666')
173
- ax.spines['top'].set_visible(False)
174
- ax.spines['right'].set_visible(False)
175
- ax.tick_params(axis='both', colors='#666666')
176
- ax.set_axisbelow(True)
177
- return fig, colors
178
-
179
- def plot_model_performance(results):
180
- metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
181
- fig, axes = plt.subplots(1, 2, figsize=(15, 6))
182
- fig, colors = set_plot_style(fig)
183
-
184
- # Training metrics
185
- train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics]
186
- for model in results.keys()}
187
- train_df = pd.DataFrame(train_data, index=metrics)
188
- train_df.plot(kind='bar', ax=axes[0], color=colors)
189
- axes[0].set_title('Performance d\'Entraînement', color='#0D47A1', pad=20)
190
- axes[0].set_ylim(0, 1)
191
-
192
- # Test metrics
193
- test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics]
194
- for model in results.keys()}
195
- test_df = pd.DataFrame(test_data, index=metrics)
196
- test_df.plot(kind='bar', ax=axes[1], color=colors)
197
- axes[1].set_title('Performance de Test', color='#0D47A1', pad=20)
198
- axes[1].set_ylim(0, 1)
199
-
200
- # Style des graphiques
201
- for ax in axes:
202
- plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
203
- ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
204
-
205
- plt.tight_layout()
206
- return fig
207
-
208
- def plot_feature_importance(model, feature_names, model_type):
209
- fig, ax = plt.subplots(figsize=(10, 6))
210
- fig, colors = set_plot_style(fig)
211
-
212
- if model_type in ["Decision Tree", "Random Forest", "Gradient Boost"]:
213
- importance = model.feature_importances_
214
- elif model_type == "Logistic Regression":
215
- importance = np.abs(model.coef_[0])
216
-
217
- importance_df = pd.DataFrame({
218
- 'feature': feature_names,
219
- 'importance': importance
220
- }).sort_values('importance', ascending=True)
221
-
222
- ax.barh(importance_df['feature'], importance_df['importance'],
223
- color='#1E88E5', alpha=0.8)
224
- ax.set_title("Importance des Caractéristiques", color='#0D47A1', pad=20)
225
-
226
- return fig
227
-
228
- def plot_correlation_matrix(data):
229
- fig, ax = plt.subplots(figsize=(10, 8))
230
- fig, _ = set_plot_style(fig)
231
-
232
- sns.heatmap(data.corr(), annot=True, cmap='coolwarm', center=0,
233
- ax=ax, fmt='.2f', square=True)
234
- ax.set_title("Matrice de Corrélation", color='#0D47A1', pad=20)
235
-
236
- return fig
237
-
238
  def app():
239
- st.markdown('<h1 class="main-header">Interpréteur de Modèles ML</h1>',
240
- unsafe_allow_html=True)
241
 
242
  # Load data
243
  X_train, y_train, X_test, y_test, feature_names = load_data()
244
 
245
  # Train models if not in session state
246
  if 'model_results' not in st.session_state:
247
- with st.spinner("🔄 Entraînement des modèles en cours..."):
248
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
249
 
250
  # Sidebar
251
- with st.sidebar:
252
- st.markdown('<h2 style="color: #1E88E5;">Navigation</h2>',
253
- unsafe_allow_html=True)
254
-
255
- selected_model = st.selectbox(
256
- "📊 Sélectionnez un modèle",
257
- list(st.session_state.model_results.keys())
258
- )
259
-
260
- st.markdown('<hr style="margin: 1rem 0;">', unsafe_allow_html=True)
261
-
262
- page = st.radio(
263
- "📑 Sélectionnez une section",
264
- ["Performance des modèles",
265
- "Interprétation du modèle",
266
- "Analyse des caractéristiques",
267
- "Simulateur de prédictions"]
268
- )
269
 
270
  current_model = st.session_state.model_results[selected_model]['model']
271
 
272
- # Main content
273
  if page == "Performance des modèles":
274
- st.markdown('<h2 class="sub-header">Performance des modèles</h2>',
275
- unsafe_allow_html=True)
276
 
 
 
277
  performance_fig = plot_model_performance(st.session_state.model_results)
278
  st.pyplot(performance_fig)
279
 
280
- st.markdown('<h3 class="sub-header">Métriques détaillées</h3>',
281
- unsafe_allow_html=True)
282
-
283
  col1, col2 = st.columns(2)
 
284
  with col1:
285
- st.markdown('<h4 style="color: #1E88E5;">Entraînement</h4>',
286
- unsafe_allow_html=True)
287
  for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
288
- st.markdown(custom_metric_card(metric.capitalize(), value),
289
- unsafe_allow_html=True)
290
 
291
  with col2:
292
- st.markdown('<h4 style="color: #1E88E5;">Test</h4>',
293
- unsafe_allow_html=True)
294
  for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
295
- st.markdown(custom_metric_card(metric.capitalize(), value),
296
- unsafe_allow_html=True)
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  elif page == "Analyse des caractéristiques":
299
- st.markdown('<h2 class="sub-header">Analyse des caractéristiques</h2>',
300
- unsafe_allow_html=True)
301
 
 
 
302
  importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
303
  st.pyplot(importance_fig)
304
 
305
- st.markdown('<h3 class="sub-header">Corrélations</h3>',
306
- unsafe_allow_html=True)
307
- corr_fig = plot_correlation_matrix(X_train)
308
- st.pyplot(corr_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  if __name__ == "__main__":
311
  app()
 
96
  plt.title(f"Feature Importance - {model_type}")
97
  return plt.gcf()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def app():
100
+ st.title("Interpréteur de Modèles ML")
 
101
 
102
  # Load data
103
  X_train, y_train, X_test, y_test, feature_names = load_data()
104
 
105
  # Train models if not in session state
106
  if 'model_results' not in st.session_state:
107
+ with st.spinner("Entraînement des modèles en cours..."):
108
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
109
 
110
  # Sidebar
111
+ st.sidebar.title("Navigation")
112
+ selected_model = st.sidebar.selectbox(
113
+ "Sélectionnez un modèle",
114
+ list(st.session_state.model_results.keys())
115
+ )
116
+
117
+ page = st.sidebar.radio(
118
+ "Sélectionnez une section",
119
+ ["Performance des modèles",
120
+ "Interprétation du modèle",
121
+ "Analyse des caractéristiques",
122
+ "Simulateur de prédictions"]
123
+ )
 
 
 
 
 
124
 
125
  current_model = st.session_state.model_results[selected_model]['model']
126
 
127
+ # Performance des modèles
128
  if page == "Performance des modèles":
129
+ st.header("Performance des modèles")
 
130
 
131
+ # Plot global performance comparison
132
+ st.subheader("Comparaison des performances")
133
  performance_fig = plot_model_performance(st.session_state.model_results)
134
  st.pyplot(performance_fig)
135
 
136
+ # Detailed metrics for selected model
137
+ st.subheader(f"Métriques détaillées - {selected_model}")
 
138
  col1, col2 = st.columns(2)
139
+
140
  with col1:
141
+ st.write("Métriques d'entraînement:")
 
142
  for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
143
+ st.write(f"{metric}: {value:.4f}")
 
144
 
145
  with col2:
146
+ st.write("Métriques de test:")
 
147
  for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
148
+ st.write(f"{metric}: {value:.4f}")
 
149
 
150
+ # Interprétation du modèle
151
+ elif page == "Interprétation du modèle":
152
+ st.header(f"Interprétation du modèle - {selected_model}")
153
+
154
+ if selected_model in ["Decision Tree", "Random Forest"]:
155
+ if selected_model == "Decision Tree":
156
+ st.subheader("Visualisation de l'arbre")
157
+ max_depth = st.slider("Profondeur maximale à afficher", 1, 5, 3)
158
+ fig, ax = plt.subplots(figsize=(20, 10))
159
+ plot_tree(current_model, feature_names=list(feature_names),
160
+ max_depth=max_depth, filled=True, rounded=True)
161
+ st.pyplot(fig)
162
+
163
+ st.subheader("Règles de décision importantes")
164
+ if selected_model == "Decision Tree":
165
+ st.text(export_text(current_model, feature_names=list(feature_names)))
166
+
167
+ # SHAP values for all models
168
+ st.subheader("SHAP Values")
169
+ with st.spinner("Calcul des valeurs SHAP en cours..."):
170
+ explainer = shap.TreeExplainer(current_model) if selected_model != "Logistic Regression" \
171
+ else shap.LinearExplainer(current_model, X_train)
172
+ shap_values = explainer.shap_values(X_train[:100]) # Using first 100 samples for speed
173
+
174
+ fig, ax = plt.subplots(figsize=(10, 6))
175
+ shap.summary_plot(shap_values, X_train[:100], feature_names=list(feature_names),
176
+ show=False)
177
+ st.pyplot(fig)
178
+
179
+ # Analyse des caractéristiques
180
  elif page == "Analyse des caractéristiques":
181
+ st.header("Analyse des caractéristiques")
 
182
 
183
+ # Feature importance
184
+ st.subheader("Importance des caractéristiques")
185
  importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
186
  st.pyplot(importance_fig)
187
 
188
+ # Feature correlation
189
+ st.subheader("Matrice de corrélation")
190
+ corr_matrix = X_train.corr()
191
+ fig, ax = plt.subplots(figsize=(10, 8))
192
+ sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0)
193
+ st.pyplot(fig)
194
+
195
+ # Simulateur de prédictions
196
+ else:
197
+ st.header("Simulateur de prédictions")
198
+
199
+ input_values = {}
200
+ for feature in feature_names:
201
+ if X_train[feature].dtype == 'object':
202
+ input_values[feature] = st.selectbox(
203
+ f"Sélectionnez {feature}",
204
+ options=X_train[feature].unique()
205
+ )
206
+ else:
207
+ input_values[feature] = st.slider(
208
+ f"Valeur pour {feature}",
209
+ float(X_train[feature].min()),
210
+ float(X_train[feature].max()),
211
+ float(X_train[feature].mean())
212
+ )
213
+
214
+ if st.button("Prédire"):
215
+ input_df = pd.DataFrame([input_values])
216
+
217
+ prediction = current_model.predict_proba(input_df)
218
+
219
+ st.write("Probabilités prédites:")
220
+ st.write({f"Classe {i}": f"{prob:.2%}" for i, prob in enumerate(prediction[0])})
221
+
222
+ if selected_model == "Decision Tree":
223
+ st.subheader("Chemin de décision")
224
+ node_indicator = current_model.decision_path(input_df)
225
+ leaf_id = current_model.apply(input_df)
226
+
227
+ node_index = node_indicator.indices[node_indicator.indptr[0]:node_indicator.indptr[1]]
228
+
229
+ rules = []
230
+ for node_id in node_index:
231
+ if node_id != leaf_id[0]:
232
+ threshold = current_model.tree_.threshold[node_id]
233
+ feature = feature_names[current_model.tree_.feature[node_id]]
234
+ if input_df.iloc[0][feature] <= threshold:
235
+ rules.append(f"{feature} ≤ {threshold:.2f}")
236
+ else:
237
+ rules.append(f"{feature} > {threshold:.2f}")
238
+
239
+ for rule in rules:
240
+ st.write(rule)
241
 
242
  if __name__ == "__main__":
243
  app()