Skip to content

Commit dcf5e55

Browse files
committed
restore dropped tests
1 parent b908a8a commit dcf5e55

1 file changed

Lines changed: 271 additions & 0 deletions

File tree

tests/contrib/tasks/test_vertex_task_converter.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,40 @@
11
"""Tests for vertex_task_converter mappings."""
22

3+
import base64
4+
5+
import pytest
6+
7+
8+
pytest.importorskip(
9+
'vertexai', reason='Vertex Task Converter tests require vertexai'
10+
)
11+
from google.genai import types as genai_types
312
from vertexai import types as vertexai_types
413

514
from a2a.contrib.tasks.vertex_task_converter import (
615
to_sdk_artifact,
716
to_sdk_message,
17+
to_sdk_part,
18+
to_sdk_task,
19+
to_sdk_task_state,
820
to_stored_artifact,
921
to_stored_message,
22+
to_stored_part,
23+
to_stored_task,
24+
to_stored_task_state,
1025
)
1126
from a2a.types import (
1227
Artifact,
1328
DataPart,
29+
FilePart,
30+
FileWithBytes,
31+
FileWithUri,
1432
Message,
1533
Part,
1634
Role,
35+
Task,
36+
TaskState,
37+
TaskStatus,
1738
TextPart,
1839
)
1940

