Skip to content

Commit b908a8a

Browse files
committed
feat: Add support for more Task Message and Artifact fields in the Vertex Task Store
1 parent 6d49122 commit b908a8a

2 files changed

Lines changed: 253 additions & 388 deletions

File tree

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 170 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,33 @@
1111
import base64
1212
import json
1313

14+
from typing import Any
15+
1416
from a2a.types import (
1517
Artifact,
1618
DataPart,
1719
FilePart,
1820
FileWithBytes,
1921
FileWithUri,
22+
Message,
2023
Part,
24+
Role,
2125
Task,
2226
TaskState,
2327
TaskStatus,
2428
TextPart,
2529
)
2630

2731

32+
_ORIGINAL_METADATA_KEY = 'originalMetadata'
33+
_EXTENSIONS_KEY = 'extensions'
34+
_REFERENCE_TASK_IDS_KEY = 'referenceTaskIds'
35+
_PART_METADATA_KEY = 'partMetadata'
36+
_PART_TYPES_KEY = 'partTypes'
37+
_METADATA_VERSION_KEY = '__vertex_compat_v'
38+
_METADATA_VERSION_NUMBER = 1.0
39+
40+
2841
_TO_SDK_TASK_STATE = {
2942
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
3043
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
@@ -52,6 +65,51 @@ def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
5265
)
5366

5467

68+
def to_stored_metadata(
69+
original_metadata: dict[str, Any] | None,
70+
extensions: list[str] | None,
71+
reference_task_ids: list[str] | None,
72+
parts: list[Part],
73+
) -> dict[str, Any]:
74+
"""Packs original metadata, extensions, and part types/metadata into a storage dictionary."""
75+
metadata: dict[str, Any] = {_METADATA_VERSION_KEY: _METADATA_VERSION_NUMBER}
76+
if original_metadata:
77+
metadata[_ORIGINAL_METADATA_KEY] = original_metadata
78+
if extensions:
79+
metadata[_EXTENSIONS_KEY] = extensions
80+
if reference_task_ids:
81+
metadata[_REFERENCE_TASK_IDS_KEY] = reference_task_ids
82+
83+
part_types = []
84+
part_metadata = []
85+
for part in parts:
86+
part_types.append('data' if isinstance(part.root, DataPart) else '')
87+
part_metadata.append(part.root.metadata)
88+
89+
metadata[_PART_TYPES_KEY] = part_types
90+
metadata[_PART_METADATA_KEY] = part_metadata
91+
92+
return metadata
93+
94+
95+
def to_sdk_metadata(stored_metadata: dict[str, Any] | None) -> dict[str, Any]:
96+
"""Unpacks metadata, extensions, and part types/metadata from a storage dictionary."""
97+
if not stored_metadata:
98+
return {}
99+
100+
version = stored_metadata.get(_METADATA_VERSION_KEY)
101+
if version is None:
102+
return {'original_metadata': stored_metadata}
103+
104+
return {
105+
'original_metadata': stored_metadata.get(_ORIGINAL_METADATA_KEY),
106+
'extensions': stored_metadata.get(_EXTENSIONS_KEY),
107+
'reference_tasks': stored_metadata.get(_REFERENCE_TASK_IDS_KEY),
108+
'part_metadata': stored_metadata.get(_PART_METADATA_KEY),
109+
'part_types': stored_metadata.get(_PART_TYPES_KEY),
110+
}
111+
112+
55113
def to_stored_part(part: Part) -> genai_types.Part:
56114
"""Converts a SDK Part to a proto Part."""
57115
if isinstance(part.root, TextPart):
@@ -82,20 +140,32 @@ def to_stored_part(part: Part) -> genai_types.Part:
82140
raise ValueError(f'Unsupported part type: {type(part.root)}')
83141

84142

85-
def to_sdk_part(stored_part: genai_types.Part) -> Part:
143+
def to_sdk_part(
144+
stored_part: genai_types.Part,
145+
part_metadata: dict[str, Any] | None = None,
146+
part_type: str = '',
147+
) -> Part:
86148
"""Converts a proto Part to a SDK Part."""
87149
if stored_part.text:
88-
return Part(root=TextPart(text=stored_part.text))
150+
return Part(
151+
root=TextPart(text=stored_part.text, metadata=part_metadata)
152+
)
89153
if stored_part.inline_data:
154+
mime_type = stored_part.inline_data.mime_type
155+
if part_type == 'data' and mime_type == 'application/json':
156+
data_dict = json.loads(stored_part.inline_data.data or b'{}')
157+
return Part(root=DataPart(data=data_dict, metadata=part_metadata))
158+
90159
encoded_bytes = base64.b64encode(
91160
stored_part.inline_data.data or b''
92161
).decode('utf-8')
93162
return Part(
94163
root=FilePart(
95164
file=FileWithBytes(
96-
mime_type=stored_part.inline_data.mime_type,
165+
mime_type=mime_type,
97166
bytes=encoded_bytes,
98-
)
167+
),
168+
metadata=part_metadata,
99169
)
100170
)
101171
if stored_part.file_data:
@@ -104,7 +174,8 @@ def to_sdk_part(stored_part: genai_types.Part) -> Part:
104174
file=FileWithUri(
105175
mime_type=stored_part.file_data.mime_type,
106176
uri=stored_part.file_data.file_uri,
107-
)
177+
),
178+
metadata=part_metadata,
108179
)
109180
)
110181

