Skip to content

Commit 1464d89

Browse files
author
Krzysztof Dziedzic
committed
test: setup itk resubscribe tests
1 parent c0c6c08 commit 1464d89

2 files changed

Lines changed: 219 additions & 55 deletions

File tree

itk/main.py

Lines changed: 201 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import base64
44
import logging
55
import os
6+
import signal
67
import uuid
78

89
import grpc
910
import httpx
1011
import uvicorn
1112

1213
from fastapi import FastAPI
14+
from typing import Any
1315

1416
from pyproto import instruction_pb2
1517

16-
from a2a.client import ClientConfig, create_client
18+
from a2a.client import Client, ClientConfig, create_client
19+
from a2a.client.errors import A2AClientError
1720
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
1821
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
1922
from a2a.server.agent_execution import AgentExecutor, RequestContext
@@ -31,20 +34,25 @@
3134
InMemoryPushNotificationConfigStore,
3235
)
3336
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
37+
from a2a.server.context import ServerCallContext
3438
from a2a.types import a2a_pb2_grpc
3539
from a2a.types.a2a_pb2 import (
3640
AgentCapabilities,
3741
AgentCard,
3842
AgentInterface,
43+
CancelTaskRequest,
3944
Message,
4045
Part,
4146
SendMessageRequest,
47+
SubscribeToTaskRequest,
4248
Task,
4349
TaskState,
4450
TaskStatus,
4551
TaskPushNotificationConfig,
4652
)
4753
from a2a.utils import TransportProtocol
54+
from a2a.utils.errors import TaskNotCancelableError
55+
from a2a.server.tasks.push_notification_sender import PushNotificationEvent
4856

