Skip to content

Commit e57c8f7

Browse files
committed
Merge branch 'restful' into restful-2
2 parents 12e595d + a063a8e commit e57c8f7

9 files changed

Lines changed: 496 additions & 272 deletions

File tree

Gemini.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
**A2A specification:** https://a2a-protocol.org/latest/specification/
2+
3+
## Project frameworks
4+
- uv as package manager
5+
6+
## How to run all tests
7+
1. If dependencies are not installed install them using following command
8+
```
9+
uv sync --all-extras
10+
```
11+
12+
2. Run tests
13+
```
14+
uv run pytest
15+
```
16+
17+
## Other instructions
18+
1. Whenever writing python code, write types as well.
19+
2. After making the changes run ruff to check and fix the formatting issues
20+
```
21+
uv run ruff check --fix
22+
```
23+
3. Run mypy type checkers to check for type errors
24+
```
25+
uv run mypy
26+
```
27+
4. Run the unit tests to make sure that none of the unit tests are broken.

error_handlers.py

Whitespace-only changes.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }]
88
requires-python = ">=3.10"
99
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent", "Agent 2 Agent"]
1010
dependencies = [
11-
"fastapi>=0.115.2",
11+
"fastapi>=0.116.1",
1212
"httpx>=0.28.1",
1313
"httpx-sse>=0.4.0",
1414
"opentelemetry-api>=1.33.0",
@@ -93,6 +93,7 @@ dev = [
9393
"pyupgrade",
9494
"autoflake",
9595
"no_implicit_optional",
96+
"trio",
9697
]
9798

9899
[[tool.uv.index]]

src/a2a/server/apps/rest/rest_app.py

Lines changed: 49 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
import functools
2-
import json
32
import logging
4-
import traceback
53

6-
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
4+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
75
from typing import Any
86

9-
from pydantic import ValidationError
107
from sse_starlette.sse import EventSourceResponse
118
from starlette.requests import Request
12-
from starlette.responses import JSONResponse
9+
from starlette.responses import JSONResponse, Response
1310

1411
from a2a.server.apps.jsonrpc import (
1512
CallContextBuilder,
@@ -19,14 +16,14 @@
1916
from a2a.server.request_handlers.request_handler import RequestHandler
2017
from a2a.server.request_handlers.rest_handler import RESTHandler
2118
from a2a.types import (
22-
A2AError,
2319
AgentCard,
24-
InternalError,
25-
InvalidRequestError,
26-
JSONParseError,
27-
UnsupportedOperationError,
20+
AuthenticatedExtendedCardNotConfiguredError,
2821
)
29-
from a2a.utils.errors import MethodNotImplementedError
22+
from a2a.utils.error_handlers import (
23+
rest_error_handler,
24+
rest_stream_error_handler,
25+
)
26+
from a2a.utils.errors import ServerError
3027

3128

3229
logger = logging.getLogger(__name__)
@@ -61,88 +58,51 @@ def __init__(
6158
)
6259
self._context_builder = context_builder or DefaultCallContextBuilder()
6360

64-
def _generate_error_response(self, error: A2AError) -> JSONResponse:
65-
"""Creates a JSONResponse for an error.
66-
67-
Logs the error based on its type.
68-
69-
Args:
70-
error: The Error object.
71-
72-
Returns:
73-
A `JSONResponse` object formatted as a JSON error response.
74-
"""
75-
log_level = (
76-
logging.ERROR
77-
if isinstance(error, InternalError)
78-
else logging.WARNING
79-
)
80-
logger.log(
81-
log_level,
82-
'Request Error: '
83-
f"Code={error.root.code}, Message='{error.root.message}'"
84-
f'{", Data=" + str(error.root.data) if error.root.data else ""}',
85-
)
86-
return JSONResponse(
87-
f'{{"message": "{error.root.message}"}}',
88-
status_code=500,
89-
)
90-
91-
def _handle_error(self, error: Exception) -> JSONResponse:
92-
traceback.print_exc()
93-
if isinstance(error, MethodNotImplementedError):
94-
return self._generate_error_response(
95-
A2AError(UnsupportedOperationError(message=error.message))
96-
)
97-
if isinstance(error, json.decoder.JSONDecodeError):
98-
return self._generate_error_response(
99-
A2AError(JSONParseError(message=str(error)))
100-
)
101-
if isinstance(error, ValidationError):
102-
return self._generate_error_response(
103-
A2AError(InvalidRequestError(data=json.loads(error.json()))),
104-
)
105-
logger.error(f'Unhandled exception: {error}')
106-
return self._generate_error_response(
107-
A2AError(InternalError(message=str(error)))
108-
)
109-
61+
@rest_error_handler
11062
async def _handle_request(
11163
self,
112-
method: Callable[[Request, ServerCallContext], Awaitable[str]],
64+
method: Callable[
65+
[Request, ServerCallContext], Awaitable[dict[str, Any]]
66+
],
11367
request: Request,
114-
) -> JSONResponse:
115-
try:
116-
call_context = self._context_builder.build(request)
117-
response = await method(request, call_context)
118-
return JSONResponse(content=response)
119-
except Exception as e:
120-
return self._handle_error(e)
68+
) -> Response:
69+
call_context = self._context_builder.build(request)
70+
response = await method(request, call_context)
71+
return JSONResponse(content=response)
12172

73+
@rest_error_handler
74+
async def _handle_list_request(
75+
self,
76+
method: Callable[
77+
[Request, ServerCallContext], Awaitable[list[dict[str, Any]]]
78+
],
79+
request: Request,
80+
) -> Response:
81+
call_context = self._context_builder.build(request)
82+
response = await method(request, call_context)
83+
return JSONResponse(content=response)
84+
85+
@rest_stream_error_handler
12286
async def _handle_streaming_request(
12387
self,
124-
method: Callable[[Request, ServerCallContext], AsyncIterator[str]],
88+
method: Callable[
89+
[Request, ServerCallContext], AsyncIterable[dict[str, Any]]
90+
],
12591
request: Request,
12692
) -> EventSourceResponse:
127-
try:
128-
call_context = self._context_builder.build(request)
93+
call_context = self._context_builder.build(request)
12994

130-
async def event_generator(
131-
stream: AsyncGenerator[str],
132-
) -> AsyncGenerator[dict[str, str]]:
133-
async for item in stream:
134-
yield {'data': item}
95+
async def event_generator(
96+
stream: AsyncIterable[dict[str, Any]],
97+
) -> AsyncIterator[dict[str, dict[str, Any]]]:
98+
async for item in stream:
99+
yield {'data': item}
135100

136-
return EventSourceResponse(
137-
event_generator(method(request, call_context))
138-
)
139-
except Exception as e:
140-
# Since the stream has started, we can't return a JSONResponse.
141-
# Instead, we runt the error handling logic (provides logging)
142-
# and reraise the error and let server framework manage
143-
self._handle_error(e)
144-
raise e
101+
return EventSourceResponse(
102+
event_generator(method(request, call_context))
103+
)
145104

105+
@rest_error_handler
146106
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
147107
"""Handles GET requests for the agent card endpoint.
148108
@@ -158,6 +118,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
158118
self.agent_card.model_dump(mode='json', exclude_none=True)
159119
)
160120

121+
@rest_error_handler
161122
async def handle_authenticated_agent_card(
162123
self, request: Request
163124
) -> JSONResponse:
@@ -173,9 +134,10 @@ async def handle_authenticated_agent_card(
173134
A JSONResponse containing the authenticated card.
174135
"""
175136
if not self.agent_card.supports_authenticated_extended_card:
176-
return JSONResponse(
177-
'{"detail": "Authenticated card not supported"}',
178-
status_code=404,
137+
raise ServerError(
138+
error=AuthenticatedExtendedCardNotConfiguredError(
139+
message='Authenticated card not supported'
140+
)
179141
)
180142
return JSONResponse(
181143
self.agent_card.model_dump(mode='json', exclude_none=True)
@@ -226,10 +188,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
226188
'/v1/tasks/{id}/pushNotificationConfigs',
227189
'GET',
228190
): functools.partial(
229-
self._handle_request, self.handler.list_push_notifications
191+
self._handle_list_request, self.handler.list_push_notifications
230192
),
231193
('/v1/tasks', 'GET'): functools.partial(
232-
self._handle_request, self.handler.list_tasks
194+
self._handle_list_request, self.handler.list_tasks
233195
),
234196
}
235197
if self.agent_card.supports_authenticated_extended_card:

0 commit comments

Comments
 (0)