Skip to content

Commit 0f8f9a9

Browse files
committed
Merge branch '1.0-dev' of https://github.com/a2aproject/a2a-python into database-compatibility-1-0
2 parents 3e44044 + 2e2d431 commit 0f8f9a9

57 files changed

Lines changed: 1168 additions & 1804 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

buf.gen.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
version: v2
33
inputs:
44
- git_repo: https://github.com/a2aproject/A2A.git
5-
ref: 1997c9d63058ca0b89361a7d6e508f4641a6f68b
5+
ref: main
66
subdir: specification
77
managed:
88
enabled: true

src/a2a/client/base_client.py

Lines changed: 43 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
from collections.abc import AsyncGenerator, AsyncIterator, Callable
2-
from typing import Any
32

43
from a2a.client.client import (
54
Client,
6-
ClientCallContext,
75
ClientConfig,
86
ClientEvent,
97
Consumer,
108
)
119
from a2a.client.client_task_manager import ClientTaskManager
12-
from a2a.client.middleware import ClientCallInterceptor
10+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1311
from a2a.client.transports.base import ClientTransport
1412
from a2a.types.a2a_pb2 import (
1513
AgentCard,
1614
CancelTaskRequest,
17-
CreateTaskPushNotificationConfigRequest,
1815
DeleteTaskPushNotificationConfigRequest,
1916
GetExtendedAgentCardRequest,
2017
GetTaskPushNotificationConfigRequest,
@@ -23,8 +20,6 @@
2320
ListTaskPushNotificationConfigsResponse,
2421
ListTasksRequest,
2522
ListTasksResponse,
26-
Message,
27-
SendMessageConfiguration,
2823
SendMessageRequest,
2924
StreamResponse,
3025
SubscribeToTaskRequest,
@@ -51,12 +46,9 @@ def __init__(
5146

5247
async def send_message(
5348
self,
54-
request: Message,
49+
request: SendMessageRequest,
5550
*,
56-
configuration: SendMessageConfiguration | None = None,
5751
context: ClientCallContext | None = None,
58-
request_metadata: dict[str, Any] | None = None,
59-
extensions: list[str] | None = None,
6052
) -> AsyncIterator[ClientEvent]:
6153
"""Sends a message to the agent.
6254
@@ -66,35 +58,15 @@ async def send_message(
6658
6759
Args:
6860
request: The message to send to the agent.
69-
configuration: Optional per-call overrides for message sending behavior.
70-
context: The client call context.
71-
request_metadata: Extensions Metadata attached to the request.
72-
extensions: List of extensions to be activated.
61+
context: Optional client call context.
7362
7463
Yields:
7564
An async iterator of `ClientEvent`
7665
"""
77-
config = SendMessageConfiguration(
78-
accepted_output_modes=self._config.accepted_output_modes,
79-
blocking=not self._config.polling,
80-
push_notification_config=(
81-
self._config.push_notification_configs[0]
82-
if self._config.push_notification_configs
83-
else None
84-
),
85-
)
86-
87-
if configuration:
88-
config.MergeFrom(configuration)
89-
config.blocking = configuration.blocking
90-
91-
send_message_request = SendMessageRequest(
92-
message=request, configuration=config, metadata=request_metadata
93-
)
94-
66+
self._apply_client_config(request)
9567
if not self._config.streaming or not self._card.capabilities.streaming:
9668
response = await self._transport.send_message(
97-
send_message_request, context=context, extensions=extensions
69+
request, context=context
9870
)
9971

10072
# In non-streaming case we convert to a StreamResponse so that the
@@ -116,11 +88,29 @@ async def send_message(
11688
return
11789

11890
stream = self._transport.send_message_streaming(
119-
send_message_request, context=context, extensions=extensions
91+
request, context=context
12092
)
12193
async for client_event in self._process_stream(stream):
12294
yield client_event
12395

96+
def _apply_client_config(self, request: SendMessageRequest) -> None:
97+
if not request.configuration.blocking and self._config.polling:
98+
request.configuration.blocking = not self._config.polling
99+
if (
100+
not request.configuration.HasField('task_push_notification_config')
101+
and self._config.push_notification_configs
102+
):
103+
request.configuration.task_push_notification_config.CopyFrom(
104+
self._config.push_notification_configs[0]
105+
)
106+
if (
107+
not request.configuration.accepted_output_modes
108+
and self._config.accepted_output_modes
109+
):
110+
request.configuration.accepted_output_modes.extend(
111+
self._config.accepted_output_modes
112+
)
113+
124114
async def _process_stream(
125115
self, stream: AsyncIterator[StreamResponse]
126116
) -> AsyncGenerator[ClientEvent]:
@@ -147,21 +137,17 @@ async def get_task(
147137
request: GetTaskRequest,
148138
*,
149139
context: ClientCallContext | None = None,
150-
extensions: list[str] | None = None,
151140
) -> Task:
152141
"""Retrieves the current state and history of a specific task.
153142
154143
Args:
155144
request: The `GetTaskRequest` object specifying the task ID.
156-
context: The client call context.
157-
extensions: List of extensions to be activated.
145+
context: Optional client call context.
158146
159147
Returns:
160148
A `Task` object representing the current state of the task.
161149
"""
162-
return await self._transport.get_task(
163-
request, context=context, extensions=extensions
164-
)
150+
return await self._transport.get_task(request, context=context)
165151

166152
async def list_tasks(
167153
self,
@@ -177,118 +163,104 @@ async def cancel_task(
177163
request: CancelTaskRequest,
178164
*,
179165
context: ClientCallContext | None = None,
180-
extensions: list[str] | None = None,
181166
) -> Task:
182167
"""Requests the agent to cancel a specific task.
183168
184169
Args:
185170
request: The `CancelTaskRequest` object specifying the task ID.
186-
context: The client call context.
187-
extensions: List of extensions to be activated.
171+
context: Optional client call context.
188172
189173
Returns:
190174
A `Task` object containing the updated task status.
191175
"""
192-
return await self._transport.cancel_task(
193-
request, context=context, extensions=extensions
194-
)
176+
return await self._transport.cancel_task(request, context=context)
195177

196178
async def create_task_push_notification_config(
197179
self,
198-
request: CreateTaskPushNotificationConfigRequest,
180+
request: TaskPushNotificationConfig,
199181
*,
200182
context: ClientCallContext | None = None,
201-
extensions: list[str] | None = None,
202183
) -> TaskPushNotificationConfig:
203184
"""Sets or updates the push notification configuration for a specific task.
204185
205186
Args:
206187
request: The `TaskPushNotificationConfig` object with the new configuration.
207-
context: The client call context.
208-
extensions: List of extensions to be activated.
188+
context: Optional client call context.
209189
210190
Returns:
211191
The created or updated `TaskPushNotificationConfig` object.
212192
"""
213193
return await self._transport.create_task_push_notification_config(
214-
request, context=context, extensions=extensions
194+
request, context=context
215195
)
216196

217197
async def get_task_push_notification_config(
218198
self,
219199
request: GetTaskPushNotificationConfigRequest,
220200
*,
221201
context: ClientCallContext | None = None,
222-
extensions: list[str] | None = None,
223202
) -> TaskPushNotificationConfig:
224203
"""Retrieves the push notification configuration for a specific task.
225204
226205
Args:
227206
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
228-
context: The client call context.
229-
extensions: List of extensions to be activated.
207+
context: Optional client call context.
230208
231209
Returns:
232210
A `TaskPushNotificationConfig` object containing the configuration.
233211
"""
234212
return await self._transport.get_task_push_notification_config(
235-
request, context=context, extensions=extensions
213+
request, context=context
236214
)
237215

238216
async def list_task_push_notification_configs(
239217
self,
240218
request: ListTaskPushNotificationConfigsRequest,
241219
*,
242220
context: ClientCallContext | None = None,
243-
extensions: list[str] | None = None,
244221
) -> ListTaskPushNotificationConfigsResponse:
245222
"""Lists push notification configurations for a specific task.
246223
247224
Args:
248225
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
249-
context: The client call context.
250-
extensions: List of extensions to be activated.
226+
context: Optional client call context.
251227
252228
Returns:
253229
A `ListTaskPushNotificationConfigsResponse` object.
254230
"""
255231
return await self._transport.list_task_push_notification_configs(
256-
request, context=context, extensions=extensions
232+
request, context=context
257233
)
258234

259235
async def delete_task_push_notification_config(
260236
self,
261237
request: DeleteTaskPushNotificationConfigRequest,
262238
*,
263239
context: ClientCallContext | None = None,
264-
extensions: list[str] | None = None,
265240
) -> None:
266241
"""Deletes the push notification configuration for a specific task.
267242
268243
Args:
269244
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
270-
context: The client call context.
271-
extensions: List of extensions to be activated.
245+
context: Optional client call context.
272246
"""
273247
await self._transport.delete_task_push_notification_config(
274-
request, context=context, extensions=extensions
248+
request, context=context
275249
)
276250

277251
async def subscribe(
278252
self,
279253
request: SubscribeToTaskRequest,
280254
*,
281255
context: ClientCallContext | None = None,
282-
extensions: list[str] | None = None,
283256
) -> AsyncIterator[ClientEvent]:
284257
"""Resubscribes to a task's event stream.
285258
286259
This is only available if both the client and server support streaming.
287260
288261
Args:
289262
request: Parameters to identify the task to resubscribe to.
290-
context: The client call context.
291-
extensions: List of extensions to be activated.
263+
context: Optional client call context.
292264
293265
Yields:
294266
An async iterator of `ClientEvent` objects.
@@ -304,9 +276,7 @@ async def subscribe(
304276
# Note: resubscribe can only be called on an existing task. As such,
305277
# we should never see Message updates, despite the typing of the service
306278
# definition indicating it may be possible.
307-
stream = self._transport.subscribe(
308-
request, context=context, extensions=extensions
309-
)
279+
stream = self._transport.subscribe(request, context=context)
310280
async for client_event in self._process_stream(stream):
311281
yield client_event
312282

@@ -315,7 +285,6 @@ async def get_extended_agent_card(
315285
request: GetExtendedAgentCardRequest,
316286
*,
317287
context: ClientCallContext | None = None,
318-
extensions: list[str] | None = None,
319288
signature_verifier: Callable[[AgentCard], None] | None = None,
320289
) -> AgentCard:
321290
"""Retrieves the agent's card.
@@ -325,8 +294,7 @@ async def get_extended_agent_card(
325294
326295
Args:
327296
request: The `GetExtendedAgentCardRequest` object specifying the request.
328-
context: The client call context.
329-
extensions: List of extensions to be activated.
297+
context: Optional client call context.
330298
signature_verifier: A callable used to verify the agent card's signatures.
331299
332300
Returns:
@@ -335,9 +303,10 @@ async def get_extended_agent_card(
335303
card = await self._transport.get_extended_agent_card(
336304
request,
337305
context=context,
338-
extensions=extensions,
339-
signature_verifier=signature_verifier,
340306
)
307+
if signature_verifier:
308+
signature_verifier(card)
309+
341310
self._card = card
342311
return card
343312

0 commit comments

Comments
 (0)