Nathan Brake commited on
Commit
2c681d3
·
unverified ·
1 Parent(s): 5d6330b

Refactor agent loading and execution to utilize AnyAgent class (#45)

Browse files

* Refactor agent loading and execution to utilize AnyAgent class for improved framework compatibility

* cli fix

* Update CLI to set instructions based on agent framework.

* lint

examples/langchain_single_agent_user_confirmation.yaml CHANGED
@@ -7,12 +7,13 @@ input_prompt_template: |
7
  in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
8
  Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
9
  and then confirm the final selection with him.
 
 
10
 
11
  framework: langchain
12
 
13
  main_agent:
14
- model_id: gpt-4o
15
- # model_id: ollama/llama3.1:latest
16
  api_key_var: OPENAI_API_KEY
17
  tools:
18
  - "surf_spot_finder.tools.driving_hours_to_meters"
@@ -21,3 +22,17 @@ main_agent:
21
  - "surf_spot_finder.tools.get_wave_forecast"
22
  - "surf_spot_finder.tools.get_wind_forecast"
23
  - "any_agent.tools.send_console_message"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  in a {MAX_DRIVING_HOURS} hour driving radius, at {DATE}?
8
  Find a few options and then discuss it with David de la Iglesia Castro. You should recommend him some choices,
9
  and then confirm the final selection with him.
10
+ Once he gives the final selection, save a detailed description of the weather at the chosen location into a file
11
+ named "final_answer.txt". Also save a file called "history.txt" which has a list of your thought process in the choice.
12
 
13
  framework: langchain
14
 
15
  main_agent:
16
+ model_id: openai/gpt-4o
 
17
  api_key_var: OPENAI_API_KEY
18
  tools:
19
  - "surf_spot_finder.tools.driving_hours_to_meters"
 
22
  - "surf_spot_finder.tools.get_wave_forecast"
23
  - "surf_spot_finder.tools.get_wind_forecast"
24
  - "any_agent.tools.send_console_message"
25
+ - command: "docker"
26
+ args:
27
+ - "run"
28
+ - "-i"
29
+ - "--rm"
30
+ - "--mount"
31
+ - "type=bind,src=/tmp/surf-spot-finder,dst=/projects"
32
+ - "mcp/filesystem"
33
+ - "/projects"
34
+ tools:
35
+ - "read_file"
36
+ - "write_file"
37
+ - "directory_tree"
38
+ - "list_allowed_directories"
examples/smolagents_single_agent_user_confirmation.yaml CHANGED
@@ -13,7 +13,6 @@ framework: smolagents
13
 
14
  main_agent:
15
  model_id: openai/gpt-4o
16
- # model_id: ollama/llama3.1:latest
17
  api_key_var: OPENAI_API_KEY
18
  tools:
19
  - "surf_spot_finder.tools.driving_hours_to_meters"
 
13
 
14
  main_agent:
15
  model_id: openai/gpt-4o
 
16
  api_key_var: OPENAI_API_KEY
17
  tools:
18
  - "surf_spot_finder.tools.driving_hours_to_meters"
src/surf_spot_finder/cli.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import yaml
2
  from pathlib import Path
3
 
@@ -7,9 +8,11 @@ from loguru import logger
7
  from surf_spot_finder.config import (
8
  Config,
9
  )
10
- from any_agent import load_agent, run_agent
11
  from any_agent.tracing import get_tracer_provider, setup_tracing
12
 
 
 
 
13
 
14
  @logger.catch(reraise=True)
15
  def find_surf_spot(
@@ -25,6 +28,12 @@ def find_surf_spot(
25
  logger.info(f"Loading {config_file}")
26
  config = Config.model_validate(yaml.safe_load(Path(config_file).read_text()))
27
 
 
 
 
 
 
 
28
  logger.info("Setting up tracing")
29
  tracer_provider, tracing_path = get_tracer_provider(
30
  project_name="surf-spot-finder", agent_framework=config.framework
@@ -33,9 +42,9 @@ def find_surf_spot(
33
 
34
  logger.info(f"Loading {config.framework} agent")
35
  logger.info(f"{config.managed_agents}")
36
- agent = load_agent(
37
- framework=config.framework,
38
- main_agent=config.main_agent,
39
  managed_agents=config.managed_agents,
40
  )
41
 
@@ -45,7 +54,7 @@ def find_surf_spot(
45
  DATE=config.date,
46
  )
47
  logger.info(f"Running agent with query:\n{query}")
48
- run_agent(agent, query)
49
 
50
  logger.success("Done!")
51
 
 
1
+ from any_agent import AgentFramework, AnyAgent
2
  import yaml
3
  from pathlib import Path
4
 
 
8
  from surf_spot_finder.config import (
9
  Config,
10
  )
 
11
  from any_agent.tracing import get_tracer_provider, setup_tracing
12
 
13
+ from surf_spot_finder.instructions.openai import SINGLE_AGENT_SYSTEM_PROMPT
14
+ from surf_spot_finder.instructions.smolagents import SYSTEM_PROMPT
15
+
16
 
17
  @logger.catch(reraise=True)
18
  def find_surf_spot(
 
28
  logger.info(f"Loading {config_file}")
29
  config = Config.model_validate(yaml.safe_load(Path(config_file).read_text()))
30
 
31
+ if not config.main_agent.instructions:
32
+ if config.main_agent.agent_framework == AgentFramework.SMOLAGENTS:
33
+ config.main_agent.instructions = SYSTEM_PROMPT
34
+ elif config.main_agent.agent_framework == AgentFramework.OPENAI:
35
+ config.main_agent.instructions = SINGLE_AGENT_SYSTEM_PROMPT
36
+
37
  logger.info("Setting up tracing")
38
  tracer_provider, tracing_path = get_tracer_provider(
39
  project_name="surf-spot-finder", agent_framework=config.framework
 
42
 
43
  logger.info(f"Loading {config.framework} agent")
44
  logger.info(f"{config.managed_agents}")
45
+ agent = AnyAgent.create(
46
+ agent_framework=config.framework,
47
+ agent_config=config.main_agent,
48
  managed_agents=config.managed_agents,
49
  )
50
 
 
54
  DATE=config.date,
55
  )
56
  logger.info(f"Running agent with query:\n{query}")
57
+ agent.run(query)
58
 
59
  logger.success("Done!")
60
 
src/surf_spot_finder/evaluation/evaluate.py CHANGED
@@ -16,7 +16,7 @@ from surf_spot_finder.evaluation.evaluators import (
16
  HypothesisEvaluator,
17
  )
18
  from surf_spot_finder.evaluation.test_case import TestCase
19
- from any_agent import load_agent, run_agent
20
  from any_agent.tracing import get_tracer_provider, setup_tracing
21
 
22
  logger.remove()
@@ -40,9 +40,9 @@ def run(test_case: TestCase, agent_config_path: str) -> str:
40
 
41
  logger.info(f"Loading {config.framework} agent")
42
  logger.info(f"{config.managed_agents}")
43
- agent = load_agent(
44
- framework=config.framework,
45
- main_agent=config.main_agent,
46
  managed_agents=config.managed_agents,
47
  )
48
 
@@ -52,7 +52,7 @@ def run(test_case: TestCase, agent_config_path: str) -> str:
52
  DATE=config.date,
53
  )
54
  logger.info(f"Running agent with query:\n{query}")
55
- run_agent(agent, query)
56
 
57
  logger.success("Done!")
58
 
 
16
  HypothesisEvaluator,
17
  )
18
  from surf_spot_finder.evaluation.test_case import TestCase
19
+ from any_agent import AnyAgent
20
  from any_agent.tracing import get_tracer_provider, setup_tracing
21
 
22
  logger.remove()
 
40
 
41
  logger.info(f"Loading {config.framework} agent")
42
  logger.info(f"{config.managed_agents}")
43
+ agent = AnyAgent.create(
44
+ agent_framework=config.framework,
45
+ agent_config=config.main_agent,
46
  managed_agents=config.managed_agents,
47
  )
48
 
 
52
  DATE=config.date,
53
  )
54
  logger.info(f"Running agent with query:\n{query}")
55
+ agent.run(query)
56
 
57
  logger.success("Done!")
58