Spaces:
Running
Running
Nathan Brake
commited on
Optionally Evaluate Cases after generating trace (#57)
Browse files- README.md +1 -1
- examples/single_agent_with_tools.yaml +21 -0
- pyproject.toml +1 -1
- src/surf_spot_finder/cli.py +39 -5
- src/surf_spot_finder/config.py +5 -1
README.md
CHANGED
@@ -49,7 +49,7 @@ pip install -e . # Install root project dependencies
|
|
49 |
### 3️⃣ Run
|
50 |
|
51 |
```bash
|
52 |
-
surf-spot-finder
|
53 |
```
|
54 |
|
55 |
## How it Works
|
|
|
49 |
### 3️⃣ Run
|
50 |
|
51 |
```bash
|
52 |
+
surf-spot-finder examples/single_agent_with_tools.yaml
|
53 |
```
|
54 |
|
55 |
## How it Works
|
examples/single_agent_with_tools.yaml
CHANGED
@@ -8,3 +8,24 @@ main_agent:
|
|
8 |
- "surf_spot_finder.tools.get_wind_forecast"
|
9 |
- "any_agent.tools.search_web"
|
10 |
- "any_agent.tools.visit_webpage"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
- "surf_spot_finder.tools.get_wind_forecast"
|
9 |
- "any_agent.tools.search_web"
|
10 |
- "any_agent.tools.visit_webpage"
|
11 |
+
|
12 |
+
|
13 |
+
evaluation_cases:
|
14 |
+
- llm_judge: openai/gpt-4.1-mini
|
15 |
+
checkpoints:
|
16 |
+
- criteria: "Check if the agent used the get_surfing_spots tool and it succeeded, and that the tool was used before the get_wave_forecast and get_wind_forecast tools"
|
17 |
+
points: 1
|
18 |
+
- criteria: "Check if the agent used the get_wave_forecast tool and it succeeded"
|
19 |
+
points: 1
|
20 |
+
- criteria: "Check if the agent used the get_wind_forecast tool and it succeeded"
|
21 |
+
points: 1
|
22 |
+
- criteria: "Check if the agent used the get_area_lat_lon tool and it succeeded"
|
23 |
+
points: 1
|
24 |
+
- criteria: "Check if the agent used the driving_hours_to_meters tool to convert the driving hours to meters and it succeeded"
|
25 |
+
points: 1
|
26 |
+
- criteria: "Check if the final answer contains any description about the weather at the chosen location"
|
27 |
+
points: 1
|
28 |
+
- criteria: "Check if the final answer contains one of the surf spots found by a call of the get_surfing_spots tool"
|
29 |
+
points: 1
|
30 |
+
- criteria: "Check that the agent completed in fewer than 10 steps"
|
31 |
+
points: 1
|
pyproject.toml
CHANGED
@@ -9,7 +9,7 @@ license = {text = "Apache-2.0"}
|
|
9 |
requires-python = ">=3.11"
|
10 |
dynamic = ["version"]
|
11 |
dependencies = [
|
12 |
-
"any-agent[all]",
|
13 |
"fire",
|
14 |
"pydantic",
|
15 |
"pyyaml",
|
|
|
9 |
requires-python = ">=3.11"
|
10 |
dynamic = ["version"]
|
11 |
dependencies = [
|
12 |
+
"any-agent[all]>=0.12.2",
|
13 |
"fire",
|
14 |
"pydantic",
|
15 |
"pyyaml",
|
src/surf_spot_finder/cli.py
CHANGED
@@ -3,8 +3,10 @@ import os
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
from any_agent import AgentFramework, AnyAgent, TracingConfig
|
|
|
6 |
from fire import Fire
|
7 |
from any_agent.logging import logger
|
|
|
8 |
|
9 |
from surf_spot_finder.config import (
|
10 |
Config,
|
@@ -27,7 +29,7 @@ async def find_surf_spot(
|
|
27 |
if config_file is None:
|
28 |
config = Config.from_dict({})
|
29 |
else:
|
30 |
-
logger.info(
|
31 |
config = Config.from_yaml(config_file)
|
32 |
|
33 |
if not config.main_agent.instructions:
|
@@ -36,8 +38,8 @@ async def find_surf_spot(
|
|
36 |
elif config.framework == AgentFramework.OPENAI:
|
37 |
config.main_agent.instructions = SINGLE_AGENT_SYSTEM_PROMPT
|
38 |
|
39 |
-
logger.info(
|
40 |
-
logger.info(
|
41 |
agent = await AnyAgent.create_async(
|
42 |
agent_framework=config.framework,
|
43 |
agent_config=config.main_agent,
|
@@ -50,10 +52,10 @@ async def find_surf_spot(
|
|
50 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
51 |
DATE=config.date,
|
52 |
)
|
53 |
-
logger.info(
|
54 |
agent_trace = await agent.run_async(query)
|
55 |
|
56 |
-
logger.info(
|
57 |
|
58 |
# dump the trace in the "output" directory
|
59 |
output_dir = "output"
|
@@ -63,6 +65,38 @@ async def find_surf_spot(
|
|
63 |
with open(file_path, "w") as f:
|
64 |
f.write(agent_trace.model_dump_json(indent=2))
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
def main():
|
68 |
Fire(find_surf_spot)
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
from any_agent import AgentFramework, AnyAgent, TracingConfig
|
6 |
+
from any_agent.evaluation.schemas import TraceEvaluationResult
|
7 |
from fire import Fire
|
8 |
from any_agent.logging import logger
|
9 |
+
from any_agent.evaluation import evaluate
|
10 |
|
11 |
from surf_spot_finder.config import (
|
12 |
Config,
|
|
|
29 |
if config_file is None:
|
30 |
config = Config.from_dict({})
|
31 |
else:
|
32 |
+
logger.info("Loading %s", config_file)
|
33 |
config = Config.from_yaml(config_file)
|
34 |
|
35 |
if not config.main_agent.instructions:
|
|
|
38 |
elif config.framework == AgentFramework.OPENAI:
|
39 |
config.main_agent.instructions = SINGLE_AGENT_SYSTEM_PROMPT
|
40 |
|
41 |
+
logger.info("Loading %s agent", config.framework)
|
42 |
+
logger.info("Managed agents: %s", config.managed_agents)
|
43 |
agent = await AnyAgent.create_async(
|
44 |
agent_framework=config.framework,
|
45 |
agent_config=config.main_agent,
|
|
|
52 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
53 |
DATE=config.date,
|
54 |
)
|
55 |
+
logger.info("Running agent with query:\n%s", query)
|
56 |
agent_trace = await agent.run_async(query)
|
57 |
|
58 |
+
logger.info("Final output from agent:\n%s", agent_trace.final_output)
|
59 |
|
60 |
# dump the trace in the "output" directory
|
61 |
output_dir = "output"
|
|
|
65 |
with open(file_path, "w") as f:
|
66 |
f.write(agent_trace.model_dump_json(indent=2))
|
67 |
|
68 |
+
if config.evaluation_cases is not None:
|
69 |
+
results = []
|
70 |
+
logger.info("Found evaluation cases, running trace evaluation")
|
71 |
+
for i, case in enumerate(config.evaluation_cases):
|
72 |
+
logger.info("Evaluating case: %s", case)
|
73 |
+
result: TraceEvaluationResult = evaluate(
|
74 |
+
evaluation_case=case,
|
75 |
+
trace=agent_trace,
|
76 |
+
agent_framework=config.framework,
|
77 |
+
)
|
78 |
+
for list_of_checkpoints in [
|
79 |
+
result.checkpoint_results,
|
80 |
+
result.direct_results,
|
81 |
+
result.hypothesis_answer_results,
|
82 |
+
]:
|
83 |
+
for checkpoint in list_of_checkpoints:
|
84 |
+
msg = (
|
85 |
+
f"Checkpoint: {checkpoint.criteria}\n"
|
86 |
+
f"\tPassed: {checkpoint.passed}\n"
|
87 |
+
f"\tReason: {checkpoint.reason}\n"
|
88 |
+
f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
|
89 |
+
)
|
90 |
+
logger.info(msg)
|
91 |
+
logger.info("==========================")
|
92 |
+
logger.info("Overall Score: %d%%", 100 * result.score)
|
93 |
+
logger.info("==========================")
|
94 |
+
results.append(result)
|
95 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
96 |
+
file_path = Path(output_dir) / f"{timestamp}_eval_case_{i}.json"
|
97 |
+
with open(file_path, "w") as f:
|
98 |
+
f.write(result.model_dump_json(indent=2))
|
99 |
+
|
100 |
|
101 |
def main():
|
102 |
Fire(find_surf_spot)
|
src/surf_spot_finder/config.py
CHANGED
@@ -8,6 +8,7 @@ from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, Posi
|
|
8 |
import yaml
|
9 |
from rich.prompt import Prompt
|
10 |
from any_agent.logging import logger
|
|
|
11 |
import geocoder
|
12 |
from litellm.litellm_core_utils.get_llm_provider_logic import (
|
13 |
get_llm_provider,
|
@@ -36,7 +37,7 @@ def ask_framework() -> AgentFramework:
|
|
36 |
[f"{i}: {framework}" for i, framework in enumerate(frameworks)]
|
37 |
)
|
38 |
prompt = f"Select the agent framework to use:\n{frameworks_str}\n"
|
39 |
-
choice = Prompt.ask(prompt, default="
|
40 |
try:
|
41 |
choice = int(choice)
|
42 |
if choice < 0 or choice >= len(frameworks):
|
@@ -148,6 +149,8 @@ class Config(BaseModel):
|
|
148 |
main_agent: AgentConfig
|
149 |
managed_agents: list[AgentConfig] | None = None
|
150 |
|
|
|
|
|
151 |
@classmethod
|
152 |
def from_dict(cls, data: dict) -> "Config":
|
153 |
"""
|
@@ -212,6 +215,7 @@ class Config(BaseModel):
|
|
212 |
data["date"] = date_picker()
|
213 |
else:
|
214 |
logger.info(f"Using date {data['date']}")
|
|
|
215 |
return cls(**data)
|
216 |
|
217 |
@classmethod
|
|
|
8 |
import yaml
|
9 |
from rich.prompt import Prompt
|
10 |
from any_agent.logging import logger
|
11 |
+
from any_agent.evaluation import EvaluationCase
|
12 |
import geocoder
|
13 |
from litellm.litellm_core_utils.get_llm_provider_logic import (
|
14 |
get_llm_provider,
|
|
|
37 |
[f"{i}: {framework}" for i, framework in enumerate(frameworks)]
|
38 |
)
|
39 |
prompt = f"Select the agent framework to use:\n{frameworks_str}\n"
|
40 |
+
choice = Prompt.ask(prompt, default="0")
|
41 |
try:
|
42 |
choice = int(choice)
|
43 |
if choice < 0 or choice >= len(frameworks):
|
|
|
149 |
main_agent: AgentConfig
|
150 |
managed_agents: list[AgentConfig] | None = None
|
151 |
|
152 |
+
evaluation_cases: list[EvaluationCase] | None = None
|
153 |
+
|
154 |
@classmethod
|
155 |
def from_dict(cls, data: dict) -> "Config":
|
156 |
"""
|
|
|
215 |
data["date"] = date_picker()
|
216 |
else:
|
217 |
logger.info(f"Using date {data['date']}")
|
218 |
+
|
219 |
return cls(**data)
|
220 |
|
221 |
@classmethod
|