Skip to content

Commit f77e7e8

Browse files
committed
Restore tests to original
1 parent 7eb3b98 commit f77e7e8

1 file changed

Lines changed: 121 additions & 94 deletions

File tree

tests/contrib/tasks/test_vertex_task_converter.py

Lines changed: 121 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
"""Tests for vertex_task_converter mappings."""
2-
31
import base64
42

53
import pytest
@@ -8,17 +6,14 @@
86
pytest.importorskip(
97
'vertexai', reason='Vertex Task Converter tests require vertexai'
108
)
11-
from google.genai import types as genai_types
129
from vertexai import types as vertexai_types
13-
10+
from google.genai import types as genai_types
1411
from a2a.contrib.tasks.vertex_task_converter import (
1512
to_sdk_artifact,
16-
to_sdk_message,
1713
to_sdk_part,
1814
to_sdk_task,
1915
to_sdk_task_state,
2016
to_stored_artifact,
21-
to_stored_message,
2217
to_stored_part,
2318
to_stored_task,
2419
to_stored_task_state,
@@ -29,114 +24,146 @@
2924
FilePart,
3025
FileWithBytes,
3126
FileWithUri,
32-
Message,
3327
Part,
34-
Role,
3528
Task,
3629
TaskState,
3730
TaskStatus,
3831
TextPart,
3932
)
4033

4134

42-
def test_artifact_conversion_symmetry() -> None:
43-
"""Test converting an Artifact to TaskArtifact and back restores everything."""
44-
original_artifact = Artifact(
45-
artifact_id='art123',
46-
name='My cool artifact',
47-
description='A very interesting description',
48-
extensions=['ext1', 'ext2'],
49-
metadata={'custom': 'value'},
50-
parts=[
51-
Part(
52-
root=TextPart(
53-
text='hello', metadata={'part_meta': 'hello_meta'}
54-
)
55-
),
56-
Part(root=DataPart(data={'foo': 'bar'})), # no metadata
57-
],
35+
def test_to_sdk_task_state() -> None:
36+
assert (
37+
to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED)
38+
== TaskState.unknown
39+
)
40+
assert (
41+
to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED)
42+
== TaskState.submitted
43+
)
44+
assert (
45+
to_sdk_task_state(vertexai_types.A2aTaskState.WORKING)
46+
== TaskState.working
47+
)
48+
assert (
49+
to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED)
50+
== TaskState.completed
51+
)
52+
assert (
53+
to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED)
54+
== TaskState.canceled
55+
)
56+
assert (
57+
to_sdk_task_state(vertexai_types.A2aTaskState.FAILED)
58+
== TaskState.failed
59+
)
60+
assert (
61+
to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED)
62+
== TaskState.rejected
63+
)
64+
assert (
65+
to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED)
66+
== TaskState.input_required
67+
)
68+
assert (
69+
to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED)
70+
== TaskState.auth_required
5871
)
72+
assert to_sdk_task_state(999) == TaskState.unknown # type: ignore
5973

