Skip to content

Commit 4200134

Browse files
authored
Merge branch '1.0-dev' into ishymko/678-push-notifications
2 parents 4e02151 + 72a1007 commit 4200134

5 files changed

Lines changed: 116 additions & 21 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ async def on_get_task(
128128
if not task:
129129
raise ServerError(error=TaskNotFoundError())
130130

131-
# Apply historyLength parameter if specified
132-
return apply_history_length(task, params.history_length)
131+
return apply_history_length(task, params)
133132

134133
async def on_list_tasks(
135134
self,
@@ -142,7 +141,7 @@ async def on_list_tasks(
142141
if not params.include_artifacts:
143142
task.ClearField('artifacts')
144143

145-
updated_task = apply_history_length(task, params.history_length)
144+
updated_task = apply_history_length(task, params)
146145
if updated_task is not task:
147146
task.CopyFrom(updated_task)
148147

@@ -381,9 +380,7 @@ async def push_notification_callback(event: Event) -> None:
381380
if isinstance(result, Task):
382381
self._validate_task_id_match(task_id, result.id)
383382
if params.configuration:
384-
result = apply_history_length(
385-
result, params.configuration.history_length
386-
)
383+
result = apply_history_length(result, params.configuration)
387384

388385
return result
389386

@@ -552,11 +549,15 @@ async def on_subscribe_to_task(
552549

553550
if task.status.state in TERMINAL_TASK_STATES:
554551
raise ServerError(
555-
error=InvalidParamsError(
552+
error=UnsupportedOperationError(
556553
message=f'Task {task.id} is in terminal state: {task.status.state}'
557554
)
558555
)
559556

557+
# The operation MUST return a Task object as the first event in the stream
558+
# https://a2a-protocol.org/latest/specification/#316-subscribe-to-task
559+
yield task
560+
560561
task_manager = TaskManager(
561562
task_id=task.id,
562563
context_id=task.context_id,

src/a2a/utils/task.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55

66
from base64 import b64decode, b64encode
7+
from typing import Literal, Protocol, runtime_checkable
78

89
from a2a.types.a2a_pb2 import (
910
Artifact,
@@ -81,27 +82,57 @@ def completed_task(
8182
)
8283

8384

84-
def apply_history_length(task: Task, history_length: int | None) -> Task:
85+
@runtime_checkable
86+
class HistoryLengthConfig(Protocol):
87+
"""Protocol for configuration arguments containing history_length field."""
88+
89+
history_length: int
90+
91+
def HasField(self, field_name: Literal['history_length']) -> bool: # noqa: N802 -- Protobuf generated code
92+
"""Checks if a field is set.
93+
94+
This method name matches the generated Protobuf code.
95+
"""
96+
...
97+
98+
99+
def apply_history_length(
100+
task: Task, config: HistoryLengthConfig | None
101+
) -> Task:
85102
"""Applies history_length parameter on task and returns a new task object.
86103
87104
Args:
88105
task: The original task object with complete history
89-
history_length: History length configuration value
106+
config: Configuration object containing 'history_length' field and HasField method.
90107
91108
Returns:
92109
A new task object with limited history
110+
111+
See Also:
112+
https://a2a-protocol.org/latest/specification/#324-history-length-semantics
93113
"""
94-
# Apply historyLength parameter if specified
95-
if history_length is not None and history_length > 0 and task.history:
96-
# Limit history to the most recent N messages
97-
limited_history = list(task.history[-history_length:])
98-
# Create a new task instance with limited history
114+
if config is None or not config.HasField('history_length'):
115+
return task
116+
117+
history_length = config.history_length
118+
119+
if history_length == 0:
120+
if not task.history:
121+
return task
99122
task_copy = Task()
100123
task_copy.CopyFrom(task)
101-
# Clear and re-add history items
102-
del task_copy.history[:]
103-
task_copy.history.extend(limited_history)
124+
task_copy.ClearField('history')
104125
return task_copy
126+
127+
if history_length > 0 and task.history:
128+
if len(task.history) <= history_length:
129+
return task
130+
131+
task_copy = Task()
132+
task_copy.CopyFrom(task)
133+
del task_copy.history[:-history_length]
134+
return task_copy
135+
105136
return task
106137

107138

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,10 @@ async def exec_side_effect(_request, queue: EventQueue):
15061506
# Allow producer to emit the next event
15071507
allow_second_event.set()
15081508

1509-
received = await resub_gen.__anext__()
1509+
first_subscribe_event = await anext(resub_gen)
1510+
assert first_subscribe_event == task_for_resub
1511+
1512+
received = await anext(resub_gen)
15101513
assert received == second_event
15111514

15121515
# Finish producer to allow cleanup paths to complete
@@ -2713,7 +2716,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state):
27132716
async for _ in request_handler.on_subscribe_to_task(params, context):
27142717
pass # pragma: no cover
27152718

2716-
assert isinstance(exc_info.value.error, InvalidParamsError)
2719+
assert isinstance(exc_info.value.error, UnsupportedOperationError)
27172720
assert exc_info.value.error.message
27182721
assert (
27192722
f'Task {task_id} is in terminal state: {terminal_state}'

tests/server/request_handlers/test_jsonrpc_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,9 @@ async def streaming_coro():
703703
collected_events: list[Any] = []
704704
async for event in response:
705705
collected_events.append(event)
706-
assert len(collected_events) == len(events)
706+
assert (
707+
len(collected_events) == len(events) + 1
708+
) # First event is task itself
707709
assert mock_task.history is not None and len(mock_task.history) == 0
708710

709711
async def test_on_subscribe_no_existing_task_error(self) -> None:

tests/utils/test_task.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,17 @@
55

66
import pytest
77

8-
from a2a.types.a2a_pb2 import Artifact, Message, Part, Role, TaskState
8+
from a2a.types.a2a_pb2 import (
9+
Artifact,
10+
Message,
11+
Part,
12+
Role,
13+
TaskState,
14+
GetTaskRequest,
15+
SendMessageConfiguration,
16+
)
917
from a2a.utils.task import (
18+
apply_history_length,
1019
completed_task,
1120
decode_page_token,
1221
encode_page_token,
@@ -213,5 +222,54 @@ def test_decode_page_token_fails(self):
213222
)
214223

215224

225+
class TestApplyHistoryLength(unittest.TestCase):
226+
def setUp(self):
227+
self.history = [
228+
Message(
229+
message_id=str(i),
230+
role=Role.ROLE_USER,
231+
parts=[Part(text=f'msg {i}')],
232+
)
233+
for i in range(5)
234+
]
235+
artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])]
236+
self.task = completed_task(
237+
task_id='t1',
238+
context_id='c1',
239+
artifacts=artifacts,
240+
history=self.history,
241+
)
242+
243+
def test_none_config_returns_full_history(self):
244+
result = apply_history_length(self.task, None)
245+
self.assertEqual(len(result.history), 5)
246+
self.assertEqual(result.history, self.history)
247+
248+
def test_unset_history_length_returns_full_history(self):
249+
result = apply_history_length(self.task, GetTaskRequest())
250+
self.assertEqual(len(result.history), 5)
251+
self.assertEqual(result.history, self.history)
252+
253+
def test_positive_history_length_truncates(self):
254+
result = apply_history_length(
255+
self.task, GetTaskRequest(history_length=2)
256+
)
257+
self.assertEqual(len(result.history), 2)
258+
self.assertEqual(result.history, self.history[-2:])
259+
260+
def test_large_history_length_returns_full_history(self):
261+
result = apply_history_length(
262+
self.task, GetTaskRequest(history_length=10)
263+
)
264+
self.assertEqual(len(result.history), 5)
265+
self.assertEqual(result.history, self.history)
266+
267+
def test_zero_history_length_returns_empty_history(self):
268+
result = apply_history_length(
269+
self.task, SendMessageConfiguration(history_length=0)
270+
)
271+
self.assertEqual(len(result.history), 0)
272+
273+
216274
if __name__ == '__main__':
217275
unittest.main()

0 commit comments

Comments
 (0)