Skip to content

Commit 6085fa0

Browse files
committed
fix: properly handle unset and zero history length
According to https://a2a-protocol.org/latest/specification/#324-history-length-semantics. It changes behavior so the fix was postponed till 1.0. Fixes #573
1 parent 427a75b commit 6085fa0

4 files changed

Lines changed: 97 additions & 19 deletions

File tree

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ async def on_get_task(
127127
if not task:
128128
raise ServerError(error=TaskNotFoundError())
129129

130-
# Apply historyLength parameter if specified
131-
return apply_history_length(task, params.history_length)
130+
return apply_history_length(task, params)
132131

133132
async def on_list_tasks(
134133
self,
@@ -141,7 +140,7 @@ async def on_list_tasks(
141140
if not params.include_artifacts:
142141
task.ClearField('artifacts')
143142

144-
updated_task = apply_history_length(task, params.history_length)
143+
updated_task = apply_history_length(task, params)
145144
if updated_task is not task:
146145
task.CopyFrom(updated_task)
147146

@@ -380,9 +379,7 @@ async def push_notification_callback() -> None:
380379
if isinstance(result, Task):
381380
self._validate_task_id_match(task_id, result.id)
382381
if params.configuration:
383-
result = apply_history_length(
384-
result, params.configuration.history_length
385-
)
382+
result = apply_history_length(result, params.configuration)
386383

387384
await self._send_push_notification_if_needed(task_id, result_aggregator)
388385

src/a2a/utils/task.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55

66
from base64 import b64decode, b64encode
7+
from typing import Literal, Protocol, runtime_checkable
78

89
from a2a.types.a2a_pb2 import (
910
Artifact,
@@ -81,27 +82,57 @@ def completed_task(
8182
)
8283

8384

84-
def apply_history_length(task: Task, history_length: int | None) -> Task:
85+
@runtime_checkable
86+
class HistoryLengthConfig(Protocol):
87+
"""Protocol for configuration arguments containing history_length field."""
88+
89+
history_length: int
90+
91+
def HasField(self, field_name: Literal['history_length']) -> bool: # noqa: N802 -- Protobuf generated code
92+
"""Checks if a field is set.
93+
94+
This method name matches the generated Protobuf code.
95+
"""
96+
...
97+
98+
99+
def apply_history_length(
100+
task: Task, config: HistoryLengthConfig | None
101+
) -> Task:
85102
"""Applies history_length parameter on task and returns a new task object.
86103
87104
Args:
88105
task: The original task object with complete history
89-
history_length: History length configuration value
106+
config: Configuration object containing 'history_length' field and HasField method.
90107
91108
Returns:
92109
A new task object with limited history
93110
"""
94-
# Apply historyLength parameter if specified
95-
if history_length is not None and history_length > 0 and task.history:
96-
# Limit history to the most recent N messages
97-
limited_history = list(task.history[-history_length:])
98-
# Create a new task instance with limited history
111+
if config is None:
112+
return task
113+
114+
# See https://a2a-protocol.org/latest/specification/#324-history-length-semantics
115+
116+
if not config.HasField('history_length'):
117+
return task
118+
119+
history_length = config.history_length
120+
121+
if history_length == 0:
122+
task_copy = Task()
123+
task_copy.CopyFrom(task)
124+
task_copy.ClearField('history')
125+
return task_copy
126+
127+
if history_length > 0 and task.history:
128+
if len(task.history) <= history_length:
129+
return task
130+
99131
task_copy = Task()
100132
task_copy.CopyFrom(task)
101-
# Clear and re-add history items
102-
del task_copy.history[:]
103-
task_copy.history.extend(limited_history)
133+
del task_copy.history[:-history_length]
104134
return task_copy
135+
105136
return task
106137

107138

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,10 @@ async def exec_side_effect(_request, queue: EventQueue):
14991499
# Allow producer to emit the next event
15001500
allow_second_event.set()
15011501

1502-
received = await resub_gen.__anext__()
1502+
first_subscribe_event = await anext(resub_gen)
1503+
assert first_subscribe_event == task_for_resub
1504+
1505+
received = await anext(resub_gen)
15031506
assert received == second_event
15041507

15051508
# Finish producer to allow cleanup paths to complete
@@ -2706,7 +2709,7 @@ async def test_on_subscribe_to_task_in_terminal_state(terminal_state):
27062709
async for _ in request_handler.on_subscribe_to_task(params, context):
27072710
pass # pragma: no cover
27082711

2709-
assert isinstance(exc_info.value.error, InvalidParamsError)
2712+
assert isinstance(exc_info.value.error, UnsupportedOperationError)
27102713
assert exc_info.value.error.message
27112714
assert (
27122715
f'Task {task_id} is in terminal state: {terminal_state}'

tests/utils/test_task.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import pytest
77

8-
from a2a.types.a2a_pb2 import Artifact, Message, Part, Role, TaskState
8+
from a2a.types.a2a_pb2 import Artifact, Message, Part, Role, TaskState, GetTaskRequest, SendMessageConfiguration
99
from a2a.utils.task import (
10+
apply_history_length,
1011
completed_task,
1112
decode_page_token,
1213
encode_page_token,
@@ -215,3 +216,49 @@ def test_decode_page_token_fails(self):
215216

216217
if __name__ == '__main__':
217218
unittest.main()
219+
220+
221+
222+
class TestApplyHistoryLength(unittest.TestCase):
223+
def setUp(self):
224+
# Create a task with some history
225+
self.history = [
226+
Message(message_id=str(i), role=Role.ROLE_USER, parts=[Part(text=f'msg {i}')])
227+
for i in range(5)
228+
]
229+
artifacts = [Artifact(artifact_id='a1', parts=[Part(text='a')])]
230+
self.task = completed_task(
231+
task_id='t1',
232+
context_id='c1',
233+
artifacts=artifacts,
234+
history=self.history
235+
)
236+
237+
def test_none_config_returns_full_history(self):
238+
# Test None (no limit) - config is None
239+
result = apply_history_length(self.task, None)
240+
self.assertEqual(len(result.history), 5)
241+
self.assertEqual(result.history, self.history)
242+
243+
def test_unset_history_length_returns_full_history(self):
244+
# Test unset (HasField returns False)
245+
# Using GetTaskRequest as it has history_length field
246+
result = apply_history_length(self.task, GetTaskRequest())
247+
self.assertEqual(len(result.history), 5)
248+
self.assertEqual(result.history, self.history)
249+
250+
def test_positive_history_length_truncates(self):
251+
# Test > 0 (partial)
252+
result = apply_history_length(self.task, GetTaskRequest(history_length=2))
253+
self.assertEqual(len(result.history), 2)
254+
self.assertEqual(result.history, self.history[-2:])
255+
256+
def test_large_history_length_returns_full_history(self):
257+
result = apply_history_length(self.task, GetTaskRequest(history_length=10))
258+
self.assertEqual(len(result.history), 5)
259+
self.assertEqual(result.history, self.history)
260+
261+
def test_zero_history_length_returns_empty_history(self):
262+
result = apply_history_length(self.task, SendMessageConfiguration(history_length=0))
263+
self.assertEqual(len(result.history), 0)
264+

0 commit comments

Comments
 (0)