Spaces:
Sleeping
Sleeping
File size: 4,582 Bytes
cfd3735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""Test Tracer classes."""
from __future__ import annotations
import json
from datetime import datetime
from typing import Tuple
from unittest.mock import patch
from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.schema import LLMResult
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracer()
return tracer
# Mock a sample TracerSession object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSession:
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
@freeze_time("2023-01-01")
@pytest.fixture
def sample_runs() -> Tuple[Run, Run, Run]:
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=1,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
child_execution_order=1,
serialized={},
inputs={},
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
inputs={"input": "test"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
outputs=None,
serialized={},
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
)
return llm_run, chain_run, tool_run
def test_persist_run(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test that persist_run method calls requests.post once per method call."""
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
for run in sample_runs:
lang_chain_tracer_v2.run_map[str(run.id)] = run
for run in sample_runs:
lang_chain_tracer_v2._end_trace(run)
assert post.call_count == 3
assert get.call_count == 0
def test_persist_run_with_example_id(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test the example ID is assigned only to the parent run and not the children."""
example_id = uuid4()
llm_run, chain_run, tool_run = sample_runs
chain_run.child_runs = [tool_run]
tool_run.child_runs = [llm_run]
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
assert post.call_count == 3
assert get.call_count == 0
posted_data = [
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
]
assert posted_data[0]["id"] == str(chain_run.id)
assert posted_data[0]["reference_example_id"] == str(example_id)
assert posted_data[1]["id"] == str(tool_run.id)
assert not posted_data[1].get("reference_example_id")
assert posted_data[2]["id"] == str(llm_run.id)
assert not posted_data[2].get("reference_example_id")
|