from typing import Union from pydantic import BaseModel, Field from openhands.core.logger import openhands_logger as logger from openhands.events.action import ( Action, ChangeAgentStateAction, MessageAction, NullAction, ) from openhands.events.event import EventSource from openhands.events.observation import ( AgentStateChangedObservation, NullObservation, Observation, ) from openhands.events.serialization.event import event_to_dict from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput TraceElement = Union[Message, ToolCall, ToolOutput, Function] def get_next_id(trace: list[TraceElement]) -> str: used_ids = [el.id for el in trace if type(el) == ToolCall] for i in range(1, len(used_ids) + 2): if str(i) not in used_ids: return str(i) return '1' def get_last_id( trace: list[TraceElement], ) -> str | None: for el in reversed(trace): if type(el) == ToolCall: return el.id return None def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]: next_id = get_next_id(trace) inv_trace = [] # type: list[TraceElement] if type(action) == MessageAction: if action.source == EventSource.USER: inv_trace.append(Message(role='user', content=action.content)) else: inv_trace.append(Message(role='assistant', content=action.content)) elif type(action) in [NullAction, ChangeAgentStateAction]: pass elif hasattr(action, 'action') and action.action is not None: event_dict = event_to_dict(action) args = event_dict.get('args', {}) thought = args.pop('thought', None) function = Function(name=action.action, arguments=args) if thought is not None: inv_trace.append(Message(role='assistant', content=thought)) inv_trace.append(ToolCall(id=next_id, type='function', function=function)) else: logger.error(f'Unknown action type: {type(action)}') return inv_trace def parse_observation( trace: list[TraceElement], obs: Observation ) -> list[TraceElement]: last_id = get_last_id(trace) if type(obs) in [NullObservation, AgentStateChangedObservation]: return [] elif hasattr(obs, 'content') and obs.content is not None: return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)] else: logger.error(f'Unknown observation type: {type(obs)}') return [] def parse_element( trace: list[TraceElement], element: Action | Observation ) -> list[TraceElement]: if isinstance(element, Action): return parse_action(trace, element) return parse_observation(trace, element) def parse_trace(trace: list[tuple[Action, Observation]]): inv_trace = [] # type: list[TraceElement] for action, obs in trace: inv_trace.extend(parse_action(inv_trace, action)) inv_trace.extend(parse_observation(inv_trace, obs)) return inv_trace class InvariantState(BaseModel): trace: list[TraceElement] = Field(default_factory=list) def add_action(self, action: Action): self.trace.extend(parse_action(self.trace, action)) def add_observation(self, obs: Observation): self.trace.extend(parse_observation(self.trace, obs)) def concatenate(self, other: 'InvariantState'): self.trace.extend(other.trace)