File size: 7,870 Bytes
172f931
 
 
a92220b
457470f
479b4ac
 
 
f9fdde4
 
 
 
 
479b4ac
 
 
 
 
 
 
172f931
a92220b
172f931
a92220b
 
 
172f931
 
a92220b
172f931
 
 
 
 
c073751
479b4ac
a92220b
 
 
5771e25
a92220b
 
172f931
 
c073751
 
 
 
172f931
f9fdde4
172f931
 
 
 
 
f9fdde4
 
 
 
172f931
c073751
314576f
5771e25
479b4ac
a92220b
 
314576f
 
f9fdde4
a92220b
f9fdde4
a92220b
172f931
f9fdde4
172f931
 
 
 
 
 
 
 
457470f
89f9030
 
 
 
 
 
172f931
 
 
 
 
479b4ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314576f
457470f
 
 
 
 
 
314576f
457470f
 
 
 
 
 
314576f
 
 
 
 
 
 
 
 
 
 
457470f
314576f
457470f
 
314576f
 
 
 
457470f
 
 
 
 
172f931
457470f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172f931
 
 
479b4ac
 
172f931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457470f
 
 
 
c073751
 
457470f
 
 
 
172f931
 
457470f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import re
import duckdb
import textwrap
from typing import List, Tuple
import argparse
import unicodedata
import re


import diskcache as dc

cache = dc.Cache("answer_cache")

def normalize_text(text: str) -> str:
    """Normalize text to remove accents, convert to lowercase, and strip spaces."""
    text = unicodedata.normalize("NFKD", text)  # Decomposes letters with accents (e.g., é → e + ́)
    text = "".join([c for c in text if not unicodedata.combining(c)])  # Remove diacritics
    text = text.lower().strip()  # Convert to lowercase and strip spaces
    return text


def _parse_answer(text: str) -> List[List[str]]:
    """
    Converts text to lowercase. Then interprets ";" as a separator between
    alternatives. Within each alternative, interprets "," and "-->" as separators
    for elements of a set. Within each set, drops all non-alphanumeric characters
    and returns that set.


    Another way to describe this is that we interpret adjacent words as
    phrases that must be present literally. However, comma and arrow separate
    distinct phrases that may be present in any order. All other characters
    are dropped.
    """
    text = text.lower()
    text = normalize_text(text)
    alternatives = re.split(r';', text)
    result = [ ]
    for alternative in alternatives:
        groups = re.split(r'–?-?-?>|,', alternative)
        result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
    return result

def _answer_without_thoughts(completion: str) -> str:
    completion = re.sub(r"(<think>)?[^<]*<\/think>", "", completion).strip()
    completion = re.sub(r".*</think>", "", completion).strip() #because qwen sometimes misses <think>
    return completion



def _check_answer(completion: str, answer: str) -> bool:
    """
    Check that all the phrases that must appear in the answer appear in the
    completion. We ignore "thoughts", capitalization, and punctuation.
    """
    key = (completion, answer)
    if key in cache:
        return cache[key]
    
    completion = _answer_without_thoughts(completion).lower()
    completion = completion.replace("**","")
    completion  = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
    completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
    completion = normalize_text(completion)
    alternative_answers = _parse_answer(answer)
    for answer_phrases in alternative_answers:
        # if all(phrase in completion for phrase in answer_phrases):
        if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
            cache[key] = True
            return True
    cache[key] = False
    return False


def _clip_text(text: str, width: int) -> str:
    return text if len(text) <= width else text[:width] + "..."

def _wrap_text(text: str, width: int) -> str:
    return textwrap.fill(text, width=width)

def load_results():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
    # conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
    conn.execute("""
        CREATE TABLE challenges AS 
        SELECT * FROM 'puzzles_cleaned.csv'
        WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
    """)
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def load_results_sample_one_only():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
    
    query = """
    CREATE TABLE sampled AS
    WITH numbered AS (
        SELECT *,
            ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn
        FROM results.completions
    )
    SELECT prompt_id, parent_dir, prompt, completion
    FROM numbered
    WHERE rn = 1;
    """
    conn.execute(query).fetchall()
    conn.execute("""
        CREATE TABLE challenges AS 
        SELECT * FROM 'puzzles_cleaned.csv'
        WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
    """)
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def r1_accuracy_by_completion_length(conn,model_name):
    """
    For the responses from the completions-r1 model:
    1. We calculate completion length and correctness for each problem.
    2. We sort by length.
    3. We compute cumulative number of correct responses.
    """
    r1_completions = conn.sql(f"""
        WITH LengthsAndCorrectness AS (
            SELECT 
                LENGTH(results.completion) AS length,
                CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
            FROM results.completions results JOIN  challenges
            ON results.prompt_id = challenges.ID
            WHERE results.parent_dir = '{model_name}'
        ),
        TotalItems AS (
            SELECT COUNT(*) as total_count
            FROM LengthsAndCorrectness
        ),
        CumulativeCorrect AS (
            SELECT 
                length,
                SUM(correct) OVER (ORDER BY length) as cumulative_correct,
            FROM LengthsAndCorrectness
        )

        SELECT 
            length,
            cumulative_correct,
            CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
        FROM CumulativeCorrect, TotalItems
        ORDER BY length
    """)
    return r1_completions


def accuracy_by_model_and_time(conn):
    model_accuracies = conn.sql("""
        WITH ChallengesWithDates AS (
            SELECT 
                ID,
                answer,
                EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
            FROM 
                challenges
        ),
        DateAnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                dates.year,
                COUNT(*) AS total,
                SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                ChallengesWithDates dates
            ON 
                results.prompt_id = dates.ID
            GROUP BY 
                results.parent_dir,
                dates.year
        )
        SELECT 
            model,
            year,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            DateAnswerCheck
        ORDER BY
            model,
            year
    """)

    return model_accuracies

def accuracy_by_model(conn):
    return conn.sql("""
        WITH AnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                SUM(results.count) AS total,
                SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                challenges challenges
            ON 
                results.prompt_id = challenges.ID
            GROUP BY 
                results.parent_dir
        )
        SELECT 
            model,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            AnswerCheck
    """)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--by-model-and-time", action="store_true")
    args = parser.parse_args()
    conn = load_results()

    if args.by_model_and_time:
        print(accuracy_by_model_and_time(conn))
    else:
        print(accuracy_by_model(conn))

if __name__ == "__main__":
    main()