Skip to content

Commit 60e9fed

Browse files
committed
refactor: replace CallContextBuilder with functional UserBuilder pattern for server request authentication
1 parent e1fb3f3 commit 60e9fed

12 files changed

Lines changed: 129 additions & 162 deletions

File tree

src/a2a/compat/v0_3/grpc_handler.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
from a2a.server.context import ServerCallContext
2424
from a2a.server.request_handlers.grpc_handler import (
2525
_ERROR_CODE_MAP,
26-
CallContextBuilder,
27-
DefaultCallContextBuilder,
26+
GrpcUserBuilder,
27+
build_grpc_server_call_context,
28+
default_grpc_user_builder,
2829
)
2930
from a2a.server.request_handlers.request_handler import RequestHandler
3031
from a2a.types.a2a_pb2 import AgentCard
@@ -44,7 +45,7 @@ def __init__(
4445
self,
4546
agent_card: AgentCard,
4647
request_handler: RequestHandler,
47-
context_builder: CallContextBuilder | None = None,
48+
user_builder: GrpcUserBuilder | None = None,
4849
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
4950
| None = None,
5051
):
@@ -54,14 +55,14 @@ def __init__(
5455
agent_card: The AgentCard describing the agent's capabilities (v1.0).
5556
request_handler: The underlying `RequestHandler` instance to
5657
delegate requests to.
57-
context_builder: The CallContextBuilder object. If none the
58-
DefaultCallContextBuilder is used.
58+
user_builder: Optional custom user builder to extract user from the
59+
gRPC context.
5960
card_modifier: An optional callback to dynamically modify the public
6061
agent card before it is served.
6162
"""
6263
self.agent_card = agent_card
6364
self.handler03 = RequestHandler03(request_handler=request_handler)
64-
self.context_builder = context_builder or DefaultCallContextBuilder()
65+
self.user_builder = user_builder or default_grpc_user_builder
6566
self.card_modifier = card_modifier
6667

6768
async def _handle_unary(
@@ -72,7 +73,9 @@ async def _handle_unary(
7273
) -> TResponse:
7374
"""Centralized error handling and context management for unary calls."""
7475
try:
75-
server_context = self.context_builder.build(context)
76+
server_context = build_grpc_server_call_context(
77+
context, self.user_builder
78+
)
7679
result = await handler_func(server_context)
7780
self._set_extension_metadata(context, server_context)
7881
except A2AError as e:
@@ -88,7 +91,9 @@ async def _handle_stream(
8891
) -> AsyncIterable[TResponse]:
8992
"""Centralized error handling and context management for streaming calls."""
9093
try:
91-
server_context = self.context_builder.build(context)
94+
server_context = build_grpc_server_call_context(
95+
context, self.user_builder
96+
)
9297
async for item in handler_func(server_context):
9398
yield item
9499
self._set_extension_metadata(context, server_context)

src/a2a/compat/v0_3/jsonrpc_adapter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from starlette.requests import Request
1212

1313
from a2a.server.request_handlers.request_handler import RequestHandler
14-
from a2a.server.routes import CallContextBuilder
1514
from a2a.types.a2a_pb2 import AgentCard
1615

1716
_package_starlette_installed = True
@@ -38,6 +37,11 @@
3837
from a2a.server.jsonrpc_models import (
3938
JSONRPCError as CoreJSONRPCError,
4039
)
40+
from a2a.server.routes.common import (
41+
UserBuilder,
42+
build_server_call_context,
43+
default_user_builder,
44+
)
4145
from a2a.utils import constants
4246
from a2a.utils.errors import ExtendedAgentCardNotConfiguredError
4347
from a2a.utils.helpers import maybe_await, validate_version
@@ -67,7 +71,7 @@ def __init__( # noqa: PLR0913
6771
agent_card: 'AgentCard',
6872
http_handler: 'RequestHandler',
6973
extended_agent_card: 'AgentCard | None' = None,
70-
context_builder: 'CallContextBuilder | None' = None,
74+
user_builder: 'UserBuilder | None' = None,
7175
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
7276
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
7377
):
@@ -78,7 +82,7 @@ def __init__( # noqa: PLR0913
7882
self.handler = RequestHandler03(
7983
request_handler=http_handler,
8084
)
81-
self._context_builder = context_builder
85+
self._user_builder = user_builder or default_user_builder
8286

8387
def supports_method(self, method: str) -> bool:
8488
"""Returns True if the v0.3 adapter supports the given method name."""
@@ -126,10 +130,8 @@ async def handle_request(
126130
CoreInvalidRequestError(data=str(e)),
127131
)
128132

129-
call_context = (
130-
self._context_builder.build(request)
131-
if self._context_builder
132-
else ServerCallContext()
133+
call_context = build_server_call_context(
134+
request, self._user_builder
133135
)
134136
call_context.tenant = (
135137
getattr(specific_request.params, 'tenant', '')

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@
3434
from a2a.compat.v0_3 import conversions
3535
from a2a.compat.v0_3.rest_handler import REST03Handler
3636
from a2a.server.context import ServerCallContext
37-
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
37+
from a2a.server.routes.common import (
38+
UserBuilder,
39+
build_server_call_context,
40+
default_user_builder,
41+
)
3842
from a2a.utils.error_handlers import (
3943
rest_error_handler,
4044
rest_stream_error_handler,
@@ -60,7 +64,7 @@ def __init__( # noqa: PLR0913
6064
agent_card: 'AgentCard',
6165
http_handler: 'RequestHandler',
6266
extended_agent_card: 'AgentCard | None' = None,
63-
context_builder: 'CallContextBuilder | None' = None,
67+
user_builder: 'UserBuilder | None' = None,
6468
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
6569
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
6670
):
@@ -71,15 +75,15 @@ def __init__( # noqa: PLR0913
7175
self.handler = REST03Handler(
7276
agent_card=agent_card, request_handler=http_handler
7377
)
74-
self._context_builder = context_builder or DefaultCallContextBuilder()
78+
self._user_builder = user_builder or default_user_builder
7579

7680
@rest_error_handler
7781
async def _handle_request(
7882
self,
7983
method: 'Callable[[Request, ServerCallContext], Awaitable[Any]]',
8084
request: Request,
8185
) -> Response:
82-
call_context = self._context_builder.build(request)
86+
call_context = build_server_call_context(request, self._user_builder)
8387
response = await method(request, call_context)
8488
return JSONResponse(content=response)
8589

@@ -96,7 +100,7 @@ async def _handle_streaming_request(
96100
message=f'Failed to pre-consume request body: {e}'
97101
) from e
98102

99-
call_context = self._context_builder.build(request)
103+
call_context = build_server_call_context(request, self._user_builder)
100104

101105
async def event_generator(
102106
stream: AsyncIterable[Any],

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# ruff: noqa: N802
22
import logging
33

4-
from abc import ABC, abstractmethod
54
from collections.abc import AsyncIterable, Awaitable, Callable
65
from typing import TypeVar
76

@@ -24,7 +23,7 @@
2423
import a2a.types.a2a_pb2_grpc as a2a_grpc
2524

2625
from a2a import types
27-
from a2a.auth.user import UnauthenticatedUser
26+
from a2a.auth.user import UnauthenticatedUser, User
2827
from a2a.extensions.common import (
2928
HTTP_EXTENSION_HEADER,
3029
get_requested_extensions,
@@ -41,15 +40,12 @@
4140

4241
logger = logging.getLogger(__name__)
4342

44-
# For now we use a trivial wrapper on the grpc context object
43+
GrpcUserBuilder = Callable[[grpc.aio.ServicerContext], User]
4544

4645

47-
class CallContextBuilder(ABC):
48-
"""A class for building ServerCallContexts using the Starlette Request."""
49-
50-
@abstractmethod
51-
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
52-
"""Builds a ServerCallContext from a gRPC Request."""
46+
def default_grpc_user_builder(context: grpc.aio.ServicerContext) -> User:
47+
"""Default strategy for creating a User from a gRPC context."""
48+
return UnauthenticatedUser()
5349

5450

5551
def _get_metadata_value(
@@ -67,20 +63,19 @@ def _get_metadata_value(
6763
]
6864

6965

70-
class DefaultCallContextBuilder(CallContextBuilder):
71-
"""A default implementation of CallContextBuilder."""
72-
73-
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
74-
"""Builds the ServerCallContext."""
75-
user = UnauthenticatedUser()
76-
state = {'grpc_context': context}
77-
return ServerCallContext(
78-
user=user,
79-
state=state,
80-
requested_extensions=get_requested_extensions(
81-
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
82-
),
83-
)
66+
def build_grpc_server_call_context(
67+
context: grpc.aio.ServicerContext, user_builder: GrpcUserBuilder
68+
) -> ServerCallContext:
69+
"""Builds a ServerCallContext from a gRPC ServicerContext."""
70+
user = user_builder(context)
71+
state = {'grpc_context': context}
72+
return ServerCallContext(
73+
user=user,
74+
state=state,
75+
requested_extensions=get_requested_extensions(
76+
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
77+
),
78+
)
8479

8580

8681
_ERROR_CODE_MAP = {
@@ -110,7 +105,7 @@ def __init__(
110105
self,
111106
agent_card: AgentCard,
112107
request_handler: RequestHandler,
113-
context_builder: CallContextBuilder | None = None,
108+
user_builder: GrpcUserBuilder | None = None,
114109
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
115110
| None = None,
116111
):
@@ -120,14 +115,14 @@ def __init__(
120115
agent_card: The AgentCard describing the agent's capabilities.
121116
request_handler: The underlying `RequestHandler` instance to
122117
delegate requests to.
123-
context_builder: The CallContextBuilder object. If none the
124-
DefaultCallContextBuilder is used.
118+
user_builder: Optional custom user builder to extract user from the
119+
gRPC context.
125120
card_modifier: An optional callback to dynamically modify the public
126121
agent card before it is served.
127122
"""
128123
self.agent_card = agent_card
129124
self.request_handler = request_handler
130-
self.context_builder = context_builder or DefaultCallContextBuilder()
125+
self.user_builder = user_builder or default_grpc_user_builder
131126
self.card_modifier = card_modifier
132127

133128
async def _handle_unary(
@@ -451,6 +446,8 @@ def _build_call_context(
451446
context: grpc.aio.ServicerContext,
452447
request: message.Message,
453448
) -> ServerCallContext:
454-
server_context = self.context_builder.build(context)
449+
server_context = build_grpc_server_call_context(
450+
context, self.user_builder
451+
)
455452
server_context.tenant = getattr(request, 'tenant', '')
456453
return server_context

src/a2a/server/routes/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
"""A2A Routes."""
22

33
from a2a.server.routes.agent_card_routes import create_agent_card_routes
4-
from a2a.server.routes.common import (
5-
CallContextBuilder,
6-
DefaultCallContextBuilder,
7-
)
4+
from a2a.server.routes.common import UserBuilder
85
from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes
96
from a2a.server.routes.rest_routes import create_rest_routes
107

118

129
__all__ = [
13-
'CallContextBuilder',
14-
'DefaultCallContextBuilder',
10+
'UserBuilder',
1511
'create_agent_card_routes',
1612
'create_jsonrpc_routes',
1713
'create_rest_routes',

src/a2a/server/routes/common.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,58 @@
1-
from abc import ABC, abstractmethod
1+
from collections.abc import Callable
22

3-
from starlette.authentication import BaseUser
43
from starlette.requests import Request
54

6-
from a2a.auth.user import UnauthenticatedUser
7-
from a2a.auth.user import User as A2AUser
5+
from a2a.auth.user import UnauthenticatedUser, User
86
from a2a.extensions.common import (
97
HTTP_EXTENSION_HEADER,
108
get_requested_extensions,
119
)
1210
from a2a.server.context import ServerCallContext
1311

1412

15-
class StarletteUserProxy(A2AUser):
16-
"""Adapts the Starlette User class to the A2A user representation."""
13+
UserBuilder = Callable[[Request], User]
1714

18-
def __init__(self, user: BaseUser):
19-
self._user = user
2015

21-
@property
22-
def is_authenticated(self) -> bool:
23-
"""Returns whether the current user is authenticated."""
24-
return self._user.is_authenticated
16+
def default_user_builder(request: Request) -> User:
17+
"""Default strategy for creating an A2AUser from a Starlette Request."""
18+
if 'user' in request.scope:
2519

26-
@property
27-
def user_name(self) -> str:
28-
"""Returns the user name of the current user."""
29-
return self._user.display_name
20+
class BaseUser(User):
21+
@property
22+
def is_authenticated(self) -> bool:
23+
return request.user.is_authenticated
3024

25+
@property
26+
def user_name(self) -> str:
27+
return request.user.display_name
3128

32-
class CallContextBuilder(ABC):
33-
"""A class for building ServerCallContexts using the Starlette Request."""
29+
return BaseUser()
30+
return UnauthenticatedUser()
3431

35-
@abstractmethod
36-
def build(self, request: Request) -> ServerCallContext:
37-
"""Builds a ServerCallContext from a Starlette Request."""
3832

33+
def build_server_call_context(
34+
request: Request, user_builder: UserBuilder
35+
) -> ServerCallContext:
36+
"""Builds a ServerCallContext from a Starlette Request.
3937
40-
class DefaultCallContextBuilder(CallContextBuilder):
41-
"""A default implementation of CallContextBuilder."""
38+
Args:
39+
request: The incoming Starlette Request object.
40+
user_builder: Optional custom user builder.
4241
43-
def build(self, request: Request) -> ServerCallContext:
44-
"""Builds a ServerCallContext from a Starlette Request.
42+
Returns:
43+
A ServerCallContext instance populated with user and state.
44+
"""
45+
user = user_builder(request)
4546

46-
Args:
47-
request: The incoming Starlette Request object.
47+
state = {}
48+
if 'auth' in request.scope:
49+
state['auth'] = request.auth
50+
state['headers'] = dict(request.headers)
4851

49-
Returns:
50-
A ServerCallContext instance populated with user and state
51-
information from the request.
52-
"""
53-
user: A2AUser = UnauthenticatedUser()
54-
state = {}
55-
if 'user' in request.scope:
56-
user = StarletteUserProxy(request.user)
57-
state['auth'] = request.auth
58-
state['headers'] = dict(request.headers)
59-
return ServerCallContext(
60-
user=user,
61-
state=state,
62-
requested_extensions=get_requested_extensions(
63-
request.headers.getlist(HTTP_EXTENSION_HEADER)
64-
),
65-
)
52+
return ServerCallContext(
53+
user=user,
54+
state=state,
55+
requested_extensions=get_requested_extensions(
56+
request.headers.getlist(HTTP_EXTENSION_HEADER)
57+
),
58+
)

0 commit comments

Comments
 (0)