import re
import duckdb
import textwrap
from typing import List, Tuple
import argparse
import unicodedata
import re


import diskcache as dc

cache = dc.Cache("answer_cache")

def normalize_text(text: str) -> str:
    """Normalize text to remove accents, convert to lowercase, and strip spaces."""
    text = unicodedata.normalize("NFKD", text)  # Decomposes letters with accents (e.g., é → e + ́)
    text = "".join([c for c in text if not unicodedata.combining(c)])  # Remove diacritics
    text = text.lower().strip()  # Convert to lowercase and strip spaces
    return text


def _parse_answer(text: str) -> List[List[str]]:
    """
    Converts text to lowercase. Then interprets ";" as a separator between
    alternatives. Within each alternative, interprets "," and "-->" as separators
    for elements of a set. Within each set, drops all non-alphanumeric characters
    and returns that set.


    Another way to describe this is that we interpret adjacent words as
    phrases that must be present literally. However, comma and arrow separate
    distinct phrases that may be present in any order. All other characters
    are dropped.
    """
    text = text.lower()
    text = normalize_text(text)
    alternatives = re.split(r';', text)
    result = [ ]
    for alternative in alternatives:
        groups = re.split(r'–?-?-?>|,', alternative)
        result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
    return result

def _answer_without_thoughts(completion: str) -> str:
    completion = re.sub(r"(<think>)?[^<]*<\/think>", "", completion).strip()
    completion = re.sub(r".*</think>", "", completion).strip() #because qwen sometimes misses <think>
    return completion



def _check_answer(completion: str, answer: str) -> bool:
    """
    Check that all the phrases that must appear in the answer appear in the
    completion. We ignore "thoughts", capitalization, and punctuation.
    """
    key = (completion, answer)
    if key in cache:
        return cache[key]
    
    completion = _answer_without_thoughts(completion).lower()
    completion = completion.replace("**","")
    completion  = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
    completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
    completion = normalize_text(completion)
    alternative_answers = _parse_answer(answer)
    for answer_phrases in alternative_answers:
        # if all(phrase in completion for phrase in answer_phrases):
        if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
            cache[key] = True
            return True
    cache[key] = False
    return False


def _clip_text(text: str, width: int) -> str:
    return text if len(text) <= width else text[:width] + "..."

def _wrap_text(text: str, width: int) -> str:
    return textwrap.fill(text, width=width)

def load_results():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
    # conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
    conn.execute("""
        CREATE TABLE challenges AS 
        SELECT * FROM 'puzzles_cleaned.csv'
        WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
    """)
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def load_results_sample_one_only():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
    
    query = """
    CREATE TABLE sampled AS
    WITH numbered AS (
        SELECT *,
            ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn
        FROM results.completions
    )
    SELECT prompt_id, parent_dir, prompt, completion
    FROM numbered
    WHERE rn = 1;
    """
    conn.execute(query).fetchall()
    conn.execute("""
        CREATE TABLE challenges AS 
        SELECT * FROM 'puzzles_cleaned.csv'
        WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
    """)
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def r1_accuracy_by_completion_length(conn,model_name):
    """
    For the responses from the completions-r1 model:
    1. We calculate completion length and correctness for each problem.
    2. We sort by length.
    3. We compute cumulative number of correct responses.
    """
    r1_completions = conn.sql(f"""
        WITH LengthsAndCorrectness AS (
            SELECT 
                LENGTH(results.completion) AS length,
                CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
            FROM results.completions results JOIN  challenges
            ON results.prompt_id = challenges.ID
            WHERE results.parent_dir = '{model_name}'
        ),
        TotalItems AS (
            SELECT COUNT(*) as total_count
            FROM LengthsAndCorrectness
        ),
        CumulativeCorrect AS (
            SELECT 
                length,
                SUM(correct) OVER (ORDER BY length) as cumulative_correct,
            FROM LengthsAndCorrectness
        )

        SELECT 
            length,
            cumulative_correct,
            CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
        FROM CumulativeCorrect, TotalItems
        ORDER BY length
    """)
    return r1_completions


def accuracy_by_model_and_time(conn):
    model_accuracies = conn.sql("""
        WITH ChallengesWithDates AS (
            SELECT 
                ID,
                answer,
                EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
            FROM 
                challenges
        ),
        DateAnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                dates.year,
                COUNT(*) AS total,
                SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                ChallengesWithDates dates
            ON 
                results.prompt_id = dates.ID
            GROUP BY 
                results.parent_dir,
                dates.year
        )
        SELECT 
            model,
            year,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            DateAnswerCheck
        ORDER BY
            model,
            year
    """)

    return model_accuracies

def accuracy_by_model(conn):
    return conn.sql("""
        WITH AnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                SUM(results.count) AS total,
                SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                challenges challenges
            ON 
                results.prompt_id = challenges.ID
            GROUP BY 
                results.parent_dir
        )
        SELECT 
            model,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            AnswerCheck
    """)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--by-model-and-time", action="store_true")
    args = parser.parse_args()
    conn = load_results()

    if args.by_model_and_time:
        print(accuracy_by_model_and_time(conn))
    else:
        print(accuracy_by_model(conn))

if __name__ == "__main__":
    main()