Roni Goldshmidt commited on
Commit
dc5408d
Β·
1 Parent(s): 154bb23

Initial leaderboard setup

Browse files
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +132 -205
  2. app.py +132 -205
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -9,40 +9,15 @@ import io
9
  import os
10
  import base64
11
 
12
- # Page config
13
- st.set_page_config(
14
- page_title="Nexar Driving Leaderboard",
15
- page_icon="nexar_logo.png",
16
- layout="wide"
17
- )
18
 
19
- # Custom styling
20
- st.markdown("""
21
- <style>
22
- .main { padding: 2rem; }
23
- .stTabs [data-baseweb="tab-list"] { gap: 8px; }
24
- .stTabs [data-baseweb="tab"] {
25
- padding: 8px 16px;
26
- border-radius: 4px;
27
- }
28
- .metric-card {
29
- background-color: #f8f9fa;
30
- padding: 20px;
31
- border-radius: 10px;
32
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
33
- }
34
- </style>
35
- """, unsafe_allow_html=True)
36
-
37
- # Header
38
- col1, col2 = st.columns([0.16, 0.84])
39
- with col1:
40
- st.image("nexar_logo.png", width=600)
41
- with col2:
42
- st.title("Driving Leaderboard")
43
 
44
  # Data loading function
45
- @st.cache_data(experimental_allow_widgets=True)
46
  def load_data(directory='results', labels_filename='Labels.csv'):
47
  labels_path = os.path.join(directory, labels_filename)
48
  df_labels = pd.read_csv(labels_path)
@@ -58,15 +33,7 @@ def load_data(directory='results', labels_filename='Labels.csv'):
58
  model_comparison = ModelComparison(evaluators)
59
  return model_comparison
60
 
61
- # Initialize session state
62
- if 'model_comparison' not in st.session_state:
63
- st.session_state.model_comparison = load_data()
64
- st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
65
- st.session_state.combined_df = st.session_state.model_comparison.combined_df
66
-
67
- # Create tabs
68
- tab1, tab2, tab3, tab4 = st.tabs(["πŸ“ˆ Leaderboard", "πŸ“Š Class Performance", "πŸ” Detailed Metrics", "βš–οΈ Model Comparison"])
69
-
70
  def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
71
  numeric_cols = df.select_dtypes(include=['float64']).columns
72
 
@@ -110,8 +77,70 @@ def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
110
  ])
111
  return styled
112
 
113
- # Tab 1: Leaderboard
114
- with tab1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  st.subheader("Model Performance Leaderboard")
116
 
117
  sort_col = st.selectbox(
@@ -121,11 +150,7 @@ with tab1:
121
  )
122
 
123
  sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
124
-
125
- st.dataframe(
126
- style_dataframe(sorted_df),
127
- use_container_width=True,
128
- )
129
 
130
  metrics = ['F1 Score', 'Precision', 'Recall']
131
  selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
@@ -151,11 +176,10 @@ with tab1:
151
 
152
  st.plotly_chart(fig, use_container_width=True)
153
 
154
- # Tab 2: Class Performance
155
- with tab2:
156
  st.subheader("Class-level Performance")
157
  categories = st.session_state.combined_df['Category'].unique()
158
-
159
  col1, col2, col3 = st.columns(3)
160
  with col1:
161
  selected_category = st.selectbox(
@@ -170,23 +194,26 @@ with tab2:
170
  key='class_metric'
171
  )
172
  with col3:
 
173
  selected_models = st.multiselect(
174
  "Select Models:",
175
- st.session_state.combined_df['Model'].unique(),
176
- default=st.session_state.combined_df['Model'].unique()
 
177
  )
178
-
179
- # Create a consistent color mapping for all models
180
  plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
181
- model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(sorted(st.session_state.combined_df['Model'].unique()))}
182
 
 
183
  class_data = st.session_state.combined_df[
184
  (st.session_state.combined_df['Category'] == selected_category) &
185
  (~st.session_state.combined_df['Class'].str.contains('Overall')) &
186
  (st.session_state.combined_df['Model'].isin(selected_models))
187
  ]
188
 
189
- # Bar chart with consistent colors
190
  fig = px.bar(
191
  class_data,
192
  x='Class',
@@ -198,13 +225,11 @@ with tab2:
198
  )
199
  st.plotly_chart(fig, use_container_width=True)
200
 
201
- # Calculate how many columns we need (aim for about 4-5 models per row)
 
202
  models_per_row = 4
203
  num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
204
 
205
- st.markdown("### Select Models to Display:")
206
-
207
- # Create toggles for models using st.columns
208
  for row in range(num_rows):
209
  cols = st.columns(models_per_row)
210
  for col_idx in range(models_per_row):
@@ -212,50 +237,44 @@ with tab2:
212
  if model_idx < len(selected_models):
213
  model = selected_models[model_idx]
214
  container = cols[col_idx].container()
215
-
216
- # Get the consistent color for this model
217
  color = model_colors[model]
218
 
219
- # Initialize toggle state if needed
220
- toggle_key = f"toggle_{model}"
221
- if toggle_key not in st.session_state:
222
- st.session_state[toggle_key] = True
223
-
224
- # Create colored legend item with HTML
225
  container.markdown(
226
  f"""
227
- <div style='display: flex; align-items: center; margin-bottom: -40px; pointer-events: none;'>
228
  <span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
229
  </div>
230
  """,
231
  unsafe_allow_html=True
232
  )
233
 
234
- # Create the checkbox without reassigning to session state
235
- container.checkbox(
236
- f" {model}", # Add some spacing to account for the circle
237
- value=st.session_state[toggle_key],
238
- key=toggle_key # Use toggle_key directly as the key
 
 
239
  )
240
 
241
- # Individual Precision-Recall plots for each class
 
242
  unique_classes = class_data['Class'].unique()
243
  num_classes = len(unique_classes)
 
 
244
 
245
- # Calculate number of rows needed (3 plots per row)
246
- num_rows = (num_classes + 2) // 3 # Using ceiling division
247
-
248
- # Create plots row by row
249
- for row in range(num_rows):
250
- cols = st.columns(3)
251
- for col_idx in range(3):
252
- class_idx = row * 3 + col_idx
253
  if class_idx < num_classes:
254
  current_class = unique_classes[class_idx]
255
 
256
- # Filter data based on visible models
257
  visible_models = [model for model in selected_models
258
- if st.session_state[f"toggle_{model}"]]
259
 
260
  class_specific_data = class_data[
261
  (class_data['Class'] == current_class) &
@@ -269,18 +288,16 @@ with tab2:
269
  color='Model',
270
  title=f'Precision vs Recall: {current_class}',
271
  height=300,
272
- color_discrete_map=model_colors # Use consistent colors
273
  )
274
 
275
- # Update layout for better visibility
276
  fig.update_layout(
277
  xaxis_range=[0, 1],
278
  yaxis_range=[0, 1],
279
  margin=dict(l=40, r=40, t=40, b=40),
280
- showlegend=False # Hide individual legends
281
  )
282
 
