Skip to content

Commit 123e36e

Browse files
committed
chore: remove the use of ambigouos and soon-to-be-deleted vertex types for VertexTaskStore
1 parent 4ebbb2e commit 123e36e

3 files changed

Lines changed: 54 additions & 53 deletions

File tree

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
try:
2+
from google.genai import types as genai_types
23
from vertexai import types as vertexai_types
34
except ImportError as e:
45
raise ImportError(
@@ -25,63 +26,63 @@
2526

2627

2728
_TO_SDK_TASK_STATE = {
28-
vertexai_types.State.STATE_UNSPECIFIED: TaskState.unknown,
29-
vertexai_types.State.SUBMITTED: TaskState.submitted,
30-
vertexai_types.State.WORKING: TaskState.working,
31-
vertexai_types.State.COMPLETED: TaskState.completed,
32-
vertexai_types.State.CANCELLED: TaskState.canceled,
33-
vertexai_types.State.FAILED: TaskState.failed,
34-
vertexai_types.State.REJECTED: TaskState.rejected,
35-
vertexai_types.State.INPUT_REQUIRED: TaskState.input_required,
36-
vertexai_types.State.AUTH_REQUIRED: TaskState.auth_required,
29+
vertexai_types.A2aTaskState.STATE_UNSPECIFIED: TaskState.unknown,
30+
vertexai_types.A2aTaskState.SUBMITTED: TaskState.submitted,
31+
vertexai_types.A2aTaskState.WORKING: TaskState.working,
32+
vertexai_types.A2aTaskState.COMPLETED: TaskState.completed,
33+
vertexai_types.A2aTaskState.CANCELLED: TaskState.canceled,
34+
vertexai_types.A2aTaskState.FAILED: TaskState.failed,
35+
vertexai_types.A2aTaskState.REJECTED: TaskState.rejected,
36+
vertexai_types.A2aTaskState.INPUT_REQUIRED: TaskState.input_required,
37+
vertexai_types.A2aTaskState.AUTH_REQUIRED: TaskState.auth_required,
3738
}
3839

3940
_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()}
4041

4142

42-
def to_sdk_task_state(stored_state: vertexai_types.State) -> TaskState:
43+
def to_sdk_task_state(stored_state: vertexai_types.A2aTaskState) -> TaskState:
4344
"""Converts a proto A2aTask.State to a TaskState enum."""
4445
return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown)
4546

4647

47-
def to_stored_task_state(task_state: TaskState) -> vertexai_types.State:
48+
def to_stored_task_state(task_state: TaskState) -> vertexai_types.A2aTaskState:
4849
"""Converts a TaskState enum to a proto A2aTask.State enum value."""
4950
return _SDK_TO_STORED_TASK_STATE.get(
50-
task_state, vertexai_types.State.STATE_UNSPECIFIED
51+
task_state, vertexai_types.A2aTaskState.STATE_UNSPECIFIED
5152
)
5253

5354

