Skip to content

Commit b53c89e

Browse files
committed
Implement a vertex based task store
1 parent cced34d commit b53c89e

9 files changed

Lines changed: 2645 additions & 1242 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"]
3737
mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"]
3838
signing = ["PyJWT>=2.0.0"]
3939
sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"]
40+
vertex = ["google-cloud-aiplatform @ git+https://github.com/googleapis/python-aiplatform.git@copybara_874762764"]
4041

4142
sql = ["a2a-sdk[postgresql,mysql,sqlite]"]
4243

@@ -47,6 +48,7 @@ all = [
4748
"a2a-sdk[grpc]",
4849
"a2a-sdk[telemetry]",
4950
"a2a-sdk[signing]",
51+
"a2a-sdk[vertex]",
5052
]
5153

5254
[project.urls]
@@ -62,6 +64,9 @@ build-backend = "hatchling.build"
6264
[tool.hatch.version]
6365
source = "uv-dynamic-versioning"
6466

67+
[tool.hatch.metadata]
68+
allow-direct-references = true
69+
6570
[tool.hatch.build.targets.wheel]
6671
packages = ["src/a2a"]
6772

scripts/run_vertex_tests.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
set -e
3+
4+
for var in VERTEX_PROJECT VERTEX_LOCATION VERTEX_BASE_URL VERTEX_API_VERSION; do
5+
if [ -z "${!var}" ]; then
6+
echo "Error: Environment variable $var is undefined or empty." >&2
7+
exit 1
8+
fi
9+
done
10+
11+
# Get the directory of this script
12+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
13+
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
14+
PYTEST_ARGS=("$@")
15+
16+
uv run pytest -v "${PYTEST_ARGS[@]}" tests/server/tasks/test_vertex_task_store.py tests/server/tasks/test_vertex_task_converter.py
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
try:
2+
from vertexai import types
3+
except ImportError as e:
4+
raise ImportError(
5+
'vertex_task_converter requires vertexai. '
6+
'Install with one of: '
7+
"'pip install a2a-sdk[vertexai]'"
8+
) from e
9+
10+
import base64
11+
import json
12+
13+
from a2a.types import (
14+
Artifact,
15+
DataPart,
16+
FilePart,
17+
FileWithBytes,
18+
FileWithUri,
19+
Part,
20+
Task,
21+
TaskState,
22+
TaskStatus,
23+
TextPart,
24+
)
25+
26+
27+
def to_sdk_task_state(stored_state: types.State) -> TaskState:
28+
"""Converts a proto A2aTask.State to a TaskState enum."""
29+
return {
30+
types.State.STATE_UNSPECIFIED: TaskState.unknown,
31+
types.State.SUBMITTED: TaskState.submitted,
32+
types.State.WORKING: TaskState.working,
33+
types.State.COMPLETED: TaskState.completed,
34+
types.State.CANCELLED: TaskState.canceled,
35+
types.State.FAILED: TaskState.failed,
36+
types.State.REJECTED: TaskState.rejected,
37+
types.State.INPUT_REQUIRED: TaskState.input_required,
38+
types.State.AUTH_REQUIRED: TaskState.auth_required,
39+
}.get(stored_state, TaskState.unknown)
40+
41+
42+
def to_stored_task_state(task_state: TaskState) -> types.State:
43+
"""Converts a TaskState enum to a proto A2aTask.State enum value."""
44+
return {
45+
TaskState.unknown: types.State.STATE_UNSPECIFIED,
46+
TaskState.submitted: types.State.SUBMITTED,
47+
TaskState.working: types.State.WORKING,
48+
TaskState.completed: types.State.COMPLETED,
49+
TaskState.canceled: types.State.CANCELLED,
50+
TaskState.failed: types.State.FAILED,
51+
TaskState.rejected: types.State.REJECTED,
52+
TaskState.input_required: types.State.INPUT_REQUIRED,
53+
TaskState.auth_required: types.State.AUTH_REQUIRED,
54+
}.get(task_state, types.State.STATE_UNSPECIFIED)
55+
56+
57+
def to_stored_part(part: Part) -> types.Part:
58+
"""Converts a SDK Part to a proto Part."""
59+
if isinstance(part.root, TextPart):
60+
return types.Part(text=part.root.text)
61+
if isinstance(part.root, DataPart):
62+
data_bytes = json.dumps(part.root.data).encode('utf-8')
63+
return types.Part(
64+
inline_data=types.Blob(
65+
mime_type='application/json', data=data_bytes
66+
)
67+
)
68+
if isinstance(part.root, FilePart):
69+
file_content = part.root.file
70+
if isinstance(file_content, FileWithBytes):
71+
decoded_bytes = base64.b64decode(file_content.bytes)
72+
return types.Part(
73+
inline_data=types.Blob(
74+
mime_type=file_content.mime_type or '', data=decoded_bytes
75+
)
76+
)
77+
if isinstance(file_content, FileWithUri):
78+
return types.Part(
79+
file_data=types.FileData(
80+
mime_type=file_content.mime_type or '',
81+
file_uri=file_content.uri,
82+
)
83+
)
84+
raise ValueError(f'Unsupported part type: {type(part.root)}')
85+
86+
87+
def to_sdk_part(stored_part: types.Part) -> Part:
88+
"""Converts a proto Part to a SDK Part."""
89+
# https://github.com/a2aproject/a2a-python/blob/cced34d4eb22f7ec78422b2d04ce54924e132215/src/a2a/types.py#L1372
90+
# https://source.corp.google.com/piper///depot/google3/google/cloud/aiplatform/master/agent_engine_task_store.proto;l=152;rcl=875163491
91+
if stored_part.text:
92+
return Part(root=TextPart(text=stored_part.text))
93+
if stored_part.inline_data:
94+
encoded_bytes = base64.b64encode(stored_part.inline_data.data).decode(
95+
'utf-8'
96+
)
97+
return Part(
98+
root=FilePart(
99+
file=FileWithBytes(
100+
mime_type=stored_part.inline_data.mime_type,
101+
bytes=encoded_bytes,
102+
)
103+
)
104+
)
105+
if stored_part.file_data:
106+
return Part(
107+
root=FilePart(
108+
file=FileWithUri(
109+
mime_type=stored_part.file_data.mime_type,
110+
uri=stored_part.file_data.file_uri,
111+
)
112+
)
113+
)
114+
115+
return Part(root=TextPart(text=''))
116+
117+
118+
def to_stored_artifact(artifact: Artifact) -> types.TaskArtifact:
119+
"""Converts a SDK Artifact to a proto TaskArtifact."""
120+
return types.TaskArtifact(
121+
artifact_id=artifact.artifact_id,
122+
parts=[to_stored_part(part) for part in artifact.parts],
123+
)
124+
125+
126+
def to_sdk_artifact(stored_artifact: types.TaskArtifact) -> Artifact:
127+
"""Converts a proto TaskArtifact to a SDK Artifact."""
128+
return Artifact(
129+
artifact_id=stored_artifact.artifact_id,
130+
parts=[to_sdk_part(part) for part in stored_artifact.parts],
131+
)
132+
133+
134+
def to_stored_task(task: Task) -> types.A2aTask:
135+
"""Converts a SDK Task to a proto A2aTask."""
136+
return types.A2aTask(
137+
context_id=task.context_id,
138+
metadata=task.metadata,
139+
state=to_stored_task_state(task.status.state),
140+
output=types.TaskOutput(
141+
artifacts=[
142+
to_stored_artifact(artifact)
143+
for artifact in task.artifacts or []
144+
]
145+
),
146+
)
147+
148+
149+
def to_sdk_task(a2a_task: types.A2aTask) -> Task:
150+
"""Converts a proto A2aTask to a SDK Task."""
151+
return Task(
152+
id=a2a_task.name.split('/')[-1],
153+
context_id=a2a_task.context_id,
154+
status=TaskStatus(state=to_sdk_task_state(a2a_task.state)),
155+
metadata=a2a_task.metadata or {},
156+
artifacts=[
157+
to_sdk_artifact(artifact)
158+
for artifact in a2a_task.output.artifacts or []
159+
]
160+
if a2a_task.output
161+
else [],
162+
history=[],
163+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import logging
2+
3+
4+
try:
5+
import vertexai
6+
7+
from google.genai import errors as genai_errors
8+
from vertexai import types
9+
except ImportError as e:
10+
raise ImportError(
11+
'VertexTaskStore requires vertexai. '
12+
'Install with one of: '
13+
"'pip install a2a-sdk[vertexai]'" # TODO: or how exactly?
14+
) from e
15+
16+
from a2a.server.context import ServerCallContext
17+
from a2a.server.tasks import vertex_task_converter
18+
from a2a.server.tasks.task_store import TaskStore
19+
from a2a.types import Task # Task is the Pydantic model
20+
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class VertexTaskStore(TaskStore):
26+
"""SQLAlchemy-based implementation of TaskStore.
27+
28+
Stores task objects in a database supported by SQLAlchemy.
29+
"""
30+
31+
agent_engine_resource_id: str
32+
project_id: str
33+
location: str
34+
35+
def __init__(
36+
self,
37+
client: vertexai.Client,
38+
agent_engine_resource_id: str,
39+
) -> None:
40+
"""Initializes the DatabaseTaskStore.
41+
42+
Args:
43+
agent_engine_resource_id: The resource ID of the agent engine.
44+
location: The location of the agent engine.
45+
"""
46+
self._client = client
47+
self._agent_engine_resource_id = agent_engine_resource_id
48+
49+
async def save(
50+
self, task: Task, context: ServerCallContext | None = None
51+
) -> None:
52+
"""Saves or updates a task in the database."""
53+
previous_task = await self._get_stored_task(task.id)
54+
if previous_task is None:
55+
self._create(task)
56+
else:
57+
self._update(previous_task, task)
58+
59+
def _create(self, sdk_task: Task) -> None:
60+
stored_task = vertex_task_converter.to_stored_task(sdk_task)
61+
self._client.agent_engines.a2a_tasks.create(
62+
name=self._agent_engine_resource_id,
63+
a2a_task_id=sdk_task.id,
64+
config=types.CreateAgentEngineTaskConfig(
65+
context_id=stored_task.context_id,
66+
metadata=stored_task.metadata,
67+
state=stored_task.state,
68+
output=stored_task.output,
69+
),
70+
)
71+
72+
def _update(self, previous_stored_task: types.A2aTask, task: Task) -> None:
73+
previous_task = vertex_task_converter.to_sdk_task(previous_stored_task)
74+
events = []
75+
event_sequence_number = previous_stored_task.next_event_sequence_number
76+
if task.status.state != previous_task.status.state:
77+
events.append(
78+
types.TaskEvent(
79+
event_data=types.TaskEventData(
80+
state_change=types.TaskStateChange(
81+
new_state=vertex_task_converter.to_stored_task_state(
82+
task.status.state
83+
),
84+
),
85+
),
86+
event_sequence_number=event_sequence_number,
87+
),
88+
)
89+
event_sequence_number += 1
90+
if task.metadata != previous_task.metadata:
91+
events.append(
92+
types.TaskEvent(
93+
event_data=types.TaskEventData(
94+
metadata_change=types.TaskMetadataChange(
95+
new_metadata=task.metadata,
96+
)
97+
),
98+
event_sequence_number=event_sequence_number,
99+
),
100+
)
101+
event_sequence_number += 1
102+
if task.artifacts != previous_task.artifacts:
103+
task_artifact_change = types.TaskArtifactChange()
104+
event = types.TaskEvent(
105+
event_data=types.TaskEventData(
106+
output_change=types.TaskOutputChange(
107+
task_artifact_change=task_artifact_change
108+
)
109+
),
110+
event_sequence_number=event_sequence_number,
111+
)
112+
task_artifacts = (
113+
{artifact.artifact_id: artifact for artifact in task.artifacts}
114+
if task.artifacts
115+
else {}
116+
)
117+
previous_task_artifacts = (
118+
{
119+
artifact.artifact_id: artifact
120+
for artifact in previous_task.artifacts
121+
}
122+
if previous_task.artifacts
123+
else {}
124+
)
125+
for artifact in previous_task_artifacts.values():
126+
if artifact.artifact_id not in task_artifacts:
127+
if not task_artifact_change.deleted_artifact_ids:
128+
task_artifact_change.deleted_artifact_ids = []
129+
task_artifact_change.deleted_artifact_ids.append(
130+
artifact.artifact_id
131+
)
132+
for artifact in task_artifacts.values():
133+
if artifact.artifact_id not in previous_task_artifacts:
134+
if not task_artifact_change.added_artifacts:
135+
task_artifact_change.added_artifacts = []
136+
task_artifact_change.added_artifacts.append(
137+
vertex_task_converter.to_stored_artifact(artifact)
138+
)
139+
elif artifact != previous_task_artifacts[artifact.artifact_id]:
140+
if not task_artifact_change.updated_artifacts:
141+
task_artifact_change.updated_artifacts = []
142+
task_artifact_change.updated_artifacts.append(
143+
vertex_task_converter.to_stored_artifact(artifact)
144+
)
145+
if task_artifact_change != types.TaskArtifactChange():
146+
events.append(event)
147+
event_sequence_number += 1
148+
print('events= ', events)
149+
if not events:
150+
return
151+
self._client.agent_engines.a2a_tasks.events.append(
152+
name=self._agent_engine_resource_id + '/a2aTasks/' + task.id,
153+
task_events=events,
154+
)
155+
156+
async def _get_stored_task(self, task_id: str) -> types.A2aTask | None:
157+
try:
158+
a2a_task = self._client.agent_engines.a2a_tasks.get(
159+
name=self._agent_engine_resource_id + '/a2aTasks/' + task_id,
160+
)
161+
except genai_errors.APIError as e:
162+
if e.status == 'NOT_FOUND':
163+
logger.debug('Task %s not found in store.', task_id)
164+
return None
165+
raise
166+
return a2a_task
167+
168+
async def get(
169+
self, task_id: str, context: ServerCallContext | None = None
170+
) -> Task | None:
171+
"""Retrieves a task from the database by ID."""
172+
a2a_task = await self._get_stored_task(task_id)
173+
if a2a_task is None:
174+
return None
175+
return vertex_task_converter.to_sdk_task(a2a_task)
176+
177+
async def delete(
178+
self, task_id: str, context: ServerCallContext | None = None
179+
) -> None:
180+
"""Deletes a task from the database by ID."""
181+
raise NotImplementedError

0 commit comments

Comments
 (0)