Skip to content

Commit 51c84f7

Browse files
committed
Cosmetics
1 parent ec2248a commit 51c84f7

5 files changed

Lines changed: 184 additions & 181 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
5656
details = e.details()
5757
if isinstance(details, str) and ': ' in details:
5858
error_type_name, error_message = details.split(': ', 1)
59-
# TODO(#723): Resolving imports by name is a temporary hack until proper error handling structure is added in #723.
59+
# TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723.
6060
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type_name)
6161
if exception_cls:
6262
raise exception_cls(error_message) from e

src/a2a/client/transports/rest.py

Lines changed: 171 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -63,64 +63,6 @@ def __init__(
6363
self._needs_extended_card = agent_card.capabilities.extended_agent_card
6464
self.extensions = extensions
6565

66-
async def _apply_interceptors(
67-
self,
68-
request_payload: dict[str, Any],
69-
http_kwargs: dict[str, Any] | None,
70-
context: ClientCallContext | None,
71-
) -> tuple[dict[str, Any], dict[str, Any]]:
72-
final_http_kwargs = http_kwargs or {}
73-
final_request_payload = request_payload
74-
# TODO: Implement interceptors for other transports
75-
return final_request_payload, final_http_kwargs
76-
77-
def _get_http_args(
78-
self, context: ClientCallContext | None
79-
) -> dict[str, Any] | None:
80-
return context.state.get('http_kwargs') if context else None
81-
82-
async def _prepare_send_message(
83-
self,
84-
request: SendMessageRequest,
85-
context: ClientCallContext | None,
86-
extensions: list[str] | None = None,
87-
) -> tuple[dict[str, Any], dict[str, Any]]:
88-
payload = MessageToDict(request)
89-
modified_kwargs = update_extension_header(
90-
self._get_http_args(context),
91-
extensions if extensions is not None else self.extensions,
92-
)
93-
payload, modified_kwargs = await self._apply_interceptors(
94-
payload,
95-
modified_kwargs,
96-
context,
97-
)
98-
return payload, modified_kwargs
99-
100-
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
101-
"""Handles HTTP status errors and raises the appropriate A2AError."""
102-
try:
103-
error_data = e.response.json()
104-
error_type = error_data.get('type')
105-
message = error_data.get('message', str(e))
106-
107-
if isinstance(error_type, str):
108-
# TODO(#723): Resolving imports by name is a temporary hack until proper error handling structure is added in #723.
109-
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type)
110-
if exception_cls:
111-
raise exception_cls(message) from e
112-
except (json.JSONDecodeError, ValueError):
113-
pass
114-
115-
# Fallback mappings for status codes if 'type' is missing or unknown
116-
status_code = e.response.status_code
117-
if status_code == httpx.codes.NOT_FOUND:
118-
raise MethodNotFoundError(
119-
f'Resource not found: {e.request.url}'
120-
) from e
121-
122-
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
123-
12466
async def send_message(
12567
self,
12668
request: SendMessageRequest,
@@ -152,94 +94,13 @@ async def send_message_streaming(
15294
request, context, extensions
15395
)
15496

155-
modified_kwargs.setdefault('timeout', None)
156-
157-
try:
158-
async with aconnect_sse(
159-
self.httpx_client,
160-
'POST',
161-
f'{self.url}/v1/message:stream',
162-
json=payload,
163-
**modified_kwargs,
164-
) as event_source:
165-
try:
166-
event_source.response.raise_for_status()
167-
async for sse in event_source.aiter_sse():
168-
event: StreamResponse = Parse(
169-
sse.data, StreamResponse()
170-
)
171-
yield event
172-
except httpx.HTTPStatusError as e:
173-
self._handle_http_error(e)
174-
except SSEError as e:
175-
raise A2AClientError(
176-
f'Invalid SSE response or protocol error: {e}'
177-
) from e
178-
except httpx.TimeoutException as e:
179-
raise A2AClientError('Client Request timed out') from e
180-
except httpx.RequestError as e:
181-
raise A2AClientError(f'Network communication error: {e}') from e
182-
except json.JSONDecodeError as e:
183-
raise A2AClientError(f'JSON Decode Error: {e}') from e
184-
185-
async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
186-
try:
187-
response = await self.httpx_client.send(request)
188-
response.raise_for_status()
189-
return response.json()
190-
except httpx.TimeoutException as e:
191-
raise A2AClientError('Client Request timed out') from e
192-
except httpx.HTTPStatusError as e:
193-
self._handle_http_error(e)
194-
except json.JSONDecodeError as e:
195-
raise A2AClientError(f'JSON Decode Error: {e}') from e
196-
except httpx.RequestError as e:
197-
raise A2AClientError(f'Network communication error: {e}') from e
198-
199-
async def _send_post_request(
200-
self,
201-
target: str,
202-
rpc_request_payload: dict[str, Any],
203-
http_kwargs: dict[str, Any] | None = None,
204-
) -> dict[str, Any]:
205-
return await self._send_request(
206-
self.httpx_client.build_request(
207-
'POST',
208-
f'{self.url}{target}',
209-
json=rpc_request_payload,
210-
**(http_kwargs or {}),
211-
)
212-
)
213-
214-
async def _send_get_request(
215-
self,
216-
target: str,
217-
query_params: dict[str, str],
218-
http_kwargs: dict[str, Any] | None = None,
219-
) -> dict[str, Any]:
220-
return await self._send_request(
221-
self.httpx_client.build_request(
222-
'GET',
223-
f'{self.url}{target}',
224-
params=query_params,
225-
**(http_kwargs or {}),
226-
)
227-
)
228-
229-
async def _send_delete_request(
230-
self,
231-
target: str,
232-
query_params: dict[str, Any],
233-
http_kwargs: dict[str, Any] | None = None,
234-
) -> dict[str, Any]:
235-
return await self._send_request(
236-
self.httpx_client.build_request(
237-
'DELETE',
238-
f'{self.url}{target}',
239-
params=query_params,
240-
**(http_kwargs or {}),
241-
)
242-
)
97+
async for event in self._send_stream_request(
98+
'POST',
99+
'/v1/message:stream',
100+
http_kwargs=modified_kwargs,
101+
json=payload,
102+
):
103+
yield event
243104

244105
async def get_task(
245106
self,
@@ -309,7 +170,7 @@ async def cancel_task(
309170
payload = MessageToDict(request)
310171
modified_kwargs = update_extension_header(
311172
self._get_http_args(context),
312-
extensions if extensions not in (None, []) else self.extensions,
173+
extensions if extensions is not None else self.extensions,
313174
)
314175
payload, modified_kwargs = await self._apply_interceptors(
315176
payload,
@@ -450,35 +311,13 @@ async def subscribe(
450311
self._get_http_args(context),
451312
extensions if extensions is not None else self.extensions,
452313
)
453-
modified_kwargs.setdefault('timeout', None)
454314

455-
try:
456-
async with aconnect_sse(
457-
self.httpx_client,
458-
'GET',
459-
f'{self.url}/v1/tasks/{request.id}:subscribe',
460-
**modified_kwargs,
461-
) as event_source:
462-
try:
463-
async for sse in event_source.aiter_sse():
464-
if not sse.data:
465-
continue
466-
event: StreamResponse = Parse(
467-
sse.data, StreamResponse()
468-
)
469-
yield event
470-
except httpx.HTTPStatusError as e:
471-
self._handle_http_error(e)
472-
except SSEError as e:
473-
raise A2AClientError(
474-
f'Invalid SSE response or protocol error: {e}'
475-
) from e
476-
except httpx.TimeoutException as e:
477-
raise A2AClientError('Client Request timed out') from e
478-
except httpx.RequestError as e:
479-
raise A2AClientError(f'Network communication error: {e}') from e
480-
except json.JSONDecodeError as e:
481-
raise A2AClientError(f'JSON Decode Error: {e}') from e
315+
async for event in self._send_stream_request(
316+
'GET',
317+
f'/v1/tasks/{request.id}:subscribe',
318+
http_kwargs=modified_kwargs,
319+
):
320+
yield event
482321

483322
async def get_extended_agent_card(
484323
self,
@@ -519,6 +358,163 @@ async def close(self) -> None:
519358
"""Closes the httpx client."""
520359
await self.httpx_client.aclose()
521360

361+
async def _apply_interceptors(
362+
self,
363+
request_payload: dict[str, Any],
364+
http_kwargs: dict[str, Any] | None,
365+
context: ClientCallContext | None,
366+
) -> tuple[dict[str, Any], dict[str, Any]]:
367+
final_http_kwargs = http_kwargs or {}
368+
final_request_payload = request_payload
369+
# TODO: Implement interceptors for other transports
370+
return final_request_payload, final_http_kwargs
371+
372+
def _get_http_args(
373+
self, context: ClientCallContext | None
374+
) -> dict[str, Any] | None:
375+
return context.state.get('http_kwargs') if context else None
376+
377+
async def _prepare_send_message(
378+
self,
379+
request: SendMessageRequest,
380+
context: ClientCallContext | None,
381+
extensions: list[str] | None = None,
382+
) -> tuple[dict[str, Any], dict[str, Any]]:
383+
payload = MessageToDict(request)
384+
modified_kwargs = update_extension_header(
385+
self._get_http_args(context),
386+
extensions if extensions is not None else self.extensions,
387+
)
388+
payload, modified_kwargs = await self._apply_interceptors(
389+
payload,
390+
modified_kwargs,
391+
context,
392+
)
393+
return payload, modified_kwargs
394+
395+
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
396+
"""Handles HTTP status errors and raises the appropriate A2AError."""
397+
try:
398+
error_data = e.response.json()
399+
error_type = error_data.get('type')
400+
message = error_data.get('message', str(e))
401+
402+
if isinstance(error_type, str):
403+
# TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723.
404+
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type)
405+
if exception_cls:
406+
raise exception_cls(message) from e
407+
except (json.JSONDecodeError, ValueError):
408+
pass
409+
410+
# Fallback mappings for status codes if 'type' is missing or unknown
411+
status_code = e.response.status_code
412+
if status_code == httpx.codes.NOT_FOUND:
413+
raise MethodNotFoundError(
414+
f'Resource not found: {e.request.url}'
415+
) from e
416+
417+
raise A2AClientError(f'HTTP Error {status_code}: {e}') from e
418+
419+
async def _send_stream_request(
420+
self,
421+
method: str,
422+
target: str,
423+
http_kwargs: dict[str, Any] | None = None,
424+
**kwargs: Any,
425+
) -> AsyncGenerator[StreamResponse]:
426+
final_kwargs = dict(http_kwargs or {})
427+
final_kwargs.update(kwargs)
428+
final_kwargs.setdefault('timeout', None)
429+
430+
try:
431+
async with aconnect_sse(
432+
self.httpx_client,
433+
method,
434+
f'{self.url}{target}',
435+
**final_kwargs,
436+
) as event_source:
437+
try:
438+
event_source.response.raise_for_status()
439+
async for sse in event_source.aiter_sse():
440+
if not sse.data:
441+
continue
442+
event: StreamResponse = Parse(
443+
sse.data, StreamResponse()
444+
)
445+
yield event
446+
except httpx.HTTPStatusError as e:
447+
self._handle_http_error(e)
448+
except SSEError as e:
449+
raise A2AClientError(
450+
f'Invalid SSE response or protocol error: {e}'
451+
) from e
452+
except httpx.TimeoutException as e:
453+
raise A2AClientError('Client Request timed out') from e
454+
except httpx.RequestError as e:
455+
raise A2AClientError(f'Network communication error: {e}') from e
456+
except json.JSONDecodeError as e:
457+
raise A2AClientError(f'JSON Decode Error: {e}') from e
458+
459+
async def _send_request(self, request: httpx.Request) -> dict[str, Any]:
460+
try:
461+
response = await self.httpx_client.send(request)
462+
response.raise_for_status()
463+
return response.json()
464+
except httpx.TimeoutException as e:
465+
raise A2AClientError('Client Request timed out') from e
466+
except httpx.HTTPStatusError as e:
467+
self._handle_http_error(e)
468+
except json.JSONDecodeError as e:
469+
raise A2AClientError(f'JSON Decode Error: {e}') from e
470+
except httpx.RequestError as e:
471+
raise A2AClientError(f'Network communication error: {e}') from e
472+
473+
async def _send_post_request(
474+
self,
475+
target: str,
476+
rpc_request_payload: dict[str, Any],
477+
http_kwargs: dict[str, Any] | None = None,
478+
) -> dict[str, Any]:
479+
return await self._send_request(
480+
self.httpx_client.build_request(
481+
'POST',
482+
f'{self.url}{target}',
483+
json=rpc_request_payload,
484+
**(http_kwargs or {}),
485+
)
486+
)
487+
488+
async def _send_get_request(
489+
self,
490+
target: str,
491+
query_params: dict[str, str],
492+
http_kwargs: dict[str, Any] | None = None,
493+
) -> dict[str, Any]:
494+
return await self._send_request(
495+
self.httpx_client.build_request(
496+
'GET',
497+
f'{self.url}{target}',
498+
params=query_params,
499+
**(http_kwargs or {}),
500+
)
501+
)
502+
503+
async def _send_delete_request(
504+
self,
505+
target: str,
506+
query_params: dict[str, Any],
507+
http_kwargs: dict[str, Any] | None = None,
508+
) -> dict[str, Any]:
509+
return await self._send_request(
510+
self.httpx_client.build_request(
511+
'DELETE',
512+
f'{self.url}{target}',
513+
params=query_params,
514+
**(http_kwargs or {}),
515+
)
516+
)
517+
522518

523519
def _model_to_query_params(instance: Message) -> dict[str, str]:
524520
data = MessageToDict(instance, preserving_proto_field_name=True)

tests/client/test_errors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,3 @@ def test_raising_base_error(self) -> None:
2323
raise A2AClientError('Generic client error')
2424

2525
assert str(excinfo.value) == 'Generic client error'
26-

0 commit comments

Comments
 (0)