54-
def to_stored_part(part: Part) -> vertexai_types.Part:
55+
def to_stored_part(part: Part) -> genai_types.Part:
5556
"""Converts a SDK Part to a proto Part."""
5657
if isinstance(part.root, TextPart):
57-
return vertexai_types.Part(text=part.root.text)
58+
return genai_types.Part(text=part.root.text)
5859
if isinstance(part.root, DataPart):
5960
data_bytes = json.dumps(part.root.data).encode('utf-8')
60-
return vertexai_types.Part(
61-
inline_data=vertexai_types.Blob(
61+
return genai_types.Part(
62+
inline_data=genai_types.Blob(
6263
mime_type='application/json', data=data_bytes
6364
)
6465
)
6566
if isinstance(part.root, FilePart):
6667
file_content = part.root.file
6768
if isinstance(file_content, FileWithBytes):
6869
decoded_bytes = base64.b64decode(file_content.bytes)
69-
return vertexai_types.Part(
70-
inline_data=vertexai_types.Blob(
70+
return genai_types.Part(
71+
inline_data=genai_types.Blob(
7172
mime_type=file_content.mime_type or '', data=decoded_bytes
7273
)
7374
)
7475
if isinstance(file_content, FileWithUri):
75-
return vertexai_types.Part(
76-
file_data=vertexai_types.FileData(
76+
return genai_types.Part(
77+
file_data=genai_types.FileData(
7778
mime_type=file_content.mime_type or '',
7879
file_uri=file_content.uri,
7980
)
8081
)
8182
raise ValueError(f'Unsupported part type: {type(part.root)}')
8283

8384

84-
def to_sdk_part(stored_part: vertexai_types.Part) -> Part:
85+
def to_sdk_part(stored_part: genai_types.Part) -> Part:
8586
"""Converts a proto Part to a SDK Part."""
8687
if stored_part.text:
8788
return Part(root=TextPart(text=stored_part.text))

src/a2a/contrib/tasks/vertex_task_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
try:
55
import vertexai
6-
7-
from google.genai import errors as genai_errors
86
from vertexai import types as vertexai_types
7+
from google.genai import types as genai_types
8+
from google.genai import errors as genai_errors
99
except ImportError as e:
1010
raise ImportError(
1111
'VertexTaskStore requires vertexai. '

tests/contrib/tasks/test_vertex_task_converter.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
'vertexai', reason='Vertex Task Converter tests require vertexai'
88
)
99
from vertexai import types as vertexai_types
10-
10+
from google.genai import types as genai_types
1111
from a2a.contrib.tasks.vertex_task_converter import (
1212
to_sdk_artifact,
1313
to_sdk_part,
@@ -34,29 +34,29 @@
3434

3535
def test_to_sdk_task_state() -> None:
3636
assert (
37-
to_sdk_task_state(vertexai_types.State.STATE_UNSPECIFIED)
37+
to_sdk_task_state(vertexai_types.A2aTaskState.STATE_UNSPECIFIED)
3838
== TaskState.unknown
3939
)
4040
assert (
41-
to_sdk_task_state(vertexai_types.State.SUBMITTED) == TaskState.submitted
41+
to_sdk_task_state(vertexai_types.A2aTaskState.SUBMITTED) == TaskState.submitted
4242
)
43-
assert to_sdk_task_state(vertexai_types.State.WORKING) == TaskState.working
43+
assert to_sdk_task_state(vertexai_types.A2aTaskState.WORKING) == TaskState.working
4444
assert (
45-
to_sdk_task_state(vertexai_types.State.COMPLETED) == TaskState.completed
45+
to_sdk_task_state(vertexai_types.A2aTaskState.COMPLETED) == TaskState.completed
4646
)
4747
assert (
48-
to_sdk_task_state(vertexai_types.State.CANCELLED) == TaskState.canceled
48+
to_sdk_task_state(vertexai_types.A2aTaskState.CANCELLED) == TaskState.canceled
4949
)
50-
assert to_sdk_task_state(vertexai_types.State.FAILED) == TaskState.failed
50+
assert to_sdk_task_state(vertexai_types.A2aTaskState.FAILED) == TaskState.failed
5151
assert (
52-
to_sdk_task_state(vertexai_types.State.REJECTED) == TaskState.rejected
52+
to_sdk_task_state(vertexai_types.A2aTaskState.REJECTED) == TaskState.rejected
5353
)
5454
assert (
55-
to_sdk_task_state(vertexai_types.State.INPUT_REQUIRED)
55+
to_sdk_task_state(vertexai_types.A2aTaskState.INPUT_REQUIRED)
5656
== TaskState.input_required
5757
)
5858
assert (
59-
to_sdk_task_state(vertexai_types.State.AUTH_REQUIRED)
59+
to_sdk_task_state(vertexai_types.A2aTaskState.AUTH_REQUIRED)
6060
== TaskState.auth_required
6161
)
6262
assert to_sdk_task_state(999) == TaskState.unknown # type: ignore
@@ -65,35 +65,35 @@ def test_to_sdk_task_state() -> None:
6565
def test_to_stored_task_state() -> None:
6666
assert (
6767
to_stored_task_state(TaskState.unknown)
68-
== vertexai_types.State.STATE_UNSPECIFIED
68+
== vertexai_types.A2aTaskState.STATE_UNSPECIFIED
6969
)
7070
assert (
7171
to_stored_task_state(TaskState.submitted)
72-
== vertexai_types.State.SUBMITTED
72+
== vertexai_types.A2aTaskState.SUBMITTED
7373
)
7474
assert (
75-
to_stored_task_state(TaskState.working) == vertexai_types.State.WORKING
75+
to_stored_task_state(TaskState.working) == vertexai_types.A2aTaskState.WORKING
7676
)
7777
assert (
7878
to_stored_task_state(TaskState.completed)
79-
== vertexai_types.State.COMPLETED
79+
== vertexai_types.A2aTaskState.COMPLETED
8080
)
8181
assert (
8282
to_stored_task_state(TaskState.canceled)
83-
== vertexai_types.State.CANCELLED
83+
== vertexai_types.A2aTaskState.CANCELLED
8484
)
85-
assert to_stored_task_state(TaskState.failed) == vertexai_types.State.FAILED
85+
assert to_stored_task_state(TaskState.failed) == vertexai_types.A2aTaskState.FAILED
8686
assert (
8787
to_stored_task_state(TaskState.rejected)
88-
== vertexai_types.State.REJECTED
88+
== vertexai_types.A2aTaskState.REJECTED
8989
)
9090
assert (
9191
to_stored_task_state(TaskState.input_required)
92-
== vertexai_types.State.INPUT_REQUIRED
92+
== vertexai_types.A2aTaskState.INPUT_REQUIRED
9393
)
9494
assert (
9595
to_stored_task_state(TaskState.auth_required)
96-
== vertexai_types.State.AUTH_REQUIRED
96+
== vertexai_types.A2aTaskState.AUTH_REQUIRED
9797
)
9898

9999

@@ -155,15 +155,15 @@ class BadPart:
155155

156156

157157
def test_to_sdk_part_text() -> None:
158-
stored_part = vertexai_types.Part(text='hello back')
158+
stored_part = genai_types.Part(text='hello back')
159159
sdk_part = to_sdk_part(stored_part)
160160
assert isinstance(sdk_part.root, TextPart)
161161
assert sdk_part.root.text == 'hello back'
162162

163163

164164
def test_to_sdk_part_inline_data() -> None:
165-
stored_part = vertexai_types.Part(
166-
inline_data=vertexai_types.Blob(
165+
stored_part = genai_types.Part(
166+
inline_data=genai_types.Blob(
167167
mime_type='application/json',
168168
data=b'{"key": "val"}',
169169
)
@@ -177,8 +177,8 @@ def test_to_sdk_part_inline_data() -> None:
177177

178178

179179
def test_to_sdk_part_file_data() -> None:
180-
stored_part = vertexai_types.Part(
181-
file_data=vertexai_types.FileData(
180+
stored_part = genai_types.Part(
181+
file_data=genai_types.FileData(
182182
mime_type='image/jpeg',
183183
file_uri='gs://bucket/image.jpg',
184184
)
@@ -191,7 +191,7 @@ def test_to_sdk_part_file_data() -> None:
191191

192192

193193
def test_to_sdk_part_unsupported() -> None:
194-
stored_part = vertexai_types.Part()
194+
stored_part = genai_types.Part()
195195
with pytest.raises(ValueError, match='Unsupported part:'):
196196
to_sdk_part(stored_part)
197197

@@ -210,7 +210,7 @@ def test_to_stored_artifact() -> None:
210210
def test_to_sdk_artifact() -> None:
211211
stored_artifact = vertexai_types.TaskArtifact(
212212
artifact_id='art-456',
213-
parts=[vertexai_types.Part(text='part_2')],
213+
parts=[genai_types.Part(text='part_2')],
214214
)
215215
sdk_artifact = to_sdk_artifact(stored_artifact)
216216
assert sdk_artifact.artifact_id == 'art-456'
@@ -236,7 +236,7 @@ def test_to_stored_task() -> None:
236236
stored_task = to_stored_task(sdk_task)
237237
assert stored_task.context_id == 'ctx-1'
238238
assert stored_task.metadata == {'foo': 'bar'}
239-
assert stored_task.state == vertexai_types.State.WORKING
239+
assert stored_task.state == vertexai_types.A2aTaskState.WORKING
240240
assert stored_task.output is not None
241241
assert stored_task.output.artifacts is not None
242242
assert len(stored_task.output.artifacts) == 1
@@ -247,13 +247,13 @@ def test_to_sdk_task() -> None:
247247
stored_task = vertexai_types.A2aTask(
248248
name='projects/123/locations/us-central1/agentEngines/456/tasks/task-2',
249249
context_id='ctx-2',
250-
state=vertexai_types.State.COMPLETED,
250+
state=vertexai_types.A2aTaskState.COMPLETED,
251251
metadata={'a': 'b'},
252252
output=vertexai_types.TaskOutput(
253253
artifacts=[
254254
vertexai_types.TaskArtifact(
255255
artifact_id='art-2',
256-
parts=[vertexai_types.Part(text='result')],
256+
parts=[genai_types.Part(text='result')],
257257
)
258258
]
259259
),
@@ -275,7 +275,7 @@ def test_to_sdk_task_no_output() -> None:
275275
stored_task = vertexai_types.A2aTask(
276276
name='tasks/task-3',
277277
context_id='ctx-3',
278-
state=vertexai_types.State.SUBMITTED,
278+
state=vertexai_types.A2aTaskState.SUBMITTED,
279279
metadata=None,
280280
)
281281
sdk_task = to_sdk_task(stored_task)

0 commit comments

Comments
 (0)