Skip to content

Commit 7bd6fb6

Browse files
committed
Revert not needed
1 parent 57ad594 commit 7bd6fb6

5 files changed

Lines changed: 40 additions & 36 deletions

File tree

src/a2a/server/models.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ class Base(DeclarativeBase):
141141

142142

143143
# TaskMixin that can be used with any table name
144-
_task_model_cache: dict[tuple[str, type], type] = {}
145-
_push_notification_config_model_cache: dict[tuple[str, type], type] = {}
146-
147-
148144
class TaskMixin:
149145
"""Mixin providing standard task columns with proper type handling."""
150146

@@ -220,9 +216,6 @@ def create_task_model(
220216
221217
TaskModel = create_task_model('tasks', MyBase)
222218
"""
223-
cache_key = (table_name, base)
224-
if cache_key in _task_model_cache:
225-
return _task_model_cache[cache_key]
226219

227220
class TaskModel(TaskMixin, base): # type: ignore
228221
__tablename__ = table_name
@@ -239,10 +232,16 @@ def __repr__(self) -> str:
239232
TaskModel.__name__ = f'TaskModel_{table_name}'
240233
TaskModel.__qualname__ = f'TaskModel_{table_name}'
241234

242-
_task_model_cache[cache_key] = TaskModel
243235
return TaskModel
244236

245237

238+
# Default TaskModel for backward compatibility
239+
class TaskModel(TaskMixin, Base):
240+
"""Default task model with standard table name."""
241+
242+
__tablename__ = 'tasks'
243+
244+
246245
# PushNotificationConfigMixin that can be used with any table name
247246
class PushNotificationConfigMixin:
248247
"""Mixin providing standard push notification config columns."""
@@ -266,9 +265,6 @@ def create_push_notification_config_model(
266265
base: type[DeclarativeBase] = Base,
267266
) -> type:
268267
"""Create a PushNotificationConfigModel class with a configurable table name."""
269-
cache_key = (table_name, base)
270-
if cache_key in _push_notification_config_model_cache:
271-
return _push_notification_config_model_cache[cache_key]
272268

273269
class PushNotificationConfigModel(PushNotificationConfigMixin, base): # type: ignore
274270
__tablename__ = table_name
@@ -277,19 +273,22 @@ class PushNotificationConfigModel(PushNotificationConfigMixin, base): # type: i
277273
def __repr__(self) -> str:
278274
"""Return a string representation of the push notification config."""
279275
return (
280-
f'<PushNotificationConfigModel[{table_name}](task_id="{self.task_id}", '
281-
f'config_id="{self.config_id}")>'
276+
f'<PushNotificationConfigModel[{table_name}]('
277+
f'task_id="{self.task_id}", config_id="{self.config_id}")>'
282278
)
283279

284-
# Set a dynamic name for better debugging
285280
PushNotificationConfigModel.__name__ = (
286281
f'PushNotificationConfigModel_{table_name}'
287282
)
288283
PushNotificationConfigModel.__qualname__ = (
289284
f'PushNotificationConfigModel_{table_name}'
290285
)
291286

292-
_push_notification_config_model_cache[cache_key] = (
293-
PushNotificationConfigModel
294-
)
295287
return PushNotificationConfigModel
288+
289+
290+
# Default PushNotificationConfigModel for backward compatibility
291+
class PushNotificationConfigModel(PushNotificationConfigMixin, Base):
292+
"""Default push notification config model with standard table name."""
293+
294+
__tablename__ = 'push_notification_configs'

src/a2a/server/tasks/database_push_notification_config_store.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44

5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING
66

77
from google.protobuf.json_format import MessageToJson, Parse
88

@@ -30,6 +30,7 @@
3030
from a2a.server.context import ServerCallContext
3131
from a2a.server.models import (
3232
Base,
33+
PushNotificationConfigModel,
3334
create_push_notification_config_model,
3435
)
3536
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
@@ -56,7 +57,7 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore):
5657
async_session_maker: async_sessionmaker[AsyncSession]
5758
create_table: bool
5859
_initialized: bool
59-
config_model: Any
60+
config_model: type[PushNotificationConfigModel]
6061
_fernet: 'Fernet | None'
6162
owner_resolver: OwnerResolver
6263

@@ -90,7 +91,11 @@ def __init__(
9091
self.create_table = create_table
9192
self._initialized = False
9293
self.owner_resolver = owner_resolver
93-
self.config_model = create_push_notification_config_model(table_name)
94+
self.config_model = (
95+
PushNotificationConfigModel
96+
if table_name == 'push_notification_configs'
97+
else create_push_notification_config_model(table_name)
98+
)
9499
self._fernet = None
95100

96101
if encryption_key:
@@ -141,7 +146,7 @@ async def _ensure_initialized(self) -> None:
141146

142147
def _to_orm(
143148
self, task_id: str, config: PushNotificationConfig, owner: str
144-
) -> Any:
149+
) -> PushNotificationConfigModel:
145150
"""Maps a PushNotificationConfig proto to a SQLAlchemy model instance.
146151
147152
The config data is serialized to JSON bytes, and encrypted if a key is configured.
@@ -160,7 +165,9 @@ def _to_orm(
160165
config_data=data_to_store,
161166
)
162167

163-
def _from_orm(self, model_instance: Any) -> PushNotificationConfig:
168+
def _from_orm(
169+
self, model_instance: PushNotificationConfigModel
170+
) -> PushNotificationConfig:
164171
"""Maps a SQLAlchemy model instance to a PushNotificationConfig proto.
165172
166173
Handles decryption if a key is configured, with a fallback to plain JSON.

src/a2a/server/tasks/database_task_store.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from google.protobuf.json_format import MessageToDict
3535

3636
from a2a.server.context import ServerCallContext
37-
from a2a.server.models import Base, create_task_model
37+
from a2a.server.models import Base, TaskModel, create_task_model
3838
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
3939
from a2a.server.tasks.task_store import TaskStore
4040
from a2a.types import a2a_pb2
@@ -56,7 +56,7 @@ class DatabaseTaskStore(TaskStore):
5656
async_session_maker: async_sessionmaker[AsyncSession]
5757
create_table: bool
5858
_initialized: bool
59-
task_model: type[Any]
59+
task_model: type[TaskModel]
6060
owner_resolver: OwnerResolver
6161

6262
def __init__(
@@ -86,7 +86,11 @@ def __init__(
8686
self._initialized = False
8787
self.owner_resolver = owner_resolver
8888

89-
self.task_model = create_task_model(table_name)
89+
self.task_model = (
90+
TaskModel
91+
if table_name == 'tasks'
92+
else create_task_model(table_name)
93+
)
9094

9195
async def initialize(self) -> None:
9296
"""Initialize the database and create the table if needed."""
@@ -111,7 +115,7 @@ async def _ensure_initialized(self) -> None:
111115
if not self._initialized:
112116
await self.initialize()
113117

114-
def _to_orm(self, task: Task, owner: str) -> Any:
118+
def _to_orm(self, task: Task, owner: str) -> TaskModel:
115119
"""Maps a Proto Task to a SQLAlchemy TaskModel instance."""
116120
# Pass proto objects directly - PydanticType/PydanticListType
117121
# handle serialization via process_bind_param
@@ -133,7 +137,7 @@ def _to_orm(self, task: Task, owner: str) -> Any:
133137
),
134138
)
135139

136-
def _from_orm(self, task_model: Any) -> Task:
140+
def _from_orm(self, task_model: TaskModel) -> Task:
137141
"""Maps a SQLAlchemy TaskModel to a Proto Task instance."""
138142
# PydanticType/PydanticListType already deserialize to proto objects
139143
# via process_result_value, so we can construct the Task directly

tests/server/tasks/test_database_push_notification_config_store.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import pytest_asyncio
1818

19-
from typing import TYPE_CHECKING, Any
2019
from _pytest.mark.structures import ParameterSet
2120

2221
# Now safe to import SQLAlchemy-dependent modules
@@ -33,10 +32,8 @@
3332

3433
from a2a.server.models import (
3534
Base,
36-
create_push_notification_config_model,
37-
)
38-
39-
PushNotificationConfigModel: Any = create_push_notification_config_model()
35+
PushNotificationConfigModel,
36+
) # Important: To get Base.metadata
4037
from a2a.server.tasks import DatabasePushNotificationConfigStore
4138
from a2a.types.a2a_pb2 import (
4239
PushNotificationConfig,

tests/server/tasks/test_database_task_store.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919

2020
from google.protobuf.json_format import MessageToDict
2121

22-
from a2a.server.models import Base, create_task_model
23-
from typing import Any
24-
25-
TaskModel: Any = create_task_model()
22+
from a2a.server.models import Base, TaskModel # Important: To get Base.metadata
2623
from a2a.server.tasks.database_task_store import DatabaseTaskStore
2724
from a2a.types.a2a_pb2 import (
2825
Artifact,

0 commit comments

Comments
 (0)