Skip to content

Commit 4c23416

Browse files
committed
refactor: use ClientCallContext for HTTP arguments in stream requests, convert rpc_request.data to dict, and adjust client blocking configuration logic.
1 parent eae38e9 commit 4c23416

3 files changed

Lines changed: 36 additions & 20 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ async def send_message(
9494
async for client_event in self._process_stream(stream):
9595
yield client_event
9696

97-
def _apply_client_config(self, request: SendMessageRequest):
97+
def _apply_client_config(self, request: SendMessageRequest) -> None:
9898
if not request.configuration.blocking and self._config.polling:
99-
request.configuration.blocking = self._config.polling
99+
request.configuration.blocking = not self._config.polling
100100
if (
101101
not request.configuration.HasField('push_notification_config')
102102
and self._config.push_notification_configs

src/a2a/client/transports/jsonrpc.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ async def send_message(
7777
params=json_format.MessageToDict(request),
7878
_id=str(uuid4()),
7979
)
80-
response_data = await self._send_request(rpc_request.data, context)
80+
response_data = await self._send_request(
81+
dict(rpc_request.data), context
82+
)
8183
json_rpc_response = JSONRPC20Response(**response_data)
8284
if json_rpc_response.error:
8385
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -99,7 +101,7 @@ async def send_message_streaming(
99101
_id=str(uuid4()),
100102
)
101103
async for event in self._send_stream_request(
102-
rpc_request.data,
104+
dict(rpc_request.data),
103105
context,
104106
):
105107
yield event
@@ -116,7 +118,9 @@ async def get_task(
116118
params=json_format.MessageToDict(request),
117119
_id=str(uuid4()),
118120
)
119-
response_data = await self._send_request(rpc_request.data, context)
121+
response_data = await self._send_request(
122+
dict(rpc_request.data), context
123+
)
120124
json_rpc_response = JSONRPC20Response(**response_data)
121125
if json_rpc_response.error:
122126
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -135,7 +139,9 @@ async def list_tasks(
135139
params=json_format.MessageToDict(request),
136140
_id=str(uuid4()),
137141
)
138-
response_data = await self._send_request(rpc_request.data, context)
142+
response_data = await self._send_request(
143+
dict(rpc_request.data), context
144+
)
139145
json_rpc_response = JSONRPC20Response(**response_data)
140146
if json_rpc_response.error:
141147
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -156,7 +162,9 @@ async def cancel_task(
156162
params=json_format.MessageToDict(request),
157163
_id=str(uuid4()),
158164
)
159-
response_data = await self._send_request(rpc_request.data, context)
165+
response_data = await self._send_request(
166+
dict(rpc_request.data), context
167+
)
160168
json_rpc_response = JSONRPC20Response(**response_data)
161169
if json_rpc_response.error:
162170
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -175,7 +183,9 @@ async def create_task_push_notification_config(
175183
params=json_format.MessageToDict(request),
176184
_id=str(uuid4()),
177185
)
178-
response_data = await self._send_request(rpc_request.data, context)
186+
response_data = await self._send_request(
187+
dict(rpc_request.data), context
188+
)
179189
json_rpc_response = JSONRPC20Response(**response_data)
180190
if json_rpc_response.error:
181191
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -196,7 +206,9 @@ async def get_task_push_notification_config(
196206
params=json_format.MessageToDict(request),
197207
_id=str(uuid4()),
198208
)
199-
response_data = await self._send_request(rpc_request.data, context)
209+
response_data = await self._send_request(
210+
dict(rpc_request.data), context
211+
)
200212
json_rpc_response = JSONRPC20Response(**response_data)
201213
if json_rpc_response.error:
202214
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -217,7 +229,9 @@ async def list_task_push_notification_configs(
217229
params=json_format.MessageToDict(request),
218230
_id=str(uuid4()),
219231
)
220-
response_data = await self._send_request(rpc_request.data, context)
232+
response_data = await self._send_request(
233+
dict(rpc_request.data), context
234+
)
221235
json_rpc_response = JSONRPC20Response(**response_data)
222236
if json_rpc_response.error:
223237
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -241,7 +255,9 @@ async def delete_task_push_notification_config(
241255
params=json_format.MessageToDict(request),
242256
_id=str(uuid4()),
243257
)
244-
response_data = await self._send_request(rpc_request.data, context)
258+
response_data = await self._send_request(
259+
dict(rpc_request.data), context
260+
)
245261
json_rpc_response = JSONRPC20Response(**response_data)
246262
if json_rpc_response.error:
247263
raise self._create_jsonrpc_error(json_rpc_response.error)
@@ -259,7 +275,7 @@ async def subscribe(
259275
_id=str(uuid4()),
260276
)
261277
async for event in self._send_stream_request(
262-
rpc_request.data,
278+
dict(rpc_request.data),
263279
context,
264280
):
265281
yield event
@@ -283,7 +299,7 @@ async def get_extended_agent_card(
283299
_id=str(uuid4()),
284300
)
285301
response_data = await self._send_request(
286-
rpc_request.data,
302+
dict(rpc_request.data),
287303
context,
288304
)
289305
json_rpc_response = JSONRPC20Response(**response_data)

src/a2a/client/transports/rest.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,12 @@ async def _send_stream_request(
337337
method: str,
338338
target: str,
339339
tenant: str,
340-
http_kwargs: dict[str, Any] | None = None,
340+
context: ClientCallContext | None = None,
341341
**kwargs: Any,
342342
) -> AsyncGenerator[StreamResponse]:
343-
final_kwargs = dict(http_kwargs or {})
344-
final_kwargs.update(kwargs)
345-
headers = dict(self.httpx_client.headers.items())
346-
headers.update(final_kwargs.get('headers', {}))
347-
final_kwargs['headers'] = headers
343+
http_kwargs = self._get_http_args(context)
344+
headers = http_kwargs.get('headers')
345+
timeout = http_kwargs.get('timeout', httpx.USE_CLIENT_DEFAULT)
348346

349347
path = self._get_path(target, tenant)
350348

@@ -353,7 +351,9 @@ async def _send_stream_request(
353351
method,
354352
f'{self.url}{path}',
355353
self._handle_http_error,
356-
**final_kwargs,
354+
headers=headers,
355+
timeout=timeout,
356+
**kwargs,
357357
):
358358
event: StreamResponse = Parse(sse_data, StreamResponse())
359359
yield event

0 commit comments

Comments
 (0)