283
- # Add diagonal reference line
284
  fig.add_trace(
285
  go.Scatter(
286
  x=[0, 1],
@@ -293,74 +310,53 @@ with tab2:
293
 
294
  cols[col_idx].plotly_chart(fig, use_container_width=True)
295
 
296
- # Tab 3: Detailed Metrics
297
- with tab3:
298
  st.subheader("Detailed Metrics Analysis")
299
 
300
  selected_model = st.selectbox(
301
  "Select Model for Detailed Analysis:",
302
- st.session_state.combined_df['Model'].unique()
 
303
  )
304
 
305
  model_data = st.session_state.combined_df[
306
  st.session_state.combined_df['Model'] == selected_model
307
  ]
308
 
309
- # Create metrics tables
310
  st.markdown("### Performance Metrics by Category")
311
-
312
- # Get unique categories and relevant classes for each category
313
  categories = model_data['Category'].unique()
314
  metrics = ['F1 Score', 'Precision', 'Recall']
315
 
316
- # Process data for each category
317
  for category in categories:
318
  st.markdown(f"#### {category}")
319
-
320
- # Filter data for this category
321
  category_data = model_data[model_data['Category'] == category].copy()
322
 
323
- # Create a clean table for this category
324
- category_metrics = pd.DataFrame()
325
-
326
- # Get classes for this category (excluding 'Overall' prefix)
327
  classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
328
-
329
- # Add the overall metric for this category
330
  overall_data = category_data[category_data['Class'].str.contains('Overall')]
331
 
332
- # Initialize the DataFrame with classes as index
333
  category_metrics = pd.DataFrame(index=classes)
334
-
335
- # Add metrics columns
336
  for metric in metrics:
337
- # Add class-specific metrics
338
  class_metrics = {}
339
  for class_name in classes:
340
  class_data = category_data[category_data['Class'] == class_name]
341
  if not class_data.empty:
342
  class_metrics[class_name] = class_data[metric].iloc[0]
343
-
344
  category_metrics[metric] = pd.Series(class_metrics)
345
 
346
- # Add overall metrics as a separate row
347
  if not overall_data.empty:
348
  overall_row = pd.DataFrame({
349
  metric: [overall_data[metric].iloc[0]] for metric in metrics
350
  }, index=['Overall'])
351
  category_metrics = pd.concat([overall_row, category_metrics])
352
 
353
- # Display the table
354
- styled_metrics = style_dataframe(category_metrics.round(4))
355
- st.dataframe(styled_metrics, use_container_width=True)
356
-
357
- # Add spacing between categories
358
  st.markdown("---")
359
 
360
  # Export functionality
361
  st.markdown("### Export Data")
362
-
363
- # Prepare export data
364
  export_data = pd.DataFrame()
365
  for category in categories:
366
  category_data = model_data[model_data['Category'] == category].copy()
@@ -372,7 +368,6 @@ with tab3:
372
  ).round(4)
373
  export_data = pd.concat([export_data, category_metrics])
374
 
375
- # Create download button
376
  csv = export_data.to_csv().encode()
377
  st.download_button(
378
  "Download Detailed Metrics",
@@ -382,31 +377,25 @@ with tab3:
382
  key='download-csv'
383
  )
384
 
385
- # Tab 4: Model Comparison
386
- with tab4:
387
  st.header("Model Comparison Analysis")
388
 
389
- # Create two columns for model selection
390
  col1, col2 = st.columns(2)
391
-
392
- # Model selection dropdown menus
393
  with col1:
394
  model1 = st.selectbox(
395
  "Select First Model:",
396
  st.session_state.combined_df['Model'].unique(),
397
- key='model1'
398
  )
399
 
400
  with col2:
401
- # Filter out the first selected model from options
402
  available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
403
  model2 = st.selectbox(
404
  "Select Second Model:",
405
  available_models,
406
- key='model2'
407
  )
408
 
409
- # Category selection
410
  selected_category = st.selectbox(
411
  "Select Category for Comparison:",
412
  st.session_state.combined_df['Category'].unique(),
@@ -423,26 +412,19 @@ with tab4:
423
  (st.session_state.combined_df['Model'] == model2) &
424
  (st.session_state.combined_df['Category'] == selected_category)
425
  ]
426
-
427
- # Define metrics list
428
- metrics = ['F1 Score', 'Precision', 'Recall']
429
 
430
- # Create comparison tables section
431
  st.subheader("Detailed Metrics Comparison")
 
432
 
433
- # Create a table for each metric
434
  for metric in metrics:
435
  st.markdown(f"#### {metric} Comparison")
436
-
437
- # Prepare data for the metric table
438
  metric_data = []
 
439
  for class_name in model1_data['Class'].unique():
440
- # Get values for both models
441
- m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0]
442
- m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0]
443
  diff = m1_value - m2_value
444
 
445
- # Add to comparison data
446
  metric_data.append({
447
  'Class': class_name,
448
  model1: m1_value,
@@ -450,92 +432,46 @@ with tab4:
450
  'Difference': diff
451
  })
452
 
453
- # Create DataFrame for the metric
454
  metric_df = pd.DataFrame(metric_data)
455
-
456
- # Style the table
457
- def style_metric_table(df):
458
- return df.style\
459
- .format({
460
- model1: '{:.2f}%',
461
- model2: '{:.2f}%',
462
- 'Difference': '{:+.2f}%'
463
- })\
464
- .background_gradient(
465
- cmap='RdYlGn',
466
- subset=['Difference'],
467
- vmin=-10,
468
- vmax=10
469
- )\
470
- .set_properties(**{
471
- 'text-align': 'center',
472
- 'padding': '10px',
473
- 'border': '1px solid #dee2e6'
474
- })\
475
- .set_table_styles([
476
- {'selector': 'th', 'props': [
477
- ('background-color', '#4a90e2'),
478
- ('color', 'white'),
479
- ('font-weight', 'bold'),
480
- ('text-align', 'center'),
481
- ('padding', '10px')
482
- ]}
483
- ])
484
-
485
- # Display the styled table
486
-
487
- st.dataframe(
488
- style_dataframe(metric_df),
489
- use_container_width=True,
490
- )
491
- # Add visual separator
492
  st.markdown("---")
493
 
494
- # Visualizations section
495
  st.subheader("Visual Performance Analysis")
496
-
497
- # Metric selector for bar chart
498
  selected_metric = st.selectbox(
499
  "Select Metric for Comparison:",
500
  metrics,
501
- key='compare_metric'
502
  )
503
 
504
- # Prepare data for bar chart
505
  comparison_data = pd.DataFrame()
506
-
507
- # Get data for both models
508
  for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
509
- # Filter out Overall classes and select relevant columns
510
  model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
511
  model_metrics = model_metrics.rename(columns={selected_metric: model_name})
512
 
513
- # Merge with existing data or create new DataFrame
514
  if idx == 0:
515
  comparison_data = model_metrics
516
  else:
517
  comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
518
 
519
- # Create bar chart
520
  fig_bar = go.Figure()
521
 
522
- # Add bars for first model
523
  fig_bar.add_trace(go.Bar(
524
  name=model1,
525
  x=comparison_data['Class'],
526
- y=comparison_data[model1],
527
  marker_color='rgb(55, 83, 109)'
528
  ))
529
 
530
- # Add bars for second model
531
  fig_bar.add_trace(go.Bar(
532
  name=model2,
533
  x=comparison_data['Class'],
534
- y=comparison_data[model2],
535
  marker_color='rgb(26, 118, 255)'
536
  ))
537
 
538
- # Update bar chart layout
539
  fig_bar.update_layout(
540
  title=f"{selected_metric} Comparison by Class",
541
  xaxis_title="Class",
@@ -552,23 +488,19 @@ with tab4:
552
  )
553
  )
554
 
555
- # Display bar chart
556
  st.plotly_chart(fig_bar, use_container_width=True)
557
 
558
- # Create Precision-Recall scatter plot
559
  st.markdown("#### Precision-Recall Analysis")
560
 
561
- # Filter data for scatter plot
562
  model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
563
  model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
564
 
565
- # Create scatter plot
566
  fig_scatter = go.Figure()
567
 
568
- # Add scatter points for first model
569
  fig_scatter.add_trace(go.Scatter(
570
- x=model1_scatter['Precision']*100,
571
- y=model1_scatter['Recall']*100,
572
  mode='markers+text',
573
  name=model1,
574
  text=model1_scatter['Class'],
@@ -576,10 +508,9 @@ with tab4:
576
  marker=dict(size=10)
577
  ))
578
 
579
- # Add scatter points for second model
580
  fig_scatter.add_trace(go.Scatter(
581
- x=model2_scatter['Precision']*100,
582
- y=model2_scatter['Recall']*100,
583
  mode='markers+text',
584
  name=model2,
585
  text=model2_scatter['Class'],
@@ -587,7 +518,6 @@ with tab4:
587
  marker=dict(size=10)
588
  ))
