Skip to content

Commit 123deea

Browse files
committed
feat: Zero-Downtown migration support for DataBases
1 parent 709b1ff commit 123deea

6 files changed

Lines changed: 415 additions & 37 deletions

File tree

src/a2a/compat/v0_3/conversions.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from google.protobuf.json_format import MessageToDict, ParseDict
66

77
from a2a.compat.v0_3 import types as types_v03
8+
from a2a.server.models import PushNotificationConfigModel, TaskModel
89
from a2a.types import a2a_pb2 as pb2_v10
10+
from cryptography.fernet import (
11+
Fernet,
12+
)
913

1014

1115
_COMPAT_TO_CORE_TASK_STATE: dict[types_v03.TaskState, Any] = {
@@ -1367,3 +1371,79 @@ def to_compat_get_extended_agent_card_request(
13671371
) -> types_v03.GetAuthenticatedExtendedCardRequest:
13681372
"""Convert get extended agent card request to v0.3 compat type."""
13691373
return types_v03.GetAuthenticatedExtendedCardRequest(id=request_id)
1374+
1375+
1376+
def core_to_compat_task_model(task: pb2_v10.Task, owner: str) -> TaskModel:
1377+
"""Converts a 1.0 core Task to a TaskModel using v0.3 JSON structure."""
1378+
compat_task = to_compat_task(task)
1379+
data = compat_task.model_dump(mode='json')
1380+
1381+
return TaskModel(
1382+
id=task.id,
1383+
context_id=task.context_id,
1384+
owner=owner,
1385+
status=data.get('status'),
1386+
history=data.get('history'),
1387+
artifacts=data.get('artifacts'),
1388+
task_metadata=data.get('metadata'),
1389+
protocol_version='0.3',
1390+
)
1391+
1392+
1393+
def compat_task_model_to_core(task_model: TaskModel) -> pb2_v10.Task:
1394+
"""Converts a TaskModel with v0.3 structure to a 1.0 core Task."""
1395+
compat_task = types_v03.Task(
1396+
id=task_model.id,
1397+
context_id=task_model.context_id,
1398+
status=types_v03.TaskStatus.model_validate(task_model.status)
1399+
if task_model.status
1400+
else None,
1401+
artifacts=(
1402+
[types_v03.Artifact.model_validate(a) for a in task_model.artifacts]
1403+
if task_model.artifacts
1404+
else []
1405+
),
1406+
history=(
1407+
[types_v03.Message.model_validate(h) for h in task_model.history]
1408+
if task_model.history
1409+
else []
1410+
),
1411+
metadata=task_model.task_metadata,
1412+
)
1413+
return to_core_task(compat_task)
1414+
1415+
1416+
def core_to_compat_push_notification_config_model(
1417+
task_id: str,
1418+
config: pb2_v10.TaskPushNotificationConfig,
1419+
owner: str,
1420+
fernet: Fernet | None = None,
1421+
) -> PushNotificationConfigModel:
1422+
"""Converts a 1.0 core TaskPushNotificationConfig to a PushNotificationConfigModel using v0.3 JSON structure."""
1423+
compat_config = to_compat_push_notification_config(config)
1424+
1425+
json_payload = compat_config.model_dump_json().encode('utf-8')
1426+
data_to_store = fernet.encrypt(json_payload) if fernet else json_payload
1427+
1428+
return PushNotificationConfigModel(
1429+
task_id=task_id,
1430+
config_id=config.id,
1431+
owner=owner,
1432+
config_data=data_to_store,
1433+
protocol_version='0.3',
1434+
)
1435+
1436+
1437+
def compat_push_notification_config_model_to_core(
1438+
model_instance: str, task_id: str
1439+
) -> pb2_v10.TaskPushNotificationConfig:
1440+
"""Converts a PushNotificationConfigModel with v0.3 structure back to a 1.0 core TaskPushNotificationConfig."""
1441+
inner_config = types_v03.PushNotificationConfig.model_validate_json(
1442+
model_instance
1443+
)
1444+
return to_core_task_push_notification_config(
1445+
types_v03.TaskPushNotificationConfig(
1446+
task_id=task_id,
1447+
push_notification_config=inner_config,
1448+
)
1449+
)

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
"or 'pip install a2a-sdk[sql]'"
2727
) from e
2828

29-
from a2a.compat.v0_3 import conversions
30-
from a2a.compat.v0_3 import types as types_v03
29+
if TYPE_CHECKING:
30+
from collections.abc import Callable
31+
32+
from a2a.compat.v0_3.conversions import (
33+
compat_push_notification_config_model_to_core,
34+
)
3135
from a2a.server.context import ServerCallContext
3236
from a2a.server.models import (
3337
Base,
@@ -40,11 +44,8 @@
4044
)
4145
from a2a.types.a2a_pb2 import TaskPushNotificationConfig
4246

43-
4447
if TYPE_CHECKING:
4548
from cryptography.fernet import Fernet
46-
47-
4849
logger = logging.getLogger(__name__)
4950

