Skip to content

Commit f865bf1

Browse files
committed
fix(server)!: deliver push notifications to all registrars across owners
1 parent a470bae commit f865bf1

12 files changed

Lines changed: 970 additions & 67 deletions

itk/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
InMemoryPushNotificationConfigStore,
3232
)
3333
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
34-
from a2a.server.context import ServerCallContext
3534
from a2a.types import a2a_pb2_grpc
3635
from a2a.types.a2a_pb2 import (
3736
AgentCapabilities,
@@ -339,7 +338,6 @@ async def main_async(http_port: int, grpc_port: int) -> None:
339338
push_sender = BasePushNotificationSender(
340339
httpx_client=httpx.AsyncClient(),
341340
config_store=push_config_store,
342-
context=ServerCallContext(),
343341
)
344342

345343
handler = DefaultRequestHandler(

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from google.protobuf.json_format import MessageToDict
77

8-
from a2a.server.context import ServerCallContext
98
from a2a.server.tasks.push_notification_config_store import (
109
PushNotificationConfigStore,
1110
)
@@ -27,26 +26,21 @@ def __init__(
2726
self,
2827
httpx_client: httpx.AsyncClient,
2928
config_store: PushNotificationConfigStore,
30-
context: ServerCallContext,
3129
) -> None:
3230
"""Initializes the BasePushNotificationSender.
3331
3432
Args:
3533
httpx_client: An async HTTP client instance to send notifications.
3634
config_store: A PushNotificationConfigStore instance to retrieve configurations.
37-
context: The `ServerCallContext` that this push notification is produced under.
3835
"""
3936
self._client = httpx_client
4037
self._config_store = config_store
41-
self._call_context: ServerCallContext = context
4238

4339
async def send_notification(
4440
self, task_id: str, event: PushNotificationEvent
4541
) -> None:
4642
"""Sends a push notification for an event if configuration exists."""
47-
push_configs = await self._config_store.get_info(
48-
task_id, self._call_context
49-
)
43+
push_configs = await self._config_store.get_info_for_dispatch(task_id)
5044
if not push_configs:
5145
return
5246

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,10 @@ async def get_info(
309309
task_id: str,
310310
context: ServerCallContext,
311311
) -> list[TaskPushNotificationConfig]:
312-
"""Retrieves all push notification configurations for a task, for the given owner."""
312+
"""Retrieves all push notification configurations for a task, for the given owner.
313+
314+
Used by the user-callable read endpoints.
315+
"""
313316
await self._ensure_initialized()
314317
owner = self.owner_resolver(context)
315318
async with self.async_session_maker() as session:
@@ -335,6 +338,35 @@ async def get_info(
335338
)
336339
return configs
337340

341+
async def get_info_for_dispatch(
342+
self,
343+
task_id: str,
344+
) -> list[TaskPushNotificationConfig]:
345+
"""Retrieves all push notification configurations for a task, across all owners.
346+
347+
Used by the push-notification dispatch path.
348+
"""
349+
await self._ensure_initialized()
350+
async with self.async_session_maker() as session:
351+
stmt = select(self.config_model).where(
352+
self.config_model.task_id == task_id
353+
)
354+
result = await session.execute(stmt)
355+
models = result.scalars().all()
356+
357+
configs = []
358+
for model in models:
359+
try:
360+
configs.append(self._from_orm(model))
361+
except ValueError: # noqa: PERF203
362+
logger.exception(
363+
'Could not deserialize push notification config for task %s, config %s, owner %s',
364+
model.task_id,
365+
model.config_id,
366+
model.owner,
367+
)
368+
return configs
369+
338370
async def delete_info(
339371
self,
340372
task_id: str,

src/a2a/server/tasks/inmemory_push_notification_config_store.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,29 @@ async def get_info(
7272
task_id: str,
7373
context: ServerCallContext,
7474
) -> list[TaskPushNotificationConfig]:
75-
"""Retrieves all push notification configurations for a task from memory, for the given owner."""
75+
"""Retrieves all push notification configurations for a task from memory, for the given owner.
76+
77+
Used by the user-callable read endpoints.
78+
"""
7679
owner = self.owner_resolver(context)
7780
async with self.lock:
7881
owner_infos = self._get_owner_push_notification_infos(owner)
7982
return list(owner_infos.get(task_id, []))
8083

84+
async def get_info_for_dispatch(
85+
self,
86+
task_id: str,
87+
) -> list[TaskPushNotificationConfig]:
88+
"""Retrieves all push notification configurations for a task across all owners.
89+
90+
Used by the push-notification dispatch path.
91+
"""
92+
async with self.lock:
93+
results: list[TaskPushNotificationConfig] = []
94+
for all_configs in self._push_notification_infos.values():
95+
results.extend(all_configs.get(task_id, []))
96+
return results
97+
8198
async def delete_info(
8299
self,
83100
task_id: str,

src/a2a/server/tasks/push_notification_config_store.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,26 @@ async def get_info(
2222
task_id: str,
2323
context: ServerCallContext,
2424
) -> list[TaskPushNotificationConfig]:
25-
"""Retrieves the push notification configuration for a task."""
25+
"""Retrieves push notification configurations for a task, scoped to the caller.
26+
27+
This is the user-callable read path. Implementations MUST return
28+
only configurations owned by the caller (as resolved from
29+
context).
30+
"""
31+
32+
@abstractmethod
33+
async def get_info_for_dispatch(
34+
self,
35+
task_id: str,
36+
) -> list[TaskPushNotificationConfig]:
37+
"""Retrieves all push notification configurations for a task, across all owners.
38+
39+
This is the internal read path used by the push-notification
40+
dispatch loop. Implementations MUST return every configuration
41+
registered for task_id regardless of which user registered
42+
it. Authorization already happened at registration time and
43+
the dispatch path fires every registered webhook for the task.
44+
"""
2645

2746
@abstractmethod
2847
async def delete_info(

tests/e2e/push_notifications/agent_app.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import httpx
22

33
from fastapi import FastAPI
4+
from starlette.applications import Starlette
5+
from starlette.requests import Request
46

7+
from a2a.auth.user import UnauthenticatedUser, User
58
from a2a.server.agent_execution import AgentExecutor, RequestContext
69
from a2a.server.context import ServerCallContext
710
from a2a.server.events import EventQueue
8-
from starlette.applications import Starlette
9-
from a2a.server.routes.rest_routes import create_rest_routes
10-
from a2a.server.routes import create_agent_card_routes
1111
from a2a.server.request_handlers import DefaultRequestHandler
12+
from a2a.server.routes import create_agent_card_routes
13+
from a2a.server.routes.common import DefaultServerCallContextBuilder
14+
from a2a.server.routes.rest_routes import create_rest_routes
1215
from a2a.server.tasks import (
1316
BasePushNotificationSender,
1417
InMemoryPushNotificationConfigStore,
@@ -30,6 +33,9 @@
3033
)
3134

3235

36+
_TEST_USER_HEADER = 'x-test-user'
37+
38+
3339
def test_agent_card(url: str) -> AgentCard:
3440
"""Returns an agent card for the test agent."""
3541
return AgentCard(
@@ -151,11 +157,85 @@ def create_agent_app(
151157
push_sender=BasePushNotificationSender(
152158
httpx_client=notification_client,
153159
config_store=push_config_store,
154-
context=ServerCallContext(),
155160
),
156161
)
157162
rest_routes = create_rest_routes(request_handler=handler)
158163
agent_card_routes = create_agent_card_routes(
159164
agent_card=card, card_url='/.well-known/agent-card.json'
160165
)
161166
return Starlette(routes=[*rest_routes, *agent_card_routes])
167+
168+
169+
class _NamedTestUser(User):
170+
"""Authenticated test user identified by ``user_name``."""
171+
172+
def __init__(self, user_name: str) -> None:
173+
self._user_name = user_name
174+
175+
@property
176+
def is_authenticated(self) -> bool:
177+
return True
178+
179+
@property
180+
def user_name(self) -> str:
181+
return self._user_name
182+
183+
184+
class _HeaderUserContextBuilder(DefaultServerCallContextBuilder):
185+
"""Builds a ServerCallContext whose user is read from a request header."""
186+
187+
def build_user(self, request: Request) -> User:
188+
user_name = request.headers.get(_TEST_USER_HEADER)
189+
if user_name:
190+
return _NamedTestUser(user_name)
191+
return UnauthenticatedUser()
192+
193+
194+
def create_multi_user_agent_app(
195+
url: str, notification_client: httpx.AsyncClient
196+
) -> Starlette:
197+
"""Creates a multi-user variant of the test agent app.
198+
199+
Differences from create_agent_app:
200+
201+
- Identity is read from the x-test-user header on each request
202+
via _HeaderUserContextBuilder. Multiple authenticated
203+
users (e.g. alice, bob) can therefore call the same
204+
server.
205+
- The InMemoryTaskStore uses a constant owner resolver, so
206+
every authenticated user has access to every task.
207+
- The InMemoryPushNotificationConfigStore keeps the default
208+
per-user owner resolver, so each registrar's configs live in their
209+
own owner partition; this exercises cross-owner aggregation in
210+
get_info_for_dispatch.
211+
"""
212+
# Shared task visibility: any authenticated user can see any task.
213+
task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared')
214+
215+
# Per-user push-config partitioning (the default).
216+
push_config_store = InMemoryPushNotificationConfigStore()
217+
218+
card = test_agent_card(url)
219+
extended_card = test_agent_card(url)
220+
extended_card.name = 'Test Agent Extended'
221+
222+
handler = DefaultRequestHandler(
223+
agent_executor=TestAgentExecutor(),
224+
task_store=task_store,
225+
agent_card=card,
226+
extended_agent_card=extended_card,
227+
push_config_store=push_config_store,
228+
push_sender=BasePushNotificationSender(
229+
httpx_client=notification_client,
230+
config_store=push_config_store,
231+
),
232+
)
233+
234+
rest_routes = create_rest_routes(
235+
request_handler=handler,
236+
context_builder=_HeaderUserContextBuilder(),
237+
)
238+
agent_card_routes = create_agent_card_routes(
239+
agent_card=card, card_url='/.well-known/agent-card.json'
240+
)
241+
return Starlette(routes=[*rest_routes, *agent_card_routes])

0 commit comments

Comments
 (0)