Skip to content

Commit aef2f30

Browse files
author
Krzysztof Dziedzic
committed
test: test push notifications in itk
1 parent cfdbe4c commit aef2f30

4 files changed

Lines changed: 105 additions & 54 deletions

File tree

.github/workflows/itk.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ jobs:
3131
run: bash run_itk.sh
3232
working-directory: itk
3333
env:
34-
A2A_SAMPLES_REVISION: itk-v.016-alpha
34+
A2A_SAMPLES_REVISION: itk-v.02-alpha

itk/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re
3636

3737
Example:
3838
```
39-
export A2A_SAMPLES_REVISION=itk-v.015-alpha
39+
export A2A_SAMPLES_REVISION=itk-v.02-alpha
4040
```
4141

4242
### 2. Execute Tests

itk/main.py

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@
1717
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
1818
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
1919
from a2a.server.agent_execution import AgentExecutor, RequestContext
20-
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
21-
from a2a.server.routes.rest_routes import create_rest_routes
2220
from a2a.server.events import EventQueue
21+
from a2a.server.routes import (
22+
create_agent_card_routes,
23+
create_jsonrpc_routes,
24+
create_rest_routes,
25+
)
2326
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
2427
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
25-
from a2a.server.tasks import TaskUpdater
28+
from a2a.server.tasks import (
29+
TaskUpdater,
30+
BasePushNotificationSender,
31+
InMemoryPushNotificationConfigStore,
32+
)
2633
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
34+
from a2a.server.context import ServerCallContext
2735
from a2a.types import a2a_pb2_grpc
2836
from a2a.types.a2a_pb2 import (
2937
AgentCapabilities,
@@ -35,11 +43,12 @@
3543
Task,
3644
TaskState,
3745
TaskStatus,
46+
TaskPushNotificationConfig,
3847
)
3948
from a2a.utils import TransportProtocol
4049

41-
42-
log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
50+
log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
51+
log_level = getattr(logging, log_level_str, logging.INFO)
4352
logging.basicConfig(level=log_level)
4453
logger = logging.getLogger(__name__)
4554

@@ -106,7 +115,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
106115
)
107116

108117

109-
async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
118+
async def handle_call_agent(
119+
call: instruction_pb2.CallAgent,
120+
) -> list[str]:
110121
"""Handles the CallAgent instruction by invoking another agent."""
111122
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)
112123

@@ -131,36 +142,47 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
131142
selected_transport == TransportProtocol.GRPC
132143
)
133144

145+
if call.HasField('push_notification'):
146+
url = call.push_notification.url
147+
if not url:
148+
raise ValueError('URL not specified in push_notification behavior')
149+
if not url.startswith(('http://', 'https://')):
150+
url = f'http://{url}'
151+
config.push_notification_config = TaskPushNotificationConfig(
152+
url=f'{url}/notifications',
153+
token='itk-token', # noqa: S106
154+
)
155+
134156
try:
135-
client = await create_client(call.agent_card_uri, client_config=config)
157+
client = await create_client(
158+
call.agent_card_uri,
159+
client_config=config,
160+
)
136161

137162
# Wrap nested instruction
138-
async with client:
139-
nested_msg = wrap_instruction_to_request(call.instruction)
140-
request = SendMessageRequest(message=nested_msg)
141-
142-
results: list[str] = []
143-
async for event in client.send_message(request):
144-
# Event is StreamResponse
145-
logger.info('Event: %s', event)
146-
stream_resp = event
147-
148-
message = None
149-
if stream_resp.HasField('message'):
150-
message = stream_resp.message
151-
elif stream_resp.HasField(
152-
'task'
153-
) and stream_resp.task.status.HasField('message'):
154-
message = stream_resp.task.status.message
155-
elif stream_resp.HasField(
156-
'status_update'
157-
) and stream_resp.status_update.status.HasField('message'):
158-
message = stream_resp.status_update.status.message
159-
160-
if message:
161-
results.extend(
162-
part.text for part in message.parts if part.text
163-
)
163+
nested_msg = wrap_instruction_to_request(call.instruction)
164+
request = SendMessageRequest(message=nested_msg)
165+
166+
results = []
167+
async for event in client.send_message(request):
168+
# Event is streaming response and task
169+
logger.info('Event: %s', event)
170+
stream_resp = event
171+
172+
message = None
173+
if stream_resp.HasField('message'):
174+
message = stream_resp.message
175+
elif stream_resp.HasField(
176+
'task'
177+
) and stream_resp.task.status.HasField('message'):
178+
message = stream_resp.task.status.message
179+
elif stream_resp.HasField(
180+
'status_update'
181+
) and stream_resp.status_update.status.HasField('message'):
182+
message = stream_resp.status_update.status.message
183+
184+
if message:
185+
results.extend(part.text for part in message.parts if part.text)
164186

165187
except Exception as e:
166188
logger.exception('Failed to call outbound agent')
@@ -171,7 +193,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
171193
return results
172194

