File size: 6,980 Bytes
5d76917
94a64b0
eef7dd3
5d76917
 
 
 
94a64b0
5d76917
 
 
fea07c2
7758a19
 
 
 
5d76917
 
0b1aa61
 
5d76917
eef7dd3
 
 
 
5d76917
0b1aa61
5d76917
fea07c2
5d76917
fea07c2
 
 
 
0b1aa61
 
 
 
 
 
 
 
 
 
5d76917
 
0b1aa61
 
 
 
 
 
 
 
 
 
 
 
5d76917
 
 
 
 
 
 
a9fb876
7c69831
5d76917
a9fb876
fea07c2
7758a19
 
 
 
eef7dd3
 
fea07c2
5d76917
 
7758a19
 
 
eef7dd3
 
 
7758a19
 
 
0b1aa61
 
 
 
 
 
 
 
7758a19
 
 
5d76917
 
94a64b0
 
 
 
eef7dd3
 
 
 
5d76917
 
 
 
eef7dd3
 
 
5d76917
94a64b0
5d76917
 
 
 
eef7dd3
 
 
5d76917
94a64b0
eef7dd3
94a64b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d76917
 
ef766f7
fea07c2
 
 
ef766f7
5d76917
 
 
 
 
 
 
fea07c2
5d76917
 
 
 
 
fea07c2
 
 
0b1aa61
5d76917
 
 
eef7dd3
5d76917
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import json
import os
import sys
from textwrap import dedent
from typing import Any, Dict, List, Optional
from loguru import logger
from fire import Fire
import pandas as pd
from surf_spot_finder.config import (
    Config,
)
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
from surf_spot_finder.evaluation.evaluators import (
    CheckpointEvaluator,
    QuestionAnsweringSquadEvaluator,
    HypothesisEvaluator,
)
from surf_spot_finder.evaluation.test_case import TestCase
from any_agent import load_agent, run_agent
from any_agent.tracing import get_tracer_provider, setup_tracing

logger.remove()
logger = logger.opt(ansi=True)
logger.add(sys.stdout, colorize=True, format="{message}")


def run(test_case: TestCase, agent_config_path: str) -> str:
    input_data = test_case.input

    logger.info("Loading config")
    config = Config.from_yaml(agent_config_path)
    config.location = input_data.location
    config.date = input_data.date
    config.max_driving_hours = input_data.max_driving_hours
    logger.info("Setting up tracing")
    tracer_provider, tracing_path = get_tracer_provider(project_name="surf-spot-finder")
    setup_tracing(tracer_provider, config.framework)

    logger.info(f"Loading {config.framework} agent")
    logger.info(f"{config.managed_agents}")
    agent = load_agent(
        framework=config.framework,
        main_agent=config.main_agent,
        managed_agents=config.managed_agents,
    )

    query = config.input_prompt_template.format(
        LOCATION=config.location,
        MAX_DRIVING_HOURS=config.max_driving_hours,
        DATE=config.date,
    )
    logger.info(f"Running agent with query:\n{query}")
    run_agent(agent, query)

    logger.success("Done!")

    return tracing_path


def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
    # load the json file
    with open(telemetry_path, "r") as f:
        telemetry: List[Dict[str, Any]] = json.loads(f.read())
    logger.info(f"Telemetry loaded from {telemetry_path}")

    agent_framework = TelemetryProcessor.determine_agent_framework(telemetry)

    # Extract the final answer from the telemetry
    processor = TelemetryProcessor.create(agent_framework)
    hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)

    # Checkpoint evaluation
    checkpoint_evaluator = CheckpointEvaluator(model=test_case.llm_judge)
    checkpoint_results = checkpoint_evaluator.evaluate(
        telemetry=telemetry,
        checkpoints=test_case.checkpoints,
        processor=processor,
    )

    # Hypothesis answer evaluation
    hypothesis_evaluator = HypothesisEvaluator(model=test_case.llm_judge)
    hypothesis_answer_results = hypothesis_evaluator.evaluate(
        hypothesis_final_answer=hypothesis_answer,
        ground_truth_answer_dict=test_case.ground_truth,
        ground_truth_checkpoints=test_case.final_answer_criteria,
    )

    # Direct answer evaluation (new)
    if test_case.ground_truth:
        direct_evaluator = QuestionAnsweringSquadEvaluator()
        direct_results = direct_evaluator.evaluate(
            hypothesis_answer=hypothesis_answer,
            ground_truth_answer=test_case.ground_truth,
        )
    else:
        direct_results = []
    # Combine all results
    verification_results = (
        checkpoint_results + hypothesis_answer_results + direct_results
    )
    # Summarize results
    output_message = ""
    output_message += (
        f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>\n"""
    )
    failed_checks = [r for r in verification_results if not r.passed]
    passed_checks = [r for r in verification_results if r.passed]
    missed_points = sum([r.points for r in failed_checks])
    won_points = sum([r.points for r in passed_checks])
    if passed_checks:
        for check in passed_checks:
            message = dedent(
                f"""
                <green>Passed:
                - {check.criteria}
                - {check.reason}</green>"""
            )
            output_message += message + "\n"
    if failed_checks:
        for check in failed_checks:
            message = dedent(
                f"""
                <red>Failed:
                - {check.criteria}
                - {check.reason}</red>"""
            )
            output_message += message + "\n"
    else:
        output_message += "<green>All checkpoints passed!</green>\n"
    output_message += f"<green>Passed checkpoints: {len(passed_checks)}</green>\n"
    output_message += f"<red>Failed checkpoints: {len(failed_checks)}</red>\n"
    output_message += "<green>=====================================</green>\n"
    output_message += (
        f"<green>Score: {won_points}/{won_points + missed_points}</green>\n"
    )
    output_message += "<green>=====================================</green>\n"
    logger.info(output_message)
    # See if the test_case.output_path file exists.
    if os.path.exists(test_case.output_path):
        df = pd.read_json(test_case.output_path, orient="records", lines=True)
    else:
        df = pd.DataFrame()
    df = pd.concat(
        [
            df,
            pd.DataFrame(
                [
                    {
                        "test_case_path": test_case.test_case_path,
                        "output_message": output_message,
                        "telemetry_path": telemetry_path,
                        "hypothesis_answer": hypothesis_answer,
                        "passed_checks": len(passed_checks),
                        "failed_checks": len(failed_checks),
                        "score": round(
                            won_points / (won_points + missed_points) * 100, 2
                        ),
                    }
                ]
            ),
        ]
    )
    df.to_json(test_case.output_path, orient="records", lines=True)


def evaluate(
    test_case_path: str,
    agent_config_path: str = None,
    telemetry_path: Optional[str] = None,
) -> None:
    """
    Evaluate agent performance using either a provided telemetry file or by running the agent.

    Args:
        telemetry_path: Optional path to an existing telemetry file. If not provided,
                        the agent will be run to generate one.
    """
    test_case = TestCase.from_yaml(test_case_path=test_case_path)

    if telemetry_path is None:
        logger.info(
            "No telemetry path provided. Running agent to generate telemetry..."
        )
        assert (
            agent_config_path is not None
        ), "Agent config path must be provided if running agent"
        telemetry_path = run(test_case, agent_config_path)
    else:
        logger.info(f"Using provided telemetry file: {telemetry_path}")
        logger.info(
            "For this to work, the telemetry file must align with the test case.",
        )

    evaluate_telemetry(test_case, telemetry_path)


def main():
    Fire(evaluate)


if __name__ == "__main__":
    main()