Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions itk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
InMemoryPushNotificationConfigStore,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.context import ServerCallContext
from a2a.types import a2a_pb2_grpc
from a2a.types.a2a_pb2 import (
AgentCapabilities,
Expand Down Expand Up @@ -339,7 +338,6 @@ async def main_async(http_port: int, grpc_port: int) -> None:
push_sender = BasePushNotificationSender(
httpx_client=httpx.AsyncClient(),
config_store=push_config_store,
context=ServerCallContext(),
)

handler = DefaultRequestHandler(
Expand Down
27 changes: 20 additions & 7 deletions src/a2a/server/tasks/base_push_notification_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,39 @@ def __init__(
self,
httpx_client: httpx.AsyncClient,
config_store: PushNotificationConfigStore,
context: ServerCallContext,
context: ServerCallContext | None = None,
) -> None:
"""Initializes the BasePushNotificationSender.

Args:
httpx_client: An async HTTP client instance to send notifications.
config_store: A PushNotificationConfigStore instance to retrieve configurations.
context: The `ServerCallContext` that this push notification is produced under.
config_store: A PushNotificationConfigStore instance to
retrieve configurations.
context: Deprecated and ignored. Accepted only for
backward compatibility with 1.0 callers that constructed
the sender with a (typically dummy) ServerCallContext.
Pass None (the default) in new code. A non-None
value logs a deprecation warning and is otherwise
ignored.
"""
if context is not None:
logger.warning(
'BasePushNotificationSender no longer uses the context '
'parameter; it is accepted only for backward compatibility '
'with 1.0 and will be removed in a future major version. '
'Push notifications now fan out across all owners via '
'PushNotificationConfigStore.get_info_for_dispatch; the '
'caller identity is not carried into dispatch. Drop the '
'context argument from the constructor call.'
)
self._client = httpx_client
self._config_store = config_store
self._call_context: ServerCallContext = context

async def send_notification(
self, task_id: str, event: PushNotificationEvent
) -> None:
"""Sends a push notification for an event if configuration exists."""
push_configs = await self._config_store.get_info(
task_id, self._call_context
)
push_configs = await self._config_store.get_info_for_dispatch(task_id)
if not push_configs:
return

Expand Down
46 changes: 33 additions & 13 deletions src/a2a/server/tasks/database_push_notification_config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


try:
from sqlalchemy import Table, and_, delete, select
from sqlalchemy import ColumnElement, Table, and_, delete, select
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
Expand Down Expand Up @@ -304,21 +304,14 @@ async def set_info(
owner,
)

async def get_info(
async def _select_configs(
self,
task_id: str,
context: ServerCallContext,
*predicates: 'ColumnElement[bool]',
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task, for the given owner."""
"""Loads configs matching the given predicates and decodes them."""
await self._ensure_initialized()
owner = self.owner_resolver(context)
async with self.async_session_maker() as session:
stmt = select(self.config_model).where(
and_(
self.config_model.task_id == task_id,
self.config_model.owner == owner,
)
)
stmt = select(self.config_model).where(and_(*predicates))
result = await session.execute(stmt)
models = result.scalars().all()

Expand All @@ -331,10 +324,37 @@ async def get_info(
'Could not deserialize push notification config for task %s, config %s, owner %s',
model.task_id,
model.config_id,
owner,
model.owner,
)
return configs

async def get_info(
self,
task_id: str,
context: ServerCallContext,
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task, for the given owner.

Used by the user-callable read endpoints.
"""
owner = self.owner_resolver(context)
return await self._select_configs(
self.config_model.task_id == task_id,
self.config_model.owner == owner,
)

async def get_info_for_dispatch(
Comment thread
sokoliva marked this conversation as resolved.
self,
task_id: str,
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task, across all owners.

Used by the push-notification dispatch path.
"""
return await self._select_configs(
self.config_model.task_id == task_id,
)

async def delete_info(
self,
task_id: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,29 @@ async def get_info(
task_id: str,
context: ServerCallContext,
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task from memory, for the given owner."""
"""Retrieves all push notification configurations for a task from memory, for the given owner.

Used by the user-callable read endpoints.
"""
owner = self.owner_resolver(context)
async with self.lock:
owner_infos = self._get_owner_push_notification_infos(owner)
return list(owner_infos.get(task_id, []))

async def get_info_for_dispatch(
self,
task_id: str,
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task across all owners.

Used by the push-notification dispatch path.
"""
async with self.lock:
results: list[TaskPushNotificationConfig] = []
for all_configs in self._push_notification_infos.values():
results.extend(all_configs.get(task_id, []))
return results
Comment thread
sokoliva marked this conversation as resolved.

async def delete_info(
self,
task_id: str,
Expand Down
46 changes: 45 additions & 1 deletion src/a2a/server/tasks/push_notification_config_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging

from abc import ABC, abstractmethod

from a2a.server.context import ServerCallContext
from a2a.types.a2a_pb2 import TaskPushNotificationConfig


logger = logging.getLogger(__name__)


class PushNotificationConfigStore(ABC):
"""Interface for storing and retrieving push notification configurations for tasks."""

Expand All @@ -22,7 +27,46 @@ async def get_info(
task_id: str,
context: ServerCallContext,
) -> list[TaskPushNotificationConfig]:
"""Retrieves the push notification configuration for a task."""
"""Retrieves push notification configurations for a task, scoped to the caller.

This is the user-callable read path. Implementations MUST return
only configurations owned by the caller (as resolved from
context).
"""

async def get_info_for_dispatch(
self,
task_id: str,
) -> list[TaskPushNotificationConfig]:
"""Retrieves all push notification configurations for a task, across all owners.

This is the internal read path used by the push-notification
dispatch loop. Implementations SHOULD override this method to
return every configuration registered for task_id regardless of
which user registered it. Authorization already happened at
registration time and the dispatch path fires every registered
webhook for the task.

The default implementation falls back to calling get_info with
a synthetic empty ServerCallContext. This preserves 1.0
behavior for subclasses that have not implemented the override
but is INCORRECT for any deployment with multiple owners: the
empty context resolves to the empty-string owner partition and
returns no configs (silently dropping every notification). A
warning is logged on every call to flag the misconfiguration.
Custom subclasses MUST override this method to deliver
notifications correctly in multi-owner deployments.
"""
logger.warning(
'%s does not override '
'PushNotificationConfigStore.get_info_for_dispatch; falling back '
'to a context-less get_info call which silently drops '
'notifications in any deployment with multiple owners. Override '
'get_info_for_dispatch to return all configs for task_id across '
'every owner.',
type(self).__name__,
)
return await self.get_info(task_id, ServerCallContext())

@abstractmethod
async def delete_info(
Expand Down
88 changes: 84 additions & 4 deletions tests/e2e/push_notifications/agent_app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import httpx

from fastapi import FastAPI
from starlette.applications import Starlette
from starlette.requests import Request

from a2a.auth.user import UnauthenticatedUser, User
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.context import ServerCallContext
from a2a.server.events import EventQueue
from starlette.applications import Starlette
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.routes import create_agent_card_routes
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.routes import create_agent_card_routes
from a2a.server.routes.common import DefaultServerCallContextBuilder
from a2a.server.routes.rest_routes import create_rest_routes
from a2a.server.tasks import (
BasePushNotificationSender,
InMemoryPushNotificationConfigStore,
Expand All @@ -30,6 +33,9 @@
)


_TEST_USER_HEADER = 'x-test-user'


def test_agent_card(url: str) -> AgentCard:
"""Returns an agent card for the test agent."""
return AgentCard(
Expand Down Expand Up @@ -151,11 +157,85 @@ def create_agent_app(
push_sender=BasePushNotificationSender(
httpx_client=notification_client,
config_store=push_config_store,
context=ServerCallContext(),
),
)
rest_routes = create_rest_routes(request_handler=handler)
agent_card_routes = create_agent_card_routes(
agent_card=card, card_url='/.well-known/agent-card.json'
)
return Starlette(routes=[*rest_routes, *agent_card_routes])


class _NamedTestUser(User):
"""Authenticated test user identified by ``user_name``."""

def __init__(self, user_name: str) -> None:
self._user_name = user_name

@property
def is_authenticated(self) -> bool:
return True

@property
def user_name(self) -> str:
return self._user_name


class _HeaderUserContextBuilder(DefaultServerCallContextBuilder):
"""Builds a ServerCallContext whose user is read from a request header."""

def build_user(self, request: Request) -> User:
user_name = request.headers.get(_TEST_USER_HEADER)
if user_name:
return _NamedTestUser(user_name)
return UnauthenticatedUser()


def create_multi_user_agent_app(
url: str, notification_client: httpx.AsyncClient
) -> Starlette:
"""Creates a multi-user variant of the test agent app.

Differences from create_agent_app:

- Identity is read from the x-test-user header on each request
via _HeaderUserContextBuilder. Multiple authenticated
users (e.g. alice, bob) can therefore call the same
server.
- The InMemoryTaskStore uses a constant owner resolver, so
every authenticated user has access to every task.
- The InMemoryPushNotificationConfigStore keeps the default
per-user owner resolver, so each registrar's configs live in their
own owner partition; this exercises cross-owner aggregation in
get_info_for_dispatch.
"""
# Shared task visibility: any authenticated user can see any task.
task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared')

# Per-user push-config partitioning (the default).
push_config_store = InMemoryPushNotificationConfigStore()

card = test_agent_card(url)
extended_card = test_agent_card(url)
extended_card.name = 'Test Agent Extended'

handler = DefaultRequestHandler(
agent_executor=TestAgentExecutor(),
task_store=task_store,
agent_card=card,
extended_agent_card=extended_card,
push_config_store=push_config_store,
push_sender=BasePushNotificationSender(
httpx_client=notification_client,
config_store=push_config_store,
),
)

rest_routes = create_rest_routes(
request_handler=handler,
context_builder=_HeaderUserContextBuilder(),
)
agent_card_routes = create_agent_card_routes(
agent_card=card, card_url='/.well-known/agent-card.json'
)
return Starlette(routes=[*rest_routes, *agent_card_routes])
Loading
Loading