File size: 2,941 Bytes
4c3d0df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
"""
uv add -r requirements.txt
uv run -- inspect eval queries.py --model ollama/deepseek-r1 --limit 20
uv run -- inspect view
"""
import json
from inspect_ai import task, Task
from inspect_ai.dataset import csv_dataset, FieldSpec
from inspect_ai.model import get_model
from inspect_ai.scorer import accuracy, scorer, Score, CORRECT, INCORRECT, match
from inspect_ai.solver import system_message, generate, solver
from inspect_ai.util import resource
from utils import is_valid, json_completion
from typing import Literal
@task
def validate():
return eval_task(scorer=match("any")) #validate_scorer())
@task
def critique():
return eval_task(scorer=critique_scorer())
# shared task implementation parmaeterized by scorer
def eval_task(scorer):
# read dataset
dataset = csv_dataset(
csv_file="sms.csv",
sample_fields=FieldSpec(
input="input",
target="target"
),
shuffle=True
)
# create eval task
return Task(
dataset=dataset,
plan=[
system_message("spam detector to determine spam or ham based on SMS."),
prompt_with_schema(),
generate()
],
scorer=scorer
)
@solver
def prompt_with_schema():
prompt_template = resource("prompt.txt")
async def solve(state, generate):
# build the prompt
state.user_prompt.text = prompt_template.replace(
"{{prompt}}", state.input #state.user_prompt.text
)
return state
return solve
@scorer(metrics=[accuracy()])
def validate_scorer():
async def score(state, target):
# check for valid query
query = json_completion(state.output.completion).strip()
if query==target:
value=CORRECT
else:
value=INCORRECT
# return score w/ query that was extracted
return Score(value=value, answer=query)
return score
@scorer(metrics=[accuracy()])
def critique_scorer(model = "ollama/deepscaler"):
async def score(state, target):
# build the critic prompt
query = state.output.completion.strip()
critic_prompt = resource("critique.txt").replace(
"{{prompt}}", state.input #state.user_prompt.text
).replace(
"{{answer}}", query
)
# run the critique
result = await get_model(model).generate(critic_prompt)
try:
parsed = json.loads(json_completion(result.completion))
value = CORRECT if target.text == query else INCORRECT
explanation = parsed["critique"]
except (json.JSONDecodeError, KeyError):
value = INCORRECT
explanation = f"JSON parsing error:\n{result.completion}"
# return value and explanation (critique text)
return Score(answer=query, value=value, explanation=explanation)
return score
|