From 27f633bd24bf772f2761a711799ce06656e0c07c Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Wed, 6 Aug 2025 17:52:58 -0700 Subject: [PATCH 1/4] feat: add traceability extension support --- src/a2a/client/base_client.py | 10 +++ src/a2a/extensions/__init__.py | 6 ++ src/a2a/extensions/base.py | 26 ++++++ src/a2a/extensions/trace.py | 81 +++++++++++++++++++ .../default_request_handler.py | 12 ++- tests/extensions/test_trace.py | 63 +++++++++++++++ tests/extensions/test_trace_extension.py | 43 ++++++++++ uv.lock | 4 +- 8 files changed, 242 insertions(+), 3 deletions(-) create mode 100644 src/a2a/extensions/__init__.py create mode 100644 src/a2a/extensions/base.py create mode 100644 src/a2a/extensions/trace.py create mode 100644 tests/extensions/test_trace.py create mode 100644 tests/extensions/test_trace_extension.py diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index f4a8d03de..009c0188b 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -11,6 +11,7 @@ from a2a.client.errors import A2AClientInvalidStateError from a2a.client.middleware import ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.base import Extension from a2a.types import ( AgentCard, GetTaskPushNotificationConfigParams, @@ -41,6 +42,12 @@ def __init__( self._card = card self._config = config self._transport = transport + self._extensions: list[Extension] = [] + + def install_extension(self, extension: Extension) -> None: + """Installs an extension on the client.""" + extension.install(self) + self._extensions.append(extension) async def send_message( self, @@ -61,6 +68,9 @@ async def send_message( Yields: An async iterator of `ClientEvent` or a final `Message` response. """ + for extension in self._extensions: + extension.on_client_message(request) + config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, diff --git a/src/a2a/extensions/__init__.py b/src/a2a/extensions/__init__.py new file mode 100644 index 000000000..edaad13dc --- /dev/null +++ b/src/a2a/extensions/__init__.py @@ -0,0 +1,6 @@ +"""A2A extensions.""" + +from .base import Extension +from . import common, trace + +__all__ = ['Extension', 'common', 'trace'] diff --git a/src/a2a/extensions/base.py b/src/a2a/extensions/base.py new file mode 100644 index 000000000..28dbfef76 --- /dev/null +++ b/src/a2a/extensions/base.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from a2a.client.client import A2AClient + from a2a.server.server import A2AServer + + +class Extension: + """Base class for all extensions.""" + + def __init__(self, **kwargs: Any) -> None: + ... + + def on_client_message(self, message: Any) -> None: + """Called when a message is sent from the client.""" + ... + + def on_server_message(self, message: Any) -> None: + """Called when a message is received by the server.""" + ... + + def install(self, client_or_server: A2AClient | A2AServer) -> None: + """Called when the extension is installed on a client or server.""" + ... diff --git a/src/a2a/extensions/trace.py b/src/a2a/extensions/trace.py new file mode 100644 index 000000000..d39da4895 --- /dev/null +++ b/src/a2a/extensions/trace.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any + +from a2a._base import A2ABaseModel +from a2a.extensions.base import Extension + + +class CallTypeEnum(str, Enum): + """The type of the operation a step represents.""" + + AGENT = 'AGENT' + TOOL = 'TOOL' + + +class ToolInvocation(A2ABaseModel): + """A tool invocation.""" + + tool_name: str + parameters: dict[str, Any] + + +class AgentInvocation(A2ABaseModel): + """An agent invocation.""" + + agent_url: str + agent_name: str + requests: dict[str, Any] + response_trace: ResponseTrace | None = None + + +class StepAction(A2ABaseModel): + """The action of a step.""" + + tool_invocation: ToolInvocation | None = None + agent_invocation: AgentInvocation | None = None + + +class Step(A2ABaseModel): + """A single operation within a trace.""" + + step_id: str + trace_id: str + parent_step_id: str | None = None + call_type: CallTypeEnum + step_action: StepAction + cost: int | None = None + total_tokens: int | None = None + additional_attributes: dict[str, str] | None = None + latency: int | None = None + start_time: datetime + end_time: datetime + + +class ResponseTrace(A2ABaseModel): + """A trace message that contains a collection of spans.""" + + trace_id: str + steps: list[Step] + + +class TraceExtension(Extension): + """An extension for traceability.""" + + def on_client_message(self, message: Any) -> None: + """Appends trace information to the message.""" + # This is a placeholder implementation. + if message.metadata is None: + message.metadata = {} + message.metadata['trace'] = 'client-trace' + + def on_server_message(self, message: Any) -> None: + """Processes trace information from the message.""" + # This is a placeholder implementation. + if hasattr(message, 'metadata') and 'trace' in message.metadata: + print(f"Received trace: {message.metadata['trace']}") + + +AgentInvocation.model_rebuild() \ No newline at end of file diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 2549d087a..6dff2dc85 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -2,7 +2,7 @@ import logging from collections.abc import AsyncGenerator -from typing import cast +from typing import Any, cast from a2a.server.agent_execution import ( AgentExecutor, @@ -26,6 +26,7 @@ TaskManager, TaskStore, ) +from a2a.extensions.base import Extension from a2a.types import ( DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, @@ -101,6 +102,12 @@ def __init__( # noqa: PLR0913 # TODO: Likely want an interface for managing this, like AgentExecutionManager. self._running_agents = {} self._running_agents_lock = asyncio.Lock() + self._extensions: list[Extension] = [] + + def install_extension(self, extension: Extension, server: Any) -> None: + """Installs an extension on the server.""" + extension.install(server) + self._extensions.append(extension) async def on_get_task( self, @@ -182,6 +189,9 @@ async def _setup_message_execution( Returns: A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) """ + for extension in self._extensions: + extension.on_server_message(params.message) + # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.task_id, diff --git a/tests/extensions/test_trace.py b/tests/extensions/test_trace.py new file mode 100644 index 000000000..850d80073 --- /dev/null +++ b/tests/extensions/test_trace.py @@ -0,0 +1,63 @@ +from datetime import datetime, timezone + +from a2a.extensions.trace import ( + AgentInvocation, + CallTypeEnum, + ResponseTrace, + Step, + StepAction, + ToolInvocation, +) + + +def test_trace_serialization(): + start_time = datetime(2025, 3, 15, 12, 0, 0, tzinfo=timezone.utc) + end_time = datetime(2025, 3, 15, 12, 0, 0, 250000, tzinfo=timezone.utc) + + trace = ResponseTrace( + trace_id='trace-example-12345', + steps=[ + Step( + step_id='step-1-agent', + trace_id='trace-example-1234p', + call_type=CallTypeEnum.AGENT, + step_action=StepAction( + agent_invocation=AgentInvocation( + agent_name='weather_agent', + agent_url='http://google3/some/agent/url', + requests={ + 'user_prompt': "What's the weather in Paris and what should I wear?" + }, + ) + ), + cost=150, + total_tokens=75, + additional_attributes={'user_country': 'US'}, + latency=250, + start_time=start_time, + end_time=end_time, + ), + Step( + step_id='step-2-tool', + trace_id='trace-example-12345', + parent_step_id='step-1-agent', + call_type=CallTypeEnum.TOOL, + step_action=StepAction( + tool_invocation=ToolInvocation( + tool_name='google_map_api_tool', + parameters={'location': 'Paris, FR'}, + ) + ), + cost=50, + total_tokens=20, + latency=100, + start_time=start_time, + end_time=end_time, + ), + ], + ) + + trace_dict = trace.model_dump(mode='json') + deserialized_trace = ResponseTrace.model_validate(trace_dict) + + assert trace == deserialized_trace diff --git a/tests/extensions/test_trace_extension.py b/tests/extensions/test_trace_extension.py new file mode 100644 index 000000000..a66687851 --- /dev/null +++ b/tests/extensions/test_trace_extension.py @@ -0,0 +1,43 @@ +from unittest.mock import Mock + +import pytest + +from a2a.client.base_client import BaseClient +from a2a.extensions.trace import TraceExtension +from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler +from a2a.types import Message, TextPart + + +@pytest.mark.asyncio +async def test_trace_extension(): + client = BaseClient(card=Mock(), config=Mock(), transport=Mock(), consumers=[], middleware=[]) + server_handler = DefaultRequestHandler( + agent_executor=Mock(), + task_store=Mock(), + ) + + trace_extension = TraceExtension() + client.install_extension(trace_extension) + server_handler.install_extension(trace_extension, server=Mock()) + + message = Message( + message_id='test_message', + role='user', + parts=[TextPart(text='Hello, world!')], + ) + + # Simulate client sending a message + for extension in client._extensions: + extension.on_client_message(message) + + assert 'trace' in message.metadata + assert message.metadata['trace'] == 'client-trace' + + # Simulate server receiving a message + for extension in server_handler._extensions: + extension.on_server_message(message) + + # Check that the server-side handler was called + # (in this case, it just prints a message) + # We can't easily check the output of print, so we'll just + # assume it worked if no exceptions were raised. diff --git a/uv.lock b/uv.lock index cbc718ad9..9c88995ca 100644 --- a/uv.lock +++ b/uv.lock @@ -69,7 +69,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, - { name = "fastapi", specifier = ">=0.116.1" }, + { name = "fastapi", specifier = ">=0.95.0" }, { name = "google-api-core", specifier = ">=1.26.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "grpcio-reflection", marker = "extra == 'grpc'", specifier = ">=1.7.0" }, @@ -78,7 +78,7 @@ requires-dist = [ { name = "httpx-sse", specifier = ">=0.4.0" }, { name = "opentelemetry-api", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, { name = "opentelemetry-sdk", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, - { name = "protobuf", specifier = "==5.29.5" }, + { name = "protobuf", specifier = ">=5.29.5" }, { name = "pydantic", specifier = ">=2.11.3" }, { name = "sqlalchemy", extras = ["aiomysql", "aiosqlite", "asyncio", "postgresql-asyncpg"], marker = "extra == 'sql'", specifier = ">=2.0.0" }, { name = "sqlalchemy", extras = ["aiomysql", "asyncio"], marker = "extra == 'mysql'", specifier = ">=2.0.0" }, From 51026f4ed9e34a193de0e3db38192560b62febb5 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Thu, 7 Aug 2025 10:31:52 -0700 Subject: [PATCH 2/4] Add recording steps and hook to the request handler --- src/a2a/client/base_client.py | 88 ++++++++++++++----- src/a2a/extensions/trace.py | 81 +++++++++++++++-- .../server/agent_execution/agent_executor.py | 7 +- .../default_request_handler.py | 39 +++++++- tests/extensions/debug_trace.py | 51 +++++++++++ tests/extensions/simple_trace_test.py | 22 +++++ tests/extensions/test_full_trace_extension.py | 58 ++++++++++++ tests/extensions/test_trace_extension.py | 9 +- .../test_default_request_handler.py | 12 ++- 9 files changed, 327 insertions(+), 40 deletions(-) create mode 100644 tests/extensions/debug_trace.py create mode 100644 tests/extensions/simple_trace_test.py create mode 100644 tests/extensions/test_full_trace_extension.py diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 009c0188b..9389c9490 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -1,4 +1,5 @@ from collections.abc import AsyncIterator +from typing import cast from a2a.client.client import ( Client, @@ -12,6 +13,12 @@ from a2a.client.middleware import ClientCallInterceptor from a2a.client.transports.base import ClientTransport from a2a.extensions.base import Extension +from a2a.extensions.trace import ( + AgentInvocation, + CallTypeEnum, + StepAction, + TraceExtension, +) from a2a.types import ( AgentCard, GetTaskPushNotificationConfigParams, @@ -68,9 +75,31 @@ async def send_message( Yields: An async iterator of `ClientEvent` or a final `Message` response. """ + trace_extension: TraceExtension | None = None for extension in self._extensions: + if isinstance(extension, TraceExtension): + trace_extension = cast(TraceExtension, extension) extension.on_client_message(request) + step = None + if trace_extension: + trace_id = request.metadata.get('trace', {}).get('trace_id') + parent_step_id = request.metadata.get('trace', {}).get( + 'parent_step_id' + ) + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=CallTypeEnum.AGENT, + step_action=StepAction( + agent_invocation=AgentInvocation( + agent_url=self._card.url, + agent_name=self._card.name, + requests=request.model_dump(mode='json'), + ) + ), + ) + config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, @@ -82,33 +111,44 @@ async def send_message( ) params = MessageSendParams(message=request, configuration=config) - if not self._config.streaming or not self._card.capabilities.streaming: - response = await self._transport.send_message( + try: + if ( + not self._config.streaming + or not self._card.capabilities.streaming + ): + response = await self._transport.send_message( + params, context=context + ) + result = ( + (response, None) + if isinstance(response, Task) + else response + ) + await self.consume(result, self._card) + yield result + return + + tracker = ClientTaskManager() + stream = self._transport.send_message_streaming( params, context=context ) - result = ( - (response, None) if isinstance(response, Task) else response - ) - await self.consume(result, self._card) - yield result - return - tracker = ClientTaskManager() - stream = self._transport.send_message_streaming(params, context=context) - - first_event = await anext(stream) - # The response from a server may be either exactly one Message or a - # series of Task updates. Separate out the first message for special - # case handling, which allows us to simplify further stream processing. - if isinstance(first_event, Message): - await self.consume(first_event, self._card) - yield first_event - return - - yield await self._process_response(tracker, first_event) - - async for event in stream: - yield await self._process_response(tracker, event) + first_event = await anext(stream) + # The response from a server may be either exactly one Message or a + # series of Task updates. Separate out the first message for special + # case handling, which allows us to simplify further stream processing. + if isinstance(first_event, Message): + await self.consume(first_event, self._card) + yield first_event + return + + yield await self._process_response(tracker, first_event) + + async for event in stream: + yield await self._process_response(tracker, event) + finally: + if trace_extension and step: + trace_extension.end_step(step.step_id) async def _process_response( self, diff --git a/src/a2a/extensions/trace.py b/src/a2a/extensions/trace.py index d39da4895..dfa4be01f 100644 --- a/src/a2a/extensions/trace.py +++ b/src/a2a/extensions/trace.py @@ -1,6 +1,8 @@ from __future__ import annotations -from datetime import datetime +import time +import uuid +from datetime import datetime, timezone from enum import Enum from typing import Any @@ -51,7 +53,7 @@ class Step(A2ABaseModel): additional_attributes: dict[str, str] | None = None latency: int | None = None start_time: datetime - end_time: datetime + end_time: datetime | None = None class ResponseTrace(A2ABaseModel): @@ -64,18 +66,81 @@ class ResponseTrace(A2ABaseModel): class TraceExtension(Extension): """An extension for traceability.""" + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.traces: dict[str, ResponseTrace] = {} + self._current_steps: dict[str, Step] = {} + + def _generate_id(self, prefix: str) -> str: + return f'{prefix}-{uuid.uuid4()}' + + def start_trace(self) -> ResponseTrace: + """Starts a new trace.""" + trace_id = self._generate_id('trace') + trace = ResponseTrace(trace_id=trace_id, steps=[]) + self.traces[trace_id] = trace + return trace + + def start_step( + self, + trace_id: str, + parent_step_id: str | None, + call_type: CallTypeEnum, + step_action: StepAction, + ) -> Step: + """Starts a new step.""" + step_id = self._generate_id('step') + step = Step( + step_id=step_id, + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=call_type, + step_action=step_action, + start_time=datetime.now(timezone.utc), + ) + self._current_steps[step_id] = step + return step + + def end_step( + self, + step_id: str, + cost: int | None = None, + total_tokens: int | None = None, + additional_attributes: dict[str, str] | None = None, + ) -> None: + """Ends a step.""" + if step_id not in self._current_steps: + return + + step = self._current_steps.pop(step_id) + step.end_time = datetime.now(timezone.utc) + step.latency = int( + (step.end_time - step.start_time).total_seconds() * 1000 + ) + step.cost = cost + step.total_tokens = total_tokens + step.additional_attributes = additional_attributes + + if step.trace_id in self.traces: + self.traces[step.trace_id].steps.append(step) + def on_client_message(self, message: Any) -> None: """Appends trace information to the message.""" - # This is a placeholder implementation. + trace = self.start_trace() if message.metadata is None: message.metadata = {} - message.metadata['trace'] = 'client-trace' + message.metadata['trace'] = trace.model_dump(mode='json') def on_server_message(self, message: Any) -> None: """Processes trace information from the message.""" - # This is a placeholder implementation. - if hasattr(message, 'metadata') and 'trace' in message.metadata: - print(f"Received trace: {message.metadata['trace']}") + if ( + hasattr(message, 'metadata') + and message.metadata is not None + and 'trace' in message.metadata + ): + trace_data = message.metadata['trace'] + trace = ResponseTrace.model_validate(trace_data) + self.traces[trace.trace_id] = trace -AgentInvocation.model_rebuild() \ No newline at end of file +AgentInvocation.model_rebuild() diff --git a/src/a2a/server/agent_execution/agent_executor.py b/src/a2a/server/agent_execution/agent_executor.py index 38be9c11c..eabe191d1 100644 --- a/src/a2a/server/agent_execution/agent_executor.py +++ b/src/a2a/server/agent_execution/agent_executor.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Any from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue @@ -13,7 +14,10 @@ class AgentExecutor(ABC): @abstractmethod async def execute( - self, context: RequestContext, event_queue: EventQueue + self, + context: RequestContext, + event_queue: EventQueue, + request_handler: Any, ) -> None: """Execute the agent's logic for a given request context. @@ -26,6 +30,7 @@ async def execute( Args: context: The request context containing the message, task ID, etc. event_queue: The queue to publish events to. + request_handler: The request handler that is executing the agent. """ @abstractmethod diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 6dff2dc85..6b7978bc0 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -27,6 +27,12 @@ TaskStore, ) from a2a.extensions.base import Extension +from a2a.extensions.trace import ( + CallTypeEnum, + StepAction, + ToolInvocation, + TraceExtension, +) from a2a.types import ( DeleteTaskPushNotificationConfigParams, GetTaskPushNotificationConfigParams, @@ -176,9 +182,40 @@ async def _run_event_stream( request: The request context for the agent. queue: The event queue for the agent to publish to. """ - await self.agent_executor.execute(request, queue) + await self.agent_executor.execute(request, queue, self) await queue.close() + async def handle_tool_call( + self, + trace_id: str, + parent_step_id: str, + tool_name: str, + parameters: dict[str, Any], + ) -> None: + """Handles a tool call from the agent executor.""" + trace_extension: TraceExtension | None = None + for extension in self._extensions: + if isinstance(extension, TraceExtension): + trace_extension = cast(TraceExtension, extension) + + if not trace_extension: + return + + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=parent_step_id, + call_type=CallTypeEnum.TOOL, + step_action=StepAction( + tool_invocation=ToolInvocation( + tool_name=tool_name, + parameters=parameters, + ) + ), + ) + # In a real implementation, you would execute the tool here. + # For this example, we'll just end the step immediately. + trace_extension.end_step(step.step_id) + async def _setup_message_execution( self, params: MessageSendParams, diff --git a/tests/extensions/debug_trace.py b/tests/extensions/debug_trace.py new file mode 100644 index 000000000..b6511106e --- /dev/null +++ b/tests/extensions/debug_trace.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +import sys +from unittest.mock import Mock + +from a2a.client.base_client import BaseClient +from a2a.extensions.trace import TraceExtension +from a2a.types import Message, TextPart, Role, Part + +def debug_trace(): + print("Starting trace debug...") + + # Create the extension + trace_extension = TraceExtension() + + # Create a trace directly to see its structure + trace = trace_extension.start_trace() + print(f"Direct trace object: {trace}") + print(f"Direct trace dict: {trace.model_dump(mode='json')}") + + # Create a message + message = Message( + message_id='test_message', + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], + ) + + print(f"Initial message metadata: {message.metadata}") + + # Call the extension method + trace_extension.on_client_message(message) + + print(f"After extension metadata: {message.metadata}") + + if message.metadata and 'trace' in message.metadata: + trace_data = message.metadata['trace'] + print(f"Trace data type: {type(trace_data)}") + print(f"Trace data: {trace_data}") + + if isinstance(trace_data, dict): + print(f"Trace data keys: {list(trace_data.keys())}") + if 'trace_id' in trace_data: + print(f"Found trace_id: {trace_data['trace_id']}") + else: + print("trace_id not found in trace data") + else: + print("Trace data is not a dict") + else: + print("No trace data found in metadata") + +if __name__ == "__main__": + debug_trace() diff --git a/tests/extensions/simple_trace_test.py b/tests/extensions/simple_trace_test.py new file mode 100644 index 000000000..ee1849f65 --- /dev/null +++ b/tests/extensions/simple_trace_test.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +# Simple test to check ResponseTrace serialization +from a2a.extensions.trace import TraceExtension, ResponseTrace + +# Create extension and trace +ext = TraceExtension() +trace = ext.start_trace() + +print("Trace object:", trace) +print("Trace type:", type(trace)) +print("Trace fields:", trace.__dict__) +print("Model dump:", trace.model_dump(mode='json')) + +# Test creating trace data like in the extension +if True: # message.metadata is None + metadata = {} +metadata['trace'] = trace.model_dump(mode='json') + +print("Metadata:", metadata) +print("Trace in metadata:", metadata['trace']) +print("Keys in trace:", metadata['trace'].keys() if isinstance(metadata['trace'], dict) else "not a dict") diff --git a/tests/extensions/test_full_trace_extension.py b/tests/extensions/test_full_trace_extension.py new file mode 100644 index 000000000..2a664e8f5 --- /dev/null +++ b/tests/extensions/test_full_trace_extension.py @@ -0,0 +1,58 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from a2a.extensions.trace import TraceExtension +from a2a.types import Message, Part, Role, TextPart + + +@pytest.mark.asyncio +async def test_full_trace_extension(): + trace_extension = TraceExtension() + + # Test the trace extension directly + message = Message( + message_id='test_message', + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], + ) + + # Simulate client sending a message - creates trace + trace_extension.on_client_message(message) + + # Verify trace was created and stored in metadata + assert 'trace' in message.metadata + trace_data = message.metadata['trace'] + assert 'traceId' in trace_data + trace_id = trace_data['traceId'] + + # Simulate server receiving a message - loads trace + trace_extension.on_server_message(message) + + # Verify trace was loaded into extension + assert trace_id in trace_extension.traces + trace = trace_extension.traces[trace_id] + assert len(trace.steps) == 0 # Initially no steps + + # Simulate a tool call being made + from a2a.extensions.trace import StepAction, ToolInvocation, CallTypeEnum + step_action = StepAction(tool_invocation=ToolInvocation( + tool_name='test_tool', + parameters={'param1': 'value1'} + )) + + step = trace_extension.start_step( + trace_id=trace_id, + parent_step_id=None, + call_type=CallTypeEnum.TOOL, + step_action=step_action + ) + + # End the step + trace_extension.end_step(step.step_id) + + # Verify the trace + assert len(trace_extension.traces) == 1 + trace = trace_extension.traces[trace_id] + assert len(trace.steps) == 1 + assert trace.steps[0].call_type == CallTypeEnum.TOOL diff --git a/tests/extensions/test_trace_extension.py b/tests/extensions/test_trace_extension.py index a66687851..a32329f12 100644 --- a/tests/extensions/test_trace_extension.py +++ b/tests/extensions/test_trace_extension.py @@ -5,7 +5,7 @@ from a2a.client.base_client import BaseClient from a2a.extensions.trace import TraceExtension from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler -from a2a.types import Message, TextPart +from a2a.types import Message, TextPart, Part, Role @pytest.mark.asyncio @@ -22,8 +22,8 @@ async def test_trace_extension(): message = Message( message_id='test_message', - role='user', - parts=[TextPart(text='Hello, world!')], + role=Role.user, + parts=[Part(TextPart(text='Hello, world!'))], ) # Simulate client sending a message @@ -31,7 +31,8 @@ async def test_trace_extension(): extension.on_client_message(message) assert 'trace' in message.metadata - assert message.metadata['trace'] == 'client-trace' + # The trace_id field is serialized as traceId due to camelCase alias generator + assert isinstance(message.metadata['trace']['traceId'], str) # Simulate server receiving a message for extension in server_handler._extensions: diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 6cb21662c..0ea8380ea 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -2,6 +2,7 @@ import logging import time +from typing import Any from unittest.mock import ( AsyncMock, MagicMock, @@ -57,7 +58,9 @@ class DummyAgentExecutor(AgentExecutor): - async def execute(self, context: RequestContext, event_queue: EventQueue): + async def execute( + self, context: RequestContext, event_queue: EventQueue, request_handler: Any + ): task_updater = TaskUpdater( event_queue, context.task_id, context.context_id ) @@ -584,7 +587,12 @@ async def test_on_message_send_task_id_mismatch(): class HelloAgentExecutor(AgentExecutor): - async def execute(self, context: RequestContext, event_queue: EventQueue): + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + request_handler: Any, + ): task = context.current_task if not task: assert context.message is not None, ( From 22e68754c7159b4893122d5e26172ed649bcf83c Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Thu, 7 Aug 2025 10:58:35 -0700 Subject: [PATCH 3/4] Update imports --- src/a2a/extensions/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/a2a/extensions/__init__.py b/src/a2a/extensions/__init__.py index edaad13dc..86a77acfa 100644 --- a/src/a2a/extensions/__init__.py +++ b/src/a2a/extensions/__init__.py @@ -1,6 +1,6 @@ """A2A extensions.""" -from .base import Extension -from . import common, trace +from a2a.extensions.base import Extension +from a2a.extensions import common, trace __all__ = ['Extension', 'common', 'trace'] From 300489370f5b280132da07f3b3418cd18a9443c2 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Thu, 7 Aug 2025 11:01:54 -0700 Subject: [PATCH 4/4] resolve conflicts on uv.lock --- uv.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uv.lock b/uv.lock index 9c88995ca..664cb2e04 100644 --- a/uv.lock +++ b/uv.lock @@ -69,7 +69,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "cryptography", marker = "extra == 'encryption'", specifier = ">=43.0.0" }, - { name = "fastapi", specifier = ">=0.95.0" }, + { name = "fastapi", marker = "extra == 'http-server'", specifier = ">=0.115.2" }, { name = "google-api-core", specifier = ">=1.26.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.60" }, { name = "grpcio-reflection", marker = "extra == 'grpc'", specifier = ">=1.7.0" },