Skip to content

Commit 851a52c

Browse files
committed
test: add e2e client-server test
Tests basic functionality with real client and server with real handlers, only agent executor is provided in test as it'd be in a real usage.
1 parent b306e44 commit 851a52c

1 file changed

Lines changed: 309 additions & 0 deletions

File tree

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import asyncio
2+
3+
from collections.abc import AsyncGenerator
4+
from typing import NamedTuple
5+
6+
import grpc
7+
import httpx
8+
import pytest
9+
import pytest_asyncio
10+
11+
from a2a.client.transports import (
12+
ClientTransport,
13+
GrpcTransport,
14+
JsonRpcTransport,
15+
RestTransport,
16+
)
17+
from a2a.server.agent_execution import AgentExecutor, RequestContext
18+
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
19+
from a2a.server.events import EventQueue
20+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
21+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
22+
from a2a.server.tasks import TaskUpdater
23+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
24+
from a2a.types import (
25+
AgentCapabilities,
26+
AgentCard,
27+
AgentInterface,
28+
GetTaskRequest,
29+
ListTasksRequest,
30+
Message,
31+
Part,
32+
Role,
33+
SendMessageConfiguration,
34+
SendMessageRequest,
35+
TaskState,
36+
a2a_pb2_grpc,
37+
)
38+
from a2a.utils import TRANSPORT_GRPC, TRANSPORT_HTTP_JSON, TRANSPORT_JSONRPC
39+
40+
41+
class MockAgentExecutor(AgentExecutor):
42+
async def execute(self, context: RequestContext, event_queue: EventQueue):
43+
task_updater = TaskUpdater(
44+
event_queue,
45+
context.task_id,
46+
context.context_id,
47+
)
48+
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
49+
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
50+
await task_updater.update_status(
51+
TaskState.TASK_STATE_COMPLETED,
52+
message=task_updater.new_agent_message([Part(text='done')]),
53+
)
54+
55+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
56+
pass
57+
58+
59+
@pytest.fixture
60+
def agent_card() -> AgentCard:
61+
return AgentCard(
62+
name='Integration Agent',
63+
description='Real in-memory integration testing.',
64+
version='1.0.0',
65+
capabilities=AgentCapabilities(
66+
streaming=True, push_notifications=False
67+
),
68+
skills=[],
69+
default_input_modes=['text/plain'],
70+
default_output_modes=['text/plain'],
71+
supported_interfaces=[
72+
AgentInterface(
73+
protocol_binding=TRANSPORT_HTTP_JSON,
74+
url='http://testserver',
75+
),
76+
AgentInterface(
77+
protocol_binding=TRANSPORT_JSONRPC,
78+
url='http://testserver',
79+
),
80+
AgentInterface(
81+
protocol_binding=TRANSPORT_GRPC,
82+
url='localhost:50051',
83+
),
84+
],
85+
)
86+
87+
88+
class TransportSetup(NamedTuple):
89+
"""Holds the transport and task_store for a given test."""
90+
91+
transport: ClientTransport
92+
task_store: InMemoryTaskStore
93+
94+
95+
@pytest.fixture
96+
def base_e2e_setup():
97+
task_store = InMemoryTaskStore()
98+
handler = DefaultRequestHandler(
99+
agent_executor=MockAgentExecutor(),
100+
task_store=task_store,
101+
queue_manager=InMemoryQueueManager(),
102+
)
103+
return task_store, handler
104+
105+
106+
@pytest.fixture
107+
def rest_setup(agent_card, base_e2e_setup) -> TransportSetup:
108+
task_store, handler = base_e2e_setup
109+
app_builder = A2ARESTFastAPIApplication(agent_card, handler)
110+
app = app_builder.build()
111+
httpx_client = httpx.AsyncClient(
112+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
113+
)
114+
transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card)
115+
return TransportSetup(
116+
transport=transport,
117+
task_store=task_store,
118+
)
119+
120+
121+
@pytest.fixture
122+
def jsonrpc_setup(agent_card, base_e2e_setup) -> TransportSetup:
123+
task_store, handler = base_e2e_setup
124+
app_builder = A2AFastAPIApplication(
125+
agent_card, handler, extended_agent_card=agent_card
126+
)
127+
app = app_builder.build()
128+
httpx_client = httpx.AsyncClient(
129+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
130+
)
131+
transport = JsonRpcTransport(
132+
httpx_client=httpx_client, agent_card=agent_card
133+
)
134+
return TransportSetup(
135+
transport=transport,
136+
task_store=task_store,
137+
)
138+
139+
140+
@pytest_asyncio.fixture
141+
async def grpc_setup(
142+
agent_card: AgentCard, base_e2e_setup
143+
) -> AsyncGenerator[TransportSetup, None]:
144+
task_store, handler = base_e2e_setup
145+
server = grpc.aio.server()
146+
port = server.add_insecure_port('[::]:0')
147+
server_address = f'localhost:{port}'
148+
149+
# Update the gRPC interface dynamically based on the assigned port
150+
for interface in agent_card.supported_interfaces:
151+
if interface.protocol_binding == TRANSPORT_GRPC:
152+
interface.url = server_address
153+
break
154+
else:
155+
raise ValueError('No gRPC interface found in agent card')
156+
157+
servicer = GrpcHandler(agent_card, handler)
158+
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
159+
await server.start()
160+
161+
channel = grpc.aio.insecure_channel(server_address)
162+
transport = GrpcTransport(agent_card=agent_card, channel=channel)
163+
yield TransportSetup(
164+
transport=transport,
165+
task_store=task_store,
166+
)
167+
168+
await channel.close()
169+
await server.stop(0)
170+
171+
172+
@pytest.fixture(
173+
params=[
174+
pytest.param('rest_setup', id='REST'),
175+
pytest.param('jsonrpc_setup', id='JSON-RPC'),
176+
pytest.param('grpc_setup', id='gRPC'),
177+
]
178+
)
179+
def transport_setups(request) -> TransportSetup:
180+
"""Parametrized fixture that runs tests against all supported transports."""
181+
return request.getfixturevalue(request.param)
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_end_to_end_send_message_blocking(transport_setups):
186+
transport = transport_setups.transport
187+
task_store = transport_setups.task_store
188+
189+
message_to_send = Message(
190+
role=Role.ROLE_USER,
191+
message_id='msg-e2e-blocking',
192+
parts=[Part(text='Run dummy agent!')],
193+
)
194+
configuration = SendMessageConfiguration(blocking=True)
195+
params = SendMessageRequest(
196+
message=message_to_send, configuration=configuration
197+
)
198+
199+
response = await transport.send_message(request=params)
200+
201+
task = response.task
202+
assert task.id
203+
204+
stored_task = await task_store.get(task.id)
205+
assert stored_task is not None
206+
assert stored_task.id == task.id
207+
208+
stored_task = await task_store.get(task.id)
209+
assert stored_task.status.state == TaskState.TASK_STATE_COMPLETED
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_end_to_end_send_message_non_blocking(transport_setups):
214+
transport = transport_setups.transport
215+
task_store = transport_setups.task_store
216+
217+
message_to_send = Message(
218+
role=Role.ROLE_USER,
219+
message_id='msg-e2e-non-blocking',
220+
parts=[Part(text='Run dummy agent!')],
221+
)
222+
configuration = SendMessageConfiguration(blocking=False)
223+
params = SendMessageRequest(
224+
message=message_to_send, configuration=configuration
225+
)
226+
227+
response = await transport.send_message(request=params)
228+
229+
task = response.task
230+
assert task.id
231+
232+
stored_task = await task_store.get(task.id)
233+
assert stored_task is not None
234+
assert stored_task.id == task.id
235+
236+
237+
@pytest.mark.asyncio
238+
async def test_end_to_end_send_message_streaming(transport_setups):
239+
transport = transport_setups.transport
240+
task_store = transport_setups.task_store
241+
242+
message_to_send = Message(
243+
role=Role.ROLE_USER,
244+
message_id='msg-e2e-streaming',
245+
parts=[Part(text='Run dummy agent!')],
246+
)
247+
params = SendMessageRequest(message=message_to_send)
248+
249+
events = [
250+
event
251+
async for event in transport.send_message_streaming(request=params)
252+
]
253+
254+
assert len(events) > 0
255+
final_event = events[-1]
256+
257+
assert final_event.HasField('status_update')
258+
task_id = final_event.status_update.task_id
259+
assert task_id
260+
261+
stored_task = await task_store.get(task_id)
262+
assert stored_task is not None
263+
assert stored_task.id == task_id
264+
assert stored_task.status.state == TaskState.TASK_STATE_COMPLETED
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_end_to_end_get_task(transport_setups):
269+
transport = transport_setups.transport
270+
271+
message_to_send = Message(
272+
role=Role.ROLE_USER,
273+
message_id='msg-e2e-get',
274+
parts=[Part(text='Test Get Task')],
275+
)
276+
response = await transport.send_message(
277+
request=SendMessageRequest(message=message_to_send)
278+
)
279+
task_id = response.task.id
280+
281+
get_request = GetTaskRequest(id=task_id)
282+
retrieved_task = await transport.get_task(request=get_request)
283+
284+
assert retrieved_task.id == task_id
285+
assert retrieved_task.status.state in {
286+
TaskState.TASK_STATE_SUBMITTED,
287+
TaskState.TASK_STATE_WORKING,
288+
TaskState.TASK_STATE_COMPLETED,
289+
}
290+
291+
292+
@pytest.mark.asyncio
293+
async def test_end_to_end_list_tasks(transport_setups):
294+
transport = transport_setups.transport
295+
296+
for i in range(3):
297+
message = Message(
298+
role=Role.ROLE_USER,
299+
message_id=f'msg-e2e-list-{i}',
300+
parts=[Part(text=f'Test List Tasks {i}')],
301+
)
302+
await transport.send_message(
303+
request=SendMessageRequest(message=message)
304+
)
305+
306+
list_request = ListTasksRequest(page_size=10)
307+
list_response = await transport.list_tasks(request=list_request)
308+
309+
assert len(list_response.tasks) == 3

0 commit comments

Comments
 (0)