Skip to content

Commit bac2fde

Browse files
committed
Added middleware
1 parent c929cc0 commit bac2fde

6 files changed

Lines changed: 520 additions & 24 deletions

File tree

backend/app/api/routes/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,18 @@ def test_email(email_to: EmailStr) -> Message:
2929
@router.get("/health-check/")
3030
async def health_check() -> bool:
3131
return True
32+
33+
34+
@router.get("/cors-debug", tags=["utils"], response_model=dict)
35+
def cors_debug() -> dict:
36+
"""
37+
Endpoint for debugging CORS settings
38+
"""
39+
from app.core.config import settings
40+
41+
return {
42+
"cors_origins": settings.BACKEND_CORS_ORIGINS,
43+
"all_cors_origins": settings.all_cors_origins,
44+
"frontend_host": settings.FRONTEND_HOST,
45+
"environment": settings.ENVIRONMENT,
46+
}

backend/app/core/middleware.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import logging
2+
from typing import Callable
3+
4+
from fastapi import FastAPI, Request, Response
5+
from starlette.middleware.base import BaseHTTPMiddleware
6+
from starlette.middleware.cors import CORSMiddleware
7+
from starlette.types import ASGIApp
8+
9+
from app.core.config import settings
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class CustomCORSMiddleware(BaseHTTPMiddleware):
15+
def __init__(
16+
self,
17+
app: ASGIApp,
18+
) -> None:
19+
super().__init__(app)
20+
self.allowed_origins = settings.all_cors_origins
21+
logger.info(f"Configured CORS origins: {self.allowed_origins}")
22+
23+
async def dispatch(self, request: Request, call_next: Callable) -> Response:
24+
origin = request.headers.get("origin", "")
25+
logger.info(f"Request from origin: {origin}")
26+
27+
# Handle preflight OPTIONS requests explicitly
28+
if request.method == "OPTIONS":
29+
response = Response(status_code=200)
30+
self._set_cors_headers(response, origin)
31+
return response
32+
33+
# For all other requests, continue with normal processing
34+
response = await call_next(request)
35+
36+
# Add CORS headers to response
37+
self._set_cors_headers(response, origin)
38+
39+
return response
40+
41+
def _set_cors_headers(self, response: Response, origin: str) -> None:
42+
if not origin:
43+
return
44+
45+
# Force add origin if in production environment
46+
if settings.ENVIRONMENT != "local" and origin:
47+
response.headers["Access-Control-Allow-Origin"] = origin
48+
response.headers["Access-Control-Allow-Credentials"] = "true"
49+
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
50+
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept"
51+
response.headers["Access-Control-Max-Age"] = "3600"
52+
logger.info(f"Added CORS headers for origin: {origin}")
53+
elif origin in self.allowed_origins:
54+
response.headers["Access-Control-Allow-Origin"] = origin
55+
response.headers["Access-Control-Allow-Credentials"] = "true"
56+
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
57+
response.headers["Access-Control-Allow-Headers"] = "Authorization, Content-Type, Accept"
58+
response.headers["Access-Control-Max-Age"] = "3600"
59+
logger.info(f"Added CORS headers for origin: {origin}")
60+
else:
61+
logger.warning(f"Origin not allowed: {origin}")
62+
63+
64+
def setup_cors(app: FastAPI) -> None:
65+
"""Configure CORS for the application."""
66+
67+
# Remove any existing CORS middleware
68+
app.middleware_stack = app.build_middleware_stack()
69+
70+
# Add our custom CORS middleware
71+
app.add_middleware(CustomCORSMiddleware)
72+
73+
logger.info("Custom CORS middleware configured")

backend/app/main.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,57 @@
1-
import sentry_sdk
1+
import logging
2+
from contextlib import asynccontextmanager
3+
24
from fastapi import FastAPI
3-
from fastapi.routing import APIRoute
4-
from starlette.middleware.cors import CORSMiddleware
5+
from fastapi.staticfiles import StaticFiles
56

67
from app.api.main import api_router
78
from app.core.config import settings
9+
from app.core.middleware import setup_cors
10+
11+
logging.basicConfig(
12+
level=logging.INFO,
13+
format="%(asctime)s %(levelname)s: %(message)s",
14+
datefmt="%Y-%m-%d %H:%M:%S",
15+
)
16+
logger = logging.getLogger(__name__)
817

918

10-
def custom_generate_unique_id(route: APIRoute) -> str:
11-
return f"{route.tags[0]}-{route.name}"
19+
@asynccontextmanager
20+
async def lifespan(app: FastAPI):
21+
"""
22+
Function that runs before the application starts,
23+
and again when the application is shutting down.
1224
25+
For now, just a placeholder, but we can use this to setup
26+
database connections, etc.
27+
"""
28+
yield
1329

14-
if settings.SENTRY_DSN and settings.ENVIRONMENT != "local":
15-
sentry_sdk.init(dsn=str(settings.SENTRY_DSN), enable_tracing=True)
1630

1731
app = FastAPI(
18-
title=settings.PROJECT_NAME,
19-
openapi_url=f"{settings.API_V1_STR}/openapi.json",
20-
generate_unique_id_function=custom_generate_unique_id,
32+
title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json", lifespan=lifespan
2133
)
2234

23-
# Set all CORS enabled origins
24-
if settings.all_cors_origins:
25-
app.add_middleware(
26-
CORSMiddleware,
27-
allow_origins=settings.all_cors_origins,
28-
allow_credentials=True,
29-
allow_methods=["*"],
30-
allow_headers=["*"],
31-
)
35+
# Use our custom CORS middleware instead of the default FastAPI one
36+
setup_cors(app)
3237

3338
app.include_router(api_router, prefix=settings.API_V1_STR)
39+
40+
# Mount the static files directory to serve static files
41+
# app.mount("/static", StaticFiles(directory="static"), name="static")
42+
43+
44+
@app.get("/")
45+
def root():
46+
"""
47+
Root endpoint for health checks
48+
"""
49+
return {"message": "Hello! This is the API root. Go to /docs for API documentation."}
50+
51+
52+
@app.get("/health")
53+
def health_check():
54+
"""
55+
Health check endpoint
56+
"""
57+
return {"status": "ok"}

backend/app/models.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class User(UserBase, table=True):
4747
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
4848
hashed_password: str
4949
items: list["Item"] = Relationship(back_populates="owner", cascade_delete=True)
50+
tickets: list["Ticket"] = Relationship(back_populates="user", sa_relationship_kwargs={"cascade": "all, delete-orphan"})
5051

5152

5253
# Properties to return via API, id is always required
@@ -208,8 +209,3 @@ class TicketsPublic(SQLModel):
208209
data: list[TicketPublic]
209210
count: int
210211
page: int
211-
212-
213-
# Update User model to include tickets relationship
214-
User.model_rebuild()
215-
setattr(User, "tickets", Relationship(back_populates="user", sa_relationship_kwargs={"cascade": "all, delete-orphan"}))

0 commit comments

Comments
 (0)