analist commited on
Commit
42fa5c8
·
verified ·
1 Parent(s): a9345e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -89
app.py CHANGED
@@ -110,7 +110,7 @@ from sklearn.linear_model import LogisticRegression
110
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
111
  import shap
112
 
113
- # Configuration de la page et du thème
114
  st.set_page_config(
115
  page_title="ML Model Interpreter",
116
  layout="wide",
@@ -120,25 +120,15 @@ st.set_page_config(
120
  # CSS personnalisé
121
  st.markdown("""
122
  <style>
123
- /* Couleurs principales */
124
- :root {
125
- --primary-blue: #1E88E5;
126
- --light-blue: #90CAF9;
127
- --dark-blue: #0D47A1;
128
- --white: #FFFFFF;
129
- }
130
-
131
- /* En-tête principal */
132
  .main-header {
133
- color: var(--dark-blue);
134
  text-align: center;
135
  padding: 1rem;
136
- background: linear-gradient(90deg, var(--white) 0%, var(--light-blue) 50%, var(--white) 100%);
137
  border-radius: 10px;
138
  margin-bottom: 2rem;
139
  }
140
 
141
- /* Carte pour les métriques */
142
  .metric-card {
143
  background-color: white;
144
  padding: 1.5rem;
@@ -147,44 +137,21 @@ st.markdown("""
147
  margin-bottom: 1rem;
148
  }
149
 
150
- /* Style pour les sous-titres */
151
  .sub-header {
152
- color: var(--primary-blue);
153
- border-bottom: 2px solid var(--light-blue);
154
  padding-bottom: 0.5rem;
155
  margin-bottom: 1rem;
156
  }
157
 
158
- /* Style pour les valeurs de métriques */
159
  .metric-value {
160
  font-size: 1.5rem;
161
  font-weight: bold;
162
- color: var(--primary-blue);
163
- }
164
-
165
- /* Style pour la barre latérale */
166
- .sidebar .sidebar-content {
167
- background-color: var(--white);
168
- }
169
-
170
- /* Style pour les boutons */
171
- .stButton > button {
172
- background-color: var(--primary-blue);
173
- color: white;
174
- border-radius: 5px;
175
- border: none;
176
- padding: 0.5rem 1rem;
177
  }
178
 
179
- /* Style pour les sliders */
180
- .stSlider > div > div {
181
- background-color: var(--light-blue);
182
- }
183
-
184
- /* Style pour les selectbox */
185
- .stSelectbox > div > div {
186
- background-color: white;
187
- border: 1px solid var(--light-blue);
188
  }
189
  </style>
190
  """, unsafe_allow_html=True)
@@ -197,57 +164,80 @@ def custom_metric_card(title, value, prefix=""):
197
  </div>
198
  """
199
 
200
- def plot_with_style(fig):
201
- # Style matplotlib
202
- plt.style.use('seaborn')
203
- fig.patch.set_facecolor('#FFFFFF')
204
  for ax in fig.axes:
205
  ax.set_facecolor('#F8F9FA')
206
- ax.grid(True, linestyle='--', alpha=0.7)
207
  ax.spines['top'].set_visible(False)
208
  ax.spines['right'].set_visible(False)
209
- return fig
210
-
211
- # [Fonctions load_data et train_models restent identiques]
212
 
213
  def plot_model_performance(results):
214
  metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
215
  fig, axes = plt.subplots(1, 2, figsize=(15, 6))
216
-
217
- # Configuration du style
218
- plt.style.use('seaborn')
219
- colors = ['#1E88E5', '#90CAF9', '#0D47A1', '#42A5F5']
220
 
221
  # Training metrics
222
  train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics]
223
  for model in results.keys()}
224
  train_df = pd.DataFrame(train_data, index=metrics)
225
- train_df.plot(kind='bar', ax=axes[0], title='Performance d\'Entraînement',
226
- color=colors)
227
  axes[0].set_ylim(0, 1)
228
 
229
  # Test metrics
230
  test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics]
231
  for model in results.keys()}
232
  test_df = pd.DataFrame(test_data, index=metrics)
233
- test_df.plot(kind='bar', ax=axes[1], title='Performance de Test',
234
- color=colors)
235
  axes[1].set_ylim(0, 1)
236
 
237
  # Style des graphiques
238
  for ax in axes:
239
- ax.set_facecolor('#F8F9FA')
240
- ax.grid(True, linestyle='--', alpha=0.7)
241
- ax.spines['top'].set_visible(False)
242
- ax.spines['right'].set_visible(False)
243
  plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
 
244
 
245
  plt.tight_layout()
246
  return fig
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def app():
249
- # En-tête principal avec style personnalisé
250
- st.markdown('<h1 class="main-header">Interpréteur de Modèles ML</h1>', unsafe_allow_html=True)
251
 
