Skip to content

Commit f647c5d

Browse files
author
Krzysztof Dziedzic
committed
test: test push notifications in itk
1 parent 24db37e commit f647c5d

4 files changed

Lines changed: 113 additions & 59 deletions

File tree

.github/workflows/itk.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ jobs:
2828
run: bash run_itk.sh
2929
working-directory: itk
3030
env:
31-
A2A_SAMPLES_REVISION: itk-v.016-alpha
31+
A2A_SAMPLES_REVISION: itk-v.02-alpha

itk/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re
3636

3737
Example:
3838
```
39-
export A2A_SAMPLES_REVISION=itk-v.015-alpha
39+
export A2A_SAMPLES_REVISION=itk-v.02-alpha
4040
```
4141

4242
### 2. Execute Tests

itk/main.py

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,21 @@
1717
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
1818
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
1919
from a2a.server.agent_execution import AgentExecutor, RequestContext
20-
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
21-
from a2a.server.routes.rest_routes import create_rest_routes
2220
from a2a.server.events import EventQueue
21+
from a2a.server.routes import (
22+
create_agent_card_routes,
23+
create_jsonrpc_routes,
24+
create_rest_routes,
25+
)
2326
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
2427
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
25-
from a2a.server.tasks import TaskUpdater
28+
from a2a.server.tasks import (
29+
TaskUpdater,
30+
BasePushNotificationSender,
31+
InMemoryPushNotificationConfigStore,
32+
)
2633
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
34+
from a2a.server.context import ServerCallContext
2735
from a2a.types import a2a_pb2_grpc
2836
from a2a.types.a2a_pb2 import (
2937
AgentCapabilities,
@@ -35,15 +43,20 @@
3543
Task,
3644
TaskState,
3745
TaskStatus,
46+
TaskPushNotificationConfig,
3847
)
3948
from a2a.utils import TransportProtocol
49+
from a2a.server.tasks.push_notification_sender import PushNotificationEvent
4050

41-
42-
log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
51+
log_level_str = os.environ.get('ITK_LOG_LEVEL', 'INFO').upper()
52+
log_level = getattr(logging, log_level_str, logging.INFO)
4353
logging.basicConfig(level=log_level)
4454
logger = logging.getLogger(__name__)
4555

4656

57+
58+
59+
4760
def extract_instruction(
4861
message: Message | None,
4962
) -> instruction_pb2.Instruction | None:
@@ -65,7 +78,7 @@ def extract_instruction(
6578
# Some clients might send it as base64 in text part
6679
raw = base64.b64decode(part.text)
6780
inst.ParseFromString(raw)
68-
except Exception:
81+
except Exception: # noqa: BLE001
6982
logger.debug(
7083
'Failed to parse instruction from binary part',
7184
exc_info=True,
@@ -80,7 +93,7 @@ def extract_instruction(
8093
raw = base64.b64decode(part.text)
8194
inst = instruction_pb2.Instruction()
8295
inst.ParseFromString(raw)
83-
except Exception:
96+
except Exception: # noqa: BLE001
8497
logger.debug(
8598
'Failed to parse instruction from text part', exc_info=True
8699
)
@@ -106,7 +119,9 @@ def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
106119
)
107120

108121

109-
async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
122+
async def handle_call_agent(
123+
call: instruction_pb2.CallAgent,
124+
) -> list[str]:
110125
"""Handles the CallAgent instruction by invoking another agent."""
111126
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)
112127

@@ -119,7 +134,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
119134
'GRPC': TransportProtocol.GRPC,
120135
}
121136

122-
selected_transport = transport_map.get(call.transport.upper())
137+
selected_transport = transport_map.get(
138+
call.transport.upper(), TransportProtocol.JSONRPC
139+
)
123140
if selected_transport is None:
124141
raise ValueError(f'Unsupported transport: {call.transport}')
125142

@@ -131,36 +148,46 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
131148
selected_transport == TransportProtocol.GRPC
132149
)
133150

