Skip to content

Commit b7b699a

Browse files
committed
refactor: remove ClientTaskManager and related consumers from client components
1 parent 9ccf99c commit b7b699a

15 files changed

Lines changed: 54 additions & 501 deletions

src/a2a/client/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
Client,
1212
ClientCallContext,
1313
ClientConfig,
14-
ClientEvent,
15-
Consumer,
1614
)
1715
from a2a.client.client_factory import ClientFactory, minimal_agent_card
1816
from a2a.client.errors import (
@@ -35,9 +33,7 @@
3533
'ClientCallContext',
3634
'ClientCallInterceptor',
3735
'ClientConfig',
38-
'ClientEvent',
3936
'ClientFactory',
40-
'Consumer',
4137
'CredentialService',
4238
'InMemoryContextCredentialStore',
4339
'create_text_message_object',

src/a2a/client/base_client.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
Client,
66
ClientCallContext,
77
ClientConfig,
8-
ClientEvent,
9-
Consumer,
108
)
11-
from a2a.client.client_task_manager import ClientTaskManager
129
from a2a.client.interceptors import (
1310
AfterArgs,
1411
BeforeArgs,
@@ -42,10 +39,9 @@ def __init__(
4239
card: AgentCard,
4340
config: ClientConfig,
4441
transport: ClientTransport,
45-
consumers: list[Consumer],
4642
interceptors: list[ClientCallInterceptor],
4743
):
48-
super().__init__(consumers, interceptors)
44+
super().__init__(interceptors)
4945
self._card = card
5046
self._config = config
5147
self._transport = transport
@@ -56,7 +52,7 @@ async def send_message(
5652
request: SendMessageRequest,
5753
*,
5854
context: ClientCallContext | None = None,
59-
) -> AsyncIterator[ClientEvent]:
55+
) -> AsyncIterator[StreamResponse]:
6056
"""Sends a message to the agent.
6157
6258
This method handles both streaming and non-streaming (polling) interactions
@@ -84,19 +80,14 @@ async def send_message(
8480
# In non-streaming case we convert to a StreamResponse so that the
8581
# client always sees the same iterator.
8682
stream_response = StreamResponse()
87-
client_event: ClientEvent
8883
if response.HasField('task'):
8984
stream_response.task.CopyFrom(response.task)
90-
client_event = (stream_response, response.task)
9185
elif response.HasField('message'):
9286
stream_response.message.CopyFrom(response.message)
93-
client_event = (stream_response, None)
9487
else:
95-
# Response must have either task or message
9688
raise ValueError('Response has neither task nor message')
9789

98-
await self.consume(client_event, self._card)
99-
yield client_event
90+
yield stream_response
10091
return
10192

10293
async for event in self._execute_stream_with_interceptors(
@@ -130,8 +121,7 @@ async def _process_stream(
130121
self,
131122
stream: AsyncIterator[StreamResponse],
132123
before_args: BeforeArgs,
133-
) -> AsyncGenerator[ClientEvent]:
134-
tracker = ClientTaskManager()
124+
) -> AsyncGenerator[StreamResponse, None]:
135125
async for stream_response in stream:
136126
after_args = AfterArgs(
137127
result=stream_response,
@@ -140,12 +130,8 @@ async def _process_stream(
140130
context=before_args.context,
141131
)
142132
await self._intercept_after(after_args)
143-
intercepted_response = after_args.result
144-
client_event = await self._format_stream_event(
145-
intercepted_response, tracker
146-
)
147-
yield client_event
148-
if intercepted_response.HasField('message'):
133+
yield after_args.result
134+
if after_args.result.HasField('message'):
149135
return
150136

151137
async def get_task(
@@ -318,7 +304,7 @@ async def subscribe(
318304
request: SubscribeToTaskRequest,
319305
*,
320306
context: ClientCallContext | None = None,
321-
) -> AsyncIterator[ClientEvent]:
307+
) -> AsyncIterator[StreamResponse]:
322308
"""Resubscribes to a task's event stream.
323309
324310
This is only available if both the client and server support streaming.
@@ -436,7 +422,7 @@ async def _execute_stream_with_interceptors(
436422
transport_call: Callable[
437423
[Any, ClientCallContext | None], AsyncIterator[StreamResponse]
438424
],
439-
) -> AsyncIterator[ClientEvent]:
425+
) -> AsyncIterator[StreamResponse]:
440426

