File size: 2,941 Bytes
4c3d0df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
uv add -r requirements.txt
uv run -- inspect eval queries.py --model ollama/deepseek-r1 --limit 20
uv run -- inspect view
"""

import json

from inspect_ai import task, Task
from inspect_ai.dataset import csv_dataset, FieldSpec
from inspect_ai.model import get_model
from inspect_ai.scorer import accuracy, scorer, Score, CORRECT, INCORRECT, match
from inspect_ai.solver import system_message, generate, solver
from inspect_ai.util import resource

from utils import is_valid, json_completion
from typing import Literal

@task
def validate():
    return eval_task(scorer=match("any")) #validate_scorer())


@task
def critique():
    return eval_task(scorer=critique_scorer())


# shared task implementation parmaeterized by scorer
def eval_task(scorer):

    # read dataset
    dataset = csv_dataset(
        csv_file="sms.csv",
        sample_fields=FieldSpec(
            input="input",
            target="target"
        ),
        shuffle=True
    )

    # create eval task
    return Task(
        dataset=dataset,
        plan=[
            system_message("spam detector to determine spam or ham based on SMS."),
            prompt_with_schema(),
            generate()
        ],
        scorer=scorer
    )


@solver
def prompt_with_schema():

    prompt_template = resource("prompt.txt")

    async def solve(state, generate):
        # build the prompt
        state.user_prompt.text = prompt_template.replace(
            "{{prompt}}", state.input #state.user_prompt.text
        )
        return state

    return solve


@scorer(metrics=[accuracy()])
def validate_scorer():

    async def score(state, target):
       
        # check for valid query
        query = json_completion(state.output.completion).strip()
        if query==target:
            value=CORRECT
        else: 
            value=INCORRECT
       
        # return score w/ query that was extracted
        return Score(value=value, answer=query)

    return score


@scorer(metrics=[accuracy()])
def critique_scorer(model = "ollama/deepscaler"):

    async def score(state, target):
       
        # build the critic prompt
        query = state.output.completion.strip()
        critic_prompt = resource("critique.txt").replace(
            "{{prompt}}", state.input #state.user_prompt.text
        ).replace(
            "{{answer}}", query
        )
       
        # run the critique
        result = await get_model(model).generate(critic_prompt)
        try:
            parsed = json.loads(json_completion(result.completion))
            value = CORRECT if target.text == query else INCORRECT
            explanation = parsed["critique"]
        except (json.JSONDecodeError, KeyError):
            value = INCORRECT
            explanation = f"JSON parsing error:\n{result.completion}"
        
        # return value and explanation (critique text)
        return Score(answer=query, value=value, explanation=explanation)

    return score