diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 2da8ddfd7..1c3866047 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -23,20 +23,43 @@ async def execute( return once the agent's execution for this request is complete or yields control (e.g., enters an input-required state). - TODO: Document request lifecycle and AgentExecutor responsibilities: - - Should not close the event_queue. - - Guarantee single execution per request (no concurrent execution). - - Throwing exception will result in TaskState.TASK_STATE_ERROR (CHECK!) - - Once call is completed it should not access context or event_queue - - Before completing the call it SHOULD update task status to terminal or interrupted state. - - Explain AUTH_REQUIRED workflow. - - Explain INPUT_REQUIRED workflow. - - Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc) - - Explain if execute can wait for cancel and if cancel can wait for execute. - - Explain behaviour of streaming / not-immediate when execute() returns in active state. - - Possible workflows: - - Enqueue a SINGLE Message object - - Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent. + Request Lifecycle & AgentExecutor Responsibilities: + - **Concurrency**: The framework guarantees single execution per request; + `execute()` will not be called concurrently for the same request context. + - **Exception Handling**: Unhandled exceptions raised by `execute()` will be + caught by the framework and result in the task transitioning to + `TaskState.TASK_STATE_ERROR`. + - **Post-Completion**: Once `execute()` completes (returns or raises), the + executor must not access the `context` or `event_queue` anymore. + - **Terminal States**: Before completing the call normally, the executor + SHOULD publish a `TaskStatusUpdateEvent` to transition the task to a + terminal state (e.g., `TASK_STATE_COMPLETED`) or an interrupted state + (`TASK_STATE_INPUT_REQUIRED` or `TASK_STATE_AUTH_REQUIRED`). + - **Interrupted Workflows**: + - `TASK_STATE_INPUT_REQUIRED`: The executor publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_INPUT_REQUIRED` and returns to yield control. + The request will resume once user input is provided. + - `TASK_STATE_AUTH_REQUIRED`: There are in-bound and out-of-bound auth models. + In both scenarios, the agent publishes a `TaskStatusUpdateEvent` with + `TaskState.TASK_STATE_AUTH_REQUIRED`. + - In-bound: The agent should return from `execute()`. The framework will + call `execute()` again once the user response is received. + - Out-of-bound: The agent should not return from `execute()`. It should wait + for the out-of-band auth provider to complete the authentication and then + continue execution. + + - **Cancellation Workflow**: When a cancellation request is received, the + async task running `execute()` is cancelled (raising an `asyncio.CancelledError`), + and `cancel()` is explicitly called by the framework. + + Allowed Workflows: + - Immediate response: Enqueue a SINGLE `Message` object. + - Asynchronous/Long-running: Enqueue a `Task` object, perform work, and emit + multiple `TaskStatusUpdateEvent` / `TaskArtifactUpdateEvent` objects over time. + + Note that the framework waits with response to the send_message request with + `return_immediately=True` parameter until the first event (Message or Task) + is enqueued by AgentExecutor. Args: context: The request context containing the message, task ID, etc. @@ -53,10 +76,6 @@ async def cancel( in the context and publish a `TaskStatusUpdateEvent` with state `TaskState.TASK_STATE_CANCELED` to the `event_queue`. - TODO: Document cancelation workflow. - - What if TaskState.TASK_STATE_CANCELED is not set by cancel() ? - - How it can interact with execute() ? - Args: context: The request context containing the task ID to cancel. event_queue: The queue to publish the cancellation status update to. diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index cee15bfcb..c50622e5c 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -113,6 +113,22 @@ def agent_card(): ) +def get_task_id(event): + if event.HasField('task'): + return event.task.id + if event.HasField('status_update'): + return event.status_update.task_id + assert False, f'Event {event} has no task_id' + + +def get_task_context_id(event): + if event.HasField('task'): + return event.task.context_id + if event.HasField('status_update'): + return event.status_update.context_id + assert False, f'Event {event} has no context_id' + + def get_state(event): if event.HasField('task'): return event.task.status.state @@ -1265,6 +1281,93 @@ async def cancel( ) +# Scenario: Auth required and in channel unblocking +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_auth_required_in_channel(use_legacy, streaming): + class AuthAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + message = context.message + if message and message.parts and message.parts[0].text == 'start': + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus( + state=TaskState.TASK_STATE_AUTH_REQUIRED + ), + ) + ) + elif ( + message + and message.parts + and message.parts[0].text == 'credentials' + ): + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + context_id=context.context_id, + status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), + ) + ) + else: + raise ValueError(f'Unexpected message {message}') + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(AuthAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + msg1 = Message( + message_id='msg-start', role=Role.ROLE_USER, parts=[Part(text='start')] + ) + + it = client.send_message( + SendMessageRequest( + message=msg1, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + events1 = [event async for event in it] + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_AUTH_REQUIRED, + ] + task_id = get_task_id(events1[0]) + context_id = get_task_context_id(events1[0]) + + # Now send another message with credentials + msg2 = Message( + task_id=task_id, + context_id=context_id, + message_id='msg-creds', + role=Role.ROLE_USER, + parts=[Part(text='credentials')], + ) + + it2 = client.send_message( + SendMessageRequest( + message=msg2, + configuration=SendMessageConfiguration(return_immediately=False), + ) + ) + + assert [get_state(event) async for event in it2] == [ + TaskState.TASK_STATE_COMPLETED, + ] + + # Scenario: Parallel subscribe attach detach # Migrated from: test_parallel_subscribe_attach_detach in test_handler_comparison @pytest.mark.timeout(5.0)