diff --git a/itk/main.py b/itk/main.py index cc761d081..76c72e1c2 100644 --- a/itk/main.py +++ b/itk/main.py @@ -31,7 +31,6 @@ InMemoryPushNotificationConfigStore, ) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore -from a2a.server.context import ServerCallContext from a2a.types import a2a_pb2_grpc from a2a.types.a2a_pb2 import ( AgentCapabilities, @@ -339,7 +338,6 @@ async def main_async(http_port: int, grpc_port: int) -> None: push_sender = BasePushNotificationSender( httpx_client=httpx.AsyncClient(), config_store=push_config_store, - context=ServerCallContext(), ) handler = DefaultRequestHandler( diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4a4929e8f..ff9ca3ce5 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -27,26 +27,39 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, - context: ServerCallContext, + context: ServerCallContext | None = None, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. - config_store: A PushNotificationConfigStore instance to retrieve configurations. - context: The `ServerCallContext` that this push notification is produced under. + config_store: A PushNotificationConfigStore instance to + retrieve configurations. + context: Deprecated and ignored. Accepted only for + backward compatibility with 1.0 callers that constructed + the sender with a (typically dummy) ServerCallContext. + Pass None (the default) in new code. A non-None + value logs a deprecation warning and is otherwise + ignored. """ + if context is not None: + logger.warning( + 'BasePushNotificationSender no longer uses the context ' + 'parameter; it is accepted only for backward compatibility ' + 'with 1.0 and will be removed in a future major version. ' + 'Push notifications now fan out across all owners via ' + 'PushNotificationConfigStore.get_info_for_dispatch; the ' + 'caller identity is not carried into dispatch. Drop the ' + 'context argument from the constructor call.' + ) self._client = httpx_client self._config_store = config_store - self._call_context: ServerCallContext = context async def send_notification( self, task_id: str, event: PushNotificationEvent ) -> None: """Sends a push notification for an event if configuration exists.""" - push_configs = await self._config_store.get_info( - task_id, self._call_context - ) + push_configs = await self._config_store.get_info_for_dispatch(task_id) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 31cd676c8..d050de7cc 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -7,7 +7,7 @@ try: - from sqlalchemy import Table, and_, delete, select + from sqlalchemy import ColumnElement, Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -304,21 +304,14 @@ async def set_info( owner, ) - async def get_info( + async def _select_configs( self, - task_id: str, - context: ServerCallContext, + *predicates: 'ColumnElement[bool]', ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task, for the given owner.""" + """Loads configs matching the given predicates and decodes them.""" await self._ensure_initialized() - owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.config_model).where( - and_( - self.config_model.task_id == task_id, - self.config_model.owner == owner, - ) - ) + stmt = select(self.config_model).where(and_(*predicates)) result = await session.execute(stmt) models = result.scalars().all() @@ -331,10 +324,37 @@ async def get_info( 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, - owner, + model.owner, ) return configs + async def get_info( + self, + task_id: str, + context: ServerCallContext, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner. + + Used by the user-callable read endpoints. + """ + owner = self.owner_resolver(context) + return await self._select_configs( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + Used by the push-notification dispatch path. + """ + return await self._select_configs( + self.config_model.task_id == task_id, + ) + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index d5b0a5b1f..19e35074a 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -72,12 +72,29 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves all push notification configurations for a task from memory, for the given owner.""" + """Retrieves all push notification configurations for a task from memory, for the given owner. + + Used by the user-callable read endpoints. + """ owner = self.owner_resolver(context) async with self.lock: owner_infos = self._get_owner_push_notification_infos(owner) return list(owner_infos.get(task_id, [])) + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task across all owners. + + Used by the push-notification dispatch path. + """ + async with self.lock: + results: list[TaskPushNotificationConfig] = [] + for all_configs in self._push_notification_infos.values(): + results.extend(all_configs.get(task_id, [])) + return results + async def delete_info( self, task_id: str, diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index 6b5b35245..e1e65c3fb 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,9 +1,14 @@ +import logging + from abc import ABC, abstractmethod from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import TaskPushNotificationConfig +logger = logging.getLogger(__name__) + + class PushNotificationConfigStore(ABC): """Interface for storing and retrieving push notification configurations for tasks.""" @@ -22,7 +27,46 @@ async def get_info( task_id: str, context: ServerCallContext, ) -> list[TaskPushNotificationConfig]: - """Retrieves the push notification configuration for a task.""" + """Retrieves push notification configurations for a task, scoped to the caller. + + This is the user-callable read path. Implementations MUST return + only configurations owned by the caller (as resolved from + context). + """ + + async def get_info_for_dispatch( + self, + task_id: str, + ) -> list[TaskPushNotificationConfig]: + """Retrieves all push notification configurations for a task, across all owners. + + This is the internal read path used by the push-notification + dispatch loop. Implementations SHOULD override this method to + return every configuration registered for task_id regardless of + which user registered it. Authorization already happened at + registration time and the dispatch path fires every registered + webhook for the task. + + The default implementation falls back to calling get_info with + a synthetic empty ServerCallContext. This preserves 1.0 + behavior for subclasses that have not implemented the override + but is INCORRECT for any deployment with multiple owners: the + empty context resolves to the empty-string owner partition and + returns no configs (silently dropping every notification). A + warning is logged on every call to flag the misconfiguration. + Custom subclasses MUST override this method to deliver + notifications correctly in multi-owner deployments. + """ + logger.warning( + '%s does not override ' + 'PushNotificationConfigStore.get_info_for_dispatch; falling back ' + 'to a context-less get_info call which silently drops ' + 'notifications in any deployment with multiple owners. Override ' + 'get_info_for_dispatch to return all configs for task_id across ' + 'every owner.', + type(self).__name__, + ) + return await self.get_info(task_id, ServerCallContext()) @abstractmethod async def delete_info( diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index bc95f6c37..9bb3a02fa 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -1,14 +1,17 @@ import httpx from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.requests import Request +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue -from starlette.applications import Starlette -from a2a.server.routes.rest_routes import create_rest_routes -from a2a.server.routes import create_agent_card_routes from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.routes import create_agent_card_routes +from a2a.server.routes.common import DefaultServerCallContextBuilder +from a2a.server.routes.rest_routes import create_rest_routes from a2a.server.tasks import ( BasePushNotificationSender, InMemoryPushNotificationConfigStore, @@ -30,6 +33,9 @@ ) +_TEST_USER_HEADER = 'x-test-user' + + def test_agent_card(url: str) -> AgentCard: """Returns an agent card for the test agent.""" return AgentCard( @@ -151,7 +157,6 @@ def create_agent_app( push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, - context=ServerCallContext(), ), ) rest_routes = create_rest_routes(request_handler=handler) @@ -159,3 +164,78 @@ def create_agent_app( agent_card=card, card_url='/.well-known/agent-card.json' ) return Starlette(routes=[*rest_routes, *agent_card_routes]) + + +class _NamedTestUser(User): + """Authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +class _HeaderUserContextBuilder(DefaultServerCallContextBuilder): + """Builds a ServerCallContext whose user is read from a request header.""" + + def build_user(self, request: Request) -> User: + user_name = request.headers.get(_TEST_USER_HEADER) + if user_name: + return _NamedTestUser(user_name) + return UnauthenticatedUser() + + +def create_multi_user_agent_app( + url: str, notification_client: httpx.AsyncClient +) -> Starlette: + """Creates a multi-user variant of the test agent app. + + Differences from create_agent_app: + + - Identity is read from the x-test-user header on each request + via _HeaderUserContextBuilder. Multiple authenticated + users (e.g. alice, bob) can therefore call the same + server. + - The InMemoryTaskStore uses a constant owner resolver, so + every authenticated user has access to every task. + - The InMemoryPushNotificationConfigStore keeps the default + per-user owner resolver, so each registrar's configs live in their + own owner partition; this exercises cross-owner aggregation in + get_info_for_dispatch. + """ + # Shared task visibility: any authenticated user can see any task. + task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared') + + # Per-user push-config partitioning (the default). + push_config_store = InMemoryPushNotificationConfigStore() + + card = test_agent_card(url) + extended_card = test_agent_card(url) + extended_card.name = 'Test Agent Extended' + + handler = DefaultRequestHandler( + agent_executor=TestAgentExecutor(), + task_store=task_store, + agent_card=card, + extended_agent_card=extended_card, + push_config_store=push_config_store, + push_sender=BasePushNotificationSender( + httpx_client=notification_client, + config_store=push_config_store, + ), + ) + + rest_routes = create_rest_routes( + request_handler=handler, + context_builder=_HeaderUserContextBuilder(), + ) + agent_card_routes = create_agent_card_routes( + agent_card=card, card_url='/.well-known/agent-card.json' + ) + return Starlette(routes=[*rest_routes, *agent_card_routes]) diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index 35e4bbeb4..84fd14c9a 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -6,7 +6,7 @@ import pytest import pytest_asyncio -from .agent_app import create_agent_app +from .agent_app import create_agent_app, create_multi_user_agent_app from .notifications_app import Notification, create_notifications_app from .utils import ( create_app_process, @@ -21,9 +21,9 @@ ) from a2a.utils.constants import TransportProtocol from a2a.types.a2a_pb2 import ( + ListTaskPushNotificationConfigsRequest, Message, Part, - TaskPushNotificationConfig, Role, SendMessageConfiguration, SendMessageRequest, @@ -33,6 +33,9 @@ ) +_TEST_USER_HEADER = 'x-test-user' + + @pytest.fixture(scope='module') def notifications_server(): """ @@ -88,6 +91,40 @@ def agent_server(notifications_client: httpx.AsyncClient): process.join() +@pytest.fixture(scope='module') +def multi_user_agent_server(notifications_client: httpx.AsyncClient): + """Starts the multi-user variant of the test agent server. + + This variant reads identity from an x-test-user request header + and uses a TaskStore whose owner resolver returns a constant, so + every authenticated user can see every task. It runs on its own + port alongside the single-user agent_server fixture; the + notifications_server is shared (notifications include the + task_id and per-config token, so collisions are avoided). + """ + host = '127.0.0.1' + port = find_free_port() + url = f'http://{host}:{port}' + + process = create_app_process( + create_multi_user_agent_app(url, notifications_client), host, port + ) + process.start() + try: + wait_for_server_ready( + f'{url}/extendedAgentCard', + headers={'A2A-Version': '1.0', _TEST_USER_HEADER: 'health-check'}, + ) + except TimeoutError as e: + process.terminate() + raise e + + yield url + + process.terminate() + process.join() + + @pytest_asyncio.fixture(scope='function') async def http_client(): """An async client fixture for test functions.""" @@ -238,6 +275,272 @@ async def test_notification_triggering_after_config_change_e2e( assert notifications[0].token == token +@pytest.mark.asyncio +async def test_multi_registrar_fan_out_e2e( + notifications_server: str, + agent_server: str, + http_client: httpx.AsyncClient, +): + """Two pushNotificationConfigs registered for the same task both fire end-to-end. + + Exercises the dispatch fan-out across multiple registered configs + over the real wire: each registered URL must receive a POST with + its own token in the X-A2A-Notification-Token header. + """ + # Configure an A2A client without a per-message push notification config + # (we'll register configs explicitly after the task is created). + a2a_client = ClientFactory( + ClientConfig( + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create(minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON])) + + # Send an initial message that requires more input, so the task lingers + # long enough for us to register multiple push configs against it. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + message_id='multi-fanout-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + stream_response = responses[0] + assert stream_response.HasField('task') + task = stream_response.task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # Register two distinct push configs for the same task. Both share the + # same registrar (this client), but use different config ids, URLs, and + # tokens. Both must fire when the next event is dispatched. + token_a = uuid.uuid4().hex + token_b = uuid.uuid4().hex + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-a', + url=f'{notifications_server}/notifications', + token=token_a, + ) + ) + await a2a_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='registrar-b', + url=f'{notifications_server}/notifications', + token=token_b, + ) + ) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # Send a follow-up message that completes the task and triggers + # dispatch. Both registered configs must receive a POST. + responses = [ + response + async for response in a2a_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='multi-fanout-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + configuration=SendMessageConfiguration(), + ) + ) + ] + assert len(responses) == 1 + + # Expect 2 notifications: one COMPLETED event, fanned out to 2 configs. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + # Both tokens must appear exactly once. + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([token_a, token_b]) + + # Both notifications must carry the same COMPLETED event payload. + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + + +def _make_user_a2a_client(agent_server: str, user_name: str): + """Builds an A2A client that identifies as user_name on every request. + + Identity is conveyed via a default header on the underlying + httpx.AsyncClient; the multi-user agent app's context builder + reads that header to populate ServerCallContext.user. + """ + httpx_client = httpx.AsyncClient(headers={_TEST_USER_HEADER: user_name}) + return ClientFactory( + ClientConfig( + httpx_client=httpx_client, + supported_protocol_bindings=[TransportProtocol.HTTP_JSON], + ) + ).create( + minimal_agent_card(agent_server, [TransportProtocol.HTTP_JSON]) + ), httpx_client + + +@pytest.mark.asyncio +async def test_alice_and_bob_both_receive_notifications_on_shared_task_e2e( + notifications_server: str, + multi_user_agent_server: str, + http_client: httpx.AsyncClient, +): + """Alice registers a webhook; Bob registers a webhook; both fire end-to-end. + + 1. Alice creates a task (it lingers in INPUT_REQUIRED). + 2. Alice registers her own push config on the task. + 3. Bob (a different authenticated user, sharing access to the task) + registers his own push config on the same task. + 4. Bob (the dispatcher, *not* the registrar of Alice's webhook) + sends a follow-up message that completes the task. + 5. Both Alice's webhook and Bob's webhook receive a POST with their + own respective tokens. + + Regression guard for the design's central guarantee: subscriptions + fire on the registrar's behalf regardless of which user's action + triggered the event. A regression that re-introduced + dispatcher-context filtering on the dispatch path would drop one of + the two notifications. + """ + alice_client, alice_http = _make_user_a2a_client( + multi_user_agent_server, 'alice' + ) + bob_client, bob_http = _make_user_a2a_client(multi_user_agent_server, 'bob') + + try: + responses = [ + response + async for response in alice_client.send_message( + SendMessageRequest( + message=Message( + message_id='shared-task-init', + parts=[Part(text='How are you?')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + assert responses[0].HasField('task') + task = responses[0].task + assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED + + # 2. Alice registers her push config. + alice_token = uuid.uuid4().hex + await alice_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='alice-cfg', + url=f'{notifications_server}/notifications', + token=alice_token, + ) + ) + + # 3. Bob registers his push config on the same task. + bob_token = uuid.uuid4().hex + await bob_client.create_task_push_notification_config( + TaskPushNotificationConfig( + task_id=task.id, + id='bob-cfg', + url=f'{notifications_server}/notifications', + token=bob_token, + ) + ) + + # Sanity: the per-user listing endpoints are owner-scoped -- + # Alice does not see Bob's config and vice-versa, even though + # both can see the underlying task. + # + # The auto-registered empty config (see step 1 quirk note) lives + # in Alice's partition under ``id == task_id``, so Alice's + # listing contains ``{'alice-cfg', task.id}``; the key invariant + # is that neither listing contains the other user's id or + # token. + alice_configs = await alice_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + alice_ids = {c.id for c in alice_configs.configs} + assert 'alice-cfg' in alice_ids + assert 'bob-cfg' not in alice_ids + assert all(c.token != bob_token for c in alice_configs.configs) + + bob_configs = await bob_client.list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id=task.id) + ) + bob_ids = {c.id for c in bob_configs.configs} + assert 'bob-cfg' in bob_ids + assert 'alice-cfg' not in bob_ids + assert all(c.token != alice_token for c in bob_configs.configs) + + # Sanity: no notifications have fired yet. + response = await http_client.get( + f'{notifications_server}/{task.id}/notifications' + ) + assert response.status_code == 200 + assert len(response.json().get('notifications', [])) == 0 + + # 4. Bob sends the follow-up message that completes the task. + # Omit ``configuration`` for the same reason as step 1. + responses = [ + response + async for response in bob_client.send_message( + SendMessageRequest( + message=Message( + task_id=task.id, + message_id='shared-task-complete', + parts=[Part(text='Good')], + role=Role.ROLE_USER, + ), + ) + ) + ] + assert len(responses) == 1 + + # 5. Both Alice's and Bob's webhooks receive the COMPLETED event. + notifications = await wait_for_n_notifications( + http_client, + f'{notifications_server}/{task.id}/notifications', + n=2, + ) + + received_tokens = sorted(n.token for n in notifications) + assert received_tokens == sorted([alice_token, bob_token]) + + for notification in notifications: + state = ( + notification.event.get('status_update', {}) + .get('status', {}) + .get('state') + ) + assert state == 'TASK_STATE_COMPLETED' + finally: + await alice_http.aclose() + await bob_http.aclose() + + async def wait_for_n_notifications( http_client: httpx.AsyncClient, url: str, diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 5a2bf0446..0138045ae 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -14,7 +14,7 @@ import pytest -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import ( AgentExecutor, RequestContext, @@ -1590,7 +1590,6 @@ def __init__(self): async def execute( self, context: RequestContext, event_queue: EventQueue ): - updater = TaskUpdater( event_queue, cast('str', context.task_id), @@ -2977,3 +2976,171 @@ async def test_on_subscribe_to_task_unsupported(agent_card): # We need to exhaust the generator to trigger the decorator evaluation async for _ in request_handler.on_subscribe_to_task(params, context): pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped( + agent_card, +): + """Bob must not see Alice's configs via tasks/pushNotificationConfig/list. + + Both users have access to the shared task (the mocked TaskStore + returns it for any caller), but listing must only return the + caller's own configs. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + # Sanity: Bob's secret is not in the response. + assert all(c.token != 'bob-secret' for c in alice_listing.configs), ( + 'Listing for Alice must not expose Bob-owned tokens' + ) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs), ( + 'Listing for Bob must not expose Alice-owned tokens' + ) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_returns_empty_for_third_user( + agent_card, +): + """A third user with task access but no registered configs sees an empty list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + ), + _ctx('alice'), + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + carol_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + _ctx('carol'), + ) + ) + assert carol_listing.configs == [] + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped( + agent_card, +): + """Bob cannot fetch Alice's config by ID via tasks/pushNotificationConfig/get. + + Even when Bob can read the task and knows (or guesses) the + config_id, the handler must raise TaskNotFoundError because Alice's + config is not in Bob's owner partition. + """ + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = create_sample_task(task_id='shared-task') + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandler( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=agent_card, + ) + + # Alice can read her own config. + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + # Bob cannot, even guessing the exact config_id. + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index e35b8f720..3f33516d3 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -7,7 +7,7 @@ import pytest -from a2a.auth.user import UnauthenticatedUser +from a2a.auth.user import UnauthenticatedUser, User from a2a.server.agent_execution import ( RequestContextBuilder, AgentExecutor, @@ -1411,3 +1411,119 @@ async def test_on_message_send_stream_rejects_event_after_terminal_state(): params, create_server_call_context() ): pass + + +class _NamedUser(User): + """Minimal authenticated test user identified by ``user_name``.""" + + def __init__(self, user_name: str) -> None: + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _ctx(user_name: str) -> ServerCallContext: + return ServerCallContext(user=_NamedUser(user_name)) + + +@pytest.mark.asyncio +async def test_on_list_task_push_notification_configs_is_owner_scoped(): + """v2 handler: Bob must not see Alice's configs via .../list.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + bob_ctx = _ctx('bob') + + alice_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ) + bob_cfg = TaskPushNotificationConfig( + task_id='shared-task', + id='bob-cfg', + url='http://bob.example.com/cb', + token='bob-secret', + ) + await push_store.set_info('shared-task', alice_cfg, alice_ctx) + await push_store.set_info('shared-task', bob_cfg, bob_ctx) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_listing = ( + await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + alice_ctx, + ) + ) + assert {c.id for c in alice_listing.configs} == {'alice-cfg'} + assert all(c.token != 'bob-secret' for c in alice_listing.configs) + + bob_listing = await request_handler.on_list_task_push_notification_configs( + ListTaskPushNotificationConfigsRequest(task_id='shared-task'), + bob_ctx, + ) + assert {c.id for c in bob_listing.configs} == {'bob-cfg'} + assert all(c.token != 'alice-secret' for c in bob_listing.configs) + + +@pytest.mark.asyncio +async def test_on_get_task_push_notification_config_is_owner_scoped(): + """v2 handler: Bob cannot fetch Alice's config by ID via .../get.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_task_store.get.return_value = Task( + id='shared-task', context_id='ctx_1' + ) + + push_store = InMemoryPushNotificationConfigStore() + alice_ctx = _ctx('alice') + await push_store.set_info( + 'shared-task', + TaskPushNotificationConfig( + task_id='shared-task', + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-secret', + ), + alice_ctx, + ) + + request_handler = DefaultRequestHandlerV2( + agent_executor=MockAgentExecutor(), + task_store=mock_task_store, + push_config_store=push_store, + agent_card=create_default_agent_card(), + ) + + alice_view = await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + alice_ctx, + ) + assert alice_view.id == 'alice-cfg' + assert alice_view.token == 'alice-secret' + + with pytest.raises(TaskNotFoundError): + await request_handler.on_get_task_push_notification_config( + GetTaskPushNotificationConfigRequest( + task_id='shared-task', id='alice-cfg' + ), + _ctx('bob'), + ) diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index b13a5cf55..6608d49bf 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -727,6 +727,57 @@ async def test_owner_resource_scoping( await config_store.delete_info('task1', context=context_user2) +@pytest.mark.asyncio +async def test_get_info_for_dispatch_returns_all_owners( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """get_info_for_dispatch MUST return configs across all owners. + + The dispatch path has no caller identity (the originating request + has completed by the time notifications fire). Authorization + happened at registration time. The DB query must therefore filter + on task_id only, with no owner predicate. + """ + config_store = db_store_parameterized + + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + alice_cfg = TaskPushNotificationConfig( + id='alice-cfg', url='http://alice.example.com/cb' + ) + bob_cfg = TaskPushNotificationConfig( + id='bob-cfg', url='http://bob.example.com/cb' + ) + other_task_cfg = TaskPushNotificationConfig( + id='alice-other', url='http://alice.example.com/other' + ) + + await config_store.set_info('shared-task', alice_cfg, alice_ctx) + await config_store.set_info('shared-task', bob_cfg, bob_ctx) + # An unrelated config on a different task -- must NOT leak through. + await config_store.set_info('other-task', other_task_cfg, alice_ctx) + + dispatched = await config_store.get_info_for_dispatch('shared-task') + + assert {c.id for c in dispatched} == {'alice-cfg', 'bob-cfg'} + assert {c.url for c in dispatched} == { + 'http://alice.example.com/cb', + 'http://bob.example.com/cb', + } + + # Sanity: user-callable get_info remains owner-scoped on the same data. + alice_view = await config_store.get_info('shared-task', alice_ctx) + assert {c.id for c in alice_view} == {'alice-cfg'} + bob_view = await config_store.get_info('shared-task', bob_ctx) + assert {c.id for c in bob_view} == {'bob-cfg'} + + # Cleanup + await config_store.delete_info('shared-task', context=alice_ctx) + await config_store.delete_info('shared-task', context=bob_ctx) + await config_store.delete_info('other-task', context=alice_ctx) + + @pytest.mark.asyncio async def test_get_0_3_push_notification_config_detailed( db_store_parameterized: DatabasePushNotificationConfigStore, diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index d8b560aae..d23bcee05 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx + from google.protobuf.json_format import MessageToDict from a2a.auth.user import User @@ -14,9 +15,9 @@ InMemoryPushNotificationConfigStore, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, + TaskPushNotificationConfig, TaskState, TaskStatus, ) @@ -70,8 +71,7 @@ def setUp(self) -> None: self.notifier = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.config_store, - context=MINIMAL_CALL_CONTEXT, - ) # Corrected argument name + ) def test_constructor_stores_client(self) -> None: self.assertEqual(self.notifier._client, self.mock_httpx_client) @@ -428,5 +428,148 @@ async def test_owner_resource_scoping(self) -> None: await self.config_store.delete_info('task1', context=context_user2) +class TestPushNotificationDispatchAcrossOwners( + unittest.IsolatedAsyncioTestCase +): + """Dispatch-correctness tests for the registrar/dispatcher asymmetry. + + Push notifications must fire for any event on the task, regardless of + which user's action triggered the event. The dispatch path therefore + reads configs via get_info_for_dispatch (cross-owner), not + get_info (owner-scoped). + """ + + def setUp(self) -> None: + self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + self.mock_httpx_client.post.return_value = mock_response + + self.config_store = InMemoryPushNotificationConfigStore() + + self.sender = BasePushNotificationSender( + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + ) + + async def test_multi_registrar_fan_out(self) -> None: + """Three users registering distinct webhooks for the same task all fire.""" + users_and_urls = [ + ('alice', 'http://alice.example.com/cb', 'tok-alice'), + ('bob', 'http://bob.example.com/cb', 'tok-bob'), + ('carol', 'http://carol.example.com/cb', 'tok-carol'), + ] + for user_name, url, token in users_and_urls: + ctx = ServerCallContext(user=SampleUser(user_name=user_name)) + cfg = TaskPushNotificationConfig( + id=f'cfg-{user_name}', url=url, token=token + ) + await self.config_store.set_info('shared-task', cfg, ctx) + + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + + self.assertEqual(self.mock_httpx_client.post.await_count, 3) + called_urls = { + call.args[0] for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_urls, + {url for _, url, _ in users_and_urls}, + ) + called_tokens = { + call.kwargs['headers']['X-A2A-Notification-Token'] + for call in self.mock_httpx_client.post.call_args_list + } + self.assertEqual( + called_tokens, + {token for _, _, token in users_and_urls}, + ) + + async def test_write_side_owner_isolation_preserved(self) -> None: + """Bob's ``delete_info`` against Alice's config is a no-op. + + After the no-op, Alice's config must still be: + (a) retrievable via the user-callable ``get_info`` for Alice, and + (b) returned by ``get_info_for_dispatch`` so that the + notification will still fire. + + Guards the write-side scoping that the design preserves + (see ยง9.3). + """ + alice_ctx = ServerCallContext(user=SampleUser(user_name='alice')) + bob_ctx = ServerCallContext(user=SampleUser(user_name='bob')) + + config = TaskPushNotificationConfig( + id='alice-cfg', + url='http://alice.example.com/cb', + token='alice-token', + ) + await self.config_store.set_info('shared-task', config, alice_ctx) + + # Bob attempts to delete Alice's config -- must be a no-op. + await self.config_store.delete_info( + 'shared-task', context=bob_ctx, config_id='alice-cfg' + ) + + # (a) Alice's user-callable view is unchanged. + alice_view = await self.config_store.get_info('shared-task', alice_ctx) + self.assertEqual(len(alice_view), 1) + self.assertEqual(alice_view[0].id, 'alice-cfg') + + # (b) Dispatch path still sees the config (notifications fire). + dispatched = await self.config_store.get_info_for_dispatch( + 'shared-task' + ) + self.assertEqual(len(dispatched), 1) + self.assertEqual(dispatched[0].id, 'alice-cfg') + self.assertEqual(dispatched[0].token, 'alice-token') + + # And end-to-end: the sender actually dispatches to Alice's URL. + await self.sender.send_notification( + 'shared-task', _create_sample_task(task_id='shared-task') + ) + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='shared-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + async def test_cross_user_dispatch_alice_registers_bob_triggers( + self, + ) -> None: + """Alice registers; Bob triggers; Alice's webhook receives the POST. + + The send_notification carries no identity, so there is no notion of + "who triggered this event" at the store layer. get_info_for_dispatch + returns Alice's config because Alice registered it. The fact that the + event was caused by Bob is not visible to (and not relevant for) the + dispatch path. + """ + alice_context = ServerCallContext(user=SampleUser(user_name='alice')) + config = _create_sample_push_config( + url='http://alice.example.com/cb', token='alice-token' + ) + await self.config_store.set_info('collab-task', config, alice_context) + + # No bob_context is passed anywhere -- the dispatch path never + # sees it. This is precisely the point: identity is not the + # dispatch path's concern. + await self.sender.send_notification( + 'collab-task', _create_sample_task(task_id='collab-task') + ) + + self.mock_httpx_client.post.assert_awaited_once_with( + 'http://alice.example.com/cb', + json=MessageToDict( + StreamResponse(task=_create_sample_task(task_id='collab-task')) + ), + headers={'X-A2A-Notification-Token': 'alice-token'}, + ) + + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index 783e1f413..22f904a2a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -6,40 +6,20 @@ from google.protobuf.json_format import MessageToDict -from a2a.auth.user import User -from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) from a2a.types.a2a_pb2 import ( - TaskPushNotificationConfig, StreamResponse, Task, TaskArtifactUpdateEvent, + TaskPushNotificationConfig, TaskState, TaskStatus, TaskStatusUpdateEvent, ) -class SampleUser(User): - """A test implementation of the User interface.""" - - def __init__(self, user_name: str): - self._user_name = user_name - - @property - def is_authenticated(self) -> bool: - return True - - @property - def user_name(self) -> str: - return self._user_name - - -MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) - - def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, @@ -66,7 +46,6 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, - context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -77,7 +56,7 @@ async def test_send_notification_success(self) -> None: task_id = 'task_send_success' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/here') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -85,8 +64,8 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -103,7 +82,7 @@ async def test_send_notification_with_token_success(self) -> None: config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -111,8 +90,8 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_data.id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id ) # assert httpx_client post method got invoked with right parameters @@ -126,12 +105,12 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = _create_sample_task(task_id=task_id) - self.mock_config_store.get_info.return_value = [] + self.mock_config_store.get_info_for_dispatch.return_value = [] await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_not_called() @@ -142,7 +121,7 @@ async def test_send_notification_http_status_error( task_id = 'task_send_http_err' task_data = _create_sample_task(task_id=task_id) config = _create_sample_push_config(url='http://notify.me/http_error') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = MagicMock(spec=httpx.Response) mock_response.status_code = 404 @@ -154,8 +133,8 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -173,7 +152,10 @@ async def test_send_notification_multiple_configs(self) -> None: config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) - self.mock_config_store.get_info.return_value = [config1, config2] + self.mock_config_store.get_info_for_dispatch.return_value = [ + config1, + config2, + ] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -181,8 +163,8 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_id, task_data) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) @@ -207,7 +189,7 @@ async def test_send_notification_status_update_event(self) -> None: status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) config = _create_sample_push_config(url='http://notify.me/status') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -215,8 +197,8 @@ async def test_send_notification_status_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, @@ -231,7 +213,7 @@ async def test_send_notification_artifact_update_event(self) -> None: append=True, ) config = _create_sample_push_config(url='http://notify.me/artifact') - self.mock_config_store.get_info.return_value = [config] + self.mock_config_store.get_info_for_dispatch.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -239,8 +221,8 @@ async def test_send_notification_artifact_update_event(self) -> None: await self.sender.send_notification(task_id, event) - self.mock_config_store.get_info.assert_awaited_once_with( - task_id, MINIMAL_CALL_CONTEXT + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_id ) self.mock_httpx_client.post.assert_awaited_once_with( config.url,