Skip to content

Commit 3ca8b02

Browse files
committed
feat: use StreamResponse as push notifications payload
Fixes #678
1 parent 427a75b commit 3ca8b02

10 files changed

Lines changed: 167 additions & 65 deletions

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from a2a.server.request_handlers.request_handler import RequestHandler
2222
from a2a.server.tasks import (
2323
PushNotificationConfigStore,
24+
PushNotificationEvent,
2425
PushNotificationSender,
2526
ResultAggregator,
2627
TaskManager,
@@ -310,13 +311,15 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None:
310311
)
311312

312313
async def _send_push_notification_if_needed(
313-
self, task_id: str, result_aggregator: ResultAggregator
314+
self, task_id: str, event: Event
314315
) -> None:
315-
"""Sends push notification if configured and task is available."""
316-
if self._push_sender and task_id:
317-
latest_task = await result_aggregator.current_result
318-
if isinstance(latest_task, Task):
319-
await self._push_sender.send_notification(latest_task)
316+
"""Sends push notification if configured."""
317+
if (
318+
self._push_sender
319+
and task_id
320+
and isinstance(event, PushNotificationEvent)
321+
):
322+
await self._push_sender.send_notification(task_id, event)
320323

321324
async def on_message_send(
322325
self,
@@ -346,10 +349,8 @@ async def on_message_send(
346349
interrupted_or_non_blocking = False
347350
try:
348351
# Create async callback for push notifications
349-
async def push_notification_callback() -> None:
350-
await self._send_push_notification_if_needed(
351-
task_id, result_aggregator
352-
)
352+
async def push_notification_callback(event: Event) -> None:
353+
await self._send_push_notification_if_needed(task_id, event)
353354

354355
(
355356
result,
@@ -384,8 +385,6 @@ async def push_notification_callback() -> None:
384385
result, params.configuration.history_length
385386
)
386387

387-
await self._send_push_notification_if_needed(task_id, result_aggregator)
388-
389388
return result
390389

391390
async def on_message_send_stream(
@@ -413,9 +412,7 @@ async def on_message_send_stream(
413412
if isinstance(event, Task):
414413
self._validate_task_id_match(task_id, event.id)
415414

416-
await self._send_push_notification_if_needed(
417-
task_id, result_aggregator
418-
)
415+
await self._send_push_notification_if_needed(task_id, event)
419416
yield event
420417
except (asyncio.CancelledError, GeneratorExit):
421418
# Client disconnected: continue consuming and persisting events in the background

src/a2a/server/tasks/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from a2a.server.tasks.push_notification_config_store import (
1313
PushNotificationConfigStore,
1414
)
15-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
15+
from a2a.server.tasks.push_notification_sender import (
16+
PushNotificationEvent,
17+
PushNotificationSender,
18+
)
1619
from a2a.server.tasks.result_aggregator import ResultAggregator
1720
from a2a.server.tasks.task_manager import TaskManager
1821
from a2a.server.tasks.task_store import TaskStore

src/a2a/server/tasks/base_push_notification_sender.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from a2a.server.tasks.push_notification_config_store import (
99
PushNotificationConfigStore,
1010
)
11-
from a2a.server.tasks.push_notification_sender import PushNotificationSender
12-
from a2a.types.a2a_pb2 import PushNotificationConfig, StreamResponse, Task
11+
from a2a.server.tasks.push_notification_sender import (
12+
PushNotificationEvent,
13+
PushNotificationSender,
14+
)
15+
from a2a.types.a2a_pb2 import PushNotificationConfig
16+
from a2a.utils.proto_utils import to_stream_response
1317

1418

1519
logger = logging.getLogger(__name__)
@@ -32,44 +36,50 @@ def __init__(
3236
self._client = httpx_client
3337
self._config_store = config_store
3438

35-
async def send_notification(self, task: Task) -> None:
36-
"""Sends a push notification for a task if configuration exists."""
37-
push_configs = await self._config_store.get_info(task.id)
39+
async def send_notification(
40+
self, task_id: str, event: PushNotificationEvent
41+
) -> None:
42+
"""Sends a push notification for an event if configuration exists."""
43+
push_configs = await self._config_store.get_info(task_id)
3844
if not push_configs:
3945
return
4046

4147
awaitables = [
42-
self._dispatch_notification(task, push_info)
48+
self._dispatch_notification(event, push_info, task_id)
4349
for push_info in push_configs
4450
]
4551
results = await asyncio.gather(*awaitables)
4652

4753
if not all(results):
4854
logger.warning(
49-
'Some push notifications failed to send for task_id=%s', task.id
55+
'Some push notifications failed to send for task_id=%s', task_id
5056
)
5157

5258
async def _dispatch_notification(
53-
self, task: Task, push_info: PushNotificationConfig
59+
self,
60+
event: PushNotificationEvent,
61+
push_info: PushNotificationConfig,
62+
task_id: str,
5463
) -> bool:
5564
url = push_info.url
5665
try:
5766
headers = None
5867
if push_info.token:
5968
headers = {'X-A2A-Notification-Token': push_info.token}
69+
6070
response = await self._client.post(
6171
url,
62-
json=MessageToDict(StreamResponse(task=task)),
72+
json=MessageToDict(to_stream_response(event)),
6373
headers=headers,
6474
)
6575
response.raise_for_status()
6676
logger.info(
67-
'Push-notification sent for task_id=%s to URL: %s', task.id, url
77+
'Push-notification sent for task_id=%s to URL: %s', task_id, url
6878
)
6979
except Exception:
7080
logger.exception(
7181
'Error sending push-notification for task_id=%s to URL: %s.',
72-
task.id,
82+
task_id,
7383
url,
7484
)
7585
return False
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
from abc import ABC, abstractmethod
22

3-
from a2a.types.a2a_pb2 import Task
3+
from a2a.types.a2a_pb2 import (
4+
Task,
5+
TaskArtifactUpdateEvent,
6+
TaskStatusUpdateEvent,
7+
)
8+
9+
10+
PushNotificationEvent = Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
411

512

613
class PushNotificationSender(ABC):
714
"""Interface for sending push notifications for tasks."""
815

916
@abstractmethod
10-
async def send_notification(self, task: Task) -> None:
17+
async def send_notification(
18+
self, task_id: str, event: PushNotificationEvent
19+
) -> None:
1120
"""Sends a push notification containing the latest task state."""

src/a2a/server/tasks/result_aggregator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def consume_and_break_on_interrupt(
9898
self,
9999
consumer: EventConsumer,
100100
blocking: bool = True,
101-
event_callback: Callable[[], Awaitable[None]] | None = None,
101+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
102102
) -> tuple[Task | Message | None, bool]:
103103
"""Processes the event stream until completion or an interruptible state is encountered.
104104
@@ -131,6 +131,9 @@ async def consume_and_break_on_interrupt(
131131
return event, False
132132
await self.task_manager.process(event)
133133

134+
if event_callback:
135+
await event_callback(event)
136+
134137
should_interrupt = False
135138
is_auth_required = (
136139
isinstance(event, Task | TaskStatusUpdateEvent)
@@ -169,7 +172,7 @@ async def consume_and_break_on_interrupt(
169172
async def _continue_consuming(
170173
self,
171174
event_stream: AsyncIterator[Event],
172-
event_callback: Callable[[], Awaitable[None]] | None = None,
175+
event_callback: Callable[[Event], Awaitable[None]] | None = None,
173176
) -> None:
174177
"""Continues processing an event stream in a background task.
175178
@@ -183,4 +186,4 @@ async def _continue_consuming(
183186
async for event in event_stream:
184187
await self.task_manager.process(event)
185188
if event_callback:
186-
await event_callback()
189+
await event_callback(event)

tests/e2e/push_notifications/notifications_app.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Notification(BaseModel):
1313
"""Encapsulates default push notification data."""
1414

15-
task: dict[str, Any]
15+
event: dict[str, Any]
1616
token: str
1717

1818

@@ -36,20 +36,33 @@ async def add_notification(request: Request):
3636
try:
3737
json_data = await request.json()
3838
stream_response = ParseDict(json_data, StreamResponse())
39-
if not stream_response.HasField('task'):
40-
raise HTTPException(
41-
status_code=400, detail='Missing task in StreamResponse'
42-
)
43-
task = stream_response.task
39+
40+
task_id = None
41+
if stream_response.HasField('task'):
42+
task_id = stream_response.task.id
43+
elif stream_response.HasField('status_update'):
44+
task_id = stream_response.status_update.task_id
45+
elif stream_response.HasField('artifact_update'):
46+
task_id = stream_response.artifact_update.task_id
47+
48+
if not task_id:
49+
# Ignore events without task_id (e.g. Message) for now, or log them?
50+
# For tests, we just want to ensure we don't 400 on valid updates.
51+
# If we return 200 but don't store, the test waiting for n=2 might timeout if n=2 expected stored.
52+
# But Message usually accompanies a Task update.
53+
return {'status': 'ignored_no_task_id'}
54+
4455
except Exception as e:
4556
raise HTTPException(status_code=400, detail=str(e))
4657

4758
async with store_lock:
48-
if task.id not in store:
49-
store[task.id] = []
50-
store[task.id].append(
59+
if task_id not in store:
60+
store[task_id] = []
61+
store[task_id].append(
5162
Notification(
52-
task=MessageToDict(task, preserving_proto_field_name=True),
63+
event=MessageToDict(
64+
stream_response, preserving_proto_field_name=True
65+
),
5366
token=token,
5467
)
5568
)

tests/e2e/push_notifications/test_default_push_notification_support.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,22 @@ async def test_notification_triggering_with_in_message_config_e2e(
139139
notifications = await wait_for_n_notifications(
140140
http_client,
141141
f'{notifications_server}/{task.id}/notifications',
142-
n=1,
142+
n=2,
143143
)
144144
assert notifications[0].token == token
145-
# Notification.task is a dict from proto serialization
146-
assert notifications[0].task['id'] == task.id
147-
assert notifications[0].task['status']['state'] == 'TASK_STATE_COMPLETED'
145+
146+
# Verify exactly two consecutive events: SUBMITTED -> COMPLETED
147+
assert len(notifications) == 2
148+
149+
# 1. First event: SUBMITTED (Task)
150+
event0 = notifications[0].event
151+
state0 = event0['task'].get('status', {}).get('state')
152+
assert state0 == 'TASK_STATE_SUBMITTED'
153+
154+
# 2. Second event: COMPLETED (TaskStatusUpdateEvent)
155+
event1 = notifications[1].event
156+
state1 = event1['status_update'].get('status', {}).get('state')
157+
assert state1 == 'TASK_STATE_COMPLETED'
148158

149159

150160
@pytest.mark.asyncio
@@ -220,10 +230,12 @@ async def test_notification_triggering_after_config_change_e2e(
220230
f'{notifications_server}/{task.id}/notifications',
221231
n=1,
222232
)
223-
# Notification.task is a dict from proto serialization
224-
assert notifications[0].task['id'] == task.id
225-
assert notifications[0].task['status']['state'] == 'TASK_STATE_COMPLETED'
226233
assert notifications[0].token == token
234+
# It should be a status update or task update
235+
event = notifications[0].event
236+
state = event['status_update'].get('status', {}).get('state', '')
237+
238+
assert state == 'TASK_STATE_COMPLETED'
227239

228240

229241
async def wait_for_n_notifications(

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import contextlib
33
import logging
4+
import uuid
45
import time
56

67
from unittest.mock import (
@@ -695,8 +696,11 @@ async def mock_consume_and_break_on_interrupt(
695696
'event_callback should not be None'
696697
)
697698

699+
# Simulate background processing invoking the callback
700+
await event_callback_received(final_task)
701+
698702
# Verify that the push notification was sent with the final task
699-
mock_push_sender.send_notification.assert_called_with(final_task)
703+
mock_push_sender.send_notification.assert_called_with(task_id, final_task)
700704

701705
# Verify that the push notification config was stored
702706
mock_push_notification_store.set_info.assert_awaited_once_with(
@@ -1412,8 +1416,12 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs):
14121416

14131417
# 2. send_notification called for each task event yielded by aggregator
14141418
assert mock_push_sender.send_notification.await_count == 2
1415-
mock_push_sender.send_notification.assert_any_await(event1_task_update)
1416-
mock_push_sender.send_notification.assert_any_await(event2_final_task)
1419+
mock_push_sender.send_notification.assert_any_await(
1420+
task_id, event1_task_update
1421+
)
1422+
mock_push_sender.send_notification.assert_any_await(
1423+
task_id, event2_final_task
1424+
)
14171425

14181426
mock_agent_executor.execute.assert_awaited_once()
14191427

tests/server/tasks/test_inmemory_push_notifications.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def test_send_notification_success(self) -> None:
156156
mock_response.status_code = 200
157157
self.mock_httpx_client.post.return_value = mock_response
158158

159-
await self.notifier.send_notification(task_data) # Pass only task_data
159+
await self.notifier.send_notification(task_id, task_data)
160160

161161
self.mock_httpx_client.post.assert_awaited_once()
162162
called_args, called_kwargs = self.mock_httpx_client.post.call_args
@@ -183,7 +183,7 @@ async def test_send_notification_with_token_success(self) -> None:
183183
mock_response.status_code = 200
184184
self.mock_httpx_client.post.return_value = mock_response
185185

186-
await self.notifier.send_notification(task_data) # Pass only task_data
186+
await self.notifier.send_notification(task_id, task_data)
187187

188188
self.mock_httpx_client.post.assert_awaited_once()
189189
called_args, called_kwargs = self.mock_httpx_client.post.call_args
@@ -205,7 +205,7 @@ async def test_send_notification_no_config(self) -> None:
205205
task_id = 'task_send_no_config'
206206
task_data = create_sample_task(task_id=task_id)
207207

208-
await self.notifier.send_notification(task_data) # Pass only task_data
208+
await self.notifier.send_notification(task_id, task_data)
209209

210210
self.mock_httpx_client.post.assert_not_called()
211211

@@ -229,7 +229,7 @@ async def test_send_notification_http_status_error(
229229
self.mock_httpx_client.post.side_effect = http_error
230230

231231
# The method should catch the error and log it, not re-raise
232-
await self.notifier.send_notification(task_data) # Pass only task_data
232+
await self.notifier.send_notification(task_id, task_data)
233233

234234
self.mock_httpx_client.post.assert_awaited_once()
235235
mock_logger.exception.assert_called_once()
@@ -251,7 +251,7 @@ async def test_send_notification_request_error(
251251
request_error = httpx.RequestError('Network issue', request=MagicMock())
252252
self.mock_httpx_client.post.side_effect = request_error
253253

254-
await self.notifier.send_notification(task_data) # Pass only task_data
254+
await self.notifier.send_notification(task_id, task_data)
255255

256256
self.mock_httpx_client.post.assert_awaited_once()
257257
mock_logger.exception.assert_called_once()
@@ -281,7 +281,7 @@ async def test_send_notification_with_auth(
281281
mock_response.status_code = 200
282282
self.mock_httpx_client.post.return_value = mock_response
283283

284-
await self.notifier.send_notification(task_data) # Pass only task_data
284+
await self.notifier.send_notification(task_id, task_data)
285285

286286
self.mock_httpx_client.post.assert_awaited_once()
287287
called_args, called_kwargs = self.mock_httpx_client.post.call_args

0 commit comments

Comments
 (0)