Skip to content

Commit f807b7d

Browse files
lbobinskiŁukasz Bobiński
authored andcommitted
fix: historyLength=0 returns full history (#573)
1 parent cb7cdb3 commit f807b7d

2 files changed

Lines changed: 45 additions & 3 deletions

File tree

src/a2a/utils/task.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,12 @@ def apply_history_length(task: Task, history_length: int | None) -> Task:
8383
A new task object with limited history
8484
"""
8585
# Apply historyLength parameter if specified
86-
if history_length is not None and history_length > 0 and task.history:
86+
if history_length is not None and history_length >= 0:
8787
# Limit history to the most recent N messages
88-
limited_history = task.history[-history_length:]
88+
if task.history and history_length > 0:
89+
limited_history = task.history[-history_length:]
90+
else:
91+
limited_history = []
8992
# Create a new task instance with limited history
9093
return task.model_copy(update={'history': limited_history})
9194

tests/utils/test_task.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from a2a.types import Artifact, Message, Part, Role, TextPart
9-
from a2a.utils.task import completed_task, new_task
9+
from a2a.utils.task import apply_history_length, completed_task, new_task
1010

1111

1212
class TestTask(unittest.TestCase):
@@ -188,6 +188,45 @@ def test_completed_task_invalid_artifact_type(self):
188188
history=[],
189189
)
190190

191+
def test_apply_history_length_cases(self):
192+
# Setup task with 3 messages
193+
history = [
194+
Message(role=Role.user, parts=[Part(root=TextPart(text='1'))], message_id='1'),
195+
Message(role=Role.agent, parts=[Part(root=TextPart(text='2'))], message_id='2'),
196+
Message(role=Role.user, parts=[Part(root=TextPart(text='3'))], message_id='3'),
197+
]
198+
task_id = str(uuid.uuid4())
199+
context_id = str(uuid.uuid4())
200+
task = completed_task(
201+
task_id=task_id,
202+
context_id=context_id,
203+
artifacts=[Artifact(artifact_id='a', parts=[Part(root=TextPart(text='a'))])],
204+
history=history
205+
)
206+
207+
# historyLength = 0 -> empty
208+
t0 = apply_history_length(task, 0)
209+
self.assertEqual(len(t0.history), 0)
210+
211+
# historyLength = 1 -> last one
212+
t1 = apply_history_length(task, 1)
213+
self.assertEqual(len(t1.history), 1)
214+
self.assertEqual(t1.history[0].message_id, '3')
215+
216+
# historyLength = 2 -> last two
217+
t2 = apply_history_length(task, 2)
218+
self.assertEqual(len(t2.history), 2)
219+
self.assertEqual(t2.history[0].message_id, '2')
220+
self.assertEqual(t2.history[1].message_id, '3')
221+
222+
# historyLength = None -> all
223+
tn = apply_history_length(task, None)
224+
self.assertEqual(len(tn.history), 3)
225+
226+
# historyLength = 10 -> all
227+
t10 = apply_history_length(task, 10)
228+
self.assertEqual(len(t10.history), 3)
229+
191230

192231
if __name__ == '__main__':
193232
unittest.main()

0 commit comments

Comments
 (0)