1818class TestADKTokenPropagationPlugin :
1919 """Unit tests for token propagation plugin covering: none, downstream, and STS exchange."""
2020
21- def _make_invocation_context (self , session_id : str , headers : dict | None ):
21+ def _make_invocation_context (self , session_id : str , headers : dict | None , extra_state : dict | None = None ):
2222 session = Mock ()
2323 session .id = session_id
2424 session .state = {}
2525 if headers is not None :
2626 session .state [HEADERS_KEY ] = headers
27+ if extra_state is not None :
28+ session .state .update (extra_state )
2729 invocation_context = Mock ()
2830 invocation_context .session = session
2931 return invocation_context
@@ -48,9 +50,83 @@ async def test_before_run_callback_no_headers(self):
4850 with patch ("agentsts.adk._base.logger" ) as mock_logger :
4951 result = await plugin .before_run_callback (invocation_context = ic )
5052 assert result is None
51- mock_logger .debug .assert_called_once_with ("No subject token found in headers for token propagation" )
53+ mock_logger .debug .assert_called_once_with ("subject token not found in session state for token propagation" )
5254 assert plugin .token_cache == {}
5355
56+ @pytest .mark .asyncio
57+ async def test_subject_token_from_callback (self ):
58+ """Case: get_subject_token callback set -> reads token from session state via callback."""
59+ sts = Mock (spec = ADKSTSIntegration )
60+ sts .get_subject_token = None
61+ sts .get_subject_token = lambda state : state .get ("subject-token" )
62+ sts .fetch_actor_token = None
63+ sts ._actor_token = "actor-token"
64+ sts .exchange_token = AsyncMock (return_value = "exchanged-token" )
65+ plugin = ADKTokenPropagationPlugin (sts )
66+ ic = self ._make_invocation_context (
67+ "sess-key-1" ,
68+ headers = None ,
69+ extra_state = {"subject-token" : "subject-jwt-from-vertex" },
70+ )
71+ result = await plugin .before_run_callback (invocation_context = ic )
72+ assert result is None
73+ sts .exchange_token .assert_called_once_with (
74+ subject_token = "subject-jwt-from-vertex" ,
75+ subject_token_type = TokenType .JWT ,
76+ actor_token = "actor-token" ,
77+ actor_token_type = TokenType .JWT ,
78+ )
79+ assert "sess-key-1" in plugin .token_cache
80+ assert plugin .token_cache ["sess-key-1" ].token == "exchanged-token"
81+
82+ @pytest .mark .asyncio
83+ async def test_subject_token_callback_returns_none (self ):
84+ """Case: get_subject_token callback returns None -> returns None."""
85+ sts = Mock (spec = ADKSTSIntegration )
86+ sts .get_subject_token = None
87+ sts .get_subject_token = lambda state : None
88+ plugin = ADKTokenPropagationPlugin (sts )
89+ ic = self ._make_invocation_context ("sess-key-2" , headers = None )
90+ with patch ("agentsts.adk._base.logger" ) as mock_logger :
91+ result = await plugin .before_run_callback (invocation_context = ic )
92+ assert result is None
93+ mock_logger .debug .assert_called_once_with ("subject token not found in session state for token propagation" )
94+ assert plugin .token_cache == {}
95+
96+ @pytest .mark .asyncio
97+ async def test_no_sts_no_headers_returns_none (self ):
98+ """Case: no STS integration, no headers -> default callback returns None, no token cached."""
99+ plugin = ADKTokenPropagationPlugin (sts_integration = None )
100+ ic = self ._make_invocation_context (
101+ "sess-key-3" ,
102+ headers = None ,
103+ )
104+ # Default callback looks for headers, finds none -> no token cached
105+ result = await plugin .before_run_callback (invocation_context = ic )
106+ assert result is None
107+ assert plugin .token_cache == {}
108+
109+ @pytest .mark .asyncio
110+ async def test_default_callback_extracts_from_headers (self ):
111+ """Case: no get_subject_token callback -> default extracts from headers."""
112+ sts = Mock (spec = ADKSTSIntegration )
113+ sts .get_subject_token = None
114+ sts .get_subject_token = None
115+ sts .fetch_actor_token = None
116+ sts ._actor_token = "actor-token"
117+ sts .exchange_token = AsyncMock (return_value = "exchanged-via-headers" )
118+ plugin = ADKTokenPropagationPlugin (sts )
119+ ic = self ._make_invocation_context ("sess-key-4" , headers = {"Authorization" : "Bearer header-jwt" })
120+ result = await plugin .before_run_callback (invocation_context = ic )
121+ assert result is None
122+ sts .exchange_token .assert_called_once_with (
123+ subject_token = "header-jwt" ,
124+ subject_token_type = TokenType .JWT ,
125+ actor_token = "actor-token" ,
126+ actor_token_type = TokenType .JWT ,
127+ )
128+ assert plugin .token_cache ["sess-key-4" ].token == "exchanged-via-headers"
129+
54130 @pytest .mark .asyncio
55131 async def test_downstream_token_propagation_without_sts (self ):
56132 """Case: headers present, no STS integration -> subject token cached and available via header_provider."""
@@ -83,6 +159,7 @@ async def test_downstream_token_propagation_without_sts(self):
83159 async def test_sts_token_exchange_success (self ):
84160 """Case: STS integration exchanges token -> access token cached and returned by header provider."""
85161 sts = Mock (spec = ADKSTSIntegration )
162+ sts .get_subject_token = None
86163 sts .fetch_actor_token = None
87164 sts ._actor_token = "actor-token"
88165 sts .exchange_token = AsyncMock (return_value = "access-token-XYZ" )
@@ -115,6 +192,7 @@ async def test_sts_token_exchange_success(self):
115192 async def test_sts_token_exchange_failure (self ):
116193 """Case: STS exchange raises -> no cache entry, graceful warning."""
117194 sts = Mock (spec = ADKSTSIntegration )
195+ sts .get_subject_token = None
118196 sts .fetch_actor_token = None
119197 sts ._actor_token = "actor-token"
120198 sts .exchange_token = AsyncMock (side_effect = Exception ("boom" ))
@@ -161,6 +239,7 @@ async def test_dynamic_token_fetch_success_sync(self):
161239 """Case: sync fetch_actor_token is called successfully and token is exchanged."""
162240 fetch_token_mock = Mock (return_value = "dynamic-actor-token" )
163241 sts = Mock (spec = ADKSTSIntegration )
242+ sts .get_subject_token = None
164243 sts .fetch_actor_token = fetch_token_mock
165244 sts ._actor_token = None
166245 sts .exchange_token = AsyncMock (return_value = "access-token-dynamic" )
@@ -200,6 +279,7 @@ async def async_fetch_token():
200279 return "dynamic-actor-token-async"
201280
202281 sts = Mock (spec = ADKSTSIntegration )
282+ sts .get_subject_token = None
203283 sts .fetch_actor_token = async_fetch_token
204284 sts ._actor_token = None
205285 sts .exchange_token = AsyncMock (return_value = "access-token-dynamic-async" )
@@ -233,6 +313,7 @@ async def test_dynamic_token_fetch_failure_sync(self):
233313 """Case: sync fetch_actor_token raises exception -> no token exchange, graceful handling."""
234314 fetch_token_mock = Mock (side_effect = Exception ("Token fetch failed" ))
235315 sts = Mock (spec = ADKSTSIntegration )
316+ sts .get_subject_token = None
236317 sts .fetch_actor_token = fetch_token_mock
237318 sts ._actor_token = None
238319
@@ -262,6 +343,7 @@ async def async_fetch_token_failing():
262343 raise Exception ("Async token fetch failed" )
263344
264345 sts = Mock (spec = ADKSTSIntegration )
346+ sts .get_subject_token = None
265347 sts .fetch_actor_token = async_fetch_token_failing
266348 sts ._actor_token = None
267349
@@ -290,6 +372,7 @@ async def test_dynamic_token_preserved_when_not_expired(self):
290372 jwt_token = "header.payload.signature" # Mock JWT
291373
292374 sts = Mock (spec = ADKSTSIntegration )
375+ sts .get_subject_token = None
293376 sts .fetch_actor_token = Mock (return_value = "dynamic-actor" )
294377 sts .exchange_token = AsyncMock (return_value = jwt_token )
295378
@@ -321,6 +404,7 @@ async def test_dynamic_token_removed_when_expired(self):
321404 jwt_token = "header.payload.signature" # Mock JWT
322405
323406 sts = Mock (spec = ADKSTSIntegration )
407+ sts .get_subject_token = None
324408 sts .fetch_actor_token = Mock (return_value = "dynamic-actor" )
325409 sts .exchange_token = AsyncMock (return_value = jwt_token )
326410
@@ -344,6 +428,7 @@ async def test_dynamic_token_removed_when_expired(self):
344428 async def test_valid_token_preserved_in_cache (self ):
345429 """Case: valid token (not expired) is preserved in after_run_callback."""
346430 sts = Mock (spec = ADKSTSIntegration )
431+ sts .get_subject_token = None
347432 sts .fetch_actor_token = None
348433 sts ._actor_token = "static-actor"
349434 sts .exchange_token = AsyncMock (return_value = "access-token-static" )
@@ -378,6 +463,7 @@ def sync_fetch_token():
378463 return "dynamic-actor-token"
379464
380465 sts = Mock (spec = ADKSTSIntegration )
466+ sts .get_subject_token = None
381467 sts .fetch_actor_token = sync_fetch_token
382468 sts ._actor_token = None
383469 sts .exchange_token = AsyncMock (return_value = "access-token" )
@@ -417,6 +503,7 @@ def sync_fetch_token():
417503 return f"dynamic-actor-token-{ fetch_count } "
418504
419505 sts = Mock (spec = ADKSTSIntegration )
506+ sts .get_subject_token = None
420507 sts .fetch_actor_token = sync_fetch_token
421508 sts ._actor_token = None
422509 sts .exchange_token = AsyncMock (return_value = "access-token" )
@@ -450,6 +537,7 @@ async def test_actor_token_cache_cleanup_on_expiry(self):
450537 past_expiry = int (time .time ()) - 100
451538
452539 sts = Mock (spec = ADKSTSIntegration )
540+ sts .get_subject_token = None
453541 sts .fetch_actor_token = Mock (return_value = "dynamic-actor" )
454542 sts ._actor_token = None
455543 sts .exchange_token = AsyncMock (return_value = "access-token" )
@@ -482,6 +570,7 @@ async def test_actor_token_cache_preserved_when_not_expired(self):
482570 future_expiry = int (time .time ()) + 3600
483571
484572 sts = Mock (spec = ADKSTSIntegration )
573+ sts .get_subject_token = None
485574 sts .fetch_actor_token = Mock (return_value = "dynamic-actor" )
486575 sts ._actor_token = None
487576 sts .exchange_token = AsyncMock (return_value = "access-token" )
@@ -514,6 +603,7 @@ def sync_fetch_token():
514603 return "actor-token-no-expiry"
515604
516605 sts = Mock (spec = ADKSTSIntegration )
606+ sts .get_subject_token = None
517607 sts .fetch_actor_token = sync_fetch_token
518608 sts ._actor_token = None
519609 sts .exchange_token = AsyncMock (return_value = "access-token" )
@@ -548,6 +638,7 @@ async def test_subject_token_cached_and_reused(self):
548638 future_expiry = int (time .time ()) + 3600
549639
550640 sts = Mock (spec = ADKSTSIntegration )
641+ sts .get_subject_token = None
551642 sts .fetch_actor_token = None
552643 sts ._actor_token = "static-actor"
553644 sts .exchange_token = AsyncMock (return_value = "exchanged-token" )
@@ -582,6 +673,7 @@ async def test_subject_token_reexchanged_after_expiry(self):
582673 future_expiry = int (time .time ()) + 3600
583674
584675 sts = Mock (spec = ADKSTSIntegration )
676+ sts .get_subject_token = None
585677 sts .fetch_actor_token = None
586678 sts ._actor_token = "static-actor"
587679 sts .exchange_token = AsyncMock (side_effect = ["token-1" , "token-2" ])
@@ -613,6 +705,7 @@ async def test_subject_token_reexchanged_after_expiry(self):
613705 async def test_subject_token_cache_no_expiry (self ):
614706 """Case: subject token without expiry is cached indefinitely and reused."""
615707 sts = Mock (spec = ADKSTSIntegration )
708+ sts .get_subject_token = None
616709 sts .fetch_actor_token = None
617710 sts ._actor_token = "static-actor"
618711 sts .exchange_token = AsyncMock (return_value = "exchanged-token-no-exp" )
@@ -650,6 +743,7 @@ async def async_fetch_token():
650743 return "dynamic-actor-token-async"
651744
652745 sts = Mock (spec = ADKSTSIntegration )
746+ sts .get_subject_token = None
653747 sts .fetch_actor_token = async_fetch_token
654748 sts ._actor_token = None
655749 sts .exchange_token = AsyncMock (return_value = "access-token" )
0 commit comments