rwightman HF staff commited on
Commit
16d5f81
·
verified ·
1 Parent(s): f301419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -8,7 +8,6 @@ import re
8
 
9
  def load_leaderboard():
10
  # Load validation / test CSV files
11
-
12
  results_csv_files = {
13
  'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
14
  'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
@@ -32,6 +31,7 @@ def load_leaderboard():
32
  dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
33
  bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
34
  bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
 
35
 
36
  # Clean up dataframes
37
  remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
@@ -88,6 +88,7 @@ def load_leaderboard():
88
 
89
  return merged_dataframes
90
 
 
91
  REGEX_PREFIX = "re:"
92
 
93
  def auto_match(pattern, text):
@@ -99,36 +100,34 @@ def auto_match(pattern, text):
99
  except re.error:
100
  # If it's an invalid regex, return False
101
  return False
102
-
103
  # Check if it's a wildcard pattern
104
  elif any(char in pattern for char in ['*', '?']):
105
  return fnmatch.fnmatch(text.lower(), pattern.lower())
106
-
107
  # If not regex or wildcard, use fuzzy matching
108
  else:
109
  return fuzz.partial_ratio(
110
  pattern.lower(), text.lower(), score_cutoff=90) > 0
111
 
112
-
113
  def filter_leaderboard(df, model_name, sort_by):
114
  if not model_name:
115
  return df.sort_values(by=sort_by, ascending=False)
116
-
117
  mask = df['model'].apply(lambda x: auto_match(model_name, x))
118
  filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
119
-
120
- return filtered_df
121
 
 
122
 
123
- def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
124
  selected_color = 'orange'
125
-
126
  fig = px.scatter(
127
  df,
128
  x=x_axis,
129
  y=y_axis,
130
- log_x=True,
131
- log_y=True,
132
  hover_data=['model'],
133
  trendline='ols',
134
  trendline_options=dict(log_x=True, log_y=True),
@@ -144,19 +143,18 @@ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
144
  legend_labels[False] = f'{model_filter or "all models"}'
145
  else:
146
  legend_labels[False] = f'{model_filter or "all models"}'
147
-
148
  # Update legend
149
  for trace in fig.data:
150
  if isinstance(trace.marker.color, str): # This is for the scatter traces
151
  trace.name = legend_labels.get(trace.marker.color == selected_color, '')
152
-
153
  fig.update_layout(
154
  showlegend=True,
155
  legend_title_text='Model Selection'
156
  )
157
-
158
- return fig
159
 
 
160
 
161
  # Load the leaderboard data
162
  merged_dataframes = load_leaderboard()
@@ -217,7 +215,6 @@ def update_leaderboard_and_plot(
217
 
218
  return display_df, fig
219
 
220
-
221
  with gr.Blocks(title="The timm Leaderboard") as app:
222
  gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
223
  gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
@@ -265,4 +262,4 @@ with gr.Blocks(title="The timm Leaderboard") as app:
265
  log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
266
  update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
267
 
268
- app.launch()
 
8
 
9
  def load_leaderboard():
10
  # Load validation / test CSV files
 
11
  results_csv_files = {
12
  'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
13
  'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
 
31
  dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
32
  bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
33
  bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
34
+ print(bench_dataframes.keys())
35
 
36
  # Clean up dataframes
37
  remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
 
88
 
89
  return merged_dataframes
90
 
91
+
92
  REGEX_PREFIX = "re:"
93
 
94
  def auto_match(pattern, text):
 
100
  except re.error:
101
  # If it's an invalid regex, return False
102
  return False
103
+
104
  # Check if it's a wildcard pattern
105
  elif any(char in pattern for char in ['*', '?']):
106
  return fnmatch.fnmatch(text.lower(), pattern.lower())
107
+
108
  # If not regex or wildcard, use fuzzy matching
109
  else:
110
  return fuzz.partial_ratio(
111
  pattern.lower(), text.lower(), score_cutoff=90) > 0
112
 
 
113
  def filter_leaderboard(df, model_name, sort_by):
114
  if not model_name:
115
  return df.sort_values(by=sort_by, ascending=False)
116
+
117
  mask = df['model'].apply(lambda x: auto_match(model_name, x))
118
  filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
 
 
119
 
120
+ return filtered_df
121
 
122
+ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter, log_x, log_y):
123
  selected_color = 'orange'
124
+
125
  fig = px.scatter(
126
  df,
127
  x=x_axis,
128
  y=y_axis,
129
+ log_x=log_x,
130
+ log_y=log_y,
131
  hover_data=['model'],
132
  trendline='ols',
133
  trendline_options=dict(log_x=True, log_y=True),
 
143
  legend_labels[False] = f'{model_filter or "all models"}'
144
  else:
145
  legend_labels[False] = f'{model_filter or "all models"}'
146
+
147
  # Update legend
148
  for trace in fig.data:
149
  if isinstance(trace.marker.color, str): # This is for the scatter traces
150
  trace.name = legend_labels.get(trace.marker.color == selected_color, '')
151
+
152
  fig.update_layout(
153
  showlegend=True,
154
  legend_title_text='Model Selection'
155
  )
 
 
156
 
157
+ return fig
158
 
159
  # Load the leaderboard data
160
  merged_dataframes = load_leaderboard()
 
215
 
216
  return display_df, fig
217
 
 
218
  with gr.Blocks(title="The timm Leaderboard") as app:
219
  gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
220
  gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
 
262
  log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
263
  update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
264
 
265
+ app.launch(debug=True)