From a6359a937f2514c23463bfa0a1b5f8298e3335f3 Mon Sep 17 00:00:00 2001 From: Paulchen Date: Thu, 26 Jun 2025 21:12:09 +0900 Subject: [PATCH 1/4] Format `.vscode/settings.json` --- .vscode/settings.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 65f0291dd..580dca87d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,7 @@ { - "python.testing.pytestArgs": ["tests"], + "python.testing.pytestArgs": [ + "tests" + ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "editor.formatOnSave": true, From f8f536e18f7dcdd2aca6ca1db1311f0e75f47f6a Mon Sep 17 00:00:00 2001 From: Nathan Brake <33383515+njbrake@users.noreply.github.com> Date: Thu, 26 Jun 2025 09:17:00 -0400 Subject: [PATCH 2/4] fix: send notifications on message not streaming (#219) # Description The proposed fix, if the team does want push notifications to be supported in a non-streaming setup Fixes #218 --- .../default_request_handler.py | 168 ++++++++---------- .../test_default_request_handler.py | 16 +- 2 files changed, 94 insertions(+), 90 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 4ef889add..ff86a069d 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -55,6 +55,7 @@ TaskState.rejected, } + @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): """Default request handler for all incoming requests. @@ -168,16 +169,17 @@ async def _run_event_stream( await self.agent_executor.execute(request, queue) await queue.close() - async def on_message_send( + async def _setup_message_execution( self, params: MessageSendParams, context: ServerCallContext | None = None, - ) -> Message | Task: - """Default handler for 'message/send' interface (non-streaming). + ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: + """Common setup logic for both streaming and non-streaming message handling. - Starts the agent execution for the message and waits for the final - result (Task or Message). + Returns: + A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) """ + # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.taskId, context_id=params.message.contextId, @@ -185,6 +187,7 @@ async def on_message_send( initial_message=params.message, ) task: Task | None = await task_manager.get_task() + if task: if task.status.state in TERMINAL_TASK_STATES: raise ServerError( @@ -206,6 +209,8 @@ async def on_message_send( await self._push_notifier.set_info( task.id, params.configuration.pushNotificationConfig ) + + # Build request context request_context = await self._request_context_builder.build( params=params, task_id=task.id if task else None, @@ -222,13 +227,49 @@ async def on_message_send( result_aggregator = ResultAggregator(task_manager) # TODO: to manage the non-blocking flows. producer_task = asyncio.create_task( - self._run_event_stream( - request_context, - queue, - ) + self._run_event_stream(request_context, queue) ) await self._register_producer(task_id, producer_task) + return task_manager, task_id, queue, result_aggregator, producer_task + + def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: + """Validates that agent-generated task ID matches the expected task ID.""" + if task_id != event_task_id: + logger.error( + f'Agent generated task_id={event_task_id} does not match the RequestContext task_id={task_id}.' + ) + raise ServerError( + InternalError(message='Task ID mismatch in agent response') + ) + + async def _send_push_notification_if_needed( + self, task_id: str, result_aggregator: ResultAggregator + ) -> None: + """Sends push notification if configured and task is available.""" + if self._push_notifier and task_id: + latest_task = await result_aggregator.current_result + if isinstance(latest_task, Task): + await self._push_notifier.send_notification(latest_task) + + async def on_message_send( + self, + params: MessageSendParams, + context: ServerCallContext | None = None, + ) -> Message | Task: + """Default handler for 'message/send' interface (non-streaming). + + Starts the agent execution for the message and waits for the final + result (Task or Message). + """ + ( + task_manager, + task_id, + queue, + result_aggregator, + producer_task, + ) = await self._setup_message_execution(params, context) + consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) @@ -241,13 +282,13 @@ async def on_message_send( if not result: raise ServerError(error=InternalError()) - if isinstance(result, Task) and task_id != result.id: - logger.error( - f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.' - ) - raise ServerError( - InternalError(message='Task ID mismatch in agent response') - ) + if isinstance(result, Task): + self._validate_task_id_match(task_id, result.id) + + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) + except Exception as e: logger.error(f'Agent execution failed. Error: {e}') raise @@ -272,85 +313,34 @@ async def on_message_send_stream( Starts the agent execution and yields events as they are produced by the agent. """ - task_manager = TaskManager( - task_id=params.message.taskId, - context_id=params.message.contextId, - task_store=self.task_store, - initial_message=params.message, - ) - task: Task | None = await task_manager.get_task() - - if task: - if task.status.state in TERMINAL_TASK_STATES: - raise ServerError( - error=InvalidParamsError( - message=f'Task {task.id} is in terminal state: {task.status.state}' - ) - ) - - task = task_manager.update_with_message(params.message, task) - if self.should_add_push_info(params): - assert isinstance(self._push_notifier, PushNotifier) - assert isinstance( - params.configuration, MessageSendConfiguration - ) - assert isinstance( - params.configuration.pushNotificationConfig, - PushNotificationConfig, - ) - await self._push_notifier.set_info( - task.id, params.configuration.pushNotificationConfig - ) - else: - queue = EventQueue() - result_aggregator = ResultAggregator(task_manager) - request_context = await self._request_context_builder.build( - params=params, - task_id=task.id if task else None, - context_id=params.message.contextId, - task=task, - context=context, - ) - - task_id = cast('str', request_context.task_id) - queue = await self._queue_manager.create_or_tap(task_id) - producer_task = asyncio.create_task( - self._run_event_stream( - request_context, - queue, - ) - ) - await self._register_producer(task_id, producer_task) + ( + task_manager, + task_id, + queue, + result_aggregator, + producer_task, + ) = await self._setup_message_execution(params, context) try: consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): - if task_id != event.id: - logger.error( - f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' - ) - raise ServerError( - InternalError( - message='Task ID mismatch in agent response' - ) - ) - - if ( - self._push_notifier - and params.configuration - and params.configuration.pushNotificationConfig - ): - await self._push_notifier.set_info( - task_id, - params.configuration.pushNotificationConfig, - ) - - if self._push_notifier and task_id: - latest_task = await result_aggregator.current_result - if isinstance(latest_task, Task): - await self._push_notifier.send_notification(latest_task) + self._validate_task_id_match(task_id, event.id) + + if ( + self._push_notifier + and params.configuration + and params.configuration.pushNotificationConfig + ): + await self._push_notifier.set_info( + task_id, + params.configuration.pushNotificationConfig, + ) + + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) yield event finally: await self._cleanup_producer(producer_task, task_id) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6f67c0f8c..dd713752c 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -361,6 +361,15 @@ async def test_on_message_send_with_push_notification(): False, ) + # Mock the current_result property to return the final task result + async def get_current_result(): + return final_task_result + + # Configure the 'current_result' property on the type of the mock instance + type(mock_result_aggregator_instance).current_result = PropertyMock( + return_value=get_current_result() + ) + with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -380,6 +389,9 @@ async def test_on_message_send_with_push_notification(): ) mock_push_notifier.set_info.assert_awaited_once_with(task_id, push_config) + mock_push_notifier.send_notification.assert_awaited_once_with( + final_task_result + ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -1139,12 +1151,14 @@ async def consume_stream(): texts = [p.root.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2'] + TERMINAL_TASK_STATES = { TaskState.completed, TaskState.canceled, TaskState.failed, TaskState.rejected, -} +} + @pytest.mark.asyncio @pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES) From 6cb7b2b95c508c04a2d9ed616fb5916dbad3d84d Mon Sep 17 00:00:00 2001 From: Shingo OKAWA Date: Fri, 27 Jun 2025 03:18:04 +0900 Subject: [PATCH 3/4] chore(format): run `nox -s format` (#247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `nox -s format` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes N/A 🦕 This PR runs `nox -s format` to fix previously missed formatting issues. Signed-off-by: Shingo OKAWA --- tests/server/tasks/test_task_manager.py | 5 +++-- tests/server/tasks/test_task_updater.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index a99ad7013..952979710 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -127,6 +127,7 @@ async def test_save_task_event_artifact_update( updated_task.artifacts = [new_artifact] mock_task_store.save.assert_called_once_with(updated_task) + @pytest.mark.asyncio async def test_save_task_event_metadata_update( task_manager: TaskManager, mock_task_store: AsyncMock @@ -134,8 +135,8 @@ async def test_save_task_event_metadata_update( """Test saving an updated metadata for an existing task.""" initial_task = Task(**MINIMAL_TASK) mock_task_store.get.return_value = initial_task - new_metadata = {"meta_key_test": "meta_value_test"} - + new_metadata = {'meta_key_test': 'meta_value_test'} + event = TaskStatusUpdateEvent( taskId=MINIMAL_TASK['id'], contextId=MINIMAL_TASK['contextId'], diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index 28647e748..2a8d5deb6 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -150,10 +150,9 @@ async def test_add_artifact_generates_id( assert event.lastChunk == None - @pytest.mark.asyncio @pytest.mark.parametrize( - "append_val, last_chunk_val", + 'append_val, last_chunk_val', [ (False, False), (True, True), @@ -166,14 +165,17 @@ async def test_add_artifact_with_append_last_chunk( ): """Test add_artifact with append and last_chunk flags.""" await task_updater.add_artifact( - parts=sample_parts, artifact_id="id1", append=append_val, last_chunk=last_chunk_val + parts=sample_parts, + artifact_id='id1', + append=append_val, + last_chunk=last_chunk_val, ) event_queue.enqueue_event.assert_called_once() event = event_queue.enqueue_event.call_args[0][0] assert isinstance(event, TaskArtifactUpdateEvent) - assert event.artifact.artifactId == "id1" + assert event.artifact.artifactId == 'id1' assert event.artifact.parts == sample_parts assert event.append == append_val assert event.lastChunk == last_chunk_val From bc3a82a015d01b3cb3464f668e328b5c177af11b Mon Sep 17 00:00:00 2001 From: Paulchen Date: Tue, 1 Jul 2025 00:23:45 +0900 Subject: [PATCH 4/4] Add missing settings upsi --- .vscode/settings.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.vscode/settings.json b/.vscode/settings.json index 580dca87d..0f968e252 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,4 +14,10 @@ } }, "ruff.importStrategy": "fromEnvironment", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": false, + "files.trimTrailingWhitespace": false, + "editor.rulers": [ + 80 + ] }