|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from openhands.controller.agent_controller import AgentController
|
|
from openhands.events import EventSource
|
|
from openhands.events.action import CmdRunAction, MessageAction
|
|
from openhands.events.observation import CmdOutputObservation
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_event_stream():
|
|
stream = MagicMock()
|
|
|
|
stream.get_events.return_value = []
|
|
|
|
stream.get_latest_event_id.return_value = 0
|
|
return stream
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
agent = MagicMock()
|
|
agent.llm = MagicMock()
|
|
agent.llm.config = MagicMock()
|
|
return agent
|
|
|
|
|
|
class TestTruncation:
|
|
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test_truncation',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
|
first_msg._source = EventSource.USER
|
|
first_msg._id = 1
|
|
|
|
cmd1 = CmdRunAction(command='ls')
|
|
cmd1._id = 2
|
|
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
|
|
obs1._id = 3
|
|
obs1._cause = 2
|
|
|
|
cmd2 = CmdRunAction(command='pwd')
|
|
cmd2._id = 4
|
|
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
|
|
obs2._id = 5
|
|
obs2._cause = 4
|
|
|
|
events = [first_msg, cmd1, obs1, cmd2, obs2]
|
|
|
|
|
|
truncated = controller._apply_conversation_window(events)
|
|
|
|
|
|
assert (
|
|
len(truncated) >= 3
|
|
)
|
|
assert truncated[0] == first_msg
|
|
assert controller.state.start_id == first_msg._id
|
|
assert controller.state.truncation_id is not None
|
|
|
|
|
|
for i, event in enumerate(truncated[1:]):
|
|
if isinstance(event, CmdOutputObservation):
|
|
assert any(e._id == event._cause for e in truncated[: i + 1])
|
|
|
|
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test_truncation',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
|
first_msg._source = EventSource.USER
|
|
first_msg._id = 1
|
|
|
|
|
|
agent_msg = MessageAction(
|
|
content='What task would you like me to perform?', wait_for_response=True
|
|
)
|
|
agent_msg._source = EventSource.AGENT
|
|
agent_msg._id = 2
|
|
|
|
|
|
user_response = MessageAction(
|
|
content='Please list all files and show me current directory',
|
|
wait_for_response=False,
|
|
)
|
|
user_response._source = EventSource.USER
|
|
user_response._id = 3
|
|
|
|
cmd1 = CmdRunAction(command='ls')
|
|
cmd1._id = 4
|
|
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
|
obs1._id = 5
|
|
obs1._cause = 4
|
|
|
|
|
|
mock_event_stream.get_events.return_value = [
|
|
first_msg,
|
|
agent_msg,
|
|
user_response,
|
|
cmd1,
|
|
obs1,
|
|
]
|
|
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
|
|
original_history_len = len(controller.state.history)
|
|
|
|
|
|
controller.state.history = controller._apply_conversation_window(
|
|
controller.state.history
|
|
)
|
|
|
|
|
|
assert len(controller.state.history) < original_history_len
|
|
assert controller.state.start_id == first_msg._id
|
|
assert controller.state.truncation_id is not None
|
|
assert controller.state.truncation_id > controller.state.start_id
|
|
|
|
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
|
|
controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test_truncation',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
|
|
|
|
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
|
first_msg._source = EventSource.USER
|
|
first_msg._id = 1
|
|
|
|
events = [first_msg]
|
|
for i in range(5):
|
|
cmd = CmdRunAction(command=f'cmd{i}')
|
|
cmd._id = i + 2
|
|
obs = CmdOutputObservation(
|
|
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
|
)
|
|
obs._cause = cmd._id
|
|
events.extend([cmd, obs])
|
|
|
|
|
|
controller.state.history = events.copy()
|
|
|
|
|
|
controller.state.history = controller._apply_conversation_window(
|
|
controller.state.history
|
|
)
|
|
|
|
|
|
saved_start_id = controller.state.start_id
|
|
saved_truncation_id = controller.state.truncation_id
|
|
saved_history_len = len(controller.state.history)
|
|
|
|
|
|
mock_event_stream.get_events.return_value = controller.state.history
|
|
|
|
|
|
new_controller = AgentController(
|
|
agent=mock_agent,
|
|
event_stream=mock_event_stream,
|
|
max_iterations=10,
|
|
sid='test_truncation',
|
|
confirmation_mode=False,
|
|
headless_mode=True,
|
|
)
|
|
new_controller.state.start_id = saved_start_id
|
|
new_controller.state.truncation_id = saved_truncation_id
|
|
new_controller.state.history = mock_event_stream.get_events()
|
|
|
|
|
|
assert len(new_controller.state.history) == saved_history_len
|
|
assert new_controller.state.history[0] == first_msg
|
|
assert new_controller.state.start_id == saved_start_id
|
|
|