Skip to content

Commit 95a3887

Browse files
committed
Address review comments
1 parent 78e9da3 commit 95a3887

4 files changed

Lines changed: 96 additions & 101 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ build-backend = "hatchling.build"
6464
[tool.hatch.version]
6565
source = "uv-dynamic-versioning"
6666

67-
[tool.hatch.metadata]
68-
allow-direct-references = true
69-
7067
[tool.hatch.build.targets.wheel]
7168
packages = ["src/a2a"]
7269

src/a2a/contrib/tasks/vertex_task_converter.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
try:
2-
from vertexai import types
2+
from vertexai import types as vertexai_types
33
except ImportError as e:
44
raise ImportError(
55
'vertex_task_converter requires vertexai. '
@@ -23,68 +23,64 @@
2323
TextPart,
2424
)
2525

26+
_TO_SDK_TASK_STATE = {
27+
vertexai_types.State.STATE_UNSPECIFIED: TaskState.unknown,
28+
vertexai_types.State.SUBMITTED: TaskState.submitted,
29+
vertexai_types.State.WORKING: TaskState.working,
30+
vertexai_types.State.COMPLETED: TaskState.completed,
31+
vertexai_types.State.CANCELLED: TaskState.canceled,
32+
vertexai_types.State.FAILED: TaskState.failed,
33+
vertexai_types.State.REJECTED: TaskState.rejected,
34+
vertexai_types.State.INPUT_REQUIRED: TaskState.input_required,
35+
vertexai_types.State.AUTH_REQUIRED: TaskState.auth_required,
36+
}
2637

