import re import duckdb import textwrap from typing import List, Tuple import argparse import unicodedata import re import diskcache as dc cache = dc.Cache("answer_cache") def normalize_text(text: str) -> str: """Normalize text to remove accents, convert to lowercase, and strip spaces.""" text = unicodedata.normalize("NFKD", text) # Decomposes letters with accents (e.g., é → e + ́) text = "".join([c for c in text if not unicodedata.combining(c)]) # Remove diacritics text = text.lower().strip() # Convert to lowercase and strip spaces return text def _parse_answer(text: str) -> List[List[str]]: """ Converts text to lowercase. Then interprets ";" as a separator between alternatives. Within each alternative, interprets "," and "-->" as separators for elements of a set. Within each set, drops all non-alphanumeric characters and returns that set. Another way to describe this is that we interpret adjacent words as phrases that must be present literally. However, comma and arrow separate distinct phrases that may be present in any order. All other characters are dropped. """ text = text.lower() text = normalize_text(text) alternatives = re.split(r';', text) result = [ ] for alternative in alternatives: groups = re.split(r'–?-?-?>|,', alternative) result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups]) return result def _answer_without_thoughts(completion: str) -> str: completion = re.sub(r"()?[^<]*<\/think>", "", completion).strip() completion = re.sub(r".*", "", completion).strip() #because qwen sometimes misses return completion def _check_answer(completion: str, answer: str) -> bool: """ Check that all the phrases that must appear in the answer appear in the completion. We ignore "thoughts", capitalization, and punctuation. """ key = (completion, answer) if key in cache: return cache[key] completion = _answer_without_thoughts(completion).lower() completion = completion.replace("**","") completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer completion = normalize_text(completion) alternative_answers = _parse_answer(answer) for answer_phrases in alternative_answers: # if all(phrase in completion for phrase in answer_phrases): if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases): cache[key] = True return True cache[key] = False return False def _clip_text(text: str, width: int) -> str: return text if len(text) <= width else text[:width] + "..." def _wrap_text(text: str, width: int) -> str: return textwrap.fill(text, width=width) def load_results(): conn = duckdb.connect(":memory:") conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") # conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") conn.execute(""" CREATE TABLE challenges AS SELECT * FROM 'puzzles_cleaned.csv' WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' """) conn.create_function("check_answer", _check_answer) conn.create_function("clip_text", _clip_text) conn.create_function("wrap_text", _wrap_text) return conn def load_results_sample_one_only(): conn = duckdb.connect(":memory:") conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") query = """ CREATE TABLE sampled AS WITH numbered AS ( SELECT *, ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn FROM results.completions ) SELECT prompt_id, parent_dir, prompt, completion FROM numbered WHERE rn = 1; """ conn.execute(query).fetchall() conn.execute(""" CREATE TABLE challenges AS SELECT * FROM 'puzzles_cleaned.csv' WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' """) conn.create_function("check_answer", _check_answer) conn.create_function("clip_text", _clip_text) conn.create_function("wrap_text", _wrap_text) return conn def r1_accuracy_by_completion_length(conn,model_name): """ For the responses from the completions-r1 model: 1. We calculate completion length and correctness for each problem. 2. We sort by length. 3. We compute cumulative number of correct responses. """ r1_completions = conn.sql(f""" WITH LengthsAndCorrectness AS ( SELECT LENGTH(results.completion) AS length, CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct FROM results.completions results JOIN challenges ON results.prompt_id = challenges.ID WHERE results.parent_dir = '{model_name}' ), TotalItems AS ( SELECT COUNT(*) as total_count FROM LengthsAndCorrectness ), CumulativeCorrect AS ( SELECT length, SUM(correct) OVER (ORDER BY length) as cumulative_correct, FROM LengthsAndCorrectness ) SELECT length, cumulative_correct, CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy FROM CumulativeCorrect, TotalItems ORDER BY length """) return r1_completions def accuracy_by_model_and_time(conn): model_accuracies = conn.sql(""" WITH ChallengesWithDates AS ( SELECT ID, answer, EXTRACT(YEAR FROM CAST(date AS DATE)) AS year FROM challenges ), DateAnswerCheck AS ( SELECT results.parent_dir AS model, dates.year, COUNT(*) AS total, SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct FROM results.completions results JOIN ChallengesWithDates dates ON results.prompt_id = dates.ID GROUP BY results.parent_dir, dates.year ) SELECT model, year, total, correct, ROUND(correct / total, 2) AS accuracy FROM DateAnswerCheck ORDER BY model, year """) return model_accuracies def accuracy_by_model(conn): return conn.sql(""" WITH AnswerCheck AS ( SELECT results.parent_dir AS model, SUM(results.count) AS total, SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct FROM results.completions results JOIN challenges challenges ON results.prompt_id = challenges.ID GROUP BY results.parent_dir ) SELECT model, total, correct, ROUND(correct / total, 2) AS accuracy FROM AnswerCheck """) def main(): parser = argparse.ArgumentParser() parser.add_argument("--by-model-and-time", action="store_true") args = parser.parse_args() conn = load_results() if args.by_model_and_time: print(accuracy_by_model_and_time(conn)) else: print(accuracy_by_model(conn)) if __name__ == "__main__": main()