Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/a2a/client/transports/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ def _get_grpc_metadata(
extensions: list[str] | None = None,
) -> list[tuple[str, str]] | None:
"""Creates gRPC metadata for extensions."""
if extensions is not None:
return [(HTTP_EXTENSION_HEADER, ','.join(extensions))]
if self.extensions is not None:
return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))]
ext_to_use = extensions if extensions is not None else self.extensions
if ext_to_use is not None:
return [(HTTP_EXTENSION_HEADER.lower(), ','.join(ext_to_use))]
return None

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def _get_metadata_value(
) -> list[str]:
md = context.invocation_metadata
raw_values: list[str | bytes] = []
lower_key = key.lower()
if isinstance(md, Metadata):
raw_values = md.get_all(key)
raw_values = md.get_all(lower_key)
elif isinstance(md, Sequence):
lower_key = key.lower()
raw_values = [e for (k, e) in md if k.lower() == lower_key]
Comment thread
cchinchilla-dev marked this conversation as resolved.
return [e if isinstance(e, str) else e.decode('utf-8') for e in raw_values]

Expand Down Expand Up @@ -417,7 +417,7 @@ def _set_extension_metadata(
if server_context.activated_extensions:
context.set_trailing_metadata(
[
(HTTP_EXTENSION_HEADER, e)
(HTTP_EXTENSION_HEADER.lower(), e)
for e in sorted(server_context.activated_extensions)
]
)
47 changes: 34 additions & 13 deletions tests/client/transports/test_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def test_send_message_task_response(
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v3',
)
]
Expand All @@ -228,7 +228,7 @@ async def test_send_message_message_response(
_, kwargs = mock_grpc_stub.SendMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
]
Expand Down Expand Up @@ -283,7 +283,7 @@ async def test_send_message_streaming( # noqa: PLR0913
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
assert kwargs['metadata'] == [
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
]
Expand Down Expand Up @@ -313,7 +313,7 @@ async def test_get_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand All @@ -338,7 +338,7 @@ async def test_get_task_with_history(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand All @@ -363,7 +363,9 @@ async def test_cancel_task(

mock_grpc_stub.CancelTask.assert_awaited_once_with(
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'),
metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')],
metadata=[
(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')
],
)
assert response.status.state == TaskState.canceled

Expand Down Expand Up @@ -395,7 +397,7 @@ async def test_set_task_callback_with_valid_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand Down Expand Up @@ -458,7 +460,7 @@ async def test_get_task_callback_with_valid_task(
),
metadata=[
(
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER.lower(),
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
)
],
Expand Down Expand Up @@ -506,27 +508,27 @@ async def test_get_task_callback_with_invalid_task(
(
['ext1'],
None,
[(HTTP_EXTENSION_HEADER, 'ext1')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
), # Case 2: Initial, No input
(
None,
['ext2'],
[(HTTP_EXTENSION_HEADER, 'ext2')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
), # Case 3: No initial, Input
(
['ext1'],
['ext2'],
[(HTTP_EXTENSION_HEADER, 'ext2')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
), # Case 4: Initial, Input (override)
(
['ext1'],
['ext2', 'ext3'],
[(HTTP_EXTENSION_HEADER, 'ext2,ext3')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
), # Case 5: Initial, Multiple inputs (override)
(
['ext1', 'ext2'],
['ext3'],
[(HTTP_EXTENSION_HEADER, 'ext3')],
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
), # Case 6: Multiple initial, Single input (override)
],
)
Expand All @@ -540,3 +542,22 @@ def test_get_grpc_metadata(
grpc_transport.extensions = initial_extensions
metadata = grpc_transport._get_grpc_metadata(input_extensions)
assert metadata == expected_metadata


@pytest.mark.parametrize(
'test_extensions',
[
(['ext1']), # Test with explicit extensions
(None), # Test with transport's default extensions
],
)
def test_get_grpc_metadata_uses_lowercase_header_key(
grpc_transport: GrpcTransport,
test_extensions: list[str] | None,
) -> None:
"""Test gRPC metadata header key is always lowercase."""
# Regression: gRPC rejects non-lowercase metadata keys
metadata = grpc_transport._get_grpc_metadata(test_extensions)
assert metadata is not None
key, _ = metadata[0]
assert key == key.lower()
Comment thread
cchinchilla-dev marked this conversation as resolved.
Outdated
20 changes: 10 additions & 10 deletions tests/server/request_handlers/test_grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ async def test_send_message_with_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
)

def side_effect(request, context: ServerCallContext):
Expand Down Expand Up @@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext):
mock_grpc_context.set_trailing_metadata.call_args.args[0]
)
assert set(called_metadata) == {
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'baz'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
}

async def test_send_message_with_comma_separated_extensions(
Expand All @@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo ,, bar,'),
(HTTP_EXTENSION_HEADER, 'baz , bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'),
(HTTP_EXTENSION_HEADER.lower(), 'baz , bar'),
)
mock_request_handler.on_message_send.return_value = types.Message(
message_id='1',
Expand All @@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions(
mock_grpc_context: AsyncMock,
) -> None:
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'bar'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
)

async def side_effect(request, context: ServerCallContext):
Expand Down Expand Up @@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext):
mock_grpc_context.set_trailing_metadata.call_args.args[0]
)
assert set(called_metadata) == {
(HTTP_EXTENSION_HEADER, 'foo'),
(HTTP_EXTENSION_HEADER, 'baz'),
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
}
Loading