Spaces:
Sleeping
Sleeping
"""Test Tracer classes.""" | |
from __future__ import annotations | |
import json | |
from datetime import datetime | |
from typing import Tuple | |
from unittest.mock import patch | |
from uuid import UUID, uuid4 | |
import pytest | |
from freezegun import freeze_time | |
from langchain.callbacks.tracers.langchain import LangChainTracer | |
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession | |
from langchain.schema import LLMResult | |
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a") | |
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad") | |
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer: | |
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id") | |
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com") | |
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo") | |
tracer = LangChainTracer() | |
return tracer | |
# Mock a sample TracerSession object | |
def sample_tracer_session_v2() -> TracerSession: | |
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID) | |
def sample_runs() -> Tuple[Run, Run, Run]: | |
llm_run = Run( | |
id="57a08cc4-73d2-4236-8370-549099d07fad", | |
name="llm_run", | |
execution_order=1, | |
child_execution_order=1, | |
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad", | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
session_id=1, | |
inputs={"prompts": []}, | |
outputs=LLMResult(generations=[[]]).dict(), | |
serialized={}, | |
extra={}, | |
run_type=RunTypeEnum.llm, | |
) | |
chain_run = Run( | |
id="57a08cc4-73d2-4236-8371-549099d07fad", | |
name="chain_run", | |
execution_order=1, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
child_execution_order=1, | |
serialized={}, | |
inputs={}, | |
outputs={}, | |
child_runs=[llm_run], | |
extra={}, | |
run_type=RunTypeEnum.chain, | |
) | |
tool_run = Run( | |
id="57a08cc4-73d2-4236-8372-549099d07fad", | |
name="tool_run", | |
execution_order=1, | |
child_execution_order=1, | |
inputs={"input": "test"}, | |
start_time=datetime.utcnow(), | |
end_time=datetime.utcnow(), | |
outputs=None, | |
serialized={}, | |
child_runs=[], | |
extra={}, | |
run_type=RunTypeEnum.tool, | |
) | |
return llm_run, chain_run, tool_run | |
def test_persist_run( | |
lang_chain_tracer_v2: LangChainTracer, | |
sample_tracer_session_v2: TracerSession, | |
sample_runs: Tuple[Run, Run, Run], | |
) -> None: | |
"""Test that persist_run method calls requests.post once per method call.""" | |
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( | |
"langchain.callbacks.tracers.langchain.requests.get" | |
) as get: | |
post.return_value.raise_for_status.return_value = None | |
lang_chain_tracer_v2.session = sample_tracer_session_v2 | |
for run in sample_runs: | |
lang_chain_tracer_v2.run_map[str(run.id)] = run | |
for run in sample_runs: | |
lang_chain_tracer_v2._end_trace(run) | |
assert post.call_count == 3 | |
assert get.call_count == 0 | |
def test_persist_run_with_example_id( | |
lang_chain_tracer_v2: LangChainTracer, | |
sample_tracer_session_v2: TracerSession, | |
sample_runs: Tuple[Run, Run, Run], | |
) -> None: | |
"""Test the example ID is assigned only to the parent run and not the children.""" | |
example_id = uuid4() | |
llm_run, chain_run, tool_run = sample_runs | |
chain_run.child_runs = [tool_run] | |
tool_run.child_runs = [llm_run] | |
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch( | |
"langchain.callbacks.tracers.langchain.requests.get" | |
) as get: | |
post.return_value.raise_for_status.return_value = None | |
lang_chain_tracer_v2.session = sample_tracer_session_v2 | |
lang_chain_tracer_v2.example_id = example_id | |
lang_chain_tracer_v2._persist_run(chain_run) | |
assert post.call_count == 3 | |
assert get.call_count == 0 | |
posted_data = [ | |
json.loads(call_args[1]["data"]) for call_args in post.call_args_list | |
] | |
assert posted_data[0]["id"] == str(chain_run.id) | |
assert posted_data[0]["reference_example_id"] == str(example_id) | |
assert posted_data[1]["id"] == str(tool_run.id) | |
assert not posted_data[1].get("reference_example_id") | |
assert posted_data[2]["id"] == str(llm_run.id) | |
assert not posted_data[2].get("reference_example_id") | |