252
  # Load data
253
  X_train, y_train, X_test, y_test, feature_names = load_data()
@@ -257,9 +247,11 @@ def app():
257
  with st.spinner("🔄 Entraînement des modèles en cours..."):
258
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
259
 
260
- # Sidebar avec style personnalisé
261
  with st.sidebar:
262
- st.markdown('<h2 style="color: #1E88E5;">Navigation</h2>', unsafe_allow_html=True)
 
 
263
  selected_model = st.selectbox(
264
  "📊 Sélectionnez un modèle",
265
  list(st.session_state.model_results.keys())
@@ -277,31 +269,43 @@ def app():
277
 
278
  current_model = st.session_state.model_results[selected_model]['model']
279
 
280
- # Container principal avec padding
281
- main_container = st.container()
282
- with main_container:
283
- if page == "Performance des modèles":
284
- st.markdown('<h2 class="sub-header">Performance des modèles</h2>', unsafe_allow_html=True)
285
-
286
- # Graphiques de performance
287
- performance_fig = plot_model_performance(st.session_state.model_results)
288
- st.pyplot(plot_with_style(performance_fig))
289
-
290
- # Métriques détaillées dans des cartes
291
- st.markdown('<h3 class="sub-header">Métriques détaillées</h3>', unsafe_allow_html=True)
292
- col1, col2 = st.columns(2)
293
-
294
- with col1:
295
- st.markdown('<h4 style="color: #1E88E5;">Entraînement</h4>', unsafe_allow_html=True)
296
- for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
297
- st.markdown(custom_metric_card(metric.capitalize(), value), unsafe_allow_html=True)
298
-
299
- with col2:
300
- st.markdown('<h4 style="color: #1E88E5;">Test</h4>', unsafe_allow_html=True)
301
- for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
302
- st.markdown(custom_metric_card(metric.capitalize(), value), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
303
 
304
- # [Le reste des sections avec style adapté...]
 
 
 
305
 
306
  if __name__ == "__main__":
307
  app()
 
110
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
111
  import shap
112
 
113
+ # Configuration de la page
114
  st.set_page_config(
115
  page_title="ML Model Interpreter",
116
  layout="wide",
 
120
  # CSS personnalisé
121
  st.markdown("""
122
  <style>
 
 
 
 
 
 
 
 
 
123
  .main-header {
124
+ color: #0D47A1;
125
  text-align: center;
126
  padding: 1rem;
127
+ background: linear-gradient(90deg, #FFFFFF 0%, #90CAF9 50%, #FFFFFF 100%);
128
  border-radius: 10px;
129
  margin-bottom: 2rem;
130
  }
131
 
 
132
  .metric-card {
133
  background-color: white;
134
  padding: 1.5rem;
 
137
  margin-bottom: 1rem;
138
  }
139
 
 
140
  .sub-header {
141
+ color: #1E88E5;
142
+ border-bottom: 2px solid #90CAF9;
143
  padding-bottom: 0.5rem;
144
  margin-bottom: 1rem;
145
  }
146
 
 
147
  .metric-value {
148
  font-size: 1.5rem;
149
  font-weight: bold;
150
+ color: #1E88E5;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  }
152
 
153
+ div[data-testid="stMetricValue"] {
154
+ color: #1E88E5;
 
 
 
 
 
 
 
155
  }
156
  </style>
157
  """, unsafe_allow_html=True)
 
164
  </div>
165
  """
166
 
167
+ def set_plot_style(fig):
168
+ """Configure le style des graphiques"""
169
+ colors = ['#1E88E5', '#90CAF9', '#0D47A1', '#42A5F5']
 
170
  for ax in fig.axes:
171
  ax.set_facecolor('#F8F9FA')
172
+ ax.grid(True, linestyle='--', alpha=0.3, color='#666666')
173
  ax.spines['top'].set_visible(False)
174
  ax.spines['right'].set_visible(False)
175
+ ax.tick_params(axis='both', colors='#666666')
176
+ ax.set_axisbelow(True)
177
+ return fig, colors
178
 
179
  def plot_model_performance(results):
180
  metrics = ['accuracy', 'f1', 'precision', 'recall', 'roc_auc']
181
  fig, axes = plt.subplots(1, 2, figsize=(15, 6))
182
+ fig, colors = set_plot_style(fig)
 
 
 
183
 
184
  # Training metrics
185
  train_data = {model: [results[model]['train_metrics'][metric] for metric in metrics]
186
  for model in results.keys()}
187
  train_df = pd.DataFrame(train_data, index=metrics)
188
+ train_df.plot(kind='bar', ax=axes[0], color=colors)
189
+ axes[0].set_title('Performance d\'Entraînement', color='#0D47A1', pad=20)
190
  axes[0].set_ylim(0, 1)
191
 
192
  # Test metrics
193
  test_data = {model: [results[model]['test_metrics'][metric] for metric in metrics]
194
  for model in results.keys()}
195
  test_df = pd.DataFrame(test_data, index=metrics)
196
+ test_df.plot(kind='bar', ax=axes[1], color=colors)
197
+ axes[1].set_title('Performance de Test', color='#0D47A1', pad=20)
198
  axes[1].set_ylim(0, 1)
199
 
200
  # Style des graphiques
201
  for ax in axes:
 
 
 
 
202
  plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
203
+ ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
204
 
205
  plt.tight_layout()
206
  return fig
207
 
208
+ def plot_feature_importance(model, feature_names, model_type):
209
+ fig, ax = plt.subplots(figsize=(10, 6))
210
+ fig, colors = set_plot_style(fig)
211
+
212
+ if model_type in ["Decision Tree", "Random Forest", "Gradient Boost"]:
213
+ importance = model.feature_importances_
214
+ elif model_type == "Logistic Regression":
215
+ importance = np.abs(model.coef_[0])
216
+
217
+ importance_df = pd.DataFrame({
218
+ 'feature': feature_names,
219
+ 'importance': importance
220
+ }).sort_values('importance', ascending=True)
221
+
222
+ ax.barh(importance_df['feature'], importance_df['importance'],
223
+ color='#1E88E5', alpha=0.8)
224
+ ax.set_title("Importance des Caractéristiques", color='#0D47A1', pad=20)
225
+
226
+ return fig
227
+
228
+ def plot_correlation_matrix(data):
229
+ fig, ax = plt.subplots(figsize=(10, 8))
230
+ fig, _ = set_plot_style(fig)
231
+
232
+ sns.heatmap(data.corr(), annot=True, cmap='coolwarm', center=0,
233
+ ax=ax, fmt='.2f', square=True)
234
+ ax.set_title("Matrice de Corrélation", color='#0D47A1', pad=20)
235
+
236
+ return fig
237
+
238
  def app():
239
+ st.markdown('<h1 class="main-header">Interpréteur de Modèles ML</h1>',
240
+ unsafe_allow_html=True)
241
 
242
  # Load data
243
  X_train, y_train, X_test, y_test, feature_names = load_data()
 
247
  with st.spinner("🔄 Entraînement des modèles en cours..."):
248
  st.session_state.model_results = train_models(X_train, y_train, X_test, y_test)
249
 
250
+ # Sidebar
251
  with st.sidebar:
252
+ st.markdown('<h2 style="color: #1E88E5;">Navigation</h2>',
253
+ unsafe_allow_html=True)
254
+
255
  selected_model = st.selectbox(
256
  "📊 Sélectionnez un modèle",
257
  list(st.session_state.model_results.keys())
 
269
 
270
  current_model = st.session_state.model_results[selected_model]['model']
271
 
272
+ # Main content
273
+ if page == "Performance des modèles":
274
+ st.markdown('<h2 class="sub-header">Performance des modèles</h2>',
275
+ unsafe_allow_html=True)
276
+
277
+ performance_fig = plot_model_performance(st.session_state.model_results)
278
+ st.pyplot(performance_fig)
279
+
280
+ st.markdown('<h3 class="sub-header">Métriques détaillées</h3>',
281
+ unsafe_allow_html=True)
282
+
283
+ col1, col2 = st.columns(2)
284
+ with col1:
285
+ st.markdown('<h4 style="color: #1E88E5;">Entraînement</h4>',
286
+ unsafe_allow_html=True)
287
+ for metric, value in st.session_state.model_results[selected_model]['train_metrics'].items():
288
+ st.markdown(custom_metric_card(metric.capitalize(), value),
289
+ unsafe_allow_html=True)
290
+
291
+ with col2:
292
+ st.markdown('<h4 style="color: #1E88E5;">Test</h4>',
293
+ unsafe_allow_html=True)
294
+ for metric, value in st.session_state.model_results[selected_model]['test_metrics'].items():
295
+ st.markdown(custom_metric_card(metric.capitalize(), value),
296
+ unsafe_allow_html=True)
297
+
298
+ elif page == "Analyse des caractéristiques":
299
+ st.markdown('<h2 class="sub-header">Analyse des caractéristiques</h2>',
300
+ unsafe_allow_html=True)
301
+
302
+ importance_fig = plot_feature_importance(current_model, feature_names, selected_model)
303
+ st.pyplot(importance_fig)
304
 
305
+ st.markdown('<h3 class="sub-header">Corrélations</h3>',
306
+ unsafe_allow_html=True)
307
+ corr_fig = plot_correlation_matrix(X_train)
308
+ st.pyplot(corr_fig)
309
 
310
  if __name__ == "__main__":
311
  app()