589
 
590
- # Add reference line
591
  fig_scatter.add_trace(go.Scatter(
592
  x=[0, 100],
593
  y=[0, 100],
@@ -596,7 +526,6 @@ with tab4:
596
  showlegend=False
597
  ))
598
 
599
- # Update scatter plot layout
600
  fig_scatter.update_layout(
601
  title="Precision vs Recall Analysis by Class",
602
  xaxis_title="Precision (%)",
@@ -613,10 +542,8 @@ with tab4:
613
  )
614
  )
615
 
616
- # Display scatter plot
617
  st.plotly_chart(fig_scatter, use_container_width=True)
618
 
619
-
620
  # Footer
621
  st.markdown("---")
622
  st.markdown("Dashboard created for model evaluation and comparison")
 
9
  import os
10
  import base64
11
 
12
+ # Initialize session state
13
+ if 'active_tab' not in st.session_state:
14
+ st.session_state.active_tab = "πŸ“ˆ Leaderboard"
 
 
 
15
 
16
+ if 'toggle_states' not in st.session_state:
17
+ st.session_state.toggle_states = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Data loading function
20
+ @st.cache_data
21
  def load_data(directory='results', labels_filename='Labels.csv'):
22
  labels_path = os.path.join(directory, labels_filename)
23
  df_labels = pd.read_csv(labels_path)
 
33
  model_comparison = ModelComparison(evaluators)
34
  return model_comparison
35
 
36
+ # Helper functions for styling
 
 
 
 
 
 
 
 
37
  def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
38
  numeric_cols = df.select_dtypes(include=['float64']).columns
39
 
 
77
  ])
78
  return styled
79
 
80
+ # Toggle state management
81
+ def get_toggle_state(model_name):
82
+ key = f"toggle_{model_name}"
83
+ if key not in st.session_state.toggle_states:
84
+ st.session_state.toggle_states[key] = True
85
+ return st.session_state.toggle_states[key]
86
+
87
+ def set_toggle_state(model_name, value):
88
+ key = f"toggle_{model_name}"
89
+ st.session_state.toggle_states[key] = value
90
+
91
+ # Page configuration
92
+ st.set_page_config(
93
+ page_title="Nexar Driving Leaderboard",
94
+ page_icon="nexar_logo.png",
95
+ layout="wide"
96
+ )
97
+
98
+ # Custom styling
99
+ st.markdown("""
100
+ <style>
101
+ .main { padding: 2rem; }
102
+ .stTabs [data-baseweb="tab-list"] { gap: 8px; }
103
+ .stTabs [data-baseweb="tab"] {
104
+ padding: 8px 16px;
105
+ border-radius: 4px;
106
+ }
107
+ .metric-card {
108
+ background-color: #f8f9fa;
109
+ padding: 20px;
110
+ border-radius: 10px;
111
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
112
+ }
113
+ </style>
114
+ """, unsafe_allow_html=True)
115
+
116
+ # Header
117
+ col1, col2 = st.columns([0.16, 0.84])
118
+ with col1:
119
+ st.image("nexar_logo.png", width=600)
120
+ with col2:
121
+ st.title("Driving Leaderboard")
122
+
123
+ # Initialize data in session state
124
+ if 'model_comparison' not in st.session_state:
125
+ st.session_state.model_comparison = load_data()
126
+ st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
127
+ st.session_state.combined_df = st.session_state.model_comparison.combined_df
128
+
129
+ # Tab callback
130
+ def handle_tab_change(tab_name):
131
+ st.session_state.active_tab = tab_name
132
+
133
+ # Define tab names
134
+ tab_names = ["πŸ“ˆ Leaderboard", "πŸ“Š Class Performance", "πŸ” Detailed Metrics", "βš–οΈ Model Comparison"]
135
+
136
+ # Create tabs
137
+ selected_tab = st.radio("", tab_names, key="tab_selector",
138
+ horizontal=True, label_visibility="collapsed",
139
+ index=tab_names.index(st.session_state.active_tab))
140
+ handle_tab_change(selected_tab)
141
+
142
+ # Content based on selected tab
143
+ if st.session_state.active_tab == "πŸ“ˆ Leaderboard":
144
  st.subheader("Model Performance Leaderboard")
145
 
146
  sort_col = st.selectbox(
 
150
  )
151
 
152
  sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
153
+ st.dataframe(style_dataframe(sorted_df), use_container_width=True)
 
 
 
 
154
 
155
  metrics = ['F1 Score', 'Precision', 'Recall']
156
  selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
 
176
 
177
  st.plotly_chart(fig, use_container_width=True)
178
 
179
+ elif st.session_state.active_tab == "πŸ“Š Class Performance":
 
180
  st.subheader("Class-level Performance")
181
  categories = st.session_state.combined_df['Category'].unique()
182
+ metrics = ['F1 Score', 'Precision', 'Recall']
183
  col1, col2, col3 = st.columns(3)
184
  with col1:
185
  selected_category = st.selectbox(
 
194
  key='class_metric'
195
  )
196
  with col3:
197
+ all_models = sorted(st.session_state.combined_df['Model'].unique())
198
  selected_models = st.multiselect(
199
  "Select Models:",
200
+ all_models,
201
+ default=all_models,
202
+ key='selected_models'
203
  )
204
+
205
+ # Create consistent color mapping
206
  plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
207
+ model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(all_models)}
208
 
209
+ # Filter data
210
  class_data = st.session_state.combined_df[
211
  (st.session_state.combined_df['Category'] == selected_category) &
212
  (~st.session_state.combined_df['Class'].str.contains('Overall')) &
213
  (st.session_state.combined_df['Model'].isin(selected_models))
214
  ]
215
 
216
+ # Bar chart
217
  fig = px.bar(
218
  class_data,
219
  x='Class',
 
225
  )
226
  st.plotly_chart(fig, use_container_width=True)
227
 
228
+ # Model toggles
229
+ st.markdown("### Model Visibility Controls")
230
  models_per_row = 4
231
  num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
232
 
 
 
 
233
  for row in range(num_rows):
234
  cols = st.columns(models_per_row)
235
  for col_idx in range(models_per_row):
 
237
  if model_idx < len(selected_models):
238
  model = selected_models[model_idx]
239
  container = cols[col_idx].container()
 
 
240
  color = model_colors[model]
241
 
242
+ # Create colored indicator
 
 
 
 
 
243
  container.markdown(
244
  f"""
245
+ <div style='display: flex; align-items: center; margin-bottom: -40px;'>
246
  <span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
247
  </div>
248
  """,
249
  unsafe_allow_html=True
250
  )
251
 
252
+ # Toggle checkbox
253
+ value = container.checkbox(
254
+ f" {model}",
255
+ value=get_toggle_state(model),
256
+ key=f"vis_{model}",
257
+ on_change=set_toggle_state,
258
+ args=(model, not get_toggle_state(model))
259
  )
260
 
261
+ # Precision-Recall plots
262
+ st.markdown("### Precision-Recall Analysis by Class")
263
  unique_classes = class_data['Class'].unique()
264
  num_classes = len(unique_classes)
265
+ plots_per_row = 3
266
+ num_plot_rows = (num_classes + plots_per_row - 1) // plots_per_row
267
 
268
+ for row in range(num_plot_rows):
269
+ cols = st.columns(plots_per_row)
270
+ for col_idx in range(plots_per_row):
271
+ class_idx = row * plots_per_row + col_idx
 
 
 
 
272
  if class_idx < num_classes:
273
  current_class = unique_classes[class_idx]
274
 
275
+ # Get visible models
276
  visible_models = [model for model in selected_models
277
+ if get_toggle_state(model)]
278
 
279
  class_specific_data = class_data[
280
  (class_data['Class'] == current_class) &
 
288
  color='Model',
289
  title=f'Precision vs Recall: {current_class}',
290
  height=300,
291
+ color_discrete_map=model_colors
292
  )
