-
Notifications
You must be signed in to change notification settings - Fork 429
feat: add traceability extension support #387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
27f633b
b6002fe
51026f4
22e6875
3004893
31471b8
7244431
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| """A2A extensions.""" | ||
|
|
||
| from .base import Extension | ||
| from . import common, trace | ||
|
Check failure on line 4 in src/a2a/extensions/__init__.py
|
||
|
|
||
| __all__ = ['Extension', 'common', 'trace'] | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,26 @@ | ||||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| from typing import TYPE_CHECKING, Any | ||||||||
|
|
||||||||
| if TYPE_CHECKING: | ||||||||
| from a2a.client.client import A2AClient | ||||||||
| from a2a.server.server import A2AServer | ||||||||
|
|
||||||||
|
|
||||||||
| class Extension: | ||||||||
| """Base class for all extensions.""" | ||||||||
|
|
||||||||
| def __init__(self, **kwargs: Any) -> None: | ||||||||
| ... | ||||||||
|
|
||||||||
| def on_client_message(self, message: Any) -> None: | ||||||||
| """Called when a message is sent from the client.""" | ||||||||
| ... | ||||||||
|
|
||||||||
| def on_server_message(self, message: Any) -> None: | ||||||||
| """Called when a message is received by the server.""" | ||||||||
|
Comment on lines
+20
to
+21
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For better type safety, the To make this work, you'll also need to add
Suggested change
|
||||||||
| ... | ||||||||
|
|
||||||||
| def install(self, client_or_server: A2AClient | A2AServer) -> None: | ||||||||
| """Called when the extension is installed on a client or server.""" | ||||||||
| ... | ||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,81 @@ | ||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| from datetime import datetime | ||||||||||
| from enum import Enum | ||||||||||
| from typing import Any | ||||||||||
|
|
||||||||||
| from a2a._base import A2ABaseModel | ||||||||||
| from a2a.extensions.base import Extension | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class CallTypeEnum(str, Enum): | ||||||||||
| """The type of the operation a step represents.""" | ||||||||||
|
|
||||||||||
| AGENT = 'AGENT' | ||||||||||
| TOOL = 'TOOL' | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ToolInvocation(A2ABaseModel): | ||||||||||
| """A tool invocation.""" | ||||||||||
|
|
||||||||||
| tool_name: str | ||||||||||
| parameters: dict[str, Any] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class AgentInvocation(A2ABaseModel): | ||||||||||
| """An agent invocation.""" | ||||||||||
|
|
||||||||||
| agent_url: str | ||||||||||
| agent_name: str | ||||||||||
| requests: dict[str, Any] | ||||||||||
| response_trace: ResponseTrace | None = None | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class StepAction(A2ABaseModel): | ||||||||||
| """The action of a step.""" | ||||||||||
|
|
||||||||||
| tool_invocation: ToolInvocation | None = None | ||||||||||
| agent_invocation: AgentInvocation | None = None | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class Step(A2ABaseModel): | ||||||||||
| """A single operation within a trace.""" | ||||||||||
|
|
||||||||||
| step_id: str | ||||||||||
| trace_id: str | ||||||||||
| parent_step_id: str | None = None | ||||||||||
| call_type: CallTypeEnum | ||||||||||
| step_action: StepAction | ||||||||||
| cost: int | None = None | ||||||||||
| total_tokens: int | None = None | ||||||||||
| additional_attributes: dict[str, str] | None = None | ||||||||||
| latency: int | None = None | ||||||||||
| start_time: datetime | ||||||||||
| end_time: datetime | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ResponseTrace(A2ABaseModel): | ||||||||||
| """A trace message that contains a collection of spans.""" | ||||||||||
|
|
||||||||||
| trace_id: str | ||||||||||
| steps: list[Step] | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class TraceExtension(Extension): | ||||||||||
| """An extension for traceability.""" | ||||||||||
|
|
||||||||||
| def on_client_message(self, message: Any) -> None: | ||||||||||
| """Appends trace information to the message.""" | ||||||||||
| # This is a placeholder implementation. | ||||||||||
| if message.metadata is None: | ||||||||||
| message.metadata = {} | ||||||||||
| message.metadata['trace'] = 'client-trace' | ||||||||||
|
|
||||||||||
| def on_server_message(self, message: Any) -> None: | ||||||||||
| """Processes trace information from the message.""" | ||||||||||
| # This is a placeholder implementation. | ||||||||||
| if hasattr(message, 'metadata') and 'trace' in message.metadata: | ||||||||||
| print(f"Received trace: {message.metadata['trace']}") | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a potential
Suggested change
|
||||||||||
|
|
||||||||||
|
|
||||||||||
| AgentInvocation.model_rebuild() | ||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |||||
| import logging | ||||||
|
|
||||||
| from collections.abc import AsyncGenerator | ||||||
| from typing import cast | ||||||
| from typing import Any, cast | ||||||
|
|
||||||
| from a2a.server.agent_execution import ( | ||||||
| AgentExecutor, | ||||||
|
|
@@ -26,6 +26,7 @@ | |||||
| TaskManager, | ||||||
| TaskStore, | ||||||
| ) | ||||||
| from a2a.extensions.base import Extension | ||||||
| from a2a.types import ( | ||||||
| DeleteTaskPushNotificationConfigParams, | ||||||
| GetTaskPushNotificationConfigParams, | ||||||
|
|
@@ -101,6 +102,12 @@ def __init__( # noqa: PLR0913 | |||||
| # TODO: Likely want an interface for managing this, like AgentExecutionManager. | ||||||
| self._running_agents = {} | ||||||
| self._running_agents_lock = asyncio.Lock() | ||||||
| self._extensions: list[Extension] = [] | ||||||
|
|
||||||
| def install_extension(self, extension: Extension, server: Any) -> None: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| """Installs an extension on the server.""" | ||||||
| extension.install(server) | ||||||
| self._extensions.append(extension) | ||||||
|
|
||||||
| async def on_get_task( | ||||||
| self, | ||||||
|
|
@@ -182,6 +189,9 @@ async def _setup_message_execution( | |||||
| Returns: | ||||||
| A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) | ||||||
| """ | ||||||
| for extension in self._extensions: | ||||||
| extension.on_server_message(params.message) | ||||||
|
|
||||||
| # Create task manager and validate existing task | ||||||
| task_manager = TaskManager( | ||||||
| task_id=params.message.task_id, | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,63 @@ | ||||||
| from datetime import datetime, timezone | ||||||
|
|
||||||
| from a2a.extensions.trace import ( | ||||||
| AgentInvocation, | ||||||
| CallTypeEnum, | ||||||
| ResponseTrace, | ||||||
| Step, | ||||||
| StepAction, | ||||||
| ToolInvocation, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def test_trace_serialization(): | ||||||
| start_time = datetime(2025, 3, 15, 12, 0, 0, tzinfo=timezone.utc) | ||||||
| end_time = datetime(2025, 3, 15, 12, 0, 0, 250000, tzinfo=timezone.utc) | ||||||
|
|
||||||
| trace = ResponseTrace( | ||||||
| trace_id='trace-example-12345', | ||||||
| steps=[ | ||||||
| Step( | ||||||
| step_id='step-1-agent', | ||||||
| trace_id='trace-example-1234p', | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| call_type=CallTypeEnum.AGENT, | ||||||
| step_action=StepAction( | ||||||
| agent_invocation=AgentInvocation( | ||||||
| agent_name='weather_agent', | ||||||
| agent_url='http://google3/some/agent/url', | ||||||
| requests={ | ||||||
| 'user_prompt': "What's the weather in Paris and what should I wear?" | ||||||
| }, | ||||||
| ) | ||||||
| ), | ||||||
| cost=150, | ||||||
| total_tokens=75, | ||||||
| additional_attributes={'user_country': 'US'}, | ||||||
| latency=250, | ||||||
| start_time=start_time, | ||||||
| end_time=end_time, | ||||||
| ), | ||||||
| Step( | ||||||
| step_id='step-2-tool', | ||||||
| trace_id='trace-example-12345', | ||||||
| parent_step_id='step-1-agent', | ||||||
| call_type=CallTypeEnum.TOOL, | ||||||
| step_action=StepAction( | ||||||
| tool_invocation=ToolInvocation( | ||||||
| tool_name='google_map_api_tool', | ||||||
| parameters={'location': 'Paris, FR'}, | ||||||
| ) | ||||||
| ), | ||||||
| cost=50, | ||||||
| total_tokens=20, | ||||||
| latency=100, | ||||||
| start_time=start_time, | ||||||
| end_time=end_time, | ||||||
| ), | ||||||
| ], | ||||||
| ) | ||||||
|
|
||||||
| trace_dict = trace.model_dump(mode='json') | ||||||
| deserialized_trace = ResponseTrace.model_validate(trace_dict) | ||||||
|
|
||||||
| assert trace == deserialized_trace | ||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,43 @@ | ||||||||||||||||||||||||||||
| from unittest.mock import Mock | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| from a2a.client.base_client import BaseClient | ||||||||||||||||||||||||||||
| from a2a.extensions.trace import TraceExtension | ||||||||||||||||||||||||||||
| from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler | ||||||||||||||||||||||||||||
| from a2a.types import Message, TextPart | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| @pytest.mark.asyncio | ||||||||||||||||||||||||||||
| async def test_trace_extension(): | ||||||||||||||||||||||||||||
| client = BaseClient(card=Mock(), config=Mock(), transport=Mock(), consumers=[], middleware=[]) | ||||||||||||||||||||||||||||
| server_handler = DefaultRequestHandler( | ||||||||||||||||||||||||||||
| agent_executor=Mock(), | ||||||||||||||||||||||||||||
| task_store=Mock(), | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| trace_extension = TraceExtension() | ||||||||||||||||||||||||||||
| client.install_extension(trace_extension) | ||||||||||||||||||||||||||||
| server_handler.install_extension(trace_extension, server=Mock()) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| message = Message( | ||||||||||||||||||||||||||||
| message_id='test_message', | ||||||||||||||||||||||||||||
| role='user', | ||||||||||||||||||||||||||||
| parts=[TextPart(text='Hello, world!')], | ||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Simulate client sending a message | ||||||||||||||||||||||||||||
| for extension in client._extensions: | ||||||||||||||||||||||||||||
| extension.on_client_message(message) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| assert 'trace' in message.metadata | ||||||||||||||||||||||||||||
| assert message.metadata['trace'] == 'client-trace' | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Simulate server receiving a message | ||||||||||||||||||||||||||||
| for extension in server_handler._extensions: | ||||||||||||||||||||||||||||
| extension.on_server_message(message) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| # Check that the server-side handler was called | ||||||||||||||||||||||||||||
| # (in this case, it just prints a message) | ||||||||||||||||||||||||||||
| # We can't easily check the output of print, so we'll just | ||||||||||||||||||||||||||||
| # assume it worked if no exceptions were raised. | ||||||||||||||||||||||||||||
|
Comment on lines
+38
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test for the server-side message handling is currently implicit, assuming it works if no exception is raised. You can make this test more explicit and robust by mocking the
Suggested change
|
||||||||||||||||||||||||||||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better type safety, the
messageparameter inon_client_messageshould be typed asMessageinstead ofAny. Based on the usage inbase_client.pyanddefault_request_handler.py, this parameter is always an instance ofa2a.types.Message.To make this work, you'll also need to add
from a2a.types import Messageinside theif TYPE_CHECKING:block at the top of the file to avoid circular imports.