forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathredis_event_queue.py
More file actions
221 lines (187 loc) · 8.16 KB
/
redis_event_queue.py
File metadata and controls
221 lines (187 loc) · 8.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""Redis-backed EventQueue implementation using Redis Streams."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
try:
import redis.asyncio as aioredis # type: ignore
from redis.exceptions import RedisError # type: ignore
except ImportError: # pragma: no cover - optional dependency
aioredis = None # type: ignore
RedisError = Exception # type: ignore
from a2a.server.events.event_queue import EventQueue
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from a2a.server.events.event_queue import Event
from pydantic import ValidationError
from a2a.types import (
Message,
Task,
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
)
from a2a.utils.telemetry import SpanKind, trace_class
logger = logging.getLogger(__name__)
class RedisNotAvailableError(RuntimeError):
"""Raised when the redis.asyncio package is not installed."""
_TYPE_MAP = {
'Message': Message,
'Task': Task,
'TaskStatusUpdateEvent': TaskStatusUpdateEvent,
'TaskArtifactUpdateEvent': TaskArtifactUpdateEvent,
}
@trace_class(kind=SpanKind.SERVER)
class RedisEventQueue(EventQueue):
"""Redis-native EventQueue backed by a Redis Stream.
This implementation does not rely on in-memory queue structures. Each
instance manages its own read cursor (last_id). `tap()` returns a new
RedisEventQueue pointing to the same stream but starting at '$' so it
receives only future events.
"""
def __init__(
self,
task_id: str,
redis_client: Any,
stream_prefix: str = 'a2a:task',
maxlen: int | None = None,
read_block_ms: int = 500,
) -> None:
# Allow passing a custom redis client (e.g. a fake in tests).
if aioredis is None and redis_client is None:
raise RedisNotAvailableError('redis.asyncio is not available')
self._task_id = task_id
self._redis = redis_client
self._stream_key = f'{stream_prefix}:{task_id}'
self._maxlen = maxlen
self._read_block_ms = read_block_ms
# By default a normal queue should start at the beginning so it can
# consume existing entries. Taps will explicitly start at '$'.
self._last_id = '0-0'
self._is_closed = False
# No in-memory queue initialization — this class is Redis-native.
async def enqueue_event(self, event: Event) -> None:
"""Serialize and append an event to the Redis stream."""
if self._is_closed:
logger.warning('Attempt to enqueue to closed RedisEventQueue')
return
# Store payload as a JSON string to avoid client-specific mapping
# behaviour when reading back from the stream.
payload = {
'type': type(event).__name__,
'payload': event.json(),
}
kwargs: dict[str, Any] = {}
if self._maxlen:
kwargs['maxlen'] = self._maxlen
try:
await self._redis.xadd(self._stream_key, payload, **kwargs)
except RedisError:
logger.exception('Failed to XADD event to redis stream')
async def dequeue_event(self, no_wait: bool = False) -> Event | Any:
"""Read one event from the Redis stream respecting no_wait semantics.
Returns a parsed pydantic model matching the event type.
"""
if self._is_closed:
raise asyncio.QueueEmpty('Queue is closed')
block = 0 if no_wait else self._read_block_ms
# Keep reading until we find a parseable payload or a CLOSE tombstone.
while True:
try:
result = await self._redis.xread(
{self._stream_key: self._last_id}, block=block, count=1
)
except RedisError:
logger.exception('Failed to XREAD from redis stream')
raise
if not result:
raise asyncio.QueueEmpty
_, entries = result[0]
entry_id, fields = entries[0]
self._last_id = entry_id
# Normalize keys/values: redis may return bytes for both keys and values
norm: dict[str, object] = {}
try:
for k, v in fields.items():
key = k.decode('utf-8') if isinstance(k, (bytes, bytearray)) else k
if isinstance(v, (bytes, bytearray)):
try:
val: object = v.decode('utf-8')
except Exception:
val = v
else:
val = v
norm[str(key)] = val
except Exception:
# Defensive: if normalization fails, skip this entry and continue
logger.debug('RedisEventQueue.dequeue_event: failed to normalize entry fields, skipping %s', entry_id)
continue
evt_type = norm.get('type')
# Handle tombstone/close message
if evt_type == 'CLOSE':
self._is_closed = True
raise asyncio.QueueEmpty('Queue closed')
raw_payload = norm.get('payload')
if raw_payload is None:
# Missing payload — likely due to key mismatch or malformed entry.
# Skip and continue to next entry instead of returning None to callers.
logger.debug('RedisEventQueue.dequeue_event: skipping entry %s with missing payload', entry_id)
# continue loop to read next entry
continue
# If payload is a JSON string, parse it; otherwise, use as-is.
if isinstance(raw_payload, str):
try:
data = json.loads(raw_payload)
except json.JSONDecodeError:
data = raw_payload
else:
data = raw_payload
model = _TYPE_MAP.get(evt_type)
if model:
try:
return model.parse_obj(data)
except ValidationError as exc:
logger.exception('Failed to parse event payload into model')
raise ValueError(f'Failed to parse event of type {evt_type}') from exc
# Unknown type — return raw data for flexibility
logger.debug('Unknown event type: %s, returning raw payload', evt_type)
return data
def task_done(self) -> None: # streams do not require task_done semantics
"""No-op for Redis streams (kept for API compatibility)."""
def tap(self) -> EventQueue:
"""Return a new RedisEventQueue that starts at the stream tail ('$')."""
q = RedisEventQueue(
task_id=self._task_id,
redis_client=self._redis,
stream_prefix=self._stream_key.rsplit(':', 1)[0],
maxlen=self._maxlen,
read_block_ms=self._read_block_ms,
)
# Set tap's cursor to the current last entry id so it receives only
# events appended after this point.
try:
lst = getattr(self._redis, 'streams', {}).get(self._stream_key, [])
if lst:
q._last_id = lst[-1][0]
else:
q._last_id = '0-0'
except (AttributeError, KeyError, IndexError, TypeError):
# Fallback: start at stream tail if we can't determine the last ID
q._last_id = '$'
return q
async def close(self, immediate: bool = False) -> None:
"""Mark the stream closed and publish a tombstone entry for readers."""
try:
await self._redis.set(f'{self._stream_key}:closed', '1')
await self._redis.xadd(self._stream_key, {'type': 'CLOSE'})
except RedisError:
logger.exception('Failed to write close marker to redis')
def is_closed(self) -> bool:
"""Return True if this queue has been closed (close() called)."""
return self._is_closed
async def clear_events(self, clear_child_queues: bool = True) -> None:
"""Attempt to remove the underlying redis stream (best-effort)."""
try:
await self._redis.delete(self._stream_key)
except RedisError:
logger.exception('Failed to delete redis stream during clear_events')