Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import re
|
|
8 |
|
9 |
def load_leaderboard():
|
10 |
# Load validation / test CSV files
|
11 |
-
|
12 |
results_csv_files = {
|
13 |
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
|
14 |
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
|
@@ -32,6 +31,7 @@ def load_leaderboard():
|
|
32 |
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
|
33 |
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
|
34 |
bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
|
|
|
35 |
|
36 |
# Clean up dataframes
|
37 |
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
|
@@ -88,6 +88,7 @@ def load_leaderboard():
|
|
88 |
|
89 |
return merged_dataframes
|
90 |
|
|
|
91 |
REGEX_PREFIX = "re:"
|
92 |
|
93 |
def auto_match(pattern, text):
|
@@ -99,36 +100,34 @@ def auto_match(pattern, text):
|
|
99 |
except re.error:
|
100 |
# If it's an invalid regex, return False
|
101 |
return False
|
102 |
-
|
103 |
# Check if it's a wildcard pattern
|
104 |
elif any(char in pattern for char in ['*', '?']):
|
105 |
return fnmatch.fnmatch(text.lower(), pattern.lower())
|
106 |
-
|
107 |
# If not regex or wildcard, use fuzzy matching
|
108 |
else:
|
109 |
return fuzz.partial_ratio(
|
110 |
pattern.lower(), text.lower(), score_cutoff=90) > 0
|
111 |
|
112 |
-
|
113 |
def filter_leaderboard(df, model_name, sort_by):
|
114 |
if not model_name:
|
115 |
return df.sort_values(by=sort_by, ascending=False)
|
116 |
-
|
117 |
mask = df['model'].apply(lambda x: auto_match(model_name, x))
|
118 |
filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
|
119 |
-
|
120 |
-
return filtered_df
|
121 |
|
|
|
122 |
|
123 |
-
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
124 |
selected_color = 'orange'
|
125 |
-
|
126 |
fig = px.scatter(
|
127 |
df,
|
128 |
x=x_axis,
|
129 |
y=y_axis,
|
130 |
-
log_x=
|
131 |
-
log_y=
|
132 |
hover_data=['model'],
|
133 |
trendline='ols',
|
134 |
trendline_options=dict(log_x=True, log_y=True),
|
@@ -144,19 +143,18 @@ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
|
144 |
legend_labels[False] = f'{model_filter or "all models"}'
|
145 |
else:
|
146 |
legend_labels[False] = f'{model_filter or "all models"}'
|
147 |
-
|
148 |
# Update legend
|
149 |
for trace in fig.data:
|
150 |
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
151 |
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
152 |
-
|
153 |
fig.update_layout(
|
154 |
showlegend=True,
|
155 |
legend_title_text='Model Selection'
|
156 |
)
|
157 |
-
|
158 |
-
return fig
|
159 |
|
|
|
160 |
|
161 |
# Load the leaderboard data
|
162 |
merged_dataframes = load_leaderboard()
|
@@ -217,7 +215,6 @@ def update_leaderboard_and_plot(
|
|
217 |
|
218 |
return display_df, fig
|
219 |
|
220 |
-
|
221 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
222 |
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
|
223 |
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
|
@@ -265,4 +262,4 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
265 |
log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
266 |
update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
267 |
|
268 |
-
app.launch()
|
|
|
8 |
|
9 |
def load_leaderboard():
|
10 |
# Load validation / test CSV files
|
|
|
11 |
results_csv_files = {
|
12 |
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
|
13 |
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
|
|
|
31 |
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
|
32 |
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
|
33 |
bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
|
34 |
+
print(bench_dataframes.keys())
|
35 |
|
36 |
# Clean up dataframes
|
37 |
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
|
|
|
88 |
|
89 |
return merged_dataframes
|
90 |
|
91 |
+
|
92 |
REGEX_PREFIX = "re:"
|
93 |
|
94 |
def auto_match(pattern, text):
|
|
|
100 |
except re.error:
|
101 |
# If it's an invalid regex, return False
|
102 |
return False
|
103 |
+
|
104 |
# Check if it's a wildcard pattern
|
105 |
elif any(char in pattern for char in ['*', '?']):
|
106 |
return fnmatch.fnmatch(text.lower(), pattern.lower())
|
107 |
+
|
108 |
# If not regex or wildcard, use fuzzy matching
|
109 |
else:
|
110 |
return fuzz.partial_ratio(
|
111 |
pattern.lower(), text.lower(), score_cutoff=90) > 0
|
112 |
|
|
|
113 |
def filter_leaderboard(df, model_name, sort_by):
|
114 |
if not model_name:
|
115 |
return df.sort_values(by=sort_by, ascending=False)
|
116 |
+
|
117 |
mask = df['model'].apply(lambda x: auto_match(model_name, x))
|
118 |
filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
|
|
|
|
|
119 |
|
120 |
+
return filtered_df
|
121 |
|
122 |
+
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter, log_x, log_y):
|
123 |
selected_color = 'orange'
|
124 |
+
|
125 |
fig = px.scatter(
|
126 |
df,
|
127 |
x=x_axis,
|
128 |
y=y_axis,
|
129 |
+
log_x=log_x,
|
130 |
+
log_y=log_y,
|
131 |
hover_data=['model'],
|
132 |
trendline='ols',
|
133 |
trendline_options=dict(log_x=True, log_y=True),
|
|
|
143 |
legend_labels[False] = f'{model_filter or "all models"}'
|
144 |
else:
|
145 |
legend_labels[False] = f'{model_filter or "all models"}'
|
146 |
+
|
147 |
# Update legend
|
148 |
for trace in fig.data:
|
149 |
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
150 |
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
151 |
+
|
152 |
fig.update_layout(
|
153 |
showlegend=True,
|
154 |
legend_title_text='Model Selection'
|
155 |
)
|
|
|
|
|
156 |
|
157 |
+
return fig
|
158 |
|
159 |
# Load the leaderboard data
|
160 |
merged_dataframes = load_leaderboard()
|
|
|
215 |
|
216 |
return display_df, fig
|
217 |
|
|
|
218 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
219 |
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
|
220 |
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
|
|
|
262 |
log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
263 |
update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
264 |
|
265 |
+
app.launch(debug=True)
|