rwightman HF staff commited on
Commit
bd26425
·
verified ·
1 Parent(s): 6316ef5

Update app.py

Browse files

Add support for selecting between multiple benchmark sets. Change log scale of x/y

Files changed (1) hide show
  1. app.py +93 -46
app.py CHANGED
@@ -17,16 +17,19 @@ def load_leaderboard():
17
  }
18
 
19
  # Load benchmark CSV files
20
- main_bench = 'amp-nchw-pt240-cu124-rtx4090'
21
  benchmark_csv_files = {
22
  'amp-nchw-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090.csv',
23
- 'amp-nhwc-pt210-cu121-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt210-cu121-rtx3090.csv',
24
- 'fp32-nchw-pt221-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt221-cpu-i9_10940x-dynamo.csv',
 
 
 
 
25
  }
26
- # FIXME support selecting benchmark 'infer_samples_per_sec' / 'infer_step_time' from different benchmark files.
27
 
28
  dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
29
  bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
 
30
  main_bench_dataframe = bench_dataframes[main_bench]
31
 
32
  # Clean up dataframes
@@ -68,17 +71,31 @@ def load_leaderboard():
68
  other_columns = [col for col in result.columns if col not in first_columns and col != 'model_benchmark']
69
  result = result[first_columns + other_columns]
70
 
71
- # Drop columns that are no longer needed / add too much noise
72
- result.drop('arch_name', axis=1, inplace=True)
73
- result.drop('crop_pct', axis=1, inplace=True)
74
- result.drop('interpolation', axis=1, inplace=True)
75
 
76
- result['highlighted'] = False
77
-
78
- # Round numerical values
79
- result = result.round(2)
80
-
81
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  REGEX_PREFIX = "re:"
@@ -152,16 +169,26 @@ def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
152
 
153
 
154
  # Load the leaderboard data
155
- full_df = load_leaderboard()
156
 
157
  # Define the available columns for sorting and plotting
158
- sort_columns = ['avg_top1', 'avg_top5', 'infer_samples_per_sec', 'param_count', 'infer_gmacs', 'infer_macts', 'infer_tflop_s']
159
- plot_columns = ['infer_samples_per_sec', 'infer_gmacs', 'infer_macts', 'infer_tflop_s', 'param_count', 'avg_top1', 'avg_top5']
160
 
161
  DEFAULT_SEARCH = ""
162
  DEFAULT_SORT = "avg_top1"
163
  DEFAULT_X = "infer_samples_per_sec"
164
  DEFAULT_Y = "avg_top1"
 
 
 
 
 
 
 
 
 
 
165
 
166
  def update_leaderboard_and_plot(
167
  model_name=DEFAULT_SEARCH,
@@ -169,12 +196,17 @@ def update_leaderboard_and_plot(
169
  sort_by=DEFAULT_SORT,
170
  x_axis=DEFAULT_X,
171
  y_axis=DEFAULT_Y,
 
 
 
172
  ):
173
- filtered_df = filter_leaderboard(full_df, model_name, sort_by)
174
-
 
 
175
  # Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
176
- highlight_df = filter_leaderboard(full_df, highlight_name, sort_by) if highlight_name else None
177
-
178
  # Combine filtered_df and highlight_df, removing duplicates
179
  if highlight_df is not None:
180
  combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
@@ -182,10 +214,17 @@ def update_leaderboard_and_plot(
182
  combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
183
  else:
184
  combined_df = filtered_df
185
-
186
- fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name)
187
- display_df = combined_df.drop(columns=['highlighted'])
188
- display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
 
 
 
 
 
 
 
189
  return display_df, fig
190
 
191
 
@@ -193,39 +232,47 @@ with gr.Blocks(title="The timm Leaderboard") as app:
193
  gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
194
  gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
195
  gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
196
-
197
  with gr.Row():
198
  search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
199
  sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
200
 
201
  with gr.Row():
202
  highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
203
-
204
  with gr.Row():
205
  x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
206
  y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y)
207
-
 
 
 
 
 
 
 
 
 
 
 
208
  update_btn = gr.Button(value="Update", variant="primary")
209
 
210
  leaderboard = gr.Dataframe()
211
  plot = gr.Plot()
212
-
213
- app.load(update_leaderboard_and_plot, outputs=[leaderboard, plot])
214
-
215
- search_bar.submit(
216
- update_leaderboard_and_plot,
217
- inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
218
- outputs=[leaderboard, plot]
219
- )
220
- highlight_bar.submit(
221
- update_leaderboard_and_plot,
222
- inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
223
- outputs=[leaderboard, plot]
224
- )
225
- update_btn.click(
226
- update_leaderboard_and_plot,
227
- inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
228
- outputs=[leaderboard, plot]
229
- )
230
 
231
  app.launch()
 
17
  }
18
 
19
  # Load benchmark CSV files
 
20
  benchmark_csv_files = {
21
  'amp-nchw-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090.csv',
22
+ 'amp-nhwc-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx4090.csv',
23
+ 'amp-nchw-pt240-cu124-rtx4090-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090-dynamo.csv',
24
+ 'amp-nchw-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv',
25
+ 'amp-nhwc-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv',
26
+ 'fp32-nchw-pt240-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i9_10940x-dynamo.csv',
27
+ 'fp32-nchw-pt240-cpu-i7_12700h-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i7_12700h-dynamo.csv',
28
  }
 
29
 
30
  dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
31
  bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
32
+ bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns}
33
  main_bench_dataframe = bench_dataframes[main_bench]
