Roni Goldshmidt
commited on
Commit
·
796891d
1
Parent(s):
0306e3a
Initial leaderboard setup
Browse files- README-Copy1.md +14 -0
- app.py +284 -0
- comparison.py +670 -0
- nexar_logo.png +0 -0
- requirements.txt +6 -0
- results/GPT-4o.csv +0 -0
- results/Gemini-2.0-flash-lite.csv +0 -0
- results/Gemini-2.0-flash.csv +0 -0
- results/Gemini-2.0-pro.csv +0 -0
- results/Labels.csv +0 -0
README-Copy1.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Nexar Dashcam Leaderboard
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: Benchmarking driving event classification & visual insights
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from comparison import ModelEvaluator, ModelComparison
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import seaborn as sns
|
8 |
+
import io
|
9 |
+
import os
|
10 |
+
import base64
|
11 |
+
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="Nexar Dashcam Leaderboard",
|
14 |
+
page_icon="nexar_logo.png",
|
15 |
+
layout="wide"
|
16 |
+
)
|
17 |
+
|
18 |
+
st.markdown("""
|
19 |
+
<style>
|
20 |
+
.main { padding: 2rem; }
|
21 |
+
.stTabs [data-baseweb="tab-list"] { gap: 8px; }
|
22 |
+
.stTabs [data-baseweb="tab"] {
|
23 |
+
padding: 8px 16px;
|
24 |
+
border-radius: 4px;
|
25 |
+
}
|
26 |
+
.metric-card {
|
27 |
+
background-color: #f8f9fa;
|
28 |
+
padding: 20px;
|
29 |
+
border-radius: 10px;
|
30 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
31 |
+
}
|
32 |
+
</style>
|
33 |
+
""", unsafe_allow_html=True)
|
34 |
+
|
35 |
+
col1, col2 = st.columns([0.15, 0.85])
|
36 |
+
with col1:
|
37 |
+
st.image("nexar_logo.png", width=600)
|
38 |
+
with col2:
|
39 |
+
st.title("Nexar Dashcam Leaderboard")
|
40 |
+
|
41 |
+
@st.cache_data
|
42 |
+
def load_data(directory='results', labels_filename='Labels.csv'):
|
43 |
+
labels_path = os.path.join(directory, labels_filename)
|
44 |
+
df_labels = pd.read_csv(labels_path)
|
45 |
+
|
46 |
+
evaluators = []
|
47 |
+
|
48 |
+
for filename in os.listdir(directory):
|
49 |
+
if filename.endswith('.csv') and filename != labels_filename:
|
50 |
+
model_name = os.path.splitext(filename)[0]
|
51 |
+
df_model = pd.read_csv(os.path.join(directory, filename))
|
52 |
+
evaluator = ModelEvaluator(df_labels, df_model, model_name)
|
53 |
+
evaluators.append(evaluator)
|
54 |
+
|
55 |
+
model_comparison = ModelComparison(evaluators)
|
56 |
+
|
57 |
+
return model_comparison
|
58 |
+
|
59 |
+
if 'model_comparison' not in st.session_state:
|
60 |
+
st.session_state.model_comparison = load_data()
|
61 |
+
st.session_state.leaderboard_df = st.session_state.model_comparison.transform_to_leaderboard()
|
62 |
+
st.session_state.combined_df = st.session_state.model_comparison.combined_df
|
63 |
+
|
64 |
+
tab1, tab2, tab3, tab4 = st.tabs([
|
65 |
+
"📈 Leaderboard",
|
66 |
+
"🎯 Category Analysis",
|
67 |
+
"📊 Class Performance",
|
68 |
+
"🔍 Detailed Metrics"
|
69 |
+
])
|
70 |
+
|
71 |
+
def style_dataframe(df):
|
72 |
+
numeric_cols = df.select_dtypes(include=['float64']).columns
|
73 |
+
|
74 |
+
def background_gradient(s):
|
75 |
+
normalized = (s - s.min()) / (s.max() - s.min())
|
76 |
+
normalized = normalized.fillna(0) # Handle NaN values
|
77 |
+
return ['background: linear-gradient(90deg, rgba(52, 152, 219, 0.2) {}%, transparent {}%)'.format(
|
78 |
+
int(val * 100), int(val * 100)) for val in normalized]
|
79 |
+
|
80 |
+
def highlight_max(s):
|
81 |
+
is_max = s == s.max()
|
82 |
+
return ['font-weight: bold; color: #2ecc71' if v else '' for v in is_max]
|
83 |
+
|
84 |
+
styled = df.style\
|
85 |
+
.format({col: '{:.2f}%' for col in numeric_cols})\
|
86 |
+
.apply(background_gradient, subset=numeric_cols)\
|
87 |
+
.apply(highlight_max, subset=numeric_cols)\
|
88 |
+
.set_properties(**{
|
89 |
+
'background-color': '#f8f9fa',
|
90 |
+
'padding': '10px',
|
91 |
+
'border': '1px solid #dee2e6',
|
92 |
+
'text-align': 'center'
|
93 |
+
})\
|
94 |
+
.set_table_styles([
|
95 |
+
{'selector': 'th', 'props': [
|
96 |
+
('background-color', '#4a90e2'),
|
97 |
+
('color', 'white'),
|
98 |
+
('font-weight', 'bold'),
|
99 |
+
('padding', '10px'),
|
100 |
+
('text-align', 'center')
|
101 |
+
]},
|
102 |
+
{'selector': 'tr:hover', 'props': [
|
103 |
+
('background-color', '#edf2f7')
|
104 |
+
]}
|
105 |
+
])
|
106 |
+
|
107 |
+
return styled
|
108 |
+
|
109 |
+
with tab1:
|
110 |
+
st.subheader("Model Performance Leaderboard")
|
111 |
+
|
112 |
+
sort_col = st.selectbox(
|
113 |
+
"Sort by metric:",
|
114 |
+
options=[col for col in st.session_state.leaderboard_df.columns if col not in ['Rank', 'Model']],
|
115 |
+
key='leaderboard_sort'
|
116 |
+
)
|
117 |
+
|
118 |
+
sorted_df = st.session_state.leaderboard_df.sort_values(by=sort_col, ascending=False)
|
119 |
+
st.dataframe(
|
120 |
+
style_dataframe(sorted_df),
|
121 |
+
use_container_width=True,
|
122 |
+
height=400
|
123 |
+
)
|
124 |
+
|
125 |
+
# Category performance bar plot
|
126 |
+
metrics = ['F1 Score', 'Precision', 'Recall']
|
127 |
+
selected_metric = st.selectbox("Select Metric for Category Analysis:", metrics)
|
128 |
+
|
129 |
+
category_data = st.session_state.combined_df[
|
130 |
+
st.session_state.combined_df['Class'].str.contains('Overall')
|
131 |
+
]
|
132 |
+
|
133 |
+
fig = px.bar(
|
134 |
+
category_data,
|
135 |
+
x='Category',
|
136 |
+
y=selected_metric,
|
137 |
+
color='Model',
|
138 |
+
barmode='group',
|
139 |
+
title=f'Category-level {selected_metric} by Model',
|
140 |
+
)
|
141 |
+
|
142 |
+
fig.update_layout(
|
143 |
+
xaxis_title="Category",
|
144 |
+
yaxis_title=selected_metric,
|
145 |
+
legend_title="Model"
|
146 |
+
)
|
147 |
+
|
148 |
+
st.plotly_chart(fig, use_container_width=True)
|
149 |
+
|
150 |
+
with tab2:
|
151 |
+
st.subheader("Category-level Analysis")
|
152 |
+
|
153 |
+
categories = st.session_state.combined_df['Category'].unique()
|
154 |
+
selected_category = st.selectbox("Select Category:", categories)
|
155 |
+
|
156 |
+
col1, col2 = st.columns(2)
|
157 |
+
|
158 |
+
with col1:
|
159 |
+
category_data = st.session_state.combined_df[
|
160 |
+
st.session_state.combined_df['Class'].str.contains('Overall')
|
161 |
+
]
|
162 |
+
|
163 |
+
fig = px.bar(
|
164 |
+
category_data,
|
165 |
+
x='Category',
|
166 |
+
y=selected_metric,
|
167 |
+
color='Model',
|
168 |
+
barmode='group',
|
169 |
+
title=f'{selected_metric} by Category'
|
170 |
+
)
|
171 |
+
st.plotly_chart(fig, use_container_width=True)
|
172 |
+
|
173 |
+
with col2:
|
174 |
+
cat_data = st.session_state.combined_df[
|
175 |
+
(st.session_state.combined_df['Category'] == selected_category) &
|
176 |
+
(~st.session_state.combined_df['Class'].str.contains('Overall'))
|
177 |
+
]
|
178 |
+
|
179 |
+
fig = go.Figure()
|
180 |
+
|
181 |
+
for model in cat_data['Model'].unique():
|
182 |
+
model_data = cat_data[cat_data['Model'] == model]
|
183 |
+
fig.add_trace(go.Scatterpolar(
|
184 |
+
r=model_data[selected_metric],
|
185 |
+
theta=model_data['Class'],
|
186 |
+
name=model,
|
187 |
+
fill='toself'
|
188 |
+
))
|
189 |
+
|
190 |
+
fig.update_layout(
|
191 |
+
polar=dict(
|
192 |
+
radialaxis=dict(
|
193 |
+
visible=True,
|
194 |
+
range=[0, 1]
|
195 |
+
)
|
196 |
+
),
|
197 |
+
showlegend=True,
|
198 |
+
title=f'{selected_metric} Distribution for {selected_category}'
|
199 |
+
)
|
200 |
+
st.plotly_chart(fig, use_container_width=True)
|
201 |
+
|
202 |
+
with tab3:
|
203 |
+
st.subheader("Class-level Performance")
|
204 |
+
|
205 |
+
col1, col2, col3 = st.columns(3)
|
206 |
+
with col1:
|
207 |
+
selected_category = st.selectbox(
|
208 |
+
"Select Category:",
|
209 |
+
categories,
|
210 |
+
key='class_category'
|
211 |
+
)
|
212 |
+
with col2:
|
213 |
+
selected_metric = st.selectbox(
|
214 |
+
"Select Metric:",
|
215 |
+
metrics,
|
216 |
+
key='class_metric'
|
217 |
+
)
|
218 |
+
with col3:
|
219 |
+
selected_models = st.multiselect(
|
220 |
+
"Select Models:",
|
221 |
+
st.session_state.combined_df['Model'].unique(),
|
222 |
+
default=st.session_state.combined_df['Model'].unique()
|
223 |
+
)
|
224 |
+
|
225 |
+
class_data = st.session_state.combined_df[
|
226 |
+
(st.session_state.combined_df['Category'] == selected_category) &
|
227 |
+
(~st.session_state.combined_df['Class'].str.contains('Overall')) &
|
228 |
+
(st.session_state.combined_df['Model'].isin(selected_models))
|
229 |
+
]
|
230 |
+
|
231 |
+
fig = px.bar(
|
232 |
+
class_data,
|
233 |
+
x='Class',
|
234 |
+
y=selected_metric,
|
235 |
+
color='Model',
|
236 |
+
barmode='group',
|
237 |
+
title=f'{selected_metric} by Class for {selected_category}'
|
238 |
+
)
|
239 |
+
st.plotly_chart(fig, use_container_width=True)
|
240 |
+
|
241 |
+
fig = px.scatter(
|
242 |
+
class_data,
|
243 |
+
x='Precision',
|
244 |
+
y='Recall',
|
245 |
+
color='Model',
|
246 |
+
size='Support',
|
247 |
+
hover_data=['Class'],
|
248 |
+
title=f'Precision vs Recall for {selected_category}'
|
249 |
+
)
|
250 |
+
fig.update_traces(marker=dict(sizeref=2.*max(class_data['Support'])/40.**2))
|
251 |
+
st.plotly_chart(fig, use_container_width=True)
|
252 |
+
|
253 |
+
with tab4:
|
254 |
+
st.subheader("Detailed Metrics Analysis")
|
255 |
+
|
256 |
+
selected_model = st.selectbox(
|
257 |
+
"Select Model for Detailed Analysis:",
|
258 |
+
st.session_state.combined_df['Model'].unique()
|
259 |
+
)
|
260 |
+
|
261 |
+
model_data = st.session_state.combined_df[
|
262 |
+
st.session_state.combined_df['Model'] == selected_model
|
263 |
+
]
|
264 |
+
|
265 |
+
st.markdown("### Detailed Metrics Table")
|
266 |
+
detailed_metrics = model_data.pivot_table(
|
267 |
+
index='Category',
|
268 |
+
columns='Class',
|
269 |
+
values=['F1 Score', 'Precision', 'Recall']
|
270 |
+
).round(4)
|
271 |
+
|
272 |
+
st.dataframe(style_dataframe(detailed_metrics), use_container_width=True)
|
273 |
+
|
274 |
+
csv = detailed_metrics.to_csv().encode()
|
275 |
+
st.download_button(
|
276 |
+
"Download Detailed Metrics",
|
277 |
+
csv,
|
278 |
+
f"detailed_metrics_{selected_model}.csv",
|
279 |
+
"text/csv",
|
280 |
+
key='download-csv'
|
281 |
+
)
|
282 |
+
|
283 |
+
st.markdown("---")
|
284 |
+
st.markdown("Dashboard created for model evaluation and comparison.")
|
comparison.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import seaborn as sns
|
5 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, balanced_accuracy_score
|
6 |
+
import warnings
|
7 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="y_pred contains classes not in y_true")
|
8 |
+
sns.set_style("whitegrid")
|
9 |
+
|
10 |
+
class ModelEvaluator:
|
11 |
+
def __init__(self, df_labels, df_predictions, model_name):
|
12 |
+
"""
|
13 |
+
Initialize the evaluator with ground truth labels and model predictions.
|
14 |
+
"""
|
15 |
+
self.df_labels = df_labels
|
16 |
+
self.df_predictions = df_predictions
|
17 |
+
self.model_name = model_name
|
18 |
+
self.metrics_df = self.compute_metrics()
|
19 |
+
|
20 |
+
def merge_data(self):
|
21 |
+
"""Merge ground truth labels with predictions based on 'id'."""
|
22 |
+
merged_df = pd.merge(self.df_labels, self.df_predictions, on='id', suffixes=('_true', '_pred'))
|
23 |
+
return merged_df
|
24 |
+
|
25 |
+
def compute_metrics(self):
|
26 |
+
"""Compute precision, recall, F1-score, accuracy, and balanced accuracy for each class and category."""
|
27 |
+
merged_df = self.merge_data()
|
28 |
+
categories = ['main-event', 'location', 'zone', 'light-conditions', 'weather-conditions', 'vehicles-density']
|
29 |
+
|
30 |
+
results = []
|
31 |
+
|
32 |
+
for category in categories:
|
33 |
+
y_true = merged_df[f"{category}_true"].astype(str)
|
34 |
+
y_pred = merged_df[f"{category}_pred"].astype(str)
|
35 |
+
|
36 |
+
labels = sorted(set(y_true) | set(y_pred))
|
37 |
+
|
38 |
+
class_precisions = precision_score(y_true, y_pred, labels=labels, average=None, zero_division=0)
|
39 |
+
class_recalls = recall_score(y_true, y_pred, labels=labels, average=None, zero_division=0)
|
40 |
+
class_f1 = f1_score(y_true, y_pred, labels=labels, average=None, zero_division=0)
|
41 |
+
|
42 |
+
overall_precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
43 |
+
overall_recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
44 |
+
overall_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
45 |
+
overall_accuracy = accuracy_score(y_true, y_pred)
|
46 |
+
overall_balanced_acc = balanced_accuracy_score(y_true, y_pred)
|
47 |
+
|
48 |
+
for i, label in enumerate(labels):
|
49 |
+
results.append({
|
50 |
+
"Model": self.model_name,
|
51 |
+
"Category": category,
|
52 |
+
"Class": label,
|
53 |
+
"Precision": class_precisions[i],
|
54 |
+
"Recall": class_recalls[i],
|
55 |
+
"F1 Score": class_f1[i],
|
56 |
+
"Accuracy": np.nan,
|
57 |
+
"Balanced Acc.": np.nan,
|
58 |
+
"Support": (y_true == label).sum()
|
59 |
+
})
|
60 |
+
|
61 |
+
results.append({
|
62 |
+
"Model": self.model_name,
|
63 |
+
"Category": category,
|
64 |
+
"Class": f"Overall ({category})",
|
65 |
+
"Precision": overall_precision,
|
66 |
+
"Recall": overall_recall,
|
67 |
+
"F1 Score": overall_f1,
|
68 |
+
"Accuracy": overall_accuracy,
|
69 |
+
"Balanced Acc.": overall_balanced_acc,
|
70 |
+
"Support": len(y_true)
|
71 |
+
})
|
72 |
+
|
73 |
+
df_res = pd.DataFrame(results)
|
74 |
+
return df_res.loc[df_res['Support']>0].reset_index(drop=True)
|
75 |
+
|
76 |
+
def get_metrics_df(self):
|
77 |
+
"""Return the computed metrics DataFrame."""
|
78 |
+
return self.metrics_df
|
79 |
+
|
80 |
+
|
81 |
+
class ModelComparison:
|
82 |
+
def __init__(self, evaluators):
|
83 |
+
"""
|
84 |
+
Compare multiple models based on their evaluation results.
|
85 |
+
|
86 |
+
:param evaluators: List of ModelEvaluator instances.
|
87 |
+
"""
|
88 |
+
self.evaluators = evaluators
|
89 |
+
self.combined_df = self.aggregate_metrics()
|
90 |
+
|
91 |
+
def aggregate_metrics(self):
|
92 |
+
"""Merge evaluation metrics from multiple models into a single DataFrame."""
|
93 |
+
dfs = [evaluator.get_metrics_df() for evaluator in self.evaluators]
|
94 |
+
return pd.concat(dfs, ignore_index=True)
|
95 |
+
|
96 |
+
def plot_category_comparison(self, metric="F1 Score"):
|
97 |
+
"""Compare models at the category level using a grouped bar chart with consistent styling."""
|
98 |
+
df = self.combined_df[self.combined_df['Class'].str.contains("Overall")]
|
99 |
+
|
100 |
+
plt.figure(figsize=(12, 6))
|
101 |
+
colors = sns.color_palette("Set2", len(df["Model"].unique())) # Consistent palette
|
102 |
+
|
103 |
+
ax = sns.barplot(
|
104 |
+
data=df, x="Category", y=metric, hue="Model", palette=colors, edgecolor="black", alpha=0.85
|
105 |
+
)
|
106 |
+
|
107 |
+
plt.title(f"{metric} Comparison Across Categories", fontsize=14, fontweight="bold")
|
108 |
+
plt.ylim(0, 1)
|
109 |
+
plt.xticks(rotation=45, fontsize=12)
|
110 |
+
plt.yticks(fontsize=12)
|
111 |
+
plt.xlabel("Category", fontsize=12)
|
112 |
+
plt.ylabel(metric, fontsize=12)
|
113 |
+
plt.legend(title="Model", fontsize=11, loc="upper left")
|
114 |
+
plt.grid(axis="y", linestyle="--", alpha=0.6)
|
115 |
+
|
116 |
+
plt.tight_layout()
|
117 |
+
plt.show()
|
118 |
+
|
119 |
+
|
120 |
+
def plot_per_class_comparison(self, category, metric="F1 Score"):
|
121 |
+
"""Compare models for a specific category across individual classes with a standardized design."""
|
122 |
+
df = self.combined_df[(self.combined_df["Category"] == category) & (~self.combined_df["Class"].str.contains("Overall"))]
|
123 |
+
|
124 |
+
plt.figure(figsize=(12, 6))
|
125 |
+
colors = sns.color_palette("Set2", len(df["Model"].unique())) # Consistent palette
|
126 |
+
|
127 |
+
ax = sns.barplot(
|
128 |
+
data=df, x="Class", y=metric, hue="Model", palette=colors, edgecolor="black", alpha=0.85
|
129 |
+
)
|
130 |
+
|
131 |
+
plt.title(f"{metric} for {category} by Model", fontsize=14, fontweight="bold")
|
132 |
+
plt.ylim(0, 1)
|
133 |
+
plt.xticks(rotation=45, fontsize=12)
|
134 |
+
plt.yticks(fontsize=12)
|
135 |
+
plt.xlabel("Class", fontsize=12)
|
136 |
+
plt.ylabel(metric, fontsize=12)
|
137 |
+
plt.legend(title="Model", fontsize=11, loc="upper left")
|
138 |
+
plt.grid(axis="y", linestyle="--", alpha=0.6)
|
139 |
+
|
140 |
+
plt.tight_layout()
|
141 |
+
plt.show()
|
142 |
+
|
143 |
+
def plot_precision_recall_per_class(self, class_name=None):
|
144 |
+
"""
|
145 |
+
Creates a grouped bar chart per class, displaying precision & recall side by side for all models.
|
146 |
+
Ensures a consistent design with plot_per_class_comparison and plot_category_comparison.
|
147 |
+
|
148 |
+
:param class_name: (str) If provided, only this class will be plotted. If None, all classes will be plotted.
|
149 |
+
"""
|
150 |
+
import matplotlib.pyplot as plt
|
151 |
+
import seaborn as sns
|
152 |
+
import numpy as np
|
153 |
+
|
154 |
+
sns.set_style("whitegrid")
|
155 |
+
|
156 |
+
# Determine which classes to plot
|
157 |
+
if class_name:
|
158 |
+
unique_classes = [class_name]
|
159 |
+
else:
|
160 |
+
unique_classes = self.combined_df["Class"].unique()
|
161 |
+
|
162 |
+
models = self.combined_df["Model"].unique()
|
163 |
+
num_models = len(models)
|
164 |
+
|
165 |
+
bar_width = 0.35 # Standardized width for better readability
|
166 |
+
spacing = 0 # No extra spacing to match other plots
|
167 |
+
|
168 |
+
colors = sns.color_palette("Set2", num_models) # Consistent color palette
|
169 |
+
|
170 |
+
for class_name in unique_classes:
|
171 |
+
df_class = self.combined_df[self.combined_df["Class"] == class_name]
|
172 |
+
|
173 |
+
if df_class.empty:
|
174 |
+
print(f"No data available for class: {class_name}")
|
175 |
+
continue
|
176 |
+
|
177 |
+
plt.figure(figsize=(12, 6))
|
178 |
+
|
179 |
+
metrics = ["Precision", "Recall"]
|
180 |
+
x_indices = np.arange(len(metrics)) # X positions for metrics
|
181 |
+
|
182 |
+
for i, model in enumerate(models):
|
183 |
+
df_model = df_class[df_class["Model"] == model]
|
184 |
+
|
185 |
+
if df_model.empty:
|
186 |
+
continue
|
187 |
+
|
188 |
+
precision = df_model["Precision"].values[0]
|
189 |
+
recall = df_model["Recall"].values[0]
|
190 |
+
|
191 |
+
# Plot bars for Precision and Recall with consistent style
|
192 |
+
plt.bar(
|
193 |
+
x_indices + (i * bar_width), # No spacing, perfectly aligned
|
194 |
+
[precision, recall],
|
195 |
+
width=bar_width,
|
196 |
+
label=model,
|
197 |
+
color=colors[i],
|
198 |
+
alpha=0.85,
|
199 |
+
edgecolor="black" # Matches the other plots
|
200 |
+
)
|
201 |
+
|
202 |
+
plt.xlabel("Metric", fontsize=12)
|
203 |
+
plt.ylabel("Score", fontsize=12)
|
204 |
+
plt.title(f"Precision & Recall for Class: {class_name}", fontsize=14, fontweight="bold")
|
205 |
+
|
206 |
+
# Adjust x-tick positions to align properly
|
207 |
+
plt.xticks(x_indices + ((bar_width * (num_models - 1)) / 2), metrics, fontsize=12)
|
208 |
+
|
209 |
+
plt.ylim(0, 1)
|
210 |
+
plt.legend(title="Model", fontsize=11, loc="upper left")
|
211 |
+
plt.grid(axis="y", linestyle="--", alpha=0.6)
|
212 |
+
|
213 |
+
plt.tight_layout()
|
214 |
+
plt.show()
|
215 |
+
|
216 |
+
def plot_recall_trends(self, selected_models=None):
|
217 |
+
"""
|
218 |
+
Plot recall trends per class across different models, sorted by recall values in descending order.
|
219 |
+
|
220 |
+
:param selected_models: List of model names to include in the plot. If None, all models in the dataset will be used.
|
221 |
+
"""
|
222 |
+
import matplotlib.pyplot as plt
|
223 |
+
import seaborn as sns
|
224 |
+
import numpy as np
|
225 |
+
|
226 |
+
sns.set_style("whitegrid")
|
227 |
+
|
228 |
+
# If no specific models are provided, use all available models in the dataset
|
229 |
+
if selected_models is None:
|
230 |
+
selected_models = self.combined_df["Model"].unique().tolist()
|
231 |
+
|
232 |
+
# Filter dataset to include only selected models
|
233 |
+
df_filtered = self.combined_df[self.combined_df["Model"].isin(selected_models)]
|
234 |
+
df_filtered_no_overall = df_filtered[~df_filtered["Class"].str.contains("Overall")]
|
235 |
+
|
236 |
+
# Sort by recall values in descending order
|
237 |
+
df_sorted = df_filtered_no_overall.sort_values(by="Recall", ascending=False)
|
238 |
+
|
239 |
+
plt.figure(figsize=(12, 6))
|
240 |
+
unique_classes = df_sorted["Class"].unique()
|
241 |
+
|
242 |
+
# Define colors for models
|
243 |
+
colors = dict(zip(selected_models, sns.color_palette("Set2", len(selected_models))))
|
244 |
+
|
245 |
+
# Connect corresponding classes across models with thin lines (drawn first)
|
246 |
+
for class_name in unique_classes:
|
247 |
+
class_data = df_sorted[df_sorted["Class"] == class_name]
|
248 |
+
if len(class_data) > 1:
|
249 |
+
plt.plot(
|
250 |
+
class_data["Class"], class_data["Recall"],
|
251 |
+
linestyle="-", alpha=0.5, color="gray", linewidth=1.5, zorder=1
|
252 |
+
)
|
253 |
+
|
254 |
+
# Plot scatter points **after** lines to ensure they are on top
|
255 |
+
for model in selected_models:
|
256 |
+
model_data = df_sorted[df_sorted["Model"] == model]
|
257 |
+
plt.scatter(
|
258 |
+
model_data["Class"], model_data["Recall"],
|
259 |
+
label=model, color=colors[model], edgecolor="black", s=120, alpha=1.0, zorder=2
|
260 |
+
)
|
261 |
+
|
262 |
+
plt.xlabel("Class", fontsize=12)
|
263 |
+
plt.ylabel("Recall", fontsize=12)
|
264 |
+
plt.xticks(rotation=45, ha="right", fontsize=12)
|
265 |
+
plt.yticks(fontsize=12)
|
266 |
+
plt.title("Recall per Class for Selected Models (Sorted by Recall)", fontsize=14, fontweight="bold")
|
267 |
+
|
268 |
+
# Move legend to the right
|
269 |
+
plt.legend(title="Model", fontsize=11, loc="upper right", bbox_to_anchor=(1.15, 1))
|
270 |
+
|
271 |
+
plt.grid(axis="y", linestyle="--", alpha=0.6)
|
272 |
+
|
273 |
+
plt.tight_layout()
|
274 |
+
plt.show()
|
275 |
+
|
276 |
+
def plot_metric(self, metric_name, figsize=(10, None), bar_height=0.8, palette="Set2", bar_spacing=0):
|
277 |
+
"""
|
278 |
+
Creates a hierarchical visualization of metrics with category headers,
|
279 |
+
sorted by category-average descending. Ensures slight separation between model bars.
|
280 |
+
"""
|
281 |
+
colors = sns.color_palette(palette, len(self.evaluators))
|
282 |
+
models = list(self.combined_df["Model"].unique())
|
283 |
+
|
284 |
+
df = self.combined_df.copy()
|
285 |
+
df = df.drop_duplicates(subset=['Category', 'Class', 'Model', metric_name])
|
286 |
+
|
287 |
+
# Calculate average support per class
|
288 |
+
avg_support = df.groupby(['Category', 'Class'])['Support'].mean().round().astype(int)
|
289 |
+
|
290 |
+
# Function to safely retrieve metric values
|
291 |
+
def safe_get_value(model, category, class_name):
|
292 |
+
mask = (
|
293 |
+
(df['Model'] == model) &
|
294 |
+
(df['Category'] == category) &
|
295 |
+
(df['Class'] == class_name)
|
296 |
+
)
|
297 |
+
values = df.loc[mask, metric_name]
|
298 |
+
return values.iloc[0] if not values.empty else np.nan
|
299 |
+
|
300 |
+
# Calculate category averages, excluding 'Global', and sort descending
|
301 |
+
df_no_global = df[df['Category'] != 'Global']
|
302 |
+
cat_avgs = df_no_global.groupby('Category', observed=False)[metric_name].mean()
|
303 |
+
cat_avgs = cat_avgs.sort_values(ascending=False)
|
304 |
+
categories_ordered = list(cat_avgs.index)
|
305 |
+
|
306 |
+
if 'Global' in df['Category'].unique():
|
307 |
+
categories_ordered.append('Global')
|
308 |
+
|
309 |
+
plot_data = []
|
310 |
+
yticks = []
|
311 |
+
ylabels = []
|
312 |
+
y_pos = 0
|
313 |
+
category_positions = {}
|
314 |
+
|
315 |
+
# Process each category and its classes
|
316 |
+
for category in categories_ordered:
|
317 |
+
if category == 'Global':
|
318 |
+
continue
|
319 |
+
|
320 |
+
category_data = df[df['Category'] == category]
|
321 |
+
overall_class_name = f"Overall ({category})"
|
322 |
+
mask_overall = category_data['Class'] == overall_class_name
|
323 |
+
category_data_overall = category_data[mask_overall]
|
324 |
+
category_data_regular = category_data[~mask_overall]
|
325 |
+
|
326 |
+
if not category_data_regular.empty:
|
327 |
+
class_means = category_data_regular.groupby('Class')[metric_name].mean()
|
328 |
+
class_means = class_means.sort_values(ascending=False)
|
329 |
+
sorted_regular_classes = list(class_means.index)
|
330 |
+
else:
|
331 |
+
sorted_regular_classes = []
|
332 |
+
|
333 |
+
# Add category header
|
334 |
+
category_start = y_pos
|
335 |
+
yticks.append(y_pos)
|
336 |
+
ylabels.append(category.upper())
|
337 |
+
y_pos += 1
|
338 |
+
|
339 |
+
# Add regular classes
|
340 |
+
for class_name in sorted_regular_classes:
|
341 |
+
values = {model: safe_get_value(model, category, class_name) for model in models}
|
342 |
+
if any(not np.isnan(v) for v in values.values()):
|
343 |
+
plot_data.append({
|
344 |
+
'category': category,
|
345 |
+
'label': class_name,
|
346 |
+
'y_pos': y_pos,
|
347 |
+
'values': values,
|
348 |
+
'is_category': False
|
349 |
+
})
|
350 |
+
support = avg_support.get((category, class_name), 0)
|
351 |
+
yticks.append(y_pos)
|
352 |
+
ylabels.append(f" {class_name} (n={support:,})")
|
353 |
+
y_pos += 1
|
354 |
+
|
355 |
+
# Add overall class if exists
|
356 |
+
if not category_data_overall.empty:
|
357 |
+
values = {model: safe_get_value(model, category, overall_class_name) for model in models}
|
358 |
+
if any(not np.isnan(v) for v in values.values()):
|
359 |
+
plot_data.append({
|
360 |
+
'category': category,
|
361 |
+
'label': overall_class_name,
|
362 |
+
'y_pos': y_pos,
|
363 |
+
'values': values,
|
364 |
+
'is_category': False
|
365 |
+
})
|
366 |
+
support = avg_support.get((category, overall_class_name), 0)
|
367 |
+
yticks.append(y_pos)
|
368 |
+
ylabels.append(f" {overall_class_name} (n={support:,})")
|
369 |
+
y_pos += 1
|
370 |
+
|
371 |
+
category_positions[category] = {
|
372 |
+
'start': category_start,
|
373 |
+
'end': y_pos - 1
|
374 |
+
}
|
375 |
+
|
376 |
+
y_pos += 0.5 # Spacing between categories
|
377 |
+
|
378 |
+
# Calculate dynamic figure height based on number of items
|
379 |
+
total_items = len(plot_data) + len(categories_ordered)
|
380 |
+
dynamic_height = max(6, total_items * 0.4)
|
381 |
+
if figsize[1] is None:
|
382 |
+
figsize = (figsize[0], dynamic_height)
|
383 |
+
|
384 |
+
# Plot the bars
|
385 |
+
bar_width = bar_height / len(models) # No extra spacing
|
386 |
+
|
387 |
+
fig, ax = plt.subplots(figsize=figsize)
|
388 |
+
|
389 |
+
for category in categories_ordered:
|
390 |
+
if category == 'Global':
|
391 |
+
continue
|
392 |
+
cat_start = category_positions[category]['start'] - 0.4
|
393 |
+
cat_end = category_positions[category]['end'] + 0.4
|
394 |
+
ax.axhspan(cat_start, cat_end, color='lightgray', alpha=0.2, zorder=0)
|
395 |
+
|
396 |
+
for i, (model, color) in enumerate(zip(models, colors)):
|
397 |
+
positions = []
|
398 |
+
values = []
|
399 |
+
for item in plot_data:
|
400 |
+
if not item.get('is_category', False):
|
401 |
+
positions.append(item['y_pos'] + (i - len(models)/2) * bar_width)
|
402 |
+
values.append(item['values'].get(model, np.nan))
|
403 |
+
|
404 |
+
ax.barh(
|
405 |
+
positions, values, height=bar_width,
|
406 |
+
label=model, color=color, alpha=0.85, edgecolor="black"
|
407 |
+
)
|
408 |
+
|
409 |
+
# Main title
|
410 |
+
ax.set_title(f'{metric_name} Comparison Across Models', fontsize=16, fontweight='bold', pad=20)
|
411 |
+
|
412 |
+
# Adjust axis labels and formatting
|
413 |
+
ax.set_yticks(yticks)
|
414 |
+
ax.set_yticklabels(ylabels, fontsize=10)
|
415 |
+
ax.set_xlabel(metric_name, fontsize=12)
|
416 |
+
ax.grid(True, axis='x', linestyle="--", alpha=0.7)
|
417 |
+
|
418 |
+
# Invert y-axis to align properly
|
419 |
+
ax.invert_yaxis()
|
420 |
+
plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left')
|
421 |
+
|
422 |
+
# Adjust layout with tighter margins
|
423 |
+
plt.subplots_adjust(left=0.25, right=0.8, top=0.95, bottom=0.1)
|
424 |
+
plt.tight_layout()
|
425 |
+
|
426 |
+
return fig
|
427 |
+
|
428 |
+
def plot_precision_recall_for_category(self, category, palette="Set2"):
|
429 |
+
"""
|
430 |
+
Creates a modernized Precision-Recall scatter plot for each class within a given category.
|
431 |
+
"""
|
432 |
+
import matplotlib.pyplot as plt
|
433 |
+
import seaborn as sns
|
434 |
+
import math
|
435 |
+
import numpy as np
|
436 |
+
|
437 |
+
# Set modern style
|
438 |
+
plt.rcParams['font.size'] = 12
|
439 |
+
|
440 |
+
# Filter data for the selected category
|
441 |
+
df = self.combined_df[self.combined_df["Category"] == category].copy()
|
442 |
+
if df.empty:
|
443 |
+
print(f"No data available for category: {category}")
|
444 |
+
return None
|
445 |
+
|
446 |
+
# Remove overall category-level rows
|
447 |
+
class_data = df[~df["Class"].str.contains("Overall")]
|
448 |
+
|
449 |
+
# Get unique models and classes
|
450 |
+
models = df["Model"].unique()
|
451 |
+
colors = dict(zip(models, sns.color_palette(palette, len(models))))
|
452 |
+
classes = sorted(class_data["Class"].unique())
|
453 |
+
|
454 |
+
# Determine grid size
|
455 |
+
cols = 2
|
456 |
+
rows = math.ceil(len(classes) / cols)
|
457 |
+
|
458 |
+
# Create figure with adjusted size
|
459 |
+
fig, axes = plt.subplots(rows, cols, figsize=(16, rows * 6))
|
460 |
+
|
461 |
+
# Set global title with better spacing
|
462 |
+
fig.suptitle(f'Precision-Recall Analysis for {category}',
|
463 |
+
fontsize=20, fontweight='bold', y=1.02)
|
464 |
+
|
465 |
+
# Iterate over classes and create scatter plots
|
466 |
+
for i, class_name in enumerate(classes):
|
467 |
+
row, col = divmod(i, cols)
|
468 |
+
ax = axes[row, col] if rows > 1 else axes[col] # Ensure indexing works for 1-row cases
|
469 |
+
|
470 |
+
# Create scatter plot
|
471 |
+
class_subset = class_data[class_data["Class"] == class_name]
|
472 |
+
sns.scatterplot(
|
473 |
+
data=class_subset,
|
474 |
+
x="Precision",
|
475 |
+
y="Recall",
|
476 |
+
hue="Model",
|
477 |
+
palette=colors,
|
478 |
+
ax=ax,
|
479 |
+
s=200,
|
480 |
+
alpha=0.85,
|
481 |
+
edgecolor="black"
|
482 |
+
)
|
483 |
+
|
484 |
+
# Add labels with lines for each point
|
485 |
+
for idx, row in class_subset.iterrows():
|
486 |
+
ax.annotate(
|
487 |
+
row["Model"],
|
488 |
+
(row["Precision"], row["Recall"]),
|
489 |
+
xytext=(8, 8), textcoords='offset points', # Adjusted to reduce overlap
|
490 |
+
bbox=dict(facecolor='white', alpha=0.7),
|
491 |
+
arrowprops=dict(
|
492 |
+
arrowstyle='->',
|
493 |
+
connectionstyle='arc3,rad=0.2',
|
494 |
+
color='black'
|
495 |
+
)
|
496 |
+
)
|
497 |
+
|
498 |
+
ax.set_title(f"Class: {class_name}", fontsize=16, fontweight="bold", pad=20)
|
499 |
+
ax.set_xlim(-0.05, 1.05)
|
500 |
+
ax.set_ylim(-0.05, 1.05)
|
501 |
+
ax.grid(True, linestyle="--", alpha=0.5)
|
502 |
+
ax.set_aspect("equal", adjustable="box")
|
503 |
+
|
504 |
+
# Remove legend as we now have direct labels
|
505 |
+
ax.get_legend().remove()
|
506 |
+
|
507 |
+
# Add labels
|
508 |
+
ax.set_xlabel("Precision", fontsize=14)
|
509 |
+
ax.set_ylabel("Recall", fontsize=14)
|
510 |
+
|
511 |
+
# Remove empty subplots if classes < grid size
|
512 |
+
for j in range(i + 1, rows * cols):
|
513 |
+
fig.delaxes(axes.flatten()[j])
|
514 |
+
|
515 |
+
# Adjust layout with better spacing
|
516 |
+
fig.subplots_adjust(top=0.92, bottom=0.08, left=0.08, right=0.92, hspace=0.35, wspace=0.3)
|
517 |
+
|
518 |
+
return fig
|
519 |
+
|
520 |
+
def plot_normalized_radar_chart(self, metric_name="F1 Score", exclude_categories=None, figsize=(12, 10), palette="Set2"):
|
521 |
+
"""
|
522 |
+
Create a normalized radar chart comparing performance across different categories.
|
523 |
+
Each vertex is normalized independently based on its maximum value.
|
524 |
+
"""
|
525 |
+
import numpy as np
|
526 |
+
import matplotlib.pyplot as plt
|
527 |
+
import seaborn as sns
|
528 |
+
from matplotlib.patches import Circle
|
529 |
+
|
530 |
+
sns.set_style("whitegrid")
|
531 |
+
|
532 |
+
# Copy data and filter exclusions
|
533 |
+
df = self.combined_df.copy()
|
534 |
+
if exclude_categories:
|
535 |
+
df = df[~df["Category"].isin(exclude_categories)]
|
536 |
+
|
537 |
+
# Get unique categories and models
|
538 |
+
categories = sorted(df["Category"].unique())
|
539 |
+
models = sorted(df["Model"].unique())
|
540 |
+
|
541 |
+
# Define colors for models
|
542 |
+
colors = dict(zip(models, sns.color_palette(palette, len(models))))
|
543 |
+
|
544 |
+
# Create figure
|
545 |
+
fig = plt.figure(figsize=figsize)
|
546 |
+
ax = plt.subplot(111, polar=True)
|
547 |
+
|
548 |
+
# Add subtle background circles
|
549 |
+
for radius in np.linspace(0, 1, 5):
|
550 |
+
circle = Circle((0, 0), radius, transform=ax.transData._b,
|
551 |
+
fill=True, color='gray', alpha=0.03)
|
552 |
+
ax.add_artist(circle)
|
553 |
+
|
554 |
+
# Number of categories and angles
|
555 |
+
N = len(categories)
|
556 |
+
angles = np.linspace(0, 2 * np.pi, N, endpoint=False).tolist()
|
557 |
+
angles += angles[:1] # Close the circle
|
558 |
+
|
559 |
+
# Get max values for each category
|
560 |
+
df_overall = df[df["Class"].str.contains("Overall")]
|
561 |
+
max_values = df_overall.groupby("Category")[metric_name].max().to_dict()
|
562 |
+
|
563 |
+
# Store normalized values for all models
|
564 |
+
normalized_values = {}
|
565 |
+
|
566 |
+
# Normalize values for each model
|
567 |
+
for model in models:
|
568 |
+
values = []
|
569 |
+
for cat in categories:
|
570 |
+
val = df_overall[(df_overall["Model"] == model) &
|
571 |
+
(df_overall["Category"] == cat)][metric_name].values
|
572 |
+
val = val[0] if len(val) > 0 else 0
|
573 |
+
norm_val = val / max_values[cat] if max_values[cat] > 0 else 0
|
574 |
+
values.append(norm_val)
|
575 |
+
normalized_values[model] = values + [values[0]]
|
576 |
+
|
577 |
+
# Plot each model with improved styling
|
578 |
+
for model, values in normalized_values.items():
|
579 |
+
color = colors[model]
|
580 |
+
|
581 |
+
# Add filled area with gradient
|
582 |
+
ax.fill(angles, values, color=color, alpha=0.15,
|
583 |
+
edgecolor=color, linewidth=0.5)
|
584 |
+
|
585 |
+
# Add main line
|
586 |
+
ax.plot(angles, values,
|
587 |
+
linewidth=2.5, linestyle='solid',
|
588 |
+
label=model, color=color, alpha=0.85,
|
589 |
+
zorder=5)
|
590 |
+
|
591 |
+
# Adjust axis settings
|
592 |
+
ax.set_theta_offset(np.pi / 2)
|
593 |
+
ax.set_theta_direction(-1)
|
594 |
+
ax.set_yticklabels([])
|
595 |
+
|
596 |
+
# Draw category labels
|
597 |
+
ax.set_thetagrids(np.degrees(angles[:-1]), categories,
|
598 |
+
fontsize=12, fontweight="bold")
|
599 |
+
|
600 |
+
# Add scale labels with improved positioning
|
601 |
+
for i, (category, angle) in enumerate(zip(categories, angles[:-1])):
|
602 |
+
max_val = max_values[category]
|
603 |
+
scales = np.linspace(0, max_val, 5)
|
604 |
+
|
605 |
+
for j, scale in enumerate(scales):
|
606 |
+
radius = j/4
|
607 |
+
|
608 |
+
# Skip zero as we'll add it separately in the center
|
609 |
+
if radius > 0:
|
610 |
+
ha = 'center'
|
611 |
+
va = 'center'
|
612 |
+
|
613 |
+
ax.text(angle, radius, f'{scale:.2f}',
|
614 |
+
ha=ha, va=va,
|
615 |
+
color='gray', fontsize=9, fontweight='bold')
|
616 |
+
|
617 |
+
# Add centered zero
|
618 |
+
ax.text(0, 0, '0.00',
|
619 |
+
ha='center', va='center',
|
620 |
+
color='gray', fontsize=9, fontweight='bold')
|
621 |
+
|
622 |
+
# Customize grid with softer lines
|
623 |
+
ax.grid(True, color='gray', alpha=0.3, linewidth=0.5)
|
624 |
+
ax.yaxis.grid(True, color='gray', alpha=0.3, linewidth=0.5)
|
625 |
+
ax.set_rticks(np.linspace(0, 1, 5))
|
626 |
+
|
627 |
+
# Add a subtle background color to the entire plot
|
628 |
+
ax.set_facecolor('#f8f9fa')
|
629 |
+
|
630 |
+
# Set title and legend
|
631 |
+
plt.title(f'Model Performance Radar Chart - {metric_name}',
|
632 |
+
pad=20, fontsize=14, fontweight="bold")
|
633 |
+
plt.legend(title="Model", fontsize=11,
|
634 |
+
loc="upper right", bbox_to_anchor=(1.3, 1))
|
635 |
+
|
636 |
+
# Adjust aspect ratio
|
637 |
+
ax.set_aspect('equal')
|
638 |
+
|
639 |
+
plt.tight_layout()
|
640 |
+
return fig
|
641 |
+
|
642 |
+
def transform_to_leaderboard(self):
|
643 |
+
# Remove class-level rows, keeping only category-level rows
|
644 |
+
df = self.combined_df.copy()
|
645 |
+
df = df[~df['Class'].str.contains('Overall', na=False)]
|
646 |
+
|
647 |
+
# Pivot the table so that each category gets its own set of columns
|
648 |
+
pivoted_df = df.pivot_table(
|
649 |
+
index='Model',
|
650 |
+
columns='Category',
|
651 |
+
values=['Precision', 'Recall', 'F1 Score', 'Accuracy'],
|
652 |
+
aggfunc='mean' # Take mean in case of multiple entries
|
653 |
+
)
|
654 |
+
|
655 |
+
# Flatten the multi-level columns
|
656 |
+
pivoted_df.columns = ['_'.join(col).strip() for col in pivoted_df.columns.values]
|
657 |
+
|
658 |
+
# Calculate the average F1 score across all categories for ranking
|
659 |
+
pivoted_df['Average F1 Score'] = pivoted_df.filter(like='F1 Score').mean(axis=1)
|
660 |
+
|
661 |
+
# Move 'Average F1 Score' to be the first column after 'Model'
|
662 |
+
pivoted_df = pivoted_df.reset_index()
|
663 |
+
cols = ['Model', 'Average F1 Score'] + [col for col in pivoted_df.columns if col not in ['Model', 'Average F1 Score']]
|
664 |
+
pivoted_df = pivoted_df[cols]
|
665 |
+
|
666 |
+
# Rank models based on their average F1 Score
|
667 |
+
pivoted_df = pivoted_df.sort_values(by='Average F1 Score', ascending=False).reset_index(drop=True)
|
668 |
+
pivoted_df.insert(0, 'Rank', range(1, len(pivoted_df) + 1))
|
669 |
+
|
670 |
+
return pivoted_df
|
nexar_logo.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
plotly
|
4 |
+
seaborn
|
5 |
+
scikit-learn
|
6 |
+
matplotlib
|
results/GPT-4o.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/Gemini-2.0-flash-lite.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/Gemini-2.0-flash.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/Gemini-2.0-pro.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/Labels.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|