|
import pdb |
|
|
|
import pyperclip |
|
from typing import Optional, Type, Callable, Dict, Any, Union, Awaitable, TypeVar |
|
from pydantic import BaseModel |
|
from browser_use.agent.views import ActionResult |
|
from browser_use.browser.context import BrowserContext |
|
from browser_use.controller.service import Controller, DoneAction |
|
from browser_use.controller.registry.service import Registry, RegisteredAction |
|
from main_content_extractor import MainContentExtractor |
|
from browser_use.controller.views import ( |
|
ClickElementAction, |
|
DoneAction, |
|
ExtractPageContentAction, |
|
GoToUrlAction, |
|
InputTextAction, |
|
OpenTabAction, |
|
ScrollAction, |
|
SearchGoogleAction, |
|
SendKeysAction, |
|
SwitchTabAction, |
|
) |
|
import logging |
|
import inspect |
|
import asyncio |
|
import os |
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
from browser_use.agent.views import ActionModel, ActionResult |
|
|
|
from src.utils.mcp_client import create_tool_param_model, setup_mcp_client_and_tools |
|
|
|
from browser_use.utils import time_execution_sync |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
Context = TypeVar('Context') |
|
|
|
|
|
class CustomController(Controller): |
|
def __init__(self, exclude_actions: list[str] = [], |
|
output_model: Optional[Type[BaseModel]] = None, |
|
ask_assistant_callback: Optional[Union[Callable[[str, BrowserContext], Dict[str, Any]], Callable[ |
|
[str, BrowserContext], Awaitable[Dict[str, Any]]]]] = None, |
|
): |
|
super().__init__(exclude_actions=exclude_actions, output_model=output_model) |
|
self._register_custom_actions() |
|
self.ask_assistant_callback = ask_assistant_callback |
|
self.mcp_client = None |
|
self.mcp_server_config = None |
|
|
|
def _register_custom_actions(self): |
|
"""Register all custom browser actions""" |
|
|
|
@self.registry.action( |
|
"When executing tasks, prioritize autonomous completion. However, if you encounter a definitive blocker " |
|
"that prevents you from proceeding independently – such as needing credentials you don't possess, " |
|
"requiring subjective human judgment, needing a physical action performed, encountering complex CAPTCHAs, " |
|
"or facing limitations in your capabilities – you must request human assistance." |
|
) |
|
async def ask_for_assistant(query: str, browser: BrowserContext): |
|
if self.ask_assistant_callback: |
|
if inspect.iscoroutinefunction(self.ask_assistant_callback): |
|
user_response = await self.ask_assistant_callback(query, browser) |
|
else: |
|
user_response = self.ask_assistant_callback(query, browser) |
|
msg = f"AI ask: {query}. User response: {user_response['response']}" |
|
logger.info(msg) |
|
return ActionResult(extracted_content=msg, include_in_memory=True) |
|
else: |
|
return ActionResult(extracted_content="Human cannot help you. Please try another way.", |
|
include_in_memory=True) |
|
|
|
@self.registry.action( |
|
'Upload file to interactive element with file path ', |
|
) |
|
async def upload_file(index: int, path: str, browser: BrowserContext, available_file_paths: list[str]): |
|
if path not in available_file_paths: |
|
return ActionResult(error=f'File path {path} is not available') |
|
|
|
if not os.path.exists(path): |
|
return ActionResult(error=f'File {path} does not exist') |
|
|
|
dom_el = await browser.get_dom_element_by_index(index) |
|
|
|
file_upload_dom_el = dom_el.get_file_upload_element() |
|
|
|
if file_upload_dom_el is None: |
|
msg = f'No file upload element found at index {index}' |
|
logger.info(msg) |
|
return ActionResult(error=msg) |
|
|
|
file_upload_el = await browser.get_locate_element(file_upload_dom_el) |
|
|
|
if file_upload_el is None: |
|
msg = f'No file upload element found at index {index}' |
|
logger.info(msg) |
|
return ActionResult(error=msg) |
|
|
|
try: |
|
await file_upload_el.set_input_files(path) |
|
msg = f'Successfully uploaded file to index {index}' |
|
logger.info(msg) |
|
return ActionResult(extracted_content=msg, include_in_memory=True) |
|
except Exception as e: |
|
msg = f'Failed to upload file to index {index}: {str(e)}' |
|
logger.info(msg) |
|
return ActionResult(error=msg) |
|
|
|
@time_execution_sync('--act') |
|
async def act( |
|
self, |
|
action: ActionModel, |
|
browser_context: Optional[BrowserContext] = None, |
|
|
|
page_extraction_llm: Optional[BaseChatModel] = None, |
|
sensitive_data: Optional[Dict[str, str]] = None, |
|
available_file_paths: Optional[list[str]] = None, |
|
|
|
context: Context | None = None, |
|
) -> ActionResult: |
|
"""Execute an action""" |
|
|
|
try: |
|
for action_name, params in action.model_dump(exclude_unset=True).items(): |
|
if params is not None: |
|
if action_name.startswith("mcp"): |
|
|
|
logger.debug(f"Invoke MCP tool: {action_name}") |
|
mcp_tool = self.registry.registry.actions.get(action_name).function |
|
result = await mcp_tool.ainvoke(params) |
|
else: |
|
result = await self.registry.execute_action( |
|
action_name, |
|
params, |
|
browser=browser_context, |
|
page_extraction_llm=page_extraction_llm, |
|
sensitive_data=sensitive_data, |
|
available_file_paths=available_file_paths, |
|
context=context, |
|
) |
|
|
|
if isinstance(result, str): |
|
return ActionResult(extracted_content=result) |
|
elif isinstance(result, ActionResult): |
|
return result |
|
elif result is None: |
|
return ActionResult() |
|
else: |
|
raise ValueError(f'Invalid action result type: {type(result)} of {result}') |
|
return ActionResult() |
|
except Exception as e: |
|
raise e |
|
|
|
async def setup_mcp_client(self, mcp_server_config: Optional[Dict[str, Any]] = None): |
|
self.mcp_server_config = mcp_server_config |
|
if self.mcp_server_config: |
|
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config) |
|
self.register_mcp_tools() |
|
|
|
def register_mcp_tools(self): |
|
""" |
|
Register the MCP tools used by this controller. |
|
""" |
|
if self.mcp_client: |
|
for server_name in self.mcp_client.server_name_to_tools: |
|
for tool in self.mcp_client.server_name_to_tools[server_name]: |
|
tool_name = f"mcp.{server_name}.{tool.name}" |
|
self.registry.registry.actions[tool_name] = RegisteredAction( |
|
name=tool_name, |
|
description=tool.description, |
|
function=tool, |
|
param_model=create_tool_param_model(tool), |
|
) |
|
logger.info(f"Add mcp tool: {tool_name}") |
|
logger.debug( |
|
f"Registered {len(self.mcp_client.server_name_to_tools[server_name])} mcp tools for {server_name}") |
|
else: |
|
logger.warning(f"MCP client not started.") |
|
|
|
async def close_mcp_client(self): |
|
if self.mcp_client: |
|
await self.mcp_client.__aexit__(None, None, None) |
|
|