Skip to content

Commit 4e03b9d

Browse files
committed
refactor: replace PydanticType/PydanticListType with explicit JSON serialization/deserialization for protobuf models and add protocol versioning to task storage.
1 parent 0f8f9a9 commit 4e03b9d

6 files changed

Lines changed: 90 additions & 281 deletions

File tree

src/a2a/server/models.py

Lines changed: 9 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import TYPE_CHECKING, Any, Generic, TypeVar
2+
from typing import TYPE_CHECKING, Any
33

44

55
if TYPE_CHECKING:
@@ -11,26 +11,14 @@ def override(func): # noqa: ANN001, ANN201
1111
return func
1212

1313

14-
from google.protobuf.json_format import MessageToDict, ParseDict, ParseError
15-
from google.protobuf.message import Message as ProtoMessage
16-
from pydantic import BaseModel, ValidationError
17-
18-
from a2a.compat.v0_3 import conversions
19-
from a2a.compat.v0_3 import types as types_v03
20-
from a2a.types.a2a_pb2 import Artifact, Message, TaskStatus
21-
22-
2314
try:
24-
from sqlalchemy import JSON, DateTime, Dialect, Index, LargeBinary, String
15+
from sqlalchemy import JSON, DateTime, Index, LargeBinary, String
2516
from sqlalchemy.orm import (
2617
DeclarativeBase,
2718
Mapped,
2819
declared_attr,
2920
mapped_column,
3021
)
31-
from sqlalchemy.types import (
32-
TypeDecorator,
33-
)
3422
except ImportError as e:
3523
raise ImportError(
3624
'Database models require SQLAlchemy. '
@@ -42,130 +30,6 @@ def override(func): # noqa: ANN001, ANN201
4230
) from e
4331

4432

45-
T = TypeVar('T')
46-
47-
48-
class PydanticType(TypeDecorator[T], Generic[T]):
49-
"""SQLAlchemy type that handles Pydantic model and Protobuf message serialization."""
50-
51-
impl = JSON
52-
cache_ok = True
53-
54-
def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]):
55-
"""Initialize the PydanticType.
56-
57-
Args:
58-
pydantic_type: The Pydantic model or Protobuf message type to handle.
59-
**kwargs: Additional arguments for TypeDecorator.
60-
"""
61-
self.pydantic_type = pydantic_type
62-
super().__init__(**kwargs)
63-
64-
def process_bind_param(
65-
self, value: T | None, dialect: Dialect
66-
) -> dict[str, Any] | None:
67-
"""Convert Pydantic model or Protobuf message to a JSON-serializable dictionary for the database."""
68-
if value is None:
69-
return None
70-
if isinstance(value, ProtoMessage):
71-
return MessageToDict(value, preserving_proto_field_name=False)
72-
if isinstance(value, BaseModel):
73-
return value.model_dump(mode='json')
74-
return value # type: ignore[return-value]
75-
76-
def process_result_value(
77-
self, value: dict[str, Any] | None, dialect: Dialect
78-
) -> T | None:
79-
"""Convert a JSON-like dictionary from the database back to a Pydantic model or Protobuf message."""
80-
if value is None:
81-
return None
82-
# Check if it's a protobuf message class
83-
if isinstance(self.pydantic_type, type) and issubclass(
84-
self.pydantic_type, ProtoMessage
85-
):
86-
try:
87-
return ParseDict(value, self.pydantic_type()) # type: ignore[return-value]
88-
except (ParseError, ValueError):
89-
# Try legacy conversion
90-
legacy_map = _get_legacy_conversions()
91-
if self.pydantic_type in legacy_map:
92-
legacy_type, convert_func = legacy_map[self.pydantic_type]
93-
try:
94-
legacy_instance = legacy_type.model_validate(value)
95-
return convert_func(legacy_instance)
96-
except ValidationError:
97-
pass
98-
raise
99-
# Assume it's a Pydantic model
100-
return self.pydantic_type.model_validate(value) # type: ignore[attr-defined]
101-
102-
103-
class PydanticListType(TypeDecorator, Generic[T]):
104-
"""SQLAlchemy type that handles lists of Pydantic models or Protobuf messages."""
105-
106-
impl = JSON
107-
cache_ok = True
108-
109-
def __init__(self, pydantic_type: type[T], **kwargs: dict[str, Any]):
110-
"""Initialize the PydanticListType.
111-
112-
Args:
113-
pydantic_type: The Pydantic model or Protobuf message type for items in the list.
114-
**kwargs: Additional arguments for TypeDecorator.
115-
"""
116-
self.pydantic_type = pydantic_type
117-
super().__init__(**kwargs)
118-
119-
def process_bind_param(
120-
self, value: list[T] | None, dialect: Dialect
121-
) -> list[dict[str, Any]] | None:
122-
"""Convert a list of Pydantic models or Protobuf messages to a JSON-serializable list for the DB."""
123-
if value is None:
124-
return None
125-
result: list[dict[str, Any]] = []
126-
for item in value:
127-
if isinstance(item, ProtoMessage):
128-
result.append(
129-
MessageToDict(item, preserving_proto_field_name=False)
130-
)
131-
elif isinstance(item, BaseModel):
132-
result.append(item.model_dump(mode='json'))
133-
else:
134-
result.append(item) # type: ignore[arg-type]
135-
return result
136-
137-
def process_result_value(
138-
self, value: list[dict[str, Any]] | None, dialect: Dialect
139-
) -> list[T] | None:
140-
"""Convert a JSON-like list from the DB back to a list of Pydantic models or Protobuf messages."""
141-
if value is None:
142-
return None
143-
# Check if it's a protobuf message class
144-
if isinstance(self.pydantic_type, type) and issubclass(
145-
self.pydantic_type, ProtoMessage
146-
):
147-
result = []
148-
legacy_map = _get_legacy_conversions()
149-
legacy_info = legacy_map.get(self.pydantic_type)
150-
151-
for item in value:
152-
try:
153-
result.append(ParseDict(item, self.pydantic_type()))
154-
except (ParseError, ValueError): # noqa: PERF203
155-
if legacy_info:
156-
legacy_type, convert_func = legacy_info
157-
try:
158-
legacy_instance = legacy_type.model_validate(item)
159-
result.append(convert_func(legacy_instance))
160-
continue
161-
except ValidationError:
162-
pass
163-
raise
164-
return result # type: ignore[return-value]
165-
# Assume it's a Pydantic model
166-
return [self.pydantic_type.model_validate(item) for item in value] # type: ignore[attr-defined]
167-
168-
16933
# Base class for all database models
17034
class Base(DeclarativeBase):
17135
"""Base class for declarative models in A2A SDK."""
@@ -184,25 +48,17 @@ class TaskMixin:
18448
last_updated: Mapped[datetime | None] = mapped_column(
18549
DateTime, nullable=True
18650
)
187-
188-
# Properly typed Pydantic fields with automatic serialization
189-
status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus))
190-
artifacts: Mapped[list[Artifact] | None] = mapped_column(
191-
PydanticListType(Artifact), nullable=True
192-
)
193-
history: Mapped[list[Message] | None] = mapped_column(
194-
PydanticListType(Message), nullable=True
195-
)
51+
status: Mapped[Any] = mapped_column(JSON)
52+
artifacts: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True)
53+
history: Mapped[list[Any] | None] = mapped_column(JSON, nullable=True)
19654
protocol_version: Mapped[str | None] = mapped_column(
19755
String(16), nullable=True
19856
)
19957

