Skip to content

Commit 50cf655

Browse files
committed
Make ServerCallContext mandatory in RequestContext
1 parent a21a16a commit 50cf655

5 files changed

Lines changed: 55 additions & 31 deletions

File tree

src/a2a/server/agent_execution/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,35 @@ class RequestContext:
2626

2727
def __init__( # noqa: PLR0913
2828
self,
29+
call_context: ServerCallContext,
2930
request: SendMessageRequest | None = None,
3031
task_id: str | None = None,
3132
context_id: str | None = None,
3233
task: Task | None = None,
3334
related_tasks: list[Task] | None = None,
34-
call_context: ServerCallContext | None = None,
3535
task_id_generator: IDGenerator | None = None,
3636
context_id_generator: IDGenerator | None = None,
3737
):
3838
"""Initializes the RequestContext.
3939
4040
Args:
41+
call_context: The server call context associated with this request.
4142
request: The incoming `SendMessageRequest` request payload.
4243
task_id: The ID of the task explicitly provided in the request or path.
4344
context_id: The ID of the context explicitly provided in the request or path.
4445
task: The existing `Task` object retrieved from the store, if any.
4546
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
46-
call_context: The server call context associated with this request.
4747
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
4848
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
4949
"""
5050
if related_tasks is None:
5151
related_tasks = []
52+
self._call_context = call_context
5253
self._params = request
5354
self._task_id = task_id
5455
self._context_id = context_id
5556
self._current_task = task
5657
self._related_tasks = related_tasks
57-
self._call_context = call_context
5858
self._task_id_generator = (
5959
task_id_generator if task_id_generator else UUIDGenerator()
6060
)

