import os
import pickle
import tempfile

import pandas as pd
import gradio as gr
import plotly.express as px
from datetime import datetime
from huggingface_hub import HfApi
from apscheduler.schedulers.background import BackgroundScheduler

from utils import (
    KEY_TO_CATEGORY_NAME,
    CAT_NAME_TO_EXPLANATION,
    download_latest_data_from_space,
    get_constants,
    update_release_date_mapping,
    format_data,
    get_trendlines,
    find_crossover_point,
)

###################
### Initialize scheduler
###################


def restart_space():
    HfApi(token=os.getenv("HF_TOKEN", None)).restart_space(
        repo_id="andrewrreed/closed-vs-open-arena-elo"
    )
    print(f"Space restarted on {datetime.now()}")


# restart the space every day at 9am
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "cron", day_of_week="mon-sun", hour=7, minute=0)
scheduler.start()

###################
### Load Data
###################

# gather ELO data
latest_elo_file_local = download_latest_data_from_space(
    repo_id="lmsys/chatbot-arena-leaderboard", file_type="pkl"
)

with open(latest_elo_file_local, "rb") as fin:
    elo_results = pickle.load(fin)

# TO-DO: need to also include vision
elo_results = elo_results["text"]

arena_dfs = {}
for k in KEY_TO_CATEGORY_NAME.keys():
    if k not in elo_results:
        continue
    arena_dfs[KEY_TO_CATEGORY_NAME[k]] = elo_results[k]["leaderboard_table_df"]

# gather open llm leaderboard data
latest_leaderboard_file_local = download_latest_data_from_space(
    repo_id="lmsys/chatbot-arena-leaderboard", file_type="csv"
)
leaderboard_df = pd.read_csv(latest_leaderboard_file_local)

# load release date mapping data
release_date_mapping = pd.read_json("release_date_mapping.json", orient="records")

###################
### Prepare Data
###################

# update release date mapping with new models
# check for new models in ELO data
new_model_keys_to_add = [
    model
    for model in arena_dfs["Overall"].index.to_list()
    if model not in release_date_mapping["key"].to_list()
]
if new_model_keys_to_add:
    release_date_mapping = update_release_date_mapping(
        new_model_keys_to_add, leaderboard_df, release_date_mapping
    )

# merge leaderboard data with ELO data
merged_dfs = {}
for k, v in arena_dfs.items():
    merged_dfs[k] = (
        pd.merge(arena_dfs[k], leaderboard_df, left_index=True, right_on="key")
        .sort_values("rating", ascending=False)
        .reset_index(drop=True)
    )

# add release dates into the merged data
for k, v in merged_dfs.items():
    merged_dfs[k] = pd.merge(
        merged_dfs[k], release_date_mapping[["key", "Release Date"]], on="key"
    )

# format dataframes
merged_dfs = {k: format_data(v) for k, v in merged_dfs.items()}

# get constants
min_elo_score, max_elo_score, upper_models_per_month = get_constants(merged_dfs)
date_updated = elo_results["full"]["last_updated_datetime"].split(" ")[0]
orgs = merged_dfs["Overall"].Organization.unique().tolist()

###################
### Build and Plot Data
###################


def get_data_split(dfs, set_name):
    df = dfs[set_name].copy(deep=True)
    return df.reset_index(drop=True)


def clean_df_for_display(df):

    df = df.loc[
        :,
        [
            "Model",
            "rating",
            "MMLU",
            "MT-bench (score)",
            "Release Date",
            "Organization",
            "License",
            "Link",
        ],
    ].rename(columns={"rating": "ELO Score", "MT-bench (score)": "MT-Bench"})

    df["Release Date"] = df["Release Date"].astype(str)
    df.sort_values("ELO Score", ascending=False, inplace=True)
    df.reset_index(drop=True, inplace=True)
    return df