27-
def to_sdk_task_state(stored_state: types.State) -> TaskState:
38+
_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()}
39+
40+
41+
def to_sdk_task_state(stored_state: vertexai_types.State) -> TaskState:
2842
"""Converts a proto A2aTask.State to a TaskState enum."""
29-
return {
30-
types.State.STATE_UNSPECIFIED: TaskState.unknown,
31-
types.State.SUBMITTED: TaskState.submitted,
32-
types.State.WORKING: TaskState.working,
33-
types.State.COMPLETED: TaskState.completed,
34-
types.State.CANCELLED: TaskState.canceled,
35-
types.State.FAILED: TaskState.failed,
36-
types.State.REJECTED: TaskState.rejected,
37-
types.State.INPUT_REQUIRED: TaskState.input_required,
38-
types.State.AUTH_REQUIRED: TaskState.auth_required,
39-
}.get(stored_state, TaskState.unknown)
40-
41-
42-
def to_stored_task_state(task_state: TaskState) -> types.State:
43+
return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown)
44+
45+
46+
def to_stored_task_state(task_state: TaskState) -> vertexai_types.State:
4347
"""Converts a TaskState enum to a proto A2aTask.State enum value."""
44-
return {
45-
TaskState.unknown: types.State.STATE_UNSPECIFIED,
46-
TaskState.submitted: types.State.SUBMITTED,
47-
TaskState.working: types.State.WORKING,
48-
TaskState.completed: types.State.COMPLETED,
49-
TaskState.canceled: types.State.CANCELLED,
50-
TaskState.failed: types.State.FAILED,
51-
TaskState.rejected: types.State.REJECTED,
52-
TaskState.input_required: types.State.INPUT_REQUIRED,
53-
TaskState.auth_required: types.State.AUTH_REQUIRED,
54-
}.get(task_state, types.State.STATE_UNSPECIFIED)
55-
56-
57-
def to_stored_part(part: Part) -> types.Part:
48+
return _SDK_TO_STORED_TASK_STATE.get(
49+
task_state, vertexai_types.State.STATE_UNSPECIFIED
50+
)
51+
52+
53+
def to_stored_part(part: Part) -> vertexai_types.Part:
5854
"""Converts a SDK Part to a proto Part."""
5955
if isinstance(part.root, TextPart):
60-
return types.Part(text=part.root.text)
56+
return vertexai_types.Part(text=part.root.text)
6157
if isinstance(part.root, DataPart):
6258
data_bytes = json.dumps(part.root.data).encode('utf-8')
63-
return types.Part(
64-
inline_data=types.Blob(
59+
return vertexai_types.Part(
60+
inline_data=vertexai_types.Blob(
6561
mime_type='application/json', data=data_bytes
6662
)
6763
)
6864
if isinstance(part.root, FilePart):
6965
file_content = part.root.file
7066
if isinstance(file_content, FileWithBytes):
7167
decoded_bytes = base64.b64decode(file_content.bytes)
72-
return types.Part(
73-
inline_data=types.Blob(
68+
return vertexai_types.Part(
69+
inline_data=vertexai_types.Blob(
7470
mime_type=file_content.mime_type or '', data=decoded_bytes
7571
)
7672
)
7773
if isinstance(file_content, FileWithUri):
78-
return types.Part(
79-
file_data=types.FileData(
74+
return vertexai_types.Part(
75+
file_data=vertexai_types.FileData(
8076
mime_type=file_content.mime_type or '',
8177
file_uri=file_content.uri,
8278
)
8379
)
8480
raise ValueError(f'Unsupported part type: {type(part.root)}')
8581

8682

87-
def to_sdk_part(stored_part: types.Part) -> Part:
83+
def to_sdk_part(stored_part: vertexai_types.Part) -> Part:
8884
"""Converts a proto Part to a SDK Part."""
8985
if stored_part.text:
9086
return Part(root=TextPart(text=stored_part.text))
@@ -113,29 +109,29 @@ def to_sdk_part(stored_part: types.Part) -> Part:
113109
return Part(root=TextPart(text=''))
114110

115111

116-
def to_stored_artifact(artifact: Artifact) -> types.TaskArtifact:
112+
def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
117113
"""Converts a SDK Artifact to a proto TaskArtifact."""
118-
return types.TaskArtifact(
114+
return vertexai_types.TaskArtifact(
119115
artifact_id=artifact.artifact_id,
120116
parts=[to_stored_part(part) for part in artifact.parts],
121117
)
122118

123119

124-
def to_sdk_artifact(stored_artifact: types.TaskArtifact) -> Artifact:
120+
def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
125121
"""Converts a proto TaskArtifact to a SDK Artifact."""
126122
return Artifact(
127123
artifact_id=stored_artifact.artifact_id,
128124
parts=[to_sdk_part(part) for part in stored_artifact.parts],
129125
)
130126

131127

132-
def to_stored_task(task: Task) -> types.A2aTask:
128+
def to_stored_task(task: Task) -> vertexai_types.A2aTask:
133129
"""Converts a SDK Task to a proto A2aTask."""
134-
return types.A2aTask(
130+
return vertexai_types.A2aTask(
135131
context_id=task.context_id,
136132
metadata=task.metadata,
137133
state=to_stored_task_state(task.status.state),
138-
output=types.TaskOutput(
134+
output=vertexai_types.TaskOutput(
139135
artifacts=[
140136
to_stored_artifact(artifact)
141137
for artifact in task.artifacts or []
@@ -144,7 +140,7 @@ def to_stored_task(task: Task) -> types.A2aTask:
144140
)
145141

146142

147-
def to_sdk_task(a2a_task: types.A2aTask) -> Task:
143+
def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
148144
"""Converts a proto A2aTask to a SDK Task."""
149145
return Task(
150146
id=a2a_task.name.split('/')[-1],

src/a2a/contrib/tasks/vertex_task_store.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import vertexai
66

77
from google.genai import errors as genai_errors
8-
from vertexai import types
8+
from vertexai import types as vertexai_types
99
except ImportError as e:
1010
raise ImportError(
1111
'VertexTaskStore requires vertexai. '
@@ -57,7 +57,7 @@ async def _create(self, sdk_task: Task) -> None:
5757
await self._client.aio.agent_engines.a2a_tasks.create(
5858
name=self._agent_engine_resource_id,
5959
a2a_task_id=sdk_task.id,
60-
config=types.CreateAgentEngineTaskConfig(
60+
config=vertexai_types.CreateAgentEngineTaskConfig(
6161
context_id=stored_task.context_id,
6262
metadata=stored_task.metadata,
6363
output=stored_task.output,
@@ -66,11 +66,11 @@ async def _create(self, sdk_task: Task) -> None:
6666

6767
def _get_status_change_event(
6868
self, previous_task: Task, task: Task, event_sequence_number: int
69-
) -> types.TaskEvent | None:
69+
) -> vertexai_types.TaskEvent | None:
7070
if task.status.state != previous_task.status.state:
71-
return types.TaskEvent(
72-
event_data=types.TaskEventData(
73-
state_change=types.TaskStateChange(
71+
return vertexai_types.TaskEvent(
72+
event_data=vertexai_types.TaskEventData(
73+
state_change=vertexai_types.TaskStateChange(
7474
new_state=vertex_task_converter.to_stored_task_state(
7575
task.status.state
7676
),
@@ -82,11 +82,11 @@ def _get_status_change_event(
8282

8383
def _get_metadata_change_event(
8484
self, previous_task: Task, task: Task, event_sequence_number: int
85-
) -> types.TaskEvent | None:
85+
) -> vertexai_types.TaskEvent | None:
8686
if task.metadata != previous_task.metadata:
87-
return types.TaskEvent(
88-
event_data=types.TaskEventData(
89-
metadata_change=types.TaskMetadataChange(
87+
return vertexai_types.TaskEvent(
88+
event_data=vertexai_types.TaskEventData(
89+
metadata_change=vertexai_types.TaskMetadataChange(
9090
new_metadata=task.metadata,
9191
)
9292
),
@@ -96,12 +96,12 @@ def _get_metadata_change_event(
9696

9797
def _get_artifacts_change_event(
9898
self, previous_task: Task, task: Task, event_sequence_number: int
99-
) -> types.TaskEvent | None:
99+
) -> vertexai_types.TaskEvent | None:
100100
if task.artifacts != previous_task.artifacts:
101-
task_artifact_change = types.TaskArtifactChange()
102-
event = types.TaskEvent(
103-
event_data=types.TaskEventData(
104-
output_change=types.TaskOutputChange(
101+
task_artifact_change = vertexai_types.TaskArtifactChange()
102+
event = vertexai_types.TaskEvent(
103+
event_data=vertexai_types.TaskEventData(
104+
output_change=vertexai_types.TaskOutputChange(
105105
task_artifact_change=task_artifact_change
106106
)
107107
),
@@ -140,12 +140,12 @@ def _get_artifacts_change_event(
140140
task_artifact_change.updated_artifacts.append(
141141
vertex_task_converter.to_stored_artifact(artifact)
142142
)
143-
if task_artifact_change != types.TaskArtifactChange():
143+
if task_artifact_change != vertexai_types.TaskArtifactChange():
144144
return event
145145
return None
146146

147147
async def _update(
148-
self, previous_stored_task: types.A2aTask, task: Task
148+
self, previous_stored_task: vertexai_types.A2aTask, task: Task
149149
) -> None:
150150
previous_task = vertex_task_converter.to_sdk_task(previous_stored_task)
151151
events = []
@@ -179,7 +179,9 @@ async def _update(
179179
task_events=events,
180180
)
181181

182-
async def _get_stored_task(self, task_id: str) -> types.A2aTask | None:
182+
async def _get_stored_task(
183+
self, task_id: str
184+
) -> vertexai_types.A2aTask | None:
183185
try:
184186
a2a_task = await self._client.aio.agent_engines.a2a_tasks.get(
185187
name=self._agent_engine_resource_id + '/a2aTasks/' + task_id,

0 commit comments

Comments
 (0)