Skip to content

Commit 65e1068

Browse files
committed
Refactor database stores for robust custom conversion handling and improve v0.3 compatibility.
1 parent 3a31d97 commit 65e1068

5 files changed

Lines changed: 56 additions & 50 deletions

File tree

src/a2a/compat/v0_3/conversions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,9 +1395,7 @@ def compat_task_model_to_core(task_model: TaskModel) -> pb2_v10.Task:
13951395
compat_task = types_v03.Task(
13961396
id=task_model.id,
13971397
context_id=task_model.context_id,
1398-
status=types_v03.TaskStatus.model_validate(task_model.status)
1399-
if task_model.status
1400-
else None,
1398+
status=types_v03.TaskStatus.model_validate(task_model.status),
14011399
artifacts=(
14021400
[types_v03.Artifact.model_validate(a) for a in task_model.artifacts]
14031401
if task_model.artifacts

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# ruff: noqa: PLC0415
2+
import inspect
23
import logging
34

45
from typing import TYPE_CHECKING
@@ -13,9 +14,7 @@
1314
AsyncSession,
1415
async_sessionmaker,
1516
)
16-
from sqlalchemy.orm import (
17-
class_mapper,
18-
)
17+
from sqlalchemy.orm import class_mapper
1918
except ImportError as e:
2019
raise ImportError(
2120
'DatabasePushNotificationConfigStore requires SQLAlchemy and a database driver. '
@@ -157,12 +156,11 @@ def _to_orm(
157156
158157
The config data is serialized to JSON bytes, and encrypted if a key is configured.
159158
"""
160-
if self.core_to_model_conversion:
161-
conversion = self.core_to_model_conversion
162-
# bound method
163-
if hasattr(conversion, '__func__'):
159+
if conversion := self.core_to_model_conversion:
160+
# If it's a bound method of this instance, call the underlying function
161+
# to avoid passing 'self' twice.
162+
if inspect.ismethod(conversion):
164163
return conversion.__func__(task_id, config, owner, self._fernet)
165-
# instance method
166164
return conversion(task_id, config, owner, self._fernet)
167165

168166
json_payload = MessageToJson(config).encode('utf-8')
@@ -187,12 +185,11 @@ def _from_orm(
187185
188186
Handles decryption if a key is configured, with a fallback to plain JSON.
189187
"""
190-
if self.model_to_core_conversion:
191-
conversion = self.model_to_core_conversion
192-
# bound method
193-
if hasattr(conversion, '__func__'):
188+
if conversion := self.model_to_core_conversion:
189+
# If it's a bound method of this instance, call the underlying function
190+
# to avoid passing 'self' twice.
191+
if inspect.ismethod(conversion):
194192
return conversion.__func__(model_instance)
195-
# instance method
196193
return conversion(model_instance)
197194

198195
payload = model_instance.config_data

src/a2a/server/tasks/database_task_store.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
1+
import inspect
12
import logging
23

34
from datetime import datetime, timezone
45
from typing import TYPE_CHECKING
56

67

78
try:
8-
from sqlalchemy import (
9-
Table,
10-
and_,
11-
delete,
12-
func,
13-
or_,
14-
select,
15-
)
9+
from sqlalchemy import Table, and_, delete, func, or_, select
1610
from sqlalchemy.ext.asyncio import (
1711
AsyncEngine,
1812
AsyncSession,
1913
async_sessionmaker,
2014
)
21-
from sqlalchemy.orm import (
22-
class_mapper,
23-
)
15+
from sqlalchemy.orm import class_mapper
2416
except ImportError as e:
2517
raise ImportError(
2618
'DatabaseTaskStore requires SQLAlchemy and a database driver. '
@@ -125,12 +117,11 @@ async def _ensure_initialized(self) -> None:
125117

126118
def _to_orm(self, task: Task, owner: str) -> TaskModel:
127119
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
128-
if self.core_to_model_conversion:
129-
conversion = self.core_to_model_conversion
130-
# bound method
131-
if hasattr(conversion, '__func__'):
120+
if conversion := self.core_to_model_conversion:
121+
# If it's a bound method of this instance, call the underlying function
122+
# to avoid passing 'self' twice.
123+
if inspect.ismethod(conversion):
132124
return conversion.__func__(task, owner)
133-
# instance method
134125
return conversion(task, owner)
135126

136127
return self.task_model(
@@ -154,12 +145,11 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
154145

155146
def _from_orm(self, task_model: TaskModel) -> Task:
156147
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
157-
if self.model_to_core_conversion:
158-
conversion = self.model_to_core_conversion
159-
# bound method
160-
if hasattr(conversion, '__func__'):
148+
if conversion := self.model_to_core_conversion:
149+
# If it's a bound method of this instance, call the underlying function
150+
# to avoid passing 'self' twice.
151+
if inspect.ismethod(conversion):
161152
return conversion.__func__(task_model)
162-
# instance method
163153
return conversion(task_model)
164154

165155
if task_model.protocol_version == '1.0':

tests/server/tasks/test_database_push_notification_config_store.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -820,21 +820,29 @@ async def test_custom_conversion():
820820
DatabasePushNotificationConfigStore.model_to_core_conversion = None
821821

822822

823+
@pytest.mark.parametrize('assignment_type', ['class', 'instance'])
823824
@pytest.mark.asyncio
824825
async def test_core_to_0_3_model_conversion(
825826
db_store_parameterized: DatabasePushNotificationConfigStore,
827+
assignment_type: str,
826828
) -> None:
827829
"""Test storing and retrieving push notification configs in v0.3 format using conversion utilities.
828830
831+
Tests both class-level and instance-level assignment of the conversion function.
829832
Setting the model_to_core_conversion to compat_push_notification_config_model_to_core would be redundant as
830833
it is always called when retrieving 0.3 PushNotificationConfigs.
831834
"""
832835
store = db_store_parameterized
833836

834837
# Set the v0.3 persistence utilities
835-
DatabasePushNotificationConfigStore.core_to_model_conversion = (
836-
core_to_compat_push_notification_config_model
837-
)
838+
if assignment_type == 'class':
839+
DatabasePushNotificationConfigStore.core_to_model_conversion = (
840+
core_to_compat_push_notification_config_model
841+
)
842+
else:
843+
store.core_to_model_conversion = (
844+
core_to_compat_push_notification_config_model
845+
)
838846

839847
try:
840848
task_id = 'v03-persistence-task'
@@ -877,6 +885,10 @@ async def test_core_to_0_3_model_conversion(
877885
assert retrieved.token == original_config.token
878886

879887
finally:
880-
# Reset class variables
881-
DatabasePushNotificationConfigStore.core_to_model_conversion = None
888+
# Reset conversion attributes
889+
if assignment_type == 'class':
890+
DatabasePushNotificationConfigStore.core_to_model_conversion = None
891+
else:
892+
store.core_to_model_conversion = None
893+
882894
await store.delete_info(task_id, MINIMAL_CALL_CONTEXT)

tests/server/tasks/test_database_task_store.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -830,9 +830,6 @@ async def test_get_0_3_task_detailed(
830830
await db_store_parameterized.delete(task_id, context_user)
831831

832832

833-
# Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml).
834-
835-
836833
@pytest.mark.asyncio
837834
async def test_custom_conversion():
838835
engine = MagicMock()
@@ -863,19 +860,25 @@ async def test_custom_conversion():
863860
DatabaseTaskStore.model_to_core_conversion = None
864861

865862

863+
@pytest.mark.parametrize('assignment_type', ['class', 'instance'])
866864
@pytest.mark.asyncio
867865
async def test_core_to_0_3_model_conversion(
868866
db_store_parameterized: DatabaseTaskStore,
867+
assignment_type: str,
869868
) -> None:
870869
"""Test storing and retrieving tasks in v0.3 format using conversion utilities.
871870
872-
Setting the model_to_core_conversion class variables to compat_task_model_to_core is redundant
871+
Tests both class-level and instance-level assignment of the conversion function.
872+
Setting the model_to_core_conversion class variables to compat_task_model_to_core would be redundant
873873
as it is always called when retrieving 0.3 tasks.
874874
"""
875875
store = db_store_parameterized
876876

877877
# Set the v0.3 persistence utilities
878-
DatabaseTaskStore.core_to_model_conversion = core_to_compat_task_model
878+
if assignment_type == 'class':
879+
DatabaseTaskStore.core_to_model_conversion = core_to_compat_task_model
880+
else:
881+
store.core_to_model_conversion = core_to_compat_task_model
879882

880883
try:
881884
task_id = 'v03-persistence-task'
@@ -905,7 +908,13 @@ async def test_core_to_0_3_model_conversion(
905908
assert dict(retrieved_task.metadata) == {'key': 'value'}
906909

907910
finally:
908-
# Reset class variables
909-
DatabaseTaskStore.core_to_model_conversion = None
910-
DatabaseTaskStore.model_to_core_conversion = None
911+
# Reset conversion attributes
912+
if assignment_type == 'class':
913+
DatabaseTaskStore.core_to_model_conversion = None
914+
else:
915+
store.core_to_model_conversion = None
916+
911917
await store.delete('v03-persistence-task')
918+
919+
920+
# Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml).

0 commit comments

Comments
 (0)