def filter_df(min_score, max_models_per_month, set_selector, org_selector):
    df = get_data_split(merged_dfs, set_name=set_selector)

    # filter data
    filtered_df = df[
        (df["rating"] >= min_score) & (df["Organization"].isin(org_selector))
    ]

    filtered_df = (
        filtered_df.groupby(["Month-Year", "License"], group_keys=False)
        .apply(lambda x: x.nlargest(max_models_per_month, "rating"))
        .reset_index(drop=True)
    )

    return filtered_df


def build_plot(toggle_annotations, filtered_df):

    # construct plot
    custom_colors = {"Open LLM": "#ff7f0e", "Proprietary LLM": "#1f77b4"}
    fig = px.scatter(
        filtered_df,
        x="Release Date",
        y="rating",
        color="License",
        hover_name="Model",
        hover_data=["Organization", "License", "Link"],
        trendline="ols",
        title=f"Open vs Proprietary LLMs by LMSYS Arena ELO Score<br>(as of {date_updated})",
        labels={"rating": "Arena ELO", "Release Date": "Release Date"},
        height=700,
        template="plotly_dark",
        color_discrete_map=custom_colors,
    )

    fig.update_layout(
        plot_bgcolor="rgba(0,0,0,0)",  # Set background color to transparent
        paper_bgcolor="rgba(0,0,0,0)",  # Set paper (plot) background color to transparent
        title={"x": 0.5},
    )

    fig.update_traces(marker=dict(size=10, opacity=0.6))

    # calculate days until crossover
    trend1, trend2 = get_trendlines(fig)
    crossover = find_crossover_point(
        b1=trend1[0], m1=trend1[1], b2=trend2[0], m2=trend2[1]
    )
    days_til_crossover = (
        pd.to_datetime(crossover, unit="s") - pd.Timestamp.today()
    ).days

    # add annotation with number of models and days til crossover
    fig.add_annotation(
        xref="paper",
        yref="paper",
        x=0.01,
        y=1.13,
        text=(
            f"<b>Number of models:</b> {len(filtered_df)}<br>"
            f"<b>Days until crossover:</b> {days_til_crossover}"
        ),
        showarrow=False,
        font=dict(size=16, color="white"),  # Increased font size
        bgcolor="rgba(0,0,0,0.7)",  # Darker background for better contrast
        bordercolor="rgba(255,255,255,0.3)",  # Light border
        borderwidth=2,
        borderpad=8,  # More padding
        align="left",
    )

    if toggle_annotations:
        # get the points to annotate (only the highest rated model per month per license)
        idx_to_annotate = filtered_df.groupby(["Month-Year", "License"])[
            "rating"
        ].idxmax()
        points_to_annotate_df = filtered_df.loc[idx_to_annotate]

        for i, row in points_to_annotate_df.iterrows():
            fig.add_annotation(
                x=row["Release Date"],
                y=row["rating"],
                text=row["Model"],
                showarrow=True,
                arrowhead=0,
            )

    return fig, clean_df_for_display(filtered_df)


set_dark_mode = """
function refresh() {
    const url = new URL(window.location);

    if (url.searchParams.get('__theme') !== 'dark') {
        url.searchParams.set('__theme', 'dark');
        window.location.href = url.href;
    }
}
"""

