Skip to content

Commit 92bc02f

Browse files
committed
refactor: Update stream event processing to use a task manager, adjust _intercept_before return type, and optimize message event handling.
1 parent ab134b3 commit 92bc02f

1 file changed

Lines changed: 38 additions & 31 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
2-
from typing import Any, Literal, cast
2+
from typing import Any, cast
33

44
from a2a.client.client import (
55
Client,
@@ -108,9 +108,13 @@ async def send_message(
108108
return
109109

110110
async for event in self._execute_stream_with_interceptors(
111-
input_data=ClientCallInput(method='send_message_streaming', value=request),
111+
input_data=ClientCallInput(
112+
method='send_message_streaming', value=request
113+
),
112114
context=context,
113-
transport_call=lambda req, ctx: self._transport.send_message_streaming(req, context=ctx),
115+
transport_call=lambda req, ctx: (
116+
self._transport.send_message_streaming(req, context=ctx)
117+
),
114118
):
115119
yield event
116120

@@ -148,22 +152,13 @@ async def _process_stream(
148152
)
149153
await self._intercept_after(cast('UnionAfterArgs', after_args))
150154
intercepted_response = after_args.result.value
151-
client_event: ClientEvent
152-
# When we get a message in the stream then we don't expect any
153-
# further messages so yield and return
155+
client_event = await self._format_stream_event(
156+
intercepted_response, tracker
157+
)
158+
yield client_event
154159
if intercepted_response.HasField('message'):
155-
client_event = (intercepted_response, None)
156-
await self.consume(client_event, self._card)
157-
yield client_event
158160
return
159161

160-
# Otherwise track the task / task update then yield to the client
161-
await tracker.process(intercepted_response)
162-
updated_task = tracker.get_task_or_raise()
163-
client_event = (intercepted_response, updated_task)
164-
await self.consume(client_event, self._card)
165-
yield client_event
166-
167162
async def get_task(
168163
self,
169164
request: GetTaskRequest,
@@ -358,7 +353,9 @@ async def subscribe(
358353
async for event in self._execute_stream_with_interceptors(
359354
input_data=ClientCallInput(method='subscribe', value=request),
360355
context=context,
361-
transport_call=lambda req, ctx: self._transport.subscribe(req, context=ctx),
356+
transport_call=lambda req, ctx: self._transport.subscribe(
357+
req, context=ctx
358+
),
362359
):
363360
yield event
364361

@@ -443,33 +440,39 @@ async def _execute_with_interceptors(
443440
await self._intercept_after(cast('UnionAfterArgs', after_args))
444441

445442
return after_args.result.value
446-
443+
447444
async def _execute_stream_with_interceptors(
448445
self,
449446
input_data: ClientCallInput[M, P],
450447
context: ClientCallContext | None,
451-
transport_call: Callable[[P, ClientCallContext | None], AsyncIterator[StreamResponse]],
448+
transport_call: Callable[
449+
[P, ClientCallContext | None], AsyncIterator[StreamResponse]
450+
],
452451
) -> AsyncIterator[ClientEvent]:
453452

454453
before_args: BeforeArgs[M, P, StreamResponse] = BeforeArgs(
455454
input=input_data,
456455
agent_card=self._card,
457456
context=context,
458457
)
459-
before_result = await self._intercept_before(cast('UnionBeforeArgs', before_args))
458+
before_result = await self._intercept_before(
459+
cast('UnionBeforeArgs', before_args)
460+
)
460461

461462
if before_result:
462463
after_args: AfterArgs[M, StreamResponse] = AfterArgs(
463-
result=before_result.early_return,
464+
result=before_result['early_return'],
464465
agent_card=self._card,
465466
context=before_args.context,
466467
)
467468
await self._intercept_after(
468-
cast('UnionAfterArgs', after_args),
469-
before_result.executed
469+
cast('UnionAfterArgs', after_args), before_result['executed']
470+
)
471+
472+
tracker = ClientTaskManager()
473+
yield await self._format_stream_event(
474+
after_args.result.value, tracker
470475
)
471-
472-
yield await self._format_stream_event(after_args.result.value)
473476
return
474477

475478
stream = transport_call(before_args.input.value, before_args.context)
@@ -509,13 +512,17 @@ async def _intercept_after(
509512
if args.early_return:
510513
return
511514

512-
async def _format_stream_event(self, stream_response: StreamResponse) -> ClientEvent:
515+
async def _format_stream_event(
516+
self, stream_response: StreamResponse, tracker: ClientTaskManager
517+
) -> ClientEvent:
513518
if stream_response.HasField('message'):
514519
client_event = (stream_response, None)
515-
elif stream_response.HasField('task'):
516-
client_event = (stream_response, stream_response.task)
517-
else:
518-
client_event = (stream_response, None)
519-
520+
await self.consume(client_event, self._card)
521+
return client_event
522+
523+
await tracker.process(stream_response)
524+
updated_task = tracker.get_task_or_raise()
525+
client_event = (stream_response, updated_task)
526+
520527
await self.consume(client_event, self._card)
521528
return client_event

0 commit comments

Comments
 (0)