File size: 3,549 Bytes
6a0e448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# ruff: noqa: E501
import json
import re
from functools import cached_property
from typing import Literal, Optional

from proxy_lite.agents import AgentConfigTypes, Agents, BaseAgent
from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.history import (
    MessageHistory,
    MessageLabel,
    SystemMessage,
)
from proxy_lite.solvers.solver_base import BaseSolver, BaseSolverConfig, Solvers
from proxy_lite.tools import ReturnValueTool, Tool

WEB_TOOL_TURN = """The action has been attempted in the computer."""


@Solvers.register_solver_config("simple")
class SimpleSolverConfig(BaseSolverConfig):
    name: Literal["simple"] = "simple"
    agent: AgentConfigTypes


@Solvers.register_solver("simple")
class SimpleSolver(BaseSolver):
    task: Optional[str] = None
    complete: bool = False

    @cached_property
    def tools(self) -> list[Tool]:
        return [ReturnValueTool()] + self.env_tools

    @cached_property
    def agent(self) -> BaseAgent:
        self.logger.debug(f"Tools: {self.tools}")
        return Agents.get(self.config.agent.name)(
            config=self.config.agent,
            env_tools=self.tools,
        )

    @property
    def history(self) -> MessageHistory:
        return MessageHistory(
            messages=[SystemMessage.from_media(text=self.agent.system_prompt)] + self.agent.history.messages,
        )

    async def initialise(self, task: str, env_tools: list[Tool], env_info: str) -> None:
        self.env_tools = env_tools
        self.task = task
        self.agent.receive_user_message(
            text=f"Task: {task}",
            label=MessageLabel.USER_INPUT,
        )
        self.logger.debug(f"Initialised with task: {task}")

    async def act(self, observation: Observation) -> Action:
        self.agent.receive_user_message(
            image=observation.state.image,
            text=observation.state.text,
            label=MessageLabel.SCREENSHOT,
            is_base64=True,
        )

        message = await self.agent.generate_output(use_tool=True)

        self.logger.debug(f"Assistant message generated: {message}")

        # check tool calls for return_value
        if any(tool_call.function["name"] == "return_value" for tool_call in message.tool_calls):
            self.complete = True
            arguments = json.loads(message.tool_calls[0].function["arguments"])
            if isinstance(arguments, str):
                arguments = json.loads(arguments)
            return_value = arguments["value"]
            return Action(tool_calls=[], text=return_value)

        text_content = message.content[0].text

        observation_match = re.search(r"<observation>(.*?)</observation>", text_content, re.DOTALL)
        observation_content = observation_match.group(1).strip() if observation_match else ""

        self.logger.info("🌐 [bold blue]Observation:[/]")
        await self.logger.stream_message(observation_content)

        # Extract text between thinking tags if present
        thinking_match = re.search(r"<thinking>(.*?)</thinking>", text_content, re.DOTALL)
        thinking_content = thinking_match.group(1).strip() if thinking_match else text_content

        self.logger.info("🧠 [bold purple]Thinking:[/]")
        await self.logger.stream_message(thinking_content)

        return Action(tool_calls=message.tool_calls, text=text_content)

    async def is_complete(self, observation: Observation) -> bool:
        env_terminated = observation.terminated
        return self.complete or env_terminated