Skip to content

Commit 9e018a1

Browse files
committed
Merge remote-tracking branch 'upstream/1.0-dev' into guglielmoc/refactor_rest_server
2 parents a8b3d0d + 8c65e84 commit 9e018a1

6 files changed

Lines changed: 544 additions & 17 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from typing import TYPE_CHECKING
6+
7+
8+
if TYPE_CHECKING:
9+
from a2a.server.context import ServerCallContext
10+
from a2a.server.tasks.task_store import TaskStore
11+
from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class CopyingTaskStoreAdapter(TaskStore):
18+
"""An adapter that ensures deep copies of tasks are passed to and returned from the underlying TaskStore.
19+
20+
This prevents accidental shared mutable state bugs where code modifies a Task object
21+
retrieved from the store without explicitly saving it, which hides missing save calls.
22+
"""
23+
24+
def __init__(self, underlying_store: TaskStore):
25+
self._store = underlying_store
26+
27+
async def save(
28+
self, task: Task, context: ServerCallContext | None = None
29+
) -> None:
30+
"""Saves a copy of the task to the underlying store."""
31+
task_copy = Task()
32+
task_copy.CopyFrom(task)
33+
await self._store.save(task_copy, context)
34+
35+
async def get(
36+
self, task_id: str, context: ServerCallContext | None = None
37+
) -> Task | None:
38+
"""Retrieves a task from the underlying store and returns a copy."""
39+
task = await self._store.get(task_id, context)
40+
if task is None:
41+
return None
42+
task_copy = Task()
43+
task_copy.CopyFrom(task)
44+
return task_copy
45+
46+
async def list(
47+
self,
48+
params: ListTasksRequest,
49+
context: ServerCallContext | None = None,
50+
) -> ListTasksResponse:
51+
"""Retrieves a list of tasks from the underlying store and returns a copy."""
52+
response = await self._store.list(params, context)
53+
response_copy = ListTasksResponse()
54+
response_copy.CopyFrom(response)
55+
return response_copy
56+
57+
async def delete(
58+
self, task_id: str, context: ServerCallContext | None = None
59+
) -> None:
60+
"""Deletes a task from the underlying store."""
61+
await self._store.delete(task_id, context)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from a2a.server.context import ServerCallContext
55
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
6+
from a2a.server.tasks.copying_task_store import CopyingTaskStoreAdapter
67
from a2a.server.tasks.task_store import TaskStore
78
from a2a.types import a2a_pb2
89
from a2a.types.a2a_pb2 import Task
@@ -14,8 +15,8 @@
1415
logger = logging.getLogger(__name__)
1516

1617

