Skip to content

Commit 1e19e18

Browse files
committed
fix: change last_updated type from String to DateTime
1 parent b7bbc17 commit 1e19e18

4 files changed

Lines changed: 23 additions & 17 deletions

File tree

src/a2a/migrations/versions/6419d2d130f6_add_columns_owner_last_updated.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,18 @@ def _get_inspector() -> sa.engine.reflection.Inspector:
3232
return inspector
3333

3434

35-
def _add_column(table: str, value: str, column_name: str) -> None:
35+
def _add_column(
36+
table: str,
37+
column_name: str,
38+
type_: sa.types.TypeEngine,
39+
value: str | None = None,
40+
) -> None:
3641
if not _column_exists(table, column_name):
3742
op.add_column(
3843
table,
3944
sa.Column(
4045
column_name,
41-
sa.String(128),
46+
type_,
4247
nullable=False,
4348
server_default=value,
4449
),
@@ -107,8 +112,8 @@ def upgrade() -> None:
107112
)
108113

109114
if _table_exists(tasks_table):
110-
_add_column(tasks_table, owner, 'owner')
111-
_add_column(tasks_table, '0', 'last_updated')
115+
_add_column(tasks_table, 'owner', sa.String(128), owner)
116+
_add_column(tasks_table, 'last_updated', sa.DateTime(timezone=True))
112117
_add_index(
113118
tasks_table,
114119
f'idx_{tasks_table}_owner_last_updated',
@@ -120,7 +125,9 @@ def upgrade() -> None:
120125
)
121126

122127
if _table_exists(push_notification_configs_table):
123-
_add_column(push_notification_configs_table, owner, 'owner')
128+
_add_column(
129+
push_notification_configs_table, 'owner', sa.String(128), owner
130+
)
124131
else:
125132
logging.warning(
126133
f"Table '{push_notification_configs_table}' does not exist. Skipping upgrade for this table."

src/a2a/server/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import datetime
12
from typing import TYPE_CHECKING, Any, Generic, TypeVar
23

34

@@ -18,7 +19,7 @@ def override(func): # noqa: ANN001, ANN201
1819

1920

2021
try:
21-
from sqlalchemy import JSON, Dialect, Index, LargeBinary, String
22+
from sqlalchemy import JSON, DateTime, Dialect, Index, LargeBinary, String
2223
from sqlalchemy.orm import (
2324
DeclarativeBase,
2425
Mapped,
@@ -149,7 +150,9 @@ class TaskMixin:
149150
String(16), nullable=False, default='task'
150151
)
151152
owner: Mapped[str] = mapped_column(String(128), nullable=False)
152-
last_updated: Mapped[str] = mapped_column(String(22), nullable=True)
153+
last_updated: Mapped[datetime | None] = mapped_column(
154+
DateTime(timezone=True), nullable=True
155+
)
153156

154157
# Properly typed Pydantic fields with automatic serialization
155158
status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus))

src/a2a/server/tasks/database_task_store.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
124124
kind='task', # Default kind for tasks
125125
owner=owner,
126126
last_updated=(
127-
task.status.timestamp.ToJsonString()
127+
task.status.timestamp.ToDatetime()
128128
if task.HasField('status') and task.status.HasField('timestamp')
129129
else None
130130
),
@@ -227,21 +227,16 @@ async def list(
227227
== a2a_pb2.TaskState.Name(params.status)
228228
)
229229
if params.HasField('status_timestamp_after'):
230-
last_updated_after_iso = (
231-
params.status_timestamp_after.ToJsonString()
232-
)
233-
base_stmt = base_stmt.where(
234-
timestamp_col >= last_updated_after_iso
235-
)
230+
last_updated_after = params.status_timestamp_after.ToDatetime()
231+
base_stmt = base_stmt.where(timestamp_col >= last_updated_after)
236232

237233
# Get total count
238234
count_stmt = select(func.count()).select_from(base_stmt.alias())
239235
total_count = (await session.execute(count_stmt)).scalar_one()
240236

241-
# Use coalesce to treat NULL timestamps as empty strings,
242-
# which sort last in descending order
237+
# Use nulls_last() to ensure NULL timestamps sort last in descending order
243238
stmt = base_stmt.order_by(
244-
func.coalesce(timestamp_col, '').desc(),
239+
timestamp_col.desc().nulls_last(),
245240
self.task_model.id.desc(),
246241
)
247242

tests/migrations/versions/test_migration_6419d2d130f6.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def test_migration_6419d2d130f6_full_cycle(
102102
tasks_columns = {row[1]: row for row in cursor.fetchall()}
103103
assert 'owner' in tasks_columns
104104
assert 'last_updated' in tasks_columns
105+
assert tasks_columns['last_updated'][2] == 'DATETIME'
105106

106107
# Check default value for owner in tasks
107108
# row[4] is dflt_value in PRAGMA table_info

0 commit comments

Comments
 (0)