Skip to content

Commit 0d1cb60

Browse files
committed
CopyingTaskStoreAdapter.
1 parent 405be3f commit 0d1cb60

5 files changed

Lines changed: 460 additions & 4 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from typing import TYPE_CHECKING
6+
7+
8+
if TYPE_CHECKING:
9+
from a2a.server.context import ServerCallContext
10+
from a2a.server.tasks.task_store import TaskStore
11+
from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class CopyingTaskStoreAdapter(TaskStore):
18+
"""An adapter that ensures deep copies of tasks are passed to and returned from the underlying TaskStore.
19+
20+
This prevents accidental shared mutable state bugs where code modifies a Task object
21+
retrieved from the store without explicitly saving it, which hides missing save calls.
22+
"""
23+
24+
def __init__(self, underlying_store: TaskStore):
25+
self._store = underlying_store
26+
27+
async def save(
28+
self, task: Task, context: ServerCallContext | None = None
29+
) -> None:
30+
"""Saves a copy of the task to the underlying store."""
31+
task_copy = Task()
32+
task_copy.CopyFrom(task)
33+
await self._store.save(task_copy, context)
34+
35+
async def get(
36+
self, task_id: str, context: ServerCallContext | None = None
37+
) -> Task | None:
38+
"""Retrieves a task from the underlying store and returns a copy."""
39+
task = await self._store.get(task_id, context)
40+
if task is None:
41+
return None
42+
task_copy = Task()
43+
task_copy.CopyFrom(task)
44+
return task_copy
45+
46+
async def list(
47+
self,
48+
params: ListTasksRequest,
49+
context: ServerCallContext | None = None,
50+
) -> ListTasksResponse:
51+
"""Retrieves a list of tasks from the underlying store and returns a copy."""
52+
response = await self._store.list(params, context)
53+
response_copy = ListTasksResponse()
54+
response_copy.CopyFrom(response)
55+
return response_copy
56+
57+
async def delete(
58+
self, task_id: str, context: ServerCallContext | None = None
59+
) -> None:
60+
"""Deletes a task from the underlying store."""
61+
await self._store.delete(task_id, context)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from a2a.server.context import ServerCallContext
55
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
6+
from a2a.server.tasks.copying_task_store import CopyingTaskStoreAdapter
67
from a2a.server.tasks.task_store import TaskStore
78
from a2a.types import a2a_pb2
89
from a2a.types.a2a_pb2 import Task
@@ -14,8 +15,8 @@
1415
logger = logging.getLogger(__name__)
1516

1617

