Skip to content

Commit ac279fe

Browse files
committed
fix: address review feedback and add tests
- Move context ownership warning to __init__ (was unreachable) - Remove unreachable owner-is-None guard in _check_context_ownership - Change _BLOCKED_NETWORKS to tuple for immutability - Restore dropped docstrings in card_resolver and default_request_handler - Fix non-ASCII chars in comments (use -> and --) - Add tests/utils/test_url_validation.py with 26 SSRF validation tests
1 parent 4e72411 commit ac279fe

6 files changed

Lines changed: 217 additions & 40 deletions

File tree

src/a2a/client/card_resolver.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(
4646
base_url: str,
4747
agent_card_path: str = AGENT_CARD_WELL_KNOWN_PATH,
4848
) -> None:
49+
"""Initializes the A2ACardResolver.
50+
51+
Args:
52+
httpx_client: An async HTTP client instance (e.g., httpx.AsyncClient).
53+
base_url: The base URL of the agent's host.
54+
agent_card_path: The path to the agent card endpoint, relative to the base URL.
55+
"""
4956
self.base_url = base_url.rstrip('/')
5057
self.agent_card_path = agent_card_path.lstrip('/')
5158
self.httpx_client = httpx_client
@@ -56,6 +63,27 @@ async def get_agent_card(
5663
http_kwargs: dict[str, Any] | None = None,
5764
signature_verifier: Callable[[AgentCard], None] | None = None,
5865
) -> AgentCard:
66+
"""Fetches an agent card from a specified path relative to the base_url.
67+
68+
If relative_card_path is None, it defaults to the resolver's configured
69+
agent_card_path (for the public agent card).
70+
71+
Args:
72+
relative_card_path: Optional path to the agent card endpoint,
73+
relative to the base URL. If None, uses the default public
74+
agent card path. Use `'/'` for an empty path.
75+
http_kwargs: Optional dictionary of keyword arguments to pass to the
76+
underlying httpx.get request.
77+
signature_verifier: A callable used to verify the agent card's signatures.
78+
79+
Returns:
80+
An `AgentCard` object representing the agent's capabilities.
81+
82+
Raises:
83+
A2AClientHTTPError: If an HTTP error occurs during the request.
84+
A2AClientJSONError: If the response body cannot be decoded as JSON,
85+
validated against the AgentCard schema, or fails SSRF URL validation.
86+
"""
5987
if not relative_card_path:
6088
path_segment = self.agent_card_path
6189
else:
@@ -77,7 +105,7 @@ async def get_agent_card(
77105
)
78106
agent_card = AgentCard.model_validate(agent_card_data)
79107

