11from collections .abc import AsyncGenerator , AsyncIterator , Callable
2- from typing import Any
32
43from a2a .client .client import (
54 Client ,
6- ClientCallContext ,
75 ClientConfig ,
86 ClientEvent ,
97 Consumer ,
108)
119from a2a .client .client_task_manager import ClientTaskManager
12- from a2a .client .middleware import ClientCallInterceptor
10+ from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
1311from a2a .client .transports .base import ClientTransport
1412from a2a .types .a2a_pb2 import (
1513 AgentCard ,
1614 CancelTaskRequest ,
17- CreateTaskPushNotificationConfigRequest ,
1815 DeleteTaskPushNotificationConfigRequest ,
1916 GetExtendedAgentCardRequest ,
2017 GetTaskPushNotificationConfigRequest ,
2320 ListTaskPushNotificationConfigsResponse ,
2421 ListTasksRequest ,
2522 ListTasksResponse ,
26- Message ,
27- SendMessageConfiguration ,
2823 SendMessageRequest ,
2924 StreamResponse ,
3025 SubscribeToTaskRequest ,
@@ -51,12 +46,9 @@ def __init__(
5146
5247 async def send_message (
5348 self ,
54- request : Message ,
49+ request : SendMessageRequest ,
5550 * ,
56- configuration : SendMessageConfiguration | None = None ,
5751 context : ClientCallContext | None = None ,
58- request_metadata : dict [str , Any ] | None = None ,
59- extensions : list [str ] | None = None ,
6052 ) -> AsyncIterator [ClientEvent ]:
6153 """Sends a message to the agent.
6254
@@ -66,35 +58,15 @@ async def send_message(
6658
6759 Args:
6860 request: The message to send to the agent.
69- configuration: Optional per-call overrides for message sending behavior.
70- context: The client call context.
71- request_metadata: Extensions Metadata attached to the request.
72- extensions: List of extensions to be activated.
61+ context: Optional client call context.
7362
7463 Yields:
7564 An async iterator of `ClientEvent`
7665 """
77- config = SendMessageConfiguration (
78- accepted_output_modes = self ._config .accepted_output_modes ,
79- blocking = not self ._config .polling ,
80- push_notification_config = (
81- self ._config .push_notification_configs [0 ]
82- if self ._config .push_notification_configs
83- else None
84- ),
85- )
86-
87- if configuration :
88- config .MergeFrom (configuration )
89- config .blocking = configuration .blocking
90-
91- send_message_request = SendMessageRequest (
92- message = request , configuration = config , metadata = request_metadata
93- )
94-
66+ self ._apply_client_config (request )
9567 if not self ._config .streaming or not self ._card .capabilities .streaming :
9668 response = await self ._transport .send_message (
97- send_message_request , context = context , extensions = extensions
69+ request , context = context
9870 )
9971
10072 # In non-streaming case we convert to a StreamResponse so that the
@@ -116,11 +88,29 @@ async def send_message(
11688 return
11789
11890 stream = self ._transport .send_message_streaming (
119- send_message_request , context = context , extensions = extensions
91+ request , context = context
12092 )
12193 async for client_event in self ._process_stream (stream ):
12294 yield client_event
12395
96+ def _apply_client_config (self , request : SendMessageRequest ) -> None :
97+ if not request .configuration .blocking and self ._config .polling :
98+ request .configuration .blocking = not self ._config .polling
99+ if (
100+ not request .configuration .HasField ('task_push_notification_config' )
101+ and self ._config .push_notification_configs
102+ ):
103+ request .configuration .task_push_notification_config .CopyFrom (
104+ self ._config .push_notification_configs [0 ]
105+ )
106+ if (
107+ not request .configuration .accepted_output_modes
108+ and self ._config .accepted_output_modes
109+ ):
110+ request .configuration .accepted_output_modes .extend (
111+ self ._config .accepted_output_modes
112+ )
113+
124114 async def _process_stream (
125115 self , stream : AsyncIterator [StreamResponse ]
126116 ) -> AsyncGenerator [ClientEvent ]:
@@ -147,21 +137,17 @@ async def get_task(
147137 request : GetTaskRequest ,
148138 * ,
149139 context : ClientCallContext | None = None ,
150- extensions : list [str ] | None = None ,
151140 ) -> Task :
152141 """Retrieves the current state and history of a specific task.
153142
154143 Args:
155144 request: The `GetTaskRequest` object specifying the task ID.
156- context: The client call context.
157- extensions: List of extensions to be activated.
145+ context: Optional client call context.
158146
159147 Returns:
160148 A `Task` object representing the current state of the task.
161149 """
162- return await self ._transport .get_task (
163- request , context = context , extensions = extensions
164- )
150+ return await self ._transport .get_task (request , context = context )
165151
166152 async def list_tasks (
167153 self ,
@@ -177,118 +163,104 @@ async def cancel_task(
177163 request : CancelTaskRequest ,
178164 * ,
179165 context : ClientCallContext | None = None ,
180- extensions : list [str ] | None = None ,
181166 ) -> Task :
182167 """Requests the agent to cancel a specific task.
183168
184169 Args:
185170 request: The `CancelTaskRequest` object specifying the task ID.
186- context: The client call context.
187- extensions: List of extensions to be activated.
171+ context: Optional client call context.
188172
189173 Returns:
190174 A `Task` object containing the updated task status.
191175 """
192- return await self ._transport .cancel_task (
193- request , context = context , extensions = extensions
194- )
176+ return await self ._transport .cancel_task (request , context = context )
195177
196178 async def create_task_push_notification_config (
197179 self ,
198- request : CreateTaskPushNotificationConfigRequest ,
180+ request : TaskPushNotificationConfig ,
199181 * ,
200182 context : ClientCallContext | None = None ,
201- extensions : list [str ] | None = None ,
202183 ) -> TaskPushNotificationConfig :
203184 """Sets or updates the push notification configuration for a specific task.
204185
205186 Args:
206187 request: The `TaskPushNotificationConfig` object with the new configuration.
207- context: The client call context.
208- extensions: List of extensions to be activated.
188+ context: Optional client call context.
209189
210190 Returns:
211191 The created or updated `TaskPushNotificationConfig` object.
212192 """
213193 return await self ._transport .create_task_push_notification_config (
214- request , context = context , extensions = extensions
194+ request , context = context
215195 )
216196
217197 async def get_task_push_notification_config (
218198 self ,
219199 request : GetTaskPushNotificationConfigRequest ,
220200 * ,
221201 context : ClientCallContext | None = None ,
222- extensions : list [str ] | None = None ,
223202 ) -> TaskPushNotificationConfig :
224203 """Retrieves the push notification configuration for a specific task.
225204
226205 Args:
227206 request: The `GetTaskPushNotificationConfigParams` object specifying the task.
228- context: The client call context.
229- extensions: List of extensions to be activated.
207+ context: Optional client call context.
230208
231209 Returns:
232210 A `TaskPushNotificationConfig` object containing the configuration.
233211 """
234212 return await self ._transport .get_task_push_notification_config (
235- request , context = context , extensions = extensions
213+ request , context = context
236214 )
237215
238216 async def list_task_push_notification_configs (
239217 self ,
240218 request : ListTaskPushNotificationConfigsRequest ,
241219 * ,
242220 context : ClientCallContext | None = None ,
243- extensions : list [str ] | None = None ,
244221 ) -> ListTaskPushNotificationConfigsResponse :
245222 """Lists push notification configurations for a specific task.
246223
247224 Args:
248225 request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
249- context: The client call context.
250- extensions: List of extensions to be activated.
226+ context: Optional client call context.
251227
252228 Returns:
253229 A `ListTaskPushNotificationConfigsResponse` object.
254230 """
255231 return await self ._transport .list_task_push_notification_configs (
256- request , context = context , extensions = extensions
232+ request , context = context
257233 )
258234
259235 async def delete_task_push_notification_config (
260236 self ,
261237 request : DeleteTaskPushNotificationConfigRequest ,
262238 * ,
263239 context : ClientCallContext | None = None ,
264- extensions : list [str ] | None = None ,
265240 ) -> None :
266241 """Deletes the push notification configuration for a specific task.
267242
268243 Args:
269244 request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
270- context: The client call context.
271- extensions: List of extensions to be activated.
245+ context: Optional client call context.
272246 """
273247 await self ._transport .delete_task_push_notification_config (
274- request , context = context , extensions = extensions
248+ request , context = context
275249 )
276250
277251 async def subscribe (
278252 self ,
279253 request : SubscribeToTaskRequest ,
280254 * ,
281255 context : ClientCallContext | None = None ,
282- extensions : list [str ] | None = None ,
283256 ) -> AsyncIterator [ClientEvent ]:
284257 """Resubscribes to a task's event stream.
285258
286259 This is only available if both the client and server support streaming.
287260
288261 Args:
289262 request: Parameters to identify the task to resubscribe to.
290- context: The client call context.
291- extensions: List of extensions to be activated.
263+ context: Optional client call context.
292264
293265 Yields:
294266 An async iterator of `ClientEvent` objects.
@@ -304,9 +276,7 @@ async def subscribe(
304276 # Note: resubscribe can only be called on an existing task. As such,
305277 # we should never see Message updates, despite the typing of the service
306278 # definition indicating it may be possible.
307- stream = self ._transport .subscribe (
308- request , context = context , extensions = extensions
309- )
279+ stream = self ._transport .subscribe (request , context = context )
310280 async for client_event in self ._process_stream (stream ):
311281 yield client_event
312282
@@ -315,7 +285,6 @@ async def get_extended_agent_card(
315285 request : GetExtendedAgentCardRequest ,
316286 * ,
317287 context : ClientCallContext | None = None ,
318- extensions : list [str ] | None = None ,
319288 signature_verifier : Callable [[AgentCard ], None ] | None = None ,
320289 ) -> AgentCard :
321290 """Retrieves the agent's card.
@@ -325,8 +294,7 @@ async def get_extended_agent_card(
325294
326295 Args:
327296 request: The `GetExtendedAgentCardRequest` object specifying the request.
328- context: The client call context.
329- extensions: List of extensions to be activated.
297+ context: Optional client call context.
330298 signature_verifier: A callable used to verify the agent card's signatures.
331299
332300 Returns:
@@ -335,9 +303,10 @@ async def get_extended_agent_card(
335303 card = await self ._transport .get_extended_agent_card (
336304 request ,
337305 context = context ,
338- extensions = extensions ,
339- signature_verifier = signature_verifier ,
340306 )
307+ if signature_verifier :
308+ signature_verifier (card )
309+
341310 self ._card = card
342311 return card
343312
0 commit comments