293
 
 
294
  fig.update_layout(
295
  xaxis_range=[0, 1],
296
  yaxis_range=[0, 1],
297
  margin=dict(l=40, r=40, t=40, b=40),
298
+ showlegend=False
299
  )
300
 
 
301
  fig.add_trace(
302
  go.Scatter(
303
  x=[0, 1],
 
310
 
311
  cols[col_idx].plotly_chart(fig, use_container_width=True)
312
 
313
+ elif st.session_state.active_tab == "πŸ” Detailed Metrics":
 
314
  st.subheader("Detailed Metrics Analysis")
315
 
316
  selected_model = st.selectbox(
317
  "Select Model for Detailed Analysis:",
318
+ st.session_state.combined_df['Model'].unique(),
319
+ key='detailed_model'
320
  )
321
 
322
  model_data = st.session_state.combined_df[
323
  st.session_state.combined_df['Model'] == selected_model
324
  ]
325
 
 
326
  st.markdown("### Performance Metrics by Category")
 
 
327
  categories = model_data['Category'].unique()
328
  metrics = ['F1 Score', 'Precision', 'Recall']
329
 
 
330
  for category in categories:
331
  st.markdown(f"#### {category}")
 
 
332
  category_data = model_data[model_data['Category'] == category].copy()
333
 
334
+ # Get classes excluding Overall
 
 
 
335
  classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
 
 
336
  overall_data = category_data[category_data['Class'].str.contains('Overall')]
337
 
338
+ # Create metrics DataFrame
339
  category_metrics = pd.DataFrame(index=classes)
 
 
340
  for metric in metrics:
 
341
  class_metrics = {}
342
  for class_name in classes:
343
  class_data = category_data[category_data['Class'] == class_name]
344
  if not class_data.empty:
345
  class_metrics[class_name] = class_data[metric].iloc[0]
 
346
  category_metrics[metric] = pd.Series(class_metrics)
347
 
348
+ # Add overall metrics
349
  if not overall_data.empty:
350
  overall_row = pd.DataFrame({
351
  metric: [overall_data[metric].iloc[0]] for metric in metrics
352
  }, index=['Overall'])
353
  category_metrics = pd.concat([overall_row, category_metrics])
354
 
355
+ st.dataframe(style_dataframe(category_metrics.round(4)), use_container_width=True)
 
 
 
 
356
  st.markdown("---")
357
 
358
  # Export functionality
359
  st.markdown("### Export Data")
 
 
360
  export_data = pd.DataFrame()
361
  for category in categories:
362
  category_data = model_data[model_data['Category'] == category].copy()
 
368
  ).round(4)
369
  export_data = pd.concat([export_data, category_metrics])
370
 
 
371
  csv = export_data.to_csv().encode()
372
  st.download_button(
373
  "Download Detailed Metrics",
 
377
  key='download-csv'
378
  )
379
 
380
+ elif st.session_state.active_tab == "βš–οΈ Model Comparison":
 
381
  st.header("Model Comparison Analysis")
382
 
 
383
  col1, col2 = st.columns(2)
 
 
384
  with col1:
385
  model1 = st.selectbox(
386
  "Select First Model:",
387
  st.session_state.combined_df['Model'].unique(),
388
+ key='compare_model1'
389
  )
390
 
391
  with col2:
 
392
  available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
393
  model2 = st.selectbox(
394
  "Select Second Model:",
395
  available_models,
396
+ key='compare_model2'
397
  )
398
 
 
399
  selected_category = st.selectbox(
400
  "Select Category for Comparison:",
401
  st.session_state.combined_df['Category'].unique(),
 
412
  (st.session_state.combined_df['Model'] == model2) &
413
  (st.session_state.combined_df['Category'] == selected_category)
414
  ]
 
 
 
415
 
 
416
  st.subheader("Detailed Metrics Comparison")
417
+ metrics = ['F1 Score', 'Precision', 'Recall']
418
 
 
419
  for metric in metrics:
420
  st.markdown(f"#### {metric} Comparison")
 
 
421
  metric_data = []
422
+
423
  for class_name in model1_data['Class'].unique():
424
+ m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0] * 100
425
+ m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0] * 100
 
426
  diff = m1_value - m2_value
427
 
 
428
  metric_data.append({
429
  'Class': class_name,
430
  model1: m1_value,
 
432
  'Difference': diff
433
  })
434
 
 
435
  metric_df = pd.DataFrame(metric_data)
436
+ st.dataframe(style_dataframe(metric_df), use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  st.markdown("---")
438
 
439
+ # Visual comparison
440
  st.subheader("Visual Performance Analysis")
 
 
441
  selected_metric = st.selectbox(
442
  "Select Metric for Comparison:",
443
  metrics,
444
+ key='visual_compare_metric'
445
  )
446
 
447
+ # Prepare data for visualization
448
  comparison_data = pd.DataFrame()
 
 
449
  for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
 
450
  model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
451
  model_metrics = model_metrics.rename(columns={selected_metric: model_name})
452
 
 
453
  if idx == 0:
454
  comparison_data = model_metrics
455
  else:
456
  comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
457
 
458
+ # Bar chart
459
  fig_bar = go.Figure()
460
 
 
461
  fig_bar.add_trace(go.Bar(
462
  name=model1,
463
  x=comparison_data['Class'],
464
+ y=comparison_data[model1] * 100,
465
  marker_color='rgb(55, 83, 109)'
466
  ))
467
 
 
468
  fig_bar.add_trace(go.Bar(
469
  name=model2,
470
  x=comparison_data['Class'],
471
+ y=comparison_data[model2] * 100,
472
  marker_color='rgb(26, 118, 255)'
473
  ))
474
 
 
475
  fig_bar.update_layout(
476
  title=f"{selected_metric} Comparison by Class",
477
  xaxis_title="Class",
 
488
  )
489
  )
490
 
 
491
  st.plotly_chart(fig_bar, use_container_width=True)
492
 
493
+ # Precision-Recall Analysis
494
  st.markdown("#### Precision-Recall Analysis")
495
 
 
496
  model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
497
  model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
498
 
 
499
  fig_scatter = go.Figure()
500
 
 
501
  fig_scatter.add_trace(go.Scatter(
502
+ x=model1_scatter['Precision'] * 100,
503
+ y=model1_scatter['Recall'] * 100,
504
  mode='markers+text',
505
  name=model1,
506
  text=model1_scatter['Class'],
 
508
  marker=dict(size=10)
509
  ))
510
 
 
511
  fig_scatter.add_trace(go.Scatter(
512
+ x=model2_scatter['Precision'] * 100,
513
+ y=model2_scatter['Recall'] * 100,
514
  mode='markers+text',
515
  name=model2,
516
  text=model2_scatter['Class'],
 
518
  marker=dict(size=10)
519
  ))
520
 
 
521
  fig_scatter.add_trace(go.Scatter(
522
  x=[0, 100],
523
  y=[0, 100],
 
526
  showlegend=False
527
  ))
528
 
 
529
  fig_scatter.update_layout(
530
  title="Precision vs Recall Analysis by Class",
531
  xaxis_title="Precision (%)",
 
542
  )
543
  )
544
 
 
545
  st.plotly_chart(fig_scatter, use_container_width=True)
546
 
 
547
  # Footer
548
  st.markdown("---")
549
  st.markdown("Dashboard created for model evaluation and comparison")
app.py CHANGED
@@ -9,40 +9,15 @@ import io
9
  import os
10
  import base64
11
 
12
- # Page config
13
- st.set_page_config(
14
- page_title="Nexar Driving Leaderboard",
15
- page_icon="nexar_logo.png",
16
- layout="wide"
17
- )
18
 
