import json
from pathlib import Path
import gradio as gr
import pandas as pd
TITLE = """
LLM Leaderboard for H4 Models
"""
DESCRIPTION = f"""
Evaluation of H4 and community models across a diverse range of benchmarks from [LightEval](https://github.com/huggingface/lighteval). All scores are reported as accuracy.
"""
BENCHMARKS_TO_SKIP = ["math", "mini_math"]
def get_leaderboard_df(merge_values: bool = True):
    filepaths = list(Path("eval_results").rglob("*.json"))
    # Parse filepaths to get unique models
    models = set()
    for filepath in filepaths:
        path_parts = Path(filepath).parts
        model_revision = "_".join(path_parts[1:4])
        models.add(model_revision)
    # Initialize DataFrame
    df = pd.DataFrame(index=list(models))
    # Extract data from each file and populate the DataFrame
    for filepath in filepaths:
        path_parts = Path(filepath).parts
        date = filepath.stem.split("_")[-1][:-3].split("T")[0]
        model_revision = "_".join(path_parts[1:4]) + "_" + date
        task = path_parts[4]  # .capitalize()
        df.loc[model_revision, "Date"] = date
        with open(filepath, "r") as file:
            data = json.load(file)
            first_result_key = next(iter(data["results"]))  # gets the first key in 'results'
            # Skip benchmarks that we don't want to include in the leaderboard
            if task.lower() in BENCHMARKS_TO_SKIP:
                continue
            # TruthfulQA has two metrics, so we need to pick the `mc2` one that's reported on the leaderboard
            if task.lower() == "truthfulqa":
                value = data["results"][first_result_key]["truthfulqa_mc2"]
            # IFEval has several metrics but we report just the prompt-loose-acc one
            elif task.lower() == "ifeval":
                value = data["results"][first_result_key]["prompt_level_loose_acc"]
            # MMLU has several metrics but we report just the average one
            elif task.lower() == "mmlu":
                value = [v["acc"] for k, v in data["results"].items() if "_average" in k.lower()][0]
            # HellaSwag and ARC reports acc_norm
            elif task.lower() in ["hellaswag", "arc"]:
                value = data["results"][first_result_key]["acc_norm"]
            # BBH has several metrics but we report just the average one
            elif task.lower() == "bbh":
                if "all" in data["results"]:
                    value = data["results"]["all"]["acc"]
                else:
                    value = -100
            # AGIEval reports acc_norm
            elif task.lower() == "agieval":
                value = data["results"]["all"]["acc_norm"]
            # MATH reports qem
            elif task.lower() in ["math", "math_v2", "aimo_kaggle"]:
                value = data["results"]["all"]["qem"]
            else:
                first_metric_key = next(
                    iter(data["results"][first_result_key])
                )  # gets the first key in the first result
                value = data["results"][first_result_key][first_metric_key]  # gets the value of the first metric
            # For mini_math we report 5 metrics, one for each level and store each one as a separate row in the dataframe
            if task.lower() in ["mini_math_v2"]:
                for k, v in data["results"].items():
                    if k != "all":
                        level = k.split("|")[1].split(":")[-1]
                        value = v["qem"]
                        df.loc[model_revision, f"{task}_{level}"] = value
            # For kaggle_pot we report N metrics, one for each prompt and store each one as a separate row in the dataframe
            elif task.lower() in ["aimo_kaggle_medium_pot"]:
                for k, v in data["results"].items():
                    if k != "all" and "_average" not in k:
                        version = k.split("|")[1].split(":")[-1]
                        value = v["qem"] if "qem" in v else v["score"]
                        df.loc[model_revision, f"{task}_{version}"] = value
            # For kaggle_pot we report N metrics, one for each prompt and store each one as a separate row in the dataframe
            elif task.lower() in ["aimo_kaggle_hard_pot"]:
                for k, v in data["results"].items():
                    if k != "all" and "_average" not in k:
                        version = k.split("|")[1].split(":")[-1]
                        value = v["qem"] if "qem" in v else v["score"]
                        df.loc[model_revision, f"{task}_{version}"] = value
            # For kaggle_tora we report accuracy, so need  to divide by 100
            elif task.lower() in [
                "aimo_tora_eval_kaggle_medium",
                "aimo_tora_eval_kaggle_hard",
                "aimo_kaggle_fast_eval_hard",
                "aimo_kaggle_tora_medium",
                "aimo_kaggle_tora_hard",
            ]:
                for k, v in data["results"].items():
                    value = float(v["qem"]) / 100.0
                    df.loc[model_revision, f"{task}"] = value
            # For AlpacaEval we report base winrate and lenght corrected one
            elif task.lower() == "alpaca_eval":
                value = data["results"][first_result_key]["win_rate"]
                df.loc[model_revision, "Alpaca_eval"] = value / 100.0
                value = data["results"][first_result_key]["length_controlled_winrate"]
                df.loc[model_revision, "Alpaca_eval_lc"] = value / 100.0
            else:
                df.loc[model_revision, task] = float(value)
    # Drop rows where every entry is NaN
    df = df.dropna(how="all", axis=0, subset=[c for c in df.columns if c != "Date"])
    # Trim minimath column names
    df.columns = [c.replace("_level_", "_l") for c in df.columns]
    # Trim AIMO column names
    df.columns = [c.replace("aimo_", "") for c in df.columns]
    df.insert(loc=1, column="Average", value=df.mean(axis=1, numeric_only=True))
    # Convert all values to percentage
    df[df.select_dtypes(include=["number"]).columns] *= 100.0
    df = df.sort_values(by=["Average"], ascending=False)
    df = df.reset_index().rename(columns={"index": "Model"}).round(2)
    # Strip off date from model name
    df["Model"] = df["Model"].apply(lambda x: x.rsplit("_", 1)[0])
    if merge_values:
        merged_df = df.drop(["Date", "Average"], axis=1).groupby("Model").max().reset_index()
        merged_df.insert(loc=0, column="Average", value=merged_df.mean(axis=1, numeric_only=True))
        df = df[["Model", "Date"]].merge(merged_df, on="Model", how="left")
        df.drop_duplicates(subset=["Model"], inplace=True)
        df = df.sort_values(by=["Average"], ascending=False).round(2)
    return df
