22from unittest .mock import MagicMock
33from fastapi import FastAPI
44from httpx import ASGITransport , AsyncClient
5- from google .protobuf import json_format
65
76from a2a .server .apps .rest .fastapi_app import A2ARESTFastAPIApplication
87from a2a .server .request_handlers .request_handler import RequestHandler
98from a2a .types .a2a_pb2 import (
109 AgentCard ,
10+ ListTaskPushNotificationConfigsResponse ,
11+ ListTasksResponse ,
1112 Message ,
12- Role ,
1313 Part ,
14- SendMessageRequest ,
15- SendMessageConfiguration ,
14+ Role ,
15+ Task ,
16+ TaskPushNotificationConfig ,
1617)
1718
1819
@@ -22,78 +23,168 @@ async def agent_card() -> AgentCard:
2223 mock_agent_card .url = 'http://mockurl.com'
2324 mock_capabilities = MagicMock ()
2425 mock_capabilities .streaming = False
26+ mock_capabilities .push_notifications = True
27+ mock_capabilities .extended_agent_card = True
2528 mock_agent_card .capabilities = mock_capabilities
2629 return mock_agent_card
2730
2831
2932@pytest .fixture
3033async def request_handler () -> RequestHandler :
3134 handler = MagicMock (spec = RequestHandler )
32- # Return a default response so the test doesn't crash on return value expectation
35+ # Setup default return values for all handlers
3336 handler .on_message_send .return_value = Message (
3437 message_id = 'test' ,
3538 role = Role .ROLE_AGENT ,
3639 parts = [Part (text = 'response message' )],
3740 )
41+ handler .on_cancel_task .return_value = Task (id = '1' )
42+ handler .on_get_task .return_value = Task (id = '1' )
43+ handler .on_list_tasks .return_value = ListTasksResponse ()
44+ handler .on_create_task_push_notification_config .return_value = (
45+ TaskPushNotificationConfig ()
46+ )
47+ handler .on_get_task_push_notification_config .return_value = (
48+ TaskPushNotificationConfig ()
49+ )
50+ handler .on_list_task_push_notification_configs .return_value = (
51+ ListTaskPushNotificationConfigsResponse ()
52+ )
53+ handler .on_delete_task_push_notification_config .return_value = None
3854 return handler
3955
4056
57+ @pytest .fixture
58+ async def extended_card_modifier () -> MagicMock :
59+ modifier = MagicMock ()
60+ modifier .return_value = AgentCard ()
61+ return modifier
62+
63+
4164@pytest .fixture
4265async def app (
43- agent_card : AgentCard , request_handler : RequestHandler
66+ agent_card : AgentCard ,
67+ request_handler : RequestHandler ,
68+ extended_card_modifier : MagicMock ,
4469) -> FastAPI :
45- return A2ARESTFastAPIApplication (agent_card , request_handler ).build (
46- agent_card_url = '/well-known/agent.json' , rpc_url = ''
47- )
70+ return A2ARESTFastAPIApplication (
71+ agent_card ,
72+ request_handler ,
73+ extended_card_modifier = extended_card_modifier ,
74+ ).build (agent_card_url = '/well-known/agent.json' , rpc_url = '' )
4875
4976
5077@pytest .fixture
5178async def client (app : FastAPI ) -> AsyncClient :
5279 return AsyncClient (transport = ASGITransport (app = app ), base_url = 'http://test' )
5380
5481
82+ @pytest .mark .parametrize (
83+ 'path_template, method, handler_method_name, json_body' ,
84+ [
85+ ('/message:send' , 'POST' , 'on_message_send' , {'message' : {}}),
86+ ('/tasks/1:cancel' , 'POST' , 'on_cancel_task' , None ),
87+ ('/tasks/1' , 'GET' , 'on_get_task' , None ),
88+ ('/tasks' , 'GET' , 'on_list_tasks' , None ),
89+ (
90+ '/tasks/1/pushNotificationConfigs/p1' ,
91+ 'GET' ,
92+ 'on_get_task_push_notification_config' ,
93+ None ,
94+ ),
95+ (
96+ '/tasks/1/pushNotificationConfigs/p1' ,
97+ 'DELETE' ,
98+ 'on_delete_task_push_notification_config' ,
99+ None ,
100+ ),
101+ (
102+ '/tasks/1/pushNotificationConfigs' ,
103+ 'POST' ,
104+ 'on_create_task_push_notification_config' ,
105+ {'config' : {'url' : 'http://foo' }},
106+ ),
107+ (
108+ '/tasks/1/pushNotificationConfigs' ,
109+ 'GET' ,
110+ 'on_list_task_push_notification_configs' ,
111+ None ,
112+ ),
113+ ],
114+ )
55115@pytest .mark .anyio
56- async def test_tenant_extraction_from_path (
57- client : AsyncClient , request_handler : MagicMock
116+ async def test_tenant_extraction_parametrized (
117+ client : AsyncClient ,
118+ request_handler : MagicMock ,
119+ extended_card_modifier : MagicMock ,
120+ path_template : str ,
121+ method : str ,
122+ handler_method_name : str ,
123+ json_body : dict | None ,
58124) -> None :
59- request = SendMessageRequest (
60- message = Message (),
61- configuration = SendMessageConfiguration (),
62- )
125+ """Test tenant extraction for standard REST endpoints."""
126+ # Test with tenant
127+ tenant = 'my-tenant'
128+ tenant_path = f'/ { tenant } { path_template } '
63129
64- # Test with tenant in URL
65- tenant_id = 'my-tenant-123'
66- response = await client .post (
67- f'/{ tenant_id } /message:send' , json = json_format .MessageToDict (request )
68- )
130+ response = await client .request (method , tenant_path , json = json_body )
69131 response .raise_for_status ()
70132
71- # Verify handler was called
72- assert request_handler .on_message_send .called
133+ # Verify handler call
134+ handler_mock = getattr (request_handler , handler_method_name )
135+
136+ assert handler_mock .called
137+ args , _ = handler_mock .call_args
138+ context = args [1 ]
139+ assert context .tenant == tenant
140+
141+ # Reset mock for non-tenant test
142+ handler_mock .reset_mock ()
143+
144+ # Test without tenant
145+ response = await client .request (method , path_template , json = json_body )
146+ response .raise_for_status ()
73147
74- # Verify call context has tenant
75- args , _ = request_handler . on_message_send . call_args
76- # args[0] is the request proto, args[1] is the ServerCallContext
148+ # Verify context.tenant == ""
149+ assert handler_mock . called
150+ args , _ = handler_mock . call_args
77151 context = args [1 ]
78- assert context .tenant == tenant_id
152+ assert context .tenant == ''
79153
80154
81155@pytest .mark .anyio
82- async def test_no_tenant_extraction (
83- client : AsyncClient , request_handler : MagicMock
156+ async def test_tenant_extraction_extended_agent_card (
157+ client : AsyncClient ,
158+ extended_card_modifier : MagicMock ,
84159) -> None :
85- request = SendMessageRequest (
86- message = Message (),
87- configuration = SendMessageConfiguration (),
88- )
160+ """Test tenant extraction specifically for extendedAgentCard endpoint.
89161
90- # Test without tenant in URL
91- response = await client .post (
92- '/message:send' , json = json_format .MessageToDict (request )
93- )
162+ This verifies that `extended_card_modifier` receives the correct context
163+ including the tenant, confirming that `_build_call_context` is used correctly.
164+ """
165+ # Test with tenant
166+ tenant = 'my-tenant'
167+ tenant_path = f'/{ tenant } /extendedAgentCard'
168+
169+ response = await client .get (tenant_path )
170+ response .raise_for_status ()
171+
172+ # Verify extended_card_modifier called with tenant context
173+ assert extended_card_modifier .called
174+ args , _ = extended_card_modifier .call_args
175+ # args[0] is card_to_serve, args[1] is context
176+ context = args [1 ]
177+ assert context .tenant == tenant
178+
179+ # Reset mock for non-tenant test
180+ extended_card_modifier .reset_mock ()
181+
182+ # Test without tenant
183+ response = await client .get ('/extendedAgentCard' )
94184 response .raise_for_status ()
95185
96- # Verify call context has empty string tenant (default)
97- args , _ = request_handler .on_message_send .call_args
186+ # Verify extended_card_modifier called with empty tenant context
187+ assert extended_card_modifier .called
188+ args , _ = extended_card_modifier .call_args
98189 context = args [1 ]
99190 assert context .tenant == ''
0 commit comments