Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from a2a.client.errors import A2AClientInvalidStateError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.extensions.base import Extension
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
Expand Down Expand Up @@ -41,6 +42,12 @@ def __init__(
self._card = card
self._config = config
self._transport = transport
self._extensions: list[Extension] = []

def install_extension(self, extension: Extension) -> None:
"""Installs an extension on the client."""
extension.install(self)
self._extensions.append(extension)

async def send_message(
self,
Expand All @@ -61,6 +68,9 @@ async def send_message(
Yields:
An async iterator of `ClientEvent` or a final `Message` response.
"""
for extension in self._extensions:
extension.on_client_message(request)

config = MessageSendConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
Expand Down
6 changes: 6 additions & 0 deletions src/a2a/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""A2A extensions."""

from .base import Extension

Check failure on line 3 in src/a2a/extensions/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TID252)

src/a2a/extensions/__init__.py:3:1: TID252 Prefer absolute imports over relative imports
from . import common, trace

Check failure on line 4 in src/a2a/extensions/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TID252)

src/a2a/extensions/__init__.py:4:1: TID252 Prefer absolute imports over relative imports

Check failure on line 4 in src/a2a/extensions/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TID252)

src/a2a/extensions/__init__.py:4:1: TID252 Prefer absolute imports over relative imports

Check failure on line 4 in src/a2a/extensions/__init__.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/extensions/__init__.py:3:1: I001 Import block is un-sorted or un-formatted

__all__ = ['Extension', 'common', 'trace']
26 changes: 26 additions & 0 deletions src/a2a/extensions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

Check failure on line 3 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (I001)

src/a2a/extensions/base.py:1:1: I001 Import block is un-sorted or un-formatted

if TYPE_CHECKING:
from a2a.client.client import A2AClient

Check failure on line 6 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

"A2AClient" is unknown import symbol (reportAttributeAccessIssue)
from a2a.server.server import A2AServer


class Extension:
"""Base class for all extensions."""

def __init__(self, **kwargs: Any) -> None:
...

def on_client_message(self, message: Any) -> None:
"""Called when a message is sent from the client."""
Comment on lines +16 to +17
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety, the message parameter in on_client_message should be typed as Message instead of Any. Based on the usage in base_client.py and default_request_handler.py, this parameter is always an instance of a2a.types.Message.

To make this work, you'll also need to add from a2a.types import Message inside the if TYPE_CHECKING: block at the top of the file to avoid circular imports.

Suggested change
def on_client_message(self, message: Any) -> None:
"""Called when a message is sent from the client."""
def on_client_message(self, message: Message) -> None:

...

Check failure on line 18 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:18:9: PIE790 Unnecessary `...` literal

def on_server_message(self, message: Any) -> None:
"""Called when a message is received by the server."""
Comment on lines +20 to +21
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety, the message parameter in on_server_message should be typed as Message instead of Any. Based on the usage in base_client.py and default_request_handler.py, this parameter is always an instance of a2a.types.Message.

To make this work, you'll also need to add from a2a.types import Message inside the if TYPE_CHECKING: block at the top of the file to avoid circular imports.

Suggested change
def on_server_message(self, message: Any) -> None:
"""Called when a message is received by the server."""
def on_server_message(self, message: Message) -> None:

...

Check failure on line 22 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:22:9: PIE790 Unnecessary `...` literal

def install(self, client_or_server: A2AClient | A2AServer) -> None:
"""Called when the extension is installed on a client or server."""
...

Check failure on line 26 in src/a2a/extensions/base.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (PIE790)

src/a2a/extensions/base.py:26:9: PIE790 Unnecessary `...` literal
81 changes: 81 additions & 0 deletions src/a2a/extensions/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from datetime import datetime

Check failure on line 3 in src/a2a/extensions/trace.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TC003)

src/a2a/extensions/trace.py:3:22: TC003 Move standard library import `datetime.datetime` into a type-checking block
from enum import Enum
from typing import Any

