Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def add_activated_extension(self, uri: str) -> None:
if self._call_context:
self._call_context.activated_extensions.add(uri)

@property
def tenant(self) -> str:
"""The tenant associated with this request."""
return self._call_context.tenant if self._call_context else ''

@property
def requested_extensions(self) -> set[str]:
"""Extensions that the client requested to activate."""
Expand Down
54 changes: 54 additions & 0 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ async def _handle_request(
request: Request,
) -> Response:
call_context = self._context_builder.build(request)
if 'tenant' in request.path_params:
call_context.tenant = request.path_params['tenant']
Comment thread
sokoliva marked this conversation as resolved.
Outdated

response = await method(request, call_context)
return JSONResponse(content=response)

Expand All @@ -131,6 +134,8 @@ async def _handle_streaming_request(
) from e

call_context = self._context_builder.build(request)
if 'tenant' in request.path_params:
call_context.tenant = request.path_params['tenant']
Comment thread
sokoliva marked this conversation as resolved.
Outdated
Comment thread
sokoliva marked this conversation as resolved.
Outdated

async def event_generator(
stream: AsyncIterable[Any],
Expand Down Expand Up @@ -250,10 +255,59 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
('/tasks', 'GET'): functools.partial(
self._handle_request, self.handler.list_tasks
),
# Tenant prefixed routes
('/{tenant}/message:send', 'POST'): functools.partial(
self._handle_request,
self.handler.on_message_send,
),
('/{tenant}/message:stream', 'POST'): functools.partial(
self._handle_streaming_request,
self.handler.on_message_send_stream,
),
('/{tenant}/tasks/{id}:cancel', 'POST'): functools.partial(
self._handle_request, self.handler.on_cancel_task
),
('/{tenant}/tasks/{id}:subscribe', 'GET'): functools.partial(
self._handle_streaming_request,
self.handler.on_subscribe_to_task,
),
('/{tenant}/tasks/{id}', 'GET'): functools.partial(
self._handle_request, self.handler.on_get_task
),
(
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
'GET',
): functools.partial(
self._handle_request, self.handler.get_push_notification
),
(
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
'DELETE',
): functools.partial(
self._handle_request, self.handler.delete_push_notification
),
(
'/{tenant}/tasks/{id}/pushNotificationConfigs',
'POST',
): functools.partial(
self._handle_request, self.handler.set_push_notification
),
(
'/{tenant}/tasks/{id}/pushNotificationConfigs',
'GET',
): functools.partial(
self._handle_request, self.handler.list_push_notifications
),
('/{tenant}/tasks', 'GET'): functools.partial(
self._handle_request, self.handler.list_tasks
),
Comment thread
sokoliva marked this conversation as resolved.
Outdated
Comment thread
sokoliva marked this conversation as resolved.
Outdated
}
if self.agent_card.capabilities.extended_agent_card:
routes[('/extendedAgentCard', 'GET')] = functools.partial(
self._handle_request, self.handle_authenticated_agent_card
)
routes[('/{tenant}/extendedAgentCard', 'GET')] = functools.partial(
self._handle_request, self.handle_authenticated_agent_card
)
Comment thread
sokoliva marked this conversation as resolved.

return routes
1 change: 1 addition & 0 deletions src/a2a/server/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ class ServerCallContext(BaseModel):

state: State = Field(default={})
user: User = Field(default=UnauthenticatedUser())
tenant: str = Field(default='')
requested_extensions: set[str] = Field(default_factory=set)
activated_extensions: set[str] = Field(default_factory=set)
99 changes: 99 additions & 0 deletions tests/server/apps/rest/test_rest_tenant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
from unittest.mock import MagicMock
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from google.protobuf import json_format

from a2a.server.apps.rest.fastapi_app import A2ARESTFastAPIApplication
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types.a2a_pb2 import (
AgentCard,
Message,
Role,
Part,
SendMessageRequest,
SendMessageConfiguration,
)


@pytest.fixture
async def agent_card() -> AgentCard:
mock_agent_card = MagicMock(spec=AgentCard)
mock_agent_card.url = 'http://mockurl.com'
mock_capabilities = MagicMock()
mock_capabilities.streaming = False
mock_agent_card.capabilities = mock_capabilities
return mock_agent_card


@pytest.fixture
async def request_handler() -> RequestHandler:
handler = MagicMock(spec=RequestHandler)
# Return a default response so the test doesn't crash on return value expectation
handler.on_message_send.return_value = Message(
message_id='test',
role=Role.ROLE_AGENT,
parts=[Part(text='response message')],
)
return handler


@pytest.fixture
async def app(
agent_card: AgentCard, request_handler: RequestHandler
) -> FastAPI:
return A2ARESTFastAPIApplication(agent_card, request_handler).build(
agent_card_url='/well-known/agent.json', rpc_url=''
)


@pytest.fixture
async def client(app: FastAPI) -> AsyncClient:
return AsyncClient(transport=ASGITransport(app=app), base_url='http://test')


@pytest.mark.anyio
async def test_tenant_extraction_from_path(
client: AsyncClient, request_handler: MagicMock
) -> None:
request = SendMessageRequest(
message=Message(),
configuration=SendMessageConfiguration(),
)

# Test with tenant in URL
tenant_id = 'my-tenant-123'
response = await client.post(
f'/{tenant_id}/message:send', json=json_format.MessageToDict(request)
)
response.raise_for_status()

# Verify handler was called
assert request_handler.on_message_send.called

# Verify call context has tenant
args, _ = request_handler.on_message_send.call_args
# args[0] is the request proto, args[1] is the ServerCallContext
context = args[1]
assert context.tenant == tenant_id


@pytest.mark.anyio
async def test_no_tenant_extraction(
client: AsyncClient, request_handler: MagicMock
) -> None:
request = SendMessageRequest(
message=Message(),
configuration=SendMessageConfiguration(),
)

# Test without tenant in URL
response = await client.post(
'/message:send', json=json_format.MessageToDict(request)
)
response.raise_for_status()

# Verify call context has empty string tenant (default)
args, _ = request_handler.on_message_send.call_args
context = args[1]
assert context.tenant == ''