Skip to content

Commit b5c3e78

Browse files
committed
DefaultRequestHandlerV2: Unification of on_message methods.
1 parent a669521 commit b5c3e78

7 files changed

Lines changed: 369 additions & 215 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 149 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import uuid
77

8-
from typing import TYPE_CHECKING, cast
8+
from typing import TYPE_CHECKING, Any, cast
99

1010
from a2a.server.agent_execution.context import RequestContext
1111

@@ -56,6 +56,12 @@
5656
}
5757

5858

59+
class _RequestStarted:
60+
def __init__(self, request_id: uuid.UUID, request_context: RequestContext):
61+
self.request_id = request_id
62+
self.request_context = request_context
63+
64+
5965
class _RequestCompleted:
6066
def __init__(self, request_id: uuid.UUID):
6167
self.request_id = request_id
@@ -199,25 +205,13 @@ async def start(
199205
logger.debug('TASK (start): %s', task)
200206

201207
if task:
208+
self._task_created.set()
202209
if task.status.state in TERMINAL_TASK_STATES:
203210
raise InvalidParamsError(
204211
message=f'Task {task.id} is in terminal state: {task.status.state}'
205212
)
206-
else:
207-
if not create_task_if_missing:
208-
raise TaskNotFoundError
209-
210-
# New task. Create and save it so it's not "missing" if queried immediately
211-
# (especially important for return_immediately=True)
212-
if self._task_manager.context_id is None:
213-
raise ValueError('Context ID is required for new tasks')
214-
task = self._task_manager._init_task_obj(
215-
self._task_id,
216-
self._task_manager.context_id,
217-
)
218-
await self._task_manager.save_task_event(task)
219-
if self._push_sender:
220-
await self._push_sender.send_notification(task.id, task)
213+
elif not create_task_if_missing:
214+
raise TaskNotFoundError
221215

222216
except Exception:
223217
logger.debug(
@@ -253,72 +247,72 @@ async def _run_producer(self) -> None:
253247
Runs as a detached asyncio.Task. Safe to cancel.
254248
"""
255249
logger.debug('Producer[%s]: Started', self._task_id)
250+
request_context = None
256251
try:
257-
active = True
258-
while active:
252+
while True:
259253
(
260254
request_context,
261255
request_id,
262256
) = await self._request_queue.get()
263257
await self._request_lock.acquire()
264258
# TODO: Should we create task manager every time?
265259
self._task_manager._call_context = request_context.call_context
260+
266261
request_context.current_task = (
267262
await self._task_manager.get_task()
268263
)
269264

270-
message = request_context.message
271-
if message:
272-
request_context.current_task = (
273-
self._task_manager.update_with_message(
274-
message,
275-
cast('Task', request_context.current_task),
276-
)
277-
)
278-
await self._task_manager.save_task_event(
279-
request_context.current_task
280-
)
281-
self._task_created.set()
282265
logger.debug(
283266
'Producer[%s]: Executing agent task %s',
284267
self._task_id,
285268
request_context.current_task,
286269
)
287270

288271
try:
272+
await self._event_queue_agent.enqueue_event(
273+
cast(
274+
'Event',
275+
_RequestStarted(request_id, request_context),
276+
)
277+
)
278+
289279
await self._agent_executor.execute(
290280
request_context, self._event_queue_agent
291281
)
292282
logger.debug(
293283
'Producer[%s]: Execution finished successfully',
294284
self._task_id,
295285
)
296-
except QueueShutDown:
297-
logger.debug(
298-
'Producer[%s]: Request queue shut down', self._task_id
299-
)
300-
raise
301-
except asyncio.CancelledError:
302-
logger.debug('Producer[%s]: Cancelled', self._task_id)
303-
raise
304-
except Exception as e:
305-
logger.exception(
306-
'Producer[%s]: Execution failed',
307-
self._task_id,
308-
)
309-
async with self._lock:
310-
await self._mark_task_as_failed(e)
311-
active = False
312286
finally:
313287
logger.debug(
314288
'Producer[%s]: Enqueuing request completed event',
315289
self._task_id,
316290
)
317-
# TODO: Hide from external consumers
318291
await self._event_queue_agent.enqueue_event(
319292
cast('Event', _RequestCompleted(request_id))
320293
)
321294
self._request_queue.task_done()
295+
except asyncio.CancelledError:
296+
logger.debug('Producer[%s]: Cancelled', self._task_id)
297+
298+
except QueueShutDown:
299+
logger.debug('Producer[%s]: Queue shut down', self._task_id)
300+
301+
except Exception as e:
302+
logger.exception(
303+
'Producer[%s]: Execution failed',
304+
self._task_id,
305+
)
306+
# Create task and mark as failed.
307+
if request_context:
308+
await self._task_manager.ensure_task_id(
309+
self._task_id,
310+
request_context.context_id or '',
311+
)
312+
self._task_created.set()
313+
async with self._lock:
314+
await self._mark_task_as_failed(e)
315+
322316
finally:
323317
self._request_queue.shutdown(immediate=True)
324318
await self._event_queue_agent.close(immediate=False)
@@ -338,6 +332,10 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
338332
`_is_finished`, unblocking all global subscribers and wait() calls.
339333
"""
340334
logger.debug('Consumer[%s]: Started', self._task_id)
335+
task_mode = None
336+
message_to_save = None
337+
# TODO: Make helper methods
338+
# TODO: Support Task enqueue
341339
try:
342340
try:
343341
try:
@@ -347,6 +345,7 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
347345
'Consumer[%s]: Waiting for event',
348346
self._task_id,
349347
)
348+
new_task = None
350349
event = await self._event_queue_agent.dequeue_event()
351350
logger.debug(
352351
'Consumer[%s]: Dequeued event %s',
@@ -361,24 +360,78 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
361360
self._task_id,
362361
)
363362
self._request_lock.release()
363+
elif isinstance(event, _RequestStarted):
364+
logger.debug(
365+
'Consumer[%s]: Request started',
366+
self._task_id,
367+
)
368+
message_to_save = event.request_context.message
369+
364370
elif isinstance(event, Message):
371+
if task_mode is not None:
372+
if task_mode:
373+
logger.error(
374+
'Received Message() object in task mode.'
375+
)
376+
else:
377+
logger.error(
378+
'Multiple Message() objects received.'
379+
)
380+
task_mode = False
365381
logger.debug(
366382
'Consumer[%s]: Setting result to Message: %s',
367383
self._task_id,
368384
event,
369385
)
370386
else:
387+
if task_mode is False:
388+
logger.error(
389+
'Received %s in message mode.',
390+
type(event).__name__,
391+
)
392+
393+
if isinstance(event, Task):
394+
new_task = event
395+
await self._task_manager.save_task_event(
396+
new_task
397+
)
398+
else:
399+
new_task = (
400+
await self._task_manager.ensure_task_id(
401+
self._task_id,
402+
event.context_id,
403+
)
404+
)
405+
406+
if message_to_save is not None:
407+
new_task = self._task_manager.update_with_message(
408+
message_to_save,
409+
new_task,
410+
)
411+
await (
412+
self._task_manager.save_task_event(
413+
new_task
414+
)
415+
)
416+
message_to_save = None
417+
418+
task_mode = True
371419
# Save structural events (like TaskStatusUpdate) to DB.
372-
# TODO: Create task manager every time ?
420+
373421
self._task_manager.context_id = event.context_id
374-
await self._task_manager.process(event)
422+
if not isinstance(event, Task):
423+
await self._task_manager.process(event)
424+
425+
self._task_created.set()
375426

376427
# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
377428
new_task = await self._task_manager.get_task()
378429
if new_task is None:
379430
raise RuntimeError(
380431
f'Task {self.task_id} not found'
381432
)
433+
if isinstance(event, Task):
434+
event = new_task
382435
is_interrupted = (
383436
new_task.status.state
384437
in INTERRUPTED_TASK_STATES
@@ -432,8 +485,23 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
432485
self._task_id, event
433486
)
434487
finally:
488+
if new_task is not None:
489+
new_task_copy = Task()
490+
new_task_copy.CopyFrom(new_task)
491+
new_task = new_task_copy
492+
if isinstance(event, Task):
493+
new_task_copy = Task()
494+
new_task_copy.CopyFrom(event)
495+
event = new_task_copy
496+
497+
logger.debug(
498+
'Consumer[%s]: Enqueuing\nEvent: %s\nNew Task: %s\n',
499+
self._task_id,
500+
event,
501+
new_task,
502+
)
435503
await self._event_queue_subscribers.enqueue_event(
436-
event
504+
cast('Any', (event, new_task))
437505
)
438506
self._event_queue_agent.task_done()
439507
except QueueShutDown:
@@ -459,6 +527,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
459527
*,
460528
request: RequestContext | None = None,
461529
include_initial_task: bool = False,
530+
replace_status_update_with_task: bool = False,
462531
) -> AsyncGenerator[Event, None]:
463532
"""Creates a queue tap and yields events as they are produced.
464533
@@ -506,9 +575,25 @@ async def subscribe( # noqa: PLR0912, PLR0915
506575

507576
# Wait for next event or task completion
508577
try:
509-
event = await asyncio.wait_for(
578+
dequeued = await asyncio.wait_for(
510579
tapped_queue.dequeue_event(), timeout=0.1
511580
)
581+
event, updated_task = cast('Any', dequeued)
582+
logger.debug(
583+
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
584+
self._task_id,
585+
event,
586+
updated_task,
587+
)
588+
if replace_status_update_with_task and isinstance(
589+
event, TaskStatusUpdateEvent
590+
):
591+
logger.debug(
592+
'Subscriber[%s]: Replacing TaskStatusUpdateEvent with Task: %s',
593+
self._task_id,
594+
updated_task,
595+
)
596+
event = updated_task
512597
if self._exception:
513598
raise self._exception from None
514599
if isinstance(event, _RequestCompleted):
@@ -522,6 +607,12 @@ async def subscribe( # noqa: PLR0912, PLR0915
522607
)
523608
return
524609
continue
610+
elif isinstance(event, _RequestStarted):
611+
logger.debug(
612+
'Subscriber[%s]: Request started',
613+
self._task_id,
614+
)
615+
continue
525616
except (asyncio.TimeoutError, TimeoutError):
526617
if self._is_finished.is_set():
527618
if self._exception:
@@ -545,7 +636,7 @@ async def subscribe( # noqa: PLR0912, PLR0915
545636
# Evaluate if this was the last subscriber on a finished task.
546637
await self._maybe_cleanup()
547638

548-
async def cancel(self, call_context: ServerCallContext) -> Task | Message:
639+
async def cancel(self, call_context: ServerCallContext) -> Task:
549640
"""Cancels the running active task.
550641
551642
Concurrency Guarantee:
@@ -558,11 +649,11 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
558649
# TODO: Conflicts with call_context on the pending request.
559650
self._task_manager._call_context = call_context
560651

