-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathtest_copying_observability.py
More file actions
190 lines (166 loc) · 6.05 KB
/
test_copying_observability.py
File metadata and controls
190 lines (166 loc) · 6.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import httpx
import pytest
from typing import NamedTuple
from starlette.applications import Starlette
from a2a.client.client import Client, ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import TaskUpdater
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentInterface,
Artifact,
GetTaskRequest,
Message,
Part,
Role,
SendMessageRequest,
TaskState,
)
from a2a.utils import TransportProtocol
from a2a.utils.task import new_task
class MockMutatingAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
assert context.task_id is not None
assert context.context_id is not None
task_updater = TaskUpdater(
event_queue,
context.task_id,
context.context_id,
)
user_input = context.get_user_input()
if user_input == 'Init task':
# Explicitly save status change to ensure task exists with some state
task = new_task(context.message)
task.id = context.task_id
task.context_id = context.context_id
task.status.state = TaskState.TASK_STATE_WORKING
await event_queue.enqueue_event(task)
await task_updater.update_status(
TaskState.TASK_STATE_WORKING,
message=task_updater.new_agent_message(
[Part(text='task working')]
),
)
else:
# Mutate the task WITHOUT saving it properly
assert context.current_task is not None
context.current_task.artifacts.append(
Artifact(
name='leaked-artifact',
parts=[Part(text='leaked artifact')],
)
)
async def cancel(self, context: RequestContext, event_queue: EventQueue):
raise NotImplementedError('Cancellation is not supported')
@pytest.fixture
def agent_card() -> AgentCard:
return AgentCard(
name='Mutating Agent',
description='Real in-memory integration testing.',
version='1.0.0',
capabilities=AgentCapabilities(
streaming=True, push_notifications=False
),
skills=[],
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=[
AgentInterface(
protocol_binding=TransportProtocol.JSONRPC,
url='http://testserver',
),
],
)
class ClientSetup(NamedTuple):
client: Client
task_store: InMemoryTaskStore
use_copying: bool
def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup:
task_store = InMemoryTaskStore(use_copying=use_copying)
handler = DefaultRequestHandler(
agent_executor=MockMutatingAgentExecutor(),
task_store=task_store,
agent_card=agent_card,
queue_manager=InMemoryQueueManager(),
extended_agent_card=agent_card,
)
agent_card_routes = create_agent_card_routes(
agent_card=agent_card, card_url='/'
)
jsonrpc_routes = create_jsonrpc_routes(
request_handler=handler,
rpc_url='/',
)
app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes])
httpx_client = httpx.AsyncClient(
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
)
factory = ClientFactory(
config=ClientConfig(
httpx_client=httpx_client,
supported_protocol_bindings=[TransportProtocol.JSONRPC],
)
)
client = factory.create(agent_card)
return ClientSetup(
client=client,
task_store=task_store,
use_copying=use_copying,
)
@pytest.mark.asyncio
@pytest.mark.parametrize('use_copying', [True, False])
async def test_mutation_observability(agent_card: AgentCard, use_copying: bool):
"""Tests that task mutations are observable when copying is disabled.
When copying is disabled, the agent mutates the task in-place and the
changes are observable by the client. When copying is enabled, the agent
mutates a copy of the task and the changes are not observable by the client.
It is ok to remove the `use_copying` parameter from the system in the future
to make InMemoryTaskStore consistent with other task stores.
"""
client_setup = setup_client(agent_card, use_copying)
client = client_setup.client
# 1. First message to create the task
message_to_send = Message(
role=Role.ROLE_USER,
message_id='msg-mut-init',
parts=[Part(text='Init task')],
)
events = [
event
async for event in client.send_message(
request=SendMessageRequest(message=message_to_send)
)
]
event = events[-1]
assert event.HasField('status_update')
task_id = event.status_update.task_id
# 2. Second message to mutate it
message_to_send_2 = Message(
role=Role.ROLE_USER,
message_id='msg-mut-do',
task_id=task_id,
parts=[Part(text='Update task without saving it')],
)
_ = [
event
async for event in client.send_message(
request=SendMessageRequest(message=message_to_send_2)
)
]
# 3. Get task via client
retrieved_task = await client.get_task(request=GetTaskRequest(id=task_id))
# 4. Assert behavior based on `use_copying`
if use_copying:
# The un-saved artifact IS NOT leaked to the client
assert len(retrieved_task.artifacts) == 0
else:
# The un-saved artifact IS leaked to the client
assert len(retrieved_task.artifacts) == 1
assert retrieved_task.artifacts[0].name == 'leaked-artifact'