Skip to content

Commit 13fd561

Browse files
author
Rajesh Ramamoorthy
committed
feat: Persistent Task Store and Distributed event bus implementation for A2A server capability
1 parent 8b647bd commit 13fd561

14 files changed

Lines changed: 3256 additions & 1 deletion

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ test_venv/
1010
coverage.xml
1111
.nox
1212
spec.json
13+
.idea

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3737
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
3838
signing = ["PyJWT>=2.0.0"]
3939
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
40+
aws = ["aioboto3>=13.0.0"]
4041

4142
sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
4243

@@ -47,6 +48,7 @@ all = [
4748
"a2a-sdk[grpc]",
4849
"a2a-sdk[telemetry]",
4950
"a2a-sdk[signing]",
51+
"a2a-sdk[aws]",
5052
]
5153

5254
[project.urls]

src/a2a/server/events/__init__.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Event handling components for the A2A server."""
22

3+
import logging
4+
35
from a2a.server.events.event_consumer import EventConsumer
46
from a2a.server.events.event_queue import Event, EventQueue
57
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
@@ -10,12 +12,74 @@
1012
)
1113

1214

15+
logger = logging.getLogger(__name__)
16+
17+
try:
18+
from a2a.server.events.distributed_event_queue import (
19+
DistributedEventQueue, # type: ignore
20+
)
21+
from a2a.server.events.queue_lifecycle_manager import (
22+
QueueLifecycleManager, # type: ignore
23+
QueueProvisionResult, # type: ignore
24+
)
25+
from a2a.server.events.sns_queue_manager import (
26+
SnsQueueManager, # type: ignore
27+
)
28+
except ImportError as e:
29+
_original_aws_error = e
30+
logger.debug(
31+
'AWS distributed event components not loaded. '
32+
'Install the aws extra to enable them. Error: %s',
33+
e,
34+
)
35+
36+
class DistributedEventQueue: # type: ignore
37+
"""Placeholder when aws extra is not installed."""
38+
39+
def __init__(self, *args, **kwargs):
40+
raise ImportError(
41+
'To use DistributedEventQueue, install the aws extra: '
42+
'\'pip install "a2a-sdk[aws]"\''
43+
) from _original_aws_error
44+
45+
class SnsQueueManager: # type: ignore
46+
"""Placeholder when aws extra is not installed."""
47+
48+
def __init__(self, *args, **kwargs):
49+
raise ImportError(
50+
'To use SnsQueueManager, install the aws extra: '
51+
'\'pip install "a2a-sdk[aws]"\''
52+
) from _original_aws_error
53+
54+
class QueueLifecycleManager: # type: ignore
55+
"""Placeholder when aws extra is not installed."""
56+
57+
def __init__(self, *args, **kwargs):
58+
raise ImportError(
59+
'To use QueueLifecycleManager, install the aws extra: '
60+
'\'pip install "a2a-sdk[aws]"\''
61+
) from _original_aws_error
62+
63+
class QueueProvisionResult: # type: ignore
64+
"""Placeholder when aws extra is not installed."""
65+
66+
def __init__(self, *args, **kwargs):
67+
raise ImportError(
68+
'To use QueueProvisionResult, install the aws extra: '
69+
'\'pip install "a2a-sdk[aws]"\''
70+
) from _original_aws_error
71+
72+
1373
__all__ = [
74+
'DistributedEventQueue',
1475
'Event',
1576
'EventConsumer',
1677
'EventQueue',
1778
'InMemoryQueueManager',
1879
'NoTaskQueue',
80+
'QueueLifecycleManager',
1981
'QueueManager',
82+
'QueueProvisionResult',
83+
'SnsQueueManager',
2084
'TaskQueueExists',
2185
]
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""DistributedEventQueue — EventQueue with SNS fan-out for multi-instance A2A."""
2+
3+
import asyncio
4+
import json
5+
import logging
6+
7+
from collections.abc import Awaitable, Callable
8+
from typing import Any
9+
10+
from a2a.server.events.event_queue import (
11+
DEFAULT_MAX_QUEUE_SIZE,
12+
Event,
13+
EventQueue,
14+
)
15+
from a2a.types import (
16+
Message,
17+
Task,
18+
TaskArtifactUpdateEvent,
19+
TaskStatusUpdateEvent,
20+
)
21+
22+
23+
logger = logging.getLogger(__name__)
24+
25+
# Wire-format type tag used for graceful queue close across instances.
26+
_CLOSE_TYPE = 'close'
27+
_EVENT_TYPE = 'event'
28+
29+
# Map of ``kind`` discriminator → concrete Pydantic model class.
30+
_KIND_TO_TYPE: dict[str, type[Event]] = {
31+
'message': Message,
32+
'task': Task,
33+
'artifact-update': TaskArtifactUpdateEvent,
34+
'status-update': TaskStatusUpdateEvent,
35+
}
36+
37+
38+
def _serialise_event(
39+
event: Event,
40+
task_id: str,
41+
instance_id: str,
42+
) -> str:
43+
"""Serialises an event into the SNS wire-format JSON string.
44+
45+
Args:
46+
event: The event to serialise.
47+
task_id: The task ID this event belongs to.
48+
instance_id: The originating instance ID (for dedup).
49+
50+
Returns:
51+
A JSON string suitable for use as an SNS ``Message`` payload.
52+
"""
53+
payload: dict[str, Any] = {
54+
'instance_id': instance_id,
55+
'task_id': task_id,
56+
'type': _EVENT_TYPE,
57+
'event_kind': event.kind,
58+
'event_data': json.loads(event.model_dump_json()),
59+
}
60+
return json.dumps(payload)
61+
62+
63+
def _serialise_close(task_id: str, instance_id: str) -> str:
64+
"""Serialises a close signal into the SNS wire-format JSON string.
65+
66+
Args:
67+
task_id: The task ID whose queue is being closed.
68+
instance_id: The originating instance ID.
69+
70+
Returns:
71+
A JSON string suitable for use as an SNS ``Message`` payload.
72+
"""
73+
payload: dict[str, Any] = {
74+
'instance_id': instance_id,
75+
'task_id': task_id,
76+
'type': _CLOSE_TYPE,
77+
}
78+
return json.dumps(payload)
79+
80+
81+
def deserialise_wire_message(
82+
raw: str,
83+
) -> dict[str, Any]:
84+
"""Parses a raw SNS/SQS wire-format JSON string.
85+
86+
Args:
87+
raw: The raw JSON string from an SQS message body.
88+
89+
Returns:
90+
The parsed wire-format dictionary. The caller is responsible for
91+
routing based on the ``type`` field (``'event'`` or ``'close'``).
92+
93+
Raises:
94+
ValueError: If the JSON is malformed or the ``type`` field is absent.
95+
"""
96+
try:
97+
msg: dict[str, Any] = json.loads(raw)
98+
except json.JSONDecodeError as exc:
99+
raise ValueError(f'Malformed wire message: {raw!r}') from exc
100+
if 'type' not in msg:
101+
raise ValueError(f"Wire message missing 'type' field: {msg!r}")
102+
return msg
103+
104+
105+
def decode_event(msg: dict[str, Any]) -> Event | None:
106+
"""Decodes an event from a parsed wire-format dictionary.
107+
108+
Args:
109+
msg: A parsed wire-format dictionary with ``event_kind`` and
110+
``event_data`` fields.
111+
112+
Returns:
113+
The decoded Event, or ``None`` if the ``kind`` is unrecognised.
114+
"""
115+
kind = msg.get('event_kind')
116+
event_data = msg.get('event_data')
117+
if kind is None or event_data is None:
118+
logger.warning('Wire message missing event_kind or event_data: %s', msg)
119+
return None
120+
event_cls = _KIND_TO_TYPE.get(kind)
121+
if event_cls is None:
122+
logger.warning('Unknown event kind in wire message: %s', kind)
123+
return None
124+
return event_cls.model_validate(event_data)
125+
126+
127+
class DistributedEventQueue(EventQueue):
128+
"""EventQueue subclass that publishes events to SNS for multi-instance delivery.
129+
130+
When ``enqueue_event`` is called by an agent handler, the event is:
131+
132+
1. Enqueued locally (for the current instance's SSE stream), **and**
133+
2. Published asynchronously to SNS (for fan-out to all other instances).
134+
135+
When the SQS poller on a remote instance receives the SNS notification, it
136+
calls ``enqueue_local`` directly — bypassing SNS re-publication — to avoid
137+
infinite broadcast loops.
138+
139+
Args:
140+
publish_fn: Async callable ``(message: str) -> None`` that publishes
141+
the serialised wire message to SNS. Provided by
142+
:class:`SnsQueueManager` and injected at construction time.
143+
task_id: The task ID this queue serves.
144+
instance_id: The unique ID of the local instance (used for dedup).
145+
max_queue_size: Maximum number of events to buffer locally.
146+
Defaults to ``DEFAULT_MAX_QUEUE_SIZE``.
147+
"""
148+
149+
def __init__(
150+
self,
151+
publish_fn: Callable[[str], Awaitable[None]],
152+
task_id: str,
153+
instance_id: str,
154+
*,
155+
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
156+
) -> None:
157+
"""Initialises the DistributedEventQueue."""
158+
super().__init__(max_queue_size=max_queue_size)
159+
self._publish_fn = publish_fn
160+
self._task_id = task_id
161+
self._instance_id = instance_id
162+
logger.debug(
163+
'DistributedEventQueue initialised (task_id=%s, instance=%s).',
164+
task_id,
165+
instance_id,
166+
)
167+
168+
async def enqueue_event(self, event: Event) -> None:
169+
"""Enqueues the event locally and publishes it to SNS.
170+
171+
The SNS publish is fire-and-forget (``asyncio.create_task``) so that
172+
local delivery is never delayed by network I/O.
173+
174+
Args:
175+
event: The event to enqueue and broadcast.
176+
"""
177+
await super().enqueue_event(event)
178+
asyncio.create_task(self._publish_event(event)) # noqa: RUF006
179+
180+
async def enqueue_local(self, event: Event) -> None:
181+
"""Enqueues an event locally without re-publishing to SNS.
182+
183+
Called by the SQS poller when delivering a remote event to this
184+
instance. Using this method prevents the event from being
185+
re-broadcast back to SNS, which would create an infinite loop.
186+
187+
Args:
188+
event: The event received from the SQS queue.
189+
"""
190+
await super().enqueue_event(event)
191+
192+
async def close(self, immediate: bool = False) -> None:
193+
"""Closes the queue locally and publishes a close signal to SNS.
194+
195+
The close signal allows other instances to also close their local
196+
queues for the same task, ensuring clean shutdown across the cluster.
197+
198+
Args:
199+
immediate: If ``True``, discard buffered events immediately
200+
rather than waiting for them to drain.
201+
"""
202+
if not self.is_closed():
203+
asyncio.create_task(self._publish_close()) # noqa: RUF006
204+
await super().close(immediate=immediate)
205+
206+
async def _publish_event(self, event: Event) -> None:
207+
"""Fire-and-forget coroutine: serialises and publishes one event.
208+
209+
Args:
210+
event: The event to publish.
211+
"""
212+
try:
213+
message = _serialise_event(event, self._task_id, self._instance_id)
214+
await self._publish_fn(message)
215+
logger.debug(
216+
'Event published to SNS (task_id=%s, kind=%s).',
217+
self._task_id,
218+
event.kind,
219+
)
220+
except Exception:
221+
logger.exception(
222+
'Failed to publish event to SNS (task_id=%s).', self._task_id
223+
)
224+
225+
async def _publish_close(self) -> None:
226+
"""Fire-and-forget coroutine: publishes the close signal to SNS."""
227+
try:
228+
message = _serialise_close(self._task_id, self._instance_id)
229+
await self._publish_fn(message)
230+
logger.debug(
231+
'Close signal published to SNS (task_id=%s).', self._task_id
232+
)
233+
except Exception:
234+
logger.exception(
235+
'Failed to publish close signal to SNS (task_id=%s).',
236+
self._task_id,
237+
)

0 commit comments

Comments
 (0)