Skip to content

Commit b138aff

Browse files
committed
feat: add ArtifactStreamer for streaming artifact updates with stable ID
Adds a stateful streaming helper to a2a.utils that maintains a stable artifact_id across chunks, enabling correct append=True semantics for TaskArtifactUpdateEvent. Closes #833
1 parent b941eef commit b138aff

3 files changed

Lines changed: 238 additions & 1 deletion

File tree

src/a2a/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utility functions for the A2A Python SDK."""
22

33
from a2a.utils.artifact import (
4+
ArtifactStreamer,
45
get_artifact_text,
56
new_artifact,
67
new_data_artifact,
@@ -39,6 +40,7 @@
3940
'DEFAULT_RPC_URL',
4041
'EXTENDED_AGENT_CARD_PATH',
4142
'PREV_AGENT_CARD_WELL_KNOWN_PATH',
43+
'ArtifactStreamer',
4244
'append_artifact_to_task',
4345
'are_modalities_compatible',
4446
'build_text_artifact',

src/a2a/utils/artifact.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
from typing import Any
66

7-
from a2a.types import Artifact, DataPart, Part, TextPart
7+
from a2a.types import (
8+
Artifact,
9+
DataPart,
10+
Part,
11+
TaskArtifactUpdateEvent,
12+
TextPart,
13+
)
814
from a2a.utils.parts import get_text_parts
915

1016

@@ -86,3 +92,104 @@ def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str:
8692
A single string containing all text content, or an empty string if no text parts are found.
8793
"""
8894
return delimiter.join(get_text_parts(artifact.parts))
95+
96+
97+
class ArtifactStreamer:
98+
"""A stateful helper for streaming artifact updates with a stable artifact ID.
99+
100+
Solves the problem where calling ``new_text_artifact`` in a loop generates
101+
a fresh ``artifact_id`` each time, making ``append=True`` unusable.
102+
103+
Example::
104+
105+
streamer = ArtifactStreamer(context_id, task_id, name='response')
106+
107+
async for chunk in llm.stream(prompt):
108+
await event_queue.enqueue_event(streamer.append(chunk))
109+
110+
await event_queue.enqueue_event(streamer.finalize())
111+
112+
Args:
113+
context_id: The context ID associated with the task.
114+
task_id: The ID of the task this artifact belongs to.
115+
name: A human-readable name for the artifact.
116+
description: An optional description of the artifact.
117+
"""
118+
119+
def __init__(
120+
self,
121+
context_id: str,
122+
task_id: str,
123+
name: str,
124+
description: str | None = None,
125+
) -> None:
126+
self._context_id = context_id
127+
self._task_id = task_id
128+
self._name = name
129+
self._description = description
130+
self._artifact_id = str(uuid.uuid4())
131+
self._finalized = False
132+
133+
@property
134+
def artifact_id(self) -> str:
135+
"""The stable artifact ID used across all chunks."""
136+
return self._artifact_id
137+
138+
def append(self, text: str) -> TaskArtifactUpdateEvent:
139+
"""Create an append event for the next chunk of text.
140+
141+
Args:
142+
text: The text content to append.
143+
144+
Returns:
145+
A ``TaskArtifactUpdateEvent`` with ``append=True`` and
146+
``last_chunk=False``.
147+
148+
Raises:
149+
RuntimeError: If ``finalize()`` has already been called.
150+
"""
151+
if self._finalized:
152+
raise RuntimeError(
153+
'Cannot append after finalize() has been called.'
154+
)
155+
return TaskArtifactUpdateEvent(
156+
context_id=self._context_id,
157+
task_id=self._task_id,
158+
append=True,
159+
last_chunk=False,
160+
artifact=Artifact(
161+
artifact_id=self._artifact_id,
162+
name=self._name,
163+
description=self._description,
164+
parts=[Part(root=TextPart(text=text))],
165+
),
166+
)
167+
168+
def finalize(self, text: str = '') -> TaskArtifactUpdateEvent:
169+
"""Create the final chunk event, closing the stream.
170+
171+
Args:
172+
text: Optional final text content. Defaults to empty string.
173+
174+
Returns:
175+
A ``TaskArtifactUpdateEvent`` with ``append=True`` and
176+
``last_chunk=True``.
177+
178+
Raises:
179+
RuntimeError: If ``finalize()`` has already been called.
180+
"""
181+
if self._finalized:
182+
raise RuntimeError('finalize() has already been called.')
183+
self._finalized = True
184+
return TaskArtifactUpdateEvent(
185+
context_id=self._context_id,
186+
task_id=self._task_id,
187+
append=True,
188+
last_chunk=True,
189+
artifact=Artifact(
190+
artifact_id=self._artifact_id,
191+
name=self._name,
192+
description=self._description,
193+
parts=[Part(root=TextPart(text=text))],
194+
),
195+
)