151+
if call.HasField('push_notification'):
152+
url = call.push_notification.url
153+
if not url:
154+
raise ValueError('URL not specified in push_notification behavior')
155+
if not url.startswith(('http://', 'https://')):
156+
url = f'http://{url}'
157+
config.push_notification_config = TaskPushNotificationConfig(
158+
url=f'{url}/notifications',
159+
token='itk-token',
160+
)
161+
134162
try:
135-
client = await create_client(call.agent_card_uri, client_config=config)
163+
client = await create_client(
164+
call.agent_card_uri,
165+
client_config=config,
166+
)
136167

137168
# Wrap nested instruction
138-
async with client:
139-
nested_msg = wrap_instruction_to_request(call.instruction)
140-
request = SendMessageRequest(message=nested_msg)
141-
142-
results: list[str] = []
143-
async for event in client.send_message(request):
144-
# Event is StreamResponse
145-
logger.info('Event: %s', event)
146-
stream_resp = event
147-
148-
message = None
149-
if stream_resp.HasField('message'):
150-
message = stream_resp.message
151-
elif stream_resp.HasField(
152-
'task'
153-
) and stream_resp.task.status.HasField('message'):
154-
message = stream_resp.task.status.message
155-
elif stream_resp.HasField(
156-
'status_update'
157-
) and stream_resp.status_update.status.HasField('message'):
158-
message = stream_resp.status_update.status.message
159-
160-
if message:
161-
results.extend(
162-
part.text for part in message.parts if part.text
163-
)
169+
nested_msg = wrap_instruction_to_request(call.instruction)
170+
request = SendMessageRequest(message=nested_msg)
171+
172+
173+
results = []
174+
async for event in client.send_message(request):
175+
# Event is streaming response and task
176+
logger.info('Event: %s', event)
177+
stream_resp = event
178+
179+
message = None
180+
if stream_resp.HasField('message'):
181+
message = stream_resp.message
182+
elif stream_resp.HasField('task') and stream_resp.task.status.HasField('message'):
183+
message = stream_resp.task.status.message
184+
elif stream_resp.HasField(
185+
'status_update'
186+
) and stream_resp.status_update.status.HasField('message'):
187+
message = stream_resp.status_update.status.message
188+
189+
if message:
190+
results.extend(part.text for part in message.parts if part.text)
164191

165192
except Exception as e:
166193
logger.exception('Failed to call outbound agent')
@@ -171,7 +198,9 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
171198
return results
172199

173200

174-
async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
201+
async def handle_instruction(
202+
inst: instruction_pb2.Instruction,
203+
) -> list[str]:
175204
"""Recursively handles instructions."""
176205
if inst.HasField('call_agent'):
177206
return await handle_call_agent(inst.call_agent)
@@ -302,34 +331,39 @@ async def main_async(http_port: int, grpc_port: int) -> None:
302331
name='ITK v10 Agent',
303332
description='Python agent using SDK 1.0.',
304333
version='1.0.0',
305-
capabilities=AgentCapabilities(
306-
streaming=True,
307-
push_notifications=True,
308-
extended_agent_card=True,
309-
),
334+
capabilities=AgentCapabilities(streaming=True),
310335
default_input_modes=['text/plain'],
311336
default_output_modes=['text/plain'],
312337
supported_interfaces=interfaces,
313338
)
314339

315340
task_store = InMemoryTaskStore()
341+
push_config_store = InMemoryPushNotificationConfigStore()
342+
push_sender = BasePushNotificationSender(
343+
httpx_client=httpx.AsyncClient(),
344+
config_store=push_config_store,
345+
context=ServerCallContext(),
346+
)
347+
316348
handler = DefaultRequestHandler(
317349
agent_executor=V10AgentExecutor(),
318-
task_store=task_store,
319350
agent_card=agent_card,
351+
task_store=task_store,
320352
queue_manager=InMemoryQueueManager(),
353+
push_config_store=push_config_store,
354+
push_sender=push_sender,
321355
)
322356

