Skip to content

Commit 7bd55a5

Browse files
fix(grpc): normalize extension metadata header key to lowercase
1 parent 957e92b commit 7bd55a5

4 files changed

Lines changed: 42 additions & 27 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def _get_grpc_metadata(
6565
) -> list[tuple[str, str]] | None:
6666
"""Creates gRPC metadata for extensions."""
6767
if extensions is not None:
68-
return [(HTTP_EXTENSION_HEADER, ','.join(extensions))]
68+
return [(HTTP_EXTENSION_HEADER.lower(), ','.join(extensions))]
6969
if self.extensions is not None:
70-
return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))]
70+
return [(HTTP_EXTENSION_HEADER.lower(), ','.join(self.extensions))]
7171
return None
7272

7373
@classmethod

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _get_metadata_value(
5656
md = context.invocation_metadata
5757
raw_values: list[str | bytes] = []
5858
if isinstance(md, Metadata):
59-
raw_values = md.get_all(key)
59+
raw_values = md.get_all(key.lower())
6060
elif isinstance(md, Sequence):
6161
lower_key = key.lower()
6262
raw_values = [e for (k, e) in md if k.lower() == lower_key]
@@ -417,7 +417,7 @@ def _set_extension_metadata(
417417
if server_context.activated_extensions:
418418
context.set_trailing_metadata(
419419
[
420-
(HTTP_EXTENSION_HEADER, e)
420+
(HTTP_EXTENSION_HEADER.lower(), e)
421421
for e in sorted(server_context.activated_extensions)
422422
]
423423
)

tests/client/transports/test_grpc_client.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def test_send_message_task_response(
202202
_, kwargs = mock_grpc_stub.SendMessage.call_args
203203
assert kwargs['metadata'] == [
204204
(
205-
HTTP_EXTENSION_HEADER,
205+
HTTP_EXTENSION_HEADER.lower(),
206206
'https://example.com/test-ext/v3',
207207
)
208208
]
@@ -228,7 +228,7 @@ async def test_send_message_message_response(
228228
_, kwargs = mock_grpc_stub.SendMessage.call_args
229229
assert kwargs['metadata'] == [
230230
(
231-
HTTP_EXTENSION_HEADER,
231+
HTTP_EXTENSION_HEADER.lower(),
232232
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
233233
)
234234
]
@@ -283,7 +283,7 @@ async def test_send_message_streaming( # noqa: PLR0913
283283
_, kwargs = mock_grpc_stub.SendStreamingMessage.call_args
284284
assert kwargs['metadata'] == [
285285
(
286-
HTTP_EXTENSION_HEADER,
286+
HTTP_EXTENSION_HEADER.lower(),
287287
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
288288
)
289289
]
@@ -313,7 +313,7 @@ async def test_get_task(
313313
),
314314
metadata=[
315315
(
316-
HTTP_EXTENSION_HEADER,
316+
HTTP_EXTENSION_HEADER.lower(),
317317
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
318318
)
319319
],
@@ -338,7 +338,7 @@ async def test_get_task_with_history(
338338
),
339339
metadata=[
340340
(
341-
HTTP_EXTENSION_HEADER,
341+
HTTP_EXTENSION_HEADER.lower(),
342342
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
343343
)
344344
],
@@ -363,7 +363,7 @@ async def test_cancel_task(
363363

364364
mock_grpc_stub.CancelTask.assert_awaited_once_with(
365365
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'),
366-
metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')],
366+
metadata=[(HTTP_EXTENSION_HEADER.lower(), 'https://example.com/test-ext/v3')],
367367
)
368368
assert response.status.state == TaskState.canceled
369369

@@ -395,7 +395,7 @@ async def test_set_task_callback_with_valid_task(
395395
),
396396
metadata=[
397397
(
398-
HTTP_EXTENSION_HEADER,
398+
HTTP_EXTENSION_HEADER.lower(),
399399
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
400400
)
401401
],
@@ -458,7 +458,7 @@ async def test_get_task_callback_with_valid_task(
458458
),
459459
metadata=[
460460
(
461-
HTTP_EXTENSION_HEADER,
461+
HTTP_EXTENSION_HEADER.lower(),
462462
'https://example.com/test-ext/v1,https://example.com/test-ext/v2',
463463
)
464464
],
@@ -506,27 +506,27 @@ async def test_get_task_callback_with_invalid_task(
506506
(
507507
['ext1'],
508508
None,
509-
[(HTTP_EXTENSION_HEADER, 'ext1')],
509+
[(HTTP_EXTENSION_HEADER.lower(), 'ext1')],
510510
), # Case 2: Initial, No input
511511
(
512512
None,
513513
['ext2'],
514-
[(HTTP_EXTENSION_HEADER, 'ext2')],
514+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
515515
), # Case 3: No initial, Input
516516
(
517517
['ext1'],
518518
['ext2'],
519-
[(HTTP_EXTENSION_HEADER, 'ext2')],
519+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2')],
520520
), # Case 4: Initial, Input (override)
521521
(
522522
['ext1'],
523523
['ext2', 'ext3'],
524-
[(HTTP_EXTENSION_HEADER, 'ext2,ext3')],
524+
[(HTTP_EXTENSION_HEADER.lower(), 'ext2,ext3')],
525525
), # Case 5: Initial, Multiple inputs (override)
526526
(
527527
['ext1', 'ext2'],
528528
['ext3'],
529-
[(HTTP_EXTENSION_HEADER, 'ext3')],
529+
[(HTTP_EXTENSION_HEADER.lower(), 'ext3')],
530530
), # Case 6: Multiple initial, Single input (override)
531531
],
532532
)
@@ -540,3 +540,18 @@ def test_get_grpc_metadata(
540540
grpc_transport.extensions = initial_extensions
541541
metadata = grpc_transport._get_grpc_metadata(input_extensions)
542542
assert metadata == expected_metadata
543+
544+
def test_get_grpc_metadata_uses_lowercase_header_key(
545+
grpc_transport: GrpcTransport,
546+
) -> None:
547+
"""Test gRPC metadata header key is always lowercase."""
548+
# Regression: gRPC rejects non-lowercase metadata keys
549+
metadata = grpc_transport._get_grpc_metadata(['ext1'])
550+
assert metadata is not None
551+
key, _ = metadata[0]
552+
assert key == key.lower()
553+
554+
metadata = grpc_transport._get_grpc_metadata()
555+
assert metadata is not None
556+
key, _ = metadata[0]
557+
assert key == key.lower()

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,8 @@ async def test_send_message_with_extensions(
350350
mock_grpc_context: AsyncMock,
351351
) -> None:
352352
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
353-
(HTTP_EXTENSION_HEADER, 'foo'),
354-
(HTTP_EXTENSION_HEADER, 'bar'),
353+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
354+
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
355355
)
356356

357357
def side_effect(request, context: ServerCallContext):
@@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext):
379379
mock_grpc_context.set_trailing_metadata.call_args.args[0]
380380
)
381381
assert set(called_metadata) == {
382-
(HTTP_EXTENSION_HEADER, 'foo'),
383-
(HTTP_EXTENSION_HEADER, 'baz'),
382+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
383+
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
384384
}
385385

386386
async def test_send_message_with_comma_separated_extensions(
@@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions(
390390
mock_grpc_context: AsyncMock,
391391
) -> None:
392392
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
393-
(HTTP_EXTENSION_HEADER, 'foo ,, bar,'),
394-
(HTTP_EXTENSION_HEADER, 'baz , bar'),
393+
(HTTP_EXTENSION_HEADER.lower(), 'foo ,, bar,'),
394+
(HTTP_EXTENSION_HEADER.lower(), 'baz , bar'),
395395
)
396396
mock_request_handler.on_message_send.return_value = types.Message(
397397
message_id='1',
@@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions(
415415
mock_grpc_context: AsyncMock,
416416
) -> None:
417417
mock_grpc_context.invocation_metadata = grpc.aio.Metadata(
418-
(HTTP_EXTENSION_HEADER, 'foo'),
419-
(HTTP_EXTENSION_HEADER, 'bar'),
418+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
419+
(HTTP_EXTENSION_HEADER.lower(), 'bar'),
420420
)
421421

422422
async def side_effect(request, context: ServerCallContext):
@@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext):
450450
mock_grpc_context.set_trailing_metadata.call_args.args[0]
451451
)
452452
assert set(called_metadata) == {
453-
(HTTP_EXTENSION_HEADER, 'foo'),
454-
(HTTP_EXTENSION_HEADER, 'baz'),
453+
(HTTP_EXTENSION_HEADER.lower(), 'foo'),
454+
(HTTP_EXTENSION_HEADER.lower(), 'baz'),
455455
}

0 commit comments

Comments
 (0)