from a2a._base import A2ABaseModel
from a2a.extensions.base import Extension


class CallTypeEnum(str, Enum):
"""The type of the operation a step represents."""

AGENT = 'AGENT'
TOOL = 'TOOL'


class ToolInvocation(A2ABaseModel):
"""A tool invocation."""

tool_name: str
parameters: dict[str, Any]


class AgentInvocation(A2ABaseModel):
"""An agent invocation."""

agent_url: str
agent_name: str
requests: dict[str, Any]
response_trace: ResponseTrace | None = None


class StepAction(A2ABaseModel):
"""The action of a step."""

tool_invocation: ToolInvocation | None = None
agent_invocation: AgentInvocation | None = None


class Step(A2ABaseModel):
"""A single operation within a trace."""

step_id: str
trace_id: str
parent_step_id: str | None = None
call_type: CallTypeEnum
step_action: StepAction
cost: int | None = None
total_tokens: int | None = None
additional_attributes: dict[str, str] | None = None
latency: int | None = None
start_time: datetime
end_time: datetime


class ResponseTrace(A2ABaseModel):
"""A trace message that contains a collection of spans."""

trace_id: str
steps: list[Step]


class TraceExtension(Extension):
"""An extension for traceability."""

def on_client_message(self, message: Any) -> None:
"""Appends trace information to the message."""
# This is a placeholder implementation.
if message.metadata is None:
message.metadata = {}
message.metadata['trace'] = 'client-trace'

