Roni Goldshmidt
commited on
Commit
Β·
dc5408d
1
Parent(s):
154bb23
Initial leaderboard setup
Browse files- .ipynb_checkpoints/app-checkpoint.py +132 -205
- app.py +132 -205
.ipynb_checkpoints/app-checkpoint.py
CHANGED
@@ -9,40 +9,15 @@ import io
|
|
9 |
import os
|
10 |
import base64
|
11 |
|
12 |
-
#
|
13 |
-
st.
|
14 |
-
|
15 |
-
page_icon="nexar_logo.png",
|
16 |
-
layout="wide"
|
17 |
-
)
|
18 |
|
19 |
-
|
20 |
-
st.
|
21 |
-
<style>
|
22 |
-
.main { padding: 2rem; }
|
23 |
-
.stTabs [data-baseweb="tab-list"] { gap: 8px; }
|
24 |
-
.stTabs [data-baseweb="tab"] {
|
25 |
-
padding: 8px 16px;
|
26 |
-
border-radius: 4px;
|
27 |
-
}
|
28 |
-
.metric-card {
|
29 |
-
background-color: #f8f9fa;
|
30 |
-
padding: 20px;
|
31 |
-
border-radius: 10px;
|
32 |
-
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
33 |
-
}
|
34 |
-
</style>
|
35 |
-
""", unsafe_allow_html=True)
|
36 |
-
|
37 |
-
# Header
|
38 |
-
col1, col2 = st.columns([0.16, 0.84])
|
39 |
-
with col1:
|
40 |
-
st.image("nexar_logo.png", width=600)
|
41 |
-
with col2:
|
42 |
-
st.title("Driving Leaderboard")
|
43 |
|
44 |
# Data loading function
|
45 |
-
@st.cache_data
|
46 |
def load_data(directory='results', labels_filename='Labels.csv'):
|
47 |
labels_path = os.path.join(directory, labels_filename)
|
48 |
df_labels = pd.read_csv(labels_path)
|
@@ -58,15 +33,7 @@ def load_data(directory='results', labels_filename='Labels.csv'):
|
|
58 |
model_comparison = ModelComparison(evaluators)
|
59 |
return model_comparison
|
60 |
|
61 |
-
#
|
62 |
-
if 'model_comparison' not in st.session_state:
|
63 |
-
st.session_state.model_comparison = load_data()
|
64 |
-
st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
|
65 |
-
st.session_state.combined_df = st.session_state.model_comparison.combined_df
|
66 |
-
|
67 |
-
# Create tabs
|
68 |
-
tab1, tab2, tab3, tab4 = st.tabs(["π Leaderboard", "π Class Performance", "π Detailed Metrics", "βοΈ Model Comparison"])
|
69 |
-
|
70 |
def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
71 |
numeric_cols = df.select_dtypes(include=['float64']).columns
|
72 |
|
@@ -110,8 +77,70 @@ def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
|
110 |
])
|
111 |
return styled
|
112 |
|
113 |
-
#
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
st.subheader("Model Performance Leaderboard")
|
116 |
|
117 |
sort_col = st.selectbox(
|
@@ -121,11 +150,7 @@ with tab1:
|
|
121 |
)
|
122 |
|
123 |
sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
|
124 |
-
|
125 |
-
st.dataframe(
|
126 |
-
style_dataframe(sorted_df),
|
127 |
-
use_container_width=True,
|
128 |
-
)
|
129 |
|
130 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
131 |
selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
|
@@ -151,11 +176,10 @@ with tab1:
|
|
151 |
|
152 |
st.plotly_chart(fig, use_container_width=True)
|
153 |
|
154 |
-
|
155 |
-
with tab2:
|
156 |
st.subheader("Class-level Performance")
|
157 |
categories = st.session_state.combined_df['Category'].unique()
|
158 |
-
|
159 |
col1, col2, col3 = st.columns(3)
|
160 |
with col1:
|
161 |
selected_category = st.selectbox(
|
@@ -170,23 +194,26 @@ with tab2:
|
|
170 |
key='class_metric'
|
171 |
)
|
172 |
with col3:
|
|
|
173 |
selected_models = st.multiselect(
|
174 |
"Select Models:",
|
175 |
-
|
176 |
-
default=
|
|
|
177 |
)
|
178 |
-
|
179 |
-
# Create
|
180 |
plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
|
181 |
-
model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(
|
182 |
|
|
|
183 |
class_data = st.session_state.combined_df[
|
184 |
(st.session_state.combined_df['Category'] == selected_category) &
|
185 |
(~st.session_state.combined_df['Class'].str.contains('Overall')) &
|
186 |
(st.session_state.combined_df['Model'].isin(selected_models))
|
187 |
]
|
188 |
|
189 |
-
# Bar chart
|
190 |
fig = px.bar(
|
191 |
class_data,
|
192 |
x='Class',
|
@@ -198,13 +225,11 @@ with tab2:
|
|
198 |
)
|
199 |
st.plotly_chart(fig, use_container_width=True)
|
200 |
|
201 |
-
#
|
|
|
202 |
models_per_row = 4
|
203 |
num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
|
204 |
|
205 |
-
st.markdown("### Select Models to Display:")
|
206 |
-
|
207 |
-
# Create toggles for models using st.columns
|
208 |
for row in range(num_rows):
|
209 |
cols = st.columns(models_per_row)
|
210 |
for col_idx in range(models_per_row):
|
@@ -212,50 +237,44 @@ with tab2:
|
|
212 |
if model_idx < len(selected_models):
|
213 |
model = selected_models[model_idx]
|
214 |
container = cols[col_idx].container()
|
215 |
-
|
216 |
-
# Get the consistent color for this model
|
217 |
color = model_colors[model]
|
218 |
|
219 |
-
#
|
220 |
-
toggle_key = f"toggle_{model}"
|
221 |
-
if toggle_key not in st.session_state:
|
222 |
-
st.session_state[toggle_key] = True
|
223 |
-
|
224 |
-
# Create colored legend item with HTML
|
225 |
container.markdown(
|
226 |
f"""
|
227 |
-
<div style='display: flex; align-items: center; margin-bottom: -40px;
|
228 |
<span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
|
229 |
</div>
|
230 |
""",
|
231 |
unsafe_allow_html=True
|
232 |
)
|
233 |
|
234 |
-
#
|
235 |
-
container.checkbox(
|
236 |
-
f" {model}",
|
237 |
-
value=
|
238 |
-
key=
|
|
|
|
|
239 |
)
|
240 |
|
241 |
-
#
|
|
|
242 |
unique_classes = class_data['Class'].unique()
|
243 |
num_classes = len(unique_classes)
|
|
|
|
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
for row in range(num_rows):
|
250 |
-
cols = st.columns(3)
|
251 |
-
for col_idx in range(3):
|
252 |
-
class_idx = row * 3 + col_idx
|
253 |
if class_idx < num_classes:
|
254 |
current_class = unique_classes[class_idx]
|
255 |
|
256 |
-
#
|
257 |
visible_models = [model for model in selected_models
|
258 |
-
if
|
259 |
|
260 |
class_specific_data = class_data[
|
261 |
(class_data['Class'] == current_class) &
|
@@ -269,18 +288,16 @@ with tab2:
|
|
269 |
color='Model',
|
270 |
title=f'Precision vs Recall: {current_class}',
|
271 |
height=300,
|
272 |
-
color_discrete_map=model_colors
|
273 |
)
|
274 |
|
275 |
-
# Update layout for better visibility
|
276 |
fig.update_layout(
|
277 |
xaxis_range=[0, 1],
|
278 |
yaxis_range=[0, 1],
|
279 |
margin=dict(l=40, r=40, t=40, b=40),
|
280 |
-
showlegend=False
|
281 |
)
|
282 |
|
283 |
-
# Add diagonal reference line
|
284 |
fig.add_trace(
|
285 |
go.Scatter(
|
286 |
x=[0, 1],
|
@@ -293,74 +310,53 @@ with tab2:
|
|
293 |
|
294 |
cols[col_idx].plotly_chart(fig, use_container_width=True)
|
295 |
|
296 |
-
|
297 |
-
with tab3:
|
298 |
st.subheader("Detailed Metrics Analysis")
|
299 |
|
300 |
selected_model = st.selectbox(
|
301 |
"Select Model for Detailed Analysis:",
|
302 |
-
st.session_state.combined_df['Model'].unique()
|
|
|
303 |
)
|
304 |
|
305 |
model_data = st.session_state.combined_df[
|
306 |
st.session_state.combined_df['Model'] == selected_model
|
307 |
]
|
308 |
|
309 |
-
# Create metrics tables
|
310 |
st.markdown("### Performance Metrics by Category")
|
311 |
-
|
312 |
-
# Get unique categories and relevant classes for each category
|
313 |
categories = model_data['Category'].unique()
|
314 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
315 |
|
316 |
-
# Process data for each category
|
317 |
for category in categories:
|
318 |
st.markdown(f"#### {category}")
|
319 |
-
|
320 |
-
# Filter data for this category
|
321 |
category_data = model_data[model_data['Category'] == category].copy()
|
322 |
|
323 |
-
#
|
324 |
-
category_metrics = pd.DataFrame()
|
325 |
-
|
326 |
-
# Get classes for this category (excluding 'Overall' prefix)
|
327 |
classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
|
328 |
-
|
329 |
-
# Add the overall metric for this category
|
330 |
overall_data = category_data[category_data['Class'].str.contains('Overall')]
|
331 |
|
332 |
-
#
|
333 |
category_metrics = pd.DataFrame(index=classes)
|
334 |
-
|
335 |
-
# Add metrics columns
|
336 |
for metric in metrics:
|
337 |
-
# Add class-specific metrics
|
338 |
class_metrics = {}
|
339 |
for class_name in classes:
|
340 |
class_data = category_data[category_data['Class'] == class_name]
|
341 |
if not class_data.empty:
|
342 |
class_metrics[class_name] = class_data[metric].iloc[0]
|
343 |
-
|
344 |
category_metrics[metric] = pd.Series(class_metrics)
|
345 |
|
346 |
-
# Add overall metrics
|
347 |
if not overall_data.empty:
|
348 |
overall_row = pd.DataFrame({
|
349 |
metric: [overall_data[metric].iloc[0]] for metric in metrics
|
350 |
}, index=['Overall'])
|
351 |
category_metrics = pd.concat([overall_row, category_metrics])
|
352 |
|
353 |
-
|
354 |
-
styled_metrics = style_dataframe(category_metrics.round(4))
|
355 |
-
st.dataframe(styled_metrics, use_container_width=True)
|
356 |
-
|
357 |
-
# Add spacing between categories
|
358 |
st.markdown("---")
|
359 |
|
360 |
# Export functionality
|
361 |
st.markdown("### Export Data")
|
362 |
-
|
363 |
-
# Prepare export data
|
364 |
export_data = pd.DataFrame()
|
365 |
for category in categories:
|
366 |
category_data = model_data[model_data['Category'] == category].copy()
|
@@ -372,7 +368,6 @@ with tab3:
|
|
372 |
).round(4)
|
373 |
export_data = pd.concat([export_data, category_metrics])
|
374 |
|
375 |
-
# Create download button
|
376 |
csv = export_data.to_csv().encode()
|
377 |
st.download_button(
|
378 |
"Download Detailed Metrics",
|
@@ -382,31 +377,25 @@ with tab3:
|
|
382 |
key='download-csv'
|
383 |
)
|
384 |
|
385 |
-
|
386 |
-
with tab4:
|
387 |
st.header("Model Comparison Analysis")
|
388 |
|
389 |
-
# Create two columns for model selection
|
390 |
col1, col2 = st.columns(2)
|
391 |
-
|
392 |
-
# Model selection dropdown menus
|
393 |
with col1:
|
394 |
model1 = st.selectbox(
|
395 |
"Select First Model:",
|
396 |
st.session_state.combined_df['Model'].unique(),
|
397 |
-
key='
|
398 |
)
|
399 |
|
400 |
with col2:
|
401 |
-
# Filter out the first selected model from options
|
402 |
available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
|
403 |
model2 = st.selectbox(
|
404 |
"Select Second Model:",
|
405 |
available_models,
|
406 |
-
key='
|
407 |
)
|
408 |
|
409 |
-
# Category selection
|
410 |
selected_category = st.selectbox(
|
411 |
"Select Category for Comparison:",
|
412 |
st.session_state.combined_df['Category'].unique(),
|
@@ -423,26 +412,19 @@ with tab4:
|
|
423 |
(st.session_state.combined_df['Model'] == model2) &
|
424 |
(st.session_state.combined_df['Category'] == selected_category)
|
425 |
]
|
426 |
-
|
427 |
-
# Define metrics list
|
428 |
-
metrics = ['F1 Score', 'Precision', 'Recall']
|
429 |
|
430 |
-
# Create comparison tables section
|
431 |
st.subheader("Detailed Metrics Comparison")
|
|
|
432 |
|
433 |
-
# Create a table for each metric
|
434 |
for metric in metrics:
|
435 |
st.markdown(f"#### {metric} Comparison")
|
436 |
-
|
437 |
-
# Prepare data for the metric table
|
438 |
metric_data = []
|
|
|
439 |
for class_name in model1_data['Class'].unique():
|
440 |
-
|
441 |
-
|
442 |
-
m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0]
|
443 |
diff = m1_value - m2_value
|
444 |
|
445 |
-
# Add to comparison data
|
446 |
metric_data.append({
|
447 |
'Class': class_name,
|
448 |
model1: m1_value,
|
@@ -450,92 +432,46 @@ with tab4:
|
|
450 |
'Difference': diff
|
451 |
})
|
452 |
|
453 |
-
# Create DataFrame for the metric
|
454 |
metric_df = pd.DataFrame(metric_data)
|
455 |
-
|
456 |
-
# Style the table
|
457 |
-
def style_metric_table(df):
|
458 |
-
return df.style\
|
459 |
-
.format({
|
460 |
-
model1: '{:.2f}%',
|
461 |
-
model2: '{:.2f}%',
|
462 |
-
'Difference': '{:+.2f}%'
|
463 |
-
})\
|
464 |
-
.background_gradient(
|
465 |
-
cmap='RdYlGn',
|
466 |
-
subset=['Difference'],
|
467 |
-
vmin=-10,
|
468 |
-
vmax=10
|
469 |
-
)\
|
470 |
-
.set_properties(**{
|
471 |
-
'text-align': 'center',
|
472 |
-
'padding': '10px',
|
473 |
-
'border': '1px solid #dee2e6'
|
474 |
-
})\
|
475 |
-
.set_table_styles([
|
476 |
-
{'selector': 'th', 'props': [
|
477 |
-
('background-color', '#4a90e2'),
|
478 |
-
('color', 'white'),
|
479 |
-
('font-weight', 'bold'),
|
480 |
-
('text-align', 'center'),
|
481 |
-
('padding', '10px')
|
482 |
-
]}
|
483 |
-
])
|
484 |
-
|
485 |
-
# Display the styled table
|
486 |
-
|
487 |
-
st.dataframe(
|
488 |
-
style_dataframe(metric_df),
|
489 |
-
use_container_width=True,
|
490 |
-
)
|
491 |
-
# Add visual separator
|
492 |
st.markdown("---")
|
493 |
|
494 |
-
#
|
495 |
st.subheader("Visual Performance Analysis")
|
496 |
-
|
497 |
-
# Metric selector for bar chart
|
498 |
selected_metric = st.selectbox(
|
499 |
"Select Metric for Comparison:",
|
500 |
metrics,
|
501 |
-
key='
|
502 |
)
|
503 |
|
504 |
-
# Prepare data for
|
505 |
comparison_data = pd.DataFrame()
|
506 |
-
|
507 |
-
# Get data for both models
|
508 |
for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
|
509 |
-
# Filter out Overall classes and select relevant columns
|
510 |
model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
|
511 |
model_metrics = model_metrics.rename(columns={selected_metric: model_name})
|
512 |
|
513 |
-
# Merge with existing data or create new DataFrame
|
514 |
if idx == 0:
|
515 |
comparison_data = model_metrics
|
516 |
else:
|
517 |
comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
|
518 |
|
519 |
-
#
|
520 |
fig_bar = go.Figure()
|
521 |
|
522 |
-
# Add bars for first model
|
523 |
fig_bar.add_trace(go.Bar(
|
524 |
name=model1,
|
525 |
x=comparison_data['Class'],
|
526 |
-
y=comparison_data[model1],
|
527 |
marker_color='rgb(55, 83, 109)'
|
528 |
))
|
529 |
|
530 |
-
# Add bars for second model
|
531 |
fig_bar.add_trace(go.Bar(
|
532 |
name=model2,
|
533 |
x=comparison_data['Class'],
|
534 |
-
y=comparison_data[model2],
|
535 |
marker_color='rgb(26, 118, 255)'
|
536 |
))
|
537 |
|
538 |
-
# Update bar chart layout
|
539 |
fig_bar.update_layout(
|
540 |
title=f"{selected_metric} Comparison by Class",
|
541 |
xaxis_title="Class",
|
@@ -552,23 +488,19 @@ with tab4:
|
|
552 |
)
|
553 |
)
|
554 |
|
555 |
-
# Display bar chart
|
556 |
st.plotly_chart(fig_bar, use_container_width=True)
|
557 |
|
558 |
-
#
|
559 |
st.markdown("#### Precision-Recall Analysis")
|
560 |
|
561 |
-
# Filter data for scatter plot
|
562 |
model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
|
563 |
model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
|
564 |
|
565 |
-
# Create scatter plot
|
566 |
fig_scatter = go.Figure()
|
567 |
|
568 |
-
# Add scatter points for first model
|
569 |
fig_scatter.add_trace(go.Scatter(
|
570 |
-
x=model1_scatter['Precision']*100,
|
571 |
-
y=model1_scatter['Recall']*100,
|
572 |
mode='markers+text',
|
573 |
name=model1,
|
574 |
text=model1_scatter['Class'],
|
@@ -576,10 +508,9 @@ with tab4:
|
|
576 |
marker=dict(size=10)
|
577 |
))
|
578 |
|
579 |
-
# Add scatter points for second model
|
580 |
fig_scatter.add_trace(go.Scatter(
|
581 |
-
x=model2_scatter['Precision']*100,
|
582 |
-
y=model2_scatter['Recall']*100,
|
583 |
mode='markers+text',
|
584 |
name=model2,
|
585 |
text=model2_scatter['Class'],
|
@@ -587,7 +518,6 @@ with tab4:
|
|
587 |
marker=dict(size=10)
|
588 |
))
|
589 |
|
590 |
-
# Add reference line
|
591 |
fig_scatter.add_trace(go.Scatter(
|
592 |
x=[0, 100],
|
593 |
y=[0, 100],
|
@@ -596,7 +526,6 @@ with tab4:
|
|
596 |
showlegend=False
|
597 |
))
|
598 |
|
599 |
-
# Update scatter plot layout
|
600 |
fig_scatter.update_layout(
|
601 |
title="Precision vs Recall Analysis by Class",
|
602 |
xaxis_title="Precision (%)",
|
@@ -613,10 +542,8 @@ with tab4:
|
|
613 |
)
|
614 |
)
|
615 |
|
616 |
-
# Display scatter plot
|
617 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
618 |
|
619 |
-
|
620 |
# Footer
|
621 |
st.markdown("---")
|
622 |
st.markdown("Dashboard created for model evaluation and comparison")
|
|
|
9 |
import os
|
10 |
import base64
|
11 |
|
12 |
+
# Initialize session state
|
13 |
+
if 'active_tab' not in st.session_state:
|
14 |
+
st.session_state.active_tab = "π Leaderboard"
|
|
|
|
|
|
|
15 |
|
16 |
+
if 'toggle_states' not in st.session_state:
|
17 |
+
st.session_state.toggle_states = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Data loading function
|
20 |
+
@st.cache_data
|
21 |
def load_data(directory='results', labels_filename='Labels.csv'):
|
22 |
labels_path = os.path.join(directory, labels_filename)
|
23 |
df_labels = pd.read_csv(labels_path)
|
|
|
33 |
model_comparison = ModelComparison(evaluators)
|
34 |
return model_comparison
|
35 |
|
36 |
+
# Helper functions for styling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
38 |
numeric_cols = df.select_dtypes(include=['float64']).columns
|
39 |
|
|
|
77 |
])
|
78 |
return styled
|
79 |
|
80 |
+
# Toggle state management
|
81 |
+
def get_toggle_state(model_name):
|
82 |
+
key = f"toggle_{model_name}"
|
83 |
+
if key not in st.session_state.toggle_states:
|
84 |
+
st.session_state.toggle_states[key] = True
|
85 |
+
return st.session_state.toggle_states[key]
|
86 |
+
|
87 |
+
def set_toggle_state(model_name, value):
|
88 |
+
key = f"toggle_{model_name}"
|
89 |
+
st.session_state.toggle_states[key] = value
|
90 |
+
|
91 |
+
# Page configuration
|
92 |
+
st.set_page_config(
|
93 |
+
page_title="Nexar Driving Leaderboard",
|
94 |
+
page_icon="nexar_logo.png",
|
95 |
+
layout="wide"
|
96 |
+
)
|
97 |
+
|
98 |
+
# Custom styling
|
99 |
+
st.markdown("""
|
100 |
+
<style>
|
101 |
+
.main { padding: 2rem; }
|
102 |
+
.stTabs [data-baseweb="tab-list"] { gap: 8px; }
|
103 |
+
.stTabs [data-baseweb="tab"] {
|
104 |
+
padding: 8px 16px;
|
105 |
+
border-radius: 4px;
|
106 |
+
}
|
107 |
+
.metric-card {
|
108 |
+
background-color: #f8f9fa;
|
109 |
+
padding: 20px;
|
110 |
+
border-radius: 10px;
|
111 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
112 |
+
}
|
113 |
+
</style>
|
114 |
+
""", unsafe_allow_html=True)
|
115 |
+
|
116 |
+
# Header
|
117 |
+
col1, col2 = st.columns([0.16, 0.84])
|
118 |
+
with col1:
|
119 |
+
st.image("nexar_logo.png", width=600)
|
120 |
+
with col2:
|
121 |
+
st.title("Driving Leaderboard")
|
122 |
+
|
123 |
+
# Initialize data in session state
|
124 |
+
if 'model_comparison' not in st.session_state:
|
125 |
+
st.session_state.model_comparison = load_data()
|
126 |
+
st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
|
127 |
+
st.session_state.combined_df = st.session_state.model_comparison.combined_df
|
128 |
+
|
129 |
+
# Tab callback
|
130 |
+
def handle_tab_change(tab_name):
|
131 |
+
st.session_state.active_tab = tab_name
|
132 |
+
|
133 |
+
# Define tab names
|
134 |
+
tab_names = ["π Leaderboard", "π Class Performance", "π Detailed Metrics", "βοΈ Model Comparison"]
|
135 |
+
|
136 |
+
# Create tabs
|
137 |
+
selected_tab = st.radio("", tab_names, key="tab_selector",
|
138 |
+
horizontal=True, label_visibility="collapsed",
|
139 |
+
index=tab_names.index(st.session_state.active_tab))
|
140 |
+
handle_tab_change(selected_tab)
|
141 |
+
|
142 |
+
# Content based on selected tab
|
143 |
+
if st.session_state.active_tab == "π Leaderboard":
|
144 |
st.subheader("Model Performance Leaderboard")
|
145 |
|
146 |
sort_col = st.selectbox(
|
|
|
150 |
)
|
151 |
|
152 |
sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
|
153 |
+
st.dataframe(style_dataframe(sorted_df), use_container_width=True)
|
|
|
|
|
|
|
|
|
154 |
|
155 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
156 |
selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
|
|
|
176 |
|
177 |
st.plotly_chart(fig, use_container_width=True)
|
178 |
|
179 |
+
elif st.session_state.active_tab == "π Class Performance":
|
|
|
180 |
st.subheader("Class-level Performance")
|
181 |
categories = st.session_state.combined_df['Category'].unique()
|
182 |
+
metrics = ['F1 Score', 'Precision', 'Recall']
|
183 |
col1, col2, col3 = st.columns(3)
|
184 |
with col1:
|
185 |
selected_category = st.selectbox(
|
|
|
194 |
key='class_metric'
|
195 |
)
|
196 |
with col3:
|
197 |
+
all_models = sorted(st.session_state.combined_df['Model'].unique())
|
198 |
selected_models = st.multiselect(
|
199 |
"Select Models:",
|
200 |
+
all_models,
|
201 |
+
default=all_models,
|
202 |
+
key='selected_models'
|
203 |
)
|
204 |
+
|
205 |
+
# Create consistent color mapping
|
206 |
plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
|
207 |
+
model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(all_models)}
|
208 |
|
209 |
+
# Filter data
|
210 |
class_data = st.session_state.combined_df[
|
211 |
(st.session_state.combined_df['Category'] == selected_category) &
|
212 |
(~st.session_state.combined_df['Class'].str.contains('Overall')) &
|
213 |
(st.session_state.combined_df['Model'].isin(selected_models))
|
214 |
]
|
215 |
|
216 |
+
# Bar chart
|
217 |
fig = px.bar(
|
218 |
class_data,
|
219 |
x='Class',
|
|
|
225 |
)
|
226 |
st.plotly_chart(fig, use_container_width=True)
|
227 |
|
228 |
+
# Model toggles
|
229 |
+
st.markdown("### Model Visibility Controls")
|
230 |
models_per_row = 4
|
231 |
num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
|
232 |
|
|
|
|
|
|
|
233 |
for row in range(num_rows):
|
234 |
cols = st.columns(models_per_row)
|
235 |
for col_idx in range(models_per_row):
|
|
|
237 |
if model_idx < len(selected_models):
|
238 |
model = selected_models[model_idx]
|
239 |
container = cols[col_idx].container()
|
|
|
|
|
240 |
color = model_colors[model]
|
241 |
|
242 |
+
# Create colored indicator
|
|
|
|
|
|
|
|
|
|
|
243 |
container.markdown(
|
244 |
f"""
|
245 |
+
<div style='display: flex; align-items: center; margin-bottom: -40px;'>
|
246 |
<span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
|
247 |
</div>
|
248 |
""",
|
249 |
unsafe_allow_html=True
|
250 |
)
|
251 |
|
252 |
+
# Toggle checkbox
|
253 |
+
value = container.checkbox(
|
254 |
+
f" {model}",
|
255 |
+
value=get_toggle_state(model),
|
256 |
+
key=f"vis_{model}",
|
257 |
+
on_change=set_toggle_state,
|
258 |
+
args=(model, not get_toggle_state(model))
|
259 |
)
|
260 |
|
261 |
+
# Precision-Recall plots
|
262 |
+
st.markdown("### Precision-Recall Analysis by Class")
|
263 |
unique_classes = class_data['Class'].unique()
|
264 |
num_classes = len(unique_classes)
|
265 |
+
plots_per_row = 3
|
266 |
+
num_plot_rows = (num_classes + plots_per_row - 1) // plots_per_row
|
267 |
|
268 |
+
for row in range(num_plot_rows):
|
269 |
+
cols = st.columns(plots_per_row)
|
270 |
+
for col_idx in range(plots_per_row):
|
271 |
+
class_idx = row * plots_per_row + col_idx
|
|
|
|
|
|
|
|
|
272 |
if class_idx < num_classes:
|
273 |
current_class = unique_classes[class_idx]
|
274 |
|
275 |
+
# Get visible models
|
276 |
visible_models = [model for model in selected_models
|
277 |
+
if get_toggle_state(model)]
|
278 |
|
279 |
class_specific_data = class_data[
|
280 |
(class_data['Class'] == current_class) &
|
|
|
288 |
color='Model',
|
289 |
title=f'Precision vs Recall: {current_class}',
|
290 |
height=300,
|
291 |
+
color_discrete_map=model_colors
|
292 |
)
|
293 |
|
|
|
294 |
fig.update_layout(
|
295 |
xaxis_range=[0, 1],
|
296 |
yaxis_range=[0, 1],
|
297 |
margin=dict(l=40, r=40, t=40, b=40),
|
298 |
+
showlegend=False
|
299 |
)
|
300 |
|
|
|
301 |
fig.add_trace(
|
302 |
go.Scatter(
|
303 |
x=[0, 1],
|
|
|
310 |
|
311 |
cols[col_idx].plotly_chart(fig, use_container_width=True)
|
312 |
|
313 |
+
elif st.session_state.active_tab == "π Detailed Metrics":
|
|
|
314 |
st.subheader("Detailed Metrics Analysis")
|
315 |
|
316 |
selected_model = st.selectbox(
|
317 |
"Select Model for Detailed Analysis:",
|
318 |
+
st.session_state.combined_df['Model'].unique(),
|
319 |
+
key='detailed_model'
|
320 |
)
|
321 |
|
322 |
model_data = st.session_state.combined_df[
|
323 |
st.session_state.combined_df['Model'] == selected_model
|
324 |
]
|
325 |
|
|
|
326 |
st.markdown("### Performance Metrics by Category")
|
|
|
|
|
327 |
categories = model_data['Category'].unique()
|
328 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
329 |
|
|
|
330 |
for category in categories:
|
331 |
st.markdown(f"#### {category}")
|
|
|
|
|
332 |
category_data = model_data[model_data['Category'] == category].copy()
|
333 |
|
334 |
+
# Get classes excluding Overall
|
|
|
|
|
|
|
335 |
classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
|
|
|
|
|
336 |
overall_data = category_data[category_data['Class'].str.contains('Overall')]
|
337 |
|
338 |
+
# Create metrics DataFrame
|
339 |
category_metrics = pd.DataFrame(index=classes)
|
|
|
|
|
340 |
for metric in metrics:
|
|
|
341 |
class_metrics = {}
|
342 |
for class_name in classes:
|
343 |
class_data = category_data[category_data['Class'] == class_name]
|
344 |
if not class_data.empty:
|
345 |
class_metrics[class_name] = class_data[metric].iloc[0]
|
|
|
346 |
category_metrics[metric] = pd.Series(class_metrics)
|
347 |
|
348 |
+
# Add overall metrics
|
349 |
if not overall_data.empty:
|
350 |
overall_row = pd.DataFrame({
|
351 |
metric: [overall_data[metric].iloc[0]] for metric in metrics
|
352 |
}, index=['Overall'])
|
353 |
category_metrics = pd.concat([overall_row, category_metrics])
|
354 |
|
355 |
+
st.dataframe(style_dataframe(category_metrics.round(4)), use_container_width=True)
|
|
|
|
|
|
|
|
|
356 |
st.markdown("---")
|
357 |
|
358 |
# Export functionality
|
359 |
st.markdown("### Export Data")
|
|
|
|
|
360 |
export_data = pd.DataFrame()
|
361 |
for category in categories:
|
362 |
category_data = model_data[model_data['Category'] == category].copy()
|
|
|
368 |
).round(4)
|
369 |
export_data = pd.concat([export_data, category_metrics])
|
370 |
|
|
|
371 |
csv = export_data.to_csv().encode()
|
372 |
st.download_button(
|
373 |
"Download Detailed Metrics",
|
|
|
377 |
key='download-csv'
|
378 |
)
|
379 |
|
380 |
+
elif st.session_state.active_tab == "βοΈ Model Comparison":
|
|
|
381 |
st.header("Model Comparison Analysis")
|
382 |
|
|
|
383 |
col1, col2 = st.columns(2)
|
|
|
|
|
384 |
with col1:
|
385 |
model1 = st.selectbox(
|
386 |
"Select First Model:",
|
387 |
st.session_state.combined_df['Model'].unique(),
|
388 |
+
key='compare_model1'
|
389 |
)
|
390 |
|
391 |
with col2:
|
|
|
392 |
available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
|
393 |
model2 = st.selectbox(
|
394 |
"Select Second Model:",
|
395 |
available_models,
|
396 |
+
key='compare_model2'
|
397 |
)
|
398 |
|
|
|
399 |
selected_category = st.selectbox(
|
400 |
"Select Category for Comparison:",
|
401 |
st.session_state.combined_df['Category'].unique(),
|
|
|
412 |
(st.session_state.combined_df['Model'] == model2) &
|
413 |
(st.session_state.combined_df['Category'] == selected_category)
|
414 |
]
|
|
|
|
|
|
|
415 |
|
|
|
416 |
st.subheader("Detailed Metrics Comparison")
|
417 |
+
metrics = ['F1 Score', 'Precision', 'Recall']
|
418 |
|
|
|
419 |
for metric in metrics:
|
420 |
st.markdown(f"#### {metric} Comparison")
|
|
|
|
|
421 |
metric_data = []
|
422 |
+
|
423 |
for class_name in model1_data['Class'].unique():
|
424 |
+
m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0] * 100
|
425 |
+
m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0] * 100
|
|
|
426 |
diff = m1_value - m2_value
|
427 |
|
|
|
428 |
metric_data.append({
|
429 |
'Class': class_name,
|
430 |
model1: m1_value,
|
|
|
432 |
'Difference': diff
|
433 |
})
|
434 |
|
|
|
435 |
metric_df = pd.DataFrame(metric_data)
|
436 |
+
st.dataframe(style_dataframe(metric_df), use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
st.markdown("---")
|
438 |
|
439 |
+
# Visual comparison
|
440 |
st.subheader("Visual Performance Analysis")
|
|
|
|
|
441 |
selected_metric = st.selectbox(
|
442 |
"Select Metric for Comparison:",
|
443 |
metrics,
|
444 |
+
key='visual_compare_metric'
|
445 |
)
|
446 |
|
447 |
+
# Prepare data for visualization
|
448 |
comparison_data = pd.DataFrame()
|
|
|
|
|
449 |
for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
|
|
|
450 |
model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
|
451 |
model_metrics = model_metrics.rename(columns={selected_metric: model_name})
|
452 |
|
|
|
453 |
if idx == 0:
|
454 |
comparison_data = model_metrics
|
455 |
else:
|
456 |
comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
|
457 |
|
458 |
+
# Bar chart
|
459 |
fig_bar = go.Figure()
|
460 |
|
|
|
461 |
fig_bar.add_trace(go.Bar(
|
462 |
name=model1,
|
463 |
x=comparison_data['Class'],
|
464 |
+
y=comparison_data[model1] * 100,
|
465 |
marker_color='rgb(55, 83, 109)'
|
466 |
))
|
467 |
|
|
|
468 |
fig_bar.add_trace(go.Bar(
|
469 |
name=model2,
|
470 |
x=comparison_data['Class'],
|
471 |
+
y=comparison_data[model2] * 100,
|
472 |
marker_color='rgb(26, 118, 255)'
|
473 |
))
|
474 |
|
|
|
475 |
fig_bar.update_layout(
|
476 |
title=f"{selected_metric} Comparison by Class",
|
477 |
xaxis_title="Class",
|
|
|
488 |
)
|
489 |
)
|
490 |
|
|
|
491 |
st.plotly_chart(fig_bar, use_container_width=True)
|
492 |
|
493 |
+
# Precision-Recall Analysis
|
494 |
st.markdown("#### Precision-Recall Analysis")
|
495 |
|
|
|
496 |
model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
|
497 |
model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
|
498 |
|
|
|
499 |
fig_scatter = go.Figure()
|
500 |
|
|
|
501 |
fig_scatter.add_trace(go.Scatter(
|
502 |
+
x=model1_scatter['Precision'] * 100,
|
503 |
+
y=model1_scatter['Recall'] * 100,
|
504 |
mode='markers+text',
|
505 |
name=model1,
|
506 |
text=model1_scatter['Class'],
|
|
|
508 |
marker=dict(size=10)
|
509 |
))
|
510 |
|
|
|
511 |
fig_scatter.add_trace(go.Scatter(
|
512 |
+
x=model2_scatter['Precision'] * 100,
|
513 |
+
y=model2_scatter['Recall'] * 100,
|
514 |
mode='markers+text',
|
515 |
name=model2,
|
516 |
text=model2_scatter['Class'],
|
|
|
518 |
marker=dict(size=10)
|
519 |
))
|
520 |
|
|
|
521 |
fig_scatter.add_trace(go.Scatter(
|
522 |
x=[0, 100],
|
523 |
y=[0, 100],
|
|
|
526 |
showlegend=False
|
527 |
))
|
528 |
|
|
|
529 |
fig_scatter.update_layout(
|
530 |
title="Precision vs Recall Analysis by Class",
|
531 |
xaxis_title="Precision (%)",
|
|
|
542 |
)
|
543 |
)
|
544 |
|
|
|
545 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
546 |
|
|
|
547 |
# Footer
|
548 |
st.markdown("---")
|
549 |
st.markdown("Dashboard created for model evaluation and comparison")
|
app.py
CHANGED
@@ -9,40 +9,15 @@ import io
|
|
9 |
import os
|
10 |
import base64
|
11 |
|
12 |
-
#
|
13 |
-
st.
|
14 |
-
|
15 |
-
page_icon="nexar_logo.png",
|
16 |
-
layout="wide"
|
17 |
-
)
|
18 |
|
19 |
-
|
20 |
-
st.
|
21 |
-
<style>
|
22 |
-
.main { padding: 2rem; }
|
23 |
-
.stTabs [data-baseweb="tab-list"] { gap: 8px; }
|
24 |
-
.stTabs [data-baseweb="tab"] {
|
25 |
-
padding: 8px 16px;
|
26 |
-
border-radius: 4px;
|
27 |
-
}
|
28 |
-
.metric-card {
|
29 |
-
background-color: #f8f9fa;
|
30 |
-
padding: 20px;
|
31 |
-
border-radius: 10px;
|
32 |
-
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
33 |
-
}
|
34 |
-
</style>
|
35 |
-
""", unsafe_allow_html=True)
|
36 |
-
|
37 |
-
# Header
|
38 |
-
col1, col2 = st.columns([0.16, 0.84])
|
39 |
-
with col1:
|
40 |
-
st.image("nexar_logo.png", width=600)
|
41 |
-
with col2:
|
42 |
-
st.title("Driving Leaderboard")
|
43 |
|
44 |
# Data loading function
|
45 |
-
@st.cache_data
|
46 |
def load_data(directory='results', labels_filename='Labels.csv'):
|
47 |
labels_path = os.path.join(directory, labels_filename)
|
48 |
df_labels = pd.read_csv(labels_path)
|
@@ -58,15 +33,7 @@ def load_data(directory='results', labels_filename='Labels.csv'):
|
|
58 |
model_comparison = ModelComparison(evaluators)
|
59 |
return model_comparison
|
60 |
|
61 |
-
#
|
62 |
-
if 'model_comparison' not in st.session_state:
|
63 |
-
st.session_state.model_comparison = load_data()
|
64 |
-
st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
|
65 |
-
st.session_state.combined_df = st.session_state.model_comparison.combined_df
|
66 |
-
|
67 |
-
# Create tabs
|
68 |
-
tab1, tab2, tab3, tab4 = st.tabs(["π Leaderboard", "π Class Performance", "π Detailed Metrics", "βοΈ Model Comparison"])
|
69 |
-
|
70 |
def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
71 |
numeric_cols = df.select_dtypes(include=['float64']).columns
|
72 |
|
@@ -110,8 +77,70 @@ def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
|
110 |
])
|
111 |
return styled
|
112 |
|
113 |
-
#
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
st.subheader("Model Performance Leaderboard")
|
116 |
|
117 |
sort_col = st.selectbox(
|
@@ -121,11 +150,7 @@ with tab1:
|
|
121 |
)
|
122 |
|
123 |
sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
|
124 |
-
|
125 |
-
st.dataframe(
|
126 |
-
style_dataframe(sorted_df),
|
127 |
-
use_container_width=True,
|
128 |
-
)
|
129 |
|
130 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
131 |
selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
|
@@ -151,11 +176,10 @@ with tab1:
|
|
151 |
|
152 |
st.plotly_chart(fig, use_container_width=True)
|
153 |
|
154 |
-
|
155 |
-
with tab2:
|
156 |
st.subheader("Class-level Performance")
|
157 |
categories = st.session_state.combined_df['Category'].unique()
|
158 |
-
|
159 |
col1, col2, col3 = st.columns(3)
|
160 |
with col1:
|
161 |
selected_category = st.selectbox(
|
@@ -170,23 +194,26 @@ with tab2:
|
|
170 |
key='class_metric'
|
171 |
)
|
172 |
with col3:
|
|
|
173 |
selected_models = st.multiselect(
|
174 |
"Select Models:",
|
175 |
-
|
176 |
-
default=
|
|
|
177 |
)
|
178 |
-
|
179 |
-
# Create
|
180 |
plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
|
181 |
-
model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(
|
182 |
|
|
|
183 |
class_data = st.session_state.combined_df[
|
184 |
(st.session_state.combined_df['Category'] == selected_category) &
|
185 |
(~st.session_state.combined_df['Class'].str.contains('Overall')) &
|
186 |
(st.session_state.combined_df['Model'].isin(selected_models))
|
187 |
]
|
188 |
|
189 |
-
# Bar chart
|
190 |
fig = px.bar(
|
191 |
class_data,
|
192 |
x='Class',
|
@@ -198,13 +225,11 @@ with tab2:
|
|
198 |
)
|
199 |
st.plotly_chart(fig, use_container_width=True)
|
200 |
|
201 |
-
#
|
|
|
202 |
models_per_row = 4
|
203 |
num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
|
204 |
|
205 |
-
st.markdown("### Select Models to Display:")
|
206 |
-
|
207 |
-
# Create toggles for models using st.columns
|
208 |
for row in range(num_rows):
|
209 |
cols = st.columns(models_per_row)
|
210 |
for col_idx in range(models_per_row):
|
@@ -212,50 +237,44 @@ with tab2:
|
|
212 |
if model_idx < len(selected_models):
|
213 |
model = selected_models[model_idx]
|
214 |
container = cols[col_idx].container()
|
215 |
-
|
216 |
-
# Get the consistent color for this model
|
217 |
color = model_colors[model]
|
218 |
|
219 |
-
#
|
220 |
-
toggle_key = f"toggle_{model}"
|
221 |
-
if toggle_key not in st.session_state:
|
222 |
-
st.session_state[toggle_key] = True
|
223 |
-
|
224 |
-
# Create colored legend item with HTML
|
225 |
container.markdown(
|
226 |
f"""
|
227 |
-
<div style='display: flex; align-items: center; margin-bottom: -40px;
|
228 |
<span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
|
229 |
</div>
|
230 |
""",
|
231 |
unsafe_allow_html=True
|
232 |
)
|
233 |
|
234 |
-
#
|
235 |
-
container.checkbox(
|
236 |
-
f" {model}",
|
237 |
-
value=
|
238 |
-
key=
|
|
|
|
|
239 |
)
|
240 |
|
241 |
-
#
|
|
|
242 |
unique_classes = class_data['Class'].unique()
|
243 |
num_classes = len(unique_classes)
|
|
|
|
|
244 |
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
for row in range(num_rows):
|
250 |
-
cols = st.columns(3)
|
251 |
-
for col_idx in range(3):
|
252 |
-
class_idx = row * 3 + col_idx
|
253 |
if class_idx < num_classes:
|
254 |
current_class = unique_classes[class_idx]
|
255 |
|
256 |
-
#
|
257 |
visible_models = [model for model in selected_models
|
258 |
-
if
|
259 |
|
260 |
class_specific_data = class_data[
|
261 |
(class_data['Class'] == current_class) &
|
@@ -269,18 +288,16 @@ with tab2:
|
|
269 |
color='Model',
|
270 |
title=f'Precision vs Recall: {current_class}',
|
271 |
height=300,
|
272 |
-
color_discrete_map=model_colors
|
273 |
)
|
274 |
|
275 |
-
# Update layout for better visibility
|
276 |
fig.update_layout(
|
277 |
xaxis_range=[0, 1],
|
278 |
yaxis_range=[0, 1],
|
279 |
margin=dict(l=40, r=40, t=40, b=40),
|
280 |
-
showlegend=False
|
281 |
)
|
282 |
|
283 |
-
# Add diagonal reference line
|
284 |
fig.add_trace(
|
285 |
go.Scatter(
|
286 |
x=[0, 1],
|
@@ -293,74 +310,53 @@ with tab2:
|
|
293 |
|
294 |
cols[col_idx].plotly_chart(fig, use_container_width=True)
|
295 |
|
296 |
-
|
297 |
-
with tab3:
|
298 |
st.subheader("Detailed Metrics Analysis")
|
299 |
|
300 |
selected_model = st.selectbox(
|
301 |
"Select Model for Detailed Analysis:",
|
302 |
-
st.session_state.combined_df['Model'].unique()
|
|
|
303 |
)
|
304 |
|
305 |
model_data = st.session_state.combined_df[
|
306 |
st.session_state.combined_df['Model'] == selected_model
|
307 |
]
|
308 |
|
309 |
-
# Create metrics tables
|
310 |
st.markdown("### Performance Metrics by Category")
|
311 |
-
|
312 |
-
# Get unique categories and relevant classes for each category
|
313 |
categories = model_data['Category'].unique()
|
314 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
315 |
|
316 |
-
# Process data for each category
|
317 |
for category in categories:
|
318 |
st.markdown(f"#### {category}")
|
319 |
-
|
320 |
-
# Filter data for this category
|
321 |
category_data = model_data[model_data['Category'] == category].copy()
|
322 |
|
323 |
-
#
|
324 |
-
category_metrics = pd.DataFrame()
|
325 |
-
|
326 |
-
# Get classes for this category (excluding 'Overall' prefix)
|
327 |
classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
|
328 |
-
|
329 |
-
# Add the overall metric for this category
|
330 |
overall_data = category_data[category_data['Class'].str.contains('Overall')]
|
331 |
|
332 |
-
#
|
333 |
category_metrics = pd.DataFrame(index=classes)
|
334 |
-
|
335 |
-
# Add metrics columns
|
336 |
for metric in metrics:
|
337 |
-
# Add class-specific metrics
|
338 |
class_metrics = {}
|
339 |
for class_name in classes:
|
340 |
class_data = category_data[category_data['Class'] == class_name]
|
341 |
if not class_data.empty:
|
342 |
class_metrics[class_name] = class_data[metric].iloc[0]
|
343 |
-
|
344 |
category_metrics[metric] = pd.Series(class_metrics)
|
345 |
|
346 |
-
# Add overall metrics
|
347 |
if not overall_data.empty:
|
348 |
overall_row = pd.DataFrame({
|
349 |
metric: [overall_data[metric].iloc[0]] for metric in metrics
|
350 |
}, index=['Overall'])
|
351 |
category_metrics = pd.concat([overall_row, category_metrics])
|
352 |
|
353 |
-
|
354 |
-
styled_metrics = style_dataframe(category_metrics.round(4))
|
355 |
-
st.dataframe(styled_metrics, use_container_width=True)
|
356 |
-
|
357 |
-
# Add spacing between categories
|
358 |
st.markdown("---")
|
359 |
|
360 |
# Export functionality
|
361 |
st.markdown("### Export Data")
|
362 |
-
|
363 |
-
# Prepare export data
|
364 |
export_data = pd.DataFrame()
|
365 |
for category in categories:
|
366 |
category_data = model_data[model_data['Category'] == category].copy()
|
@@ -372,7 +368,6 @@ with tab3:
|
|
372 |
).round(4)
|
373 |
export_data = pd.concat([export_data, category_metrics])
|
374 |
|
375 |
-
# Create download button
|
376 |
csv = export_data.to_csv().encode()
|
377 |
st.download_button(
|
378 |
"Download Detailed Metrics",
|
@@ -382,31 +377,25 @@ with tab3:
|
|
382 |
key='download-csv'
|
383 |
)
|
384 |
|
385 |
-
|
386 |
-
with tab4:
|
387 |
st.header("Model Comparison Analysis")
|
388 |
|
389 |
-
# Create two columns for model selection
|
390 |
col1, col2 = st.columns(2)
|
391 |
-
|
392 |
-
# Model selection dropdown menus
|
393 |
with col1:
|
394 |
model1 = st.selectbox(
|
395 |
"Select First Model:",
|
396 |
st.session_state.combined_df['Model'].unique(),
|
397 |
-
key='
|
398 |
)
|
399 |
|
400 |
with col2:
|
401 |
-
# Filter out the first selected model from options
|
402 |
available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
|
403 |
model2 = st.selectbox(
|
404 |
"Select Second Model:",
|
405 |
available_models,
|
406 |
-
key='
|
407 |
)
|
408 |
|
409 |
-
# Category selection
|
410 |
selected_category = st.selectbox(
|
411 |
"Select Category for Comparison:",
|
412 |
st.session_state.combined_df['Category'].unique(),
|
@@ -423,26 +412,19 @@ with tab4:
|
|
423 |
(st.session_state.combined_df['Model'] == model2) &
|
424 |
(st.session_state.combined_df['Category'] == selected_category)
|
425 |
]
|
426 |
-
|
427 |
-
# Define metrics list
|
428 |
-
metrics = ['F1 Score', 'Precision', 'Recall']
|
429 |
|
430 |
-
# Create comparison tables section
|
431 |
st.subheader("Detailed Metrics Comparison")
|
|
|
432 |
|
433 |
-
# Create a table for each metric
|
434 |
for metric in metrics:
|
435 |
st.markdown(f"#### {metric} Comparison")
|
436 |
-
|
437 |
-
# Prepare data for the metric table
|
438 |
metric_data = []
|
|
|
439 |
for class_name in model1_data['Class'].unique():
|
440 |
-
|
441 |
-
|
442 |
-
m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0]
|
443 |
diff = m1_value - m2_value
|
444 |
|
445 |
-
# Add to comparison data
|
446 |
metric_data.append({
|
447 |
'Class': class_name,
|
448 |
model1: m1_value,
|
@@ -450,92 +432,46 @@ with tab4:
|
|
450 |
'Difference': diff
|
451 |
})
|
452 |
|
453 |
-
# Create DataFrame for the metric
|
454 |
metric_df = pd.DataFrame(metric_data)
|
455 |
-
|
456 |
-
# Style the table
|
457 |
-
def style_metric_table(df):
|
458 |
-
return df.style\
|
459 |
-
.format({
|
460 |
-
model1: '{:.2f}%',
|
461 |
-
model2: '{:.2f}%',
|
462 |
-
'Difference': '{:+.2f}%'
|
463 |
-
})\
|
464 |
-
.background_gradient(
|
465 |
-
cmap='RdYlGn',
|
466 |
-
subset=['Difference'],
|
467 |
-
vmin=-10,
|
468 |
-
vmax=10
|
469 |
-
)\
|
470 |
-
.set_properties(**{
|
471 |
-
'text-align': 'center',
|
472 |
-
'padding': '10px',
|
473 |
-
'border': '1px solid #dee2e6'
|
474 |
-
})\
|
475 |
-
.set_table_styles([
|
476 |
-
{'selector': 'th', 'props': [
|
477 |
-
('background-color', '#4a90e2'),
|
478 |
-
('color', 'white'),
|
479 |
-
('font-weight', 'bold'),
|
480 |
-
('text-align', 'center'),
|
481 |
-
('padding', '10px')
|
482 |
-
]}
|
483 |
-
])
|
484 |
-
|
485 |
-
# Display the styled table
|
486 |
-
|
487 |
-
st.dataframe(
|
488 |
-
style_dataframe(metric_df),
|
489 |
-
use_container_width=True,
|
490 |
-
)
|
491 |
-
# Add visual separator
|
492 |
st.markdown("---")
|
493 |
|
494 |
-
#
|
495 |
st.subheader("Visual Performance Analysis")
|
496 |
-
|
497 |
-
# Metric selector for bar chart
|
498 |
selected_metric = st.selectbox(
|
499 |
"Select Metric for Comparison:",
|
500 |
metrics,
|
501 |
-
key='
|
502 |
)
|
503 |
|
504 |
-
# Prepare data for
|
505 |
comparison_data = pd.DataFrame()
|
506 |
-
|
507 |
-
# Get data for both models
|
508 |
for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
|
509 |
-
# Filter out Overall classes and select relevant columns
|
510 |
model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
|
511 |
model_metrics = model_metrics.rename(columns={selected_metric: model_name})
|
512 |
|
513 |
-
# Merge with existing data or create new DataFrame
|
514 |
if idx == 0:
|
515 |
comparison_data = model_metrics
|
516 |
else:
|
517 |
comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
|
518 |
|
519 |
-
#
|
520 |
fig_bar = go.Figure()
|
521 |
|
522 |
-
# Add bars for first model
|
523 |
fig_bar.add_trace(go.Bar(
|
524 |
name=model1,
|
525 |
x=comparison_data['Class'],
|
526 |
-
y=comparison_data[model1],
|
527 |
marker_color='rgb(55, 83, 109)'
|
528 |
))
|
529 |
|
530 |
-
# Add bars for second model
|
531 |
fig_bar.add_trace(go.Bar(
|
532 |
name=model2,
|
533 |
x=comparison_data['Class'],
|
534 |
-
y=comparison_data[model2],
|
535 |
marker_color='rgb(26, 118, 255)'
|
536 |
))
|
537 |
|
538 |
-
# Update bar chart layout
|
539 |
fig_bar.update_layout(
|
540 |
title=f"{selected_metric} Comparison by Class",
|
541 |
xaxis_title="Class",
|
@@ -552,23 +488,19 @@ with tab4:
|
|
552 |
)
|
553 |
)
|
554 |
|
555 |
-
# Display bar chart
|
556 |
st.plotly_chart(fig_bar, use_container_width=True)
|
557 |
|
558 |
-
#
|
559 |
st.markdown("#### Precision-Recall Analysis")
|
560 |
|
561 |
-
# Filter data for scatter plot
|
562 |
model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
|
563 |
model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
|
564 |
|
565 |
-
# Create scatter plot
|
566 |
fig_scatter = go.Figure()
|
567 |
|
568 |
-
# Add scatter points for first model
|
569 |
fig_scatter.add_trace(go.Scatter(
|
570 |
-
x=model1_scatter['Precision']*100,
|
571 |
-
y=model1_scatter['Recall']*100,
|
572 |
mode='markers+text',
|
573 |
name=model1,
|
574 |
text=model1_scatter['Class'],
|
@@ -576,10 +508,9 @@ with tab4:
|
|
576 |
marker=dict(size=10)
|
577 |
))
|
578 |
|
579 |
-
# Add scatter points for second model
|
580 |
fig_scatter.add_trace(go.Scatter(
|
581 |
-
x=model2_scatter['Precision']*100,
|
582 |
-
y=model2_scatter['Recall']*100,
|
583 |
mode='markers+text',
|
584 |
name=model2,
|
585 |
text=model2_scatter['Class'],
|
@@ -587,7 +518,6 @@ with tab4:
|
|
587 |
marker=dict(size=10)
|
588 |
))
|
589 |
|
590 |
-
# Add reference line
|
591 |
fig_scatter.add_trace(go.Scatter(
|
592 |
x=[0, 100],
|
593 |
y=[0, 100],
|
@@ -596,7 +526,6 @@ with tab4:
|
|
596 |
showlegend=False
|
597 |
))
|
598 |
|
599 |
-
# Update scatter plot layout
|
600 |
fig_scatter.update_layout(
|
601 |
title="Precision vs Recall Analysis by Class",
|
602 |
xaxis_title="Precision (%)",
|
@@ -613,10 +542,8 @@ with tab4:
|
|
613 |
)
|
614 |
)
|
615 |
|
616 |
-
# Display scatter plot
|
617 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
618 |
|
619 |
-
|
620 |
# Footer
|
621 |
st.markdown("---")
|
622 |
st.markdown("Dashboard created for model evaluation and comparison")
|
|
|
9 |
import os
|
10 |
import base64
|
11 |
|
12 |
+
# Initialize session state
|
13 |
+
if 'active_tab' not in st.session_state:
|
14 |
+
st.session_state.active_tab = "π Leaderboard"
|
|
|
|
|
|
|
15 |
|
16 |
+
if 'toggle_states' not in st.session_state:
|
17 |
+
st.session_state.toggle_states = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Data loading function
|
20 |
+
@st.cache_data
|
21 |
def load_data(directory='results', labels_filename='Labels.csv'):
|
22 |
labels_path = os.path.join(directory, labels_filename)
|
23 |
df_labels = pd.read_csv(labels_path)
|
|
|
33 |
model_comparison = ModelComparison(evaluators)
|
34 |
return model_comparison
|
35 |
|
36 |
+
# Helper functions for styling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
|
38 |
numeric_cols = df.select_dtypes(include=['float64']).columns
|
39 |
|
|
|
77 |
])
|
78 |
return styled
|
79 |
|
80 |
+
# Toggle state management
|
81 |
+
def get_toggle_state(model_name):
|
82 |
+
key = f"toggle_{model_name}"
|
83 |
+
if key not in st.session_state.toggle_states:
|
84 |
+
st.session_state.toggle_states[key] = True
|
85 |
+
return st.session_state.toggle_states[key]
|
86 |
+
|
87 |
+
def set_toggle_state(model_name, value):
|
88 |
+
key = f"toggle_{model_name}"
|
89 |
+
st.session_state.toggle_states[key] = value
|
90 |
+
|
91 |
+
# Page configuration
|
92 |
+
st.set_page_config(
|
93 |
+
page_title="Nexar Driving Leaderboard",
|
94 |
+
page_icon="nexar_logo.png",
|
95 |
+
layout="wide"
|
96 |
+
)
|
97 |
+
|
98 |
+
# Custom styling
|
99 |
+
st.markdown("""
|
100 |
+
<style>
|
101 |
+
.main { padding: 2rem; }
|
102 |
+
.stTabs [data-baseweb="tab-list"] { gap: 8px; }
|
103 |
+
.stTabs [data-baseweb="tab"] {
|
104 |
+
padding: 8px 16px;
|
105 |
+
border-radius: 4px;
|
106 |
+
}
|
107 |
+
.metric-card {
|
108 |
+
background-color: #f8f9fa;
|
109 |
+
padding: 20px;
|
110 |
+
border-radius: 10px;
|
111 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
112 |
+
}
|
113 |
+
</style>
|
114 |
+
""", unsafe_allow_html=True)
|
115 |
+
|
116 |
+
# Header
|
117 |
+
col1, col2 = st.columns([0.16, 0.84])
|
118 |
+
with col1:
|
119 |
+
st.image("nexar_logo.png", width=600)
|
120 |
+
with col2:
|
121 |
+
st.title("Driving Leaderboard")
|
122 |
+
|
123 |
+
# Initialize data in session state
|
124 |
+
if 'model_comparison' not in st.session_state:
|
125 |
+
st.session_state.model_comparison = load_data()
|
126 |
+
st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
|
127 |
+
st.session_state.combined_df = st.session_state.model_comparison.combined_df
|
128 |
+
|
129 |
+
# Tab callback
|
130 |
+
def handle_tab_change(tab_name):
|
131 |
+
st.session_state.active_tab = tab_name
|
132 |
+
|
133 |
+
# Define tab names
|
134 |
+
tab_names = ["π Leaderboard", "π Class Performance", "π Detailed Metrics", "βοΈ Model Comparison"]
|
135 |
+
|
136 |
+
# Create tabs
|
137 |
+
selected_tab = st.radio("", tab_names, key="tab_selector",
|
138 |
+
horizontal=True, label_visibility="collapsed",
|
139 |
+
index=tab_names.index(st.session_state.active_tab))
|
140 |
+
handle_tab_change(selected_tab)
|
141 |
+
|
142 |
+
# Content based on selected tab
|
143 |
+
if st.session_state.active_tab == "π Leaderboard":
|
144 |
st.subheader("Model Performance Leaderboard")
|
145 |
|
146 |
sort_col = st.selectbox(
|
|
|
150 |
)
|
151 |
|
152 |
sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
|
153 |
+
st.dataframe(style_dataframe(sorted_df), use_container_width=True)
|
|
|
|
|
|
|
|
|
154 |
|
155 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
156 |
selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
|
|
|
176 |
|
177 |
st.plotly_chart(fig, use_container_width=True)
|
178 |
|
179 |
+
elif st.session_state.active_tab == "π Class Performance":
|
|
|
180 |
st.subheader("Class-level Performance")
|
181 |
categories = st.session_state.combined_df['Category'].unique()
|
182 |
+
metrics = ['F1 Score', 'Precision', 'Recall']
|
183 |
col1, col2, col3 = st.columns(3)
|
184 |
with col1:
|
185 |
selected_category = st.selectbox(
|
|
|
194 |
key='class_metric'
|
195 |
)
|
196 |
with col3:
|
197 |
+
all_models = sorted(st.session_state.combined_df['Model'].unique())
|
198 |
selected_models = st.multiselect(
|
199 |
"Select Models:",
|
200 |
+
all_models,
|
201 |
+
default=all_models,
|
202 |
+
key='selected_models'
|
203 |
)
|
204 |
+
|
205 |
+
# Create consistent color mapping
|
206 |
plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
|
207 |
+
model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(all_models)}
|
208 |
|
209 |
+
# Filter data
|
210 |
class_data = st.session_state.combined_df[
|
211 |
(st.session_state.combined_df['Category'] == selected_category) &
|
212 |
(~st.session_state.combined_df['Class'].str.contains('Overall')) &
|
213 |
(st.session_state.combined_df['Model'].isin(selected_models))
|
214 |
]
|
215 |
|
216 |
+
# Bar chart
|
217 |
fig = px.bar(
|
218 |
class_data,
|
219 |
x='Class',
|
|
|
225 |
)
|
226 |
st.plotly_chart(fig, use_container_width=True)
|
227 |
|
228 |
+
# Model toggles
|
229 |
+
st.markdown("### Model Visibility Controls")
|
230 |
models_per_row = 4
|
231 |
num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
|
232 |
|
|
|
|
|
|
|
233 |
for row in range(num_rows):
|
234 |
cols = st.columns(models_per_row)
|
235 |
for col_idx in range(models_per_row):
|
|
|
237 |
if model_idx < len(selected_models):
|
238 |
model = selected_models[model_idx]
|
239 |
container = cols[col_idx].container()
|
|
|
|
|
240 |
color = model_colors[model]
|
241 |
|
242 |
+
# Create colored indicator
|
|
|
|
|
|
|
|
|
|
|
243 |
container.markdown(
|
244 |
f"""
|
245 |
+
<div style='display: flex; align-items: center; margin-bottom: -40px;'>
|
246 |
<span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
|
247 |
</div>
|
248 |
""",
|
249 |
unsafe_allow_html=True
|
250 |
)
|
251 |
|
252 |
+
# Toggle checkbox
|
253 |
+
value = container.checkbox(
|
254 |
+
f" {model}",
|
255 |
+
value=get_toggle_state(model),
|
256 |
+
key=f"vis_{model}",
|
257 |
+
on_change=set_toggle_state,
|
258 |
+
args=(model, not get_toggle_state(model))
|
259 |
)
|
260 |
|
261 |
+
# Precision-Recall plots
|
262 |
+
st.markdown("### Precision-Recall Analysis by Class")
|
263 |
unique_classes = class_data['Class'].unique()
|
264 |
num_classes = len(unique_classes)
|
265 |
+
plots_per_row = 3
|
266 |
+
num_plot_rows = (num_classes + plots_per_row - 1) // plots_per_row
|
267 |
|
268 |
+
for row in range(num_plot_rows):
|
269 |
+
cols = st.columns(plots_per_row)
|
270 |
+
for col_idx in range(plots_per_row):
|
271 |
+
class_idx = row * plots_per_row + col_idx
|
|
|
|
|
|
|
|
|
272 |
if class_idx < num_classes:
|
273 |
current_class = unique_classes[class_idx]
|
274 |
|
275 |
+
# Get visible models
|
276 |
visible_models = [model for model in selected_models
|
277 |
+
if get_toggle_state(model)]
|
278 |
|
279 |
class_specific_data = class_data[
|
280 |
(class_data['Class'] == current_class) &
|
|
|
288 |
color='Model',
|
289 |
title=f'Precision vs Recall: {current_class}',
|
290 |
height=300,
|
291 |
+
color_discrete_map=model_colors
|
292 |
)
|
293 |
|
|
|
294 |
fig.update_layout(
|
295 |
xaxis_range=[0, 1],
|
296 |
yaxis_range=[0, 1],
|
297 |
margin=dict(l=40, r=40, t=40, b=40),
|
298 |
+
showlegend=False
|
299 |
)
|
300 |
|
|
|
301 |
fig.add_trace(
|
302 |
go.Scatter(
|
303 |
x=[0, 1],
|
|
|
310 |
|
311 |
cols[col_idx].plotly_chart(fig, use_container_width=True)
|
312 |
|
313 |
+
elif st.session_state.active_tab == "π Detailed Metrics":
|
|
|
314 |
st.subheader("Detailed Metrics Analysis")
|
315 |
|
316 |
selected_model = st.selectbox(
|
317 |
"Select Model for Detailed Analysis:",
|
318 |
+
st.session_state.combined_df['Model'].unique(),
|
319 |
+
key='detailed_model'
|
320 |
)
|
321 |
|
322 |
model_data = st.session_state.combined_df[
|
323 |
st.session_state.combined_df['Model'] == selected_model
|
324 |
]
|
325 |
|
|
|
326 |
st.markdown("### Performance Metrics by Category")
|
|
|
|
|
327 |
categories = model_data['Category'].unique()
|
328 |
metrics = ['F1 Score', 'Precision', 'Recall']
|
329 |
|
|
|
330 |
for category in categories:
|
331 |
st.markdown(f"#### {category}")
|
|
|
|
|
332 |
category_data = model_data[model_data['Category'] == category].copy()
|
333 |
|
334 |
+
# Get classes excluding Overall
|
|
|
|
|
|
|
335 |
classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
|
|
|
|
|
336 |
overall_data = category_data[category_data['Class'].str.contains('Overall')]
|
337 |
|
338 |
+
# Create metrics DataFrame
|
339 |
category_metrics = pd.DataFrame(index=classes)
|
|
|
|
|
340 |
for metric in metrics:
|
|
|
341 |
class_metrics = {}
|
342 |
for class_name in classes:
|
343 |
class_data = category_data[category_data['Class'] == class_name]
|
344 |
if not class_data.empty:
|
345 |
class_metrics[class_name] = class_data[metric].iloc[0]
|
|
|
346 |
category_metrics[metric] = pd.Series(class_metrics)
|
347 |
|
348 |
+
# Add overall metrics
|
349 |
if not overall_data.empty:
|
350 |
overall_row = pd.DataFrame({
|
351 |
metric: [overall_data[metric].iloc[0]] for metric in metrics
|
352 |
}, index=['Overall'])
|
353 |
category_metrics = pd.concat([overall_row, category_metrics])
|
354 |
|
355 |
+
st.dataframe(style_dataframe(category_metrics.round(4)), use_container_width=True)
|
|
|
|
|
|
|
|
|
356 |
st.markdown("---")
|
357 |
|
358 |
# Export functionality
|
359 |
st.markdown("### Export Data")
|
|
|
|
|
360 |
export_data = pd.DataFrame()
|
361 |
for category in categories:
|
362 |
category_data = model_data[model_data['Category'] == category].copy()
|
|
|
368 |
).round(4)
|
369 |
export_data = pd.concat([export_data, category_metrics])
|
370 |
|
|
|
371 |
csv = export_data.to_csv().encode()
|
372 |
st.download_button(
|
373 |
"Download Detailed Metrics",
|
|
|
377 |
key='download-csv'
|
378 |
)
|
379 |
|
380 |
+
elif st.session_state.active_tab == "βοΈ Model Comparison":
|
|
|
381 |
st.header("Model Comparison Analysis")
|
382 |
|
|
|
383 |
col1, col2 = st.columns(2)
|
|
|
|
|
384 |
with col1:
|
385 |
model1 = st.selectbox(
|
386 |
"Select First Model:",
|
387 |
st.session_state.combined_df['Model'].unique(),
|
388 |
+
key='compare_model1'
|
389 |
)
|
390 |
|
391 |
with col2:
|
|
|
392 |
available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
|
393 |
model2 = st.selectbox(
|
394 |
"Select Second Model:",
|
395 |
available_models,
|
396 |
+
key='compare_model2'
|
397 |
)
|
398 |
|
|
|
399 |
selected_category = st.selectbox(
|
400 |
"Select Category for Comparison:",
|
401 |
st.session_state.combined_df['Category'].unique(),
|
|
|
412 |
(st.session_state.combined_df['Model'] == model2) &
|
413 |
(st.session_state.combined_df['Category'] == selected_category)
|
414 |
]
|
|
|
|
|
|
|
415 |
|
|
|
416 |
st.subheader("Detailed Metrics Comparison")
|
417 |
+
metrics = ['F1 Score', 'Precision', 'Recall']
|
418 |
|
|
|
419 |
for metric in metrics:
|
420 |
st.markdown(f"#### {metric} Comparison")
|
|
|
|
|
421 |
metric_data = []
|
422 |
+
|
423 |
for class_name in model1_data['Class'].unique():
|
424 |
+
m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0] * 100
|
425 |
+
m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0] * 100
|
|
|
426 |
diff = m1_value - m2_value
|
427 |
|
|
|
428 |
metric_data.append({
|
429 |
'Class': class_name,
|
430 |
model1: m1_value,
|
|
|
432 |
'Difference': diff
|
433 |
})
|
434 |
|
|
|
435 |
metric_df = pd.DataFrame(metric_data)
|
436 |
+
st.dataframe(style_dataframe(metric_df), use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
st.markdown("---")
|
438 |
|
439 |
+
# Visual comparison
|
440 |
st.subheader("Visual Performance Analysis")
|
|
|
|
|
441 |
selected_metric = st.selectbox(
|
442 |
"Select Metric for Comparison:",
|
443 |
metrics,
|
444 |
+
key='visual_compare_metric'
|
445 |
)
|
446 |
|
447 |
+
# Prepare data for visualization
|
448 |
comparison_data = pd.DataFrame()
|
|
|
|
|
449 |
for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
|
|
|
450 |
model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
|
451 |
model_metrics = model_metrics.rename(columns={selected_metric: model_name})
|
452 |
|
|
|
453 |
if idx == 0:
|
454 |
comparison_data = model_metrics
|
455 |
else:
|
456 |
comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
|
457 |
|
458 |
+
# Bar chart
|
459 |
fig_bar = go.Figure()
|
460 |
|
|
|
461 |
fig_bar.add_trace(go.Bar(
|
462 |
name=model1,
|
463 |
x=comparison_data['Class'],
|
464 |
+
y=comparison_data[model1] * 100,
|
465 |
marker_color='rgb(55, 83, 109)'
|
466 |
))
|
467 |
|
|
|
468 |
fig_bar.add_trace(go.Bar(
|
469 |
name=model2,
|
470 |
x=comparison_data['Class'],
|
471 |
+
y=comparison_data[model2] * 100,
|
472 |
marker_color='rgb(26, 118, 255)'
|
473 |
))
|
474 |
|
|
|
475 |
fig_bar.update_layout(
|
476 |
title=f"{selected_metric} Comparison by Class",
|
477 |
xaxis_title="Class",
|
|
|
488 |
)
|
489 |
)
|
490 |
|
|
|
491 |
st.plotly_chart(fig_bar, use_container_width=True)
|
492 |
|
493 |
+
# Precision-Recall Analysis
|
494 |
st.markdown("#### Precision-Recall Analysis")
|
495 |
|
|
|
496 |
model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
|
497 |
model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
|
498 |
|
|
|
499 |
fig_scatter = go.Figure()
|
500 |
|
|
|
501 |
fig_scatter.add_trace(go.Scatter(
|
502 |
+
x=model1_scatter['Precision'] * 100,
|
503 |
+
y=model1_scatter['Recall'] * 100,
|
504 |
mode='markers+text',
|
505 |
name=model1,
|
506 |
text=model1_scatter['Class'],
|
|
|
508 |
marker=dict(size=10)
|
509 |
))
|
510 |
|
|
|
511 |
fig_scatter.add_trace(go.Scatter(
|
512 |
+
x=model2_scatter['Precision'] * 100,
|
513 |
+
y=model2_scatter['Recall'] * 100,
|
514 |
mode='markers+text',
|
515 |
name=model2,
|
516 |
text=model2_scatter['Class'],
|
|
|
518 |
marker=dict(size=10)
|
519 |
))
|
520 |
|
|
|
521 |
fig_scatter.add_trace(go.Scatter(
|
522 |
x=[0, 100],
|
523 |
y=[0, 100],
|
|
|
526 |
showlegend=False
|
527 |
))
|
528 |
|
|
|
529 |
fig_scatter.update_layout(
|
530 |
title="Precision vs Recall Analysis by Class",
|
531 |
xaxis_title="Precision (%)",
|
|
|
542 |
)
|
543 |
)
|
544 |
|
|
|
545 |
st.plotly_chart(fig_scatter, use_container_width=True)
|
546 |
|
|
|
547 |
# Footer
|
548 |
st.markdown("---")
|
549 |
st.markdown("Dashboard created for model evaluation and comparison")
|