11from datetime import datetime
2- from typing import TYPE_CHECKING , Any , Generic , TypeVar
2+ from typing import TYPE_CHECKING , Any
33
44
55if TYPE_CHECKING :
@@ -11,26 +11,14 @@ def override(func): # noqa: ANN001, ANN201
1111 return func
1212
1313
14- from google .protobuf .json_format import MessageToDict , ParseDict , ParseError
15- from google .protobuf .message import Message as ProtoMessage
16- from pydantic import BaseModel , ValidationError
17-
18- from a2a .compat .v0_3 import conversions
19- from a2a .compat .v0_3 import types as types_v03
20- from a2a .types .a2a_pb2 import Artifact , Message , TaskStatus
21-
22-
2314try :
24- from sqlalchemy import JSON , DateTime , Dialect , Index , LargeBinary , String
15+ from sqlalchemy import JSON , DateTime , Index , LargeBinary , String
2516 from sqlalchemy .orm import (
2617 DeclarativeBase ,
2718 Mapped ,
2819 declared_attr ,
2920 mapped_column ,
3021 )
31- from sqlalchemy .types import (
32- TypeDecorator ,
33- )
3422except ImportError as e :
3523 raise ImportError (
3624 'Database models require SQLAlchemy. '
@@ -42,130 +30,6 @@ def override(func): # noqa: ANN001, ANN201
4230 ) from e
4331
4432
45- T = TypeVar ('T' )
46-
47-
48- class PydanticType (TypeDecorator [T ], Generic [T ]):
49- """SQLAlchemy type that handles Pydantic model and Protobuf message serialization."""
50-
51- impl = JSON
52- cache_ok = True
53-
54- def __init__ (self , pydantic_type : type [T ], ** kwargs : dict [str , Any ]):
55- """Initialize the PydanticType.
56-
57- Args:
58- pydantic_type: The Pydantic model or Protobuf message type to handle.
59- **kwargs: Additional arguments for TypeDecorator.
60- """
61- self .pydantic_type = pydantic_type
62- super ().__init__ (** kwargs )
63-
64- def process_bind_param (
65- self , value : T | None , dialect : Dialect
66- ) -> dict [str , Any ] | None :
67- """Convert Pydantic model or Protobuf message to a JSON-serializable dictionary for the database."""
68- if value is None :
69- return None
70- if isinstance (value , ProtoMessage ):
71- return MessageToDict (value , preserving_proto_field_name = False )
72- if isinstance (value , BaseModel ):
73- return value .model_dump (mode = 'json' )
74- return value # type: ignore[return-value]
75-
76- def process_result_value (
77- self , value : dict [str , Any ] | None , dialect : Dialect
78- ) -> T | None :
79- """Convert a JSON-like dictionary from the database back to a Pydantic model or Protobuf message."""
80- if value is None :
81- return None
82- # Check if it's a protobuf message class
83- if isinstance (self .pydantic_type , type ) and issubclass (
84- self .pydantic_type , ProtoMessage
85- ):
86- try :
87- return ParseDict (value , self .pydantic_type ()) # type: ignore[return-value]
88- except (ParseError , ValueError ):
89- # Try legacy conversion
90- legacy_map = _get_legacy_conversions ()
91- if self .pydantic_type in legacy_map :
92- legacy_type , convert_func = legacy_map [self .pydantic_type ]
93- try :
94- legacy_instance = legacy_type .model_validate (value )
95- return convert_func (legacy_instance )
96- except ValidationError :
97- pass
98- raise
99- # Assume it's a Pydantic model
100- return self .pydantic_type .model_validate (value ) # type: ignore[attr-defined]
101-
102-
103- class PydanticListType (TypeDecorator , Generic [T ]):
104- """SQLAlchemy type that handles lists of Pydantic models or Protobuf messages."""
105-
106- impl = JSON
107- cache_ok = True
108-
109- def __init__ (self , pydantic_type : type [T ], ** kwargs : dict [str , Any ]):
110- """Initialize the PydanticListType.
111-
112- Args:
113- pydantic_type: The Pydantic model or Protobuf message type for items in the list.
114- **kwargs: Additional arguments for TypeDecorator.
115- """
116- self .pydantic_type = pydantic_type
117- super ().__init__ (** kwargs )
118-
119- def process_bind_param (
120- self , value : list [T ] | None , dialect : Dialect
121- ) -> list [dict [str , Any ]] | None :
122- """Convert a list of Pydantic models or Protobuf messages to a JSON-serializable list for the DB."""
123- if value is None :
124- return None
125- result : list [dict [str , Any ]] = []
126- for item in value :
127- if isinstance (item , ProtoMessage ):
128- result .append (
129- MessageToDict (item , preserving_proto_field_name = False )
130- )
131- elif isinstance (item , BaseModel ):
132- result .append (item .model_dump (mode = 'json' ))
133- else :
134- result .append (item ) # type: ignore[arg-type]
135- return result
136-
137- def process_result_value (
138- self , value : list [dict [str , Any ]] | None , dialect : Dialect
139- ) -> list [T ] | None :
140- """Convert a JSON-like list from the DB back to a list of Pydantic models or Protobuf messages."""
141- if value is None :
142- return None
143- # Check if it's a protobuf message class
144- if isinstance (self .pydantic_type , type ) and issubclass (
145- self .pydantic_type , ProtoMessage
146- ):
147- result = []
148- legacy_map = _get_legacy_conversions ()
149- legacy_info = legacy_map .get (self .pydantic_type )
150-
151- for item in value :
152- try :
153- result .append (ParseDict (item , self .pydantic_type ()))
154- except (ParseError , ValueError ): # noqa: PERF203
155- if legacy_info :
156- legacy_type , convert_func = legacy_info
157- try :
158- legacy_instance = legacy_type .model_validate (item )
159- result .append (convert_func (legacy_instance ))
160- continue
161- except ValidationError :
162- pass
163- raise
164- return result # type: ignore[return-value]
165- # Assume it's a Pydantic model
166- return [self .pydantic_type .model_validate (item ) for item in value ] # type: ignore[attr-defined]
167-
168-
16933# Base class for all database models
17034class Base (DeclarativeBase ):
17135 """Base class for declarative models in A2A SDK."""
@@ -184,25 +48,17 @@ class TaskMixin:
18448 last_updated : Mapped [datetime | None ] = mapped_column (
18549 DateTime , nullable = True
18650 )
187-
188- # Properly typed Pydantic fields with automatic serialization
189- status : Mapped [TaskStatus ] = mapped_column (PydanticType (TaskStatus ))
190- artifacts : Mapped [list [Artifact ] | None ] = mapped_column (
191- PydanticListType (Artifact ), nullable = True
192- )
193- history : Mapped [list [Message ] | None ] = mapped_column (
194- PydanticListType (Message ), nullable = True
195- )
51+ status : Mapped [Any ] = mapped_column (JSON )
52+ artifacts : Mapped [list [Any ] | None ] = mapped_column (JSON , nullable = True )
53+ history : Mapped [list [Any ] | None ] = mapped_column (JSON , nullable = True )
19654 protocol_version : Mapped [str | None ] = mapped_column (
19755 String (16 ), nullable = True
19856 )
19957
200- # Using declared_attr to avoid conflict with Pydantic's metadata
201- @declared_attr
202- @classmethod
203- def task_metadata (cls ) -> Mapped [dict [str , Any ] | None ]:
204- """Define the 'metadata' column, avoiding name conflicts with Pydantic."""
205- return mapped_column (JSON , nullable = True , name = 'metadata' )
58+ # Using 'task_metadata' to avoid conflict with SQLAlchemy's 'Base.metadata'
59+ task_metadata : Mapped [dict [str , Any ] | None ] = mapped_column (
60+ JSON , nullable = True , name = 'metadata'
61+ )
20662
20763 @override
20864 def __repr__ (self ) -> str :
@@ -329,15 +185,3 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
329185 """Default push notification config model with standard table name."""
330186
331187 __tablename__ = 'push_notification_configs'
332-
333-
334- def _get_legacy_conversions () -> dict [type , tuple [type , Any ]]:
335- """Get the mapping of current types to their legacy counterparts and conversion functions."""
336- return {
337- TaskStatus : (
338- types_v03 .TaskStatus ,
339- conversions .to_core_task_status ,
340- ),
341- Message : (types_v03 .Message , conversions .to_core_message ),
342- Artifact : (types_v03 .Artifact , conversions .to_core_artifact ),
343- }
0 commit comments