David de la Iglesia Castro commited on
Commit
f53eca9
·
unverified ·
1 Parent(s): 8aa7d2f

7 make tools configurable (#21)

Browse files

* Make tools configurable via arguments.

- Make `mcp` optional, not used by default.
- Update examples to use `tools`.

* Update pyproject.tooml optional deps

* Make openai and smolagents imports lazy

* Update api.md with sections

* fix tests

* Add missing arize

.github/workflows/tests.yaml CHANGED
@@ -28,7 +28,7 @@ jobs:
28
  cache: "pip"
29
 
30
  - name: Install
31
- run: pip install -e '.[openai,tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
 
28
  cache: "pip"
29
 
30
  - name: Install
31
+ run: pip install -e '.[arize,smolagents,openai,tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
docs/api.md CHANGED
@@ -4,17 +4,30 @@
4
 
5
  ::: surf_spot_finder.config.Config
6
 
 
 
7
  ::: surf_spot_finder.agents.RUNNERS
8
 
9
  ::: surf_spot_finder.agents.openai
10
 
11
  ::: surf_spot_finder.agents.smolagents
12
 
 
 
13
  ::: surf_spot_finder.tools.openmeteo
 
14
  ::: surf_spot_finder.tools.openstreetmap
15
 
 
 
 
 
 
 
16
  ::: surf_spot_finder.tracing
17
 
 
 
18
  ::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
19
 
20
  ::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
 
4
 
5
  ::: surf_spot_finder.config.Config
6
 
7
+ ## Agents
8
+
9
  ::: surf_spot_finder.agents.RUNNERS
10
 
11
  ::: surf_spot_finder.agents.openai
12
 
13
  ::: surf_spot_finder.agents.smolagents
14
 
15
+ ## Tools
16
+
17
  ::: surf_spot_finder.tools.openmeteo
18
+
19
  ::: surf_spot_finder.tools.openstreetmap
20
 
21
+ ::: surf_spot_finder.tools.user_interaction
22
+
23
+ ::: surf_spot_finder.tools.web_browsing
24
+
25
+ ## Tracing
26
+
27
  ::: surf_spot_finder.tracing
28
 
29
+ ## Prompts
30
+
31
  ::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
32
 
33
  ::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
examples/openai_single_agent.yaml CHANGED
@@ -3,4 +3,7 @@ date: 2025-03-22 12:00
3
  max_driving_hours: 2
4
  model_id: o3-mini
5
  agent_type: openai
 
 
 
6
  # input_prompt_template:
 
3
  max_driving_hours: 2
4
  model_id: o3-mini
5
  agent_type: openai
6
+ tools:
7
+ - "surf_spot_finder.tools.search_web"
8
+ - "surf_spot_finder.tools.visit_webpage"
9
  # input_prompt_template:
examples/openai_single_agent_vertical.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: o3-mini
5
+ agent_type: openai
6
+ tools:
7
+ - "surf_spot_finder.tools.driving_hours_to_meters"
8
+ - "surf_spot_finder.tools.get_area_lat_lon"
9
+ - "surf_spot_finder.tools.get_surfing_spots"
10
+ - "surf_spot_finder.tools.get_wave_forecast"
11
+ - "surf_spot_finder.tools.get_wind_forecast"
12
+ - "surf_spot_finder.tools.search_web"
13
+ - "surf_spot_finder.tools.show_plan"
14
+ - "surf_spot_finder.tools.visit_webpage"
15
+ # input_prompt_template:
examples/smolagents_single_agent.yaml CHANGED
@@ -1,7 +1,7 @@
1
  location: Pontevedra
2
  date: 2025-03-22 12:00
3
  max_driving_hours: 2
4
- model_id: openai/gpt-3.5-turbo
5
  api_key_var: OPENAI_API_KEY
6
  agent_type: smolagents
7
  # input_prompt_template:
 
1
  location: Pontevedra
2
  date: 2025-03-22 12:00
3
  max_driving_hours: 2
4
+ model_id: openai/o3-mini
5
  api_key_var: OPENAI_API_KEY
6
  agent_type: smolagents
7
  # input_prompt_template:
examples/smolagents_single_agent_mcp.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: openai/gpt-3.5-turbo
5
+ api_key_var: OPENAI_API_KEY
6
+ agent_type: smolagents
7
+ tools:
8
+ - "smolagents.DuckDuckGoSearchTool"
9
+ - "mcp/fetch"
10
+
11
+ # input_prompt_template:
mkdocs.yml CHANGED
@@ -50,3 +50,4 @@ plugins:
50
  python:
51
  options:
52
  show_root_heading: true
 
 
50
  python:
51
  options:
52
  show_root_heading: true
53
+ heading_level: 3
pyproject.toml CHANGED
@@ -9,20 +9,28 @@ license = {text = "Apache-2.0"}
9
  requires-python = ">=3.10"
10
  dynamic = ["version"]
11
  dependencies = [
12
- "arize-phoenix>=8.12.1",
13
  "fire",
14
  "loguru",
15
- "mcp==1.3.0",
 
16
  "pydantic",
17
- "smolagents[litellm,mcp,telemetry]>=1.10.0",
18
  ]
19
 
20
  [project.optional-dependencies]
 
 
 
 
 
21
  openai = [
22
  "openai-agents",
23
  "openinference-instrumentation-openai"
24
  ]
25
 
 
 
 
 
26
  demo = [
27
  "gradio",
28
  "spaces"
@@ -41,7 +49,7 @@ tests = [
41
  ]
42
 
43
  # TODO maybe we don't want to keep this, or we want to swap this to Lumigator SDK
44
- tracing = [
45
  "arize-phoenix>=8.12.1",
46
  ]
47
 
 
9
  requires-python = ">=3.10"
10
  dynamic = ["version"]
11
  dependencies = [
 
12
  "fire",
13
  "loguru",
14
+ "opentelemetry-exporter-otlp",
15
+ "opentelemetry-sdk",
16
  "pydantic",
 
17
  ]
18
 
19
  [project.optional-dependencies]
20
+ smolagents = [
21
+ "smolagents[litellm]>=1.10.0",
22
+ "openinference-instrumentation-smolagents>=0.1.4"
23
+ ]
24
+
25
  openai = [
26
  "openai-agents",
27
  "openinference-instrumentation-openai"
28
  ]
29
 
30
+ mcp = [
31
+ "mcp==1.3.0",
32
+ ]
33
+
34
  demo = [
35
  "gradio",
36
  "spaces"
 
49
  ]
50
 
51
  # TODO maybe we don't want to keep this, or we want to swap this to Lumigator SDK
52
+ arize = [
53
  "arize-phoenix>=8.12.1",
54
  ]
55
 
src/surf_spot_finder/agents/openai.py CHANGED
@@ -1,85 +1,27 @@
 
1
  import os
2
  from typing import Optional
3
 
4
- from agents import (
5
- Agent,
6
- AsyncOpenAI,
7
- OpenAIChatCompletionsModel,
8
- Runner,
9
- RunResult,
10
- function_tool,
11
- )
12
  from loguru import logger
13
- from smolagents import (
14
- DuckDuckGoSearchTool,
15
- VisitWebpageTool,
16
- FinalAnswerTool,
17
- )
18
-
19
 
20
  from surf_spot_finder.prompts.openai import (
21
  SINGLE_AGENT_SYSTEM_PROMPT,
22
  MULTI_AGENT_SYSTEM_PROMPT,
23
  )
24
- from surf_spot_finder.tools.openmeteo import get_wave_forecast, get_wind_forecast
25
- from surf_spot_finder.tools.openstreetmap import (
26
- driving_hours_to_meters,
27
- get_area_lat_lon,
28
- get_surfing_places,
29
- )
30
-
31
- driving_hours_to_meters = function_tool(driving_hours_to_meters)
32
- get_area_lat_lon = function_tool(get_area_lat_lon)
33
- get_surfing_places = function_tool(get_surfing_places)
34
- get_wave_forecast = function_tool(get_wave_forecast)
35
- get_wind_forecast = function_tool(get_wind_forecast)
36
-
37
-
38
- @function_tool
39
- def search_web(query: str) -> str:
40
- """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
41
-
42
- Args:
43
- query: The search query to perform.
44
- """
45
- logger.debug(f"Calling search_web: {query}")
46
- search_tool = DuckDuckGoSearchTool()
47
- return search_tool.forward(query)
48
-
49
-
50
- @function_tool
51
- def visit_webpage(url: str) -> str:
52
- """Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages.
53
-
54
- Args:
55
- url: The url of the webpage to visit.
56
- """
57
- logger.debug(f"Calling visit_webpage: {url}")
58
- visit_tool = VisitWebpageTool()
59
- return visit_tool.forward(url)
60
-
61
-
62
- @function_tool
63
- def final_answer(answer: str) -> str:
64
- """Provides a final answer to the given problem.
65
-
66
- Args:
67
- answer: The answer to the problem.
68
- """
69
- logger.debug("Calling final_answer")
70
- final_answer_tool = FinalAnswerTool()
71
- return final_answer_tool.forward(answer)
72
-
73
 
74
- @function_tool
75
- def user_verification(query: str) -> str:
76
- """Asks user to verify the given `query`.
 
 
 
 
 
 
77
 
78
- Args:
79
- query: The question that requires verification.
80
- """
81
- logger.debug("Calling user_verification")
82
- return input(f"{query} => Type your answer here:")
83
 
84
 
85
  @logger.catch(reraise=True)
@@ -89,7 +31,8 @@ def run_openai_agent(
89
  name: str = "surf-spot-finder",
90
  instructions: Optional[str] = SINGLE_AGENT_SYSTEM_PROMPT,
91
  api_key_var: Optional[str] = None,
92
- base_url: Optional[str] = None,
 
93
  ) -> RunResult:
94
  """Runs an OpenAI agent with the given prompt and configuration.
95
 
@@ -109,19 +52,36 @@ def run_openai_agent(
109
  api_key_var (Optional[str], optional): The name of the environment variable
110
  containing the OpenAI API key. If provided, along with `base_url`, an
111
  external OpenAI client will be used. Defaults to None.
112
- base_url (Optional[str], optional): The base URL for the OpenAI API.
113
  Required if `api_key_var` is provided to use an external OpenAI client.
114
  Defaults to None.
115
 
 
116
  Returns:
117
  RunResult: A RunResult object containing the output of the agent run.
118
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
119
  """
120
-
121
- if api_key_var and base_url:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  external_client = AsyncOpenAI(
123
  api_key=os.environ[api_key_var],
124
- base_url=base_url,
125
  )
126
  agent = Agent(
127
  name=name,
@@ -130,22 +90,14 @@ def run_openai_agent(
130
  model=model_id,
131
  openai_client=external_client,
132
  ),
133
- tools=[search_web, visit_webpage],
134
  )
135
  else:
136
  agent = Agent(
137
  model=model_id,
138
  instructions=instructions,
139
  name=name,
140
- tools=[
141
- search_web,
142
- visit_webpage,
143
- get_area_lat_lon,
144
- get_surfing_places,
145
- get_wave_forecast,
146
- get_wind_forecast,
147
- driving_hours_to_meters,
148
- ],
149
  )
150
  result = Runner.run_sync(agent, prompt)
151
  logger.info(result.final_output)
@@ -158,6 +110,7 @@ def run_openai_multi_agent(
158
  prompt: str,
159
  name: str = "surf-spot-finder",
160
  instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
 
161
  ) -> RunResult:
162
  """Runs multiple OpenAI agents orchestrated by a main agent.
163
 
@@ -179,25 +132,36 @@ def run_openai_multi_agent(
179
  RunResult: A RunResult object containing the output of the agent run.
180
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
181
  """
 
 
 
 
 
 
 
 
 
 
 
182
  user_verification_agent = Agent(
183
  model=model_id,
184
- instructions="Display the current output to the user, then ask for verification.",
185
  name="user-verification-agent",
186
- tools=[user_verification],
187
  )
188
 
189
  search_web_agent = Agent(
190
  model=model_id,
191
- instructions="Find relevant information about the provided task by combining web searches with visiting webpages.",
192
  name="search-web-agent",
193
- tools=[search_web, visit_webpage],
194
  )
195
 
196
  communication_agent = Agent(
197
  model=model_id,
198
- instructions=None,
199
  name="communication-agent",
200
- tools=[final_answer],
201
  )
202
 
203
  main_agent = Agent(
 
1
+ import importlib
2
  import os
3
  from typing import Optional
4
 
 
 
 
 
 
 
 
 
5
  from loguru import logger
 
 
 
 
 
 
6
 
7
  from surf_spot_finder.prompts.openai import (
8
  SINGLE_AGENT_SYSTEM_PROMPT,
9
  MULTI_AGENT_SYSTEM_PROMPT,
10
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ try:
13
+ from agents import (
14
+ Agent,
15
+ AsyncOpenAI,
16
+ OpenAIChatCompletionsModel,
17
+ Runner,
18
+ RunResult,
19
+ function_tool,
20
+ )
21
 
22
+ agents_available = True
23
+ except ImportError:
24
+ agents_available = None
 
 
25
 
26
 
27
  @logger.catch(reraise=True)
 
31
  name: str = "surf-spot-finder",
32
  instructions: Optional[str] = SINGLE_AGENT_SYSTEM_PROMPT,
33
  api_key_var: Optional[str] = None,
34
+ api_base: Optional[str] = None,
35
+ tools: Optional[list[str]] = None,
36
  ) -> RunResult:
37
  """Runs an OpenAI agent with the given prompt and configuration.
38
 
 
52
  api_key_var (Optional[str], optional): The name of the environment variable
53
  containing the OpenAI API key. If provided, along with `base_url`, an
54
  external OpenAI client will be used. Defaults to None.
55
+ api_base (Optional[str], optional): The base URL for the OpenAI API.
56
  Required if `api_key_var` is provided to use an external OpenAI client.
57
  Defaults to None.
58
 
59
+
60
  Returns:
61
  RunResult: A RunResult object containing the output of the agent run.
62
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
63
  """
64
+ if not agents_available:
65
+ raise ImportError("You need to `pip install openai-agents` to use this agent")
66
+
67
+ if tools is None:
68
+ tools = [
69
+ "surf_spot_finder.tools.search_web",
70
+ "surf_spot_finder.tools.visit_webpage",
71
+ ]
72
+
73
+ imported_tools = []
74
+ for tool in tools:
75
+ module, func = tool.rsplit(".", 1)
76
+ module = importlib.import_module(module)
77
+ tool = getattr(module, func)
78
+ imported_tools.append(function_tool(tool))
79
+
80
+ logger.info(f"Imported tools: {imported_tools}")
81
+ if api_key_var and api_base:
82
  external_client = AsyncOpenAI(
83
  api_key=os.environ[api_key_var],
84
+ base_url=api_base,
85
  )
86
  agent = Agent(
87
  name=name,
 
90
  model=model_id,
91
  openai_client=external_client,
92
  ),
93
+ tools=imported_tools,
94
  )
95
  else:
96
  agent = Agent(
97
  model=model_id,
98
  instructions=instructions,
99
  name=name,
100
+ tools=imported_tools,
 
 
 
 
 
 
 
 
101
  )
102
  result = Runner.run_sync(agent, prompt)
103
  logger.info(result.final_output)
 
110
  prompt: str,
111
  name: str = "surf-spot-finder",
112
  instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
113
+ **kwargs,
114
  ) -> RunResult:
115
  """Runs multiple OpenAI agents orchestrated by a main agent.
116
 
 
132
  RunResult: A RunResult object containing the output of the agent run.
133
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
134
  """
135
+ if not agents_available:
136
+ raise ImportError("You need to `pip install openai-agents` to use this agent")
137
+
138
+ from surf_spot_finder.tools import (
139
+ ask_user_verification,
140
+ show_final_answer,
141
+ show_plan,
142
+ search_web,
143
+ visit_webpage,
144
+ )
145
+
146
  user_verification_agent = Agent(
147
  model=model_id,
148
+ instructions="Interact with the user by showing information and asking for verification.",
149
  name="user-verification-agent",
150
+ tools=[function_tool(ask_user_verification), function_tool(show_plan)],
151
  )
152
 
153
  search_web_agent = Agent(
154
  model=model_id,
155
+ instructions="Find relevant information about the provided task by using your tools.",
156
  name="search-web-agent",
157
+ tools=[function_tool(search_web), function_tool(visit_webpage)],
158
  )
159
 
160
  communication_agent = Agent(
161
  model=model_id,
162
+ instructions="Communicate the final answer to the user.",
163
  name="communication-agent",
164
+ tools=[function_tool(show_final_answer)],
165
  )
166
 
167
  main_agent = Agent(
src/surf_spot_finder/agents/smolagents.py CHANGED
@@ -1,17 +1,22 @@
 
1
  import os
2
  from typing import Optional
3
 
4
  from loguru import logger
5
 
6
- from smolagents import (
7
- CodeAgent,
8
- DuckDuckGoSearchTool,
9
- LiteLLMModel,
10
- ToolCollection,
11
- )
12
- from mcp import StdioServerParameters
13
  from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
14
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @logger.catch(reraise=True)
17
  def run_smolagent(
@@ -19,6 +24,7 @@ def run_smolagent(
19
  prompt: str,
20
  api_key_var: Optional[str] = None,
21
  api_base: Optional[str] = None,
 
22
  ) -> CodeAgent:
23
  """
24
  Create and configure a Smolagents CodeAgent with the specified model.
@@ -39,6 +45,26 @@ def run_smolagent(
39
  >>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
40
  >>> agent.run("Find surf spots near San Diego")
41
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  model = LiteLLMModel(
44
  model_id=model_id,
@@ -46,24 +72,31 @@ def run_smolagent(
46
  api_key=os.environ[api_key_var] if api_key_var else None,
47
  )
48
 
49
- # We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
50
- # or at https://glama.ai/mcp/servers
51
- # or at https://smithery.ai/
52
- server_parameters = StdioServerParameters(
53
- command="docker",
54
- args=["run", "-i", "--rm", "mcp/fetch"],
55
- env={**os.environ},
56
- )
57
- # https://huggingface.co/docs/smolagents/v1.10.0/en/reference/tools#smolagents.ToolCollection.from_mcp
58
- with ToolCollection.from_mcp(server_parameters) as tool_collection:
 
 
 
 
 
 
 
 
 
 
 
59
  agent = CodeAgent(
60
- tools=[
61
- *tool_collection.tools,
62
- DuckDuckGoSearchTool(),
63
- ],
64
  prompt_templates={"system_prompt": SYSTEM_PROMPT},
65
  model=model,
66
- add_base_tools=False, # Turn this on if you want to let it run python code as it sees fit
67
  )
68
  agent.run(prompt)
69
 
 
1
+ import importlib
2
  import os
3
  from typing import Optional
4
 
5
  from loguru import logger
6
 
 
 
 
 
 
 
 
7
  from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
8
 
9
+ try:
10
+ from smolagents import (
11
+ CodeAgent,
12
+ LiteLLMModel,
13
+ ToolCollection,
14
+ )
15
+
16
+ smolagents_available = True
17
+ except ImportError:
18
+ smolagents_available = None
19
+
20
 
21
  @logger.catch(reraise=True)
22
  def run_smolagent(
 
24
  prompt: str,
25
  api_key_var: Optional[str] = None,
26
  api_base: Optional[str] = None,
27
+ tools: Optional[list[str]] = None,
28
  ) -> CodeAgent:
29
  """
30
  Create and configure a Smolagents CodeAgent with the specified model.
 
45
  >>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
46
  >>> agent.run("Find surf spots near San Diego")
47
  """
48
+ if not smolagents_available:
49
+ raise ImportError("You need to `pip install smolagents` to use this agent")
50
+
51
+ if tools is None:
52
+ tools = [
53
+ "smolagents.DuckDuckGoSearchTool",
54
+ "smolagents.VisitWebpageTool",
55
+ "smolagents.PythonInterpreterTool",
56
+ ]
57
+
58
+ imported_tools = []
59
+ mcp_tool = None
60
+ for tool in tools:
61
+ if "mcp" in tool:
62
+ mcp_tool = tool
63
+ else:
64
+ module, func = tool.rsplit(".", 1)
65
+ module = importlib.import_module(module)
66
+ tool = getattr(module, func)
67
+ imported_tools.append(tool())
68
 
69
  model = LiteLLMModel(
70
  model_id=model_id,
 
72
  api_key=os.environ[api_key_var] if api_key_var else None,
73
  )
74
 
75
+ if mcp_tool:
76
+ from mcp import StdioServerParameters
77
+
78
+ # We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
79
+ # or at https://glama.ai/mcp/servers
80
+ # or at https://smithery.ai/
81
+ server_parameters = StdioServerParameters(
82
+ command="docker",
83
+ args=["run", "-i", "--rm", mcp_tool],
84
+ env={**os.environ},
85
+ )
86
+ # https://huggingface.co/docs/smolagents/v1.10.0/en/reference/tools#smolagents.ToolCollection.from_mcp
87
+ with ToolCollection.from_mcp(server_parameters) as tool_collection:
88
+ agent = CodeAgent(
89
+ tools=imported_tools + tool_collection.tools,
90
+ prompt_templates={"system_prompt": SYSTEM_PROMPT},
91
+ model=model,
92
+ add_base_tools=False, # Turn this on if you want to let it run python code as it sees fit
93
+ )
94
+ agent.run(prompt)
95
+ else:
96
  agent = CodeAgent(
97
+ tools=imported_tools,
 
 
 
98
  prompt_templates={"system_prompt": SYSTEM_PROMPT},
99
  model=model,
 
100
  )
101
  agent.run(prompt)
102
 
src/surf_spot_finder/cli.py CHANGED
@@ -24,6 +24,7 @@ def find_surf_spot(
24
  input_prompt_template: str = INPUT_PROMPT,
25
  json_tracer: bool = True,
26
  api_base: Optional[str] = None,
 
27
  from_config: Optional[str] = None,
28
  ):
29
  """Find the best surf spot based on the given criteria.
@@ -82,6 +83,9 @@ def find_surf_spot(
82
  MAX_DRIVING_HOURS=config.max_driving_hours,
83
  DATE=config.date,
84
  ),
 
 
 
85
  )
86
 
87
 
 
24
  input_prompt_template: str = INPUT_PROMPT,
25
  json_tracer: bool = True,
26
  api_base: Optional[str] = None,
27
+ tools: Optional[list[dict]] = None,
28
  from_config: Optional[str] = None,
29
  ):
30
  """Find the best surf spot based on the given criteria.
 
83
  MAX_DRIVING_HOURS=config.max_driving_hours,
84
  DATE=config.date,
85
  ),
86
+ api_base=config.api_base,
87
+ api_key_var=config.api_key_var,
88
+ tools=config.tools,
89
  )
90
 
91
 
src/surf_spot_finder/config.py CHANGED
@@ -30,3 +30,4 @@ class Config(BaseModel):
30
  api_key_var: Optional[str] = None
31
  json_tracer: bool = True
32
  api_base: Optional[str] = None
 
 
30
  api_key_var: Optional[str] = None
31
  json_tracer: bool = True
32
  api_base: Optional[str] = None
33
+ tools: Optional[list[str]] = None
src/surf_spot_finder/tools/__init__.py CHANGED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .openmeteo import get_wave_forecast, get_wind_forecast
2
+ from .openstreetmap import driving_hours_to_meters, get_area_lat_lon, get_surfing_spots
3
+ from .user_interaction import show_final_answer, show_plan, ask_user_verification
4
+ from .web_browsing import search_web, visit_webpage
5
+
6
+ __all__ = [
7
+ driving_hours_to_meters,
8
+ get_area_lat_lon,
9
+ get_surfing_spots,
10
+ get_wave_forecast,
11
+ get_wind_forecast,
12
+ search_web,
13
+ show_final_answer,
14
+ show_plan,
15
+ ask_user_verification,
16
+ visit_webpage,
17
+ ]
src/surf_spot_finder/tools/openstreetmap.py CHANGED
@@ -59,10 +59,10 @@ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
59
  )
60
 
61
 
62
- def get_surfing_places(
63
  lat: float, lon: float, radius: int
64
  ) -> list[tuple[str, tuple[float, float]]]:
65
- """Get surfing places around a given latitude and longitude.
66
 
67
  Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
68
 
 
59
  )
