Skip to content

Commit c7f4eb0

Browse files
committed
Gemini authored: refactor clients into BaseClient + ClientTransport
1 parent 0891716 commit c7f4eb0

11 files changed

Lines changed: 1352 additions & 2380 deletions

File tree

src/a2a/client/__init__.py

Lines changed: 17 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from a2a.client.client import Client, ClientConfig, ClientEvent, Consumer
1212
from a2a.client.client_factory import (
1313
ClientFactory,
14-
ClientProducer,
1514
minimal_agent_card,
1615
)
1716
from a2a.client.errors import (
@@ -21,77 +20,28 @@
2120
A2AClientTimeoutError,
2221
)
2322
from a2a.client.helpers import create_text_message_object
24-
from a2a.client.jsonrpc_client import (
25-
A2AClient,
26-
JsonRpcClient,
27-
JsonRpcTransportClient,
28-
NewJsonRpcClient,
29-
)
3023
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
31-
from a2a.client.rest_client import (
32-
NewRestfulClient,
33-
RestClient,
34-
RestTransportClient,
35-
)
3624

3725

3826
logger = logging.getLogger(__name__)
3927

40-
try:
41-
from a2a.client.grpc_client import (
42-
GrpcClient,
43-
GrpcTransportClient, # type: ignore
44-
NewGrpcClient,
45-
)
46-
except ImportError as e:
47-
_original_error = e
48-
logger.debug(
49-
'A2AGrpcClient not loaded. This is expected if gRPC dependencies are not installed. Error: %s',
50-
_original_error,
51-
)
52-
53-
class GrpcTransportClient: # type: ignore
54-
"""Placeholder for A2AGrpcClient when dependencies are not installed."""
55-
56-
def __init__(self, *args, **kwargs):
57-
raise ImportError(
58-
'To use A2AGrpcClient, its dependencies must be installed. '
59-
'You can install them with \'pip install "a2a-sdk[grpc]"\''
60-
) from _original_error
61-
finally:
62-
# For backward compatability define this alias. This will be deprecated in
63-
# a future release.
64-
A2AGrpcClient = GrpcTransportClient # type: ignore
65-
6628

6729
__all__ = [
68-
'A2ACardResolver',
69-
'A2AClient', # for backward compatability
70-
'A2AClientError',
71-
'A2AClientHTTPError',
72-
'A2AClientJSONError',
73-
'A2AClientTimeoutError',
74-
'A2AGrpcClient', # for backward compatability
75-
'AuthInterceptor',
76-
'Client',
77-
'ClientCallContext',
78-
'ClientCallInterceptor',
79-
'ClientConfig',
80-
'ClientEvent',
81-
'ClientFactory',
82-
'ClientProducer',
83-
'Consumer',
84-
'CredentialService',
85-
'GrpcClient',
86-
'GrpcTransportClient',
87-
'InMemoryContextCredentialStore',
88-
'JsonRpcClient',
89-
'JsonRpcTransportClient',
90-
'NewGrpcClient',
91-
'NewJsonRpcClient',
92-
'NewRestfulClient',
93-
'RestClient',
94-
'RestTransportClient',
95-
'create_text_message_object',
96-
'minimal_agent_card',
30+
"A2ACardResolver",
31+
"A2AClientError",
32+
"A2AClientHTTPError",
33+
"A2AClientJSONError",
34+
"A2AClientTimeoutError",
35+
"AuthInterceptor",
36+
"Client",
37+
"ClientCallContext",
38+
"ClientCallInterceptor",
39+
"ClientConfig",
40+
"ClientEvent",
41+
"ClientFactory",
42+
"Consumer",
43+
"CredentialService",
44+
"InMemoryContextCredentialStore",
45+
"create_text_message_object",
46+
"minimal_agent_card",
9747
]

