@@ -107,43 +107,12 @@ async def send_message(
107107 yield client_event
108108 return
109109
110- before_args : BeforeArgs [
111- Literal ['send_message_streaming' ],
112- SendMessageRequest ,
113- StreamResponse ,
114- ] = BeforeArgs (
115- input = ClientCallInput (
116- method = 'send_message_streaming' , value = request
117- ),
118- agent_card = self ._card ,
110+ async for event in self ._execute_stream_with_interceptors (
111+ input_data = ClientCallInput (method = 'send_message_streaming' , value = request ),
119112 context = context ,
120- )
121- before_result = await self ._intercept_before (before_args )
122-
123- if before_result is not None :
124- after_args = AfterArgs (
125- result = ClientCallResult (
126- method = before_args .input .method ,
127- value = before_result ['early_return' ].value ,
128- ),
129- agent_card = self ._card ,
130- context = before_args .context ,
131- )
132- await self ._intercept_after (
133- cast ('UnionAfterArgs' , after_args ), before_result ['executed' ]
134- )
135- yield after_args .result .value
136- return
137-
138- stream = self ._transport .send_message_streaming (
139- before_args .input .value , context = before_args .context
140- )
141-
142- async for client_event in self ._process_stream (
143- stream ,
144- before_args = before_args ,
113+ transport_call = lambda req , ctx : self ._transport .send_message_streaming (req , context = ctx ),
145114 ):
146- yield client_event
115+ yield event
147116
148117 def _apply_client_config (self , request : SendMessageRequest ) -> None :
149118 if not request .configuration .blocking and self ._config .polling :
@@ -386,40 +355,12 @@ async def subscribe(
386355 'client and/or server do not support resubscription.'
387356 )
388357
389- # Note: resubscribe can only be called on an existing task. As such,
390- before_args : BeforeArgs [
391- Literal ['subscribe' ], SubscribeToTaskRequest , StreamResponse
392- ] = BeforeArgs (
393- input = ClientCallInput (method = 'subscribe' , value = request ),
394- agent_card = self ._card ,
358+ async for event in self ._execute_stream_with_interceptors (
359+ input_data = ClientCallInput (method = 'subscribe' , value = request ),
395360 context = context ,
396- )
397- before_result = await self ._intercept_before (before_args )
398-
399- if before_result is not None :
400- after_args = AfterArgs (
401- result = ClientCallResult (
402- method = before_args .input .method ,
403- value = before_result ['early_return' ].value ,
404- ),
405- agent_card = self ._card ,
406- context = before_args .context ,
407- )
408- await self ._intercept_after (
409- cast ('UnionAfterArgs' , after_args ), before_result ['executed' ]
410- )
411- yield after_args .result .value
412- return
413-
414- stream = self ._transport .subscribe (
415- before_args .input .value , context = before_args .context
416- )
417-
418- async for client_event in self ._process_stream (
419- stream ,
420- before_args = before_args ,
361+ transport_call = lambda req , ctx : self ._transport .subscribe (req , context = ctx ),
421362 ):
422- yield client_event
363+ yield event
423364
424365 async def get_extended_agent_card (
425366 self ,
@@ -502,6 +443,39 @@ async def _execute_with_interceptors(
502443 await self ._intercept_after (cast ('UnionAfterArgs' , after_args ))
503444
504445 return after_args .result .value
446+
447+ async def _execute_stream_with_interceptors (
448+ self ,
449+ input_data : ClientCallInput [M , P ],
450+ context : ClientCallContext | None ,
451+ transport_call : Callable [[P , ClientCallContext | None ], AsyncIterator [StreamResponse ]],
452+ ) -> AsyncIterator [ClientEvent ]:
453+
454+ before_args : BeforeArgs [M , P , StreamResponse ] = BeforeArgs (
455+ input = input_data ,
456+ agent_card = self ._card ,
457+ context = context ,
458+ )
459+ before_result = await self ._intercept_before (cast ('UnionBeforeArgs' , before_args ))
460+
461+ if before_result :
462+ after_args : AfterArgs [M , StreamResponse ] = AfterArgs (
463+ result = before_result .early_return ,
464+ agent_card = self ._card ,
465+ context = before_args .context ,
466+ )
467+ await self ._intercept_after (
468+ cast ('UnionAfterArgs' , after_args ),
469+ before_result .executed
470+ )
471+
472+ yield await self ._format_stream_event (after_args .result .value )
473+ return
474+
475+ stream = transport_call (before_args .input .value , before_args .context )
476+
477+ async for client_event in self ._process_stream (stream , before_args ):
478+ yield client_event
505479
506480 async def _intercept_before (
507481 self ,
@@ -534,3 +508,14 @@ async def _intercept_after(
534508 await interceptor .after (args )
535509 if args .early_return :
536510 return
511+
512+ async def _format_stream_event (self , stream_response : StreamResponse ) -> ClientEvent :
513+ if stream_response .HasField ('message' ):
514+ client_event = (stream_response , None )
515+ elif stream_response .HasField ('task' ):
516+ client_event = (stream_response , stream_response .task )
517+ else :
518+ client_event = (stream_response , None )
519+
520+ await self .consume (client_event , self ._card )
521+ return client_event
0 commit comments