Skip to content

Commit a9a4a1d

Browse files
committed
Add integration tests for context id preservation.
1 parent c0c6c08 commit a9a4a1d

2 files changed

Lines changed: 73 additions & 3 deletions

File tree

src/a2a/server/request_handlers/default_request_handler_v2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@
7171
logger = logging.getLogger(__name__)
7272

7373

74-
# TODO: cleanup context_id management
75-
76-
7774
@trace_class(kind=SpanKind.SERVER)
7875
class DefaultRequestHandlerV2(RequestHandler):
7976
"""Default request handler for all incoming requests."""

tests/integration/test_scenarios.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,3 +2158,76 @@ async def cancel(
21582158
):
21592159
async for _ in it:
21602160
pass
2161+
2162+
2163+
# Scenario: Task context_id and task_id visibility
2164+
@pytest.mark.timeout(2.0)
2165+
@pytest.mark.asyncio
2166+
@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy'])
2167+
@pytest.mark.parametrize(
2168+
'streaming', [False, True], ids=['blocking', 'streaming']
2169+
)
2170+
async def test_scenario_context_id_visibility(use_legacy, streaming):
2171+
seen_contexts = []
2172+
2173+
class ContextIdAgent(AgentExecutor):
2174+
async def execute(
2175+
self, context: RequestContext, event_queue: EventQueue
2176+
):
2177+
seen_contexts.append((context.task_id, context.context_id))
2178+
task = new_task_from_user_message(context.message)
2179+
task.status.state = TaskState.TASK_STATE_COMPLETED
2180+
await event_queue.enqueue_event(task)
2181+
2182+
async def cancel(
2183+
self, context: RequestContext, event_queue: EventQueue
2184+
):
2185+
pass
2186+
2187+
handler = create_handler(ContextIdAgent(), use_legacy)
2188+
client = await create_client(
2189+
handler, agent_card=agent_card(), streaming=streaming
2190+
)
2191+
2192+
# 1. Create a new task (empty context_id, empty task_id)
2193+
msg1 = Message(
2194+
message_id='test-msg-1',
2195+
role=Role.ROLE_USER,
2196+
parts=[Part(text='start task 1')],
2197+
)
2198+
it1 = client.send_message(SendMessageRequest(message=msg1))
2199+
events1 = [event async for event in it1]
2200+
2201+
agent_task1_id, agent_context1_id = seen_contexts[0]
2202+
client_task1_id = get_task_id(events1[-1])
2203+
client_context1_id = get_task_context_id(events1[-1])
2204+
2205+
# Verify that agent can see non-empty context_id and task_id
2206+
assert agent_task1_id != ''
2207+
assert agent_context1_id != ''
2208+
2209+
# Verify for task1 that context_id and task_id visible by agent and client are the same.
2210+
assert agent_task1_id == client_task1_id
2211+
assert agent_context1_id == client_context1_id
2212+
2213+
# 2. Create a new task with context_id from task_1: task_2
2214+
msg2 = Message(
2215+
message_id='test-msg-2',
2216+
context_id=client_context1_id,
2217+
role=Role.ROLE_USER,
2218+
parts=[Part(text='start task 2')],
2219+
)
2220+
it2 = client.send_message(SendMessageRequest(message=msg2))
2221+
events2 = [event async for event in it2]
2222+
2223+
agent_task2_id, agent_context2_id = seen_contexts[1]
2224+
client_task2_id = get_task_id(events2[-1])
2225+
client_context2_id = get_task_context_id(events2[-1])
2226+
2227+
# Verify for task2 that context_id and task_id visible by agent and client are the same.
2228+
assert agent_task2_id == client_task2_id
2229+
assert agent_context2_id == client_context2_id
2230+
2231+
# Verify that agent can see context_id the same as in task_1 and different task id
2232+
assert agent_context2_id == agent_context1_id
2233+
assert agent_task2_id != agent_task1_id

0 commit comments

Comments
 (0)