Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added samples/__init__.py
Empty file.
125 changes: 125 additions & 0 deletions samples/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import argparse
import asyncio
import os
import signal
import uuid

from typing import Any

import grpc
import httpx

from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState


async def _handle_stream(
stream: Any, current_task_id: str | None
) -> str | None:
async for event, task in stream:
if not task:
continue
if not current_task_id:
current_task_id = task.id

if event:
if event.HasField('status_update'):
state_name = TaskState.Name(event.status_update.status.state)
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
if event.status_update.status.HasField('message'):
for part in event.status_update.status.message.parts:
if part.text:
print(part.text, end=' ')
print()

if (
event.status_update.status.state
== TaskState.TASK_STATE_COMPLETED
):
current_task_id = None
print('--- Task Completed ---')

elif event.HasField('artifact_update'):
print(
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
end=' ',
)
for part in event.artifact_update.artifact.parts:
if part.text:
print(part.text, end=' ')
print()

return current_task_id


async def main() -> None:
"""Run the A2A terminal client."""
parser = argparse.ArgumentParser(description='A2A Terminal Client')
parser.add_argument(
'--url', default='http://127.0.0.1:41241', help='Agent base URL'
)
parser.add_argument(
'--transport',
default=None,
help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)',
)
args = parser.parse_args()

config = ClientConfig()
if args.transport:
config.supported_protocol_bindings = [args.transport]

print(
f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})'
)

async with httpx.AsyncClient() as httpx_client:
resolver = A2ACardResolver(httpx_client, args.url)
card = await resolver.get_agent_card()
print('\n✓ Agent Card Found:')
print(f' Name: {card.name}')

client = await ClientFactory.connect(card, client_config=config)

actual_transport = getattr(client, '_transport', client)
print(f' Picked Transport: {actual_transport.__class__.__name__}')

print('\nConnected! Send a message or type /quit to exit.')

current_task_id = None
current_context_id = str(uuid.uuid4())

while True:
try:
loop = asyncio.get_running_loop()
user_input = await loop.run_in_executor(None, input, 'You: ')
except KeyboardInterrupt:
break

if user_input.lower() in ('/quit', '/exit'):
break
if not user_input.strip():
continue

message = Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
parts=[Part(text=user_input)],
task_id=current_task_id,
context_id=current_context_id,
)

request = SendMessageRequest(message=message)

try:
stream = client.send_message(request)
current_task_id = await _handle_stream(stream, current_task_id)
except (httpx.RequestError, grpc.RpcError) as e:
print(f'Error communicating with agent: {e}')

await client.close()


if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
asyncio.run(main())
245 changes: 245 additions & 0 deletions samples/hello_world_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import asyncio
import contextlib
import logging

import grpc
import uvicorn

from fastapi import FastAPI

from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
from a2a.server.agent_execution.agent_executor import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
from a2a.server.events.event_queue import EventQueue
from a2a.server.request_handlers import GrpcHandler
from a2a.server.request_handlers.default_request_handler import (
DefaultRequestHandler,
)
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.tasks.task_updater import TaskUpdater
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentInterface,
AgentProvider,
AgentSkill,
Part,
a2a_pb2_grpc,
)


logger = logging.getLogger(__name__)


class SampleAgentExecutor(AgentExecutor):
"""Sample agent executor logic similar to the a2a-js sample."""

def __init__(self) -> None:
self.running_tasks: set[str] = set()

async def cancel(
self, context: RequestContext, event_queue: EventQueue
) -> None:
"""Cancels a task."""
task_id = context.task_id
if task_id in self.running_tasks:
self.running_tasks.remove(task_id)

updater = TaskUpdater(
event_queue=event_queue,
task_id=task_id or '',
context_id=context.context_id or '',
)
await updater.cancel()

async def execute(
self, context: RequestContext, event_queue: EventQueue
) -> None:
"""Executes a task inline."""
user_message = context.message
task_id = context.task_id
context_id = context.context_id

if not user_message or not task_id or not context_id:
return

self.running_tasks.add(task_id)

logger.info(
'[SampleAgentExecutor] Processing message %s for task %s (context: %s)',
user_message.message_id,
task_id,
context_id,
)

updater = TaskUpdater(
event_queue=event_queue,
task_id=task_id,
context_id=context_id,
)

working_message = updater.new_agent_message(
parts=[Part(text='Processing your question...')]
)
await updater.start_work(message=working_message)

query = context.get_user_input()

