Skip to content

Commit bd8917f

Browse files
committed
DefaultRequestHandlerV2: Unification of on_message methods.
1 parent a669521 commit bd8917f

7 files changed

Lines changed: 354 additions & 215 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 133 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import uuid
77

8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING, Any, cast
99

1010
from a2a.server.agent_execution.context import RequestContext
1111

@@ -56,6 +56,12 @@
5656
}
5757

5858

59+
class _RequestStarted:
60+
def __init__(self, request_id: uuid.UUID, request_context: RequestContext):
61+
self.request_id = request_id
62+
self.request_context = request_context
63+
64+
5965
class _RequestCompleted:
6066
def __init__(self, request_id: uuid.UUID):
6167
self.request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205
logger.debug('TASK (start): %s', task)
200206

201207
if task:
208+
self._task_created.set()
202209
if task.status.state in TERMINAL_TASK_STATES:
203210
raise InvalidParamsError(
204211
message=f'Task {task.id} is in terminal state: {task.status.state}'
205212
)
206-
else:
207-
if not create_task_if_missing:
208-
raise TaskNotFoundError
209-
210-
# New task. Create and save it so it's not "missing" if queried immediately
211-
# (especially important for return_immediately=True)
212-
if self._task_manager.context_id is None:
213-
raise ValueError('Context ID is required for new tasks')
214-
task = self._task_manager._init_task_obj(
215-
self._task_id,
216-
self._task_manager.context_id,
217-
)
218-
await self._task_manager.save_task_event(task)
219-
if self._push_sender:
220-
await self._push_sender.send_notification(task.id, task)
213+
elif not create_task_if_missing:
214+
raise TaskNotFoundError
221215

222216
except Exception:
223217
logger.debug(
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247
Runs as a detached asyncio.Task. Safe to cancel.
254248
"""
255249
logger.debug('Producer[%s]: Started', self._task_id)
250+
request_context = None
256251
try:
257-
active = True
258-
while active:
252+
while True:
259253
(
260254
request_context,
261255
request_id,
262256
) = await self._request_queue.get()
263257
await self._request_lock.acquire()
264258
# TODO: Should we create task manager every time?
265259
self._task_manager._call_context = request_context.call_context
260+
266261
request_context.current_task = (
267262
await self._task_manager.get_task()
268263
)
269264

270-
message = request_context.message
271-
if message:
272-
request_context.current_task = (
273-
self._task_manager.update_with_message(
274-
message,
275-
cast('Task', request_context.current_task),
276-
)
277-
)
278-
await self._task_manager.save_task_event(
279-
request_context.current_task
280-
)
281-
self._task_created.set()
282265
logger.debug(
283266
'Producer[%s]: Executing agent task %s',
284267
self._task_id,
285268
request_context.current_task,
286269
)
287270

288271
try:
272+
await self._event_queue_agent.enqueue_event(
273+
cast(
274+
'Event',
275+
_RequestStarted(request_id, request_context),
276+
)
277+
)
278+
289279
await self._agent_executor.execute(
290280
request_context, self._event_queue_agent
291281
)
292282
logger.debug(
293283
'Producer[%s]: Execution finished successfully',
294284
self._task_id,
295285
)
296-
except QueueShutDown:
297-
logger.debug(
298-
'Producer[%s]: Request queue shut down', self._task_id
299-
)
300-
raise
301-
except asyncio.CancelledError:
302-
logger.debug('Producer[%s]: Cancelled', self._task_id)
303-
raise
304-
except Exception as e:
305-
logger.exception(
306-
'Producer[%s]: Execution failed',
307-
self._task_id,
308-
)
309-
async with self._lock:
310-
await self._mark_task_as_failed(e)
311-
active = False
312286
finally:
313287
logger.debug(
314288
'Producer[%s]: Enqueuing request completed event',
315289
self._task_id,
316290
)
317-
# TODO: Hide from external consumers
318291
await self._event_queue_agent.enqueue_event(
319292
cast('Event', _RequestCompleted(request_id))
320293
)
321294
self._request_queue.task_done()
295+
except asyncio.CancelledError:
296+
logger.debug('Producer[%s]: Cancelled', self._task_id)
297+
298+
except QueueShutDown:
299+
logger.debug('Producer[%s]: Queue shut down', self._task_id)
300+
301+
except Exception as e:
302+
logger.exception(
303+
'Producer[%s]: Execution failed',
304+
self._task_id,
305+
)
306+
# Create task and mark as failed.
307+
if request_context:
308+
await self._task_manager.ensure_task_id(
309+
self._task_id,
310+
request_context.context_id or '',
311+
)
312+
self._task_created.set()
313+
async with self._lock:
314+
await self._mark_task_as_failed(e)
315+
322316
finally:
323317
self._request_queue.shutdown(immediate=True)
324318
await self._event_queue_agent.close(immediate=False)
@@ -338,6 +332,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332
`_is_finished`, unblocking all global subscribers and wait() calls.
339333
"""
340334
logger.debug('Consumer[%s]: Started', self._task_id)
335+
task_mode = None
336+
message_to_save = None
341337
try:
342338
try:
343339
try:
@@ -347,6 +343,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347343
'Consumer[%s]: Waiting for event',
348344
self._task_id,
349345
)
346+
new_task = None
350347
event = await self._event_queue_agent.dequeue_event()
351348
logger.debug(
352349
'Consumer[%s]: Dequeued event %s',
@@ -361,15 +358,60 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361358
self._task_id,
362359
)
363360
self._request_lock.release()
361+
elif isinstance(event, _RequestStarted):
362+
logger.debug(
363+
'Consumer[%s]: Request started',
364+
self._task_id,
365+
)
366+
message_to_save = event.request_context.message
367+
364368
elif isinstance(event, Message):
369+
if task_mode is not None:
370+
if task_mode:
371+
logger.error(
372+
'Received Message() object in task mode.'
373+
)
374+
else:
375+
logger.error(
376+
'Multiple Message() objects received.'
377+
)
378+
task_mode = False
365379
logger.debug(
366380
'Consumer[%s]: Setting result to Message: %s',
367381
self._task_id,
368382
event,
369383
)
370384
else:
385+
if task_mode is False:
386+
logger.error(
387+
'Received %s in message mode.',
388+
type(event).__name__,
389+
)
390+
391+
new_task = (
392+
await self._task_manager.ensure_task_id(
393+
self._task_id,
394+
event.context_id,
395+
)
396+
)
397+
398+
if message_to_save is not None:
399+
new_task = (
400+
self._task_manager.update_with_message(
401+
message_to_save,
402+
new_task,
403+
)
404+
)
405+
await self._task_manager.save_task_event(
406+
new_task
407+
)
408+
message_to_save = None
409+
410+
self._task_created.set()
411+
412+
task_mode = True
371413
# Save structural events (like TaskStatusUpdate) to DB.
372-
# TODO: Create task manager every time ?
414+
373415
self._task_manager.context_id = event.context_id
374416
await self._task_manager.process(event)
375417

