Skip to content

Commit fda725a

Browse files
authored
agentsts: allow subject token lookup from session state (#1559)
Runtimes like Vertex do not pass along the HTTP headers to ADK plugins, and instead we need to rely on passing custom metadata such as application specific state via the session state. This change allows retrieving the subject token using a callback function that defaults to extracting it from the Authorization header in the session state. Signed-off-by: Shashank Ram <shashank.ram@solo.io>
1 parent 27d6495 commit fda725a

3 files changed

Lines changed: 133 additions & 13 deletions

File tree

python/packages/agentsts-adk/src/agentsts/adk/_base.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import logging
55
import time
6-
from typing import Any, Awaitable, Callable, Dict, Optional, Union
6+
from typing import Awaitable, Callable, Dict, Optional, Union
77

88
import jwt
99
from google.adk.agents import BaseAgent, LlmAgent
@@ -33,8 +33,27 @@
3333
HEADERS_KEY = "headers"
3434

3535

36+
def _default_get_subject_token(state: dict) -> Optional[str]:
37+
"""Default subject token retrieval from Authorization header in session state."""
38+
headers = state.get(HEADERS_KEY, None)
39+
return _extract_jwt_from_headers(headers)
40+
41+
3642
class ADKSTSIntegration(STSIntegrationBase):
37-
"""Google ADK-specific STS integration."""
43+
"""Google ADK-specific STS integration.
44+
45+
By default, the subject token is read from the ``Authorization`` header
46+
stored in the session state under the ``headers`` key. To retrieve the
47+
subject token from a custom source, pass a ``get_subject_token`` callback::
48+
49+
integration = ADKSTSIntegration(
50+
well_known_uri="https://example.com/.well-known/sts",
51+
get_subject_token=lambda state: state.get("my_custom_token_key"),
52+
)
53+
54+
The callback receives ``session.state`` (a dict) and should return the
55+
subject token string, or ``None`` if not available.
56+
"""
3857

3958
def __init__(
4059
self,
@@ -44,7 +63,7 @@ def __init__(
4463
timeout: int = 5,
4564
verify_ssl: bool = True,
4665
use_issuer_host: bool = False,
47-
additional_config: Optional[Dict[str, Any]] = None,
66+
get_subject_token: Optional[Callable[[dict], Optional[str]]] = None,
4867
):
4968
"""Initialize the ADK STS integration.
5069
@@ -55,7 +74,9 @@ def __init__(
5574
timeout: Request timeout in seconds
5675
verify_ssl: Whether to verify SSL certificates
5776
use_issuer_host: Replace the host:port in token_endpoint with the host:port from well_known_uri
58-
additional_config: Additional configuration
77+
get_subject_token: Optional callback that takes session.state (dict) and returns
78+
the subject token string or None. If not set, defaults to extracting the
79+
JWT from the Authorization header in session.state["headers"].
5980
"""
6081
super().__init__(
6182
well_known_uri=well_known_uri,
@@ -64,7 +85,7 @@ def __init__(
6485
timeout=timeout,
6586
verify_ssl=verify_ssl,
6687
use_issuer_host=use_issuer_host,
67-
additional_config=additional_config,
88+
get_subject_token=get_subject_token or _default_get_subject_token,
6889
)
6990

7091

@@ -143,10 +164,14 @@ async def before_run_callback(
143164
return None
144165

145166
# No valid cached token, need to get/exchange subject token
146-
headers = invocation_context.session.state.get(HEADERS_KEY, None)
147-
subject_token = _extract_jwt_from_headers(headers)
167+
get_subject_token = (
168+
self.sts_integration.get_subject_token
169+
if self.sts_integration and self.sts_integration.get_subject_token
170+
else _default_get_subject_token
171+
)
172+
subject_token = get_subject_token(invocation_context.session.state)
148173
if not subject_token:
149-
logger.debug("No subject token found in headers for token propagation")
174+
logger.debug("subject token not found in session state for token propagation")
150175
return None
151176

152177
if self.sts_integration:

python/packages/agentsts-adk/tests/test_adk_integration.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
class TestADKTokenPropagationPlugin:
1919
"""Unit tests for token propagation plugin covering: none, downstream, and STS exchange."""
2020

21-
def _make_invocation_context(self, session_id: str, headers: dict | None):
21+
def _make_invocation_context(self, session_id: str, headers: dict | None, extra_state: dict | None = None):
2222
session = Mock()
2323
session.id = session_id
2424
session.state = {}
2525
if headers is not None:
2626
session.state[HEADERS_KEY] = headers
27+
if extra_state is not None:
28+
session.state.update(extra_state)
2729
invocation_context = Mock()
2830
invocation_context.session = session
2931
return invocation_context
@@ -48,9 +50,83 @@ async def test_before_run_callback_no_headers(self):
4850
with patch("agentsts.adk._base.logger") as mock_logger:
4951
result = await plugin.before_run_callback(invocation_context=ic)
5052
assert result is None
51-
mock_logger.debug.assert_called_once_with("No subject token found in headers for token propagation")
53+
mock_logger.debug.assert_called_once_with("subject token not found in session state for token propagation")
5254
assert plugin.token_cache == {}
5355

56+
@pytest.mark.asyncio
57+
async def test_subject_token_from_callback(self):
58+
"""Case: get_subject_token callback set -> reads token from session state via callback."""
59+
sts = Mock(spec=ADKSTSIntegration)
60+
sts.get_subject_token = None
61+
sts.get_subject_token = lambda state: state.get("subject-token")
62+
sts.fetch_actor_token = None
63+
sts._actor_token = "actor-token"
64+
sts.exchange_token = AsyncMock(return_value="exchanged-token")
65+
plugin = ADKTokenPropagationPlugin(sts)
66+
ic = self._make_invocation_context(
67+
"sess-key-1",
68+
headers=None,
69+
extra_state={"subject-token": "subject-jwt-from-vertex"},
70+
)
71+
result = await plugin.before_run_callback(invocation_context=ic)
72+
assert result is None
73+
sts.exchange_token.assert_called_once_with(
74+
subject_token="subject-jwt-from-vertex",
75+
subject_token_type=TokenType.JWT,
76+
actor_token="actor-token",
77+
actor_token_type=TokenType.JWT,
78+
)
79+
assert "sess-key-1" in plugin.token_cache
80+
assert plugin.token_cache["sess-key-1"].token == "exchanged-token"
81+
82+
@pytest.mark.asyncio
83+
async def test_subject_token_callback_returns_none(self):
84+
"""Case: get_subject_token callback returns None -> returns None."""
85+
sts = Mock(spec=ADKSTSIntegration)
86+
sts.get_subject_token = None
87+
sts.get_subject_token = lambda state: None
88+
plugin = ADKTokenPropagationPlugin(sts)
89+
ic = self._make_invocation_context("sess-key-2", headers=None)
90+
with patch("agentsts.adk._base.logger") as mock_logger:
91+
result = await plugin.before_run_callback(invocation_context=ic)
92+
assert result is None
93+
mock_logger.debug.assert_called_once_with("subject token not found in session state for token propagation")
94+
assert plugin.token_cache == {}
95+
96+
@pytest.mark.asyncio
97+
async def test_no_sts_no_headers_returns_none(self):
98+
"""Case: no STS integration, no headers -> default callback returns None, no token cached."""
99+
plugin = ADKTokenPropagationPlugin(sts_integration=None)
100+
ic = self._make_invocation_context(
101+
"sess-key-3",
102+
headers=None,
103+
)
104+
# Default callback looks for headers, finds none -> no token cached
105+
result = await plugin.before_run_callback(invocation_context=ic)
106+
assert result is None
107+
assert plugin.token_cache == {}
108+
109+
@pytest.mark.asyncio
110+
async def test_default_callback_extracts_from_headers(self):
111+
"""Case: no get_subject_token callback -> default extracts from headers."""
112+
sts = Mock(spec=ADKSTSIntegration)
113+
sts.get_subject_token = None
114+
sts.get_subject_token = None
115+
sts.fetch_actor_token = None
116+
sts._actor_token = "actor-token"
117+
sts.exchange_token = AsyncMock(return_value="exchanged-via-headers")
118+
plugin = ADKTokenPropagationPlugin(sts)
119+
ic = self._make_invocation_context("sess-key-4", headers={"Authorization": "Bearer header-jwt"})
120+
result = await plugin.before_run_callback(invocation_context=ic)
121+
assert result is None
122+
sts.exchange_token.assert_called_once_with(
123+
subject_token="header-jwt",
124+
subject_token_type=TokenType.JWT,
125+
actor_token="actor-token",
126+
actor_token_type=TokenType.JWT,
127+
)
128+
assert plugin.token_cache["sess-key-4"].token == "exchanged-via-headers"
129+
54130
@pytest.mark.asyncio
55131
async def test_downstream_token_propagation_without_sts(self):
56132
"""Case: headers present, no STS integration -> subject token cached and available via header_provider."""
@@ -83,6 +159,7 @@ async def test_downstream_token_propagation_without_sts(self):
83159
async def test_sts_token_exchange_success(self):
84160
"""Case: STS integration exchanges token -> access token cached and returned by header provider."""
85161
sts = Mock(spec=ADKSTSIntegration)
162+
sts.get_subject_token = None
86163
sts.fetch_actor_token = None
87164
sts._actor_token = "actor-token"
88165
sts.exchange_token = AsyncMock(return_value="access-token-XYZ")
@@ -115,6 +192,7 @@ async def test_sts_token_exchange_success(self):
115192
async def test_sts_token_exchange_failure(self):
116193
"""Case: STS exchange raises -> no cache entry, graceful warning."""
117194
sts = Mock(spec=ADKSTSIntegration)
195+
sts.get_subject_token = None
118196
sts.fetch_actor_token = None
119197
sts._actor_token = "actor-token"
120198
sts.exchange_token = AsyncMock(side_effect=Exception("boom"))
@@ -161,6 +239,7 @@ async def test_dynamic_token_fetch_success_sync(self):
161239
"""Case: sync fetch_actor_token is called successfully and token is exchanged."""
162240
fetch_token_mock = Mock(return_value="dynamic-actor-token")
163241
sts = Mock(spec=ADKSTSIntegration)
242+
sts.get_subject_token = None
164243
sts.fetch_actor_token = fetch_token_mock
165244
sts._actor_token = None
166245
sts.exchange_token = AsyncMock(return_value="access-token-dynamic")
@@ -200,6 +279,7 @@ async def async_fetch_token():
200279
return "dynamic-actor-token-async"
201280

202281
sts = Mock(spec=ADKSTSIntegration)
282+
sts.get_subject_token = None
203283
sts.fetch_actor_token = async_fetch_token
204284
sts._actor_token = None
205285
sts.exchange_token = AsyncMock(return_value="access-token-dynamic-async")
@@ -233,6 +313,7 @@ async def test_dynamic_token_fetch_failure_sync(self):
233313
"""Case: sync fetch_actor_token raises exception -> no token exchange, graceful handling."""
234314
fetch_token_mock = Mock(side_effect=Exception("Token fetch failed"))
235315
sts = Mock(spec=ADKSTSIntegration)
316+
sts.get_subject_token = None
236317
sts.fetch_actor_token = fetch_token_mock
237318
sts._actor_token = None
238319

@@ -262,6 +343,7 @@ async def async_fetch_token_failing():
262343
raise Exception("Async token fetch failed")
263344

264345
sts = Mock(spec=ADKSTSIntegration)
346+
sts.get_subject_token = None
265347
sts.fetch_actor_token = async_fetch_token_failing
266348
sts._actor_token = None
267349

@@ -290,6 +372,7 @@ async def test_dynamic_token_preserved_when_not_expired(self):
290372
jwt_token = "header.payload.signature" # Mock JWT
291373

292374
sts = Mock(spec=ADKSTSIntegration)
375+
sts.get_subject_token = None
293376
sts.fetch_actor_token = Mock(return_value="dynamic-actor")
294377
sts.exchange_token = AsyncMock(return_value=jwt_token)
295378

@@ -321,6 +404,7 @@ async def test_dynamic_token_removed_when_expired(self):
321404
jwt_token = "header.payload.signature" # Mock JWT
322405

323406
sts = Mock(spec=ADKSTSIntegration)
407+
sts.get_subject_token = None
324408
sts.fetch_actor_token = Mock(return_value="dynamic-actor")
325409
sts.exchange_token = AsyncMock(return_value=jwt_token)
326410

@@ -344,6 +428,7 @@ async def test_dynamic_token_removed_when_expired(self):
344428
async def test_valid_token_preserved_in_cache(self):
345429
"""Case: valid token (not expired) is preserved in after_run_callback."""
346430
sts = Mock(spec=ADKSTSIntegration)
431+
sts.get_subject_token = None
347432
sts.fetch_actor_token = None
348433
sts._actor_token = "static-actor"
349434
sts.exchange_token = AsyncMock(return_value="access-token-static")
@@ -378,6 +463,7 @@ def sync_fetch_token():
378463
return "dynamic-actor-token"
379464

380465
sts = Mock(spec=ADKSTSIntegration)
466+
sts.get_subject_token = None
381467
sts.fetch_actor_token = sync_fetch_token
382468
sts._actor_token = None
383469
sts.exchange_token = AsyncMock(return_value="access-token")
@@ -417,6 +503,7 @@ def sync_fetch_token():
417503
return f"dynamic-actor-token-{fetch_count}"
418504

419505
sts = Mock(spec=ADKSTSIntegration)
506+
sts.get_subject_token = None
420507
sts.fetch_actor_token = sync_fetch_token
421508
sts._actor_token = None
422509
sts.exchange_token = AsyncMock(return_value="access-token")
@@ -450,6 +537,7 @@ async def test_actor_token_cache_cleanup_on_expiry(self):
450537
past_expiry = int(time.time()) - 100
451538

452539
sts = Mock(spec=ADKSTSIntegration)
540+
sts.get_subject_token = None
453541
sts.fetch_actor_token = Mock(return_value="dynamic-actor")
454542
sts._actor_token = None
455543
sts.exchange_token = AsyncMock(return_value="access-token")
@@ -482,6 +570,7 @@ async def test_actor_token_cache_preserved_when_not_expired(self):
482570
future_expiry = int(time.time()) + 3600
483571

484572
sts = Mock(spec=ADKSTSIntegration)
573+
sts.get_subject_token = None
485574
sts.fetch_actor_token = Mock(return_value="dynamic-actor")
486575
sts._actor_token = None
487576
sts.exchange_token = AsyncMock(return_value="access-token")
@@ -514,6 +603,7 @@ def sync_fetch_token():
514603
return "actor-token-no-expiry"
515604

516605
sts = Mock(spec=ADKSTSIntegration)
606+
sts.get_subject_token = None
517607
sts.fetch_actor_token = sync_fetch_token
518608
sts._actor_token = None
519609
sts.exchange_token = AsyncMock(return_value="access-token")
@@ -548,6 +638,7 @@ async def test_subject_token_cached_and_reused(self):
548638
future_expiry = int(time.time()) + 3600
549639

550640
sts = Mock(spec=ADKSTSIntegration)
641+
sts.get_subject_token = None
551642
sts.fetch_actor_token = None
552643
sts._actor_token = "static-actor"
553644
sts.exchange_token = AsyncMock(return_value="exchanged-token")
@@ -582,6 +673,7 @@ async def test_subject_token_reexchanged_after_expiry(self):
582673
future_expiry = int(time.time()) + 3600
583674

584675
sts = Mock(spec=ADKSTSIntegration)
676+
sts.get_subject_token = None
585677
sts.fetch_actor_token = None
586678
sts._actor_token = "static-actor"
587679
sts.exchange_token = AsyncMock(side_effect=["token-1", "token-2"])
@@ -613,6 +705,7 @@ async def test_subject_token_reexchanged_after_expiry(self):
613705
async def test_subject_token_cache_no_expiry(self):
614706
"""Case: subject token without expiry is cached indefinitely and reused."""
615707
sts = Mock(spec=ADKSTSIntegration)
708+
sts.get_subject_token = None
616709
sts.fetch_actor_token = None
617710
sts._actor_token = "static-actor"
618711
sts.exchange_token = AsyncMock(return_value="exchanged-token-no-exp")
@@ -650,6 +743,7 @@ async def async_fetch_token():
650743
return "dynamic-actor-token-async"
651744

652745
sts = Mock(spec=ADKSTSIntegration)
746+
sts.get_subject_token = None
653747
sts.fetch_actor_token = async_fetch_token
654748
sts._actor_token = None
655749
sts.exchange_token = AsyncMock(return_value="access-token")

python/packages/agentsts-core/src/agentsts/core/_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
timeout: int = 30,
2222
verify_ssl: bool = True,
2323
use_issuer_host: bool = False,
24-
additional_config: Optional[Dict[str, Any]] = None,
24+
get_subject_token: Optional[Callable[[dict], Optional[str]]] = None,
2525
):
2626
"""Initialize the STS integration.
2727
@@ -32,13 +32,14 @@ def __init__(
3232
timeout: Request timeout in seconds
3333
verify_ssl: Whether to verify SSL certificates
3434
use_issuer_host: Replace the host:port in token_endpoint with the host:port from well_known_uri
35-
additional_config: Additional configuration for the specific framework
35+
get_subject_token: Optional callback that takes session state (dict) and returns
36+
the subject token string or None
3637
"""
3738
self.well_known_uri = well_known_uri
3839
self.timeout = timeout
3940
self.verify_ssl = verify_ssl
40-
self.additional_config = additional_config or {}
4141
self.fetch_actor_token = fetch_actor_token
42+
self.get_subject_token = get_subject_token
4243

4344
# Initialize STS client
4445
config = STSConfig(

0 commit comments

Comments
 (0)