|
""" |
|
uv add -r requirements.txt |
|
uv run -- inspect eval queries.py --model ollama/deepseek-r1 --limit 20 |
|
uv run -- inspect view |
|
""" |
|
|
|
import json |
|
|
|
from inspect_ai import task, Task |
|
from inspect_ai.dataset import csv_dataset, FieldSpec |
|
from inspect_ai.model import get_model |
|
from inspect_ai.scorer import accuracy, scorer, Score, CORRECT, INCORRECT, match |
|
from inspect_ai.solver import system_message, generate, solver |
|
from inspect_ai.util import resource |
|
|
|
from utils import is_valid, json_completion |
|
from typing import Literal |
|
|
|
@task |
|
def validate(): |
|
return eval_task(scorer=match("any")) |
|
|
|
|
|
@task |
|
def critique(): |
|
return eval_task(scorer=critique_scorer()) |
|
|
|
|
|
|
|
def eval_task(scorer): |
|
|
|
|
|
dataset = csv_dataset( |
|
csv_file="sms.csv", |
|
sample_fields=FieldSpec( |
|
input="input", |
|
target="target" |
|
), |
|
shuffle=True |
|
) |
|
|
|
|
|
return Task( |
|
dataset=dataset, |
|
plan=[ |
|
system_message("spam detector to determine spam or ham based on SMS."), |
|
prompt_with_schema(), |
|
generate() |
|
], |
|
scorer=scorer |
|
) |
|
|
|
|
|
@solver |
|
def prompt_with_schema(): |
|
|
|
prompt_template = resource("prompt.txt") |
|
|
|
async def solve(state, generate): |
|
|
|
state.user_prompt.text = prompt_template.replace( |
|
"{{prompt}}", state.input |
|
) |
|
return state |
|
|
|
return solve |
|
|
|
|
|
@scorer(metrics=[accuracy()]) |
|
def validate_scorer(): |
|
|
|
async def score(state, target): |
|
|
|
|
|
query = json_completion(state.output.completion).strip() |
|
if query==target: |
|
value=CORRECT |
|
else: |
|
value=INCORRECT |
|
|
|
|
|
return Score(value=value, answer=query) |
|
|
|
return score |
|
|
|
|
|
@scorer(metrics=[accuracy()]) |
|
def critique_scorer(model = "ollama/deepscaler"): |
|
|
|
async def score(state, target): |
|
|
|
|
|
query = state.output.completion.strip() |
|
critic_prompt = resource("critique.txt").replace( |
|
"{{prompt}}", state.input |
|
).replace( |
|
"{{answer}}", query |
|
) |
|
|
|
|
|
result = await get_model(model).generate(critic_prompt) |
|
try: |
|
parsed = json.loads(json_completion(result.completion)) |
|
value = CORRECT if target.text == query else INCORRECT |
|
explanation = parsed["critique"] |
|
except (json.JSONDecodeError, KeyError): |
|
value = INCORRECT |
|
explanation = f"JSON parsing error:\n{result.completion}" |
|
|
|
|
|
return Score(answer=query, value=value, explanation=explanation) |
|
|
|
return score |
|
|
|
|