forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtext_client.py
More file actions
114 lines (94 loc) · 3.91 KB
/
text_client.py
File metadata and controls
114 lines (94 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import uuid
from types import TracebackType
from typing_extensions import Self
from a2a.client.client import Client, ClientCallContext
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
from a2a.utils import get_artifact_text, get_message_text
_TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.TASK_STATE_COMPLETED,
TaskState.TASK_STATE_FAILED,
TaskState.TASK_STATE_CANCELED,
TaskState.TASK_STATE_REJECTED,
}
)
class TextClient:
"""A facade around Client that simplifies text-based communication.
Wraps an underlying Client instance and exposes a simplified interface
for sending plain-text messages and receiving aggregated text responses.
Maintains session state (context_id, task_id) automatically across calls.
For full Client API access, use the underlying client directly via
the `client` property.
"""
def __init__(self, client: Client):
self._client = client
self._context_id: str = str(uuid.uuid4())
self._task_id: str | None = None
async def __aenter__(self) -> Self:
"""Enters the async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exits the async context manager and closes the client."""
await self.close()
@property
def client(self) -> Client:
"""Returns the underlying Client instance for full API access."""
return self._client
def reset_session(self) -> None:
"""Starts a new session by generating a fresh context ID and clearing the task ID."""
self._context_id = str(uuid.uuid4())
self._task_id = None
async def send_text_message(
self,
text: str,
*,
delimiter: str = ' ',
context: ClientCallContext | None = None,
) -> str:
"""Sends a text message and returns the aggregated text response.
Session state (context_id, task_id) is managed automatically across
calls. Use reset_session() to start a new conversation.
Args:
text: The plain-text message to send.
delimiter: String used to join response parts. Defaults to a
single space. Use '' for token-streamed responses or a
newline for paragraph-separated chunks.
context: Optional call-level context.
"""
request = SendMessageRequest(
message=Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
context_id=self._context_id,
task_id=self._task_id,
parts=[Part(text=text)],
)
)
response_parts: list[str] = []
async for event in self._client.send_message(request, context=context):
if event.HasField('task'):
self._task_id = event.task.id
elif event.HasField('message'):
response_parts.append(get_message_text(event.message))
elif event.HasField('status_update'):
if not self._task_id and event.status_update.task_id:
self._task_id = event.status_update.task_id
if event.status_update.status.state in _TERMINAL_STATES:
self._task_id = None
if event.status_update.status.HasField('message'):
response_parts.append(
get_message_text(event.status_update.status.message)
)
elif event.HasField('artifact_update'):
response_parts.append(
get_artifact_text(event.artifact_update.artifact)
)
return delimiter.join(response_parts)
async def close(self) -> None:
"""Closes the underlying client."""
await self._client.close()