Skip to content

Commit 3eb4bf6

Browse files
committed
Merge remote-tracking branch 'origin/1.0-dev' into ishymko/history-length
2 parents 6085fa0 + e71ac62 commit 3eb4bf6

3 files changed

Lines changed: 35 additions & 18 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,11 +552,15 @@ async def on_subscribe_to_task(
552552

553553
if task.status.state in TERMINAL_TASK_STATES:
554554
raise ServerError(
555-
error=InvalidParamsError(
555+
error=UnsupportedOperationError(
556556
message=f'Task {task.id} is in terminal state: {task.status.state}'
557557
)
558558
)
559559

560+
# The operation MUST return a Task object as the first event in the stream
561+
# https://a2a-protocol.org/latest/specification/#316-subscribe-to-task
562+
yield task
563+
560564
task_manager = TaskManager(
561565
task_id=task.id,
562566
context_id=task.context_id,

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,9 @@ async def streaming_coro():
703703
collected_events: list[Any] = []
704704
async for event in response:
705705
collected_events.append(event)
706-
assert len(collected_events) == len(events)
706+
assert (
707+
len(collected_events) == len(events) + 1
708+
) # First event is task itself
707709
assert mock_task.history is not None and len(mock_task.history) == 0
708710

709711
async def test_on_subscribe_no_existing_task_error(self) -> None:

tests/utils/test_task.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55

66
import pytest
77

8-
from a2a.types.a2a_pb2 import Artifact, Message, Part, Role, TaskState, GetTaskRequest, SendMessageConfiguration
8+
from a2a.types.a2a_pb2 import (
9+
Artifact,
10+
Message,
11+
Part,
12+
Role,
13+
TaskState,
14+
GetTaskRequest,
15+
SendMessageConfiguration,
16+
)
917
from a2a.utils.task import (
1018
apply_history_length,
1119
completed_task,
@@ -214,51 +222,54 @@ def test_decode_page_token_fails(self):
214222
)
215223

216224

217-
if __name__ == '__main__':
218-
unittest.main()
219-
220-
221-
222225
class TestApplyHistoryLength(unittest.TestCase):
223226
def setUp(self):
224-
# Create a task with some history
225227
self.history = [
226-
Message(message_id=str(i), role=Role.ROLE_USER, parts=[Part(text=f'msg {i}')])
228+
Message(
229+
message_id=str(i),
230+
role=Role.ROLE_USER,
231+
parts=[Part(text=f'msg {i}')],
232+
)
227233
for i in range(5)
228234
]
229235
artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])]
230236
self.task = completed_task(
231237
task_id='t1',
232238
context_id='c1',
233239
artifacts=artifacts,
234-
history=self.history
240+
history=self.history,
235241
)
236242

237243
def test_none_config_returns_full_history(self):
238-
# Test None (no limit) - config is None
239244
result = apply_history_length(self.task, None)
240245
self.assertEqual(len(result.history), 5)
241246
self.assertEqual(result.history, self.history)
242247

243248
def test_unset_history_length_returns_full_history(self):
244-
# Test unset (HasField returns False)
245-
# Using GetTaskRequest as it has history_length field
246249
result = apply_history_length(self.task, GetTaskRequest())
247250
self.assertEqual(len(result.history), 5)
248251
self.assertEqual(result.history, self.history)
249252

250253
def test_positive_history_length_truncates(self):
251-
# Test > 0 (partial)
252-
result = apply_history_length(self.task, GetTaskRequest(history_length=2))
254+
result = apply_history_length(
255+
self.task, GetTaskRequest(history_length=2)
256+
)
253257
self.assertEqual(len(result.history), 2)
254258
self.assertEqual(result.history, self.history[-2:])
255259

256260
def test_large_history_length_returns_full_history(self):
257-
result = apply_history_length(self.task, GetTaskRequest(history_length=10))
261+
result = apply_history_length(
262+
self.task, GetTaskRequest(history_length=10)
263+
)
258264
self.assertEqual(len(result.history), 5)
259265
self.assertEqual(result.history, self.history)
260266

261267
def test_zero_history_length_returns_empty_history(self):
262-
result = apply_history_length(self.task, SendMessageConfiguration(history_length=0))
268+
result = apply_history_length(
269+
self.task, SendMessageConfiguration(history_length=0)
270+
)
263271
self.assertEqual(len(result.history), 0)
264272

273+
274+
if __name__ == '__main__':
275+
unittest.main()

0 commit comments

Comments
 (0)