File size: 4,582 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Test Tracer classes."""
from __future__ import annotations

import json
from datetime import datetime
from typing import Tuple
from unittest.mock import patch
from uuid import UUID, uuid4

import pytest
from freezegun import freeze_time

from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.schema import LLMResult

_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")


@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
    monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
    monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
    monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
    tracer = LangChainTracer()
    return tracer


# Mock a sample TracerSession object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSession:
    return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)


@freeze_time("2023-01-01")
@pytest.fixture
def sample_runs() -> Tuple[Run, Run, Run]:
    llm_run = Run(
        id="57a08cc4-73d2-4236-8370-549099d07fad",
        name="llm_run",
        execution_order=1,
        child_execution_order=1,
        parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
        start_time=datetime.utcnow(),
        end_time=datetime.utcnow(),
        session_id=1,
        inputs={"prompts": []},
        outputs=LLMResult(generations=[[]]).dict(),
        serialized={},
        extra={},
        run_type=RunTypeEnum.llm,
    )
    chain_run = Run(
        id="57a08cc4-73d2-4236-8371-549099d07fad",
        name="chain_run",
        execution_order=1,
        start_time=datetime.utcnow(),
        end_time=datetime.utcnow(),
        child_execution_order=1,
        serialized={},
        inputs={},
        outputs={},
        child_runs=[llm_run],
        extra={},
        run_type=RunTypeEnum.chain,
    )

    tool_run = Run(
        id="57a08cc4-73d2-4236-8372-549099d07fad",
        name="tool_run",
        execution_order=1,
        child_execution_order=1,
        inputs={"input": "test"},
        start_time=datetime.utcnow(),
        end_time=datetime.utcnow(),
        outputs=None,
        serialized={},
        child_runs=[],
        extra={},
        run_type=RunTypeEnum.tool,
    )
    return llm_run, chain_run, tool_run


def test_persist_run(
    lang_chain_tracer_v2: LangChainTracer,
    sample_tracer_session_v2: TracerSession,
    sample_runs: Tuple[Run, Run, Run],
) -> None:
    """Test that persist_run method calls requests.post once per method call."""
    with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
        "langchain.callbacks.tracers.langchain.requests.get"
    ) as get:
        post.return_value.raise_for_status.return_value = None
        lang_chain_tracer_v2.session = sample_tracer_session_v2
        for run in sample_runs:
            lang_chain_tracer_v2.run_map[str(run.id)] = run
        for run in sample_runs:
            lang_chain_tracer_v2._end_trace(run)

        assert post.call_count == 3
        assert get.call_count == 0


def test_persist_run_with_example_id(
    lang_chain_tracer_v2: LangChainTracer,
    sample_tracer_session_v2: TracerSession,
    sample_runs: Tuple[Run, Run, Run],
) -> None:
    """Test the example ID is assigned only to the parent run and not the children."""
    example_id = uuid4()
    llm_run, chain_run, tool_run = sample_runs
    chain_run.child_runs = [tool_run]
    tool_run.child_runs = [llm_run]
    with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
        "langchain.callbacks.tracers.langchain.requests.get"
    ) as get:
        post.return_value.raise_for_status.return_value = None
        lang_chain_tracer_v2.session = sample_tracer_session_v2
        lang_chain_tracer_v2.example_id = example_id
        lang_chain_tracer_v2._persist_run(chain_run)

        assert post.call_count == 3
        assert get.call_count == 0
        posted_data = [
            json.loads(call_args[1]["data"]) for call_args in post.call_args_list
        ]
        assert posted_data[0]["id"] == str(chain_run.id)
        assert posted_data[0]["reference_example_id"] == str(example_id)
        assert posted_data[1]["id"] == str(tool_run.id)
        assert not posted_data[1].get("reference_example_id")
        assert posted_data[2]["id"] == str(llm_run.id)
        assert not posted_data[2].get("reference_example_id")