17-
class InMemoryTaskStore(TaskStore):
18-
"""In-memory implementation of TaskStore.
18+
class _InMemoryTaskStoreImpl(TaskStore):
19+
"""Internal In-memory implementation of TaskStore.
1920
2021
Stores task objects in a nested dictionary in memory, keyed by owner then task_id.
2122
Task data is lost when the server process stops.
@@ -25,8 +26,8 @@ def __init__(
2526
self,
2627
owner_resolver: OwnerResolver = resolve_user_scope,
2728
) -> None:
28-
"""Initializes the InMemoryTaskStore."""
29-
logger.debug('Initializing InMemoryTaskStore')
29+
"""Initializes the internal _InMemoryTaskStoreImpl."""
30+
logger.debug('Initializing _InMemoryTaskStoreImpl')
3031
self.tasks: dict[str, dict[str, Task]] = {}
3132
self.lock = asyncio.Lock()
3233
self.owner_resolver = owner_resolver
@@ -183,3 +184,55 @@ async def delete(
183184
if not owner_tasks:
184185
del self.tasks[owner]
185186
logger.debug('Removed empty owner %s from store.', owner)
187+
188+
189+
class InMemoryTaskStore(TaskStore):
190+
"""In-memory implementation of TaskStore.
191+
192+
Can optionally use CopyingTaskStoreAdapter to wrap the internal dictionary-based
193+
implementation, preventing shared mutable state issues by always returning and
194+
storing deep copies.
195+
"""
196+
197+
def __init__(
198+
self,
199+
owner_resolver: OwnerResolver = resolve_user_scope,
200+
use_copying: bool = True,
201+
) -> None:
202+
"""Initializes the InMemoryTaskStore.
203+
204+
Args:
205+
owner_resolver: Resolver for task owners.
206+
use_copying: If True, the store will return and save deep copies of tasks.
207+
Copying behavior is consistent with database task stores.
208+
"""
209+
self._impl = _InMemoryTaskStoreImpl(owner_resolver=owner_resolver)
210+
self._store: TaskStore = (
211+
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
212+
)
213+
214+
async def save(
215+
self, task: Task, context: ServerCallContext | None = None
216+
) -> None:
217+
"""Saves or updates a task in the store."""
218+
await self._store.save(task, context)
219+
220+
async def get(
221+
self, task_id: str, context: ServerCallContext | None = None
222+
) -> Task | None:
223+
"""Retrieves a task from the store by ID."""
224+
return await self._store.get(task_id, context)
225+
226+
async def list(
227+
self,
228+
params: a2a_pb2.ListTasksRequest,
229+
context: ServerCallContext | None = None,
230+
) -> a2a_pb2.ListTasksResponse:
231+
"""Retrieves a list of tasks from the store."""
232+
return await self._store.list(params, context)
233+
234+
async def delete(
235+
self, task_id: str, context: ServerCallContext | None = None
236+
) -> None:
237+
"""Deletes a task from the store by ID."""
238+
await self._store.delete(task_id, context)

tests/integration/test_client_server_integration.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
32
from collections.abc import AsyncGenerator
43
from typing import Any, NamedTuple
54
from unittest.mock import ANY, AsyncMock, patch
@@ -8,22 +7,25 @@
87
import httpx
98
import pytest
109
import pytest_asyncio
11-
1210
from cryptography.hazmat.primitives.asymmetric import ec
1311
from google.protobuf.json_format import MessageToDict
1412
from google.protobuf.timestamp_pb2 import Timestamp
1513

1614
from a2a.client import Client, ClientConfig
1715
from a2a.client.base_client import BaseClient
1816
from a2a.client.card_resolver import A2ACardResolver
19-
from a2a.client.client_factory import ClientFactory
2017
from a2a.client.client import ClientCallContext
18+
from a2a.client.client_factory import ClientFactory
2119
from a2a.client.service_parameters import (
2220
ServiceParametersFactory,
2321
with_a2a_extensions,
2422
)
2523
from a2a.client.transports import JsonRpcTransport, RestTransport
2624
from starlette.applications import Starlette
25+
26+
# Compat v0.3 imports for dedicated tests
27+
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
28+
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
2729
from a2a.server.routes import (
2830
create_agent_card_routes,
2931
create_jsonrpc_routes,
@@ -55,12 +57,10 @@
5557
TaskStatus,
5658
TaskStatusUpdateEvent,
5759
)
58-
from a2a.utils.constants import (
59-
TransportProtocol,
60-
)
60+
from a2a.utils.constants import TransportProtocol
6161
from a2a.utils.errors import (
62-
ExtendedAgentCardNotConfiguredError,
6362
ContentTypeNotSupportedError,
63+
ExtendedAgentCardNotConfiguredError,
6464
ExtensionSupportRequiredError,
6565
InternalError,
6666
InvalidAgentResponseError,
@@ -78,11 +78,6 @@
7878
create_signature_verifier,
7979
)
8080

81-
# Compat v0.3 imports for dedicated tests
82-
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
83-
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
84-
85-
8681
# --- Test Constants ---
8782

8883
TASK_FROM_STREAM = Task(
@@ -374,9 +369,9 @@ def grpc_03_setup(
374369
) -> TransportSetup:
375370
"""Sets up the CompatGrpcTransport and in-process 0.3 server."""
376371
server_address, handler = grpc_03_server_and_handler
377-
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
378372
from a2a.client.base_client import BaseClient
379373
from a2a.client.client import ClientConfig
374+
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
380375

381376
channel = grpc.aio.insecure_channel(server_address)
382377
transport = CompatGrpcTransport(channel=channel, agent_card=agent_card)
@@ -932,6 +927,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None:
932927
await client.close()
933928

934929

930+
@pytest.mark.asyncio
931+
@pytest.mark.parametrize(
932+
'error_cls',
933+
[
934+
TaskNotFoundError,
935+
TaskNotCancelableError,
936+
PushNotificationNotSupportedError,
937+
UnsupportedOperationError,
938+
ContentTypeNotSupportedError,
939+
InvalidAgentResponseError,
940+
ExtendedAgentCardNotConfiguredError,
941+
ExtensionSupportRequiredError,
942+
VersionNotSupportedError,
943+
],
944+
)
945+
@pytest.mark.parametrize(
946+
'handler_attr, client_method, request_params',
947+
[
948+
pytest.param(
949+
'on_message_send_stream',
950+
'send_message',
951+
SendMessageRequest(
952+
message=Message(
953+
role=Role.ROLE_USER,
954+
message_id='msg-integration-test',
955+
parts=[Part(text='Hello, integration test!')],
956+
)
957+
),
958+
id='stream',
959+
),
960+
pytest.param(
961+
'on_subscribe_to_task',
962+
'subscribe',
963+
SubscribeToTaskRequest(id='some-id'),
964+
id='subscribe',
965+
),
966+
],
967+
)
968+
async def test_client_handles_a2a_errors_streaming(
969+
transport_setups, error_cls, handler_attr, client_method, request_params
970+
) -> None:
971+
"""Integration test to verify error propagation from streaming handlers to client.
972+
973+
The handler raises an A2AError before yielding any events. All transports
974+
must propagate this as the exact error_cls, not wrapped in an ExceptionGroup
975+
or converted to a generic client error.
976+
"""
977+
client = transport_setups.client
978+
handler = transport_setups.handler
979+
980+
async def mock_generator(*args, **kwargs):
981+
raise error_cls('Test error message')
982+
yield
983+
984+
getattr(handler, handler_attr).side_effect = mock_generator
985+
986+
with pytest.raises(error_cls) as exc_info:
987+
async for _ in getattr(client, client_method)(request=request_params):
988+
pass
989+
990+
assert 'Test error message' in str(exc_info.value)
991+
992+
getattr(handler, handler_attr).side_effect = None
993+
994+
await client.close()
995+
996+
935997
@pytest.mark.asyncio
936998
@pytest.mark.parametrize(
937999
'request_kwargs, expected_error_code',

0 commit comments

Comments
 (0)