19
- # Custom styling
20
- st.markdown("""
21
- <style>
22
- .main { padding: 2rem; }
23
- .stTabs [data-baseweb="tab-list"] { gap: 8px; }
24
- .stTabs [data-baseweb="tab"] {
25
- padding: 8px 16px;
26
- border-radius: 4px;
27
- }
28
- .metric-card {
29
- background-color: #f8f9fa;
30
- padding: 20px;
31
- border-radius: 10px;
32
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
33
- }
34
- </style>
35
- """, unsafe_allow_html=True)
36
-
37
- # Header
38
- col1, col2 = st.columns([0.16, 0.84])
39
- with col1:
40
- st.image("nexar_logo.png", width=600)
41
- with col2:
42
- st.title("Driving Leaderboard")
43
 
44
  # Data loading function
45
- @st.cache_data(experimental_allow_widgets=True)
46
  def load_data(directory='results', labels_filename='Labels.csv'):
47
  labels_path = os.path.join(directory, labels_filename)
48
  df_labels = pd.read_csv(labels_path)
@@ -58,15 +33,7 @@ def load_data(directory='results', labels_filename='Labels.csv'):
58
  model_comparison = ModelComparison(evaluators)
59
  return model_comparison
60
 
61
- # Initialize session state
62
- if 'model_comparison' not in st.session_state:
63
- st.session_state.model_comparison = load_data()
64
- st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
65
- st.session_state.combined_df = st.session_state.model_comparison.combined_df
66
-
67
- # Create tabs
68
- tab1, tab2, tab3, tab4 = st.tabs(["πŸ“ˆ Leaderboard", "πŸ“Š Class Performance", "πŸ” Detailed Metrics", "βš–οΈ Model Comparison"])
69
-
70
  def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
71
  numeric_cols = df.select_dtypes(include=['float64']).columns
72
 
@@ -110,8 +77,70 @@ def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
110
  ])
111
  return styled
112
 
113
- # Tab 1: Leaderboard
114
- with tab1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  st.subheader("Model Performance Leaderboard")
116
 
117
  sort_col = st.selectbox(
@@ -121,11 +150,7 @@ with tab1:
121
  )
122
 
123
  sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
124
-
125
- st.dataframe(
126
- style_dataframe(sorted_df),
127
- use_container_width=True,
128
- )
129
 
130
  metrics = ['F1 Score', 'Precision', 'Recall']
131
  selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
@@ -151,11 +176,10 @@ with tab1:
151
 
152
  st.plotly_chart(fig, use_container_width=True)
153
 
154
- # Tab 2: Class Performance
155
- with tab2:
156
  st.subheader("Class-level Performance")
157
  categories = st.session_state.combined_df['Category'].unique()
158
-
159
  col1, col2, col3 = st.columns(3)
160
  with col1:
161
  selected_category = st.selectbox(
@@ -170,23 +194,26 @@ with tab2:
170
  key='class_metric'
171
  )
172
  with col3:
 
173
  selected_models = st.multiselect(
174
  "Select Models:",
175
- st.session_state.combined_df['Model'].unique(),
176
- default=st.session_state.combined_df['Model'].unique()
 
177
  )
178
-
179
- # Create a consistent color mapping for all models
180
  plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
181
- model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(sorted(st.session_state.combined_df['Model'].unique()))}
182
 
 
183
  class_data = st.session_state.combined_df[
184
  (st.session_state.combined_df['Category'] == selected_category) &
185
  (~st.session_state.combined_df['Class'].str.contains('Overall')) &
186
  (st.session_state.combined_df['Model'].isin(selected_models))
187
  ]
188
 
189
- # Bar chart with consistent colors
190
  fig = px.bar(
191
  class_data,
192
  x='Class',
@@ -198,13 +225,11 @@ with tab2:
198
  )
199
  st.plotly_chart(fig, use_container_width=True)
200
 
201
- # Calculate how many columns we need (aim for about 4-5 models per row)
 
202
  models_per_row = 4
203
  num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
204
 
205
- st.markdown("### Select Models to Display:")
206
-
207
- # Create toggles for models using st.columns
208
  for row in range(num_rows):
209
  cols = st.columns(models_per_row)
210
  for col_idx in range(models_per_row):
@@ -212,50 +237,44 @@ with tab2:
212
  if model_idx < len(selected_models):
213
  model = selected_models[model_idx]
214
  container = cols[col_idx].container()
215
-
216
- # Get the consistent color for this model
217
  color = model_colors[model]
218
 
219
- # Initialize toggle state if needed
220
- toggle_key = f"toggle_{model}"
221
- if toggle_key not in st.session_state:
222
- st.session_state[toggle_key] = True
223
-
224
- # Create colored legend item with HTML
225
  container.markdown(
226
  f"""
227
- <div style='display: flex; align-items: center; margin-bottom: -40px; pointer-events: none;'>
228
  <span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
229
  </div>
230
  """,
231
  unsafe_allow_html=True
232
  )
233
 
234
- # Create the checkbox without reassigning to session state
235
- container.checkbox(
236
- f" {model}", # Add some spacing to account for the circle
237
- value=st.session_state[toggle_key],
238
- key=toggle_key # Use toggle_key directly as the key
 
 
239
  )
240
 
241
- # Individual Precision-Recall plots for each class
 
242
  unique_classes = class_data['Class'].unique()
243
  num_classes = len(unique_classes)
 
 
244
 
245
- # Calculate number of rows needed (3 plots per row)
246
- num_rows = (num_classes + 2) // 3 # Using ceiling division
247
-
248
- # Create plots row by row
249
- for row in range(num_rows):
250
- cols = st.columns(3)
251
- for col_idx in range(3):
252
- class_idx = row * 3 + col_idx
253
  if class_idx < num_classes:
254
  current_class = unique_classes[class_idx]
255
 
256
- # Filter data based on visible models
257
  visible_models = [model for model in selected_models
258
- if st.session_state[f"toggle_{model}"]]
259
 
260
  class_specific_data = class_data[
261
  (class_data['Class'] == current_class) &
@@ -269,18 +288,16 @@ with tab2:
269
  color='Model',
270
  title=f'Precision vs Recall: {current_class}',
271
  height=300,
272
- color_discrete_map=model_colors # Use consistent colors
273
  )
274
 
275
- # Update layout for better visibility
276
  fig.update_layout(
277
  xaxis_range=[0, 1],
278
  yaxis_range=[0, 1],
279
  margin=dict(l=40, r=40, t=40, b=40),
280
- showlegend=False # Hide individual legends
281
  )
282
 
