Spaces:
Running
Running
"""Test the LangChain+ client.""" | |
import uuid | |
from datetime import datetime | |
from io import BytesIO | |
from typing import Any, Dict, List, Optional, Union | |
from unittest import mock | |
import pytest | |
from langchain.base_language import BaseLanguageModel | |
from langchain.callbacks.tracers.langchain import LangChainTracer | |
from langchain.callbacks.tracers.schemas import TracerSession | |
from langchain.chains.base import Chain | |
from langchain.client.langchain import ( | |
LangChainPlusClient, | |
_get_link_stem, | |
_is_localhost, | |
) | |
from langchain.client.models import Dataset, Example | |
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0) | |
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4" | |
def test_link_split(api_url: str, expected_url: str) -> None: | |
"""Test the link splitting handles both localhost and deployed urls.""" | |
assert _get_link_stem(api_url) == expected_url | |
def test_is_localhost() -> None: | |
assert _is_localhost("http://localhost:8000") | |
assert _is_localhost("http://127.0.0.1:8000") | |
assert _is_localhost("http://0.0.0.0:8000") | |
assert not _is_localhost("http://example.com:8000") | |
def test_validate_api_key_if_hosted() -> None: | |
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: | |
return _TENANT_ID | |
with mock.patch.object( | |
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id | |
): | |
with pytest.raises(ValueError, match="API key must be provided"): | |
LangChainPlusClient(api_url="http://www.example.com") | |
with mock.patch.object( | |
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id | |
): | |
client = LangChainPlusClient(api_url="http://localhost:8000") | |
assert client.api_url == "http://localhost:8000" | |
assert client.api_key is None | |
def test_headers() -> None: | |
def mock_get_seeded_tenant_id(api_url: str, api_key: Optional[str]) -> str: | |
return _TENANT_ID | |
with mock.patch.object( | |
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id | |
): | |
client = LangChainPlusClient(api_url="http://localhost:8000", api_key="123") | |
assert client._headers == {"authorization": "Bearer 123"} | |
with mock.patch.object( | |
LangChainPlusClient, "_get_seeded_tenant_id", new=mock_get_seeded_tenant_id | |
): | |
client_no_key = LangChainPlusClient(api_url="http://localhost:8000") | |
assert client_no_key._headers == {} | |
def test_upload_csv(mock_post: mock.Mock) -> None: | |
mock_response = mock.Mock() | |
dataset_id = str(uuid.uuid4()) | |
example_1 = Example( | |
id=str(uuid.uuid4()), | |
created_at=_CREATED_AT, | |
inputs={"input": "1"}, | |
outputs={"output": "2"}, | |
dataset_id=dataset_id, | |
) | |
example_2 = Example( | |
id=str(uuid.uuid4()), | |
created_at=_CREATED_AT, | |
inputs={"input": "3"}, | |
outputs={"output": "4"}, | |
dataset_id=dataset_id, | |
) | |
mock_response.json.return_value = { | |
"id": dataset_id, | |
"name": "test.csv", | |
"description": "Test dataset", | |
"owner_id": "the owner", | |
"created_at": _CREATED_AT, | |
"examples": [example_1, example_2], | |
"tenant_id": _TENANT_ID, | |
} | |
mock_post.return_value = mock_response | |
client = LangChainPlusClient( | |
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID | |
) | |
csv_file = ("test.csv", BytesIO(b"input,output\n1,2\n3,4\n")) | |
dataset = client.upload_csv( | |
csv_file, "Test dataset", input_keys=["input"], output_keys=["output"] | |
) | |
assert dataset.id == uuid.UUID(dataset_id) | |
assert dataset.name == "test.csv" | |
assert dataset.description == "Test dataset" | |
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: | |
dataset = Dataset( | |
id=uuid.uuid4(), | |
name="test", | |
description="Test dataset", | |
owner_id="owner", | |
created_at=_CREATED_AT, | |
tenant_id=_TENANT_ID, | |
) | |
uuids = [ | |
"0c193153-2309-4704-9a47-17aee4fb25c8", | |
"0d11b5fd-8e66-4485-b696-4b55155c0c05", | |
"90d696f0-f10d-4fd0-b88b-bfee6df08b84", | |
"4ce2c6d8-5124-4c0c-8292-db7bdebcf167", | |
"7b5a524c-80fa-4960-888e-7d380f9a11ee", | |
] | |
examples = [ | |
Example( | |
id=uuids[0], | |
created_at=_CREATED_AT, | |
inputs={"input": "1"}, | |
outputs={"output": "2"}, | |
dataset_id=str(uuid.uuid4()), | |
), | |
Example( | |
id=uuids[1], | |
created_at=_CREATED_AT, | |
inputs={"input": "3"}, | |
outputs={"output": "4"}, | |
dataset_id=str(uuid.uuid4()), | |
), | |
Example( | |
id=uuids[2], | |
created_at=_CREATED_AT, | |
inputs={"input": "5"}, | |
outputs={"output": "6"}, | |
dataset_id=str(uuid.uuid4()), | |
), | |
Example( | |
id=uuids[3], | |
created_at=_CREATED_AT, | |
inputs={"input": "7"}, | |
outputs={"output": "8"}, | |
dataset_id=str(uuid.uuid4()), | |
), | |
Example( | |
id=uuids[4], | |
created_at=_CREATED_AT, | |
inputs={"input": "9"}, | |
outputs={"output": "10"}, | |
dataset_id=str(uuid.uuid4()), | |
), | |
] | |
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset: | |
return dataset | |
def mock_list_examples(*args: Any, **kwargs: Any) -> List[Example]: | |
return examples | |
async def mock_arun_chain( | |
example: Example, | |
tracer: Any, | |
llm_or_chain: Union[BaseLanguageModel, Chain], | |
n_repetitions: int, | |
) -> List[Dict[str, Any]]: | |
return [ | |
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions) | |
] | |
def mock_ensure_session(self: Any, *args: Any, **kwargs: Any) -> TracerSession: | |
return TracerSession(name="test_session", tenant_id=_TENANT_ID, id=uuid.uuid4()) | |
with mock.patch.object( | |
LangChainPlusClient, "read_dataset", new=mock_read_dataset | |
), mock.patch.object( | |
LangChainPlusClient, "list_examples", new=mock_list_examples | |
), mock.patch.object( | |
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain | |
), mock.patch.object( | |
LangChainTracer, "ensure_session", new=mock_ensure_session | |
): | |
monkeypatch.setenv("LANGCHAIN_TENANT_ID", _TENANT_ID) | |
client = LangChainPlusClient( | |
api_url="http://localhost:8000", api_key="123", tenant_id=_TENANT_ID | |
) | |
chain = mock.MagicMock() | |
num_repetitions = 3 | |
results = await client.arun_on_dataset( | |
dataset_name="test", | |
llm_or_chain_factory=lambda: chain, | |
concurrency_level=2, | |
session_name="test_session", | |
num_repetitions=num_repetitions, | |
) | |
expected = { | |
uuid_: [ | |
{"result": f"Result for example {uuid.UUID(uuid_)}"} | |
for _ in range(num_repetitions) | |
] | |
for uuid_ in uuids | |
} | |
assert results == expected | |