def on_server_message(self, message: Any) -> None:
"""Processes trace information from the message."""
# This is a placeholder implementation.
if hasattr(message, 'metadata') and 'trace' in message.metadata:
print(f"Received trace: {message.metadata['trace']}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential TypeError in on_server_message. The expression 'trace' in message.metadata will raise an exception if message.metadata is None. You should ensure message.metadata is not None before checking for a key within it.

Suggested change
if hasattr(message, 'metadata') and 'trace' in message.metadata:
print(f"Received trace: {message.metadata['trace']}")
if hasattr(message, 'metadata') and message.metadata and 'trace' in message.metadata:
print(f"Received trace: {message.metadata['trace']}")



AgentInvocation.model_rebuild()

Check failure on line 81 in src/a2a/extensions/trace.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (W292)

src/a2a/extensions/trace.py:81:32: W292 No newline at end of file
12 changes: 11 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from collections.abc import AsyncGenerator
from typing import cast
from typing import Any, cast

from a2a.server.agent_execution import (
AgentExecutor,
Expand All @@ -26,6 +26,7 @@
TaskManager,
TaskStore,
)
from a2a.extensions.base import Extension
from a2a.types import (
DeleteTaskPushNotificationConfigParams,
GetTaskPushNotificationConfigParams,
Expand Down Expand Up @@ -101,6 +102,12 @@ def __init__( # noqa: PLR0913
# TODO: Likely want an interface for managing this, like AgentExecutionManager.
self._running_agents = {}
self._running_agents_lock = asyncio.Lock()
self._extensions: list[Extension] = []

def install_extension(self, extension: Extension, server: Any) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The server parameter is typed as Any. For better type safety, it should be typed more specifically. Based on the Extension.install method signature, it should be A2AServer. To avoid circular imports, you can use a string forward reference: server: "A2AServer". You would then need to add from a2a.server.server import A2AServer inside a TYPE_CHECKING block at the top of the file.

Suggested change
def install_extension(self, extension: Extension, server: Any) -> None:
def install_extension(self, extension: Extension, server: "A2AServer") -> None:

"""Installs an extension on the server."""
extension.install(server)
self._extensions.append(extension)

async def on_get_task(
self,
Expand Down Expand Up @@ -182,6 +189,9 @@ async def _setup_message_execution(
Returns:
A tuple of (task_manager, task_id, queue, result_aggregator, producer_task)
"""
for extension in self._extensions:
extension.on_server_message(params.message)

# Create task manager and validate existing task
task_manager = TaskManager(
task_id=params.message.task_id,
Expand Down
63 changes: 63 additions & 0 deletions tests/extensions/test_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from datetime import datetime, timezone

from a2a.extensions.trace import (
AgentInvocation,
CallTypeEnum,
ResponseTrace,
Step,
StepAction,
ToolInvocation,
)


def test_trace_serialization():
start_time = datetime(2025, 3, 15, 12, 0, 0, tzinfo=timezone.utc)
end_time = datetime(2025, 3, 15, 12, 0, 0, 250000, tzinfo=timezone.utc)

trace = ResponseTrace(
trace_id='trace-example-12345',
steps=[
Step(
step_id='step-1-agent',
trace_id='trace-example-1234p',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The trace_id for this Step is 'trace-example-1234p', which is different from the ResponseTrace's trace_id ('trace-example-12345'). While this doesn't break the current serialization test, it represents logically inconsistent test data which might be confusing for future readers. A step's trace_id should typically match the ID of the trace it belongs to.

Suggested change
trace_id='trace-example-1234p',
trace_id='trace-example-12345',

call_type=CallTypeEnum.AGENT,
step_action=StepAction(
agent_invocation=AgentInvocation(
agent_name='weather_agent',
agent_url='http://google3/some/agent/url',
requests={
'user_prompt': "What's the weather in Paris and what should I wear?"
},
)
),
cost=150,
total_tokens=75,
additional_attributes={'user_country': 'US'},
latency=250,
start_time=start_time,
end_time=end_time,
),
Step(
step_id='step-2-tool',
trace_id='trace-example-12345',
parent_step_id='step-1-agent',
call_type=CallTypeEnum.TOOL,
step_action=StepAction(
tool_invocation=ToolInvocation(
tool_name='google_map_api_tool',
parameters={'location': 'Paris, FR'},
)
),
cost=50,
total_tokens=20,
latency=100,
start_time=start_time,
end_time=end_time,
),
],
)

trace_dict = trace.model_dump(mode='json')
deserialized_trace = ResponseTrace.model_validate(trace_dict)

assert trace == deserialized_trace
43 changes: 43 additions & 0 deletions tests/extensions/test_trace_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import Mock

import pytest

from a2a.client.base_client import BaseClient
from a2a.extensions.trace import TraceExtension
from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler
from a2a.types import Message, TextPart


@pytest.mark.asyncio
async def test_trace_extension():
client = BaseClient(card=Mock(), config=Mock(), transport=Mock(), consumers=[], middleware=[])
server_handler = DefaultRequestHandler(
agent_executor=Mock(),
task_store=Mock(),
)

trace_extension = TraceExtension()
client.install_extension(trace_extension)
server_handler.install_extension(trace_extension, server=Mock())

message = Message(
message_id='test_message',
role='user',
parts=[TextPart(text='Hello, world!')],
)

# Simulate client sending a message
for extension in client._extensions:
extension.on_client_message(message)

assert 'trace' in message.metadata
assert message.metadata['trace'] == 'client-trace'

# Simulate server receiving a message
for extension in server_handler._extensions:
extension.on_server_message(message)

# Check that the server-side handler was called
# (in this case, it just prints a message)
# We can't easily check the output of print, so we'll just
# assume it worked if no exceptions were raised.
Comment on lines +38 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test for the server-side message handling is currently implicit, assuming it works if no exception is raised. You can make this test more explicit and robust by mocking the print function and asserting that it's called with the expected arguments. You will need to add import unittest.mock or similar to the top of the file.

Suggested change
for extension in server_handler._extensions:
extension.on_server_message(message)
# Check that the server-side handler was called
# (in this case, it just prints a message)
# We can't easily check the output of print, so we'll just
# assume it worked if no exceptions were raised.
with unittest.mock.patch('builtins.print') as mock_print:
for extension in server_handler._extensions:
extension.on_server_message(message)
# Check that the server-side handler was called
mock_print.assert_called_once_with("Received trace: client-trace")

4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading