Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from a2a.client.transports.jsonrpc import JsonRpcTransport
from a2a.client.transports.rest import RestTransport
from a2a.client.transports.tenant_decorator import TenantTransportDecorator
from a2a.compat.v0_3.versions import is_legacy_version
from a2a.types.a2a_pb2 import (
AgentCapabilities,
AgentCard,
Expand Down Expand Up @@ -111,7 +112,7 @@ def jsonrpc_transport_producer(
else PROTOCOL_VERSION_CURRENT
)

if ClientFactory._is_legacy_version(version):
if is_legacy_version(version):
from a2a.compat.v0_3.jsonrpc_transport import ( # noqa: PLC0415
CompatJsonRpcTransport,
)
Expand Down Expand Up @@ -150,7 +151,7 @@ def rest_transport_producer(
else PROTOCOL_VERSION_CURRENT
)

if ClientFactory._is_legacy_version(version):
if is_legacy_version(version):
from a2a.compat.v0_3.rest_transport import ( # noqa: PLC0415
CompatRestTransport,
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def grpc_transport_producer(
)

if (
ClientFactory._is_legacy_version(version)
is_legacy_version(version)
and CompatGrpcTransport is not None
):
return CompatGrpcTransport.create(card, url, config)
Expand All @@ -215,21 +216,6 @@ def grpc_transport_producer(
grpc_transport_producer,
)

@staticmethod
def _is_legacy_version(version: str | None) -> bool:
"""Determines if the given version is a legacy protocol version (>=0.3 and <1.0)."""
if not version:
return False
try:
v = Version(version)
return (
Version(PROTOCOL_VERSION_0_3)
<= v
< Version(PROTOCOL_VERSION_1_0)
)
except InvalidVersion:
return False

@staticmethod
def _find_best_interface(
interfaces: list[AgentInterface],
Expand Down
39 changes: 22 additions & 17 deletions src/a2a/compat/v0_3/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from google.protobuf.json_format import MessageToDict, ParseDict

from a2a.compat.v0_3 import types as types_v03
from a2a.compat.v0_3.versions import is_legacy_version
from a2a.server.models import PushNotificationConfigModel, TaskModel
from a2a.types import a2a_pb2 as pb2_v10
from a2a.utils import constants, errors


_COMPAT_TO_CORE_TASK_STATE: dict[types_v03.TaskState, Any] = {
Expand Down Expand Up @@ -676,7 +678,7 @@ def to_core_agent_interface(
return pb2_v10.AgentInterface(
url=compat_interface.url,
protocol_binding=compat_interface.transport,
protocol_version='0.3.0', # Defaulting for legacy
protocol_version=constants.PROTOCOL_VERSION_0_3, # Defaulting for legacy
)


Expand Down Expand Up @@ -857,7 +859,8 @@ def to_core_agent_card(compat_card: types_v03.AgentCard) -> pb2_v10.AgentCard:
primary_interface = pb2_v10.AgentInterface(
url=compat_card.url,
protocol_binding=compat_card.preferred_transport or 'JSONRPC',
protocol_version=compat_card.protocol_version or '0.3.0',
protocol_version=compat_card.protocol_version
or constants.PROTOCOL_VERSION_0_3,
)
core_card.supported_interfaces.append(primary_interface)

Expand Down Expand Up @@ -918,21 +921,23 @@ def to_core_agent_card(compat_card: types_v03.AgentCard) -> pb2_v10.AgentCard:
def to_compat_agent_card(core_card: pb2_v10.AgentCard) -> types_v03.AgentCard:
# Map supported interfaces back to legacy layout
"""Convert agent card to v0.3 compat type."""
primary_interface = (
core_card.supported_interfaces[0]
if core_card.supported_interfaces
else pb2_v10.AgentInterface(
url='', protocol_binding='JSONRPC', protocol_version='0.3.0'
compat_interfaces = [
interface
for interface in core_card.supported_interfaces
if (
(not interface.protocol_version)
or is_legacy_version(interface.protocol_version)
)
)
additional_interfaces = (
[
to_compat_agent_interface(i)
for i in core_card.supported_interfaces[1:]
]
if len(core_card.supported_interfaces) > 1
else None
)
]
if not compat_interfaces:
raise errors.VersionNotSupportedError(
'AgentCard must have at least one interface with compatible protocol version.'
)

primary_interface = compat_interfaces[0]
additional_interfaces = [
to_compat_agent_interface(i) for i in compat_interfaces[1:]
]

compat_cap = to_compat_agent_capabilities(core_card.capabilities)
supports_authenticated_extended_card = (
Expand All @@ -948,7 +953,7 @@ def to_compat_agent_card(core_card: pb2_v10.AgentCard) -> types_v03.AgentCard:
url=primary_interface.url,
preferred_transport=primary_interface.protocol_binding,
protocol_version=primary_interface.protocol_version,
additional_interfaces=additional_interfaces,
additional_interfaces=additional_interfaces or None,
provider=to_compat_agent_provider(core_card.provider)
if core_card.HasField('provider')
else None,
Expand Down
18 changes: 18 additions & 0 deletions src/a2a/compat/v0_3/versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Utility functions for protocol version comparison and validation."""

from packaging.version import InvalidVersion, Version

from a2a.utils.constants import PROTOCOL_VERSION_0_3, PROTOCOL_VERSION_1_0


def is_legacy_version(version: str | None) -> bool:
"""Determines if the given version is a legacy protocol version (>=0.3 and <1.0)."""
if not version:
return False
try:
v = Version(version)
return (
Version(PROTOCOL_VERSION_0_3) <= v < Version(PROTOCOL_VERSION_1_0)
)
except InvalidVersion:
return False
7 changes: 5 additions & 2 deletions src/a2a/server/request_handlers/response_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def agent_card_to_dict(card: AgentCard) -> dict[str, Any]:
"""Convert AgentCard to dict and inject backward compatibility fields."""
result = MessageToDict(card)

compat_card = to_compat_agent_card(card)
compat_dict = compat_card.model_dump(exclude_none=True)
try:
compat_card = to_compat_agent_card(card)
compat_dict = compat_card.model_dump(exclude_none=True)
except VersionNotSupportedError:
compat_dict = {}

# Do not include supportsAuthenticatedExtendedCard if false
if not compat_dict.get('supportsAuthenticatedExtendedCard'):
Expand Down
Empty file.
11 changes: 2 additions & 9 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from httpx_sse import EventSource, ServerSentEvent

from a2a.client import create_text_message_object
from a2a.client.client import ClientCallContext
from a2a.client.errors import A2AClientError
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
Expand Down Expand Up @@ -162,7 +163,6 @@ async def test_send_message_with_timeout_context(
self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock
):
"""Test that send_message passes context timeout to build_request."""
from a2a.client.client import ClientCallContext

client = RestTransport(
httpx_client=mock_httpx_client,
Expand Down Expand Up @@ -258,8 +258,6 @@ async def test_send_message_with_default_extensions(
mock_response.status_code = 200
mock_httpx_client.send.return_value = mock_response

from a2a.client.client import ClientCallContext

context = ClientCallContext(
service_parameters={
'X-A2A-Extensions': 'https://example.com/test-ext/v1,https://example.com/test-ext/v2'
Expand Down Expand Up @@ -302,8 +300,6 @@ async def test_send_message_streaming_with_new_extensions(
mock_event_source
)

from a2a.client.client import ClientCallContext

context = ClientCallContext(
service_parameters={
'X-A2A-Extensions': 'https://example.com/test-ext/v2'
Expand Down Expand Up @@ -404,8 +400,6 @@ async def test_get_card_with_extended_card_support_with_extensions(

request = GetExtendedAgentCardRequest()

from a2a.client.client import ClientCallContext

context = ClientCallContext(
service_parameters={HTTP_EXTENSION_HEADER: extensions_str}
)
Expand All @@ -419,7 +413,6 @@ async def test_get_card_with_extended_card_support_with_extensions(
await client.get_extended_agent_card(request, context=context)

mock_execute_request.assert_called_once()
# _execute_request(method, target, tenant, context)
call_args = mock_execute_request.call_args
assert (
call_args[1].get('context') == context or call_args[0][3] == context
Expand Down Expand Up @@ -694,7 +687,7 @@ async def test_rest_get_task_prepend_empty_tenant(
)
@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
async def test_rest_streaming_methods_prepend_tenant(
async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913
self,
mock_aconnect_sse,
method_name,
Expand Down
26 changes: 24 additions & 2 deletions tests/compat/v0_3/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from a2a.server.models import PushNotificationConfigModel, TaskModel
from cryptography.fernet import Fernet
from a2a.types import a2a_pb2 as pb2_v10
from a2a.utils.errors import VersionNotSupportedError


def test_text_part_conversion():
Expand Down Expand Up @@ -986,7 +987,7 @@ def test_security_scheme_mtls_minimal():
def test_agent_interface_conversion():
v03_int = types_v03.AgentInterface(url='http', transport='JSONRPC')
v10_expected = pb2_v10.AgentInterface(
url='http', protocol_binding='JSONRPC', protocol_version='0.3.0'
url='http', protocol_binding='JSONRPC', protocol_version='0.3'
)
v10_int = to_core_agent_interface(v03_int)
assert v10_int == v10_expected
Expand Down Expand Up @@ -1131,7 +1132,7 @@ def test_agent_card_conversion():
url='u1', protocol_binding='JSONRPC', protocol_version='0.3.0'
),
pb2_v10.AgentInterface(
url='u2', protocol_binding='HTTP', protocol_version='0.3.0'
url='u2', protocol_binding='HTTP', protocol_version='0.3'
),
]
)
Expand Down Expand Up @@ -2014,3 +2015,24 @@ def test_push_notification_config_persistence_conversion_with_encryption():
assert v10_restored.id == v10_config.id
assert v10_restored.url == v10_config.url
assert v10_restored.token == v10_config.token


def test_to_compat_agent_card_unsupported_version():
card = pb2_v10.AgentCard(
name='Modern Agent',
description='Only supports 1.0',
version='1.0.0',
supported_interfaces=[
pb2_v10.AgentInterface(
url='http://grpc.v10.com',
protocol_binding='GRPC',
protocol_version='1.0.0',
),
],
capabilities=pb2_v10.AgentCapabilities(),
)
with pytest.raises(
VersionNotSupportedError,
match='AgentCard must have at least one interface with compatible protocol version.',
):
to_compat_agent_card(card)
10 changes: 9 additions & 1 deletion tests/compat/v0_3/test_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def sample_agent_card() -> a2a_pb2.AgentCard:
name='Test Agent',
description='A test agent',
version='1.0.0',
supported_interfaces=[
a2a_pb2.AgentInterface(
url='http://jsonrpc.v03.com',
protocol_binding='JSONRPC',
protocol_version='0.3',
),
],
)


Expand Down Expand Up @@ -434,8 +441,9 @@ async def test_get_agent_card_success(
expected_res = a2a_v0_3_pb2.AgentCard(
name='Test Agent',
description='A test agent',
url='http://jsonrpc.v03.com',
version='1.0.0',
protocol_version='0.3.0',
protocol_version='0.3',
preferred_transport='JSONRPC',
capabilities=a2a_v0_3_pb2.AgentCapabilities(),
)
Expand Down
4 changes: 1 addition & 3 deletions tests/compat/v0_3/test_rest_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,7 @@ async def test_compat_rest_transport_subscribe_post_405_get_405_fails(

async def mock_stream(method, path, context=None, json=None):
method_count[method] = method_count.get(method, 0) + 1
if method == 'POST':
assert json is None
elif method == 'GET':
if method in {'POST', 'GET'}:
assert json is None
# To make it an async generator even when it raises
if False:
Expand Down
27 changes: 27 additions & 0 deletions tests/compat/v0_3/test_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tests for version utility functions."""

import pytest

from a2a.compat.v0_3.versions import is_legacy_version


@pytest.mark.parametrize(
'version, expected',
[
('0.3', True),
('0.3.0', True),
('0.9', True),
('0.9.9', True),
('1.0', False),
('1.0.0', False),
('1.1', False),
('0.2', False),
('0.2.9', False),
(None, False),
('', False),
('invalid', False),
('v0.3', True),
],
)
def test_is_legacy_version(version, expected):
assert is_legacy_version(version) == expected
Empty file.
Loading
Loading