|
1 | 1 | import logging |
2 | 2 |
|
3 | | -from collections.abc import AsyncIterator |
4 | | -from contextlib import asynccontextmanager |
5 | 3 | from typing import TYPE_CHECKING, Any |
6 | 4 |
|
7 | 5 |
|
|
36 | 34 | logger = logging.getLogger(__name__) |
37 | 35 |
|
38 | 36 |
|
| 37 | +class A2AFastAPI(FastAPI): |
| 38 | + """A FastAPI application that adds A2A-specific OpenAPI components.""" |
| 39 | + |
| 40 | + _a2a_components_added: bool = False |
| 41 | + |
| 42 | + def openapi(self) -> dict[str, Any]: |
| 43 | + """Generates the OpenAPI schema for the application.""" |
| 44 | + openapi_schema = super().openapi() |
| 45 | + if not self._a2a_components_added: |
| 46 | + a2a_request_schema = A2ARequest.model_json_schema( |
| 47 | + ref_template='#/components/schemas/{model}' |
| 48 | + ) |
| 49 | + defs = a2a_request_schema.pop('$defs', {}) |
| 50 | + component_schemas = openapi_schema.setdefault( |
| 51 | + 'components', {} |
| 52 | + ).setdefault('schemas', {}) |
| 53 | + component_schemas.update(defs) |
| 54 | + component_schemas['A2ARequest'] = a2a_request_schema |
| 55 | + self._a2a_components_added = True |
| 56 | + return openapi_schema |
| 57 | + |
| 58 | + |
39 | 59 | class A2AFastAPIApplication(JSONRPCApplication): |
40 | 60 | """A FastAPI application implementing the A2A protocol server endpoints. |
41 | 61 |
|
@@ -139,23 +159,7 @@ def build( |
139 | 159 | Returns: |
140 | 160 | A configured FastAPI application instance. |
141 | 161 | """ |
142 | | - |
143 | | - @asynccontextmanager |
144 | | - async def lifespan(app: FastAPI) -> AsyncIterator[None]: |
145 | | - a2a_request_schema = A2ARequest.model_json_schema( |
146 | | - ref_template='#/components/schemas/{model}' |
147 | | - ) |
148 | | - defs = a2a_request_schema.pop('$defs', {}) |
149 | | - openapi_schema = app.openapi() |
150 | | - component_schemas = openapi_schema.setdefault( |
151 | | - 'components', {} |
152 | | - ).setdefault('schemas', {}) |
153 | | - component_schemas.update(defs) |
154 | | - component_schemas['A2ARequest'] = a2a_request_schema |
155 | | - |
156 | | - yield |
157 | | - |
158 | | - app = FastAPI(lifespan=lifespan, **kwargs) |
| 162 | + app = A2AFastAPI(**kwargs) |
159 | 163 |
|
160 | 164 | self.add_routes_to_app( |
161 | 165 | app, agent_card_url, rpc_url, extended_agent_card_url |
|
0 commit comments