diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 65ece313e..3ae5ad6fe 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -1,15 +1,18 @@ import logging +from collections.abc import Callable from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from fastapi import APIRouter, FastAPI, Request, Response + from fastapi.responses import JSONResponse _package_fastapi_installed = True else: try: from fastapi import APIRouter, FastAPI, Request, Response + from fastapi.responses import JSONResponse _package_fastapi_installed = True except ImportError: @@ -23,6 +26,7 @@ from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder from a2a.server.apps.rest.rest_adapter import RESTAdapter +from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -39,11 +43,17 @@ class A2ARESTFastAPIApplication: (SSE). """ - def __init__( + def __init__( # noqa: PLR0913 self, agent_card: AgentCard, http_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], AgentCard] | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], AgentCard + ] + | None = None, ): """Initializes the A2ARESTFastAPIApplication. @@ -56,6 +66,11 @@ def __init__( context_builder: The CallContextBuilder used to construct the ServerCallContext passed to the http_handler. If None, no ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. """ if not _package_fastapi_installed: raise ImportError( @@ -66,7 +81,10 @@ def __init__( self._adapter = RESTAdapter( agent_card=agent_card, http_handler=http_handler, + extended_agent_card=extended_agent_card, context_builder=context_builder, + card_modifier=card_modifier, + extended_card_modifier=extended_card_modifier, ) def build( @@ -95,7 +113,8 @@ def build( @router.get(f'{rpc_url}{agent_card_url}') async def get_agent_card(request: Request) -> Response: - return await self._adapter.handle_get_agent_card(request) + card = await self._adapter.handle_get_agent_card(request) + return JSONResponse(card) app.include_router(router) return app diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 898192854..0860825bf 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -52,11 +52,17 @@ class RESTAdapter: manages response generation including Server-Sent Events (SSE). """ - def __init__( + def __init__( # noqa: PLR0913 self, agent_card: AgentCard, http_handler: RequestHandler, + extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, + card_modifier: Callable[[AgentCard], AgentCard] | None = None, + extended_card_modifier: Callable[ + [AgentCard, ServerCallContext], AgentCard + ] + | None = None, ): """Initializes the RESTApplication. @@ -64,9 +70,16 @@ def __init__( agent_card: The AgentCard describing the agent's capabilities. http_handler: The handler instance responsible for processing A2A requests via http. + extended_agent_card: An optional, distinct AgentCard to be served + at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the ServerCallContext passed to the http_handler. If None, no ServerCallContext is passed. + card_modifier: An optional callback to dynamically modify the public + agent card before it is served. + extended_card_modifier: An optional callback to dynamically modify + the extended agent card before it is served. It receives the + call context. """ if not _package_starlette_installed: raise ImportError( @@ -75,9 +88,20 @@ def __init__( ' optional dependencies, `a2a-sdk[http-server]`.' ) self.agent_card = agent_card + self.extended_agent_card = extended_agent_card + self.card_modifier = card_modifier + self.extended_card_modifier = extended_card_modifier self.handler = RESTHandler( agent_card=agent_card, request_handler=http_handler ) + if ( + self.agent_card.supports_authenticated_extended_card + and self.extended_agent_card is None + and self.extended_card_modifier is None + ): + logger.error( + 'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.' + ) self._context_builder = context_builder or DefaultCallContextBuilder() @rest_error_handler @@ -108,26 +132,27 @@ async def event_generator( event_generator(method(request, call_context)) ) - @rest_error_handler - async def handle_get_agent_card(self, request: Request) -> JSONResponse: + async def handle_get_agent_card( + self, request: Request, call_context: ServerCallContext | None = None + ) -> dict[str, Any]: """Handles GET requests for the agent card endpoint. Args: request: The incoming Starlette Request object. + call_context: ServerCallContext Returns: A JSONResponse containing the agent card data. """ - # The public agent card is a direct serialization of the agent_card - # provided at initialization. - return JSONResponse( - self.agent_card.model_dump(mode='json', exclude_none=True) - ) + card_to_serve = self.agent_card + if self.card_modifier: + card_to_serve = self.card_modifier(card_to_serve) + + return card_to_serve.model_dump(mode='json', exclude_none=True) - @rest_error_handler async def handle_authenticated_agent_card( - self, request: Request - ) -> JSONResponse: + self, request: Request, call_context: ServerCallContext | None = None + ) -> dict[str, Any]: """Hook for per credential agent card response. If a dynamic card is needed based on the credentials provided in the request @@ -135,6 +160,7 @@ async def handle_authenticated_agent_card( Args: request: The incoming Starlette Request object. + call_context: ServerCallContext Returns: A JSONResponse containing the authenticated card. @@ -145,9 +171,18 @@ async def handle_authenticated_agent_card( message='Authenticated card not supported' ) ) - return JSONResponse( - self.agent_card.model_dump(mode='json', exclude_none=True) - ) + card_to_serve = self.extended_agent_card + + if not card_to_serve: + card_to_serve = self.agent_card + + if self.extended_card_modifier: + context = self._context_builder.build(request) + # If no base extended card is provided, pass the public card to the modifier + base_card = card_to_serve if card_to_serve else self.agent_card + card_to_serve = self.extended_card_modifier(base_card, context) + + return card_to_serve.model_dump(mode='json', exclude_none=True) def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: """Constructs a dictionary of API routes and their corresponding handlers. @@ -201,6 +236,8 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: ), } if self.agent_card.supports_authenticated_extended_card: - routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card + routes[('/v1/card', 'GET')] = functools.partial( + self._handle_request, self.handle_authenticated_agent_card + ) return routes diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 46907ee64..88d4d3d11 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,5 +1,4 @@ import asyncio - from collections.abc import AsyncGenerator from typing import NamedTuple from unittest.mock import ANY, AsyncMock @@ -8,7 +7,6 @@ import httpx import pytest import pytest_asyncio - from grpc.aio import Channel from a2a.client.transports import JsonRpcTransport, RestTransport @@ -38,7 +36,6 @@ TransportProtocol, ) - # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -130,7 +127,7 @@ def agent_card() -> AgentCard: default_input_modes=['text/plain'], default_output_modes=['text/plain'], preferred_transport=TransportProtocol.jsonrpc, - supports_authenticated_extended_card=True, + supports_authenticated_extended_card=False, additional_interfaces=[ AgentInterface( transport=TransportProtocol.http_json, url='http://testserver' @@ -709,9 +706,7 @@ async def test_http_transport_get_card( transport_setup_fixture ) transport = transport_setup.transport - - # The transport starts with a minimal card, get_card() fetches the full one - transport.agent_card.supports_authenticated_extended_card = True + # Get the base card. result = await transport.get_card() assert result.name == agent_card.name @@ -722,6 +717,33 @@ async def test_http_transport_get_card( await transport.close() +@pytest.mark.asyncio +async def test_http_transport_get_authenticated_card( + agent_card: AgentCard, + mock_request_handler: AsyncMock, +) -> None: + agent_card.supports_authenticated_extended_card = True + extended_agent_card = agent_card.model_copy(deep=True) + extended_agent_card.name = 'Extended Agent Card' + + app_builder = A2ARESTFastAPIApplication( + agent_card, + mock_request_handler, + extended_agent_card=extended_agent_card, + ) + app = app_builder.build() + httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + + transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card) + result = await transport.get_card() + assert result.name == extended_agent_card.name + assert transport.agent_card.name == extended_agent_card.name + assert transport._needs_extended_card is False + + if hasattr(transport, 'close'): + await transport.close() + + @pytest.mark.asyncio async def test_grpc_transport_get_card( grpc_server_and_handler: tuple[str, AsyncMock],