Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 201 additions & 55 deletions itk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import base64
import logging
import os
import signal
import uuid

import grpc
import httpx
import uvicorn

from fastapi import FastAPI
from typing import Any

from pyproto import instruction_pb2

from a2a.client import ClientConfig, create_client
from a2a.client import Client, ClientConfig, create_client
from a2a.client.errors import A2AClientError
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.agent_execution import AgentExecutor, RequestContext
Expand All @@ -31,20 +34,25 @@
InMemoryPushNotificationConfigStore,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.context import ServerCallContext
from a2a.types import a2a_pb2_grpc
from a2a.types.a2a_pb2 import (
AgentCapabilities,
AgentCard,
AgentInterface,
CancelTaskRequest,
Message,
Part,
SendMessageRequest,
SubscribeToTaskRequest,
Task,
TaskState,
TaskStatus,
TaskPushNotificationConfig,
)
from a2a.utils import TransportProtocol
from a2a.utils.errors import TaskNotCancelableError

Check failure on line 54 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (F401)

itk/main.py:54:30: F401 `a2a.utils.errors.TaskNotCancelableError` imported but unused help: Remove unused import: `a2a.utils.errors.TaskNotCancelableError`
from a2a.server.tasks.push_notification_sender import PushNotificationEvent

Check failure on line 55 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (F401)

itk/main.py:55:55: F401 `a2a.server.tasks.push_notification_sender.PushNotificationEvent` imported but unused help: Remove unused import: `a2a.server.tasks.push_notification_sender.PushNotificationEvent`

log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
log_level = getattr(logging, log_level_str, logging.INFO)
Expand Down Expand Up @@ -73,7 +81,7 @@
# Some clients might send it as base64 in text part
raw = base64.b64decode(part.text)
inst.ParseFromString(raw)
except Exception:
except Exception: # noqa: BLE001

Check failure on line 84 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (RUF100)

