Skip to content

Commit 1d6e5c0

Browse files
committed
feat: Implement a vertex based task store for the 1.0 branch
1 parent 942f4ae commit 1d6e5c0

10 files changed

Lines changed: 1805 additions & 7 deletions

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
4040
signing = ["PyJWT>=2.0.0"]
4141
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
4242
db-cli = ["alembic>=1.14.0"]
43+
vertex = ["google-cloud-aiplatform>=1.140.0"]
4344

4445
sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
4546

@@ -51,6 +52,7 @@ all = [
5152
"a2a-sdk[telemetry]",
5253
"a2a-sdk[signing]",
5354
"a2a-sdk[db-cli]",
55+
"a2a-sdk[vertex]",
5456
]
5557

5658
[project.urls]

src/a2a/contrib/tasks/__init__.py

Whitespace-only changes.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
try:
2+
from vertexai import types as vertexai_types
3+
except ImportError as e:
4+
raise ImportError(
5+
'vertex_task_converter requires vertexai. '
6+
'Install with: '
7+
"'pip install a2a-sdk[vertex]'"
8+
) from e
9+
10+
import base64
11+
import json
12+
13+
from a2a.compat.v0_3.types import (
14+
Artifact,
15+
DataPart,
16+
FilePart,
17+
FileWithBytes,
18+
FileWithUri,
19+
Part,
20+
Task,
21+
TaskState,
22+
TaskStatus,
23+
TextPart,
24+
)
25+
26+
27+
_TO_SDK_TASK_STATE = {
28+
vertexai_types.State.STATE_UNSPECIFIED: TaskState.unknown,
29+
vertexai_types.State.SUBMITTED: TaskState.submitted,
30+
vertexai_types.State.WORKING: TaskState.working,
31+
vertexai_types.State.COMPLETED: TaskState.completed,
32+
vertexai_types.State.CANCELLED: TaskState.canceled,
33+
vertexai_types.State.FAILED: TaskState.failed,
34+
vertexai_types.State.REJECTED: TaskState.rejected,
35+
vertexai_types.State.INPUT_REQUIRED: TaskState.input_required,
36+
vertexai_types.State.AUTH_REQUIRED: TaskState.auth_required,
37+
}
38+
39+
_SDK_TO_STORED_TASK_STATE = {v: k for k, v in _TO_SDK_TASK_STATE.items()}
40+
41+
42+
def to_sdk_task_state(stored_state: vertexai_types.State) -> TaskState:
43+
"""Converts a proto A2aTask.State to a TaskState enum."""
44+
return _TO_SDK_TASK_STATE.get(stored_state, TaskState.unknown)
45+
46+
47+
def to_stored_task_state(task_state: TaskState) -> vertexai_types.State:
48+
"""Converts a TaskState enum to a proto A2aTask.State enum value."""
49+
return _SDK_TO_STORED_TASK_STATE.get(
50+
task_state, vertexai_types.State.STATE_UNSPECIFIED
51+
)
52+
53+
54+
def to_stored_part(part: Part) -> vertexai_types.Part:
55+
"""Converts a SDK Part to a proto Part."""
56+
if isinstance(part.root, TextPart):
57+
return vertexai_types.Part(text=part.root.text)
58+
if isinstance(part.root, DataPart):
59+
data_bytes = json.dumps(part.root.data).encode('utf-8')
60+
return vertexai_types.Part(
61+
inline_data=vertexai_types.Blob(
62+
mime_type='application/json', data=data_bytes
63+
)
64+
)
65+
if isinstance(part.root, FilePart):
66+
file_content = part.root.file
67+
if isinstance(file_content, FileWithBytes):
68+
decoded_bytes = base64.b64decode(file_content.bytes)
69+
return vertexai_types.Part(
70+
inline_data=vertexai_types.Blob(
71+
mime_type=file_content.mime_type or '', data=decoded_bytes
72+
)
73+
)
74+
if isinstance(file_content, FileWithUri):
75+
return vertexai_types.Part(
76+
file_data=vertexai_types.FileData(
77+
mime_type=file_content.mime_type or '',
78+
file_uri=file_content.uri,
79+
)
80+
)
81+
raise ValueError(f'Unsupported part type: {type(part.root)}')
82+
83+
84+
def to_sdk_part(stored_part: vertexai_types.Part) -> Part:
85+
"""Converts a proto Part to a SDK Part."""
86+
if stored_part.text:
87+
return Part(root=TextPart(text=stored_part.text))
88+
if stored_part.inline_data:
89+
encoded_bytes = base64.b64encode(stored_part.inline_data.data).decode(
90+
'utf-8'
91+
)
92+
return Part(
93+
root=FilePart(
94+
file=FileWithBytes(
95+
mime_type=stored_part.inline_data.mime_type,
96+
bytes=encoded_bytes,
97+
)
98+
)
99+
)
100+
if stored_part.file_data:
101+
return Part(
102+
root=FilePart(
103+
file=FileWithUri(
104+
mime_type=stored_part.file_data.mime_type,
105+
uri=stored_part.file_data.file_uri,
106+
)
107+
)
108+
)
109+
110+
raise ValueError(f'Unsupported part: {stored_part}')
111+
112+
113+
def to_stored_artifact(artifact: Artifact) -> vertexai_types.TaskArtifact:
114+
"""Converts a SDK Artifact to a proto TaskArtifact."""
115+
return vertexai_types.TaskArtifact(
116+
artifact_id=artifact.artifact_id,
117+
parts=[to_stored_part(part) for part in artifact.parts],
118+
)
119+
120+
121+
def to_sdk_artifact(stored_artifact: vertexai_types.TaskArtifact) -> Artifact:
122+
"""Converts a proto TaskArtifact to a SDK Artifact."""
123+
return Artifact(
124+
artifact_id=stored_artifact.artifact_id,
125+
parts=[to_sdk_part(part) for part in stored_artifact.parts],
126+
)
127+
128+
129+
def to_stored_task(task: Task) -> vertexai_types.A2aTask:
130+
"""Converts a SDK Task to a proto A2aTask."""
131+
return vertexai_types.A2aTask(
132+
context_id=task.context_id,
133+
metadata=task.metadata,
134+
state=to_stored_task_state(task.status.state),
135+
output=vertexai_types.TaskOutput(
136+
artifacts=[
137+
to_stored_artifact(artifact)
138+
for artifact in task.artifacts or []
139+
]
140+
),
141+
)
142+
143+
144+
def to_sdk_task(a2a_task: vertexai_types.A2aTask) -> Task:
145+
"""Converts a proto A2aTask to a SDK Task."""
146+
return Task(
147+
id=a2a_task.name.split('/')[-1],
148+
context_id=a2a_task.context_id,
149+
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
150+
metadata=a2a_task.metadata or {},
151+
artifacts=[
152+
to_sdk_artifact(artifact)
153+
for artifact in a2a_task.output.artifacts or []
154+
]
155+
if a2a_task.output
156+
else [],
157+
history=[],
158+
)
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import logging
2+
3+
4+
try:
5+
import vertexai
6+
7+
from google.genai import errors as genai_errors
8+
from vertexai import types as vertexai_types
9+
except ImportError as e:
10+
raise ImportError(
11+
'VertexTaskStore requires vertexai. '
12+
'Install with: '
13+
"'pip install a2a-sdk[vertex]'"
14+
) from e
15+
16+
from a2a.compat.v0_3.conversions import to_compat_task, to_core_task
17+
from a2a.compat.v0_3.types import Task as CompatTask
18+
from a2a.contrib.tasks import vertex_task_converter
19+
from a2a.server.context import ServerCallContext
20+
from a2a.server.tasks.task_store import TaskStore
21+
from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task
22+
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
class VertexTaskStore(TaskStore):
28+
"""Implementation of TaskStore using Vertex AI Agent Engine Task Store.
29+
30+
Stores task objects in Vertex AI Agent Engine Task Store.
31+
"""
32+
33+
def __init__(
34+
self,
35+
client: vertexai.Client,
36+
agent_engine_resource_id: str,
37+
) -> None:
38+
"""Initializes the VertexTaskStore.
39+
40+
Args:
41+
client: The Vertex AI client.
42+
agent_engine_resource_id: The resource ID of the agent engine.
43+
"""
44+
self._client = client
45+
self._agent_engine_resource_id = agent_engine_resource_id
46+
47+
async def save(
48+
self, task: Task, context: ServerCallContext | None = None
49+
) -> None:
50+
"""Saves or updates a task in the store."""
51+
compat_task = to_compat_task(task)
52+
previous_task = await self._get_stored_task(compat_task.id)
53+
if previous_task is None:
54+
await self._create(compat_task)
55+
else:
56+
await self._update(previous_task, compat_task)
57+
58+
async def _create(self, sdk_task: CompatTask) -> None:
59+
stored_task = vertex_task_converter.to_stored_task(sdk_task)
60+
await self._client.aio.agent_engines.a2a_tasks.create(
61+
name=self._agent_engine_resource_id,
62+
a2a_task_id=sdk_task.id,
63+
config=vertexai_types.CreateAgentEngineTaskConfig(
64+
context_id=stored_task.context_id,
65+
metadata=stored_task.metadata,
66+
output=stored_task.output,
67+
),
68+
)
69+
70+
def _get_status_change_event(
71+
self, previous_task: CompatTask, task: CompatTask, event_sequence_number: int
72+
) -> vertexai_types.TaskEvent | None:
73+
if task.status.state != previous_task.status.state:
74+
return vertexai_types.TaskEvent(
75+
event_data=vertexai_types.TaskEventData(
76+
state_change=vertexai_types.TaskStateChange(
77+
new_state=vertex_task_converter.to_stored_task_state(
78+
task.status.state
79+
),
80+
),
81+
),
82+
event_sequence_number=event_sequence_number,
83+
)
84+
return None
85+
86+
def _get_metadata_change_event(
87+
self, previous_task: CompatTask, task: CompatTask, event_sequence_number: int
88+
) -> vertexai_types.TaskEvent | None:
89+
if task.metadata != previous_task.metadata:
90+
return vertexai_types.TaskEvent(
91+
event_data=vertexai_types.TaskEventData(
92+
metadata_change=vertexai_types.TaskMetadataChange(
93+
new_metadata=task.metadata,
94+
)
95+
),
96+
event_sequence_number=event_sequence_number,
97+
)
98+
return None
99+
100+
def _get_artifacts_change_event(
101+
self, previous_task: CompatTask, task: CompatTask, event_sequence_number: int
102+
) -> vertexai_types.TaskEvent | None:
103+
if task.artifacts != previous_task.artifacts:
104+
task_artifact_change = vertexai_types.TaskArtifactChange()
105+
event = vertexai_types.TaskEvent(
106+
event_data=vertexai_types.TaskEventData(
107+
output_change=vertexai_types.TaskOutputChange(
108+
task_artifact_change=task_artifact_change
109+
)
110+
),
111+
event_sequence_number=event_sequence_number,
112+
)
113+
task_artifacts = (
114+
{artifact.artifact_id: artifact for artifact in task.artifacts}
115+
if task.artifacts
116+
else {}
117+
)
118+
previous_task_artifacts = (
119+
{
120+
artifact.artifact_id: artifact
121+
for artifact in previous_task.artifacts
122+
}
123+
if previous_task.artifacts
124+
else {}
125+
)
126+
for artifact in previous_task_artifacts.values():
127+
if artifact.artifact_id not in task_artifacts:
128+
if not task_artifact_change.deleted_artifact_ids:
129+
task_artifact_change.deleted_artifact_ids = []
130+
task_artifact_change.deleted_artifact_ids.append(
131+
artifact.artifact_id
132+
)
133+
for artifact in task_artifacts.values():
134+
if artifact.artifact_id not in previous_task_artifacts:
135+
if not task_artifact_change.added_artifacts:
136+
task_artifact_change.added_artifacts = []
137+
task_artifact_change.added_artifacts.append(
138+
vertex_task_converter.to_stored_artifact(artifact)
139+
)
140+
elif artifact != previous_task_artifacts[artifact.artifact_id]:
141+
if not task_artifact_change.updated_artifacts:
142+
task_artifact_change.updated_artifacts = []
143+
task_artifact_change.updated_artifacts.append(
144+
vertex_task_converter.to_stored_artifact(artifact)
145+
)
146+
if task_artifact_change != vertexai_types.TaskArtifactChange():
147+
return event
148+
return None
149+
150+
async def _update(
151+
self, previous_stored_task: vertexai_types.A2aTask, task: CompatTask
152+
) -> None:
153+
previous_task = vertex_task_converter.to_sdk_task(previous_stored_task)
154+
events = []
155+
event_sequence_number = previous_stored_task.next_event_sequence_number
156+
157+
status_event = self._get_status_change_event(
158+
previous_task, task, event_sequence_number
159+
)
160+
if status_event:
161+
events.append(status_event)
162+
event_sequence_number += 1
163+
164+
metadata_event = self._get_metadata_change_event(
165+
previous_task, task, event_sequence_number
166+
)
167+
if metadata_event:
168+
events.append(metadata_event)
169+
event_sequence_number += 1
170+
171+
artifacts_event = self._get_artifacts_change_event(
172+
previous_task, task, event_sequence_number
173+
)
174+
if artifacts_event:
175+
events.append(artifacts_event)
176+
event_sequence_number += 1
177+
178+
if not events:
179+
return
180+
await self._client.aio.agent_engines.a2a_tasks.events.append(
181+
name=self._agent_engine_resource_id + '/a2aTasks/' + task.id,
182+
task_events=events,
183+
)
184+
185+
async def _get_stored_task(
186+
self, task_id: str
187+
) -> vertexai_types.A2aTask | None:
188+
try:
189+
a2a_task = await self._client.aio.agent_engines.a2a_tasks.get(
190+
name=self._agent_engine_resource_id + '/a2aTasks/' + task_id,
191+
)
192+
except genai_errors.APIError as e:
193+
if e.status == 'NOT_FOUND':
194+
logger.debug('Task %s not found in store.', task_id)
195+
return None
196+
raise
197+
return a2a_task
198+
199+
async def get(
200+
self, task_id: str, context: ServerCallContext | None = None
201+
) -> Task | None:
202+
"""Retrieves a task from the database by ID."""
203+
a2a_task = await self._get_stored_task(task_id)
204+
if a2a_task is None:
205+
return None
206+
return to_core_task(vertex_task_converter.to_sdk_task(a2a_task))
207+
208+
async def list(
209+
self,
210+
params: ListTasksRequest,
211+
context: ServerCallContext | None = None,
212+
) -> ListTasksResponse:
213+
"""Retrieves a list of tasks from the store."""
214+
raise NotImplementedError
215+
216+
async def delete(
217+
self, task_id: str, context: ServerCallContext | None = None
218+
) -> None:
219+
"""The backend doesn't support deleting tasks, so this is not implemented."""
220+
raise NotImplementedError

0 commit comments

Comments
 (0)