80-
# ---- FIX: A2A-SSRF-01 validate card.url before returning ----
108+
# ---- FIX: A2A-SSRF-01 -- validate card.url before returning ----
81109
# Without this check, any caller who controls the card endpoint
82110
# can redirect all subsequent RPC calls to an internal address.
83111
try:

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
Root cause of vulnerability:
77
_setup_message_execution() uses params.message.context_id directly without
88
any ownership check. An attacker who knows a victim's contextId can send a
9-
new task under that context task_manager.get_task() returns None for the
9+
new task under that context -- task_manager.get_task() returns None for the
1010
new task_id, so the original task-level check is never reached.
1111
1212
Fix design:
13-
DefaultRequestHandler maintains a _context_owners dict (context_id owner)
13+
DefaultRequestHandler maintains a _context_owners dict (context_id -> owner)
1414
in memory. When a get_caller_id extractor is configured:
1515
1. On first message for a context_id: record caller as owner.
1616
2. On subsequent messages for same context_id: verify caller matches owner.
17-
If get_caller_id is None (default): no ownership tracking backward compatible.
17+
If get_caller_id is None (default): no ownership tracking -- backward compatible.
1818
1919
Target file: src/a2a/server/request_handlers/default_request_handler.py
2020
"""
@@ -118,7 +118,7 @@ def __init__( # noqa: PLR0913
118118
fingerprint). When provided, the handler tracks which caller
119119
created each contextId and rejects messages from different
120120
callers attempting to join that context (A2A-INJ-01 fix).
121-
If None (default), no ownership tracking is performed
121+
If None (default), no ownership tracking is performed --
122122
backward compatible with existing deployments.
123123
124124
Example::
@@ -147,8 +147,15 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None:
147147
)
148148
# ---- NEW (fix for A2A-INJ-01) ----
149149
self._get_caller_id: CallerIdExtractor | None = get_caller_id
150-
# Maps context_id owner identity; populated on first message per context.
150+
# Maps context_id -> owner identity; populated on first message per context.
151151
self._context_owners: dict[str, str] = {}
152+
if get_caller_id is None:
153+
logger.warning(
154+
'DefaultRequestHandler initialized without get_caller_id: '
155+
'context ownership is not enforced. Cross-user context injection '
156+
'(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id '
157+
'extractor to enable ownership checks.'
158+
)
152159
# ----------------------------------
153160
self._running_agents = {}
154161
self._running_agents_lock = asyncio.Lock()
@@ -168,7 +175,10 @@ async def on_get_task(
168175
async def on_cancel_task(
169176
self, params: TaskIdParams, context: ServerCallContext | None = None
170177
) -> Task | None:
171-
"""Default handler for 'tasks/cancel'."""
178+
"""Default handler for 'tasks/cancel'.
179+
180+
Attempts to cancel the task managed by the `AgentExecutor`.
181+
"""
172182
task: Task | None = await self.task_store.get(params.id, context)
173183
if not task:
174184
raise ServerError(error=TaskNotFoundError())
@@ -225,6 +235,12 @@ async def on_cancel_task(
225235
async def _run_event_stream(
226236
self, request: RequestContext, queue: EventQueue
227237
) -> None:
238+
"""Runs the agent's `execute` method and closes the queue afterwards.
239+
240+
Args:
241+
request: The request context for the agent.
242+
queue: The event queue for the agent to publish to.
243+
"""
228244
await self.agent_executor.execute(request, queue)
229245
await queue.close()
230246

@@ -236,32 +252,13 @@ def _check_context_ownership(
236252
"""Enforce context ownership when get_caller_id is configured.
237253
238254
Called before any message is processed for an existing context_id.
255+
Only invoked when context_id is already present in _context_owners,
256+
which guarantees _get_caller_id is not None and owner is not None.
239257
Raises ServerError(InvalidParamsError) if the caller does not own
240258
the context.
241259
"""
242-
if self._get_caller_id is None:
243-
# Ownership tracking not configured — log warning and allow.
244-
# Operators should configure get_caller_id in production.
245-
logger.warning(
246-
'Context ownership not enforced for context_id=%s: '
247-
'no get_caller_id configured on DefaultRequestHandler. '
248-
'This allows cross-user context injection (A2A-INJ-01 / CWE-639). '
249-
'Provide a get_caller_id extractor to enable ownership checks.',
250-
context_id,
251-
)
252-
return
253-
254-
caller = self._get_caller_id(context)
255-
owner = self._context_owners.get(context_id)
256-
257-
if owner is None:
258-
# Context exists in the store but ownership was not recorded
259-
# (e.g. created before this patch was deployed). Skip check.
260-
logger.debug(
261-
'context_id=%s has no recorded owner; skipping ownership check.',
262-
context_id,
263-
)
264-
return
260+
caller = self._get_caller_id(context) # type: ignore[misc]
261+
owner = self._context_owners[context_id]
265262

266263
if caller is None:
267264
raise ServerError(
@@ -308,10 +305,10 @@ async def _setup_message_execution(
308305
) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]:
309306
context_id = params.message.context_id
310307

311-
# ---- FIX: A2A-INJ-01 enforce context ownership BEFORE task lookup ----
308+
# ---- FIX: A2A-INJ-01 -- enforce context ownership BEFORE task lookup ----
312309
# The check must happen at context_id level, not task level. An attacker
313310
# who sends a new task_id under an existing context_id would otherwise
314-
# bypass a task-level check (get_task() returns None check never runs).
311+
# bypass a task-level check (get_task() returns None -> check never runs).
315312
if context_id and context_id in self._context_owners:
316313
self._check_context_ownership(context_id, context)
317314
# -----------------------------------------------------------------------
@@ -396,7 +393,11 @@ async def on_message_send(
396393
params: MessageSendParams,
397394
context: ServerCallContext | None = None,
398395
) -> Message | Task:
399-
"""Default handler for 'message/send' (non-streaming)."""
396+
"""Default handler for 'message/send' interface (non-streaming).
397+
398+
Starts the agent execution for the message and waits for the final
399+
result (Task or Message).
400+
"""
400401
(
401402
_task_manager,
402403
task_id,
@@ -461,7 +462,11 @@ async def on_message_send_stream(
461462
params: MessageSendParams,
462463
context: ServerCallContext | None = None,
463464
) -> AsyncGenerator[Event]:
464-
"""Default handler for 'message/stream' (streaming)."""
465+
"""Default handler for 'message/stream' (streaming).
466+
467+
Starts the agent execution and yields events as they are produced
468+
by the agent.
469+
"""
465470
(
466471
_task_manager,
467472
task_id,

src/a2a/utils/url_validation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,27 @@
2121
# Networks that must never be reachable via a resolved AgentCard URL.
2222
# Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other
2323
# IANA-reserved blocks that have no legitimate use as public agent endpoints.
24-
_BLOCKED_NETWORKS: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = [
24+
_BLOCKED_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = (
2525
# Loopback
2626
ipaddress.ip_network('127.0.0.0/8'),
2727
ipaddress.ip_network('::1/128'),
2828
# RFC 1918 private ranges
2929
ipaddress.ip_network('10.0.0.0/8'),
3030
ipaddress.ip_network('172.16.0.0/12'),
3131
ipaddress.ip_network('192.168.0.0/16'),
32-
# Link-local covers AWS/GCP/Azure/OCI IMDS (169.254.169.254)
32+
# Link-local -- covers AWS/GCP/Azure/OCI IMDS (169.254.169.254)
3333
ipaddress.ip_network('169.254.0.0/16'),
3434
ipaddress.ip_network('fe80::/10'),
35-
# IPv6 unique local (ULA) equivalent of RFC 1918 for IPv6
35+
# IPv6 unique local (ULA) -- equivalent of RFC 1918 for IPv6
3636
ipaddress.ip_network('fc00::/7'),
37-
# Shared address space (RFC 6598 carrier-grade NAT)
37+
# Shared address space (RFC 6598 -- carrier-grade NAT)
3838
ipaddress.ip_network('100.64.0.0/10'),
3939
# Other IANA reserved / unroutable
4040
ipaddress.ip_network('0.0.0.0/8'),
4141
ipaddress.ip_network('192.0.0.0/24'),
4242
ipaddress.ip_network('198.18.0.0/15'),
4343
ipaddress.ip_network('240.0.0.0/4'),
44-
]
44+
)
4545

4646

4747
class A2ASSRFValidationError(ValueError):
@@ -56,7 +56,7 @@ def validate_agent_card_url(url: str) -> None:
5656
1. URL must be parseable and non-empty.
5757
2. Scheme must be ``http`` or ``https``.
5858
3. Hostname must be present and non-empty.
59-
4. The hostname must resolve to a publicly routable IP address it must
59+
4. The hostname must resolve to a publicly routable IP address -- it must
6060
not resolve to a loopback, private, link-local, or otherwise reserved
6161
address (SSRF / IMDS protection).
6262

tests/client/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""conftest.py for tests/client/
2+
3+
Patches out SSRF DNS validation so that card resolver and transport tests can
4+
use test hostnames (localhost, testserver, example.com) without real DNS
5+
lookups. The validate_agent_card_url function is tested directly in
6+
tests/utils/test_url_validation.py.
7+
8+
Target: tests/client/conftest.py
9+
"""
10+
11+
import pytest
12+
from unittest.mock import patch
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def bypass_ssrf_url_validation():
17+
"""Bypass DNS-based SSRF validation for all tests in tests/client/.
18+
19+
Tests here mock HTTP transports and use synthetic hostnames that do not
20+
resolve to real IP addresses. SSRF URL validation is exercised by its own
21+
dedicated test suite in tests/utils/test_url_validation.py.
22+
"""
23+
with patch('a2a.client.card_resolver.validate_agent_card_url'):
24+
yield

