import re import traceback from datatypes import ParseError, StepOutput, TaskState from tasks.base import Task from openhands.controller.state.state import State class SimplifiedEnv: INVALID_INPUT_MESSAGE = ( "I don't understand your input. \n" 'If you want to execute code, please use YOUR_CODE_HERE .\n' 'If you want to give me an answer, please use YOUR_SOLUTION_HERE .\n' 'For example: The answer to the question is 42 . \n' ) def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]): self.agent_state = agent_state self.task = task agent_action_count = { 'propose_solution': 0, 'use_tool': 0, 'invalid_action': 0, } # check if agent_state has attribute turn_info set if hasattr(self.agent_state, 'propose_solution_count'): agent_action_count['propose_solution'] = ( self.agent_state.propose_solution_count ) self.task_state = TaskState(agent_action_count=agent_action_count) self.task_config = task_config def step(self, lm_message: str): observation = self.handle_propose_solution(lm_message) self.check_max_iteration() turn_info = ( self.task_config['max_iterations'] - self.agent_state.iteration, self.task_config['max_propose_solution'] - self.task_state.agent_action_count['propose_solution'], ) output = StepOutput( observation=observation, success=self.task_state.success, turn_info=turn_info, ) self.agent_state.propose_solution_count = self.task_state.agent_action_count[ 'propose_solution' ] self.log_output(output) return self.task_state def handle_propose_solution(self, lm_message) -> str | None: """Propose answer to check the task success. It might set self.state.finished = True if the task is successful. """ self.task_state.agent_action_count['propose_solution'] += 1 try: parsed = self.parse_propose_solution(lm_message) task_success = self.check_task_success(parsed['answer']) if task_success: self.task_state.finished = True self.task_state.success = True self.task_state.terminate_reason = 'task_success' # NOTE: should not return the function now, because we need to log the output # Set state.finished = True will terminate the episode except ParseError: return SimplifiedEnv.INVALID_INPUT_MESSAGE except Exception: error_traceback = traceback.format_exc() return f'{error_traceback}' def parse_propose_solution(self, lm_message: str) -> dict: """Define the parsing logic.""" lm_output = '\n' + lm_message + '\n' answer = '\n'.join( [ i.strip() for i in re.findall(r'(.*?)', lm_output, re.DOTALL) ] ) if answer == '': raise ParseError('No answer found.') return {'answer': answer} def log_output(self, output: StepOutput) -> None: if self.task_state.finished: return content = output.to_str() self.task_state.latest_output = output.to_dict() self.task_state.latest_output['content'] = content def check_task_success(self, answer: str) -> bool: # log_message.info(f"STUDENT ANSWER: [{answer}]") # log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]") return self.task.success(answer) def check_max_iteration(self): """Check if the agent has reached the max iteration limit. It might set self.state.finished = True if the agent has reached the max iteration limit. """ if self.task_state.finished: # ignore if the episode is already finished (e.g., task success) return if ( # propose solution > max output solution self.task_state.agent_action_count['propose_solution'] >= self.task_config['max_propose_solution'] ): self.task_state.finished = True self.task_state.success = False self.task_state.terminate_reason = 'max_propose_steps' elif self.agent_state.iteration >= self.task_config['max_iterations']: self.task_state.finished = True self.task_state.success = False self.task_state.terminate_reason = 'max_iterations'