Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/a2a/client/service_parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from collections.abc import Callable
from typing import TypeAlias

from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
get_requested_extensions,
)


ServiceParameters: TypeAlias = dict[str, str]
Expand Down Expand Up @@ -44,17 +47,18 @@ def create_from(


def with_a2a_extensions(extensions: list[str]) -> ServiceParametersUpdate:
"""Create a ServiceParametersUpdate that adds A2A extensions.
"""Create a ServiceParametersUpdate that merges A2A extension URIs.

Args:
extensions: List of extension strings.

Returns:
A function that updates ServiceParameters with the extensions header.
Unions the supplied URIs with any already present in the A2A-Extensions
parameter, deduplicating and emitting them in sorted order. Repeated
calls accumulate rather than overwrite.
"""

def update(parameters: ServiceParameters) -> None:
if extensions:
parameters[HTTP_EXTENSION_HEADER] = ','.join(extensions)
if not extensions:
return
existing = parameters.get(HTTP_EXTENSION_HEADER, '')
merged = sorted(get_requested_extensions([existing, *extensions]))
parameters[HTTP_EXTENSION_HEADER] = ','.join(merged)

return update
53 changes: 53 additions & 0 deletions tests/client/test_service_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for a2a.client.service_parameters module."""

from a2a.client.service_parameters import (
ServiceParametersFactory,
with_a2a_extensions,
)
from a2a.extensions.common import HTTP_EXTENSION_HEADER


def test_with_a2a_extensions_merges_dedupes_and_sorts():
"""Repeated calls accumulate; duplicates collapse; output is sorted."""
parameters = ServiceParametersFactory.create(
[
with_a2a_extensions(['ext-c', 'ext-a']),
with_a2a_extensions(['ext-b', 'ext-a']),
]
)

assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'


def test_with_a2a_extensions_merges_existing_header_value():
"""Pre-existing comma-separated header values are parsed and merged."""
parameters = ServiceParametersFactory.create_from(
{HTTP_EXTENSION_HEADER: 'ext-a, ext-b'},
[with_a2a_extensions(['ext-c'])],
)

assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'


def test_with_a2a_extensions_empty_is_noop():
"""An empty extensions list leaves the header untouched / absent."""
parameters = ServiceParametersFactory.create(
[
with_a2a_extensions(['ext-a']),
with_a2a_extensions([]),
]
)

assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a'
assert HTTP_EXTENSION_HEADER not in ServiceParametersFactory.create(
[with_a2a_extensions([])]
)


def test_with_a2a_extensions_normalizes_input_strings():
"""Input strings are split on commas and stripped, like header values."""
parameters = ServiceParametersFactory.create(
[with_a2a_extensions(['ext-a, ext-b', ' ext-c '])]
)

assert parameters[HTTP_EXTENSION_HEADER] == 'ext-a,ext-b,ext-c'
Loading