File size: 3,063 Bytes
172f931
 
 
a92220b
172f931
a92220b
172f931
a92220b
 
 
172f931
 
a92220b
172f931
 
 
 
 
 
a92220b
 
 
 
 
 
172f931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92220b
 
 
 
 
172f931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
import duckdb
import textwrap
from typing import List, Tuple

def _parse_answer(text: str) -> List[List[str]]:
    """
    Converts text to lowercase. Then interprets ";" as a separator between
    alternatives. Within each alternative, interprets "," and "-->" as separators
    for elements of a set. Within each set, drops all non-alphanumeric characters
    and returns that set.


    Another way to describe this is that we interpret adjacent words as
    phrases that must be present literally. However, comma and arrow separate
    distinct phrases that may be present in any order. All other characters
    are dropped.
    """
    text = text.lower()
    alternatives = re.split(r';', text)
    result = [ ]
    for alternative in alternatives:
        groups = re.split(r'-->|,', alternative)
        result.append([" ".join(re.findall(r'\b\w+\b', group)) for group in groups])
    return result

def _answer_without_thoughts(completion: str) -> str:
    if "<think>" not in completion[:200]:
        return completion
    
    chunks = completion.split("</think>")
    if len(chunks) <= 1:
        return ""
    
    return chunks[-1].strip()

def _check_answer(completion: str, answer: str) -> bool:
    """
    Check that all the phrases that must appear in the answer appear in the
    completion. We ignore "thoughts", capitalization, and punctuation.
    """
    completion = _answer_without_thoughts(completion).lower()
    alternative_answers = _parse_answer(answer)
    for answer_phrases in alternative_answers:
        if all(phrase in completion for phrase in answer_phrases):
            return True
    return False


def _clip_text(text: str, width: int) -> str:
    return text if len(text) <= width else text[:width] + "..."

def _wrap_text(text: str, width: int) -> str:
    return textwrap.fill(text, width=width)

def load_results():
    conn = duckdb.connect(":memory:")
    conn.execute("ATTACH DATABASE 'results.duckdb' AS results")
    conn.execute("CREATE TABLE challenges as SELECT * FROM 'puzzles_cleaned.csv'")
    conn.create_function("check_answer", _check_answer)
    conn.create_function("clip_text", _clip_text)
    conn.create_function("wrap_text", _wrap_text)
    return conn

def accuracy_by_model(conn):
    model_accuracies = conn.sql("""
        WITH AnswerCheck AS (
            SELECT 
                results.parent_dir AS model,
                COUNT(*) AS total,
                SUM(CAST(check_answer(results.completion, challenges.answer) AS INTEGER)) AS correct
            FROM 
                results.completions results
            JOIN 
                challenges challenges
            ON 
                results.prompt_id = challenges.ID
            GROUP BY 
                results.parent_dir
        )
        SELECT 
            model,
            total,
            correct,
            ROUND(correct / total, 2) AS accuracy
        FROM 
            AnswerCheck
    """)

    print(model_accuracies)

if __name__ == "__main__":
    conn = load_results()
    accuracy_by_model(conn)