Skip to content

Commit b0f2033

Browse files
committed
Refactor transport request methods to use explicit json and params keyword arguments and streamline http_kwargs passing.
1 parent 4c23416 commit b0f2033

3 files changed

Lines changed: 26 additions & 45 deletions

File tree

src/a2a/client/transports/jsonrpc.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ async def _send_request(
349349
context: ClientCallContext | None = None,
350350
) -> dict[str, Any]:
351351
http_kwargs = self._get_http_args(context)
352+
352353
request = self.httpx_client.build_request(
353354
'POST', self.url, json=payload, **(http_kwargs or {})
354355
)
@@ -358,22 +359,16 @@ async def _send_stream_request(
358359
self,
359360
rpc_request_payload: dict[str, Any],
360361
context: ClientCallContext | None = None,
361-
**kwargs: Any,
362362
) -> AsyncGenerator[StreamResponse]:
363363
http_kwargs = self._get_http_args(context)
364-
final_kwargs = dict(http_kwargs or {})
365-
final_kwargs.update(kwargs)
366-
headers = dict(self.httpx_client.headers.items())
367-
headers.update(final_kwargs.get('headers', {}))
368-
final_kwargs['headers'] = headers
369364

370365
async for sse_data in send_http_stream_request(
371366
self.httpx_client,
372367
'POST',
373368
self.url,
374369
None,
375370
json=rpc_request_payload,
376-
**final_kwargs,
371+
**http_kwargs,
377372
):
378373
json_rpc_response = JSONRPC20Response.from_json(sse_data)
379374
if json_rpc_response.error:

