11import functools
2- import json
32import logging
4- import traceback
53
6- from collections .abc import AsyncGenerator , AsyncIterator , Awaitable , Callable
4+ from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
75from typing import Any
86
9- from pydantic import ValidationError
107from sse_starlette .sse import EventSourceResponse
118from starlette .requests import Request
12- from starlette .responses import JSONResponse
9+ from starlette .responses import JSONResponse , Response
1310
1411from a2a .server .apps .jsonrpc import (
1512 CallContextBuilder ,
1916from a2a .server .request_handlers .request_handler import RequestHandler
2017from a2a .server .request_handlers .rest_handler import RESTHandler
2118from a2a .types import (
22- A2AError ,
2319 AgentCard ,
24- InternalError ,
25- InvalidRequestError ,
26- JSONParseError ,
27- UnsupportedOperationError ,
20+ AuthenticatedExtendedCardNotConfiguredError ,
2821)
29- from a2a .utils .errors import MethodNotImplementedError
22+ from a2a .utils .error_handlers import (
23+ rest_error_handler ,
24+ rest_stream_error_handler ,
25+ )
26+ from a2a .utils .errors import ServerError
3027
3128
3229logger = logging .getLogger (__name__ )
@@ -61,88 +58,51 @@ def __init__(
6158 )
6259 self ._context_builder = context_builder or DefaultCallContextBuilder ()
6360
64- def _generate_error_response (self , error : A2AError ) -> JSONResponse :
65- """Creates a JSONResponse for an error.
66-
67- Logs the error based on its type.
68-
69- Args:
70- error: The Error object.
71-
72- Returns:
73- A `JSONResponse` object formatted as a JSON error response.
74- """
75- log_level = (
76- logging .ERROR
77- if isinstance (error , InternalError )
78- else logging .WARNING
79- )
80- logger .log (
81- log_level ,
82- 'Request Error: '
83- f"Code={ error .root .code } , Message='{ error .root .message } '"
84- f'{ ", Data=" + str (error .root .data ) if error .root .data else "" } ' ,
85- )
86- return JSONResponse (
87- f'{{"message": "{ error .root .message } "}}' ,
88- status_code = 500 ,
89- )
90-
91- def _handle_error (self , error : Exception ) -> JSONResponse :
92- traceback .print_exc ()
93- if isinstance (error , MethodNotImplementedError ):
94- return self ._generate_error_response (
95- A2AError (UnsupportedOperationError (message = error .message ))
96- )
97- if isinstance (error , json .decoder .JSONDecodeError ):
98- return self ._generate_error_response (
99- A2AError (JSONParseError (message = str (error )))
100- )
101- if isinstance (error , ValidationError ):
102- return self ._generate_error_response (
103- A2AError (InvalidRequestError (data = json .loads (error .json ()))),
104- )
105- logger .error (f'Unhandled exception: { error } ' )
106- return self ._generate_error_response (
107- A2AError (InternalError (message = str (error )))
108- )
109-
61+ @rest_error_handler
11062 async def _handle_request (
11163 self ,
112- method : Callable [[Request , ServerCallContext ], Awaitable [str ]],
64+ method : Callable [
65+ [Request , ServerCallContext ], Awaitable [dict [str , Any ]]
66+ ],
11367 request : Request ,
114- ) -> JSONResponse :
115- try :
116- call_context = self ._context_builder .build (request )
117- response = await method (request , call_context )
118- return JSONResponse (content = response )
119- except Exception as e :
120- return self ._handle_error (e )
68+ ) -> Response :
69+ call_context = self ._context_builder .build (request )
70+ response = await method (request , call_context )
71+ return JSONResponse (content = response )
12172
73+ @rest_error_handler
74+ async def _handle_list_request (
75+ self ,
76+ method : Callable [
77+ [Request , ServerCallContext ], Awaitable [list [dict [str , Any ]]]
78+ ],
79+ request : Request ,
80+ ) -> Response :
81+ call_context = self ._context_builder .build (request )
82+ response = await method (request , call_context )
83+ return JSONResponse (content = response )
84+
85+ @rest_stream_error_handler
12286 async def _handle_streaming_request (
12387 self ,
124- method : Callable [[Request , ServerCallContext ], AsyncIterator [str ]],
88+ method : Callable [
89+ [Request , ServerCallContext ], AsyncIterable [dict [str , Any ]]
90+ ],
12591 request : Request ,
12692 ) -> EventSourceResponse :
127- try :
128- call_context = self ._context_builder .build (request )
93+ call_context = self ._context_builder .build (request )
12994
130- async def event_generator (
131- stream : AsyncGenerator [ str ],
132- ) -> AsyncGenerator [dict [str , str ]]:
133- async for item in stream :
134- yield {'data' : item }
95+ async def event_generator (
96+ stream : AsyncIterable [ dict [ str , Any ] ],
97+ ) -> AsyncIterator [dict [str , dict [ str , Any ] ]]:
98+ async for item in stream :
99+ yield {'data' : item }
135100
136- return EventSourceResponse (
137- event_generator (method (request , call_context ))
138- )
139- except Exception as e :
140- # Since the stream has started, we can't return a JSONResponse.
141- # Instead, we runt the error handling logic (provides logging)
142- # and reraise the error and let server framework manage
143- self ._handle_error (e )
144- raise e
101+ return EventSourceResponse (
102+ event_generator (method (request , call_context ))
103+ )
145104
105+ @rest_error_handler
146106 async def _handle_get_agent_card (self , request : Request ) -> JSONResponse :
147107 """Handles GET requests for the agent card endpoint.
148108
@@ -158,6 +118,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
158118 self .agent_card .model_dump (mode = 'json' , exclude_none = True )
159119 )
160120
121+ @rest_error_handler
161122 async def handle_authenticated_agent_card (
162123 self , request : Request
163124 ) -> JSONResponse :
@@ -173,9 +134,10 @@ async def handle_authenticated_agent_card(
173134 A JSONResponse containing the authenticated card.
174135 """
175136 if not self .agent_card .supports_authenticated_extended_card :
176- return JSONResponse (
177- '{"detail": "Authenticated card not supported"}' ,
178- status_code = 404 ,
137+ raise ServerError (
138+ error = AuthenticatedExtendedCardNotConfiguredError (
139+ message = 'Authenticated card not supported'
140+ )
179141 )
180142 return JSONResponse (
181143 self .agent_card .model_dump (mode = 'json' , exclude_none = True )
@@ -226,10 +188,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
226188 '/v1/tasks/{id}/pushNotificationConfigs' ,
227189 'GET' ,
228190 ): functools .partial (
229- self ._handle_request , self .handler .list_push_notifications
191+ self ._handle_list_request , self .handler .list_push_notifications
230192 ),
231193 ('/v1/tasks' , 'GET' ): functools .partial (
232- self ._handle_request , self .handler .list_tasks
194+ self ._handle_list_request , self .handler .list_tasks
233195 ),
234196 }
235197 if self .agent_card .supports_authenticated_extended_card :
0 commit comments