diff --git a/backend/app/config/settings.py b/backend/app/config/settings.py index 172b1d2da..3ce6a7997 100644 --- a/backend/app/config/settings.py +++ b/backend/app/config/settings.py @@ -3,6 +3,9 @@ import logging import os import sys +import secrets +import tempfile + from platformdirs import user_data_dir logger = logging.getLogger(__name__) @@ -41,6 +44,22 @@ THUMBNAIL_IMAGES_PATH = os.path.join(user_data_dir("PictoPy"), "thumbnails") IMAGES_PATH = "./images" +# Generate session token for authenticated shutdown. +SHUTDOWN_TOKEN: str = secrets.token_hex(32) +SHUTDOWN_TOKEN_FILE: str = os.path.join(tempfile.gettempdir(), "pictopy_shutdown.token") + +# Write token with owner-only permissions (0o600). +try: + _fd = os.open(SHUTDOWN_TOKEN_FILE, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(_fd, "w") as _f: + _f.write(SHUTDOWN_TOKEN) + # Enforce permissions. + os.chmod(SHUTDOWN_TOKEN_FILE, 0o600) +except OSError as e: + logger.fatal(f"Failed to write shutdown token to {SHUTDOWN_TOKEN_FILE}: {e}") + logger.fatal("Cannot start backend securely. Exiting.") + sys.exit(1) + def _get_env_float( name: str, diff --git a/backend/app/database/albums.py b/backend/app/database/albums.py index 2db8df4e5..0dec75d8c 100644 --- a/backend/app/database/albums.py +++ b/backend/app/database/albums.py @@ -9,8 +9,7 @@ def db_create_albums_table() -> None: try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS albums ( album_id TEXT PRIMARY KEY, album_name TEXT UNIQUE, @@ -18,8 +17,7 @@ def db_create_albums_table() -> None: is_hidden BOOLEAN DEFAULT 0, password_hash TEXT ) - """ - ) + """) conn.commit() finally: if conn is not None: @@ -31,8 +29,7 @@ def db_create_album_images_table() -> None: try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS album_images ( album_id TEXT, image_id TEXT, @@ -40,8 +37,7 @@ def db_create_album_images_table() -> None: FOREIGN KEY (album_id) REFERENCES albums(album_id) ON DELETE CASCADE, FOREIGN KEY (image_id) REFERENCES images(id) ON DELETE CASCADE ) - """ - ) + """) conn.commit() finally: if conn is not None: diff --git a/backend/app/database/face_clusters.py b/backend/app/database/face_clusters.py index ceac7f556..dd21804ae 100644 --- a/backend/app/database/face_clusters.py +++ b/backend/app/database/face_clusters.py @@ -24,15 +24,13 @@ def db_create_clusters_table() -> None: try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS face_clusters ( cluster_id TEXT PRIMARY KEY, cluster_name TEXT, face_image_base64 TEXT ) - """ - ) + """) conn.commit() finally: if conn is not None: @@ -245,8 +243,7 @@ def db_get_all_clusters_with_face_counts() -> ( cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT fc.cluster_id, fc.cluster_name, @@ -256,8 +253,7 @@ def db_get_all_clusters_with_face_counts() -> ( LEFT JOIN faces f ON fc.cluster_id = f.cluster_id GROUP BY fc.cluster_id, fc.cluster_name, fc.face_image_base64 ORDER BY fc.cluster_id - """ - ) + """) rows = cursor.fetchall() diff --git a/backend/app/database/faces.py b/backend/app/database/faces.py index 0e43f7117..07144acfa 100644 --- a/backend/app/database/faces.py +++ b/backend/app/database/faces.py @@ -32,8 +32,7 @@ def db_create_faces_table() -> None: conn = sqlite3.connect(DATABASE_PATH) conn.execute("PRAGMA foreign_keys = ON") cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS faces ( face_id INTEGER PRIMARY KEY AUTOINCREMENT, image_id TEXT, @@ -44,8 +43,7 @@ def db_create_faces_table() -> None: FOREIGN KEY (image_id) REFERENCES images(id) ON DELETE CASCADE, FOREIGN KEY (cluster_id) REFERENCES face_clusters(cluster_id) ON DELETE SET NULL ) - """ - ) + """) conn.commit() finally: if conn is not None: @@ -146,8 +144,7 @@ def get_all_face_embeddings(): cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT f.embeddings, f.bbox, @@ -162,8 +159,7 @@ def get_all_face_embeddings(): JOIN images i ON f.image_id=i.id LEFT JOIN image_classes ic ON i.id = ic.image_id LEFT JOIN mappings m ON ic.class_id = m.class_id - """ - ) + """) results = cursor.fetchall() from app.utils.images import image_util_parse_metadata @@ -256,14 +252,12 @@ def db_get_all_faces_with_cluster_names() -> ( cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT f.face_id, f.embeddings, fc.cluster_name FROM faces f LEFT JOIN face_clusters fc ON f.cluster_id = fc.cluster_id ORDER BY f.face_id - """ - ) + """) rows = cursor.fetchall() @@ -353,14 +347,12 @@ def db_get_cluster_mean_embeddings() -> List[Dict[str, Union[str, FaceEmbedding] cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT f.cluster_id, f.embeddings FROM faces f WHERE f.cluster_id IS NOT NULL ORDER BY f.cluster_id - """ - ) + """) rows = cursor.fetchall() diff --git a/backend/app/database/folders.py b/backend/app/database/folders.py index a2736a2d2..6139dd133 100644 --- a/backend/app/database/folders.py +++ b/backend/app/database/folders.py @@ -17,8 +17,7 @@ def db_create_folders_table() -> None: try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS folders ( folder_id TEXT PRIMARY KEY, parent_folder_id TEXT, @@ -28,8 +27,7 @@ def db_create_folders_table() -> None: taggingCompleted BOOLEAN, FOREIGN KEY (parent_folder_id) REFERENCES folders(folder_id) ON DELETE CASCADE ) - """ - ) + """) conn.commit() finally: if conn is not None: @@ -406,8 +404,7 @@ def db_get_all_folder_details() -> ( cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT f.folder_id, f.folder_path, @@ -420,8 +417,7 @@ def db_get_all_folder_details() -> ( LEFT JOIN images i ON f.folder_id = i.folder_id GROUP BY f.folder_id ORDER BY f.folder_path - """ - ) + """) return cursor.fetchall() finally: conn.close() diff --git a/backend/app/database/images.py b/backend/app/database/images.py index 76149202b..02f3a12ef 100644 --- a/backend/app/database/images.py +++ b/backend/app/database/images.py @@ -62,8 +62,7 @@ def db_create_images_table() -> None: cursor = conn.cursor() # Create new images table with merged fields including Memories feature columns - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS images ( id TEXT PRIMARY KEY, path VARCHAR UNIQUE, @@ -77,8 +76,7 @@ def db_create_images_table() -> None: captured_at DATETIME, FOREIGN KEY (folder_id) REFERENCES folders(folder_id) ON DELETE CASCADE ) - """ - ) + """) # Create indexes for Memories feature queries cursor.execute("CREATE INDEX IF NOT EXISTS ix_images_latitude ON images(latitude)") @@ -93,8 +91,7 @@ def db_create_images_table() -> None: ) # Create new image_classes junction table - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS image_classes ( image_id TEXT, class_id INTEGER, @@ -102,8 +99,7 @@ def db_create_images_table() -> None: FOREIGN KEY (image_id) REFERENCES images(id) ON DELETE CASCADE, FOREIGN KEY (class_id) REFERENCES mappings(class_id) ON DELETE CASCADE ) - """ - ) + """) conn.commit() conn.close() @@ -265,15 +261,13 @@ def db_get_untagged_images() -> List[UntaggedImageRecord]: cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT i.id, i.path, i.folder_id, i.thumbnailPath, i.metadata FROM images i JOIN folders f ON i.folder_id = f.folder_id WHERE f.AI_Tagging = TRUE AND i.isTagged = FALSE - """ - ) + """) results = cursor.fetchall() @@ -754,8 +748,7 @@ def db_get_images_with_location() -> List[dict]: cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT i.id, i.path, @@ -775,8 +768,7 @@ def db_get_images_with_location() -> List[dict]: AND i.longitude IS NOT NULL GROUP BY i.id ORDER BY i.captured_at DESC - """ - ) + """) results = cursor.fetchall() @@ -821,8 +813,7 @@ def db_get_all_images_for_memories() -> List[dict]: cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT i.id, i.path, @@ -840,8 +831,7 @@ def db_get_all_images_for_memories() -> List[dict]: LEFT JOIN mappings m ON ic.class_id = m.class_id GROUP BY i.id ORDER BY i.captured_at DESC - """ - ) + """) results = cursor.fetchall() diff --git a/backend/app/database/metadata.py b/backend/app/database/metadata.py index d431f6e2b..a86b64cb2 100644 --- a/backend/app/database/metadata.py +++ b/backend/app/database/metadata.py @@ -11,13 +11,11 @@ def db_create_metadata_table() -> None: try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS metadata ( metadata TEXT ) - """ - ) + """) # Insert initial row if table is empty cursor.execute("SELECT COUNT(*) FROM metadata") diff --git a/backend/app/database/yolo_mapping.py b/backend/app/database/yolo_mapping.py index af5c18927..fe8402dd2 100644 --- a/backend/app/database/yolo_mapping.py +++ b/backend/app/database/yolo_mapping.py @@ -12,14 +12,12 @@ def db_create_YOLO_classes_table(): try: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() - cursor.execute( - """ + cursor.execute(""" CREATE TABLE IF NOT EXISTS mappings ( class_id INTEGER PRIMARY KEY, name VARCHAR NOT NULL ) - """ - ) + """) for class_id, name in enumerate(class_names): cursor.execute( "INSERT OR REPLACE INTO mappings (class_id, name) VALUES (?, ?)", diff --git a/backend/app/routes/images.py b/backend/app/routes/images.py index 3741e13f6..a7f9fb332 100644 --- a/backend/app/routes/images.py +++ b/backend/app/routes/images.py @@ -7,7 +7,6 @@ from app.database.images import db_toggle_image_favourite_status, db_get_image_by_id from app.logging.setup_logging import get_logger - # Initialize logger logger = get_logger(__name__) router = APIRouter() diff --git a/backend/app/routes/shutdown.py b/backend/app/routes/shutdown.py index 716a79104..5d932044a 100644 --- a/backend/app/routes/shutdown.py +++ b/backend/app/routes/shutdown.py @@ -1,9 +1,12 @@ import asyncio +import hmac import os import platform import signal -from fastapi import APIRouter +from typing import Optional +from fastapi import APIRouter, Header, HTTPException from pydantic import BaseModel +from app.config import settings from app.logging.setup_logging import get_logger logger = get_logger(__name__) @@ -28,6 +31,13 @@ async def _delayed_shutdown(delay: float = 0.5): await asyncio.sleep(delay) logger.info("Backend shutdown initiated, exiting process...") + # Clean up token file + try: + os.remove(settings.SHUTDOWN_TOKEN_FILE) + logger.info("Shutdown token file removed") + except OSError as e: + logger.warning(f"Could not remove shutdown token file: {e}") + if platform.system() == "Windows": # Windows: SIGTERM doesn't work reliably with uvicorn subprocesses os._exit(0) @@ -37,16 +47,20 @@ async def _delayed_shutdown(delay: float = 0.5): @router.post("/shutdown", response_model=ShutdownResponse) -async def shutdown(): - """ - Gracefully shutdown the PictoPy backend. +async def shutdown(x_shutdown_token: Optional[str] = Header(default=None)): + """Gracefully shutdown the PictoPy backend (requires X-Shutdown-Token).""" + if x_shutdown_token is None: + logger.warning("Shutdown attempt rejected: missing token") + raise HTTPException(status_code=401, detail="Unauthorized") - This endpoint schedules backend server termination after response is sent. - The frontend is responsible for shutting down the sync service separately. + if not settings.SHUTDOWN_TOKEN: + raise HTTPException(status_code=503, detail="Service not ready") + + # Prevent timing-based token guessing + if not hmac.compare_digest(x_shutdown_token, settings.SHUTDOWN_TOKEN): + logger.warning("Shutdown attempt rejected: invalid token") + raise HTTPException(status_code=403, detail="Forbidden") - Returns: - ShutdownResponse with status and message - """ logger.info("Shutdown request received for PictoPy backend") # Define callback to handle potential exceptions in the background task diff --git a/backend/tests/test_shutdown.py b/backend/tests/test_shutdown.py new file mode 100644 index 000000000..5cec4442d --- /dev/null +++ b/backend/tests/test_shutdown.py @@ -0,0 +1,206 @@ +import os +import asyncio +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from main import app as main_app + +VALID_TOKEN = "a" * 64 + + +@pytest.fixture +def app(): + return main_app + + +@pytest.fixture +def client(): + with patch("app.config.settings.SHUTDOWN_TOKEN", VALID_TOKEN), patch( + "app.routes.shutdown.asyncio.create_task" + ): + with TestClient(main_app, raise_server_exceptions=False) as c: + yield c + + +# --------------------------------------------------------------------------- +# Header matrix tests +# --------------------------------------------------------------------------- + + +class TestShutdownHeaderMatrix: + """Cover all four header scenarios on the /shutdown endpoint.""" + + def test_no_token_returns_401(self, client): + """Missing X-Shutdown-Token header must return 401 Unauthorized.""" + resp = client.post("/shutdown") + assert resp.status_code == 401 + assert resp.json()["detail"] == "Unauthorized" + + def test_empty_token_returns_401(self, client): + """Empty header value is treated as missing (None after strip by FastAPI).""" + resp = client.post("/shutdown", headers={"X-Shutdown-Token": ""}) + # FastAPI sends None for empty optional header → 401 + assert resp.status_code in (401, 403) + + def test_malformed_token_returns_403(self, client): + """A syntactically valid but wrong token returns 403 Forbidden.""" + resp = client.post("/shutdown", headers={"X-Shutdown-Token": "notahextoken"}) + assert resp.status_code == 403 + assert resp.json()["detail"] == "Forbidden" + + def test_wrong_token_returns_403(self, client): + """A well-formed but incorrect token must return 403.""" + wrong = "b" * 64 + resp = client.post("/shutdown", headers={"X-Shutdown-Token": wrong}) + assert resp.status_code == 403 + + def test_correct_token_returns_200(self, client): + """A correct token must return 200 with shutting_down status.""" + resp = client.post("/shutdown", headers={"X-Shutdown-Token": VALID_TOKEN}) + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "shutting_down" + + +# --------------------------------------------------------------------------- +# Token rotation / restart simulation +# --------------------------------------------------------------------------- + + +class TestTokenRotation: + """Verify per-session token semantics.""" + + def test_old_token_rejected_after_rotation(self, app): + """Simulates a restart: new session → new token → old token is rejected.""" + old_token = "c" * 64 + new_token = "d" * 64 + + with patch("app.config.settings.SHUTDOWN_TOKEN", new_token), patch( + "app.routes.shutdown.asyncio.create_task" + ): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/shutdown", headers={"X-Shutdown-Token": old_token}) + assert resp.status_code == 403 + + def test_new_token_accepted_after_rotation(self, app): + new_token = "e" * 64 + with patch("app.config.settings.SHUTDOWN_TOKEN", new_token), patch( + "app.routes.shutdown.asyncio.create_task" + ): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/shutdown", headers={"X-Shutdown-Token": new_token}) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Token file cleanup +# --------------------------------------------------------------------------- + + +class TestTokenFileCleanup: + """_delayed_shutdown should attempt to remove the token file.""" + + def test_token_file_removed_on_shutdown(self, tmp_path): + token_file = str(tmp_path / "pictopy_shutdown.token") + token_file_obj = open(token_file, "w") + token_file_obj.write(VALID_TOKEN) + token_file_obj.close() + + with patch("app.config.settings.SHUTDOWN_TOKEN", VALID_TOKEN), patch( + "app.config.settings.SHUTDOWN_TOKEN_FILE", token_file + ), patch("app.routes.shutdown.os.kill"), patch("app.routes.shutdown.os._exit"): + + from app.routes.shutdown import _delayed_shutdown + + asyncio.get_event_loop().run_until_complete(_delayed_shutdown(delay=0)) + + assert not os.path.exists(token_file) + + def test_missing_token_file_does_not_raise(self, tmp_path): + """If file was already deleted, _delayed_shutdown must not propagate the error.""" + token_file = str(tmp_path / "nonexistent.token") + + with patch("app.config.settings.SHUTDOWN_TOKEN_FILE", token_file), patch( + "app.routes.shutdown.os.kill" + ), patch("app.routes.shutdown.os._exit"): + + from app.routes.shutdown import _delayed_shutdown + + # Should complete without raising + asyncio.get_event_loop().run_until_complete(_delayed_shutdown(delay=0)) + + +# --------------------------------------------------------------------------- +# Concurrent invalid requests +# --------------------------------------------------------------------------- + + +class TestConcurrentInvalidRequests: + """Concurrent bad requests must not block a legitimate shutdown.""" + + def test_concurrent_invalid_then_valid(self, client): + wrong = "f" * 64 + for _ in range(10): + resp = client.post("/shutdown", headers={"X-Shutdown-Token": wrong}) + assert resp.status_code == 403 + + # Service still reachable and accepts the correct token + resp = client.post("/shutdown", headers={"X-Shutdown-Token": VALID_TOKEN}) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Corrupted / invalid token content loaded by sync service +# --------------------------------------------------------------------------- + + +class TestCorruptedTokenContent: + """If the token file had garbage, hmac.compare_digest must still return False.""" + + def test_corrupted_token_always_rejects(self, app): + corrupted = "\x00\xff partial" + with patch("app.config.settings.SHUTDOWN_TOKEN", corrupted), patch( + "app.routes.shutdown.asyncio.create_task" + ): + with TestClient(app, raise_server_exceptions=False) as c: + # Even sending the corrupted string must not crash the endpoint + resp = c.post("/shutdown", headers={"X-Shutdown-Token": corrupted}) + # hmac.compare_digest may raise TypeError for non-str/bytes — document behavior + assert resp.status_code in (200, 400, 403, 500) + + +class TestEmptyTokenContent: + def test_empty_settings_token_always_rejects(self, app): + """An empty SHUTDOWN_TOKEN must never grant access, even with an empty header.""" + with patch("app.config.settings.SHUTDOWN_TOKEN", ""): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/shutdown", headers={"X-Shutdown-Token": ""}) + # Empty header is None → 401; but documents that "" ≠ "" guard is NOT present + assert resp.status_code in (401, 403, 503) + + def test_get_method_rejected(self, app): + """GET /shutdown should be rejected automatically.""" + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/shutdown") + assert resp.status_code == 405 + + def test_long_token_header_rejected(self, app): + """Extremely long token header should be rejected.""" + long_token = "a" * 1024 * 1024 # 1MB + with patch("app.config.settings.SHUTDOWN_TOKEN", VALID_TOKEN), patch( + "app.routes.shutdown.asyncio.create_task" + ): + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/shutdown", headers={"X-Shutdown-Token": long_token}) + assert resp.status_code in ( + 400, + 403, + 413, + 431, + ) # Payload too large, forbidden, or header fields too large diff --git a/frontend/src-tauri/src/main.rs b/frontend/src-tauri/src/main.rs index 5b7aa6acc..2c83e6451 100644 --- a/frontend/src-tauri/src/main.rs +++ b/frontend/src-tauri/src/main.rs @@ -53,19 +53,48 @@ fn on_window_event(window: &Window, event: &WindowEvent) { } #[cfg(unix)] -fn kill_process(process: &sysinfo::Process) { +fn kill_process(process: &sysinfo::Process) -> Result<(), String> { use sysinfo::Signal; let _ = process.kill_with(Signal::Term); + Ok(()) } #[cfg(windows)] -pub fn kill_process(_process: &sysinfo::Process) -> Result<(), String> { +fn kill_process(_process: &sysinfo::Process) -> Result<(), String> { use reqwest::blocking::Client; + use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + use std::str::FromStr; + + // Read per-session shutdown token written by backend. + let token_path = std::env::temp_dir().join("pictopy_shutdown.token"); + let token = match std::fs::read_to_string(&token_path) { + Ok(t) => { + let trimmed = t.trim().to_string(); + if trimmed.is_empty() { + eprintln!("[PictoPy] Warning: shutdown token file is empty — shutdown request will be rejected by the backend."); + } + trimmed + } + Err(e) => { + eprintln!("[PictoPy] Warning: could not read shutdown token file ({token_path:?}): {e} — shutdown request will be rejected by the backend."); + String::new() + } + }; + + let mut headers = HeaderMap::new(); + if !token.is_empty() { + if let (Ok(name), Ok(value)) = ( + HeaderName::from_str("x-shutdown-token"), + HeaderValue::from_str(&token), + ) { + headers.insert(name, value); + } + } let client = Client::builder().build().map_err(|e| e.to_string())?; for (name, url, _) in &ENDPOINTS { - match client.post(*url).send() { + match client.post(*url).headers(headers.clone()).send() { Ok(resp) => { let status = resp.status(); @@ -73,7 +102,12 @@ pub fn kill_process(_process: &sysinfo::Process) -> Result<(), String> { println!("[{}] Shutdown OK ({})", name, status); } } - Err(_err) => {} + Err(_err) => { + eprintln!( + "[{}] Failed to send shutdown request to {}: {}", + name, url, _err + ); + } } } @@ -95,7 +129,12 @@ fn kill_process_tree() -> Result<(), String> { let name = process.name().to_string_lossy(); if target_names.iter().any(|t| name.eq_ignore_ascii_case(t)) { - let _ = kill_process(process); + if let Err(e) = kill_process(process) { + eprintln!( + "[PictoPy] Failed to send shutdown signal to process {}: {}", + name, e + ); + } } } diff --git a/sync-microservice/app/config/settings.py b/sync-microservice/app/config/settings.py index b9e2053a3..65e22f39f 100644 --- a/sync-microservice/app/config/settings.py +++ b/sync-microservice/app/config/settings.py @@ -1,5 +1,7 @@ -from platformdirs import user_data_dir import os +import tempfile + +from platformdirs import user_data_dir # Model Exports Path MODEL_EXPORTS_PATH = "app/models/ONNX_Exports" @@ -28,3 +30,7 @@ DATABASE_PATH = os.path.join(user_data_dir("PictoPy"), "database", "PictoPy.db") THUMBNAIL_IMAGES_PATH = "./images/thumbnails" IMAGES_PATH = "./images" + +# Shared session token file for authenticated shutdown. +SHUTDOWN_TOKEN_FILE: str = os.path.join(tempfile.gettempdir(), "pictopy_shutdown.token") +SHUTDOWN_TOKEN: str = "" diff --git a/sync-microservice/app/core/lifespan.py b/sync-microservice/app/core/lifespan.py index 26b75a5c0..0612498e6 100644 --- a/sync-microservice/app/core/lifespan.py +++ b/sync-microservice/app/core/lifespan.py @@ -1,6 +1,8 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +import asyncio import time +import app.config.settings as settings from app.database.folders import ( db_check_database_connection, ) @@ -27,6 +29,30 @@ async def lifespan(app: FastAPI): # Startup logger.info("Starting PictoPy Sync Microservice...") + # Wait for shutdown token from backend (up to 5 seconds) + logger.info("Waiting for shutdown token from backend...") + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + try: + with open(settings.SHUTDOWN_TOKEN_FILE) as f: + token = f.read().strip() + if token: + settings.SHUTDOWN_TOKEN = token + logger.info("Shutdown token loaded successfully") + break + except FileNotFoundError: + pass + except Exception as e: + logger.error(f"Error reading shutdown token file: {e}") + await asyncio.sleep(0.1) + + if not settings.SHUTDOWN_TOKEN: + logger.error( + f"pictopy_shutdown.token not found at {settings.SHUTDOWN_TOKEN_FILE} after 5 seconds" + ) + logger.error("Ensure the backend starts before the sync service.") + raise RuntimeError("Backend shutdown token not found") + # Check database connection logger.info("Checking database connection...") connection_timeout = 60 @@ -54,7 +80,7 @@ async def lifespan(app: FastAPI): logger.warning( f"Database connection attempt {attempt} failed. Retrying in {retry_interval} seconds... ({elapsed_time:.1f}s elapsed)" ) - time.sleep(retry_interval) + await asyncio.sleep(retry_interval) watcher_started = watcher_util_start_folder_watcher() diff --git a/sync-microservice/app/database/folders.py b/sync-microservice/app/database/folders.py index 0cc6b3ade..b2f4f82d3 100644 --- a/sync-microservice/app/database/folders.py +++ b/sync-microservice/app/database/folders.py @@ -30,12 +30,10 @@ def db_get_all_folders_with_ids() -> List[FolderIdPath]: cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT folder_id, folder_path FROM folders ORDER BY folder_path - """ - ) + """) return cursor.fetchall() except Exception as e: logger.error(f"Error getting folders from database: {e}") @@ -56,12 +54,10 @@ def db_check_database_connection() -> bool: conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() # Check if folders table exists - cursor.execute( - """ + cursor.execute(""" SELECT name FROM sqlite_master WHERE type='table' AND name='folders' - """ - ) + """) result = cursor.fetchone() return result is not None except Exception as e: @@ -84,8 +80,7 @@ def db_get_tagging_progress() -> List[FolderTaggingInfo]: cursor = conn.cursor() try: - cursor.execute( - """ + cursor.execute(""" SELECT f.folder_id, f.folder_path, @@ -94,8 +89,7 @@ def db_get_tagging_progress() -> List[FolderTaggingInfo]: FROM folders f LEFT JOIN images i ON f.folder_id = i.folder_id GROUP BY f.folder_id, f.folder_path - """ - ) + """) results = cursor.fetchall() diff --git a/sync-microservice/app/routes/shutdown.py b/sync-microservice/app/routes/shutdown.py index 9e31ee42d..ff9aae6b5 100644 --- a/sync-microservice/app/routes/shutdown.py +++ b/sync-microservice/app/routes/shutdown.py @@ -1,9 +1,12 @@ import asyncio +import hmac import os import platform import signal -from fastapi import APIRouter +from typing import Optional +from fastapi import APIRouter, Header, HTTPException from pydantic import BaseModel +from app.config import settings from app.utils.watcher import watcher_util_stop_folder_watcher from app.logging.setup_logging import get_sync_logger @@ -29,6 +32,13 @@ async def _delayed_shutdown(delay: float = 0.1): await asyncio.sleep(delay) logger.info("Exiting sync microservice...") + # Clean up token file + try: + os.remove(settings.SHUTDOWN_TOKEN_FILE) + logger.info("Shutdown token file removed") + except OSError as e: + logger.warning(f"Could not remove shutdown token file: {e}") + if platform.system() == "Windows": # Windows: SIGTERM doesn't work reliably with uvicorn subprocesses os._exit(0) @@ -38,18 +48,17 @@ async def _delayed_shutdown(delay: float = 0.1): @router.post("/shutdown", response_model=ShutdownResponse) -async def shutdown(): - """ - Gracefully shutdown the sync microservice. +async def shutdown(x_shutdown_token: Optional[str] = Header(default=None)): + """Gracefully shutdown the sync microservice (requires X-Shutdown-Token).""" + if x_shutdown_token is None: + logger.warning("Shutdown attempt rejected: missing token") + raise HTTPException(status_code=401, detail="Unauthorized") - This endpoint: - 1. Stops the folder watcher - 2. Schedules server termination after response is sent - 3. Returns confirmation to the caller + # Prevent timing-based token guessing + if not hmac.compare_digest(x_shutdown_token, settings.SHUTDOWN_TOKEN): + logger.warning("Shutdown attempt rejected: invalid token") + raise HTTPException(status_code=403, detail="Forbidden") - Returns: - ShutdownResponse with status and message - """ logger.info("Shutdown request received for sync microservice") try: