Skip to content

Commit 39d5351

Browse files
committed
refactor: add helper _build_call_context and reduce code duplication in routes
1 parent fa2863b commit 39d5351

1 file changed

Lines changed: 17 additions & 57 deletions

File tree

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 17 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ async def _handle_request(
110110
method: Callable[[Request, ServerCallContext], Awaitable[Any]],
111111
request: Request,
112112
) -> Response:
113-
call_context = self._context_builder.build(request)
114-
if 'tenant' in request.path_params:
115-
call_context.tenant = request.path_params['tenant']
113+
call_context = self._build_call_context(request)
116114

117115
response = await method(request, call_context)
118116
return JSONResponse(content=response)
@@ -133,9 +131,7 @@ async def _handle_streaming_request(
133131
message=f'Failed to pre-consume request body: {e}'
134132
) from e
135133

136-
call_context = self._context_builder.build(request)
137-
if 'tenant' in request.path_params:
138-
call_context.tenant = request.path_params['tenant']
134+
call_context = self._build_call_context(request)
139135

140136
async def event_generator(
141137
stream: AsyncIterable[Any],
@@ -210,7 +206,7 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
210206
A dictionary where each key is a tuple of (path, http_method) and
211207
the value is the callable handler for that route.
212208
"""
213-
routes: dict[tuple[str, str], Callable[[Request], Any]] = {
209+
base_routes: dict[tuple[str, str], Callable[[Request], Any]] = {
214210
('/message:send', 'POST'): functools.partial(
215211
self._handle_request, self.handler.on_message_send
216212
),
@@ -255,59 +251,23 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
255251
('/tasks', 'GET'): functools.partial(
256252
self._handle_request, self.handler.list_tasks
257253
),
258-
# Tenant prefixed routes
259-
('/{tenant}/message:send', 'POST'): functools.partial(
260-
self._handle_request,
261-
self.handler.on_message_send,
262-
),
263-
('/{tenant}/message:stream', 'POST'): functools.partial(
264-
self._handle_streaming_request,
265-
self.handler.on_message_send_stream,
266-
),
267-
('/{tenant}/tasks/{id}:cancel', 'POST'): functools.partial(
268-
self._handle_request, self.handler.on_cancel_task
269-
),
270-
('/{tenant}/tasks/{id}:subscribe', 'GET'): functools.partial(
271-
self._handle_streaming_request,
272-
self.handler.on_subscribe_to_task,
273-
),
274-
('/{tenant}/tasks/{id}', 'GET'): functools.partial(
275-
self._handle_request, self.handler.on_get_task
276-
),
277-
(
278-
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
279-
'GET',
280-
): functools.partial(
281-
self._handle_request, self.handler.get_push_notification
282-
),
283-
(
284-
'/{tenant}/tasks/{id}/pushNotificationConfigs/{push_id}',
285-
'DELETE',
286-
): functools.partial(
287-
self._handle_request, self.handler.delete_push_notification
288-
),
289-
(
290-
'/{tenant}/tasks/{id}/pushNotificationConfigs',
291-
'POST',
292-
): functools.partial(
293-
self._handle_request, self.handler.set_push_notification
294-
),
295-
(
296-
'/{tenant}/tasks/{id}/pushNotificationConfigs',
297-
'GET',
298-
): functools.partial(
299-
self._handle_request, self.handler.list_push_notifications
300-
),
301-
('/{tenant}/tasks', 'GET'): functools.partial(
302-
self._handle_request, self.handler.list_tasks
303-
),
304254
}
255+
305256
if self.agent_card.capabilities.extended_agent_card:
306-
routes[('/extendedAgentCard', 'GET')] = functools.partial(
307-
self._handle_request, self.handle_authenticated_agent_card
308-
)
309-
routes[('/{tenant}/extendedAgentCard', 'GET')] = functools.partial(
257+
base_routes[('/extendedAgentCard', 'GET')] = functools.partial(
310258
self._handle_request, self.handle_authenticated_agent_card
311259
)
312260

261+
routes: dict[tuple[str, str], Callable[[Request], Any]] = {
262+
(p, method): handler
263+
for (path, method), handler in base_routes.items()
264+
for p in (path, f'/{{tenant}}{path}')
265+
}
266+
313267
return routes
268+
269+
def _build_call_context(self, request: Request) -> ServerCallContext:
270+
call_context = self._context_builder.build(request)
271+
if 'tenant' in request.path_params:
272+
call_context.tenant = request.path_params['tenant']
273+
return call_context

0 commit comments

Comments
 (0)