"""
This program helps us explore model's responses to the benchmark. It is a web
app that displays the following:

1. A list of benchmark items loaded from puzzles_cleaned.csv. The list shows
   the columns ID, challenge, and answer.
2. When we select a puzzle from the list, we see the transcript, Explanation,
   and Editor's Note in textboxes. (Scrollable since they can be long.)
3. The list in (1) also has a column for each model, with checkboxes indicating 
   whether the model's response is correct or not. We load the model responses
   from results.duckdb. That file has a table called completions with
   columns 'prompt_id', 'parent_dir', and 'completion'. The prompt_id can be
   joined with ID from puzzles_cleaned.csv. The parent_dir is the model name.
   The completion is the model response, which we compare with the answer from 
   puzzles_cleaned.csv using the function check_answer defined below.
4. Finally, when an item is selected from the list, we get a dropdown that lets
   us select a model to see the completion from that model.

Note that not every model has a response for every puzzle.
"""
import gradio as gr
import pandas as pd
import numpy as np
from metrics import load_results_sample_one_only, accuracy_by_model_and_time
import metrics
from pathlib import Path

def get_model_response(prompt_id, model_name):
    query = f"""
        SELECT completion FROM sampled
        WHERE prompt_id = {prompt_id} AND parent_dir = '{model_name}'
    """
    response = conn.sql(query).fetchone()
    return response[0] if response else None

def display_puzzle(puzzle_id):
    query = f"""
        SELECT challenge, answer, transcript, Explanation, "Editor's Notes"
        FROM challenges
        WHERE ID = {puzzle_id}
    """
    puzzle = conn.sql(query).fetchone()
    return puzzle if puzzle else (None, None,None, None, None)

def display_model_response(puzzle_id, model_name, show_thoughts):
    response = get_model_response(puzzle_id, model_name)
    if response is None:
        return "No response from this model."
    split_thoughts = response.split("</think>")
    if len(split_thoughts) > 1:
        if show_thoughts:
            return response.strip()
        else:
            return split_thoughts[-1].strip()
    else:
        return response.strip()


conn = load_results_sample_one_only()

# Get all unique model names
model_names = [item[0] for item in conn.sql("SELECT DISTINCT parent_dir FROM sampled").fetchall()]
model_names.sort()
# Just for display.
cleaned_model_names = [name.replace("completions-", "") for name in model_names]


def build_table():
    # Construct the query to create two columns for each model: MODEL_answer and MODEL_ok
    query = """
        SELECT c.ID, c.challenge, wrap_text(c.answer, 40) AS answer,
    """

    model_correct_columns = []
    for model in model_names:
        normalized_model_name = model.replace("-", "_")
        model_correct_columns.append(normalized_model_name + "_ok")
        query += f"""
            MAX(CASE WHEN r.parent_dir = '{model}' THEN r.completion ELSE NULL END) AS {normalized_model_name}_answer,
            MAX(CASE WHEN r.parent_dir = '{model}' THEN check_answer(r.completion, c.answer) ELSE NULL END) AS {normalized_model_name}_ok,
        """

    query = query.rstrip(',')  # Remove the trailing comma
    query += """
        clip_text(c.challenge, 40) as challenge_clipped,
        FROM challenges c
        LEFT JOIN sampled r
        ON c.ID = r.prompt_id
        GROUP BY c.ID, c.challenge, c.answer
    """

    joined_df = conn.sql(query).fetchdf()

    # Transform the model_correct columns to use emojis
    for model in model_names:
        normalized_model_name = model.replace("-", "_")
        joined_df[normalized_model_name + '_ok'] = joined_df[normalized_model_name + '_ok'].apply(
            lambda x: "✅" if x == 1 else ("❌" if x == 0 else "❓")
        )

    return joined_df, model_correct_columns


joined_df, model_correct_columns = build_table()

relabelled_df = joined_df[['ID', 'challenge_clipped', 'answer', *model_correct_columns]].rename(columns={
    'ID': 'ID',
    'challenge_clipped': 'Challenge',
    'answer': 'Answer',
    **{model.replace("-", "_") + '_ok': model.replace("completions-", "") for model in model_names}
}).sort_values(by='ID')

model_columns = {
    index + 3: name for index, name in enumerate(model_names)
}

valid_model_indices = list(model_columns.keys())
default_model = model_columns[valid_model_indices[0]]

def summary_view():
    accuracy_over_time = accuracy_by_model_and_time(conn).to_df()
    accuracy_over_time["model"] = accuracy_over_time["model"].apply(lambda x: x.replace("completions-", ""))
    # This hack so that Gradio doesn't render a year 2020 as "2,020.0".
    accuracy_over_time["year"] = accuracy_over_time["year"].astype(str)
    accuracy_over_time.rename(columns={"model": "Model", "year": "Year", "accuracy": "Accuracy"}, inplace=True)
    gr.LinePlot(
        accuracy_over_time,
        x="Year",
        y="Accuracy", 
        color="Model",
        title="Model Accuracy Over Time",
        y_lim=[0, 1],
        x_label="Year",
        y_label="Accuracy",
    )


