11# ruff: noqa: N802
22import logging
33
4- from abc import ABC , abstractmethod
54from collections .abc import AsyncIterable , Awaitable , Callable
65from typing import TypeVar
76
2423import a2a .types .a2a_pb2_grpc as a2a_grpc
2524
2625from a2a import types
27- from a2a .auth .user import UnauthenticatedUser
26+ from a2a .auth .user import UnauthenticatedUser , User
2827from a2a .extensions .common import (
2928 HTTP_EXTENSION_HEADER ,
3029 get_requested_extensions ,
4140
4241logger = logging .getLogger (__name__ )
4342
44- # For now we use a trivial wrapper on the grpc context object
43+ GrpcUserBuilder = Callable [[ grpc . aio . ServicerContext ], User ]
4544
4645
47- class CallContextBuilder (ABC ):
48- """A class for building ServerCallContexts using the Starlette Request."""
49-
50- @abstractmethod
51- def build (self , context : grpc .aio .ServicerContext ) -> ServerCallContext :
52- """Builds a ServerCallContext from a gRPC Request."""
46+ def default_grpc_user_builder (context : grpc .aio .ServicerContext ) -> User :
47+ """Default strategy for creating a User from a gRPC context."""
48+ return UnauthenticatedUser ()
5349
5450
5551def _get_metadata_value (
@@ -67,20 +63,19 @@ def _get_metadata_value(
6763 ]
6864
6965
70- class DefaultCallContextBuilder (CallContextBuilder ):
71- """A default implementation of CallContextBuilder."""
72-
73- def build (self , context : grpc .aio .ServicerContext ) -> ServerCallContext :
74- """Builds the ServerCallContext."""
75- user = UnauthenticatedUser ()
76- state = {'grpc_context' : context }
77- return ServerCallContext (
78- user = user ,
79- state = state ,
80- requested_extensions = get_requested_extensions (
81- _get_metadata_value (context , HTTP_EXTENSION_HEADER )
82- ),
83- )
66+ def build_grpc_server_call_context (
67+ context : grpc .aio .ServicerContext , user_builder : GrpcUserBuilder
68+ ) -> ServerCallContext :
69+ """Builds a ServerCallContext from a gRPC ServicerContext."""
70+ user = user_builder (context )
71+ state = {'grpc_context' : context }
72+ return ServerCallContext (
73+ user = user ,
74+ state = state ,
75+ requested_extensions = get_requested_extensions (
76+ _get_metadata_value (context , HTTP_EXTENSION_HEADER )
77+ ),
78+ )
8479
8580
8681_ERROR_CODE_MAP = {
@@ -110,7 +105,7 @@ def __init__(
110105 self ,
111106 agent_card : AgentCard ,
112107 request_handler : RequestHandler ,
113- context_builder : CallContextBuilder | None = None ,
108+ user_builder : GrpcUserBuilder | None = None ,
114109 card_modifier : Callable [[AgentCard ], Awaitable [AgentCard ] | AgentCard ]
115110 | None = None ,
116111 ):
@@ -120,14 +115,14 @@ def __init__(
120115 agent_card: The AgentCard describing the agent's capabilities.
121116 request_handler: The underlying `RequestHandler` instance to
122117 delegate requests to.
123- context_builder: The CallContextBuilder object. If none the
124- DefaultCallContextBuilder is used .
118+ user_builder: Optional custom user builder to extract user from the
119+ gRPC context .
125120 card_modifier: An optional callback to dynamically modify the public
126121 agent card before it is served.
127122 """
128123 self .agent_card = agent_card
129124 self .request_handler = request_handler
130- self .context_builder = context_builder or DefaultCallContextBuilder ()
125+ self .user_builder = user_builder or default_grpc_user_builder
131126 self .card_modifier = card_modifier
132127
133128 async def _handle_unary (
@@ -451,6 +446,8 @@ def _build_call_context(
451446 context : grpc .aio .ServicerContext ,
452447 request : message .Message ,
453448 ) -> ServerCallContext :
454- server_context = self .context_builder .build (context )
449+ server_context = build_grpc_server_call_context (
450+ context , self .user_builder
451+ )
455452 server_context .tenant = getattr (request , 'tenant' , '' )
456453 return server_context
0 commit comments