200-
# Using declared_attr to avoid conflict with Pydantic's metadata
201-
@declared_attr
202-
@classmethod
203-
def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
204-
"""Define the 'metadata' column, avoiding name conflicts with Pydantic."""
205-
return mapped_column(JSON, nullable=True, name='metadata')
58+
# Using 'task_metadata' to avoid conflict with SQLAlchemy's 'Base.metadata'
59+
task_metadata: Mapped[dict[str, Any] | None] = mapped_column(
60+
JSON, nullable=True, name='metadata'
61+
)
20662

20763
@override
20864
def __repr__(self) -> str:
@@ -329,15 +185,3 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
329185
"""Default push notification config model with standard table name."""
330186

331187
__tablename__ = 'push_notification_configs'
332-
333-
334-
def _get_legacy_conversions() -> dict[type, tuple[type, Any]]:
335-
"""Get the mapping of current types to their legacy counterparts and conversion functions."""
336-
return {
337-
TaskStatus: (
338-
types_v03.TaskStatus,
339-
conversions.to_core_task_status,
340-
),
341-
Message: (types_v03.Message, conversions.to_core_message),
342-
Artifact: (types_v03.Artifact, conversions.to_core_artifact),
343-
}

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -167,22 +167,6 @@ def _to_orm(
167167
protocol_version='1.0',
168168
)
169169