@@ -105,3 +126,253 @@ def test_message_conversion_symmetry() -> None:
105126
assert isinstance(restored_message.parts[0].root, TextPart)
106127
assert restored_message.parts[0].root.text == 'message text'
107128
assert restored_message.parts[0].root.metadata is None
129+
130+
131+
def test_to_stored_part_unsupported() -> None:
132+
part = Part.model_construct(
133+
root=Task( # type: ignore[arg-type]
134+
id='invalid-part',
135+
context_id='ctx',
136+
status=TaskStatus(state=TaskState.submitted),
137+
history=[],
138+
)
139+
)
140+
with pytest.raises(ValueError, match='Unsupported part type'):
141+
to_stored_part(part)
142+
143+
144+
def test_to_sdk_part_text() -> None:
145+
stored_part = genai_types.Part(text='hello back')
146+
sdk_part = to_sdk_part(stored_part)
147+
assert isinstance(sdk_part.root, TextPart)
148+
assert sdk_part.root.text == 'hello back'
149+
150+
151+
def test_to_sdk_part_inline_data() -> None:
152+
stored_part = genai_types.Part(
153+
inline_data=genai_types.Blob(
154+
mime_type='application/json',
155+
data=b'{"key": "val"}',
156+
)
157+
)
158+
sdk_part = to_sdk_part(stored_part)
159+
assert isinstance(sdk_part.root, FilePart)
160+
assert isinstance(sdk_part.root.file, FileWithBytes)
161+
expected_b64 = base64.b64encode(b'{"key": "val"}').decode('utf-8')
162+
assert sdk_part.root.file.mime_type == 'application/json'
163+
assert sdk_part.root.file.bytes == expected_b64
164+
165+
166+
def test_to_sdk_part_file_data() -> None:
167+
stored_part = genai_types.Part(
168+
file_data=genai_types.FileData(
169+
mime_type='image/jpeg',
170+
file_uri='gs://bucket/image.jpg',
171+
)
172+
)
173+
sdk_part = to_sdk_part(stored_part)
174+
assert isinstance(sdk_part.root, FilePart)
175+
assert isinstance(sdk_part.root.file, FileWithUri)
176+
assert sdk_part.root.file.mime_type == 'image/jpeg'
177+
assert sdk_part.root.file.uri == 'gs://bucket/image.jpg'
178+
179+
180+
def test_to_sdk_part_unsupported() -> None:
181+
stored_part = genai_types.Part()
182+
with pytest.raises(ValueError, match='Unsupported part:'):
183+
to_sdk_part(stored_part)
184+
185+
186+
def test_to_stored_artifact() -> None:
187+
sdk_artifact = Artifact(
188+
artifact_id='art-123',
189+
parts=[Part(root=TextPart(text='part_1'))],
190+
)
191+
stored_artifact = to_stored_artifact(sdk_artifact)
192+
assert stored_artifact.artifact_id == 'art-123'
193+
assert len(stored_artifact.parts) == 1
194+
assert stored_artifact.parts[0].text == 'part_1'
195+
196+
197+
def test_to_sdk_artifact() -> None:
198+
stored_artifact = vertexai_types.TaskArtifact(
199+
artifact_id='art-456',
200+
parts=[genai_types.Part(text='part_2')],
201+
)
202+
sdk_artifact = to_sdk_artifact(stored_artifact)
203+
assert sdk_artifact.artifact_id == 'art-456'
204+
assert len(sdk_artifact.parts) == 1
205+
assert isinstance(sdk_artifact.parts[0].root, TextPart)
206+
assert sdk_artifact.parts[0].root.text == 'part_2'
207+
208+
209+
def test_to_stored_task() -> None:
210+
sdk_task = Task(
211+
id='task-1',
212+
context_id='ctx-1',
213+
status=TaskStatus(state=TaskState.working),
214+
metadata={'foo': 'bar'},
215+
artifacts=[
216+
Artifact(
217+
artifact_id='art-1',
218+
parts=[Part(root=TextPart(text='stuff'))],
219+
)
220+
],
221+
history=[],
222+
)
223+
stored_task = to_stored_task(sdk_task)
224+
assert stored_task.context_id == 'ctx-1'
225+
assert stored_task.metadata == {'foo': 'bar'}
226+
assert stored_task.state == vertexai_types.A2aTaskState.WORKING
227+
assert stored_task.output is not None
228+
assert stored_task.output.artifacts is not None
229+
assert len(stored_task.output.artifacts) == 1
230+
assert stored_task.output.artifacts[0].artifact_id == 'art-1'
231+
232+
233+
def test_to_sdk_task() -> None:
234+
stored_task = vertexai_types.A2aTask(
235+
name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2',
236+
context_id='ctx-2',
237+
state=vertexai_types.A2aTaskState.COMPLETED,
238+
metadata={'a': 'b'},
239+
output=vertexai_types.TaskOutput(
240+
artifacts=[
241+
vertexai_types.TaskArtifact(
242+
artifact_id='art-2',
243+
parts=[genai_types.Part(text='result')],
244+
)
245+
]
246+
),
247+
)
248+
sdk_task = to_sdk_task(stored_task)
249+
assert sdk_task.id == 'task-2'
250+
assert sdk_task.context_id == 'ctx-2'
251+
assert sdk_task.status.state == TaskState.completed
252+
assert sdk_task.metadata == {'a': 'b'}
253+
assert sdk_task.history == []
254+
assert sdk_task.artifacts is not None
255+
assert len(sdk_task.artifacts) == 1
256+
assert sdk_task.artifacts[0].artifact_id == 'art-2'
257+
assert isinstance(sdk_task.artifacts[0].parts[0].root, TextPart)
258+
assert sdk_task.artifacts[0].parts[0].root.text == 'result'
259+
260+
261+
def test_to_sdk_task_no_output() -> None:
262+
stored_task = vertexai_types.A2aTask(
263+
name='tasks/task-3',
264+
context_id='ctx-3',
265+
state=vertexai_types.A2aTaskState.SUBMITTED,
266+
metadata=None,
267+
)
268+
sdk_task = to_sdk_task(stored_task)
269+
assert sdk_task.id == 'task-3'
270+
assert sdk_task.metadata == {}
271+
assert sdk_task.artifacts == []
272+
273+
274+
def test_sdk_task_state_conversion_round_trip() -> None:
275+
for state in TaskState:
276+
stored_state = to_stored_task_state(state)
277+
round_trip_state = to_sdk_task_state(stored_state)
278+
assert round_trip_state == state
279+
280+
281+
def test_sdk_part_text_conversion_round_trip() -> None:
282+
sdk_part = Part(root=TextPart(text='hello world'))
283+
stored_part = to_stored_part(sdk_part)
284+
round_trip_sdk_part = to_sdk_part(stored_part)
285+
assert round_trip_sdk_part == sdk_part
286+
287+
288+
def test_sdk_part_data_conversion_round_trip() -> None:
289+
# A DataPart is converted to `inline_data` in Vertex AI, which lacks the original
290+
# `DataPart` vs `FilePart` distinction. When reading it back from the stored
291+
# protocol format, it becomes a `FilePart` with base64-encoded `FileWithBytes`
292+
# and `mime_type="application/json"`.
293+
sdk_part = Part(root=DataPart(data={'key': 'value'}))
294+
stored_part = to_stored_part(sdk_part)
295+
round_trip_sdk_part = to_sdk_part(stored_part)
296+
297+
expected_b64 = base64.b64encode(b'{"key": "value"}').decode('utf-8')
298+
assert round_trip_sdk_part == Part(
299+
root=FilePart(
300+
file=FileWithBytes(
301+
bytes=expected_b64,
302+
mime_type='application/json',
303+
)
304+
)
305+
)
306+
307+
308+
def test_sdk_part_file_bytes_conversion_round_trip() -> None:
309+
encoded_b64 = base64.b64encode(b'test data').decode('utf-8')
310+
sdk_part = Part(
311+
root=FilePart(
312+
file=FileWithBytes(
313+
bytes=encoded_b64,
314+
mime_type='text/plain',
315+
)
316+
)
317+
)
318+
stored_part = to_stored_part(sdk_part)
319+
round_trip_sdk_part = to_sdk_part(stored_part)
320+
assert round_trip_sdk_part == sdk_part
321+
322+
323+
def test_sdk_part_file_uri_conversion_round_trip() -> None:
324+
sdk_part = Part(
325+
root=FilePart(
326+
file=FileWithUri(
327+
uri='gs://test-bucket/file.txt',
328+
mime_type='text/plain',
329+
)
330+
)
331+
)
332+
stored_part = to_stored_part(sdk_part)
333+
round_trip_sdk_part = to_sdk_part(stored_part)
334+
assert round_trip_sdk_part == sdk_part
335+
336+
337+
def test_sdk_artifact_conversion_round_trip() -> None:
338+
sdk_artifact = Artifact(
339+
artifact_id='art-123',
340+
parts=[Part(root=TextPart(text='part_1'))],
341+
)
342+
stored_artifact = to_stored_artifact(sdk_artifact)
343+
round_trip_sdk_artifact = to_sdk_artifact(stored_artifact)
344+
assert round_trip_sdk_artifact == sdk_artifact
345+
346+
347+
def test_sdk_task_conversion_round_trip() -> None:
348+
sdk_task = Task(
349+
id='task-1',
350+
context_id='ctx-1',
351+
status=TaskStatus(state=TaskState.working),
352+
metadata={'foo': 'bar'},
353+
artifacts=[
354+
Artifact(
355+
artifact_id='art-1',
356+
parts=[Part(root=TextPart(text='stuff'))],
357+
)
358+
],
359+
history=[
360+
# History is not yet implemented and later will be supported
361+
# via events.
362+
],
363+
)
364+
stored_task = to_stored_task(sdk_task)
365+
# Simulate Vertex storing the ID in the fully qualified resource name.
366+
# The task ID during creation gets appended to the parent name.
367+
stored_task.name = (
368+
f'projects/p/locations/l/agentEngines/e/tasks/{sdk_task.id}'
369+
)
370+
371+
round_trip_sdk_task = to_sdk_task(stored_task)
372+
373+
assert round_trip_sdk_task.id == sdk_task.id
374+
assert round_trip_sdk_task.context_id == sdk_task.context_id
375+
assert round_trip_sdk_task.status == sdk_task.status
376+
assert round_trip_sdk_task.metadata == sdk_task.metadata
377+
assert round_trip_sdk_task.artifacts == sdk_task.artifacts
378+
assert round_trip_sdk_task.history == []

0 commit comments

Comments
 (0)