60-
stored = to_stored_artifact(original_artifact)
61-
assert isinstance(stored, vertexai_types.TaskArtifact)
62-
63-
# ensure it was populated correctly
64-
assert stored.display_name == 'My cool artifact'
65-
assert stored.description == 'A very interesting description'
66-
assert stored.metadata['__vertex_compat_v'] == 1.0
67-
68-
restored_artifact = to_sdk_artifact(stored)
69-
70-
assert restored_artifact.artifact_id == original_artifact.artifact_id
71-
assert restored_artifact.name == original_artifact.name
72-
assert restored_artifact.description == original_artifact.description
73-
assert restored_artifact.extensions == original_artifact.extensions
74-
assert restored_artifact.metadata == original_artifact.metadata
75-
76-
assert len(restored_artifact.parts) == 2
77-
assert isinstance(restored_artifact.parts[0].root, TextPart)
78-
assert restored_artifact.parts[0].root.text == 'hello'
79-
assert restored_artifact.parts[0].root.metadata == {
80-
'part_meta': 'hello_meta'
81-
}
82-
83-
assert isinstance(restored_artifact.parts[1].root, DataPart)
84-
assert restored_artifact.parts[1].root.data == {'foo': 'bar'}
85-
assert restored_artifact.parts[1].root.metadata is None
86-
87-
88-
def test_message_conversion_symmetry() -> None:
89-
"""Test converting a Message to TaskMessage and back restores everything."""
90-
original_message = Message(
91-
message_id='msg456',
92-
role=Role.agent,
93-
context_id='ctx1',
94-
task_id='tsk1',
95-
reference_task_ids=['tsk2', 'tsk3'],
96-
extensions=['ext_msg'],
97-
metadata={'msg_meta': 42},
98-
parts=[
99-
Part(root=TextPart(text='message text')),
100-
],
74+
75+
def test_to_stored_task_state() -> None:
76+
assert (
77+
to_stored_task_state(TaskState.unknown)
78+
== vertexai_types.A2aTaskState.STATE_UNSPECIFIED
79+
)
80+
assert (
81+
to_stored_task_state(TaskState.submitted)
82+
== vertexai_types.A2aTaskState.SUBMITTED
83+
)
84+
assert (
85+
to_stored_task_state(TaskState.working)
86+
== vertexai_types.A2aTaskState.WORKING
87+
)
88+
assert (
89+
to_stored_task_state(TaskState.completed)
90+
== vertexai_types.A2aTaskState.COMPLETED
91+
)
92+
assert (
93+
to_stored_task_state(TaskState.canceled)
94+
== vertexai_types.A2aTaskState.CANCELLED
95+
)
96+
assert (
97+
to_stored_task_state(TaskState.failed)
98+
== vertexai_types.A2aTaskState.FAILED
99+
)
100+
assert (
101+
to_stored_task_state(TaskState.rejected)
102+
== vertexai_types.A2aTaskState.REJECTED
103+
)
104+
assert (
105+
to_stored_task_state(TaskState.input_required)
106+
== vertexai_types.A2aTaskState.INPUT_REQUIRED
107+
)
108+
assert (
109+
to_stored_task_state(TaskState.auth_required)
110+
== vertexai_types.A2aTaskState.AUTH_REQUIRED
101111
)
102112

103-
stored = to_stored_message(original_message)
104-
assert stored is not None
105-
assert isinstance(stored, vertexai_types.TaskMessage)
106113

107-
assert stored.message_id == 'msg456'
108-
assert stored.role == 'agent'
109-
assert stored.metadata['__vertex_compat_v'] == 1.0
114+
def test_to_stored_part_text() -> None:
115+
sdk_part = Part(root=TextPart(text='hello world'))
116+
stored_part = to_stored_part(sdk_part)
117+
assert stored_part.text == 'hello world'
118+
assert not stored_part.inline_data
119+
assert not stored_part.file_data
110120

111-
restored_message = to_sdk_message(stored)
112-
assert restored_message is not None
113121

114-
assert restored_message.message_id == original_message.message_id
115-
assert restored_message.role == original_message.role
116-
# context_id and task_id are not serialized via Message metadata in Go implementation but via Task,
117-
# but reference_task_ids and extensions ARE part of Message metadata.
118-
assert (
119-
restored_message.reference_task_ids
120-
== original_message.reference_task_ids
121-
)
122-
assert restored_message.extensions == original_message.extensions
123-
assert restored_message.metadata == original_message.metadata
122+
def test_to_stored_part_data() -> None:
123+
sdk_part = Part(root=DataPart(data={'key': 'value'}))
124+
stored_part = to_stored_part(sdk_part)
125+
assert stored_part.inline_data is not None
126+
assert stored_part.inline_data.mime_type == 'application/json'
127+
assert stored_part.inline_data.data == b'{"key": "value"}'
124128

125-
assert len(restored_message.parts) == 1
126-
assert isinstance(restored_message.parts[0].root, TextPart)
127-
assert restored_message.parts[0].root.text == 'message text'
128-
assert restored_message.parts[0].root.metadata is None
129129

130+
def test_to_stored_part_file_bytes() -> None:
131+
encoded_b64 = base64.b64encode(b'test data').decode('utf-8')
132+
sdk_part = Part(
133+
root=FilePart(
134+
file=FileWithBytes(
135+
bytes=encoded_b64,
136+
mime_type='text/plain',
137+
)
138+
)
139+
)
140+
stored_part = to_stored_part(sdk_part)
141+
assert stored_part.inline_data is not None
142+
assert stored_part.inline_data.mime_type == 'text/plain'
143+
assert stored_part.inline_data.data == b'test data'
130144

131-
def test_to_stored_part_unsupported() -> None:
132-
part = Part.model_construct(
133-
root=Task( # type: ignore[arg-type]
134-
id='invalid-part',
135-
context_id='ctx',
136-
status=TaskStatus(state=TaskState.submitted),
137-
history=[],
145+
146+
def test_to_stored_part_file_uri() -> None:
147+
sdk_part = Part(
148+
root=FilePart(
149+
file=FileWithUri(
150+
uri='gs://test-bucket/file.txt',
151+
mime_type='text/plain',
152+
)
138153
)
139154
)
155+
stored_part = to_stored_part(sdk_part)
156+
assert stored_part.file_data is not None
157+
assert stored_part.file_data.mime_type == 'text/plain'
158+
assert stored_part.file_data.file_uri == 'gs://test-bucket/file.txt'
159+
160+
161+
def test_to_stored_part_unsupported() -> None:
162+
class BadPart:
163+
pass
164+
165+
part = Part(root=TextPart(text='t'))
166+
part.root = BadPart() # type: ignore
140167
with pytest.raises(ValueError, match='Unsupported part type'):
141168
to_stored_part(part)
142169

0 commit comments

Comments
 (0)