Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@
logger = logging.getLogger(__name__)


# TODO: cleanup context_id management


@trace_class(kind=SpanKind.SERVER)
class DefaultRequestHandlerV2(RequestHandler):
"""Default request handler for all incoming requests."""
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/test_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,3 +2158,76 @@ async def cancel(
):
async for _ in it:
pass


# Scenario: Task context_id and task_id visibility
@pytest.mark.timeout(2.0)
@pytest.mark.asyncio
@pytest.mark.parametrize('use_legacy', [False, True], ids=['v2', 'legacy'])
@pytest.mark.parametrize(
'streaming', [False, True], ids=['blocking', 'streaming']
)
async def test_scenario_context_id_visibility(use_legacy, streaming):
seen_contexts = []

class ContextIdAgent(AgentExecutor):
async def execute(
self, context: RequestContext, event_queue: EventQueue
):
seen_contexts.append((context.task_id, context.context_id))
task = new_task_from_user_message(context.message)
task.status.state = TaskState.TASK_STATE_COMPLETED
await event_queue.enqueue_event(task)

async def cancel(
self, context: RequestContext, event_queue: EventQueue
):
pass

handler = create_handler(ContextIdAgent(), use_legacy)
client = await create_client(
handler, agent_card=agent_card(), streaming=streaming
)

# 1. Create a new task (empty context_id, empty task_id)
msg1 = Message(
message_id='test-msg-1',
role=Role.ROLE_USER,
parts=[Part(text='start task 1')],
)
it1 = client.send_message(SendMessageRequest(message=msg1))
events1 = [event async for event in it1]

agent_task1_id, agent_context1_id = seen_contexts[0]
client_task1_id = get_task_id(events1[-1])
client_context1_id = get_task_context_id(events1[-1])

# Verify that agent can see non-empty context_id and task_id
assert agent_task1_id != ''
assert agent_context1_id != ''

# Verify for task1 that context_id and task_id visible by agent and client are the same.
assert agent_task1_id == client_task1_id
assert agent_context1_id == client_context1_id

# 2. Create a new task with context_id from task_1: task_2
msg2 = Message(
message_id='test-msg-2',
context_id=client_context1_id,
role=Role.ROLE_USER,
parts=[Part(text='start task 2')],
)
it2 = client.send_message(SendMessageRequest(message=msg2))
events2 = [event async for event in it2]

agent_task2_id, agent_context2_id = seen_contexts[1]
client_task2_id = get_task_id(events2[-1])
client_context2_id = get_task_context_id(events2[-1])

# Verify for task2 that context_id and task_id visible by agent and client are the same.
assert agent_task2_id == client_task2_id
assert agent_context2_id == client_context2_id

# Verify that agent can see context_id the same as in task_1 and different task id
assert agent_context2_id == agent_context1_id
assert agent_task2_id != agent_task1_id
Loading