34
 
35
  # Clean up dataframes
 
71
  other_columns = [col for col in result.columns if col not in first_columns and col != 'model_benchmark']
72
  result = result[first_columns + other_columns]
73
 
74
+ # Create fully merged dataframes for each benchmark set
75
+ merged_dataframes = {}
76
+ for bench_name, bench_df in bench_dataframes.items():
77
+ merged_df = pd.merge(result, bench_df, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark'))
78
 
79
+ # Calculate TFLOP/s
80
+ merged_df['infer_tflop_s'] = merged_df['infer_samples_per_sec'] * merged_df['infer_gmacs'] * 2 / 1000
81
+
82
+ # Reorder columns
83
+ first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5']
84
+ other_columns = [col for col in merged_df.columns if col not in first_columns]
85
+ merged_df = merged_df[first_columns + other_columns].copy()
86
+
87
+ # Drop columns that are no longer needed / add too much noise
88
+ merged_df.drop('arch_name', axis=1, inplace=True)
89
+ merged_df.drop('crop_pct', axis=1, inplace=True)
90
+ merged_df.drop('interpolation', axis=1, inplace=True)
91
+ merged_df.drop('model_benchmark', axis=1, inplace=True)
92
+ merged_df['infer_usec_per_sample'] = 1e6 / merged_df.infer_samples_per_sec
93
+
94
+ merged_df['highlighted'] = False
95
+ merged_df = merged_df.round(2)
96
+ merged_dataframes[bench_name] = merged_df
97
+
98
+ return merged_dataframes
99
 
100
 
101
  REGEX_PREFIX = "re:"
 
169
 
170
 
171
  # Load the leaderboard data
172
+ merged_dataframes = load_leaderboard()
173
 
174
  # Define the available columns for sorting and plotting
175
+ sort_columns = ['avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5', 'infer_samples_per_sec', 'infer_usec_per_sample', 'param_count', 'infer_gmacs', 'infer_macts', 'infer_tflop_s']
176
+ plot_columns = ['infer_samples_per_sec', 'infer_usec_per_sample', 'infer_gmacs', 'infer_macts', 'infer_tflop_s', 'param_count', 'avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5']
177
 
178
  DEFAULT_SEARCH = ""
179
  DEFAULT_SORT = "avg_top1"
180
  DEFAULT_X = "infer_samples_per_sec"
181
  DEFAULT_Y = "avg_top1"
182
+ DEFAULT_BM = 'amp-nchw-pt240-cu124-rtx4090'
183
+
184
+
185
+ def col_formatter(value, precision=None):
186
+ if isinstance(value, int):
187
+ return f'{value:d}'
188
+ elif isinstance(value, float):
189
+ return f'{value:.{precision}f}' if precision is not None else f'{value:g}'
190
+ return str(value)
191
+
192
 
193
  def update_leaderboard_and_plot(
194
  model_name=DEFAULT_SEARCH,
 
196
  sort_by=DEFAULT_SORT,
197
  x_axis=DEFAULT_X,
198
  y_axis=DEFAULT_Y,
199
+ benchmark_selection=DEFAULT_BM,
200
+ log_x=True,
201
+ log_y=True,
202
  ):
203
+ df = merged_dataframes[benchmark_selection].copy()
204
+
205
+ filtered_df = filter_leaderboard(df, model_name, sort_by)
206
+
207
  # Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
208
+ highlight_df = filter_leaderboard(df, highlight_name, sort_by) if highlight_name else None
209
+
210
  # Combine filtered_df and highlight_df, removing duplicates
211
  if highlight_df is not None:
212
  combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
 
214
  combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
215
  else:
216
  combined_df = filtered_df
217
+ combined_df['highlighted'] = False
218
+
219
+ fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name, log_x, log_y)
220
+ display_df = combined_df.drop(columns=['highlighted'])
221
+ display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(
222
+ {
223
+ 'infer_batch_size': lambda x: col_formatter(x), # Integer column
224
+ },
225
+ precision=2,
226
+ )
227
+
228
  return display_df, fig
229
 
230
 
 
232
  gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>")
233
  gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>")
234
  gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
235
+
236
  with gr.Row():
237
  search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
238
  sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
239
 
240
  with gr.Row():
241
  highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
242
+
243
  with gr.Row():
244
  x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
245
  y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y)
246
+
247
+ with gr.Row():
248
+ benchmark_dropdown = gr.Dropdown(
249
+ choices=list(merged_dataframes.keys()),
250
+ label="Benchmark Selection",
251
+ value=DEFAULT_BM,
252
+ )
253
+
254
+ with gr.Row():
255
+ log_x = gr.Checkbox(label="Log scale X-axis", value=True)
256
+ log_y = gr.Checkbox(label="Log scale Y-axis", value=True)
257
+
258
  update_btn = gr.Button(value="Update", variant="primary")
259
 
260
  leaderboard = gr.Dataframe()
261
  plot = gr.Plot()
262
+
263
+ inputs = [search_bar, highlight_bar, sort_dropdown, x_axis, y_axis, benchmark_dropdown, log_x, log_y]
264
+ outputs = [leaderboard, plot]
265
+
266
+ app.load(update_leaderboard_and_plot, outputs=outputs)
267
+
268
+ search_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
269
+ highlight_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
270
+ sort_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
271
+ x_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
272
+ y_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
273
+ benchmark_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
274
+ log_x.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
275
+ log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
276
+ update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs)
 
 
 
277
 
278
  app.launch()