Spaces:
Running
Running
| import re | |
| import duckdb | |
| import textwrap | |
| from typing import List, Tuple | |
| import argparse | |
| import unicodedata | |
| import unicodedata | |
| import re | |
| def normalize_text(text: str) -> str: | |
| """Normalize text to remove accents, convert to lowercase, and strip spaces.""" | |
| text = unicodedata.normalize("NFKD", text) # Decomposes letters with accents (e.g., é → e + ́) | |
| text = "".join([c for c in text if not unicodedata.combining(c)]) # Remove diacritics | |
| text = text.lower().strip() # Convert to lowercase and strip spaces | |
| return text | |
| def _parse_answer(text: str) -> List[List[str]]: | |
| """ | |
| Converts text to lowercase. Then interprets ";" as a separator between | |
| alternatives. Within each alternative, interprets "," and "-->" as separators | |
| for elements of a set. Within each set, drops all non-alphanumeric characters | |
| and returns that set. | |
| Another way to describe this is that we interpret adjacent words as | |
| phrases that must be present literally. However, comma and arrow separate | |
| distinct phrases that may be present in any order. All other characters | |
| are dropped. | |
| """ | |
| text = normalize_text(text) | |
| alternatives = re.split(r';', text) | |
| result = [ ] | |
| for alternative in alternatives: | |
| groups = re.split(r'–?-?-?>|,', alternative) | |
| result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups]) | |
| return result | |
| def _answer_without_thoughts(completion: str) -> str: | |
| if "<think>" not in completion[:200]: | |
| return completion | |
| chunks = completion.split("</think>") | |
| if len(chunks) <= 1: | |
| return "" | |
| return chunks[-1].strip() | |
| def _check_answer(completion: str, answer: str) -> bool: | |
| """ | |
| Check that all the phrases that must appear in the answer appear in the | |
| completion. We ignore "thoughts", capitalization, and punctuation. | |
| """ | |
| completion = _answer_without_thoughts(completion).lower() | |
| completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join | |
| completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer | |
| completion = normalize_text(completion) | |
| alternative_answers = _parse_answer(answer) | |
| for answer_phrases in alternative_answers: | |
| # if all(phrase in completion for phrase in answer_phrases): | |
| if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases): | |
| return True | |
| return False | |
| def _clip_text(text: str, width: int) -> str: | |
| return text if len(text) <= width else text[:width] + "..." | |
| def _wrap_text(text: str, width: int) -> str: | |
| return textwrap.fill(text, width=width) | |
| def load_results(): | |
| conn = duckdb.connect(":memory:") | |
| conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") | |
| # conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") | |
| conn.execute(""" | |
| CREATE TABLE challenges AS | |
| SELECT * FROM 'puzzles_cleaned.csv' | |
| WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' | |
| """) | |
| conn.create_function("check_answer", _check_answer) | |
| conn.create_function("clip_text", _clip_text) | |
| conn.create_function("wrap_text", _wrap_text) | |
| return conn | |
| def load_results_sample_one_only(): | |
| conn = duckdb.connect(":memory:") | |
| conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") | |
| query = """ | |
| CREATE TABLE sampled AS | |
| WITH numbered AS ( | |
| SELECT *, | |
| ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn | |
| FROM results.completions | |
| ) | |
| SELECT prompt_id, parent_dir, prompt, completion | |
| FROM numbered | |
| WHERE rn = 1; | |
| """ | |
| conn.execute(query).fetchall() | |
| # #print how how many rows are in the table | |
| # print(conn.execute("SELECT COUNT(*) FROM sampled").fetchall()) | |
| # #describe the sampled table | |
| # print(conn.execute("DESCRIBE sampled").fetchall()) | |
| conn.execute(""" | |
| CREATE TABLE challenges AS | |
| SELECT * FROM 'puzzles_cleaned.csv' | |
| WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' | |
| """) | |
| conn.create_function("check_answer", _check_answer) | |
| conn.create_function("clip_text", _clip_text) | |
| conn.create_function("wrap_text", _wrap_text) | |
| return conn | |
| def r1_accuracy_by_completion_length(conn,model_name): | |
| """ | |
| For the responses from the completions-r1 model: | |
| 1. We calculate completion length and correctness for each problem. | |
| 2. We sort by length. | |
| 3. We compute cumulative number of correct responses. | |
| """ | |
| r1_completions = conn.sql(f""" | |
| WITH LengthsAndCorrectness AS ( | |
| SELECT | |
| LENGTH(results.completion) AS length, | |
| CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct | |
| FROM results.completions results JOIN challenges | |
| ON results.prompt_id = challenges.ID | |
| WHERE results.parent_dir = '{model_name}' | |
| ), | |
| TotalItems AS ( | |
| SELECT COUNT(*) as total_count | |
| FROM LengthsAndCorrectness | |
| ), | |
| CumulativeCorrect AS ( | |
| SELECT | |
| length, | |
| SUM(correct) OVER (ORDER BY length) as cumulative_correct, | |
| FROM LengthsAndCorrectness | |
| ) | |
| SELECT | |
| length, | |
| cumulative_correct, | |
| CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy | |
| FROM CumulativeCorrect, TotalItems | |
| ORDER BY length | |
| """) | |
| return r1_completions | |
| def accuracy_by_model_and_time(conn): | |
| model_accuracies = conn.sql(""" | |
| WITH ChallengesWithDates AS ( | |
| SELECT | |
| ID, | |
| answer, | |
| EXTRACT(YEAR FROM CAST(date AS DATE)) AS year | |
| FROM | |
| challenges | |
| ), | |
| DateAnswerCheck AS ( | |
| SELECT | |
| results.parent_dir AS model, | |
| dates.year, | |
| COUNT(*) AS total, | |
| SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct | |
| FROM | |
| results.completions results | |
| JOIN | |
| ChallengesWithDates dates | |
| ON | |
| results.prompt_id = dates.ID | |
| GROUP BY | |
| results.parent_dir, | |
| dates.year | |
| ) | |
| SELECT | |
| model, | |
| year, | |
| total, | |
| correct, | |
| ROUND(correct / total, 2) AS accuracy | |
| FROM | |
| DateAnswerCheck | |
| ORDER BY | |
| model, | |
| year | |
| """) | |
| return model_accuracies | |
| def accuracy_by_model(conn): | |
| return conn.sql(""" | |
| WITH AnswerCheck AS ( | |
| SELECT | |
| results.parent_dir AS model, | |
| SUM(results.count) AS total, | |
| SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct | |
| FROM | |
| results.completions results | |
| JOIN | |
| challenges challenges | |
| ON | |
| results.prompt_id = challenges.ID | |
| GROUP BY | |
| results.parent_dir | |
| ) | |
| SELECT | |
| model, | |
| total, | |
| correct, | |
| ROUND(correct / total, 2) AS accuracy | |
| FROM | |
| AnswerCheck | |
| """) | |
| def accuracy_by_model_only_one(conn): | |
| query = """ | |
| WITH FirstResponses AS ( | |
| SELECT | |
| parent_dir AS model, | |
| prompt_id, | |
| completion, | |
| count, | |
| ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt_id) AS rn | |
| FROM results.completions | |
| WHERE parent_dir = 'completions-r1_cursor_hosted' -- Only consider rows where parent_dir is 'r1_cursor_hosted' | |
| ), | |
| AnswerCheck AS ( | |
| SELECT | |
| fr.model, | |
| SUM(fr.count) AS total, | |
| SUM(fr.count * CAST(check_answer(fr.completion, c.answer) AS INTEGER)) AS correct | |
| FROM FirstResponses fr | |
| JOIN challenges c ON fr.prompt_id = c.ID | |
| WHERE fr.rn = 1 -- Select only the first response per model per prompt | |
| GROUP BY fr.model | |
| ) | |
| SELECT | |
| model, | |
| total, | |
| correct, | |
| ROUND(correct / total, 2) AS accuracy | |
| FROM AnswerCheck; | |
| """ | |
| return conn.sql(query) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--by-model-and-time", action="store_true") | |
| args = parser.parse_args() | |
| conn = load_results_sample_one_only() | |
| query = """ | |
| SELECT parent_dir, prompt_id, COUNT(DISTINCT completion) AS completion_count | |
| FROM sampled | |
| GROUP BY parent_dir, prompt_id | |
| HAVING COUNT(DISTINCT completion) == 1; | |
| """ | |
| wrongones = conn.execute(query).fetchall() | |
| assert not wrongones, f"Found {len(wrongones)} prompts with not just one completion" | |
| if args.by_model_and_time: | |
| print(accuracy_by_model_and_time(conn)) | |
| else: | |
| print(accuracy_by_model(conn)) | |
| if __name__ == "__main__": | |
| main() | |