1- from unittest .mock import MagicMock
1+ from unittest .mock import AsyncMock , MagicMock
22
33import pytest
4-
4+ from starlette .applications import Starlette
5+ from starlette .testclient import TestClient
56
67# Attempt to import StarletteBaseUser, fallback to MagicMock if not available
78try :
89 from starlette .authentication import BaseUser as StarletteBaseUser
910except ImportError :
1011 StarletteBaseUser = MagicMock () # type: ignore
1112
13+ from a2a .extensions .common import HTTP_EXTENSION_HEADER
1214from a2a .server .apps .jsonrpc .jsonrpc_app import (
13- JSONRPCApplication , # Still needed for JSONRPCApplication default constructor arg
15+ JSONRPCApplication ,
1416 StarletteUserProxy ,
1517)
18+ from a2a .server .apps .jsonrpc .starlette_app import A2AStarletteApplication
19+ from a2a .server .context import ServerCallContext
1620from a2a .server .request_handlers .request_handler import (
17- RequestHandler , # For mock spec
21+ RequestHandler ,
22+ ) # For mock spec
23+ from a2a .types import (
24+ AgentCapabilities ,
25+ AgentCard ,
26+ Message ,
27+ MessageSendParams ,
28+ Role ,
29+ SendMessageRequest ,
30+ SendMessageResponse ,
31+ SendMessageSuccessResponse ,
32+ TextPart ,
1833)
19- from a2a .types import AgentCard # For mock spec
20-
2134
2235# --- StarletteUserProxy Tests ---
2336
@@ -69,6 +82,7 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror(
6982 mock_agent_card .url = 'http://mockurl.com'
7083 # Ensure 'supportsAuthenticatedExtendedCard' attribute exists
7184 mock_agent_card .supportsAuthenticatedExtendedCard = False
85+ mock_agent_card .capabilities = AgentCapabilities (streaming = True )
7286
7387 # This will fail at definition time if an abstract method is not implemented
7488 with pytest .raises (
@@ -86,5 +100,149 @@ def some_other_method(self):
86100 )
87101
88102
103+ class TestJSONRPCExtensions :
104+ @pytest .fixture
105+ def mock_handler (self ):
106+ handler = AsyncMock (spec = RequestHandler )
107+ handler .on_message_send .return_value = SendMessageResponse (
108+ root = SendMessageSuccessResponse (
109+ id = '1' ,
110+ result = Message (
111+ messageId = 'test' ,
112+ role = Role .agent ,
113+ parts = [TextPart (text = 'response message' )],
114+ ),
115+ )
116+ )
117+ return handler
118+
119+ @pytest .fixture
120+ def test_app (self , mock_handler ):
121+ mock_agent_card = MagicMock (spec = AgentCard )
122+ mock_agent_card .url = 'http://mockurl.com'
123+ mock_agent_card .supportsAuthenticatedExtendedCard = False
124+
125+ return A2AStarletteApplication (
126+ agent_card = mock_agent_card , http_handler = mock_handler
127+ )
128+
129+ @pytest .fixture
130+ def client (self , test_app ):
131+ return TestClient (test_app .build ())
132+
133+ def test_request_with_single_extension (self , client , mock_handler ):
134+ headers = {HTTP_EXTENSION_HEADER : 'foo' }
135+ response = client .post (
136+ '/' ,
137+ headers = headers ,
138+ json = SendMessageRequest (
139+ id = '1' ,
140+ params = MessageSendParams (
141+ message = Message (
142+ messageId = '1' ,
143+ role = Role .user ,
144+ parts = [TextPart (text = 'hi' )],
145+ )
146+ ),
147+ ).model_dump (),
148+ )
149+ response .raise_for_status ()
150+
151+ mock_handler .on_message_send .assert_called_once ()
152+ call_context = mock_handler .on_message_send .call_args [0 ][1 ]
153+ assert isinstance (call_context , ServerCallContext )
154+ assert call_context .requested_extensions == {'foo' }
155+
156+ def test_request_with_comma_separated_extensions (
157+ self , client , mock_handler
158+ ):
159+ headers = {HTTP_EXTENSION_HEADER : 'foo, bar' }
160+ response = client .post (
161+ '/' ,
162+ headers = headers ,
163+ json = SendMessageRequest (
164+ id = '1' ,
165+ params = MessageSendParams (
166+ message = Message (
167+ messageId = '1' ,
168+ role = Role .user ,
169+ parts = [TextPart (text = 'hi' )],
170+ )
171+ ),
172+ ).model_dump (),
173+ )
174+ response .raise_for_status ()
175+
176+ mock_handler .on_message_send .assert_called_once ()
177+ call_context = mock_handler .on_message_send .call_args [0 ][1 ]
178+ assert call_context .requested_extensions == {'foo' , 'bar' }
179+
180+ def test_request_with_multiple_extension_headers (
181+ self , client , mock_handler
182+ ):
183+ headers = [
184+ (HTTP_EXTENSION_HEADER , 'foo' ),
185+ (HTTP_EXTENSION_HEADER , 'bar' ),
186+ ]
187+ response = client .post (
188+ '/' ,
189+ headers = headers ,
190+ json = SendMessageRequest (
191+ id = '1' ,
192+ params = MessageSendParams (
193+ message = Message (
194+ messageId = '1' ,
195+ role = Role .user ,
196+ parts = [TextPart (text = 'hi' )],
197+ )
198+ ),
199+ ).model_dump (),
200+ )
201+ response .raise_for_status ()
202+
203+ mock_handler .on_message_send .assert_called_once ()
204+ call_context = mock_handler .on_message_send .call_args [0 ][1 ]
205+ assert call_context .requested_extensions == {'foo' , 'bar' }
206+
207+ def test_response_with_activated_extensions (self , client , mock_handler ):
208+ def side_effect (request , context : ServerCallContext ):
209+ context .activated_extensions .add ('foo' )
210+ context .activated_extensions .add ('baz' )
211+ return SendMessageResponse (
212+ root = SendMessageSuccessResponse (
213+ id = '1' ,
214+ result = Message (
215+ messageId = 'test' ,
216+ role = Role .agent ,
217+ parts = [TextPart (text = 'response message' )],
218+ ),
219+ )
220+ )
221+
222+ mock_handler .on_message_send .side_effect = side_effect
223+
224+ response = client .post (
225+ '/' ,
226+ json = SendMessageRequest (
227+ id = '1' ,
228+ params = MessageSendParams (
229+ message = Message (
230+ messageId = '1' ,
231+ role = Role .user ,
232+ parts = [TextPart (text = 'hi' )],
233+ )
234+ ),
235+ ).model_dump (),
236+ )
237+ response .raise_for_status ()
238+
239+ assert response .status_code == 200
240+ assert HTTP_EXTENSION_HEADER in response .headers
241+ assert set (response .headers [HTTP_EXTENSION_HEADER ].split (', ' )) == {
242+ 'foo' ,
243+ 'baz' ,
244+ }
245+
246+
89247if __name__ == '__main__' :
90248 pytest .main ([__file__ ])
0 commit comments