Skip to content
Closed
62 changes: 62 additions & 0 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from sse_starlette.sse import EventSourceResponse
from starlette.applications import Starlette
from starlette.authentication import BaseUser
from starlette.datastructures import URL
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand All @@ -70,6 +71,7 @@
from sse_starlette.sse import EventSourceResponse
from starlette.applications import Starlette
from starlette.authentication import BaseUser
from starlette.datastructures import URL
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand Down Expand Up @@ -486,6 +488,58 @@ async def event_generator(
headers=headers,
)

def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
"""Modifies Agent's RPC URL based on the AgentCard request.

Args:
agent_card (AgentCard): Original AgentCard
request (Request): AgentCard request
"""
rpc_url = URL(agent_card.url)
rpc_path = rpc_url.path
port = None
if 'X-Forwarded-Host' in request.headers:
host = request.headers['X-Forwarded-Host']
else:
host = request.url.hostname
port = request.url.port

if 'X-Forwarded-Proto' in request.headers:
scheme = request.headers['X-Forwarded-Proto']
port = None
else:
scheme = request.url.scheme
if not scheme:
scheme = 'http'
if ':' in host: # type: ignore
comps = host.rsplit(':', 1) # type: ignore
host = comps[0]
port = comps[1]
Comment thread
vladkol marked this conversation as resolved.
Outdated

# Handle URL maps,
# e.g. "agents/my-agent/.well-known/agent-card.json"
if 'X-Forwarded-Path' in request.headers:
forwarded_path = request.headers['X-Forwarded-Path'].strip()
if (
forwarded_path
and request.url.path != forwarded_path
and forwarded_path.endswith(request.url.path)
):
# "agents/my-agent" for "agents/my-agent/.well-known/agent-card.json"
extra_path = forwarded_path[: -len(request.url.path)]
new_path = extra_path + rpc_path
# If original path was just "/",
# we remove trailing "/" in the the extended one
if len(new_path) > 1 and rpc_path == '/':
new_path = new_path.rstrip('/')
rpc_path = new_path

agent_card.url = str(
rpc_url.replace(
hostname=host, port=port, scheme=scheme, path=rpc_path
)
)
Comment thread
vladkol marked this conversation as resolved.
Outdated

async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
"""Handles GET requests for the agent card endpoint.

Expand All @@ -502,8 +556,12 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse:
)

card_to_serve = self.agent_card
rpc_url = card_to_serve.url
if self.card_modifier:
card_to_serve = self.card_modifier(card_to_serve)
# If agent's RPC URL was not modified, we build it dynamically.
if rpc_url == card_to_serve.url:
self._modify_rpc_url(card_to_serve, request)

return JSONResponse(
card_to_serve.model_dump(
Expand All @@ -528,13 +586,17 @@ async def _handle_get_authenticated_extended_agent_card(

card_to_serve = self.extended_agent_card

rpc_url = card_to_serve.url if card_to_serve else None
if self.extended_card_modifier:
context = self._context_builder.build(request)
# If no base extended card is provided, pass the public card to the modifier
base_card = card_to_serve if card_to_serve else self.agent_card
card_to_serve = self.extended_card_modifier(base_card, context)

if card_to_serve:
# If agent's RPC URL was not modified, we build it dynamically.
if rpc_url == card_to_serve.url:
self._modify_rpc_url(card_to_serve, request)
return JSONResponse(
card_to_serve.model_dump(
exclude_none=True,
Expand Down
58 changes: 57 additions & 1 deletion src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

if TYPE_CHECKING:
from sse_starlette.sse import EventSourceResponse
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

Expand All @@ -15,6 +16,7 @@
else:
try:
from sse_starlette.sse import EventSourceResponse
from starlette.datastructures import URL
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

Expand Down Expand Up @@ -119,7 +121,8 @@ async def handle_get_agent_card(self, request: Request) -> JSONResponse:
A JSONResponse containing the agent card data.
"""
# The public agent card is a direct serialization of the agent_card
# provided at initialization.
# provided at initialization except for the RPC URL.
self._modify_rpc_url(self.agent_card, request)
return JSONResponse(
self.agent_card.model_dump(mode='json', exclude_none=True)
)
Expand All @@ -145,10 +148,63 @@ async def handle_authenticated_agent_card(
message='Authenticated card not supported'
)
)
self._modify_rpc_url(self.agent_card, request)
return JSONResponse(
self.agent_card.model_dump(mode='json', exclude_none=True)
)

def _modify_rpc_url(self, agent_card: AgentCard, request: Request):
"""Modifies Agent's RPC URL based on the AgentCard request.

Args:
agent_card (AgentCard): Original AgentCard
request (Request): AgentCard request
"""
rpc_url = URL(agent_card.url)
rpc_path = rpc_url.path
port = None
if 'X-Forwarded-Host' in request.headers:
host = request.headers['X-Forwarded-Host']
else:
host = request.url.hostname
port = request.url.port

if 'X-Forwarded-Proto' in request.headers:
scheme = request.headers['X-Forwarded-Proto']
port = None
else:
scheme = request.url.scheme
if not scheme:
scheme = 'http'
if ':' in host: # type: ignore
comps = host.rsplit(':', 1) # type: ignore
host = comps[0]
port = comps[1]
Comment thread
vladkol marked this conversation as resolved.
Outdated

