Skip to content

Commit f910736

Browse files
committed
Add tests, rest updates for fastapi, and jsonrpc updates for starlette
1 parent 4b586b4 commit f910736

6 files changed

Lines changed: 108 additions & 4 deletions

File tree

src/a2a/server/apps/jsonrpc/starlette_app.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
import logging
22

3-
from collections.abc import Awaitable, Callable
3+
from collections.abc import Awaitable, Callable, Sequence
44
from typing import TYPE_CHECKING, Any
55

66

77
if TYPE_CHECKING:
88
from starlette.applications import Starlette
9+
from starlette.middleware import Middleware
910
from starlette.routing import Route
1011

1112
_package_starlette_installed = True
1213

1314
else:
1415
try:
1516
from starlette.applications import Starlette
17+
from starlette.middleware import Middleware
1618
from starlette.routing import Route
1719

1820
_package_starlette_installed = True
1921
except ImportError:
2022
Starlette = Any
23+
Middleware = Any
2124
Route = Any
2225

2326
_package_starlette_installed = False
@@ -102,23 +105,30 @@ def routes(
102105
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
103106
rpc_url: str = DEFAULT_RPC_URL,
104107
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
108+
middleware: Sequence[Middleware] | None = None,
105109
) -> list[Route]:
106110
"""Returns the Starlette Routes for handling A2A requests.
107111
108112
Args:
109113
agent_card_url: The URL path for the agent card endpoint.
110114
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
111115
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
116+
middleware: Optional sequence of Starlette Middleware (e.g.
117+
`[Middleware(AuthenticationMiddleware)]`) applied to the RPC
118+
endpoint and the authenticated extended agent card endpoint.
112119
113120
Returns:
114121
A list of Starlette Route objects.
115122
"""
123+
route_mw = list(middleware) if middleware else None
124+
116125
app_routes = [
117126
Route(
118127
rpc_url,
119128
self._handle_requests,
120129
methods=['POST'],
121130
name='a2a_handler',
131+
middleware=route_mw,
122132
),
123133
Route(
124134
agent_card_url,
@@ -148,6 +158,7 @@ def routes(
148158
self._handle_get_authenticated_extended_agent_card,
149159
methods=['GET'],
150160
name='authenticated_extended_agent_card',
161+
middleware=route_mw,
151162
)
152163
)
153164
return app_routes
@@ -158,6 +169,7 @@ def add_routes_to_app(
158169
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
159170
rpc_url: str = DEFAULT_RPC_URL,
160171
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
172+
middleware: Sequence[Middleware] | None = None,
161173
) -> None:
162174
"""Adds the routes to the Starlette application.
163175
@@ -166,11 +178,13 @@ def add_routes_to_app(
166178
agent_card_url: The URL path for the agent card endpoint.
167179
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
168180
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
181+
middleware: Optional sequence of Starlette Middleware.
169182
"""
170183
routes = self.routes(
171184
agent_card_url=agent_card_url,
172185
rpc_url=rpc_url,
173186
extended_agent_card_url=extended_agent_card_url,
187+
middleware=middleware,
174188
)
175189
app.routes.extend(routes)
176190

@@ -179,6 +193,7 @@ def build(
179193
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
180194
rpc_url: str = DEFAULT_RPC_URL,
181195
extended_agent_card_url: str = EXTENDED_AGENT_CARD_PATH,
196+
middleware: Sequence[Middleware] | None = None,
182197
**kwargs: Any,
183198
) -> Starlette:
184199
"""Builds and returns the Starlette application instance.
@@ -187,6 +202,7 @@ def build(
187202
agent_card_url: The URL path for the agent card endpoint.
188203
rpc_url: The URL path for the A2A JSON-RPC endpoint (POST requests).
189204
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
205+
middleware: Optional sequence of Starlette Middleware applied to authenticated routes.
190206
**kwargs: Additional keyword arguments to pass to the Starlette constructor.
191207
192208
Returns:
@@ -195,7 +211,7 @@ def build(
195211
app = Starlette(**kwargs)
196212

197213
self.add_routes_to_app(
198-
app, agent_card_url, rpc_url, extended_agent_card_url
214+
app, agent_card_url, rpc_url, extended_agent_card_url, middleware
199215
)
200216

201217
return app

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
import logging
22

3-
from collections.abc import Awaitable, Callable
3+
from collections.abc import Awaitable, Callable, Sequence
44
from typing import TYPE_CHECKING, Any
55

66

77
if TYPE_CHECKING:
88
from fastapi import APIRouter, FastAPI, Request, Response
9+
from fastapi.params import Depends
910
from fastapi.responses import JSONResponse
1011

1112
_package_fastapi_installed = True
1213
else:
1314
try:
1415
from fastapi import APIRouter, FastAPI, Request, Response
16+
from fastapi.params import Depends
1517
from fastapi.responses import JSONResponse
1618

1719
_package_fastapi_installed = True
@@ -92,6 +94,8 @@ def build(
9294
self,
9395
agent_card_url: str = AGENT_CARD_WELL_KNOWN_PATH,
9496
rpc_url: str = '',
97+
extended_agent_card_url: str = '',
98+
dependencies: Sequence[Depends] | None = None,
9599
**kwargs: Any,
96100
) -> FastAPI:
97101
"""Builds and returns the FastAPI application instance.
@@ -100,16 +104,29 @@ def build(
100104
agent_card_url: The URL for the agent card endpoint.
101105
rpc_url: The URL for the A2A JSON-RPC endpoint.
102106
extended_agent_card_url: The URL for the authenticated extended agent card endpoint.
107+
dependencies: Optional sequence of FastAPI dependencies (e.g.
108+
`[Security(get_current_active_user, scopes=["a2a"])]`)
109+
applied to the RPC endpoint and the authenticated extended
110+
agent card endpoint. The public agent card endpoint is left
111+
unprotected.
103112
**kwargs: Additional keyword arguments to pass to the FastAPI constructor.
104113
105114
Returns:
106115
A configured FastAPI application instance.
107116
"""
108117
app = FastAPI(**kwargs)
118+
119+
route_deps: dict[str, Any] = {}
120+
if dependencies:
121+
route_deps['dependencies'] = list(dependencies)
122+
109123
router = APIRouter()
110124
for route, callback in self._adapter.routes().items():
111125
router.add_api_route(
112-
f'{rpc_url}{route[0]}', callback, methods=[route[1]]
126+
f'{rpc_url}{route[0]}',
127+
callback,
128+
methods=[route[1]],
129+
**route_deps,
113130
)
114131

115132
@router.get(f'{rpc_url}{agent_card_url}')

tests/server/apps/jsonrpc/test_fastapi_app.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ def test_create_a2a_fastapi_app_with_present_deps_succeeds(
6161
' A2AFastAPIApplication instance should not raise ImportError'
6262
)
6363

64+
def test_build_a2a_fastapi_app_with_dependencies_succeeds(
65+
self, mock_app_params: dict
66+
):
67+
from fastapi import Depends
68+
69+
def mock_dependency():
70+
return 'mock'
71+
72+
app = A2AFastAPIApplication(**mock_app_params)
73+
fastapi_app = app.build(dependencies=[Depends(mock_dependency)])
74+
75+
from fastapi.routing import APIRoute
76+
77+
# Check that routes have the dependency
78+
for route in fastapi_app.routes:
79+
if getattr(route, 'path', '') in ['/v1/message:send', '/v1/card']:
80+
assert isinstance(route, APIRoute)
81+
assert len(route.dependencies) == 1
82+
assert route.dependencies[0].dependency == mock_dependency
83+
6484
def test_create_a2a_fastapi_app_with_missing_deps_raises_importerror(
6585
self,
6686
mock_app_params: dict,

tests/server/apps/jsonrpc/test_serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def test_starlette_agent_card_with_api_key_scheme_alias(
7070

7171
try:
7272
parsed_card = AgentCard.model_validate(response_data)
73+
assert parsed_card.security_schemes is not None
7374
parsed_scheme_wrapper = parsed_card.security_schemes['api_key_auth']
7475
assert isinstance(parsed_scheme_wrapper.root, APIKeySecurityScheme)
7576
assert parsed_scheme_wrapper.root.in_ == In.header

tests/server/apps/jsonrpc/test_starlette_app.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,30 @@ def test_create_a2a_starlette_app_with_present_deps_succeeds(
6666
' A2AStarletteApplication instance should not raise ImportError'
6767
)
6868

69+
def test_build_a2a_starlette_app_with_middleware_succeeds(
70+
self, mock_app_params: dict
71+
):
72+
from starlette.middleware import Middleware
73+
from starlette.middleware.base import BaseHTTPMiddleware
74+
75+
class MockMiddleware(BaseHTTPMiddleware):
76+
async def dispatch(self, request, call_next):
77+
return await call_next(request)
78+
79+
app = A2AStarletteApplication(**mock_app_params)
80+
starlette_app = app.build(middleware=[Middleware(MockMiddleware)])
81+
82+
from starlette.routing import Route
83+
84+
# Check that routes have the middleware
85+
for route in starlette_app.routes:
86+
if getattr(route, 'path', '') in [
87+
'/',
88+
'/agent/authenticatedExtendedCard',
89+
]:
90+
assert isinstance(route, Route)
91+
assert isinstance(route.app, MockMiddleware)
92+
6993
def test_create_a2a_starlette_app_with_missing_deps_raises_importerror(
7094
self,
7195
mock_app_params: dict,

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,32 @@ async def test_create_a2a_rest_fastapi_app_with_present_deps_succeeds(
160160
)
161161

162162

163+
@pytest.mark.anyio
164+
async def test_build_a2a_rest_fastapi_app_with_dependencies_succeeds(
165+
agent_card: AgentCard, request_handler: RequestHandler
166+
):
167+
from fastapi import Depends
168+
169+
def mock_dependency():
170+
return 'mock'
171+
172+
app = A2ARESTFastAPIApplication(agent_card, request_handler)
173+
fastapi_app = app.build(
174+
agent_card_url='/well-known/agent.json',
175+
rpc_url='',
176+
dependencies=[Depends(mock_dependency)],
177+
)
178+
179+
from fastapi.routing import APIRoute
180+
181+
# Check that routes have the dependency
182+
for route in fastapi_app.routes:
183+
if getattr(route, 'path', '') in ['/v1/message:send']:
184+
assert isinstance(route, APIRoute)
185+
assert len(route.dependencies) == 1
186+
assert route.dependencies[0].dependency == mock_dependency
187+
188+
163189
@pytest.mark.anyio
164190
async def test_create_a2a_rest_fastapi_app_with_missing_deps_raises_importerror(
165191
agent_card: AgentCard,

0 commit comments

Comments
 (0)