diff --git a/client/src/locales/en.json b/client/src/locales/en.json
index 8f4925756..ecb1c5f05 100644
--- a/client/src/locales/en.json
+++ b/client/src/locales/en.json
@@ -119,7 +119,9 @@
"resetPassword": "Reset password",
"resetPasswordSuccess": "Password reset",
"resetPasswordFailed": "Failed to reset password - Token incorrect or expired",
- "resetPasswordNote": "Enter your new password below and click submit to reset your password."
+ "resetPasswordNote": "Enter your new password below and click submit to reset your password.",
+ "oidc_login": "Login with OpenID Connect",
+ "oidc_with": "Login with"
}
},
"core": {
diff --git a/server/src/api/http/auth.py b/server/src/api/http/auth.py
index 12887cb0f..71e5ee49d 100644
--- a/server/src/api/http/auth.py
+++ b/server/src/api/http/auth.py
@@ -1,5 +1,7 @@
import asyncio
import random
+import secrets
+import logging
from aiohttp import web
from aiohttp_security import authorized_userid, forget, remember
@@ -11,7 +13,7 @@
from ...db.models.user import User
from ...mail import send_mail
from ...state.auth import auth_state
-
+from .oidc import oidc_auth
async def is_authed(request):
user = await get_authorized_user(request)
@@ -25,6 +27,8 @@ async def is_authed(request):
async def login(request):
+ if "local" not in cfg().general.authentication_methods:
+ return web.HTTPForbidden(reason="Local authentication is disabled")
try:
user = await get_authorized_user(request)
except web.HTTPUnauthorized:
@@ -83,6 +87,8 @@ async def logout(request):
async def forgot_password(request):
+ if "local" not in cfg().general.authentication_methods:
+ return web.HTTPForbidden(reason="Local authentication is disabled")
data = await request.json()
email = data["email"]
user = User.by_email(email)
@@ -110,6 +116,8 @@ async def forgot_password(request):
async def reset_password(request):
+ if "local" not in cfg().general.authentication_methods:
+ return web.HTTPForbidden(reason="Local authentication is disabled")
data = await request.json()
token = data["token"]
password = data["password"]
@@ -123,3 +131,87 @@ async def reset_password(request):
user.save()
return web.HTTPOk()
+
+# OIDC Authentication endpoints
+
+async def oidc_providers(request):
+ if "oidc" not in cfg().general.authentication_methods:
+ return web.HTTPForbidden(reason="OIDC authentication is disabled")
+ providers = oidc_auth.get_providers()
+ return web.json_response({"providers": providers})
+
+async def oidc_login(request):
+ logger = logging.getLogger("PlanarAllyServer")
+ if "oidc" not in cfg().general.authentication_methods:
+ logger.warning(f"OIDC authentication attempted but disabled in config")
+ return web.HTTPForbidden(reason="OIDC authentication is disabled")
+ try:
+ data = await request.json()
+ except Exception as e:
+ return web.HTTPBadRequest(reason="Invalid request format")
+
+ # Grab the state from the browser's request ensuring we can validate it later
+ provider_name = data.get("provider_name")
+ if not provider_name:
+ return web.HTTPBadRequest(reason="Missing required parameters")
+
+ auth_url = await oidc_auth.get_authorization_url(provider_name)
+
+ if not auth_url:
+ return web.HTTPInternalServerError(reason="Failed to initiate OIDC login")
+ logger.debug(f"Redirecting to OIDC provider with URL: {auth_url}")
+ # Instruct the client to redirect to the auth_url
+ return web.json_response({"authorization_url": auth_url})
+
+async def oidc_callback(request):
+ logger = logging.getLogger("PlanarAllyServer")
+ logger.debug("OIDC callback invoked")
+ if "oidc" not in cfg().general.authentication_methods:
+ logger.warning(f"OIDC authentication attempted but disabled in config")
+ return web.HTTPForbidden(reason="OIDC authentication is disabled")
+
+ # Get the provider name from the URL path since we have multiple providers
+ # and cannot rely on a single endpoint
+ provider_name = request.match_info["provider"]
+ code = request.query.get("code")
+ if not code or not provider_name:
+ logger.error("Missing code or provider_name in OIDC callback request")
+ return web.HTTPBadRequest(reason="Missing required parameters")
+
+ try:
+ user_info = await oidc_auth.exchange_code_for_token(code, provider_name, request.query.get("state"))
+ if not user_info:
+ logger.error("Failed to retrieve user info from OIDC provider")
+ return web.HTTPUnauthorized(reason="OIDC authentication failed")
+
+
+ if not user_info.username or not user_info.email:
+ logger.error("OIDC user info missing username/email")
+ return web.HTTPUnauthorized(reason="OIDC authentication failed")
+
+ user = User.by_name(user_info.username) or User.by_email(user_info.email)
+ if user is None:
+ # Check if auto-signup is allowed
+ if not cfg().general.allow_signups:
+ logger.info(f"Auto-signup disabled, rejecting OIDC login for unknown user: {user_info.username}")
+ return web.HTTPForbidden(reason="User does not exist and auto-signup is disabled")
+ # Auto-register the user
+ with db.atomic():
+ # Generate a sufficiently random password, since it won't be used for login
+ password = secrets.token_urlsafe(16)
+ user = User.create_new(user_info.username, password, user_info.email)
+ stats.events.user_created(user.id)
+ logger.info(f"Auto-registered new user: {user_info.username}")
+ response = web.HTTPOk()
+ user.update_last_login()
+ await remember(request, response, user_info.username)
+ logger.info(f"User {user_info.username} logged in via OIDC")
+ # Convert this response into a redirect for the browser
+ # we need the the auth system to modify the response headers
+ response.headers["Location"] = f"{cfg().general.client_url}"
+ response.set_status(302)
+ return response
+
+ except Exception as e:
+ logger.error(f"Error during OIDC callback processing: {e}")
+ return web.HTTPInternalServerError(reason="An error occurred during OIDC authentication")
diff --git a/server/src/api/http/oidc.py b/server/src/api/http/oidc.py
new file mode 100644
index 000000000..23e5e9666
--- /dev/null
+++ b/server/src/api/http/oidc.py
@@ -0,0 +1,257 @@
+
+from datetime import UTC, datetime
+import aiohttp
+
+from ...config import cfg, cfg_last_update
+from ...logs import logger
+from dataclasses import dataclass
+from ...config.types import OidcConfig
+
+
+class CodeCache:
+ """This service caches PKCE code challenges and states for OIDC authentication flows.
+
+ Cache entries are removed after 5 minutes to keep memory usage low."""
+ def __init__(self):
+ self.cache = dict[str, dict]()
+
+ def set(self, state: str, data: dict):
+ """Set the data for the given state"""
+ # Clear the cache of stale entries
+ now = datetime.now(UTC)
+ for key in list(self.cache.keys()):
+ _, timestamp = self.cache[key]
+ if (now - timestamp).total_seconds() > 300:
+ del self.cache[key]
+ self.cache[state] = (data, now)
+
+ def get(self, state: str) -> dict | None:
+ """Get and remove the data for the given state"""
+ entry = self.cache.pop(state, None)
+ if entry is None:
+ return None
+ data, timestamp = entry
+ # Do not return stale entries
+ if (datetime.now(UTC) - timestamp).total_seconds() > 300:
+ return None
+ return data
+
+
+class OidcServerConfig:
+ """Configuration for a single OIDC provider"""
+ refreshed: bool = False
+
+ def __init__(self, config: OidcConfig, redirect_uri: str):
+ self.display_name = config.display_name or "OpenID Connect"
+ """The display name for this OIDC provider"""
+ self.provider_id = config.provider_id
+ """The id of the provider"""
+ self.client_id = config.client_id
+ """"The OIDC Client ID to use for authentication"""
+ self.client_secret = config.client_secret
+ """The OIDC Client Secret to use for authentication"""
+ self.discovery_url = config.discovery_url
+ """The OIDC Discovery URL used to fetch provider configuration"""
+ self.redirect_uri = redirect_uri + config.provider_id
+ """The Redirect URI for this provider"""
+ self.scopes = config.scopes
+ """The scopes to request during authentication"""
+ self.username_claim = config.username_claim
+ """The claim to use as the username"""
+ self.email_claim = config.email_claim
+ """The claim to use as the email"""
+ self.doc = None
+ """The OIDC discovery document, once fetched"""
+ self.pkce = config.pkce
+ """Whether to use PKCE (Proof Key for Code Exchange) during authentication"""
+
+
+@dataclass
+class ExchangeData:
+ """Standardized data returned after exchanging code for token"""
+ username: str
+ email: str
+
+class OIDCAuth:
+ """Handles OIDC authentication logic
+
+ Provides methods to get authorization URLs, exchange codes for tokens, and manages multiple OIDC providers.
+ """
+ # Set to the minimum time so that we always initialize on first call
+ last_update: datetime = datetime.min.replace(tzinfo=UTC)
+
+ def __init__(self):
+ #self.last_update = datetime.min
+ # For each configured OIDC provider, create an OidcConfig
+ self.code_cache = CodeCache()
+ self.providers = dict[str, OidcServerConfig]()
+
+ def get_provider(self, provider_id: str) -> OidcServerConfig | None:
+ """Get the OIDC provider configuration by ID
+
+ Also refreshes the providers from the config if needed."""
+ # check the config and remove any providers that are no longer present
+ # Generate the redirect URI based on the server config
+ # but only if the config has changed since last update
+ if cfg_last_update() > self.last_update:
+ self.last_update = cfg_last_update()
+ logger.debug("Refreshing OIDC providers from config")
+ redirect_uri = (cfg().general.client_url or "") + "/api/oidc/callback/"
+
+ # Mark all providers as not refreshed first
+ for provider_id in list(self.providers.keys()):
+ self.providers[provider_id].refreshed = False
+
+ # Now refresh from the config
+ for oidc_cfg in cfg().oidc:
+ provider: OidcServerConfig
+ if oidc_cfg.provider_id not in self.providers:
+ provider = OidcServerConfig(oidc_cfg, redirect_uri)
+ self.providers[oidc_cfg.provider_id] = provider
+ else:
+ provider = self.providers[oidc_cfg.provider_id]
+ provider.refreshed = True
+
+ # Remove any providers that were not refreshed
+ for provider_id in list(self.providers.keys()):
+ if not self.providers[provider_id].refreshed:
+ del self.providers[provider_id]
+
+ return self.providers.get(provider_id)
+
+ def get_providers(self) -> list[dict] | None:
+ """Get a list of configured OIDC providers
+
+ Returns a list of dictionaries with provider display names and IDs."""
+ # This is often the first call, so we will attempt to get a None provider
+ self.get_provider("") # Trigger load if needed
+ return [
+ {
+ "display_name": provider.display_name,
+ "provider_id": provider.provider_id,
+ }
+ for provider in self.providers.values()
+ ]
+
+ async def get_discovery_document(self, provider: OidcServerConfig) -> dict | None:
+ """Fetch the OIDC discovery document for the given provider"""
+ # Check if we have already fetched it
+ if provider.doc is not None:
+ return provider.doc
+
+ async with aiohttp.ClientSession() as session:
+ async with session.get(provider.discovery_url) as resp:
+ if resp.status == 200:
+ doc = await resp.json()
+ provider.doc = doc # Cache the discovery document
+ return doc
+ else:
+ logger.error(
+ f"Failed to fetch OIDC discovery document from {provider.discovery_url}: {resp.status}"
+ )
+ return None
+
+ async def get_authorization_url(self, provider_id: str) -> str | None:
+ """Generate the authorization URL for the OIDC provider"""
+
+ provider = self.get_provider(provider_id)
+ if not provider:
+ return None
+
+ discovery_doc = await self.get_discovery_document(provider)
+ if not discovery_doc:
+ return None
+
+ auth_endpoint = discovery_doc.get("authorization_endpoint")
+ if not auth_endpoint:
+ logger.warning(f"Authorization endpoint not found in discovery document for provider '{provider_id}'")
+ return None
+ scope_str = " ".join(provider.scopes)
+ if provider.pkce:
+ # Generate a code challenge and verifier
+ import secrets
+ import hashlib
+ import base64
+ code_verifier = secrets.token_urlsafe(64)
+ code_challenge = base64.urlsafe_b64encode(
+ hashlib.sha256(code_verifier.encode()).digest()
+ ).rstrip(b'=').decode('utf-8')
+ state = secrets.token_urlsafe(16)
+ # Cache the code verifier associated with this state
+ self.code_cache.set(state, code_verifier)
+ auth_url = (
+ f"{auth_endpoint}?response_type=code&client_id={provider.client_id}"
+ f"&redirect_uri={provider.redirect_uri}&scope={scope_str}"
+ f"&code_challenge={code_challenge}&code_challenge_method=S256&state={state}"
+ )
+ else:
+ auth_url = (
+ f"{auth_endpoint}?response_type=code&client_id={provider.client_id}"
+ f"&redirect_uri={provider.redirect_uri}&scope={scope_str}"
+ )
+ return auth_url
+
+ async def exchange_code_for_token(self, code: str, provider_name: str, state: str | None = None) -> dict | None:
+ """Exchange the authorization code for tokens"""
+
+ provider = self.get_provider(provider_name)
+ if not provider:
+ return None
+ discovery_doc = await self.get_discovery_document(provider)
+ if not discovery_doc:
+ return None
+
+ token_endpoint = discovery_doc.get("token_endpoint")
+ if not token_endpoint:
+ logger.warning("Token endpoint not found in discovery document")
+ return None
+
+ user_info_endpoint = discovery_doc.get("userinfo_endpoint")
+ if not user_info_endpoint:
+ logger.warning("Userinfo endpoint not found in discovery document")
+ return None
+
+ data = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": provider.redirect_uri,
+ "client_id": provider.client_id,
+ "client_secret": provider.client_secret,
+ }
+ if provider.pkce:
+ if not state:
+ logger.debug("State is required for PKCE but not provided")
+ return None
+ # Retrieve the code verifier from the cache
+ code_verifier = self.code_cache.get(state)
+ if not code_verifier:
+ logger.debug("No code verifier found for the given state")
+ return None
+ data["code_verifier"] = code_verifier
+
+ async with aiohttp.ClientSession() as session:
+ async with session.post(token_endpoint, data=data) as resp:
+ if resp.status == 200:
+ token_response = await resp.json()
+ # Now fetch user info
+ async with session.get(
+ user_info_endpoint,
+ headers={"Authorization": f"Bearer {token_response.get('access_token')}"}
+ ) as userinfo_resp:
+ if userinfo_resp.status == 200:
+ user_info = await userinfo_resp.json()
+ # Useful to track down issues with claim mappings
+ logger.debug(f"OIDC user info response: {user_info}")
+ return ExchangeData(
+ username=user_info.get(provider.username_claim, ""),
+ email=user_info.get(provider.email_claim, ""),
+ )
+ else:
+ logger.error(f"Failed to fetch user info: {userinfo_resp.status}")
+ return None
+ else:
+ logger.error(f"Failed to exchange code for token: {resp.status}")
+ return None
+
+
+oidc_auth = OIDCAuth()
\ No newline at end of file
diff --git a/server/src/config/__init__.py b/server/src/config/__init__.py
index ee64aea2c..c0b332dd9 100644
--- a/server/src/config/__init__.py
+++ b/server/src/config/__init__.py
@@ -9,3 +9,5 @@
# Otherwise, updates would not be reflected if they change the pointer
def cfg():
return config_manager.config
+def cfg_last_update():
+ return config_manager.last_update
diff --git a/server/src/config/manager.py b/server/src/config/manager.py
index f39fea3f3..3e3c7b20c 100644
--- a/server/src/config/manager.py
+++ b/server/src/config/manager.py
@@ -28,13 +28,14 @@ def on_modified(self, event):
class ConfigManager:
+
def __init__(self, config_path: Path):
self.config_path = config_path
self.config = ServerConfig()
self._file_observer = Observer()
+ self.last_update: datetime = datetime.now(UTC)
self.load_config(startup=True)
-
# Setup file watching
event_handler = ConfigFileHandler(self.load_config)
self._file_observer.schedule(event_handler, str(self.config_path.parent), recursive=False)
@@ -54,6 +55,7 @@ def load_config(self, *, startup=False) -> None:
logger.info("Config file changed, reloading")
reset_email()
+ self.last_update = datetime.now(UTC)
except rtoml.TomlParsingError as e:
print(f"Error loading config: {e}")
except ValidationError as e:
@@ -75,21 +77,6 @@ def save_config(self) -> None:
except Exception as e:
logger.error(f"Error saving config: {e}")
- def update_config(self, updates: dict[str, Any]) -> None:
- """Update config with new values"""
- if "admin_user" in updates:
- raise ValueError("admin_user cannot be updated dynamically for security reasons.")
-
- try:
- # Create new config with updates
- new_config = ServerConfig(**(self.config.model_dump() | updates))
- self.config = new_config
-
- # Save and notify
- self.save_config()
- except ValidationError as e:
- raise ValueError(f"Invalid configuration: {e}")
-
def cleanup(self) -> None:
"""Cleanup resources"""
if self._file_observer.is_alive():
diff --git a/server/src/config/types.py b/server/src/config/types.py
index 904c97403..1594c4e3d 100644
--- a/server/src/config/types.py
+++ b/server/src/config/types.py
@@ -66,6 +66,17 @@ class AssetsConfig(ConfigModel):
max_single_asset_size_in_bytes: int = 0
max_total_asset_size_in_bytes: int = 0
+class LoggingConfig(ConfigModel):
+ # Enable logging to a file
+ destinations: list[Literal["stdout", "file"]] = ["stdout", "file"]
+ # The log file path
+ file_path: str = "planarally.log"
+ # The Log Level
+ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
+ # These settings are used for log rotating,
+ # see https://docs.python.org/3/library/logging.handlers.html#logging.handlers.RotatingFileHandler for details
+ max_log_size_in_bytes: int = 200_000
+ max_log_backups: int = 5
class GeneralConfig(ConfigModel):
# Location of the save file
@@ -90,11 +101,16 @@ class GeneralConfig(ConfigModel):
# These settings are used for log rotating,
# see https://docs.python.org/3/library/logging.handlers.html#logging.handlers.RotatingFileHandler for details
- max_log_size_in_bytes: int = 200_000
- max_log_backups: int = 5
+ # Retiring this for a dedicated logging section
+ max_log_size_in_bytes: int | None = None
+ max_log_backups: int | None = None
admin_user: str | None = None
+ # Authentication options (local = built-in username/password, oidc = OpenID Connect)
+ # Both methods can be enabled at the same time
+ authentication_methods: list[Literal["local", "oidc"]] = ["local"]
+
class MailConfig(ConfigModel):
# Can be used to disable email functionality
@@ -126,6 +142,25 @@ class StatsConfig(ConfigModel):
# The base URL to send stats to
stats_url: str = "https://stats.planarally.io"
+class OidcConfig(ConfigModel):
+ # The display name for this OIDC provider
+ display_name: str
+ # The id of the provider
+ provider_id: str
+ # The OIDC Client ID to use for authentication
+ client_id: str
+ # The OIDC Client Secret to use for authentication
+ client_secret: str
+ # The OIDC Discovery URL used to fetch provider configuration
+ discovery_url: str
+ # Whether to use PKCE (Proof Key for Code Exchange) during authentication
+ pkce: bool = True
+ # The list of scopes to request during authentication (normally you shouldn't need to change this)
+ scopes: list[str] = ["openid", "email", "profile"]
+ # The claim to use as the username
+ username_claim: str = "preferred_username"
+ # The claim to use as the email
+ email_claim: str = "email"
class ServerConfig(ConfigModel):
general: GeneralConfig = GeneralConfig()
@@ -133,3 +168,5 @@ class ServerConfig(ConfigModel):
webserver: WebserverConfig = WebserverConfig()
stats: StatsConfig = StatsConfig()
mail: MailConfig | None = None
+ logging: LoggingConfig = LoggingConfig()
+ oidc: list[OidcConfig] = []
diff --git a/server/src/logs.py b/server/src/logs.py
index 3fdb2f04e..d78627ee9 100644
--- a/server/src/logs.py
+++ b/server/src/logs.py
@@ -7,20 +7,32 @@
# SETUP LOGGING
+config = cfg()
+
+# Initialize the main logger destination
logger = logging.getLogger("PlanarAllyServer")
-logger.setLevel(logging.INFO)
-file_handler = RotatingFileHandler(
- str(FILE_DIR / "planarallyserver.log"),
- maxBytes=cfg().general.max_log_size_in_bytes,
- backupCount=cfg().general.max_log_backups,
-)
-file_handler.setLevel(logging.INFO)
+# Set the logging level based on configuration
+logger.setLevel(config.logging.level)
+# Define the log message format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)")
-file_handler.setFormatter(formatter)
-stream_handler = logging.StreamHandler(sys.stdout)
-stream_handler.setFormatter(formatter)
-logger.addHandler(file_handler)
-logger.addHandler(stream_handler)
+
+# Initialize the stream handler for logging to stdout if enabled, always at least log to stdout even if no destinations are set
+if "stdout" in config.logging.destinations or not config.logging.destinations:
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+# Initialize the file handler for logging
+if "file" in config.logging.destinations:
+ # Check the logging file path from configuration to see if its absolute
+ log_file_path = config.logging.file_path
+ if not log_file_path.startswith("/") and not log_file_path[1:3] == ":\\": # Windows drive letter check
+ log_file_path = str(FILE_DIR / log_file_path)
+ maxBytes = config.logging.max_log_size_in_bytes or config.general.max_log_size_in_bytes
+ backupCount = config.logging.max_log_backups or config.general.max_log_backups
+ file_handler = RotatingFileHandler(filename=log_file_path, maxBytes=maxBytes, backupCount=backupCount)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
def handle_exception(exc_type, exc_value, exc_traceback):
diff --git a/server/src/routes.py b/server/src/routes.py
index 7e34b618e..6524e9a7b 100644
--- a/server/src/routes.py
+++ b/server/src/routes.py
@@ -22,6 +22,8 @@ def __replace_config_data(data: bytes) -> bytes:
data = data.replace(b'name="PA-signup" content="true"', b'name="PA-signup" content="false"')
if not config.mail or not config.mail.enabled:
data = data.replace(b'name="PA-mail" content="true"', b'name="PA-mail" content="false"')
+ # Replace the PA-auth meta tag based on enabled authentication methods
+ data = data.replace(b'name="PA-auth" content="local"', ('name="PA-auth" content="'+ " ".join(config.general.authentication_methods) +'"').encode())
return data
@@ -70,6 +72,9 @@ async def root_dev(request):
main_app.router.add_get(f"{subpath}/api/changelog", version.get_changelog)
main_app.router.add_get(f"{subpath}/api/notifications", notifications.collect)
main_app.router.add_post(f"{subpath}/api/mod/upload", mods.upload)
+main_app.router.add_get(f"{subpath}/api/oidc/providers", auth.oidc_providers)
+main_app.router.add_post(f"{subpath}/api/oidc/login", auth.oidc_login)
+main_app.router.add_get(f"{subpath}/api/oidc/callback/{{provider}}", auth.oidc_callback)
TAIL_REGEX = "/{tail:(?!api).*}"
if "dev" in sys.argv:
diff --git a/server/src/save.py b/server/src/save.py
index f572d5eb1..4b73988a7 100644
--- a/server/src/save.py
+++ b/server/src/save.py
@@ -19,7 +19,6 @@
import asyncio
from collections import defaultdict
import json
-import logging
import secrets
import shutil
import sys
@@ -33,10 +32,8 @@
from .db.models.constants import Constants
from .thumbnail import generate_thumbnail_for_asset
from .utils import ASSETS_DIR, FILE_DIR, SAVE_PATH, OldVersionException, UnknownVersionException, get_asset_hash_subpath
-
-logger: logging.Logger = logging.getLogger("PlanarAllyServer")
-logger.setLevel(logging.INFO)
-
+# Import logger to retain configuration settings
+from .logs import logger
def get_save_version(db: SqliteExtDatabase):
return db.execute_sql("SELECT save_version FROM constants").fetchone()[0]
@@ -56,7 +53,6 @@ def create_new_db(db: SqliteExtDatabase, version: int):
api_token=secrets.token_hex(32),
)
-
def check_existence() -> bool:
if not SAVE_PATH.exists():
logger.warning("Provided save file does not exist. Creating a new one.")