Skip to content

Commit c665e3c

Browse files
committed
Create a fake vertex backend for testing
1 parent 0490e99 commit c665e3c

2 files changed

Lines changed: 178 additions & 19 deletions

File tree

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Fake Vertex AI Client implementations for testing."""
2+
3+
import copy
4+
5+
from google.genai import errors as genai_errors
6+
from vertexai import types as vertexai_types
7+
8+
9+
class FakeAgentEnginesA2aTasksEventsClient:
10+
def __init__(self, parent_client):
11+
self.parent_client = parent_client
12+
13+
async def append(
14+
self, name: str, task_events: list[vertexai_types.TaskEvent]
15+
) -> None:
16+
task = self.parent_client.tasks.get(name)
17+
if not task:
18+
raise genai_errors.APIError(
19+
code=404,
20+
response_json={
21+
'error': {
22+
'status': 'NOT_FOUND',
23+
'message': 'Task not found',
24+
}
25+
},
26+
)
27+
28+
task = copy.deepcopy(task)
29+
if (
30+
not hasattr(task, 'next_event_sequence_number')
31+
or not task.next_event_sequence_number
32+
):
33+
task.next_event_sequence_number = 0
34+
35+
for event in task_events:
36+
data = event.event_data
37+
if getattr(data, 'state_change', None):
38+
task.state = getattr(data.state_change, 'new_state', task.state)
39+
if getattr(data, 'metadata_change', None):
40+
task.metadata = getattr(
41+
data.metadata_change, 'new_metadata', task.metadata
42+
)
43+
if getattr(data, 'output_change', None):
44+
change = getattr(
45+
data.output_change, 'task_artifact_change', None
46+
)
47+
if not change:
48+
continue
49+
if not getattr(task, 'output', None):
50+
task.output = vertexai_types.TaskOutput()
51+
52+
current_artifacts = (
53+
list(task.output.artifacts)
54+
if getattr(task.output, 'artifacts', None)
55+
else []
56+
)
57+
58+
deleted_ids = getattr(change, 'deleted_artifact_ids', []) or []
59+
if deleted_ids:
60+
current_artifacts = [
61+
a
62+
for a in current_artifacts
63+
if a.artifact_id not in deleted_ids
64+
]
65+
66+
added = getattr(change, 'added_artifacts', []) or []
67+
if added:
68+
current_artifacts.extend(added)
69+
70+
updated = getattr(change, 'updated_artifacts', []) or []
71+
if updated:
72+
updated_map = {a.artifact_id: a for a in updated}
73+
current_artifacts = [
74+
updated_map.get(a.artifact_id, a)
75+
for a in current_artifacts
76+
]
77+
78+
try:
79+
del task.output.artifacts[:]
80+
task.output.artifacts.extend(current_artifacts)
81+
except Exception:
82+
task.output.artifacts = current_artifacts
83+
task.next_event_sequence_number += 1
84+
85+
self.parent_client.tasks[name] = task
86+
87+
88+
class FakeAgentEnginesA2aTasksClient:
89+
def __init__(self):
90+
self.tasks: dict[str, vertexai_types.A2aTask] = {}
91+
self.events = FakeAgentEnginesA2aTasksEventsClient(self)
92+
93+
async def create(
94+
self,
95+
name: str,
96+
a2a_task_id: str,
97+
config: vertexai_types.CreateAgentEngineTaskConfig,
98+
) -> vertexai_types.A2aTask:
99+
full_name = f'{name}/a2aTasks/{a2a_task_id}'
100+
task = vertexai_types.A2aTask(
101+
name=full_name,
102+
context_id=config.context_id,
103+
metadata=config.metadata,
104+
output=config.output,
105+
state=vertexai_types.State.SUBMITTED,
106+
)
107+
task.next_event_sequence_number = 1
108+
self.tasks[full_name] = task
109+
return task
110+
111+
async def get(self, name: str) -> vertexai_types.A2aTask:
112+
if name not in self.tasks:
113+
raise genai_errors.APIError(
114+
code=404,
115+
response_json={
116+
'error': {
117+
'status': 'NOT_FOUND',
118+
'message': 'Task not found',
119+
}
120+
},
121+
)
122+
return copy.deepcopy(self.tasks[name])
123+
124+
125+
class FakeAgentEnginesClient:
126+
def __init__(self):
127+
self.a2a_tasks = FakeAgentEnginesA2aTasksClient()
128+
129+
130+
class FakeAioClient:
131+
def __init__(self):
132+
self.agent_engines = FakeAgentEnginesClient()
133+
134+
135+
class FakeVertexClient:
136+
def __init__(self):
137+
self.aio = FakeAioClient()

tests/contrib/tasks/test_vertex_task_store.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import vertexai
3030

3131

32-
# Skip the entire test module if required environment variables are not set
32+
# Skip the real backend tests if required environment variables are not set
3333
missing_env_vars = not all(
3434
os.environ.get(var)
3535
for var in [
@@ -39,13 +39,25 @@
3939
'VERTEX_API_VERSION',
4040
]
4141
)
42-
pytestmark = pytest.mark.skipif(
43-
missing_env_vars,
44-
reason=(
45-
'Vertex Task Store tests require VERTEX_PROJECT, VERTEX_LOCATION, '
46-
'VERTEX_BASE_URL, and VERTEX_API_VERSION environment variables. '
47-
),
42+
import sys
43+
44+
45+
@pytest.fixture(
46+
scope='module',
47+
params=[
48+
'fake',
49+
pytest.param(
50+
'real',
51+
marks=pytest.mark.skipif(
52+
missing_env_vars,
53+
reason='Missing required environment variables for real Vertex Task Store.',
54+
),
55+
),
56+
],
4857
)
58+
def backend_type(request) -> str:
59+
return request.param
60+
4961

5062
from a2a.contrib.tasks.vertex_task_store import VertexTaskStore
5163
from a2a.types import (
@@ -75,12 +87,15 @@
7587

7688

7789
@pytest.fixture(scope='module')
78-
def agent_engine_resource_id() -> Generator[str, None, None]:
90+
def agent_engine_resource_id(backend_type: str) -> Generator[str, None, None]:
7991
"""
8092
Module-scoped fixture that creates and deletes a single Agent Engine
81-
for all the tests. This is on the module scope to speed up the testing
82-
process.
93+
for all the tests. For fake backend, it yields a mock resource.
8394
"""
95+
if backend_type == 'fake':
96+
yield 'projects/mock-project/locations/mock-location/agentEngines/mock-engine'
97+
return
98+
8499
project = os.environ.get('VERTEX_PROJECT')
85100
location = os.environ.get('VERTEX_LOCATION')
86101
base_url = os.environ.get('VERTEX_BASE_URL')
@@ -95,23 +110,30 @@ def agent_engine_resource_id() -> Generator[str, None, None]:
95110

96111
@pytest_asyncio.fixture
97112
async def vertex_store(
113+
backend_type: str,
98114
agent_engine_resource_id: str,
99115
) -> AsyncGenerator[VertexTaskStore, None]:
100116
"""
101117
Function-scoped fixture providing a fresh VertexTaskStore per test,
102-
reusing the module-scoped engine.
118+
reusing the module-scoped engine. Uses fake client for 'fake' backend.
103119
"""
104-
project = os.environ.get('VERTEX_PROJECT')
105-
location = os.environ.get('VERTEX_LOCATION')
106-
base_url = os.environ.get('VERTEX_BASE_URL')
107-
api_version = os.environ.get('VERTEX_API_VERSION')
120+
if backend_type == 'fake':
121+
sys.path.append(os.path.dirname(__file__))
122+
from fake_vertex_client import FakeVertexClient
108123

109-
client = vertexai.Client(project=project, location=location)
110-
client._api_client._http_options.base_url = base_url
111-
client._api_client._http_options.api_version = api_version
124+
client = FakeVertexClient()
125+
else:
126+
project = os.environ.get('VERTEX_PROJECT')
127+
location = os.environ.get('VERTEX_LOCATION')
128+
base_url = os.environ.get('VERTEX_BASE_URL')
129+
api_version = os.environ.get('VERTEX_API_VERSION')
130+
131+
client = vertexai.Client(project=project, location=location)
132+
client._api_client._http_options.base_url = base_url
133+
client._api_client._http_options.api_version = api_version
112134

113135
store = VertexTaskStore(
114-
client=client,
136+
client=client, # type: ignore
115137
agent_engine_resource_id=agent_engine_resource_id,
116138
)
117139
yield store

0 commit comments

Comments
 (0)