Nathan Brake commited on
Commit
515cbf5
·
unverified ·
1 Parent(s): c211e28

Optionally Evaluate Cases after generating trace (#57)

Browse files
README.md CHANGED
@@ -49,7 +49,7 @@ pip install -e . # Install root project dependencies
49
  ### 3️⃣ Run
50
 
51
  ```bash
52
- surf-spot-finder
53
  ```
54
 
55
  ## How it Works
 
49
  ### 3️⃣ Run
50
 
51
  ```bash
52
+ surf-spot-finder examples/single_agent_with_tools.yaml
53
  ```
54
 
55
  ## How it Works
examples/single_agent_with_tools.yaml CHANGED
@@ -8,3 +8,24 @@ main_agent:
8
  - "surf_spot_finder.tools.get_wind_forecast"
9
  - "any_agent.tools.search_web"
10
  - "any_agent.tools.visit_webpage"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  - "surf_spot_finder.tools.get_wind_forecast"
9
  - "any_agent.tools.search_web"
10
  - "any_agent.tools.visit_webpage"
11
+
12
+
13
+ evaluation_cases:
14
+ - llm_judge: openai/gpt-4.1-mini
15
+ checkpoints:
16
+ - criteria: "Check if the agent used the get_surfing_spots tool and it succeeded, and that the tool was used before the get_wave_forecast and get_wind_forecast tools"
17
+ points: 1
18
+ - criteria: "Check if the agent used the get_wave_forecast tool and it succeeded"
19
+ points: 1
20
+ - criteria: "Check if the agent used the get_wind_forecast tool and it succeeded"
21
+ points: 1
22
+ - criteria: "Check if the agent used the get_area_lat_lon tool and it succeeded"
23
+ points: 1
24
+ - criteria: "Check if the agent used the driving_hours_to_meters tool to convert the driving hours to meters and it succeeded"
25
+ points: 1
26
+ - criteria: "Check if the final answer contains any description about the weather at the chosen location"
27
+ points: 1
28
+ - criteria: "Check if the final answer contains one of the surf spots found by a call of the get_surfing_spots tool"
29
+ points: 1
30
+ - criteria: "Check that the agent completed in fewer than 10 steps"
31
+ points: 1
pyproject.toml CHANGED
@@ -9,7 +9,7 @@ license = {text = "Apache-2.0"}
9
  requires-python = ">=3.11"
10
  dynamic = ["version"]
11
  dependencies = [
12
- "any-agent[all]",
13
  "fire",
14
  "pydantic",
15
  "pyyaml",
 
9
  requires-python = ">=3.11"
10
  dynamic = ["version"]
11
  dependencies = [
12
+ "any-agent[all]>=0.12.2",
13
  "fire",
14
  "pydantic",
15
  "pyyaml",
src/surf_spot_finder/cli.py CHANGED
@@ -3,8 +3,10 @@ import os
3
  from pathlib import Path
4
 
5
  from any_agent import AgentFramework, AnyAgent, TracingConfig
 
6
  from fire import Fire
7
  from any_agent.logging import logger
 
8
 
9
  from surf_spot_finder.config import (
10
  Config,
@@ -27,7 +29,7 @@ async def find_surf_spot(
27
  if config_file is None:
28
  config = Config.from_dict({})
29
  else:
30
- logger.info(f"Loading {config_file}")
31
  config = Config.from_yaml(config_file)
32
 
33
  if not config.main_agent.instructions:
@@ -36,8 +38,8 @@ async def find_surf_spot(
36
  elif config.framework == AgentFramework.OPENAI:
37
  config.main_agent.instructions = SINGLE_AGENT_SYSTEM_PROMPT
38
 
39
- logger.info(f"Loading {config.framework} agent")
40
- logger.info(f"{config.managed_agents}")
41
  agent = await AnyAgent.create_async(
42
  agent_framework=config.framework,
43
  agent_config=config.main_agent,
@@ -50,10 +52,10 @@ async def find_surf_spot(
50
  MAX_DRIVING_HOURS=config.max_driving_hours,
51
  DATE=config.date,
52
  )
53
- logger.info(f"Running agent with query:\n{query}")
54
  agent_trace = await agent.run_async(query)
55
 
56
- logger.info(f"Final output from agent:\n{agent_trace.final_output}")
57
 
58
  # dump the trace in the "output" directory
59
  output_dir = "output"
@@ -63,6 +65,38 @@ async def find_surf_spot(
63
  with open(file_path, "w") as f:
64
  f.write(agent_trace.model_dump_json(indent=2))
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def main():
68
  Fire(find_surf_spot)
 
3
  from pathlib import Path
4
 
5
  from any_agent import AgentFramework, AnyAgent, TracingConfig
6
+ from any_agent.evaluation.schemas import TraceEvaluationResult
7
  from fire import Fire
8
  from any_agent.logging import logger
9
+ from any_agent.evaluation import evaluate
10
 
11
  from surf_spot_finder.config import (
12
  Config,
 
29
  if config_file is None:
30
  config = Config.from_dict({})
31
  else:
32
+ logger.info("Loading %s", config_file)
33
  config = Config.from_yaml(config_file)
34
 
35
  if not config.main_agent.instructions:
 
38
  elif config.framework == AgentFramework.OPENAI:
39
  config.main_agent.instructions = SINGLE_AGENT_SYSTEM_PROMPT
40
 
41
+ logger.info("Loading %s agent", config.framework)
42
+ logger.info("Managed agents: %s", config.managed_agents)
43
  agent = await AnyAgent.create_async(
44
  agent_framework=config.framework,
45
  agent_config=config.main_agent,
 
52
  MAX_DRIVING_HOURS=config.max_driving_hours,
53
  DATE=config.date,
54
  )
55
+ logger.info("Running agent with query:\n%s", query)
56
  agent_trace = await agent.run_async(query)
57
 
58
+ logger.info("Final output from agent:\n%s", agent_trace.final_output)
59
 
60
  # dump the trace in the "output" directory
61
  output_dir = "output"
 
65
  with open(file_path, "w") as f:
66
  f.write(agent_trace.model_dump_json(indent=2))
67
 
68
+ if config.evaluation_cases is not None:
69
+ results = []
70
+ logger.info("Found evaluation cases, running trace evaluation")
71
+ for i, case in enumerate(config.evaluation_cases):
72
+ logger.info("Evaluating case: %s", case)
73
+ result: TraceEvaluationResult = evaluate(
74
+ evaluation_case=case,
75
+ trace=agent_trace,
76
+ agent_framework=config.framework,
77
+ )
78
+ for list_of_checkpoints in [
79
+ result.checkpoint_results,
80
+ result.direct_results,
81
+ result.hypothesis_answer_results,
82
+ ]:
83
+ for checkpoint in list_of_checkpoints:
84
+ msg = (
85
+ f"Checkpoint: {checkpoint.criteria}\n"
86
+ f"\tPassed: {checkpoint.passed}\n"
87
+ f"\tReason: {checkpoint.reason}\n"
88
+ f"\tScore: {'%d/%d' % (checkpoint.points, checkpoint.points) if checkpoint.passed else '0/%d' % checkpoint.points}"
89
+ )
90
+ logger.info(msg)
91
+ logger.info("==========================")
92
+ logger.info("Overall Score: %d%%", 100 * result.score)
93
+ logger.info("==========================")
94
+ results.append(result)
95
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
96
+ file_path = Path(output_dir) / f"{timestamp}_eval_case_{i}.json"
97
+ with open(file_path, "w") as f:
98
+ f.write(result.model_dump_json(indent=2))
99
+
100
 
101
  def main():
102
  Fire(find_surf_spot)
src/surf_spot_finder/config.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import AfterValidator, BaseModel, ConfigDict, FutureDatetime, Posi
8
  import yaml
9
  from rich.prompt import Prompt
10
  from any_agent.logging import logger
 
11
  import geocoder
12
  from litellm.litellm_core_utils.get_llm_provider_logic import (
13
  get_llm_provider,
@@ -36,7 +37,7 @@ def ask_framework() -> AgentFramework:
36
  [f"{i}: {framework}" for i, framework in enumerate(frameworks)]
37
  )
38
  prompt = f"Select the agent framework to use:\n{frameworks_str}\n"
39
- choice = Prompt.ask(prompt, default="3")
40
  try:
41
  choice = int(choice)
42
  if choice < 0 or choice >= len(frameworks):
@@ -148,6 +149,8 @@ class Config(BaseModel):
148
  main_agent: AgentConfig
149
  managed_agents: list[AgentConfig] | None = None
150
 
 
 
151
  @classmethod
152
  def from_dict(cls, data: dict) -> "Config":
153
  """
@@ -212,6 +215,7 @@ class Config(BaseModel):
212
  data["date"] = date_picker()
213
  else:
214
  logger.info(f"Using date {data['date']}")
 
215
  return cls(**data)
216
 
217
  @classmethod
 
8
  import yaml
9
  from rich.prompt import Prompt
10
  from any_agent.logging import logger
11
+ from any_agent.evaluation import EvaluationCase
12
  import geocoder
13
  from litellm.litellm_core_utils.get_llm_provider_logic import (
14
  get_llm_provider,
 
37
  [f"{i}: {framework}" for i, framework in enumerate(frameworks)]
38
  )
39
  prompt = f"Select the agent framework to use:\n{frameworks_str}\n"
40
+ choice = Prompt.ask(prompt, default="0")
41
  try:
42
  choice = int(choice)
43
  if choice < 0 or choice >= len(frameworks):
 
149
  main_agent: AgentConfig
150
  managed_agents: list[AgentConfig] | None = None
151
 
152
+ evaluation_cases: list[EvaluationCase] | None = None
153
+
154
  @classmethod
155
  def from_dict(cls, data: dict) -> "Config":
156
  """
 
215
  data["date"] = date_picker()
216
  else:
217
  logger.info(f"Using date {data['date']}")
218
+
219
  return cls(**data)
220
 
221
  @classmethod