David de la Iglesia Castro commited on
Commit
ebeb860
·
unverified ·
1 Parent(s): bf33614

enh(tools): Add `wrappers` module. (#23)

Browse files

Allows to import and reuse same tools for different frameworks.

Add `smolagents_single_agent_vertical` and `langchain_single_agent_vertical`

examples/langchain_single_agent_vertical.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: o3-mini
5
+ agent_type: langchain
6
+ tools:
7
+ - "surf_spot_finder.tools.driving_hours_to_meters"
8
+ - "surf_spot_finder.tools.get_area_lat_lon"
9
+ - "surf_spot_finder.tools.get_surfing_spots"
10
+ - "surf_spot_finder.tools.get_wave_forecast"
11
+ - "surf_spot_finder.tools.get_wind_forecast"
12
+ - "surf_spot_finder.tools.search_web"
13
+ - "surf_spot_finder.tools.visit_webpage"
14
+ # input_prompt_template:
examples/smolagents_single_agent_vertical.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: openai/o3-mini
5
+ api_key_var: OPENAI_API_KEY
6
+ agent_type: smolagents
7
+ tools:
8
+ - "surf_spot_finder.tools.driving_hours_to_meters"
9
+ - "surf_spot_finder.tools.get_area_lat_lon"
10
+ - "surf_spot_finder.tools.get_surfing_spots"
11
+ - "surf_spot_finder.tools.get_wave_forecast"
12
+ - "surf_spot_finder.tools.get_wind_forecast"
13
+ - "surf_spot_finder.tools.search_web"
14
+ - "surf_spot_finder.tools.visit_webpage"
15
+ - "smolagents.PythonInterpreterTool"
16
+ - "smolagents.FinalAnswerTool"
17
+ # input_prompt_template:
src/surf_spot_finder/agents/openai.py CHANGED
@@ -1,4 +1,3 @@
1
- import importlib
2
  import os
3
  from typing import Optional
4
 
@@ -8,6 +7,8 @@ from surf_spot_finder.prompts.openai import (
8
  SINGLE_AGENT_SYSTEM_PROMPT,
9
  MULTI_AGENT_SYSTEM_PROMPT,
10
  )
 
 
11
 
12
  try:
13
  from agents import (
@@ -70,12 +71,7 @@ def run_openai_agent(
70
  "surf_spot_finder.tools.visit_webpage",
71
  ]
72
 
73
- imported_tools = []
74
- for tool in tools:
75
- module, func = tool.rsplit(".", 1)
76
- module = importlib.import_module(module)
77
- tool = getattr(module, func)
78
- imported_tools.append(function_tool(tool))
79
 
80
  if api_key_var and api_base:
81
  external_client = AsyncOpenAI(
 
 
1
  import os
2
  from typing import Optional
3
 
 
7
  SINGLE_AGENT_SYSTEM_PROMPT,
8
  MULTI_AGENT_SYSTEM_PROMPT,
9
  )
10
+ from surf_spot_finder.tools.wrappers import import_and_wrap_tools, wrap_tool_openai
11
+
12
 
13
  try:
14
  from agents import (
 
71
  "surf_spot_finder.tools.visit_webpage",
72
  ]
73
 
74
+ imported_tools = import_and_wrap_tools(tools, wrap_tool_openai)
 
 
 
 
 
75
 
76
  if api_key_var and api_base:
77
  external_client = AsyncOpenAI(
src/surf_spot_finder/agents/smolagents.py CHANGED
@@ -1,16 +1,16 @@
1
- import importlib
2
  import os
3
  from typing import Optional
4
 
5
  from loguru import logger
6
 
7
  from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
 
 
8
 
9
  try:
10
  from smolagents import (
11
  CodeAgent,
12
  LiteLLMModel,
13
- ToolCollection,
14
  )
15
 
16
  smolagents_available = True
@@ -55,16 +55,12 @@ def run_smolagent(
55
  "smolagents.PythonInterpreterTool",
56
  ]
57
 
58
- imported_tools = []
59
  mcp_tool = None
60
  for tool in tools:
61
  if "mcp" in tool:
62
  mcp_tool = tool
63
- else:
64
- module, func = tool.rsplit(".", 1)
65
- module = importlib.import_module(module)
66
- tool = getattr(module, func)
67
- imported_tools.append(tool())
68
 
69
  model = LiteLLMModel(
70
  model_id=model_id,
@@ -74,6 +70,7 @@ def run_smolagent(
74
 
75
  if mcp_tool:
76
  from mcp import StdioServerParameters
 
77
 
78
  # We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
79
  # or at https://glama.ai/mcp/servers
 
 
1
  import os
2
  from typing import Optional
3
 
4
  from loguru import logger
5
 
6
  from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
7
+ from surf_spot_finder.tools.wrappers import import_and_wrap_tools, wrap_tool_smolagents
8
+
9
 
10
  try:
11
  from smolagents import (
12
  CodeAgent,
13
  LiteLLMModel,
 
14
  )
15
 
16
  smolagents_available = True
 
55
  "smolagents.PythonInterpreterTool",
56
  ]
57
 
 
58
  mcp_tool = None
59
  for tool in tools:
60
  if "mcp" in tool:
61
  mcp_tool = tool
62
+ tools.remove(tool)
63
+ imported_tools = import_and_wrap_tools(tools, wrap_tool_smolagents)
 
 
 
64
 
65
  model = LiteLLMModel(
66
  model_id=model_id,
 
70
 
71
  if mcp_tool:
72
  from mcp import StdioServerParameters
73
+ from smolagents import ToolCollection
74
 
75
  # We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
76
  # or at https://glama.ai/mcp/servers
src/surf_spot_finder/tools/openmeteo.py CHANGED
@@ -35,13 +35,13 @@ def get_wave_forecast(lat: float, lon: float, date: str | None = None) -> list[d
35
  - sea_level_height_msl (meters)
36
 
37
  Args:
38
- lat (float): Latitude of the location.
39
- lon (float): Longitude of the location.
40
- date (str | None): Date to filter by in any valid ISO 8601 format.
41
  If not provided, all data (default to 6 days forecast) will be returned.
42
 
43
  Returns:
44
- list[dict]: Hourly data for wave forecast.
45
  Example output:
46
 
47
  ```json
@@ -81,19 +81,19 @@ def get_wind_forecast(lat: float, lon: float, date: str | None = None) -> list[d
81
  - wind_speed (meters per second)
82
 
83
  Args:
84
- lat (float): Latitude of the location.
85
- lon (float): Longitude of the location.
86
- date (str | None): Date to filter by in any valid ISO 8601 format.
87
  If not provided, all data (default to 6 days forecast) will be returned.
88
 
89
  Returns:
90
- list[dict]: Hourly data for wind forecast.
91
  Example output:
92
 
93
  ```json
94
  [
95
- {"time": "2025-03-18T22:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.45, "sea_level_height_msl": -1.27},
96
- {"time": "2025-03-18T23:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.35, "sea_level_height_msl": -1.35},
97
  ]
98
  ```
99
  """
 
35
  - sea_level_height_msl (meters)
36
 
37
  Args:
38
+ lat: Latitude of the location.
39
+ lon: Longitude of the location.
40
+ date: Date to filter by in any valid ISO 8601 format.
41
  If not provided, all data (default to 6 days forecast) will be returned.
42
 
43
  Returns:
44
+ Hourly data for wave forecast.
45
  Example output:
46
 
47
  ```json
 
81
  - wind_speed (meters per second)
82
 
83
  Args:
84
+ lat: Latitude of the location.
85
+ lon: Longitude of the location.
86
+ date: Date to filter by in any valid ISO 8601 format.
87
  If not provided, all data (default to 6 days forecast) will be returned.
88
 
89
  Returns:
90
+ Hourly data for wind forecast.
91
  Example output:
92
 
93
  ```json
94
  [
95
+ {"time": "2025-03-18T22:00", "wind_direction": 196, "wind_speed": 9.6},
96
+ {"time": "2025-03-18T23:00", "wind_direction": 183, "wind_speed": 7.9},
97
  ]
98
  ```
99
  """
src/surf_spot_finder/tools/openstreetmap.py CHANGED
@@ -8,10 +8,10 @@ def get_area_lat_lon(area_name: str) -> tuple[float, float]:
8
  Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
9
 
10
  Args:
11
- area_name (str): The name of the area.
12
 
13
  Returns:
14
- dict: The area found.
15
  """
16
  response = requests.get(
17
  f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
@@ -27,10 +27,10 @@ def driving_hours_to_meters(driving_hours: int) -> int:
27
 
28
 
29
  Args:
30
- driving_hours (int): The driving hours.
31
 
32
  Returns:
33
- int: The distance in meters.
34
  """
35
  return driving_hours * 70 * 1000
36
 
@@ -39,7 +39,7 @@ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
39
  """Get the latitude and longitude of the center of a bounding box.
40
 
41
  Args:
42
- bounds (dict): The bounding box.
43
 
44
  ```json
45
  {
@@ -51,7 +51,7 @@ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
51
  ```
52
 
53
  Returns:
54
- tuple: The latitude and longitude of the center.
55
  """
56
  return (
57
  (bounds["minlat"] + bounds["maxlat"]) / 2,
@@ -67,12 +67,12 @@ def get_surfing_spots(
67
  Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
68
 
69
  Args:
70
- lat (float): The latitude.
71
- lon (float): The longitude.
72
- radius (int): The radius in meters.
73
 
74
  Returns:
75
- dict: The surfing places found.
76
  """
77
  overpass_url = "https://overpass-api.de/api/interpreter"
78
  query = "[out:json];("
 
8
  Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
9
 
10
  Args:
11
+ area_name: The name of the area.
12
 
13
  Returns:
14
+ The area found.
15
  """
16
  response = requests.get(
17
  f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
 
27
 
28
 
29
  Args:
30
+ driving_hours: The driving hours.
31
 
32
  Returns:
33
+ The distance in meters.
34
  """
35
  return driving_hours * 70 * 1000
36
 
 
39
  """Get the latitude and longitude of the center of a bounding box.
40
 
41
  Args:
42
+ bounds: The bounding box.
43
 
44
  ```json
45
  {
 
51
  ```
52
 
53
  Returns:
54
+ The latitude and longitude of the center.
55
  """
56
  return (
57
  (bounds["minlat"] + bounds["maxlat"]) / 2,
 
67
  Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
68
 
69
  Args:
70
+ lat: The latitude.
71
+ lon: The longitude.
72
+ radius: The radius in meters.
73
 
74
  Returns:
75
+ The surfing places found.
76
  """
77
  overpass_url = "https://overpass-api.de/api/interpreter"
78
  query = "[out:json];("
src/surf_spot_finder/tools/wrappers.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import importlib
3
+ from collections.abc import Callable
4
+
5
+
6
+ def import_and_wrap_tools(tools: list[str], wrapper: Callable) -> list[Callable]:
7
+ imported_tools = []
8
+ for tool in tools:
9
+ module, func = tool.rsplit(".", 1)
10
+ module = importlib.import_module(module)
11
+ imported_tool = getattr(module, func)
12
+ if inspect.isclass(imported_tool):
13
+ imported_tool = imported_tool()
14
+ imported_tools.append(wrapper(imported_tool))
15
+ return imported_tools
16
+
17
+
18
+ def wrap_tool_openai(tool):
19
+ from agents import function_tool, FunctionTool
20
+
21
+ if not isinstance(tool, FunctionTool):
22
+ return function_tool(tool)
23
+ return tool
24
+
25
+
26
+ def wrap_tool_langchain(tool):
27
+ from langchain_core.tools import BaseTool
28
+ from langchain_core.tools import tool as langchain_tool
29
+
30
+ if not isinstance(tool, BaseTool):
31
+ return langchain_tool(tool)
32
+ return tool
33
+
34
+
35
+ def wrap_tool_smolagents(tool):
36
+ from smolagents import Tool, tool as smolagents_tool
37
+
38
+ if not isinstance(tool, Tool):
39
+ return smolagents_tool(tool)
40
+
41
+ return tool