@@ -115,15 +186,98 @@ def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
115186
"""Converts a SDK Artifact to a proto TaskArtifact."""
116187
return vertexai_types.TaskArtifact(
117188
artifact_id=artifact.artifact_id,
189+
display_name=artifact.name,
190+
description=artifact.description,
118191
parts=[to_stored_part(part) for part in artifact.parts],
192+
metadata=to_stored_metadata(
193+
original_metadata=artifact.metadata,
194+
extensions=artifact.extensions,
195+
reference_task_ids=None,
196+
parts=artifact.parts,
197+
),
119198
)
120199

121200

122201
def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
123202
"""Converts a proto TaskArtifact to a SDK Artifact."""
203+
unpacked_meta = to_sdk_metadata(stored_artifact.metadata)
204+
part_metadatas = unpacked_meta.get('part_metadata') or []
205+
part_types = unpacked_meta.get('part_types') or []
206+
207+
parts = []
208+
for i, part in enumerate(stored_artifact.parts or []):
209+
meta: dict[str, Any] | None = None
210+
if i < len(part_metadatas):
211+
meta = part_metadatas[i]
212+
ptype = ''
213+
if i < len(part_types):
214+
ptype = part_types[i]
215+
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))
216+
124217
return Artifact(
125218
artifact_id=stored_artifact.artifact_id,
126-
parts=[to_sdk_part(part) for part in stored_artifact.parts],
219+
name=stored_artifact.display_name,
220+
description=stored_artifact.description,
221+
extensions=unpacked_meta.get('extensions'),
222+
metadata=unpacked_meta.get('original_metadata'),
223+
parts=parts,
224+
)
225+
226+
227+
def to_stored_message(
228+
message: Message | None,
229+
) -> vertexai_types.TaskMessage | None:
230+
"""Converts a SDK Message to a proto Message."""
231+
if not message:
232+
return None
233+
role = message.role.value if message.role else ''
234+
return vertexai_types.TaskMessage(
235+
message_id=message.message_id,
236+
role=role,
237+
parts=[to_stored_part(part) for part in message.parts],
238+
metadata=to_stored_metadata(
239+
original_metadata=message.metadata,
240+
extensions=message.extensions,
241+
reference_task_ids=message.reference_task_ids,
242+
parts=message.parts,
243+
),
244+
)
245+
246+
247+
def to_sdk_message(
248+
stored_msg: vertexai_types.TaskMessage | None,
249+
) -> Message | None:
250+
"""Converts a proto Message to a SDK Message."""
251+
if not stored_msg:
252+
return None
253+
unpacked_meta = to_sdk_metadata(stored_msg.metadata)
254+
part_metadatas = unpacked_meta.get('part_metadata') or []
255+
part_types = unpacked_meta.get('part_types') or []
256+
257+
parts = []
258+
for i, part in enumerate(stored_msg.parts or []):
259+
meta: dict[str, Any] | None = None
260+
if i < len(part_metadatas):
261+
meta = part_metadatas[i]
262+
ptype = ''
263+
if i < len(part_types):
264+
ptype = part_types[i]
265+
parts.append(to_sdk_part(part, part_metadata=meta, part_type=ptype))
266+
267+
role = None
268+
if stored_msg.role:
269+
try:
270+
role = Role(stored_msg.role)
271+
except ValueError:
272+
role = None
273+
274+
return Message(
275+
message_id=stored_msg.message_id,
276+
role=role, # type: ignore
277+
extensions=unpacked_meta.get('extensions'),
278+
reference_task_ids=unpacked_meta.get('reference_tasks'),
279+
metadata=unpacked_meta.get('original_metadata'),
280+
parts=parts,
127281
)
128282

129283

@@ -133,6 +287,11 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
133287
context_id=task.context_id,
134288
metadata=task.metadata,
135289
state=to_stored_task_state(task.status.state),
290+
status_details=vertexai_types.TaskStatusDetails(
291+
task_message=to_stored_message(task.status.message)
292+
)
293+
if task.status.message
294+
else None,
136295
output=vertexai_types.TaskOutput(
137296
artifacts=[
138297
to_stored_artifact(artifact)
@@ -144,10 +303,14 @@ def to_stored_task(task: Task) -> vertexai_types.A2aTask:
144303

145304
def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
146305
"""Converts a proto A2aTask to a SDK Task."""
306+
msg: Message | None = None
307+
if a2a_task.status_details and a2a_task.status_details.task_message:
308+
msg = to_sdk_message(a2a_task.status_details.task_message)
309+
147310
return Task(
148311
id=a2a_task.name.split('/')[-1],
149312
context_id=a2a_task.context_id,
150-
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
313+
status=TaskStatus(state=to_sdk_task_state(a2a_task.state), message=msg),
151314
metadata=a2a_task.metadata or {},
152315
artifacts=[
153316
to_sdk_artifact(artifact)

0 commit comments

Comments
 (0)