Skip to content

Commit 5cd06bd

Browse files
committed
Add extra EventQueue between AgentExecutor and subscribers.
1 parent 966745c commit 5cd06bd

8 files changed

Lines changed: 80 additions & 91 deletions

File tree

run_test_debug.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/a2a/server/agent_execution/active_task.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
from a2a.server.agent_execution.agent_executor import AgentExecutor
1717
from a2a.server.context import ServerCallContext
18-
from a2a.server.events.event_queue_v2 import Event, EventQueue
1918
from a2a.server.tasks.push_notification_sender import (
2019
PushNotificationSender,
2120
)
2221
from a2a.server.tasks.task_manager import TaskManager
2322

24-
2523
from a2a.server.events.event_queue_v2 import (
2624
AsyncQueue,
25+
Event,
26+
EventQueueSource,
2727
QueueShutDown,
2828
_create_async_queue,
2929
)
@@ -48,6 +48,10 @@
4848
TaskState.TASK_STATE_FAILED,
4949
TaskState.TASK_STATE_REJECTED,
5050
}
51+
INTERRUPTED_TASK_STATES = {
52+
TaskState.TASK_STATE_AUTH_REQUIRED,
53+
TaskState.TASK_STATE_INPUT_REQUIRED,
54+
}
5155

5256

