Skip to content

Commit 4cd6708

Browse files
committed
make ServerCallCOntext mandatory
1 parent e77f8d9 commit 4cd6708

4 files changed

Lines changed: 80 additions & 61 deletions

File tree

src/a2a/contrib/tasks/vertex_task_store.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def __init__(
4444
self._client = client
4545
self._agent_engine_resource_id = agent_engine_resource_id
4646

47-
async def save(
48-
self, task: Task, context: ServerCallContext | None = None
49-
) -> None:
47+
async def save(self, task: Task, context: ServerCallContext) -> None:
5048
"""Saves or updates a task in the store."""
5149
compat_task = to_compat_task(task)
5250
previous_task = await self._get_stored_task(compat_task.id)
@@ -206,7 +204,7 @@ async def _get_stored_task(
206204
return a2a_task
207205

208206
async def get(
209-
self, task_id: str, context: ServerCallContext | None = None
207+
self, task_id: str, context: ServerCallContext
210208
) -> Task | None:
211209
"""Retrieves a task from the database by ID."""
212210
a2a_task = await self._get_stored_task(task_id)
@@ -217,13 +215,11 @@ async def get(
217215
async def list(
218216
self,
219217
params: ListTasksRequest,
220-
context: ServerCallContext | None = None,
218+
context: ServerCallContext,
221219
) -> ListTasksResponse:
222220
"""Retrieves a list of tasks from the store."""
223221
raise NotImplementedError
224222

225-
async def delete(
226-
self, task_id: str, context: ServerCallContext | None = None
227-
) -> None:
223+
async def delete(self, task_id: str, context: ServerCallContext) -> None:
228224
"""The backend doesn't support deleting tasks, so this is not implemented."""
229225
raise NotImplementedError

src/a2a/server/agent_execution/context.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def configuration(self) -> SendMessageConfiguration | None:
140140
return self._params.configuration if self._params else None
141141

142142
@property
143-
def call_context(self) -> ServerCallContext | None:
143+
def call_context(self) -> ServerCallContext:
144144
"""The server call context associated with this request."""
145145
return self._call_context
146146

@@ -157,22 +157,17 @@ def add_activated_extension(self, uri: str) -> None:
157157
This causes the extension to be indicated back to the client in the
158158
response.
159159
"""
160-
if self._call_context:
161-
self._call_context.activated_extensions.add(uri)
160+
self._call_context.activated_extensions.add(uri)
162161

163162
@property
164163
def tenant(self) -> str:
165164
"""The tenant associated with this request."""
166-
return self._call_context.tenant if self._call_context else ''
165+
return self._call_context.tenant
167166

168167
@property
169168
def requested_extensions(self) -> set[str]:
170169
"""Extensions that the client requested to activate."""
171-
return (
172-
self._call_context.requested_extensions
173-
if self._call_context
174-
else set()
175-
)
170+
return self._call_context.requested_extensions
176171

177172
def _check_or_generate_task_id(self) -> None:
178173
"""Ensures a task ID is present, generating one if necessary."""

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ async def _setup_message_execution(
290290
await self._push_config_store.set_info(
291291
task_id,
292292
params.configuration.task_push_notification_config,
293-
context or ServerCallContext(),
293+
context,
294294
)
295295

296296
queue = await self._queue_manager.create_or_tap(task_id)
@@ -504,7 +504,7 @@ async def on_create_task_push_notification_config(
504504
await self._push_config_store.set_info(
505505
task_id,
506506
params,
507-
context or ServerCallContext(),
507+
context,
508508
)
509509

510510
return params
@@ -529,10 +529,7 @@ async def on_get_task_push_notification_config(
529529
raise TaskNotFoundError
530530

531531
push_notification_configs: list[TaskPushNotificationConfig] = (
532-
await self._push_config_store.get_info(
533-
task_id, context or ServerCallContext()
534-
)
535-
or []
532+
await self._push_config_store.get_info(task_id, context) or []
536533
)
537534

538535
for config in push_notification_configs:
@@ -603,7 +600,7 @@ async def on_list_task_push_notification_configs(
603600
raise TaskNotFoundError
604601

605602
push_notification_config_list = await self._push_config_store.get_info(
606-
task_id, context or ServerCallContext()
603+
task_id, context
607604
)
608605

609606
return ListTaskPushNotificationConfigsResponse(
@@ -629,6 +626,4 @@ async def on_delete_task_push_notification_config(
629626
if not task:
630627
raise TaskNotFoundError
631628

632-
await self._push_config_store.delete_info(
633-
task_id, context or ServerCallContext(), config_id
634-
)
629+
await self._push_config_store.delete_info(task_id, context, config_id)

tests/contrib/tasks/test_vertex_task_store.py

Lines changed: 67 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def backend_type(request) -> str:
6262

6363

6464
from a2a.contrib.tasks.vertex_task_store import VertexTaskStore
65+
from a2a.server.context import ServerCallContext
6566
from a2a.types.a2a_pb2 import (
6667
Artifact,
6768
Part,
@@ -140,9 +141,11 @@ async def test_save_task(vertex_store: VertexTaskStore) -> None:
140141
task_to_save = Task()
141142
task_to_save.CopyFrom(MINIMAL_TASK_OBJ)
142143
task_to_save.id = 'save-test-task-2'
143-
await vertex_store.save(task_to_save)
144+
await vertex_store.save(task_to_save, ServerCallContext())
144145

145-
retrieved_task = await vertex_store.get(task_to_save.id)
146+
retrieved_task = await vertex_store.get(
147+
task_to_save.id, ServerCallContext()
148+
)
146149
assert retrieved_task is not None
147150
assert retrieved_task.id == task_to_save.id
148151

@@ -156,9 +159,11 @@ async def test_get_task(vertex_store: VertexTaskStore) -> None:
156159
task_to_save = Task()
157160
task_to_save.CopyFrom(MINIMAL_TASK_OBJ)
158161
task_to_save.id = task_id
159-
await vertex_store.save(task_to_save)
162+
await vertex_store.save(task_to_save, ServerCallContext())
160163

161-
retrieved_task = await vertex_store.get(task_to_save.id)
164+
retrieved_task = await vertex_store.get(
165+
task_to_save.id, ServerCallContext()
166+
)
162167
assert retrieved_task is not None
163168
assert retrieved_task.id == task_to_save.id
164169
assert retrieved_task.context_id == task_to_save.context_id
@@ -170,7 +175,9 @@ async def test_get_nonexistent_task(
170175
vertex_store: VertexTaskStore,
171176
) -> None:
172177
"""Test retrieving a nonexistent task."""
173-
retrieved_task = await vertex_store.get('nonexistent-task-id')
178+
retrieved_task = await vertex_store.get(
179+
'nonexistent-task-id', ServerCallContext()
180+
)
174181
assert retrieved_task is None
175182

176183

@@ -196,8 +203,8 @@ async def test_save_and_get_detailed_task(
196203
test_task.metadata['key1'] = 'value1'
197204
test_task.metadata['key2'] = 123
198205

199-
await vertex_store.save(test_task)
200-
retrieved_task = await vertex_store.get(test_task.id)
206+
await vertex_store.save(test_task, ServerCallContext())
207+
retrieved_task = await vertex_store.get(test_task.id, ServerCallContext())
201208

202209
assert retrieved_task is not None
203210
assert retrieved_task.id == test_task.id
@@ -221,9 +228,11 @@ async def test_update_task_status_and_metadata(
221228
artifacts=[],
222229
history=[],
223230
)
224-
await vertex_store.save(original_task)
231+
await vertex_store.save(original_task, ServerCallContext())
225232

226-
retrieved_before_update = await vertex_store.get(task_id)
233+
retrieved_before_update = await vertex_store.get(
234+
task_id, ServerCallContext()
235+
)
227236
assert retrieved_before_update is not None
228237
assert (
229238
retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED
@@ -236,9 +245,11 @@ async def test_update_task_status_and_metadata(
236245
updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z')
237246
updated_task.metadata.update({'update_key': 'update_value'})
238247

239-
await vertex_store.save(updated_task)
248+
await vertex_store.save(updated_task, ServerCallContext())
240249

241-
retrieved_after_update = await vertex_store.get(task_id)
250+
retrieved_after_update = await vertex_store.get(
251+
task_id, ServerCallContext()
252+
)
242253
assert retrieved_after_update is not None
243254
assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED
244255
assert retrieved_after_update.metadata == {'update_key': 'update_value'}
@@ -260,9 +271,11 @@ async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None:
260271
],
261272
history=[],
262273
)
263-
await vertex_store.save(original_task)
274+
await vertex_store.save(original_task, ServerCallContext())
264275

265-
retrieved_before_update = await vertex_store.get(task_id)
276+
retrieved_before_update = await vertex_store.get(
277+
task_id, ServerCallContext()
278+
)
266279
assert retrieved_before_update is not None
267280
assert (
268281
retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED
@@ -281,9 +294,11 @@ async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None:
281294
)
282295
)
283296

284-
await vertex_store.save(updated_task)
297+
await vertex_store.save(updated_task, ServerCallContext())
285298

286-
retrieved_after_update = await vertex_store.get(task_id)
299+
retrieved_after_update = await vertex_store.get(
300+
task_id, ServerCallContext()
301+
)
287302
assert retrieved_after_update is not None
288303
assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING
289304

@@ -321,9 +336,11 @@ async def test_update_task_update_artifact(
321336
],
322337
history=[],
323338
)
324-
await vertex_store.save(original_task)
339+
await vertex_store.save(original_task, ServerCallContext())
325340

326-
retrieved_before_update = await vertex_store.get(task_id)
341+
retrieved_before_update = await vertex_store.get(
342+
task_id, ServerCallContext()
343+
)
327344
assert retrieved_before_update is not None
328345
assert (
329346
retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED
@@ -337,9 +354,11 @@ async def test_update_task_update_artifact(
337354

338355
updated_task.artifacts[0].parts[0].text = 'ahoy'
339356

340-
await vertex_store.save(updated_task)
357+
await vertex_store.save(updated_task, ServerCallContext())
341358

342-
retrieved_after_update = await vertex_store.get(task_id)
359+
retrieved_after_update = await vertex_store.get(
360+
task_id, ServerCallContext()
361+
)
343362
assert retrieved_after_update is not None
344363
assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING
345364

@@ -377,9 +396,11 @@ async def test_update_task_delete_artifact(
377396
],
378397
history=[],
379398
)
380-
await vertex_store.save(original_task)
399+
await vertex_store.save(original_task, ServerCallContext())
381400

382-
retrieved_before_update = await vertex_store.get(task_id)
401+
retrieved_before_update = await vertex_store.get(
402+
task_id, ServerCallContext()
403+
)
383404
assert retrieved_before_update is not None
384405
assert (
385406
retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED
@@ -393,9 +414,11 @@ async def test_update_task_delete_artifact(
393414

394415
del updated_task.artifacts[1]
395416

396-
await vertex_store.save(updated_task)
417+
await vertex_store.save(updated_task, ServerCallContext())
397418

398-
retrieved_after_update = await vertex_store.get(task_id)
419+
retrieved_after_update = await vertex_store.get(
420+
task_id, ServerCallContext()
421+
)
399422
assert retrieved_after_update is not None
400423
assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING
401424

@@ -426,8 +449,10 @@ async def test_metadata_field_mapping(
426449
context_id='session-meta-1',
427450
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
428451
)
429-
await vertex_store.save(task_no_metadata)
430-
retrieved_no_metadata = await vertex_store.get('task-metadata-test-1')
452+
await vertex_store.save(task_no_metadata, ServerCallContext())
453+
retrieved_no_metadata = await vertex_store.get(
454+
'task-metadata-test-1', ServerCallContext()
455+
)
431456
assert retrieved_no_metadata is not None
432457
assert retrieved_no_metadata.metadata == {}
433458

@@ -439,8 +464,10 @@ async def test_metadata_field_mapping(
439464
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
440465
metadata=simple_metadata,
441466
)
442-
await vertex_store.save(task_simple_metadata)
443-
retrieved_simple = await vertex_store.get('task-metadata-test-2')
467+
await vertex_store.save(task_simple_metadata, ServerCallContext())
468+
retrieved_simple = await vertex_store.get(
469+
'task-metadata-test-2', ServerCallContext()
470+
)
444471
assert retrieved_simple is not None
445472
assert retrieved_simple.metadata == simple_metadata
446473

@@ -463,8 +490,10 @@ async def test_metadata_field_mapping(
463490
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
464491
metadata=complex_metadata,
465492
)
466-
await vertex_store.save(task_complex_metadata)
467-
retrieved_complex = await vertex_store.get('task-metadata-test-3')
493+
await vertex_store.save(task_complex_metadata, ServerCallContext())
494+
retrieved_complex = await vertex_store.get(
495+
'task-metadata-test-3', ServerCallContext()
496+
)
468497
assert retrieved_complex is not None
469498
assert retrieved_complex.metadata == complex_metadata
470499

@@ -474,16 +503,18 @@ async def test_metadata_field_mapping(
474503
context_id='session-meta-4',
475504
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
476505
)
477-
await vertex_store.save(task_update_metadata)
506+
await vertex_store.save(task_update_metadata, ServerCallContext())
478507

479508
# Update metadata
480509
task_update_metadata.metadata.Clear()
481510
task_update_metadata.metadata.update(
482511
{'updated': True, 'timestamp': '2024-01-01'}
483512
)
484-
await vertex_store.save(task_update_metadata)
513+
await vertex_store.save(task_update_metadata, ServerCallContext())
485514

486-
retrieved_updated = await vertex_store.get('task-metadata-test-4')
515+
retrieved_updated = await vertex_store.get(
516+
'task-metadata-test-4', ServerCallContext()
517+
)
487518
assert retrieved_updated is not None
488519
assert retrieved_updated.metadata == {
489520
'updated': True,
@@ -492,8 +523,10 @@ async def test_metadata_field_mapping(
492523

493524
# Test 5: Update metadata from dict to None
494525
task_update_metadata.metadata.Clear()
495-
await vertex_store.save(task_update_metadata)
526+
await vertex_store.save(task_update_metadata, ServerCallContext())
496527

497-
retrieved_none = await vertex_store.get('task-metadata-test-4')
528+
retrieved_none = await vertex_store.get(
529+
'task-metadata-test-4', ServerCallContext()
530+
)
498531
assert retrieved_none is not None
499532
assert retrieved_none.metadata == {}

0 commit comments

Comments
 (0)