Skip to content

Commit 97b82e2

Browse files
committed
Add conversion functions as constructor parameters
1 parent 7686573 commit 97b82e2

4 files changed

Lines changed: 89 additions & 106 deletions

File tree

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa: PLC0415
2-
import inspect
32
import logging
43

54
from typing import TYPE_CHECKING
@@ -25,8 +24,7 @@
2524
"or 'pip install a2a-sdk[sql]'"
2625
) from e
2726

28-
if TYPE_CHECKING:
29-
from collections.abc import Callable
27+
from collections.abc import Callable
3028

3129
from a2a.compat.v0_3.conversions import (
3230
compat_push_notification_config_model_to_core,
@@ -62,17 +60,34 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
6260
config_model: type[PushNotificationConfigModel]
6361
_fernet: 'Fernet | None'
6462
owner_resolver: OwnerResolver
63+
core_to_model_conversion: (
64+
Callable[
65+
[str, TaskPushNotificationConfig, str, 'Fernet | None'],
66+
PushNotificationConfigModel,
67+
]
68+
| None
69+
)
70+
model_to_core_conversion: (
71+
Callable[[PushNotificationConfigModel], TaskPushNotificationConfig]
72+
| None
73+
)
6574

66-
core_to_model_conversion: 'Callable[[str, TaskPushNotificationConfig, str, Fernet | None], PushNotificationConfigModel] | None' = None
67-
model_to_core_conversion: 'Callable[[PushNotificationConfigModel], TaskPushNotificationConfig] | None' = None
68-
69-
def __init__(
75+
def __init__( # noqa: PLR0913
7076
self,
7177
engine: AsyncEngine,
7278
create_table: bool = True,
7379
table_name: str = 'push_notification_configs',
7480
encryption_key: str | bytes | None = None,
7581
owner_resolver: OwnerResolver = resolve_user_scope,
82+
core_to_model_conversion: Callable[
83+
[str, TaskPushNotificationConfig, str, 'Fernet | None'],
84+
PushNotificationConfigModel,
85+
]
86+
| None = None,
87+
model_to_core_conversion: Callable[
88+
[PushNotificationConfigModel], TaskPushNotificationConfig
89+
]
90+
| None = None,
7691
) -> None:
7792
"""Initializes the DatabasePushNotificationConfigStore.
7893
@@ -84,6 +99,8 @@ def __init__(
8499
If provided, `config_data` will be encrypted in the database.
85100
The key must be a URL-safe base64-encoded 32-byte key.
86101
owner_resolver: Function to resolve the owner from the context.
102+
core_to_model_conversion: Optional function to convert a TaskPushNotificationConfig to a TaskPushNotificationConfigModel.
103+
model_to_core_conversion: Optional function to convert a TaskPushNotificationConfigModel to a TaskPushNotificationConfig.
87104
"""
88105
logger.debug(
89106
'Initializing DatabasePushNotificationConfigStore with existing engine, table: %s',
@@ -102,6 +119,8 @@ def __init__(
102119
else create_push_notification_config_model(table_name)
103120
)
104121
self._fernet = None
122+
self.core_to_model_conversion = core_to_model_conversion
123+
self.model_to_core_conversion = model_to_core_conversion
105124

106125
if encryption_key:
107126
try:
@@ -156,12 +175,10 @@ def _to_orm(
156175
157176
The config data is serialized to JSON bytes, and encrypted if a key is configured.
158177
"""
159-
if conversion := self.core_to_model_conversion:
160-
# If it's a bound method of this instance, call the underlying function
161-
# to avoid passing 'self' twice.
162-
if inspect.ismethod(conversion):
163-
return conversion.__func__(task_id, config, owner, self._fernet)
164-
return conversion(task_id, config, owner, self._fernet)
178+
if self.core_to_model_conversion:
179+
return self.core_to_model_conversion(
180+
task_id, config, owner, self._fernet
181+
)
165182

166183
json_payload = MessageToJson(config).encode('utf-8')
167184

@@ -185,12 +202,8 @@ def _from_orm(
185202
186203
Handles decryption if a key is configured, with a fallback to plain JSON.
187204
"""
188-
if conversion := self.model_to_core_conversion:
189-
# If it's a bound method of this instance, call the underlying function
190-
# to avoid passing 'self' twice.
191-
if inspect.ismethod(conversion):
192-
return conversion.__func__(model_instance)
193-
return conversion(model_instance)
205+
if self.model_to_core_conversion:
206+
return self.model_to_core_conversion(model_instance)
194207

195208
payload = model_instance.config_data
196209

src/a2a/server/tasks/database_task_store.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import inspect
21
import logging
32

3+
from collections.abc import Callable
44
from datetime import datetime, timezone
5-
from typing import TYPE_CHECKING
65

76

87
try:
@@ -22,13 +21,11 @@
2221
"'pip install a2a-sdk[sqlite]', "
2322
"or 'pip install a2a-sdk[sql]'"
2423
) from e
25-
26-
if TYPE_CHECKING:
27-
from collections.abc import Callable
28-
2924
from google.protobuf.json_format import MessageToDict, ParseDict
3025

31-
from a2a.compat.v0_3 import conversions
26+
from a2a.compat.v0_3.conversions import (
27+
compat_task_model_to_core,
28+
)
3229
from a2a.server.context import ServerCallContext
3330
from a2a.server.models import Base, TaskModel, create_task_model
3431
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -55,16 +52,18 @@ class DatabaseTaskStore(TaskStore):
5552
_initialized: bool
5653
task_model: type[TaskModel]
5754
owner_resolver: OwnerResolver
55+
core_to_model_conversion: Callable[[Task, str], TaskModel] | None = None
56+
model_to_core_conversion: Callable[[TaskModel], Task] | None = None
5857

59-
core_to_model_conversion: 'Callable[[Task, str], TaskModel] | None' = None
60-
model_to_core_conversion: 'Callable[[TaskModel], Task] | None' = None
61-
62-
def __init__(
58+
def __init__( # noqa: PLR0913
6359
self,
6460
engine: AsyncEngine,
6561
create_table: bool = True,
6662
table_name: str = 'tasks',
6763
owner_resolver: OwnerResolver = resolve_user_scope,
64+
core_to_model_conversion: Callable[[Task, str], TaskModel]
65+
| None = None,
66+
model_to_core_conversion: Callable[[TaskModel], Task] | None = None,
6867
) -> None:
6968
"""Initializes the DatabaseTaskStore.
7069
@@ -73,6 +72,8 @@ def __init__(
7372
create_table: If true, create tasks table on initialization.
7473
table_name: Name of the database table. Defaults to 'tasks'.
7574
owner_resolver: Function to resolve the owner from the context.
75+
core_to_model_conversion: Optional function to convert a Task to a TaskModel.
76+
model_to_core_conversion: Optional function to convert a TaskModel to a Task.
7677
"""
7778
logger.debug(
7879
'Initializing DatabaseTaskStore with existing engine, table: %s',
@@ -85,6 +86,8 @@ def __init__(
8586
self.create_table = create_table
8687
self._initialized = False
8788
self.owner_resolver = owner_resolver
89+
self.core_to_model_conversion = core_to_model_conversion
90+
self.model_to_core_conversion = model_to_core_conversion
8891

8992
self.task_model = (
9093
TaskModel
@@ -117,12 +120,8 @@ async def _ensure_initialized(self) -> None:
117120

118121
def _to_orm(self, task: Task, owner: str) -> TaskModel:
119122
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
120-
if conversion := self.core_to_model_conversion:
121-
# If it's a bound method of this instance, call the underlying function
122-
# to avoid passing 'self' twice.
123-
if inspect.ismethod(conversion):
124-
return conversion.__func__(task, owner)
125-
return conversion(task, owner)
123+
if self.core_to_model_conversion:
124+
return self.core_to_model_conversion(task, owner)
126125

127126
return self.task_model(
128127
id=task.id,
@@ -145,12 +144,8 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
145144

146145
def _from_orm(self, task_model: TaskModel) -> Task:
147146
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
148-
if conversion := self.model_to_core_conversion:
149-
# If it's a bound method of this instance, call the underlying function
150-
# to avoid passing 'self' twice.
151-
if inspect.ismethod(conversion):
152-
return conversion.__func__(task_model)
153-
return conversion(task_model)
147+
if self.model_to_core_conversion:
148+
return self.model_to_core_conversion(task_model)
154149

155150
if task_model.protocol_version == '1.0':
156151
task = Task(
@@ -172,7 +167,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
172167
return task
173168

174169
# Legacy conversion
175-
return conversions.compat_task_model_to_core(task_model)
170+
return compat_task_model_to_core(task_model)
176171

177172
async def save(
178173
self, task: Task, context: ServerCallContext | None = None

tests/server/tasks/test_database_push_notification_config_store.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,6 @@ async def test_get_0_3_push_notification_config_detailed(
788788
@pytest.mark.asyncio
789789
async def test_custom_conversion():
790790
engine = MagicMock()
791-
store = DatabasePushNotificationConfigStore(engine=engine)
792791

793792
# Custom callables
794793
mock_to_orm = MagicMock(
@@ -797,33 +796,26 @@ async def test_custom_conversion():
797796
mock_from_orm = MagicMock(
798797
return_value=TaskPushNotificationConfig(id='custom_config')
799798
)
799+
store = DatabasePushNotificationConfigStore(
800+
engine=engine,
801+
core_to_model_conversion=mock_to_orm,
802+
model_to_core_conversion=mock_from_orm,
803+
)
800804

801-
DatabasePushNotificationConfigStore.core_to_model_conversion = mock_to_orm
802-
DatabasePushNotificationConfigStore.model_to_core_conversion = mock_from_orm
803-
804-
try:
805-
config = TaskPushNotificationConfig(id='orig')
806-
model = store._to_orm('t1', config, 'owner')
807-
assert model.config_id == 'c1'
808-
mock_to_orm.assert_called_once_with('t1', config, 'owner', None)
805+
config = TaskPushNotificationConfig(id='orig')
806+
model = store._to_orm('t1', config, 'owner')
807+
assert model.config_id == 'c1'
808+
mock_to_orm.assert_called_once_with('t1', config, 'owner', None)
809809

810-
model_instance = PushNotificationConfigModel(
811-
task_id='t1', config_id='c1'
812-
)
813-
loaded_config = store._from_orm(model_instance)
814-
assert loaded_config.id == 'custom_config'
815-
mock_from_orm.assert_called_once_with(model_instance)
816-
finally:
817-
# Reset class variables
818-
DatabasePushNotificationConfigStore.core_to_model_conversion = None
819-
DatabasePushNotificationConfigStore.model_to_core_conversion = None
810+
model_instance = PushNotificationConfigModel(task_id='t1', config_id='c1')
811+
loaded_config = store._from_orm(model_instance)
812+
assert loaded_config.id == 'custom_config'
813+
mock_from_orm.assert_called_once_with(model_instance)
820814

821815

822-
@pytest.mark.parametrize('assignment_type', ['class', 'instance'])
823816
@pytest.mark.asyncio
824817
async def test_core_to_0_3_model_conversion(
825818
db_store_parameterized: DatabasePushNotificationConfigStore,
826-
assignment_type: str,
827819
) -> None:
828820
"""Test storing and retrieving push notification configs in v0.3 format using conversion utilities.
829821
@@ -834,14 +826,9 @@ async def test_core_to_0_3_model_conversion(
834826
store = db_store_parameterized
835827

836828
# Set the v0.3 persistence utilities
837-
if assignment_type == 'class':
838-
DatabasePushNotificationConfigStore.core_to_model_conversion = (
839-
core_to_compat_push_notification_config_model
840-
)
841-
else:
842-
store.core_to_model_conversion = (
843-
core_to_compat_push_notification_config_model
844-
)
829+
store.core_to_model_conversion = (
830+
core_to_compat_push_notification_config_model
831+
)
845832

846833
task_id = 'v03-persistence-task'
847834
config_id = 'c1'
@@ -852,6 +839,7 @@ async def test_core_to_0_3_model_conversion(
852839
)
853840
# 1. Save the config (will use core_to_compat_push_notification_config_model)
854841
await store.set_info(task_id, original_config, MINIMAL_CALL_CONTEXT)
842+
855843
# 2. Verify it's stored in v0.3 format directly in DB
856844
async with store.async_session_maker() as session:
857845
db_model = await session.get(store.config_model, (task_id, config_id))
@@ -868,16 +856,15 @@ async def test_core_to_0_3_model_conversion(
868856
assert data['id'] == 'c1'
869857
assert data['token'] == 'legacy-token'
870858
assert 'taskId' not in data
859+
871860
# 3. Retrieve the config (will use compat_push_notification_config_model_to_core)
872861
retrieved_configs = await store.get_info(task_id, MINIMAL_CALL_CONTEXT)
873862
assert len(retrieved_configs) == 1
874863
retrieved = retrieved_configs[0]
875864
assert retrieved.id == original_config.id
876865
assert retrieved.url == original_config.url
877866
assert retrieved.token == original_config.token
867+
878868
# Reset conversion attributes
879-
if assignment_type == 'class':
880-
DatabasePushNotificationConfigStore.core_to_model_conversion = None
881-
else:
882-
store.core_to_model_conversion = None
869+
store.core_to_model_conversion = None
883870
await store.delete_info(task_id, MINIMAL_CALL_CONTEXT)

0 commit comments

Comments
 (0)