Spaces:
Running
Running
| import inspect | |
| import logging | |
| import uuid | |
| from datetime import date, datetime, time | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional, Set, Type, Union, get_type_hints | |
| from browser_use.controller.registry.views import ActionModel | |
| from langchain.tools import BaseTool | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from pydantic import BaseModel, Field, create_model | |
| from pydantic.v1 import BaseModel, Field | |
| logger = logging.getLogger(__name__) | |
| async def setup_mcp_client_and_tools(mcp_server_config: Dict[str, Any]) -> Optional[MultiServerMCPClient]: | |
| """ | |
| Initializes the MultiServerMCPClient, connects to servers, fetches tools, | |
| filters them, and returns a flat list of usable tools and the client instance. | |
| Returns: | |
| A tuple containing: | |
| - list[BaseTool]: The filtered list of usable LangChain tools. | |
| - MultiServerMCPClient | None: The initialized and started client instance, or None on failure. | |
| """ | |
| logger.info("Initializing MultiServerMCPClient...") | |
| if not mcp_server_config: | |
| logger.error("No MCP server configuration provided.") | |
| return None | |
| try: | |
| if "mcpServers" in mcp_server_config: | |
| mcp_server_config = mcp_server_config["mcpServers"] | |
| client = MultiServerMCPClient(mcp_server_config) | |
| await client.__aenter__() | |
| return client | |
| except Exception as e: | |
| logger.error(f"Failed to setup MCP client or fetch tools: {e}", exc_info=True) | |
| return None | |
| def create_tool_param_model(tool: BaseTool) -> Type[BaseModel]: | |
| """Creates a Pydantic model from a LangChain tool's schema""" | |
| # Get tool schema information | |
| json_schema = tool.args_schema | |
| tool_name = tool.name | |
| # If the tool already has a schema defined, convert it to a new param_model | |
| if json_schema is not None: | |
| # Create new parameter model | |
| params = {} | |
| # Process properties if they exist | |
| if 'properties' in json_schema: | |
| # Find required fields | |
| required_fields: Set[str] = set(json_schema.get('required', [])) | |
| for prop_name, prop_details in json_schema['properties'].items(): | |
| field_type = resolve_type(prop_details, f"{tool_name}_{prop_name}") | |
| # Check if parameter is required | |
| is_required = prop_name in required_fields | |
| # Get default value and description | |
| default_value = prop_details.get('default', ... if is_required else None) | |
| description = prop_details.get('description', '') | |
| # Add field constraints | |
| field_kwargs = {'default': default_value} | |
| if description: | |
| field_kwargs['description'] = description | |
| # Add additional constraints if present | |
| if 'minimum' in prop_details: | |
| field_kwargs['ge'] = prop_details['minimum'] | |
| if 'maximum' in prop_details: | |
| field_kwargs['le'] = prop_details['maximum'] | |
| if 'minLength' in prop_details: | |
| field_kwargs['min_length'] = prop_details['minLength'] | |
| if 'maxLength' in prop_details: | |
| field_kwargs['max_length'] = prop_details['maxLength'] | |
| if 'pattern' in prop_details: | |
| field_kwargs['pattern'] = prop_details['pattern'] | |
| # Add to parameters dictionary | |
| params[prop_name] = (field_type, Field(**field_kwargs)) | |
| return create_model( | |
| f'{tool_name}_parameters', | |
| __base__=ActionModel, | |
| **params, # type: ignore | |
| ) | |
| # If no schema is defined, extract parameters from the _run method | |
| run_method = tool._run | |
| sig = inspect.signature(run_method) | |
| # Get type hints for better type information | |
| try: | |
| type_hints = get_type_hints(run_method) | |
| except Exception: | |
| type_hints = {} | |
| params = {} | |
| for name, param in sig.parameters.items(): | |
| # Skip 'self' parameter and any other parameters you want to exclude | |
| if name == 'self': | |
| continue | |
| # Get annotation from type hints if available, otherwise from signature | |
| annotation = type_hints.get(name, param.annotation) | |
| if annotation == inspect.Parameter.empty: | |
| annotation = Any | |
| # Use default value if available, otherwise make it required | |
| if param.default != param.empty: | |
| params[name] = (annotation, param.default) | |
| else: | |
| params[name] = (annotation, ...) | |
| return create_model( | |
| f'{tool_name}_parameters', | |
| __base__=ActionModel, | |
| **params, # type: ignore | |
| ) | |
| def resolve_type(prop_details: Dict[str, Any], prefix: str = "") -> Any: | |
| """Recursively resolves JSON schema type to Python/Pydantic type""" | |
| # Handle reference types | |
| if '$ref' in prop_details: | |
| # In a real application, reference resolution would be needed | |
| return Any | |
| # Basic type mapping | |
| type_mapping = { | |
| 'string': str, | |
| 'integer': int, | |
| 'number': float, | |
| 'boolean': bool, | |
| 'array': List, | |
| 'object': Dict, | |
| 'null': type(None), | |
| } | |
| # Handle formatted strings | |
| if prop_details.get('type') == 'string' and 'format' in prop_details: | |
| format_mapping = { | |
| 'date-time': datetime, | |
| 'date': date, | |
| 'time': time, | |
| 'email': str, | |
| 'uri': str, | |
| 'url': str, | |
| 'uuid': uuid.UUID, | |
| 'binary': bytes, | |
| } | |
| return format_mapping.get(prop_details['format'], str) | |
| # Handle enum types | |
| if 'enum' in prop_details: | |
| enum_values = prop_details['enum'] | |
| # Create dynamic enum class with safe names | |
| enum_dict = {} | |
| for i, v in enumerate(enum_values): | |
| # Ensure enum names are valid Python identifiers | |
| if isinstance(v, str): | |
| key = v.upper().replace(' ', '_').replace('-', '_') | |
| if not key.isidentifier(): | |
| key = f"VALUE_{i}" | |
| else: | |
| key = f"VALUE_{i}" | |
| enum_dict[key] = v | |
| # Only create enum if we have values | |
| if enum_dict: | |
| return Enum(f"{prefix}_Enum", enum_dict) | |
| return str # Fallback | |
| # Handle array types | |
| if prop_details.get('type') == 'array' and 'items' in prop_details: | |
| item_type = resolve_type(prop_details['items'], f"{prefix}_item") | |
| return List[item_type] # type: ignore | |
| # Handle object types with properties | |
| if prop_details.get('type') == 'object' and 'properties' in prop_details: | |
| nested_params = {} | |
| for nested_name, nested_details in prop_details['properties'].items(): | |
| nested_type = resolve_type(nested_details, f"{prefix}_{nested_name}") | |
| # Get required field info | |
| required_fields = prop_details.get('required', []) | |
| is_required = nested_name in required_fields | |
| default_value = nested_details.get('default', ... if is_required else None) | |
| description = nested_details.get('description', '') | |
| field_kwargs = {'default': default_value} | |
| if description: | |
| field_kwargs['description'] = description | |
| nested_params[nested_name] = (nested_type, Field(**field_kwargs)) | |
| # Create nested model | |
| nested_model = create_model(f"{prefix}_Model", **nested_params) | |
| return nested_model | |
| # Handle union types (oneOf, anyOf) | |
| if 'oneOf' in prop_details or 'anyOf' in prop_details: | |
| union_schema = prop_details.get('oneOf') or prop_details.get('anyOf') | |
| union_types = [] | |
| for i, t in enumerate(union_schema): | |
| union_types.append(resolve_type(t, f"{prefix}_{i}")) | |
| if union_types: | |
| return Union.__getitem__(tuple(union_types)) # type: ignore | |
| return Any | |
| # Handle allOf (intersection types) | |
| if 'allOf' in prop_details: | |
| nested_params = {} | |
| for i, schema_part in enumerate(prop_details['allOf']): | |
| if 'properties' in schema_part: | |
| for nested_name, nested_details in schema_part['properties'].items(): | |
| nested_type = resolve_type(nested_details, f"{prefix}_allOf_{i}_{nested_name}") | |
| # Check if required | |
| required_fields = schema_part.get('required', []) | |
| is_required = nested_name in required_fields | |
| nested_params[nested_name] = (nested_type, ... if is_required else None) | |
| # Create composite model | |
| if nested_params: | |
| composite_model = create_model(f"{prefix}_CompositeModel", **nested_params) | |
| return composite_model | |
| return Dict | |
| # Default to basic types | |
| schema_type = prop_details.get('type', 'string') | |
| if isinstance(schema_type, list): | |
| # Handle multiple types (e.g., ["string", "null"]) | |
| non_null_types = [t for t in schema_type if t != 'null'] | |
| if non_null_types: | |
| primary_type = type_mapping.get(non_null_types[0], Any) | |
| if 'null' in schema_type: | |
| return Optional[primary_type] # type: ignore | |
| return primary_type | |
| return Any | |
| return type_mapping.get(schema_type, Any) | |