Skip to content

Commit 2a57166

Browse files
committed
fix: add tenant support to handle_authenticated_agent_card and add more tests
1 parent b0fdc69 commit 2a57166

2 files changed

Lines changed: 130 additions & 39 deletions

File tree

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ async def handle_authenticated_agent_card(
186186
card_to_serve = self.agent_card
187187

188188
if self.extended_card_modifier:
189-
context = self._context_builder.build(request)
189+
context = self._build_call_context(request)
190190
card_to_serve = await maybe_await(
191191
self.extended_card_modifier(card_to_serve, context)
192192
)

tests/server/apps/rest/test_rest_tenant.py

Lines changed: 129 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
from unittest.mock import MagicMock
33
from fastapi import FastAPI
44
from httpx import ASGITransport, AsyncClient
5-
from google.protobuf import json_format
65

76
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
87
from a2a.server.request_handlers.request_handler import RequestHandler
98
from a2a.types.a2a_pb2 import (
109
AgentCard,
10+
ListTaskPushNotificationConfigsResponse,
11+
ListTasksResponse,
1112
Message,
12-
Role,
1313
Part,
14-
SendMessageRequest,
15-
SendMessageConfiguration,
14+
Role,
15+
Task,
16+
TaskPushNotificationConfig,
1617
)
1718

1819

@@ -22,78 +23,168 @@ async def agent_card() -> AgentCard:
2223
mock_agent_card.url = 'http://mockurl.com'
2324
mock_capabilities = MagicMock()
2425
mock_capabilities.streaming = False
26+
mock_capabilities.push_notifications = True
27+
mock_capabilities.extended_agent_card = True
2528
mock_agent_card.capabilities = mock_capabilities
2629
return mock_agent_card
2730

2831

2932
@pytest.fixture
3033
async def request_handler() -> RequestHandler:
3134
handler = MagicMock(spec=RequestHandler)
32-
# Return a default response so the test doesn't crash on return value expectation
35+
# Setup default return values for all handlers
3336
handler.on_message_send.return_value = Message(
3437
message_id='test',
3538
role=Role.ROLE_AGENT,
3639
parts=[Part(text='response message')],
3740
)
41+
handler.on_cancel_task.return_value = Task(id='1')
42+
handler.on_get_task.return_value = Task(id='1')
43+
handler.on_list_tasks.return_value = ListTasksResponse()
44+
handler.on_create_task_push_notification_config.return_value = (
45+
TaskPushNotificationConfig()
46+
)
47+
handler.on_get_task_push_notification_config.return_value = (
48+
TaskPushNotificationConfig()
49+
)
50+
handler.on_list_task_push_notification_configs.return_value = (
51+
ListTaskPushNotificationConfigsResponse()
52+
)
53+
handler.on_delete_task_push_notification_config.return_value = None
3854
return handler
3955

4056

57+
@pytest.fixture
58+
async def extended_card_modifier() -> MagicMock:
59+
modifier = MagicMock()
60+
modifier.return_value = AgentCard()
61+
return modifier
62+
63+
4164
@pytest.fixture
4265
async def app(
43-
agent_card: AgentCard, request_handler: RequestHandler
66+
agent_card: AgentCard,
67+
request_handler: RequestHandler,
68+
extended_card_modifier: MagicMock,
4469
) -> FastAPI:
45-
return A2ARESTFastAPIApplication(agent_card, request_handler).build(
46-
agent_card_url='/well-known/agent.json', rpc_url=''
47-
)
70+
return A2ARESTFastAPIApplication(
71+
agent_card,
72+
request_handler,
73+
extended_card_modifier=extended_card_modifier,
74+
).build(agent_card_url='/well-known/agent.json', rpc_url='')
4875

4976

5077
@pytest.fixture
5178
async def client(app: FastAPI) -> AsyncClient:
5279
return AsyncClient(transport=ASGITransport(app=app), base_url='http://test')
5380

5481

82+
@pytest.mark.parametrize(
83+
'path_template, method, handler_method_name, json_body',
84+
[
85+
('/message:send', 'POST', 'on_message_send', {'message': {}}),
86+
('/tasks/1:cancel', 'POST', 'on_cancel_task', None),
87+
('/tasks/1', 'GET', 'on_get_task', None),
88+
('/tasks', 'GET', 'on_list_tasks', None),
89+
(
90+
'/tasks/1/pushNotificationConfigs/p1',
91+
'GET',
92+
'on_get_task_push_notification_config',
93+
None,
94+
),
95+
(
96+
'/tasks/1/pushNotificationConfigs/p1',
97+
'DELETE',
98+
'on_delete_task_push_notification_config',
99+
None,
100+
),
101+
(
102+
'/tasks/1/pushNotificationConfigs',
103+
'POST',
104+
'on_create_task_push_notification_config',
105+
{'config': {'url': 'http://foo'}},
106+
),
107+
(
108+
'/tasks/1/pushNotificationConfigs',
109+
'GET',
110+
'on_list_task_push_notification_configs',
111+
None,
112+
),
113+
],
114+
)
55115
@pytest.mark.anyio
56-
async def test_tenant_extraction_from_path(
57-
client: AsyncClient, request_handler: MagicMock
116+
async def test_tenant_extraction_parametrized(
117+
client: AsyncClient,
118+
request_handler: MagicMock,
119+
extended_card_modifier: MagicMock,
120+
path_template: str,
121+
method: str,
122+
handler_method_name: str,
123+
json_body: dict | None,
58124
) -> None:
59-
request = SendMessageRequest(
60-
message=Message(),
61-
configuration=SendMessageConfiguration(),
62-
)
125+
"""Test tenant extraction for standard REST endpoints."""
126+
# Test with tenant
127+
tenant = 'my-tenant'
128+
tenant_path = f'/{tenant}{path_template}'
63129

64-
# Test with tenant in URL
65-
tenant_id = 'my-tenant-123'
66-
response = await client.post(
67-
f'/{tenant_id}/message:send', json=json_format.MessageToDict(request)
68-
)
130+
response = await client.request(method, tenant_path, json=json_body)
69131
response.raise_for_status()
70132

71-
# Verify handler was called
72-
assert request_handler.on_message_send.called
133+
# Verify handler call
134+
handler_mock = getattr(request_handler, handler_method_name)
135+
136+
assert handler_mock.called
137+
args, _ = handler_mock.call_args
138+
context = args[1]
139+
assert context.tenant == tenant
140+
141+
# Reset mock for non-tenant test
142+
handler_mock.reset_mock()
143+
144+
# Test without tenant
145+
response = await client.request(method, path_template, json=json_body)
146+
response.raise_for_status()
73147

74-
# Verify call context has tenant
75-
args, _ = request_handler.on_message_send.call_args
76-
# args[0] is the request proto, args[1] is the ServerCallContext
148+
# Verify context.tenant == ""
149+
assert handler_mock.called
150+
args, _ = handler_mock.call_args
77151
context = args[1]
78-
assert context.tenant == tenant_id
152+
assert context.tenant == ''
79153

80154

81155
@pytest.mark.anyio
82-
async def test_no_tenant_extraction(
83-
client: AsyncClient, request_handler: MagicMock
156+
async def test_tenant_extraction_extended_agent_card(
157+
client: AsyncClient,
158+
extended_card_modifier: MagicMock,
84159
) -> None:
85-
request = SendMessageRequest(
86-
message=Message(),
87-
configuration=SendMessageConfiguration(),
88-
)
160+
"""Test tenant extraction specifically for extendedAgentCard endpoint.
89161
90-
# Test without tenant in URL
91-
response = await client.post(
92-
'/message:send', json=json_format.MessageToDict(request)
93-
)
162+
This verifies that `extended_card_modifier` receives the correct context
163+
including the tenant, confirming that `_build_call_context` is used correctly.
164+
"""
165+
# Test with tenant
166+
tenant = 'my-tenant'
167+
tenant_path = f'/{tenant}/extendedAgentCard'
168+
169+
response = await client.get(tenant_path)
170+
response.raise_for_status()
171+
172+
# Verify extended_card_modifier called with tenant context
173+
assert extended_card_modifier.called
174+
args, _ = extended_card_modifier.call_args
175+
# args[0] is card_to_serve, args[1] is context
176+
context = args[1]
177+
assert context.tenant == tenant
178+
179+
# Reset mock for non-tenant test
180+
extended_card_modifier.reset_mock()
181+
182+
# Test without tenant
183+
response = await client.get('/extendedAgentCard')
94184
response.raise_for_status()
95185

96-
# Verify call context has empty string tenant (default)
97-
args, _ = request_handler.on_message_send.call_args
186+
# Verify extended_card_modifier called with empty tenant context
187+
assert extended_card_modifier.called
188+
args, _ = extended_card_modifier.call_args
98189
context = args[1]
99190
assert context.tenant == ''

0 commit comments

Comments
 (0)