Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ import re
|
|
| 8 |
|
| 9 |
def load_leaderboard():
|
| 10 |
# Load validation / test CSV files
|
| 11 |
-
|
| 12 |
results_csv_files = {
|
| 13 |
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
|
| 14 |
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
|
|
@@ -32,6 +31,7 @@ def load_leaderboard():
|
|
| 32 |
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
|
| 33 |
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
|
| 34 |
bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
|
|
|
|
| 35 |
|
| 36 |
# Clean up dataframes
|
| 37 |
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
|
|
@@ -88,6 +88,7 @@ def load_leaderboard():
|
|
| 88 |
|
| 89 |
return merged_dataframes
|
| 90 |
|
|
|
|
| 91 |
REGEX_PREFIX = "re:"
|
| 92 |
|
| 93 |
def auto_match(pattern, text):
|
|
@@ -99,36 +100,34 @@ def auto_match(pattern, text):
|
|
| 99 |
except re.error:
|
| 100 |
# If it's an invalid regex, return False
|
| 101 |
return False
|
| 102 |
-
|
| 103 |
# Check if it's a wildcard pattern
|
| 104 |
elif any(char in pattern for char in ['*', '?']):
|
| 105 |
return fnmatch.fnmatch(text.lower(), pattern.lower())
|
| 106 |
-
|
| 107 |
# If not regex or wildcard, use fuzzy matching
|
| 108 |
else:
|
| 109 |
return fuzz.partial_ratio(
|
| 110 |
pattern.lower(), text.lower(), score_cutoff=90) > 0
|
| 111 |
|
| 112 |
-
|
| 113 |
def filter_leaderboard(df, model_name, sort_by):
|
| 114 |
if not model_name:
|
| 115 |
return df.sort_values(by=sort_by, ascending=False)
|
| 116 |
-
|
| 117 |
mask = df['model'].apply(lambda x: auto_match(model_name, x))
|
| 118 |
filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
|
| 119 |
-
|
| 120 |
-
return filtered_df
|
| 121 |
|
|
|
|
| 122 |
|
| 123 |
-
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
| 124 |
selected_color = 'orange'
|
| 125 |
-
|
| 126 |
fig = px.scatter(
|
| 127 |
df,
|
| 128 |
x=x_axis,
|
| 129 |
y=y_axis,
|
| 130 |
-
log_x=
|
| 131 |
-
log_y=
|
| 132 |
hover_data=['model'],
|
| 133 |
trendline='ols',
|
| 134 |
trendline_options=dict(log_x=True, log_y=True),
|
|
@@ -144,19 +143,18 @@ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
|
| 144 |
legend_labels[False] = f'{model_filter or "all models"}'
|
| 145 |
else:
|
| 146 |
legend_labels[False] = f'{model_filter or "all models"}'
|
| 147 |
-
|
| 148 |
# Update legend
|
| 149 |
for trace in fig.data:
|
| 150 |
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
| 151 |
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
| 152 |
-
|
| 153 |
fig.update_layout(
|
| 154 |
showlegend=True,
|
| 155 |
legend_title_text='Model Selection'
|
| 156 |
)
|
| 157 |
-
|
| 158 |
-
return fig
|
| 159 |
|
|
|
|
| 160 |
|
| 161 |
# Load the leaderboard data
|
| 162 |
merged_dataframes = load_leaderboard()
|
|
@@ -217,7 +215,6 @@ def update_leaderboard_and_plot(
|
|
| 217 |
|
| 218 |
return display_df, fig
|
| 219 |
|
| 220 |
-
|
| 221 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
| 222 |
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
|
| 223 |
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
|
|
@@ -265,4 +262,4 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
| 265 |
log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
| 266 |
update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
| 267 |
|
| 268 |
-
app.launch()
|
|
|
|
| 8 |
|
| 9 |
def load_leaderboard():
|
| 10 |
# Load validation / test CSV files
|
|
|
|
| 11 |
results_csv_files = {
|
| 12 |
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
|
| 13 |
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
|
|
|
|
| 31 |
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
|
| 32 |
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
|
| 33 |
bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
|
| 34 |
+
print(bench_dataframes.keys())
|
| 35 |
|
| 36 |
# Clean up dataframes
|
| 37 |
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
|
|
|
|
| 88 |
|
| 89 |
return merged_dataframes
|
| 90 |
|
| 91 |
+
|
| 92 |
REGEX_PREFIX = "re:"
|
| 93 |
|
| 94 |
def auto_match(pattern, text):
|
|
|
|
| 100 |
except re.error:
|
| 101 |
# If it's an invalid regex, return False
|
| 102 |
return False
|
| 103 |
+
|
| 104 |
# Check if it's a wildcard pattern
|
| 105 |
elif any(char in pattern for char in ['*', '?']):
|
| 106 |
return fnmatch.fnmatch(text.lower(), pattern.lower())
|
| 107 |
+
|
| 108 |
# If not regex or wildcard, use fuzzy matching
|
| 109 |
else:
|
| 110 |
return fuzz.partial_ratio(
|
| 111 |
pattern.lower(), text.lower(), score_cutoff=90) > 0
|
| 112 |
|
|
|
|
| 113 |
def filter_leaderboard(df, model_name, sort_by):
|
| 114 |
if not model_name:
|
| 115 |
return df.sort_values(by=sort_by, ascending=False)
|
| 116 |
+
|
| 117 |
mask = df['model'].apply(lambda x: auto_match(model_name, x))
|
| 118 |
filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
return filtered_df
|
| 121 |
|
| 122 |
+
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter, log_x, log_y):
|
| 123 |
selected_color = 'orange'
|
| 124 |
+
|
| 125 |
fig = px.scatter(
|
| 126 |
df,
|
| 127 |
x=x_axis,
|
| 128 |
y=y_axis,
|
| 129 |
+
log_x=log_x,
|
| 130 |
+
log_y=log_y,
|
| 131 |
hover_data=['model'],
|
| 132 |
trendline='ols',
|
| 133 |
trendline_options=dict(log_x=True, log_y=True),
|
|
|
|
| 143 |
legend_labels[False] = f'{model_filter or "all models"}'
|
| 144 |
else:
|
| 145 |
legend_labels[False] = f'{model_filter or "all models"}'
|
| 146 |
+
|
| 147 |
# Update legend
|
| 148 |
for trace in fig.data:
|
| 149 |
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
| 150 |
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
| 151 |
+
|
| 152 |
fig.update_layout(
|
| 153 |
showlegend=True,
|
| 154 |
legend_title_text='Model Selection'
|
| 155 |
)
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
return fig
|
| 158 |
|
| 159 |
# Load the leaderboard data
|
| 160 |
merged_dataframes = load_leaderboard()
|
|
|
|
| 215 |
|
| 216 |
return display_df, fig
|
| 217 |
|
|
|
|
| 218 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
| 219 |
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
|
| 220 |
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
|
|
|
|
| 262 |
log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
| 263 |
update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
|
| 264 |
|
| 265 |
+
app.launch(debug=True)
|