|
1 | 1 | import httpx |
2 | 2 |
|
3 | 3 | from fastapi import FastAPI |
| 4 | +from starlette.applications import Starlette |
| 5 | +from starlette.requests import Request |
4 | 6 |
|
| 7 | +from a2a.auth.user import UnauthenticatedUser, User |
5 | 8 | from a2a.server.agent_execution import AgentExecutor, RequestContext |
6 | 9 | from a2a.server.context import ServerCallContext |
7 | 10 | from a2a.server.events import EventQueue |
8 | | -from starlette.applications import Starlette |
9 | | -from a2a.server.routes.rest_routes import create_rest_routes |
10 | | -from a2a.server.routes import create_agent_card_routes |
11 | 11 | from a2a.server.request_handlers import DefaultRequestHandler |
| 12 | +from a2a.server.routes import create_agent_card_routes |
| 13 | +from a2a.server.routes.common import DefaultServerCallContextBuilder |
| 14 | +from a2a.server.routes.rest_routes import create_rest_routes |
12 | 15 | from a2a.server.tasks import ( |
13 | 16 | BasePushNotificationSender, |
14 | 17 | InMemoryPushNotificationConfigStore, |
|
30 | 33 | ) |
31 | 34 |
|
32 | 35 |
|
| 36 | +_TEST_USER_HEADER = 'x-test-user' |
| 37 | + |
| 38 | + |
33 | 39 | def test_agent_card(url: str) -> AgentCard: |
34 | 40 | """Returns an agent card for the test agent.""" |
35 | 41 | return AgentCard( |
@@ -151,11 +157,85 @@ def create_agent_app( |
151 | 157 | push_sender=BasePushNotificationSender( |
152 | 158 | httpx_client=notification_client, |
153 | 159 | config_store=push_config_store, |
154 | | - context=ServerCallContext(), |
155 | 160 | ), |
156 | 161 | ) |
157 | 162 | rest_routes = create_rest_routes(request_handler=handler) |
158 | 163 | agent_card_routes = create_agent_card_routes( |
159 | 164 | agent_card=card, card_url='/.well-known/agent-card.json' |
160 | 165 | ) |
161 | 166 | return Starlette(routes=[*rest_routes, *agent_card_routes]) |
| 167 | + |
| 168 | + |
| 169 | +class _NamedTestUser(User): |
| 170 | + """Authenticated test user identified by ``user_name``.""" |
| 171 | + |
| 172 | + def __init__(self, user_name: str) -> None: |
| 173 | + self._user_name = user_name |
| 174 | + |
| 175 | + @property |
| 176 | + def is_authenticated(self) -> bool: |
| 177 | + return True |
| 178 | + |
| 179 | + @property |
| 180 | + def user_name(self) -> str: |
| 181 | + return self._user_name |
| 182 | + |
| 183 | + |
| 184 | +class _HeaderUserContextBuilder(DefaultServerCallContextBuilder): |
| 185 | + """Builds a ServerCallContext whose user is read from a request header.""" |
| 186 | + |
| 187 | + def build_user(self, request: Request) -> User: |
| 188 | + user_name = request.headers.get(_TEST_USER_HEADER) |
| 189 | + if user_name: |
| 190 | + return _NamedTestUser(user_name) |
| 191 | + return UnauthenticatedUser() |
| 192 | + |
| 193 | + |
| 194 | +def create_multi_user_agent_app( |
| 195 | + url: str, notification_client: httpx.AsyncClient |
| 196 | +) -> Starlette: |
| 197 | + """Creates a multi-user variant of the test agent app. |
| 198 | +
|
| 199 | + Differences from create_agent_app: |
| 200 | +
|
| 201 | + - Identity is read from the x-test-user header on each request |
| 202 | + via _HeaderUserContextBuilder. Multiple authenticated |
| 203 | + users (e.g. alice, bob) can therefore call the same |
| 204 | + server. |
| 205 | + - The InMemoryTaskStore uses a constant owner resolver, so |
| 206 | + every authenticated user has access to every task. |
| 207 | + - The InMemoryPushNotificationConfigStore keeps the default |
| 208 | + per-user owner resolver, so each registrar's configs live in their |
| 209 | + own owner partition; this exercises cross-owner aggregation in |
| 210 | + get_info_for_dispatch. |
| 211 | + """ |
| 212 | + # Shared task visibility: any authenticated user can see any task. |
| 213 | + task_store = InMemoryTaskStore(owner_resolver=lambda _ctx: 'shared') |
| 214 | + |
| 215 | + # Per-user push-config partitioning (the default). |
| 216 | + push_config_store = InMemoryPushNotificationConfigStore() |
| 217 | + |
| 218 | + card = test_agent_card(url) |
| 219 | + extended_card = test_agent_card(url) |
| 220 | + extended_card.name = 'Test Agent Extended' |
| 221 | + |
| 222 | + handler = DefaultRequestHandler( |
| 223 | + agent_executor=TestAgentExecutor(), |
| 224 | + task_store=task_store, |
| 225 | + agent_card=card, |
| 226 | + extended_agent_card=extended_card, |
| 227 | + push_config_store=push_config_store, |
| 228 | + push_sender=BasePushNotificationSender( |
| 229 | + httpx_client=notification_client, |
| 230 | + config_store=push_config_store, |
| 231 | + ), |
| 232 | + ) |
| 233 | + |
| 234 | + rest_routes = create_rest_routes( |
| 235 | + request_handler=handler, |
| 236 | + context_builder=_HeaderUserContextBuilder(), |
| 237 | + ) |
| 238 | + agent_card_routes = create_agent_card_routes( |
| 239 | + agent_card=card, card_url='/.well-known/agent-card.json' |
| 240 | + ) |
| 241 | + return Starlette(routes=[*rest_routes, *agent_card_routes]) |
0 commit comments