|
6 | 6 | import pytest |
7 | 7 |
|
8 | 8 | from a2a.types import Artifact, Message, Part, Role, TextPart |
9 | | -from a2a.utils.task import completed_task, new_task |
| 9 | +from a2a.utils.task import apply_history_length, completed_task, new_task |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class TestTask(unittest.TestCase): |
@@ -188,6 +188,45 @@ def test_completed_task_invalid_artifact_type(self): |
188 | 188 | history=[], |
189 | 189 | ) |
190 | 190 |
|
| 191 | + def test_apply_history_length_cases(self): |
| 192 | + # Setup task with 3 messages |
| 193 | + history = [ |
| 194 | + Message(role=Role.user, parts=[Part(root=TextPart(text='1'))], message_id='1'), |
| 195 | + Message(role=Role.agent, parts=[Part(root=TextPart(text='2'))], message_id='2'), |
| 196 | + Message(role=Role.user, parts=[Part(root=TextPart(text='3'))], message_id='3'), |
| 197 | + ] |
| 198 | + task_id = str(uuid.uuid4()) |
| 199 | + context_id = str(uuid.uuid4()) |
| 200 | + task = completed_task( |
| 201 | + task_id=task_id, |
| 202 | + context_id=context_id, |
| 203 | + artifacts=[Artifact(artifact_id='a', parts=[Part(root=TextPart(text='a'))])], |
| 204 | + history=history |
| 205 | + ) |
| 206 | + |
| 207 | + # historyLength = 0 -> empty |
| 208 | + t0 = apply_history_length(task, 0) |
| 209 | + self.assertEqual(len(t0.history), 0) |
| 210 | + |
| 211 | + # historyLength = 1 -> last one |
| 212 | + t1 = apply_history_length(task, 1) |
| 213 | + self.assertEqual(len(t1.history), 1) |
| 214 | + self.assertEqual(t1.history[0].message_id, '3') |
| 215 | + |
| 216 | + # historyLength = 2 -> last two |
| 217 | + t2 = apply_history_length(task, 2) |
| 218 | + self.assertEqual(len(t2.history), 2) |
| 219 | + self.assertEqual(t2.history[0].message_id, '2') |
| 220 | + self.assertEqual(t2.history[1].message_id, '3') |
| 221 | + |
| 222 | + # historyLength = None -> all |
| 223 | + tn = apply_history_length(task, None) |
| 224 | + self.assertEqual(len(tn.history), 3) |
| 225 | + |
| 226 | + # historyLength = 10 -> all |
| 227 | + t10 = apply_history_length(task, 10) |
| 228 | + self.assertEqual(len(t10.history), 3) |
| 229 | + |
191 | 230 |
|
192 | 231 | if __name__ == '__main__': |
193 | 232 | unittest.main() |
0 commit comments