170-
def _parse_config(
171-
self, json_payload: str, protocol_version: str | None = None
172-
) -> PushNotificationConfig:
173-
"""Parses a JSON payload into a PushNotificationConfig proto.
174-
175-
Uses protocol_version to decide between modern parsing and legacy fallback.
176-
"""
177-
if protocol_version == '1.0':
178-
return Parse(json_payload, PushNotificationConfig())
179-
180-
# Legacy case: no version or older
181-
legacy_instance = types_v03.PushNotificationConfig.model_validate_json(
182-
json_payload
183-
)
184-
return conversions.to_core_push_notification_config(legacy_instance)
185-
186170
def _from_orm(
187171
self, model_instance: PushNotificationConfigModel
188172
) -> TaskPushNotificationConfig:
@@ -355,3 +339,22 @@ async def delete_info(
355339
owner,
356340
config_id,
357341
)
342+
343+
def _parse_config(
344+
self, json_payload: str, protocol_version: str | None = None
345+
) -> TaskPushNotificationConfig:
346+
"""Parses a JSON payload into a TaskPushNotificationConfig proto.
347+
348+
Uses protocol_version to decide between modern parsing and legacy fallback.
349+
"""
350+
if protocol_version == '1.0':
351+
return Parse(json_payload, TaskPushNotificationConfig())
352+
353+
legacy_instance = (
354+
types_v03.TaskPushNotificationConfig.model_validate_json(
355+
json_payload
356+
)
357+
)
358+
return conversions.to_core_task_push_notification_config(
359+
legacy_instance
360+
)

src/a2a/server/tasks/database_task_store.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
"or 'pip install a2a-sdk[sql]'"
3232
) from e
3333

34-
from google.protobuf.json_format import MessageToDict
34+
from google.protobuf.json_format import MessageToDict, ParseDict
3535

36+
from a2a.compat.v0_3 import conversions
37+
from a2a.compat.v0_3 import types as types_v03
3638
from a2a.server.context import ServerCallContext
3739
from a2a.server.models import Base, TaskModel, create_task_model
3840
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -117,8 +119,7 @@ async def _ensure_initialized(self) -> None:
117119

118120
def _to_orm(self, task: Task, owner: str) -> TaskModel:
119121
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
120-
# Pass proto objects directly - PydanticType/PydanticListType
121-
# handle serialization via process_bind_param
122+
task_dict = MessageToDict(task)
122123
return self.task_model(
123124
id=task.id,
124125
context_id=task.context_id,
@@ -129,33 +130,52 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
129130
if task.HasField('status') and task.status.HasField('timestamp')
130131
else None
131132
),
132-
status=task.status if task.HasField('status') else None,
133-
artifacts=list(task.artifacts) if task.artifacts else [],
134-
history=list(task.history) if task.history else [],
135-
task_metadata=(
136-
MessageToDict(task.metadata) if task.metadata.fields else None
137-
),
133+
status=task_dict.get('status'),
134+
artifacts=task_dict.get('artifacts', []),
135+
history=task_dict.get('history', []),
136+
task_metadata=task_dict.get('metadata'),
137+
protocol_version='1.0',
138138
)
139139

140140
def _from_orm(self, task_model: TaskModel) -> Task:
141141
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
142-
# PydanticType/PydanticListType already deserialize to proto objects
143-
# via process_result_value, so we can construct the Task directly
142+
# Data is stored as raw JSON (dicts/lists), so we parse it manually
144143
task = Task(
145144
id=task_model.id,
146145
context_id=task_model.context_id,
147146
)
148-
if task_model.status:
149-
task.status.CopyFrom(task_model.status)
150-
if task_model.artifacts:
151-
task.artifacts.extend(task_model.artifacts)
152-
if task_model.history:
153-
task.history.extend(task_model.history)
154-
if task_model.task_metadata:
155-
task.metadata.update(
156-
cast('dict[str, Any]', task_model.task_metadata)
157-
)
158-
return task
147+
if task_model.protocol_version == '1.0':
148+
if task_model.status:
149+
ParseDict(
150+
cast('dict[str, Any]', task_model.status), task.status
151+
)
152+
if task_model.artifacts:
153+
for art_dict in cast(
154+
'list[dict[str, Any]]', task_model.artifacts
155+
):
156+
art = task.artifacts.add()
157+
ParseDict(art_dict, art)
158+
if task_model.history:
159+
for msg_dict in cast(
160+
'list[dict[str, Any]]', task_model.history
161+
):
162+
msg = task.history.add()
163+
ParseDict(msg_dict, msg)
164+
if task_model.task_metadata:
165+
task.metadata.update(
166+
cast('dict[str, Any]', task_model.task_metadata)
167+
)
168+
return task
169+
# Reconstruct legacy task from raw columns (which are dicts/lists here)
170+
legacy_task = types_v03.Task(
171+
id=task_model.id,
172+
context_id=task_model.context_id,
173+
status=cast('dict[str, Any]', task_model.status),
174+
artifacts=cast('list[dict[str, Any]]', task_model.artifacts),
175+
history=cast('list[dict[str, Any]]', task_model.history),
176+
metadata=cast('dict[str, Any]', task_model.task_metadata),
177+
)
178+
return conversions.to_core_task(legacy_task)
159179

160180
async def save(
161181
self, task: Task, context: ServerCallContext | None = None

0 commit comments

Comments
 (0)