Spaces:
Running
Running
| import os | |
| import random | |
| import time | |
| import numpy as np | |
| import gradio as gr | |
| import pandas as pd | |
| import zipfile | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from huggingface_hub import HfApi, snapshot_download | |
| from datasets import load_dataset | |
| from src.utils import load_all_data | |
| from src.md import ABOUT_TEXT, TOP_TEXT, SUBMIT_TEXT | |
| from src.css import custom_css | |
| api = HfApi() | |
| COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN") | |
| eval_set_repo_id = "KwaiVGI/VideoGen-RewardBench" | |
| eval_set_dir = "dataset" | |
| eval_results_dir = "evals" | |
| def restart_space(): | |
| api.restart_space(repo_id=eval_set_repo_id, token=COLLAB_TOKEN) | |
| color_map = { | |
| "Generative": "#7497db", | |
| "Custom Classifiers": "#E8ECF2", | |
| "Seq. Classifiers": "#ffcd75", | |
| "DPO": "#75809c", | |
| } | |
| def color_model_type_column(df, color_map): | |
| """ | |
| Apply color to the 'Model Type' column of the DataFrame based on a given color mapping. | |
| Parameters: | |
| df (pd.DataFrame): The DataFrame containing the 'Model Type' column. | |
| color_map (dict): A dictionary mapping model types to colors. | |
| Returns: | |
| pd.Styler: The styled DataFrame. | |
| """ | |
| # Function to apply color based on the model type | |
| def apply_color(val): | |
| color = color_map.get(val, "default") # Default color if not specified in color_map | |
| return f'background-color: {color}' | |
| # Format for different columns | |
| format_dict = {col: "{:.2f}" for col in df.columns if col not in ['Avg.', 'Model', 'Model Type']} | |
| format_dict['Avg.'] = "{:.2f}" | |
| format_dict[''] = "{:d}" | |
| return df.style.applymap(apply_color, subset=['Model Type']).format(format_dict, na_rep='') | |
| def regex_table(dataframe, regex, filter_button, style=True): | |
| """ | |
| Takes a Model as a regex, then returns only the rows that has that in it. | |
| """ | |
| # Split regex statement by comma and trim whitespace around regexes | |
| regex_list = [x.strip() for x in regex.split(",")] | |
| # Join the list into a single regex pattern with '|' acting as OR | |
| combined_regex = '|'.join(regex_list) | |
| update_scores = False | |
| if isinstance(filter_button, list) or isinstance(filter_button, str): | |
| if "Seq. Classifiers" not in filter_button: | |
| dataframe = dataframe[~dataframe["Model Type"].str.contains("Seq. Classifiers", case=False, na=False)] | |
| if "Custom Classifiers" not in filter_button: | |
| dataframe = dataframe[~dataframe["Model Type"].str.contains("Custom Classifiers", case=False, na=False)] | |
| if "Generative" not in filter_button: | |
| dataframe = dataframe[~dataframe["Model Type"].str.contains("Generative", case=False, na=False)] | |
| if "w/o Ties" not in filter_button: | |
| dataframe = dataframe[[col for col in dataframe.columns if "w/o Ties" not in col]] | |
| if "w/ Ties" not in filter_button: | |
| dataframe = dataframe[[col for col in dataframe.columns if "w/ Ties" not in col]] | |
| # Filter the dataframe such that 'model' contains any of the regex patterns | |
| data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)] | |
| data.reset_index(drop=True, inplace=True) | |
| data.insert(0, '', range(len(data))) | |
| data = color_model_type_column(data, color_map) | |
| return data | |
| repo = snapshot_download( | |
| local_dir=eval_set_dir, | |
| repo_id=eval_set_repo_id, | |
| use_auth_token=COLLAB_TOKEN, | |
| tqdm_class=None, | |
| etag_timeout=30, | |
| repo_type="dataset", | |
| ) | |
| with zipfile.ZipFile(os.path.join(eval_set_dir, 'videos.zip'), 'r') as zip_ref: | |
| zip_ref.extractall(eval_set_dir) | |
| rewardbench_data = load_all_data(eval_results_dir).sort_values(by='Avg.', ascending=False) | |
| col_types_rewardbench = ["number"] + ["markdown"]+ ["str"] + ["number"] * (len(rewardbench_data.columns) - 1) | |
| # for showing random samples | |
| eval_set = pd.read_csv(os.path.join(eval_set_dir, 'videogen-rewardbench.csv')) | |
| subsets = list(eval_set['prompt'].unique()) | |
| # N=20 | |
| # if len(subsets) > N: | |
| # random.seed(time.time()) | |
| # subsets = random.sample(subsets, N) | |
| def random_sample(selected_prompts): | |
| # Filter the eval_set based on the selected prompts | |
| filtered_data = eval_set[eval_set['prompt'] == selected_prompts] | |
| if filtered_data.empty: | |
| return "No data available for the selected prompt(s)." | |
| # Randomly select a sample from the filtered data | |
| sample = filtered_data.sample(n=1, random_state=int(time.time())).iloc[0] | |
| # Prepare the markdown text with the required fields | |
| markdown_text = f"**Prompt**: {sample['prompt']}\n\n\n" | |
| markdown_text += f"**Preference**: \n" | |
| markdown_text += "| **Visual Quality** | **Motion Quality** | **Text Alignment** | **Overall** | **A_model** | **B_model** |\n" | |
| markdown_text += "|:------------------:|:------------------:|:------------------:|:-----------:|:-----------:|:-----------:|\n" | |
| markdown_text += "| " | |
| markdown_text += f"{'A>B' if sample['VQ'] == 'A' else 'A<B' if sample['VQ'] == 'B' else 'A=B'} | " | |
| markdown_text += f"{'A>B' if sample['MQ'] == 'A' else 'A<B' if sample['MQ'] == 'B' else 'A=B'} | " | |
| markdown_text += f"{'A>B' if sample['TA'] == 'A' else 'A<B' if sample['TA'] == 'B' else 'A=B'} | " | |
| markdown_text += f"{'A>B' if sample['Overall'] == 'A' else 'A<B' if sample['Overall'] == 'B' else 'A=B'} | " | |
| markdown_text += f"{sample['A_model']} | {sample['B_model']} |\n" | |
| # Load and display videos from path_A and path_B | |
| video_a = gr.Video(value=os.path.join(eval_set_dir, sample['path_A'])) | |
| video_b = gr.Video(value=os.path.join(eval_set_dir, sample['path_B'])) | |
| return markdown_text, video_a, video_b | |
| total_models = len(rewardbench_data) | |
| with gr.Blocks(css=custom_css) as app: | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| gr.Markdown(TOP_TEXT.format(str(total_models))) | |
| with gr.Column(scale=3): | |
| gr.Markdown(""" | |
| <img src="https://i.postimg.cc/rpMSzBnV/logo.png" style="width:800px;" alt="Logo"> | |
| """) | |
| with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
| with gr.TabItem("🏆 VideoGen-RewardBench Leaderboard"): | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| search_1 = gr.Textbox(label="Model Search (delimit with , )", | |
| placeholder="Model Search (delimit with , )", | |
| show_label=False) | |
| with gr.Column(scale=6): | |
| model_types_1 = gr.CheckboxGroup(["Seq. Classifiers", "Custom Classifiers", "Generative", "w/o Ties", "w/ Ties"], | |
| value=["Seq. Classifiers", "Custom Classifiers", "Generative", "w/o Ties", "w/ Ties"], | |
| label="Model Types", | |
| show_label=False) | |
| with gr.Row(): | |
| # reference data | |
| rewardbench_table_hidden = gr.Dataframe( | |
| rewardbench_data, | |
| datatype=col_types_rewardbench, | |
| headers=rewardbench_data.columns.tolist(), | |
| visible=False, | |
| ) | |
| rewardbench_table = gr.Dataframe( | |
| regex_table(rewardbench_data.copy(), "", ["Seq. Classifiers", "Custom Classifiers", "Generative", "Others", "w/o Ties", "w/ Ties"]), | |
| datatype=col_types_rewardbench, | |
| headers=rewardbench_data.columns.tolist(), | |
| elem_id="rewardbench_dataframe_avg", | |
| # height=1000, | |
| ) | |
| with gr.Row(): | |
| gr.Markdown(ABOUT_TEXT) | |
| with gr.TabItem("📤 How to Submit"): | |
| with gr.Row(): | |
| gr.Markdown(SUBMIT_TEXT) | |
| with gr.TabItem("🔍 Dataset Viewer"): | |
| with gr.Row(): | |
| # loads one sample | |
| gr.Markdown("""## Random Dataset Sample Viewer""") | |
| subset_selector = gr.Dropdown(subsets, label="Subset", value=None, multiselect=False) | |
| button = gr.Button("Show Random Sample") | |
| with gr.Row(): | |
| sample_display = gr.Markdown("{sampled data loads here}") | |
| with gr.Row(): | |
| video_a_display = gr.Video() | |
| video_b_display = gr.Video() | |
| button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display, video_a_display, video_b_display]) | |
| search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, model_types_1], outputs=rewardbench_table) | |
| model_types_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, model_types_1], outputs=rewardbench_table) | |
| with gr.Row(): | |
| with gr.Accordion("📚 Citation", open=False): | |
| citation_button = gr.Textbox( | |
| value=r"""@article{liu2025improving, | |
| title={Improving Video Generation with Human Feedback}, | |
| author={Liu, Jie and Liu, Gongye and Liang, Jiajun and Yuan, Ziyang and Liu, Xiaokun and Zheng, Mingwu and Wu, Xiele and Wang, Qiulin and Qin, Wenyu and Xia, Menghan and others}, | |
| journal={arXiv preprint arXiv:2501.13918}, | |
| year={2025} | |
| }""", | |
| lines=5, | |
| label="Copy the following to cite these results.", | |
| elem_id="citation-button", | |
| show_copy_button=True, | |
| ) | |
| scheduler = BackgroundScheduler() | |
| scheduler.add_job(restart_space, "interval", seconds=1800) | |
| scheduler.start() | |
| app.queue(default_concurrency_limit=40).launch() |