441427
before_args = BeforeArgs(
442428
input=input_data,
@@ -455,8 +441,7 @@ async def _execute_stream_with_interceptors(
455441
)
456442
await self._intercept_after(after_args, before_result['executed'])
457443

458-
tracker = ClientTaskManager()
459-
yield await self._format_stream_event(after_args.result, tracker)
444+
yield after_args.result
460445
return
461446

462447
stream = transport_call(before_args.input, before_args.context)
@@ -495,19 +480,3 @@ async def _intercept_after(
495480
await interceptor.after(args)
496481
if args.early_return:
497482
return
498-
499-
async def _format_stream_event(
500-
self, stream_response: StreamResponse, tracker: ClientTaskManager
501-
) -> ClientEvent:
502-
client_event: ClientEvent
503-
if stream_response.HasField('message'):
504-
client_event = (stream_response, None)
505-
await self.consume(client_event, self._card)
506-
return client_event
507-
508-
await tracker.process(stream_response)
509-
updated_task = tracker.get_task_or_raise()
510-
client_event = (stream_response, updated_task)
511-
512-
await self.consume(client_event, self._card)
513-
return client_event

src/a2a/client/client.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33

44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
5+
from collections.abc import AsyncIterator, Callable, MutableMapping
66
from types import TracebackType
77
from typing import Any
88

@@ -77,13 +77,6 @@ class ClientConfig:
7777
"""Push notification configurations to use for every request."""
7878

7979

80-
ClientEvent = tuple[StreamResponse, Task | None]
81-
82-
# Alias for an event consuming callback. It takes either a (task, update) pair
83-
# or a message as well as the agent card for the agent this came from.
84-
Consumer = Callable[[ClientEvent, AgentCard], Coroutine[None, Any, Any]]
85-
86-
8780
class ClientCallContext(BaseModel):
8881
"""A context passed with each client call, allowing for call-specific.
8982
@@ -106,16 +99,13 @@ class Client(ABC):
10699

107100
def __init__(
108101
self,
109-
consumers: list[Consumer] | None = None,
110102
interceptors: list[ClientCallInterceptor] | None = None,
111103
):
112-
"""Initializes the client with consumers and interceptors.
104+
"""Initializes the client with interceptors.
113105
114106
Args:
115-
consumers: A list of callables to process events from the agent.
116107
interceptors: A list of interceptors to process requests and responses.
117108
"""
118-
self._consumers = consumers or []
119109
self._interceptors = interceptors or []
120110

121111
async def __aenter__(self) -> Self:
@@ -137,7 +127,7 @@ async def send_message(
137127
request: SendMessageRequest,
138128
*,
139129
context: ClientCallContext | None = None,
140-
) -> AsyncIterator[ClientEvent]:
130+
) -> AsyncIterator[StreamResponse]:
141131
"""Sends a message to the server.
142132
143133
This will automatically use the streaming or non-streaming approach
@@ -218,7 +208,7 @@ async def subscribe(
218208
request: SubscribeToTaskRequest,
219209
*,
220210
context: ClientCallContext | None = None,
221-
) -> AsyncIterator[ClientEvent]:
211+
) -> AsyncIterator[StreamResponse]:
222212
"""Resubscribes to a task's event stream."""
223213
return
224214
yield
@@ -233,23 +223,10 @@ async def get_extended_agent_card(
233223
) -> AgentCard:
234224
"""Retrieves the agent's card."""
235225

236-
async def add_event_consumer(self, consumer: Consumer) -> None:
237-
"""Attaches additional consumers to the `Client`."""
238-
self._consumers.append(consumer)
239-
240226
async def add_interceptor(self, interceptor: ClientCallInterceptor) -> None:
241227
"""Attaches additional interceptors to the `Client`."""
242228
self._interceptors.append(interceptor)
243229

244-
async def consume(
245-
self,
246-
event: ClientEvent,
247-
card: AgentCard,
248-
) -> None:
249-
"""Processes the event via all the registered `Consumer`s."""
250-
for c in self._consumers:
251-
await c(event, card)
252-
253230
@abstractmethod
254231
async def close(self) -> None:
255232
"""Closes the client and releases any underlying resources."""

src/a2a/client/client_factory.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from a2a.client.base_client import BaseClient
1313
from a2a.client.card_resolver import A2ACardResolver
14-
from a2a.client.client import Client, ClientConfig, Consumer
14+
from a2a.client.client import Client, ClientConfig
1515
from a2a.client.transports.base import ClientTransport
1616
from a2a.client.transports.jsonrpc import JsonRpcTransport
1717
from a2a.client.transports.rest import RestTransport
@@ -77,17 +77,12 @@ class ClientFactory:
7777
def __init__(
7878
self,
7979
config: ClientConfig,
80-
consumers: list[Consumer] | None = None,
8180
):
82-
if consumers is None:
83-
consumers = []
84-
8581
client = config.httpx_client or httpx.AsyncClient()
8682
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
8783
config.httpx_client = client
8884

8985
self._config = config
90-
self._consumers = consumers
9186
self._registry: dict[str, TransportProducer] = {}
9287
self._register_defaults(config.supported_protocol_bindings)
9388

@@ -263,7 +258,6 @@ async def connect( # noqa: PLR0913
263258
cls,
264259
agent: str | AgentCard,
265260
client_config: ClientConfig | None = None,
266-
consumers: list[Consumer] | None = None,
267261
interceptors: list[ClientCallInterceptor] | None = None,
268262
relative_card_path: str | None = None,
269263
resolver_http_kwargs: dict[str, Any] | None = None,
@@ -286,7 +280,7 @@ async def connect( # noqa: PLR0913
286280
Args:
287281
agent: The base URL of the agent, or the AgentCard to connect to.
288282
client_config: The ClientConfig to use when connecting to the agent.
289-
consumers: A list of `Consumer` methods to pass responses to.
283+
290284
interceptors: A list of interceptors to use for each request. These
291285
are used for things like attaching credentials or http headers
292286
to all outbound requests.
@@ -325,7 +319,7 @@ async def connect( # noqa: PLR0913
325319
factory = cls(client_config)
326320
for label, generator in (extra_transports or {}).items():
327321
factory.register(label, generator)
328-
return factory.create(card, consumers, interceptors)
322+
return factory.create(card, interceptors)
329323

330324
def register(self, label: str, generator: TransportProducer) -> None:
331325
"""Register a new transport producer for a given transport label."""
@@ -334,14 +328,12 @@ def register(self, label: str, generator: TransportProducer) -> None:
334328
def create(
335329
self,
336330
card: AgentCard,
337-
consumers: list[Consumer] | None = None,
338331
interceptors: list[ClientCallInterceptor] | None = None,
339332
) -> Client:
340333
"""Create a new `Client` for the provided `AgentCard`.
341334
342335
Args:
343336
card: An `AgentCard` defining the characteristics of the agent.
344-
consumers: A list of `Consumer` methods to pass responses to.
345337
interceptors: A list of interceptors to use for each request. These
346338
are used for things like attaching credentials or http headers
347339
to all outbound requests.
@@ -381,10 +373,6 @@ def create(
381373
if transport_protocol not in self._registry:
382374
raise ValueError(f'no client available for {transport_protocol}')
383375

384-
all_consumers = self._consumers.copy()
385-
if consumers:
386-
all_consumers.extend(consumers)
387-
388376
transport = self._registry[transport_protocol](
389377
card, selected_interface.url, self._config
390378
)
@@ -398,7 +386,6 @@ def create(
398386
card,
399387
self._config,
400388
transport,
401-
all_consumers,
402389
interceptors or [],
403390
)
404391

0 commit comments

Comments
 (0)