Spaces:
Running
Running
David de la Iglesia Castro
commited on
Add `langchain` agent (#22)
Browse files* Add `langchain` agent
* Add langchain to api.md
* fix(tests): Add missing `langchain`.
.github/workflows/tests.yaml
CHANGED
@@ -28,7 +28,7 @@ jobs:
|
|
28 |
cache: "pip"
|
29 |
|
30 |
- name: Install
|
31 |
-
run: pip install -e '.[arize,smolagents,openai,tests]'
|
32 |
|
33 |
- name: Run tests
|
34 |
run: pytest -v tests
|
|
|
28 |
cache: "pip"
|
29 |
|
30 |
- name: Install
|
31 |
+
run: pip install -e '.[arize,langchain,smolagents,openai,tests]'
|
32 |
|
33 |
- name: Run tests
|
34 |
run: pytest -v tests
|
docs/api.md
CHANGED
@@ -8,6 +8,8 @@
|
|
8 |
|
9 |
::: surf_spot_finder.agents.RUNNERS
|
10 |
|
|
|
|
|
11 |
::: surf_spot_finder.agents.openai
|
12 |
|
13 |
::: surf_spot_finder.agents.smolagents
|
|
|
8 |
|
9 |
::: surf_spot_finder.agents.RUNNERS
|
10 |
|
11 |
+
::: surf_spot_finder.agents.langchain
|
12 |
+
|
13 |
::: surf_spot_finder.agents.openai
|
14 |
|
15 |
::: surf_spot_finder.agents.smolagents
|
examples/langchain_single_agent.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: o3-mini
|
5 |
+
agent_type: langchain
|
6 |
+
tools:
|
7 |
+
- "surf_spot_finder.tools.search_web"
|
8 |
+
- "surf_spot_finder.tools.visit_webpage"
|
9 |
+
# input_prompt_template:
|
pyproject.toml
CHANGED
@@ -17,6 +17,11 @@ dependencies = [
|
|
17 |
]
|
18 |
|
19 |
[project.optional-dependencies]
|
|
|
|
|
|
|
|
|
|
|
20 |
smolagents = [
|
21 |
"smolagents[litellm]>=1.10.0",
|
22 |
"openinference-instrumentation-smolagents>=0.1.4"
|
|
|
17 |
]
|
18 |
|
19 |
[project.optional-dependencies]
|
20 |
+
langchain = [
|
21 |
+
"langchain",
|
22 |
+
"langgraph",
|
23 |
+
"openinference-instrumentation-langchain"
|
24 |
+
]
|
25 |
smolagents = [
|
26 |
"smolagents[litellm]>=1.10.0",
|
27 |
"openinference-instrumentation-smolagents>=0.1.4"
|
src/surf_spot_finder/agents/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
|
|
1 |
from .openai import run_openai_agent, run_openai_multi_agent
|
2 |
from .smolagents import run_smolagent
|
3 |
|
4 |
RUNNERS = {
|
|
|
5 |
"openai": run_openai_agent,
|
6 |
"smolagents": run_smolagent,
|
7 |
"openai_multi_agent": run_openai_multi_agent,
|
|
|
1 |
+
from .langchain import run_lanchain_agent
|
2 |
from .openai import run_openai_agent, run_openai_multi_agent
|
3 |
from .smolagents import run_smolagent
|
4 |
|
5 |
RUNNERS = {
|
6 |
+
"langchain": run_lanchain_agent,
|
7 |
"openai": run_openai_agent,
|
8 |
"smolagents": run_smolagent,
|
9 |
"openai_multi_agent": run_openai_multi_agent,
|
src/surf_spot_finder/agents/langchain.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
from loguru import logger
|
5 |
+
|
6 |
+
try:
|
7 |
+
from langchain.chat_models import init_chat_model
|
8 |
+
from langchain_core.messages import HumanMessage
|
9 |
+
from langchain_core.tools import BaseTool, tool
|
10 |
+
from langgraph.checkpoint.memory import MemorySaver
|
11 |
+
from langgraph.prebuilt import create_react_agent
|
12 |
+
|
13 |
+
langchain_available = True
|
14 |
+
except ImportError:
|
15 |
+
langchain_available = False
|
16 |
+
|
17 |
+
|
18 |
+
@logger.catch(reraise=True)
|
19 |
+
def run_lanchain_agent(
|
20 |
+
model_id: str, prompt: str, tools: list[str] | None = None, **kwargs
|
21 |
+
):
|
22 |
+
"""Runs an langchain ReAct agent with the given prompt and configuration.
|
23 |
+
|
24 |
+
Uses [create_react_agent](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent).
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model_id: The ID of the model to use.
|
28 |
+
See [init_chat_model](https://python.langchain.com/api_reference/langchain/chat_models/langchain.chat_models.base.init_chat_model.html).
|
29 |
+
prompt: The prompt to be given to the agent.
|
30 |
+
"""
|
31 |
+
if not langchain_available:
|
32 |
+
raise ImportError(
|
33 |
+
"You need to `pip install langchain langgraph` to use this agent"
|
34 |
+
)
|
35 |
+
|
36 |
+
if tools is None:
|
37 |
+
tools = [
|
38 |
+
"surf_spot_finder.tools.search_web",
|
39 |
+
"surf_spot_finder.tools.visit_webpage",
|
40 |
+
]
|
41 |
+
|
42 |
+
imported_tools = []
|
43 |
+
for imported_tool in tools:
|
44 |
+
module, func = imported_tool.rsplit(".", 1)
|
45 |
+
module = importlib.import_module(module)
|
46 |
+
imported_tool = getattr(module, func)
|
47 |
+
if inspect.isclass(imported_tool):
|
48 |
+
imported_tool = imported_tool()
|
49 |
+
if not isinstance(imported_tool, BaseTool):
|
50 |
+
imported_tool = tool(imported_tool)
|
51 |
+
imported_tools.append((imported_tool))
|
52 |
+
|
53 |
+
model = init_chat_model(model_id)
|
54 |
+
agent = create_react_agent(
|
55 |
+
model=model,
|
56 |
+
tools=imported_tools,
|
57 |
+
checkpointer=MemorySaver(),
|
58 |
+
)
|
59 |
+
for step in agent.stream(
|
60 |
+
{"messages": [HumanMessage(content=prompt)]},
|
61 |
+
{"configurable": {"thread_id": "abc123"}},
|
62 |
+
stream_mode="values",
|
63 |
+
):
|
64 |
+
step["messages"][-1].pretty_print()
|
src/surf_spot_finder/tracing.py
CHANGED
@@ -104,3 +104,7 @@ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
|
|
104 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
105 |
|
106 |
SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
|
|
|
|
|
|
|
|
|
|
104 |
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
105 |
|
106 |
SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
|
107 |
+
elif agent_type == "langchain":
|
108 |
+
from openinference.instrumentation.langchain import LangChainInstrumentor
|
109 |
+
|
110 |
+
LangChainInstrumentor().instrument(tracer_provider=tracer_provider)
|
tests/unit/agents/test_unit_langchain.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import patch, MagicMock
|
2 |
+
|
3 |
+
from surf_spot_finder.agents.langchain import (
|
4 |
+
run_lanchain_agent,
|
5 |
+
)
|
6 |
+
from surf_spot_finder.tools import (
|
7 |
+
search_web,
|
8 |
+
visit_webpage,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def test_run_langchain_agent_default():
|
13 |
+
model_mock = MagicMock()
|
14 |
+
create_mock = MagicMock()
|
15 |
+
agent_mock = MagicMock()
|
16 |
+
create_mock.return_value = agent_mock
|
17 |
+
memory_mock = MagicMock()
|
18 |
+
tool_mock = MagicMock()
|
19 |
+
|
20 |
+
with (
|
21 |
+
patch("surf_spot_finder.agents.langchain.create_react_agent", create_mock),
|
22 |
+
patch("surf_spot_finder.agents.langchain.init_chat_model", model_mock),
|
23 |
+
patch("surf_spot_finder.agents.langchain.MemorySaver", memory_mock),
|
24 |
+
patch("surf_spot_finder.agents.langchain.tool", tool_mock),
|
25 |
+
):
|
26 |
+
run_lanchain_agent("gpt-4o", "Test prompt")
|
27 |
+
model_mock.assert_called_once_with("gpt-4o")
|
28 |
+
create_mock.assert_called_once_with(
|
29 |
+
model=model_mock.return_value,
|
30 |
+
tools=[tool_mock(search_web), tool_mock(visit_webpage)],
|
31 |
+
checkpointer=memory_mock.return_value,
|
32 |
+
)
|
33 |
+
agent_mock.stream.assert_called_once()
|