Skip to content

Commit 44b74d8

Browse files
committed
wip refactoring
1 parent b58b03e commit 44b74d8

25 files changed

Lines changed: 357 additions & 1465 deletions

src/a2a/client/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
A2AClientTimeoutError,
2323
AgentCardResolutionError,
2424
)
25-
from a2a.client.helpers import create_text_message_object
2625
from a2a.client.interceptors import ClientCallInterceptor
2726

2827

@@ -41,6 +40,5 @@
4140
'CredentialService',
4241
'InMemoryContextCredentialStore',
4342
'create_client',
44-
'create_text_message_object',
4543
'minimal_agent_card',
4644
]

src/a2a/client/helpers.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Helper functions for the A2A client."""
22

33
from typing import Any
4-
from uuid import uuid4
54

65
from google.protobuf.json_format import ParseDict
76

8-
from a2a.types.a2a_pb2 import AgentCard, Message, Part, Role
7+
from a2a.types.a2a_pb2 import AgentCard
98

109

1110
def parse_agent_card(agent_card_data: dict[str, Any]) -> AgentCard:
@@ -111,20 +110,3 @@ def _handle_security_compatibility(agent_card_data: dict[str, Any]) -> None:
111110
new_scheme_wrapper = {mapped_name: scheme.copy()}
112111
scheme.clear()
113112
scheme.update(new_scheme_wrapper)
114-
115-
116-
def create_text_message_object(
117-
role: Role = Role.ROLE_USER, content: str = ''
118-
) -> Message:
119-
"""Create a Message object containing a single text Part.
120-
121-
Args:
122-
role: The role of the message sender (user or agent). Defaults to Role.ROLE_USER.
123-
content: The text content of the message. Defaults to an empty string.
124-
125-
Returns:
126-
A `Message` object with a new UUID message_id.
127-
"""
128-
return Message(
129-
role=role, parts=[Part(text=content)], message_id=str(uuid4())
130-
)

src/a2a/helpers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Helper functions for the A2A Python SDK."""

src/a2a/helpers/types.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Unified helper functions for creating and handling A2A types."""
2+
3+
import uuid
4+
from collections.abc import Sequence
5+
from typing import Any
6+
7+
from google.protobuf.json_format import MessageToDict
8+
9+
from a2a.types.a2a_pb2 import (
10+
Artifact,
11+
Message,
12+
Part,
13+
Role,
14+
StreamResponse,
15+
Task,
16+
TaskArtifactUpdateEvent,
17+
TaskState,
18+
TaskStatus,
19+
TaskStatusUpdateEvent,
20+
)
21+
22+
23+
# --- Message Helpers ---
24+
25+
26+
def new_message(
27+
parts: list[Part],
28+
role: Role = Role.ROLE_AGENT,
29+
context_id: str | None = None,
30+
task_id: str | None = None,
31+
) -> Message:
32+
"""Creates a new message containing a list of Parts."""
33+
return Message(
34+
role=role,
35+
parts=parts,
36+
message_id=str(uuid.uuid4()),
37+
task_id=task_id,
38+
context_id=context_id,
39+
)
40+
41+
42+
def new_text_message(
43+
text: str,
44+
context_id: str | None = None,
45+
task_id: str | None = None,
46+
role: Role = Role.ROLE_AGENT,
47+
) -> Message:
48+
"""Creates a new message containing a single text Part."""
49+
return new_message(
50+
parts=[Part(text=text)],
51+
role=role,
52+
task_id=task_id,
53+
context_id=context_id,
54+
)
55+
56+
57+
def get_message_text(message: Message, delimiter: str = '\n') -> str:
58+
"""Extracts and joins all text content from a Message's parts."""
59+
return delimiter.join(get_text_parts(message.parts))
60+
61+
62+
# --- Artifact Helpers ---
63+
64+
65+
def new_artifact(
66+
parts: list[Part],
67+
name: str,
68+
description: str | None = None,
69+
) -> Artifact:
70+
"""Creates a new Artifact object."""
71+
return Artifact(
72+
artifact_id=str(uuid.uuid4()),
73+
parts=parts,
74+
name=name,
75+
description=description,
76+
)
77+
78+
79+
def new_text_artifact(
80+
name: str,
81+
text: str,
82+
description: str | None = None,
83+
) -> Artifact:
84+
"""Creates a new Artifact object containing only a single text Part."""
85+
return new_artifact(
86+
[Part(text=text)],
87+
name,
88+
description,
89+
)
90+
91+
92+
def get_artifact_text(artifact: Artifact, delimiter: str = '\n') -> str:
93+
"""Extracts and joins all text content from an Artifact's parts."""
94+
return delimiter.join(get_text_parts(artifact.parts))
95+
96+
97+
# --- Task Helpers ---
98+
99+
100+
def new_task_from_request(request: Message) -> Task:
101+
"""Creates a new Task object from an initial user message."""
102+
if not request.role:
103+
raise TypeError('Message role cannot be None')
104+
if not request.parts:
105+
raise ValueError('Message parts cannot be empty')
106+
for part in request.parts:
107+
if part.HasField('text') and not part.text:
108+
raise ValueError('Message.text cannot be empty')
109+
110+
return Task(
111+
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
112+
id=request.task_id or str(uuid.uuid4()),
113+
context_id=request.context_id or str(uuid.uuid4()),
114+
history=[request],
115+
)
116+
117+
118+
def new_task(
119+
task_id: str,
120+
context_id: str,
121+
state: TaskState,
122+
artifacts: list[Artifact] | None = None,
123+
history: list[Message] | None = None,
124+
) -> Task:
125+
"""Creates a Task object with a specified status."""
126+
if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
127+
raise ValueError(
128+
'artifacts must be a non-empty list of Artifact objects'
129+
)
130+
131+
if history is None:
132+
history = []
133+
return Task(
134+
status=TaskStatus(state=state),
135+
id=task_id,
136+
context_id=context_id,
137+
artifacts=artifacts,
138+
history=history,
139+
)
140+
141+
142+
# --- Part Helpers ---
143+
144+
145+
def get_text_parts(parts: Sequence[Part]) -> list[str]:
146+
"""Extracts text content from all text Parts."""
147+
return [part.text for part in parts if part.HasField('text')]
148+
149+
150+
# --- Event & Stream Helpers ---
151+
152+
153+
def new_text_status_update_event(
154+
task_id: str,
155+
context_id: str,
156+
state: TaskState,
157+
text: str,
158+
) -> TaskStatusUpdateEvent:
159+
"""Creates a TaskStatusUpdateEvent with a single text message."""
160+
return TaskStatusUpdateEvent(
161+
task_id=task_id,
162+
context_id=context_id,
163+
status=TaskStatus(
164+
state=state,
165+
message=new_text_message(
166+
text=text,
167+
role=Role.ROLE_AGENT,
168+
context_id=context_id,
169+
task_id=task_id,
170+
),
171+
),
172+
)
173+
174+
175+
def new_text_artifact_update_event(
176+
task_id: str,
177+
context_id: str,
178+
name: str,
179+
text: str,
180+
append: bool = False,
181+
last_chunk: bool = False,
182+
) -> TaskArtifactUpdateEvent:
183+
"""Creates a TaskArtifactUpdateEvent with a single text artifact."""
184+
return TaskArtifactUpdateEvent(
185+
task_id=task_id,
186+
context_id=context_id,
187+
artifact=Artifact(
188+
artifact_id=str(uuid.uuid4()), name=name, parts=[Part(text=text)]
189+
),
190+
append=append,
191+
last_chunk=last_chunk,
192+
)
193+
194+
195+
def get_stream_response_text(response: StreamResponse, delimiter: str = '\n') -> str:
196+
"""Extracts text content from a StreamResponse."""
197+
if response.HasField('message'):
198+
return get_message_text(response.message, delimiter)
199+
elif response.HasField('task'):
200+
if response.task.status.HasField('message'):
201+
return get_message_text(response.task.status.message, delimiter)
202+
return ''
203+
elif response.HasField('status_update'):
204+
if response.status_update.status.HasField('message'):
205+
return get_message_text(response.status_update.status.message, delimiter)
206+
return ''
207+
elif response.HasField('artifact_update'):
208+
return get_artifact_text(response.artifact_update.artifact, delimiter)
209+
return ''