src/a2a/client/transports/rest.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ async def send_message(
7575
'POST',
7676
'/message:send',
7777
request.tenant,
78-
MessageToDict(request),
7978
context=context,
79+
json=MessageToDict(request),
8080
)
8181
response: SendMessageResponse = ParseDict(
8282
response_data, SendMessageResponse()
@@ -90,14 +90,13 @@ async def send_message_streaming(
9090
context: ClientCallContext | None = None,
9191
) -> AsyncGenerator[StreamResponse]:
9292
"""Sends a streaming message request to the agent and yields responses as they arrive."""
93-
http_kwargs = self._get_http_args(context)
9493
payload = MessageToDict(request)
9594

9695
async for event in self._send_stream_request(
9796
'POST',
9897
'/message:stream',
9998
request.tenant,
100-
http_kwargs=http_kwargs,
99+
context=context,
101100
json=payload,
102101
):
103102
yield event
@@ -117,8 +116,8 @@ async def get_task(
117116
'GET',
118117
f'/tasks/{request.id}',
119118
request.tenant,
120-
params,
121119
context=context,
120+
params=params,
122121
)
123122
response: Task = ParseDict(response_data, Task())
124123
return response
@@ -134,8 +133,8 @@ async def list_tasks(
134133
'GET',
135134
'/tasks',
136135
request.tenant,
137-
_model_to_query_params(request),
138136
context=context,
137+
params=MessageToDict(request),
139138
)
140139
response: ListTasksResponse = ParseDict(
141140
response_data, ListTasksResponse()
@@ -153,8 +152,8 @@ async def cancel_task(
153152
'POST',
154153
f'/tasks/{request.id}:cancel',
155154
request.tenant,
156-
MessageToDict(request),
157155
context=context,
156+
json=MessageToDict(request),
158157
)
159158
response: Task = ParseDict(response_data, Task())
160159
return response
@@ -170,8 +169,8 @@ async def create_task_push_notification_config(
170169
'POST',
171170
f'/tasks/{request.task_id}/pushNotificationConfigs',
172171
request.tenant,
173-
MessageToDict(request),
174172
context=context,
173+
json=MessageToDict(request),
175174
)
176175
response: TaskPushNotificationConfig = ParseDict(
177176
response_data, TaskPushNotificationConfig()
@@ -195,8 +194,8 @@ async def get_task_push_notification_config(
195194
'GET',
196195
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
197196
request.tenant,
198-
params,
199197
context=context,
198+
params=params,
200199
)
201200
response: TaskPushNotificationConfig = ParseDict(
202201
response_data, TaskPushNotificationConfig()
@@ -218,8 +217,8 @@ async def list_task_push_notification_configs(
218217
'GET',
219218
f'/tasks/{request.task_id}/pushNotificationConfigs',
220219
request.tenant,
221-
params,
222220
context=context,
221+
params=params,
223222
)
224223
response: ListTaskPushNotificationConfigsResponse = ParseDict(
225224
response_data, ListTaskPushNotificationConfigsResponse()
@@ -243,8 +242,8 @@ async def delete_task_push_notification_config(
243242
'DELETE',
244243
f'/tasks/{request.task_id}/pushNotificationConfigs/{request.id}',
245244
request.tenant,
246-
params,
247245
context=context,
246+
params=params,
248247
)
249248

250249
async def subscribe(
@@ -254,13 +253,11 @@ async def subscribe(
254253
context: ClientCallContext | None = None,
255254
) -> AsyncGenerator[StreamResponse]:
256255
"""Reconnects to get task updates."""
257-
http_kwargs = self._get_http_args(context)
258-
259256
async for event in self._send_stream_request(
260257
'GET',
261258
f'/tasks/{request.id}:subscribe',
262259
request.tenant,
263-
http_kwargs=http_kwargs,
260+
context=context,
264261
):
265262
yield event
266263

@@ -278,7 +275,7 @@ async def get_extended_agent_card(
278275
return card
279276

280277
response_data = await self._execute_request(
281-
'GET', '/extendedAgentCard', request.tenant, {}, context
278+
'GET', '/extendedAgentCard', request.tenant, context=context
282279
)
283280
response: AgentCard = ParseDict(response_data, AgentCard())
284281

@@ -338,22 +335,19 @@ async def _send_stream_request(
338335
target: str,
339336
tenant: str,
340337
context: ClientCallContext | None = None,
341-
**kwargs: Any,
338+
*,
339+
json: dict[str, Any] | None = None,
342340
) -> AsyncGenerator[StreamResponse]:
343-
http_kwargs = self._get_http_args(context)
344-
headers = http_kwargs.get('headers')
345-
timeout = http_kwargs.get('timeout', httpx.USE_CLIENT_DEFAULT)
346-
347341
path = self._get_path(target, tenant)
342+
http_kwargs = self._get_http_args(context)
348343

349344
async for sse_data in send_http_stream_request(
350345
self.httpx_client,
351346
method,
352347
f'{self.url}{path}',
353348
self._handle_http_error,
354-
headers=headers,
355-
timeout=timeout,
356-
**kwargs,
349+
json=json,
350+
**http_kwargs,
357351
):
358352
event: StreamResponse = Parse(sse_data, StreamResponse())
359353
yield event
@@ -368,26 +362,20 @@ async def _execute_request(
368362
method: str,
369363
target: str,
370364
tenant: str,
371-
payload: dict[str, Any] | None = None,
372365
context: ClientCallContext | None = None,
366+
*,
367+
json: dict[str, Any] | None = None,
368+
params: dict[str, Any] | None = None,
373369
) -> dict[str, Any]:
374370
path = self._get_path(target, tenant)
375371
http_kwargs = self._get_http_args(context)
376-
payload = payload or {}
377-
378-
headers = http_kwargs.get('headers')
379-
timeout = http_kwargs.get('timeout', httpx.USE_CLIENT_DEFAULT)
380-
381-
json_payload = payload if method == 'POST' else None
382-
params = payload if method != 'POST' else None
383372

384373
request = self.httpx_client.build_request(
385374
method,
386375
f'{self.url}{path}',
387-
json=json_payload,
376+
json=json,
388377
params=params,
389-
headers=headers, # type: ignore[arg-type]
390-
timeout=timeout, # type: ignore[arg-type]
378+
**http_kwargs,
391379
)
392380
return await self._send_request(request)
393381

tests/client/transports/test_rest_client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,11 @@ async def test_get_card_with_extended_card_support_with_extensions(
363363
await client.get_extended_agent_card(request, context=context)
364364

365365
mock_execute_request.assert_called_once()
366-
# _execute_request(method, target, tenant, payload, context)
366+
# _execute_request(method, target, tenant, context)
367367
call_args = mock_execute_request.call_args
368-
assert (
369-
call_args[1].get('context') == context or call_args[0][4] == context
370-
)
368+
assert call_args[1].get('context') == context or call_args[0][3] == context
371369

372-
_context = call_args[1].get('context') or call_args[0][4]
370+
_context = call_args[1].get('context') or call_args[0][3]
373371
assert _context.service_parameters == {
374372
HTTP_EXTENSION_HEADER: extensions_str
375373
}

0 commit comments

Comments
 (0)