itk/main.py:84:32: RUF100 Unused `noqa` directive (unused: `BLE001`) help: Remove unused `noqa` directive
logger.debug(
'Failed to parse instruction from binary part',
exc_info=True,
Expand All @@ -88,7 +96,7 @@
raw = base64.b64decode(part.text)
inst = instruction_pb2.Instruction()
inst.ParseFromString(raw)
except Exception:
except Exception: # noqa: BLE001

Check failure on line 99 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (RUF100)

itk/main.py:99:32: RUF100 Unused `noqa` directive (unused: `BLE001`) help: Remove unused `noqa` directive
logger.debug(
'Failed to parse instruction from text part', exc_info=True
)
Expand All @@ -98,6 +106,93 @@
return None


def _extract_text_from_event(event: Any) -> list[str]:
"""Extracts text parts from an event's message."""
if isinstance(event, tuple):
results = []
for item in event:
results.extend(_extract_text_from_event(item))
return results

message = None
if hasattr(event, 'HasField'):
if event.HasField('message'):
message = event.message
elif event.HasField('task') and event.task.status.HasField('message'):
message = event.task.status.message
elif event.HasField(
'status_update'
) and event.status_update.status.HasField('message'):
message = event.status_update.status.message

results = []
if message:
results.extend(part.text for part in message.parts if part.text)
return results


async def _handle_call_agent_with_resubscribe(

Check failure on line 134 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (PLR0912)

itk/main.py:134:11: PLR0912 Too many branches (15 > 12)
client: Client, request: SendMessageRequest
) -> list[str]:
"""Handles the send-disconnect-resubscribe flow."""
results = []
logger.info('Executing re-subscribe behavior')
agen = client.send_message(request)
task_id = None

async for event in agen:
logger.info('Event before disconnect: %s', event)
if event.HasField('task'):
task_id = event.task.id
elif event.HasField('status_update'):
task_id = event.status_update.task_id
break

await agen.aclose()
logger.info('Disconnected from task %s. Now re-subscribing.', task_id)

resub_agen = client.subscribe(SubscribeToTaskRequest(id=task_id))

task_obj = None
finished = False
async for event in resub_agen:
logger.info('Event after re-subscribe: %s', event)
if hasattr(event, 'task'):
task_obj = event.task
elif hasattr(event, 'HasField') and event.HasField('task'):
task_obj = event.task

Check failure on line 163 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (SIM114)

itk/main.py:160:9: SIM114 Combine `if` branches using logical `or` operator help: Combine `if` branches
Comment on lines +160 to +163
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to extract task_obj from the event is potentially incorrect. For a protobuf message like StreamResponse, hasattr(event, 'task') will always be True, but this doesn't mean the task field is actually set in the message. Accessing event.task when the field is not set will give you a default Task instance, which is likely not what you want. The correct way to check for presence is event.HasField('task').

The elif condition is currently unreachable because the if condition will always be met for a StreamResponse object. I suggest simplifying this block to correctly check for the task field.

Suggested change
if hasattr(event, 'task'):
task_obj = event.task
elif hasattr(event, 'HasField') and event.HasField('task'):
task_obj = event.task
if event.HasField('task'):
task_obj = event.task


extracted_text = _extract_text_from_event(event)
for text in extracted_text:
processed_text = text.replace('task-finished', '')
results.append(processed_text)
if any('task-finished' in text for text in extracted_text):
logger.info(
'Received task-finished after re-subscribe, breaking loop.'
)
finished = True
break

if not results and task_obj and hasattr(task_obj, 'history'):
logger.info('Results empty after loop, reading from history.')
for msg in task_obj.history:
if msg.role == 'ROLE_AGENT' or msg.role == 'agent':

Check failure on line 179 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (PLR1714)

itk/main.py:179:16: PLR1714 Consider merging multiple comparisons: `msg.role in {'ROLE_AGENT', 'agent'}`. help: Merge multiple comparisons
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Comparing the role enum field with strings like 'ROLE_AGENT' or 'agent' is not reliable and is likely incorrect. Protobuf enums in Python are integers, so this comparison will likely always evaluate to false. You should compare msg.role with the actual enum member.

To do this, you'll need to import Role from a2a.types.a2a_pb2 at the top of the file.

Suggested change
if msg.role == 'ROLE_AGENT' or msg.role == 'agent':
if msg.role == Role.ROLE_AGENT:

for part in msg.parts:
if part.text:
results.append(part.text.replace('task-finished', ''))

Check failure on line 182 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (PERF401)

itk/main.py:182:25: PERF401 Use `list.extend` to create a transformed list help: Replace for loop with list.extend

if not finished:
logger.info('Canceling task %s after retrieval.', task_id)
try:
await client.cancel_task(CancelTaskRequest(id=task_id))
logger.info('Task cancelled successfully: %s', task_id)
except A2AClientError as e:
logger.error('Failed to cancel task %s: %s', task_id, str(e))

Check failure on line 190 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (TRY400)

itk/main.py:190:13: TRY400 Use `logging.exception` instead of `logging.error` help: Replace with `exception`
raise

return results


def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
"""Wraps an Instruction proto into an A2A Message."""
inst_bytes = inst.SerializeToString()
Expand Down Expand Up @@ -129,18 +224,22 @@
'GRPC': TransportProtocol.GRPC,
}

selected_transport = transport_map.get(call.transport.upper())
selected_transport = transport_map.get(
call.transport.upper(), TransportProtocol.JSONRPC
)
if selected_transport is None:
raise ValueError(f'Unsupported transport: {call.transport}')

config = ClientConfig()
config.httpx_client = httpx.AsyncClient(timeout=30.0)
config.grpc_channel_factory = grpc.aio.insecure_channel
config.supported_protocol_bindings = [selected_transport]
config.streaming = call.streaming or (
selected_transport == TransportProtocol.GRPC
)

if call.HasField('resubscribe') and not config.streaming:
raise ValueError('Re-subscription requires streaming to be enabled')

if call.HasField('push_notification'):
url = call.push_notification.url
if not url:
Expand All @@ -149,47 +248,48 @@
url = f'http://{url}'
config.push_notification_config = TaskPushNotificationConfig(
url=f'{url}/notifications',
token='itk-token', # noqa: S106
token='itk-token',

Check failure on line 251 in itk/main.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (S106)

itk/main.py:251:13: S106 Possible hardcoded password assigned to argument: "token"
)

try:
client = await create_client(
call.agent_card_uri,
client_config=config,
)
async with httpx.AsyncClient(timeout=30.0) as httpx_client:
config.httpx_client = httpx_client
try:
client = await create_client(
call.agent_card_uri,
client_config=config,
)

