@@ -113,6 +113,22 @@ def agent_card():
113113 )
114114
115115
116+ def get_task_id (event ):
117+ if event .HasField ('task' ):
118+ return event .task .id
119+ if event .HasField ('status_update' ):
120+ return event .status_update .task_id
121+ assert False , f'Event { event } has no task_id'
122+
123+
124+ def get_task_context_id (event ):
125+ if event .HasField ('task' ):
126+ return event .task .context_id
127+ if event .HasField ('status_update' ):
128+ return event .status_update .context_id
129+ assert False , f'Event { event } has no context_id'
130+
131+
116132def get_state (event ):
117133 if event .HasField ('task' ):
118134 return event .task .status .state
@@ -1265,6 +1281,93 @@ async def cancel(
12651281 )
12661282
12671283
1284+ # Scenario: Auth required and in channel unblocking
1285+ @pytest .mark .timeout (2.0 )
1286+ @pytest .mark .asyncio
1287+ @pytest .mark .parametrize ('use_legacy' , [False , True ], ids = ['v2' , 'legacy' ])
1288+ @pytest .mark .parametrize (
1289+ 'streaming' , [False , True ], ids = ['blocking' , 'streaming' ]
1290+ )
1291+ async def test_scenario_auth_required_in_channel (use_legacy , streaming ):
1292+ class AuthAgent (AgentExecutor ):
1293+ async def execute (
1294+ self , context : RequestContext , event_queue : EventQueue
1295+ ):
1296+ message = context .message
1297+ if message and message .parts and message .parts [0 ].text == 'start' :
1298+ await event_queue .enqueue_event (
1299+ TaskStatusUpdateEvent (
1300+ task_id = context .task_id ,
1301+ context_id = context .context_id ,
1302+ status = TaskStatus (
1303+ state = TaskState .TASK_STATE_AUTH_REQUIRED
1304+ ),
1305+ )
1306+ )
1307+ elif (
1308+ message
1309+ and message .parts
1310+ and message .parts [0 ].text == 'credentials'
1311+ ):
1312+ await event_queue .enqueue_event (
1313+ TaskStatusUpdateEvent (
1314+ task_id = context .task_id ,
1315+ context_id = context .context_id ,
1316+ status = TaskStatus (state = TaskState .TASK_STATE_COMPLETED ),
1317+ )
1318+ )
1319+ else :
1320+ raise ValueError (f'Unexpected message { message } ' )
1321+
1322+ async def cancel (
1323+ self , context : RequestContext , event_queue : EventQueue
1324+ ):
1325+ pass
1326+
1327+ handler = create_handler (AuthAgent (), use_legacy )
1328+ client = await create_client (
1329+ handler , agent_card = agent_card (), streaming = streaming
1330+ )
1331+
1332+ msg1 = Message (
1333+ message_id = 'msg-start' , role = Role .ROLE_USER , parts = [Part (text = 'start' )]
1334+ )
1335+
1336+ it = client .send_message (
1337+ SendMessageRequest (
1338+ message = msg1 ,
1339+ configuration = SendMessageConfiguration (return_immediately = False ),
1340+ )
1341+ )
1342+
1343+ events1 = [event async for event in it ]
1344+ assert [get_state (event ) for event in events1 ] == [
1345+ TaskState .TASK_STATE_AUTH_REQUIRED ,
1346+ ]
1347+ task_id = get_task_id (events1 [0 ])
1348+ context_id = get_task_context_id (events1 [0 ])
1349+
1350+ # Now send another message with credentials
1351+ msg2 = Message (
1352+ task_id = task_id ,
1353+ context_id = context_id ,
1354+ message_id = 'msg-creds' ,
1355+ role = Role .ROLE_USER ,
1356+ parts = [Part (text = 'credentials' )],
1357+ )
1358+
1359+ it2 = client .send_message (
1360+ SendMessageRequest (
1361+ message = msg2 ,
1362+ configuration = SendMessageConfiguration (return_immediately = False ),
1363+ )
1364+ )
1365+
1366+ assert [get_state (event ) async for event in it2 ] == [
1367+ TaskState .TASK_STATE_COMPLETED ,
1368+ ]
1369+
1370+
12681371# Scenario: Parallel subscribe attach detach
12691372# Migrated from: test_parallel_subscribe_attach_detach in test_handler_comparison
12701373@pytest .mark .timeout (5.0 )
0 commit comments