Skip to content

Commit 3627c06

Browse files
author
Rajesh Ramamoorthy
committed
feat: Addressed PR review comments and fixed spelling and linting issues
1 parent 13fd561 commit 3627c06

6 files changed

Lines changed: 105 additions & 72 deletions

File tree

src/a2a/server/events/distributed_event_queue.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@
3535
}
3636

3737

38-
def _serialise_event(
38+
def _serialize_event(
3939
event: Event,
4040
task_id: str,
4141
instance_id: str,
4242
) -> str:
43-
"""Serialises an event into the SNS wire-format JSON string.
43+
"""Serializes an event into the SNS wire-format JSON string.
4444
4545
Args:
46-
event: The event to serialise.
46+
event: The event to serialize.
4747
task_id: The task ID this event belongs to.
48-
instance_id: The originating instance ID (for dedup).
48+
instance_id: The originating instance ID (for deduplication).
4949
5050
Returns:
5151
A JSON string suitable for use as an SNS ``Message`` payload.
@@ -55,13 +55,13 @@ def _serialise_event(
5555
'task_id': task_id,
5656
'type': _EVENT_TYPE,
5757
'event_kind': event.kind,
58-
'event_data': json.loads(event.model_dump_json()),
58+
'event_data': event.model_dump(mode='json'),
5959
}
6060
return json.dumps(payload)
6161

6262

63-
def _serialise_close(task_id: str, instance_id: str) -> str:
64-
"""Serialises a close signal into the SNS wire-format JSON string.
63+
def _serialize_close(task_id: str, instance_id: str) -> str:
64+
"""Serializes a close signal into the SNS wire-format JSON string.
6565
6666
Args:
6767
task_id: The task ID whose queue is being closed.
@@ -78,7 +78,7 @@ def _serialise_close(task_id: str, instance_id: str) -> str:
7878
return json.dumps(payload)
7979

8080

81-
def deserialise_wire_message(
81+
def deserialize_wire_message(
8282
raw: str,
8383
) -> dict[str, Any]:
8484
"""Parses a raw SNS/SQS wire-format JSON string.
@@ -110,7 +110,7 @@ def decode_event(msg: dict[str, Any]) -> Event | None:
110110
``event_data`` fields.
111111
112112
Returns:
113-
The decoded Event, or ``None`` if the ``kind`` is unrecognised.
113+
The decoded Event, or ``None`` if the ``kind`` is unrecognized.
114114
"""
115115
kind = msg.get('event_kind')
116116
event_data = msg.get('event_data')
@@ -138,10 +138,11 @@ class DistributedEventQueue(EventQueue):
138138
139139
Args:
140140
publish_fn: Async callable ``(message: str) -> None`` that publishes
141-
the serialised wire message to SNS. Provided by
141+
the serialized wire message to SNS. Provided by
142142
:class:`SnsQueueManager` and injected at construction time.
143143
task_id: The task ID this queue serves.
144-
instance_id: The unique ID of the local instance (used for dedup).
144+
instance_id: The unique ID of the local instance (used for
145+
deduplication of self-published messages).
145146
max_queue_size: Maximum number of events to buffer locally.
146147
Defaults to ``DEFAULT_MAX_QUEUE_SIZE``.
147148
"""
@@ -154,13 +155,13 @@ def __init__(
154155
*,
155156
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
156157
) -> None:
157-
"""Initialises the DistributedEventQueue."""
158+
"""Initializes the DistributedEventQueue."""
158159
super().__init__(max_queue_size=max_queue_size)
159160
self._publish_fn = publish_fn
160161
self._task_id = task_id
161162
self._instance_id = instance_id
162163
logger.debug(
163-
'DistributedEventQueue initialised (task_id=%s, instance=%s).',
164+
'DistributedEventQueue initialized (task_id=%s, instance=%s).',
164165
task_id,
165166
instance_id,
166167
)
@@ -204,13 +205,13 @@ async def close(self, immediate: bool = False) -> None:
204205
await super().close(immediate=immediate)
205206

206207
async def _publish_event(self, event: Event) -> None:
207-
"""Fire-and-forget coroutine: serialises and publishes one event.
208+
"""Fire-and-forget coroutine: serializes and publishes one event.
208209
209210
Args:
210211
event: The event to publish.
211212
"""
212213
try:
213-
message = _serialise_event(event, self._task_id, self._instance_id)
214+
message = _serialize_event(event, self._task_id, self._instance_id)
214215
await self._publish_fn(message)
215216
logger.debug(
216217
'Event published to SNS (task_id=%s, kind=%s).',
@@ -225,7 +226,7 @@ async def _publish_event(self, event: Event) -> None:
225226
async def _publish_close(self) -> None:
226227
"""Fire-and-forget coroutine: publishes the close signal to SNS."""
227228
try:
228-
message = _serialise_close(self._task_id, self._instance_id)
229+
message = _serialize_close(self._task_id, self._instance_id)
229230
await self._publish_fn(message)
230231
logger.debug(
231232
'Close signal published to SNS (task_id=%s).', self._task_id

src/a2a/server/events/queue_lifecycle_manager.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,20 @@
11
"""SQS queue lifecycle manager for per-instance ECS auto-scaling support."""
22

3+
from __future__ import annotations
4+
35
import json
46
import logging
57
import uuid
68

79
from dataclasses import dataclass, field
10+
from typing import TYPE_CHECKING
811

912

10-
logger = logging.getLogger(__name__)
13+
if TYPE_CHECKING:
14+
import aioboto3
1115

12-
# SQS policy template granting SNS permission to send messages to the queue.
13-
_SNS_POLICY_TEMPLATE = json.dumps(
14-
{
15-
'Version': '2012-10-17',
16-
'Statement': [
17-
{
18-
'Effect': 'Allow',
19-
'Principal': {'Service': 'sns.amazonaws.com'},
20-
'Action': 'SQS:SendMessage',
21-
'Resource': '{queue_arn}',
22-
'Condition': {
23-
'ArnEquals': {'aws:SourceArn': '{topic_arn}'},
24-
},
25-
}
26-
],
27-
}
28-
)
16+
17+
logger = logging.getLogger(__name__)
2918

3019

3120
@dataclass
@@ -77,7 +66,7 @@ class QueueLifecycleManager:
7766
queue_name_prefix: str = 'a2a-instance'
7867
instance_id: str = field(default_factory=lambda: str(uuid.uuid4()))
7968
region_name: str = 'us-east-1'
80-
session: object | None = field(default=None, repr=False)
69+
session: aioboto3.Session | None = field(default=None, repr=False)
8170

8271
# Set after provision(); read via properties.
8372
_provision_result: QueueProvisionResult | None = field(
@@ -163,7 +152,7 @@ async def provision(self) -> QueueProvisionResult:
163152
self.region_name,
164153
)
165154

166-
async with self.session.client( # type: ignore[union-attr]
155+
async with self.session.client(
167156
'sqs', region_name=self.region_name
168157
) as sqs:
169158
# Step 1: Create the SQS queue.
@@ -179,10 +168,27 @@ async def provision(self) -> QueueProvisionResult:
179168
queue_arn: str = attr_resp['Attributes']['QueueArn']
180169
logger.debug('SQS queue ARN: %s.', queue_arn)
181170

182-
# Step 3: Attach SNS → SQS access policy.
183-
policy = _SNS_POLICY_TEMPLATE.replace(
184-
'{queue_arn}', queue_arn
185-
).replace('{topic_arn}', self.topic_arn)
171+
# Step 3: Attach SNS → SQS access policy. ARN values are
172+
# set as dict entries so json.dumps() escapes them safely,
173+
# preventing any injection via crafted ARN strings.
174+
policy = json.dumps(
175+
{
176+
'Version': '2012-10-17',
177+
'Statement': [
178+
{
179+
'Effect': 'Allow',
180+
'Principal': {'Service': 'sns.amazonaws.com'},
181+
'Action': 'SQS:SendMessage',
182+
'Resource': queue_arn,
183+
'Condition': {
184+
'ArnEquals': {
185+
'aws:SourceArn': self.topic_arn,
186+
},
187+
},
188+
}
189+
],
190+
}
191+
)
186192
await sqs.set_queue_attributes(
187193
QueueUrl=queue_url,
188194
Attributes={'Policy': policy},
@@ -192,7 +198,7 @@ async def provision(self) -> QueueProvisionResult:
192198
# Step 4: Subscribe the SQS queue to the SNS topic.
193199
subscription_arn = ''
194200
try:
195-
async with self.session.client( # type: ignore[union-attr]
201+
async with self.session.client(
196202
'sns', region_name=self.region_name
197203
) as sns:
198204
sub_resp = await sns.subscribe(
@@ -210,7 +216,7 @@ async def provision(self) -> QueueProvisionResult:
210216
queue_url,
211217
)
212218
try:
213-
async with self.session.client( # type: ignore[union-attr]
219+
async with self.session.client(
214220
'sqs', region_name=self.region_name
215221
) as sqs:
216222
await sqs.delete_queue(QueueUrl=queue_url)
@@ -251,7 +257,7 @@ async def teardown(self) -> None:
251257

252258
# Step 1: Unsubscribe from SNS (best-effort).
253259
try:
254-
async with self.session.client( # type: ignore[union-attr]
260+
async with self.session.client(
255261
'sns', region_name=self.region_name
256262
) as sns:
257263
await sns.unsubscribe(SubscriptionArn=result.subscription_arn)
@@ -266,7 +272,7 @@ async def teardown(self) -> None:
266272

267273
# Step 2: Delete the SQS queue.
268274
try:
269-
async with self.session.client( # type: ignore[union-attr]
275+
async with self.session.client(
270276
'sqs', region_name=self.region_name
271277
) as sqs:
272278
await sqs.delete_queue(QueueUrl=result.queue_url)
@@ -278,7 +284,7 @@ async def teardown(self) -> None:
278284
# Async context manager
279285
# ------------------------------------------------------------------
280286

281-
async def __aenter__(self) -> 'QueueLifecycleManager': # noqa: PYI034
287+
async def __aenter__(self) -> QueueLifecycleManager: # noqa: PYI034
282288
"""Provisions resources and returns ``self``."""
283289
await self.provision()
284290
return self

src/a2a/server/events/sns_queue_manager.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
"""SnsQueueManager — distributed QueueManager using SNS/SQS fan-out."""
22

3+
from __future__ import annotations
4+
35
import asyncio
46
import json
57
import logging
68
import uuid
79

8-
from typing import Any
10+
from typing import TYPE_CHECKING, Any
11+
12+
13+
if TYPE_CHECKING:
14+
import aioboto3
15+
16+
from a2a.server.events.event_queue import EventQueue
917

1018
from a2a.server.events.distributed_event_queue import (
1119
DistributedEventQueue,
1220
decode_event,
13-
deserialise_wire_message,
21+
deserialize_wire_message,
1422
)
15-
from a2a.server.events.event_queue import EventQueue
1623
from a2a.server.events.queue_manager import (
1724
NoTaskQueue,
1825
QueueManager,
@@ -78,12 +85,12 @@ def __init__( # noqa: PLR0913
7885
*,
7986
instance_id: str | None = None,
8087
region_name: str = 'us-east-1',
81-
session: object | None = None,
88+
session: aioboto3.Session | None = None,
8289
poll_interval_seconds: float = 1.0,
8390
max_messages: int = 10,
8491
visibility_timeout_seconds: int = 30,
8592
) -> None:
86-
"""Initialises the SnsQueueManager."""
93+
"""Initializes the SnsQueueManager."""
8794
try:
8895
import aioboto3 # noqa: PLC0415
8996
except ImportError as exc:
@@ -157,7 +164,7 @@ async def stop(self) -> None:
157164
# ------------------------------------------------------------------
158165

159166
async def add(self, task_id: str, queue: EventQueue) -> None:
160-
"""Adds a pre-existing EventQueue for *task_id*.
167+
"""Adds an already-created EventQueue for *task_id*.
161168
162169
Raises:
163170
TaskQueueExists: If a queue for *task_id* already exists.
@@ -233,12 +240,12 @@ async def create_or_tap(self, task_id: str) -> EventQueue:
233240
# ------------------------------------------------------------------
234241

235242
async def _sns_publish(self, message: str) -> None:
236-
"""Publishes a serialised wire message to the SNS topic.
243+
"""Publishes a serialized wire message to the SNS topic.
237244
238245
Args:
239246
message: JSON string in the distributed wire format.
240247
"""
241-
async with self._session.client( # type: ignore[union-attr]
248+
async with self._session.client(
242249
'sns', region_name=self._region_name
243250
) as sns:
244251
await sns.publish(TopicArn=self._topic_arn, Message=message)
@@ -261,7 +268,7 @@ async def _poll_loop(self) -> None:
261268
self._instance_id,
262269
self._sqs_queue_url,
263270
)
264-
async with self._session.client( # type: ignore[union-attr]
271+
async with self._session.client(
265272
'sqs', region_name=self._region_name
266273
) as sqs:
267274
while not self._stop_event.is_set():
@@ -332,27 +339,37 @@ async def _handle_sqs_message(self, sqs_msg: dict[str, Any]) -> None: # noqa: P
332339
Args:
333340
sqs_msg: A single SQS message dictionary from ReceiveMessage.
334341
"""
342+
# Use message ID in warnings to avoid logging potentially sensitive
343+
# message body content (PII in Task/Message payloads).
344+
msg_id = sqs_msg.get('MessageId', '<no-id>')
335345
body_str = sqs_msg.get('Body', '{}')
346+
336347
try:
337348
body: dict[str, Any] = json.loads(body_str)
338349
except json.JSONDecodeError:
339-
logger.warning('SQS message body is not valid JSON: %s', body_str)
350+
logger.warning(
351+
'SQS message body is not valid JSON (msg_id=%s).', msg_id
352+
)
340353
return
341354

342355
# Unwrap SNS notification envelope if present.
343356
if body.get('Type') == 'Notification':
344357
inner_str = body.get('Message', '{}')
345358
try:
346-
wire_msg = deserialise_wire_message(inner_str)
359+
wire_msg = deserialize_wire_message(inner_str)
347360
except ValueError:
348-
logger.warning('Malformed inner SNS message: %s', inner_str)
361+
logger.warning(
362+
'Malformed inner SNS message (msg_id=%s).', msg_id
363+
)
349364
return
350365
else:
351366
# Raw delivery — body itself is the wire message.
352367
try:
353-
wire_msg = deserialise_wire_message(body_str)
368+
wire_msg = deserialize_wire_message(body_str)
354369
except ValueError:
355-
logger.warning('Malformed raw SQS message body: %s', body_str)
370+
logger.warning(
371+
'Malformed raw SQS message body (msg_id=%s).', msg_id
372+
)
356373
return
357374

358375
# Deduplicate: ignore messages we published ourselves.

0 commit comments

Comments
 (0)