283
- # Add diagonal reference line
284
  fig.add_trace(
285
  go.Scatter(
286
  x=[0, 1],
@@ -293,74 +310,53 @@ with tab2:
293
 
294
  cols[col_idx].plotly_chart(fig, use_container_width=True)
295
 
296
- # Tab 3: Detailed Metrics
297
- with tab3:
298
  st.subheader("Detailed Metrics Analysis")
299
 
300
  selected_model = st.selectbox(
301
  "Select Model for Detailed Analysis:",
302
- st.session_state.combined_df['Model'].unique()
 
303
  )
304
 
305
  model_data = st.session_state.combined_df[
306
  st.session_state.combined_df['Model'] == selected_model
307
  ]
308
 
309
- # Create metrics tables
310
  st.markdown("### Performance Metrics by Category")
311
-
312
- # Get unique categories and relevant classes for each category
313
  categories = model_data['Category'].unique()
314
  metrics = ['F1 Score', 'Precision', 'Recall']
315
 
316
- # Process data for each category
317
  for category in categories:
318
  st.markdown(f"#### {category}")
319
-
320
- # Filter data for this category
321
  category_data = model_data[model_data['Category'] == category].copy()
322
 
323
- # Create a clean table for this category
324
- category_metrics = pd.DataFrame()
325
-
326
- # Get classes for this category (excluding 'Overall' prefix)
327
  classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
328
-
329
- # Add the overall metric for this category
330
  overall_data = category_data[category_data['Class'].str.contains('Overall')]
331
 
332
- # Initialize the DataFrame with classes as index
333
  category_metrics = pd.DataFrame(index=classes)
334
-
335
- # Add metrics columns
336
  for metric in metrics:
337
- # Add class-specific metrics
338
  class_metrics = {}
339
  for class_name in classes:
340
  class_data = category_data[category_data['Class'] == class_name]
341
  if not class_data.empty:
342
  class_metrics[class_name] = class_data[metric].iloc[0]
343
-
344
  category_metrics[metric] = pd.Series(class_metrics)
345
 
346
- # Add overall metrics as a separate row
347
  if not overall_data.empty:
348
  overall_row = pd.DataFrame({
349
  metric: [overall_data[metric].iloc[0]] for metric in metrics
350
  }, index=['Overall'])
351
  category_metrics = pd.concat([overall_row, category_metrics])
352
 
353
- # Display the table
354
- styled_metrics = style_dataframe(category_metrics.round(4))
355
- st.dataframe(styled_metrics, use_container_width=True)
356
-
357
- # Add spacing between categories
358
  st.markdown("---")
359
 
360
  # Export functionality
361
  st.markdown("### Export Data")
362
-
363
- # Prepare export data
364
  export_data = pd.DataFrame()
365
  for category in categories:
366
  category_data = model_data[model_data['Category'] == category].copy()
@@ -372,7 +368,6 @@ with tab3:
372
  ).round(4)
373
  export_data = pd.concat([export_data, category_metrics])
374
 
375
- # Create download button
376
  csv = export_data.to_csv().encode()
377
  st.download_button(
378
  "Download Detailed Metrics",
@@ -382,31 +377,25 @@ with tab3:
382
  key='download-csv'
383
  )
384
 
385
- # Tab 4: Model Comparison
386
- with tab4:
387
  st.header("Model Comparison Analysis")
388
 
389
- # Create two columns for model selection
390
  col1, col2 = st.columns(2)
391
-
392
- # Model selection dropdown menus
393
  with col1:
394
  model1 = st.selectbox(
395
  "Select First Model:",
396
  st.session_state.combined_df['Model'].unique(),
397
- key='model1'
398
  )
399
 
400
  with col2:
401
- # Filter out the first selected model from options
402
  available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
403
  model2 = st.selectbox(
404
  "Select Second Model:",
405
  available_models,
406
- key='model2'
407
  )
408
 
409
- # Category selection
410
  selected_category = st.selectbox(
411
  "Select Category for Comparison:",
412
  st.session_state.combined_df['Category'].unique(),
@@ -423,26 +412,19 @@ with tab4:
423
  (st.session_state.combined_df['Model'] == model2) &
424
  (st.session_state.combined_df['Category'] == selected_category)
425
  ]
426
-
427
- # Define metrics list
428
- metrics = ['F1 Score', 'Precision', 'Recall']
429
 
430
- # Create comparison tables section
431
  st.subheader("Detailed Metrics Comparison")
 
432
 
433
- # Create a table for each metric
434
  for metric in metrics:
435
  st.markdown(f"#### {metric} Comparison")
436
-
437
- # Prepare data for the metric table
438
  metric_data = []
 
439
  for class_name in model1_data['Class'].unique():
440
- # Get values for both models
441
- m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0]
442
- m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0]
443
  diff = m1_value - m2_value
444
 
445
- # Add to comparison data
446
  metric_data.append({
447
  'Class': class_name,
448
  model1: m1_value,
@@ -450,92 +432,46 @@ with tab4:
450
  'Difference': diff
451
  })
452
 
453
- # Create DataFrame for the metric
454
  metric_df = pd.DataFrame(metric_data)
455
-
456
- # Style the table
457
- def style_metric_table(df):
458
- return df.style\
459
- .format({
460
- model1: '{:.2f}%',
461
- model2: '{:.2f}%',
462
- 'Difference': '{:+.2f}%'
463
- })\
464
- .background_gradient(
465
- cmap='RdYlGn',
466
- subset=['Difference'],
467
- vmin=-10,
468
- vmax=10
469
- )\
470
- .set_properties(**{
471
- 'text-align': 'center',
472
- 'padding': '10px',
473
- 'border': '1px solid #dee2e6'
474
- })\
475
- .set_table_styles([
476
- {'selector': 'th', 'props': [
477
- ('background-color', '#4a90e2'),
478
- ('color', 'white'),
479
- ('font-weight', 'bold'),
480
- ('text-align', 'center'),
481
- ('padding', '10px')
482
- ]}
483
- ])
484
-
485
- # Display the styled table
486
-
487
- st.dataframe(
488
- style_dataframe(metric_df),
489
- use_container_width=True,
490
- )
491
- # Add visual separator
492
  st.markdown("---")
493
 
494
- # Visualizations section
495
  st.subheader("Visual Performance Analysis")
496
-
497
- # Metric selector for bar chart
498
  selected_metric = st.selectbox(
499
  "Select Metric for Comparison:",
500
  metrics,
501
- key='compare_metric'
502
  )
503
 
504
- # Prepare data for bar chart
505
  comparison_data = pd.DataFrame()
506
-
507
- # Get data for both models
508
  for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
509
- # Filter out Overall classes and select relevant columns
510
  model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
511
  model_metrics = model_metrics.rename(columns={selected_metric: model_name})
512
 
513
- # Merge with existing data or create new DataFrame
514
  if idx == 0:
515
  comparison_data = model_metrics
516
  else:
517
  comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
518
 
519
- # Create bar chart
520
  fig_bar = go.Figure()
521
 
522
- # Add bars for first model
523
  fig_bar.add_trace(go.Bar(
524
  name=model1,
525
  x=comparison_data['Class'],
526
- y=comparison_data[model1],
527
  marker_color='rgb(55, 83, 109)'
528
  ))
529
 
530
- # Add bars for second model
531
  fig_bar.add_trace(go.Bar(
532
  name=model2,
533
  x=comparison_data['Class'],
534
- y=comparison_data[model2],
535
  marker_color='rgb(26, 118, 255)'
536
  ))
537
 
538
- # Update bar chart layout
539
  fig_bar.update_layout(
540
  title=f"{selected_metric} Comparison by Class",
541
  xaxis_title="Class",
@@ -552,23 +488,19 @@ with tab4:
552
  )
553
  )
554
 
555
- # Display bar chart
556
  st.plotly_chart(fig_bar, use_container_width=True)
557
 
558
- # Create Precision-Recall scatter plot
559
  st.markdown("#### Precision-Recall Analysis")
560
 
561
- # Filter data for scatter plot
562
  model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
563
  model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
564
 
565
- # Create scatter plot
566
  fig_scatter = go.Figure()
567
 
568
- # Add scatter points for first model
569
  fig_scatter.add_trace(go.Scatter(
570
- x=model1_scatter['Precision']*100,
571
- y=model1_scatter['Recall']*100,
572
  mode='markers+text',
573
  name=model1,
574
  text=model1_scatter['Class'],
@@ -576,10 +508,9 @@ with tab4:
576
  marker=dict(size=10)
577
  ))
578
 
579
- # Add scatter points for second model
580
  fig_scatter.add_trace(go.Scatter(
581
- x=model2_scatter['Precision']*100,
582
- y=model2_scatter['Recall']*100,
583
  mode='markers+text',
584
  name=model2,
585
  text=model2_scatter['Class'],
@@ -587,7 +518,6 @@ with tab4:
587
  marker=dict(size=10)
588
  ))
589
 
590
- # Add reference line
591
  fig_scatter.add_trace(go.Scatter(
592
  x=[0, 100],
593
  y=[0, 100],
@@ -596,7 +526,6 @@ with tab4:
596
  showlegend=False
597
  ))
598
 
599
- # Update scatter plot layout
600
  fig_scatter.update_layout(
601
  title="Precision vs Recall Analysis by Class",
602
  xaxis_title="Precision (%)",
@@ -613,10 +542,8 @@ with tab4:
613
  )
614
  )
615
 
