-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathtask_updater.py
More file actions
208 lines (183 loc) · 6.93 KB
/
task_updater.py
File metadata and controls
208 lines (183 loc) · 6.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
Message,
Part,
Role,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
class TaskUpdater:
"""Helper class for agents to publish updates to a task's event queue.
Simplifies the process of creating and enqueueing standard task events.
"""
def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
"""Initializes the TaskUpdater.
Args:
event_queue: The `EventQueue` associated with the task.
task_id: The ID of the task.
context_id: The context ID of the task.
"""
self.event_queue = event_queue
self.task_id = task_id
self.context_id = context_id
self._lock = asyncio.Lock()
self._terminal_state_reached = False
self._terminal_states = {
TaskState.completed,
TaskState.canceled,
TaskState.failed,
TaskState.rejected,
}
async def update_status(
self,
state: TaskState,
message: Message | None = None,
final: bool = False,
timestamp: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Updates the status of the task and publishes a `TaskStatusUpdateEvent`.
Args:
state: The new state of the task.
message: An optional message associated with the status update.
final: If True, indicates this is the final status update for the task.
timestamp: Optional ISO 8601 datetime string. Defaults to current time.
metadata: Optional metadata for extensions.
"""
async with self._lock:
if self._terminal_state_reached:
raise RuntimeError(
f'Task {self.task_id} is already in a terminal state.'
)
if state in self._terminal_states:
self._terminal_state_reached = True
final = True
current_timestamp = (
timestamp
if timestamp
else datetime.now(timezone.utc).isoformat()
)
await self.event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=self.task_id,
context_id=self.context_id,
final=final,
metadata=metadata,
status=TaskStatus(
state=state,
message=message,
timestamp=current_timestamp,
),
)
)
async def add_artifact( # noqa: PLR0913
self,
parts: list[Part],
artifact_id: str | None = None,
name: str | None = None,
metadata: dict[str, Any] | None = None,
append: bool | None = None,
last_chunk: bool | None = None,
) -> None:
"""Adds an artifact chunk to the task and publishes a `TaskArtifactUpdateEvent`.
Args:
parts: A list of `Part` objects forming the artifact chunk.
artifact_id: The ID of the artifact. A new UUID is generated if not provided.
name: Optional name for the artifact.
metadata: Optional metadata for the artifact.
append: Optional boolean indicating if this chunk appends to a previous one.
last_chunk: Optional boolean indicating if this is the last chunk.
"""
if not artifact_id:
artifact_id = str(uuid.uuid4())
await self.event_queue.enqueue_event(
TaskArtifactUpdateEvent(
task_id=self.task_id,
context_id=self.context_id,
artifact=Artifact(
artifact_id=artifact_id,
name=name,
parts=parts,
metadata=metadata,
),
append=append,
last_chunk=last_chunk,
)
)
async def complete(self, message: Message | None = None) -> None:
"""Marks the task as completed and publishes a final status update."""
await self.update_status(
TaskState.completed,
message=message,
final=True,
)
async def failed(self, message: Message | None = None) -> None:
"""Marks the task as failed and publishes a final status update."""
await self.update_status(TaskState.failed, message=message, final=True)
async def reject(self, message: Message | None = None) -> None:
"""Marks the task as rejected and publishes a final status update."""
await self.update_status(
TaskState.rejected, message=message, final=True
)
async def submit(self, message: Message | None = None) -> None:
"""Marks the task as submitted and publishes a status update."""
await self.update_status(
TaskState.submitted,
message=message,
)
async def start_work(self, message: Message | None = None) -> None:
"""Marks the task as working and publishes a status update."""
await self.update_status(
TaskState.working,
message=message,
)
async def cancel(self, message: Message | None = None) -> None:
"""Marks the task as cancelled and publishes a finalstatus update."""
await self.update_status(
TaskState.canceled, message=message, final=True
)
async def requires_input(
self, message: Message | None = None, final: bool = False
) -> None:
"""Marks the task as input required and publishes a status update."""
await self.update_status(
TaskState.input_required,
message=message,
final=final,
)
async def requires_auth(
self, message: Message | None = None, final: bool = False
) -> None:
"""Marks the task as auth required and publishes a status update."""
await self.update_status(
TaskState.auth_required, message=message, final=final
)
def new_agent_message(
self,
parts: list[Part],
metadata: dict[str, Any] | None = None,
) -> Message:
"""Creates a new message object sent by the agent for this task/context.
Note: This method only *creates* the message object. It does not
automatically enqueue it.
Args:
parts: A list of `Part` objects for the message content.
metadata: Optional metadata for the message.
Returns:
A new `Message` object.
"""
return Message(
role=Role.agent,
task_id=self.task_id,
context_id=self.context_id,
message_id=str(uuid.uuid4()),
metadata=metadata,
parts=parts,
)