33
44import pytest
55
6+ from a2a .auth .user import User
7+ from a2a .server .context import ServerCallContext
68from a2a .server .tasks import TaskManager
7- from a2a .utils .errors import InvalidParamsError
89from a2a .types .a2a_pb2 import (
910 Artifact ,
1011 Message ,
1920from a2a .utils .errors import InvalidParamsError
2021
2122
23+ class SampleUser (User ):
24+ """A test implementation of the User interface."""
25+
26+ def __init__ (self , user_name : str ):
27+ self ._user_name = user_name
28+
29+ @property
30+ def is_authenticated (self ) -> bool :
31+ return True
32+
33+ @property
34+ def user_name (self ) -> str :
35+ return self ._user_name
36+
37+
38+ TEST_CONTEXT = ServerCallContext (user = SampleUser ('test_user' ))
39+
40+
2241# Create proto task instead of dict
2342def create_minimal_task (
2443 task_id : str = 'task-abc' ,
@@ -49,6 +68,7 @@ def task_manager(mock_task_store: AsyncMock) -> TaskManager:
4968 context_id = MINIMAL_CONTEXT_ID ,
5069 task_store = mock_task_store ,
5170 initial_message = None ,
71+ context = TEST_CONTEXT ,
5272 )
5373
5474
@@ -63,6 +83,7 @@ def test_task_manager_invalid_task_id(
6383 context_id = 'test_context' ,
6484 task_store = mock_task_store ,
6585 initial_message = None ,
86+ context = TEST_CONTEXT ,
6687 )
6788
6889
@@ -75,7 +96,7 @@ async def test_get_task_existing(
7596 mock_task_store .get .return_value = expected_task
7697 retrieved_task = await task_manager .get_task ()
7798 assert retrieved_task == expected_task
78- mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , None )
99+ mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , TEST_CONTEXT )
79100
80101
81102@pytest .mark .asyncio
@@ -86,7 +107,7 @@ async def test_get_task_nonexistent(
86107 mock_task_store .get .return_value = None
87108 retrieved_task = await task_manager .get_task ()
88109 assert retrieved_task is None
89- mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , None )
110+ mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , TEST_CONTEXT )
90111
91112
92113@pytest .mark .asyncio
@@ -96,7 +117,7 @@ async def test_save_task_event_new_task(
96117 """Test saving a new task."""
97118 task = create_minimal_task ()
98119 await task_manager .save_task_event (task )
99- mock_task_store .save .assert_called_once_with (task , None )
120+ mock_task_store .save .assert_called_once_with (task , TEST_CONTEXT )
100121
101122
102123@pytest .mark .asyncio
@@ -188,7 +209,7 @@ async def test_ensure_task_existing(
188209 )
189210 retrieved_task = await task_manager .ensure_task (event )
190211 assert retrieved_task == expected_task
191- mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , None )
212+ mock_task_store .get .assert_called_once_with (MINIMAL_TASK_ID , TEST_CONTEXT )
192213
193214
194215@pytest .mark .asyncio
@@ -202,6 +223,7 @@ async def test_ensure_task_nonexistent(
202223 context_id = None ,
203224 task_store = mock_task_store ,
204225 initial_message = None ,
226+ context = TEST_CONTEXT ,
205227 )
206228 event = TaskStatusUpdateEvent (
207229 task_id = 'new-task' ,
@@ -212,7 +234,7 @@ async def test_ensure_task_nonexistent(
212234 assert new_task .id == 'new-task'
213235 assert new_task .context_id == 'some-context'
214236 assert new_task .status .state == TaskState .TASK_STATE_SUBMITTED
215- mock_task_store .save .assert_called_once_with (new_task , None )
237+ mock_task_store .save .assert_called_once_with (new_task , TEST_CONTEXT )
216238 assert task_manager_without_id .task_id == 'new-task'
217239 assert task_manager_without_id .context_id == 'some-context'
218240
@@ -233,7 +255,7 @@ async def test_save_task(
233255 """Test saving a task."""
234256 task = create_minimal_task ()
235257 await task_manager ._save_task (task ) # type: ignore
236- mock_task_store .save .assert_called_once_with (task , None )
258+ mock_task_store .save .assert_called_once_with (task , TEST_CONTEXT )
237259
238260
239261@pytest .mark .asyncio
@@ -263,14 +285,15 @@ async def test_save_task_event_new_task_no_task_id(
263285 context_id = None ,
264286 task_store = mock_task_store ,
265287 initial_message = None ,
288+ context = TEST_CONTEXT ,
266289 )
267290 task = Task (
268291 id = 'new-task-id' ,
269292 context_id = 'some-context' ,
270293 status = TaskStatus (state = TaskState .TASK_STATE_WORKING ),
271294 )
272295 await task_manager_without_id .save_task_event (task )
273- mock_task_store .save .assert_called_once_with (task , None )
296+ mock_task_store .save .assert_called_once_with (task , TEST_CONTEXT )
274297 assert task_manager_without_id .task_id == 'new-task-id'
275298 assert task_manager_without_id .context_id == 'some-context'
276299 # initial submit should be updated to working
@@ -287,6 +310,7 @@ async def test_get_task_no_task_id(
287310 context_id = 'some-context' ,
288311 task_store = mock_task_store ,
289312 initial_message = None ,
313+ context = TEST_CONTEXT ,
290314 )
291315 retrieved_task = await task_manager_without_id .get_task ()
292316 assert retrieved_task is None
@@ -303,6 +327,7 @@ async def test_save_task_event_no_task_existing(
303327 context_id = None ,
304328 task_store = mock_task_store ,
305329 initial_message = None ,
330+ context = TEST_CONTEXT ,
306331 )
307332 mock_task_store .get .return_value = None
308333 event = TaskStatusUpdateEvent (
0 commit comments