4957
log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
5058
log_level = getattr(logging, log_level_str, logging.INFO)
@@ -73,7 +81,7 @@ def extract_instruction(
7381
# Some clients might send it as base64 in text part
7482
raw = base64.b64decode(part.text)
7583
inst.ParseFromString(raw)
76-
except Exception:
84+
except Exception: # noqa: BLE001
7785
logger.debug(
7886
'Failed to parse instruction from binary part',
7987
exc_info=True,
@@ -88,7 +96,7 @@ def extract_instruction(
8896
raw = base64.b64decode(part.text)
8997
inst = instruction_pb2.Instruction()
9098
inst.ParseFromString(raw)
91-
except Exception:
99+
except Exception: # noqa: BLE001
92100
logger.debug(
93101
'Failed to parse instruction from text part', exc_info=True
94102
)
@@ -98,6 +106,93 @@ def extract_instruction(
98106
return None
99107

100108

109+
def _extract_text_from_event(event: Any) -> list[str]:
110+
"""Extracts text parts from an event's message."""
111+
if isinstance(event, tuple):
112+
results = []
113+
for item in event:
114+
results.extend(_extract_text_from_event(item))
115+
return results
116+
117+
message = None
118+
if hasattr(event, 'HasField'):
119+
if event.HasField('message'):
120+
message = event.message
121+
elif event.HasField('task') and event.task.status.HasField('message'):
122+
message = event.task.status.message
123+
elif event.HasField(
124+
'status_update'
125+
) and event.status_update.status.HasField('message'):
126+
message = event.status_update.status.message
127+
128+
results = []
129+
if message:
130+
results.extend(part.text for part in message.parts if part.text)
131+
return results
132+
133+
134+
async def _handle_call_agent_with_resubscribe(
135+
client: Client, request: SendMessageRequest
136+
) -> list[str]:
137+
"""Handles the send-disconnect-resubscribe flow."""
138+
results = []
139+
logger.info('Executing re-subscribe behavior')
140+
agen = client.send_message(request)
141+
task_id = None
142+
143+
async for event in agen:
144+
logger.info('Event before disconnect: %s', event)
145+
if event.HasField('task'):
146+
task_id = event.task.id
147+
elif event.HasField('status_update'):
148+
task_id = event.status_update.task_id
149+
break
150+
151+
await agen.aclose()
152+
logger.info('Disconnected from task %s. Now re-subscribing.', task_id)
153+
154+
resub_agen = client.subscribe(SubscribeToTaskRequest(id=task_id))
155+
156+
task_obj = None
157+
finished = False
158+
async for event in resub_agen:
159+
logger.info('Event after re-subscribe: %s', event)
160+
if hasattr(event, 'task'):
161+
task_obj = event.task
162+
elif hasattr(event, 'HasField') and event.HasField('task'):
163+
task_obj = event.task
164+
165+
extracted_text = _extract_text_from_event(event)
166+
for text in extracted_text:
167+
processed_text = text.replace('task-finished', '')
168+
results.append(processed_text)
169+
if any('task-finished' in text for text in extracted_text):
170+
logger.info(
171+
'Received task-finished after re-subscribe, breaking loop.'
172+
)
173+
finished = True
174+
break
175+
176+
if not results and task_obj and hasattr(task_obj, 'history'):
177+
logger.info('Results empty after loop, reading from history.')
178+
for msg in task_obj.history:
179+
if msg.role == 'ROLE_AGENT' or msg.role == 'agent':
180+
for part in msg.parts:
181+
if part.text:
182+
results.append(part.text.replace('task-finished', ''))
183+
184+
if not finished:
185+
logger.info('Canceling task %s after retrieval.', task_id)
186+
try:
187+
await client.cancel_task(CancelTaskRequest(id=task_id))
188+
logger.info('Task cancelled successfully: %s', task_id)
189+
except A2AClientError as e:
190+
logger.error('Failed to cancel task %s: %s', task_id, str(e))
191+
raise
192+
193+
return results
194+
195+
101196
def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
102197
"""Wraps an Instruction proto into an A2A Message."""
103198
inst_bytes = inst.SerializeToString()
@@ -129,18 +224,22 @@ async def handle_call_agent(
129224
'GRPC': TransportProtocol.GRPC,
130225
}
131226

132-
selected_transport = transport_map.get(call.transport.upper())
227+
selected_transport = transport_map.get(
228+
call.transport.upper(), TransportProtocol.JSONRPC
229+
)
133230
if selected_transport is None:
134231
raise ValueError(f'Unsupported transport: {call.transport}')
135232

136233
config = ClientConfig()
137-
config.httpx_client = httpx.AsyncClient(timeout=30.0)
138234
config.grpc_channel_factory = grpc.aio.insecure_channel
139235
config.supported_protocol_bindings = [selected_transport]
140236
config.streaming = call.streaming or (
141237
selected_transport == TransportProtocol.GRPC
142238
)
143239

240+
if call.HasField('resubscribe') and not config.streaming:
241+
raise ValueError('Re-subscription requires streaming to be enabled')
242+
144243
if call.HasField('push_notification'):
145244
url = call.push_notification.url
146245
if not url:
@@ -149,47 +248,48 @@ async def handle_call_agent(
149248
url = f'http://{url}'
150249
config.push_notification_config = TaskPushNotificationConfig(
151250
url=f'{url}/notifications',
152-
token='itk-token', # noqa: S106
251+
token='itk-token',
153252
)
154253

155-
try:
156-
client = await create_client(
157-
call.agent_card_uri,
158-
client_config=config,
159-
)
254+
async with httpx.AsyncClient(timeout=30.0) as httpx_client:
255+
config.httpx_client = httpx_client
256+
try:
257+
client = await create_client(
258+
call.agent_card_uri,
259+
client_config=config,
260+
)
160261

161-
# Wrap nested instruction
162-
nested_msg = wrap_instruction_to_request(call.instruction)
163-
request = SendMessageRequest(message=nested_msg)
262+
# Wrap nested instruction
263+
nested_msg = wrap_instruction_to_request(call.instruction)
264+
request = SendMessageRequest(message=nested_msg)
164265

165-
results = []
166-
async for event in client.send_message(request):
167-
# Event is streaming response and task
168-
logger.info('Event: %s', event)
169-
stream_resp = event
170-
171-
message = None
172-
if stream_resp.HasField('message'):
173-
message = stream_resp.message
174-
elif stream_resp.HasField(
175-
'task'
176-
) and stream_resp.task.status.HasField('message'):
177-
message = stream_resp.task.status.message
178-
elif stream_resp.HasField(
179-
'status_update'
180-
) and stream_resp.status_update.status.HasField('message'):
181-
message = stream_resp.status_update.status.message
182-
183-
if message:
184-
results.extend(part.text for part in message.parts if part.text)
185-
186-
except Exception as e:
187-
logger.exception('Failed to call outbound agent')
188-
raise RuntimeError(
189-
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
190-
) from e
191-
else:
192-
return results
266+
results = []
267+
268+
if call.HasField('resubscribe'):
269+
results.extend(
270+
await _handle_call_agent_with_resubscribe(client, request)
271+
)
272+
else:
273+
async for event in client.send_message(request):
274+
logger.info('Event: %s', event)
275+
results.extend(_extract_text_from_event(event))
276+
277+
except Exception as e:
278+
logger.exception('Failed to call outbound agent')
279+
raise RuntimeError(
280+
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
281+
) from e
282+
else:
283+
return results
284+
285+
286+
def _should_hold(inst: instruction_pb2.Instruction) -> bool:
287+
"""Recursively checks if any part of the instruction requests holding the task."""
288+
if inst.HasField('return_response') and inst.return_response.hold_task:
289+
return True
290+
if inst.HasField('steps'):
291+
return any(_should_hold(step) for step in inst.steps.instructions)
292+
return False
193293

194294

195295
async def handle_instruction(
@@ -245,23 +345,57 @@ async def execute(
245345
)
246346
return
247347

348+
should_hold_task = _should_hold(instruction)
349+
248350
try:
249351
logger.info('Instruction: %s', instruction)
250352
results = await handle_instruction(instruction)
353+
251354
response_text = '\n'.join(results)
252355
logger.info('Response: %s', response_text)
253-
await task_updater.update_status(
254-
TaskState.TASK_STATE_COMPLETED,
255-
message=task_updater.new_agent_message(
256-
[Part(text=response_text)]
257-
),
258-
)
259-
logger.info('Task %s completed', context.task_id)
356+
357+
if should_hold_task:
358+
logger.info('Holding task %s as requested', context.task_id)
359+
# Emitted event: response + task-finished
360+
logger.info(
361+
'Emitting response and task-finished for held task %s', context.task_id
362+
)
363+
await task_updater.update_status(
364+
TaskState.TASK_STATE_WORKING,
365+
message=task_updater.new_agent_message(
366+
[Part(text=response_text + '\ntask-finished')]
367+
),
368+
)
369+
await asyncio.sleep(2)
370+
371+
# Continue emitting "task-finished" every 2 seconds
372+
try:
373+
while True:
374+
logger.info(
375+
'Emitting periodic status update for held task %s',
376+
context.task_id,
377+
)
378+
await task_updater.update_status(
379+
TaskState.TASK_STATE_WORKING,
380+
message=None,
381+
)
382+
await asyncio.sleep(2)
383+
except asyncio.CancelledError:
384+
logger.info('Task %s cancelled', context.task_id)
385+
return
386+
else:
387+
await task_updater.update_status(
388+
TaskState.TASK_STATE_COMPLETED,
389+
message=task_updater.new_agent_message(
390+
[Part(text=response_text)]
391+
),
392+
)
393+
logger.info('Task %s completed', context.task_id)
260394
except Exception as e:
261395
logger.exception('Error during instruction handling')
262396
await task_updater.update_status(
263397
TaskState.TASK_STATE_FAILED,
264-
message=task_updater.new_agent_message([Part(text=str(e))]),
398+
message=None,
265399
)
266400

267401
async def cancel(
@@ -325,19 +459,19 @@ async def main_async(http_port: int, grpc_port: int) -> None:
325459
name='ITK v10 Agent',
326460
description='Python agent using SDK 1.0.',
327461
version='1.0.0',
328-
capabilities=AgentCapabilities(
329-
streaming=True, push_notifications=True, extended_agent_card=True
330-
),
462+
capabilities=AgentCapabilities(streaming=True),
331463
default_input_modes=['text/plain'],
332464
default_output_modes=['text/plain'],
333465
supported_interfaces=interfaces,
334466
)
335467

336468
task_store = InMemoryTaskStore()
337469
push_config_store = InMemoryPushNotificationConfigStore()
470+
httpx_client = httpx.AsyncClient()
338471
push_sender = BasePushNotificationSender(
339-
httpx_client=httpx.AsyncClient(),
472+
httpx_client=httpx_client,
340473
config_store=push_config_store,
474+
context=ServerCallContext(),
341475
)
342476

343477
handler = DefaultRequestHandler(
@@ -396,10 +530,22 @@ async def main_async(http_port: int, grpc_port: int) -> None:
396530
)
397531

398532
config = uvicorn.Config(
399-
app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower()
533+
app, host='127.0.0.1', port=http_port, log_level='info'
400534
)
401535
uvicorn_server = uvicorn.Server(config)
402536

537+
# Signal handling
538+
loop = asyncio.get_running_loop()
539+
540+
async def shutdown() -> None:
541+
logger.info('Shutting down...')
542+
uvicorn_server.should_exit = True
543+
await server.stop(5)
544+
await httpx_client.aclose()
545+
546+
for sig in (signal.SIGINT, signal.SIGTERM):
547+
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))
548+
403549
await uvicorn_server.serve()
404550

405551

0 commit comments

Comments
 (0)