src/a2a/server/tasks/task_manager.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,83 @@
44
from a2a.server.events.event_queue import Event
55
from a2a.server.tasks.task_store import TaskStore
66
from a2a.types.a2a_pb2 import (
7+
Artifact,
78
Message,
89
Task,
910
TaskArtifactUpdateEvent,
1011
TaskState,
1112
TaskStatus,
1213
TaskStatusUpdateEvent,
1314
)
14-
from a2a.utils import append_artifact_to_task
1515
from a2a.utils.errors import InvalidParamsError
1616

1717

1818
logger = logging.getLogger(__name__)
1919

2020

21+
def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None:
22+
"""Helper method for updating a Task object with new artifact data from an event.
23+
24+
Handles creating the artifacts list if it doesn't exist, adding new artifacts,
25+
and appending parts to existing artifacts based on the `append` flag in the event.
26+
27+
Args:
28+
task: The `Task` object to modify.
29+
event: The `TaskArtifactUpdateEvent` containing the artifact data.
30+
"""
31+
new_artifact_data: Artifact = event.artifact
32+
artifact_id: str = new_artifact_data.artifact_id
33+
append_parts: bool = event.append
34+
35+
existing_artifact: Artifact | None = None
36+
existing_artifact_list_index: int | None = None
37+
38+
# Find existing artifact by its id
39+
for i, art in enumerate(task.artifacts):
40+
if art.artifact_id == artifact_id:
41+
existing_artifact = art
42+
existing_artifact_list_index = i
43+
break
44+
45+
if not append_parts:
46+
# This represents the first chunk for this artifact index.
47+
if existing_artifact_list_index is not None:
48+
# Replace the existing artifact entirely with the new data
49+
logger.debug(
50+
'Replacing artifact at id %s for task %s', artifact_id, task.id
51+
)
52+
task.artifacts[existing_artifact_list_index].CopyFrom(
53+
new_artifact_data
54+
)
55+
else:
56+
# Append the new artifact since no artifact with this index exists yet
57+
logger.debug(
58+
'Adding new artifact with id %s for task %s',
59+
artifact_id,
60+
task.id,
61+
)
62+
task.artifacts.append(new_artifact_data)
63+
elif existing_artifact:
64+
# Append new parts to the existing artifact's part list
65+
logger.debug(
66+
'Appending parts to artifact id %s for task %s',
67+
artifact_id,
68+
task.id,
69+
)
70+
existing_artifact.parts.extend(new_artifact_data.parts)
71+
existing_artifact.metadata.update(
72+
dict(new_artifact_data.metadata.items())
73+
)
74+
else:
75+
# We received a chunk to append, but we don't have an existing artifact.
76+
# we will ignore this chunk
77+
logger.warning(
78+
'Received append=True for nonexistent artifact index %s in task %s. Ignoring chunk.',
79+
artifact_id,
80+
task.id,
81+
)
82+
83+
2184
class TaskManager:
2285
"""Helps manage a task's lifecycle during execution of a request.
2386

0 commit comments

Comments
 (0)