src/a2a/client/base_client.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
from collections.abc import AsyncIterator
2+
3+
from a2a.client.client import (
4+
Client,
5+
ClientCallContext,
6+
ClientConfig,
7+
ClientEvent,
8+
Consumer,
9+
)
10+
from a2a.client.client_task_manager import ClientTaskManager
11+
from a2a.client.errors import A2AClientInvalidStateError
12+
from a2a.client.middleware import ClientCallInterceptor
13+
from a2a.client.transports.base import ClientTransport
14+
from a2a.types import (
15+
AgentCard,
16+
GetTaskPushNotificationConfigParams,
17+
Message,
18+
MessageSendConfiguration,
19+
MessageSendParams,
20+
Task,
21+
TaskArtifactUpdateEvent,
22+
TaskIdParams,
23+
TaskPushNotificationConfig,
24+
TaskQueryParams,
25+
TaskStatusUpdateEvent,
26+
)
27+
28+
29+
class BaseClient(Client):
30+
"""Base implementation of the A2A client, containing transport-independent logic."""
31+
32+
def __init__(
33+
self,
34+
card: AgentCard,
35+
config: ClientConfig,
36+
transport: ClientTransport,
37+
consumers: list[Consumer],
38+
middleware: list[ClientCallInterceptor],
39+
):
40+
super().__init__(consumers, middleware)
41+
self._card = card
42+
self._config = config
43+
self._transport = transport
44+
45+
async def send_message(
46+
self,
47+
request: Message,
48+
*,
49+
context: ClientCallContext | None = None,
50+
) -> AsyncIterator[ClientEvent | Message]:
51+
"""Sends a message to the agent.
52+
53+
This method handles both streaming and non-streaming (polling) interactions
54+
based on the client configuration and agent capabilities. It will yield
55+
events as they are received from the agent.
56+
57+
Args:
58+
request: The message to send to the agent.
59+
context: The client call context.
60+
61+
Yields:
62+
An async iterator of `ClientEvent` or a final `Message` response.
63+
"""
64+
config = MessageSendConfiguration(
65+
accepted_output_modes=self._config.accepted_output_modes,
66+
blocking=not self._config.polling,
67+
push_notification_config=(
68+
self._config.push_notification_configs[0]
69+
if self._config.push_notification_configs
70+
else None
71+
),
72+
)
73+
params = MessageSendParams(message=request, configuration=config)
74+
75+
if not self._config.streaming or not self._card.capabilities.streaming:
76+
response = await self._transport.send_message(params, context=context)
77+
result = (
78+
(response, None) if isinstance(response, Task) else response
79+
)
80+
await self.consume(result, self._card)
81+
yield result
82+
return
83+
84+
tracker = ClientTaskManager()
85+
stream = self._transport.send_message_streaming(params, context=context)
86+
87+
first_event = await anext(stream)
88+
if isinstance(first_event, Message):
89+
await self.consume(first_event, self._card)
90+
yield first_event
91+
return
92+
93+
yield await self._process_response(tracker, first_event)
94+
95+
async for event in stream:
96+
yield await self._process_response(tracker, event)
97+
98+
async def _process_response(
99+
self,
100+
tracker: ClientTaskManager,
101+
event: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent,
102+
) -> ClientEvent:
103+
if isinstance(event, Message):
104+
raise A2AClientInvalidStateError(
105+
"received a streamed Message from server after first response; this is not supported"
106+
)
107+
await tracker.process(event)
108+
task = tracker.get_task_or_raise()
109+
update = None if isinstance(event, Task) else event
110+
client_event = (task, update)
111+
await self.consume(client_event, self._card)
112+
return client_event
113+
114+
async def get_task(
115+
self,
116+
request: TaskQueryParams,
117+
*,
118+
context: ClientCallContext | None = None,
119+
) -> Task:
120+
"""Retrieves the current state and history of a specific task.
121+
122+
Args:
123+
request: The `TaskQueryParams` object specifying the task ID.
124+
context: The client call context.
125+
126+
Returns:
127+
A `Task` object representing the current state of the task.
128+
"""
129+
return await self._transport.get_task(request, context=context)
130+
131+
async def cancel_task(
132+
self,
133+
request: TaskIdParams,
134+
*,
135+
context: ClientCallContext | None = None,
136+
) -> Task:
137+
"""Requests the agent to cancel a specific task.
138+
139+
Args:
140+
request: The `TaskIdParams` object specifying the task ID.
141+
context: The client call context.
142+
143+
Returns:
144+
A `Task` object containing the updated task status.
145+
"""
146+
return await self._transport.cancel_task(request, context=context)
147+
148+
async def set_task_callback(
149+
self,
150+
request: TaskPushNotificationConfig,
151+
*,
152+
context: ClientCallContext | None = None,
153+
) -> TaskPushNotificationConfig:
154+
"""Sets or updates the push notification configuration for a specific task.
155+
156+
Args:
157+
request: The `TaskPushNotificationConfig` object with the new configuration.
158+
context: The client call context.
159+
160+
Returns:
161+
The created or updated `TaskPushNotificationConfig` object.
162+
"""
163+
return await self._transport.set_task_callback(request, context=context)
164+
165+
async def get_task_callback(
166+
self,
167+
request: GetTaskPushNotificationConfigParams,
168+
*,
169+
context: ClientCallContext | None = None,
170+
) -> TaskPushNotificationConfig:
171+
"""Retrieves the push notification configuration for a specific task.
172+
173+
Args:
174+
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
175+
context: The client call context.
176+
177+
Returns:
178+
A `TaskPushNotificationConfig` object containing the configuration.
179+
"""
180+
return await self._transport.get_task_callback(request, context=context)
181+
182+
async def resubscribe(
183+
self,
184+
request: TaskIdParams,
185+
*,
186+
context: ClientCallContext | None = None,
187+
) -> AsyncIterator[ClientEvent]:
188+
"""Resubscribes to a task's event stream.
189+
190+
This is only available if both the client and server support streaming.
191+
192+
Args:
193+
request: Parameters to identify the task to resubscribe to.
194+
context: The client call context.
195+
196+
Yields:
197+
An async iterator of `ClientEvent` objects.
198+
199+
Raises:
200+
NotImplementedError: If streaming is not supported by the client or server.
201+
"""
202+
if not self._config.streaming or not self._card.capabilities.streaming:
203+
raise NotImplementedError(
204+
"client and/or server do not support resubscription."
205+
)
206+
207+
tracker = ClientTaskManager()
208+
async for event in self._transport.resubscribe(request, context=context):
209+
yield await self._process_response(tracker, event)
210+
211+
async def get_card(
212+
self, *, context: ClientCallContext | None = None
213+
) -> AgentCard:
214+
"""Retrieves the agent's card.
215+
216+
This will fetch the authenticated card if necessary and update the
217+
client's internal state with the new card.
218+
219+
Args:
220+
context: The client call context.
221+
222+
Returns:
223+
The `AgentCard` for the agent.
224+
"""
225+
card = await self._transport.get_card(context=context)
226+
self._card = card
227+
return card
228+
229+
async def close(self) -> None:
230+
"""Closes the underlying transport."""
231+
await self._transport.close()

0 commit comments

Comments
 (0)