55import pytest
66
77from a2a import types
8- from a2a .extensions .common import HTTP_EXTENSION_HEADER
8+ from a2a .extensions .common import GRPC_EXTENSION_HEADER
99from a2a .grpc import a2a_pb2
1010from a2a .server .context import ServerCallContext
1111from a2a .server .request_handlers import GrpcHandler , RequestHandler
@@ -350,8 +350,8 @@ async def test_send_message_with_extensions(
350350 mock_grpc_context : AsyncMock ,
351351 ) -> None :
352352 mock_grpc_context .invocation_metadata .return_value = grpc .aio .Metadata (
353- (HTTP_EXTENSION_HEADER .lower (), 'foo' ),
354- (HTTP_EXTENSION_HEADER .lower (), 'bar' ),
353+ (GRPC_EXTENSION_HEADER .lower (), 'foo' ),
354+ (GRPC_EXTENSION_HEADER .lower (), 'bar' ),
355355 )
356356
357357 def side_effect (request , context : ServerCallContext ):
@@ -379,8 +379,8 @@ def side_effect(request, context: ServerCallContext):
379379 mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
380380 )
381381 assert set (called_metadata ) == {
382- (HTTP_EXTENSION_HEADER .lower (), 'foo' ),
383- (HTTP_EXTENSION_HEADER .lower (), 'baz' ),
382+ (GRPC_EXTENSION_HEADER .lower (), 'foo' ),
383+ (GRPC_EXTENSION_HEADER .lower (), 'baz' ),
384384 }
385385
386386 async def test_send_message_with_comma_separated_extensions (
@@ -390,8 +390,8 @@ async def test_send_message_with_comma_separated_extensions(
390390 mock_grpc_context : AsyncMock ,
391391 ) -> None :
392392 mock_grpc_context .invocation_metadata .return_value = grpc .aio .Metadata (
393- (HTTP_EXTENSION_HEADER .lower (), 'foo ,, bar,' ),
394- (HTTP_EXTENSION_HEADER .lower (), 'baz , bar' ),
393+ (GRPC_EXTENSION_HEADER .lower (), 'foo ,, bar,' ),
394+ (GRPC_EXTENSION_HEADER .lower (), 'baz , bar' ),
395395 )
396396 mock_request_handler .on_message_send .return_value = types .Message (
397397 message_id = '1' ,
@@ -415,8 +415,8 @@ async def test_send_streaming_message_with_extensions(
415415 mock_grpc_context : AsyncMock ,
416416 ) -> None :
417417 mock_grpc_context .invocation_metadata .return_value = grpc .aio .Metadata (
418- (HTTP_EXTENSION_HEADER .lower (), 'foo' ),
419- (HTTP_EXTENSION_HEADER .lower (), 'bar' ),
418+ (GRPC_EXTENSION_HEADER .lower (), 'foo' ),
419+ (GRPC_EXTENSION_HEADER .lower (), 'bar' ),
420420 )
421421
422422 async def side_effect (request , context : ServerCallContext ):
@@ -450,6 +450,6 @@ async def side_effect(request, context: ServerCallContext):
450450 mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
451451 )
452452 assert set (called_metadata ) == {
453- (HTTP_EXTENSION_HEADER .lower (), 'foo' ),
454- (HTTP_EXTENSION_HEADER .lower (), 'baz' ),
453+ (GRPC_EXTENSION_HEADER .lower (), 'foo' ),
454+ (GRPC_EXTENSION_HEADER .lower (), 'baz' ),
455455 }
0 commit comments