agent_reply_text = self._parse_input(query)
await asyncio.sleep(1)

if task_id not in self.running_tasks:
return

await updater.add_artifact(
parts=[Part(text=agent_reply_text)],
name='response',
last_chunk=True,
)
await updater.complete()

logger.info(
'[SampleAgentExecutor] Task %s finished with state: completed',
task_id,
)

def _parse_input(self, query: str) -> str:
if not query:
return 'Hello! Please provide a message for me to respond to.'

ql = query.lower()
if 'hello' in ql or 'hi' in ql:
return 'Hello World! Nice to meet you!'
if 'how are you' in ql:
return (
"I'm doing great! Thanks for asking. How can I help you today?"
)
if 'goodbye' in ql or 'bye' in ql:
return 'Goodbye! Have a wonderful day!'
return f"Hello World! You said: '{query}'. Thanks for your message!"


async def serve(
host: str = '127.0.0.1',
port: int = 41241,
grpc_port: int = 50051,
compat_grpc_port: int = 50052,
) -> None:
"""Run the Sample Agent server with mounted JSON-RPC, HTTP+JSON and gRPC transports."""
agent_card = AgentCard(
name='Sample Agent',
description='A sample agent to test the stream functionality.',
provider=AgentProvider(
organization='A2A Samples', url='https://example.com'
),
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True, push_notifications=False
),
default_input_modes=['text'],
default_output_modes=['text', 'task-status'],
skills=[
AgentSkill(
id='sample_agent',
name='Sample Agent',
description='Say hi.',
tags=['sample'],
examples=['hi'],
input_modes=['text'],
output_modes=['text', 'task-status'],
)
],
supported_interfaces=[
AgentInterface(
protocol_binding='GRPC',
protocol_version='1.0',
url=f'{host}:{grpc_port}',
),
AgentInterface(
protocol_binding='GRPC',
protocol_version='0.3',
url=f'{host}:{compat_grpc_port}',
),
AgentInterface(
protocol_binding='JSONRPC',
protocol_version='1.0',
url=f'http://{host}:{port}/a2a/jsonrpc/',
),
AgentInterface(
protocol_binding='JSONRPC',
protocol_version='0.3',
url=f'http://{host}:{port}/a2a/jsonrpc/',
),
AgentInterface(
protocol_binding='HTTP+JSON',
protocol_version='1.0',
url=f'http://{host}:{port}/a2a/rest/',
),
AgentInterface(
protocol_binding='HTTP+JSON',
protocol_version='0.3',
url=f'http://{host}:{port}/a2a/rest/',
),
],
)

task_store = InMemoryTaskStore()
request_handler = DefaultRequestHandler(
agent_executor=SampleAgentExecutor(), task_store=task_store
)

rest_app_builder = A2ARESTFastAPIApplication(
agent_card=agent_card,
http_handler=request_handler,
enable_v0_3_compat=True,
)
rest_app = rest_app_builder.build()

jsonrpc_app_builder = A2AFastAPIApplication(
agent_card=agent_card,
http_handler=request_handler,
enable_v0_3_compat=True,
)

app = FastAPI()
jsonrpc_app_builder.add_routes_to_app(app, rpc_url='/a2a/jsonrpc/')
app.mount('/a2a/rest', rest_app)

grpc_server = grpc.aio.server()
grpc_server.add_insecure_port(f'{host}:{grpc_port}')
servicer = GrpcHandler(agent_card, request_handler)
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server)

compat_grpc_server = grpc.aio.server()
compat_grpc_server.add_insecure_port(f'{host}:{compat_grpc_port}')
compat_servicer = CompatGrpcHandler(agent_card, request_handler)
a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(
compat_servicer, compat_grpc_server
)

config = uvicorn.Config(app, host=host, port=port)
uvicorn_server = uvicorn.Server(config)

logger.info('Starting Sample Agent servers:')
logger.info(' - HTTP on http://%s:%s', host, port)
logger.info(' - gRPC on %s:%s', host, grpc_port)
logger.info(' - gRPC (v0.3 compat) on %s:%s', host, compat_grpc_port)
logger.info(
'Agent Card available at http://%s:%s/.well-known/agent-card.json',
host,
port,
)

await asyncio.gather(
grpc_server.start(),
compat_grpc_server.start(),
uvicorn_server.serve(),
)


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
with contextlib.suppress(KeyboardInterrupt):
Comment thread
ishymko marked this conversation as resolved.
asyncio.run(serve())
Loading