File size: 1,785 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Test base tool child implementations."""


import inspect
import re
from typing import List, Type

import pytest

from langchain.tools.base import BaseTool
from langchain.tools.gmail.base import GmailBaseTool
from langchain.tools.playwright.base import BaseBrowserTool


def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
    to_skip = {BaseBrowserTool, GmailBaseTool}  # Abstract but not recognized
    subclasses = []
    for subclass in cls.__subclasses__():
        if (
            not getattr(subclass, "__abstract__", None)
            and not subclass.__name__.startswith("_")
            and subclass not in to_skip
        ):
            subclasses.append(subclass)
        sc = get_non_abstract_subclasses(subclass)
        subclasses.extend(sc)
    return subclasses


@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool))  # type: ignore
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
    """Test that tools defined in this repo accept a run manager argument."""
    # This wouldn't be necessary if the BaseTool had a strict API.
    if cls._run is not BaseTool._arun:
        run_func = cls._run
        params = inspect.signature(run_func).parameters
        assert "run_manager" in params
        pattern = re.compile(r"(?!Async)CallbackManagerForToolRun")
        assert bool(re.search(pattern, str(params["run_manager"].annotation)))
        assert params["run_manager"].default is None

    if cls._arun is not BaseTool._arun:
        run_func = cls._arun
        params = inspect.signature(run_func).parameters
        assert "run_manager" in params
        assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation)
        assert params["run_manager"].default is None