diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 1043a7d72..d5387a047 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -5,17 +5,17 @@ import httpx import pytest import pytest_asyncio +from starlette.applications import Starlette from a2a.client.base_client import BaseClient from a2a.client.client import ClientConfig from a2a.client.client_factory import ClientFactory from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.routes.rest_routes import create_rest_routes -from starlette.applications import Starlette -from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes from a2a.server.events import EventQueue from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager -from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler +from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import TaskUpdater from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( @@ -37,7 +37,7 @@ TaskState, a2a_pb2_grpc, ) -from a2a.utils import TransportProtocol +from a2a.utils import TransportProtocol, new_task from a2a.utils.errors import InvalidParamsError @@ -69,7 +69,9 @@ def assert_events_match(events, expected_events): events, expected_events, strict=True ): assert event.HasField(expected_type) - if expected_type == 'status_update': + if expected_type == 'task': + assert event.task.status.state == expected_val + elif expected_type == 'status_update': assert event.status_update.status.state == expected_val elif expected_type == 'artifact_update': if expected_val is not None: @@ -83,26 +85,30 @@ def assert_events_match(events, expected_events): class MockAgentExecutor(AgentExecutor): async def execute(self, context: RequestContext, event_queue: EventQueue): - task_updater = TaskUpdater( - event_queue, - context.task_id, - context.context_id, - ) user_input = context.get_user_input() - is_input_required_resumption = ( - context.current_task is not None - and context.current_task.status.state - == TaskState.TASK_STATE_INPUT_REQUIRED - ) - - if not is_input_required_resumption: - await task_updater.update_status( - TaskState.TASK_STATE_SUBMITTED, - message=task_updater.new_agent_message( - [Part(text='task submitted')] - ), + # Direct message response (no task created). + if user_input.startswith('Message:'): + await event_queue.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id='direct-reply-1', + parts=[Part(text=f'Direct reply to: {user_input}')], + ) ) + return + + # Task-based response. + task = context.current_task + if not task: + task = new_task(context.message) + await event_queue.enqueue_event(task) + + task_updater = TaskUpdater( + event_queue, + task.id, + task.context_id, + ) await task_updater.update_status( TaskState.TASK_STATE_WORKING, @@ -168,7 +174,8 @@ class ClientSetup(NamedTuple): @pytest.fixture def base_e2e_setup(agent_card): task_store = InMemoryTaskStore() - handler = DefaultRequestHandler( + # TODO(https://github.com/a2aproject/a2a-python/issues/869): Use DefaultRequestHandler once it's fixed + handler = LegacyRequestHandler( agent_executor=MockAgentExecutor(), task_store=task_store, agent_card=agent_card, @@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups): response.task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('artifact_update', [('test-artifact', 'artifact content')]), ('status_update', TaskState.TASK_STATE_COMPLETED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert_history_matches( task.history, [ (Role.ROLE_USER, 'Run dummy agent!'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups): ) ] response = events[0] - task_id = response.status_update.task_id + task_id = response.task.id get_request = GetTaskRequest(id=task_id) retrieved_task = await client.get_task(request=get_request) @@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups): retrieved_task.history, [ (Role.ROLE_USER, 'Test Get Task'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups): ) ) ) - expected_task_ids.append(response.status_update.task_id) + expected_task_ids.append(response.task.id) list_request = ListTasksRequest(page_size=page_size) @@ -514,13 +518,13 @@ async def test_end_to_end_input_required(transport_setups): assert_events_match( events, [ - ('status_update', TaskState.TASK_STATE_SUBMITTED), + ('task', TaskState.TASK_STATE_SUBMITTED), ('status_update', TaskState.TASK_STATE_WORKING), ('status_update', TaskState.TASK_STATE_INPUT_REQUIRED), ], ) - task_id = events[0].status_update.task_id + task_id = events[0].task.id task = await client.get_task(request=GetTaskRequest(id=task_id)) assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED @@ -528,7 +532,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), ], ) @@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups): task.history, [ (Role.ROLE_USER, 'Need input'), - (Role.ROLE_AGENT, 'task submitted'), (Role.ROLE_AGENT, 'task working'), (Role.ROLE_AGENT, 'Please provide input'), (Role.ROLE_USER, 'Here is the input'), @@ -681,3 +683,78 @@ async def test_end_to_end_subscribe_validation_error( assert {e['field'] for e in errors} == {'id'} await client.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'streaming', + [ + pytest.param(False, id='blocking'), + pytest.param(True, id='streaming'), + ], +) +async def test_end_to_end_direct_message(transport_setups, streaming): + """Test that an executor can return a direct Message without creating a Task.""" + client = transport_setups.client + client._config.streaming = streaming + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct', + parts=[Part(text='Message: Hello agent')], + ) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest(message=message_to_send) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Hello agent', + ) + + +@pytest.mark.asyncio +async def test_end_to_end_direct_message_return_immediately(transport_setups): + """Test that return_immediately still returns the Message for direct replies. + + When the executor responds with a direct Message, the response is + inherently immediate -- there is no async task to defer to. The client + should receive the Message regardless of the return_immediately flag. + """ + client = transport_setups.client + client._config.streaming = False + + message_to_send = Message( + role=Role.ROLE_USER, + message_id='msg-direct-return-immediately', + parts=[Part(text='Message: Quick question')], + ) + configuration = SendMessageConfiguration(return_immediately=True) + + events = [ + event + async for event in client.send_message( + request=SendMessageRequest( + message=message_to_send, configuration=configuration + ) + ) + ] + + assert len(events) == 1 + response = events[0] + assert response.HasField('message') + assert not response.HasField('task') + assert_message_matches( + response.message, + Role.ROLE_AGENT, + 'Direct reply to: Message: Quick question', + )