tests/integration/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""conftest.py for tests/integration/
2+
3+
Patches out SSRF DNS validation so that integration tests can use httpx
4+
TestClient's synthetic 'testserver' hostname in AgentCard.url without
5+
triggering real DNS resolution. The validate_agent_card_url function is
6+
tested directly in tests/utils/test_url_validation.py.
7+
8+
Target: tests/integration/conftest.py
9+
"""
10+
11+
import pytest
12+
from unittest.mock import patch
13+
14+
15+
@pytest.fixture(autouse=True)
16+
def bypass_ssrf_url_validation():
17+
"""Bypass DNS-based SSRF validation for all tests in tests/integration/.
18+
19+
Integration tests use httpx's TestClient which binds to the synthetic
20+
'testserver' hostname. This hostname cannot be resolved via DNS.
21+
SSRF URL validation is exercised by its own dedicated test suite in
22+
tests/utils/test_url_validation.py.
23+
"""
24+
with patch('a2a.client.card_resolver.validate_agent_card_url'):
25+
yield

tests/utils/test_url_validation.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Tests for a2a.utils.url_validation (A2A-SSRF-01 fix).
2+
3+
Target: tests/utils/test_url_validation.py
4+
"""
5+
6+
import pytest
7+
8+
from a2a.utils.url_validation import A2ASSRFValidationError, validate_agent_card_url
9+
10+
11+
class TestValidateAgentCardUrlScheme:
12+
"""URL scheme validation."""
13+
14+
@pytest.mark.parametrize('url', [
15+
'file:///etc/passwd',
16+
'gopher://internal/1',
17+
'ftp://files.example.com/secret',
18+
'dict://internal/',
19+
'ldap://ldap.example.com/',
20+
'',
21+
])
22+
def test_non_http_schemes_are_blocked(self, url):
23+
with pytest.raises(A2ASSRFValidationError):
24+
validate_agent_card_url(url)
25+
26+
@pytest.mark.parametrize('url', [
27+
'http://example.com/rpc',
28+
'https://example.com/rpc',
29+
'HTTP://EXAMPLE.COM/RPC',
30+
'HTTPS://EXAMPLE.COM/RPC',
31+
])
32+
def test_http_and_https_are_allowed(self, url):
33+
# Should not raise — only scheme + hostname check, DNS may vary
34+
# We only verify scheme acceptance here; real DNS tested separately.
35+
try:
36+
validate_agent_card_url(url)
37+
except A2ASSRFValidationError as exc:
38+
# Accept DNS resolution failure — scheme was accepted
39+
assert 'could not be resolved' in str(exc) or 'blocked network' in str(exc)
40+
41+
42+
class TestValidateAgentCardUrlPrivateIPs:
43+
"""Private / reserved IP range blocking."""
44+
45+
@pytest.mark.parametrize('url,label', [
46+
('http://127.0.0.1/rpc', 'loopback IPv4'),
47+
('http://127.1.2.3/rpc', 'loopback IPv4 (non-zero host)'),
48+
('http://[::1]/rpc', 'loopback IPv6'),
49+
('http://10.0.0.1/rpc', 'RFC 1918 10/8'),
50+
('http://10.255.255.255/rpc', 'RFC 1918 10/8 broadcast'),
51+
('http://172.16.0.1/rpc', 'RFC 1918 172.16/12'),
52+
('http://172.31.255.255/rpc', 'RFC 1918 172.31 (last in range)'),
53+
('http://192.168.1.1/rpc', 'RFC 1918 192.168/16'),
54+
('http://169.254.169.254/latest/meta-data/', 'AWS IMDS'),
55+
('http://169.254.0.1/rpc', 'link-local'),
56+
('http://100.64.0.1/rpc', 'shared address space RFC 6598'),
57+
])
58+
def test_private_addresses_are_blocked(self, url, label):
59+
with pytest.raises(A2ASSRFValidationError, match='blocked network'):
60+
validate_agent_card_url(url)
61+
62+
def test_public_ip_is_allowed(self):
63+
"""A routable public IP should not be blocked."""
64+
# 93.184.216.34 is example.com — guaranteed public
65+
try:
66+
validate_agent_card_url('http://93.184.216.34/rpc')
67+
except A2ASSRFValidationError as exc:
68+
# Only acceptable failure is DNS (not a blocked-network error)
69+
assert 'could not be resolved' in str(exc)
70+
pytest.skip('DNS not available in this environment')
71+
72+
73+
class TestValidateAgentCardUrlHostname:
74+
"""Hostname-level checks."""
75+
76+
def test_missing_hostname_is_blocked(self):
77+
with pytest.raises(A2ASSRFValidationError, match='no hostname'):
78+
validate_agent_card_url('http:///path')
79+
80+
def test_empty_url_is_blocked(self):
81+
with pytest.raises(A2ASSRFValidationError, match='must not be empty'):
82+
validate_agent_card_url('')
83+
84+
85+
class TestA2ASSRFValidationError:
86+
"""Exception type tests."""
87+
88+
def test_is_subclass_of_value_error(self):
89+
assert issubclass(A2ASSRFValidationError, ValueError)
90+
91+
def test_raises_with_descriptive_message(self):
92+
with pytest.raises(A2ASSRFValidationError) as exc_info:
93+
validate_agent_card_url('http://127.0.0.1/rpc')
94+
assert '127.0.0.1' in str(exc_info.value)
95+
assert 'CWE-918' in str(exc_info.value)

0 commit comments

Comments
 (0)