David de la Iglesia Castro commited on
Commit
c61646b
·
unverified ·
1 Parent(s): de37bdf

Implement agent using `openai-agents-python` (#9)

Browse files

* enh(agents): Add `RUNNERS` registry.

Update cli to accept agent_type arg.

* enh(tracing): Decouple into `get_trace_provider` and `setup_tracing`.

Pass `agent_type` to `setup_tracing` in order to use the right instrumentor.

* feat(agents): Add basic agent using openai-agents-python.

* fix(tests): Install openai deps

* Address PR comments.

- Drop Literal in favor of explicit and centralized validation.
- Add docstring and types to `setup_tracing` .

* fix(docs): Add RUNNERS to API Reference.

.github/workflows/tests.yaml CHANGED
@@ -28,7 +28,7 @@ jobs:
28
  cache: "pip"
29
 
30
  - name: Install
31
- run: pip install -e '.[tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
 
28
  cache: "pip"
29
 
30
  - name: Install
31
+ run: pip install -e '.[openai,tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
docs/api.md CHANGED
@@ -2,6 +2,10 @@
2
 
3
  ::: surf_spot_finder.config.Config
4
 
 
 
 
 
5
  ::: surf_spot_finder.agents.smolagents
6
 
7
  ::: surf_spot_finder.tracing
 
2
 
3
  ::: surf_spot_finder.config.Config
4
 
5
+ ::: surf_spot_finder.agents.RUNNERS
6
+
7
+ ::: surf_spot_finder.agents.openai
8
+
9
  ::: surf_spot_finder.agents.smolagents
10
 
11
  ::: surf_spot_finder.tracing
pyproject.toml CHANGED
@@ -18,6 +18,11 @@ dependencies = [
18
  ]
19
 
20
  [project.optional-dependencies]
 
 
 
 
 
21
  demo = [
22
  "gradio",
23
  "spaces"
 
18
  ]
19
 
20
  [project.optional-dependencies]
21
+ openai = [
22
+ "openai-agents",
23
+ "openinference-instrumentation-openai"
24
+ ]
25
+
26
  demo = [
27
  "gradio",
28
  "spaces"
src/surf_spot_finder/agents/__init__.py CHANGED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .openai import run_openai_agent
2
+ from .smolagents import run_smolagent
3
+
4
+ RUNNERS = {
5
+ "openai": run_openai_agent,
6
+ "smolagents": run_smolagent,
7
+ }
8
+
9
+
10
+ def validate_agent_type(value) -> str:
11
+ if value not in RUNNERS:
12
+ raise ValueError(f"agent_type must be one of {RUNNERS.keys()}")
src/surf_spot_finder/agents/openai.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, TYPE_CHECKING
3
+
4
+ from loguru import logger
5
+
6
+ if TYPE_CHECKING:
7
+ from agents import RunResult
8
+
9
+
10
+ @logger.catch(reraise=True)
11
+ def run_openai_agent(
12
+ model_id: str,
13
+ prompt: str,
14
+ name: str = "surf-spot-finder",
15
+ instructions: Optional[str] = None,
16
+ api_key_var: Optional[str] = None,
17
+ base_url: Optional[str] = None,
18
+ ) -> "RunResult":
19
+ """Runs an OpenAI agent with the given prompt and configuration.
20
+
21
+ It leverages the 'agents' library to create and manage the agent
22
+ execution.
23
+
24
+ See https://openai.github.io/openai-agents-python/ref/agent/ for more details.
25
+
26
+
27
+ Args:
28
+ model_id (str): The ID of the OpenAI model to use (e.g., "gpt4o").
29
+ See https://platform.openai.com/docs/api-reference/models.
30
+ prompt (str): The prompt to be given to the agent.
31
+ name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
32
+ instructions (Optional[str], optional): Initial instructions to give the agent.
33
+ Defaults to None.
34
+ api_key_var (Optional[str], optional): The name of the environment variable
35
+ containing the OpenAI API key. If provided, along with `base_url`, an
36
+ external OpenAI client will be used. Defaults to None.
37
+ base_url (Optional[str], optional): The base URL for the OpenAI API.
38
+ Required if `api_key_var` is provided to use an external OpenAI client.
39
+ Defaults to None.
40
+
41
+ Returns:
42
+ RunResult: A RunResult object containing the output of the agent run.
43
+ See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
44
+ """
45
+ from agents import (
46
+ Agent,
47
+ AsyncOpenAI,
48
+ OpenAIChatCompletionsModel,
49
+ Runner,
50
+ WebSearchTool,
51
+ )
52
+
53
+ if api_key_var and base_url:
54
+ external_client = AsyncOpenAI(
55
+ api_key=os.environ[api_key_var],
56
+ base_url=base_url,
57
+ )
58
+ agent = Agent(
59
+ name=name,
60
+ instructions=instructions,
61
+ model=OpenAIChatCompletionsModel(
62
+ model=model_id,
63
+ openai_client=external_client,
64
+ ),
65
+ tools=[WebSearchTool()],
66
+ )
67
+ else:
68
+ agent = Agent(
69
+ model=model_id,
70
+ instructions=instructions,
71
+ name=name,
72
+ tools=[WebSearchTool()],
73
+ )
74
+ result = Runner.run_sync(agent, prompt)
75
+ logger.info(result.final_output)
76
+ return result
src/surf_spot_finder/agents/smolagents.py CHANGED
@@ -16,7 +16,9 @@ def run_smolagent(
16
  ) -> "CodeAgent":
17
  """
18
  Create and configure a Smolagents CodeAgent with the specified model.
 
19
  See https://docs.litellm.ai/docs/providers for details on available LiteLLM providers.
 
20
  Args:
21
  model_id (str): Model identifier using LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet')
22
  prompt (str): Prompt to provide to the model
@@ -36,13 +38,6 @@ def run_smolagent(
36
  LiteLLMModel,
37
  ToolCollection,
38
  )
39
-
40
- model = LiteLLMModel(
41
- model_id=model_id,
42
- api_base=api_base if api_base else None,
43
- api_key=os.environ[api_key_var] if api_key_var else None,
44
- )
45
-
46
  from mcp import StdioServerParameters
47
 
48
  model = LiteLLMModel(
 
16
  ) -> "CodeAgent":
17
  """
18
  Create and configure a Smolagents CodeAgent with the specified model.
19
+
20
  See https://docs.litellm.ai/docs/providers for details on available LiteLLM providers.
21
+
22
  Args:
23
  model_id (str): Model identifier using LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet')
24
  prompt (str): Prompt to provide to the model
 
38
  LiteLLMModel,
39
  ToolCollection,
40
  )
 
 
 
 
 
 
 
41
  from mcp import StdioServerParameters
42
 
43
  model = LiteLLMModel(
src/surf_spot_finder/cli.py CHANGED
@@ -7,8 +7,8 @@ from surf_spot_finder.config import (
7
  Config,
8
  DEFAULT_PROMPT,
9
  )
10
- from surf_spot_finder.agents.smolagents import run_smolagent
11
- from surf_spot_finder.tracing import setup_tracing
12
 
13
 
14
  @logger.catch(reraise=True)
@@ -17,6 +17,7 @@ def find_surf_spot(
17
  date: str,
18
  max_driving_hours: int,
19
  model_id: str,
 
20
  api_key_var: Optional[str] = None,
21
  prompt: str = DEFAULT_PROMPT,
22
  json_tracer: bool = True,
@@ -28,6 +29,7 @@ def find_surf_spot(
28
  date=date,
29
  max_driving_hours=max_driving_hours,
30
  model_id=model_id,
 
31
  api_key_var=api_key_var,
32
  prompt=prompt,
33
  json_tracer=json_tracer,
@@ -35,13 +37,14 @@ def find_surf_spot(
35
  )
36
 
37
  logger.info("Setting up tracing")
38
- setup_tracing(project_name="surf-spot-finder", json_tracer=config.json_tracer)
 
 
 
39
 
40
- logger.info("Running agent")
41
- run_smolagent(
42
  model_id=config.model_id,
43
- api_key_var=config.api_key_var,
44
- api_base=config.api_base,
45
  prompt=config.prompt.format(
46
  LOCATION=config.location,
47
  MAX_DRIVING_HOURS=config.max_driving_hours,
 
7
  Config,
8
  DEFAULT_PROMPT,
9
  )
10
+ from surf_spot_finder.agents import RUNNERS
11
+ from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
12
 
13
 
14
  @logger.catch(reraise=True)
 
17
  date: str,
18
  max_driving_hours: int,
19
  model_id: str,
20
+ agent_type: str = "smolagents",
21
  api_key_var: Optional[str] = None,
22
  prompt: str = DEFAULT_PROMPT,
23
  json_tracer: bool = True,
 
29
  date=date,
30
  max_driving_hours=max_driving_hours,
31
  model_id=model_id,
32
+ agent_type=agent_type,
33
  api_key_var=api_key_var,
34
  prompt=prompt,
35
  json_tracer=json_tracer,
 
37
  )
38
 
39
  logger.info("Setting up tracing")
40
+ tracer_provider = get_tracer_provider(
41
+ project_name="surf-spot-finder", json_tracer=config.json_tracer
42
+ )
43
+ setup_tracing(tracer_provider, config.agent_type)
44
 
45
+ logger.info(f"Running {config.agent_type} agent")
46
+ RUNNERS[config.agent_type](
47
  model_id=config.model_id,
 
 
48
  prompt=config.prompt.format(
49
  LOCATION=config.location,
50
  MAX_DRIVING_HOURS=config.max_driving_hours,
src/surf_spot_finder/config.py CHANGED
@@ -21,12 +21,20 @@ def validate_prompt(value) -> str:
21
  return value
22
 
23
 
 
 
 
 
 
 
 
24
  class Config(BaseModel):
25
  prompt: Annotated[str, AfterValidator(validate_prompt)]
26
  location: str
27
  max_driving_hours: PositiveInt
28
  date: FutureDatetime
29
  model_id: str
 
30
  api_key_var: Optional[str] = None
31
  json_tracer: bool = True
32
  api_base: Optional[str] = None
 
21
  return value
22
 
23
 
24
+ def validate_agent_type(value) -> str:
25
+ from surf_spot_finder.agents import validate_agent_type
26
+
27
+ validate_agent_type(value)
28
+ return value
29
+
30
+
31
  class Config(BaseModel):
32
  prompt: Annotated[str, AfterValidator(validate_prompt)]
33
  location: str
34
  max_driving_hours: PositiveInt
35
  date: FutureDatetime
36
  model_id: str
37
+ agent_type: Annotated[str, AfterValidator(validate_agent_type)]
38
  api_key_var: Optional[str] = None
39
  json_tracer: bool = True
40
  api_base: Optional[str] = None
src/surf_spot_finder/tracing.py CHANGED
@@ -1,8 +1,7 @@
1
- from datetime import datetime
2
  import os
 
3
 
4
  from opentelemetry import trace
5
- from openinference.instrumentation.smolagents import SmolagentsInstrumentor
6
  from opentelemetry.sdk.trace import TracerProvider
7
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
8
  from opentelemetry.sdk.trace.export import SpanExporter
@@ -24,9 +23,9 @@ class JsonFileSpanExporter(SpanExporter):
24
  pass
25
 
26
 
27
- def setup_tracing(project_name: str, json_tracer: bool) -> TracerProvider:
28
  """
29
- Set up tracing configuration based on the selected mode.
30
 
31
  Args:
32
  project_name: Name of the project for tracing
@@ -52,6 +51,27 @@ def setup_tracing(project_name: str, json_tracer: bool) -> TracerProvider:
52
  else:
53
  tracer_provider = register(project_name=project_name)
54
 
55
- SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
56
-
57
  return tracer_provider
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from datetime import datetime
3
 
4
  from opentelemetry import trace
 
5
  from opentelemetry.sdk.trace import TracerProvider
6
  from opentelemetry.sdk.trace.export import SimpleSpanProcessor
7
  from opentelemetry.sdk.trace.export import SpanExporter
 
23
  pass
24
 
25
 
26
+ def get_tracer_provider(project_name: str, json_tracer: bool) -> TracerProvider:
27
  """
28
+ Create a tracer_provider based on the selected mode.
29
 
30
  Args:
31
  project_name: Name of the project for tracing
 
51
  else:
52
  tracer_provider = register(project_name=project_name)
53
 
 
 
54
  return tracer_provider
55
+
56
+
57
+ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
58
+ """Setup tracing for `agent_type` by instrumenting `trace_provider`.
59
+
60
+ Args:
61
+ tracer_provider (TracerProvider): The configured tracer provider from
62
+ [get_tracer_provider][surf_spot_finder.tracing.get_tracer_provider].
63
+ agent_type (str): The type of agent being used.
64
+ Must be one of the supported types in [RUNNERS][surf_spot_finder.agents.RUNNERS].
65
+ """
66
+ from surf_spot_finder.agents import validate_agent_type
67
+
68
+ validate_agent_type(agent_type)
69
+
70
+ if agent_type == "openai":
71
+ from openinference.instrumentation.openai import OpenAIInstrumentor
72
+
73
+ OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
74
+ elif agent_type == "smolagents":
75
+ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
76
+
77
+ SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
tests/unit/agents/test_unit_openai.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from unittest.mock import patch, MagicMock
4
+
5
+ from surf_spot_finder.agents.openai import run_openai_agent
6
+
7
+
8
+ @pytest.fixture
9
+ def mock_agents_module():
10
+ agents_mocks = {
11
+ name: MagicMock()
12
+ for name in (
13
+ "Agent",
14
+ "AsyncOpenAI",
15
+ "OpenAIChatCompletionsModel",
16
+ "Runner",
17
+ "WebSearchTool",
18
+ )
19
+ }
20
+ with patch.dict(
21
+ "sys.modules",
22
+ {
23
+ "agents": MagicMock(**agents_mocks),
24
+ },
25
+ ):
26
+ yield agents_mocks
27
+
28
+
29
+ def test_run_openai_agent_default(mock_agents_module):
30
+ run_openai_agent("gpt-4o", "Test prompt")
31
+ mock_agents_module["Agent"].assert_called_once_with(
32
+ model="gpt-4o",
33
+ instructions=None,
34
+ name="surf-spot-finder",
35
+ tools=[
36
+ mock_agents_module["WebSearchTool"].return_value,
37
+ ],
38
+ )
39
+
40
+
41
+ def test_run_openai_agent_base_url_and_api_key_var(mock_agents_module):
42
+ with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
43
+ run_openai_agent(
44
+ "gpt-4o", "Test prompt", base_url="FOO", api_key_var="TEST_API_KEY"
45
+ )
46
+ mock_agents_module["AsyncOpenAI"].assert_called_once_with(
47
+ api_key="test-key-12345",
48
+ base_url="FOO",
49
+ )
50
+ mock_agents_module["OpenAIChatCompletionsModel"].assert_called_once()
51
+
52
+
53
+ def test_run_smolagent_environment_error():
54
+ """Test that passing a bad api_key_var throws an error"""
55
+ with patch.dict(os.environ, {}, clear=True):
56
+ with pytest.raises(KeyError, match="MISSING_KEY"):
57
+ run_openai_agent(
58
+ "test-model", "Test prompt", base_url="FOO", api_key_var="MISSING_KEY"
59
+ )