Spaces:
Running
Running
"""Test base tool child implementations.""" | |
import inspect | |
import re | |
from typing import List, Type | |
import pytest | |
from langchain.tools.base import BaseTool | |
from langchain.tools.gmail.base import GmailBaseTool | |
from langchain.tools.playwright.base import BaseBrowserTool | |
def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]: | |
to_skip = {BaseBrowserTool, GmailBaseTool} # Abstract but not recognized | |
subclasses = [] | |
for subclass in cls.__subclasses__(): | |
if ( | |
not getattr(subclass, "__abstract__", None) | |
and not subclass.__name__.startswith("_") | |
and subclass not in to_skip | |
): | |
subclasses.append(subclass) | |
sc = get_non_abstract_subclasses(subclass) | |
subclasses.extend(sc) | |
return subclasses | |
# type: ignore | |
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None: | |
"""Test that tools defined in this repo accept a run manager argument.""" | |
# This wouldn't be necessary if the BaseTool had a strict API. | |
if cls._run is not BaseTool._arun: | |
run_func = cls._run | |
params = inspect.signature(run_func).parameters | |
assert "run_manager" in params | |
pattern = re.compile(r"(?!Async)CallbackManagerForToolRun") | |
assert bool(re.search(pattern, str(params["run_manager"].annotation))) | |
assert params["run_manager"].default is None | |
if cls._arun is not BaseTool._arun: | |
run_func = cls._arun | |
params = inspect.signature(run_func).parameters | |
assert "run_manager" in params | |
assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation) | |
assert params["run_manager"].default is None | |