# Handle URL maps,
# e.g. "agents/my-agent/.well-known/agent-card.json"
if 'X-Forwarded-Path' in request.headers:
forwarded_path = request.headers['X-Forwarded-Path'].strip()
if (
forwarded_path
and request.url.path != forwarded_path
and forwarded_path.endswith(request.url.path)
):
# "agents/my-agent" for "agents/my-agent/.well-known/agent-card.json"
extra_path = forwarded_path[: -len(request.url.path)]
new_path = extra_path + rpc_path
# If original path was just "/",
# we remove trailing "/" in the the extended one
if len(new_path) > 1 and rpc_path == '/':
new_path = new_path.rstrip('/')
rpc_path = new_path

agent_card.url = str(
rpc_url.replace(
hostname=host, port=port, scheme=scheme, path=rpc_path
)
)

def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
"""Constructs a dictionary of API routes and their corresponding handlers.

Expand Down
61 changes: 61 additions & 0 deletions tests/server/apps/jsonrpc/test_jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RequestHandler,
) # For mock spec
from a2a.types import (
AgentCapabilities,
AgentCard,
Message,
MessageSendParams,
Expand All @@ -36,6 +37,7 @@
SendMessageSuccessResponse,
TextPart,
)
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH


# --- StarletteUserProxy Tests ---
Expand Down Expand Up @@ -356,5 +358,64 @@ def side_effect(request, context: ServerCallContext):
}


class TestAgentCardHandler:
@pytest.fixture
def agent_card(self):
return AgentCard(
name='APIKeyAgent',
description='An agent that uses API Key auth.',
url='http://localhost:8000',
version='1.0.0',
capabilities=AgentCapabilities(),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
skills=[],
)

def test_agent_card_url_rewriting(
self,
agent_card: AgentCard,
):
"""
Tests that the A2AStarletteApplication endpoint correctly handles Agent URL rewriting.
"""
handler = AsyncMock()
app_instance = A2AStarletteApplication(agent_card, handler)
client = TestClient(
app_instance.build(), base_url='https://my-agents.com:5000'
)

response = client.get(AGENT_CARD_WELL_KNOWN_PATH)
response.raise_for_status()
assert response.json()['url'] == 'https://my-agents.com:5000'

response = client.get(
AGENT_CARD_WELL_KNOWN_PATH,
headers={
'X-Forwarded-Host': 'my-great-agents.com:5678',
'X-Forwarded-Proto': 'http',
'X-Forwarded-Path': '/agents/my-agent'
+ AGENT_CARD_WELL_KNOWN_PATH,
},
)
assert (
response.json()['url']
== 'http://my-great-agents.com:5678/agents/my-agent'
)

client = TestClient(
app_instance.build(
agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
),
base_url='https://my-mighty-agents.com',
)

response = client.get('/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH)
assert (
response.json()['url']
== 'https://my-mighty-agents.com/agents/my-agent'
)


if __name__ == '__main__':
pytest.main([__file__])
68 changes: 67 additions & 1 deletion tests/server/apps/rest/test_rest_fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from typing import Any
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock

import pytest

Expand All @@ -15,6 +15,7 @@
from a2a.server.apps.rest.rest_adapter import RESTAdapter
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types import (
AgentCapabilities,
AgentCard,
Message,
Part,
Expand All @@ -24,6 +25,7 @@
TaskStatus,
TextPart,
)
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -222,5 +224,69 @@ async def test_send_message_success_task(
assert expected_response == actual_response


class TestAgentCardHandler:
@pytest.fixture
def agent_card(self):
return AgentCard(
name='APIKeyAgent',
description='An agent that uses API Key auth.',
url='http://localhost:8000',
version='1.0.0',
capabilities=AgentCapabilities(),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
skills=[],
)

@pytest.mark.anyio
async def test_agent_card_url_rewriting(
self,
agent_card: AgentCard,
):
"""
Tests that the REST endpoint correctly handles Agent URL rewriting.
"""
app_instance = A2ARESTFastAPIApplication(agent_card, AsyncMock())
app = app_instance.build(agent_card_url=AGENT_CARD_WELL_KNOWN_PATH)
client = AsyncClient(
transport=ASGITransport(app=app),
base_url='https://my-agents.com:5000',
)

response = await client.get(AGENT_CARD_WELL_KNOWN_PATH)
response.raise_for_status()
assert response.json()['url'] == 'https://my-agents.com:5000'

response = await client.get(
AGENT_CARD_WELL_KNOWN_PATH,
headers={
'X-Forwarded-Host': 'my-great-agents.com:5678',
'X-Forwarded-Proto': 'http',
'X-Forwarded-Path': '/agents/my-agent'
+ AGENT_CARD_WELL_KNOWN_PATH,
},
)
assert (
response.json()['url']
== 'http://my-great-agents.com:5678/agents/my-agent'
)

app = app_instance.build(
agent_card_url='/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
)
client = AsyncClient(
transport=ASGITransport(app=app),
base_url='https://my-mighty-agents.com',
)

response = await client.get(
'/agents/my-agent' + AGENT_CARD_WELL_KNOWN_PATH
)
assert (
response.json()['url']
== 'https://my-mighty-agents.com/agents/my-agent'
)


if __name__ == '__main__':
pytest.main([__file__])