Spaces:
Running
Running
"""Test Tracer classes.""" | |
from __future__ import annotations | |
from datetime import datetime | |
from typing import List, Optional, Union | |
from uuid import uuid4 | |
import pytest | |
from freezegun import freeze_time | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.tracers.base import BaseTracer, TracerException | |
from langchain.callbacks.tracers.langchain_v1 import ( | |
ChainRun, | |
LangChainTracerV1, | |
LLMRun, | |
ToolRun, | |
TracerSessionV1, | |
) | |
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base | |
from langchain.schema import LLMResult | |
TEST_SESSION_ID = 2023 | |
def load_session(session_name: str) -> TracerSessionV1: | |
"""Load a tracing session.""" | |
return TracerSessionV1( | |
id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow() | |
) | |
def new_session(name: Optional[str] = None) -> TracerSessionV1: | |
"""Create a new tracing session.""" | |
return TracerSessionV1( | |
id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow() | |
) | |
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1: | |
"""Persist a tracing session.""" | |
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID}) | |
def load_default_session() -> TracerSessionV1: | |
"""Load a tracing session.""" | |
return TracerSessionV1( | |
id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow() | |
) | |
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1: | |
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") | |
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") | |
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") | |
tracer = LangChainTracerV1() | |
return tracer | |
class FakeTracer(BaseTracer): | |
"""Fake tracer that records LangChain execution.""" | |
def __init__(self) -> None: | |
"""Initialize the tracer.""" | |
super().__init__() | |
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] | |
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None: | |
"""Persist a run.""" | |
if isinstance(run, Run): | |
with pytest.MonkeyPatch().context() as m: | |
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") | |
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") | |
m.setenv("LANGCHAIN_API_KEY", "foo") | |
tracer = LangChainTracerV1() | |
tracer.load_default_session = load_default_session # type: ignore | |
run = tracer._convert_to_v1_run(run) | |
self.runs.append(run) | |
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1: | |
"""Persist a tracing session.""" | |
return _persist_session(session) | |
def new_session(self, name: Optional[str] = None) -> TracerSessionV1: | |
"""Create a new tracing session.""" | |
return new_session(name) | |
def load_session(self, session_name: str) -> TracerSessionV1: | |
"""Load a tracing session.""" | |
return load_session(session_name) | |
def load_default_session(self) -> TracerSessionV1: | |
"""Load a tracing session.""" | |
return load_default_session() | |
def test_tracer_llm_run() -> None: | |
"""Test tracer on an LLM run.""" | |
uuid = uuid4() | |
compare_run = LLMRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "llm"}, | |
prompts=[], | |
response=LLMResult(generations=[[]]), | |
session_id=TEST_SESSION_ID, | |
error=None, | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_chat_model_run() -> None: | |
"""Test tracer on a Chat Model run.""" | |
uuid = uuid4() | |
compare_run = LLMRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "chat_model"}, | |
prompts=[""], | |
response=LLMResult(generations=[[]]), | |
session_id=TEST_SESSION_ID, | |
error=None, | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
manager = CallbackManager(handlers=[tracer]) | |
run_manager = manager.on_chat_model_start( | |
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid | |
) | |
run_manager.on_llm_end(response=LLMResult(generations=[[]])) | |
assert tracer.runs == [compare_run] | |
def test_tracer_llm_run_errors_no_start() -> None: | |
"""Test tracer on an LLM run without a start.""" | |
tracer = FakeTracer() | |
tracer.new_session() | |
with pytest.raises(TracerException): | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) | |
def test_tracer_multiple_llm_runs() -> None: | |
"""Test the tracer with multiple runs.""" | |
uuid = uuid4() | |
compare_run = LLMRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "llm"}, | |
prompts=[], | |
response=LLMResult(generations=[[]]), | |
session_id=TEST_SESSION_ID, | |
error=None, | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
num_runs = 10 | |
for _ in range(num_runs): | |
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) | |
assert tracer.runs == [compare_run] * num_runs | |
def test_tracer_chain_run() -> None: | |
"""Test tracer on a Chain run.""" | |
uuid = uuid4() | |
compare_run = ChainRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "chain"}, | |
inputs={}, | |
outputs={}, | |
session_id=TEST_SESSION_ID, | |
error=None, | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) | |
tracer.on_chain_end(outputs={}, run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_tool_run() -> None: | |
"""Test tracer on a Tool run.""" | |
uuid = uuid4() | |
compare_run = ToolRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "tool"}, | |
tool_input="test", | |
output="test", | |
action="{'name': 'tool'}", | |
session_id=TEST_SESSION_ID, | |
error=None, | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) | |
tracer.on_tool_end("test", run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_nested_run() -> None: | |
"""Test tracer on a nested run.""" | |
tracer = FakeTracer() | |
tracer.new_session() | |
chain_uuid = uuid4() | |
tool_uuid = uuid4() | |
llm_uuid1 = uuid4() | |
llm_uuid2 = uuid4() | |
for _ in range(10): | |
tracer.on_chain_start( | |
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid | |
) | |
tracer.on_tool_start( | |
serialized={"name": "tool"}, | |
input_str="test", | |
run_id=tool_uuid, | |
parent_run_id=chain_uuid, | |
) | |
tracer.on_llm_start( | |
serialized={"name": "llm"}, | |
prompts=[], | |
run_id=llm_uuid1, | |
parent_run_id=tool_uuid, | |
) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) | |
tracer.on_tool_end("test", run_id=tool_uuid) | |
tracer.on_llm_start( | |
serialized={"name": "llm"}, | |
prompts=[], | |
run_id=llm_uuid2, | |
parent_run_id=chain_uuid, | |
) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) | |
tracer.on_chain_end(outputs={}, run_id=chain_uuid) | |
compare_run = ChainRun( | |
uuid=str(chain_uuid), | |
error=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=4, | |
serialized={"name": "chain"}, | |
inputs={}, | |
outputs={}, | |
session_id=TEST_SESSION_ID, | |
child_chain_runs=[], | |
child_tool_runs=[ | |
ToolRun( | |
uuid=str(tool_uuid), | |
parent_uuid=str(chain_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=2, | |
child_execution_order=3, | |
serialized={"name": "tool"}, | |
tool_input="test", | |
output="test", | |
action="{'name': 'tool'}", | |
session_id=TEST_SESSION_ID, | |
error=None, | |
child_chain_runs=[], | |
child_tool_runs=[], | |
child_llm_runs=[ | |
LLMRun( | |
uuid=str(llm_uuid1), | |
parent_uuid=str(tool_uuid), | |
error=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=3, | |
child_execution_order=3, | |
serialized={"name": "llm"}, | |
prompts=[], | |
response=LLMResult(generations=[[]]), | |
session_id=TEST_SESSION_ID, | |
) | |
], | |
), | |
], | |
child_llm_runs=[ | |
LLMRun( | |
uuid=str(llm_uuid2), | |
parent_uuid=str(chain_uuid), | |
error=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=4, | |
child_execution_order=4, | |
serialized={"name": "llm"}, | |
prompts=[], | |
response=LLMResult(generations=[[]]), | |
session_id=TEST_SESSION_ID, | |
), | |
], | |
) | |
assert tracer.runs[0] == compare_run | |
assert tracer.runs == [compare_run] * 10 | |
def test_tracer_llm_run_on_error() -> None: | |
"""Test tracer on an LLM run with an error.""" | |
exception = Exception("test") | |
uuid = uuid4() | |
compare_run = LLMRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "llm"}, | |
prompts=[], | |
response=None, | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) | |
tracer.on_llm_error(exception, run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_chain_run_on_error() -> None: | |
"""Test tracer on a Chain run with an error.""" | |
exception = Exception("test") | |
uuid = uuid4() | |
compare_run = ChainRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "chain"}, | |
inputs={}, | |
outputs=None, | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid) | |
tracer.on_chain_error(exception, run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_tool_run_on_error() -> None: | |
"""Test tracer on a Tool run with an error.""" | |
exception = Exception("test") | |
uuid = uuid4() | |
compare_run = ToolRun( | |
uuid=str(uuid), | |
parent_uuid=None, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=1, | |
serialized={"name": "tool"}, | |
tool_input="test", | |
output=None, | |
action="{'name': 'tool'}", | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
) | |
tracer = FakeTracer() | |
tracer.new_session() | |
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid) | |
tracer.on_tool_error(exception, run_id=uuid) | |
assert tracer.runs == [compare_run] | |
def test_tracer_nested_runs_on_error() -> None: | |
"""Test tracer on a nested run with an error.""" | |
exception = Exception("test") | |
tracer = FakeTracer() | |
tracer.new_session() | |
chain_uuid = uuid4() | |
tool_uuid = uuid4() | |
llm_uuid1 = uuid4() | |
llm_uuid2 = uuid4() | |
llm_uuid3 = uuid4() | |
for _ in range(3): | |
tracer.on_chain_start( | |
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid | |
) | |
tracer.on_llm_start( | |
serialized={"name": "llm"}, | |
prompts=[], | |
run_id=llm_uuid1, | |
parent_run_id=chain_uuid, | |
) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) | |
tracer.on_llm_start( | |
serialized={"name": "llm"}, | |
prompts=[], | |
run_id=llm_uuid2, | |
parent_run_id=chain_uuid, | |
) | |
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) | |
tracer.on_tool_start( | |
serialized={"name": "tool"}, | |
input_str="test", | |
run_id=tool_uuid, | |
parent_run_id=chain_uuid, | |
) | |
tracer.on_llm_start( | |
serialized={"name": "llm"}, | |
prompts=[], | |
run_id=llm_uuid3, | |
parent_run_id=tool_uuid, | |
) | |
tracer.on_llm_error(exception, run_id=llm_uuid3) | |
tracer.on_tool_error(exception, run_id=tool_uuid) | |
tracer.on_chain_error(exception, run_id=chain_uuid) | |
compare_run = ChainRun( | |
uuid=str(chain_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=1, | |
child_execution_order=5, | |
serialized={"name": "chain"}, | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
inputs={}, | |
outputs=None, | |
child_llm_runs=[ | |
LLMRun( | |
uuid=str(llm_uuid1), | |
parent_uuid=str(chain_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=2, | |
child_execution_order=2, | |
serialized={"name": "llm"}, | |
session_id=TEST_SESSION_ID, | |
error=None, | |
prompts=[], | |
response=LLMResult(generations=[[]], llm_output=None), | |
), | |
LLMRun( | |
uuid=str(llm_uuid2), | |
parent_uuid=str(chain_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=3, | |
child_execution_order=3, | |
serialized={"name": "llm"}, | |
session_id=TEST_SESSION_ID, | |
error=None, | |
prompts=[], | |
response=LLMResult(generations=[[]], llm_output=None), | |
), | |
], | |
child_chain_runs=[], | |
child_tool_runs=[ | |
ToolRun( | |
uuid=str(tool_uuid), | |
parent_uuid=str(chain_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=4, | |
child_execution_order=5, | |
serialized={"name": "tool"}, | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
tool_input="test", | |
output=None, | |
action="{'name': 'tool'}", | |
child_llm_runs=[ | |
LLMRun( | |
uuid=str(llm_uuid3), | |
parent_uuid=str(tool_uuid), | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
extra={}, | |
execution_order=5, | |
child_execution_order=5, | |
serialized={"name": "llm"}, | |
session_id=TEST_SESSION_ID, | |
error=repr(exception), | |
prompts=[], | |
response=None, | |
) | |
], | |
child_chain_runs=[], | |
child_tool_runs=[], | |
), | |
], | |
) | |
assert tracer.runs == [compare_run] * 3 | |
def sample_tracer_session_v1() -> TracerSessionV1: | |
return TracerSessionV1(id=2, name="Sample session") | |
def test_convert_run( | |
lang_chain_tracer_v1: LangChainTracerV1, | |
sample_tracer_session_v1: TracerSessionV1, | |
) -> None: | |
"""Test converting a run to a V1 run.""" | |
llm_run = Run( | |
id="57a08cc4-73d2-4236-8370-549099d07fad", | |
name="llm_run", | |
execution_order=1, | |
child_execution_order=1, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
session_id=TEST_SESSION_ID, | |
inputs={"prompts": []}, | |
outputs=LLMResult(generations=[[]]).dict(), | |
serialized={}, | |
extra={}, | |
run_type=RunTypeEnum.llm, | |
) | |
chain_run = Run( | |
id="57a08cc4-73d2-4236-8371-549099d07fad", | |
name="chain_run", | |
execution_order=1, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
child_execution_order=1, | |
serialized={}, | |
inputs={}, | |
outputs={}, | |
child_runs=[llm_run], | |
extra={}, | |
run_type=RunTypeEnum.chain, | |
) | |
tool_run = Run( | |
id="57a08cc4-73d2-4236-8372-549099d07fad", | |
name="tool_run", | |
execution_order=1, | |
child_execution_order=1, | |
inputs={"input": "test"}, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
outputs=None, | |
serialized={}, | |
child_runs=[], | |
extra={}, | |
run_type=RunTypeEnum.tool, | |
) | |
expected_llm_run = LLMRun( | |
uuid="57a08cc4-73d2-4236-8370-549099d07fad", | |
name="llm_run", | |
execution_order=1, | |
child_execution_order=1, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
session_id=2, | |
prompts=[], | |
response=LLMResult(generations=[[]]), | |
serialized={}, | |
extra={}, | |
) | |
expected_chain_run = ChainRun( | |
uuid="57a08cc4-73d2-4236-8371-549099d07fad", | |
name="chain_run", | |
execution_order=1, | |
child_execution_order=1, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
session_id=2, | |
serialized={}, | |
inputs={}, | |
outputs={}, | |
child_llm_runs=[expected_llm_run], | |
child_chain_runs=[], | |
child_tool_runs=[], | |
extra={}, | |
) | |
expected_tool_run = ToolRun( | |
uuid="57a08cc4-73d2-4236-8372-549099d07fad", | |
name="tool_run", | |
execution_order=1, | |
child_execution_order=1, | |
session_id=2, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
tool_input="test", | |
action="{}", | |
serialized={}, | |
child_llm_runs=[], | |
child_chain_runs=[], | |
child_tool_runs=[], | |
extra={}, | |
) | |
lang_chain_tracer_v1.session = sample_tracer_session_v1 | |
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run) | |
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run) | |
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run) | |
assert isinstance(converted_llm_run, LLMRun) | |
assert isinstance(converted_chain_run, ChainRun) | |
assert isinstance(converted_tool_run, ToolRun) | |
assert converted_llm_run == expected_llm_run | |
assert converted_tool_run == expected_tool_run | |
assert converted_chain_run == expected_chain_run | |