import json
import os
import sys
from textwrap import dedent
from typing import Any, Dict, List, Optional
from loguru import logger
from fire import Fire
import pandas as pd
from surf_spot_finder.config import (
Config,
)
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
from surf_spot_finder.evaluation.evaluators import (
CheckpointEvaluator,
QuestionAnsweringSquadEvaluator,
HypothesisEvaluator,
)
from surf_spot_finder.evaluation.test_case import TestCase
from any_agent import load_agent, run_agent
from any_agent.tracing import get_tracer_provider, setup_tracing
logger.remove()
logger = logger.opt(ansi=True)
logger.add(sys.stdout, colorize=True, format="{message}")
def run(test_case: TestCase, agent_config_path: str) -> str:
input_data = test_case.input
logger.info("Loading config")
config = Config.from_yaml(agent_config_path)
config.location = input_data.location
config.date = input_data.date
config.max_driving_hours = input_data.max_driving_hours
logger.info("Setting up tracing")
tracer_provider, tracing_path = get_tracer_provider(project_name="surf-spot-finder")
setup_tracing(tracer_provider, config.framework)
logger.info(f"Loading {config.framework} agent")
logger.info(f"{config.managed_agents}")
agent = load_agent(
framework=config.framework,
main_agent=config.main_agent,
managed_agents=config.managed_agents,
)
query = config.input_prompt_template.format(
LOCATION=config.location,
MAX_DRIVING_HOURS=config.max_driving_hours,
DATE=config.date,
)
logger.info(f"Running agent with query:\n{query}")
run_agent(agent, query)
logger.success("Done!")
return tracing_path
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
# load the json file
with open(telemetry_path, "r") as f:
telemetry: List[Dict[str, Any]] = json.loads(f.read())
logger.info(f"Telemetry loaded from {telemetry_path}")
agent_framework = TelemetryProcessor.determine_agent_framework(telemetry)
# Extract the final answer from the telemetry
processor = TelemetryProcessor.create(agent_framework)
hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
# Checkpoint evaluation
checkpoint_evaluator = CheckpointEvaluator(model=test_case.llm_judge)
checkpoint_results = checkpoint_evaluator.evaluate(
telemetry=telemetry,
checkpoints=test_case.checkpoints,
processor=processor,
)
# Hypothesis answer evaluation
hypothesis_evaluator = HypothesisEvaluator(model=test_case.llm_judge)
hypothesis_answer_results = hypothesis_evaluator.evaluate(
hypothesis_final_answer=hypothesis_answer,
ground_truth_answer_dict=test_case.ground_truth,
ground_truth_checkpoints=test_case.final_answer_criteria,
)
# Direct answer evaluation (new)
if test_case.ground_truth:
direct_evaluator = QuestionAnsweringSquadEvaluator()
direct_results = direct_evaluator.evaluate(
hypothesis_answer=hypothesis_answer,
ground_truth_answer=test_case.ground_truth,
)
else:
direct_results = []
# Combine all results
verification_results = (
checkpoint_results + hypothesis_answer_results + direct_results
)
# Summarize results
output_message = ""
output_message += (
f"""Hypothesis Final answer extracted: {hypothesis_answer}\n"""
)
failed_checks = [r for r in verification_results if not r.passed]
passed_checks = [r for r in verification_results if r.passed]
missed_points = sum([r.points for r in failed_checks])
won_points = sum([r.points for r in passed_checks])
if passed_checks:
for check in passed_checks:
message = dedent(
f"""
Passed:
- {check.criteria}
- {check.reason}"""
)
output_message += message + "\n"
if failed_checks:
for check in failed_checks:
message = dedent(
f"""
Failed:
- {check.criteria}
- {check.reason}"""
)
output_message += message + "\n"
else:
output_message += "All checkpoints passed!\n"
output_message += f"Passed checkpoints: {len(passed_checks)}\n"
output_message += f"Failed checkpoints: {len(failed_checks)}\n"
output_message += "=====================================\n"
output_message += (
f"Score: {won_points}/{won_points + missed_points}\n"
)
output_message += "=====================================\n"
logger.info(output_message)
# See if the test_case.output_path file exists.
if os.path.exists(test_case.output_path):
df = pd.read_json(test_case.output_path, orient="records", lines=True)
else:
df = pd.DataFrame()
df = pd.concat(
[
df,
pd.DataFrame(
[
{
"test_case_path": test_case.test_case_path,
"output_message": output_message,
"telemetry_path": telemetry_path,
"hypothesis_answer": hypothesis_answer,
"passed_checks": len(passed_checks),
"failed_checks": len(failed_checks),
"score": round(
won_points / (won_points + missed_points) * 100, 2
),
}
]
),
]
)
df.to_json(test_case.output_path, orient="records", lines=True)
def evaluate(
test_case_path: str,
agent_config_path: str = None,
telemetry_path: Optional[str] = None,
) -> None:
"""
Evaluate agent performance using either a provided telemetry file or by running the agent.
Args:
telemetry_path: Optional path to an existing telemetry file. If not provided,
the agent will be run to generate one.
"""
test_case = TestCase.from_yaml(test_case_path=test_case_path)
if telemetry_path is None:
logger.info(
"No telemetry path provided. Running agent to generate telemetry..."
)
assert (
agent_config_path is not None
), "Agent config path must be provided if running agent"
telemetry_path = run(test_case, agent_config_path)
else:
logger.info(f"Using provided telemetry file: {telemetry_path}")
logger.info(
"For this to work, the telemetry file must align with the test case.",
)
evaluate_telemetry(test_case, telemetry_path)
def main():
Fire(evaluate)
if __name__ == "__main__":
main()