Skip to content

Commit 86665b5

Browse files
committed
refactor: extract StarletteUserProxy and add comprehensive route utility tests
1 parent 60e9fed commit 86665b5

2 files changed

Lines changed: 183 additions & 13 deletions

File tree

src/a2a/server/routes/common.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from collections.abc import Callable
2+
from typing import TYPE_CHECKING, Any
23

3-
from starlette.requests import Request
4+
if TYPE_CHECKING:
5+
from starlette.requests import Request
6+
else:
7+
try:
8+
from starlette.requests import Request
9+
except ImportError:
10+
Request = Any
411

512
from a2a.auth.user import UnauthenticatedUser, User
613
from a2a.extensions.common import (
@@ -13,20 +20,27 @@
1320
UserBuilder = Callable[[Request], User]
1421

1522

16-
def default_user_builder(request: Request) -> User:
17-
"""Default strategy for creating an A2AUser from a Starlette Request."""
18-
if 'user' in request.scope:
23+
class StarletteUser(User):
24+
"""Adapts a Starlette BaseUser to the A2A User interface."""
25+
26+
def __init__(self, user: Any):
27+
self._user = user
1928

20-
class BaseUser(User):
21-
@property
22-
def is_authenticated(self) -> bool:
23-
return request.user.is_authenticated
29+
@property
30+
def is_authenticated(self) -> bool:
31+
"""Returns whether the current user is authenticated."""
32+
return self._user.is_authenticated
2433

25-
@property
26-
def user_name(self) -> str:
27-
return request.user.display_name
34+
@property
35+
def user_name(self) -> str:
36+
"""Returns the user name of the current user."""
37+
return self._user.display_name
2838

29-
return BaseUser()
39+
40+
def default_user_builder(request: Request) -> User:
41+
"""Default strategy for creating an A2AUser from a Starlette Request."""
42+
if 'user' in request.scope:
43+
return StarletteUser(request.user)
3044
return UnauthenticatedUser()
3145

3246

@@ -37,7 +51,7 @@ def build_server_call_context(
3751
3852
Args:
3953
request: The incoming Starlette Request object.
40-
user_builder: Optional custom user builder.
54+
user_builder: A callable that creates a User from the request.
4155
4256
Returns:
4357
A ServerCallContext instance populated with user and state.

tests/server/routes/test_common.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from starlette.datastructures import Headers
5+
6+
try:
7+
from starlette.authentication import BaseUser as StarletteBaseUser
8+
except ImportError:
9+
StarletteBaseUser = MagicMock() # type: ignore
10+
11+
from a2a.auth.user import UnauthenticatedUser
12+
from a2a.extensions.common import HTTP_EXTENSION_HEADER
13+
from a2a.server.context import ServerCallContext
14+
from a2a.server.routes.common import (
15+
StarletteUser,
16+
build_server_call_context,
17+
default_user_builder,
18+
)
19+
20+
21+
# --- StarletteUser Tests ---
22+
23+
24+
class TestStarletteUser:
25+
def test_is_authenticated_true(self):
26+
starlette_user = MagicMock(spec=StarletteBaseUser)
27+
starlette_user.is_authenticated = True
28+
proxy = StarletteUser(starlette_user)
29+
assert proxy.is_authenticated is True
30+
31+
def test_is_authenticated_false(self):
32+
starlette_user = MagicMock(spec=StarletteBaseUser)
33+
starlette_user.is_authenticated = False
34+
proxy = StarletteUser(starlette_user)
35+
assert proxy.is_authenticated is False
36+
37+
def test_user_name(self):
38+
starlette_user = MagicMock(spec=StarletteBaseUser)
39+
starlette_user.display_name = 'Test User'
40+
proxy = StarletteUser(starlette_user)
41+
assert proxy.user_name == 'Test User'
42+
43+
def test_user_name_raises_attribute_error(self):
44+
starlette_user = MagicMock(spec=StarletteBaseUser)
45+
del starlette_user.display_name
46+
proxy = StarletteUser(starlette_user)
47+
with pytest.raises(AttributeError, match='display_name'):
48+
_ = proxy.user_name
49+
50+
51+
# --- default_user_builder Tests ---
52+
53+
54+
def _make_mock_request(scope=None, headers=None):
55+
request = MagicMock()
56+
request.scope = scope or {}
57+
request.headers = Headers(headers or {})
58+
return request
59+
60+
61+
class TestDefaultUserBuilder:
62+
def test_returns_unauthenticated_user_when_no_user_in_scope(self):
63+
request = _make_mock_request(scope={})
64+
user = default_user_builder(request)
65+
assert isinstance(user, UnauthenticatedUser)
66+
assert user.is_authenticated is False
67+
assert user.user_name == ''
68+
69+
def test_returns_proxy_when_user_in_scope(self):
70+
starlette_user = MagicMock()
71+
starlette_user.is_authenticated = True
72+
starlette_user.display_name = 'Alice'
73+
request = _make_mock_request(scope={'user': starlette_user})
74+
request.user = starlette_user
75+
76+
user = default_user_builder(request)
77+
assert isinstance(user, StarletteUser)
78+
assert user.is_authenticated is True
79+
assert user.user_name == 'Alice'
80+
81+
def test_returns_unauthenticated_proxy_when_user_not_authenticated(self):
82+
starlette_user = MagicMock()
83+
starlette_user.is_authenticated = False
84+
starlette_user.display_name = ''
85+
request = _make_mock_request(scope={'user': starlette_user})
86+
request.user = starlette_user
87+
88+
user = default_user_builder(request)
89+
assert isinstance(user, StarletteUser)
90+
assert user.is_authenticated is False
91+
92+
93+
# --- build_server_call_context Tests ---
94+
95+
96+
class TestBuildServerCallContext:
97+
def test_basic_context_with_default_user_builder(self):
98+
request = _make_mock_request(
99+
scope={}, headers={'content-type': 'application/json'}
100+
)
101+
ctx = build_server_call_context(request, default_user_builder)
102+
103+
assert isinstance(ctx, ServerCallContext)
104+
assert isinstance(ctx.user, UnauthenticatedUser)
105+
assert 'headers' in ctx.state
106+
assert ctx.state['headers']['content-type'] == 'application/json'
107+
assert 'auth' not in ctx.state
108+
109+
def test_auth_populated_when_in_scope(self):
110+
auth_credentials = MagicMock()
111+
request = _make_mock_request(scope={'auth': auth_credentials})
112+
request.auth = auth_credentials
113+
114+
ctx = build_server_call_context(request, default_user_builder)
115+
assert ctx.state['auth'] is auth_credentials
116+
117+
def test_auth_not_populated_when_not_in_scope(self):
118+
request = _make_mock_request(scope={})
119+
ctx = build_server_call_context(request, default_user_builder)
120+
assert 'auth' not in ctx.state
121+
122+
def test_headers_captured_in_state(self):
123+
request = _make_mock_request(
124+
headers={'x-custom': 'value', 'authorization': 'Bearer tok'}
125+
)
126+
ctx = build_server_call_context(request, default_user_builder)
127+
assert ctx.state['headers']['x-custom'] == 'value'
128+
assert ctx.state['headers']['authorization'] == 'Bearer tok'
129+
130+
def test_requested_extensions_single(self):
131+
request = _make_mock_request(headers={HTTP_EXTENSION_HEADER: 'foo'})
132+
ctx = build_server_call_context(request, default_user_builder)
133+
assert ctx.requested_extensions == {'foo'}
134+
135+
def test_requested_extensions_comma_separated(self):
136+
request = _make_mock_request(
137+
headers={HTTP_EXTENSION_HEADER: 'foo, bar'}
138+
)
139+
ctx = build_server_call_context(request, default_user_builder)
140+
assert ctx.requested_extensions == {'foo', 'bar'}
141+
142+
def test_no_extensions(self):
143+
request = _make_mock_request()
144+
ctx = build_server_call_context(request, default_user_builder)
145+
assert ctx.requested_extensions == set()
146+
147+
def test_custom_user_builder(self):
148+
custom_user = MagicMock(spec=UnauthenticatedUser)
149+
custom_user.is_authenticated = True
150+
151+
def my_builder(req):
152+
return custom_user
153+
154+
request = _make_mock_request()
155+
ctx = build_server_call_context(request, my_builder)
156+
assert ctx.user is custom_user

0 commit comments

Comments
 (0)