|
1 | 1 | import logging |
2 | | -from typing import Callable, List |
| 2 | +from typing import Callable |
3 | 3 |
|
4 | | -from fastapi import FastAPI |
| 4 | +from fastapi import FastAPI, Request, Response |
| 5 | +from starlette.middleware.base import BaseHTTPMiddleware |
5 | 6 | from starlette.middleware.cors import CORSMiddleware |
| 7 | +from starlette.types import ASGIApp |
6 | 8 |
|
7 | 9 | from app.core.config import settings |
8 | 10 |
|
9 | 11 | logger = logging.getLogger(__name__) |
10 | 12 |
|
| 13 | + |
| 14 | +class CustomCORSMiddleware(BaseHTTPMiddleware): |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + app: ASGIApp, |
| 18 | + ) -> None: |
| 19 | + super().__init__(app) |
| 20 | + self.allowed_origins = settings.all_cors_origins |
| 21 | + logger.info(f"Configured CORS origins: {self.allowed_origins}") |
| 22 | + |
| 23 | + async def dispatch(self, request: Request, call_next: Callable) -> Response: |
| 24 | + origin = request.headers.get("origin", "") |
| 25 | + logger.info(f"Request from origin: {origin}") |
| 26 | + |
| 27 | + # Handle preflight OPTIONS requests explicitly |
| 28 | + if request.method == "OPTIONS": |
| 29 | + response = Response(status_code=200) |
| 30 | + self._set_cors_headers(response, origin) |
| 31 | + return response |
| 32 | + |
| 33 | + # For all other requests, continue with normal processing |
| 34 | + response = await call_next(request) |
| 35 | + |
| 36 | + # Add CORS headers to response |
| 37 | + self._set_cors_headers(response, origin) |
| 38 | + |
| 39 | + return response |
| 40 | + |
| 41 | + def _set_cors_headers(self, response: Response, origin: str) -> None: |
| 42 | + if not origin: |
| 43 | + return |
| 44 | + |
| 45 | + # Force add origin if in production environment |
| 46 | + if settings.ENVIRONMENT != "local" and origin: |
| 47 | + response.headers["Access-Control-Allow-Origin"] = origin |
| 48 | + response.headers["Access-Control-Allow-Credentials"] = "true" |
| 49 | + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" |
| 50 | + response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept" |
| 51 | + response.headers["Access-Control-Max-Age"] = "3600" |
| 52 | + logger.info(f"Added CORS headers for origin: {origin}") |
| 53 | + elif origin in self.allowed_origins: |
| 54 | + response.headers["Access-Control-Allow-Origin"] = origin |
| 55 | + response.headers["Access-Control-Allow-Credentials"] = "true" |
| 56 | + response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" |
| 57 | + response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept" |
| 58 | + response.headers["Access-Control-Max-Age"] = "3600" |
| 59 | + logger.info(f"Added CORS headers for origin: {origin}") |
| 60 | + else: |
| 61 | + logger.warning(f"Origin not allowed: {origin}") |
| 62 | + |
| 63 | + |
11 | 64 | def setup_cors(app: FastAPI) -> None: |
12 | | - """ |
13 | | - Configure CORS for the application - must be called before app startup. |
14 | | - """ |
15 | | - origins = settings.all_cors_origins |
| 65 | + """Configure CORS for the application.""" |
16 | 66 |
|
17 | | - logger.info(f"Setting up CORS middleware with origins: {origins}") |
| 67 | + # Remove any existing CORS middleware |
| 68 | + app.middleware_stack = app.build_middleware_stack() |
18 | 69 |
|
19 | | - # Add CORS middleware |
20 | | - app.add_middleware( |
21 | | - CORSMiddleware, |
22 | | - allow_origins=origins or ["*"], |
23 | | - allow_credentials=True, |
24 | | - allow_methods=["*"], |
25 | | - allow_headers=["*"], |
26 | | - expose_headers=["Content-Disposition"], |
27 | | - max_age=3600, |
28 | | - ) |
| 70 | + # Add our custom CORS middleware |
| 71 | + app.add_middleware(CustomCORSMiddleware) |
29 | 72 |
|
30 | | - logger.info("CORS middleware configured") |
| 73 | + logger.info("Custom CORS middleware configured") |
0 commit comments