Spaces:
Sleeping
Sleeping
File size: 1,785 Bytes
cfd3735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
"""Test base tool child implementations."""
import inspect
import re
from typing import List, Type
import pytest
from langchain.tools.base import BaseTool
from langchain.tools.gmail.base import GmailBaseTool
from langchain.tools.playwright.base import BaseBrowserTool
def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
to_skip = {BaseBrowserTool, GmailBaseTool} # Abstract but not recognized
subclasses = []
for subclass in cls.__subclasses__():
if (
not getattr(subclass, "__abstract__", None)
and not subclass.__name__.startswith("_")
and subclass not in to_skip
):
subclasses.append(subclass)
sc = get_non_abstract_subclasses(subclass)
subclasses.extend(sc)
return subclasses
@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool)) # type: ignore
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params
pattern = re.compile(r"(?!Async)CallbackManagerForToolRun")
assert bool(re.search(pattern, str(params["run_manager"].annotation)))
assert params["run_manager"].default is None
if cls._arun is not BaseTool._arun:
run_func = cls._arun
params = inspect.signature(run_func).parameters
assert "run_manager" in params
assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation)
assert params["run_manager"].default is None
|