66Root cause of vulnerability:
77 _setup_message_execution() uses params.message.context_id directly without
88 any ownership check. An attacker who knows a victim's contextId can send a
9- new task under that context — task_manager.get_task() returns None for the
9+ new task under that context -- task_manager.get_task() returns None for the
1010 new task_id, so the original task-level check is never reached.
1111
1212Fix design:
13- DefaultRequestHandler maintains a _context_owners dict (context_id → owner)
13+ DefaultRequestHandler maintains a _context_owners dict (context_id -> owner)
1414 in memory. When a get_caller_id extractor is configured:
1515 1. On first message for a context_id: record caller as owner.
1616 2. On subsequent messages for same context_id: verify caller matches owner.
17- If get_caller_id is None (default): no ownership tracking — backward compatible.
17+ If get_caller_id is None (default): no ownership tracking -- backward compatible.
1818
1919Target file: src/a2a/server/request_handlers/default_request_handler.py
2020"""
@@ -118,7 +118,7 @@ def __init__( # noqa: PLR0913
118118 fingerprint). When provided, the handler tracks which caller
119119 created each contextId and rejects messages from different
120120 callers attempting to join that context (A2A-INJ-01 fix).
121- If None (default), no ownership tracking is performed —
121+ If None (default), no ownership tracking is performed --
122122 backward compatible with existing deployments.
123123
124124 Example::
@@ -147,8 +147,15 @@ def get_caller_id(ctx: ServerCallContext | None) -> str | None:
147147 )
148148 # ---- NEW (fix for A2A-INJ-01) ----
149149 self ._get_caller_id : CallerIdExtractor | None = get_caller_id
150- # Maps context_id → owner identity; populated on first message per context.
150+ # Maps context_id -> owner identity; populated on first message per context.
151151 self ._context_owners : dict [str , str ] = {}
152+ if get_caller_id is None :
153+ logger .warning (
154+ 'DefaultRequestHandler initialized without get_caller_id: '
155+ 'context ownership is not enforced. Cross-user context injection '
156+ '(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id '
157+ 'extractor to enable ownership checks.'
158+ )
152159 # ----------------------------------
153160 self ._running_agents = {}
154161 self ._running_agents_lock = asyncio .Lock ()
@@ -168,7 +175,10 @@ async def on_get_task(
168175 async def on_cancel_task (
169176 self , params : TaskIdParams , context : ServerCallContext | None = None
170177 ) -> Task | None :
171- """Default handler for 'tasks/cancel'."""
178+ """Default handler for 'tasks/cancel'.
179+
180+ Attempts to cancel the task managed by the `AgentExecutor`.
181+ """
172182 task : Task | None = await self .task_store .get (params .id , context )
173183 if not task :
174184 raise ServerError (error = TaskNotFoundError ())
@@ -225,6 +235,12 @@ async def on_cancel_task(
225235 async def _run_event_stream (
226236 self , request : RequestContext , queue : EventQueue
227237 ) -> None :
238+ """Runs the agent's `execute` method and closes the queue afterwards.
239+
240+ Args:
241+ request: The request context for the agent.
242+ queue: The event queue for the agent to publish to.
243+ """
228244 await self .agent_executor .execute (request , queue )
229245 await queue .close ()
230246
@@ -236,32 +252,13 @@ def _check_context_ownership(
236252 """Enforce context ownership when get_caller_id is configured.
237253
238254 Called before any message is processed for an existing context_id.
255+ Only invoked when context_id is already present in _context_owners,
256+ which guarantees _get_caller_id is not None and owner is not None.
239257 Raises ServerError(InvalidParamsError) if the caller does not own
240258 the context.
241259 """
242- if self ._get_caller_id is None :
243- # Ownership tracking not configured — log warning and allow.
244- # Operators should configure get_caller_id in production.
245- logger .warning (
246- 'Context ownership not enforced for context_id=%s: '
247- 'no get_caller_id configured on DefaultRequestHandler. '
248- 'This allows cross-user context injection (A2A-INJ-01 / CWE-639). '
249- 'Provide a get_caller_id extractor to enable ownership checks.' ,
250- context_id ,
251- )
252- return
253-
254- caller = self ._get_caller_id (context )
255- owner = self ._context_owners .get (context_id )
256-
257- if owner is None :
258- # Context exists in the store but ownership was not recorded
259- # (e.g. created before this patch was deployed). Skip check.
260- logger .debug (
261- 'context_id=%s has no recorded owner; skipping ownership check.' ,
262- context_id ,
263- )
264- return
260+ caller = self ._get_caller_id (context ) # type: ignore[misc]
261+ owner = self ._context_owners [context_id ]
265262
266263 if caller is None :
267264 raise ServerError (
@@ -308,10 +305,10 @@ async def _setup_message_execution(
308305 ) -> tuple [TaskManager , str , EventQueue , ResultAggregator , asyncio .Task ]:
309306 context_id = params .message .context_id
310307
311- # ---- FIX: A2A-INJ-01 — enforce context ownership BEFORE task lookup ----
308+ # ---- FIX: A2A-INJ-01 -- enforce context ownership BEFORE task lookup ----
312309 # The check must happen at context_id level, not task level. An attacker
313310 # who sends a new task_id under an existing context_id would otherwise
314- # bypass a task-level check (get_task() returns None → check never runs).
311+ # bypass a task-level check (get_task() returns None -> check never runs).
315312 if context_id and context_id in self ._context_owners :
316313 self ._check_context_ownership (context_id , context )
317314 # -----------------------------------------------------------------------
@@ -396,7 +393,11 @@ async def on_message_send(
396393 params : MessageSendParams ,
397394 context : ServerCallContext | None = None ,
398395 ) -> Message | Task :
399- """Default handler for 'message/send' (non-streaming)."""
396+ """Default handler for 'message/send' interface (non-streaming).
397+
398+ Starts the agent execution for the message and waits for the final
399+ result (Task or Message).
400+ """
400401 (
401402 _task_manager ,
402403 task_id ,
@@ -461,7 +462,11 @@ async def on_message_send_stream(
461462 params : MessageSendParams ,
462463 context : ServerCallContext | None = None ,
463464 ) -> AsyncGenerator [Event ]:
464- """Default handler for 'message/stream' (streaming)."""
465+ """Default handler for 'message/stream' (streaming).
466+
467+ Starts the agent execution and yields events as they are produced
468+ by the agent.
469+ """
465470 (
466471 _task_manager ,
467472 task_id ,
0 commit comments