Skip to content

Commit 93e3e20

Browse files
committed
make ServerCallCOntext mandatory in task manager
1 parent 72a8a7f commit 93e3e20

4 files changed

Lines changed: 50 additions & 16 deletions

File tree

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ async def build(
6868
):
6969
tasks = await asyncio.gather(
7070
*[
71-
self._task_store.get(task_id)
71+
self._task_store.get(
72+
task_id, context or ServerCallContext()
73+
)
7274
for task_id in params.message.reference_task_ids
7375
]
7476
)

src/a2a/server/tasks/task_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
context_id: str | None,
3232
task_store: TaskStore,
3333
initial_message: Message | None,
34-
context: ServerCallContext | None = None,
34+
context: ServerCallContext,
3535
):
3636
"""Initializes the TaskManager.
3737
@@ -51,7 +51,7 @@ def __init__(
5151
self.task_store = task_store
5252
self._initial_message = initial_message
5353
self._current_task: Task | None = None
54-
self._call_context: ServerCallContext | None = context
54+
self._call_context: ServerCallContext = context
5555
logger.debug(
5656
'TaskManager initialized with task_id: %s, context_id: %s',
5757
task_id,

tests/server/agent_execution/test_simple_request_context_builder.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,12 @@ async def test_build_populate_true_with_reference_task_ids(self) -> None:
127127
mock_ref_task1 = create_sample_task(task_id=ref_task_id1)
128128
mock_ref_task3 = create_sample_task(task_id=ref_task_id3)
129129

130+
server_call_context = ServerCallContext(user=UnauthenticatedUser())
131+
130132
# Configure task_store.get mock
131133
# Note: AsyncMock side_effect needs to handle multiple calls if they have different args.
132134
# A simple way is a list of return values, or a function.
133-
async def get_side_effect(task_id):
135+
async def get_side_effect(task_id, server_call_context):
134136
if task_id == ref_task_id1:
135137
return mock_ref_task1
136138
if task_id == ref_task_id3:
@@ -144,7 +146,6 @@ async def get_side_effect(task_id):
144146
reference_task_ids=[ref_task_id1, ref_task_id2, ref_task_id3]
145147
)
146148
)
147-
server_call_context = ServerCallContext(user=UnauthenticatedUser())
148149

149150
request_context = await builder.build(
150151
params=params,
@@ -155,9 +156,15 @@ async def get_side_effect(task_id):
155156
)
156157

157158
self.assertEqual(self.mock_task_store.get.call_count, 3)
158-
self.mock_task_store.get.assert_any_call(ref_task_id1)
159-
self.mock_task_store.get.assert_any_call(ref_task_id2)
160-
self.mock_task_store.get.assert_any_call(ref_task_id3)
159+
self.mock_task_store.get.assert_any_call(
160+
ref_task_id1, server_call_context
161+
)
162+
self.mock_task_store.get.assert_any_call(
163+
ref_task_id2, server_call_context
164+
)
165+
self.mock_task_store.get.assert_any_call(
166+
ref_task_id3, server_call_context
167+
)
161168

162169
self.assertIsNotNone(request_context.related_tasks)
163170
self.assertEqual(

tests/server/tasks/test_task_manager.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import pytest
55

6+
from a2a.auth.user import User
7+
from a2a.server.context import ServerCallContext
68
from a2a.server.tasks import TaskManager
7-
from a2a.utils.errors import InvalidParamsError
89
from a2a.types.a2a_pb2 import (
910
Artifact,
1011
Message,
@@ -19,6 +20,24 @@
1920
from a2a.utils.errors import InvalidParamsError
2021

2122

23+
class SampleUser(User):
24+
"""A test implementation of the User interface."""
25+
26+
def __init__(self, user_name: str):
27+
self._user_name = user_name
28+
29+
@property
30+
def is_authenticated(self) -> bool:
31+
return True
32+
33+
@property
34+
def user_name(self) -> str:
35+
return self._user_name
36+
37+
38+
TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user'))
39+
40+
2241
# Create proto task instead of dict
2342
def create_minimal_task(
2443
task_id: str = 'task-abc',
@@ -49,6 +68,7 @@ def task_manager(mock_task_store: AsyncMock) -> TaskManager:
4968
context_id=MINIMAL_CONTEXT_ID,
5069
task_store=mock_task_store,
5170
initial_message=None,
71+
context=TEST_CONTEXT,
5272
)
5373

5474

@@ -63,6 +83,7 @@ def test_task_manager_invalid_task_id(
6383
context_id='test_context',
6484
task_store=mock_task_store,
6585
initial_message=None,
86+
context=TEST_CONTEXT,
6687
)
6788

6889

@@ -75,7 +96,7 @@ async def test_get_task_existing(
7596
mock_task_store.get.return_value = expected_task
7697
retrieved_task = await task_manager.get_task()
7798
assert retrieved_task == expected_task
78-
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None)
99+
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT)
79100

80101

81102
@pytest.mark.asyncio
@@ -86,7 +107,7 @@ async def test_get_task_nonexistent(
86107
mock_task_store.get.return_value = None
87108
retrieved_task = await task_manager.get_task()
88109
assert retrieved_task is None
89-
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None)
110+
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT)
90111

91112

92113
@pytest.mark.asyncio
@@ -96,7 +117,7 @@ async def test_save_task_event_new_task(
96117
"""Test saving a new task."""
97118
task = create_minimal_task()
98119
await task_manager.save_task_event(task)
99-
mock_task_store.save.assert_called_once_with(task, None)
120+
mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT)
100121

101122

102123
@pytest.mark.asyncio
@@ -188,7 +209,7 @@ async def test_ensure_task_existing(
188209
)
189210
retrieved_task = await task_manager.ensure_task(event)
190211
assert retrieved_task == expected_task
191-
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None)
212+
mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT)
192213

193214

194215
@pytest.mark.asyncio
@@ -202,6 +223,7 @@ async def test_ensure_task_nonexistent(
202223
context_id=None,
203224
task_store=mock_task_store,
204225
initial_message=None,
226+
context=TEST_CONTEXT,
205227
)
206228
event = TaskStatusUpdateEvent(
207229
task_id='new-task',
@@ -212,7 +234,7 @@ async def test_ensure_task_nonexistent(
212234
assert new_task.id == 'new-task'
213235
assert new_task.context_id == 'some-context'
214236
assert new_task.status.state == TaskState.TASK_STATE_SUBMITTED
215-
mock_task_store.save.assert_called_once_with(new_task, None)
237+
mock_task_store.save.assert_called_once_with(new_task, TEST_CONTEXT)
216238
assert task_manager_without_id.task_id == 'new-task'
217239
assert task_manager_without_id.context_id == 'some-context'
218240

@@ -233,7 +255,7 @@ async def test_save_task(
233255
"""Test saving a task."""
234256
task = create_minimal_task()
235257
await task_manager._save_task(task) # type: ignore
236-
mock_task_store.save.assert_called_once_with(task, None)
258+
mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT)
237259

238260

239261
@pytest.mark.asyncio
@@ -263,14 +285,15 @@ async def test_save_task_event_new_task_no_task_id(
263285
context_id=None,
264286
task_store=mock_task_store,
265287
initial_message=None,
288+
context=TEST_CONTEXT,
266289
)
267290
task = Task(
268291
id='new-task-id',
269292
context_id='some-context',
270293
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
271294
)
272295
await task_manager_without_id.save_task_event(task)
273-
mock_task_store.save.assert_called_once_with(task, None)
296+
mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT)
274297
assert task_manager_without_id.task_id == 'new-task-id'
275298
assert task_manager_without_id.context_id == 'some-context'
276299
# initial submit should be updated to working
@@ -287,6 +310,7 @@ async def test_get_task_no_task_id(
287310
context_id='some-context',
288311
task_store=mock_task_store,
289312
initial_message=None,
313+
context=TEST_CONTEXT,
290314
)
291315
retrieved_task = await task_manager_without_id.get_task()
292316
assert retrieved_task is None
@@ -303,6 +327,7 @@ async def test_save_task_event_no_task_existing(
303327
context_id=None,
304328
task_store=mock_task_store,
305329
initial_message=None,
330+
context=TEST_CONTEXT,
306331
)
307332
mock_task_store.get.return_value = None
308333
event = TaskStatusUpdateEvent(

0 commit comments

Comments
 (0)