File size: 4,889 Bytes
246d201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import re
import traceback
from datatypes import ParseError, StepOutput, TaskState
from tasks.base import Task
from openhands.controller.state.state import State
class SimplifiedEnv:
INVALID_INPUT_MESSAGE = (
"I don't understand your input. \n"
'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
'For example: The answer to the question is <solution> 42 </solution>. \n'
)
def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]):
self.agent_state = agent_state
self.task = task
agent_action_count = {
'propose_solution': 0,
'use_tool': 0,
'invalid_action': 0,
}
# check if agent_state has attribute turn_info set
if hasattr(self.agent_state, 'propose_solution_count'):
agent_action_count['propose_solution'] = (
self.agent_state.propose_solution_count
)
self.task_state = TaskState(agent_action_count=agent_action_count)
self.task_config = task_config
def step(self, lm_message: str):
observation = self.handle_propose_solution(lm_message)
self.check_max_iteration()
turn_info = (
self.task_config['max_iterations'] - self.agent_state.iteration,
self.task_config['max_propose_solution']
- self.task_state.agent_action_count['propose_solution'],
)
output = StepOutput(
observation=observation,
success=self.task_state.success,
turn_info=turn_info,
)
self.agent_state.propose_solution_count = self.task_state.agent_action_count[
'propose_solution'
]
self.log_output(output)
return self.task_state
def handle_propose_solution(self, lm_message) -> str | None:
"""Propose answer to check the task success.
It might set self.state.finished = True if the task is successful.
"""
self.task_state.agent_action_count['propose_solution'] += 1
try:
parsed = self.parse_propose_solution(lm_message)
task_success = self.check_task_success(parsed['answer'])
if task_success:
self.task_state.finished = True
self.task_state.success = True
self.task_state.terminate_reason = 'task_success'
# NOTE: should not return the function now, because we need to log the output
# Set state.finished = True will terminate the episode
except ParseError:
return SimplifiedEnv.INVALID_INPUT_MESSAGE
except Exception:
error_traceback = traceback.format_exc()
return f'{error_traceback}'
def parse_propose_solution(self, lm_message: str) -> dict:
"""Define the parsing logic."""
lm_output = '\n' + lm_message + '\n'
answer = '\n'.join(
[
i.strip()
for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
]
)
if answer == '':
raise ParseError('No answer found.')
return {'answer': answer}
def log_output(self, output: StepOutput) -> None:
if self.task_state.finished:
return
content = output.to_str()
self.task_state.latest_output = output.to_dict()
self.task_state.latest_output['content'] = content
def check_task_success(self, answer: str) -> bool:
# log_message.info(f"STUDENT ANSWER: [{answer}]")
# log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]")
return self.task.success(answer)
def check_max_iteration(self):
"""Check if the agent has reached the max iteration limit.
It might set self.state.finished = True if the agent has reached the max iteration limit.
"""
if self.task_state.finished:
# ignore if the episode is already finished (e.g., task success)
return
if (
# propose solution > max output solution
self.task_state.agent_action_count['propose_solution']
>= self.task_config['max_propose_solution']
):
self.task_state.finished = True
self.task_state.success = False
self.task_state.terminate_reason = 'max_propose_steps'
elif self.agent_state.iteration >= self.task_config['max_iterations']:
self.task_state.finished = True
self.task_state.success = False
self.task_state.terminate_reason = 'max_iterations'
|