Skip to content

Commit 42aada9

Browse files
committed
refactor: improve SQLAlchemy ORM to Protobuf message mapping for JSON fields and refine model column definitions.
1 parent 335ea41 commit 42aada9

2 files changed

Lines changed: 56 additions & 23 deletions

File tree

src/a2a/server/models.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,23 @@ class TaskMixin:
4848
last_updated: Mapped[datetime | None] = mapped_column(
4949
DateTime, nullable=True
5050
)
51-
status: Mapped[Any] = mapped_column(JSON)
52-
artifacts: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True)
53-
history: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True)
51+
status: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
52+
artifacts: Mapped[list[dict[str, Any]] | None] = mapped_column(
53+
JSON, nullable=True
54+
)
55+
history: Mapped[list[dict[str, Any]] | None] = mapped_column(
56+
JSON, nullable=True
57+
)
5458
protocol_version: Mapped[str | None] = mapped_column(
5559
String(16), nullable=True
5660
)
5761

58-
# Using 'task_metadata' to avoid conflict with SQLAlchemy's 'Base.metadata'
59-
task_metadata: Mapped[dict[str, Any] | None] = mapped_column(
60-
JSON, nullable=True, name='metadata'
61-
)
62+
# Using declared_attr to avoid conflict with Pydantic's metadata
63+
@declared_attr
64+
@classmethod
65+
def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
66+
"""Define the 'metadata' column, avoiding name conflicts with Pydantic."""
67+
return mapped_column(JSON, nullable=True, name='metadata')
6268

6369
@override
6470
def __repr__(self) -> str:

src/a2a/server/tasks/database_task_store.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
from datetime import datetime, timezone
4+
from typing import Any, cast
45

56

67
try:
@@ -118,37 +119,63 @@ async def _ensure_initialized(self) -> None:
118119

119120
def _to_orm(self, task: Task, owner: str) -> TaskModel:
120121
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
121-
task_dict = MessageToDict(task)
122122
return self.task_model(
123123
id=task.id,
124124
context_id=task.context_id,
125125
kind='task', # Default kind for tasks
126126
owner=owner,
127127
last_updated=(
128128
task.status.timestamp.ToDatetime()
129-
if task.HasField('status') and task.status.HasField('timestamp')
129+
if task.status.HasField('timestamp')
130130
else None
131131
),
132-
status=task_dict.get('status'),
133-
artifacts=task_dict.get('artifacts', []),
134-
history=task_dict.get('history', []),
135-
task_metadata=task_dict.get('metadata'),
132+
status=MessageToDict(task.status),
133+
artifacts=[MessageToDict(artifact) for artifact in task.artifacts],
134+
history=[MessageToDict(history) for history in task.history],
135+
task_metadata=(
136+
MessageToDict(task.metadata) if task.metadata.fields else None
137+
),
136138
protocol_version='1.0',
137139
)
138140

139141
def _from_orm(self, task_model: TaskModel) -> Task:
140142
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
141-
task_dict = {
142-
'id': task_model.id,
143-
'context_id': task_model.context_id,
144-
'status': task_model.status,
145-
'artifacts': task_model.artifacts,
146-
'history': task_model.history,
147-
'metadata': task_model.task_metadata,
148-
}
149143
if task_model.protocol_version == '1.0':
150-
return ParseDict(task_dict, Task())
151-
legacy_task = types_v03.Task.model_validate(task_dict)
144+
task = Task(
145+
id=task_model.id,
146+
context_id=task_model.context_id,
147+
)
148+
if task_model.status:
149+
ParseDict(
150+
cast('dict[str, Any]', task_model.status), task.status
151+
)
152+
if task_model.artifacts:
153+
for art_dict in cast(
154+
'list[dict[str, Any]]', task_model.artifacts
155+
):
156+
art = task.artifacts.add()
157+
ParseDict(art_dict, art)
158+
if task_model.history:
159+
for msg_dict in cast(
160+
'list[dict[str, Any]]', task_model.history
161+
):
162+
msg = task.history.add()
163+
ParseDict(msg_dict, msg)
164+
if task_model.task_metadata:
165+
task.metadata.update(
166+
cast('dict[str, Any]', task_model.task_metadata)
167+
)
168+
return task
169+
170+
# Legacy conversion
171+
legacy_task = types_v03.Task(
172+
id=task_model.id,
173+
context_id=task_model.context_id,
174+
status=task_model.status,
175+
artifacts=task_model.artifacts or [],
176+
history=task_model.history or [],
177+
metadata=task_model.task_metadata or {},
178+
)
152179
return conversions.to_core_task(legacy_task)
153180

154181
async def save(

0 commit comments

Comments
 (0)