Spaces:
Running
Running
File size: 6,980 Bytes
5d76917 94a64b0 eef7dd3 5d76917 94a64b0 5d76917 fea07c2 7758a19 5d76917 0b1aa61 5d76917 eef7dd3 5d76917 0b1aa61 5d76917 fea07c2 5d76917 fea07c2 0b1aa61 5d76917 0b1aa61 5d76917 a9fb876 7c69831 5d76917 a9fb876 fea07c2 7758a19 eef7dd3 fea07c2 5d76917 7758a19 eef7dd3 7758a19 0b1aa61 7758a19 5d76917 94a64b0 eef7dd3 5d76917 eef7dd3 5d76917 94a64b0 5d76917 eef7dd3 5d76917 94a64b0 eef7dd3 94a64b0 5d76917 ef766f7 fea07c2 ef766f7 5d76917 fea07c2 5d76917 fea07c2 0b1aa61 5d76917 eef7dd3 5d76917 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import json
import os
import sys
from textwrap import dedent
from typing import Any, Dict, List, Optional
from loguru import logger
from fire import Fire
import pandas as pd
from surf_spot_finder.config import (
Config,
)
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
from surf_spot_finder.evaluation.evaluators import (
CheckpointEvaluator,
QuestionAnsweringSquadEvaluator,
HypothesisEvaluator,
)
from surf_spot_finder.evaluation.test_case import TestCase
from any_agent import load_agent, run_agent
from any_agent.tracing import get_tracer_provider, setup_tracing
logger.remove()
logger = logger.opt(ansi=True)
logger.add(sys.stdout, colorize=True, format="{message}")
def run(test_case: TestCase, agent_config_path: str) -> str:
input_data = test_case.input
logger.info("Loading config")
config = Config.from_yaml(agent_config_path)
config.location = input_data.location
config.date = input_data.date
config.max_driving_hours = input_data.max_driving_hours
logger.info("Setting up tracing")
tracer_provider, tracing_path = get_tracer_provider(project_name="surf-spot-finder")
setup_tracing(tracer_provider, config.framework)
logger.info(f"Loading {config.framework} agent")
logger.info(f"{config.managed_agents}")
agent = load_agent(
framework=config.framework,
main_agent=config.main_agent,
managed_agents=config.managed_agents,
)
query = config.input_prompt_template.format(
LOCATION=config.location,
MAX_DRIVING_HOURS=config.max_driving_hours,
DATE=config.date,
)
logger.info(f"Running agent with query:\n{query}")
run_agent(agent, query)
logger.success("Done!")
return tracing_path
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
# load the json file
with open(telemetry_path, "r") as f:
telemetry: List[Dict[str, Any]] = json.loads(f.read())
logger.info(f"Telemetry loaded from {telemetry_path}")
agent_framework = TelemetryProcessor.determine_agent_framework(telemetry)
# Extract the final answer from the telemetry
processor = TelemetryProcessor.create(agent_framework)
hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
# Checkpoint evaluation
checkpoint_evaluator = CheckpointEvaluator(model=test_case.llm_judge)
checkpoint_results = checkpoint_evaluator.evaluate(
telemetry=telemetry,
checkpoints=test_case.checkpoints,
processor=processor,
)
# Hypothesis answer evaluation
hypothesis_evaluator = HypothesisEvaluator(model=test_case.llm_judge)
hypothesis_answer_results = hypothesis_evaluator.evaluate(
hypothesis_final_answer=hypothesis_answer,
ground_truth_answer_dict=test_case.ground_truth,
ground_truth_checkpoints=test_case.final_answer_criteria,
)
# Direct answer evaluation (new)
if test_case.ground_truth:
direct_evaluator = QuestionAnsweringSquadEvaluator()
direct_results = direct_evaluator.evaluate(
hypothesis_answer=hypothesis_answer,
ground_truth_answer=test_case.ground_truth,
)
else:
direct_results = []
# Combine all results
verification_results = (
checkpoint_results + hypothesis_answer_results + direct_results
)
# Summarize results
output_message = ""
output_message += (
f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>\n"""
)
failed_checks = [r for r in verification_results if not r.passed]
passed_checks = [r for r in verification_results if r.passed]
missed_points = sum([r.points for r in failed_checks])
won_points = sum([r.points for r in passed_checks])
if passed_checks:
for check in passed_checks:
message = dedent(
f"""
<green>Passed:
- {check.criteria}
- {check.reason}</green>"""
)
output_message += message + "\n"
if failed_checks:
for check in failed_checks:
message = dedent(
f"""
<red>Failed:
- {check.criteria}
- {check.reason}</red>"""
)
output_message += message + "\n"
else:
output_message += "<green>All checkpoints passed!</green>\n"
output_message += f"<green>Passed checkpoints: {len(passed_checks)}</green>\n"
output_message += f"<red>Failed checkpoints: {len(failed_checks)}</red>\n"
output_message += "<green>=====================================</green>\n"
output_message += (
f"<green>Score: {won_points}/{won_points + missed_points}</green>\n"
)
output_message += "<green>=====================================</green>\n"
logger.info(output_message)
# See if the test_case.output_path file exists.
if os.path.exists(test_case.output_path):
df = pd.read_json(test_case.output_path, orient="records", lines=True)
else:
df = pd.DataFrame()
df = pd.concat(
[
df,
pd.DataFrame(
[
{
"test_case_path": test_case.test_case_path,
"output_message": output_message,
"telemetry_path": telemetry_path,
"hypothesis_answer": hypothesis_answer,
"passed_checks": len(passed_checks),
"failed_checks": len(failed_checks),
"score": round(
won_points / (won_points + missed_points) * 100, 2
),
}
]
),
]
)
df.to_json(test_case.output_path, orient="records", lines=True)
def evaluate(
test_case_path: str,
agent_config_path: str = None,
telemetry_path: Optional[str] = None,
) -> None:
"""
Evaluate agent performance using either a provided telemetry file or by running the agent.
Args:
telemetry_path: Optional path to an existing telemetry file. If not provided,
the agent will be run to generate one.
"""
test_case = TestCase.from_yaml(test_case_path=test_case_path)
if telemetry_path is None:
logger.info(
"No telemetry path provided. Running agent to generate telemetry..."
)
assert (
agent_config_path is not None
), "Agent config path must be provided if running agent"
telemetry_path = run(test_case, agent_config_path)
else:
logger.info(f"Using provided telemetry file: {telemetry_path}")
logger.info(
"For this to work, the telemetry file must align with the test case.",
)
evaluate_telemetry(test_case, telemetry_path)
def main():
Fire(evaluate)
if __name__ == "__main__":
main()
|