|
1 | 1 | from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable |
2 | | -from typing import Any, Literal, cast |
| 2 | +from typing import Any, cast |
3 | 3 |
|
4 | 4 | from a2a.client.client import ( |
5 | 5 | Client, |
@@ -108,9 +108,13 @@ async def send_message( |
108 | 108 | return |
109 | 109 |
|
110 | 110 | async for event in self._execute_stream_with_interceptors( |
111 | | - input_data=ClientCallInput(method='send_message_streaming', value=request), |
| 111 | + input_data=ClientCallInput( |
| 112 | + method='send_message_streaming', value=request |
| 113 | + ), |
112 | 114 | context=context, |
113 | | - transport_call=lambda req, ctx: self._transport.send_message_streaming(req, context=ctx), |
| 115 | + transport_call=lambda req, ctx: ( |
| 116 | + self._transport.send_message_streaming(req, context=ctx) |
| 117 | + ), |
114 | 118 | ): |
115 | 119 | yield event |
116 | 120 |
|
@@ -148,22 +152,13 @@ async def _process_stream( |
148 | 152 | ) |
149 | 153 | await self._intercept_after(cast('UnionAfterArgs', after_args)) |
150 | 154 | intercepted_response = after_args.result.value |
151 | | - client_event: ClientEvent |
152 | | - # When we get a message in the stream then we don't expect any |
153 | | - # further messages so yield and return |
| 155 | + client_event = await self._format_stream_event( |
| 156 | + intercepted_response, tracker |
| 157 | + ) |
| 158 | + yield client_event |
154 | 159 | if intercepted_response.HasField('message'): |
155 | | - client_event = (intercepted_response, None) |
156 | | - await self.consume(client_event, self._card) |
157 | | - yield client_event |
158 | 160 | return |
159 | 161 |
|
160 | | - # Otherwise track the task / task update then yield to the client |
161 | | - await tracker.process(intercepted_response) |
162 | | - updated_task = tracker.get_task_or_raise() |
163 | | - client_event = (intercepted_response, updated_task) |
164 | | - await self.consume(client_event, self._card) |
165 | | - yield client_event |
166 | | - |
167 | 162 | async def get_task( |
168 | 163 | self, |
169 | 164 | request: GetTaskRequest, |
@@ -358,7 +353,9 @@ async def subscribe( |
358 | 353 | async for event in self._execute_stream_with_interceptors( |
359 | 354 | input_data=ClientCallInput(method='subscribe', value=request), |
360 | 355 | context=context, |
361 | | - transport_call=lambda req, ctx: self._transport.subscribe(req, context=ctx), |
| 356 | + transport_call=lambda req, ctx: self._transport.subscribe( |
| 357 | + req, context=ctx |
| 358 | + ), |
362 | 359 | ): |
363 | 360 | yield event |
364 | 361 |
|
@@ -443,33 +440,39 @@ async def _execute_with_interceptors( |
443 | 440 | await self._intercept_after(cast('UnionAfterArgs', after_args)) |
444 | 441 |
|
445 | 442 | return after_args.result.value |
446 | | - |
| 443 | + |
447 | 444 | async def _execute_stream_with_interceptors( |
448 | 445 | self, |
449 | 446 | input_data: ClientCallInput[M, P], |
450 | 447 | context: ClientCallContext | None, |
451 | | - transport_call: Callable[[P, ClientCallContext | None], AsyncIterator[StreamResponse]], |
| 448 | + transport_call: Callable[ |
| 449 | + [P, ClientCallContext | None], AsyncIterator[StreamResponse] |
| 450 | + ], |
452 | 451 | ) -> AsyncIterator[ClientEvent]: |
453 | 452 |
|
454 | 453 | before_args: BeforeArgs[M, P, StreamResponse] = BeforeArgs( |
455 | 454 | input=input_data, |
456 | 455 | agent_card=self._card, |
457 | 456 | context=context, |
458 | 457 | ) |
459 | | - before_result = await self._intercept_before(cast('UnionBeforeArgs', before_args)) |
| 458 | + before_result = await self._intercept_before( |
| 459 | + cast('UnionBeforeArgs', before_args) |
| 460 | + ) |
460 | 461 |
|
461 | 462 | if before_result: |
462 | 463 | after_args: AfterArgs[M, StreamResponse] = AfterArgs( |
463 | | - result=before_result.early_return, |
| 464 | + result=before_result['early_return'], |
464 | 465 | agent_card=self._card, |
465 | 466 | context=before_args.context, |
466 | 467 | ) |
467 | 468 | await self._intercept_after( |
468 | | - cast('UnionAfterArgs', after_args), |
469 | | - before_result.executed |
| 469 | + cast('UnionAfterArgs', after_args), before_result['executed'] |
| 470 | + ) |
| 471 | + |
| 472 | + tracker = ClientTaskManager() |
| 473 | + yield await self._format_stream_event( |
| 474 | + after_args.result.value, tracker |
470 | 475 | ) |
471 | | - |
472 | | - yield await self._format_stream_event(after_args.result.value) |
473 | 476 | return |
474 | 477 |
|
475 | 478 | stream = transport_call(before_args.input.value, before_args.context) |
@@ -509,13 +512,17 @@ async def _intercept_after( |
509 | 512 | if args.early_return: |
510 | 513 | return |
511 | 514 |
|
512 | | - async def _format_stream_event(self, stream_response: StreamResponse) -> ClientEvent: |
| 515 | + async def _format_stream_event( |
| 516 | + self, stream_response: StreamResponse, tracker: ClientTaskManager |
| 517 | + ) -> ClientEvent: |
513 | 518 | if stream_response.HasField('message'): |
514 | 519 | client_event = (stream_response, None) |
515 | | - elif stream_response.HasField('task'): |
516 | | - client_event = (stream_response, stream_response.task) |
517 | | - else: |
518 | | - client_event = (stream_response, None) |
519 | | - |
| 520 | + await self.consume(client_event, self._card) |
| 521 | + return client_event |
| 522 | + |
| 523 | + await tracker.process(stream_response) |
| 524 | + updated_task = tracker.get_task_or_raise() |
| 525 | + client_event = (stream_response, updated_task) |
| 526 | + |
520 | 527 | await self.consume(client_event, self._card) |
521 | 528 | return client_event |
0 commit comments