561-
task = await self.get_task()
652+
task = await self._task_manager.get_task()
562653
request_context = RequestContext(
563654
call_context=call_context,
564655
task_id=self._task_id,
565-
context_id=task.context_id,
656+
context_id=task.context_id if task else None,
566657
task=task,
567658
)
568659

@@ -591,7 +682,10 @@ async def cancel(self, call_context: ServerCallContext) -> Task | Message:
591682
)
592683

593684
await self._is_finished.wait()
594-
return await self.get_task()
685+
task = await self._task_manager.get_task()
686+
if not task:
687+
raise RuntimeError('Task should have been created')
688+
return task
595689

596690
async def _maybe_cleanup(self) -> None:
597691
"""Triggers cleanup if task is finished and has no subscribers.

src/a2a/server/agent_execution/agent_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ async def execute(
3434
- Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc)
3535
- Explain if execute can wait for cancel and if cancel can wait for execute.
3636
- Explain behaviour of streaming / not-immediate when execute() returns in active state.
37+
- Possible workflows:
38+
- Enqueue a SINGLE Message object
39+
- Enqueue TaskStatusUpdateEvent (TASK_STATE_SUBMITTED or TASK_STATE_REJECTED) and continue with TaskStatusUpdateEvent / TaskArtifactUpdateEvent.
3740
3841
Args:
3942
context: The request context containing the message, task ID, etc.

0 commit comments

Comments
 (0)