src/a2a/server/agent_execution/request_context_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ class RequestContextBuilder(ABC):
1111
@abstractmethod
1212
async def build(
1313
self,
14+
context: ServerCallContext,
1415
params: SendMessageRequest | None = None,
1516
task_id: str | None = None,
1617
context_id: str | None = None,
1718
task: Task | None = None,
18-
context: ServerCallContext | None = None,
1919
) -> RequestContext:
2020
pass

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def __init__(
3535

3636
async def build(
3737
self,
38+
context: ServerCallContext,
3839
params: SendMessageRequest | None = None,
3940
task_id: str | None = None,
4041
context_id: str | None = None,
4142
task: Task | None = None,
42-
context: ServerCallContext | None = None,
4343
) -> RequestContext:
4444
"""Builds the request context for an agent execution.
4545
@@ -48,11 +48,11 @@ async def build(
4848
referenced in `params.message.reference_task_ids` from the `task_store`.
4949
5050
Args:
51+
context: The server call context, containing metadata about the call.
5152
params: The parameters of the incoming message send request.
5253
task_id: The ID of the task being executed.
5354
context_id: The ID of the current execution context.
5455
task: The primary task object associated with the request.
55-
context: The server call context, containing metadata about the call.
5656
5757
Returns:
5858
An instance of RequestContext populated with the provided information
@@ -77,12 +77,12 @@ async def build(
7777
related_tasks = [x for x in tasks if x is not None]
7878

7979
return RequestContext(
80+
call_context=context,
8081
request=params,
8182
task_id=task_id,
8283
context_id=context_id,
8384
task=task,
8485
related_tasks=related_tasks,
85-
call_context=context,
8686
task_id_generator=self._task_id_generator,
8787
context_id_generator=self._context_id_generator,
8888
)

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ async def on_cancel_task(
195195

196196
await self.agent_executor.cancel(
197197
RequestContext(
198-
None,
198+
call_context=context,
199+
request=None,
199200
task_id=task.id,
200201
context_id=task.context_id,
201202
task=task,

tests/server/agent_execution/test_context.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def mock_task(self) -> Mock:
3535

3636
def test_init_without_params(self) -> None:
3737
"""Test initialization without parameters."""
38-
context = RequestContext()
38+
context = RequestContext(ServerCallContext())
3939
assert context.message is None
4040
assert context.task_id is None
4141
assert context.context_id is None
@@ -51,7 +51,7 @@ def test_init_with_params_no_ids(self, mock_params: Mock) -> None:
5151
uuid.UUID('00000000-0000-0000-0000-000000000002'),
5252
],
5353
):
54-
context = RequestContext(request=mock_params)
54+
context = RequestContext(ServerCallContext(), request=mock_params)
5555

5656
assert context.message == mock_params.message
5757
assert context.task_id == '00000000-0000-0000-0000-000000000001'
@@ -68,15 +68,19 @@ def test_init_with_params_no_ids(self, mock_params: Mock) -> None:
6868
def test_init_with_task_id(self, mock_params: Mock) -> None:
6969
"""Test initialization with task ID provided."""
7070
task_id = 'task-123'
71-
context = RequestContext(request=mock_params, task_id=task_id)
71+
context = RequestContext(
72+
ServerCallContext(), request=mock_params, task_id=task_id
73+
)
7274

7375
assert context.task_id == task_id
7476
assert mock_params.message.task_id == task_id
7577

7678
def test_init_with_context_id(self, mock_params: Mock) -> None:
7779
"""Test initialization with context ID provided."""
7880
context_id = 'context-456'
79-
context = RequestContext(request=mock_params, context_id=context_id)
81+
context = RequestContext(
82+
ServerCallContext(), request=mock_params, context_id=context_id
83+
)
8084

8185
assert context.context_id == context_id
8286
assert mock_params.message.context_id == context_id
@@ -86,7 +90,10 @@ def test_init_with_both_ids(self, mock_params: Mock) -> None:
8690
task_id = 'task-123'
8791
context_id = 'context-456'
8892
context = RequestContext(
89-
request=mock_params, task_id=task_id, context_id=context_id
93+
ServerCallContext(),
94+
request=mock_params,
95+
task_id=task_id,
96+
context_id=context_id,
9097
)
9198

9299
assert context.task_id == task_id
@@ -96,18 +103,20 @@ def test_init_with_both_ids(self, mock_params: Mock) -> None:
96103

97104
def test_init_with_task(self, mock_params: Mock, mock_task: Mock) -> None:
98105
"""Test initialization with a task object."""
99-
context = RequestContext(request=mock_params, task=mock_task)
106+
context = RequestContext(
107+
ServerCallContext(), request=mock_params, task=mock_task
108+
)
100109

101110
assert context.current_task == mock_task
102111

103112
def test_get_user_input_no_params(self) -> None:
104113
"""Test get_user_input with no params returns empty string."""
105-
context = RequestContext()
114+
context = RequestContext(ServerCallContext())
106115
assert context.get_user_input() == ''
107116

108117
def test_attach_related_task(self, mock_task: Mock) -> None:
109118
"""Test attach_related_task adds a task to related_tasks."""
110-
context = RequestContext()
119+
context = RequestContext(ServerCallContext())
111120
assert len(context.related_tasks) == 0
112121

113122
context.attach_related_task(mock_task)
@@ -122,7 +131,7 @@ def test_attach_related_task(self, mock_task: Mock) -> None:
122131

123132
def test_current_task_property(self, mock_task: Mock) -> None:
124133
"""Test current_task getter and setter."""
125-
context = RequestContext()
134+
context = RequestContext(ServerCallContext())
126135
assert context.current_task is None
127136

128137
context.current_task = mock_task
@@ -135,7 +144,7 @@ def test_current_task_property(self, mock_task: Mock) -> None:
135144

136145
def test_check_or_generate_task_id_no_params(self) -> None:
137146
"""Test _check_or_generate_task_id with no params does nothing."""
138-
context = RequestContext()
147+
context = RequestContext(ServerCallContext())
139148
context._check_or_generate_task_id()
140149
assert context.task_id is None
141150

@@ -146,7 +155,7 @@ def test_check_or_generate_task_id_with_existing_task_id(
146155
existing_id = 'existing-task-id'
147156
mock_params.message.task_id = existing_id
148157

149-
context = RequestContext(request=mock_params)
158+
context = RequestContext(ServerCallContext(), request=mock_params)
150159
# The method is called during initialization
151160

152161
assert context.task_id == existing_id
@@ -160,15 +169,17 @@ def test_check_or_generate_task_id_with_custom_id_generator(
160169
id_generator.generate.return_value = 'custom-task-id'
161170

162171
context = RequestContext(
163-
request=mock_params, task_id_generator=id_generator
172+
ServerCallContext(),
173+
request=mock_params,
174+
task_id_generator=id_generator,
164175
)
165176
# The method is called during initialization
166177

167178
assert context.task_id == 'custom-task-id'
168179

169180
def test_check_or_generate_context_id_no_params(self) -> None:
170181
"""Test _check_or_generate_context_id with no params does nothing."""
171-
context = RequestContext()
182+
context = RequestContext(ServerCallContext())
172183
context._check_or_generate_context_id()
173184
assert context.context_id is None
174185

@@ -179,7 +190,7 @@ def test_check_or_generate_context_id_with_existing_context_id(
179190
existing_id = 'existing-context-id'
180191
mock_params.message.context_id = existing_id
181192

182-
context = RequestContext(request=mock_params)
193+
context = RequestContext(ServerCallContext(), request=mock_params)
183194
# The method is called during initialization
184195

185196
assert context.context_id == existing_id
@@ -193,7 +204,9 @@ def test_check_or_generate_context_id_with_custom_id_generator(
193204
id_generator.generate.return_value = 'custom-context-id'
194205

195206
context = RequestContext(
196-
request=mock_params, context_id_generator=id_generator
207+
ServerCallContext(),
208+
request=mock_params,
209+
context_id_generator=id_generator,
197210
)
198211
# The method is called during initialization
199212

@@ -205,7 +218,10 @@ def test_init_raises_error_on_task_id_mismatch(
205218
"""Test that an error is raised if provided task_id mismatches task.id."""
206219
with pytest.raises(InvalidParamsError) as exc_info:
207220
RequestContext(
208-
request=mock_params, task_id='wrong-task-id', task=mock_task
221+
ServerCallContext(),
222+
request=mock_params,
223+
task_id='wrong-task-id',
224+
task=mock_task,
209225
)
210226
assert 'bad task id' in exc_info.value.message
211227

@@ -218,6 +234,7 @@ def test_init_raises_error_on_context_id_mismatch(
218234

219235
with pytest.raises(InvalidParamsError) as exc_info:
220236
RequestContext(
237+
ServerCallContext(),
221238
request=mock_params,
222239
task_id=mock_task.id,
223240
context_id='wrong-context-id',
@@ -229,30 +246,32 @@ def test_init_raises_error_on_context_id_mismatch(
229246
def test_with_related_tasks_provided(self, mock_task: Mock) -> None:
230247
"""Test initialization with related tasks provided."""
231248
related_tasks = [mock_task, Mock(spec=Task)]
232-
context = RequestContext(related_tasks=related_tasks) # type: ignore[arg-type]
249+
context = RequestContext(
250+
ServerCallContext(), related_tasks=related_tasks
251+
) # type: ignore[arg-type]
233252

234253
assert context.related_tasks == related_tasks
235254
assert len(context.related_tasks) == 2
236255

237256
def test_message_property_without_params(self) -> None:
238257
"""Test message property returns None when no params are provided."""
239-
context = RequestContext()
258+
context = RequestContext(ServerCallContext())
240259
assert context.message is None
241260

242261
def test_message_property_with_params(self, mock_params: Mock) -> None:
243262
"""Test message property returns the message from params."""
244-
context = RequestContext(request=mock_params)
263+
context = RequestContext(ServerCallContext(), request=mock_params)
245264
assert context.message == mock_params.message
246265

247266
def test_metadata_property_without_content(self) -> None:
248267
"""Test metadata property returns empty dict when no content are provided."""
249-
context = RequestContext()
268+
context = RequestContext(ServerCallContext())
250269
assert context.metadata == {}
251270

252271
def test_metadata_property_with_content(self, mock_params: Mock) -> None:
253272
"""Test metadata property returns the metadata from params."""
254273
mock_params.metadata = {'key': 'value'}
255-
context = RequestContext(request=mock_params)
274+
context = RequestContext(ServerCallContext(), request=mock_params)
256275
assert context.metadata == {'key': 'value'}
257276

258277
def test_init_with_existing_ids_in_message(
@@ -262,7 +281,7 @@ def test_init_with_existing_ids_in_message(
262281
mock_message.task_id = 'existing-task-id'
263282
mock_message.context_id = 'existing-context-id'
264283

265-
context = RequestContext(request=mock_params)
284+
context = RequestContext(ServerCallContext(), request=mock_params)
266285

267286
assert context.task_id == 'existing-task-id'
268287
assert context.context_id == 'existing-context-id'
@@ -275,7 +294,10 @@ def test_init_with_task_id_and_existing_task_id_match(
275294
mock_params.message.task_id = mock_task.id
276295

277296
context = RequestContext(
278-
request=mock_params, task_id=mock_task.id, task=mock_task
297+
ServerCallContext(),
298+
request=mock_params,
299+
task_id=mock_task.id,
300+
task=mock_task,
279301
)
280302

281303
assert context.task_id == mock_task.id
@@ -289,6 +311,7 @@ def test_init_with_context_id_and_existing_context_id_match(
289311
mock_params.message.context_id = mock_task.context_id
290312

291313
context = RequestContext(
314+
ServerCallContext(),
292315
request=mock_params,
293316
task_id=mock_task.id,
294317
context_id=mock_task.context_id,

0 commit comments

Comments
 (0)