David de la Iglesia Castro commited on
Commit
a5c72af
·
unverified ·
1 Parent(s): f53eca9

Add `langchain` agent (#22)

Browse files

* Add `langchain` agent

* Add langchain to api.md

* fix(tests): Add missing `langchain`.

.github/workflows/tests.yaml CHANGED
@@ -28,7 +28,7 @@ jobs:
28
  cache: "pip"
29
 
30
  - name: Install
31
- run: pip install -e '.[arize,smolagents,openai,tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
 
28
  cache: "pip"
29
 
30
  - name: Install
31
+ run: pip install -e '.[arize,langchain,smolagents,openai,tests]'
32
 
33
  - name: Run tests
34
  run: pytest -v tests
docs/api.md CHANGED
@@ -8,6 +8,8 @@
8
 
9
  ::: surf_spot_finder.agents.RUNNERS
10
 
 
 
11
  ::: surf_spot_finder.agents.openai
12
 
13
  ::: surf_spot_finder.agents.smolagents
 
8
 
9
  ::: surf_spot_finder.agents.RUNNERS
10
 
11
+ ::: surf_spot_finder.agents.langchain
12
+
13
  ::: surf_spot_finder.agents.openai
14
 
15
  ::: surf_spot_finder.agents.smolagents
examples/langchain_single_agent.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: o3-mini
5
+ agent_type: langchain
6
+ tools:
7
+ - "surf_spot_finder.tools.search_web"
8
+ - "surf_spot_finder.tools.visit_webpage"
9
+ # input_prompt_template:
pyproject.toml CHANGED
@@ -17,6 +17,11 @@ dependencies = [
17
  ]
18
 
19
  [project.optional-dependencies]
 
 
 
 
 
20
  smolagents = [
21
  "smolagents[litellm]>=1.10.0",
22
  "openinference-instrumentation-smolagents>=0.1.4"
 
17
  ]
18
 
19
  [project.optional-dependencies]
20
+ langchain = [
21
+ "langchain",
22
+ "langgraph",
23
+ "openinference-instrumentation-langchain"
24
+ ]
25
  smolagents = [
26
  "smolagents[litellm]>=1.10.0",
27
  "openinference-instrumentation-smolagents>=0.1.4"
src/surf_spot_finder/agents/__init__.py CHANGED
@@ -1,7 +1,9 @@
 
1
  from .openai import run_openai_agent, run_openai_multi_agent
2
  from .smolagents import run_smolagent
3
 
4
  RUNNERS = {
 
5
  "openai": run_openai_agent,
6
  "smolagents": run_smolagent,
7
  "openai_multi_agent": run_openai_multi_agent,
 
1
+ from .langchain import run_lanchain_agent
2
  from .openai import run_openai_agent, run_openai_multi_agent
3
  from .smolagents import run_smolagent
4
 
5
  RUNNERS = {
6
+ "langchain": run_lanchain_agent,
7
  "openai": run_openai_agent,
8
  "smolagents": run_smolagent,
9
  "openai_multi_agent": run_openai_multi_agent,
src/surf_spot_finder/agents/langchain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import importlib
3
+
4
+ from loguru import logger
5
+
6
+ try:
7
+ from langchain.chat_models import init_chat_model
8
+ from langchain_core.messages import HumanMessage
9
+ from langchain_core.tools import BaseTool, tool
10
+ from langgraph.checkpoint.memory import MemorySaver
11
+ from langgraph.prebuilt import create_react_agent
12
+
13
+ langchain_available = True
14
+ except ImportError:
15
+ langchain_available = False
16
+
17
+
18
+ @logger.catch(reraise=True)
19
+ def run_lanchain_agent(
20
+ model_id: str, prompt: str, tools: list[str] | None = None, **kwargs
21
+ ):
22
+ """Runs an langchain ReAct agent with the given prompt and configuration.
23
+
24
+ Uses [create_react_agent](https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent).
25
+
26
+ Args:
27
+ model_id: The ID of the model to use.
28
+ See [init_chat_model](https://python.langchain.com/api_reference/langchain/chat_models/langchain.chat_models.base.init_chat_model.html).
29
+ prompt: The prompt to be given to the agent.
30
+ """
31
+ if not langchain_available:
32
+ raise ImportError(
33
+ "You need to `pip install langchain langgraph` to use this agent"
34
+ )
35
+
36
+ if tools is None:
37
+ tools = [
38
+ "surf_spot_finder.tools.search_web",
39
+ "surf_spot_finder.tools.visit_webpage",
40
+ ]
41
+
42
+ imported_tools = []
43
+ for imported_tool in tools:
44
+ module, func = imported_tool.rsplit(".", 1)
45
+ module = importlib.import_module(module)
46
+ imported_tool = getattr(module, func)
47
+ if inspect.isclass(imported_tool):
48
+ imported_tool = imported_tool()
49
+ if not isinstance(imported_tool, BaseTool):
50
+ imported_tool = tool(imported_tool)
51
+ imported_tools.append((imported_tool))
52
+
53
+ model = init_chat_model(model_id)
54
+ agent = create_react_agent(
55
+ model=model,
56
+ tools=imported_tools,
57
+ checkpointer=MemorySaver(),
58
+ )
59
+ for step in agent.stream(
60
+ {"messages": [HumanMessage(content=prompt)]},
61
+ {"configurable": {"thread_id": "abc123"}},
62
+ stream_mode="values",
63
+ ):
64
+ step["messages"][-1].pretty_print()
src/surf_spot_finder/tracing.py CHANGED
@@ -104,3 +104,7 @@ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
104
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
105
 
106
  SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
 
 
 
 
 
104
  from openinference.instrumentation.smolagents import SmolagentsInstrumentor
105
 
106
  SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
107
+ elif agent_type == "langchain":
108
+ from openinference.instrumentation.langchain import LangChainInstrumentor
109
+
110
+ LangChainInstrumentor().instrument(tracer_provider=tracer_provider)
tests/unit/agents/test_unit_langchain.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import patch, MagicMock
2
+
3
+ from surf_spot_finder.agents.langchain import (
4
+ run_lanchain_agent,
5
+ )
6
+ from surf_spot_finder.tools import (
7
+ search_web,
8
+ visit_webpage,
9
+ )
10
+
11
+
12
+ def test_run_langchain_agent_default():
13
+ model_mock = MagicMock()
14
+ create_mock = MagicMock()
15
+ agent_mock = MagicMock()
16
+ create_mock.return_value = agent_mock
17
+ memory_mock = MagicMock()
18
+ tool_mock = MagicMock()
19
+
20
+ with (
21
+ patch("surf_spot_finder.agents.langchain.create_react_agent", create_mock),
22
+ patch("surf_spot_finder.agents.langchain.init_chat_model", model_mock),
23
+ patch("surf_spot_finder.agents.langchain.MemorySaver", memory_mock),
24
+ patch("surf_spot_finder.agents.langchain.tool", tool_mock),
25
+ ):
26
+ run_lanchain_agent("gpt-4o", "Test prompt")
27
+ model_mock.assert_called_once_with("gpt-4o")
28
+ create_mock.assert_called_once_with(
29
+ model=model_mock.return_value,
30
+ tools=[tool_mock(search_web), tool_mock(visit_webpage)],
31
+ checkpointer=memory_mock.return_value,
32
+ )
33
+ agent_mock.stream.assert_called_once()