# Wrap nested instruction
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)
# Wrap nested instruction
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results = []
async for event in client.send_message(request):
# Event is streaming response and task
logger.info('Event: %s', event)
stream_resp = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif stream_resp.HasField(
'task'
) and stream_resp.task.status.HasField('message'):
message = stream_resp.task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
message = stream_resp.status_update.status.message

if message:
results.extend(part.text for part in message.parts if part.text)

except Exception as e:
logger.exception('Failed to call outbound agent')
raise RuntimeError(
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
) from e
else:
return results
results = []

if call.HasField('resubscribe'):
results.extend(
await _handle_call_agent_with_resubscribe(client, request)
)
else:
async for event in client.send_message(request):
logger.info('Event: %s', event)
results.extend(_extract_text_from_event(event))

except Exception as e:
logger.exception('Failed to call outbound agent')
raise RuntimeError(
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
) from e
else:
return results


def _should_hold(inst: instruction_pb2.Instruction) -> bool:
"""Recursively checks if any part of the instruction requests holding the task."""
if inst.HasField('return_response') and inst.return_response.hold_task:
return True
if inst.HasField('steps'):
return any(_should_hold(step) for step in inst.steps.instructions)
return False


async def handle_instruction(
Expand Down Expand Up @@ -245,23 +345,57 @@
)
return

should_hold_task = _should_hold(instruction)

try:
logger.info('Instruction: %s', instruction)
results = await handle_instruction(instruction)

response_text = '\n'.join(results)
logger.info('Response: %s', response_text)
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text=response_text)]
),
)
logger.info('Task %s completed', context.task_id)

if should_hold_task:
logger.info('Holding task %s as requested', context.task_id)
# Emitted event: response + task-finished
logger.info(
'Emitting response and task-finished for held task %s', context.task_id
)
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=task_updater.new_agent_message(
[Part(text=response_text + '\n' +'task-finished')]
),
)
await asyncio.sleep(2)

# Continue emitting "task-finished" every 2 seconds
try:
while True:
logger.info(
'Emitting periodic status update for held task %s',
context.task_id,
)
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=None,
)
await asyncio.sleep(2)
except asyncio.CancelledError:
logger.info('Task %s cancelled', context.task_id)
return
else:
await task_updater.update_status(
TaskState.TASK_STATE_COMPLETED,
message=task_updater.new_agent_message(
[Part(text=response_text)]
),
)
logger.info('Task %s completed', context.task_id)
except Exception as e:
logger.exception('Error during instruction handling')
await task_updater.update_status(
TaskState.TASK_STATE_FAILED,
message=task_updater.new_agent_message([Part(text=str(e))]),
message=None,
)

async def cancel(
Expand Down Expand Up @@ -325,19 +459,19 @@
name='ITK v10 Agent',
description='Python agent using SDK 1.0.',
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True, push_notifications=True, extended_agent_card=True
),
capabilities=AgentCapabilities(streaming=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=interfaces,
)

task_store = InMemoryTaskStore()
push_config_store = InMemoryPushNotificationConfigStore()
httpx_client = httpx.AsyncClient()
push_sender = BasePushNotificationSender(
httpx_client=httpx.AsyncClient(),
httpx_client=httpx_client,
config_store=push_config_store,
context=ServerCallContext(),
)
Comment on lines 471 to 475
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The context parameter of BasePushNotificationSender is deprecated and will be removed in a future version. The constructor logs a warning when it is used. You can safely remove it from this call.

    push_sender = BasePushNotificationSender(
        httpx_client=httpx_client,
        config_store=push_config_store,
    )


handler = DefaultRequestHandler(
Expand Down Expand Up @@ -396,10 +530,22 @@
)

config = uvicorn.Config(
app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower()
app, host='127.0.0.1', port=http_port, log_level='info'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The log_level for uvicorn is now hardcoded to 'info'. This removes the ability to configure it using the ITK_LOG_LEVEL environment variable. It's better to restore the previous behavior to allow for flexible log level configuration during testing.

Suggested change
app, host='127.0.0.1', port=http_port, log_level='info'
app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower()

)
uvicorn_server = uvicorn.Server(config)

# Signal handling
loop = asyncio.get_running_loop()

async def shutdown() -> None:
logger.info('Shutting down...')
uvicorn_server.should_exit = True
await server.stop(5)
await httpx_client.aclose()

for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))

await uvicorn_server.serve()


Expand Down
Loading
Loading