with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue=gr.themes.colors.sky,
        secondary_hue=gr.themes.colors.green,
        # spacing_size=gr.themes.sizes.spacing_sm,
        text_size=gr.themes.sizes.text_sm,
        font=[
            gr.themes.GoogleFont("Open Sans"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        ],
    ),
    js=set_dark_mode,
) as demo:
    gr.Markdown(
        """
        <div style="text-align: center; max-width: 650px; margin: auto;">
            <h1 style="font-weight: 900; margin-top: 5px;">🔬 Progress Tracker: Open vs. Proprietary LLMs 🔬</h1>
            <p style="text-align: left; margin-top: 30px; margin-bottom: 30px; line-height: 20px;">
            This app visualizes the progress of proprietary and open-source LLMs over time as scored by the <a href="https://leaderboard.lmsys.org/">LMSYS Chatbot Arena</a>.
            The idea is inspired by <a href="https://www.linkedin.com/posts/maxime-labonne_arena-elo-graph-updated-with-new-models-activity-7187062633735368705-u2jB">this great work</a> 
            from <a href="https://huggingface.co/mlabonne/">Maxime Labonne</a>, and is intended to stay up-to-date as new models are released and evaluated.
            <div style="text-align: left;">
            <strong>Plot info:</strong>
            <br>
            <ul style="padding-left: 20px;">
                <li> The ELO score (y-axis) is a measure of the relative strength of a model based on its performance against other models in the arena. </li>
                <li> The Release Date (x-axis) corresponds to when the model was first publicly released or when its ELO results were first reported (for ease of automated updates). </li>
                <li> Trend lines are based on Ordinary Least Squares (OLS) regression and adjust based on the filter criteria. </li>
            <ul>
            </div>
            </p>
        </div>
        """
    )
    with gr.Group():
        with gr.Row(variant="compact"):
            set_selector = gr.Dropdown(
                choices=list(CAT_NAME_TO_EXPLANATION.keys()),
                label="Select Category",
                value="Overall",
                info="Select the category to visualize",
            )
            min_score = gr.Slider(
                minimum=min_elo_score,
                maximum=max_elo_score,
                value=min_elo_score,
                step=50,
                label="Minimum ELO Score",
                info="Filter out low scoring models",
            )
            max_models_per_month = gr.Slider(
                value=3,
                minimum=1,
                maximum=upper_models_per_month,
                step=1,
                label="Max Models per Month (per License)",
                info="Limit to N best models per month per license to reduce clutter",
            )
            toggle_annotations = gr.Radio(
                choices=[True, False],
                label="Overlay Best Model Name",
                value=True,
                info="Toggle to overlay the name of the best model per month per license",
            )
        with gr.Row(variant="compact"):
            with gr.Accordion("More options", open=False):
                org_selector = gr.Dropdown(
                    choices=sorted(orgs),
                    label="Filter by Organization",
                    value=sorted(orgs),
                    multiselect=True,
                    info="Limit organizations included in plot",
                )

    # Show plot
    filtered_df = gr.State()
    with gr.Group():
        with gr.Tab("Plot"):
            plot = gr.Plot(show_label=False)
        with gr.Tab("Raw Data"):
            display_df = gr.DataFrame(interactive=False)
            with gr.Row():
                # Empty column to push download controls to the right
                with gr.Column(scale=3):
                    gr.Markdown("")
                # Download controls on the right
                with gr.Column(scale=1):
                    download_button = gr.Button(
                        "📥 Download Data", variant="secondary", size="sm"
                    )
                    csv_file = gr.File(label="Download CSV", visible=False)

    def create_download(df):
        with tempfile.NamedTemporaryFile(
            mode="w", delete=False, suffix=".csv"
        ) as temp_file:
            df.to_csv(temp_file.name, index=False)
        return temp_file.name, gr.update(visible=True)

    download_button.click(
        fn=create_download,
        inputs=[display_df],
        outputs=[csv_file, csv_file],
    )

    demo.load(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    min_score.change(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    max_models_per_month.change(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    toggle_annotations.change(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    set_selector.change(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    org_selector.change(
        fn=filter_df,
        inputs=[min_score, max_models_per_month, set_selector, org_selector],
        outputs=filtered_df,
    ).then(
        fn=build_plot,
        inputs=[toggle_annotations, filtered_df],
        outputs=[plot, display_df],
    )

    gr.Markdown(
        """
                <div style="text-align: center; max-width: 650px; margin: auto;">
                <p style="margin-top: 40px;"> If you have any questions, feel free to open a discussion or <a href="https://twitter.com/andrewrreed">reach out to me on social</a>. </p>
                </p>
                </div>
                """
    )

demo.launch()