-
Notifications
You must be signed in to change notification settings - Fork 429
Expand file tree
/
Copy pathproto_helpers.py
More file actions
214 lines (175 loc) · 5.48 KB
/
proto_helpers.py
File metadata and controls
214 lines (175 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Unified helper functions for creating and handling A2A types."""
import uuid
from collections.abc import Sequence
from a2a.types.a2a_pb2 import (
Artifact,
Message,
Part,
Role,
StreamResponse,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
# --- Message Helpers ---
def new_message(
parts: list[Part],
role: Role = Role.ROLE_AGENT,
context_id: str | None = None,
task_id: str | None = None,
) -> Message:
"""Creates a new message containing a list of Parts."""
return Message(
role=role,
parts=parts,
message_id=str(uuid.uuid4()),
task_id=task_id,
context_id=context_id,
)
def new_text_message(
text: str,
context_id: str | None = None,
task_id: str | None = None,
role: Role = Role.ROLE_AGENT,
) -> Message:
"""Creates a new message containing a single text Part."""
return new_message(
parts=[Part(text=text)],
role=role,
task_id=task_id,
context_id=context_id,
)
def get_message_text(message: Message, delimiter: str = '\n') -> str:
"""Extracts and joins all text content from a Message's parts."""
return delimiter.join(get_text_parts(message.parts))
# --- Artifact Helpers ---
def new_artifact(
parts: list[Part],
name: str,
description: str | None = None,
artifact_id: str | None = None,
) -> Artifact:
"""Creates a new Artifact object."""
return Artifact(
artifact_id=artifact_id or str(uuid.uuid4()),
parts=parts,
name=name,
description=description,
)
def new_text_artifact(
name: str,
text: str,
description: str | None = None,
artifact_id: str | None = None,
) -> Artifact:
"""Creates a new Artifact object containing only a single text Part."""
return new_artifact(
[Part(text=text)],
name,
description,
artifact_id=artifact_id,
)
def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str:
"""Extracts and joins all text content from an Artifact's parts."""
return delimiter.join(get_text_parts(artifact.parts))
# --- Task Helpers ---
def new_task_from_user_message(user_message: Message) -> Task:
"""Creates a new Task object from an initial user message."""
if user_message.role != Role.ROLE_USER:
raise ValueError('Message must be from a user')
if not user_message.parts:
raise ValueError('Message parts cannot be empty')
for part in user_message.parts:
if part.HasField('text') and not part.text:
raise ValueError('Message.text cannot be empty')
return Task(
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
id=user_message.task_id or str(uuid.uuid4()),
context_id=user_message.context_id or str(uuid.uuid4()),
history=[user_message],
)
def new_task(
task_id: str,
context_id: str,
state: TaskState,
artifacts: list[Artifact] | None = None,
history: list[Message] | None = None,
) -> Task:
"""Creates a Task object with a specified status."""
if history is None:
history = []
if artifacts is None:
artifacts = []
return Task(
status=TaskStatus(state=state),
id=task_id,
context_id=context_id,
artifacts=artifacts,
history=history,
)
# --- Part Helpers ---
def get_text_parts(parts: Sequence[Part]) -> list[str]:
"""Extracts text content from all text Parts."""
return [part.text for part in parts if part.HasField('text')]
# --- Event & Stream Helpers ---
def new_text_status_update_event(
task_id: str,
context_id: str,
state: TaskState,
text: str,
) -> TaskStatusUpdateEvent:
"""Creates a TaskStatusUpdateEvent with a single text message."""
return TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(
state=state,
message=new_text_message(
text=text,
role=Role.ROLE_AGENT,
context_id=context_id,
task_id=task_id,
),
),
)
def new_text_artifact_update_event( # noqa: PLR0913
task_id: str,
context_id: str,
name: str,
text: str,
append: bool = False,
last_chunk: bool = False,
artifact_id: str | None = None,
) -> TaskArtifactUpdateEvent:
"""Creates a TaskArtifactUpdateEvent with a single text artifact."""
return TaskArtifactUpdateEvent(
task_id=task_id,
context_id=context_id,
artifact=new_text_artifact(
name=name, text=text, artifact_id=artifact_id
),
append=append,
last_chunk=last_chunk,
)
def get_stream_response_text(
response: StreamResponse, delimiter: str = '\n'
) -> str:
"""Extracts text content from a StreamResponse."""
if response.HasField('message'):
return get_message_text(response.message, delimiter)
if response.HasField('task'):
texts = [
get_artifact_text(a, delimiter) for a in response.task.artifacts
]
return delimiter.join(t for t in texts if t)
if response.HasField('status_update'):
if response.status_update.status.HasField('message'):
return get_message_text(
response.status_update.status.message, delimiter
)
return ''
if response.HasField('artifact_update'):
return get_artifact_text(response.artifact_update.artifact, delimiter)
return ''