|
24 | 24 | import threading |
25 | 25 | import time |
26 | 26 | from dataclasses import dataclass, field |
27 | | -from typing import Dict, Optional |
| 27 | +from typing import Annotated, Dict, List, Optional |
28 | 28 |
|
29 | 29 | try: |
30 | | - from fastapi import FastAPI, Header, HTTPException, Request |
| 30 | + from fastapi import Depends, FastAPI, Header, HTTPException, Request |
31 | 31 | from fastapi.middleware.cors import CORSMiddleware |
32 | 32 | from fastapi.staticfiles import StaticFiles |
33 | 33 | from pydantic import BaseModel |
@@ -113,100 +113,137 @@ class _AnswerIn(BaseModel): |
113 | 113 | sdp: str |
114 | 114 |
|
115 | 115 |
|
116 | | -def create_app(shared_secret: Optional[str] = None, |
117 | | - ttl_s: float = _DEFAULT_TTL_S, |
118 | | - serve_web_viewer: bool = True, |
119 | | - cors_origins: Optional[list] = None) -> FastAPI: |
120 | | - """Build the FastAPI app. Importable for embedding in larger services.""" |
121 | | - app = FastAPI(title="AutoControl Signaling", version="1.0.0") |
122 | | - store = _SessionStore(ttl_s=ttl_s) |
| 116 | +_AUTH_RESPONSES = {401: {"description": "bad shared secret"}} |
| 117 | +_VALIDATION_RESPONSES = { |
| 118 | + 400: {"description": "invalid host_id or sdp"}, |
| 119 | + **_AUTH_RESPONSES, |
| 120 | +} |
| 121 | +_NOT_FOUND_RESPONSES = { |
| 122 | + 404: {"description": "session or message not found"}, |
| 123 | + **_AUTH_RESPONSES, |
| 124 | +} |
| 125 | + |
| 126 | + |
| 127 | +def _build_secret_dependency(shared_secret: Optional[str]): |
| 128 | + """Return a FastAPI dependency that enforces ``X-Signaling-Secret``.""" |
| 129 | + def _check( |
| 130 | + x_signaling_secret: Annotated[ |
| 131 | + Optional[str], Header(alias="X-Signaling-Secret"), |
| 132 | + ] = None, |
| 133 | + ) -> None: |
| 134 | + if shared_secret and x_signaling_secret != shared_secret: |
| 135 | + raise HTTPException(status_code=401, detail="bad shared secret") |
| 136 | + return _check |
| 137 | + |
| 138 | + |
| 139 | +def _validate_host_id(host_id: str) -> None: |
| 140 | + if not host_id or len(host_id) > 128 or not host_id.isalnum(): |
| 141 | + raise HTTPException(status_code=400, detail="invalid host_id") |
| 142 | + |
| 143 | + |
| 144 | +def _validate_sdp(sdp: str) -> None: |
| 145 | + if not sdp or len(sdp.encode("utf-8")) > _MAX_SDP_BYTES: |
| 146 | + raise HTTPException(status_code=400, detail="invalid sdp size") |
| 147 | + |
| 148 | + |
| 149 | +def _configure_cors(app: FastAPI, cors_origins: Optional[List[str]]) -> None: |
| 150 | + # ``["*"]`` is the documented default — the signaling server is |
| 151 | + # meant to be reached from any browser tab running the viewer SPA; |
| 152 | + # access control runs at the X-Signaling-Secret layer, not Origin. |
| 153 | + # Operators tighten this via the repeatable --cors-origin CLI flag. |
123 | 154 | app.add_middleware( |
124 | 155 | CORSMiddleware, |
125 | | - allow_origins=cors_origins or ["*"], |
| 156 | + allow_origins=cors_origins or ["*"], # nosemgrep: python.fastapi.security.wildcard-cors.wildcard-cors |
126 | 157 | allow_methods=["GET", "POST", "DELETE", "OPTIONS"], |
127 | 158 | allow_headers=["Content-Type", "X-Signaling-Secret"], |
128 | 159 | ) |
| 160 | + |
| 161 | + |
| 162 | +def _maybe_mount_viewer(app: FastAPI, serve_web_viewer: bool) -> None: |
129 | 163 | if serve_web_viewer and _WEB_VIEWER_DIR.exists(): |
130 | 164 | app.mount( |
131 | 165 | "/viewer", |
132 | 166 | StaticFiles(directory=str(_WEB_VIEWER_DIR), html=True), |
133 | 167 | name="viewer", |
134 | 168 | ) |
135 | 169 |
|
136 | | - def _check_secret(secret_header: Optional[str]) -> None: |
137 | | - if shared_secret and secret_header != shared_secret: |
138 | | - raise HTTPException(status_code=401, detail="bad shared secret") |
139 | | - |
140 | | - def _validate_host_id(host_id: str) -> None: |
141 | | - if not host_id or len(host_id) > 128 or not host_id.isalnum(): |
142 | | - raise HTTPException(status_code=400, detail="invalid host_id") |
143 | 170 |
|
144 | | - def _validate_sdp(sdp: str) -> None: |
145 | | - if not sdp or len(sdp.encode("utf-8")) > _MAX_SDP_BYTES: |
146 | | - raise HTTPException(status_code=400, detail="invalid sdp size") |
| 171 | +def _register_routes(app: FastAPI, store: "_SessionStore", |
| 172 | + secret_dep) -> None: |
| 173 | + # Apply the auth dependency at the route layer so each handler's |
| 174 | + # signature stays free of plumbing parameters. The dependency |
| 175 | + # itself uses the recommended ``Annotated[Optional[str], Header(...)]`` |
| 176 | + # form for its ``X-Signaling-Secret`` parameter — see |
| 177 | + # ``_build_secret_dependency`` above. |
| 178 | + auth_only = [Depends(secret_dep)] |
147 | 179 |
|
148 | 180 | @app.get("/health") |
149 | 181 | def _health() -> dict: |
150 | 182 | return {"status": "ok"} |
151 | 183 |
|
152 | | - @app.post("/sessions/{host_id}/offer") |
153 | | - def _post_offer(host_id: str, body: _OfferIn, |
154 | | - x_signaling_secret: Optional[str] = Header(default=None) |
155 | | - ) -> dict: |
156 | | - _check_secret(x_signaling_secret) |
| 184 | + @app.post("/sessions/{host_id}/offer", |
| 185 | + responses=_VALIDATION_RESPONSES, dependencies=auth_only) |
| 186 | + def _post_offer(host_id: str, body: _OfferIn) -> dict: |
157 | 187 | _validate_host_id(host_id) |
158 | 188 | _validate_sdp(body.sdp) |
159 | 189 | store.upsert_offer(host_id, body.sdp) |
160 | 190 | return {"ok": True} |
161 | 191 |
|
162 | | - @app.get("/sessions/{host_id}/offer") |
163 | | - def _get_offer(host_id: str, |
164 | | - x_signaling_secret: Optional[str] = Header(default=None) |
165 | | - ) -> dict: |
166 | | - _check_secret(x_signaling_secret) |
| 192 | + @app.get("/sessions/{host_id}/offer", |
| 193 | + responses=_NOT_FOUND_RESPONSES, dependencies=auth_only) |
| 194 | + def _get_offer(host_id: str) -> dict: |
167 | 195 | _validate_host_id(host_id) |
168 | 196 | sdp = store.fetch_offer(host_id) |
169 | 197 | if sdp is None: |
170 | 198 | raise HTTPException(status_code=404, detail="no offer pending") |
171 | 199 | return {"sdp": sdp} |
172 | 200 |
|
173 | | - @app.post("/sessions/{host_id}/answer") |
174 | | - def _post_answer(host_id: str, body: _AnswerIn, |
175 | | - x_signaling_secret: Optional[str] = Header(default=None) |
176 | | - ) -> dict: |
177 | | - _check_secret(x_signaling_secret) |
| 201 | + @app.post("/sessions/{host_id}/answer", |
| 202 | + responses={**_VALIDATION_RESPONSES, **_NOT_FOUND_RESPONSES}, |
| 203 | + dependencies=auth_only) |
| 204 | + def _post_answer(host_id: str, body: _AnswerIn) -> dict: |
178 | 205 | _validate_host_id(host_id) |
179 | 206 | _validate_sdp(body.sdp) |
180 | 207 | if not store.upsert_answer(host_id, body.sdp): |
181 | 208 | raise HTTPException(status_code=404, detail="no offer to match") |
182 | 209 | return {"ok": True} |
183 | 210 |
|
184 | | - @app.get("/sessions/{host_id}/answer") |
185 | | - def _get_answer(host_id: str, |
186 | | - x_signaling_secret: Optional[str] = Header(default=None) |
187 | | - ) -> dict: |
188 | | - _check_secret(x_signaling_secret) |
| 211 | + @app.get("/sessions/{host_id}/answer", |
| 212 | + responses=_NOT_FOUND_RESPONSES, dependencies=auth_only) |
| 213 | + def _get_answer(host_id: str) -> dict: |
189 | 214 | _validate_host_id(host_id) |
190 | 215 | sdp = store.fetch_answer(host_id) |
191 | 216 | if sdp is None: |
192 | 217 | raise HTTPException(status_code=404, detail="no answer yet") |
193 | 218 | return {"sdp": sdp} |
194 | 219 |
|
195 | | - @app.delete("/sessions/{host_id}") |
196 | | - def _delete(host_id: str, |
197 | | - x_signaling_secret: Optional[str] = Header(default=None) |
198 | | - ) -> dict: |
199 | | - _check_secret(x_signaling_secret) |
| 220 | + @app.delete("/sessions/{host_id}", |
| 221 | + responses=_AUTH_RESPONSES, dependencies=auth_only) |
| 222 | + def _delete(host_id: str) -> dict: |
200 | 223 | _validate_host_id(host_id) |
201 | 224 | return {"deleted": store.delete(host_id)} |
202 | 225 |
|
| 226 | + |
| 227 | +def _register_request_logging(app: FastAPI) -> None: |
203 | 228 | @app.middleware("http") |
204 | 229 | async def _log_request(request: Request, call_next): |
205 | 230 | response = await call_next(request) |
206 | 231 | _LOG.info("%s %s -> %d", request.method, request.url.path, |
207 | 232 | response.status_code) |
208 | 233 | return response |
209 | 234 |
|
| 235 | + |
| 236 | +def create_app(shared_secret: Optional[str] = None, |
| 237 | + ttl_s: float = _DEFAULT_TTL_S, |
| 238 | + serve_web_viewer: bool = True, |
| 239 | + cors_origins: Optional[list] = None) -> FastAPI: |
| 240 | + """Build the FastAPI app. Importable for embedding in larger services.""" |
| 241 | + app = FastAPI(title="AutoControl Signaling", version="1.0.0") |
| 242 | + store = _SessionStore(ttl_s=ttl_s) |
| 243 | + _configure_cors(app, cors_origins) |
| 244 | + _maybe_mount_viewer(app, serve_web_viewer) |
| 245 | + _register_routes(app, store, _build_secret_dependency(shared_secret)) |
| 246 | + _register_request_logging(app) |
210 | 247 | return app |
211 | 248 |
|
212 | 249 |
|
|
0 commit comments