diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index ecdc0cfef..dbc196fc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -71,9 +71,6 @@ logger = logging.getLogger(__name__) -# TODO: cleanup context_id management - - @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandlerV2(RequestHandler): """Default request handler for all incoming requests.""" diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index 15b0e3a20..bf1c6d3af 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -2158,3 +2158,76 @@ async def cancel( ): async for _ in it: pass + + +# Scenario: Task context_id and task_id visibility +@pytest.mark.timeout(2.0) +@pytest.mark.asyncio +@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy']) +@pytest.mark.parametrize( + 'streaming', [False, True], ids=['blocking', 'streaming'] +) +async def test_scenario_context_id_visibility(use_legacy, streaming): + seen_contexts = [] + + class ContextIdAgent(AgentExecutor): + async def execute( + self, context: RequestContext, event_queue: EventQueue + ): + seen_contexts.append((context.task_id, context.context_id)) + task = new_task_from_user_message(context.message) + task.status.state = TaskState.TASK_STATE_COMPLETED + await event_queue.enqueue_event(task) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ): + pass + + handler = create_handler(ContextIdAgent(), use_legacy) + client = await create_client( + handler, agent_card=agent_card(), streaming=streaming + ) + + # 1. Create a new task (empty context_id, empty task_id) + msg1 = Message( + message_id='test-msg-1', + role=Role.ROLE_USER, + parts=[Part(text='start task 1')], + ) + it1 = client.send_message(SendMessageRequest(message=msg1)) + events1 = [event async for event in it1] + + agent_task1_id, agent_context1_id = seen_contexts[0] + client_task1_id = get_task_id(events1[-1]) + client_context1_id = get_task_context_id(events1[-1]) + + # Verify that agent can see non-empty context_id and task_id + assert agent_task1_id != '' + assert agent_context1_id != '' + + # Verify for task1 that context_id and task_id visible by agent and client are the same. + assert agent_task1_id == client_task1_id + assert agent_context1_id == client_context1_id + + # 2. Create a new task with context_id from task_1: task_2 + msg2 = Message( + message_id='test-msg-2', + context_id=client_context1_id, + role=Role.ROLE_USER, + parts=[Part(text='start task 2')], + ) + it2 = client.send_message(SendMessageRequest(message=msg2)) + events2 = [event async for event in it2] + + agent_task2_id, agent_context2_id = seen_contexts[1] + client_task2_id = get_task_id(events2[-1]) + client_context2_id = get_task_context_id(events2[-1]) + + # Verify for task2 that context_id and task_id visible by agent and client are the same. + assert agent_task2_id == client_task2_id + assert agent_context2_id == client_context2_id + + # Verify that agent can see context_id the same as in task_1 and different task id + assert agent_context2_id == agent_context1_id + assert agent_task2_id != agent_task1_id