diff --git a/samples/__init__.py b/samples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/samples/cli.py b/samples/cli.py new file mode 100644 index 000000000..6a4597fa9 --- /dev/null +++ b/samples/cli.py @@ -0,0 +1,125 @@ +import argparse +import asyncio +import os +import signal +import uuid + +from typing import Any + +import grpc +import httpx + +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory +from a2a.types import Message, Part, Role, SendMessageRequest, TaskState + + +async def _handle_stream( + stream: Any, current_task_id: str | None +) -> str | None: + async for event, task in stream: + if not task: + continue + if not current_task_id: + current_task_id = task.id + + if event: + if event.HasField('status_update'): + state_name = TaskState.Name(event.status_update.status.state) + print(f'TaskStatusUpdate [state={state_name}]:', end=' ') + if event.status_update.status.HasField('message'): + for part in event.status_update.status.message.parts: + if part.text: + print(part.text, end=' ') + print() + + if ( + event.status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ): + current_task_id = None + print('--- Task Completed ---') + + elif event.HasField('artifact_update'): + print( + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', + end=' ', + ) + for part in event.artifact_update.artifact.parts: + if part.text: + print(part.text, end=' ') + print() + + return current_task_id + + +async def main() -> None: + """Run the A2A terminal client.""" + parser = argparse.ArgumentParser(description='A2A Terminal Client') + parser.add_argument( + '--url', default='http://127.0.0.1:41241', help='Agent base URL' + ) + parser.add_argument( + '--transport', + default=None, + help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)', + ) + args = parser.parse_args() + + config = ClientConfig() + if args.transport: + config.supported_protocol_bindings = [args.transport] + + print( + f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})' + ) + + async with httpx.AsyncClient() as httpx_client: + resolver = A2ACardResolver(httpx_client, args.url) + card = await resolver.get_agent_card() + print('\n✓ Agent Card Found:') + print(f' Name: {card.name}') + + client = await ClientFactory.connect(card, client_config=config) + + actual_transport = getattr(client, '_transport', client) + print(f' Picked Transport: {actual_transport.__class__.__name__}') + + print('\nConnected! Send a message or type /quit to exit.') + + current_task_id = None + current_context_id = str(uuid.uuid4()) + + while True: + try: + loop = asyncio.get_running_loop() + user_input = await loop.run_in_executor(None, input, 'You: ') + except KeyboardInterrupt: + break + + if user_input.lower() in ('/quit', '/exit'): + break + if not user_input.strip(): + continue + + message = Message( + role=Role.ROLE_USER, + message_id=str(uuid.uuid4()), + parts=[Part(text=user_input)], + task_id=current_task_id, + context_id=current_context_id, + ) + + request = SendMessageRequest(message=message) + + try: + stream = client.send_message(request) + current_task_id = await _handle_stream(stream, current_task_id) + except (httpx.RequestError, grpc.RpcError) as e: + print(f'Error communicating with agent: {e}') + + await client.close() + + +if __name__ == '__main__': + signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0)) + asyncio.run(main()) diff --git a/samples/hello_world_agent.py b/samples/hello_world_agent.py new file mode 100644 index 000000000..38dfdf561 --- /dev/null +++ b/samples/hello_world_agent.py @@ -0,0 +1,245 @@ +import asyncio +import contextlib +import logging + +import grpc +import uvicorn + +from fastapi import FastAPI + +from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication +from a2a.server.events.event_queue import EventQueue +from a2a.server.request_handlers import GrpcHandler +from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, +) +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.task_updater import TaskUpdater +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + AgentProvider, + AgentSkill, + Part, + a2a_pb2_grpc, +) + + +logger = logging.getLogger(__name__) + + +class SampleAgentExecutor(AgentExecutor): + """Sample agent executor logic similar to the a2a-js sample.""" + + def __init__(self) -> None: + self.running_tasks: set[str] = set() + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + """Cancels a task.""" + task_id = context.task_id + if task_id in self.running_tasks: + self.running_tasks.remove(task_id) + + updater = TaskUpdater( + event_queue=event_queue, + task_id=task_id or '', + context_id=context.context_id or '', + ) + await updater.cancel() + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + """Executes a task inline.""" + user_message = context.message + task_id = context.task_id + context_id = context.context_id + + if not user_message or not task_id or not context_id: + return + + self.running_tasks.add(task_id) + + logger.info( + '[SampleAgentExecutor] Processing message %s for task %s (context: %s)', + user_message.message_id, + task_id, + context_id, + ) + + updater = TaskUpdater( + event_queue=event_queue, + task_id=task_id, + context_id=context_id, + ) + + working_message = updater.new_agent_message( + parts=[Part(text='Processing your question...')] + ) + await updater.start_work(message=working_message) + + query = context.get_user_input() + + agent_reply_text = self._parse_input(query) + await asyncio.sleep(1) + + if task_id not in self.running_tasks: + return + + await updater.add_artifact( + parts=[Part(text=agent_reply_text)], + name='response', + last_chunk=True, + ) + await updater.complete() + + logger.info( + '[SampleAgentExecutor] Task %s finished with state: completed', + task_id, + ) + + def _parse_input(self, query: str) -> str: + if not query: + return 'Hello! Please provide a message for me to respond to.' + + ql = query.lower() + if 'hello' in ql or 'hi' in ql: + return 'Hello World! Nice to meet you!' + if 'how are you' in ql: + return ( + "I'm doing great! Thanks for asking. How can I help you today?" + ) + if 'goodbye' in ql or 'bye' in ql: + return 'Goodbye! Have a wonderful day!' + return f"Hello World! You said: '{query}'. Thanks for your message!" + + +async def serve( + host: str = '127.0.0.1', + port: int = 41241, + grpc_port: int = 50051, + compat_grpc_port: int = 50052, +) -> None: + """Run the Sample Agent server with mounted JSON-RPC, HTTP+JSON and gRPC transports.""" + agent_card = AgentCard( + name='Sample Agent', + description='A sample agent to test the stream functionality.', + provider=AgentProvider( + organization='A2A Samples', url='https://example.com' + ), + version='1.0.0', + capabilities=AgentCapabilities( + streaming=True, push_notifications=False + ), + default_input_modes=['text'], + default_output_modes=['text', 'task-status'], + skills=[ + AgentSkill( + id='sample_agent', + name='Sample Agent', + description='Say hi.', + tags=['sample'], + examples=['hi'], + input_modes=['text'], + output_modes=['text', 'task-status'], + ) + ], + supported_interfaces=[ + AgentInterface( + protocol_binding='GRPC', + protocol_version='1.0', + url=f'{host}:{grpc_port}', + ), + AgentInterface( + protocol_binding='GRPC', + protocol_version='0.3', + url=f'{host}:{compat_grpc_port}', + ), + AgentInterface( + protocol_binding='JSONRPC', + protocol_version='1.0', + url=f'http://{host}:{port}/a2a/jsonrpc/', + ), + AgentInterface( + protocol_binding='JSONRPC', + protocol_version='0.3', + url=f'http://{host}:{port}/a2a/jsonrpc/', + ), + AgentInterface( + protocol_binding='HTTP+JSON', + protocol_version='1.0', + url=f'http://{host}:{port}/a2a/rest/', + ), + AgentInterface( + protocol_binding='HTTP+JSON', + protocol_version='0.3', + url=f'http://{host}:{port}/a2a/rest/', + ), + ], + ) + + task_store = InMemoryTaskStore() + request_handler = DefaultRequestHandler( + agent_executor=SampleAgentExecutor(), task_store=task_store + ) + + rest_app_builder = A2ARESTFastAPIApplication( + agent_card=agent_card, + http_handler=request_handler, + enable_v0_3_compat=True, + ) + rest_app = rest_app_builder.build() + + jsonrpc_app_builder = A2AFastAPIApplication( + agent_card=agent_card, + http_handler=request_handler, + enable_v0_3_compat=True, + ) + + app = FastAPI() + jsonrpc_app_builder.add_routes_to_app(app, rpc_url='/a2a/jsonrpc/') + app.mount('/a2a/rest', rest_app) + + grpc_server = grpc.aio.server() + grpc_server.add_insecure_port(f'{host}:{grpc_port}') + servicer = GrpcHandler(agent_card, request_handler) + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, grpc_server) + + compat_grpc_server = grpc.aio.server() + compat_grpc_server.add_insecure_port(f'{host}:{compat_grpc_port}') + compat_servicer = CompatGrpcHandler(agent_card, request_handler) + a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server( + compat_servicer, compat_grpc_server + ) + + config = uvicorn.Config(app, host=host, port=port) + uvicorn_server = uvicorn.Server(config) + + logger.info('Starting Sample Agent servers:') + logger.info(' - HTTP on http://%s:%s', host, port) + logger.info(' - gRPC on %s:%s', host, grpc_port) + logger.info(' - gRPC (v0.3 compat) on %s:%s', host, compat_grpc_port) + logger.info( + 'Agent Card available at http://%s:%s/.well-known/agent-card.json', + host, + port, + ) + + await asyncio.gather( + grpc_server.start(), + compat_grpc_server.start(), + uvicorn_server.serve(), + ) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + with contextlib.suppress(KeyboardInterrupt): + asyncio.run(serve())