11from unittest .mock import AsyncMock , MagicMock , patch
22
33import grpc
4+ import grpc .aio
45import pytest
56
67from a2a import types
@@ -22,7 +23,6 @@ def mock_request_handler() -> AsyncMock:
2223def mock_grpc_context () -> AsyncMock :
2324 context = AsyncMock (spec = grpc .aio .ServicerContext )
2425 context .abort = AsyncMock ()
25- context .invocation_metadata = MagicMock ()
2626 context .set_trailing_metadata = MagicMock ()
2727 return context
2828
@@ -286,18 +286,15 @@ async def test_abort_context_error_mapping(
286286
287287@pytest .mark .asyncio
288288class TestGrpcExtensions :
289- @patch (
290- 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
291- )
292289 async def test_send_message_with_extensions (
293290 self ,
294- mock_build ,
295291 grpc_handler : GrpcHandler ,
296292 mock_request_handler : AsyncMock ,
297293 mock_grpc_context : AsyncMock ,
298294 ):
299- mock_build .return_value = ServerCallContext (
300- requested_extensions = {'foo' , 'bar' }
295+ mock_grpc_context .invocation_metadata = grpc .aio .Metadata (
296+ (HTTP_EXTENSION_HEADER , 'foo' ),
297+ (HTTP_EXTENSION_HEADER , 'bar' ),
301298 )
302299
303300 def side_effect (request , context : ServerCallContext ):
@@ -321,21 +318,48 @@ def side_effect(request, context: ServerCallContext):
321318 assert call_context .requested_extensions == {'foo' , 'bar' }
322319
323320 mock_grpc_context .set_trailing_metadata .assert_called_once ()
324- called_metadata = mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
325- assert set (called_metadata ) == {(HTTP_EXTENSION_HEADER , 'foo' ), (HTTP_EXTENSION_HEADER , 'baz' )}
321+ called_metadata = (
322+ mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
323+ )
324+ assert set (called_metadata ) == {
325+ (HTTP_EXTENSION_HEADER , 'foo' ),
326+ (HTTP_EXTENSION_HEADER , 'baz' ),
327+ }
328+
329+ async def test_send_message_with_comma_separated_extensions (
330+ self ,
331+ grpc_handler : GrpcHandler ,
332+ mock_request_handler : AsyncMock ,
333+ mock_grpc_context : AsyncMock ,
334+ ):
335+ mock_grpc_context .invocation_metadata = grpc .aio .Metadata (
336+ (HTTP_EXTENSION_HEADER , 'foo ,, bar,' ),
337+ (HTTP_EXTENSION_HEADER , 'baz , bar' ),
338+ )
339+ mock_request_handler .on_message_send .return_value = types .Message (
340+ messageId = '1' ,
341+ role = types .Role .agent ,
342+ parts = [types .TextPart (text = 'test' )],
343+ )
344+
345+ await grpc_handler .SendMessage (
346+ a2a_pb2 .SendMessageRequest (), mock_grpc_context
347+ )
348+
349+ mock_request_handler .on_message_send .assert_awaited_once ()
350+ call_context = mock_request_handler .on_message_send .call_args [0 ][1 ]
351+ assert isinstance (call_context , ServerCallContext )
352+ assert call_context .requested_extensions == {'foo' , 'bar' , 'baz' }
326353
327- @patch (
328- 'a2a.server.request_handlers.grpc_handler.DefaultCallContextBuilder.build'
329- )
330354 async def test_send_streaming_message_with_extensions (
331355 self ,
332- mock_build ,
333356 grpc_handler : GrpcHandler ,
334357 mock_request_handler : AsyncMock ,
335358 mock_grpc_context : AsyncMock ,
336359 ):
337- mock_build .return_value = ServerCallContext (
338- requested_extensions = {'foo' , 'bar' }
360+ mock_grpc_context .invocation_metadata = grpc .aio .Metadata (
361+ (HTTP_EXTENSION_HEADER , 'foo' ),
362+ (HTTP_EXTENSION_HEADER , 'bar' ),
339363 )
340364
341365 async def side_effect (request , context : ServerCallContext ):
@@ -365,5 +389,10 @@ async def side_effect(request, context: ServerCallContext):
365389 assert call_context .requested_extensions == {'foo' , 'bar' }
366390
367391 mock_grpc_context .set_trailing_metadata .assert_called_once ()
368- called_metadata = mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
369- assert set (called_metadata ) == {(HTTP_EXTENSION_HEADER , 'foo' ), (HTTP_EXTENSION_HEADER , 'baz' )}
392+ called_metadata = (
393+ mock_grpc_context .set_trailing_metadata .call_args .args [0 ]
394+ )
395+ assert set (called_metadata ) == {
396+ (HTTP_EXTENSION_HEADER , 'foo' ),
397+ (HTTP_EXTENSION_HEADER , 'baz' ),
398+ }
0 commit comments