616
- # Display scatter plot
617
  st.plotly_chart(fig_scatter, use_container_width=True)
618
 
619
-
620
  # Footer
621
  st.markdown("---")
622
  st.markdown("Dashboard created for model evaluation and comparison")
 
9
  import os
10
  import base64
11
 
12
+ # Initialize session state
13
+ if 'active_tab' not in st.session_state:
14
+ st.session_state.active_tab = "πŸ“ˆ Leaderboard"
 
 
 
15
 
16
+ if 'toggle_states' not in st.session_state:
17
+ st.session_state.toggle_states = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Data loading function
20
+ @st.cache_data
21
  def load_data(directory='results', labels_filename='Labels.csv'):
22
  labels_path = os.path.join(directory, labels_filename)
23
  df_labels = pd.read_csv(labels_path)
 
33
  model_comparison = ModelComparison(evaluators)
34
  return model_comparison
35
 
36
+ # Helper functions for styling
 
 
 
 
 
 
 
 
37
  def style_dataframe(df, highlight_first_column=True, show_progress_bars=True):
38
  numeric_cols = df.select_dtypes(include=['float64']).columns
39
 
 
77
  ])
78
  return styled
79
 
80
+ # Toggle state management
81
+ def get_toggle_state(model_name):
82
+ key = f"toggle_{model_name}"
83
+ if key not in st.session_state.toggle_states:
84
+ st.session_state.toggle_states[key] = True
85
+ return st.session_state.toggle_states[key]
86
+
87
+ def set_toggle_state(model_name, value):
88
+ key = f"toggle_{model_name}"
89
+ st.session_state.toggle_states[key] = value
90
+
91
+ # Page configuration
92
+ st.set_page_config(
93
+ page_title="Nexar Driving Leaderboard",
94
+ page_icon="nexar_logo.png",
95
+ layout="wide"
96
+ )
97
+
98
+ # Custom styling
99
+ st.markdown("""
100
+ <style>
101
+ .main { padding: 2rem; }
102
+ .stTabs [data-baseweb="tab-list"] { gap: 8px; }
103
+ .stTabs [data-baseweb="tab"] {
104
+ padding: 8px 16px;
105
+ border-radius: 4px;
106
+ }
107
+ .metric-card {
108
+ background-color: #f8f9fa;
109
+ padding: 20px;
110
+ border-radius: 10px;
111
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
112
+ }
113
+ </style>
114
+ """, unsafe_allow_html=True)
115
+
116
+ # Header
117
+ col1, col2 = st.columns([0.16, 0.84])
118
+ with col1:
119
+ st.image("nexar_logo.png", width=600)
120
+ with col2:
121
+ st.title("Driving Leaderboard")
122
+
123
+ # Initialize data in session state
124
+ if 'model_comparison' not in st.session_state:
125
+ st.session_state.model_comparison = load_data()
126
+ st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
127
+ st.session_state.combined_df = st.session_state.model_comparison.combined_df
128
+
129
+ # Tab callback
130
+ def handle_tab_change(tab_name):
131
+ st.session_state.active_tab = tab_name
132
+
133
+ # Define tab names
134
+ tab_names = ["πŸ“ˆ Leaderboard", "πŸ“Š Class Performance", "πŸ” Detailed Metrics", "βš–οΈ Model Comparison"]
135
+
136
+ # Create tabs
137
+ selected_tab = st.radio("", tab_names, key="tab_selector",
138
+ horizontal=True, label_visibility="collapsed",
139
+ index=tab_names.index(st.session_state.active_tab))
140
+ handle_tab_change(selected_tab)
141
+
142
+ # Content based on selected tab
143
+ if st.session_state.active_tab == "πŸ“ˆ Leaderboard":
144
  st.subheader("Model Performance Leaderboard")
145
 
146
  sort_col = st.selectbox(
 
150
  )
151
 
152
  sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
153
+ st.dataframe(style_dataframe(sorted_df), use_container_width=True)
 
 
 
 
154
 
155
  metrics = ['F1 Score', 'Precision', 'Recall']
156
  selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
 
176
 
177
  st.plotly_chart(fig, use_container_width=True)
178
 
179
+ elif st.session_state.active_tab == "πŸ“Š Class Performance":
 
180
  st.subheader("Class-level Performance")
181
  categories = st.session_state.combined_df['Category'].unique()
182
+ metrics = ['F1 Score', 'Precision', 'Recall']
183
  col1, col2, col3 = st.columns(3)
184
  with col1:
185
  selected_category = st.selectbox(
 
194
  key='class_metric'
195
  )
196
  with col3:
197
+ all_models = sorted(st.session_state.combined_df['Model'].unique())
198
  selected_models = st.multiselect(
199
  "Select Models:",
200
+ all_models,
201
+ default=all_models,
202
+ key='selected_models'
203
  )
204
+
205
+ # Create consistent color mapping
206
  plotly_colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', '#19D3F3', '#FF6692', '#B6E880', '#FF97FF', '#FECB52']
207
+ model_colors = {model: plotly_colors[i % len(plotly_colors)] for i, model in enumerate(all_models)}
208
 
209
+ # Filter data
210
  class_data = st.session_state.combined_df[
211
  (st.session_state.combined_df['Category'] == selected_category) &
212
  (~st.session_state.combined_df['Class'].str.contains('Overall')) &
213
  (st.session_state.combined_df['Model'].isin(selected_models))
214
  ]
215
 
216
+ # Bar chart
217
  fig = px.bar(
218
  class_data,
219
  x='Class',
 
225
  )
226
  st.plotly_chart(fig, use_container_width=True)
227
 
228
+ # Model toggles
229
+ st.markdown("### Model Visibility Controls")
230
  models_per_row = 4
231
  num_rows = (len(selected_models) + models_per_row - 1) // models_per_row
232
 
 
 
 
233
  for row in range(num_rows):
234
  cols = st.columns(models_per_row)
235
  for col_idx in range(models_per_row):
 
237
  if model_idx < len(selected_models):
238
  model = selected_models[model_idx]
239
  container = cols[col_idx].container()
 
 
240
  color = model_colors[model]
241
 
242
+ # Create colored indicator
 
 
 
 
 
243
  container.markdown(
244
  f"""
245
+ <div style='display: flex; align-items: center; margin-bottom: -40px;'>
246
  <span style='display: inline-block; width: 12px; height: 12px; background-color: {color}; border-radius: 50%; margin-right: 8px;'></span>
247
  </div>
248
  """,
249
  unsafe_allow_html=True
250
  )
251
 
252
+ # Toggle checkbox
253
+ value = container.checkbox(
254
+ f" {model}",
255
+ value=get_toggle_state(model),
256
+ key=f"vis_{model}",
257
+ on_change=set_toggle_state,
258
+ args=(model, not get_toggle_state(model))
259
  )
260
 
261
+ # Precision-Recall plots
262
+ st.markdown("### Precision-Recall Analysis by Class")
263
  unique_classes = class_data['Class'].unique()
264
  num_classes = len(unique_classes)
265
+ plots_per_row = 3
266
+ num_plot_rows = (num_classes + plots_per_row - 1) // plots_per_row
267
 
268
+ for row in range(num_plot_rows):
269
+ cols = st.columns(plots_per_row)
270
+ for col_idx in range(plots_per_row):
271
+ class_idx = row * plots_per_row + col_idx
 
 
 
 
272
  if class_idx < num_classes:
273
  current_class = unique_classes[class_idx]
274
 
275
+ # Get visible models
276
  visible_models = [model for model in selected_models
277
+ if get_toggle_state(model)]
278
 
279
  class_specific_data = class_data[
280
  (class_data['Class'] == current_class) &
 
288
  color='Model',
289
  title=f'Precision vs Recall: {current_class}',
290
  height=300,
291
+ color_discrete_map=model_colors
292
  )
293
 
 
294
  fig.update_layout(
295
  xaxis_range=[0, 1],
296
  yaxis_range=[0, 1],
297
  margin=dict(l=40, r=40, t=40, b=40),
298
+ showlegend=False
299
  )
