Skip to content

Commit e3236ce

Browse files
committed
Updates
1 parent c7a29a9 commit e3236ce

2 files changed

Lines changed: 172 additions & 3 deletions

File tree

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,9 @@ async def on_message_send( # noqa: D102
279279
result = event
280280
# Do NOT break here as Message is supposed to be the only
281281
# event in "Message-only" interaction.
282-
# We rely on AgentExecutor here to simplify wrong implementations detection.
282+
# ActiveTask consumer (see active_task.py) validates the event
283+
# stream and raises InvalidAgentResponseError if more events are
284+
# pushed after a Message.
283285

284286
if result is None:
285287
logger.debug('Missing result for task %s', request_context.task_id)
@@ -315,8 +317,12 @@ async def on_message_send_stream( # noqa: D102
315317
request=request_context,
316318
include_initial_task=False,
317319
):
318-
# Do NOT break here as we rely on AgentExecutor to yield control
319-
# to simplify wrong implementations detection.
320+
# Do NOT break here as we rely on AgentExecutor to yield control.
321+
# ActiveTask consumer (see active_task.py) validates the event
322+
# stream and raises InvalidAgentResponseError on misbehaving agents:
323+
# - an event after a Message
324+
# - Message after entering task mode
325+
# - an event after a terminal state
320326
if isinstance(event, Task):
321327
self._validate_task_id_match(task_id, event.id)
322328
yield apply_history_length(event, params.configuration)

tests/server/request_handlers/test_default_request_handler_v2.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from a2a.types import (
3030
InternalError,
31+
InvalidAgentResponseError,
3132
InvalidParamsError,
3233
TaskNotFoundError,
3334
PushNotificationNotSupportedError,
@@ -1244,3 +1245,165 @@ async def test_on_message_send_with_push_notification():
12441245
push_store.set_info.assert_awaited_once_with(
12451246
result.id, push_config, context
12461247
)
1248+
1249+
1250+
class MultipleMessagesAgentExecutor(AgentExecutor):
1251+
"""Misbehaving agent that yields more than one Message."""
1252+
1253+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1254+
await event_queue.enqueue_event(
1255+
new_text_message('first', role=Role.ROLE_AGENT)
1256+
)
1257+
await event_queue.enqueue_event(
1258+
new_text_message('second', role=Role.ROLE_AGENT)
1259+
)
1260+
1261+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1262+
pass
1263+
1264+
1265+
class MessageAfterTaskEventAgentExecutor(AgentExecutor):
1266+
"""Misbehaving agent that yields a task-mode event then a Message."""
1267+
1268+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1269+
task = new_task_from_user_message(context.message)
1270+
await event_queue.enqueue_event(task)
1271+
updater = TaskUpdater(event_queue, task.id, task.context_id)
1272+
await updater.update_status(TaskState.TASK_STATE_WORKING)
1273+
await event_queue.enqueue_event(
1274+
new_text_message('stray message', role=Role.ROLE_AGENT)
1275+
)
1276+
1277+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1278+
pass
1279+
1280+
1281+
class TaskEventAfterMessageAgentExecutor(AgentExecutor):
1282+
"""Misbehaving agent that yields a Message and then a task-mode event."""
1283+
1284+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1285+
await event_queue.enqueue_event(
1286+
new_text_message('only message', role=Role.ROLE_AGENT)
1287+
)
1288+
await event_queue.enqueue_event(
1289+
TaskStatusUpdateEvent(
1290+
task_id=str(context.task_id or ''),
1291+
context_id=str(context.context_id or ''),
1292+
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
1293+
)
1294+
)
1295+
1296+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1297+
pass
1298+
1299+
1300+
class EventAfterTerminalStateAgentExecutor(AgentExecutor):
1301+
"""Misbehaving agent that yields an event after reaching a terminal state."""
1302+
1303+
async def execute(self, context: RequestContext, event_queue: EventQueue):
1304+
task = new_task_from_user_message(context.message)
1305+
await event_queue.enqueue_event(task)
1306+
updater = TaskUpdater(event_queue, task.id, task.context_id)
1307+
await updater.complete()
1308+
await event_queue.enqueue_event(
1309+
new_text_message('after terminal', role=Role.ROLE_AGENT)
1310+
)
1311+
1312+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
1313+
pass
1314+
1315+
1316+
@pytest.mark.asyncio
1317+
@pytest.mark.timeout(1)
1318+
async def test_on_message_send_stream_rejects_multiple_messages():
1319+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1320+
second Message after the first one (see comment in on_message_send_stream)."""
1321+
request_handler = DefaultRequestHandlerV2(
1322+
agent_executor=MultipleMessagesAgentExecutor(),
1323+
task_store=InMemoryTaskStore(),
1324+
agent_card=create_default_agent_card(),
1325+
)
1326+
params = SendMessageRequest(
1327+
message=Message(
1328+
role=Role.ROLE_USER,
1329+
message_id='msg_multi_stream',
1330+
parts=[Part(text='Hi')],
1331+
)
1332+
)
1333+
with pytest.raises(InvalidAgentResponseError):
1334+
async for _ in request_handler.on_message_send_stream(
1335+
params, create_server_call_context()
1336+
):
1337+
pass
1338+
1339+
1340+
@pytest.mark.asyncio
1341+
@pytest.mark.timeout(1)
1342+
async def test_on_message_send_stream_rejects_message_after_task_event():
1343+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1344+
Message after entering task mode (see comment in on_message_send_stream)."""
1345+
request_handler = DefaultRequestHandlerV2(
1346+
agent_executor=MessageAfterTaskEventAgentExecutor(),
1347+
task_store=InMemoryTaskStore(),
1348+
agent_card=create_default_agent_card(),
1349+
)
1350+
params = SendMessageRequest(
1351+
message=Message(
1352+
role=Role.ROLE_USER,
1353+
message_id='msg_after_task_stream',
1354+
parts=[Part(text='Hi')],
1355+
)
1356+
)
1357+
with pytest.raises(InvalidAgentResponseError):
1358+
async for _ in request_handler.on_message_send_stream(
1359+
params, create_server_call_context()
1360+
):
1361+
pass
1362+
1363+
1364+
@pytest.mark.asyncio
1365+
@pytest.mark.timeout(1)
1366+
async def test_on_message_send_stream_rejects_task_event_after_message():
1367+
"""Stream surfaces InvalidAgentResponseError when the agent yields a
1368+
task-mode event after a Message (see comment in on_message_send_stream)."""
1369+
request_handler = DefaultRequestHandlerV2(
1370+
agent_executor=TaskEventAfterMessageAgentExecutor(),
1371+
task_store=InMemoryTaskStore(),
1372+
agent_card=create_default_agent_card(),
1373+
)
1374+
params = SendMessageRequest(
1375+
message=Message(
1376+
role=Role.ROLE_USER,
1377+
message_id='msg_then_task_stream',
1378+
parts=[Part(text='Hi')],
1379+
)
1380+
)
1381+
with pytest.raises(InvalidAgentResponseError):
1382+
async for _ in request_handler.on_message_send_stream(
1383+
params, create_server_call_context()
1384+
):
1385+
pass
1386+
1387+
1388+
@pytest.mark.asyncio
1389+
@pytest.mark.timeout(1)
1390+
async def test_on_message_send_stream_rejects_event_after_terminal_state():
1391+
"""Stream surfaces InvalidAgentResponseError when the agent yields an event
1392+
after reaching a terminal state (see comment in on_message_send_stream)."""
1393+
request_handler = DefaultRequestHandlerV2(
1394+
agent_executor=EventAfterTerminalStateAgentExecutor(),
1395+
task_store=InMemoryTaskStore(),
1396+
agent_card=create_default_agent_card(),
1397+
)
1398+
params = SendMessageRequest(
1399+
message=Message(
1400+
role=Role.ROLE_USER,
1401+
message_id='msg_after_terminal_stream',
1402+
parts=[Part(text='Hi')],
1403+
)
1404+
)
1405+
with pytest.raises(InvalidAgentResponseError):
1406+
async for _ in request_handler.on_message_send_stream(
1407+
params, create_server_call_context()
1408+
):
1409+
pass

0 commit comments

Comments
 (0)