Skip to content

Commit 12e595d

Browse files
committed
Fix an issue with ClientFactory not respecting transport URL, add tests
1 parent b32b0da commit 12e595d

4 files changed

Lines changed: 117 additions & 18 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333

3434
TransportProducer = Callable[
35-
[AgentCard, ClientConfig, list[ClientCallInterceptor]],
35+
[AgentCard, str, ClientConfig, list[ClientCallInterceptor]],
3636
ClientTransport,
3737
]
3838

@@ -68,28 +68,28 @@ def __init__(
6868
def _register_defaults(self) -> None:
6969
self.register(
7070
TransportProtocol.jsonrpc,
71-
lambda card, config, interceptors: JsonRpcTransport(
71+
lambda card, url, config, interceptors: JsonRpcTransport(
7272
config.httpx_client or httpx.AsyncClient(),
7373
card,
74-
card.url,
74+
url,
7575
interceptors,
7676
),
7777
)
7878
self.register(
7979
TransportProtocol.http_json,
80-
lambda card, config, interceptors: RestTransport(
80+
lambda card, url, config, interceptors: RestTransport(
8181
config.httpx_client or httpx.AsyncClient(),
8282
card,
83-
card.url,
83+
url,
8484
interceptors,
8585
),
8686
)
8787
if GrpcTransport:
8888
self.register(
8989
TransportProtocol.grpc,
90-
lambda card, config, interceptors: GrpcTransport(
90+
lambda card, url, config, interceptors: GrpcTransport(
9191
a2a_pb2_grpc.A2AServiceStub(
92-
config.grpc_channel_factory(card.url)
92+
config.grpc_channel_factory(url)
9393
),
9494
card,
9595
),
@@ -121,24 +121,30 @@ def create(
121121
If there is no valid matching of the client configuration with the
122122
server configuration, a `ValueError` is raised.
123123
"""
124-
server_set = [card.preferred_transport or TransportProtocol.jsonrpc]
124+
server_preferred = card.preferred_transport or TransportProtocol.jsonrpc
125+
server_set = {server_preferred: card.url}
125126
if card.additional_interfaces:
126-
server_set.extend([x.transport for x in card.additional_interfaces])
127+
server_set.update(
128+
{x.transport: x.url for x in card.additional_interfaces}
129+
)
127130
client_set = self._config.supported_transports or [
128131
TransportProtocol.jsonrpc
129132
]
130133
transport_protocol = None
134+
transport_url = None
131135
if self._config.use_client_preference:
132136
for x in client_set:
133137
if x in server_set:
134138
transport_protocol = x
139+
transport_url = server_set[x]
135140
break
136141
else:
137-
for x in server_set:
142+
for x, url in server_set.items():
138143
if x in client_set:
139144
transport_protocol = x
145+
transport_url = url
140146
break
141-
if not transport_protocol:
147+
if not transport_protocol or not transport_url:
142148
raise ValueError('no compatible transports found.')
143149
if transport_protocol not in self._registry:
144150
raise ValueError(f'no client available for {transport_protocol}')
@@ -148,7 +154,7 @@ def create(
148154
all_consumers.extend(consumers)
149155

150156
transport = self._registry[transport_protocol](
151-
card, self._config, interceptors or []
157+
card, transport_url, self._config, interceptors or []
152158
)
153159

154160
return BaseClient(

src/a2a/client/transports/jsonrpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ def __init__(
6464
interceptors: list[ClientCallInterceptor] | None = None,
6565
):
6666
"""Initializes the JsonRpcTransport."""
67-
if agent_card:
68-
self.url = agent_card.url
69-
elif url:
67+
if url:
7068
self.url = url
69+
elif agent_card:
70+
self.url = agent_card.url
7171
else:
7272
raise ValueError('Must provide either agent_card or url')
7373

src/a2a/client/transports/rest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def __init__(
4545
interceptors: list[ClientCallInterceptor] | None = None,
4646
):
4747
"""Initializes the RestTransport."""
48-
if agent_card:
49-
self.url = agent_card.url
50-
elif url:
48+
if url:
5149
self.url = url
50+
elif agent_card:
51+
self.url = agent_card.url
5252
else:
5353
raise ValueError('Must provide either agent_card or url')
5454
if self.url.endswith('/'):
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Tests for the ClientFactory."""
2+
3+
import httpx
4+
import pytest
5+
6+
from a2a.client import ClientConfig, ClientFactory
7+
from a2a.client.transports import JsonRpcTransport, RestTransport
8+
from a2a.types import (
9+
AgentCard,
10+
AgentCapabilities,
11+
AgentInterface,
12+
TransportProtocol,
13+
)
14+
15+
16+
@pytest.fixture
17+
def base_agent_card() -> AgentCard:
18+
"""Provides a base AgentCard for tests."""
19+
return AgentCard(
20+
name="Test Agent",
21+
description="An agent for testing.",
22+
url="http://primary-url.com",
23+
version="1.0.0",
24+
capabilities=AgentCapabilities(),
25+
skills=[],
26+
default_input_modes=[],
27+
default_output_modes=[],
28+
preferred_transport=TransportProtocol.jsonrpc,
29+
)
30+
31+
32+
def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard):
33+
"""Verify that the factory selects the preferred transport by default."""
34+
config = ClientConfig(
35+
httpx_client=httpx.AsyncClient(),
36+
supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json],
37+
)
38+
factory = ClientFactory(config)
39+
client = factory.create(base_agent_card)
40+
41+
assert isinstance(client._transport, JsonRpcTransport)
42+
assert client._transport.url == "http://primary-url.com"
43+
44+
45+
def test_client_factory_selects_secondary_transport_url(base_agent_card: AgentCard):
46+
"""Verify that the factory selects the correct URL for a secondary transport."""
47+
base_agent_card.additional_interfaces = [
48+
AgentInterface(
49+
transport=TransportProtocol.http_json, url="http://secondary-url.com"
50+
)
51+
]
52+
# Client prefers REST, which is available as a secondary transport
53+
config = ClientConfig(
54+
httpx_client=httpx.AsyncClient(),
55+
supported_transports=[TransportProtocol.http_json, TransportProtocol.jsonrpc],
56+
use_client_preference=True,
57+
)
58+
factory = ClientFactory(config)
59+
client = factory.create(base_agent_card)
60+
61+
assert isinstance(client._transport, RestTransport)
62+
assert client._transport.url == "http://secondary-url.com"
63+
64+
65+
def test_client_factory_server_preference(base_agent_card: AgentCard):
66+
"""Verify that the factory respects server transport preference."""
67+
base_agent_card.preferred_transport = TransportProtocol.http_json
68+
base_agent_card.additional_interfaces = [
69+
AgentInterface(
70+
transport=TransportProtocol.jsonrpc, url="http://secondary-url.com"
71+
)
72+
]
73+
# Client supports both, but server prefers REST
74+
config = ClientConfig(
75+
httpx_client=httpx.AsyncClient(),
76+
supported_transports=[TransportProtocol.jsonrpc, TransportProtocol.http_json],
77+
)
78+
factory = ClientFactory(config)
79+
client = factory.create(base_agent_card)
80+
81+
assert isinstance(client._transport, RestTransport)
82+
assert client._transport.url == "http://primary-url.com"
83+
84+
85+
def test_client_factory_no_compatible_transport(base_agent_card: AgentCard):
86+
"""Verify that the factory raises an error if no compatible transport is found."""
87+
config = ClientConfig(
88+
httpx_client=httpx.AsyncClient(),
89+
supported_transports=[TransportProtocol.grpc],
90+
)
91+
factory = ClientFactory(config)
92+
with pytest.raises(ValueError, match="no compatible transports found"):
93+
factory.create(base_agent_card)

0 commit comments

Comments
 (0)