File size: 7,162 Bytes
246d201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig, LLMConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
return agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Patch AgentController and State.restore_from_session to fail
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.max_iterations == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(sid='test-session', file_store=file_store)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=Runtime)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
mock_restored_state.truncation_id = -1
mock_restored_state.max_iterations = 5
# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Patch AgentController and State.restore_from_session to succeed
with patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
), patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
), patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
):
await session.start(
runtime_name='test-runtime',
config=AppConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.max_iterations == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
assert session.controller.state.truncation_id == -1
|