Skip to content

Commit 20526d7

Browse files
committed
Cursor-based pagination and update time handling
1 parent 10b638a commit 20526d7

6 files changed

Lines changed: 354 additions & 66 deletions

File tree

src/a2a/server/tasks/database_task_store.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import logging
22

3+
from datetime import datetime, timezone
4+
35

46
try:
5-
from sqlalchemy import Table, delete, func, select
7+
from sqlalchemy import (
8+
Table,
9+
and_,
10+
delete,
11+
func,
12+
or_,
13+
select,
14+
)
615
from sqlalchemy.ext.asyncio import (
716
AsyncEngine,
817
AsyncSession,
@@ -24,6 +33,7 @@
2433
from a2a.server.tasks.task_store import TaskStore, TasksPage
2534
from a2a.types import ListTasksParams, Task
2635
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
36+
from a2a.utils.task import decode_page_token, encode_page_token
2737

2838

2939
logger = logging.getLogger(__name__)
@@ -154,44 +164,85 @@ async def list(
154164
"""Retrieves all tasks from the database."""
155165
await self._ensure_initialized()
156166
async with self.async_session_maker() as session:
157-
page_number = int(params.page_token) if params.page_token else 0
158-
page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
159-
offset = page_number * page_size
160-
161-
# Base query for filtering
162167
base_stmt = select(self.task_model)
168+
169+
# Add filters
163170
if params.context_id:
164171
base_stmt = base_stmt.where(
165172
self.task_model.context_id == params.context_id
166173
)
167-
if params.status is not None:
174+
if params.status and params.status != 'unknown':
168175
base_stmt = base_stmt.where(
169176
self.task_model.status['state'].as_string()
170177
== params.status.value
171178
)
179+
if params.last_updated_after:
180+
last_updated_after_iso = datetime.fromtimestamp(
181+
params.last_updated_after / 1000, tz=timezone.utc
182+
).isoformat()
183+
base_stmt = base_stmt.where(
184+
self.task_model.status['timestamp'].as_string()
185+
>= last_updated_after_iso
186+
)
172187

173188
# Get total count
174189
count_stmt = select(func.count()).select_from(base_stmt.alias())
175190
total_count = (await session.execute(count_stmt)).scalar_one()
176191

177-
# Get paginated results
178-
stmt = (
179-
base_stmt.order_by(self.task_model.id.desc())
180-
.limit(page_size)
181-
.offset(offset)
192+
stmt = base_stmt.order_by(
193+
self.task_model.status['timestamp']
194+
.as_string()
195+
.desc()
196+
.nulls_last(),
197+
self.task_model.id.desc(),
182198
)
199+
200+
# Get paginated results
201+
if params.page_token:
202+
start_task_id = decode_page_token(params.page_token)
203+
start_task = (
204+
await session.execute(
205+
select(self.task_model).where(
206+
self.task_model.id == start_task_id
207+
)
208+
)
209+
).scalar_one_or_none()
210+
if not start_task:
211+
raise ValueError(f'Invalid page token: {params.page_token}')
212+
if start_task.status.timestamp:
213+
stmt = stmt.where(
214+
or_(
215+
self.task_model.status['timestamp']
216+
.as_string()
217+
.is_(None),
218+
self.task_model.status['timestamp'].as_string()
219+
>= start_task.status.timestamp,
220+
)
221+
)
222+
else:
223+
stmt = stmt.where(
224+
and_(
225+
self.task_model.status['timestamp']
226+
.as_string()
227+
.is_(None),
228+
self.task_model.id <= start_task.id,
229+
)
230+
)
231+
page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
232+
stmt = stmt.limit(page_size + 1) # Add 1 for next page token
233+
183234
result = await session.execute(stmt)
184235
tasks_models = result.scalars().all()
185236
tasks = [self._from_orm(task_model) for task_model in tasks_models]
186237

187238
next_page_token = (
188-
str(page_number + 1)
189-
if total_count > (page_number + 1) * page_size
239+
encode_page_token(tasks[-1].id)
240+
if len(tasks) == page_size + 1
190241
else None
191242
)
192243

193244
return TasksPage(
194-
tasks=tasks,
245+
tasks=tasks[:page_size],
195246
total_size=total_count,
196247
next_page_token=next_page_token,
197248
)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import asyncio
22
import logging
33

4+
from datetime import datetime, timezone
5+
46
from a2a.server.context import ServerCallContext
57
from a2a.server.tasks.task_store import TaskStore, TasksPage
68
from a2a.types import ListTasksParams, Task
79
from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE
10+
from a2a.utils.task import decode_page_token, encode_page_token
811

912

1013
logger = logging.getLogger(__name__)
@@ -53,27 +56,58 @@ async def list(
5356
async with self.lock:
5457
tasks = list(self.tasks.values())
5558

56-
# Apply filtering
59+
# Filter tasks
5760
if params.context_id:
5861
tasks = [
5962
task for task in tasks if task.context_id == params.context_id
6063
]
61-
if params.status is not None:
64+
if params.status and params.status != 'unknown':
6265
tasks = [
6366
task for task in tasks if task.status.state == params.status
6467
]
68+
if params.last_updated_after:
69+
last_updated_after_iso = datetime.fromtimestamp(
70+
params.last_updated_after / 1000, tz=timezone.utc
71+
).isoformat()
72+
tasks = [
73+
task
74+
for task in tasks
75+
if (
76+
task.status.timestamp
77+
and task.status.timestamp >= last_updated_after_iso
78+
)
79+
]
6580

66-
# Apply pagination
67-
total_size = len(tasks)
68-
page_token = int(params.page_token) if params.page_token else 0
69-
page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE
70-
tasks = tasks[page_token * page_size : (page_token + 1) * page_size]
81+
# Order tasks by last update time. To ensure stable sorting, in cases where timestamps are null or not unique, do a second order comparison of IDs.
82+
tasks.sort(
83+
key=lambda task: (
84+
task.status.timestamp is not None,
85+
task.status.timestamp,
86+
task.id,
87+
),
88+
reverse=True,
89+
)
7190

91+
# Paginate tasks
92+
total_size = len(tasks)
93+
start_idx = 0
94+
if params.page_token:
95+
start_task_id = decode_page_token(params.page_token)
96+
valid_token = False
97+
for i, task in enumerate(tasks):
98+
if task.id == start_task_id:
99+
start_idx = i
100+
valid_token = True
101+
break
102+
if not valid_token:
103+
raise ValueError(f'Invalid page token: {params.page_token}')
104+
end_idx = start_idx + (params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE)
72105
next_page_token = (
73-
str(page_token + 1)
74-
if (page_token + 1) * page_size < total_size
106+
encode_page_token(tasks[end_idx].id)
107+
if end_idx < total_size
75108
else None
76109
)
110+
tasks = tasks[start_idx:end_idx]
77111

78112
return TasksPage(
79113
next_page_token=next_page_token,

src/a2a/utils/task.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Utility functions for creating A2A Task objects."""
22

3+
import binascii
34
import uuid
45

6+
from base64 import b64decode, b64encode
7+
58
from a2a.types import Artifact, Message, Task, TaskState, TaskStatus, TextPart
69

710

@@ -92,3 +95,39 @@ def apply_history_length(task: Task, history_length: int | None) -> Task:
9295
return task.model_copy(update={'history': limited_history})
9396

9497
return task
98+
99+
100+
_ENCODING = 'utf-8'
101+
102+
103+
def encode_page_token(task_id: str) -> str:
104+
"""Encodes page token for tasks pagination.
105+
106+
Args:
107+
task_id: The ID of the task.
108+
109+
Returns:
110+
The encoded page token.
111+
"""
112+
return b64encode(task_id.encode(_ENCODING)).decode(_ENCODING)
113+
114+
115+
def decode_page_token(page_token: str) -> str:
116+
"""Decodes page token for tasks pagination.
117+
118+
Args:
119+
page_token: The encoded page token.
120+
121+
Returns:
122+
The decoded task ID.
123+
"""
124+
encoded_str = page_token
125+
missing_padding = len(encoded_str) % 4
126+
if missing_padding:
127+
encoded_str += '=' * (4 - missing_padding)
128+
print(f'input: {encoded_str}')
129+
try:
130+
decoded = b64decode(encoded_str.encode(_ENCODING)).decode(_ENCODING)
131+
except (binascii.Error, UnicodeDecodeError) as e:
132+
raise ValueError('Token is not a valid base64-encoded cursor.') from e
133+
return decoded

0 commit comments

Comments
 (0)