tests/utils/test_artifact.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
Artifact,
88
DataPart,
99
Part,
10+
TaskArtifactUpdateEvent,
1011
TextPart,
1112
)
1213
from a2a.utils.artifact import (
14+
ArtifactStreamer,
1315
get_artifact_text,
1416
new_artifact,
1517
new_data_artifact,
@@ -155,5 +157,131 @@ def test_get_artifact_text_empty_parts(self):
155157
assert result == ''
156158

157159

160+
class TestArtifactStreamer(unittest.TestCase):
161+
def setUp(self):
162+
self.context_id = 'ctx-123'
163+
self.task_id = 'task-456'
164+
self.name = 'response'
165+
166+
def test_stable_artifact_id_across_appends(self):
167+
streamer = ArtifactStreamer(
168+
self.context_id, self.task_id, name=self.name
169+
)
170+
event1 = streamer.append('Hello ')
171+
event2 = streamer.append('world')
172+
self.assertEqual(
173+
event1.artifact.artifact_id, event2.artifact.artifact_id
174+
)
175+
176+
def test_append_returns_correct_event_type(self):
177+
streamer = ArtifactStreamer(
178+
self.context_id, self.task_id, name=self.name
179+
)
180+
event = streamer.append('chunk')
181+
self.assertIsInstance(event, TaskArtifactUpdateEvent)
182+
183+
def test_append_sets_append_true_last_chunk_false(self):
184+
streamer = ArtifactStreamer(
185+
self.context_id, self.task_id, name=self.name
186+
)
187+
event = streamer.append('chunk')
188+
self.assertTrue(event.append)
189+
self.assertFalse(event.last_chunk)
190+
191+
def test_append_sets_context_and_task_ids(self):
192+
streamer = ArtifactStreamer(
193+
self.context_id, self.task_id, name=self.name
194+
)
195+
event = streamer.append('chunk')
196+
self.assertEqual(event.context_id, self.context_id)
197+
self.assertEqual(event.task_id, self.task_id)
198+
199+
def test_append_sets_text_content(self):
200+
streamer = ArtifactStreamer(
201+
self.context_id, self.task_id, name=self.name
202+
)
203+
event = streamer.append('Hello world')
204+
self.assertEqual(len(event.artifact.parts), 1)
205+
self.assertEqual(event.artifact.parts[0].root.text, 'Hello world')
206+
207+
def test_append_sets_artifact_name_and_description(self):
208+
streamer = ArtifactStreamer(
209+
self.context_id,
210+
self.task_id,
211+
name='my-artifact',
212+
description='A streamed response',
213+
)
214+
event = streamer.append('chunk')
215+
self.assertEqual(event.artifact.name, 'my-artifact')
216+
self.assertEqual(event.artifact.description, 'A streamed response')
217+
218+
def test_finalize_sets_last_chunk_true(self):
219+
streamer = ArtifactStreamer(
220+
self.context_id, self.task_id, name=self.name
221+
)
222+
event = streamer.finalize('done')
223+
self.assertTrue(event.append)
224+
self.assertTrue(event.last_chunk)
225+
226+
def test_finalize_with_empty_text(self):
227+
streamer = ArtifactStreamer(
228+
self.context_id, self.task_id, name=self.name
229+
)
230+
event = streamer.finalize()
231+
self.assertEqual(event.artifact.parts[0].root.text, '')
232+
self.assertTrue(event.last_chunk)
233+
234+
def test_finalize_uses_same_artifact_id(self):
235+
streamer = ArtifactStreamer(
236+
self.context_id, self.task_id, name=self.name
237+
)
238+
append_event = streamer.append('chunk')
239+
finalize_event = streamer.finalize()
240+
self.assertEqual(
241+
append_event.artifact.artifact_id,
242+
finalize_event.artifact.artifact_id,
243+
)
244+
245+
def test_append_after_finalize_raises(self):
246+
streamer = ArtifactStreamer(
247+
self.context_id, self.task_id, name=self.name
248+
)
249+
streamer.finalize()
250+
with self.assertRaises(RuntimeError):
251+
streamer.append('too late')
252+
253+
def test_double_finalize_raises(self):
254+
streamer = ArtifactStreamer(
255+
self.context_id, self.task_id, name=self.name
256+
)
257+
streamer.finalize()
258+
with self.assertRaises(RuntimeError):
259+
streamer.finalize()
260+
261+
def test_artifact_id_property(self):
262+
streamer = ArtifactStreamer(
263+
self.context_id, self.task_id, name=self.name
264+
)
265+
artifact_id = streamer.artifact_id
266+
self.assertIsInstance(artifact_id, str)
267+
self.assertTrue(len(artifact_id) > 0)
268+
269+
@patch('uuid.uuid4')
270+
def test_artifact_id_from_uuid(self, mock_uuid4):
271+
mock_uuid = uuid.UUID('12345678-1234-5678-1234-567812345678')
272+
mock_uuid4.return_value = mock_uuid
273+
streamer = ArtifactStreamer(
274+
self.context_id, self.task_id, name=self.name
275+
)
276+
self.assertEqual(streamer.artifact_id, str(mock_uuid))
277+
278+
def test_description_defaults_to_none(self):
279+
streamer = ArtifactStreamer(
280+
self.context_id, self.task_id, name=self.name
281+
)
282+
event = streamer.append('chunk')
283+
self.assertIsNone(event.artifact.description)
284+
285+
158286
if __name__ == '__main__':
159287
unittest.main()

0 commit comments

Comments
 (0)