def accuracy_by_completion_length():
    r1_completions = metrics.r1_accuracy_by_completion_length(conn,'completions-r1').to_df()
    gemini2_completions = metrics.r1_accuracy_by_completion_length(conn,'completions-gemini2').to_df()
    qwq_completions = metrics.r1_accuracy_by_completion_length(conn,'completions-qwen32b').to_df()
    sonnetET_completions = metrics.r1_accuracy_by_completion_length(conn,'completions-claude-3-7-sonnet-20250219').to_df()
    r1_completions["model"] = "R1"
    gemini2_completions["model"] = "Gemini2"
    qwq_completions["model"] = "QWQ 32B"
    sonnetET_completions["model"] = "Sonnet 3.7 ET"
    r1_completions = pd.concat([r1_completions, gemini2_completions, qwq_completions, sonnetET_completions])
    
    r1_completions["length"] = r1_completions["length"] / 3.2

    with gr.Blocks(fill_height=True):
        gr.LinePlot(
            r1_completions,
            x="length",
            y="cumulative_accuracy",
            title="Accuracy by Maximum Completion Length",
            x_label="Max Response Length (tokens)",
            y_label="Accuracy (%)",
            x_lim=[0, 32_768],
            y_lim=[0, 1],
            color="model",
        )

def all_challenges_view():
    # Using "markdown" as the datatype makes Gradio interpret newlines.
    puzzle_list = gr.DataFrame(
        value=relabelled_df,
        datatype=["number", "str", "markdown", *["str"] * len(model_correct_columns)],
        # headers=["ID", "Challenge", "Answer", *cleaned_model_names],
    )
    with gr.Row(scale=2):
        model_name = gr.State(value=default_model)
        challenge_id = gr.State(value=0)
        show_thoughts = gr.State(value=False)
        with gr.Column():
            challenge = gr.Textbox(label="Challenge", interactive=False)
            answer = gr.Textbox(label="Answer", interactive=False)
            explanation = gr.Textbox(label="Explanation", interactive=False)
            editors_note = gr.Textbox(label="Editor's Note", interactive=False)

        def show_thoughts_toggle(x):
            return not x
        
        with gr.Column():
            show_thoughts_checkbox = gr.Checkbox(
                label="Show Thoughts", value=False
            ).change(
                fn=show_thoughts_toggle,  inputs=[show_thoughts], outputs=[show_thoughts]
            )
            model_response = gr.Textbox(label="Model Response", interactive=False)
        transcript = gr.Textbox(label="Transcript", interactive=False)

    def select_table_item(evt: gr.SelectData):
        model_index = evt.index[1]
        # challenge_id = evt.index[0]
        row_index = evt.index[0]  # The row index of the selected row
        # Map the row index to the challenge_id (which is the 'ID' in your DataFrame)
        challenge_id = relabelled_df.iloc[row_index]['ID']
        model_name = model_columns[model_index] if model_index in valid_model_indices else default_model
        return (model_name, challenge_id)

    def update_puzzle(challenge_id: str, model_name: str, show_thoughts: bool):
        return (*display_puzzle(challenge_id), 
                gr.Textbox(
                    value=display_model_response(challenge_id, model_name, show_thoughts), 
                    label=model_name
                ))

    puzzle_list.select(
        fn=select_table_item, 
        inputs=[], 
        outputs=[model_name, challenge_id]
    )

    model_name.change(
        fn=update_puzzle, 
        inputs=[challenge_id, model_name, show_thoughts], 
        outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
    )

    challenge_id.change(
        fn=update_puzzle, 
        inputs=[challenge_id, model_name, show_thoughts], 
        outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
    )

    show_thoughts.change(
        fn=update_puzzle, 
        inputs=[challenge_id, model_name, show_thoughts], 
        outputs=[challenge, answer, transcript, explanation, editors_note, model_response]
    )

    

def overview_view():
    with gr.Blocks(fill_height=True):
        with gr.Row():
            readme_text = Path("README.md").read_text()
            # Find the second "---" and remove the text after it.
            readme_text = readme_text.split("---")[2]
            gr.Markdown(readme_text)
        with gr.Row():
            gr.DataFrame(metrics.accuracy_by_model(conn).to_df())


def create_interface():
    with gr.Blocks() as demo:
        with gr.Tabs():
            with gr.TabItem("Overview"):
                overview_view()
            with gr.TabItem("All Challenges"):
                all_challenges_view()
            with gr.TabItem("Accuracy Over Time"):
                summary_view()
            with gr.TabItem("Reasoning Length Analysis"):
                accuracy_by_completion_length()
    demo.launch()

if __name__ == "__main__":
    create_interface()