Skip to content
Closed
10 changes: 10 additions & 0 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.request_utils import update_card_rpc_url_from_request
from a2a.types import (
A2AError,
A2ARequest,
Expand Down Expand Up @@ -89,6 +90,7 @@
Request = Any
JSONResponse = Any
Response = Any
URL = Any
HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any

MAX_CONTENT_LENGTH = 1_000_000
Expand Down Expand Up @@ -568,8 +570,12 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
)

card_to_serve = self.agent_card
rpc_url = card_to_serve.url
if self.card_modifier:
card_to_serve = self.card_modifier(card_to_serve)
# If agent's RPC URL was not modified, we build it dynamically.
if rpc_url == card_to_serve.url:
update_card_rpc_url_from_request(card_to_serve, request)

return JSONResponse(
card_to_serve.model_dump(
Expand All @@ -594,13 +600,17 @@ async def _handle_get_authenticated_extended_agent_card(

card_to_serve = self.extended_agent_card

rpc_url = card_to_serve.url if card_to_serve else None
if self.extended_card_modifier:
context = self._context_builder.build(request)
# If no base extended card is provided, pass the public card to the modifier
base_card = card_to_serve if card_to_serve else self.agent_card
card_to_serve = self.extended_card_modifier(base_card, context)

if card_to_serve:
# If agent's RPC URL was not modified, we build it dynamically.
if rpc_url == card_to_serve.url:
update_card_rpc_url_from_request(card_to_serve, request)
return JSONResponse(
card_to_serve.model_dump(
exclude_none=True,
Expand Down
7 changes: 7 additions & 0 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.request_handlers.rest_handler import RESTHandler
from a2a.server.request_utils import update_card_rpc_url_from_request
from a2a.types import AgentCard, AuthenticatedExtendedCardNotConfiguredError
from a2a.utils.error_handlers import (
rest_error_handler,
Expand Down Expand Up @@ -145,8 +146,11 @@ async def handle_get_agent_card(
A JSONResponse containing the agent card data.
"""
card_to_serve = self.agent_card
rpc_url = card_to_serve.url
if self.card_modifier:
card_to_serve = self.card_modifier(card_to_serve)
if rpc_url == card_to_serve.url:
update_card_rpc_url_from_request(card_to_serve, request)

return card_to_serve.model_dump(mode='json', exclude_none=True)

Expand Down Expand Up @@ -175,12 +179,15 @@ async def handle_authenticated_agent_card(

if not card_to_serve:
card_to_serve = self.agent_card
rpc_url = card_to_serve.url

if self.extended_card_modifier:
context = self._context_builder.build(request)
# If no base extended card is provided, pass the public card to the modifier
base_card = card_to_serve if card_to_serve else self.agent_card
card_to_serve = self.extended_card_modifier(base_card, context)
if rpc_url == card_to_serve.url:
update_card_rpc_url_from_request(card_to_serve, request)

return card_to_serve.model_dump(mode='json', exclude_none=True)

Expand Down
73 changes: 73 additions & 0 deletions src/a2a/server/request_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import TYPE_CHECKING, Any

from a2a.types import AgentCard


if TYPE_CHECKING:
from starlette.datastructures import URL
from starlette.requests import Request

_package_starlette_installed = True
else:
try:
from starlette.datastructures import URL
from starlette.requests import Request

_package_starlette_installed = True
except ImportError:
_package_starlette_installed = False
URL = Any
Request = Any


def update_card_rpc_url_from_request(
agent_card: AgentCard, request: Request
) -> None:
"""Modifies Agent's RPC URL based on the AgentCard request.

Args:
agent_card (AgentCard): Original AgentCard
request (Request): AgentCard request
"""
rpc_url = URL(agent_card.url)
rpc_path = rpc_url.path
port = None
if 'X-Forwarded-Host' in request.headers:
host = request.headers['X-Forwarded-Host']
else:
host = request.url.hostname or rpc_url.hostname or 'localhost'
port = request.url.port

if 'X-Forwarded-Proto' in request.headers:
scheme = request.headers['X-Forwarded-Proto']
port = None
else:
scheme = request.url.scheme
if not scheme:
scheme = 'http'
if ':' in host: # type: ignore
comps = host.rsplit(':', 1) # type: ignore
host = comps[0]
port = int(comps[1]) if comps[1] else port

# Handle URL maps,
# e.g. "agents/my-agent/.well-known/agent-card.json"
if 'X-Forwarded-Path' in request.headers:
forwarded_path = request.headers['X-Forwarded-Path'].strip()
if (
forwarded_path
and request.url.path != forwarded_path
and forwarded_path.endswith(request.url.path)
):
# "agents/my-agent" for "agents/my-agent/.well-known/agent-card.json"
extra_path = forwarded_path[: -len(request.url.path)]
new_path = extra_path + rpc_path
# If original path was just "/",
# we remove trailing "/" in the extended one
if len(new_path) > 1 and rpc_path == '/':
new_path = new_path.rstrip('/')
rpc_path = new_path

agent_card.url = str(
rpc_url.replace(hostname=host, port=port, scheme=scheme, path=rpc_path)
)
3 changes: 3 additions & 0 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio

from collections.abc import AsyncGenerator
from typing import NamedTuple
from unittest.mock import ANY, AsyncMock
Expand All @@ -7,6 +8,7 @@
import httpx
import pytest
import pytest_asyncio

from grpc.aio import Channel

from a2a.client.transports import JsonRpcTransport, RestTransport
Expand Down Expand Up @@ -36,6 +38,7 @@
TransportProtocol,
)


# --- Test Constants ---

TASK_FROM_STREAM = Task(
Expand Down
61 changes: 61 additions & 0 deletions tests/server/apps/jsonrpc/test_jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RequestHandler,
) # For mock spec
from a2a.types import (
AgentCapabilities,
AgentCard,
Message,
MessageSendParams,
Expand All @@ -36,6 +37,7 @@
SendMessageSuccessResponse,
TextPart,
)
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH


# --- StarletteUserProxy Tests ---
Expand Down Expand Up @@ -356,5 +358,64 @@ def side_effect(request, context: ServerCallContext):
}


class TestAgentCardHandler:
@pytest.fixture
def agent_card(self):
return AgentCard(
name='APIKeyAgent',
description='An agent that uses API Key auth.',
url='http://localhost:8000',
version='1.0.0',
capabilities=AgentCapabilities(),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
skills=[],
)

def test_agent_card_url_rewriting(
self,
agent_card: AgentCard,
):
"""
Tests that the A2AStarletteApplication endpoint correctly handles Agent URL rewriting.
"""
handler = AsyncMock()
app_instance = A2AStarletteApplication(agent_card, handler)
client = TestClient(
app_instance.build(), base_url='https://my-agents.com:5000'
)

response = client.get(AGENT_CARD_WELL_KNOWN_PATH)
response.raise_for_status()
assert response.json()['url'] == 'https://my-agents.com:5000'

response = client.get(
AGENT_CARD_WELL_KNOWN_PATH,
headers={
'X-Forwarded-Host': 'my-great-agents.com:5678',
'X-Forwarded-Proto': 'http',
'X-Forwarded-Path': '/agents/my-agent'
+ AGENT_CARD_WELL_KNOWN_PATH,
},
)
assert (
response.json()['url']
== 'http://my-great-agents.com:5678/agents/my-agent'
)

client = TestClient(
app_instance.build(
agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
),
base_url='https://my-mighty-agents.com',
)

response = client.get('/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH)
assert (
response.json()['url']
== 'https://my-mighty-agents.com/agents/my-agent'
)


if __name__ == '__main__':
pytest.main([__file__])
68 changes: 67 additions & 1 deletion tests/server/apps/rest/test_rest_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from typing import Any
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest

Expand All @@ -15,6 +15,7 @@
from a2a.server.apps.rest.rest_adapter import RESTAdapter
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types import (
AgentCapabilities,
AgentCard,
Message,
Part,
Expand All @@ -24,6 +25,7 @@
TaskStatus,
TextPart,
)
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -222,5 +224,69 @@ async def test_send_message_success_task(
assert expected_response == actual_response


class TestAgentCardHandler:
@pytest.fixture
def agent_card(self):
return AgentCard(
name='APIKeyAgent',
description='An agent that uses API Key auth.',
url='http://localhost:8000',
version='1.0.0',
capabilities=AgentCapabilities(),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
skills=[],
)

@pytest.mark.anyio
async def test_agent_card_url_rewriting(
self,
agent_card: AgentCard,
):
"""
Tests that the REST endpoint correctly handles Agent URL rewriting.
"""
app_instance = A2ARESTFastAPIApplication(agent_card, AsyncMock())
app = app_instance.build(agent_card_url=AGENT_CARD_WELL_KNOWN_PATH)
client = AsyncClient(
transport=ASGITransport(app=app),
base_url='https://my-agents.com:5000',
)

response = await client.get(AGENT_CARD_WELL_KNOWN_PATH)
response.raise_for_status()
assert response.json()['url'] == 'https://my-agents.com:5000'

response = await client.get(
AGENT_CARD_WELL_KNOWN_PATH,
headers={
'X-Forwarded-Host': 'my-great-agents.com:5678',
'X-Forwarded-Proto': 'http',
'X-Forwarded-Path': '/agents/my-agent'
+ AGENT_CARD_WELL_KNOWN_PATH,
},
)
assert (
response.json()['url']
== 'http://my-great-agents.com:5678/agents/my-agent'
)

app = app_instance.build(
agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
)
client = AsyncClient(
transport=ASGITransport(app=app),
base_url='https://my-mighty-agents.com',
)

response = await client.get(
'/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
)
assert (
response.json()['url']
== 'https://my-mighty-agents.com/agents/my-agent'
)


if __name__ == '__main__':
pytest.main([__file__])
Loading