ai / sms_spam /queries.py
kevinhug's picture
llm evals
4c3d0df
"""
uv add -r requirements.txt
uv run -- inspect eval queries.py --model ollama/deepseek-r1 --limit 20
uv run -- inspect view
"""
import json
from inspect_ai import task, Task
from inspect_ai.dataset import csv_dataset, FieldSpec
from inspect_ai.model import get_model
from inspect_ai.scorer import accuracy, scorer, Score, CORRECT, INCORRECT, match
from inspect_ai.solver import system_message, generate, solver
from inspect_ai.util import resource
from utils import is_valid, json_completion
from typing import Literal
@task
def validate():
return eval_task(scorer=match("any")) #validate_scorer())
@task
def critique():
return eval_task(scorer=critique_scorer())
# shared task implementation parmaeterized by scorer
def eval_task(scorer):
# read dataset
dataset = csv_dataset(
csv_file="sms.csv",
sample_fields=FieldSpec(
input="input",
target="target"
),
shuffle=True
)
# create eval task
return Task(
dataset=dataset,
plan=[
system_message("spam detector to determine spam or ham based on SMS."),
prompt_with_schema(),
generate()
],
scorer=scorer
)
@solver
def prompt_with_schema():
prompt_template = resource("prompt.txt")
async def solve(state, generate):
# build the prompt
state.user_prompt.text = prompt_template.replace(
"{{prompt}}", state.input #state.user_prompt.text
)
return state
return solve
@scorer(metrics=[accuracy()])
def validate_scorer():
async def score(state, target):
# check for valid query
query = json_completion(state.output.completion).strip()
if query==target:
value=CORRECT
else:
value=INCORRECT
# return score w/ query that was extracted
return Score(value=value, answer=query)
return score
@scorer(metrics=[accuracy()])
def critique_scorer(model = "ollama/deepscaler"):
async def score(state, target):
# build the critic prompt
query = state.output.completion.strip()
critic_prompt = resource("critique.txt").replace(
"{{prompt}}", state.input #state.user_prompt.text
).replace(
"{{answer}}", query
)
# run the critique
result = await get_model(model).generate(critic_prompt)
try:
parsed = json.loads(json_completion(result.completion))
value = CORRECT if target.text == query else INCORRECT
explanation = parsed["critique"]
except (json.JSONDecodeError, KeyError):
value = INCORRECT
explanation = f"JSON parsing error:\n{result.completion}"
# return value and explanation (critique text)
return Score(answer=query, value=value, explanation=explanation)
return score