Skip to content

Commit 7405dc7

Browse files
committed
refactor: Remove redundant JSON-RPC Pydantic types, use jsonrpc library directly
- Update JSONRPCHandler to return dict[str, Any] instead of Pydantic RootModels - Update response_helpers to build dicts with JSON-RPC 2.0 structure - Remove unused Pydantic response types from types module - Fix proto dependency loading in a2a_pb2.py - Update all tests to check dict responses instead of Pydantic models - Add TransportProtocol constants to utils module
1 parent 424dd7e commit 7405dc7

38 files changed

Lines changed: 822 additions & 1117 deletions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"protobuf>=5.29.5",
1515
"google-api-core>=1.26.0",
1616
"json-rpc>=1.15.0",
17+
"googleapis-common-protos>=1.70.0",
1718
]
1819

1920
classifiers = [

src/a2a/client/auth/interceptor.py

Lines changed: 44 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,10 @@
33

44
from a2a.client.auth.credentials import CredentialService
55
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6-
from a2a.types.a2a_pb2 import (
7-
AgentCard,
8-
APIKeySecurityScheme,
9-
HTTPAuthSecurityScheme,
10-
MutualTlsSecurityScheme,
11-
OAuth2SecurityScheme,
12-
OpenIdConnectSecurityScheme,
13-
SecurityScheme,
14-
)
6+
from a2a.types.a2a_pb2 import AgentCard
157

168
logger = logging.getLogger(__name__)
179

18-
_SecuritySchemeValue = (
19-
APIKeySecurityScheme
20-
| HTTPAuthSecurityScheme
21-
| OAuth2SecurityScheme
22-
| OpenIdConnectSecurityScheme
23-
| MutualTlsSecurityScheme
24-
| None
25-
)
26-
27-
28-
def _get_security_scheme_value(scheme: SecurityScheme) -> _SecuritySchemeValue:
29-
"""Extract the actual security scheme from the oneof union."""
30-
which = scheme.WhichOneof('scheme')
31-
if which == 'api_key_security_scheme':
32-
return scheme.api_key_security_scheme
33-
if which == 'http_auth_security_scheme':
34-
return scheme.http_auth_security_scheme
35-
if which == 'oauth2_security_scheme':
36-
return scheme.oauth2_security_scheme
37-
if which == 'open_id_connect_security_scheme':
38-
return scheme.open_id_connect_security_scheme
39-
if which == 'mtls_security_scheme':
40-
return scheme.mtls_security_scheme
41-
return None
42-
4310

4411
class AuthInterceptor(ClientCallInterceptor):
4512
"""An interceptor that automatically adds authentication details to requests.
@@ -72,54 +39,53 @@ async def intercept(
7239
scheme_name, context
7340
)
7441
if credential and scheme_name in agent_card.security_schemes:
75-
scheme_def_union = agent_card.security_schemes.get(
76-
scheme_name
77-
)
78-
if not scheme_def_union:
79-
continue
80-
scheme_def = _get_security_scheme_value(scheme_def_union)
81-
if not scheme_def:
42+
scheme = agent_card.security_schemes.get(scheme_name)
43+
if not scheme:
8244
continue
8345

8446
headers = http_kwargs.get('headers', {})
8547

86-
match scheme_def:
87-
# Case 1a: HTTP Bearer scheme with an if guard
88-
case HTTPAuthSecurityScheme() if (
89-
scheme_def.scheme.lower() == 'bearer'
90-
):
91-
headers['Authorization'] = f'Bearer {credential}'
92-
logger.debug(
93-
"Added Bearer token for scheme '%s'.",
94-
scheme_name,
95-
)
96-
http_kwargs['headers'] = headers
97-
return request_payload, http_kwargs
98-
99-
# Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer
100-
case (
101-
OAuth2SecurityScheme()
102-
| OpenIdConnectSecurityScheme()
103-
):
104-
headers['Authorization'] = f'Bearer {credential}'
105-
logger.debug(
106-
"Added Bearer token for scheme '%s'.",
107-
scheme_name,
108-
)
109-
http_kwargs['headers'] = headers
110-
return request_payload, http_kwargs
111-
112-
# Case 2: API Key in Header
113-
case APIKeySecurityScheme() if (
114-
scheme_def.location.lower() == 'header'
115-
):
116-
headers[scheme_def.name] = credential
117-
logger.debug(
118-
"Added API Key Header for scheme '%s'.",
119-
scheme_name,
120-
)
121-
http_kwargs['headers'] = headers
122-
return request_payload, http_kwargs
48+
# HTTP Bearer authentication
49+
if (
50+
scheme.HasField('http_auth_security_scheme')
51+
and scheme.http_auth_security_scheme.scheme.lower()
52+
== 'bearer'
53+
):
54+
headers['Authorization'] = f'Bearer {credential}'
55+
logger.debug(
56+
"Added Bearer token for scheme '%s'.",
57+
scheme_name,
58+
)
59+
http_kwargs['headers'] = headers
60+
return request_payload, http_kwargs
61+
62+
# OAuth2 and OIDC schemes are implicitly Bearer
63+
if scheme.HasField(
64+
'oauth2_security_scheme'
65+
) or scheme.HasField('open_id_connect_security_scheme'):
66+
headers['Authorization'] = f'Bearer {credential}'
67+
logger.debug(
68+
"Added Bearer token for scheme '%s'.",
69+
scheme_name,
70+
)
71+
http_kwargs['headers'] = headers
72+
return request_payload, http_kwargs
73+
74+
# API Key in Header
75+
if (
76+
scheme.HasField('api_key_security_scheme')
77+
and scheme.api_key_security_scheme.location.lower()
78+
== 'header'
79+
):
80+
headers[scheme.api_key_security_scheme.name] = (
81+
credential
82+
)
83+
logger.debug(
84+
"Added API Key Header for scheme '%s'.",
85+
scheme_name,
86+
)
87+
http_kwargs['headers'] = headers
88+
return request_payload, http_kwargs
12389

12490
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
12591

src/a2a/client/errors.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from a2a.types.extras import A2AError, JSONRPCErrorResponse
5+
from a2a.types.extras import A2AError
66

77

88
class A2AClientError(Exception):
@@ -81,16 +81,11 @@ class A2AClientJSONRPCError(A2AClientError):
8181

8282
error: dict[str, Any] | A2AError
8383

84-
def __init__(self, error: JSONRPCErrorResponse | dict[str, Any]):
84+
def __init__(self, error: dict[str, Any] | A2AError):
8585
"""Initializes the A2AClientJsonRPCError.
8686
8787
Args:
88-
error: The JSON-RPC error object or dict from the jsonrpc library.
88+
error: The JSON-RPC error dict from the jsonrpc library, or A2AError object.
8989
"""
90-
if isinstance(error, dict):
91-
# Raw dict from jsonrpc library: {'code': ..., 'message': ...}
92-
self.error = error
93-
else:
94-
# JSONRPCErrorResponse object
95-
self.error = error.error
90+
self.error = error
9691
super().__init__(f'JSON-RPC Error {self.error}')

src/a2a/server/apps/jsonrpc/jsonrpc_app.py

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""JSON-RPC application for A2A server."""
2+
13
import contextlib
24
import json
35
import logging
@@ -8,7 +10,7 @@
810
from typing import TYPE_CHECKING, Any
911

1012
from google.protobuf.json_format import MessageToDict, ParseDict
11-
from pydantic import RootModel, ValidationError
13+
from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response
1214

1315
from a2a.auth.user import UnauthenticatedUser
1416
from a2a.auth.user import User as A2AUser
@@ -37,10 +39,7 @@
3739
InvalidParamsError,
3840
InvalidRequestError,
3941
JSONParseError,
40-
JSONRPCErrorResponse,
41-
JSONRPCRequest,
4242
MethodNotFoundError,
43-
SendStreamingMessageResponse,
4443
TaskResubscriptionRequest,
4544
UnsupportedOperationError,
4645
)
@@ -233,10 +232,8 @@ def _generate_error_response(
233232
Returns:
234233
A `JSONResponse` object formatted as a JSON-RPC error response.
235234
"""
236-
error_resp = JSONRPCErrorResponse(
237-
id=request_id,
238-
error=error,
239-
)
235+
error_dict = error.model_dump(exclude_none=True)
236+
error_resp = JSONRPC20Response(error=error_dict, _id=request_id)
240237

241238
log_level = (
242239
logging.ERROR
@@ -247,14 +244,14 @@ def _generate_error_response(
247244
log_level,
248245
"Request Error (ID: %s): Code=%s, Message='%s'%s",
249246
request_id,
250-
error_resp.error.code,
251-
error_resp.error.message,
252-
', Data=' + str(error_resp.error.data)
253-
if error_resp.error.data
247+
error_dict.get('code'),
248+
error_dict.get('message'),
249+
', Data=' + str(error_dict.get('data'))
250+
if error_dict.get('data')
254251
else '',
255252
)
256253
return JSONResponse(
257-
error_resp.model_dump(mode='json', exclude_none=True),
254+
error_resp.data,
258255
status_code=200,
259256
)
260257

@@ -274,7 +271,7 @@ def _allowed_content_length(self, request: Request) -> bool:
274271
return False
275272
return True
276273

277-
async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
274+
async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911, PLR0912
278275
"""Handles incoming POST requests to the main A2A endpoint.
279276
280277
Parses the request body as JSON, validates it against A2A request types,
@@ -313,17 +310,31 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911
313310
logger.debug('Request body: %s', body)
314311
# 1) Validate base JSON-RPC structure only (-32600 on failure)
315312
try:
316-
base_request = JSONRPCRequest.model_validate(body)
317-
except ValidationError as e:
313+
base_request = JSONRPC20Request.from_data(body)
314+
if not isinstance(base_request, JSONRPC20Request):
315+
# Batch requests are not supported
316+
return self._generate_error_response(
317+
request_id,
318+
InvalidRequestError(
319+
message='Batch requests are not supported'
320+
),
321+
)
322+
except Exception as e:
318323
logger.exception('Failed to validate base JSON-RPC request')
319324
return self._generate_error_response(
320325
request_id,
321-
InvalidRequestError(data=json.loads(e.json())),
326+
InvalidRequestError(data=str(e)),
322327
)
323328

324329
# 2) Route by method name; unknown -> -32601, known -> validate params (-32602 on failure)
325-
method = base_request.method
326-
request_id = base_request.id
330+
method: str | None = base_request.method
331+
request_id = base_request._id # noqa: SLF001
332+
333+
if not method:
334+
return self._generate_error_response(
335+
request_id,
336+
InvalidRequestError(message='Method is required'),
337+
)
327338

328339
model_class = self.METHOD_TO_MODEL.get(method)
329340
if not model_class:
@@ -483,33 +494,25 @@ async def _process_non_streaming_request(
483494
error = UnsupportedOperationError(
484495
message=f'Request type {type(request_obj).__name__} is unknown.'
485496
)
486-
handler_result = JSONRPCErrorResponse(
487-
id=request_id, error=error
488-
)
497+
return self._generate_error_response(request_id, error)
489498

490499
return self._create_response(context, handler_result)
491500

492501
def _create_response(
493502
self,
494503
context: ServerCallContext,
495-
handler_result: (
496-
AsyncGenerator[SendStreamingMessageResponse]
497-
| JSONRPCErrorResponse
498-
| RootModel[Any]
499-
),
504+
handler_result: AsyncGenerator[dict[str, Any]] | dict[str, Any],
500505
) -> Response:
501506
"""Creates a Starlette Response based on the result from the request handler.
502507
503508
Handles:
504509
- AsyncGenerator for Server-Sent Events (SSE).
505-
- JSONRPCErrorResponse for explicit errors returned by handlers.
506-
- Pydantic RootModels (like GetTaskResponse) containing success or error
507-
payloads.
510+
- Dict responses from handlers.
508511
509512
Args:
510513
context: The ServerCallContext provided to the request handler.
511514
handler_result: The result from a request handler method. Can be an
512-
async generator for streaming or a Pydantic model for non-streaming.
515+
async generator for streaming or a dict for non-streaming.
513516
514517
Returns:
515518
A Starlette JSONResponse or EventSourceResponse.
@@ -518,29 +521,19 @@ def _create_response(
518521
if exts := context.activated_extensions:
519522
headers[HTTP_EXTENSION_HEADER] = ', '.join(sorted(exts))
520523
if isinstance(handler_result, AsyncGenerator):
521-
# Result is a stream of SendStreamingMessageResponse objects
524+
# Result is a stream of dict objects
522525
async def event_generator(
523-
stream: AsyncGenerator[SendStreamingMessageResponse],
526+
stream: AsyncGenerator[dict[str, Any]],
524527
) -> AsyncGenerator[dict[str, str]]:
525528
async for item in stream:
526-
yield {'data': item.root.model_dump_json(exclude_none=True)}
529+
yield {'data': json.dumps(item)}
527530

528531
return EventSourceResponse(
529532
event_generator(handler_result), headers=headers
530533
)
531-
if isinstance(handler_result, JSONRPCErrorResponse):
532-
return JSONResponse(
533-
handler_result.model_dump(
534-
mode='json',
535-
exclude_none=True,
536-
),
537-
headers=headers,
538-
)
539534

540-
return JSONResponse(
541-
handler_result.root.model_dump(mode='json', exclude_none=True),
542-
headers=headers,
543-
)
535+
# handler_result is a dict (JSON-RPC response)
536+
return JSONResponse(handler_result, headers=headers)
544537

545538
async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
546539
"""Handles GET requests for the agent card endpoint.

0 commit comments

Comments
 (0)