Skip to content

Commit c6e22ad

Browse files
committed
refactor: simplify tenant resolution logic in base
1 parent a4f7d91 commit c6e22ad

1 file changed

Lines changed: 15 additions & 21 deletions

File tree

src/a2a/client/transports/base.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import AsyncGenerator, Callable
33
from types import TracebackType
44

5-
from google.protobuf.message import Message
65
from typing_extensions import Self
76

87
from a2a.client.middleware import ClientCallContext
@@ -168,20 +167,15 @@ def __init__(self, base: ClientTransport, tenant: str):
168167
self._base = base
169168
self._tenant = tenant
170169

171-
def update_tenant(self, request: Message) -> str | None:
172-
"""Ensures the tenant is set on the request if provided and not already set.
170+
def update_tenant(self, tenant: str) -> str:
171+
"""If tenant is not provided, use the default tenant.
173172
174173
Returns:
175174
The tenant used for the request.
176175
"""
177-
current_tenant = getattr(request, 'tenant', None)
178-
if current_tenant:
179-
return current_tenant
180-
181-
if self._tenant and hasattr(request, 'tenant'):
182-
setattr(request, 'tenant', self._tenant)
183-
return self._tenant
184-
return None
176+
if tenant != '':
177+
return tenant
178+
return self._tenant or ''
185179

186180
async def send_message(
187181
self,
@@ -191,7 +185,7 @@ async def send_message(
191185
extensions: list[str] | None = None,
192186
) -> SendMessageResponse:
193187
"""Sends a streaming message request to the agent and yields responses as they arrive."""
194-
self.update_tenant(request)
188+
request.tenant = self.update_tenant(request.tenant)
195189
return await self._base.send_message(
196190
request, context=context, extensions=extensions
197191
)
@@ -204,7 +198,7 @@ async def send_message_streaming(
204198
extensions: list[str] | None = None,
205199
) -> AsyncGenerator[StreamResponse]:
206200
"""Sends a streaming message request to the agent and yields responses."""
207-
self.update_tenant(request)
201+
request.tenant = self.update_tenant(request.tenant)
208202
async for event in self._base.send_message_streaming(
209203
request, context=context, extensions=extensions
210204
):
@@ -218,7 +212,7 @@ async def get_task(
218212
extensions: list[str] | None = None,
219213
) -> Task:
220214
"""Retrieves the current state and history of a specific task."""
221-
self.update_tenant(request)
215+
request.tenant = self.update_tenant(request.tenant)
222216
return await self._base.get_task(
223217
request, context=context, extensions=extensions
224218
)
@@ -231,7 +225,7 @@ async def list_tasks(
231225
extensions: list[str] | None = None,
232226
) -> ListTasksResponse:
233227
"""Retrieves tasks for an agent."""
234-
self.update_tenant(request)
228+
request.tenant = self.update_tenant(request.tenant)
235229
return await self._base.list_tasks(
236230
request, context=context, extensions=extensions
237231
)
@@ -244,7 +238,7 @@ async def cancel_task(
244238
extensions: list[str] | None = None,
245239
) -> Task:
246240
"""Requests the agent to cancel a specific task."""
247-
self.update_tenant(request)
241+
request.tenant = self.update_tenant(request.tenant)
248242
return await self._base.cancel_task(
249243
request, context=context, extensions=extensions
250244
)
@@ -257,7 +251,7 @@ async def create_task_push_notification_config(
257251
extensions: list[str] | None = None,
258252
) -> TaskPushNotificationConfig:
259253
"""Sets or updates the push notification configuration for a specific task."""
260-
self.update_tenant(request)
254+
request.tenant = self.update_tenant(request.tenant)
261255
return await self._base.create_task_push_notification_config(
262256
request, context=context, extensions=extensions
263257
)
@@ -270,7 +264,7 @@ async def get_task_push_notification_config(
270264
extensions: list[str] | None = None,
271265
) -> TaskPushNotificationConfig:
272266
"""Retrieves the push notification configuration for a specific task."""
273-
self.update_tenant(request)
267+
request.tenant = self.update_tenant(request.tenant)
274268
return await self._base.get_task_push_notification_config(
275269
request, context=context, extensions=extensions
276270
)
@@ -283,7 +277,7 @@ async def list_task_push_notification_configs(
283277
extensions: list[str] | None = None,
284278
) -> ListTaskPushNotificationConfigsResponse:
285279
"""Lists push notification configurations for a specific task."""
286-
self.update_tenant(request)
280+
request.tenant = self.update_tenant(request.tenant)
287281
return await self._base.list_task_push_notification_configs(
288282
request, context=context, extensions=extensions
289283
)
@@ -296,7 +290,7 @@ async def delete_task_push_notification_config(
296290
extensions: list[str] | None = None,
297291
) -> None:
298292
"""Deletes the push notification configuration for a specific task."""
299-
self.update_tenant(request)
293+
request.tenant = self.update_tenant(request.tenant)
300294
await self._base.delete_task_push_notification_config(
301295
request, context=context, extensions=extensions
302296
)
@@ -309,7 +303,7 @@ async def subscribe(
309303
extensions: list[str] | None = None,
310304
) -> AsyncGenerator[StreamResponse]:
311305
"""Reconnects to get task updates."""
312-
self.update_tenant(request)
306+
request.tenant = self.update_tenant(request.tenant)
313307
async for event in self._base.subscribe(
314308
request, context=context, extensions=extensions
315309
):

0 commit comments

Comments
 (0)