@@ -432,8 +474,19 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432474
self._task_id, event
433475
)
434476
finally:
477+
if new_task is not None:
478+
new_task_copy = Task()
479+
new_task_copy.CopyFrom(new_task)
480+
new_task = new_task_copy
481+
482+
logger.debug(
483+
'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n',
484+
self._task_id,
485+
event,
486+
new_task,
487+
)
435488
await self._event_queue_subscribers.enqueue_event(
436-
event
489+
cast('Any', (event, new_task))
437490
)
438491
self._event_queue_agent.task_done()
439492
except QueueShutDown:
@@ -459,6 +512,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459512
*,
460513
request: RequestContext | None = None,
461514
include_initial_task: bool = False,
515+
replace_status_update_with_task: bool = False,
462516
) -> AsyncGenerator[Event, None]:
463517
"""Creates a queue tap and yields events as they are produced.
464518
@@ -506,9 +560,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506560

507561
# Wait for next event or task completion
508562
try:
509-
event = await asyncio.wait_for(
563+
dequeued = await asyncio.wait_for(
510564
tapped_queue.dequeue_event(), timeout=0.1
511565
)
566+
event, updated_task = cast('Any', dequeued)
567+
logger.debug(
568+
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
569+
self._task_id,
570+
event,
571+
updated_task,
572+
)
573+
if replace_status_update_with_task and isinstance(
574+
event, TaskStatusUpdateEvent
575+
):
576+
logger.debug(
577+
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
578+
self._task_id,
579+
updated_task,
580+
)
581+
event = updated_task
512582
if self._exception:
513583
raise self._exception from None
514584
if isinstance(event, _RequestCompleted):
@@ -522,6 +592,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522592
)
523593
return
524594
continue
595+
elif isinstance(event, _RequestStarted):
596+
logger.debug(
597+
'Subscriber[%s]: Request started',
598+
self._task_id,
599+
)
600+
continue
525601
except (asyncio.TimeoutError, TimeoutError):
526602
if self._is_finished.is_set():
527603
if self._exception:
@@ -545,7 +621,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545621
# Evaluate if this was the last subscriber on a finished task.
546622
await self._maybe_cleanup()
547623

548-
async def cancel(self, call_context: ServerCallContext) -> Task | Message:
624+
async def cancel(self, call_context: ServerCallContext) -> Task:
549625
"""Cancels the running active task.
550626
551627
Concurrency Guarantee:
@@ -558,11 +634,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558634
# TODO: Conflicts with call_context on the pending request.
559635
self._task_manager._call_context = call_context
560636

561-
task = await self.get_task()
637+
task = await self._task_manager.get_task()
562638
request_context = RequestContext(
563639
call_context=call_context,
564640
task_id=self._task_id,
565-
context_id=task.context_id,
641+
context_id=task.context_id if task else None,
566642
task=task,
567643
)
568644

@@ -591,7 +667,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591667
)
592668

593669
await self._is_finished.wait()
594-
return await self.get_task()
670+
task = await self._task_manager.get_task()
671+
if not task:
672+
raise RuntimeError('Task should have been created')
673+
return task
595674

596675
async def _maybe_cleanup(self) -> None:
597676
"""Triggers cleanup if task is finished and has no subscribers.

src/a2a/server/agent_execution/agent_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ async def execute(
3434
- Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc)
3535
- Explain if execute can wait for cancel and if cancel can wait for execute.
3636
- Explain behaviour of streaming / not-immediate when execute() returns in active state.
37+
- Possible workflows:
38+
- Enqueue a SINGLE Message object
39+
- Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent.
3740
3841
Args:
3942
context: The request context containing the message, task ID, etc.

0 commit comments

Comments
 (0)