22from collections .abc import AsyncGenerator , Callable
33from types import TracebackType
44
5- from google .protobuf .message import Message
65from typing_extensions import Self
76
87from a2a .client .middleware import ClientCallContext
@@ -168,20 +167,15 @@ def __init__(self, base: ClientTransport, tenant: str):
168167 self ._base = base
169168 self ._tenant = tenant
170169
171- def update_tenant (self , request : Message ) -> str | None :
172- """Ensures the tenant is set on the request if provided and not already set .
170+ def update_tenant (self , tenant : str ) -> str :
171+ """If tenant is not provided, use the default tenant .
173172
174173 Returns:
175174 The tenant used for the request.
176175 """
177- current_tenant = getattr (request , 'tenant' , None )
178- if current_tenant :
179- return current_tenant
180-
181- if self ._tenant and hasattr (request , 'tenant' ):
182- setattr (request , 'tenant' , self ._tenant )
183- return self ._tenant
184- return None
176+ if tenant != '' :
177+ return tenant
178+ return self ._tenant or ''
185179
186180 async def send_message (
187181 self ,
@@ -191,7 +185,7 @@ async def send_message(
191185 extensions : list [str ] | None = None ,
192186 ) -> SendMessageResponse :
193187 """Sends a streaming message request to the agent and yields responses as they arrive."""
194- self .update_tenant (request )
188+ request . tenant = self .update_tenant (request . tenant )
195189 return await self ._base .send_message (
196190 request , context = context , extensions = extensions
197191 )
@@ -204,7 +198,7 @@ async def send_message_streaming(
204198 extensions : list [str ] | None = None ,
205199 ) -> AsyncGenerator [StreamResponse ]:
206200 """Sends a streaming message request to the agent and yields responses."""
207- self .update_tenant (request )
201+ request . tenant = self .update_tenant (request . tenant )
208202 async for event in self ._base .send_message_streaming (
209203 request , context = context , extensions = extensions
210204 ):
@@ -218,7 +212,7 @@ async def get_task(
218212 extensions : list [str ] | None = None ,
219213 ) -> Task :
220214 """Retrieves the current state and history of a specific task."""
221- self .update_tenant (request )
215+ request . tenant = self .update_tenant (request . tenant )
222216 return await self ._base .get_task (
223217 request , context = context , extensions = extensions
224218 )
@@ -231,7 +225,7 @@ async def list_tasks(
231225 extensions : list [str ] | None = None ,
232226 ) -> ListTasksResponse :
233227 """Retrieves tasks for an agent."""
234- self .update_tenant (request )
228+ request . tenant = self .update_tenant (request . tenant )
235229 return await self ._base .list_tasks (
236230 request , context = context , extensions = extensions
237231 )
@@ -244,7 +238,7 @@ async def cancel_task(
244238 extensions : list [str ] | None = None ,
245239 ) -> Task :
246240 """Requests the agent to cancel a specific task."""
247- self .update_tenant (request )
241+ request . tenant = self .update_tenant (request . tenant )
248242 return await self ._base .cancel_task (
249243 request , context = context , extensions = extensions
250244 )
@@ -257,7 +251,7 @@ async def create_task_push_notification_config(
257251 extensions : list [str ] | None = None ,
258252 ) -> TaskPushNotificationConfig :
259253 """Sets or updates the push notification configuration for a specific task."""
260- self .update_tenant (request )
254+ request . tenant = self .update_tenant (request . tenant )
261255 return await self ._base .create_task_push_notification_config (
262256 request , context = context , extensions = extensions
263257 )
@@ -270,7 +264,7 @@ async def get_task_push_notification_config(
270264 extensions : list [str ] | None = None ,
271265 ) -> TaskPushNotificationConfig :
272266 """Retrieves the push notification configuration for a specific task."""
273- self .update_tenant (request )
267+ request . tenant = self .update_tenant (request . tenant )
274268 return await self ._base .get_task_push_notification_config (
275269 request , context = context , extensions = extensions
276270 )
@@ -283,7 +277,7 @@ async def list_task_push_notification_configs(
283277 extensions : list [str ] | None = None ,
284278 ) -> ListTaskPushNotificationConfigsResponse :
285279 """Lists push notification configurations for a specific task."""
286- self .update_tenant (request )
280+ request . tenant = self .update_tenant (request . tenant )
287281 return await self ._base .list_task_push_notification_configs (
288282 request , context = context , extensions = extensions
289283 )
@@ -296,7 +290,7 @@ async def delete_task_push_notification_config(
296290 extensions : list [str ] | None = None ,
297291 ) -> None :
298292 """Deletes the push notification configuration for a specific task."""
299- self .update_tenant (request )
293+ request . tenant = self .update_tenant (request . tenant )
300294 await self ._base .delete_task_push_notification_config (
301295 request , context = context , extensions = extensions
302296 )
@@ -309,7 +303,7 @@ async def subscribe(
309303 extensions : list [str ] | None = None ,
310304 ) -> AsyncGenerator [StreamResponse ]:
311305 """Reconnects to get task updates."""
312- self .update_tenant (request )
306+ request . tenant = self .update_tenant (request . tenant )
313307 async for event in self ._base .subscribe (
314308 request , context = context , extensions = extensions
315309 ):
0 commit comments