5357
class _RequestCompleted:
@@ -76,11 +80,10 @@ class ActiveTask:
7680
permanently ceased execution and closed its queues.
7781
"""
7882

79-
def __init__( # noqa: PLR0913
83+
def __init__(
8084
self,
8185
agent_executor: AgentExecutor,
8286
task_id: str,
83-
event_queue: EventQueue,
8487
task_manager: TaskManager,
8588
push_sender: PushNotificationSender | None = None,
8689
on_cleanup: Callable[[ActiveTask], None] | None = None,
@@ -90,8 +93,6 @@ def __init__( # noqa: PLR0913
9093
Args:
9194
agent_executor: The executor to run the agent logic (producer).
9295
task_id: The unique identifier of the task being managed.
93-
event_queue: The queue for events produced by the agent. Acts as the pipe
94-
between the producer and consumer tasks.
9596
task_manager: The manager for task state and database persistence.
9697
push_sender: Optional sender for out-of-band push notifications.
9798
on_cleanup: Optional callback triggered when the task is fully finished
@@ -101,7 +102,10 @@ def __init__( # noqa: PLR0913
101102
# --- Core Dependencies ---
102103
self._agent_executor = agent_executor
103104
self._task_id = task_id
104-
self._event_queue = event_queue
105+
self._event_queue_agent = EventQueueSource()
106+
self._event_queue_subscribers = EventQueueSource(
107+
create_default_sink=False
108+
)
105109
self._task_manager = task_manager
106110
self._push_sender = push_sender
107111
self._on_cleanup = on_cleanup
@@ -284,7 +288,7 @@ async def _run_producer(self) -> None:
284288

285289
try:
286290
await self._agent_executor.execute(
287-
request_context, self._event_queue
291+
request_context, self._event_queue_agent
288292
)
289293
logger.debug(
290294
'Producer[%s]: Execution finished successfully',
@@ -301,7 +305,7 @@ async def _run_producer(self) -> None:
301305
self._task_id,
302306
)
303307
# TODO: Hide from external consumers
304-
await self._event_queue.enqueue_event(
308+
await self._event_queue_agent.enqueue_event(
305309
cast('Event', _RequestCompleted(request_id))
306310
)
307311
self._request_queue.task_done()
@@ -319,7 +323,8 @@ async def _run_producer(self) -> None:
319323
self._exception = e
320324
finally:
321325
self._request_queue.shutdown(immediate=True)
322-
await self._event_queue.close(immediate=False)
326+
await self._event_queue_agent.close(immediate=False)
327+
await self._event_queue_subscribers.close(immediate=False)
323328
finally:
324329
logger.debug('Producer[%s]: Completed', self._task_id)
325330

@@ -345,7 +350,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
345350
'Consumer[%s]: Waiting for event',
346351
self._task_id,
347352
)
348-
event = await self._event_queue.dequeue_event()
353+
event = await self._event_queue_agent.dequeue_event()
349354
logger.debug(
350355
'Consumer[%s]: Dequeued event %s',
351356
self._task_id,
@@ -373,15 +378,14 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
373378

374379
# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
375380
res = await self._task_manager.get_task()
376-
is_interrupted = res and res.status.state in (
377-
TaskState.TASK_STATE_AUTH_REQUIRED,
378-
TaskState.TASK_STATE_INPUT_REQUIRED,
381+
is_interrupted = (
382+
res
383+
and res.status.state
384+
in INTERRUPTED_TASK_STATES
379385
)
380-
is_terminal = res and res.status.state in (
381-
TaskState.TASK_STATE_COMPLETED,
382-
TaskState.TASK_STATE_CANCELED,
383-
TaskState.TASK_STATE_FAILED,
384-
TaskState.TASK_STATE_REJECTED,
386+
is_terminal = (
387+
res
388+
and res.status.state in TERMINAL_TASK_STATES
385389
)
386390

387391
# If we hit a breakpoint or terminal state, lock in the result.
@@ -427,9 +431,11 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
427431
await self._push_sender.send_notification(
428432
self._task_id, event
429433
)
430-
431434
finally:
432-
self._event_queue.task_done()
435+
await self._event_queue_subscribers.enqueue_event(
436+
event
437+
)
438+
self._event_queue_agent.task_done()
433439
except QueueShutDown:
434440
logger.debug(
435441
'Consumer[%s]: Event queue shut down', self._task_id
@@ -482,7 +488,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
482488
self._reference_count,
483489
)
484490

485-
tapped_queue = await self._event_queue.tap()
491+
tapped_queue = await self._event_queue_subscribers.tap()
486492
request_id = await self.enqueue_request(request) if request else None
487493

488494
try:
@@ -562,17 +568,14 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
562568
)
563569

564570
async with self._lock:
565-
# if self._producer_task and not self._producer_task.done():
566571
if not self._is_finished.is_set() and self._producer_task:
567572
logger.debug(
568573
'Cancel[%s]: Cancelling producer task', self._task_id
569574
)
570-
# We do NOT await self._agent_executor.cancel here
571-
# because it might take a while and we want to block on `_is_finished.wait()`
572575
self._producer_task.cancel()
573576
try:
574577
await self._agent_executor.cancel(
575-
request_context, self._event_queue
578+
request_context, self._event_queue_agent
576579
)
577580
except Exception as e:
578581
logger.exception(

src/a2a/server/agent_execution/active_task_registry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from a2a.server.tasks.task_store import TaskStore
1414

1515
from a2a.server.agent_execution.active_task import ActiveTask
16-
from a2a.server.events.event_queue_v2 import EventQueueSource
1716
from a2a.server.tasks.task_manager import TaskManager
1817

1918

@@ -48,7 +47,6 @@ async def get_or_create(
4847
if task_id in self._active_tasks:
4948
return self._active_tasks[task_id]
5049

51-
event_queue = EventQueueSource()
5250
task_manager = TaskManager(
5351
task_id=task_id,
5452
context_id=context_id,
@@ -60,7 +58,6 @@ async def get_or_create(
6058
active_task = ActiveTask(
6159
agent_executor=self._agent_executor,
6260
task_id=task_id,
63-
event_queue=event_queue,
6461
task_manager=task_manager,
6562
push_sender=self._push_sender,
6663
on_cleanup=self._on_active_task_cleanup,
@@ -75,6 +72,7 @@ async def get_or_create(
7572

7673
def _on_active_task_cleanup(self, active_task: ActiveTask) -> None:
7774
"""Called by ActiveTask when it's finished and has no subscribers."""
75+
logger.debug('Active task %s cleanup scheduled', active_task.task_id)
7876
task = asyncio.create_task(self._remove_task(active_task.task_id))
7977
self._cleanup_tasks.add(task)
8078
task.add_done_callback(self._cleanup_tasks.discard)

