Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from pathlib import Path | |
from src.json_leaderboard import create_leaderboard_df | |
from src.about import ( | |
CITATION_BUTTON_TEXT, | |
INTRODUCTION_TEXT, | |
LINKS_AND_INFO, | |
TITLE, | |
) | |
from src.display.css_html_js import custom_css | |
# 固定列,永远在前面 | |
FIXED_COLUMNS = ["Model Name (clickable)", "Release Date", "HF Model", "Open Source"] | |
def get_json_df(): | |
"""Load the leaderboard DataFrame""" | |
json_path = Path(__file__).parent / "leaderboard_data.json" | |
df = create_leaderboard_df(str(json_path)) | |
return df | |
# 提取大类及其子类 | |
def extract_categories_and_subs(df): | |
""" | |
返回 {大类: {"overall": 大类列, "subs": [子类列]}} | |
大类列以 '-Overall' 结尾,紧跟其后的列为子类 | |
""" | |
category_dict = {} | |
all_cols = list(df.columns) | |
skip_cols = set(FIXED_COLUMNS + ["Overall"]) | |
i = 0 | |
while i < len(all_cols): | |
col = all_cols[i] | |
if col.endswith("-Overall") and col not in skip_cols: | |
cat_name = col.replace("-Overall", "") | |
subs = [] | |
j = i + 1 | |
while j < len(all_cols): | |
next_col = all_cols[j] | |
if next_col.endswith("-Overall") or next_col in skip_cols: | |
break | |
subs.append(next_col) | |
j += 1 | |
category_dict[cat_name] = {"overall": col, "subs": subs} | |
i += 1 | |
return category_dict | |
# 列过滤函数,保持固定列 + 用户选择列 + 顺序不变 | |
def filtered_leaderboard(df, selected_columns): | |
selected_columns = selected_columns or [] | |
final_cols = FIXED_COLUMNS + [col for col in df.columns if col in selected_columns and col not in FIXED_COLUMNS] | |
return df[final_cols] | |
# Update functions | |
def update_leaderboard_overall(selected_cols, df_overall): | |
return filtered_leaderboard(df_overall, selected_cols) | |
def update_leaderboard_cat(selected_cols, df_cat): | |
return filtered_leaderboard(df_cat, selected_cols) | |
# 初始化 | |
df = get_json_df() | |
ALL_COLUMNS_ORDERED = list(df.columns) | |
categories = extract_categories_and_subs(df) | |
# 可选列 = 全部列 - 固定列 | |
optional_columns = [col for col in df.columns if col not in FIXED_COLUMNS] | |
# Gradio interface | |
demo = gr.Blocks(css=custom_css, title="UniGenBench Leaderboard") | |
with demo: | |
gr.HTML(TITLE) | |
gr.HTML(LINKS_AND_INFO) | |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
# Overall leaderboard | |
with gr.TabItem("🏅 Overall Leaderboard", elem_id="tab-overall"): | |
selected_columns_overall = gr.CheckboxGroup( | |
choices=optional_columns, | |
label="Select additional columns to display", | |
value=optional_columns | |
) | |
leaderboard_table = gr.Dataframe( | |
value=df[ALL_COLUMNS_ORDERED], | |
headers=list(df.columns), | |
datatype=["html" if col in ["Model Name (clickable)","HF Model"] else "str" for col in df.columns], | |
interactive=False, | |
wrap=False | |
) | |
selected_columns_overall.change( | |
fn=update_leaderboard_overall, | |
inputs=[selected_columns_overall, gr.State(value=df)], | |
outputs=leaderboard_table | |
) | |
# 每个大类 leaderboard | |
for cat_name, info in categories.items(): | |
with gr.TabItem(f"🏆 {cat_name}", elem_id=f"tab-{cat_name}"): | |
cat_cols = [info["overall"]] + info["subs"] | |
cat_df = df[FIXED_COLUMNS + cat_cols] | |
optional_columns_cat = [col for col in cat_cols if col not in FIXED_COLUMNS] | |
selected_columns_cat = gr.CheckboxGroup( | |
choices=optional_columns_cat, | |
label=f"Select additional columns for {cat_name}", | |
value=optional_columns_cat | |
) | |
leaderboard_table_cat = gr.Dataframe( | |
value=cat_df, | |
headers=list(cat_df.columns), | |
datatype=["html" if col in ["Model Name (clickable)","HF Model"] else "str" for col in cat_df.columns], | |
interactive=False, | |
wrap=False | |
) | |
selected_columns_cat.change( | |
fn=update_leaderboard_cat, | |
inputs=[selected_columns_cat, gr.State(value=cat_df)], | |
outputs=leaderboard_table_cat | |
) | |
# Citation | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## 📙 Citation") | |
gr.Markdown("If you use [UniGenBench]() in your research, please cite our work:") | |
citation_textbox = gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
elem_id="citation-textbox", | |
show_label=False, | |
interactive=False, | |
lines=8, | |
show_copy_button=True | |
) | |
if __name__ == "__main__": | |
demo.launch() | |