Spaces:
Running
Running
| import json | |
| import logging | |
| import pdb | |
| import traceback | |
| from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Type, TypeVar | |
| from PIL import Image, ImageDraw, ImageFont | |
| import os | |
| import base64 | |
| import io | |
| import asyncio | |
| import time | |
| import platform | |
| from browser_use.agent.prompts import SystemPrompt, AgentMessagePrompt | |
| from browser_use.agent.service import Agent | |
| from browser_use.agent.message_manager.utils import convert_input_messages, extract_json_from_model_output, \ | |
| save_conversation | |
| from browser_use.agent.views import ( | |
| ActionResult, | |
| AgentError, | |
| AgentHistory, | |
| AgentHistoryList, | |
| AgentOutput, | |
| AgentSettings, | |
| AgentState, | |
| AgentStepInfo, | |
| StepMetadata, | |
| ToolCallingMethod, | |
| ) | |
| from browser_use.agent.gif import create_history_gif | |
| from browser_use.browser.browser import Browser | |
| from browser_use.browser.context import BrowserContext | |
| from browser_use.browser.views import BrowserStateHistory | |
| from browser_use.controller.service import Controller | |
| from browser_use.telemetry.views import ( | |
| AgentEndTelemetryEvent, | |
| AgentRunTelemetryEvent, | |
| AgentStepTelemetryEvent, | |
| ) | |
| from browser_use.utils import time_execution_async | |
| from langchain_core.language_models.chat_models import BaseChatModel | |
| from langchain_core.messages import ( | |
| BaseMessage, | |
| HumanMessage, | |
| AIMessage | |
| ) | |
| from browser_use.browser.views import BrowserState, BrowserStateHistory | |
| from browser_use.agent.prompts import PlannerPrompt | |
| from json_repair import repair_json | |
| from src.utils.agent_state import AgentState | |
| from .custom_message_manager import CustomMessageManager, CustomMessageManagerSettings | |
| from .custom_views import CustomAgentOutput, CustomAgentStepInfo, CustomAgentState | |
| logger = logging.getLogger(__name__) | |
| Context = TypeVar('Context') | |
| class CustomAgent(Agent): | |
| def __init__( | |
| self, | |
| task: str, | |
| llm: BaseChatModel, | |
| add_infos: str = "", | |
| # Optional parameters | |
| browser: Browser | None = None, | |
| browser_context: BrowserContext | None = None, | |
| controller: Controller[Context] = Controller(), | |
| # Initial agent run parameters | |
| sensitive_data: Optional[Dict[str, str]] = None, | |
| initial_actions: Optional[List[Dict[str, Dict[str, Any]]]] = None, | |
| # Cloud Callbacks | |
| register_new_step_callback: Callable[['BrowserState', 'AgentOutput', int], Awaitable[None]] | None = None, | |
| register_done_callback: Callable[['AgentHistoryList'], Awaitable[None]] | None = None, | |
| register_external_agent_status_raise_error_callback: Callable[[], Awaitable[bool]] | None = None, | |
| # Agent settings | |
| use_vision: bool = True, | |
| use_vision_for_planner: bool = False, | |
| save_conversation_path: Optional[str] = None, | |
| save_conversation_path_encoding: Optional[str] = 'utf-8', | |
| max_failures: int = 3, | |
| retry_delay: int = 10, | |
| system_prompt_class: Type[SystemPrompt] = SystemPrompt, | |
| agent_prompt_class: Type[AgentMessagePrompt] = AgentMessagePrompt, | |
| max_input_tokens: int = 128000, | |
| validate_output: bool = False, | |
| message_context: Optional[str] = None, | |
| generate_gif: bool | str = False, | |
| available_file_paths: Optional[list[str]] = None, | |
| include_attributes: list[str] = [ | |
| 'title', | |
| 'type', | |
| 'name', | |
| 'role', | |
| 'aria-label', | |
| 'placeholder', | |
| 'value', | |
| 'alt', | |
| 'aria-expanded', | |
| 'data-date-format', | |
| ], | |
| max_actions_per_step: int = 10, | |
| tool_calling_method: Optional[ToolCallingMethod] = 'auto', | |
| page_extraction_llm: Optional[BaseChatModel] = None, | |
| planner_llm: Optional[BaseChatModel] = None, | |
| planner_interval: int = 1, # Run planner every N steps | |
| # Inject state | |
| injected_agent_state: Optional[AgentState] = None, | |
| context: Context | None = None, | |
| ): | |
| super(CustomAgent, self).__init__( | |
| task=task, | |
| llm=llm, | |
| browser=browser, | |
| browser_context=browser_context, | |
| controller=controller, | |
| sensitive_data=sensitive_data, | |
| initial_actions=initial_actions, | |
| register_new_step_callback=register_new_step_callback, | |
| register_done_callback=register_done_callback, | |
| register_external_agent_status_raise_error_callback=register_external_agent_status_raise_error_callback, | |
| use_vision=use_vision, | |
| use_vision_for_planner=use_vision_for_planner, | |
| save_conversation_path=save_conversation_path, | |
| save_conversation_path_encoding=save_conversation_path_encoding, | |
| max_failures=max_failures, | |
| retry_delay=retry_delay, | |
| system_prompt_class=system_prompt_class, | |
| max_input_tokens=max_input_tokens, | |
| validate_output=validate_output, | |
| message_context=message_context, | |
| generate_gif=generate_gif, | |
| available_file_paths=available_file_paths, | |
| include_attributes=include_attributes, | |
| max_actions_per_step=max_actions_per_step, | |
| tool_calling_method=tool_calling_method, | |
| page_extraction_llm=page_extraction_llm, | |
| planner_llm=planner_llm, | |
| planner_interval=planner_interval, | |
| injected_agent_state=injected_agent_state, | |
| context=context, | |
| ) | |
| self.state = injected_agent_state or CustomAgentState() | |
| self.add_infos = add_infos | |
| self._message_manager = CustomMessageManager( | |
| task=task, | |
| system_message=self.settings.system_prompt_class( | |
| self.available_actions, | |
| max_actions_per_step=self.settings.max_actions_per_step, | |
| ).get_system_message(), | |
| settings=CustomMessageManagerSettings( | |
| max_input_tokens=self.settings.max_input_tokens, | |
| include_attributes=self.settings.include_attributes, | |
| message_context=self.settings.message_context, | |
| sensitive_data=sensitive_data, | |
| available_file_paths=self.settings.available_file_paths, | |
| agent_prompt_class=agent_prompt_class | |
| ), | |
| state=self.state.message_manager_state, | |
| ) | |
| def _log_response(self, response: CustomAgentOutput) -> None: | |
| """Log the model's response""" | |
| if "Success" in response.current_state.evaluation_previous_goal: | |
| emoji = "β " | |
| elif "Failed" in response.current_state.evaluation_previous_goal: | |
| emoji = "β" | |
| else: | |
| emoji = "π€·" | |
| logger.info(f"{emoji} Eval: {response.current_state.evaluation_previous_goal}") | |
| logger.info(f"π§ New Memory: {response.current_state.important_contents}") | |
| logger.info(f"π€ Thought: {response.current_state.thought}") | |
| logger.info(f"π― Next Goal: {response.current_state.next_goal}") | |
| for i, action in enumerate(response.action): | |
| logger.info( | |
| f"π οΈ Action {i + 1}/{len(response.action)}: {action.model_dump_json(exclude_unset=True)}" | |
| ) | |
| def _setup_action_models(self) -> None: | |
| """Setup dynamic action models from controller's registry""" | |
| # Get the dynamic action model from controller's registry | |
| self.ActionModel = self.controller.registry.create_action_model() | |
| # Create output model with the dynamic actions | |
| self.AgentOutput = CustomAgentOutput.type_with_custom_actions(self.ActionModel) | |
| def update_step_info( | |
| self, model_output: CustomAgentOutput, step_info: CustomAgentStepInfo = None | |
| ): | |
| """ | |
| update step info | |
| """ | |
| if step_info is None: | |
| return | |
| step_info.step_number += 1 | |
| important_contents = model_output.current_state.important_contents | |
| if ( | |
| important_contents | |
| and "None" not in important_contents | |
| and important_contents not in step_info.memory | |
| ): | |
| step_info.memory += important_contents + "\n" | |
| logger.info(f"π§ All Memory: \n{step_info.memory}") | |
| async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput: | |
| """Get next action from LLM based on current state""" | |
| fixed_input_messages = self._convert_input_messages(input_messages) | |
| ai_message = self.llm.invoke(fixed_input_messages) | |
| self.message_manager._add_message_with_tokens(ai_message) | |
| if hasattr(ai_message, "reasoning_content"): | |
| logger.info("π€― Start Deep Thinking: ") | |
| logger.info(ai_message.reasoning_content) | |
| logger.info("π€― End Deep Thinking") | |
| if isinstance(ai_message.content, list): | |
| ai_content = ai_message.content[0] | |
| else: | |
| ai_content = ai_message.content | |
| try: | |
| ai_content = ai_content.replace("```json", "").replace("```", "") | |
| ai_content = repair_json(ai_content) | |
| parsed_json = json.loads(ai_content) | |
| parsed: AgentOutput = self.AgentOutput(**parsed_json) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| logger.debug(ai_message.content) | |
| raise ValueError('Could not parse response.') | |
| if parsed is None: | |
| logger.debug(ai_message.content) | |
| raise ValueError('Could not parse response.') | |
| # cut the number of actions to max_actions_per_step if needed | |
| if len(parsed.action) > self.settings.max_actions_per_step: | |
| parsed.action = parsed.action[: self.settings.max_actions_per_step] | |
| self._log_response(parsed) | |
| return parsed | |
| async def _run_planner(self) -> Optional[str]: | |
| """Run the planner to analyze state and suggest next steps""" | |
| # Skip planning if no planner_llm is set | |
| if not self.settings.planner_llm: | |
| return None | |
| # Create planner message history using full message history | |
| planner_messages = [ | |
| PlannerPrompt(self.controller.registry.get_prompt_description()).get_system_message(), | |
| *self.message_manager.get_messages()[1:], # Use full message history except the first | |
| ] | |
| if not self.settings.use_vision_for_planner and self.settings.use_vision: | |
| last_state_message: HumanMessage = planner_messages[-1] | |
| # remove image from last state message | |
| new_msg = '' | |
| if isinstance(last_state_message.content, list): | |
| for msg in last_state_message.content: | |
| if msg['type'] == 'text': | |
| new_msg += msg['text'] | |
| elif msg['type'] == 'image_url': | |
| continue | |
| else: | |
| new_msg = last_state_message.content | |
| planner_messages[-1] = HumanMessage(content=new_msg) | |
| # Get planner output | |
| response = await self.settings.planner_llm.ainvoke(planner_messages) | |
| plan = str(response.content) | |
| last_state_message = self.message_manager.get_messages()[-1] | |
| if isinstance(last_state_message, HumanMessage): | |
| # remove image from last state message | |
| if isinstance(last_state_message.content, list): | |
| for msg in last_state_message.content: | |
| if msg['type'] == 'text': | |
| msg['text'] += f"\nPlanning Agent outputs plans:\n {plan}\n" | |
| else: | |
| last_state_message.content += f"\nPlanning Agent outputs plans:\n {plan}\n " | |
| try: | |
| plan_json = json.loads(plan.replace("```json", "").replace("```", "")) | |
| logger.info(f'π Plans:\n{json.dumps(plan_json, indent=4)}') | |
| if hasattr(response, "reasoning_content"): | |
| logger.info("π€― Start Planning Deep Thinking: ") | |
| logger.info(response.reasoning_content) | |
| logger.info("π€― End Planning Deep Thinking") | |
| except json.JSONDecodeError: | |
| logger.info(f'π Plans:\n{plan}') | |
| except Exception as e: | |
| logger.debug(f'Error parsing planning analysis: {e}') | |
| logger.info(f'π Plans: {plan}') | |
| return plan | |
| async def step(self, step_info: Optional[CustomAgentStepInfo] = None) -> None: | |
| """Execute one step of the task""" | |
| logger.info(f"\nπ Step {self.state.n_steps}") | |
| state = None | |
| model_output = None | |
| result: list[ActionResult] = [] | |
| step_start_time = time.time() | |
| tokens = 0 | |
| try: | |
| state = await self.browser_context.get_state() | |
| await self._raise_if_stopped_or_paused() | |
| self.message_manager.add_state_message(state, self.state.last_action, self.state.last_result, step_info, | |
| self.settings.use_vision) | |
| # Run planner at specified intervals if planner is configured | |
| if self.settings.planner_llm and self.state.n_steps % self.settings.planner_interval == 0: | |
| await self._run_planner() | |
| input_messages = self.message_manager.get_messages() | |
| tokens = self._message_manager.state.history.current_tokens | |
| try: | |
| model_output = await self.get_next_action(input_messages) | |
| self.update_step_info(model_output, step_info) | |
| self.state.n_steps += 1 | |
| if self.register_new_step_callback: | |
| await self.register_new_step_callback(state, model_output, self.state.n_steps) | |
| if self.settings.save_conversation_path: | |
| target = self.settings.save_conversation_path + f'_{self.state.n_steps}.txt' | |
| save_conversation(input_messages, model_output, target, | |
| self.settings.save_conversation_path_encoding) | |
| if self.model_name != "deepseek-reasoner": | |
| # remove prev message | |
| self.message_manager._remove_state_message_by_index(-1) | |
| await self._raise_if_stopped_or_paused() | |
| except Exception as e: | |
| # model call failed, remove last state message from history | |
| self.message_manager._remove_state_message_by_index(-1) | |
| raise e | |
| result: list[ActionResult] = await self.multi_act(model_output.action) | |
| for ret_ in result: | |
| if ret_.extracted_content and "Extracted page" in ret_.extracted_content: | |
| # record every extracted page | |
| if ret_.extracted_content[:100] not in self.state.extracted_content: | |
| self.state.extracted_content += ret_.extracted_content | |
| self.state.last_result = result | |
| self.state.last_action = model_output.action | |
| if len(result) > 0 and result[-1].is_done: | |
| if not self.state.extracted_content: | |
| self.state.extracted_content = step_info.memory | |
| result[-1].extracted_content = self.state.extracted_content | |
| logger.info(f"π Result: {result[-1].extracted_content}") | |
| self.state.consecutive_failures = 0 | |
| except InterruptedError: | |
| logger.debug('Agent paused') | |
| self.state.last_result = [ | |
| ActionResult( | |
| error='The agent was paused - now continuing actions might need to be repeated', | |
| include_in_memory=True | |
| ) | |
| ] | |
| return | |
| except Exception as e: | |
| result = await self._handle_step_error(e) | |
| self.state.last_result = result | |
| finally: | |
| step_end_time = time.time() | |
| actions = [a.model_dump(exclude_unset=True) for a in model_output.action] if model_output else [] | |
| self.telemetry.capture( | |
| AgentStepTelemetryEvent( | |
| agent_id=self.state.agent_id, | |
| step=self.state.n_steps, | |
| actions=actions, | |
| consecutive_failures=self.state.consecutive_failures, | |
| step_error=[r.error for r in result if r.error] if result else ['No result'], | |
| ) | |
| ) | |
| if not result: | |
| return | |
| if state: | |
| metadata = StepMetadata( | |
| step_number=self.state.n_steps, | |
| step_start_time=step_start_time, | |
| step_end_time=step_end_time, | |
| input_tokens=tokens, | |
| ) | |
| self._make_history_item(model_output, state, result, metadata) | |
| async def run(self, max_steps: int = 100) -> AgentHistoryList: | |
| """Execute the task with maximum number of steps""" | |
| try: | |
| self._log_agent_run() | |
| # Execute initial actions if provided | |
| if self.initial_actions: | |
| result = await self.multi_act(self.initial_actions, check_for_new_elements=False) | |
| self.state.last_result = result | |
| step_info = CustomAgentStepInfo( | |
| task=self.task, | |
| add_infos=self.add_infos, | |
| step_number=1, | |
| max_steps=max_steps, | |
| memory="", | |
| ) | |
| for step in range(max_steps): | |
| # Check if we should stop due to too many failures | |
| if self.state.consecutive_failures >= self.settings.max_failures: | |
| logger.error(f'β Stopping due to {self.settings.max_failures} consecutive failures') | |
| break | |
| # Check control flags before each step | |
| if self.state.stopped: | |
| logger.info('Agent stopped') | |
| break | |
| while self.state.paused: | |
| await asyncio.sleep(0.2) # Small delay to prevent CPU spinning | |
| if self.state.stopped: # Allow stopping while paused | |
| break | |
| await self.step(step_info) | |
| if self.state.history.is_done(): | |
| if self.settings.validate_output and step < max_steps - 1: | |
| if not await self._validate_output(): | |
| continue | |
| await self.log_completion() | |
| break | |
| else: | |
| logger.info("β Failed to complete task in maximum steps") | |
| if not self.state.extracted_content: | |
| self.state.history.history[-1].result[-1].extracted_content = step_info.memory | |
| else: | |
| self.state.history.history[-1].result[-1].extracted_content = self.state.extracted_content | |
| return self.state.history | |
| finally: | |
| self.telemetry.capture( | |
| AgentEndTelemetryEvent( | |
| agent_id=self.state.agent_id, | |
| is_done=self.state.history.is_done(), | |
| success=self.state.history.is_successful(), | |
| steps=self.state.n_steps, | |
| max_steps_reached=self.state.n_steps >= max_steps, | |
| errors=self.state.history.errors(), | |
| total_input_tokens=self.state.history.total_input_tokens(), | |
| total_duration_seconds=self.state.history.total_duration_seconds(), | |
| ) | |
| ) | |
| if not self.injected_browser_context: | |
| await self.browser_context.close() | |
| if not self.injected_browser and self.browser: | |
| await self.browser.close() | |
| if self.settings.generate_gif: | |
| output_path: str = 'agent_history.gif' | |
| if isinstance(self.settings.generate_gif, str): | |
| output_path = self.settings.generate_gif | |
| create_history_gif(task=self.task, history=self.state.history, output_path=output_path) | |