|
1 | 1 | import logging |
2 | 2 |
|
| 3 | +from datetime import datetime, timezone |
| 4 | + |
3 | 5 |
|
4 | 6 | try: |
5 | | - from sqlalchemy import Table, delete, func, select |
| 7 | + from sqlalchemy import ( |
| 8 | + Table, |
| 9 | + and_, |
| 10 | + delete, |
| 11 | + func, |
| 12 | + or_, |
| 13 | + select, |
| 14 | + ) |
6 | 15 | from sqlalchemy.ext.asyncio import ( |
7 | 16 | AsyncEngine, |
8 | 17 | AsyncSession, |
|
24 | 33 | from a2a.server.tasks.task_store import TaskStore, TasksPage |
25 | 34 | from a2a.types import ListTasksParams, Task |
26 | 35 | from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE |
| 36 | +from a2a.utils.task import decode_page_token, encode_page_token |
27 | 37 |
|
28 | 38 |
|
29 | 39 | logger = logging.getLogger(__name__) |
@@ -154,44 +164,85 @@ async def list( |
154 | 164 | """Retrieves all tasks from the database.""" |
155 | 165 | await self._ensure_initialized() |
156 | 166 | async with self.async_session_maker() as session: |
157 | | - page_number = int(params.page_token) if params.page_token else 0 |
158 | | - page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE |
159 | | - offset = page_number * page_size |
160 | | - |
161 | | - # Base query for filtering |
162 | 167 | base_stmt = select(self.task_model) |
| 168 | + |
| 169 | + # Add filters |
163 | 170 | if params.context_id: |
164 | 171 | base_stmt = base_stmt.where( |
165 | 172 | self.task_model.context_id == params.context_id |
166 | 173 | ) |
167 | | - if params.status is not None: |
| 174 | + if params.status and params.status != 'unknown': |
168 | 175 | base_stmt = base_stmt.where( |
169 | 176 | self.task_model.status['state'].as_string() |
170 | 177 | == params.status.value |
171 | 178 | ) |
| 179 | + if params.last_updated_after: |
| 180 | + last_updated_after_iso = datetime.fromtimestamp( |
| 181 | + params.last_updated_after / 1000, tz=timezone.utc |
| 182 | + ).isoformat() |
| 183 | + base_stmt = base_stmt.where( |
| 184 | + self.task_model.status['timestamp'].as_string() |
| 185 | + >= last_updated_after_iso |
| 186 | + ) |
172 | 187 |
|
173 | 188 | # Get total count |
174 | 189 | count_stmt = select(func.count()).select_from(base_stmt.alias()) |
175 | 190 | total_count = (await session.execute(count_stmt)).scalar_one() |
176 | 191 |
|
177 | | - # Get paginated results |
178 | | - stmt = ( |
179 | | - base_stmt.order_by(self.task_model.id.desc()) |
180 | | - .limit(page_size) |
181 | | - .offset(offset) |
| 192 | + stmt = base_stmt.order_by( |
| 193 | + self.task_model.status['timestamp'] |
| 194 | + .as_string() |
| 195 | + .desc() |
| 196 | + .nulls_last(), |
| 197 | + self.task_model.id.desc(), |
182 | 198 | ) |
| 199 | + |
| 200 | + # Get paginated results |
| 201 | + if params.page_token: |
| 202 | + start_task_id = decode_page_token(params.page_token) |
| 203 | + start_task = ( |
| 204 | + await session.execute( |
| 205 | + select(self.task_model).where( |
| 206 | + self.task_model.id == start_task_id |
| 207 | + ) |
| 208 | + ) |
| 209 | + ).scalar_one_or_none() |
| 210 | + if not start_task: |
| 211 | + raise ValueError(f'Invalid page token: {params.page_token}') |
| 212 | + if start_task.status.timestamp: |
| 213 | + stmt = stmt.where( |
| 214 | + or_( |
| 215 | + self.task_model.status['timestamp'] |
| 216 | + .as_string() |
| 217 | + .is_(None), |
| 218 | + self.task_model.status['timestamp'].as_string() |
| 219 | + >= start_task.status.timestamp, |
| 220 | + ) |
| 221 | + ) |
| 222 | + else: |
| 223 | + stmt = stmt.where( |
| 224 | + and_( |
| 225 | + self.task_model.status['timestamp'] |
| 226 | + .as_string() |
| 227 | + .is_(None), |
| 228 | + self.task_model.id <= start_task.id, |
| 229 | + ) |
| 230 | + ) |
| 231 | + page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE |
| 232 | + stmt = stmt.limit(page_size + 1) # Add 1 for next page token |
| 233 | + |
183 | 234 | result = await session.execute(stmt) |
184 | 235 | tasks_models = result.scalars().all() |
185 | 236 | tasks = [self._from_orm(task_model) for task_model in tasks_models] |
186 | 237 |
|
187 | 238 | next_page_token = ( |
188 | | - str(page_number + 1) |
189 | | - if total_count > (page_number + 1) * page_size |
| 239 | + encode_page_token(tasks[-1].id) |
| 240 | + if len(tasks) == page_size + 1 |
190 | 241 | else None |
191 | 242 | ) |
192 | 243 |
|
193 | 244 | return TasksPage( |
194 | | - tasks=tasks, |
| 245 | + tasks=tasks[:page_size], |
195 | 246 | total_size=total_count, |
196 | 247 | next_page_token=next_page_token, |
197 | 248 | ) |
|
0 commit comments