|
import re
|
|
import traceback
|
|
|
|
from datatypes import ParseError, StepOutput, TaskState
|
|
from tasks.base import Task
|
|
|
|
from openhands.controller.state.state import State
|
|
|
|
|
|
class SimplifiedEnv:
|
|
INVALID_INPUT_MESSAGE = (
|
|
"I don't understand your input. \n"
|
|
'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
|
|
'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
|
|
'For example: The answer to the question is <solution> 42 </solution>. \n'
|
|
)
|
|
|
|
def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]):
|
|
self.agent_state = agent_state
|
|
self.task = task
|
|
|
|
agent_action_count = {
|
|
'propose_solution': 0,
|
|
'use_tool': 0,
|
|
'invalid_action': 0,
|
|
}
|
|
|
|
if hasattr(self.agent_state, 'propose_solution_count'):
|
|
agent_action_count['propose_solution'] = (
|
|
self.agent_state.propose_solution_count
|
|
)
|
|
|
|
self.task_state = TaskState(agent_action_count=agent_action_count)
|
|
|
|
self.task_config = task_config
|
|
|
|
def step(self, lm_message: str):
|
|
observation = self.handle_propose_solution(lm_message)
|
|
|
|
self.check_max_iteration()
|
|
|
|
turn_info = (
|
|
self.task_config['max_iterations'] - self.agent_state.iteration,
|
|
self.task_config['max_propose_solution']
|
|
- self.task_state.agent_action_count['propose_solution'],
|
|
)
|
|
|
|
output = StepOutput(
|
|
observation=observation,
|
|
success=self.task_state.success,
|
|
turn_info=turn_info,
|
|
)
|
|
|
|
self.agent_state.propose_solution_count = self.task_state.agent_action_count[
|
|
'propose_solution'
|
|
]
|
|
self.log_output(output)
|
|
return self.task_state
|
|
|
|
def handle_propose_solution(self, lm_message) -> str | None:
|
|
"""Propose answer to check the task success.
|
|
|
|
It might set self.state.finished = True if the task is successful.
|
|
"""
|
|
self.task_state.agent_action_count['propose_solution'] += 1
|
|
try:
|
|
parsed = self.parse_propose_solution(lm_message)
|
|
task_success = self.check_task_success(parsed['answer'])
|
|
if task_success:
|
|
self.task_state.finished = True
|
|
self.task_state.success = True
|
|
self.task_state.terminate_reason = 'task_success'
|
|
|
|
|
|
except ParseError:
|
|
return SimplifiedEnv.INVALID_INPUT_MESSAGE
|
|
except Exception:
|
|
error_traceback = traceback.format_exc()
|
|
return f'{error_traceback}'
|
|
|
|
def parse_propose_solution(self, lm_message: str) -> dict:
|
|
"""Define the parsing logic."""
|
|
lm_output = '\n' + lm_message + '\n'
|
|
|
|
answer = '\n'.join(
|
|
[
|
|
i.strip()
|
|
for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
|
|
]
|
|
)
|
|
if answer == '':
|
|
raise ParseError('No answer found.')
|
|
|
|
return {'answer': answer}
|
|
|
|
def log_output(self, output: StepOutput) -> None:
|
|
if self.task_state.finished:
|
|
return
|
|
|
|
content = output.to_str()
|
|
self.task_state.latest_output = output.to_dict()
|
|
self.task_state.latest_output['content'] = content
|
|
|
|
def check_task_success(self, answer: str) -> bool:
|
|
|
|
|
|
return self.task.success(answer)
|
|
|
|
def check_max_iteration(self):
|
|
"""Check if the agent has reached the max iteration limit.
|
|
|
|
It might set self.state.finished = True if the agent has reached the max iteration limit.
|
|
"""
|
|
if self.task_state.finished:
|
|
|
|
return
|
|
|
|
if (
|
|
|
|
self.task_state.agent_action_count['propose_solution']
|
|
>= self.task_config['max_propose_solution']
|
|
):
|
|
self.task_state.finished = True
|
|
self.task_state.success = False
|
|
self.task_state.terminate_reason = 'max_propose_steps'
|
|
elif self.agent_state.iteration >= self.task_config['max_iterations']:
|
|
self.task_state.finished = True
|
|
self.task_state.success = False
|
|
self.task_state.terminate_reason = 'max_iterations'
|
|
|