Skip to content

Commit 185c0ef

Browse files
committed
fix: update with_a2a_extensions to append instead of overwriting
1 parent 934b595 commit 185c0ef

2 files changed

Lines changed: 96 additions & 4 deletions

File tree

src/a2a/client/service_parameters.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from collections.abc import Callable
22
from typing import TypeAlias
33

4-
from a2a.extensions.common import HTTP_EXTENSION_HEADER
4+
from a2a.extensions.common import (
5+
HTTP_EXTENSION_HEADER,
6+
get_requested_extensions,
7+
)
58

69

710
ServiceParameters: TypeAlias = dict[str, str]
@@ -46,15 +49,26 @@ def create_from(
4649
def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate:
4750
"""Create a ServiceParametersUpdate that adds A2A extensions.
4851
52+
Merges the supplied URIs with any extensions already present in the
53+
A2A-Extensions service parameter, deduplicating and producing a stable
54+
(sorted) order. Calling this multiple times in a chain accumulates the
55+
requested extensions instead of overwriting prior values.
56+
4957
Args:
50-
extensions: List of extension strings.
58+
extensions: List of extension URIs to advertise.
5159
5260
Returns:
5361
A function that updates ServiceParameters with the extensions header.
5462
"""
5563

5664
def update(parameters: ServiceParameters) -> None:
57-
if extensions:
58-
parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions)
65+
if not extensions:
66+
return
67+
existing = parameters.get(HTTP_EXTENSION_HEADER)
68+
merged = sorted(
69+
get_requested_extensions([existing] if existing else [])
70+
| set(extensions)
71+
)
72+
parameters[HTTP_EXTENSION_HEADER] = ','.join(merged)
5973

6074
return update
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""Tests for a2a.client.service_parameters module."""
2+
3+
from a2a.client.service_parameters import (
4+
ServiceParametersFactory,
5+
with_a2a_extensions,
6+
)
7+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
8+
9+
10+
def test_with_a2a_extensions_sets_header_when_empty():
11+
"""First call on empty parameters sets the joined URIs."""
12+
parameters = ServiceParametersFactory.create(
13+
[with_a2a_extensions(['ext-b', 'ext-a'])]
14+
)
15+
16+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b'
17+
18+
19+
def test_with_a2a_extensions_merges_disjoint_calls():
20+
"""A second call with disjoint URIs unions both sets."""
21+
parameters = ServiceParametersFactory.create(
22+
[
23+
with_a2a_extensions(['ext-a']),
24+
with_a2a_extensions(['ext-b']),
25+
]
26+
)
27+
28+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b'
29+
30+
31+
def test_with_a2a_extensions_deduplicates_overlapping():
32+
"""Overlapping URIs do not produce duplicates."""
33+
parameters = ServiceParametersFactory.create(
34+
[
35+
with_a2a_extensions(['ext-a', 'ext-b']),
36+
with_a2a_extensions(['ext-b', 'ext-c']),
37+
]
38+
)
39+
40+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'
41+
42+
43+
def test_with_a2a_extensions_empty_is_noop():
44+
"""Calling with an empty list leaves any existing header untouched."""
45+
parameters = ServiceParametersFactory.create(
46+
[
47+
with_a2a_extensions(['ext-a']),
48+
with_a2a_extensions([]),
49+
]
50+
)
51+
52+
assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a'
53+
54+
55+
def test_with_a2a_extensions_empty_does_not_create_header():
56+
"""Calling with an empty list on empty parameters adds nothing."""
57+
parameters = ServiceParametersFactory.create([with_a2a_extensions([])])
58+
59+
assert HTTP_EXTENSION_HEADER not in parameters
60+
61+
62+
def test_with_a2a_extensions_output_is_sorted():
63+
"""Output ordering is deterministic (sorted) regardless of input order."""
64+
parameters = ServiceParametersFactory.create(
65+
[with_a2a_extensions(['c', 'a', 'b'])]
66+
)
67+
68+
assert parameters[HTTP_EXTENSION_HEADER] == 'a,b,c'
69+
70+
71+
def test_with_a2a_extensions_merges_existing_header_value():
72+
"""Existing comma-separated header values are parsed and merged."""
73+
base = ServiceParametersFactory.create_from(
74+
{HTTP_EXTENSION_HEADER: 'ext-a, ext-b'},
75+
[with_a2a_extensions(['ext-c'])],
76+
)
77+
78+
assert base[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'

0 commit comments

Comments
 (0)