src/a2a/server/events/event_queue_v2.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ class EventQueueSource(EventQueue):
2828
in `_incoming_queue` and distributed to all child Sinks by a background dispatcher task.
2929
"""
3030

31-
def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
31+
def __init__(
32+
self,
33+
max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE,
34+
create_default_sink: bool = True,
35+
) -> None:
3236
"""Initializes the EventQueueSource."""
3337
if max_queue_size <= 0:
3438
raise ValueError('max_queue_size must be greater than 0')
@@ -41,10 +45,15 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
4145
self._is_closed = False
4246

4347
# Internal sink for backward compatibility
44-
self._default_sink = EventQueueSink(
45-
parent=self, max_queue_size=max_queue_size
46-
)
47-
self._sinks.add(self._default_sink)
48+
self._default_sink: EventQueueSink | None
49+
if create_default_sink:
50+
self._default_sink = EventQueueSink(
51+
parent=self, max_queue_size=max_queue_size
52+
)
53+
self._sinks.add(self._default_sink)
54+
else:
55+
self._default_sink = None
56+
4857
self._dispatcher_task = asyncio.create_task(self._dispatch_loop())
4958

5059
self._dispatcher_task_expected_to_cancel = False
@@ -54,6 +63,8 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
5463
@property
5564
def queue(self) -> AsyncQueue[Event]:
5665
"""Returns the underlying asyncio.Queue of the default sink."""
66+
if self._default_sink is None:
67+
raise ValueError('No default sink available.')
5768
return self._default_sink.queue
5869

5970
async def _dispatch_loop(self) -> None:
@@ -183,10 +194,14 @@ async def enqueue_event(self, event: Event) -> None:
183194

184195
async def dequeue_event(self) -> Event:
185196
"""Dequeues an event from the default internal sink queue."""
197+
if self._default_sink is None:
198+
raise ValueError('No default sink available.')
186199
return await self._default_sink.dequeue_event()
187200

188201
def task_done(self) -> None:
189202
"""Signals that a formerly enqueued task is complete via the default internal sink queue."""
203+
if self._default_sink is None:
204+
raise ValueError('No default sink available.')
190205
self._default_sink.task_done()
191206

192207
async def close(self, immediate: bool = False) -> None:

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
RequestContextBuilder,
1212
SimpleRequestContextBuilder,
1313
)
14+
from a2a.server.agent_execution.active_task import (
15+
INTERRUPTED_TASK_STATES,
16+
TERMINAL_TASK_STATES,
17+
)
1418
from a2a.server.agent_execution.active_task_registry import ActiveTaskRegistry
1519
from a2a.server.request_handlers.request_handler import (
1620
RequestHandler,
@@ -30,7 +34,6 @@
3034
SubscribeToTaskRequest,
3135
Task,
3236
TaskPushNotificationConfig,
33-
TaskState,
3437
TaskStatusUpdateEvent,
3538
)
3639
from a2a.utils.errors import (
@@ -63,12 +66,6 @@
6366

6467
logger = logging.getLogger(__name__)
6568

66-
TERMINAL_TASK_STATES = {
67-
TaskState.TASK_STATE_COMPLETED,
68-
TaskState.TASK_STATE_CANCELED,
69-
TaskState.TASK_STATE_FAILED,
70-
TaskState.TASK_STATE_REJECTED,
71-
}
7269

7370
# TODO: cleanup context_id management
7471

@@ -241,14 +238,7 @@ async def on_message_send( # noqa: D102
241238
return task
242239

243240
try:
244-
result_states = {
245-
TaskState.TASK_STATE_COMPLETED,
246-
TaskState.TASK_STATE_FAILED,
247-
TaskState.TASK_STATE_CANCELED,
248-
TaskState.TASK_STATE_REJECTED,
249-
TaskState.TASK_STATE_INPUT_REQUIRED,
250-
TaskState.TASK_STATE_AUTH_REQUIRED,
251-
}
241+
result_states = TERMINAL_TASK_STATES | INTERRUPTED_TASK_STATES
252242

253243
result = None
254244
async for event in active_task.subscribe(request=request_context):

tests/integration/test_scenarios.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from a2a.server.request_handlers.default_request_handler import (
2121
LegacyRequestHandler,
2222
)
23-
from a2a.server.routes import CallContextBuilder
23+
from a2a.server.request_handlers import GrpcServerCallContextBuilder
2424
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
2525
from a2a.types import a2a_pb2_grpc
2626
from a2a.types.a2a_pb2 import (
@@ -89,7 +89,7 @@ def user_name(self) -> str:
8989
return 'test-user'
9090

9191

92-
class MockCallContextBuilder(CallContextBuilder):
92+
class MockCallContextBuilder(GrpcServerCallContextBuilder):
9393
def build(self, request: Any) -> ServerCallContext:
9494
return ServerCallContext(
9595
user=MockUser(), state={'headers': {'a2a-version': '1.0'}}
@@ -996,6 +996,7 @@ async def cancel(
996996
await task1
997997
else:
998998
await task1
999+
9991000
await task2
10001001

10011002
# Consume remaining events if any

tests/server/agent_execution/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)