From f2b41d08feb25d0d2fe838499a9d874d96ecab81 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Tue, 14 Apr 2026 10:57:06 +0000 Subject: [PATCH] Additonal tests for request handlers. --- src/a2a/server/agent_execution/active_task.py | 34 +- .../cross_version/client_server/server_0_3.py | 10 +- .../cross_version/client_server/server_1_0.py | 8 +- .../integration/test_copying_observability.py | 9 +- tests/integration/test_scenarios.py | 613 ++++++++++++++---- .../agent_execution/test_active_task.py | 201 +----- .../test_default_request_handler_v2.py | 18 +- 7 files changed, 561 insertions(+), 332 deletions(-) diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index db7bb5146..5479a38c1 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -36,6 +36,7 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ( + InvalidAgentResponseError, InvalidParamsError, TaskNotFoundError, ) @@ -370,13 +371,12 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 elif isinstance(event, Message): if task_mode is not None: if task_mode: - logger.error( - 'Received Message() object in task mode.' - ) - else: - logger.error( - 'Multiple Message() objects received.' + raise InvalidAgentResponseError( + 'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.' ) + raise InvalidAgentResponseError( + 'Multiple Message objects received.' + ) task_mode = False logger.debug( 'Consumer[%s]: Setting result to Message: %s', @@ -385,9 +385,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 ) else: if task_mode is False: - logger.error( - 'Received %s in message mode.', - type(event).__name__, + raise InvalidAgentResponseError( + f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.' ) if isinstance(event, Task): @@ -408,6 +407,18 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 # Initial task should already contain the message. message_to_save = None else: + if ( + isinstance(event, TaskStatusUpdateEvent) + and not self._task_created.is_set() + ): + task = ( + await self._task_manager.get_task() + ) + if task is None: + raise InvalidAgentResponseError( + f'Agent should enqueue Task before {type(event).__name__} event' + ) + new_task = ( await self._task_manager.ensure_task_id( self._task_id, @@ -434,8 +445,6 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 if not isinstance(event, Task): await self._task_manager.process(event) - self._task_created.set() - # Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states new_task = await self._task_manager.get_task() if new_task is None: @@ -496,6 +505,9 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912 await self._push_sender.send_notification( self._task_id, event ) + + self._task_created.set() + finally: if new_task is not None: new_task_copy = Task() diff --git a/tests/integration/cross_version/client_server/server_0_3.py b/tests/integration/cross_version/client_server/server_0_3.py index 7bd5f7e75..875cbb1ca 100644 --- a/tests/integration/cross_version/client_server/server_0_3.py +++ b/tests/integration/cross_version/client_server/server_0_3.py @@ -38,7 +38,7 @@ from starlette.requests import Request from starlette.concurrency import iterate_in_threadpool import time - +from a2a.utils.task import new_task from server_common import CustomLoggingMiddleware @@ -48,12 +48,18 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + + task = new_task(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.working + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.submitted) await task_updater.update_status(TaskState.working) text = '' diff --git a/tests/integration/cross_version/client_server/server_1_0.py b/tests/integration/cross_version/client_server/server_1_0.py index e11b1d69d..1ed15ea45 100644 --- a/tests/integration/cross_version/client_server/server_1_0.py +++ b/tests/integration/cross_version/client_server/server_1_0.py @@ -28,6 +28,7 @@ from a2a.utils import TransportProtocol from server_common import CustomLoggingMiddleware from google.protobuf.struct_pb2 import Struct, Value +from a2a.utils.task import new_task class MockAgentExecutor(AgentExecutor): @@ -36,12 +37,17 @@ def __init__(self): async def execute(self, context: RequestContext, event_queue: EventQueue): print(f'SERVER: execute called for task {context.task_id}') + task = new_task(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + task_updater = TaskUpdater( event_queue, context.task_id, context.context_id, ) - await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) await task_updater.update_status(TaskState.TASK_STATE_WORKING) text = '' diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py index d5171097a..0e362b7d6 100644 --- a/tests/integration/test_copying_observability.py +++ b/tests/integration/test_copying_observability.py @@ -26,6 +26,7 @@ TaskState, ) from a2a.utils import TransportProtocol +from a2a.utils.task import new_task class MockMutatingAgentExecutor(AgentExecutor): @@ -42,6 +43,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue): if user_input == 'Init task': # Explicitly save status change to ensure task exists with some state + task = new_task(context.message) + task.id = context.task_id + task.context_id = context.context_id + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + await task_updater.update_status( TaskState.TASK_STATE_WORKING, message=task_updater.new_agent_message( @@ -153,6 +160,7 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): ] event = events[-1] + assert event.HasField('status_update') task_id = event.status_update.task_id # 2. Second message to mutate it @@ -162,7 +170,6 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): task_id=task_id, parts=[Part(text='Update task without saving it')], ) - _ = [ event async for event in client.send_message( diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index c50622e5c..4f7959ed7 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -1,5 +1,6 @@ import asyncio import collections +import contextlib import logging from typing import Any @@ -47,10 +48,12 @@ TaskStatusUpdateEvent, ) from a2a.utils import TransportProtocol +from a2a.utils.task import new_task from a2a.utils.errors import ( InvalidParamsError, TaskNotCancelableError, TaskNotFoundError, + InvalidAgentResponseError, ) @@ -246,13 +249,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -277,7 +276,11 @@ async def cancel( event async for event in client.send_message(SendMessageRequest(message=msg)) ] - assert [event.status_update.status.state for event in events] == [ + task, status_update = events + assert task.HasField('task') + assert status_update.HasField('status_update') + + assert [get_state(event) for event in events] == [ TaskState.TASK_STATE_WORKING, TaskState.TASK_STATE_COMPLETED, ] @@ -291,13 +294,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -350,13 +349,9 @@ class DummyAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) async def cancel( self, context: RequestContext, event_queue: EventQueue @@ -393,11 +388,9 @@ async def cancel( (event,) = [event async for event in it] if streaming: - assert event.HasField('status_update') - task_id = event.status_update.task_id - assert ( - event.status_update.status.state == TaskState.TASK_STATE_COMPLETED - ) + assert event.HasField('task') + task_id = event.task.id + validate_state(event, TaskState.TASK_STATE_COMPLETED) else: assert event.HasField('task') task_id = event.task.id @@ -485,13 +478,9 @@ class ErrorAfterAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() raise ValueError('TEST_ERROR_IN_EXECUTE') @@ -515,7 +504,7 @@ async def cancel( if streaming: res = await it.__anext__() - assert res.status_update.status.state == TaskState.TASK_STATE_WORKING + validate_state(res, TaskState.TASK_STATE_WORKING) continue_event.set() else: @@ -554,13 +543,9 @@ class ErrorCancelAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await hang_event.wait() @@ -614,13 +599,9 @@ class ErrorAfterAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() raise ValueError('TEST_ERROR_IN_EXECUTE') @@ -744,13 +725,9 @@ class DummyCancelAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await hang_event.wait() @@ -812,13 +789,9 @@ class ComplexAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await working_event.wait() @@ -931,13 +904,9 @@ async def execute( ) return - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) started_event.set() await continue_event.wait() await event_queue.enqueue_event( @@ -1059,13 +1028,9 @@ class ImmediateAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -1120,27 +1085,17 @@ async def execute( ): message = context.message if message and message.parts and message.parts[0].text == 'start': - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_INPUT_REQUIRED - ), - ) - ) + task = new_task(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) elif ( message and message.parts and message.parts[0].text == 'here is input' ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) + task = new_task(message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) else: raise ValueError('Unexpected message') @@ -1209,13 +1164,9 @@ class AuthAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskStatusUpdateEvent( task_id=context.task_id, @@ -1295,15 +1246,9 @@ async def execute( ): message = context.message if message and message.parts and message.parts[0].text == 'start': - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_AUTH_REQUIRED - ), - ) - ) + task = new_task(message) + task.status.state = TaskState.TASK_STATE_AUTH_REQUIRED + await event_queue.enqueue_event(task) elif ( message and message.parts @@ -1316,6 +1261,7 @@ async def execute( status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), ) ) + else: raise ValueError(f'Unexpected message {message}') @@ -1380,13 +1326,9 @@ class EmitAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - context_id=context.context_id, - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ) + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) phases = [ ('trigger_phase_1', 'artifact_1'), @@ -1602,6 +1544,9 @@ class ArtifactAgent(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ): + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) await event_queue.enqueue_event( TaskArtifactUpdateEvent( task_id=context.task_id, @@ -1724,7 +1669,7 @@ async def cancel( configuration=SendMessageConfiguration(return_immediately=False), ) ) - events = [event async for event in it] + _ = [event async for event in it] (final_task,) = (await client.list_tasks(ListTasksRequest())).tasks @@ -1744,4 +1689,438 @@ async def cancel( if record.levelname == 'ERROR' and 'Ignoring task replacement' in record.message ] + assert len(error_logs) == 1 + + +# Scenario: Task restoration - terminal state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_first', + [False, True], + ids=['no_subscribe_first', 'subscribe_first'], +) +async def test_restore_task_terminal_state( + use_legacy, streaming, subscribe_first +): + class TerminalAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + task_id = get_task_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_COMPLETED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler( + TerminalAgent(), use_legacy, task_store=task_store + ) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_COMPLETED + + if subscribe_first and streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + msg2 = Message( + task_id=task_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='message to completed task')], + ) + + with pytest.raises(Exception, match=r'terminal state'): + async for _ in client2.send_message(SendMessageRequest(message=msg2)): + pass + + if streaming: + with pytest.raises( + Exception, + match=r'terminal state', + ): + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + pass + + +# Scenario: Task restoration - user input required state +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize( + 'subscribe_mode', + ['none', 'drop', 'listen'], + ids=['no_sub', 'sub_drop', 'sub_listen'], +) +async def test_restore_task_input_required_state( + use_legacy, streaming, subscribe_mode +): + class InputAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + task = new_task(message) + task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED + await event_queue.enqueue_event(task) + elif message and message.parts and message.parts[0].text == 'input': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + task_store = InMemoryTaskStore() + handler1 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client1 = await create_client( + handler1, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='test-msg-1', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + it1 = client1.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events1 = [event async for event in it1] + + task_id = get_task_id(events1[-1]) + context_id = get_task_context_id(events1[-1]) + + await wait_for_state( + client1, task_id, expected_states={TaskState.TASK_STATE_INPUT_REQUIRED} + ) + + # Restore task in a new handler (simulating server restart) + handler2 = create_handler(InputAgent(), use_legacy, task_store=task_store) + client2 = await create_client( + handler2, agent_card=agent_card(), streaming=streaming + ) + + restored_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert restored_task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Subscription logic based on mode + listen_task = None + if streaming: + if subscribe_mode == 'drop': + # Subscribing and dropping immediately (cancelling the generator) + async for _ in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + break + elif subscribe_mode == 'listen': + sub_started_event = asyncio.Event() + + async def listen_to_end(): + res = [] + async for ev in client2.subscribe( + SubscribeToTaskRequest(id=task_id) + ): + res.append(ev) + sub_started_event.set() + return res + + listen_task = asyncio.create_task(listen_to_end()) + # Wait for subscription to establish and yield the initial task event + await asyncio.wait_for(sub_started_event.wait(), timeout=1.0) + + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='test-msg-2', + role=Role.ROLE_USER, + parts=[Part(text='input')], + ) + + it2 = client2.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + events2 = [event async for event in it2] + + if streaming: + assert ( + events2[-1].status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ) + else: + assert events2[-1].task.status.state == TaskState.TASK_STATE_COMPLETED + + if listen_task: + if use_legacy and streaming: + # Error: Legacy handler does not properly manage subscriptions for restored tasks + with pytest.raises(TaskNotFoundError): + await listen_task + else: + listen_events = await listen_task + # The first event is the initial task state (INPUT_REQUIRED), the last should be COMPLETED + assert ( + get_state(listen_events[-1]) == TaskState.TASK_STATE_COMPLETED + ) + + final_task = await client2.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + + +# Scenario 20: Create initial task with new_task +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +@pytest.mark.parametrize('initial_task_type', ['new_task', 'status_update']) +async def test_scenario_initial_task_types( + use_legacy, streaming, initial_task_type +): + started_event = asyncio.Event() + continue_event = asyncio.Event() + + class InitialTaskAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + if initial_task_type == 'new_task': + # Create with new_task + task = new_task(context.message) + task.status.state = TaskState.TASK_STATE_WORKING + await event_queue.enqueue_event(task) + else: + # Create with status update (illegal in v2) + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + ) + + started_event.set() + await continue_event.wait() + + await event_queue.enqueue_event( + TaskArtifactUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + artifact=Artifact( + artifact_id='art-1', parts=[Part(text='artifact data')] + ), + ) + ) + + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(InitialTaskAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg, + configuration=SendMessageConfiguration( + return_immediately=streaming + ), + ) + ) + + if streaming: + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + await it.__anext__() + + # End of the test. + return + + res = await it.__anext__() + if initial_task_type == 'status_update' and use_legacy: + # First message has to be a Task. + assert res.HasField('status_update') + + # End of the test. + return + + assert res.HasField('task') + task_id = get_task_id(res) + + await asyncio.wait_for(started_event.wait(), timeout=1.0) + + # Start subscription + sub = client.subscribe(SubscribeToTaskRequest(id=task_id)) + + # first subscriber receives current task state (WORKING) + first_event = await sub.__anext__() + assert first_event.HasField('task') + + continue_event.set() + + events = [first_event] + [event async for event in sub] + else: + # blocking + async def release_agent(): + await started_event.wait() + continue_event.set() + + release_task = asyncio.create_task(release_agent()) + if initial_task_type == 'status_update' and not use_legacy: + with pytest.raises( + InvalidAgentResponseError, + match='Agent should enqueue Task before TaskStatusUpdateEvent event', + ): + events = [event async for event in it] + # End of the test + return + else: + events = [event async for event in it] + await release_task + + if streaming: + task, artifact_update, status_update = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_WORKING) + assert artifact_update.artifact_update.artifact.artifact_id == 'art-1' + assert status_update.HasField('status_update') + validate_state(status_update, TaskState.TASK_STATE_COMPLETED) + else: + (task,) = events + assert task.HasField('task') + validate_state(task, TaskState.TASK_STATE_COMPLETED) + (artifact,) = task.task.artifacts + assert artifact.artifact_id == 'art-1' + task_id = task.task.id + + (final_task_from_list,) = ( + await client.list_tasks(ListTasksRequest(include_artifacts=True)) + ).tasks + assert len(final_task_from_list.artifacts) > 0 + assert final_task_from_list.artifacts[0].artifact_id == 'art-1' + + final_task = await client.get_task(GetTaskRequest(id=task_id)) + assert final_task.status.state == TaskState.TASK_STATE_COMPLETED + assert len(final_task.artifacts) > 0 + assert final_task.artifacts[0].artifact_id == 'art-1' + + +# Scenario 23: Invalid Agent Response - Task followed by Message +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_23_invalid_response_task_message(use_legacy, streaming): + class TaskMessageAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + await event_queue.enqueue_event(new_task(context.message)) + await event_queue.enqueue_event( + Message(message_id='m1', parts=[Part(text='m1')]) + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(TaskMessageAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg = Message( + message_id='test-msg', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message(SendMessageRequest(message=msg)) + + if use_legacy: + # Legacy: no error. + async for _ in it: + pass + else: + with pytest.raises( + InvalidAgentResponseError, + match='Received Message object in task mode', + ): + async for _ in it: + pass diff --git a/tests/server/agent_execution/test_active_task.py b/tests/server/agent_execution/test_active_task.py index 3a4a24ff6..2d74a59d9 100644 --- a/tests/server/agent_execution/test_active_task.py +++ b/tests/server/agent_execution/test_active_task.py @@ -19,8 +19,11 @@ TaskState, TaskStatus, TaskStatusUpdateEvent, + Role, + Part, ) from a2a.utils.errors import InvalidParamsError +from a2a.utils.task import new_task logger = logging.getLogger(__name__) @@ -71,51 +74,6 @@ async def active_task( push_sender=push_sender, ) - @pytest.mark.asyncio - async def test_active_task_lifecycle( - self, - active_task: ActiveTask, - agent_executor: Mock, - request_context: Mock, - task_manager: Mock, - ) -> None: - """Test the basic lifecycle of an ActiveTask.""" - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - await active_task.enqueue_request(request_context) - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - # Wait for the task to finish - events = [e async for e in active_task.subscribe()] - result = next(e for e in events if isinstance(e, Message)) - - assert isinstance(result, Message) - assert result.message_id == 'm1' - assert active_task.task_id == 'test-task-id' - @pytest.mark.asyncio async def test_active_task_already_started( self, active_task: ActiveTask, request_context: Mock @@ -132,36 +90,6 @@ async def test_active_task_already_started( ) assert active_task._producer_task is not None - @pytest.mark.asyncio - async def test_active_task_subscribe( - self, - active_task: ActiveTask, - agent_executor: Mock, - request_context: Mock, - ) -> None: - """Test subscribing to events from an ActiveTask.""" - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event(Message(message_id='m2')) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - - await active_task.enqueue_request(request_context) - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - events = [] - async for event in active_task.subscribe(): - events.append(event) - if len(events) == 2: - break - - assert len(events) == 2 - assert events[0].message_id == 'm1' - assert events[1].message_id == 'm2' - @pytest.mark.asyncio async def test_active_task_cancel( self, @@ -355,59 +283,6 @@ async def execute_mock(req, q): push_sender.send_notification.assert_called() - @pytest.mark.asyncio - async def test_active_task_cleanup( - self, - agent_executor: Mock, - task_manager: Mock, - request_context: Mock, - ) -> None: - """Test that the cleanup callback is called.""" - on_cleanup = Mock() - active_task = ActiveTask( - agent_executor=agent_executor, - task_id='test-task-id', - task_manager=task_manager, - on_cleanup=on_cleanup, - ) - - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - async for _ in active_task.subscribe(request=request_context): - pass - - # Wait for consumer thread to finish and call cleanup - for _ in range(20): - if on_cleanup.called: - break - await asyncio.sleep(0.05) - - on_cleanup.assert_called_once_with(active_task) - @pytest.mark.asyncio async def test_active_task_consumer_failure( self, @@ -894,76 +769,6 @@ async def test_active_task_maybe_cleanup_not_finished( await active_task._maybe_cleanup() on_cleanup.assert_not_called() - @pytest.mark.asyncio - async def test_active_task_maybe_cleanup_with_subscribers( - self, - agent_executor: Mock, - task_manager: Mock, - push_sender: Mock, - request_context: Mock, - ) -> None: - """Test that cleanup is not called if there are subscribers.""" - on_cleanup = Mock() - active_task = ActiveTask( - agent_executor=agent_executor, - task_id='test-task-id', - task_manager=task_manager, - push_sender=push_sender, - on_cleanup=on_cleanup, - ) - - # Mock execute to finish immediately - async def execute_mock(req, q): - await q.enqueue_event(Message(message_id='m1')) - await q.enqueue_event( - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ) - - agent_executor.execute = AsyncMock(side_effect=execute_mock) - task_manager.get_task.side_effect = [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - ] + [ - Task( - id='test-task-id', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ) - ] * 10 - - # 1. Start a subscriber before task finishes - gen = active_task.subscribe() - # Start the generator to increment reference count - msg_task = asyncio.create_task(gen.__anext__()) - - # 2. Start the task and wait for it to finish - await active_task.start( - call_context=ServerCallContext(), create_task_if_missing=True - ) - - async for _ in active_task.subscribe(request=request_context): - pass - - # Give the consumer loop a moment to set _is_finished - await asyncio.sleep(0.1) - - # Ensure we got the message - assert (await msg_task).message_id == 'm1' - - # At this point, task is finished, but we still have a subscriber (gen). - # _maybe_cleanup was called by consumer loop, but should have done nothing. - on_cleanup.assert_not_called() - - # 3. Close the subscriber - await gen.aclose() - - # Now cleanup should be triggered - on_cleanup.assert_called_once_with(active_task) - @pytest.mark.asyncio async def test_active_task_subscribe_exception_already_set( self, active_task: ActiveTask diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index d48b82461..cf35d6376 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -53,6 +53,7 @@ TaskPushNotificationConfig, TaskState, TaskStatus, + TaskStatusUpdateEvent, ) from a2a.utils import new_agent_text_message, new_task @@ -68,11 +69,15 @@ def create_default_agent_card(): class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): + if context.message: + await event_queue.enqueue_event(new_task(context.message)) + task_updater = TaskUpdater( event_queue, str(context.task_id or ''), str(context.context_id or ''), ) + async for i in self._run(): parts = [Part(text=f'Event {i}')] try: @@ -569,8 +574,15 @@ async def consume_stream(): elapsed = time.perf_counter() - start assert len(events) == 3 assert elapsed < 0.5 - texts = [p.text for e in events for p in e.status.message.parts] - assert texts == ['Event 0', 'Event 1', 'Event 2'] + task, event0, event1 = events + assert isinstance(task, Task) + assert task.history[0].parts[0].text == 'How are you?' + + assert isinstance(event0, TaskStatusUpdateEvent) + assert event0.status.message.parts[0].text == 'Event 0' + + assert isinstance(event1, TaskStatusUpdateEvent) + assert event1.status.message.parts[0].text == 'Event 1' @pytest.mark.asyncio @@ -951,6 +963,8 @@ class HelloWorldAgentExecutor(AgentExecutor): async def execute( self, context: RequestContext, event_queue: EventQueue ) -> None: + if context.message: + await event_queue.enqueue_event(new_task(context.message)) updater = TaskUpdater( event_queue, task_id=context.task_id or str(uuid.uuid4()),