300
 
 
301
  fig.add_trace(
302
  go.Scatter(
303
  x=[0, 1],
 
310
 
311
  cols[col_idx].plotly_chart(fig, use_container_width=True)
312
 
313
+ elif st.session_state.active_tab == "πŸ” Detailed Metrics":
 
314
  st.subheader("Detailed Metrics Analysis")
315
 
316
  selected_model = st.selectbox(
317
  "Select Model for Detailed Analysis:",
318
+ st.session_state.combined_df['Model'].unique(),
319
+ key='detailed_model'
320
  )
321
 
322
  model_data = st.session_state.combined_df[
323
  st.session_state.combined_df['Model'] == selected_model
324
  ]
325
 
 
326
  st.markdown("### Performance Metrics by Category")
 
 
327
  categories = model_data['Category'].unique()
328
  metrics = ['F1 Score', 'Precision', 'Recall']
329
 
 
330
  for category in categories:
331
  st.markdown(f"#### {category}")
 
 
332
  category_data = model_data[model_data['Category'] == category].copy()
333
 
334
+ # Get classes excluding Overall
 
 
 
335
  classes = category_data[~category_data['Class'].str.contains('Overall')]['Class'].unique()
 
 
336
  overall_data = category_data[category_data['Class'].str.contains('Overall')]
337
 
338
+ # Create metrics DataFrame
339
  category_metrics = pd.DataFrame(index=classes)
 
 
340
  for metric in metrics:
 
341
  class_metrics = {}
342
  for class_name in classes:
343
  class_data = category_data[category_data['Class'] == class_name]
344
  if not class_data.empty:
345
  class_metrics[class_name] = class_data[metric].iloc[0]
 
346
  category_metrics[metric] = pd.Series(class_metrics)
347
 
348
+ # Add overall metrics
349
  if not overall_data.empty:
350
  overall_row = pd.DataFrame({
351
  metric: [overall_data[metric].iloc[0]] for metric in metrics
352
  }, index=['Overall'])
353
  category_metrics = pd.concat([overall_row, category_metrics])
354
 
355
+ st.dataframe(style_dataframe(category_metrics.round(4)), use_container_width=True)
 
 
 
 
356
  st.markdown("---")
357
 
358
  # Export functionality
359
  st.markdown("### Export Data")
 
 
360
  export_data = pd.DataFrame()
361
  for category in categories:
362
  category_data = model_data[model_data['Category'] == category].copy()
 
368
  ).round(4)
369
  export_data = pd.concat([export_data, category_metrics])
370
 
 
371
  csv = export_data.to_csv().encode()
372
  st.download_button(
373
  "Download Detailed Metrics",
 
377
  key='download-csv'
378
  )
379
 
380
+ elif st.session_state.active_tab == "βš–οΈ Model Comparison":
 
381
  st.header("Model Comparison Analysis")
382
 
 
383
  col1, col2 = st.columns(2)
 
 
384
  with col1:
385
  model1 = st.selectbox(
386
  "Select First Model:",
387
  st.session_state.combined_df['Model'].unique(),
388
+ key='compare_model1'
389
  )
390
 
391
  with col2:
 
392
  available_models = [m for m in st.session_state.combined_df['Model'].unique() if m != model1]
393
  model2 = st.selectbox(
394
  "Select Second Model:",
395
  available_models,
396
+ key='compare_model2'
397
  )
398
 
 
399
  selected_category = st.selectbox(
400
  "Select Category for Comparison:",
401
  st.session_state.combined_df['Category'].unique(),
 
412
  (st.session_state.combined_df['Model'] == model2) &
413
  (st.session_state.combined_df['Category'] == selected_category)
414
  ]
 
 
 
415
 
 
416
  st.subheader("Detailed Metrics Comparison")
417
+ metrics = ['F1 Score', 'Precision', 'Recall']
418
 
 
419
  for metric in metrics:
420
  st.markdown(f"#### {metric} Comparison")
 
 
421
  metric_data = []
422
+
423
  for class_name in model1_data['Class'].unique():
424
+ m1_value = model1_data[model1_data['Class'] == class_name][metric].iloc[0] * 100
425
+ m2_value = model2_data[model2_data['Class'] == class_name][metric].iloc[0] * 100
 
426
  diff = m1_value - m2_value
427
 
 
428
  metric_data.append({
429
  'Class': class_name,
430
  model1: m1_value,
 
432
  'Difference': diff
433
  })
434
 
 
435
  metric_df = pd.DataFrame(metric_data)
436
+ st.dataframe(style_dataframe(metric_df), use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  st.markdown("---")
438
 
439
+ # Visual comparison
440
  st.subheader("Visual Performance Analysis")
 
 
441
  selected_metric = st.selectbox(
442
  "Select Metric for Comparison:",
443
  metrics,
444
+ key='visual_compare_metric'
445
  )
446
 
447
+ # Prepare data for visualization
448
  comparison_data = pd.DataFrame()
 
 
449
  for idx, (model_name, model_data) in enumerate([(model1, model1_data), (model2, model2_data)]):
 
450
  model_metrics = model_data[~model_data['Class'].str.contains('Overall', na=False)][['Class', selected_metric]]
451
  model_metrics = model_metrics.rename(columns={selected_metric: model_name})
452
 
 
453
  if idx == 0:
454
  comparison_data = model_metrics
455
  else:
456
  comparison_data = comparison_data.merge(model_metrics, on='Class', how='outer')
457
 
458
+ # Bar chart
459
  fig_bar = go.Figure()
460
 
 
461
  fig_bar.add_trace(go.Bar(
462
  name=model1,
463
  x=comparison_data['Class'],
464
+ y=comparison_data[model1] * 100,
465
  marker_color='rgb(55, 83, 109)'
466
  ))
467
 
 
468
  fig_bar.add_trace(go.Bar(
469
  name=model2,
470
  x=comparison_data['Class'],
471
+ y=comparison_data[model2] * 100,
472
  marker_color='rgb(26, 118, 255)'
473
  ))
474
 
 
475
  fig_bar.update_layout(
476
  title=f"{selected_metric} Comparison by Class",
477
  xaxis_title="Class",
 
488
  )
489
  )
490
 
 
491
  st.plotly_chart(fig_bar, use_container_width=True)
492
 
493
+ # Precision-Recall Analysis
494
  st.markdown("#### Precision-Recall Analysis")
495
 
 
496
  model1_scatter = model1_data[~model1_data['Class'].str.contains('Overall', na=False)]
497
  model2_scatter = model2_data[~model2_data['Class'].str.contains('Overall', na=False)]
498
 
 
499
  fig_scatter = go.Figure()
500
 
 
501
  fig_scatter.add_trace(go.Scatter(
502
+ x=model1_scatter['Precision'] * 100,
503
+ y=model1_scatter['Recall'] * 100,
504
  mode='markers+text',
505
  name=model1,
506
  text=model1_scatter['Class'],
 
508
  marker=dict(size=10)
509
  ))
510
 
 
511
  fig_scatter.add_trace(go.Scatter(
512
+ x=model2_scatter['Precision'] * 100,
513
+ y=model2_scatter['Recall'] * 100,
514
  mode='markers+text',
515
  name=model2,
516
  text=model2_scatter['Class'],
 
518
  marker=dict(size=10)
519
  ))
520
 
 
521
  fig_scatter.add_trace(go.Scatter(
522
  x=[0, 100],
523
  y=[0, 100],
 
526
  showlegend=False
527
  ))
528
 
 
529
  fig_scatter.update_layout(
530
  title="Precision vs Recall Analysis by Class",
531
  xaxis_title="Precision (%)",
 
542
  )
543
  )
544
 
 
545
  st.plotly_chart(fig_scatter, use_container_width=True)
546
 
 
547
  # Footer
548
  st.markdown("---")
549
  st.markdown("Dashboard created for model evaluation and comparison")