Skip to content

Commit fa2863b

Browse files
committed
feat: add tenant to ServerCallContext, add tenant-prefixed routes for REST endpoints
and introduce tenant extraction from REST API paths
1 parent 5b354e4 commit fa2863b

4 files changed

Lines changed: 159 additions & 0 deletions

File tree

src/a2a/server/agent_execution/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ def add_activated_extension(self, uri: str) -> None:
160160
if self._call_context:
161161
self._call_context.activated_extensions.add(uri)
162162

163+
@property
164+
def tenant(self) -> str:
165+
"""The tenant associated with this request."""
166+
return self._call_context.tenant if self._call_context else ''
167+
163168
@property
164169
def requested_extensions(self) -> set[str]:
165170
"""Extensions that the client requested to activate."""

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ async def _handle_request(
111111
request: Request,
112112
) -> Response:
113113
call_context = self._context_builder.build(request)
114+
if 'tenant' in request.path_params:
115+
call_context.tenant = request.path_params['tenant']
116+
114117
response = await method(request, call_context)
115118
return JSONResponse(content=response)
116119

@@ -131,6 +134,8 @@ async def _handle_streaming_request(
131134
) from e
132135

133136
call_context = self._context_builder.build(request)
137+
if 'tenant' in request.path_params:
138+
call_context.tenant = request.path_params['tenant']
134139

135140
async def event_generator(
136141
stream: AsyncIterable[Any],
@@ -250,10 +255,59 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
250255
('/tasks', 'GET'): functools.partial(
251256
self._handle_request, self.handler.list_tasks
252257
),
258+
# Tenant prefixed routes
259+
('/{tenant}/message:send', 'POST'): functools.partial(
260+
self._handle_request,
261+
self.handler.on_message_send,
262+
),
263+
('/{tenant}/message:stream', 'POST'): functools.partial(
264+
self._handle_streaming_request,
265+
self.handler.on_message_send_stream,
266+
),
267+
('/{tenant}/tasks/{id}:cancel', 'POST'): functools.partial(
268+
self._handle_request, self.handler.on_cancel_task
269+
),
270+
('/{tenant}/tasks/{id}:subscribe', 'GET'): functools.partial(
271+
self._handle_streaming_request,
272+
self.handler.on_subscribe_to_task,
273+
),
274+
('/{tenant}/tasks/{id}', 'GET'): functools.partial(
275+
self._handle_request, self.handler.on_get_task
276+
),
277+
(
278+
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
279+
'GET',
280+
): functools.partial(
281+
self._handle_request, self.handler.get_push_notification
282+
),
283+
(
284+
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
285+
'DELETE',
286+
): functools.partial(
287+
self._handle_request, self.handler.delete_push_notification
288+
),
289+
(
290+
'/{tenant}/tasks/{id}/pushNotificationConfigs',
291+
'POST',
292+
): functools.partial(
293+
self._handle_request, self.handler.set_push_notification
294+
),
295+
(
296+
'/{tenant}/tasks/{id}/pushNotificationConfigs',
297+
'GET',
298+
): functools.partial(
299+
self._handle_request, self.handler.list_push_notifications
300+
),
301+
('/{tenant}/tasks', 'GET'): functools.partial(
302+
self._handle_request, self.handler.list_tasks
303+
),
253304
}
254305
if self.agent_card.capabilities.extended_agent_card:
255306
routes[('/extendedAgentCard', 'GET')] = functools.partial(
256307
self._handle_request, self.handle_authenticated_agent_card
257308
)
309+
routes[('/{tenant}/extendedAgentCard', 'GET')] = functools.partial(
310+
self._handle_request, self.handle_authenticated_agent_card
311+
)
258312

259313
return routes

src/a2a/server/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@ class ServerCallContext(BaseModel):
2121

2222
state: State = Field(default={})
2323
user: User = Field(default=UnauthenticatedUser())
24+
tenant: str = Field(default='')
2425
requested_extensions: set[str] = Field(default_factory=set)
2526
activated_extensions: set[str] = Field(default_factory=set)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from fastapi import FastAPI
4+
from httpx import ASGITransport, AsyncClient
5+
from google.protobuf import json_format
6+
7+
from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
8+
from a2a.server.request_handlers.request_handler import RequestHandler
9+
from a2a.types.a2a_pb2 import (
10+
AgentCard,
11+
Message,
12+
Role,
13+
Part,
14+
SendMessageRequest,
15+
SendMessageConfiguration,
16+
)
17+
18+
19+
@pytest.fixture
20+
async def agent_card() -> AgentCard:
21+
mock_agent_card = MagicMock(spec=AgentCard)
22+
mock_agent_card.url = 'http://mockurl.com'
23+
mock_capabilities = MagicMock()
24+
mock_capabilities.streaming = False
25+
mock_agent_card.capabilities = mock_capabilities
26+
return mock_agent_card
27+
28+
29+
@pytest.fixture
30+
async def request_handler() -> RequestHandler:
31+
handler = MagicMock(spec=RequestHandler)
32+
# Return a default response so the test doesn't crash on return value expectation
33+
handler.on_message_send.return_value = Message(
34+
message_id='test',
35+
role=Role.ROLE_AGENT,
36+
parts=[Part(text='response message')],
37+
)
38+
return handler
39+
40+
41+
@pytest.fixture
42+
async def app(
43+
agent_card: AgentCard, request_handler: RequestHandler
44+
) -> FastAPI:
45+
return A2ARESTFastAPIApplication(agent_card, request_handler).build(
46+
agent_card_url='/well-known/agent.json', rpc_url=''
47+
)
48+
49+
50+
@pytest.fixture
51+
async def client(app: FastAPI) -> AsyncClient:
52+
return AsyncClient(transport=ASGITransport(app=app), base_url='http://test')
53+
54+
55+
@pytest.mark.anyio
56+
async def test_tenant_extraction_from_path(
57+
client: AsyncClient, request_handler: MagicMock
58+
) -> None:
59+
request = SendMessageRequest(
60+
message=Message(),
61+
configuration=SendMessageConfiguration(),
62+
)
63+
64+
# Test with tenant in URL
65+
tenant_id = 'my-tenant-123'
66+
response = await client.post(
67+
f'/{tenant_id}/message:send', json=json_format.MessageToDict(request)
68+
)
69+
response.raise_for_status()
70+
71+
# Verify handler was called
72+
assert request_handler.on_message_send.called
73+
74+
# Verify call context has tenant
75+
args, _ = request_handler.on_message_send.call_args
76+
# args[0] is the request proto, args[1] is the ServerCallContext
77+
context = args[1]
78+
assert context.tenant == tenant_id
79+
80+
81+
@pytest.mark.anyio
82+
async def test_no_tenant_extraction(
83+
client: AsyncClient, request_handler: MagicMock
84+
) -> None:
85+
request = SendMessageRequest(
86+
message=Message(),
87+
configuration=SendMessageConfiguration(),
88+
)
89+
90+
# Test without tenant in URL
91+
response = await client.post(
92+
'/message:send', json=json_format.MessageToDict(request)
93+
)
94+
response.raise_for_status()
95+
96+
# Verify call context has empty string tenant (default)
97+
args, _ = request_handler.on_message_send.call_args
98+
context = args[1]
99+
assert context.tenant == ''

0 commit comments

Comments
 (0)