Spaces:
Running
Running
File size: 7,870 Bytes
172f931 a92220b 457470f 479b4ac f9fdde4 479b4ac 172f931 a92220b 172f931 a92220b 172f931 a92220b 172f931 c073751 479b4ac a92220b 5771e25 a92220b 172f931 c073751 172f931 f9fdde4 172f931 f9fdde4 172f931 c073751 314576f 5771e25 479b4ac a92220b 314576f f9fdde4 a92220b f9fdde4 a92220b 172f931 f9fdde4 172f931 457470f 89f9030 172f931 479b4ac 314576f 457470f 314576f 457470f 314576f 457470f 314576f 457470f 314576f 457470f 172f931 457470f 172f931 479b4ac 172f931 457470f c073751 457470f 172f931 457470f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import re
import duckdb
import textwrap
from typing import List, Tuple
import argparse
import unicodedata
import re
import diskcache as dc
cache = dc.Cache("answer_cache")
def normalize_text(text: str) -> str:
"""Normalize text to remove accents, convert to lowercase, and strip spaces."""
text = unicodedata.normalize("NFKD", text) # Decomposes letters with accents (e.g., é → e + ́)
text = "".join([c for c in text if not unicodedata.combining(c)]) # Remove diacritics
text = text.lower().strip() # Convert to lowercase and strip spaces
return text
def _parse_answer(text: str) -> List[List[str]]:
"""
Converts text to lowercase. Then interprets ";" as a separator between
alternatives. Within each alternative, interprets "," and "-->" as separators
for elements of a set. Within each set, drops all non-alphanumeric characters
and returns that set.
Another way to describe this is that we interpret adjacent words as
phrases that must be present literally. However, comma and arrow separate
distinct phrases that may be present in any order. All other characters
are dropped.
"""
text = text.lower()
text = normalize_text(text)
alternatives = re.split(r';', text)
result = [ ]
for alternative in alternatives:
groups = re.split(r'–?-?-?>|,', alternative)
result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
return result
def _answer_without_thoughts(completion: str) -> str:
completion = re.sub(r"(<think>)?[^<]*<\/think>", "", completion).strip()
completion = re.sub(r".*</think>", "", completion).strip() #because qwen sometimes misses <think>
return completion
def _check_answer(completion: str, answer: str) -> bool:
"""
Check that all the phrases that must appear in the answer appear in the
completion. We ignore "thoughts", capitalization, and punctuation.
"""
key = (completion, answer)
if key in cache:
return cache[key]
completion = _answer_without_thoughts(completion).lower()
completion = completion.replace("**","")
completion = re.sub(r'[^\w\s]', ' ', completion) # this replaces punctuations with space, aligning with the _parse_answer function's ' '.join
completion = re.sub(r'\s+', ' ', completion) # normalize consecutive (Unicode) spaces to finish aligning with _parse_answer
completion = normalize_text(completion)
alternative_answers = _parse_answer(answer)
for answer_phrases in alternative_answers:
# if all(phrase in completion for phrase in answer_phrases):
if all(re.search(rf'\b{re.escape(phrase)}\b', completion) for phrase in answer_phrases):
cache[key] = True
return True
cache[key] = False
return False
def _clip_text(text: str, width: int) -> str:
return text if len(text) <= width else text[:width] + "..."
def _wrap_text(text: str, width: int) -> str:
return textwrap.fill(text, width=width)
def load_results():
conn = duckdb.connect(":memory:")
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
# conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
conn.execute("""
CREATE TABLE challenges AS
SELECT * FROM 'puzzles_cleaned.csv'
WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
""")
conn.create_function("check_answer", _check_answer)
conn.create_function("clip_text", _clip_text)
conn.create_function("wrap_text", _wrap_text)
return conn
def load_results_sample_one_only():
conn = duckdb.connect(":memory:")
conn.execute("ATTACH DATABASE 'results.duckdb' AS results (READ_ONLY)")
query = """
CREATE TABLE sampled AS
WITH numbered AS (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY parent_dir, prompt ORDER BY prompt_id) AS rn
FROM results.completions
)
SELECT prompt_id, parent_dir, prompt, completion
FROM numbered
WHERE rn = 1;
"""
conn.execute(query).fetchall()
conn.execute("""
CREATE TABLE challenges AS
SELECT * FROM 'puzzles_cleaned.csv'
WHERE Warnings IS NULL OR Warnings NOT LIKE '%(E)%'
""")
conn.create_function("check_answer", _check_answer)
conn.create_function("clip_text", _clip_text)
conn.create_function("wrap_text", _wrap_text)
return conn
def r1_accuracy_by_completion_length(conn,model_name):
"""
For the responses from the completions-r1 model:
1. We calculate completion length and correctness for each problem.
2. We sort by length.
3. We compute cumulative number of correct responses.
"""
r1_completions = conn.sql(f"""
WITH LengthsAndCorrectness AS (
SELECT
LENGTH(results.completion) AS length,
CAST(check_answer(results.completion, challenges.answer) AS INT32) AS correct
FROM results.completions results JOIN challenges
ON results.prompt_id = challenges.ID
WHERE results.parent_dir = '{model_name}'
),
TotalItems AS (
SELECT COUNT(*) as total_count
FROM LengthsAndCorrectness
),
CumulativeCorrect AS (
SELECT
length,
SUM(correct) OVER (ORDER BY length) as cumulative_correct,
FROM LengthsAndCorrectness
)
SELECT
length,
cumulative_correct,
CAST(cumulative_correct AS FLOAT) / total_count AS cumulative_accuracy
FROM CumulativeCorrect, TotalItems
ORDER BY length
""")
return r1_completions
def accuracy_by_model_and_time(conn):
model_accuracies = conn.sql("""
WITH ChallengesWithDates AS (
SELECT
ID,
answer,
EXTRACT(YEAR FROM CAST(date AS DATE)) AS year
FROM
challenges
),
DateAnswerCheck AS (
SELECT
results.parent_dir AS model,
dates.year,
COUNT(*) AS total,
SUM(CAST(check_answer(results.completion, dates.answer) AS INTEGER)) AS correct
FROM
results.completions results
JOIN
ChallengesWithDates dates
ON
results.prompt_id = dates.ID
GROUP BY
results.parent_dir,
dates.year
)
SELECT
model,
year,
total,
correct,
ROUND(correct / total, 2) AS accuracy
FROM
DateAnswerCheck
ORDER BY
model,
year
""")
return model_accuracies
def accuracy_by_model(conn):
return conn.sql("""
WITH AnswerCheck AS (
SELECT
results.parent_dir AS model,
SUM(results.count) AS total,
SUM(results.count * CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
FROM
results.completions results
JOIN
challenges challenges
ON
results.prompt_id = challenges.ID
GROUP BY
results.parent_dir
)
SELECT
model,
total,
correct,
ROUND(correct / total, 2) AS accuracy
FROM
AnswerCheck
""")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--by-model-and-time", action="store_true")
args = parser.parse_args()
conn = load_results()
if args.by_model_and_time:
print(accuracy_by_model_and_time(conn))
else:
print(accuracy_by_model(conn))
if __name__ == "__main__":
main()
|