diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index e0b638879..736299a12 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -18,6 +18,7 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_utils import update_card_rpc_url_from_request from a2a.types import ( A2AError, A2ARequest, @@ -89,6 +90,7 @@ Request = Any JSONResponse = Any Response = Any + URL = Any HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any MAX_CONTENT_LENGTH = 1_000_000 @@ -568,8 +570,12 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: ) card_to_serve = self.agent_card + rpc_url = card_to_serve.url if self.card_modifier: card_to_serve = self.card_modifier(card_to_serve) + # If agent's RPC URL was not modified, we build it dynamically. + if rpc_url == card_to_serve.url: + update_card_rpc_url_from_request(card_to_serve, request) return JSONResponse( card_to_serve.model_dump( @@ -594,6 +600,7 @@ async def _handle_get_authenticated_extended_agent_card( card_to_serve = self.extended_agent_card + rpc_url = card_to_serve.url if card_to_serve else None if self.extended_card_modifier: context = self._context_builder.build(request) # If no base extended card is provided, pass the public card to the modifier @@ -601,6 +608,9 @@ async def _handle_get_authenticated_extended_agent_card( card_to_serve = self.extended_card_modifier(base_card, context) if card_to_serve: + # If agent's RPC URL was not modified, we build it dynamically. + if rpc_url == card_to_serve.url: + update_card_rpc_url_from_request(card_to_serve, request) return JSONResponse( card_to_serve.model_dump( exclude_none=True, diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 0860825bf..70e7b379f 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -34,6 +34,7 @@ from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.rest_handler import RESTHandler +from a2a.server.request_utils import update_card_rpc_url_from_request from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError from a2a.utils.error_handlers import ( rest_error_handler, @@ -145,8 +146,11 @@ async def handle_get_agent_card( A JSONResponse containing the agent card data. """ card_to_serve = self.agent_card + rpc_url = card_to_serve.url if self.card_modifier: card_to_serve = self.card_modifier(card_to_serve) + if rpc_url == card_to_serve.url: + update_card_rpc_url_from_request(card_to_serve, request) return card_to_serve.model_dump(mode='json', exclude_none=True) @@ -175,12 +179,15 @@ async def handle_authenticated_agent_card( if not card_to_serve: card_to_serve = self.agent_card + rpc_url = card_to_serve.url if self.extended_card_modifier: context = self._context_builder.build(request) # If no base extended card is provided, pass the public card to the modifier base_card = card_to_serve if card_to_serve else self.agent_card card_to_serve = self.extended_card_modifier(base_card, context) + if rpc_url == card_to_serve.url: + update_card_rpc_url_from_request(card_to_serve, request) return card_to_serve.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/server/request_utils.py b/src/a2a/server/request_utils.py new file mode 100644 index 000000000..cc7f9f05f --- /dev/null +++ b/src/a2a/server/request_utils.py @@ -0,0 +1,73 @@ +from typing import TYPE_CHECKING, Any + +from a2a.types import AgentCard + + +if TYPE_CHECKING: + from starlette.datastructures import URL + from starlette.requests import Request + + _package_starlette_installed = True +else: + try: + from starlette.datastructures import URL + from starlette.requests import Request + + _package_starlette_installed = True + except ImportError: + _package_starlette_installed = False + URL = Any + Request = Any + + +def update_card_rpc_url_from_request( + agent_card: AgentCard, request: Request +) -> None: + """Modifies Agent's RPC URL based on the AgentCard request. + + Args: + agent_card (AgentCard): Original AgentCard + request (Request): AgentCard request + """ + rpc_url = URL(agent_card.url) + rpc_path = rpc_url.path + port = None + if 'X-Forwarded-Host' in request.headers: + host = request.headers['X-Forwarded-Host'] + else: + host = request.url.hostname or rpc_url.hostname or 'localhost' + port = request.url.port + + if 'X-Forwarded-Proto' in request.headers: + scheme = request.headers['X-Forwarded-Proto'] + port = None + else: + scheme = request.url.scheme + if not scheme: + scheme = 'http' + if ':' in host: # type: ignore + comps = host.rsplit(':', 1) # type: ignore + host = comps[0] + port = int(comps[1]) if comps[1] else port + + # Handle URL maps, + # e.g. "agents/my-agent/.well-known/agent-card.json" + if 'X-Forwarded-Path' in request.headers: + forwarded_path = request.headers['X-Forwarded-Path'].strip() + if ( + forwarded_path + and request.url.path != forwarded_path + and forwarded_path.endswith(request.url.path) + ): + # "agents/my-agent" for "agents/my-agent/.well-known/agent-card.json" + extra_path = forwarded_path[: -len(request.url.path)] + new_path = extra_path + rpc_path + # If original path was just "/", + # we remove trailing "/" in the extended one + if len(new_path) > 1 and rpc_path == '/': + new_path = new_path.rstrip('/') + rpc_path = new_path + + agent_card.url = str( + rpc_url.replace(hostname=host, port=port, scheme=scheme, path=rpc_path) + ) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 88d4d3d11..323ff00fd 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import NamedTuple from unittest.mock import ANY, AsyncMock @@ -7,6 +8,7 @@ import httpx import pytest import pytest_asyncio + from grpc.aio import Channel from a2a.client.transports import JsonRpcTransport, RestTransport @@ -36,6 +38,7 @@ TransportProtocol, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( diff --git a/tests/server/apps/jsonrpc/test_jsonrpc_app.py b/tests/server/apps/jsonrpc/test_jsonrpc_app.py index 72da73772..855d80b0c 100644 --- a/tests/server/apps/jsonrpc/test_jsonrpc_app.py +++ b/tests/server/apps/jsonrpc/test_jsonrpc_app.py @@ -26,6 +26,7 @@ RequestHandler, ) # For mock spec from a2a.types import ( + AgentCapabilities, AgentCard, Message, MessageSendParams, @@ -36,6 +37,7 @@ SendMessageSuccessResponse, TextPart, ) +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH # --- StarletteUserProxy Tests --- @@ -356,5 +358,64 @@ def side_effect(request, context: ServerCallContext): } +class TestAgentCardHandler: + @pytest.fixture + def agent_card(self): + return AgentCard( + name='APIKeyAgent', + description='An agent that uses API Key auth.', + url='http://localhost:8000', + version='1.0.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + skills=[], + ) + + def test_agent_card_url_rewriting( + self, + agent_card: AgentCard, + ): + """ + Tests that the A2AStarletteApplication endpoint correctly handles Agent URL rewriting. + """ + handler = AsyncMock() + app_instance = A2AStarletteApplication(agent_card, handler) + client = TestClient( + app_instance.build(), base_url='https://my-agents.com:5000' + ) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + response.raise_for_status() + assert response.json()['url'] == 'https://my-agents.com:5000' + + response = client.get( + AGENT_CARD_WELL_KNOWN_PATH, + headers={ + 'X-Forwarded-Host': 'my-great-agents.com:5678', + 'X-Forwarded-Proto': 'http', + 'X-Forwarded-Path': '/agents/my-agent' + + AGENT_CARD_WELL_KNOWN_PATH, + }, + ) + assert ( + response.json()['url'] + == 'http://my-great-agents.com:5678/agents/my-agent' + ) + + client = TestClient( + app_instance.build( + agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH + ), + base_url='https://my-mighty-agents.com', + ) + + response = client.get('/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH) + assert ( + response.json()['url'] + == 'https://my-mighty-agents.com/agents/my-agent' + ) + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index c5ea89c40..4059647e5 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -1,7 +1,7 @@ import logging from typing import Any -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -15,6 +15,7 @@ from a2a.server.apps.rest.rest_adapter import RESTAdapter from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( + AgentCapabilities, AgentCard, Message, Part, @@ -24,6 +25,7 @@ TaskStatus, TextPart, ) +from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH logger = logging.getLogger(__name__) @@ -222,5 +224,69 @@ async def test_send_message_success_task( assert expected_response == actual_response +class TestAgentCardHandler: + @pytest.fixture + def agent_card(self): + return AgentCard( + name='APIKeyAgent', + description='An agent that uses API Key auth.', + url='http://localhost:8000', + version='1.0.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + skills=[], + ) + + @pytest.mark.anyio + async def test_agent_card_url_rewriting( + self, + agent_card: AgentCard, + ): + """ + Tests that the REST endpoint correctly handles Agent URL rewriting. + """ + app_instance = A2ARESTFastAPIApplication(agent_card, AsyncMock()) + app = app_instance.build(agent_card_url=AGENT_CARD_WELL_KNOWN_PATH) + client = AsyncClient( + transport=ASGITransport(app=app), + base_url='https://my-agents.com:5000', + ) + + response = await client.get(AGENT_CARD_WELL_KNOWN_PATH) + response.raise_for_status() + assert response.json()['url'] == 'https://my-agents.com:5000' + + response = await client.get( + AGENT_CARD_WELL_KNOWN_PATH, + headers={ + 'X-Forwarded-Host': 'my-great-agents.com:5678', + 'X-Forwarded-Proto': 'http', + 'X-Forwarded-Path': '/agents/my-agent' + + AGENT_CARD_WELL_KNOWN_PATH, + }, + ) + assert ( + response.json()['url'] + == 'http://my-great-agents.com:5678/agents/my-agent' + ) + + app = app_instance.build( + agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH + ) + client = AsyncClient( + transport=ASGITransport(app=app), + base_url='https://my-mighty-agents.com', + ) + + response = await client.get( + '/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH + ) + assert ( + response.json()['url'] + == 'https://my-mighty-agents.com/agents/my-agent' + ) + + if __name__ == '__main__': pytest.main([__file__])