Skip to content

Commit be01104

Browse files
committed
feat: handle tenant in Client
1 parent 041f0f5 commit be01104

8 files changed

Lines changed: 606 additions & 29 deletions

File tree

src/a2a/client/client_factory.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from a2a.client.card_resolver import A2ACardResolver
1212
from a2a.client.client import Client, ClientConfig, Consumer
1313
from a2a.client.middleware import ClientCallInterceptor
14-
from a2a.client.transports.base import ClientTransport
14+
from a2a.client.transports.base import ClientTransport, TenantTransportDecorator
1515
from a2a.client.transports.jsonrpc import JsonRpcTransport
1616
from a2a.client.transports.rest import RestTransport
1717
from a2a.types.a2a_pb2 import (
@@ -208,28 +208,27 @@ def create(
208208
TransportProtocol.JSONRPC
209209
]
210210
transport_protocol = None
211-
transport_url = None
211+
selected_interface = None
212212
if self._config.use_client_preference:
213213
for protocol_binding in client_set:
214-
supported_interface = next(
214+
selected_interface = next(
215215
(
216216
si
217217
for si in card.supported_interfaces
218218
if si.protocol_binding == protocol_binding
219219
),
220220
None,
221221
)
222-
if supported_interface:
222+
if selected_interface:
223223
transport_protocol = protocol_binding
224-
transport_url = supported_interface.url
225224
break
226225
else:
227226
for supported_interface in card.supported_interfaces:
228227
if supported_interface.protocol_binding in client_set:
229228
transport_protocol = supported_interface.protocol_binding
230-
transport_url = supported_interface.url
229+
selected_interface = supported_interface
231230
break
232-
if not transport_protocol or not transport_url:
231+
if not transport_protocol or not selected_interface:
233232
raise ValueError('no compatible transports found.')
234233
if transport_protocol not in self._registry:
235234
raise ValueError(f'no client available for {transport_protocol}')
@@ -244,9 +243,14 @@ def create(
244243
self._config.extensions = all_extensions
245244

246245
transport = self._registry[transport_protocol](
247-
card, transport_url, self._config, interceptors or []
246+
card, selected_interface.url, self._config, interceptors or []
248247
)
249248

249+
if selected_interface.tenant:
250+
transport = TenantTransportDecorator(
251+
transport, selected_interface.tenant
252+
)
253+
250254
return BaseClient(
251255
card,
252256
self._config,

src/a2a/client/transports/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""A2A Client Transports."""
22

3-
from a2a.client.transports.base import ClientTransport
3+
from a2a.client.transports.base import ClientTransport, TenantTransportDecorator
44
from a2a.client.transports.jsonrpc import JsonRpcTransport
55
from a2a.client.transports.rest import RestTransport
66

@@ -16,4 +16,5 @@
1616
'GrpcTransport',
1717
'JsonRpcTransport',
1818
'RestTransport',
19+
'TenantTransportDecorator',
1920
]

src/a2a/client/transports/base.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from types import TracebackType
44

55
from typing_extensions import Self
6+
from google.protobuf.message import Message
67

78
from a2a.client.middleware import ClientCallContext
89
from a2a.types.a2a_pb2 import (
@@ -158,3 +159,164 @@ async def get_extended_agent_card(
158159
@abstractmethod
159160
async def close(self) -> None:
160161
"""Closes the transport."""
162+
163+
164+
class TenantTransportDecorator(ClientTransport):
165+
"""A transport decorator that attaches a tenant to all requests."""
166+
167+
def __init__(self, base: ClientTransport, tenant: str):
168+
self._base = base
169+
self._tenant = tenant
170+
171+
def update_tenant(self, request: Message) -> str | None:
172+
"""Ensures the tenant is set on the request if provided and not already set.
173+
174+
Returns:
175+
The tenant used for the request.
176+
"""
177+
current_tenant = getattr(request, 'tenant', None)
178+
if current_tenant:
179+
return current_tenant
180+
181+
if self._tenant and hasattr(request, 'tenant'):
182+
request.tenant = self._tenant
183+
return self._tenant
184+
return None
185+
186+
async def send_message(
187+
self,
188+
request: SendMessageRequest,
189+
*,
190+
context: ClientCallContext | None = None,
191+
extensions: list[str] | None = None,
192+
) -> SendMessageResponse:
193+
self.update_tenant(request)
194+
return await self._base.send_message(
195+
request, context=context, extensions=extensions
196+
)
197+
198+
async def send_message_streaming(
199+
self,
200+
request: SendMessageRequest,
201+
*,
202+
context: ClientCallContext | None = None,
203+
extensions: list[str] | None = None,
204+
) -> AsyncGenerator[StreamResponse]:
205+
self.update_tenant(request)
206+
async for event in self._base.send_message_streaming(
207+
request, context=context, extensions=extensions
208+
):
209+
yield event
210+
211+
async def get_task(
212+
self,
213+
request: GetTaskRequest,
214+
*,
215+
context: ClientCallContext | None = None,
216+
extensions: list[str] | None = None,
217+
) -> Task:
218+
self.update_tenant(request)
219+
return await self._base.get_task(
220+
request, context=context, extensions=extensions
221+
)
222+
223+
async def list_tasks(
224+
self,
225+
request: ListTasksRequest,
226+
*,
227+
context: ClientCallContext | None = None,
228+
extensions: list[str] | None = None,
229+
) -> ListTasksResponse:
230+
self.update_tenant(request)
231+
return await self._base.list_tasks(
232+
request, context=context, extensions=extensions
233+
)
234+
235+
async def cancel_task(
236+
self,
237+
request: CancelTaskRequest,
238+
*,
239+
context: ClientCallContext | None = None,
240+
extensions: list[str] | None = None,
241+
) -> Task:
242+
self.update_tenant(request)
243+
return await self._base.cancel_task(
244+
request, context=context, extensions=extensions
245+
)
246+
247+
async def set_task_callback(
248+
self,
249+
request: CreateTaskPushNotificationConfigRequest,
250+
*,
251+
context: ClientCallContext | None = None,
252+
extensions: list[str] | None = None,
253+
) -> TaskPushNotificationConfig:
254+
self.update_tenant(request)
255+
return await self._base.set_task_callback(
256+
request, context=context, extensions=extensions
257+
)
258+
259+
async def get_task_callback(
260+
self,
261+
request: GetTaskPushNotificationConfigRequest,
262+
*,
263+
context: ClientCallContext | None = None,
264+
extensions: list[str] | None = None,
265+
) -> TaskPushNotificationConfig:
266+
self.update_tenant(request)
267+
return await self._base.get_task_callback(
268+
request, context=context, extensions=extensions
269+
)
270+
271+
async def list_task_callback(
272+
self,
273+
request: ListTaskPushNotificationConfigsRequest,
274+
*,
275+
context: ClientCallContext | None = None,
276+
extensions: list[str] | None = None,
277+
) -> ListTaskPushNotificationConfigsResponse:
278+
self.update_tenant(request)
279+
return await self._base.list_task_callback(
280+
request, context=context, extensions=extensions
281+
)
282+
283+
async def delete_task_callback(
284+
self,
285+
request: DeleteTaskPushNotificationConfigRequest,
286+
*,
287+
context: ClientCallContext | None = None,
288+
extensions: list[str] | None = None,
289+
) -> None:
290+
self.update_tenant(request)
291+
await self._base.delete_task_callback(
292+
request, context=context, extensions=extensions
293+
)
294+
295+
async def subscribe(
296+
self,
297+
request: SubscribeToTaskRequest,
298+
*,
299+
context: ClientCallContext | None = None,
300+
extensions: list[str] | None = None,
301+
) -> AsyncGenerator[StreamResponse]:
302+
self.update_tenant(request)
303+
async for event in self._base.subscribe(
304+
request, context=context, extensions=extensions
305+
):
306+
yield event
307+
308+
async def get_extended_agent_card(
309+
self,
310+
*,
311+
context: ClientCallContext | None = None,
312+
extensions: list[str] | None = None,
313+
signature_verifier: Callable[[AgentCard], None] | None = None,
314+
) -> AgentCard:
315+
return await self._base.get_extended_agent_card(
316+
context=context,
317+
extensions=extensions,
318+
signature_verifier=signature_verifier,
319+
)
320+
321+
async def close(self) -> None:
322+
await self._base.close()

0 commit comments

Comments
 (0)