323357
handler_extended = DefaultRequestHandler(
324358
agent_executor=V10AgentExecutor(),
325-
task_store=task_store,
326359
agent_card=agent_card,
360+
task_store=task_store,
327361
queue_manager=InMemoryQueueManager(),
362+
push_config_store=push_config_store,
363+
push_sender=push_sender,
328364
extended_agent_card=agent_card,
329365
)
330366

331-
app = FastAPI()
332-
333367
agent_card_routes = create_agent_card_routes(
334368
agent_card=agent_card, card_url='/.well-known/agent-card.json'
335369
)
@@ -338,15 +372,16 @@ async def main_async(http_port: int, grpc_port: int) -> None:
338372
rpc_url='/',
339373
enable_v0_3_compat=True,
340374
)
341-
app.mount(
342-
'/jsonrpc',
343-
FastAPI(routes=jsonrpc_routes + agent_card_routes),
344-
)
345-
346375
rest_routes = create_rest_routes(
347376
request_handler=handler,
348377
enable_v0_3_compat=True,
349378
)
379+
380+
app = FastAPI()
381+
app.mount(
382+
'/jsonrpc',
383+
FastAPI(routes=jsonrpc_routes + agent_card_routes),
384+
)
350385
app.mount('/rest', FastAPI(routes=rest_routes + agent_card_routes))
351386

352387
server = grpc.aio.server()
@@ -365,9 +400,8 @@ async def main_async(http_port: int, grpc_port: int) -> None:
365400
grpc_port,
366401
)
367402

368-
uvicorn_log_level = os.environ.get('ITK_LOG_LEVEL', 'INFO').lower()
369403
config = uvicorn.Config(
370-
app, host='127.0.0.1', port=http_port, log_level=uvicorn_log_level
404+
app, host='127.0.0.1', port=http_port, log_level='info'
371405
)
372406
uvicorn_server = uvicorn.Server(config)
373407

itk/run_itk.sh

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,30 +119,50 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \
119119
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
120120
"traversal": "euler",
121121
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
122-
"protocols": ["jsonrpc", "grpc"]
122+
"protocols": ["jsonrpc", "grpc"],
123+
"behavior": "send_message"
123124
},
124125
{
125126
"name": "Star Topology (No Go v03) - HTTP_JSON",
126127
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
127128
"traversal": "euler",
128129
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
129-
"protocols": ["http_json"]
130+
"protocols": ["http_json"],
131+
"behavior": "send_message"
130132
},
131133
{
132134
"name": "Star Topology (Full) - JSONRPC & GRPC (Streaming)",
133135
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
134136
"traversal": "euler",
135137
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
136138
"protocols": ["jsonrpc", "grpc"],
137-
"streaming": true
139+
"streaming": true,
140+
"behavior": "send_message"
138141
},
139142
{
140143
"name": "Star Topology (No Go v03) - HTTP_JSON (Streaming)",
141144
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
142145
"traversal": "euler",
143146
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
144147
"protocols": ["http_json"],
145-
"streaming": true
148+
"streaming": true,
149+
"behavior": "send_message"
150+
},
151+
{
152+
"name": "Push Notification Test - JSONRPC & GRPC",
153+
"sdks": ["current", "python_v10", "python_v03", "go_v03"],
154+
"traversal": "euler",
155+
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
156+
"protocols": ["jsonrpc", "grpc"],
157+
"behavior": "push_notification"
158+
},
159+
{
160+
"name": "Push Notification Test - HTTP_JSON",
161+
"sdks": ["current", "python_v10", "python_v03"],
162+
"traversal": "euler",
163+
"edges": ["0->1", "0->2", "1->0", "2->0"],
164+
"protocols": ["http_json"],
165+
"behavior": "push_notification"
146166
}
147167
]
148168
}')

0 commit comments

Comments
 (0)