17-
class InMemoryTaskStore(TaskStore):
18-
"""In-memory implementation of TaskStore.
18+
class _InMemoryTaskStoreImpl(TaskStore):
19+
"""Internal In-memory implementation of TaskStore.
1920
2021
Stores task objects in a nested dictionary in memory, keyed by owner then task_id.
2122
Task data is lost when the server process stops.
@@ -25,8 +26,8 @@ def __init__(
2526
self,
2627
owner_resolver: OwnerResolver = resolve_user_scope,
2728
) -> None:
28-
"""Initializes the InMemoryTaskStore."""
29-
logger.debug('Initializing InMemoryTaskStore')
29+
"""Initializes the internal _InMemoryTaskStoreImpl."""
30+
logger.debug('Initializing _InMemoryTaskStoreImpl')
3031
self.tasks: dict[str, dict[str, Task]] = {}
3132
self.lock = asyncio.Lock()
3233
self.owner_resolver = owner_resolver
@@ -183,3 +184,55 @@ async def delete(
183184
if not owner_tasks:
184185
del self.tasks[owner]
185186
logger.debug('Removed empty owner %s from store.', owner)
187+
188+
189+
class InMemoryTaskStore(TaskStore):
190+
"""In-memory implementation of TaskStore.
191+
192+
Can optionally use CopyingTaskStoreAdapter to wrap the internal dictionary-based
193+
implementation, preventing shared mutable state issues by always returning and
194+
storing deep copies.
195+
"""
196+
197+
def __init__(
198+
self,
199+
owner_resolver: OwnerResolver = resolve_user_scope,
200+
use_copying: bool = True,
201+
) -> None:
202+
"""Initializes the InMemoryTaskStore.
203+
204+
Args:
205+
owner_resolver: Resolver for task owners.
206+
use_copying: If True, the store will return and save deep copies of tasks.
207+
Copying behavior is consistent with database task stores.
208+
"""
209+
self._impl = _InMemoryTaskStoreImpl(owner_resolver=owner_resolver)
210+
self._store: TaskStore = (
211+
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
212+
)
213+
214+
async def save(
215+
self, task: Task, context: ServerCallContext | None = None
216+
) -> None:
217+
"""Saves or updates a task in the store."""
218+
await self._store.save(task, context)
219+
220+
async def get(
221+
self, task_id: str, context: ServerCallContext | None = None
222+
) -> Task | None:
223+
"""Retrieves a task from the store by ID."""
224+
return await self._store.get(task_id, context)
225+
226+
async def list(
227+
self,
228+
params: a2a_pb2.ListTasksRequest,
229+
context: ServerCallContext | None = None,
230+
) -> a2a_pb2.ListTasksResponse:
231+
"""Retrieves a list of tasks from the store."""
232+
return await self._store.list(params, context)
233+
234+
async def delete(
235+
self, task_id: str, context: ServerCallContext | None = None
236+
) -> None:
237+
"""Deletes a task from the store by ID."""
238+
await self._store.delete(task_id, context)
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import httpx
2+
import pytest
3+
from typing import NamedTuple
4+
5+
from a2a.client.client import Client, ClientConfig
6+
from a2a.client.client_factory import ClientFactory
7+
from a2a.server.agent_execution import AgentExecutor, RequestContext
8+
from a2a.server.apps import A2AFastAPIApplication
9+
from a2a.server.events import EventQueue
10+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
11+
from a2a.server.request_handlers import DefaultRequestHandler
12+
from a2a.server.tasks import TaskUpdater
13+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
14+
from a2a.types import (
15+
AgentCapabilities,
16+
AgentCard,
17+
AgentInterface,
18+
Artifact,
19+
GetTaskRequest,
20+
Message,
21+
Part,
22+
Role,
23+
SendMessageRequest,
24+
TaskState,
25+
)
26+
from a2a.utils import TransportProtocol
27+
28+
29+
class MockMutatingAgentExecutor(AgentExecutor):
30+
async def execute(self, context: RequestContext, event_queue: EventQueue):
31+
assert context.task_id is not None
32+
assert context.context_id is not None
33+
task_updater = TaskUpdater(
34+
event_queue,
35+
context.task_id,
36+
context.context_id,
37+
)
38+
39+
user_input = context.get_user_input()
40+
41+
if user_input == 'Init task':
42+
# Explicitly save status change to ensure task exists with some state
43+
await task_updater.update_status(
44+
TaskState.TASK_STATE_WORKING,
45+
message=task_updater.new_agent_message(
46+
[Part(text='task working')]
47+
),
48+
)
49+
else:
50+
# Mutate the task WITHOUT saving it properly
51+
context.current_task.artifacts.append(
52+
Artifact(
53+
name='leaked-artifact',
54+
parts=[Part(text='leaked artifact')],
55+
)
56+
)
57+
58+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
59+
raise NotImplementedError('Cancellation is not supported')
60+
61+
62+
@pytest.fixture
63+
def agent_card() -> AgentCard:
64+
return AgentCard(
65+
name='Mutating Agent',
66+
description='Real in-memory integration testing.',
67+
version='1.0.0',
68+
capabilities=AgentCapabilities(
69+
streaming=True, push_notifications=False
70+
),
71+
skills=[],
72+
default_input_modes=['text/plain'],
73+
default_output_modes=['text/plain'],
74+
supported_interfaces=[
75+
AgentInterface(
76+
protocol_binding=TransportProtocol.JSONRPC,
77+
url='http://testserver',
78+
),
79+
],
80+
)
81+
82+
83+
class ClientSetup(NamedTuple):
84+
client: Client
85+
task_store: InMemoryTaskStore
86+
use_copying: bool
87+
88+
89+
def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup:
90+
task_store = InMemoryTaskStore(use_copying=use_copying)
91+
handler = DefaultRequestHandler(
92+
agent_executor=MockMutatingAgentExecutor(),
93+
task_store=task_store,
94+
queue_manager=InMemoryQueueManager(),
95+
)
96+
app_builder = A2AFastAPIApplication(
97+
agent_card, handler, extended_agent_card=agent_card
98+
)
99+
app = app_builder.build()
100+
httpx_client = httpx.AsyncClient(
101+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
102+
)
103+
factory = ClientFactory(
104+
config=ClientConfig(
105+
httpx_client=httpx_client,
106+
supported_protocol_bindings=[TransportProtocol.JSONRPC],
107+
)
108+
)
109+
client = factory.create(agent_card)
110+
return ClientSetup(
111+
client=client,
112+
task_store=task_store,
113+
use_copying=use_copying,
114+
)
115+
116+
117+
@pytest.mark.asyncio
118+
@pytest.mark.parametrize('use_copying', [True, False])
119+
async def test_mutation_observability(agent_card: AgentCard, use_copying: bool):
120+
"""Tests that task mutations are observable when copying is disabled.
121+
122+
When copying is disabled, the agent mutates the task in-place and the
123+
changes are observable by the client. When copying is enabled, the agent
124+
mutates a copy of the task and the changes are not observable by the client.
125+
126+
It is ok to remove the `use_copying` parameter from the system in the future
127+
to make InMemoryTaskStore consistent with other task stores.
128+
"""
129+
client_setup = setup_client(agent_card, use_copying)
130+
client = client_setup.client
131+
132+
# 1. First message to create the task
133+
message_to_send = Message(
134+
role=Role.ROLE_USER,
135+
message_id='msg-mut-init',
136+
parts=[Part(text='Init task')],
137+
)
138+
139+
events = [
140+
event
141+
async for event in client.send_message(
142+
request=SendMessageRequest(message=message_to_send)
143+
)
144+
]
145+
146+
task = events[-1][1]
147+
assert task is not None
148+
task_id = task.id
149+
150+
# 2. Second message to mutate it
151+
message_to_send_2 = Message(
152+
role=Role.ROLE_USER,
153+
message_id='msg-mut-do',
154+
task_id=task_id,
155+
parts=[Part(text='Update task without saving it')],
156+
)
157+
158+
_ = [
159+
event
160+
async for event in client.send_message(
161+
request=SendMessageRequest(message=message_to_send_2)
162+
)
163+
]
164+
165+
# 3. Get task via client
166+
retrieved_task = await client.get_task(request=GetTaskRequest(id=task_id))
167+
168+
# 4. Assert behavior based on `use_copying`
169+
if use_copying:
170+
# The un-saved artifact IS NOT leaked to the client
171+
assert len(retrieved_task.artifacts) == 0
172+
else:
173+
# The un-saved artifact IS leaked to the client
174+
assert len(retrieved_task.artifacts) == 1
175+
assert retrieved_task.artifacts[0].name == 'leaked-artifact'

0 commit comments

Comments
 (0)