Spaces:
Running
Running
David de la Iglesia Castro
commited on
Implement agent using `openai-agents-python` (#9)
Browse files* enh(agents): Add `RUNNERS` registry.
Update cli to accept agent_type arg.
* enh(tracing): Decouple into `get_trace_provider` and `setup_tracing`.
Pass `agent_type` to `setup_tracing` in order to use the right instrumentor.
* feat(agents): Add basic agent using openai-agents-python.
* fix(tests): Install openai deps
* Address PR comments.
- Drop Literal in favor of explicit and centralized validation.
- Add docstring and types to `setup_tracing` .
* fix(docs): Add RUNNERS to API Reference.
- .github/workflows/tests.yaml +1 -1
- docs/api.md +4 -0
- pyproject.toml +5 -0
- src/surf_spot_finder/agents/__init__.py +12 -0
- src/surf_spot_finder/agents/openai.py +76 -0
- src/surf_spot_finder/agents/smolagents.py +2 -7
- src/surf_spot_finder/cli.py +10 -7
- src/surf_spot_finder/config.py +8 -0
- src/surf_spot_finder/tracing.py +26 -6
- tests/unit/agents/test_unit_openai.py +59 -0
.github/workflows/tests.yaml
CHANGED
@@ -28,7 +28,7 @@ jobs:
|
|
28 |
cache: "pip"
|
29 |
|
30 |
- name: Install
|
31 |
-
run: pip install -e '.[tests]'
|
32 |
|
33 |
- name: Run tests
|
34 |
run: pytest -v tests
|
|
|
28 |
cache: "pip"
|
29 |
|
30 |
- name: Install
|
31 |
+
run: pip install -e '.[openai,tests]'
|
32 |
|
33 |
- name: Run tests
|
34 |
run: pytest -v tests
|
docs/api.md
CHANGED
@@ -2,6 +2,10 @@
|
|
2 |
|
3 |
::: surf_spot_finder.config.Config
|
4 |
|
|
|
|
|
|
|
|
|
5 |
::: surf_spot_finder.agents.smolagents
|
6 |
|
7 |
::: surf_spot_finder.tracing
|
|
|
2 |
|
3 |
::: surf_spot_finder.config.Config
|
4 |
|
5 |
+
::: surf_spot_finder.agents.RUNNERS
|
6 |
+
|
7 |
+
::: surf_spot_finder.agents.openai
|
8 |
+
|
9 |
::: surf_spot_finder.agents.smolagents
|
10 |
|
11 |
::: surf_spot_finder.tracing
|
pyproject.toml
CHANGED
@@ -18,6 +18,11 @@ dependencies = [
|
|
18 |
]
|
19 |
|
20 |
[project.optional-dependencies]
|
|
|
|
|
|
|
|
|
|
|
21 |
demo = [
|
22 |
"gradio",
|
23 |
"spaces"
|
|
|
18 |
]
|
19 |
|
20 |
[project.optional-dependencies]
|
21 |
+
openai = [
|
22 |
+
"openai-agents",
|
23 |
+
"openinference-instrumentation-openai"
|
24 |
+
]
|
25 |
+
|
26 |
demo = [
|
27 |
"gradio",
|
28 |
"spaces"
|
src/surf_spot_finder/agents/__init__.py
CHANGED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .openai import run_openai_agent
|
2 |
+
from .smolagents import run_smolagent
|
3 |
+
|
4 |
+
RUNNERS = {
|
5 |
+
"openai": run_openai_agent,
|
6 |
+
"smolagents": run_smolagent,
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
def validate_agent_type(value) -> str:
|
11 |
+
if value not in RUNNERS:
|
12 |
+
raise ValueError(f"agent_type must be one of {RUNNERS.keys()}")
|
src/surf_spot_finder/agents/openai.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, TYPE_CHECKING
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from agents import RunResult
|
8 |
+
|
9 |
+
|
10 |
+
@logger.catch(reraise=True)
|
11 |
+
def run_openai_agent(
|
12 |
+
model_id: str,
|
13 |
+
prompt: str,
|
14 |
+
name: str = "surf-spot-finder",
|
15 |
+
instructions: Optional[str] = None,
|
16 |
+
api_key_var: Optional[str] = None,
|
17 |
+
base_url: Optional[str] = None,
|
18 |
+
) -> "RunResult":
|
19 |
+
"""Runs an OpenAI agent with the given prompt and configuration.
|
20 |
+
|
21 |
+
It leverages the 'agents' library to create and manage the agent
|
22 |
+
execution.
|
23 |
+
|
24 |
+
See https://openai.github.io/openai-agents-python/ref/agent/ for more details.
|
25 |
+
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_id (str): The ID of the OpenAI model to use (e.g., "gpt4o").
|
29 |
+
See https://platform.openai.com/docs/api-reference/models.
|
30 |
+
prompt (str): The prompt to be given to the agent.
|
31 |
+
name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
|
32 |
+
instructions (Optional[str], optional): Initial instructions to give the agent.
|
33 |
+
Defaults to None.
|
34 |
+
api_key_var (Optional[str], optional): The name of the environment variable
|
35 |
+
containing the OpenAI API key. If provided, along with `base_url`, an
|
36 |
+
external OpenAI client will be used. Defaults to None.
|
37 |
+
base_url (Optional[str], optional): The base URL for the OpenAI API.
|
38 |
+
Required if `api_key_var` is provided to use an external OpenAI client.
|
39 |
+
Defaults to None.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
RunResult: A RunResult object containing the output of the agent run.
|
43 |
+
See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
|
44 |
+
"""
|
45 |
+
from agents import (
|
46 |
+
Agent,
|
47 |
+
AsyncOpenAI,
|
48 |
+
OpenAIChatCompletionsModel,
|
49 |
+
Runner,
|
50 |
+
WebSearchTool,
|
51 |
+
)
|
52 |
+
|
53 |
+
if api_key_var and base_url:
|
54 |
+
external_client = AsyncOpenAI(
|
55 |
+
api_key=os.environ[api_key_var],
|
56 |
+
base_url=base_url,
|
57 |
+
)
|
58 |
+
agent = Agent(
|
59 |
+
name=name,
|
60 |
+
instructions=instructions,
|
61 |
+
model=OpenAIChatCompletionsModel(
|
62 |
+
model=model_id,
|
63 |
+
openai_client=external_client,
|
64 |
+
),
|
65 |
+
tools=[WebSearchTool()],
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
agent = Agent(
|
69 |
+
model=model_id,
|
70 |
+
instructions=instructions,
|
71 |
+
name=name,
|
72 |
+
tools=[WebSearchTool()],
|
73 |
+
)
|
74 |
+
result = Runner.run_sync(agent, prompt)
|
75 |
+
logger.info(result.final_output)
|
76 |
+
return result
|
src/surf_spot_finder/agents/smolagents.py
CHANGED
@@ -16,7 +16,9 @@ def run_smolagent(
|
|
16 |
) -> "CodeAgent":
|
17 |
"""
|
18 |
Create and configure a Smolagents CodeAgent with the specified model.
|
|
|
19 |
See https://docs.litellm.ai/docs/providers for details on available LiteLLM providers.
|
|
|
20 |
Args:
|
21 |
model_id (str): Model identifier using LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet')
|
22 |
prompt (str): Prompt to provide to the model
|
@@ -36,13 +38,6 @@ def run_smolagent(
|
|
36 |
LiteLLMModel,
|
37 |
ToolCollection,
|
38 |
)
|
39 |
-
|
40 |
-
model = LiteLLMModel(
|
41 |
-
model_id=model_id,
|
42 |
-
api_base=api_base if api_base else None,
|
43 |
-
api_key=os.environ[api_key_var] if api_key_var else None,
|
44 |
-
)
|
45 |
-
|
46 |
from mcp import StdioServerParameters
|
47 |
|
48 |
model = LiteLLMModel(
|
|
|
16 |
) -> "CodeAgent":
|
17 |
"""
|
18 |
Create and configure a Smolagents CodeAgent with the specified model.
|
19 |
+
|
20 |
See https://docs.litellm.ai/docs/providers for details on available LiteLLM providers.
|
21 |
+
|
22 |
Args:
|
23 |
model_id (str): Model identifier using LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet')
|
24 |
prompt (str): Prompt to provide to the model
|
|
|
38 |
LiteLLMModel,
|
39 |
ToolCollection,
|
40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
from mcp import StdioServerParameters
|
42 |
|
43 |
model = LiteLLMModel(
|
src/surf_spot_finder/cli.py
CHANGED
@@ -7,8 +7,8 @@ from surf_spot_finder.config import (
|
|
7 |
Config,
|
8 |
DEFAULT_PROMPT,
|
9 |
)
|
10 |
-
from surf_spot_finder.agents
|
11 |
-
from surf_spot_finder.tracing import setup_tracing
|
12 |
|
13 |
|
14 |
@logger.catch(reraise=True)
|
@@ -17,6 +17,7 @@ def find_surf_spot(
|
|
17 |
date: str,
|
18 |
max_driving_hours: int,
|
19 |
model_id: str,
|
|
|
20 |
api_key_var: Optional[str] = None,
|
21 |
prompt: str = DEFAULT_PROMPT,
|
22 |
json_tracer: bool = True,
|
@@ -28,6 +29,7 @@ def find_surf_spot(
|
|
28 |
date=date,
|
29 |
max_driving_hours=max_driving_hours,
|
30 |
model_id=model_id,
|
|
|
31 |
api_key_var=api_key_var,
|
32 |
prompt=prompt,
|
33 |
json_tracer=json_tracer,
|
@@ -35,13 +37,14 @@ def find_surf_spot(
|
|
35 |
)
|
36 |
|
37 |
logger.info("Setting up tracing")
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
-
logger.info("Running agent")
|
41 |
-
|
42 |
model_id=config.model_id,
|
43 |
-
api_key_var=config.api_key_var,
|
44 |
-
api_base=config.api_base,
|
45 |
prompt=config.prompt.format(
|
46 |
LOCATION=config.location,
|
47 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
|
|
7 |
Config,
|
8 |
DEFAULT_PROMPT,
|
9 |
)
|
10 |
+
from surf_spot_finder.agents import RUNNERS
|
11 |
+
from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
|
12 |
|
13 |
|
14 |
@logger.catch(reraise=True)
|
|
|
17 |
date: str,
|
18 |
max_driving_hours: int,
|
19 |
model_id: str,
|
20 |
+
agent_type: str = "smolagents",
|
21 |
api_key_var: Optional[str] = None,
|
22 |
prompt: str = DEFAULT_PROMPT,
|
23 |
json_tracer: bool = True,
|
|
|
29 |
date=date,
|
30 |
max_driving_hours=max_driving_hours,
|
31 |
model_id=model_id,
|
32 |
+
agent_type=agent_type,
|
33 |
api_key_var=api_key_var,
|
34 |
prompt=prompt,
|
35 |
json_tracer=json_tracer,
|
|
|
37 |
)
|
38 |
|
39 |
logger.info("Setting up tracing")
|
40 |
+
tracer_provider = get_tracer_provider(
|
41 |
+
project_name="surf-spot-finder", json_tracer=config.json_tracer
|
42 |
+
)
|
43 |
+
setup_tracing(tracer_provider, config.agent_type)
|
44 |
|
45 |
+
logger.info(f"Running {config.agent_type} agent")
|
46 |
+
RUNNERS[config.agent_type](
|
47 |
model_id=config.model_id,
|
|
|
|
|
48 |
prompt=config.prompt.format(
|
49 |
LOCATION=config.location,
|
50 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
src/surf_spot_finder/config.py
CHANGED
@@ -21,12 +21,20 @@ def validate_prompt(value) -> str:
|
|
21 |
return value
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
class Config(BaseModel):
|
25 |
prompt: Annotated[str, AfterValidator(validate_prompt)]
|
26 |
location: str
|
27 |
max_driving_hours: PositiveInt
|
28 |
date: FutureDatetime
|
29 |
model_id: str
|
|
|
30 |
api_key_var: Optional[str] = None
|
31 |
json_tracer: bool = True
|
32 |
api_base: Optional[str] = None
|
|
|
21 |
return value
|
22 |
|
23 |
|
24 |
+
def validate_agent_type(value) -> str:
|
25 |
+
from surf_spot_finder.agents import validate_agent_type
|
26 |
+
|
27 |
+
validate_agent_type(value)
|
28 |
+
return value
|
29 |
+
|
30 |
+
|
31 |
class Config(BaseModel):
|
32 |
prompt: Annotated[str, AfterValidator(validate_prompt)]
|
33 |
location: str
|
34 |
max_driving_hours: PositiveInt
|
35 |
date: FutureDatetime
|
36 |
model_id: str
|
37 |
+
agent_type: Annotated[str, AfterValidator(validate_agent_type)]
|
38 |
api_key_var: Optional[str] = None
|
39 |
json_tracer: bool = True
|
40 |
api_base: Optional[str] = None
|
src/surf_spot_finder/tracing.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
-
from datetime import datetime
|
2 |
import os
|
|
|
3 |
|
4 |
from opentelemetry import trace
|
5 |
-
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
6 |
from opentelemetry.sdk.trace import TracerProvider
|
7 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
8 |
from opentelemetry.sdk.trace.export import SpanExporter
|
@@ -24,9 +23,9 @@ class JsonFileSpanExporter(SpanExporter):
|
|
24 |
pass
|
25 |
|
26 |
|
27 |
-
def
|
28 |
"""
|
29 |
-
|
30 |
|
31 |
Args:
|
32 |
project_name: Name of the project for tracing
|
@@ -52,6 +51,27 @@ def setup_tracing(project_name: str, json_tracer: bool) -> TracerProvider:
|
|
52 |
else:
|
53 |
tracer_provider = register(project_name=project_name)
|
54 |
|
55 |
-
SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
|
56 |
-
|
57 |
return tracer_provider
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from datetime import datetime
|
3 |
|
4 |
from opentelemetry import trace
|
|
|
5 |
from opentelemetry.sdk.trace import TracerProvider
|
6 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
7 |
from opentelemetry.sdk.trace.export import SpanExporter
|
|
|
23 |
pass
|
24 |
|
25 |
|
26 |
+
def get_tracer_provider(project_name: str, json_tracer: bool) -> TracerProvider:
|
27 |
"""
|
28 |
+
Create a tracer_provider based on the selected mode.
|
29 |
|
30 |
Args:
|
31 |
project_name: Name of the project for tracing
|
|
|
51 |
else:
|
52 |
tracer_provider = register(project_name=project_name)
|
53 |
|
|
|
|
|
54 |
return tracer_provider
|
55 |
+
|
56 |
+
|
57 |
+
def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
|
58 |
+
"""Setup tracing for `agent_type` by instrumenting `trace_provider`.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
tracer_provider (TracerProvider): The configured tracer provider from
|
62 |
+
[get_tracer_provider][surf_spot_finder.tracing.get_tracer_provider].
|
63 |
+
agent_type (str): The type of agent being used.
|
64 |
+
Must be one of the supported types in [RUNNERS][surf_spot_finder.agents.RUNNERS].
|
65 |
+
"""
|
66 |
+
from surf_spot_finder.agents import validate_agent_type
|
67 |
+
|
68 |
+
validate_agent_type(agent_type)
|
69 |
+
|
70 |
+
if agent_type == "openai":
|
71 |
+
from openinference.instrumentation.openai import OpenAIInstrumentor
|
72 |
+
|
73 |
+
OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
|
74 |
+
elif agent_type == "smolagents":
|
75 |
+
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
76 |
+
|
77 |
+
SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
|
tests/unit/agents/test_unit_openai.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytest
|
3 |
+
from unittest.mock import patch, MagicMock
|
4 |
+
|
5 |
+
from surf_spot_finder.agents.openai import run_openai_agent
|
6 |
+
|
7 |
+
|
8 |
+
@pytest.fixture
|
9 |
+
def mock_agents_module():
|
10 |
+
agents_mocks = {
|
11 |
+
name: MagicMock()
|
12 |
+
for name in (
|
13 |
+
"Agent",
|
14 |
+
"AsyncOpenAI",
|
15 |
+
"OpenAIChatCompletionsModel",
|
16 |
+
"Runner",
|
17 |
+
"WebSearchTool",
|
18 |
+
)
|
19 |
+
}
|
20 |
+
with patch.dict(
|
21 |
+
"sys.modules",
|
22 |
+
{
|
23 |
+
"agents": MagicMock(**agents_mocks),
|
24 |
+
},
|
25 |
+
):
|
26 |
+
yield agents_mocks
|
27 |
+
|
28 |
+
|
29 |
+
def test_run_openai_agent_default(mock_agents_module):
|
30 |
+
run_openai_agent("gpt-4o", "Test prompt")
|
31 |
+
mock_agents_module["Agent"].assert_called_once_with(
|
32 |
+
model="gpt-4o",
|
33 |
+
instructions=None,
|
34 |
+
name="surf-spot-finder",
|
35 |
+
tools=[
|
36 |
+
mock_agents_module["WebSearchTool"].return_value,
|
37 |
+
],
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def test_run_openai_agent_base_url_and_api_key_var(mock_agents_module):
|
42 |
+
with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
|
43 |
+
run_openai_agent(
|
44 |
+
"gpt-4o", "Test prompt", base_url="FOO", api_key_var="TEST_API_KEY"
|
45 |
+
)
|
46 |
+
mock_agents_module["AsyncOpenAI"].assert_called_once_with(
|
47 |
+
api_key="test-key-12345",
|
48 |
+
base_url="FOO",
|
49 |
+
)
|
50 |
+
mock_agents_module["OpenAIChatCompletionsModel"].assert_called_once()
|
51 |
+
|
52 |
+
|
53 |
+
def test_run_smolagent_environment_error():
|
54 |
+
"""Test that passing a bad api_key_var throws an error"""
|
55 |
+
with patch.dict(os.environ, {}, clear=True):
|
56 |
+
with pytest.raises(KeyError, match="MISSING_KEY"):
|
57 |
+
run_openai_agent(
|
58 |
+
"test-model", "Test prompt", base_url="FOO", api_key_var="MISSING_KEY"
|
59 |
+
)
|