diff --git a/src/a2a/client/service_parameters.py b/src/a2a/client/service_parameters.py index cef250807..39fe79ce1 100644 --- a/src/a2a/client/service_parameters.py +++ b/src/a2a/client/service_parameters.py @@ -1,7 +1,10 @@ from collections.abc import Callable from typing import TypeAlias -from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, + get_requested_extensions, +) ServiceParameters: TypeAlias = dict[str, str] @@ -44,17 +47,18 @@ def create_from( def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate: - """Create a ServiceParametersUpdate that adds A2A extensions. + """Create a ServiceParametersUpdate that merges A2A extension URIs. - Args: - extensions: List of extension strings. - - Returns: - A function that updates ServiceParameters with the extensions header. + Unions the supplied URIs with any already present in the A2A-Extensions + parameter, deduplicating and emitting them in sorted order. Repeated + calls accumulate rather than overwrite. """ def update(parameters: ServiceParameters) -> None: - if extensions: - parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions) + if not extensions: + return + existing = parameters.get(HTTP_EXTENSION_HEADER, '') + merged = sorted(get_requested_extensions([existing, *extensions])) + parameters[HTTP_EXTENSION_HEADER] = ','.join(merged) return update diff --git a/tests/client/test_service_parameters.py b/tests/client/test_service_parameters.py new file mode 100644 index 000000000..fbabd9719 --- /dev/null +++ b/tests/client/test_service_parameters.py @@ -0,0 +1,53 @@ +"""Tests for a2a.client.service_parameters module.""" + +from a2a.client.service_parameters import ( + ServiceParametersFactory, + with_a2a_extensions, +) +from a2a.extensions.common import HTTP_EXTENSION_HEADER + + +def test_with_a2a_extensions_merges_dedupes_and_sorts(): + """Repeated calls accumulate; duplicates collapse; output is sorted.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-c', 'ext-a']), + with_a2a_extensions(['ext-b', 'ext-a']), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_merges_existing_header_value(): + """Pre-existing comma-separated header values are parsed and merged.""" + parameters = ServiceParametersFactory.create_from( + {HTTP_EXTENSION_HEADER: 'ext-a, ext-b'}, + [with_a2a_extensions(['ext-c'])], + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c' + + +def test_with_a2a_extensions_empty_is_noop(): + """An empty extensions list leaves the header untouched / absent.""" + parameters = ServiceParametersFactory.create( + [ + with_a2a_extensions(['ext-a']), + with_a2a_extensions([]), + ] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a' + assert HTTP_EXTENSION_HEADER not in ServiceParametersFactory.create( + [with_a2a_extensions([])] + ) + + +def test_with_a2a_extensions_normalizes_input_strings(): + """Input strings are split on commas and stripped, like header values.""" + parameters = ServiceParametersFactory.create( + [with_a2a_extensions(['ext-a, ext-b', ' ext-c '])] + ) + + assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'