|
1 | 1 | import logging |
2 | | -from typing import Callable |
| 2 | +from typing import Callable, List |
3 | 3 |
|
4 | | -from fastapi import FastAPI, Request, Response |
5 | | -from starlette.middleware.base import BaseHTTPMiddleware |
| 4 | +from fastapi import FastAPI |
6 | 5 | from starlette.middleware.cors import CORSMiddleware |
7 | | -from starlette.types import ASGIApp |
8 | 6 |
|
9 | 7 | from app.core.config import settings |
10 | 8 |
|
11 | 9 | logger = logging.getLogger(__name__) |
12 | 10 |
|
13 | | - |
14 | | -class CustomCORSMiddleware(BaseHTTPMiddleware): |
15 | | - def __init__( |
16 | | - self, |
17 | | - app: ASGIApp, |
18 | | - ) -> None: |
19 | | - super().__init__(app) |
20 | | - self.allowed_origins = settings.all_cors_origins |
21 | | - logger.info(f"Configured CORS origins: {self.allowed_origins}") |
22 | | - |
23 | | - async def dispatch(self, request: Request, call_next: Callable) -> Response: |
24 | | - origin = request.headers.get("origin", "") |
25 | | - logger.info(f"Request from origin: {origin}") |
26 | | - |
27 | | - # Handle preflight OPTIONS requests explicitly |
28 | | - if request.method == "OPTIONS": |
29 | | - response = Response(status_code=200) |
30 | | - self._set_cors_headers(response, origin) |
31 | | - return response |
32 | | - |
33 | | - # For all other requests, continue with normal processing |
34 | | - response = await call_next(request) |
35 | | - |
36 | | - # Add CORS headers to response |
37 | | - self._set_cors_headers(response, origin) |
38 | | - |
39 | | - return response |
40 | | - |
41 | | - def _set_cors_headers(self, response: Response, origin: str) -> None: |
42 | | - if not origin: |
43 | | - return |
44 | | - |
45 | | - # Force add origin if in production environment |
46 | | - if settings.ENVIRONMENT != "local" and origin: |
47 | | - response.headers["Access-Control-Allow-Origin"] = origin |
48 | | - response.headers["Access-Control-Allow-Credentials"] = "true" |
49 | | - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" |
50 | | - response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept" |
51 | | - response.headers["Access-Control-Max-Age"] = "3600" |
52 | | - logger.info(f"Added CORS headers for origin: {origin}") |
53 | | - elif origin in self.allowed_origins: |
54 | | - response.headers["Access-Control-Allow-Origin"] = origin |
55 | | - response.headers["Access-Control-Allow-Credentials"] = "true" |
56 | | - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" |
57 | | - response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept" |
58 | | - response.headers["Access-Control-Max-Age"] = "3600" |
59 | | - logger.info(f"Added CORS headers for origin: {origin}") |
60 | | - else: |
61 | | - logger.warning(f"Origin not allowed: {origin}") |
62 | | - |
63 | | - |
64 | 11 | def setup_cors(app: FastAPI) -> None: |
65 | | - """Configure CORS for the application.""" |
| 12 | + """ |
| 13 | + Configure CORS for the application - must be called before app startup. |
| 14 | + """ |
| 15 | + origins = settings.all_cors_origins |
66 | 16 |
|
67 | | - # Remove any existing CORS middleware |
68 | | - app.middleware_stack = app.build_middleware_stack() |
| 17 | + logger.info(f"Setting up CORS middleware with origins: {origins}") |
69 | 18 |
|
70 | | - # Add our custom CORS middleware |
71 | | - app.add_middleware(CustomCORSMiddleware) |
| 19 | + # Add CORS middleware |
| 20 | + app.add_middleware( |
| 21 | + CORSMiddleware, |
| 22 | + allow_origins=origins or ["*"], |
| 23 | + allow_credentials=True, |
| 24 | + allow_methods=["*"], |
| 25 | + allow_headers=["*"], |
| 26 | + expose_headers=["Content-Disposition"], |
| 27 | + max_age=3600, |
| 28 | + ) |
72 | 29 |
|
73 | | - logger.info("Custom CORS middleware configured") |
| 30 | + logger.info("CORS middleware configured") |
0 commit comments