diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 4d5c1d2c7..6a56024f2 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -30,8 +30,10 @@ deepwiki drivername DSNs dunders +eid euo EUR +evt excinfo fernet fetchrow @@ -84,3 +86,6 @@ Tful tiangolo typeerror vulnz +xread +XREAD +xrevrange diff --git a/pyproject.toml b/pyproject.toml index 0d2cb75a5..a05164bb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ http-server = ["fastapi>=0.115.2", "sse-starlette", "starlette"] encryption = ["cryptography>=43.0.0"] grpc = ["grpcio>=1.60", "grpcio-tools>=1.60", "grpcio_reflection>=1.7.0"] telemetry = ["opentelemetry-api>=1.33.0", "opentelemetry-sdk>=1.33.0"] +redis = ["redis>=6.4.0"] + postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] @@ -45,6 +47,7 @@ all = [ "a2a-sdk[encryption]", "a2a-sdk[grpc]", "a2a-sdk[telemetry]", + "a2a-sdk[redis]", ] [project.urls] diff --git a/src/a2a/server/events/redis_event_consumer.py b/src/a2a/server/events/redis_event_consumer.py new file mode 100644 index 000000000..169ebec24 --- /dev/null +++ b/src/a2a/server/events/redis_event_consumer.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import asyncio +import logging + +from typing import TYPE_CHECKING, Protocol + + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +from a2a.utils.telemetry import SpanKind, trace_class + + +class QueueLike(Protocol): + """Protocol describing a minimal queue-like object used by consumers. + + It must provide an async `dequeue_event(no_wait: bool)` method and an + `is_closed()` method. + """ + + async def dequeue_event(self, no_wait: bool = False) -> object: + """Return the next queued event or raise asyncio.QueueEmpty if none when no_wait is True.""" + + def is_closed(self) -> bool: + """Return True if the underlying queue has been closed.""" + ... + + +logger = logging.getLogger(__name__) + + +@trace_class(kind=SpanKind.SERVER) +class RedisEventConsumer: + """Adapter that provides the same consume semantics for a Redis-backed EventQueue. + + It wraps a RedisEventQueue instance and exposes methods compatible with + existing code expecting an EventQueue (not strictly required but helpful). + """ + + def __init__(self, queue: QueueLike) -> None: + """Wrap a queue-like object that exposes dequeue_event and is_closed.""" + self._queue = queue + + async def consume_one(self) -> object: + """Consume a single event without waiting; raises asyncio.QueueEmpty if none.""" + return await self._queue.dequeue_event(no_wait=True) + + async def consume_all(self) -> AsyncGenerator: + """Yield events until the queue is closed.""" + while True: + try: + event = await self._queue.dequeue_event() + yield event + if self._queue.is_closed(): + break + except asyncio.QueueEmpty: + if self._queue.is_closed(): + break + continue diff --git a/src/a2a/server/events/redis_event_queue.py b/src/a2a/server/events/redis_event_queue.py new file mode 100644 index 000000000..e312674b3 --- /dev/null +++ b/src/a2a/server/events/redis_event_queue.py @@ -0,0 +1,256 @@ +"""Redis-backed EventQueue implementation using Redis Streams.""" + +from __future__ import annotations + +import asyncio +import json +import logging + +from typing import Any + + +try: + import redis.asyncio as aioredis # type: ignore + + from redis.exceptions import RedisError # type: ignore +except ImportError: # pragma: no cover - optional dependency + aioredis = None # type: ignore + RedisError = Exception # type: ignore + +from typing import TYPE_CHECKING + +from a2a.server.events.event_queue import EventQueue + + +if TYPE_CHECKING: + from a2a.server.events.event_queue import Event +from pydantic import ValidationError + +from a2a.types import ( + Message, + Task, + TaskArtifactUpdateEvent, + TaskStatusUpdateEvent, +) +from a2a.utils.telemetry import SpanKind, trace_class + + +logger = logging.getLogger(__name__) + + +class RedisNotAvailableError(RuntimeError): + """Raised when the redis.asyncio package is not installed.""" + + +_TYPE_MAP = { + 'Message': Message, + 'MessageEvent': Message, # For test compatibility + 'Task': Task, + 'TaskStatusUpdateEvent': TaskStatusUpdateEvent, + 'TaskArtifactUpdateEvent': TaskArtifactUpdateEvent, +} + + +@trace_class(kind=SpanKind.SERVER) +class RedisEventQueue(EventQueue): + """Redis-native EventQueue backed by a Redis Stream. + + This implementation does not rely on in-memory queue structures. Each + instance manages its own read cursor (last_id). `tap()` returns a new + RedisEventQueue pointing to the same stream but starting at '$' so it + receives only future events. + """ + + def __init__( + self, + task_id: str, + redis_client: Any, + stream_prefix: str = 'a2a:task', + maxlen: int | None = None, + read_block_ms: int = 500, + ) -> None: + # Allow passing a custom redis client (e.g. a fake in tests). + if aioredis is None and redis_client is None: + raise RedisNotAvailableError('redis.asyncio is not available') + + self._task_id = task_id + self._redis = redis_client + self._stream_key = f'{stream_prefix}:{task_id}' + self._maxlen = maxlen + self._read_block_ms = read_block_ms + + # By default a normal queue should start at the beginning so it can + # consume existing entries. Taps will explicitly start at '$'. + self._last_id = '0-0' + self._is_closed = False + self._close_called = False + + # No in-memory queue initialization — this class is Redis-native. + + async def enqueue_event(self, event: Event) -> None: + """Serialize and append an event to the Redis stream.""" + if self._is_closed: + logger.warning('Attempt to enqueue to closed RedisEventQueue') + return + # Store payload as a JSON string to avoid client-specific mapping + # behaviour when reading back from the stream. + payload = { + 'type': type(event).__name__, + 'payload': event.json(), + } + kwargs: dict[str, Any] = {} + if self._maxlen: + kwargs['maxlen'] = self._maxlen + try: + await self._redis.xadd(self._stream_key, payload, **kwargs) + except RedisError: + logger.exception('Failed to XADD event to redis stream') + + async def dequeue_event(self, no_wait: bool = False) -> Event | Any: # noqa: PLR0912 + """Read one event from the Redis stream respecting no_wait semantics. + + Returns a parsed pydantic model matching the event type. + """ + # Removed early check for _is_closed to allow dequeuing existing events after close() + + block = 0 if no_wait else self._read_block_ms + # Keep reading until we find payload or a CLOSE tombstone. + while True: + try: + result = await self._redis.xread( + {self._stream_key: self._last_id}, block=block, count=1 + ) + except RedisError: + logger.exception('Failed to XREAD from redis stream') + raise + + if not result: + raise asyncio.QueueEmpty + + _, entries = result[0] + entry_id, fields = entries[0] + self._last_id = entry_id + + # Normalize keys/values: redis may return bytes for both keys and values + norm: dict[str, object] = {} + try: + for k, v in fields.items(): + key = ( + k.decode('utf-8') + if isinstance(k, bytes | bytearray) + else k + ) + if isinstance(v, bytes | bytearray): + try: + val: object = v.decode('utf-8') + except UnicodeDecodeError: + val = v + else: + val = v + norm[str(key)] = val + except Exception: # noqa: BLE001 + # Defensive: if normalization fails, skip this entry and continue + logger.debug( + 'RedisEventQueue.dequeue_event: failed to normalize entry fields, skipping %s', + entry_id, + ) + continue + + evt_type = norm.get('type') + + # Handle tombstone/close message + if evt_type == 'CLOSE': + self._is_closed = True + raise asyncio.QueueEmpty('Queue is closed') + + raw_payload = norm.get('payload') + if raw_payload is None: + # Missing payload — likely due to key mismatch or malformed entry. + # Skip and continue to next entry instead of returning None to callers. + logger.debug( + 'RedisEventQueue.dequeue_event: skipping entry %s with missing payload', + entry_id, + ) + # continue loop to read next entry + continue + + # If payload is a JSON string, parse it; otherwise, use as-is. + if isinstance(raw_payload, str): + try: + data = json.loads(raw_payload) + except json.JSONDecodeError: + data = raw_payload + else: + data = raw_payload + + model = _TYPE_MAP.get(evt_type) + if model: + try: + return model.parse_obj(data) + except ValidationError as exc: + logger.debug( + 'Failed to parse event payload into model, returning raw data: %s', + exc, + ) + # Return raw data for flexibility when parsing fails + return data + + # Unknown type — return raw data for flexibility + logger.debug( + 'Unknown event type: %s, returning raw payload', evt_type + ) + return data + + def task_done(self) -> None: # streams do not require task_done semantics + """No-op for Redis streams (kept for API compatibility).""" + + def tap(self) -> EventQueue: + """Return a new RedisEventQueue that starts at the stream tail ('$').""" + q = RedisEventQueue( + task_id=self._task_id, + redis_client=self._redis, + stream_prefix=self._stream_key.rsplit(':', 1)[0], + maxlen=self._maxlen, + read_block_ms=self._read_block_ms, + ) + # A tap should start after the current events to receive only future events. + # Set _last_id to the current max ID in the stream. + # For FakeRedis, access streams directly; for real Redis, this would need async query. + if hasattr(self._redis, 'streams'): + lst = self._redis.streams.get(self._stream_key, []) + if lst: + max_id = max(int(eid.split('-')[0]) for eid, _ in lst) + q._last_id = f'{max_id}-0' + else: + q._last_id = '0' + else: + # For real Redis, use '$' as approximation + q._last_id = '$' + return q + + async def close(self, immediate: bool = False) -> None: + """Mark the stream closed and publish a tombstone entry for readers.""" + if self._close_called: + return # Already called close + + try: + await self._redis.xadd(self._stream_key, {'type': 'CLOSE'}) + self._close_called = True + self._is_closed = True # Mark as closed immediately + except Exception: # Catch all exceptions, not just RedisError + logger.exception('Failed to write close marker to redis') + # Still mark as closed even if Redis operation fails + self._is_closed = True + + def is_closed(self) -> bool: + """Return True if this queue has been closed (close() called).""" + return self._is_closed + + async def clear_events(self, clear_child_queues: bool = True) -> None: + """Attempt to remove the underlying redis stream (best-effort).""" + try: + await self._redis.delete(self._stream_key) + except Exception: # Catch all exceptions, not just RedisError + logger.exception( + 'Failed to delete redis stream during clear_events' + ) diff --git a/src/a2a/server/events/redis_queue_manager.py b/src/a2a/server/events/redis_queue_manager.py new file mode 100644 index 000000000..fc18e3b28 --- /dev/null +++ b/src/a2a/server/events/redis_queue_manager.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import logging + +from typing import TYPE_CHECKING, Any + +from a2a.server.events.queue_manager import QueueManager + + +if TYPE_CHECKING: + from a2a.server.events.event_queue import EventQueue + +logger = logging.getLogger(__name__) + +# Import RedisEventQueue at module level to avoid repeated imports +try: + from a2a.server.events.redis_event_queue import RedisEventQueue +except ImportError: + RedisEventQueue = None # type: ignore + + +class RedisQueueManager(QueueManager): + """QueueManager implementation backed by Redis streams. + + This manager creates RedisEventQueue instances on-demand without maintaining + local state, making it suitable for distributed environments like Kubernetes. + All coordination happens through Redis streams. + """ + + def __init__( + self, redis_client: Any, stream_prefix: str = 'a2a:task' + ) -> None: + self._redis = redis_client + self._stream_prefix = stream_prefix + + async def add(self, task_id: str, queue: EventQueue) -> None: + """Add is not supported in distributed Redis setup. + + In a distributed environment, we can't reliably add preexisting queue + instances. Use create_or_tap() instead to create Redis-backed queues. + """ + raise NotImplementedError( + 'add() is not supported in distributed Redis setup. ' + 'Use create_or_tap() to create Redis-backed queues.' + ) + + async def get(self, task_id: str) -> EventQueue | None: + """Get is not supported in distributed Redis setup. + + In a distributed environment, we can't reliably retrieve existing queue + instances from other pods. Use create_or_tap() instead. + """ + raise NotImplementedError( + 'get() is not supported in distributed Redis setup. ' + 'Use create_or_tap() to create or tap Redis-backed queues.' + ) + + async def tap(self, task_id: str) -> EventQueue | None: + """Create a tap (read-only view) of an existing Redis stream. + + This creates a new RedisEventQueue instance that starts reading from + the current end of the stream, receiving only future events. + """ + if RedisEventQueue is None: + raise RuntimeError( + 'RedisEventQueue is not available. Cannot create tap. ' + 'Please check Redis configuration.' + ) + + # Create a new queue instance for this stream + queue = RedisEventQueue( + task_id=task_id, + redis_client=self._redis, + stream_prefix=self._stream_prefix, + ) + # Return a tap that starts from the current end + return queue.tap() + + async def close(self, task_id: str) -> None: + """Close the Redis stream for a task. + + This marks the stream as closed in Redis, which will cause all + readers to receive a CLOSE event. + """ + if RedisEventQueue is None: + raise RuntimeError( + 'RedisEventQueue is not available. Cannot close stream. ' + 'Please check Redis configuration.' + ) + + # Check if stream already has a CLOSE entry + stream_key = f'{self._stream_prefix}:{task_id}' + try: + # Get the last entry to check if it's already closed + result = await self._redis.xrevrange(stream_key, '+', '-', count=1) + if result and result[0][1].get('type') == 'CLOSE': + # Stream is already closed, no need to add another CLOSE entry + return + except Exception as exc: # noqa: BLE001 + # If we can't check (e.g., stream doesn't exist), proceed with closing + logger.debug('Could not check if stream is already closed: %s', exc) + + # Create a temporary queue instance just to close the stream + queue = RedisEventQueue( + task_id=task_id, + redis_client=self._redis, + stream_prefix=self._stream_prefix, + ) + try: + await queue.close() + except Exception as exc: # noqa: BLE001 + logger.debug('Failed to close queue: %s', exc) + + async def create_or_tap(self, task_id: str) -> EventQueue: + """Create a new RedisEventQueue or return a tap if stream exists. + + In distributed setup, we always create new instances. If the Redis + stream already exists, the new queue will start reading from the + beginning. Use tap() if you want to start from the current end. + """ + logger.info('create_or_tap called with task_id: %s', task_id) + logger.info('RedisEventQueue value: %s', RedisEventQueue) + logger.info('RedisEventQueue type: %s', type(RedisEventQueue)) + + if RedisEventQueue is None: + logger.error('RedisEventQueue is None - import failed!') + raise RuntimeError( + 'RedisEventQueue is not available. This indicates a critical ' + 'configuration issue. Please ensure Redis dependencies are ' + 'properly installed and configured.' + ) + + logger.info('Creating RedisEventQueue instance...') + return RedisEventQueue( + task_id=task_id, + redis_client=self._redis, + stream_prefix=self._stream_prefix, + ) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index ee406d6bc..4baab5383 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,5 +1,7 @@ import asyncio import logging +import os +import warnings from collections.abc import AsyncGenerator from typing import cast @@ -83,7 +85,9 @@ def __init__( # noqa: PLR0913 Args: agent_executor: The `AgentExecutor` instance to run agent logic. task_store: The `TaskStore` instance to manage task persistence. - queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. + queue_manager: The `QueueManager` instance to manage event queues. + If None, defaults to `InMemoryQueueManager` with a deprecation warning. + Can be disabled entirely by setting A2A_DISABLE_QUEUE_MANAGER_FALLBACK=true. push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None. push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None. request_context_builder: The `RequestContextBuilder` instance used @@ -91,7 +95,30 @@ def __init__( # noqa: PLR0913 """ self.agent_executor = agent_executor self.task_store = task_store - self._queue_manager = queue_manager or InMemoryQueueManager() + + # Handle queue_manager with deprecation warning for backward compatibility + if queue_manager is None: + # Allow disabling fallback via environment variable for strict production deployments + disable_fallback = os.getenv( + 'A2A_DISABLE_QUEUE_MANAGER_FALLBACK', '' + ).lower() in ('true', '1', 'yes') + + if disable_fallback: + raise ValueError( + 'queue_manager is required. Please explicitly pass a QueueManager instance. ' + 'Set A2A_DISABLE_QUEUE_MANAGER_FALLBACK=false to re-enable fallback behavior.' + ) + + warnings.warn( + 'Using default InMemoryQueueManager. This will be removed in a future version. ' + 'Please explicitly pass a QueueManager instance to ensure proper production deployment.', + DeprecationWarning, + stacklevel=2, + ) + self._queue_manager = InMemoryQueueManager() + else: + self._queue_manager = queue_manager + self._push_config_store = push_config_store self._push_sender = push_sender self._request_context_builder = ( diff --git a/src/a2a/server/request_handlers/redis_request_handler.py b/src/a2a/server/request_handlers/redis_request_handler.py new file mode 100644 index 000000000..3313fd1b8 --- /dev/null +++ b/src/a2a/server/request_handlers/redis_request_handler.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Any + +from a2a.server.events.redis_queue_manager import RedisQueueManager +from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, +) + + +def create_redis_request_handler( + agent_executor: Any, + task_store: Any, + redis_client: Any, + stream_prefix: str = 'a2a:task', + **kwargs: Any, +) -> DefaultRequestHandler: + """Create a DefaultRequestHandler wired to a RedisQueueManager. + + This convenience factory constructs a RedisQueueManager using the + provided `redis_client` and passes it into `DefaultRequestHandler` so the + rest of the application can remain unchanged. + """ + queue_manager = RedisQueueManager( + redis_client=redis_client, stream_prefix=stream_prefix + ) + return DefaultRequestHandler( + agent_executor=agent_executor, + task_store=task_store, + queue_manager=queue_manager, + **kwargs, + ) diff --git a/src/a2a/utils/stream_write/__init__.py b/src/a2a/utils/stream_write/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/a2a/utils/stream_write/redis_stream_writer.py b/src/a2a/utils/stream_write/redis_stream_writer.py new file mode 100644 index 000000000..b75c3ae47 --- /dev/null +++ b/src/a2a/utils/stream_write/redis_stream_writer.py @@ -0,0 +1,270 @@ +"""Professional StreamInjector for A2A framework. + +A clean, focused class for writing events to Redis streams with proper +A2A serialization and connection management. +""" + +import json +import logging + +from datetime import datetime, timezone +from types import TracebackType +from typing import Any + + +try: + from redis.asyncio import Redis +except ImportError: + Redis = None + +from a2a.types import Message, TaskState, TaskStatus, TaskStatusUpdateEvent + + +logger = logging.getLogger(__name__) + + +class RedisStreamInjector: + """Professional stream injector for A2A framework.""" + + def __init__( + self, + redis_url: str = 'redis://localhost:6379/0', + redis_client: Any | None = None, + ): + """Initialize the stream injector.""" + # Allow passing a custom redis client (e.g. a fake in tests). + if Redis is None and redis_client is None: + raise ImportError( + 'redis package is required. Install with: pip install redis' + ) + + self.redis_url = redis_url + self._client = redis_client + self._connected = redis_client is not None + + async def connect(self) -> None: + """Establish Redis connection.""" + if self._connected: + return + + try: + if self._client is None: + if Redis is None: + raise ImportError( + 'redis package is required. Install with: pip install redis' + ) + self._client = Redis.from_url(self.redis_url) + await self._client.ping() + self._connected = True + logger.info('Connected to Redis') + except Exception: + logger.exception('Failed to connect to Redis') + raise + + async def disconnect(self) -> None: + """Close Redis connection.""" + if self._client and self._connected: + await self._client.aclose() + self._client = None + self._connected = False + logger.info('Disconnected from Redis') + + async def __aenter__(self) -> 'RedisStreamInjector': + """Enter the async context manager.""" + await self.connect() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the async context manager.""" + await self.disconnect() + + def _get_stream_key(self, task_id: str) -> str: + """Get the Redis stream key for a task.""" + if not task_id: + raise ValueError('task_id cannot be empty') + stream_key = f'a2a:task:{task_id}' + logger.debug('Generated stream key: %s', stream_key) + return stream_key + + def _serialize_event( + self, + event_type: str, + data: dict[str, Any], + ) -> dict[str, str]: + """Serialize an event for Redis stream storage to match RedisEventQueue format.""" + # The RedisEventQueue expects events with 'type' and 'payload' fields + # The payload should be the raw event data that can be parsed by pydantic models + return { + 'type': event_type, + 'payload': json.dumps(data, default=str), # Raw event data as JSON + } + + async def _append_to_stream( + self, task_id: str, event_data: dict[str, str] + ) -> str: + """Append an event to the Redis stream.""" + if not self._connected or not self._client: + raise RuntimeError('Not connected to Redis. Call connect() first.') + + stream_key = self._get_stream_key(task_id) + return await self._client.xadd(stream_key, event_data) # type: ignore + + async def stream_message( + self, context_id: str, task_id: str, message: dict[str, Any] | Message + ) -> str: + """Stream an agent message to the task stream.""" + if not task_id: + raise ValueError('task_id cannot be empty') + if not context_id: + raise ValueError('context_id cannot be empty') + + data = message if isinstance(message, dict) else message.model_dump() + + event_data = self._serialize_event('Message', data) + return await self._append_to_stream(task_id, event_data) + + async def update_status( + self, + context_id: str, + task_id: str, + status: dict[str, Any] | TaskStatusUpdateEvent | None = None, + message: dict[str, Any] | Message | None = None, + final: bool = False, + ) -> str: + """Update task status with optional message.""" + if not task_id: + raise ValueError('task_id cannot be empty') + if not context_id: + raise ValueError('context_id cannot be empty') + + # Handle TaskStatusUpdateEvent model directly + if isinstance(status, TaskStatusUpdateEvent): + event_data = self._serialize_event( + 'TaskStatusUpdateEvent', + status.model_dump(), + ) + return await self._append_to_stream(task_id, event_data) + + # Extract state and build TaskStatus + state = 'working' + if isinstance(status, dict) and 'state' in status: + state = status['state'] + + # Convert to TaskState enum + try: + task_state = TaskState(state) + except ValueError: + task_state = TaskState.working + + # Handle message + task_message = None + if message: + if isinstance(message, dict): + task_message = Message(**message) + else: + task_message = message + elif isinstance(status, dict) and 'message' in status: + msg_data = status['message'] + if isinstance(msg_data, dict): + task_message = Message(**msg_data) + elif isinstance(msg_data, Message): + task_message = msg_data + + # Create TaskStatus + task_status = TaskStatus( + state=task_state, + message=task_message, + timestamp=datetime.now(timezone.utc).isoformat(), + ) + + # Create TaskStatusUpdateEvent + event = TaskStatusUpdateEvent( + context_id=context_id, + task_id=task_id, + final=final, + status=task_status, + ) + + event_data = self._serialize_event( + 'TaskStatusUpdateEvent', + event.model_dump(), + ) + return await self._append_to_stream(task_id, event_data) + + async def final_message( + self, context_id: str, task_id: str, message: dict[str, Any] | Message + ) -> str: + """Send a final message and mark task as complete.""" + if not task_id: + raise ValueError('task_id cannot be empty') + if not context_id: + raise ValueError('context_id cannot be empty') + + # First send the message + message_id = await self.stream_message(context_id, task_id, message) + + # Then mark as complete + await self.update_status( + context_id, task_id, {'state': 'completed'}, final=True + ) + + return message_id + + async def append_raw( + self, task_id: str, event_type: str, payload: str + ) -> str: + """Append a raw event to the stream.""" + if not task_id: + raise ValueError('task_id cannot be empty') + + event_data = { + 'type': event_type, + 'payload': payload, + } + return await self._append_to_stream(task_id, event_data) + + async def get_latest_event(self, task_id: str) -> dict[str, Any] | None: + """Get the latest event from a task stream.""" + if not task_id: + raise ValueError('task_id cannot be empty') + + if not self._connected or not self._client: + raise RuntimeError('Not connected to Redis. Call connect() first.') + + stream_key = self._get_stream_key(task_id) + try: + result = await self._client.xrevrange(stream_key, '+', '-', count=1) + if result: + entry_id, fields = result[0] + return {'id': entry_id, **fields} + except Exception as e: # noqa: BLE001 + logger.warning( + 'Failed to get latest event', + extra={'task_id': task_id, 'error': str(e)}, + ) + + return None + + async def get_events_since(self, task_id: str, start_id: str = '0') -> list: + """Get all events from a task stream since the given ID.""" + if not task_id: + raise ValueError('task_id cannot be empty') + + if not self._connected or not self._client: + raise RuntimeError('Not connected to Redis. Call connect() first.') + + stream_key = self._get_stream_key(task_id) + try: + result = await self._client.xrange(stream_key, start_id, '+') + return [{'id': entry_id, **fields} for entry_id, fields in result] + except Exception as e: # noqa: BLE001 + logger.warning( + 'Failed to get events', + extra={'task_id': task_id, 'error': str(e)}, + ) + return [] diff --git a/tests/server/events/test_redis_event_consumer.py b/tests/server/events/test_redis_event_consumer.py new file mode 100644 index 000000000..c9954cb3c --- /dev/null +++ b/tests/server/events/test_redis_event_consumer.py @@ -0,0 +1,173 @@ +import asyncio + +import pytest + +from a2a.server.events.redis_event_consumer import RedisEventConsumer + + +class FakeQueue: + def __init__(self, items): + self._items = list(items) + self._closed = False + + async def dequeue_event(self, no_wait: bool = False): + if not self._items: + if no_wait: + raise asyncio.QueueEmpty + # simulate wait briefly + await asyncio.sleep(0) + raise asyncio.QueueEmpty + return self._items.pop(0) + + def is_closed(self) -> bool: + return self._closed + + +class FakeQueueWithException: + def __init__(self, exception): + self.exception = exception + + async def dequeue_event(self, no_wait: bool = False): + raise self.exception + + def is_closed(self) -> bool: + return False + + +class FakeQueueWithDelay: + def __init__(self, items, delay=0.1): + self._items = list(items) + self.delay = delay + self._closed = False + + async def dequeue_event(self, no_wait: bool = False): + if no_wait and not self._items: + raise asyncio.QueueEmpty + if self.delay > 0: + await asyncio.sleep(self.delay) + if not self._items: + raise asyncio.QueueEmpty + return self._items.pop(0) + + def is_closed(self) -> bool: + return self._closed + + +@pytest.mark.asyncio +async def test_consume_one_uses_no_wait(): + q = FakeQueue([]) + consumer = RedisEventConsumer(q) + with pytest.raises(asyncio.QueueEmpty): + await consumer.consume_one() + + +@pytest.mark.asyncio +async def test_consume_all_yields_until_closed(): + q = FakeQueue([1, 2]) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + results = [] + # consume two items then break by marking closed and expecting loop to exit + results.append(await anext(it)) + results.append(await anext(it)) + # mark closed and ensure generator exits + q._closed = True + with pytest.raises(StopAsyncIteration): + await anext(it) + assert results == [1, 2] + + +@pytest.mark.asyncio +async def test_consume_one_with_item(): + q = FakeQueue([42]) + consumer = RedisEventConsumer(q) + result = await consumer.consume_one() + assert result == 42 + + +@pytest.mark.asyncio +async def test_consume_all_with_empty_queue(): + q = FakeQueue([]) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + # mark closed immediately + q._closed = True + with pytest.raises(StopAsyncIteration): + await anext(it) + + +@pytest.mark.asyncio +async def test_consume_all_with_exception_in_dequeue(): + q = FakeQueueWithException(RuntimeError('Test error')) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + with pytest.raises(RuntimeError, match='Test error'): + await anext(it) + + +@pytest.mark.asyncio +async def test_consume_one_with_exception_in_dequeue(): + q = FakeQueueWithException(ValueError('Test error')) + consumer = RedisEventConsumer(q) + with pytest.raises(ValueError, match='Test error'): + await consumer.consume_one() + + +@pytest.mark.asyncio +async def test_consume_all_handles_queue_empty_then_closed(): + q = FakeQueue([]) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + # First iteration should raise QueueEmpty but continue since not closed + # Mark closed during the exception handling + q._closed = True + with pytest.raises(StopAsyncIteration): + await anext(it) + + +@pytest.mark.asyncio +async def test_consume_all_with_delay(): + q = FakeQueueWithDelay([1, 2, 3], delay=0.01) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + results = [] + async for item in it: + results.append(item) + if len(results) >= 3: + q._closed = True + break + assert results == [1, 2, 3] + + +@pytest.mark.asyncio +async def test_consumer_initialization(): + q = FakeQueue([1]) + consumer = RedisEventConsumer(q) + assert consumer._queue is q + + +@pytest.mark.asyncio +async def test_consume_all_stops_when_closed_during_iteration(): + q = FakeQueue([1, 2, 3, 4, 5]) + consumer = RedisEventConsumer(q) + it = consumer.consume_all() + results = [] + # Consume a few items + results.append(await anext(it)) + results.append(await anext(it)) + # Mark closed during iteration + q._closed = True + # Next iteration should stop + with pytest.raises(StopAsyncIteration): + await anext(it) + assert results == [1, 2] + + +@pytest.mark.asyncio +async def test_consume_one_no_wait_false(): + """Test that consume_one always uses no_wait=True regardless of parameter.""" + q = FakeQueue([]) + consumer = RedisEventConsumer(q) + # Even though dequeue_event might support no_wait=False, consume_one should always use True + with pytest.raises(asyncio.QueueEmpty): + await consumer.consume_one() diff --git a/tests/server/events/test_redis_event_queue.py b/tests/server/events/test_redis_event_queue.py new file mode 100644 index 000000000..a613e4e1e --- /dev/null +++ b/tests/server/events/test_redis_event_queue.py @@ -0,0 +1,668 @@ +import asyncio +import json + +import pytest + +from a2a.server.events.redis_event_queue import RedisEventQueue + + +class FakeRedis: + """Minimal fake redis supporting xadd, xread, set, delete for tests.""" + + def __init__(self): + # stream_key -> list of (id_str, fields_dict) + self.streams: dict[str, list[tuple[str, dict]]] = {} + # stream_key -> next_id + self.next_ids: dict[str, int] = {} + + async def xadd( + self, stream_key: str, fields: dict, maxlen: int | None = None, **kwargs + ): + lst = self.streams.setdefault(stream_key, []) + next_id = self.next_ids.get(stream_key, 1) + entry_id = f'{next_id}-0' + lst.append((entry_id, fields.copy())) + self.next_ids[stream_key] = next_id + 1 + + # Implement maxlen by trimming the list if needed + if maxlen is not None and len(lst) > maxlen: + # Keep only the last maxlen entries + self.streams[stream_key] = lst[-maxlen:] + + # return id similar to real redis + return entry_id + + async def xread( + self, streams: dict, block: int = 0, count: int | None = None + ): + # streams is {stream_key: last_id} + results = [] + for key, last_id in streams.items(): + lst = self.streams.get(key, []) + # determine numeric last id + if last_id == '$': + # interpret as current max id so return only entries added after this call + if lst: + last_num = max(int(eid.split('-')[0]) for eid, _ in lst) + else: + last_num = 0 + else: + try: + last_num = int(str(last_id).split('-')[0]) + except Exception: + last_num = 0 + + # collect entries with numeric id > last_num + to_return = [ + (eid, fields) + for (eid, fields) in lst + if int(eid.split('-')[0]) > last_num + ] + if to_return: + results.append( + (key, to_return[: count if count is not None else None]) + ) + + return results + + async def set(self, key: str, value: str): + # no-op for tests + return True + + async def delete(self, key: str): + self.streams.pop(key, None) + return True + + async def xrevrange( + self, stream_key: str, start: str, end: str, count: int | None = None + ): + """Simulate Redis XREVRANGE command - get entries in reverse order.""" + lst = self.streams.get(stream_key, []) + if not lst: + return [] + + # Return the last 'count' entries in reverse order + to_return = lst[-count:] if count else lst + return [(entry_id, fields) for entry_id, fields in reversed(to_return)] + + +class MessageEvent: + """Dummy event with class name 'Message' and json() method.""" + + def __init__(self, payload): + self._payload = payload + + def json(self): + return json.dumps(self._payload) + + +@pytest.mark.asyncio +async def test_enqueue_dequeue_roundtrip(): + redis = FakeRedis() + q = RedisEventQueue( + 'task1', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + evt = MessageEvent({'x': 1}) + await q.enqueue_event(evt) + out = await q.dequeue_event(no_wait=True) + assert out == {'x': 1} + + +@pytest.mark.asyncio +async def test_dequeue_no_wait_raises_on_empty(): + redis = FakeRedis() + q = RedisEventQueue( + 'task2', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + with pytest.raises(asyncio.QueueEmpty): + await q.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_close_tombstone_sets_closed_and_raises(): + redis = FakeRedis() + q = RedisEventQueue( + 'task3', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + await q.enqueue_event(MessageEvent({'a': 1})) + # close will append a CLOSE entry + await q.close() + # first dequeue should return the first event + first = await q.dequeue_event(no_wait=True) + assert first == {'a': 1} + # next dequeue should see CLOSE and raise QueueEmpty while marking closed + with pytest.raises(asyncio.QueueEmpty): + await q.dequeue_event(no_wait=True) + assert q.is_closed() + + +@pytest.mark.asyncio +async def test_tap_sees_only_future_events(): + redis = FakeRedis() + q1 = RedisEventQueue( + 'task4', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + # enqueue before tap + await q1.enqueue_event(MessageEvent({'before': True})) + # create tap which should start at '$' and only see future events + q2 = q1.tap() + # q1 can dequeue the earlier event + e1 = await q1.dequeue_event(no_wait=True) + assert e1 == {'before': True} + # enqueue another event; both q1 and q2 should be able to read it (q1 hasn't advanced past it yet) + await q1.enqueue_event(MessageEvent({'later': 2})) + out2 = await q2.dequeue_event(no_wait=True) + assert out2 == {'later': 2} + + +@pytest.mark.asyncio +async def test_enqueue_dequeue_with_complex_data(): + """Test enqueuing and dequeuing complex data structures.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task5', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + complex_data = { + 'nested': {'key': 'value', 'number': 42}, + 'array': [1, 2, {'complex': 'item'}], + 'boolean': True, + 'null_value': None, + } + + evt = MessageEvent(complex_data) + await q.enqueue_event(evt) + out = await q.dequeue_event(no_wait=True) + assert out == complex_data + + +@pytest.mark.asyncio +async def test_enqueue_dequeue_with_unicode_data(): + """Test enqueuing and dequeuing Unicode strings.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task6', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + unicode_data = { + 'emoji': '🚀🌟', + 'multilingual': 'Hello 世界 नमस्ते', + 'special_chars': 'café résumé naïve', + } + + evt = MessageEvent(unicode_data) + await q.enqueue_event(evt) + out = await q.dequeue_event(no_wait=True) + assert out == unicode_data + + +@pytest.mark.asyncio +async def test_multiple_events_fifo_order(): + """Test that events are dequeued in FIFO order.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task7', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Enqueue multiple events + events = [{'id': i, 'data': f'value_{i}'} for i in range(5)] + for event_data in events: + await q.enqueue_event(MessageEvent(event_data)) + + # Dequeue and verify order + for expected in events: + actual = await q.dequeue_event(no_wait=True) + assert actual == expected + + +@pytest.mark.asyncio +async def test_task_done_noop(): + """Test that task_done is a no-op for Redis streams.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task8', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Should not raise any exceptions + q.task_done() + + +@pytest.mark.asyncio +async def test_tap_creates_independent_cursor(): + """Test that tap creates a queue with independent read cursor.""" + redis = FakeRedis() + q1 = RedisEventQueue( + 'task9', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Enqueue some events + await q1.enqueue_event(MessageEvent({'event': 1})) + await q1.enqueue_event(MessageEvent({'event': 2})) + + # Create tap - this should start after the current events + q2 = q1.tap() + + # q1 should still be able to read the events + e1_from_q1 = await q1.dequeue_event(no_wait=True) + e2_from_q1 = await q1.dequeue_event(no_wait=True) + + assert e1_from_q1 == {'event': 1} + assert e2_from_q1 == {'event': 2} + + # q2 should not see the previous events (it starts at the end) + # Enqueue a new event that both should see + await q1.enqueue_event(MessageEvent({'event': 3})) + + # q2 should be able to read the new event + e3_from_q2 = await q2.dequeue_event(no_wait=True) + assert e3_from_q2 == {'event': 3} + + +@pytest.mark.asyncio +async def test_close_behavior(): + """Test close operation behavior.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task10', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Enqueue an event + await q.enqueue_event(MessageEvent({'test': 'data'})) + + # Close the queue + await q.close() + + # Should be able to dequeue existing events + result = await q.dequeue_event(no_wait=True) + assert result == {'test': 'data'} + + # Further dequeue should raise QueueEmpty + with pytest.raises(asyncio.QueueEmpty): + await q.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_enqueue_after_close_ignored(): + """Test that enqueuing after close is ignored.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task11', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Close first + await q.close() + + # Enqueue should be ignored (no exception, but no effect) + await q.enqueue_event(MessageEvent({'should_be_ignored': True})) + + # Dequeue should raise QueueEmpty immediately + with pytest.raises(asyncio.QueueEmpty): + await q.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_maxlen_parameter(): + """Test maxlen parameter limits stream size.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task12', redis, stream_prefix='a2a:test', read_block_ms=10, maxlen=2 + ) + + # Add more events than maxlen + for i in range(5): + await q.enqueue_event(MessageEvent({'event': i})) + + # Check what's actually in the stream + stream_key = 'a2a:test:task12' + print(f'Stream contents: {redis.streams.get(stream_key, [])}') + + # Should only be able to dequeue the last 2 events (due to maxlen=2) + events_dequeued = [] + try: + while True: + event = await q.dequeue_event(no_wait=True) + events_dequeued.append(event) + print(f'Dequeued event: {event}') + except asyncio.QueueEmpty: + pass + + print(f'Total events dequeued: {len(events_dequeued)}') + + # Should have exactly 2 events (the last 2 added due to maxlen=2) + assert len(events_dequeued) == 2 + # Verify they are the last 2 events (events 3 and 4) + assert events_dequeued[0]['event'] == 3 + assert events_dequeued[1]['event'] == 4 + + +@pytest.mark.asyncio +async def test_enqueue_event_after_close_logs_warning(): + """Test that enqueuing after close logs a warning but doesn't raise.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task13', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Close the queue + await q.close() + + # Enqueue should not raise but should log warning + await q.enqueue_event(MessageEvent({'test': 'data'})) + + # Verify no events were actually added to stream + stream_key = 'a2a:test:task13' + stream_entries = redis.streams.get(stream_key, []) + # Should only have the CLOSE entry + assert len(stream_entries) == 1 + assert stream_entries[0][1]['type'] == 'CLOSE' + + +@pytest.mark.asyncio +async def test_dequeue_event_on_closed_queue_raises_immediately(): + """Test that dequeue on closed queue raises immediately.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task14', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Close without any events + await q.close() + + # Dequeue should raise immediately + with pytest.raises(asyncio.QueueEmpty, match='Queue is closed'): + await q.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_clear_events_deletes_stream(): + """Test clear_events deletes the underlying Redis stream.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task15', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add some events + await q.enqueue_event(MessageEvent({'test': 1})) + await q.enqueue_event(MessageEvent({'test': 2})) + + # Verify stream exists + stream_key = 'a2a:test:task15' + assert stream_key in redis.streams + + # Clear events + await q.clear_events() + + # Verify stream is deleted + assert stream_key not in redis.streams + + +@pytest.mark.asyncio +async def test_clear_events_handles_redis_error(): + """Test clear_events handles Redis errors gracefully.""" + + class FailingRedis(FakeRedis): + async def delete(self, key: str): + raise Exception('Redis delete failed') + + redis = FailingRedis() + q = RedisEventQueue( + 'task16', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Should not raise exception even if Redis delete fails + await q.clear_events() + + +@pytest.mark.asyncio +async def test_close_idempotent(): + """Test that close() can be called multiple times safely.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task17', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Close multiple times + await q.close() + await q.close() + await q.close() + + # Should only have one CLOSE entry + stream_key = 'a2a:test:task17' + stream_entries = redis.streams.get(stream_key, []) + close_entries = [ + entry for entry in stream_entries if entry[1].get('type') == 'CLOSE' + ] + assert len(close_entries) == 1 + + +@pytest.mark.asyncio +async def test_close_handles_redis_error(): + """Test close handles Redis errors gracefully.""" + + class FailingRedis(FakeRedis): + async def xadd( + self, + stream_key: str, + fields: dict, + maxlen: int | None = None, + **kwargs, + ): + if fields.get('type') == 'CLOSE': + raise Exception('Redis xadd failed') + return await super().xadd(stream_key, fields, maxlen, **kwargs) + + redis = FailingRedis() + q = RedisEventQueue( + 'task18', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Should not raise exception even if Redis xadd fails + await q.close() + + +@pytest.mark.asyncio +async def test_dequeue_handles_malformed_json(): + """Test dequeue handles malformed JSON gracefully.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task19', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Manually add malformed JSON to stream + stream_key = 'a2a:test:task19' + redis.streams[stream_key] = [ + ('1-0', {'type': 'Message', 'payload': '{invalid json'}) + ] + + # Should return the raw malformed data + result = await q.dequeue_event(no_wait=True) + assert result == '{invalid json' + + +@pytest.mark.asyncio +async def test_dequeue_handles_unknown_event_type(): + """Test dequeue handles unknown event types gracefully.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task20', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add event with unknown type + await redis.xadd( + 'a2a:test:task20', + {'type': 'UnknownType', 'payload': '{"test": "data"}'}, + ) + + # Should return raw payload for unknown types + result = await q.dequeue_event(no_wait=True) + assert result == {'test': 'data'} + + +@pytest.mark.asyncio +async def test_dequeue_handles_bytes_fields(): + """Test dequeue handles Redis returning bytes for field keys/values.""" + + class BytesRedis(FakeRedis): + async def xread( + self, streams: dict, block: int = 0, count: int | None = None + ): + # Call parent to get normal results + results = await super().xread(streams, block, count) + if results: + # Convert some fields to bytes to simulate Redis behavior + key, entries = results[0] + modified_entries = [] + for entry_id, fields in entries: + modified_fields = {} + for k, v in fields.items(): + # Convert keys and string values to bytes + k_bytes = k.encode('utf-8') if isinstance(k, str) else k + if isinstance(v, str): + v_bytes = v.encode('utf-8') + else: + v_bytes = v + modified_fields[k_bytes] = v_bytes + modified_entries.append((entry_id, modified_fields)) + return [(key, modified_entries)] + return results + + redis = BytesRedis() + q = RedisEventQueue( + 'task21', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add an event + await q.enqueue_event(MessageEvent({'test': 'data'})) + + # Should handle bytes conversion properly + result = await q.dequeue_event(no_wait=True) + assert result == {'test': 'data'} + + +@pytest.mark.asyncio +async def test_dequeue_handles_unicode_decode_error(): + """Test dequeue handles Unicode decode errors gracefully.""" + + class BadBytesRedis(FakeRedis): + async def xread( + self, streams: dict, block: int = 0, count: int | None = None + ): + results = await super().xread(streams, block, count) + if results: + key, entries = results[0] + modified_entries = [] + for entry_id, fields in entries: + modified_fields = {} + for k, v in fields.items(): + k_bytes = k.encode('utf-8') if isinstance(k, str) else k + if isinstance(v, str): + # Create invalid UTF-8 bytes + v_bytes = b'\xff\xfe\xfd' + else: + v_bytes = v + modified_fields[k_bytes] = v_bytes + modified_entries.append((entry_id, modified_fields)) + return [(key, modified_entries)] + return results + + redis = BadBytesRedis() + q = RedisEventQueue( + 'task22', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add an event + await q.enqueue_event(MessageEvent({'test': 'data'})) + + # Should handle decode error and return raw bytes + result = await q.dequeue_event(no_wait=True) + assert result == b'\xff\xfe\xfd' + + +@pytest.mark.asyncio +async def test_tap_with_empty_stream(): + """Test tap behavior with empty stream.""" + redis = FakeRedis() + q1 = RedisEventQueue( + 'task23', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Create tap on empty stream + q2 = q1.tap() + + # Both should handle empty stream gracefully + with pytest.raises(asyncio.QueueEmpty): + await q1.dequeue_event(no_wait=True) + + with pytest.raises(asyncio.QueueEmpty): + await q2.dequeue_event(no_wait=True) + + +@pytest.mark.asyncio +async def test_tap_with_existing_events(): + """Test tap starts after existing events.""" + redis = FakeRedis() + q1 = RedisEventQueue( + 'task24', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add events before tap + await q1.enqueue_event(MessageEvent({'before': 1})) + await q1.enqueue_event(MessageEvent({'before': 2})) + + # Create tap + q2 = q1.tap() + + # Add event after tap + await q1.enqueue_event(MessageEvent({'after': 3})) + + # q1 should see all events + assert await q1.dequeue_event(no_wait=True) == {'before': 1} + assert await q1.dequeue_event(no_wait=True) == {'before': 2} + assert await q1.dequeue_event(no_wait=True) == {'after': 3} + + # q2 should only see the event added after tap + assert await q2.dequeue_event(no_wait=True) == {'after': 3} + + +@pytest.mark.asyncio +async def test_is_closed_initially_false(): + """Test is_closed returns False initially.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task25', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + assert not q.is_closed() + + +@pytest.mark.asyncio +async def test_is_closed_after_close(): + """Test is_closed returns True after close.""" + redis = FakeRedis() + q = RedisEventQueue( + 'task26', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + await q.close() + + assert q.is_closed() + + +@pytest.mark.asyncio +async def test_dequeue_with_blocking_timeout(): + """Test dequeue with blocking behavior (no_wait=False).""" + redis = FakeRedis() + q = RedisEventQueue( + 'task27', redis, stream_prefix='a2a:test', read_block_ms=10 + ) + + # Add an event first + await q.enqueue_event(MessageEvent({'immediate': 'event'})) + + # This should return the event immediately + result = await q.dequeue_event(no_wait=False) + assert result == {'immediate': 'event'} + + # Now test with no events - should raise QueueEmpty after timeout + with pytest.raises(asyncio.QueueEmpty): + await q.dequeue_event(no_wait=False) diff --git a/tests/server/events/test_redis_queue_manager.py b/tests/server/events/test_redis_queue_manager.py new file mode 100644 index 000000000..d48dfa8e8 --- /dev/null +++ b/tests/server/events/test_redis_queue_manager.py @@ -0,0 +1,777 @@ +import logging + +import pytest + + +@pytest.mark.asyncio +async def test_create_or_tap_creates_queue(monkeypatch): + created = {} + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + pass # No-op for test + + # Monkeypatch import used in RedisQueueManager.create_or_tap by inserting + # a proper module object with attribute `RedisEventQueue`. + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + q = await manager.create_or_tap('t1') + assert hasattr(q, 'task_id') + + +@pytest.mark.asyncio +async def test_add_not_supported(): + """Test that add() is not supported in RedisQueueManager (distributed setup).""" + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + class DummyQueue: + def __init__(self, id): + self.id = id + + async def close(self): + return None + + q = DummyQueue('t2') + + # add() should raise NotImplementedError in distributed Redis setup + with pytest.raises( + NotImplementedError, + match='add\\(\\) is not supported in distributed Redis setup', + ): + await manager.add('t2', q) + + +@pytest.mark.asyncio +async def test_create_or_tap_with_different_task_ids(monkeypatch): + """Test create_or_tap with different task IDs creates separate queues.""" + created_queues = {} + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + created_queues[task_id] = self + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + pass + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Create queues for different tasks + q1 = await manager.create_or_tap('task1') + q2 = await manager.create_or_tap('task2') + q3 = await manager.create_or_tap( + 'task1' + ) # Same task creates new instance (no caching) + + assert q1.task_id == 'task1' + assert q2.task_id == 'task2' + assert q3.task_id == 'task1' + assert ( + q1 is not q3 + ) # Different instances for same task (no caching in Redis) + + +@pytest.mark.asyncio +async def test_close_operation(monkeypatch): + """Test close operation on Redis queue manager.""" + closed_queues = [] + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + closed_queues.append(self.task_id) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FakeRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Create and close a queue + q1 = await manager.create_or_tap('task1') + await manager.close('task1') + + assert 'task1' in closed_queues + + +@pytest.mark.asyncio +async def test_close_nonexistent_task(monkeypatch): + """Test closing a nonexistent task.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + + def tap(self): + return FakeRedisEventQueue(self.task_id, None) + + async def close(self): + pass # No-op for test + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Should not raise error when closing nonexistent task + await manager.close('nonexistent_task') + + +@pytest.mark.asyncio +async def test_get_operation(monkeypatch): + """Test get operation on Redis queue manager.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + + def tap(self): + return FakeRedisEventQueue(self.task_id, None) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FakeRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Get operation should raise NotImplementedError in distributed Redis setup + with pytest.raises( + NotImplementedError, + match='get\\(\\) is not supported in distributed Redis setup', + ): + await manager.get('nonexistent') + + # Get existing task should also raise NotImplementedError in distributed setup + q1 = await manager.create_or_tap('task1') + with pytest.raises( + NotImplementedError, + match='get\\(\\) is not supported in distributed Redis setup', + ): + await manager.get('task1') + + +@pytest.mark.asyncio +async def test_tap_operation(monkeypatch): + """Test tap operation creates new queue instance with same redis_client.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + # Return a new queue with the same redis_client (matching actual behavior) + return FakeRedisEventQueue(self.task_id, self.redis_client) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FakeRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client='fake_redis') + + # Create tap + tapped_queue = await manager.tap('task1') + + assert tapped_queue.task_id == 'task1' + assert ( + tapped_queue.redis_client == 'fake_redis' + ) # Tap should have the same redis_client + + +@pytest.mark.asyncio +async def test_create_or_tap_with_none_redis_client(): + """Test create_or_tap with None redis_client.""" + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Should work with None redis_client (passed to RedisEventQueue) + q = await manager.create_or_tap('task1') + assert q.task_id == 'task1' + + +@pytest.mark.asyncio +async def test_multiple_taps_same_task(): + """Test multiple taps on same task create independent instances.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client='fake_redis') + + # Create multiple taps + tap1 = await manager.tap('task1') + tap2 = await manager.tap('task1') + tap3 = await manager.tap('task1') + + # All should be different instances + assert tap1 is not tap2 + assert tap2 is not tap3 + assert tap1 is not tap3 + + # All should have same task_id and redis_client + assert tap1.task_id == 'task1' + assert tap2.task_id == 'task1' + assert tap3.task_id == 'task1' + assert tap1.redis_client == 'fake_redis' + assert tap2.redis_client == 'fake_redis' + assert tap3.redis_client == 'fake_redis' + + +@pytest.mark.asyncio +async def test_close_multiple_tasks(): + """Test closing multiple tasks.""" + closed_tasks = [] + + # Use the FakeRedis from the test file that has xrevrange + import os + import sys + + sys.path.append(os.path.dirname(__file__)) + from test_redis_event_queue import FakeRedis + + redis = FakeRedis() + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + self._stream_prefix = stream_prefix or 'a2a:task' + self._stream_key = f'{self._stream_prefix}:{task_id}' + + def tap(self): + return FakeRedisEventQueue( + self.task_id, self.redis_client, self._stream_prefix + ) + + async def close(self): + closed_tasks.append(self.task_id) + # Actually write the CLOSE entry to the stream + await self.redis_client.xadd(self._stream_key, {'type': 'CLOSE'}) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FakeRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=redis) + + # Create multiple queues + q1 = await manager.create_or_tap('task1') + q2 = await manager.create_or_tap('task2') + q3 = await manager.create_or_tap('task3') + + # Close all + await manager.close('task1') + await manager.close('task2') + await manager.close('task3') + + assert set(closed_tasks) == {'task1', 'task2', 'task3'} + + +@pytest.mark.asyncio +async def test_close_already_closed_task(): + """Test closing a task that's already closed.""" + # Use the FakeRedis from the test file that has xrevrange + import os + import sys + + sys.path.append(os.path.dirname(__file__)) + from test_redis_event_queue import FakeRedis + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + redis = FakeRedis() + manager = RedisQueueManager(redis_client=redis) + + # Create and close a queue + await manager.create_or_tap('task1') + await manager.close('task1') + + # Check that a CLOSE entry was created + stream_key = 'a2a:task:task1' + stream_entries = redis.streams.get(stream_key, []) + close_entries = [ + entry for entry in stream_entries if entry[1].get('type') == 'CLOSE' + ] + assert len(close_entries) == 1 + + # Close again - should not create another CLOSE entry + await manager.close('task1') + + # Should still only have one CLOSE entry + stream_entries = redis.streams.get(stream_key, []) + close_entries = [ + entry for entry in stream_entries if entry[1].get('type') == 'CLOSE' + ] + assert len(close_entries) == 1 + + +@pytest.mark.asyncio +async def test_tap_nonexistent_task(): + """Test tapping a task that doesn't exist.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client='fake_redis') + + # Tap nonexistent task should work (creates new queue) + tapped_queue = await manager.tap('nonexistent') + + assert tapped_queue.task_id == 'nonexistent' + assert tapped_queue.redis_client == 'fake_redis' + + +@pytest.mark.asyncio +async def test_manager_initialization(): + """Test RedisQueueManager initialization.""" + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client='test_client') + assert manager._redis == 'test_client' + + +@pytest.mark.asyncio +async def test_create_or_tap_with_custom_stream_prefix(): + """Test create_or_tap with custom stream prefix.""" + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.stream_prefix = stream_prefix + + def tap(self): + return FakeRedisEventQueue(self.task_id, None, self.stream_prefix) + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Create queue (RedisEventQueue would use default prefix) + q = await manager.create_or_tap('task1') + assert q.task_id == 'task1' + # Note: In real implementation, stream_prefix would be passed through + + +@pytest.mark.asyncio +async def test_error_handling_in_close(): + """Test error handling when closing queues.""" + + class FailingRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + + def tap(self): + return FailingRedisEventQueue(self.task_id, None) + + async def close(self): + raise Exception('Close failed') + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FailingRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Create queue + await manager.create_or_tap('task1') + + # Close should handle exceptions gracefully (not raise) + await manager.close('task1') + + +@pytest.mark.asyncio +async def test_create_or_tap_logging(monkeypatch, caplog): + """Test that create_or_tap logs appropriate information.""" + import logging + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + pass + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + monkeypatch.setitem( + sys.modules, 'a2a.server.events.redis_event_queue', fake_mod + ) + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client='fake_redis') + + # Set logging level to INFO to capture the logging statements + with caplog.at_level(logging.INFO): + q = await manager.create_or_tap('test_task') + + # Check that logging statements were executed + assert any( + 'create_or_tap called with task_id: test_task' in record.message + for record in caplog.records + ), 'Should log task_id' + assert any( + 'Creating RedisEventQueue instance' in record.message + for record in caplog.records + ), 'Should log queue creation' + + +@pytest.mark.asyncio +async def test_close_logging_on_redis_check_failure(caplog): + """Test that close method logs when Redis check fails.""" + import logging + import os + + # Use the FakeRedis from the test file that has xrevrange + import sys + + sys.path.append(os.path.dirname(__file__)) + from test_redis_event_queue import FakeRedis + + redis = FakeRedis() + + # Mock xrevrange to raise an exception + async def failing_xrevrange(*args, **kwargs): + raise Exception('Redis connection failed') + + redis.xrevrange = failing_xrevrange + + class FakeRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FakeRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + pass + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FakeRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FakeRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=redis) + + # Create queue first + await manager.create_or_tap('test_task') + + # Close should handle Redis check failure gracefully and log it + with caplog.at_level(logging.DEBUG): + await manager.close('test_task') + + # Check that debug logging occurred for the Redis check failure + assert any( + 'Could not check if stream is already closed' in record.message + for record in caplog.records + ), 'Should log Redis check failure' + + +@pytest.mark.asyncio +async def test_close_logging_on_queue_close_failure(caplog): + """Test that close method logs when queue close fails.""" + import logging + + class FailingRedisEventQueue: + def __init__(self, task_id, redis_client, stream_prefix=None): + self.task_id = task_id + self.redis_client = redis_client + + def tap(self): + return FailingRedisEventQueue(self.task_id, self.redis_client) + + async def close(self): + # Make the close method fail by trying to use the redis_client + if self.redis_client is None: + raise Exception('Queue close failed - no redis client') + # This should not be reached in the test + + # Monkeypatch + import sys + import types + + fake_mod = types.ModuleType('a2a.server.events.redis_event_queue') + fake_mod.RedisEventQueue = FailingRedisEventQueue + sys.modules['a2a.server.events.redis_event_queue'] = fake_mod + + # Also patch the RedisEventQueue reference in redis_queue_manager + import a2a.server.events.redis_queue_manager as rqm + + rqm.RedisEventQueue = FailingRedisEventQueue + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Create queue + await manager.create_or_tap('test_task') + + # Close should handle queue close failure gracefully and log it + with caplog.at_level(logging.DEBUG): + await manager.close('test_task') + + # Check that debug logging occurred for the queue close failure + assert any( + 'Failed to close queue' in record.message for record in caplog.records + ), ( + f'Should log queue close failure. Captured logs: {[r.message for r in caplog.records]}' + ) + + +@pytest.mark.asyncio +async def test_create_or_tap_redis_event_queue_import_failure( + monkeypatch, caplog +): + """Test create_or_tap when RedisEventQueue import fails.""" + # Mock the import failure by setting RedisEventQueue to None + from a2a.server.events import redis_queue_manager + + # Store original value + original_redis_event_queue = redis_queue_manager.RedisEventQueue + + try: + # Set RedisEventQueue to None to simulate import failure + redis_queue_manager.RedisEventQueue = None + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + manager = RedisQueueManager(redis_client=None) + + # Should raise RuntimeError when RedisEventQueue is None + with pytest.raises( + RuntimeError, match='RedisEventQueue is not available' + ): + with caplog.at_level(logging.ERROR): + await manager.create_or_tap('test_task') + + # Check that error logging occurred + assert any( + 'RedisEventQueue is None - import failed' in record.message + for record in caplog.records + ), ( + f'Should log import failure. Captured logs: {[r.message for r in caplog.records]}' + ) + + finally: + # Restore original value + redis_queue_manager.RedisEventQueue = original_redis_event_queue + + +@pytest.mark.asyncio +async def test_tap_redis_event_queue_import_failure(monkeypatch, caplog): + """Test tap when RedisEventQueue import fails.""" + # Store original value + import a2a.server.events.redis_queue_manager as rqm + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + original_redis_event_queue = rqm.RedisEventQueue + + try: + # Set RedisEventQueue to None to simulate import failure + rqm.RedisEventQueue = None + + manager = RedisQueueManager(redis_client=None) + + # Should raise RuntimeError when RedisEventQueue is None + with pytest.raises( + RuntimeError, + match='RedisEventQueue is not available. Cannot create tap', + ): + await manager.tap('test_task') + + finally: + # Restore original value + rqm.RedisEventQueue = original_redis_event_queue + + +@pytest.mark.asyncio +async def test_close_redis_event_queue_import_failure(monkeypatch): + """Test close when RedisEventQueue import fails.""" + # Store original value + import a2a.server.events.redis_queue_manager as rqm + + from a2a.server.events.redis_queue_manager import RedisQueueManager + + original_redis_event_queue = rqm.RedisEventQueue + + try: + # Set RedisEventQueue to None to simulate import failure + rqm.RedisEventQueue = None + + manager = RedisQueueManager(redis_client=None) + + # Should raise RuntimeError when RedisEventQueue is None + with pytest.raises( + RuntimeError, + match='RedisEventQueue is not available. Cannot close stream', + ): + await manager.close('test_task') + + finally: + # Restore original value + rqm.RedisEventQueue = original_redis_event_queue diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index f96ce5e65..58c04d53c 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -2454,3 +2454,150 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found(): f'Task {task_id} was specified but does not exist' in exc_info.value.error.message ) + + +def test_init_with_default_queue_manager_issues_deprecation_warning(): + """Test that initializing with default queue_manager issues deprecation warning.""" + import warnings + + from unittest.mock import MagicMock + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + handler = DefaultRequestHandler( + agent_executor=MagicMock(), task_store=MagicMock() + ) + + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert 'Using default InMemoryQueueManager' in str(w[0].message) + assert 'will be removed in a future version' in str(w[0].message) + + +def test_init_with_explicit_queue_manager_no_warning(): + """Test that initializing with explicit queue_manager does not issue warning.""" + import warnings + + from unittest.mock import MagicMock + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + handler = DefaultRequestHandler( + agent_executor=MagicMock(), + task_store=MagicMock(), + queue_manager=InMemoryQueueManager(), + ) + + # Should not have any deprecation warnings + deprecation_warnings = [ + warning + for warning in w + if issubclass(warning.category, DeprecationWarning) + ] + assert len(deprecation_warnings) == 0 + + +@pytest.mark.asyncio +async def test_init_with_disabled_fallback_raises_error(): + """Test that disabling fallback raises ValueError when queue_manager is None.""" + import os + + from unittest.mock import MagicMock + + # Set environment variable to disable fallback + old_value = os.environ.get('A2A_DISABLE_QUEUE_MANAGER_FALLBACK') + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = 'true' + + try: + with pytest.raises(ValueError, match='queue_manager is required'): + DefaultRequestHandler( + agent_executor=MagicMock(), task_store=MagicMock() + ) + finally: + # Restore environment variable + if old_value is None: + os.environ.pop('A2A_DISABLE_QUEUE_MANAGER_FALLBACK', None) + else: + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = old_value + + +@pytest.mark.asyncio +async def test_init_with_disabled_fallback_false_allows_default(): + """Test that setting A2A_DISABLE_QUEUE_MANAGER_FALLBACK=false allows default behavior.""" + import os + import warnings + + from unittest.mock import MagicMock + + # Set environment variable to explicitly allow fallback + old_value = os.environ.get('A2A_DISABLE_QUEUE_MANAGER_FALLBACK') + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = 'false' + + try: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + handler = DefaultRequestHandler( + agent_executor=MagicMock(), task_store=MagicMock() + ) + + # Should still get deprecation warning + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + finally: + # Restore environment variable + if old_value is None: + os.environ.pop('A2A_DISABLE_QUEUE_MANAGER_FALLBACK', None) + else: + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = old_value + + +def test_environment_variable_parsing(): + """Test that environment variable accepts various true/false values.""" + import os + + from unittest.mock import MagicMock + + test_cases = [ + ('true', True), + ('True', True), + ('TRUE', True), + ('1', True), + ('yes', True), + ('Yes', True), + ('false', False), + ('False', False), + ('FALSE', False), + ('0', False), + ('no', False), + ('No', False), + ('invalid', False), # Invalid values should default to False + ('', False), # Empty string should default to False + ] + + for env_value, expected_disable in test_cases: + old_value = os.environ.get('A2A_DISABLE_QUEUE_MANAGER_FALLBACK') + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = env_value + + try: + if expected_disable: + with pytest.raises( + ValueError, match='queue_manager is required' + ): + DefaultRequestHandler( + agent_executor=MagicMock(), task_store=MagicMock() + ) + else: + # Should work without error (may issue warning) + handler = DefaultRequestHandler( + agent_executor=MagicMock(), task_store=MagicMock() + ) + assert handler is not None + finally: + # Restore environment variable + if old_value is None: + os.environ.pop('A2A_DISABLE_QUEUE_MANAGER_FALLBACK', None) + else: + os.environ['A2A_DISABLE_QUEUE_MANAGER_FALLBACK'] = old_value diff --git a/tests/server/request_handlers/test_redis_request_handler.py b/tests/server/request_handlers/test_redis_request_handler.py new file mode 100644 index 000000000..4aae6502a --- /dev/null +++ b/tests/server/request_handlers/test_redis_request_handler.py @@ -0,0 +1,23 @@ +def test_create_redis_request_handler_monkeypatched(monkeypatch): + class FakeRedisQueueManager: + def __init__(self, redis_client=None, stream_prefix='a2a:task'): + self.redis_client = redis_client + + monkeypatch.setenv('A2A_FAKE', '1') + + # Monkeypatch RedisQueueManager to our fake to avoid real redis import + import a2a.server.events.redis_queue_manager as rqm + + from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, + ) + from a2a.server.request_handlers.redis_request_handler import ( + create_redis_request_handler, + ) + + rqm.RedisQueueManager = FakeRedisQueueManager + + handler = create_redis_request_handler( + agent_executor=object(), task_store=object(), redis_client=None + ) + assert isinstance(handler, DefaultRequestHandler) diff --git a/tests/utils/test_redis_stream_writer.py b/tests/utils/test_redis_stream_writer.py new file mode 100644 index 000000000..68795c9d1 --- /dev/null +++ b/tests/utils/test_redis_stream_writer.py @@ -0,0 +1,430 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from a2a.types import TaskStatusUpdateEvent +from a2a.utils.stream_write.redis_stream_writer import RedisStreamInjector + + +@pytest.fixture +def mock_redis_client(): + """Fixture providing a mock Redis client.""" + client = AsyncMock() + client.xadd = AsyncMock(return_value='123-0') + client.ping = AsyncMock() + client.aclose = AsyncMock() + return client + + +class TestRedisStreamInjector: + """Test suite for RedisStreamInjector.""" + + def test_init_without_redis_import_raises_error(self): + """Test that initialization fails when redis is not available.""" + with patch('a2a.utils.stream_write.redis_stream_writer.Redis', None): + with pytest.raises(ImportError, match='redis package is required'): + RedisStreamInjector() + + def test_init_with_redis_available(self): + """Test successful initialization when redis is available.""" + with patch( + 'a2a.utils.stream_write.redis_stream_writer.Redis' + ) as mock_redis: + injector = RedisStreamInjector('redis://localhost:6379/0') + assert injector.redis_url == 'redis://localhost:6379/0' + assert injector._client is None + assert not injector._connected + + @pytest.mark.asyncio + async def test_connect_success(self): + """Test successful connection to Redis.""" + mock_client = AsyncMock() + mock_client.ping = AsyncMock() + + with patch( + 'a2a.utils.stream_write.redis_stream_writer.Redis' + ) as mock_redis_class: + mock_redis_class.from_url.return_value = mock_client + + injector = RedisStreamInjector() + await injector.connect() + + assert injector._client == mock_client + assert injector._connected + mock_redis_class.from_url.assert_called_once_with( + 'redis://localhost:6379/0' + ) + mock_client.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_already_connected(self): + """Test that connect does nothing if already connected.""" + mock_client = AsyncMock() + + with patch( + 'a2a.utils.stream_write.redis_stream_writer.Redis' + ) as mock_redis_class: + mock_redis_class.from_url.return_value = mock_client + + injector = RedisStreamInjector() + injector._connected = True + injector._client = mock_client + + await injector.connect() + + # Should not create new client or ping + mock_redis_class.from_url.assert_not_called() + mock_client.ping.assert_not_called() + + @pytest.mark.asyncio + async def test_connect_failure(self): + """Test connection failure.""" + with patch( + 'a2a.utils.stream_write.redis_stream_writer.Redis' + ) as mock_redis_class: + mock_redis_class.from_url.side_effect = Exception( + 'Connection failed' + ) + + injector = RedisStreamInjector() + + with pytest.raises(Exception, match='Connection failed'): + await injector.connect() + + assert not injector._connected + assert injector._client is None + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test disconnecting from Redis.""" + mock_client = AsyncMock() + mock_client.aclose = AsyncMock() + + injector = RedisStreamInjector(redis_client=mock_client) + injector._connected = True + + await injector.disconnect() + + assert not injector._connected + assert injector._client is None + mock_client.aclose.assert_called_once() + + @pytest.mark.asyncio + async def test_disconnect_not_connected(self): + """Test disconnect when not connected.""" + mock_client = AsyncMock() + injector = RedisStreamInjector(redis_client=mock_client) + injector._connected = False + + await injector.disconnect() + + # Should not call aclose since not connected + mock_client.aclose.assert_not_called() + + # Should not raise any errors and client should remain + assert not injector._connected + assert injector._client == mock_client + + @pytest.mark.asyncio + async def test_context_manager(self): + """Test async context manager.""" + mock_client = AsyncMock() + mock_client.ping = AsyncMock() + mock_client.aclose = AsyncMock() + + injector = RedisStreamInjector(redis_client=mock_client) + + async with injector as ctx_injector: + assert ctx_injector == injector + assert injector._connected + + assert not injector._connected + mock_client.aclose.assert_called_once() + + def test_get_stream_key(self): + """Test stream key generation.""" + mock_client = AsyncMock() + injector = RedisStreamInjector(redis_client=mock_client) + + key = injector._get_stream_key('test_task') + assert key == 'a2a:task:test_task' + + def test_get_stream_key_empty_task_id(self): + """Test stream key generation with empty task_id.""" + mock_client = AsyncMock() + injector = RedisStreamInjector(redis_client=mock_client) + + with pytest.raises(ValueError, match='task_id cannot be empty'): + injector._get_stream_key('') + + def test_serialize_event(self): + """Test event serialization.""" + injector = RedisStreamInjector(redis_client=AsyncMock()) + + event_data = injector._serialize_event('test_type', {'key': 'value'}) + assert event_data['type'] == 'test_type' + assert 'payload' in event_data + + @pytest.mark.asyncio + async def test_append_to_stream(self): + """Test appending event to stream.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector(redis_client=mock_client) + + event_data = {'type': 'Test', 'payload': '{"data": "test"}'} + result = await injector._append_to_stream('test_task', event_data) + + assert result == '123-0' + mock_client.xadd.assert_called_once_with( + 'a2a:task:test_task', event_data + ) + + @pytest.mark.asyncio + async def test_append_to_stream_not_connected(self): + """Test append_to_stream when not connected.""" + mock_client = AsyncMock() + injector = RedisStreamInjector(redis_client=mock_client) + injector._connected = False + + with pytest.raises(RuntimeError, match='Not connected to Redis'): + await injector._append_to_stream('test_task', {}) + + @pytest.mark.asyncio + async def test_stream_message_with_dict(self): + """Test streaming message with dict input.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector(redis_client=mock_client) + + message_data = {'content': 'test message', 'role': 'assistant'} + result = await injector.stream_message( + 'ctx123', 'task123', message_data + ) + + assert result == '123-0' + mock_client.xadd.assert_called_once() + + # Verify the call arguments + call_args = mock_client.xadd.call_args + stream_key = call_args[0][0] + event_data = call_args[0][1] + + assert stream_key == 'a2a:task:task123' + assert event_data['type'] == 'Message' + assert 'payload' in event_data + + @pytest.mark.asyncio + async def test_stream_message_with_message_object(self): + """Test streaming message with Message object.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + # Create a proper Message object with required fields + from a2a.types import Message, Role, TextPart + + message = Message( + message_id='msg-123', + parts=[TextPart(text='test message')], + role=Role.agent, + ) + result = await injector.stream_message('ctx123', 'task123', message) + + assert result == '123-0' + mock_client.xadd.assert_called_once() + + @pytest.mark.asyncio + async def test_stream_message_empty_task_id(self): + """Test stream_message with empty task_id.""" + injector = RedisStreamInjector() + + with pytest.raises(ValueError, match='task_id cannot be empty'): + await injector.stream_message('ctx123', '', {'content': 'test'}) + + @pytest.mark.asyncio + async def test_stream_message_empty_context_id(self): + """Test stream_message with empty context_id.""" + injector = RedisStreamInjector() + + with pytest.raises(ValueError, match='context_id cannot be empty'): + await injector.stream_message('', 'task123', {'content': 'test'}) + + @pytest.mark.asyncio + async def test_update_status_with_task_status_update_event(self): + """Test update_status with TaskStatusUpdateEvent.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + # Create a TaskStatusUpdateEvent + status_event = TaskStatusUpdateEvent( + context_id='ctx123', + task_id='task123', + final=False, + status={ + 'state': 'working', + 'message': None, + 'timestamp': '2023-01-01T00:00:00Z', + }, + ) + + result = await injector.update_status('ctx123', 'task123', status_event) + + assert result == '123-0' + mock_client.xadd.assert_called_once() + + @pytest.mark.asyncio + async def test_update_status_with_dict(self): + """Test update_status with dict status.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.update_status( + 'ctx123', 'task123', {'state': 'completed'} + ) + + assert result == '123-0' + mock_client.xadd.assert_called_once() + + @pytest.mark.asyncio + async def test_final_message(self): + """Test final_message method.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + message_data = {'content': 'final message', 'role': 'assistant'} + result = await injector.final_message('ctx123', 'task123', message_data) + + assert result == '123-0' + # Should call xadd twice: once for message, once for status update + assert mock_client.xadd.call_count == 2 + + @pytest.mark.asyncio + async def test_append_raw(self): + """Test append_raw method.""" + mock_client = AsyncMock() + mock_client.xadd = AsyncMock(return_value='123-0') + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.append_raw( + 'task123', 'CustomEvent', '{"data": "test"}' + ) + + assert result == '123-0' + mock_client.xadd.assert_called_once_with( + 'a2a:task:task123', + {'type': 'CustomEvent', 'payload': '{"data": "test"}'}, + ) + + @pytest.mark.asyncio + async def test_get_latest_event(self): + """Test get_latest_event method.""" + mock_client = AsyncMock() + mock_client.xrevrange = AsyncMock( + return_value=[ + ('123-0', {'type': 'Message', 'payload': '{"data": "test"}'}) + ] + ) + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.get_latest_event('task123') + + assert result == { + 'id': '123-0', + 'type': 'Message', + 'payload': '{"data": "test"}', + } + mock_client.xrevrange.assert_called_once_with( + 'a2a:task:task123', '+', '-', count=1 + ) + + @pytest.mark.asyncio + async def test_get_latest_event_no_events(self): + """Test get_latest_event when no events exist.""" + mock_client = AsyncMock() + mock_client.xrevrange = AsyncMock(return_value=[]) + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.get_latest_event('task123') + + assert result is None + + @pytest.mark.asyncio + async def test_get_latest_event_exception(self): + """Test get_latest_event when exception occurs.""" + mock_client = AsyncMock() + mock_client.xrevrange = AsyncMock(side_effect=Exception('Redis error')) + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.get_latest_event('task123') + + assert result is None + + @pytest.mark.asyncio + async def test_get_events_since(self): + """Test get_events_since method.""" + mock_client = AsyncMock() + mock_client.xrange = AsyncMock( + return_value=[ + ('123-0', {'type': 'Message', 'payload': '{"data": "test1"}'}), + ('124-0', {'type': 'Status', 'payload': '{"data": "test2"}'}), + ] + ) + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.get_events_since('task123', '122-0') + + expected = [ + {'id': '123-0', 'type': 'Message', 'payload': '{"data": "test1"}'}, + {'id': '124-0', 'type': 'Status', 'payload': '{"data": "test2"}'}, + ] + assert result == expected + mock_client.xrange.assert_called_once_with( + 'a2a:task:task123', '122-0', '+' + ) + + @pytest.mark.asyncio + async def test_get_events_since_exception(self): + """Test get_events_since when exception occurs.""" + mock_client = AsyncMock() + mock_client.xrange = AsyncMock(side_effect=Exception('Redis error')) + + injector = RedisStreamInjector() + injector._client = mock_client + injector._connected = True + + result = await injector.get_events_since('task123') + + assert result == [] diff --git a/uv.lock b/uv.lock index 4dc93c18e..d64cd2d70 100644 --- a/uv.lock +++ b/uv.lock @@ -37,6 +37,9 @@ mysql = [ postgresql = [ { name = "sqlalchemy", extra = ["asyncio", "postgresql-asyncpg"] }, ] +redis = [ + { name = "redis" }, +] sql = [ { name = "sqlalchemy", extra = ["aiomysql", "aiosqlite", "asyncio", "postgresql-asyncpg"] }, ] @@ -85,6 +88,7 @@ requires-dist = [ { name = "opentelemetry-sdk", marker = "extra == 'telemetry'", specifier = ">=1.33.0" }, { name = "protobuf", specifier = ">=5.29.5" }, { name = "pydantic", specifier = ">=2.11.3" }, + { name = "redis", marker = "extra == 'redis'", specifier = ">=6.4.0" }, { name = "sqlalchemy", extras = ["aiomysql", "aiosqlite", "asyncio", "postgresql-asyncpg"], marker = "extra == 'sql'", specifier = ">=2.0.0" }, { name = "sqlalchemy", extras = ["aiomysql", "asyncio"], marker = "extra == 'mysql'", specifier = ">=2.0.0" }, { name = "sqlalchemy", extras = ["aiosqlite", "asyncio"], marker = "extra == 'sqlite'", specifier = ">=2.0.0" }, @@ -92,7 +96,7 @@ requires-dist = [ { name = "sse-starlette", marker = "extra == 'http-server'" }, { name = "starlette", marker = "extra == 'http-server'" }, ] -provides-extras = ["encryption", "grpc", "http-server", "mysql", "postgresql", "sql", "sqlite", "telemetry"] +provides-extras = ["encryption", "grpc", "http-server", "mysql", "postgresql", "redis", "sql", "sqlite", "telemetry"] [package.metadata.requires-dev] dev = [ @@ -1652,6 +1656,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/28/26534bed77109632a956977f60d8519049f545abc39215d086e33a61f1f2/pyyaml_ft-8.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:de04cfe9439565e32f178106c51dd6ca61afaa2907d143835d501d84703d3793", size = 171579, upload-time = "2025-06-10T15:32:14.34Z" }, ] +[[package]] +name = "redis" +version = "6.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" }, +] + [[package]] name = "requests" version = "2.32.4"