Skip to content

Commit ab134b3

Browse files
committed
refactor: centralize stream interception and execution logic into _execute_stream_with_interceptors and add _format_stream_event.
1 parent ef658ff commit ab134b3

1 file changed

Lines changed: 52 additions & 67 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 52 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -107,43 +107,12 @@ async def send_message(
107107
yield client_event
108108
return
109109

110-
before_args: BeforeArgs[
111-
Literal['send_message_streaming'],
112-
SendMessageRequest,
113-
StreamResponse,
114-
] = BeforeArgs(
115-
input=ClientCallInput(
116-
method='send_message_streaming', value=request
117-
),
118-
agent_card=self._card,
110+
async for event in self._execute_stream_with_interceptors(
111+
input_data=ClientCallInput(method='send_message_streaming', value=request),
119112
context=context,
120-
)
121-
before_result = await self._intercept_before(before_args)
122-
123-
if before_result is not None:
124-
after_args = AfterArgs(
125-
result=ClientCallResult(
126-
method=before_args.input.method,
127-
value=before_result['early_return'].value,
128-
),
129-
agent_card=self._card,
130-
context=before_args.context,
131-
)
132-
await self._intercept_after(
133-
cast('UnionAfterArgs', after_args), before_result['executed']
134-
)
135-
yield after_args.result.value
136-
return
137-
138-
stream = self._transport.send_message_streaming(
139-
before_args.input.value, context=before_args.context
140-
)
141-
142-
async for client_event in self._process_stream(
143-
stream,
144-
before_args=before_args,
113+
transport_call=lambda req, ctx: self._transport.send_message_streaming(req, context=ctx),
145114
):
146-
yield client_event
115+
yield event
147116

148117
def _apply_client_config(self, request: SendMessageRequest) -> None:
149118
if not request.configuration.blocking and self._config.polling:
@@ -386,40 +355,12 @@ async def subscribe(
386355
'client and/or server do not support resubscription.'
387356
)
388357

389-
# Note: resubscribe can only be called on an existing task. As such,
390-
before_args: BeforeArgs[
391-
Literal['subscribe'], SubscribeToTaskRequest, StreamResponse
392-
] = BeforeArgs(
393-
input=ClientCallInput(method='subscribe', value=request),
394-
agent_card=self._card,
358+
async for event in self._execute_stream_with_interceptors(
359+
input_data=ClientCallInput(method='subscribe', value=request),
395360
context=context,
396-
)
397-
before_result = await self._intercept_before(before_args)
398-
399-
if before_result is not None:
400-
after_args = AfterArgs(
401-
result=ClientCallResult(
402-
method=before_args.input.method,
403-
value=before_result['early_return'].value,
404-
),
405-
agent_card=self._card,
406-
context=before_args.context,
407-
)
408-
await self._intercept_after(
409-
cast('UnionAfterArgs', after_args), before_result['executed']
410-
)
411-
yield after_args.result.value
412-
return
413-
414-
stream = self._transport.subscribe(
415-
before_args.input.value, context=before_args.context
416-
)
417-
418-
async for client_event in self._process_stream(
419-
stream,
420-
before_args=before_args,
361+
transport_call=lambda req, ctx: self._transport.subscribe(req, context=ctx),
421362
):
422-
yield client_event
363+
yield event
423364

424365
async def get_extended_agent_card(
425366
self,
@@ -502,6 +443,39 @@ async def _execute_with_interceptors(
502443
await self._intercept_after(cast('UnionAfterArgs', after_args))
503444

504445
return after_args.result.value
446+
447+
async def _execute_stream_with_interceptors(
448+
self,
449+
input_data: ClientCallInput[M, P],
450+
context: ClientCallContext | None,
451+
transport_call: Callable[[P, ClientCallContext | None], AsyncIterator[StreamResponse]],
452+
) -> AsyncIterator[ClientEvent]:
453+
454+
before_args: BeforeArgs[M, P, StreamResponse] = BeforeArgs(
455+
input=input_data,
456+
agent_card=self._card,
457+
context=context,
458+
)
459+
before_result = await self._intercept_before(cast('UnionBeforeArgs', before_args))
460+
461+
if before_result:
462+
after_args: AfterArgs[M, StreamResponse] = AfterArgs(
463+
result=before_result.early_return,
464+
agent_card=self._card,
465+
context=before_args.context,
466+
)
467+
await self._intercept_after(
468+
cast('UnionAfterArgs', after_args),
469+
before_result.executed
470+
)
471+
472+
yield await self._format_stream_event(after_args.result.value)
473+
return
474+
475+
stream = transport_call(before_args.input.value, before_args.context)
476+
477+
async for client_event in self._process_stream(stream, before_args):
478+
yield client_event
505479

506480
async def _intercept_before(
507481
self,
@@ -534,3 +508,14 @@ async def _intercept_after(
534508
await interceptor.after(args)
535509
if args.early_return:
536510
return
511+
512+
async def _format_stream_event(self, stream_response: StreamResponse) -> ClientEvent:
513+
if stream_response.HasField('message'):
514+
client_event = (stream_response, None)
515+
elif stream_response.HasField('task'):
516+
client_event = (stream_response, stream_response.task)
517+
else:
518+
client_event = (stream_response, None)
519+
520+
await self.consume(client_event, self._card)
521+
return client_event

0 commit comments

Comments
 (0)