|
| 1 | +"""OAuth2 configuration database storage with encryption. |
| 2 | +
|
| 3 | +This module provides database persistence for OAuth2 configuration, |
| 4 | +with automatic encryption/decryption of sensitive fields like client_secret. |
| 5 | +""" |
| 6 | + |
| 7 | +import json |
| 8 | +import logging |
| 9 | +from typing import Any, Dict, List, Optional |
| 10 | + |
| 11 | +from cryptography.fernet import Fernet |
| 12 | +from sqlalchemy import Column, DateTime, Integer, String, Text |
| 13 | + |
| 14 | +from derisk._private.config import Config |
| 15 | +from derisk.storage.metadata import BaseDao, Model |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class OAuth2ConfigEntity(Model): |
| 21 | + """OAuth2 configuration entity for database storage (plain text).""" |
| 22 | + |
| 23 | + __tablename__ = "oauth2_config" |
| 24 | + |
| 25 | + id = Column(Integer, primary_key=True, autoincrement=True) |
| 26 | + config_key = Column( |
| 27 | + String(64), |
| 28 | + nullable=False, |
| 29 | + default="global", |
| 30 | + comment="Configuration key (default: global)", |
| 31 | + ) |
| 32 | + enabled = Column( |
| 33 | + Integer, |
| 34 | + nullable=False, |
| 35 | + default=0, |
| 36 | + comment="OAuth2 enabled flag (1=true, 0=false)", |
| 37 | + ) |
| 38 | + providers_json = Column( |
| 39 | + Text, |
| 40 | + nullable=True, |
| 41 | + comment="OAuth2 providers configuration (JSON array)", |
| 42 | + ) |
| 43 | + admin_users_json = Column( |
| 44 | + Text, |
| 45 | + nullable=True, |
| 46 | + comment="Admin users list (JSON array)", |
| 47 | + ) |
| 48 | + gmt_create = Column(DateTime, nullable=True) |
| 49 | + gmt_modify = Column(DateTime, nullable=True) |
| 50 | + |
| 51 | + def to_dict(self) -> Dict[str, Any]: |
| 52 | + """Convert to dictionary.""" |
| 53 | + return { |
| 54 | + "id": self.id, |
| 55 | + "config_key": self.config_key, |
| 56 | + "enabled": bool(self.enabled), |
| 57 | + "providers_json": self.providers_json, |
| 58 | + "admin_users_json": self.admin_users_json, |
| 59 | + } |
| 60 | + |
| 61 | + |
| 62 | +class OAuth2ConfigDao(BaseDao[OAuth2ConfigEntity, Any, Any]): |
| 63 | + """DAO for OAuth2 configuration.""" |
| 64 | + |
| 65 | + def get_by_key(self, config_key: str = "global") -> Optional[OAuth2ConfigEntity]: |
| 66 | + """Get OAuth2 config by key.""" |
| 67 | + with self.session() as session: |
| 68 | + return ( |
| 69 | + session.query(OAuth2ConfigEntity) |
| 70 | + .filter(OAuth2ConfigEntity.config_key == config_key) |
| 71 | + .first() |
| 72 | + ) |
| 73 | + |
| 74 | + @staticmethod |
| 75 | + def _is_masked_secret(secret: str) -> bool: |
| 76 | + """Check if a secret is masked (e.g., 'abcd****').""" |
| 77 | + return bool(secret and "****" in secret) |
| 78 | + |
| 79 | + def _merge_secrets( |
| 80 | + self, |
| 81 | + new_providers: List[Dict[str, Any]], |
| 82 | + old_providers: List[Dict[str, Any]], |
| 83 | + ) -> List[Dict[str, Any]]: |
| 84 | + """Merge new providers with old, preserving secrets when new values are masked. |
| 85 | +
|
| 86 | + If new provider's client_secret is masked (e.g., 'abcd****'), |
| 87 | + use the corresponding old provider's secret (if provider id matches). |
| 88 | + """ |
| 89 | + if not old_providers: |
| 90 | + return new_providers |
| 91 | + |
| 92 | + # Build lookup for old providers by id |
| 93 | + old_by_id = {p.get("id", ""): p for p in old_providers if p.get("id")} |
| 94 | + |
| 95 | + merged = [] |
| 96 | + for new_p in new_providers: |
| 97 | + pid = new_p.get("id", "") |
| 98 | + new_secret = new_p.get("client_secret", "") |
| 99 | + |
| 100 | + # If secret is masked and we have old provider with same id |
| 101 | + if self._is_masked_secret(new_secret) and pid in old_by_id: |
| 102 | + old_p = old_by_id[pid] |
| 103 | + old_secret = old_p.get("client_secret", "") |
| 104 | + |
| 105 | + # Make a copy and replace masked secret with original |
| 106 | + merged_p = dict(new_p) |
| 107 | + merged_p["client_secret"] = old_secret |
| 108 | + merged.append(merged_p) |
| 109 | + else: |
| 110 | + # Secret is not masked (new value) or no old provider |
| 111 | + merged.append(new_p) |
| 112 | + |
| 113 | + return merged |
| 114 | + |
| 115 | + def save_or_update( |
| 116 | + self, |
| 117 | + enabled: bool, |
| 118 | + providers: List[Dict[str, Any]], |
| 119 | + admin_users: List[str], |
| 120 | + config_key: str = "global", |
| 121 | + ) -> OAuth2ConfigEntity: |
| 122 | + """Save or update OAuth2 config (stored in plain text, mask on display).""" |
| 123 | + from datetime import datetime |
| 124 | + |
| 125 | + with self.session() as session: |
| 126 | + entity = ( |
| 127 | + session.query(OAuth2ConfigEntity) |
| 128 | + .filter(OAuth2ConfigEntity.config_key == config_key) |
| 129 | + .first() |
| 130 | + ) |
| 131 | + |
| 132 | + # If entity exists, merge secrets to avoid overwriting with masked values |
| 133 | + if entity and entity.providers_json: |
| 134 | + try: |
| 135 | + old_providers = json.loads(entity.providers_json) |
| 136 | + providers = self._merge_secrets(providers, old_providers) |
| 137 | + except json.JSONDecodeError: |
| 138 | + pass |
| 139 | + |
| 140 | + # Store providers as plain JSON (client_secret included, unmasked) |
| 141 | + providers_json = json.dumps(providers, ensure_ascii=False) |
| 142 | + admin_users_json = json.dumps(admin_users, ensure_ascii=False) |
| 143 | + |
| 144 | + if entity: |
| 145 | + entity.enabled = 1 if enabled else 0 |
| 146 | + entity.providers_json = providers_json |
| 147 | + entity.admin_users_json = admin_users_json |
| 148 | + entity.gmt_modify = datetime.utcnow() |
| 149 | + else: |
| 150 | + entity = OAuth2ConfigEntity( |
| 151 | + config_key=config_key, |
| 152 | + enabled=1 if enabled else 0, |
| 153 | + providers_json=providers_json, |
| 154 | + admin_users_json=admin_users_json, |
| 155 | + gmt_create=datetime.utcnow(), |
| 156 | + gmt_modify=datetime.utcnow(), |
| 157 | + ) |
| 158 | + session.add(entity) |
| 159 | + |
| 160 | + session.commit() |
| 161 | + session.refresh(entity) |
| 162 | + return entity |
| 163 | + |
| 164 | + def _mask_providers_for_display( |
| 165 | + self, providers: List[Dict[str, Any]] |
| 166 | + ) -> List[Dict[str, Any]]: |
| 167 | + """Mask sensitive fields (client_secret) for display purposes. |
| 168 | +
|
| 169 | + This returns a copy with client_secret hidden (showing only first 4 chars). |
| 170 | + The actual secret remains in the database. |
| 171 | + """ |
| 172 | + if not providers: |
| 173 | + return [] |
| 174 | + |
| 175 | + masked = json.loads(json.dumps(providers)) |
| 176 | + for provider in masked: |
| 177 | + secret = provider.get("client_secret", "") |
| 178 | + if secret and len(secret) > 4: |
| 179 | + # Show first 4 chars, mask the rest |
| 180 | + provider["client_secret"] = secret[:4] + "****" |
| 181 | + elif secret: |
| 182 | + provider["client_secret"] = "****" |
| 183 | + return masked |
| 184 | + |
| 185 | + def get_config(self, config_key: str = "global", mask_secrets: bool = True) -> Optional[Dict[str, Any]]: |
| 186 | + """Get OAuth2 config from database. |
| 187 | +
|
| 188 | + Args: |
| 189 | + config_key: Configuration key (default: global) |
| 190 | + mask_secrets: If True, mask client_secret in providers for display |
| 191 | + """ |
| 192 | + entity = self.get_by_key(config_key) |
| 193 | + if not entity: |
| 194 | + return None |
| 195 | + |
| 196 | + try: |
| 197 | + admin_users = ( |
| 198 | + json.loads(entity.admin_users_json) |
| 199 | + if entity.admin_users_json |
| 200 | + else [] |
| 201 | + ) |
| 202 | + except json.JSONDecodeError: |
| 203 | + admin_users = [] |
| 204 | + |
| 205 | + try: |
| 206 | + providers = json.loads(entity.providers_json or "[]") |
| 207 | + except json.JSONDecodeError: |
| 208 | + providers = [] |
| 209 | + |
| 210 | + # Mask secrets if requested (for display purposes) |
| 211 | + if mask_secrets: |
| 212 | + providers = self._mask_providers_for_display(providers) |
| 213 | + |
| 214 | + return { |
| 215 | + "enabled": bool(entity.enabled), |
| 216 | + "providers": providers, |
| 217 | + "admin_users": admin_users, |
| 218 | + } |
| 219 | + |
| 220 | + def get_config_with_secrets(self, config_key: str = "global") -> Optional[Dict[str, Any]]: |
| 221 | + """Get OAuth2 config with actual secrets (for internal use only).""" |
| 222 | + return self.get_config(config_key, mask_secrets=False) |
| 223 | + |
| 224 | + |
| 225 | +class OAuth2DbStorage: |
| 226 | + """High-level storage interface for OAuth2 config.""" |
| 227 | + |
| 228 | + def __init__(self): |
| 229 | + self._dao: Optional[OAuth2ConfigDao] = None |
| 230 | + |
| 231 | + @property |
| 232 | + def dao(self) -> OAuth2ConfigDao: |
| 233 | + if self._dao is None: |
| 234 | + self._dao = OAuth2ConfigDao() |
| 235 | + return self._dao |
| 236 | + |
| 237 | + def load(self, mask_secrets: bool = True) -> Optional[Dict[str, Any]]: |
| 238 | + """Load OAuth2 config from database.""" |
| 239 | + return self.dao.get_config("global", mask_secrets=mask_secrets) |
| 240 | + |
| 241 | + def load_with_secrets(self) -> Optional[Dict[str, Any]]: |
| 242 | + """Load OAuth2 config with actual secrets (for internal use only).""" |
| 243 | + return self.dao.get_config_with_secrets("global") |
| 244 | + |
| 245 | + def save(self, enabled: bool, providers: List[Dict], admin_users: List[str]) -> bool: |
| 246 | + """Save OAuth2 config to database.""" |
| 247 | + try: |
| 248 | + self.dao.save_or_update(enabled, providers, admin_users, "global") |
| 249 | + return True |
| 250 | + except Exception as e: |
| 251 | + logger.exception(f"Failed to save OAuth2 config: {e}") |
| 252 | + return False |
| 253 | + |
| 254 | + |
| 255 | +# Singleton instance |
| 256 | +_oauth2_storage: Optional[OAuth2DbStorage] = None |
| 257 | + |
| 258 | + |
| 259 | +def get_oauth2_db_storage() -> OAuth2DbStorage: |
| 260 | + """Get OAuth2 database storage singleton.""" |
| 261 | + global _oauth2_storage |
| 262 | + if _oauth2_storage is None: |
| 263 | + _oauth2_storage = OAuth2DbStorage() |
| 264 | + return _oauth2_storage |
0 commit comments