Spaces:
Running
Running
David de la Iglesia Castro
commited on
enh(tools): Add `wrappers` module. (#23)
Browse filesAllows to import and reuse same tools for different frameworks.
Add `smolagents_single_agent_vertical` and `langchain_single_agent_vertical`
- examples/langchain_single_agent_vertical.yaml +14 -0
- examples/smolagents_single_agent_vertical.yaml +17 -0
- src/surf_spot_finder/agents/openai.py +3 -7
- src/surf_spot_finder/agents/smolagents.py +5 -8
- src/surf_spot_finder/tools/openmeteo.py +10 -10
- src/surf_spot_finder/tools/openstreetmap.py +10 -10
- src/surf_spot_finder/tools/wrappers.py +41 -0
examples/langchain_single_agent_vertical.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: o3-mini
|
5 |
+
agent_type: langchain
|
6 |
+
tools:
|
7 |
+
- "surf_spot_finder.tools.driving_hours_to_meters"
|
8 |
+
- "surf_spot_finder.tools.get_area_lat_lon"
|
9 |
+
- "surf_spot_finder.tools.get_surfing_spots"
|
10 |
+
- "surf_spot_finder.tools.get_wave_forecast"
|
11 |
+
- "surf_spot_finder.tools.get_wind_forecast"
|
12 |
+
- "surf_spot_finder.tools.search_web"
|
13 |
+
- "surf_spot_finder.tools.visit_webpage"
|
14 |
+
# input_prompt_template:
|
examples/smolagents_single_agent_vertical.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: openai/o3-mini
|
5 |
+
api_key_var: OPENAI_API_KEY
|
6 |
+
agent_type: smolagents
|
7 |
+
tools:
|
8 |
+
- "surf_spot_finder.tools.driving_hours_to_meters"
|
9 |
+
- "surf_spot_finder.tools.get_area_lat_lon"
|
10 |
+
- "surf_spot_finder.tools.get_surfing_spots"
|
11 |
+
- "surf_spot_finder.tools.get_wave_forecast"
|
12 |
+
- "surf_spot_finder.tools.get_wind_forecast"
|
13 |
+
- "surf_spot_finder.tools.search_web"
|
14 |
+
- "surf_spot_finder.tools.visit_webpage"
|
15 |
+
- "smolagents.PythonInterpreterTool"
|
16 |
+
- "smolagents.FinalAnswerTool"
|
17 |
+
# input_prompt_template:
|
src/surf_spot_finder/agents/openai.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import importlib
|
2 |
import os
|
3 |
from typing import Optional
|
4 |
|
@@ -8,6 +7,8 @@ from surf_spot_finder.prompts.openai import (
|
|
8 |
SINGLE_AGENT_SYSTEM_PROMPT,
|
9 |
MULTI_AGENT_SYSTEM_PROMPT,
|
10 |
)
|
|
|
|
|
11 |
|
12 |
try:
|
13 |
from agents import (
|
@@ -70,12 +71,7 @@ def run_openai_agent(
|
|
70 |
"surf_spot_finder.tools.visit_webpage",
|
71 |
]
|
72 |
|
73 |
-
imported_tools =
|
74 |
-
for tool in tools:
|
75 |
-
module, func = tool.rsplit(".", 1)
|
76 |
-
module = importlib.import_module(module)
|
77 |
-
tool = getattr(module, func)
|
78 |
-
imported_tools.append(function_tool(tool))
|
79 |
|
80 |
if api_key_var and api_base:
|
81 |
external_client = AsyncOpenAI(
|
|
|
|
|
1 |
import os
|
2 |
from typing import Optional
|
3 |
|
|
|
7 |
SINGLE_AGENT_SYSTEM_PROMPT,
|
8 |
MULTI_AGENT_SYSTEM_PROMPT,
|
9 |
)
|
10 |
+
from surf_spot_finder.tools.wrappers import import_and_wrap_tools, wrap_tool_openai
|
11 |
+
|
12 |
|
13 |
try:
|
14 |
from agents import (
|
|
|
71 |
"surf_spot_finder.tools.visit_webpage",
|
72 |
]
|
73 |
|
74 |
+
imported_tools = import_and_wrap_tools(tools, wrap_tool_openai)
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
if api_key_var and api_base:
|
77 |
external_client = AsyncOpenAI(
|
src/surf_spot_finder/agents/smolagents.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
import importlib
|
2 |
import os
|
3 |
from typing import Optional
|
4 |
|
5 |
from loguru import logger
|
6 |
|
7 |
from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
|
|
|
|
|
8 |
|
9 |
try:
|
10 |
from smolagents import (
|
11 |
CodeAgent,
|
12 |
LiteLLMModel,
|
13 |
-
ToolCollection,
|
14 |
)
|
15 |
|
16 |
smolagents_available = True
|
@@ -55,16 +55,12 @@ def run_smolagent(
|
|
55 |
"smolagents.PythonInterpreterTool",
|
56 |
]
|
57 |
|
58 |
-
imported_tools = []
|
59 |
mcp_tool = None
|
60 |
for tool in tools:
|
61 |
if "mcp" in tool:
|
62 |
mcp_tool = tool
|
63 |
-
|
64 |
-
|
65 |
-
module = importlib.import_module(module)
|
66 |
-
tool = getattr(module, func)
|
67 |
-
imported_tools.append(tool())
|
68 |
|
69 |
model = LiteLLMModel(
|
70 |
model_id=model_id,
|
@@ -74,6 +70,7 @@ def run_smolagent(
|
|
74 |
|
75 |
if mcp_tool:
|
76 |
from mcp import StdioServerParameters
|
|
|
77 |
|
78 |
# We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
|
79 |
# or at https://glama.ai/mcp/servers
|
|
|
|
|
1 |
import os
|
2 |
from typing import Optional
|
3 |
|
4 |
from loguru import logger
|
5 |
|
6 |
from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
|
7 |
+
from surf_spot_finder.tools.wrappers import import_and_wrap_tools, wrap_tool_smolagents
|
8 |
+
|
9 |
|
10 |
try:
|
11 |
from smolagents import (
|
12 |
CodeAgent,
|
13 |
LiteLLMModel,
|
|
|
14 |
)
|
15 |
|
16 |
smolagents_available = True
|
|
|
55 |
"smolagents.PythonInterpreterTool",
|
56 |
]
|
57 |
|
|
|
58 |
mcp_tool = None
|
59 |
for tool in tools:
|
60 |
if "mcp" in tool:
|
61 |
mcp_tool = tool
|
62 |
+
tools.remove(tool)
|
63 |
+
imported_tools = import_and_wrap_tools(tools, wrap_tool_smolagents)
|
|
|
|
|
|
|
64 |
|
65 |
model = LiteLLMModel(
|
66 |
model_id=model_id,
|
|
|
70 |
|
71 |
if mcp_tool:
|
72 |
from mcp import StdioServerParameters
|
73 |
+
from smolagents import ToolCollection
|
74 |
|
75 |
# We could easily use any of the MCPs at https://github.com/modelcontextprotocol/servers
|
76 |
# or at https://glama.ai/mcp/servers
|
src/surf_spot_finder/tools/openmeteo.py
CHANGED
@@ -35,13 +35,13 @@ def get_wave_forecast(lat: float, lon: float, date: str | None = None) -> list[d
|
|
35 |
- sea_level_height_msl (meters)
|
36 |
|
37 |
Args:
|
38 |
-
lat
|
39 |
-
lon
|
40 |
-
date
|
41 |
If not provided, all data (default to 6 days forecast) will be returned.
|
42 |
|
43 |
Returns:
|
44 |
-
|
45 |
Example output:
|
46 |
|
47 |
```json
|
@@ -81,19 +81,19 @@ def get_wind_forecast(lat: float, lon: float, date: str | None = None) -> list[d
|
|
81 |
- wind_speed (meters per second)
|
82 |
|
83 |
Args:
|
84 |
-
lat
|
85 |
-
lon
|
86 |
-
date
|
87 |
If not provided, all data (default to 6 days forecast) will be returned.
|
88 |
|
89 |
Returns:
|
90 |
-
|
91 |
Example output:
|
92 |
|
93 |
```json
|
94 |
[
|
95 |
-
{"time": "2025-03-18T22:00", "
|
96 |
-
{"time": "2025-03-18T23:00", "
|
97 |
]
|
98 |
```
|
99 |
"""
|
|
|
35 |
- sea_level_height_msl (meters)
|
36 |
|
37 |
Args:
|
38 |
+
lat: Latitude of the location.
|
39 |
+
lon: Longitude of the location.
|
40 |
+
date: Date to filter by in any valid ISO 8601 format.
|
41 |
If not provided, all data (default to 6 days forecast) will be returned.
|
42 |
|
43 |
Returns:
|
44 |
+
Hourly data for wave forecast.
|
45 |
Example output:
|
46 |
|
47 |
```json
|
|
|
81 |
- wind_speed (meters per second)
|
82 |
|
83 |
Args:
|
84 |
+
lat: Latitude of the location.
|
85 |
+
lon: Longitude of the location.
|
86 |
+
date: Date to filter by in any valid ISO 8601 format.
|
87 |
If not provided, all data (default to 6 days forecast) will be returned.
|
88 |
|
89 |
Returns:
|
90 |
+
Hourly data for wind forecast.
|
91 |
Example output:
|
92 |
|
93 |
```json
|
94 |
[
|
95 |
+
{"time": "2025-03-18T22:00", "wind_direction": 196, "wind_speed": 9.6},
|
96 |
+
{"time": "2025-03-18T23:00", "wind_direction": 183, "wind_speed": 7.9},
|
97 |
]
|
98 |
```
|
99 |
"""
|
src/surf_spot_finder/tools/openstreetmap.py
CHANGED
@@ -8,10 +8,10 @@ def get_area_lat_lon(area_name: str) -> tuple[float, float]:
|
|
8 |
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
9 |
|
10 |
Args:
|
11 |
-
area_name
|
12 |
|
13 |
Returns:
|
14 |
-
|
15 |
"""
|
16 |
response = requests.get(
|
17 |
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
|
@@ -27,10 +27,10 @@ def driving_hours_to_meters(driving_hours: int) -> int:
|
|
27 |
|
28 |
|
29 |
Args:
|
30 |
-
driving_hours
|
31 |
|
32 |
Returns:
|
33 |
-
|
34 |
"""
|
35 |
return driving_hours * 70 * 1000
|
36 |
|
@@ -39,7 +39,7 @@ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
|
|
39 |
"""Get the latitude and longitude of the center of a bounding box.
|
40 |
|
41 |
Args:
|
42 |
-
bounds
|
43 |
|
44 |
```json
|
45 |
{
|
@@ -51,7 +51,7 @@ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
|
|
51 |
```
|
52 |
|
53 |
Returns:
|
54 |
-
|
55 |
"""
|
56 |
return (
|
57 |
(bounds["minlat"] + bounds["maxlat"]) / 2,
|
@@ -67,12 +67,12 @@ def get_surfing_spots(
|
|
67 |
Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
|
68 |
|
69 |
Args:
|
70 |
-
lat
|
71 |
-
lon
|
72 |
-
radius
|
73 |
|
74 |
Returns:
|
75 |
-
|
76 |
"""
|
77 |
overpass_url = "https://overpass-api.de/api/interpreter"
|
78 |
query = "[out:json];("
|
|
|
8 |
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
9 |
|
10 |
Args:
|
11 |
+
area_name: The name of the area.
|
12 |
|
13 |
Returns:
|
14 |
+
The area found.
|
15 |
"""
|
16 |
response = requests.get(
|
17 |
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
|
|
|
27 |
|
28 |
|
29 |
Args:
|
30 |
+
driving_hours: The driving hours.
|
31 |
|
32 |
Returns:
|
33 |
+
The distance in meters.
|
34 |
"""
|
35 |
return driving_hours * 70 * 1000
|
36 |
|
|
|
39 |
"""Get the latitude and longitude of the center of a bounding box.
|
40 |
|
41 |
Args:
|
42 |
+
bounds: The bounding box.
|
43 |
|
44 |
```json
|
45 |
{
|
|
|
51 |
```
|
52 |
|
53 |
Returns:
|
54 |
+
The latitude and longitude of the center.
|
55 |
"""
|
56 |
return (
|
57 |
(bounds["minlat"] + bounds["maxlat"]) / 2,
|
|
|
67 |
Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
|
68 |
|
69 |
Args:
|
70 |
+
lat: The latitude.
|
71 |
+
lon: The longitude.
|
72 |
+
radius: The radius in meters.
|
73 |
|
74 |
Returns:
|
75 |
+
The surfing places found.
|
76 |
"""
|
77 |
overpass_url = "https://overpass-api.de/api/interpreter"
|
78 |
query = "[out:json];("
|
src/surf_spot_finder/tools/wrappers.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import importlib
|
3 |
+
from collections.abc import Callable
|
4 |
+
|
5 |
+
|
6 |
+
def import_and_wrap_tools(tools: list[str], wrapper: Callable) -> list[Callable]:
|
7 |
+
imported_tools = []
|
8 |
+
for tool in tools:
|
9 |
+
module, func = tool.rsplit(".", 1)
|
10 |
+
module = importlib.import_module(module)
|
11 |
+
imported_tool = getattr(module, func)
|
12 |
+
if inspect.isclass(imported_tool):
|
13 |
+
imported_tool = imported_tool()
|
14 |
+
imported_tools.append(wrapper(imported_tool))
|
15 |
+
return imported_tools
|
16 |
+
|
17 |
+
|
18 |
+
def wrap_tool_openai(tool):
|
19 |
+
from agents import function_tool, FunctionTool
|
20 |
+
|
21 |
+
if not isinstance(tool, FunctionTool):
|
22 |
+
return function_tool(tool)
|
23 |
+
return tool
|
24 |
+
|
25 |
+
|
26 |
+
def wrap_tool_langchain(tool):
|
27 |
+
from langchain_core.tools import BaseTool
|
28 |
+
from langchain_core.tools import tool as langchain_tool
|
29 |
+
|
30 |
+
if not isinstance(tool, BaseTool):
|
31 |
+
return langchain_tool(tool)
|
32 |
+
return tool
|
33 |
+
|
34 |
+
|
35 |
+
def wrap_tool_smolagents(tool):
|
36 |
+
from smolagents import Tool, tool as smolagents_tool
|
37 |
+
|
38 |
+
if not isinstance(tool, Tool):
|
39 |
+
return smolagents_tool(tool)
|
40 |
+
|
41 |
+
return tool
|