|
1 | 1 | import logging |
2 | 2 |
|
3 | 3 | from datetime import datetime, timezone |
| 4 | +from typing import Any, cast |
4 | 5 |
|
5 | 6 |
|
6 | 7 | try: |
@@ -118,37 +119,63 @@ async def _ensure_initialized(self) -> None: |
118 | 119 |
|
119 | 120 | def _to_orm(self, task: Task, owner: str) -> TaskModel: |
120 | 121 | """Maps a Proto Task to a SQLAlchemy TaskModel instance.""" |
121 | | - task_dict = MessageToDict(task) |
122 | 122 | return self.task_model( |
123 | 123 | id=task.id, |
124 | 124 | context_id=task.context_id, |
125 | 125 | kind='task', # Default kind for tasks |
126 | 126 | owner=owner, |
127 | 127 | last_updated=( |
128 | 128 | task.status.timestamp.ToDatetime() |
129 | | - if task.HasField('status') and task.status.HasField('timestamp') |
| 129 | + if task.status.HasField('timestamp') |
130 | 130 | else None |
131 | 131 | ), |
132 | | - status=task_dict.get('status'), |
133 | | - artifacts=task_dict.get('artifacts', []), |
134 | | - history=task_dict.get('history', []), |
135 | | - task_metadata=task_dict.get('metadata'), |
| 132 | + status=MessageToDict(task.status), |
| 133 | + artifacts=[MessageToDict(artifact) for artifact in task.artifacts], |
| 134 | + history=[MessageToDict(history) for history in task.history], |
| 135 | + task_metadata=( |
| 136 | + MessageToDict(task.metadata) if task.metadata.fields else None |
| 137 | + ), |
136 | 138 | protocol_version='1.0', |
137 | 139 | ) |
138 | 140 |
|
139 | 141 | def _from_orm(self, task_model: TaskModel) -> Task: |
140 | 142 | """Maps a SQLAlchemy TaskModel to a Proto Task instance.""" |
141 | | - task_dict = { |
142 | | - 'id': task_model.id, |
143 | | - 'context_id': task_model.context_id, |
144 | | - 'status': task_model.status, |
145 | | - 'artifacts': task_model.artifacts, |
146 | | - 'history': task_model.history, |
147 | | - 'metadata': task_model.task_metadata, |
148 | | - } |
149 | 143 | if task_model.protocol_version == '1.0': |
150 | | - return ParseDict(task_dict, Task()) |
151 | | - legacy_task = types_v03.Task.model_validate(task_dict) |
| 144 | + task = Task( |
| 145 | + id=task_model.id, |
| 146 | + context_id=task_model.context_id, |
| 147 | + ) |
| 148 | + if task_model.status: |
| 149 | + ParseDict( |
| 150 | + cast('dict[str, Any]', task_model.status), task.status |
| 151 | + ) |
| 152 | + if task_model.artifacts: |
| 153 | + for art_dict in cast( |
| 154 | + 'list[dict[str, Any]]', task_model.artifacts |
| 155 | + ): |
| 156 | + art = task.artifacts.add() |
| 157 | + ParseDict(art_dict, art) |
| 158 | + if task_model.history: |
| 159 | + for msg_dict in cast( |
| 160 | + 'list[dict[str, Any]]', task_model.history |
| 161 | + ): |
| 162 | + msg = task.history.add() |
| 163 | + ParseDict(msg_dict, msg) |
| 164 | + if task_model.task_metadata: |
| 165 | + task.metadata.update( |
| 166 | + cast('dict[str, Any]', task_model.task_metadata) |
| 167 | + ) |
| 168 | + return task |
| 169 | + |
| 170 | + # Legacy conversion |
| 171 | + legacy_task = types_v03.Task( |
| 172 | + id=task_model.id, |
| 173 | + context_id=task_model.context_id, |
| 174 | + status=task_model.status, |
| 175 | + artifacts=task_model.artifacts or [], |
| 176 | + history=task_model.history or [], |
| 177 | + metadata=task_model.task_metadata or {}, |
| 178 | + ) |
152 | 179 | return conversions.to_core_task(legacy_task) |
153 | 180 |
|
154 | 181 | async def save( |
|
0 commit comments