Spaces:
Running
Running
| import fnmatch | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| from rapidfuzz import fuzz | |
| import re | |
| def load_leaderboard(): | |
| # Load validation / test CSV files | |
| pd.set_option('display.float_format', '{:.2f}'.format) | |
| results_csv_files = { | |
| 'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv', | |
| 'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv', | |
| 'v2': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenetv2-matched-frequency.csv', | |
| 'sketch': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-sketch.csv', | |
| 'a': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-a.csv', | |
| 'r': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-r.csv', | |
| } | |
| # Load benchmark CSV files | |
| benchmark_csv_files = { | |
| 'amp-nchw-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090.csv', | |
| 'amp-nhwc-pt240-cu124-rtx4090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx4090.csv', | |
| 'amp-nchw-pt240-cu124-rtx4090-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx4090-dynamo.csv', | |
| 'amp-nchw-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv', | |
| 'amp-nhwc-pt240-cu124-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv', | |
| 'fp32-nchw-pt240-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i9_10940x-dynamo.csv', | |
| 'fp32-nchw-pt240-cpu-i7_12700h-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt240-cpu-i7_12700h-dynamo.csv', | |
| } | |
| dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()} | |
| bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()} | |
| bench_dataframes = {name: df for name, df in bench_dataframes.items() if 'infer_gmacs' in df.columns} | |
| print(bench_dataframes.keys()) | |
| # Clean up dataframes | |
| remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"] | |
| for df in dataframes.values(): | |
| for col in remove_column_names: | |
| if col in df.columns: | |
| df.drop(columns=[col], inplace=True) | |
| # Rename / process results columns | |
| for name, df in dataframes.items(): | |
| df.rename(columns={"top1": f"{name}_top1", "top5": f"{name}_top5"}, inplace=True) | |
| df['arch_name'] = df['model'].apply(lambda x: x.split('.')[0]) | |
| # Process benchmark dataframes | |
| for name, df in bench_dataframes.items(): | |
| df['arch_name'] = df['model'] | |
| df.rename(columns={'infer_img_size': 'img_size'}, inplace=True) | |
| # Merge all result dataframes | |
| result = dataframes['imagenet'] | |
| for name, df in dataframes.items(): | |
| if name != 'imagenet': | |
| result = pd.merge(result, df, on=['arch_name', 'model', 'img_size', 'crop_pct', 'interpolation'], how='outer') | |
| # Calculate average scores | |
| top1_columns = [col for col in result.columns if col.endswith('_top1') and not col == 'a_top1'] | |
| top5_columns = [col for col in result.columns if col.endswith('_top5') and not col == 'a_top5'] | |
| result['avg_top1'] = result[top1_columns].mean(axis=1) | |
| result['avg_top5'] = result[top5_columns].mean(axis=1) | |
| # Create fully merged dataframes for each benchmark set | |
| merged_dataframes = {} | |
| for bench_name, bench_df in bench_dataframes.items(): | |
| merged_df = pd.merge(result, bench_df, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark')) | |
| # Calculate TFLOP/s | |
| merged_df['infer_tflop_s'] = merged_df['infer_samples_per_sec'] * merged_df['infer_gmacs'] * 2 / 1000 | |
| # Reorder columns | |
| first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5'] | |
| other_columns = [col for col in merged_df.columns if col not in first_columns] | |
| merged_df = merged_df[first_columns + other_columns].copy(deep=True) | |
| # Drop columns that are no longer needed / add too much noise | |
| merged_df.drop('arch_name', axis=1, inplace=True) | |
| merged_df.drop('crop_pct', axis=1, inplace=True) | |
| merged_df.drop('interpolation', axis=1, inplace=True) | |
| merged_df.drop('model_benchmark', axis=1, inplace=True) | |
| merged_df['infer_usec_per_sample'] = 1e6 / merged_df.infer_samples_per_sec | |
| merged_df['highlighted'] = False | |
| merged_df = merged_df.round(2) | |
| merged_dataframes[bench_name] = merged_df | |
| return merged_dataframes | |
| REGEX_PREFIX = "re:" | |
| def auto_match(pattern, text): | |
| # Check if it's a regex pattern (starts with 're:') | |
| if pattern.startswith(REGEX_PREFIX): | |
| regex_pattern = pattern[len(REGEX_PREFIX):].strip() | |
| try: | |
| return bool(re.match(regex_pattern, text, re.IGNORECASE)) | |
| except re.error: | |
| # If it's an invalid regex, return False | |
| return False | |
| # Check if it's a wildcard pattern | |
| elif any(char in pattern for char in ['*', '?']): | |
| return fnmatch.fnmatch(text.lower(), pattern.lower()) | |
| # If not regex or wildcard, use fuzzy matching | |
| else: | |
| return fuzz.partial_ratio( | |
| pattern.lower(), text.lower(), score_cutoff=90) > 0 | |
| def filter_leaderboard(df, model_name, sort_by): | |
| if not model_name: | |
| return df.sort_values(by=sort_by, ascending=False) | |
| mask = df['model'].apply(lambda x: auto_match(model_name, x)) | |
| filtered_df = df[mask].sort_values(by=sort_by, ascending=False) | |
| return filtered_df | |
| def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter): | |
| selected_color = 'orange' | |
| fig = px.scatter( | |
| df, | |
| x=x_axis, | |
| y=y_axis, | |
| log_x=True, | |
| log_y=True, | |
| hover_data=['model'], | |
| trendline='ols', | |
| trendline_options=dict(log_x=True, log_y=True), | |
| color='highlighted', | |
| color_discrete_map={True: selected_color, False: 'blue'}, | |
| title=f'{y_axis} vs {x_axis}' | |
| ) | |
| # Create legend labels | |
| legend_labels = {} | |
| if highlight_filter: | |
| legend_labels[True] = f'{highlight_filter}' | |
| legend_labels[False] = f'{model_filter or "all models"}' | |
| else: | |
| legend_labels[False] = f'{model_filter or "all models"}' | |
| # Update legend | |
| for trace in fig.data: | |
| if isinstance(trace.marker.color, str): # This is for the scatter traces | |
| trace.name = legend_labels.get(trace.marker.color == selected_color, '') | |
| fig.update_layout( | |
| showlegend=True, | |
| legend_title_text='Model Selection' | |
| ) | |
| return fig | |
| # Load the leaderboard data | |
| merged_dataframes = load_leaderboard() | |
| # Define the available columns for sorting and plotting | |
| sort_columns = ['avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5', 'infer_samples_per_sec', 'infer_usec_per_sample', 'param_count', 'infer_gmacs', 'infer_macts', 'infer_tflop_s'] | |
| plot_columns = ['infer_samples_per_sec', 'infer_usec_per_sample', 'infer_gmacs', 'infer_macts', 'infer_tflop_s', 'param_count', 'avg_top1', 'avg_top5', 'imagenet_top1', 'imagenet_top5'] | |
| DEFAULT_SEARCH = "" | |
| DEFAULT_SORT = "avg_top1" | |
| DEFAULT_X = "infer_samples_per_sec" | |
| DEFAULT_Y = "avg_top1" | |
| DEFAULT_BM = 'amp-nchw-pt240-cu124-rtx4090' | |
| def col_formatter(value, precision=None): | |
| if isinstance(value, int): | |
| return f'{value:d}' | |
| elif isinstance(value, float): | |
| return f'{value:.{precision}f}' if precision is not None else f'{value:g}' | |
| return str(value) | |
| def update_leaderboard_and_plot( | |
| model_name=DEFAULT_SEARCH, | |
| highlight_name=None, | |
| sort_by=DEFAULT_SORT, | |
| x_axis=DEFAULT_X, | |
| y_axis=DEFAULT_Y, | |
| benchmark_selection=DEFAULT_BM, | |
| log_x=True, | |
| log_y=True, | |
| ): | |
| df = merged_dataframes[benchmark_selection].copy() | |
| filtered_df = filter_leaderboard(df, model_name, sort_by) | |
| # Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint | |
| highlight_df = filter_leaderboard(df, highlight_name, sort_by) if highlight_name else None | |
| # Combine filtered_df and highlight_df, removing duplicates | |
| if highlight_df is not None: | |
| combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True) | |
| combined_df = combined_df.sort_values(by=sort_by, ascending=False) | |
| combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model']) | |
| else: | |
| combined_df = filtered_df | |
| combined_df['highlighted'] = False | |
| fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name, log_x, log_y) | |
| display_df = combined_df.drop(columns=['highlighted']) | |
| display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format( | |
| { | |
| 'infer_batch_size': lambda x: col_formatter(x), # Integer column | |
| }, | |
| precision=2, | |
| ) | |
| return display_df, fig | |
| with gr.Blocks(title="The timm Leaderboard") as app: | |
| gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>") | |
| gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>") | |
| gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>") | |
| with gr.Row(): | |
| search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3) | |
| sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1) | |
| with gr.Row(): | |
| highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient") | |
| with gr.Row(): | |
| x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X) | |
| y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y) | |
| with gr.Row(): | |
| benchmark_dropdown = gr.Dropdown( | |
| choices=list(merged_dataframes.keys()), | |
| label="Benchmark Selection", | |
| value=DEFAULT_BM, | |
| ) | |
| with gr.Row(): | |
| log_x = gr.Checkbox(label="Log scale X-axis", value=True) | |
| log_y = gr.Checkbox(label="Log scale Y-axis", value=True) | |
| update_btn = gr.Button(value="Update", variant="primary") | |
| leaderboard = gr.Dataframe() | |
| plot = gr.Plot() | |
| inputs = [search_bar, highlight_bar, sort_dropdown, x_axis, y_axis, benchmark_dropdown, log_x, log_y] | |
| outputs = [leaderboard, plot] | |
| app.load(update_leaderboard_and_plot, outputs=outputs) | |
| search_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| highlight_bar.submit(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| sort_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| x_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| y_axis.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| benchmark_dropdown.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| log_x.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| log_y.change(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| update_btn.click(update_leaderboard_and_plot, inputs=inputs, outputs=outputs) | |
| app.launch() |