import re | |
import duckdb | |
import textwrap | |
from typing import List, Tuple | |
import argparse | |
import unicodedata | |
import re | |
import diskcache as dc | |
cache = dc.Cache("answer_cache") | |
def normalize_text(text: str) -> str: | |
"""Normalize text to remove accents, convert to lowercase, and strip spaces.""" | |
text = unicodedata.normalize("NFKD", text) # Decomposes letters with accents (e.g., é → e + ́) | |
text = "".join([c for c in text if not unicodedata.combining(c)]) # Remove diacritics | |
text = text.lower().strip() # Convert to lowercase and strip spaces | |
return text | |
def _parse_answer(text: str) -> List[List[str]]: | |
""" | |
Converts text to lowercase. Then interprets ";" as a separator between | |
alternatives. Within each alternative, interprets "," and "-->" as separators | |
for elements of a set. Within each set, drops all non-alphanumeric characters | |
and returns that set. | |
Another way to describe this is that we interpret adjacent words as | |
phrases that must be present literally. However, comma and arrow separate | |
distinct phrases that may be present in any order. All other characters | |
are dropped. | |
""" | |
text = text.lower() | |
text = normalize_text(text) | |
alternatives = re.split(r';', text) | |
result = [ ] | |
for alternative in alternatives: | |
groups = re.split(r'–?-?-?>|,', alternative) | |
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups]) | |
return result | |
def _answer_without_thoughts(completion: str) -> str: | |
completion = re.sub(r"(<think>)?[^<]*<\/think>", "", completion).strip() | |
completion = re.sub(r".*</think>", "", completion).strip() #because qwen sometimes misses <think> | |
return completion | |
def _check_answer(completion: str, answer: str) -> bool: | |
""" | |
Check that all the phrases that must appear in the answer appear in the | |
completion. We ignore "thoughts", capitalization, and punctuation. | |
""" | |
key = (completion, answer) | |
if key in cache: | |
return cache[key] | |
completion = _answer_without_thoughts(completion).lower() | |
completion = completion.replace("**","") | |
completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join | |
completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer | |
completion = normalize_text(completion) | |
alternative_answers = _parse_answer(answer) | |
for answer_phrases in alternative_answers: | |
# if all(phrase in completion for phrase in answer_phrases): | |
if all('\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases): | |
cache[key] = True | |
return True | |
cache[key] = False | |
return False | |
def _clip_text(text: str, width: int) -> str: | |
return text if len(text) <= width else text[:width] + "..." | |
def _wrap_text(text: str, width: int) -> str: | |
return textwrap.fill(text, width=width) | |
def load_results(): | |
conn = duckdb.connect(":memory:") | |
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") | |
# conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'") | |
conn.execute(""" | |
CREATE TABLE challenges AS | |
SELECT * FROM 'puzzles_cleaned.csv' | |
WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' | |
""") | |
conn.create_function("check_answer", _check_answer) | |
conn.create_function("clip_text", _clip_text) | |
conn.create_function("wrap_text", _wrap_text) | |
return conn | |
def load_results_sample_one_only(): | |
conn = duckdb.connect(":memory:") | |
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)") | |
query = """ | |
CREATE TABLE sampled AS | |
WITH numbered AS ( | |
SELECT *, | |
ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn | |
FROM results.completions | |
) | |
SELECT prompt_id, parent_dir, prompt, completion | |
FROM numbered | |
WHERE rn = 1; | |
""" | |
conn.execute(query).fetchall() | |
conn.execute(""" | |
CREATE TABLE challenges AS | |
SELECT * FROM 'puzzles_cleaned.csv' | |
WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%' | |
""") | |
conn.create_function("check_answer", _check_answer) | |
conn.create_function("clip_text", _clip_text) | |
conn.create_function("wrap_text", _wrap_text) | |
return conn | |
def r1_accuracy_by_completion_length(conn,model_name): | |
""" | |
For the responses from the completions-r1 model: | |
1. We calculate completion length and correctness for each problem. | |
2. We sort by length. | |
3. We compute cumulative number of correct responses. | |
""" | |
r1_completions = conn.sql(f""" | |
WITH LengthsAndCorrectness AS ( | |
LENGTH(results.completion) AS length, | |
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct | |
FROM results.completions results JOIN challenges | |
ON results.prompt_id = challenges.ID | |
WHERE results.parent_dir = '{model_name}' | |
), | |
TotalItems AS ( | |
SELECT COUNT(*) as total_count | |
FROM LengthsAndCorrectness | |
), | |
CumulativeCorrect AS ( | |
length, | |
SUM(correct) OVER (ORDER BY length) as cumulative_correct, | |
FROM LengthsAndCorrectness | |
) | |
length, | |
cumulative_correct, | |
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy | |
FROM CumulativeCorrect, TotalItems | |
ORDER BY length | |
""") | |
return r1_completions | |
def accuracy_by_model_and_time(conn): | |
model_accuracies = conn.sql(""" | |
WITH ChallengesWithDates AS ( | |
ID, | |
answer, | |
FROM | |
challenges | |
), | |
DateAnswerCheck AS ( | |
results.parent_dir AS model, | |
dates.year, | |
COUNT(*) AS total, | |
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct | |
FROM | |
results.completions results | |
JOIN | |
ChallengesWithDates dates | |
ON | |
results.prompt_id = dates.ID | |
results.parent_dir, | |
dates.year | |
) | |
model, | |
year, | |
total, | |
correct, | |
ROUND(correct / total, 2) AS accuracy | |
FROM | |
DateAnswerCheck | |
model, | |
year | |
""") | |
return model_accuracies | |
def accuracy_by_model(conn): | |
return conn.sql(""" | |
WITH AnswerCheck AS ( | |
results.parent_dir AS model, | |
SUM(results.count) AS total, | |
SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct | |
FROM | |
results.completions results | |
JOIN | |
challenges challenges | |
ON | |
results.prompt_id = challenges.ID | |
results.parent_dir | |
) | |
model, | |
total, | |
correct, | |
ROUND(correct / total, 2) AS accuracy | |
FROM | |
AnswerCheck | |
""") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--by-model-and-time", action="store_true") | |
args = parser.parse_args() | |
conn = load_results() | |
if args.by_model_and_time: | |
print(accuracy_by_model_and_time(conn)) | |
else: | |
print(accuracy_by_model(conn)) | |
if __name__ == "__main__": | |
main() | |