5051

@@ -62,6 +63,9 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
6263
_fernet: 'Fernet | None'
6364
owner_resolver: OwnerResolver
6465

66+
core_to_model_conversion: 'Callable[[str, TaskPushNotificationConfig, str, Fernet | None], PushNotificationConfigModel] | None' = None
67+
model_to_core_conversion: 'Callable[[PushNotificationConfigModel], TaskPushNotificationConfig] | None' = None
68+
6569
def __init__(
6670
self,
6771
engine: AsyncEngine,
@@ -152,6 +156,14 @@ def _to_orm(
152156
153157
The config data is serialized to JSON bytes, and encrypted if a key is configured.
154158
"""
159+
if self.core_to_model_conversion:
160+
conversion = self.core_to_model_conversion
161+
# bound method
162+
if hasattr(conversion, '__func__'):
163+
return conversion.__func__(task_id, config, owner, self._fernet)
164+
# instance method
165+
return conversion(task_id, config, owner, self._fernet)
166+
155167
json_payload = MessageToJson(config).encode('utf-8')
156168

157169
if self._fernet:
@@ -174,6 +186,14 @@ def _from_orm(
174186
175187
Handles decryption if a key is configured, with a fallback to plain JSON.
176188
"""
189+
if self.model_to_core_conversion:
190+
conversion = self.model_to_core_conversion
191+
# bound method
192+
if hasattr(conversion, '__func__'):
193+
return conversion.__func__(model_instance)
194+
# instance method
195+
return conversion(model_instance)
196+
177197
payload = model_instance.config_data
178198

179199
if self._fernet:
@@ -359,12 +379,7 @@ def _parse_config(
359379
"""
360380
if protocol_version == '1.0':
361381
return Parse(json_payload, TaskPushNotificationConfig())
362-
inner_config = types_v03.PushNotificationConfig.model_validate_json(
363-
json_payload
364-
)
365-
return conversions.to_core_task_push_notification_config(
366-
types_v03.TaskPushNotificationConfig(
367-
task_id=task_id or '',
368-
push_notification_config=inner_config,
369-
)
382+
383+
return compat_push_notification_config_model_to_core(
384+
json_payload, task_id or ''
370385
)

src/a2a/server/tasks/database_task_store.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
from datetime import datetime, timezone
4+
from typing import TYPE_CHECKING
45

56

67
try:
@@ -30,10 +31,12 @@
3031
"or 'pip install a2a-sdk[sql]'"
3132
) from e
3233

34+
if TYPE_CHECKING:
35+
from collections.abc import Callable
36+
3337
from google.protobuf.json_format import MessageToDict, ParseDict
3438

3539
from a2a.compat.v0_3 import conversions
36-
from a2a.compat.v0_3 import types as types_v03
3740
from a2a.server.context import ServerCallContext
3841
from a2a.server.models import Base, TaskModel, create_task_model
3942
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -61,6 +64,9 @@ class DatabaseTaskStore(TaskStore):
6164
task_model: type[TaskModel]
6265
owner_resolver: OwnerResolver
6366

67+
core_to_model_conversion: 'Callable[[Task, str], TaskModel] | None' = None
68+
model_to_core_conversion: 'Callable[[TaskModel], Task] | None' = None
69+
6470
def __init__(
6571
self,
6672
engine: AsyncEngine,
@@ -119,6 +125,14 @@ async def _ensure_initialized(self) -> None:
119125

120126
def _to_orm(self, task: Task, owner: str) -> TaskModel:
121127
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
128+
if self.core_to_model_conversion:
129+
conversion = self.core_to_model_conversion
130+
# bound method
131+
if hasattr(conversion, '__func__'):
132+
return conversion.__func__(task, owner)
133+
# instance method
134+
return conversion(task, owner)
135+
122136
return self.task_model(
123137
id=task.id,
124138
context_id=task.context_id,
@@ -140,6 +154,14 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
140154

141155
def _from_orm(self, task_model: TaskModel) -> Task:
142156
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
157+
if self.model_to_core_conversion:
158+
conversion = self.model_to_core_conversion
159+
# bound method
160+
if hasattr(conversion, '__func__'):
161+
return conversion.__func__(task_model)
162+
# instance method
163+
return conversion(task_model)
164+
143165
if task_model.protocol_version == '1.0':
144166
task = Task(
145167
id=task_model.id,
@@ -160,29 +182,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
160182
return task
161183

162184
# Legacy conversion
163-
legacy_task = types_v03.Task(
164-
id=task_model.id,
165-
context_id=task_model.context_id,
166-
status=types_v03.TaskStatus.model_validate(task_model.status),
167-
artifacts=(
168-
[
169-
types_v03.Artifact.model_validate(a)
170-
for a in task_model.artifacts
171-
]
172-
if task_model.artifacts
173-
else []
174-
),
175-
history=(
176-
[
177-
types_v03.Message.model_validate(m)
178-
for m in task_model.history
179-
]
180-
if task_model.history
181-
else []
182-
),
183-
metadata=task_model.task_metadata or {},
184-
)
185-
return conversions.to_core_task(legacy_task)
185+
return conversions.compat_task_model_to_core(task_model)
186186

187187
async def save(
188188
self, task: Task, context: ServerCallContext | None = None

tests/compat/v0_3/test_conversions.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@
7272
to_core_task_push_notification_config,
7373
to_core_task_status,
7474
to_core_task_status_update_event,
75+
core_to_compat_task_model,
76+
compat_task_model_to_core,
77+
core_to_compat_push_notification_config_model,
78+
compat_push_notification_config_model_to_core,
7579
)
80+
from a2a.server.models import PushNotificationConfigModel, TaskModel
81+
from cryptography.fernet import Fernet
7682
from a2a.types import a2a_pb2 as pb2_v10
7783

7884

@@ -1911,3 +1917,100 @@ def test_to_core_part_unknown_part():
19111917
assert not core_part.HasField('data')
19121918
assert not core_part.HasField('raw')
19131919
assert not core_part.HasField('url')
1920+
1921+
1922+
def test_task_db_conversion():
1923+
v10_task = pb2_v10.Task(
1924+
id='task-123',
1925+
context_id='ctx-456',
1926+
status=pb2_v10.TaskStatus(
1927+
state=pb2_v10.TaskState.TASK_STATE_WORKING,
1928+
),
1929+
metadata={'m1': 'v1'},
1930+
)
1931+
owner = 'owner-789'
1932+
1933+
# Test Core -> Model
1934+
model = core_to_compat_task_model(v10_task, owner)
1935+
assert model.id == 'task-123'
1936+
assert model.context_id == 'ctx-456'
1937+
assert model.owner == owner
1938+
assert model.protocol_version == '0.3'
1939+
assert model.status['state'] == 'working'
1940+
assert model.task_metadata == {'m1': 'v1'}
1941+
1942+
# Test Model -> Core
1943+
v10_restored = compat_task_model_to_core(model)
1944+
assert v10_restored.id == v10_task.id
1945+
assert v10_restored.context_id == v10_task.context_id
1946+
assert v10_restored.status.state == v10_task.status.state
1947+
assert v10_restored.metadata == v10_task.metadata
1948+
1949+
1950+
def test_push_notification_config_db_conversion():
1951+
task_id = 'task-123'
1952+
v10_config = pb2_v10.TaskPushNotificationConfig(
1953+
id='pnc-1',
1954+
url='https://example.com/push',
1955+
token='secret-token',
1956+
)
1957+
owner = 'owner-789'
1958+
1959+
# Test Core -> Model (No encryption)
1960+
model = core_to_compat_push_notification_config_model(
1961+
task_id, v10_config, owner
1962+
)
1963+
assert model.task_id == task_id
1964+
assert model.config_id == 'pnc-1'
1965+
assert model.owner == owner
1966+
assert model.protocol_version == '0.3'
1967+
1968+
import json
1969+
1970+
data = json.loads(model.config_data.decode('utf-8'))
1971+
assert data['url'] == 'https://example.com/push'
1972+
assert data['token'] == 'secret-token'
1973+
1974+
# Test Model -> Core
1975+
v10_restored = compat_push_notification_config_model_to_core(
1976+
model.config_data.decode('utf-8'), task_id
1977+
)
1978+
assert v10_restored.id == v10_config.id
1979+
assert v10_restored.url == v10_config.url
1980+
assert v10_restored.token == v10_config.token
1981+
1982+
1983+
def test_push_notification_config_persistence_conversion_with_encryption():
1984+
task_id = 'task-123'
1985+
v10_config = pb2_v10.TaskPushNotificationConfig(
1986+
id='pnc-1',
1987+
url='https://example.com/push',
1988+
token='secret-token',
1989+
)
1990+
owner = 'owner-789'
1991+
key = Fernet.generate_key()
1992+
fernet = Fernet(key)
1993+
1994+
# Test Core -> Model (With encryption)
1995+
model = core_to_compat_push_notification_config_model(
1996+
task_id, v10_config, owner, fernet=fernet
1997+
)
1998+
assert (
1999+
model.config_data != v10_config.SerializeToString()
2000+
) # Should be encrypted
2001+
2002+
# Decrypt and verify
2003+
decrypted_data = fernet.decrypt(model.config_data)
2004+
import json
2005+
2006+
data = json.loads(decrypted_data.decode('utf-8'))
2007+
assert data['url'] == 'https://example.com/push'
2008+
assert data['token'] == 'secret-token'
2009+
2010+
# Test Model -> Core
2011+
v10_restored = compat_push_notification_config_model_to_core(
2012+
decrypted_data.decode('utf-8'), task_id
2013+
)
2014+
assert v10_restored.id == v10_config.id
2015+
assert v10_restored.url == v10_config.url
2016+
assert v10_restored.token == v10_config.token

0 commit comments

Comments
 (0)