60
 
61
 
62
+ def get_surfing_spots(
63
  lat: float, lon: float, radius: int
64
  ) -> list[tuple[str, tuple[float, float]]]:
65
+ """Get surfing spots around a given latitude and longitude.
66
 
67
  Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
68
 
src/surf_spot_finder/tools/user_interaction.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+
3
+
4
+ def show_plan(plan: str) -> None:
5
+ """Show the current plan to the user.
6
+
7
+ Args:
8
+ plan: The current plan.
9
+ """
10
+ logger.info(f"Current plan: {plan}")
11
+ return plan
12
+
13
+
14
+ def show_final_answer(answer: str) -> None:
15
+ """Show the final answer to the user.
16
+
17
+ Args:
18
+ answer: The final answer.
19
+ """
20
+ logger.info(f"Final answer: {answer}")
21
+ return answer
22
+
23
+
24
+ def ask_user_verification(query: str) -> str:
25
+ """Asks user to verify the given `query`.
26
+
27
+ Args:
28
+ query: The question that requires verification.
29
+ """
30
+ return input(f"{query} => Type your answer here:")
src/surf_spot_finder/tools/web_browsing.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import requests
4
+ from duckduckgo_search import DDGS
5
+ from markdownify import markdownify
6
+ from requests.exceptions import RequestException
7
+
8
+
9
+ def _truncate_content(content: str, max_length: int) -> str:
10
+ if len(content) <= max_length:
11
+ return content
12
+ else:
13
+ return (
14
+ content[: max_length // 2]
15
+ + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
16
+ + content[-max_length // 2 :]
17
+ )
18
+
19
+
20
+ def search_web(query: str) -> str:
21
+ """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
22
+
23
+ Args:
24
+ query: The search query to perform.
25
+
26
+ Returns:
27
+ The top search results.
28
+ """
29
+ ddgs = DDGS()
30
+ results = ddgs.text(query, max_results=10)
31
+ return "\n".join(
32
+ f"[{result['title']}]({result['href']})\n{result['body']}" for result in results
33
+ )
34
+
35
+
36
+ def visit_webpage(url: str) -> str:
37
+ """Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages.
38
+
39
+ Args:
40
+ url: The url of the webpage to visit.
41
+ """
42
+ try:
43
+ response = requests.get(url)
44
+ response.raise_for_status()
45
+
46
+ markdown_content = markdownify(response.text).strip()
47
+
48
+ markdown_content = re.sub(r"\n{2,}", "\n", markdown_content)
49
+
50
+ return _truncate_content(markdown_content, 10000)
51
+ except RequestException as e:
52
+ return f"Error fetching the webpage: {str(e)}"
53
+ except Exception as e:
54
+ return f"An unexpected error occurred: {str(e)}"
src/surf_spot_finder/tracing.py CHANGED
@@ -6,7 +6,6 @@ from opentelemetry import trace
6
  from opentelemetry.sdk.trace import TracerProvider
7
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
8
  from opentelemetry.sdk.trace.export import SpanExporter
9
- from phoenix.otel import register
10
 
11
 
12
  class JsonFileSpanExporter(SpanExporter):
@@ -74,6 +73,8 @@ def get_tracer_provider(
74
  span_processor = SimpleSpanProcessor(json_file_exporter)
75
  tracer_provider.add_span_processor(span_processor)
76
  else:
 
 
77
  tracer_provider = register(
78
  project_name=project_name, set_global_tracer_provider=True
79
  )
 
6
  from opentelemetry.sdk.trace import TracerProvider
7
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
8
  from opentelemetry.sdk.trace.export import SpanExporter
 
9
 
10
 
11
  class JsonFileSpanExporter(SpanExporter):
 
73
  span_processor = SimpleSpanProcessor(json_file_exporter)
74
  tracer_provider.add_span_processor(span_processor)
75
  else:
76
+ from phoenix.otel import register
77
+
78
  tracer_provider = register(
79
  project_name=project_name, set_global_tracer_provider=True
80
  )
tests/unit/agents/test_unit_openai.py CHANGED
@@ -3,11 +3,14 @@ import pytest
3
  from unittest.mock import patch, MagicMock, ANY
4
 
5
  from surf_spot_finder.agents.openai import (
6
- final_answer,
7
  run_openai_agent,
8
  run_openai_multi_agent,
 
 
 
 
 
9
  search_web,
10
- user_verification,
11
  visit_webpage,
12
  )
13
  from surf_spot_finder.prompts.openai import (
@@ -46,7 +49,7 @@ def test_run_openai_agent_base_url_and_api_key_var():
46
  patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}),
47
  ):
48
  run_openai_agent(
49
- "gpt-4o", "Test prompt", base_url="FOO", api_key_var="TEST_API_KEY"
50
  )
51
  async_openai_mock.assert_called_once_with(
52
  api_key="test-key-12345",
@@ -59,37 +62,42 @@ def test_run_openai_environment_error():
59
  with patch.dict(os.environ, {}, clear=True):
60
  with pytest.raises(KeyError, match="MISSING_KEY"):
61
  run_openai_agent(
62
- "test-model", "Test prompt", base_url="FOO", api_key_var="MISSING_KEY"
63
  )
64
 
65
 
66
  def test_run_openai_multiagent():
67
  mock_agent = MagicMock()
 
68
 
69
  with (
70
  patch("surf_spot_finder.agents.openai.Agent", mock_agent),
71
  patch("surf_spot_finder.agents.openai.Runner", MagicMock()),
 
72
  ):
73
  run_openai_multi_agent("gpt-4o", "Test prompt")
74
  mock_agent.assert_any_call(
75
  model="gpt-4o",
76
- instructions="Display the current output to the user, then ask for verification.",
77
  name="user-verification-agent",
78
- tools=[user_verification],
 
 
 
79
  )
80
 
81
  mock_agent.assert_any_call(
82
  model="gpt-4o",
83
- instructions="Find relevant information about the provided task by combining web searches with visiting webpages.",
84
  name="search-web-agent",
85
- tools=[search_web, visit_webpage],
86
  )
87
 
88
  mock_agent.assert_any_call(
89
  model="gpt-4o",
90
- instructions=None,
91
  name="communication-agent",
92
- tools=[final_answer],
93
  )
94
 
95
  mock_agent.assert_any_call(
 
3
  from unittest.mock import patch, MagicMock, ANY
4
 
5
  from surf_spot_finder.agents.openai import (
 
6
  run_openai_agent,
7
  run_openai_multi_agent,
8
+ )
9
+ from surf_spot_finder.tools import (
10
+ show_final_answer,
11
+ show_plan,
12
+ ask_user_verification,
13
  search_web,
 
14
  visit_webpage,
15
  )
16
  from surf_spot_finder.prompts.openai import (
 
49
  patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}),
50
  ):
51
  run_openai_agent(
52
+ "gpt-4o", "Test prompt", api_base="FOO", api_key_var="TEST_API_KEY"
53
  )
54
  async_openai_mock.assert_called_once_with(
55
  api_key="test-key-12345",
 
62
  with patch.dict(os.environ, {}, clear=True):
63
  with pytest.raises(KeyError, match="MISSING_KEY"):
64
  run_openai_agent(
65
+ "test-model", "Test prompt", api_base="FOO", api_key_var="MISSING_KEY"
66
  )
67
 
68
 
69
  def test_run_openai_multiagent():
70
  mock_agent = MagicMock()
71
+ mock_function_tool = MagicMock()
72
 
73
  with (
74
  patch("surf_spot_finder.agents.openai.Agent", mock_agent),
75
  patch("surf_spot_finder.agents.openai.Runner", MagicMock()),
76
+ patch("surf_spot_finder.agents.openai.function_tool", mock_function_tool),
77
  ):
78
  run_openai_multi_agent("gpt-4o", "Test prompt")
79
  mock_agent.assert_any_call(
80
  model="gpt-4o",
81
+ instructions="Interact with the user by showing information and asking for verification.",
82
  name="user-verification-agent",
83
+ tools=[
84
+ mock_function_tool(show_plan),
85
+ mock_function_tool(ask_user_verification),
86
+ ],
87
  )
88
 
89
  mock_agent.assert_any_call(
90
  model="gpt-4o",
91
+ instructions="Find relevant information about the provided task by using your tools.",
92
  name="search-web-agent",
93
+ tools=[mock_function_tool(search_web), mock_function_tool(visit_webpage)],
94
  )
95
 
96
  mock_agent.assert_any_call(
97
  model="gpt-4o",
98
+ instructions="Communicate the final answer to the user.",
99
  name="communication-agent",
100
+ tools=[mock_function_tool(show_final_answer)],
101
  )
102
 
103
  mock_agent.assert_any_call(
tests/unit/agents/test_unit_smolagents.py CHANGED
@@ -11,31 +11,18 @@ def common_patches():
11
  litellm_model_mock = MagicMock()
12
  code_agent_mock = MagicMock()
13
  patch_context = contextlib.ExitStack()
14
- mock_tool_collection = MagicMock()
15
-
16
- mock_tool_collection.from_mcp.return_value.__enter__.return_value = (
17
- mock_tool_collection
18
- )
19
- mock_tool_collection.from_mcp.return_value.__exit__.return_value = None
20
- mock_tool_collection.tools = ["mock_tool"]
21
- patch_context.enter_context(
22
- patch("surf_spot_finder.agents.smolagents.StdioServerParameters", MagicMock())
23
- )
24
  patch_context.enter_context(
25
  patch("surf_spot_finder.agents.smolagents.CodeAgent", code_agent_mock)
26
  )
27
  patch_context.enter_context(
28
  patch("surf_spot_finder.agents.smolagents.LiteLLMModel", litellm_model_mock)
29
  )
30
- patch_context.enter_context(
31
- patch("surf_spot_finder.agents.smolagents.ToolCollection", mock_tool_collection)
32
- )
33
- yield patch_context, litellm_model_mock, code_agent_mock, mock_tool_collection
34
  patch_context.close()
35
 
36
 
37
  def test_run_smolagent_with_api_key_var(common_patches):
38
- patch_context, litellm_model_mock, code_agent_mock, *_ = common_patches
39
 
40
  with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
41
  run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
 
11
  litellm_model_mock = MagicMock()
12
  code_agent_mock = MagicMock()
13
  patch_context = contextlib.ExitStack()
 
 
 
 
 
 
 
 
 
 
14
  patch_context.enter_context(
15
  patch("surf_spot_finder.agents.smolagents.CodeAgent", code_agent_mock)
16
  )
17
  patch_context.enter_context(
18
  patch("surf_spot_finder.agents.smolagents.LiteLLMModel", litellm_model_mock)
19
  )
20
+ yield patch_context, litellm_model_mock, code_agent_mock
 
 
 
21
  patch_context.close()
22
 
23
 
24
  def test_run_smolagent_with_api_key_var(common_patches):
25
+ patch_context, litellm_model_mock, code_agent_mock = common_patches
26
 
27
  with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
28
  run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
tests/unit/test_unit_tracing.py CHANGED
@@ -14,7 +14,7 @@ def test_get_tracer_provider(tmp_path, json_tracer):
14
  with (
15
  patch("surf_spot_finder.tracing.trace", mock_trace),
16
  patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
17
- patch("surf_spot_finder.tracing.register", mock_register),
18
  ):
19
  get_tracer_provider(
20
  project_name="test_project",
 
14
  with (
15
  patch("surf_spot_finder.tracing.trace", mock_trace),
16
  patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
17
+ patch("phoenix.otel.register", mock_register),
18
  ):
19
  get_tracer_provider(
20
  project_name="test_project",
tests/unit/tools/test_unit_openstreetmap.py CHANGED
@@ -29,7 +29,7 @@ def test_get_lat_lon_center():
29
  assert lon == -2.5
30
 
31
 
32
- def test_get_surfing_places():
33
  with patch("requests.get") as mock_get:
34
  mock_response = MagicMock()
35
  mock_response.status_code = 200
@@ -66,7 +66,7 @@ def test_get_surfing_places():
66
  }
67
  mock_get.return_value = mock_response
68
 
69
- results = openstreetmap.get_surfing_places(lat=40.5, lon=-3.5, radius=10000)
70
  assert len(results) == 2
71
  assert results[0][0] == "Surf Spot 1"
72
  assert results[0][1] == (40.05, -2.95)
 
29
  assert lon == -2.5
30
 
31
 
32
+ def test_get_surfing_spots():
33
  with patch("requests.get") as mock_get:
34
  mock_response = MagicMock()
35
  mock_response.status_code = 200
 
66
  }
67
  mock_get.return_value = mock_response
68
 
69
+ results = openstreetmap.get_surfing_spots(lat=40.5, lon=-3.5, radius=10000)
70
  assert len(results) == 2
71
  assert results[0][0] == "Surf Spot 1"
72
  assert results[0][1] == (40.05, -2.95)