173195

174-
async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
196+
async def handle_instruction(
197+
inst: instruction_pb2.Instruction,
198+
) -> list[str]:
175199
"""Recursively handles instructions."""
176200
if inst.HasField('call_agent'):
177201
return await handle_call_agent(inst.call_agent)
@@ -303,33 +327,40 @@ async def main_async(http_port: int, grpc_port: int) -> None:
303327
description='Python agent using SDK 1.0.',
304328
version='1.0.0',
305329
capabilities=AgentCapabilities(
306-
streaming=True,
307-
push_notifications=True,
308-
extended_agent_card=True,
330+
streaming=True, push_notifications=True, extended_agent_card=True
309331
),
310332
default_input_modes=['text/plain'],
311333
default_output_modes=['text/plain'],
312334
supported_interfaces=interfaces,
313335
)
314336

315337
task_store = InMemoryTaskStore()
338+
push_config_store = InMemoryPushNotificationConfigStore()
339+
push_sender = BasePushNotificationSender(
340+
httpx_client=httpx.AsyncClient(),
341+
config_store=push_config_store,
342+
context=ServerCallContext(),
343+
)
344+
316345
handler = DefaultRequestHandler(
317346
agent_executor=V10AgentExecutor(),
318-
task_store=task_store,
319347
agent_card=agent_card,
348+
task_store=task_store,
320349
queue_manager=InMemoryQueueManager(),
350+
push_config_store=push_config_store,
351+
push_sender=push_sender,
321352
)
322353

323354
handler_extended = DefaultRequestHandler(
324355
agent_executor=V10AgentExecutor(),
325-
task_store=task_store,
326356
agent_card=agent_card,
357+
task_store=task_store,
327358
queue_manager=InMemoryQueueManager(),
359+
push_config_store=push_config_store,
360+
push_sender=push_sender,
328361
extended_agent_card=agent_card,
329362
)
330363

331-
app = FastAPI()
332-
333364
agent_card_routes = create_agent_card_routes(
334365
agent_card=agent_card, card_url='/.well-known/agent-card.json'
335366
)
@@ -338,15 +369,16 @@ async def main_async(http_port: int, grpc_port: int) -> None:
338369
rpc_url='/',
339370
enable_v0_3_compat=True,
340371
)
341-
app.mount(
342-
'/jsonrpc',
343-
FastAPI(routes=jsonrpc_routes + agent_card_routes),
344-
)
345-
346372
rest_routes = create_rest_routes(
347373
request_handler=handler,
348374
enable_v0_3_compat=True,
349375
)
376+
377+
app = FastAPI()
378+
app.mount(
379+
'/jsonrpc',
380+
FastAPI(routes=jsonrpc_routes + agent_card_routes),
381+
)
350382
app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes))
351383

352384
server = grpc.aio.server()
@@ -365,9 +397,8 @@ async def main_async(http_port: int, grpc_port: int) -> None:
365397
grpc_port,
366398
)
367399

368-
uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower()
369400
config = uvicorn.Config(
370-
app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level
401+
app, host='127.0.0.1', port=http_port, log_level=log_level_str.lower()
371402
)
372403
uvicorn_server = uvicorn.Server(config)
373404

itk/run_itk.sh

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,50 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \
119119
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
120120
"traversal": "euler",
121121
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
122-
"protocols": ["jsonrpc", "grpc"]
122+
"protocols": ["jsonrpc", "grpc"],
123+
"behavior": "send_message"
123124
},
124125
{
125126
"name": "Star Topology (No Go v03) - HTTP_JSON",
126127
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
127128
"traversal": "euler",
128129
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
129-
"protocols": ["http_json"]
130+
"protocols": ["http_json"],
131+
"behavior": "send_message"
130132
},
131133
{
132134
"name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)",
133135
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
134136
"traversal": "euler",
135137
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
136138
"protocols": ["jsonrpc", "grpc"],
137-
"streaming": true
139+
"streaming": true,
140+
"behavior": "send_message"
138141
},
139142
{
140143
"name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)",
141144
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
142145
"traversal": "euler",
143146
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
144147
"protocols": ["http_json"],
145-
"streaming": true
148+
"streaming": true,
149+
"behavior": "send_message"
150+
},
151+
{
152+
"name": "Push Notification Test - JSONRPC & GRPC",
153+
"sdks": ["current", "python_v10", "python_v03", "go_v03"],
154+
"traversal": "euler",
155+
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
156+
"protocols": ["jsonrpc", "grpc"],
157+
"behavior": "push_notification"
158+
},
159+
{
160+
"name": "Push Notification Test - HTTP_JSON",
161+
"sdks": ["current", "python_v10", "python_v03"],
162+
"traversal": "euler",
163+
"edges": ["0->1", "0->2", "1->0", "2->0"],
164+
"protocols": ["http_json"],
165+
"behavior": "push_notification"
146166
}
147167
]
148168
}')

0 commit comments

Comments
 (0)