def refresh(merge_values: bool = True):
    return get_leaderboard_df(merge_values)
# Function to update the table based on search query
def update_table(search_query):
    df = get_leaderboard_df()
    if search_query:
        search_terms = search_query.split(";")
        search_terms = [term.strip().lower() for term in search_terms]
        pattern = "|".join(search_terms)
        df = df[df["Model"].str.lower().str.contains(pattern, regex=True)]
        # Drop any columns which are all NaN
        df = df.dropna(how="all", axis=1)
    return df
def filter_columns(cols):
    index_cols = list(leaderboard_df.columns[:2])
    new_cols = index_cols + cols
    df = get_leaderboard_df()
    df = df.copy()[new_cols]
    # Drop rows with NaN values
    df = df.copy().dropna(how="all", axis=0, subset=[c for c in df.columns if c in cols])
    # Recompute average
    df["Average"] = df.mean(axis=1, numeric_only=True)
    return df
leaderboard_df = get_leaderboard_df()
demo = gr.Blocks()
with demo:
    gr.HTML(TITLE)
    with gr.Column():
        gr.Markdown(DESCRIPTION, elem_classes="markdown-text")
        with gr.Row():
            search_bar = gr.Textbox(placeholder="Search for your model...", show_label=False)
            merge_values = gr.Checkbox(
                value=True,
                label="Merge evals",
                info="Merge evals for the same model. If there are duplicates, we display the largest one.",
            )
        with gr.Row():
            cols_bar = gr.CheckboxGroup(
                choices=[c for c in leaderboard_df.columns[2:] if c != "Average"],
                show_label=False,
                info="Select columns to display",
            )
        with gr.Group():
            leaderboard_df = get_leaderboard_df()
            leaderboard_table = gr.Dataframe(
                value=leaderboard_df,
                wrap=True,
                height=1000,
                column_widths=[400, 110] + [(220 + len(c)) for c in leaderboard_df.columns[2:]],
            )
        with gr.Row():
            refresh_button = gr.Button("Refresh")
    cols_bar.change(filter_columns, inputs=[cols_bar], outputs=[leaderboard_table])
    merge_values.change(refresh, inputs=[merge_values], outputs=[leaderboard_table])
    search_bar.submit(update_table, inputs=[search_bar], outputs=[leaderboard_table])
    refresh_button.click(refresh, inputs=[], outputs=[leaderboard_table])
demo.launch()