Skip to content

Commit 3486783

Browse files
committed
DefaultRequestHandlerV2
1 parent 5d22186 commit 3486783

13 files changed

Lines changed: 4941 additions & 16 deletions

run_test_debug.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# We can just run pytest with -s to see stdout and add prints

src/a2a/server/agent_execution/active_task.py

Lines changed: 628 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
6+
from typing import TYPE_CHECKING
7+
8+
9+
if TYPE_CHECKING:
10+
from a2a.server.agent_execution.agent_executor import AgentExecutor
11+
from a2a.server.tasks.push_notification_sender import PushNotificationSender
12+
from a2a.server.tasks.task_store import TaskStore
13+
14+
from a2a.server.agent_execution.active_task import ActiveTask
15+
from a2a.server.context import ServerCallContext
16+
from a2a.server.events.event_queue_v2 import EventQueueSource
17+
from a2a.server.tasks.task_manager import TaskManager
18+
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class ActiveTaskRegistry:
24+
"""A registry for active ActiveTask instances."""
25+
26+
def __init__(
27+
self,
28+
agent_executor: AgentExecutor,
29+
task_store: TaskStore,
30+
push_sender: PushNotificationSender | None = None,
31+
):
32+
self._agent_executor = agent_executor
33+
self._task_store = task_store
34+
self._push_sender = push_sender
35+
self._active_tasks: dict[str, ActiveTask] = {}
36+
self._lock = asyncio.Lock()
37+
self._cleanup_tasks: set[asyncio.Task[None]] = set()
38+
39+
async def get_or_create(
40+
self,
41+
task_id: str,
42+
call_context: ServerCallContext,
43+
context_id: str | None = None,
44+
create_task_if_missing: bool = False,
45+
) -> ActiveTask:
46+
"""Retrieves an existing ActiveTask or creates a new one."""
47+
async with self._lock:
48+
if task_id in self._active_tasks:
49+
return self._active_tasks[task_id]
50+
51+
event_queue = EventQueueSource()
52+
task_manager = TaskManager(
53+
task_id=task_id,
54+
context_id=context_id,
55+
task_store=self._task_store,
56+
initial_message=None,
57+
context=call_context,
58+
)
59+
60+
active_task = ActiveTask(
61+
agent_executor=self._agent_executor,
62+
task_id=task_id,
63+
event_queue=event_queue,
64+
task_manager=task_manager,
65+
push_sender=self._push_sender,
66+
on_cleanup=self._on_active_task_cleanup,
67+
)
68+
self._active_tasks[task_id] = active_task
69+
70+
await active_task.start(
71+
call_context=call_context,
72+
create_task_if_missing=create_task_if_missing,
73+
)
74+
return active_task
75+
76+
def _on_active_task_cleanup(self, active_task: ActiveTask) -> None:
77+
"""Called by ActiveTask when it's finished and has no subscribers."""
78+
task = asyncio.create_task(self._remove_task(active_task.task_id))
79+
self._cleanup_tasks.add(task)
80+
task.add_done_callback(self._cleanup_tasks.discard)
81+
82+
async def _remove_task(self, task_id: str) -> None:
83+
async with self._lock:
84+
self._active_tasks.pop(task_id, None)
85+
logger.debug('Removed active task for %s from registry', task_id)
86+
87+
async def get(self, task_id: str) -> ActiveTask | None:
88+
"""Retrieves an existing task."""
89+
async with self._lock:
90+
return self._active_tasks.get(task_id)

src/a2a/server/agent_execution/agent_executor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from a2a.server.agent_execution.context import RequestContext
4-
from a2a.server.events.event_queue import EventQueue
4+
from a2a.server.events.event_queue_v2 import EventQueue
55

66

77
class AgentExecutor(ABC):
@@ -23,6 +23,18 @@ async def execute(
2323
return once the agent's execution for this request is complete or
2424
yields control (e.g., enters an input-required state).
2525
26+
TODO: Document request lifecycle and AgentExecutor responsibilities:
27+
- Should not close the event_queue.
28+
- Guarantee single execution per request (no concurrent execution).
29+
- Throwing exception will result in TaskState.TASK_STATE_ERROR (CHECK!)
30+
- Once call is completed it should not access context or event_queue
31+
- Before completing the call it SHOULD update task status to terminal or interrupted state.
32+
- Explain AUTH_REQUIRED workflow.
33+
- Explain INPUT_REQUIRED workflow.
34+
- Explain how cancelation work (executor task will be canceled, cancel() is called, order of calls, etc)
35+
- Explain if execute can wait for cancel and if cancel can wait for execute.
36+
- Explain behaviour of streaming / not-immediate when execute() returns in active state.
37+
2638
Args:
2739
context: The request context containing the message, task ID, etc.
2840
event_queue: The queue to publish events to.
@@ -38,6 +50,10 @@ async def cancel(
3850
in the context and publish a `TaskStatusUpdateEvent` with state
3951
`TaskState.TASK_STATE_CANCELED` to the `event_queue`.
4052
53+
TODO: Document cancelation workflow.
54+
- What if TaskState.TASK_STATE_CANCELED is not set by cancel() ?
55+
- How it can interact with execute() ?
56+
4157
Args:
4258
context: The request context containing the task ID to cancel.
4359
event_queue: The queue to publish the cancellation status update to.

src/a2a/server/agent_execution/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def current_task(self) -> Task | None:
120120
return self._current_task
121121

122122
@current_task.setter
123-
def current_task(self, task: Task) -> None:
123+
def current_task(self, task: Task | None) -> None:
124124
"""Sets the current task object."""
125125
self._current_task = task
126126

src/a2a/server/request_handlers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import logging
44

55
from a2a.server.request_handlers.default_request_handler import (
6-
DefaultRequestHandler,
6+
LegacyRequestHandler,
7+
)
8+
from a2a.server.request_handlers.default_request_handler_v2 import (
9+
DefaultRequestHandlerV2,
710
)
811
from a2a.server.request_handlers.request_handler import (
912
RequestHandler,
@@ -38,9 +41,13 @@ def __init__(self, *args, **kwargs):
3841
) from _original_error
3942

4043

44+
DefaultRequestHandler = DefaultRequestHandlerV2
45+
4146
__all__ = [
4247
'DefaultRequestHandler',
48+
'DefaultRequestHandlerV2',
4349
'GrpcHandler',
50+
'LegacyRequestHandler',
4451
'RequestHandler',
4552
'build_error_response',
4653
'prepare_response_object',

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474

7575

7676
@trace_class(kind=SpanKind.SERVER)
77-
class DefaultRequestHandler(RequestHandler):
77+
class LegacyRequestHandler(RequestHandler):
7878
"""Default request handler for all incoming requests.
7979
8080
This handler provides default implementations for all A2A JSON-RPC methods,

0 commit comments

Comments
 (0)