Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions diracx-core/src/diracx/core/config/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import asyncio
import logging
import os
from abc import ABCMeta, abstractmethod
from datetime import datetime, timezone
from pathlib import Path
Expand Down Expand Up @@ -162,7 +161,12 @@ def __init_subclass__(cls) -> None:

@classmethod
def create(cls):
return cls.create_from_url(backend_url=os.environ["DIRACX_CONFIG_BACKEND_URL"])
# Avoid circular import
from diracx.core.settings import FactorySettings

return cls.create_from_url(
backend_url=FactorySettings().diracx_config_backend_url
)

@classmethod
def create_from_url(
Expand Down
143 changes: 143 additions & 0 deletions diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Self, TypeVar, cast

import dotenv
from aiobotocore.session import get_session
from botocore.config import Config
from botocore.errorfactory import ClientError
Expand All @@ -32,16 +33,21 @@
SecretStr,
TypeAdapter,
UrlConstraints,
create_model,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict

from .config.sources import ConfigSourceUrl
from .extensions import DiracEntryPoint, select_from_extension
from .properties import SecurityProperty
from .s3 import s3_bucket_exists
from .utils import dotenv_files_from_environment

if TYPE_CHECKING:
from types_aiobotocore_s3.client import S3Client


T = TypeVar("T")


Expand Down Expand Up @@ -358,3 +364,140 @@ def s3_client(self) -> S3Client:
if self._client is None:
raise RuntimeError("S3 client accessed before lifetime function")
return self._client


def _build_factory_settings_model() -> type[ServiceSettingsBase]:

class _EnabledServicesBase(ServiceSettingsBase):
model_config = SettingsConfigDict(
frozen=True,
use_attribute_docstrings=True,
)

enabled_services_field: dict[str, tuple[Any, Any]] = {}

for entry_point in select_from_extension(group=DiracEntryPoint.SERVICES):
if "well-known" in entry_point.name:
continue
enabled_services_field[f"{entry_point.name}"] = (
bool,
Field(
default=True,
validation_alias=f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED",
description=f"Enable the {entry_point.name.upper()} router",
),
)

EnabledServices = create_model( # noqa: N806
"EnabledServices",
__doc__="Enabled services",
__base__=_EnabledServicesBase,
**cast(dict[str, Any], enabled_services_field),
)

class _OpenSearchDBSettingsBase(ServiceSettingsBase):
model_config = SettingsConfigDict(
frozen=True,
use_attribute_docstrings=True,
)

opensearch_db_field: dict[str, tuple[Any, Any]] = {}

for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
db_name = entry_point.name
opensearch_db_field[f"{db_name}"] = (
str,
Field(
default="",
validation_alias=f"DIRACX_OS_DB_{db_name.upper()}",
description="A JSON-encoded dictionary of connection keyword arguments"
f" for the OpenSearch database {db_name}.",
),
)

OpenSearchDBSettings = create_model( # noqa: N806
"OpenSearchDBSettings",
__doc__="OpenSearch database settings",
__base__=_OpenSearchDBSettingsBase,
**cast(dict[str, Any], opensearch_db_field),
)

class _SqlDBSettingsBase(ServiceSettingsBase):
model_config = SettingsConfigDict(
frozen=True,
use_attribute_docstrings=True,
)

sql_db_field: dict[str, tuple[Any, Any]] = {}

for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
db_name = entry_point.name
sql_db_field[f"{db_name}"] = (
str,
Field(
default="",
validation_alias=f"DIRACX_DB_URL_{db_name.upper()}",
description=f"The URL for the SQL database {db_name}.",
),
)

SqlDBSettings = create_model( # noqa: N806
"SqlDBSettings",
__doc__="SQL database settings",
__base__=_SqlDBSettingsBase,
**cast(dict[str, Any], sql_db_field),
)

class _BaseFactorySettings(ServiceSettingsBase):
"""Factory settings.

Settings which do not fit into dedicated classes,
or are dynamically generated.
"""

model_config = SettingsConfigDict(use_attribute_docstrings=True)

diracx_config_backend_url: ConfigSourceUrl
"""The URL of the configuration backend.
"""

diracx_legacy_exchange_hashed_api_key: str = ""
"""The hashed API key for the legacy exchange endpoint.
"""

@model_validator(mode="before")
@classmethod
def load_dotenv_files(cls, data: Any) -> Any:
"""Load dotenv files before reading settings from environment."""
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
if not dotenv.load_dotenv(env_file):
raise NotImplementedError(f"Could not load dotenv file {env_file}")
return data

enabled_services: EnabledServices = Field(
default_factory=EnabledServices,
description="""The following environment variables dictates which routers are enabled.""",
)

opensearch_dbs: OpenSearchDBSettings = Field(
default_factory=OpenSearchDBSettings,
description="""The following environment variables configure the OpenSearch database connections.""",
)

sql_dbs: SqlDBSettings = Field(
default_factory=SqlDBSettings,
description="""The following environment variables configure the SQL database connections.""",
)

fields: dict[str, tuple[Any, Any]] = {}

new_mod = create_model(
"FactorySettings",
__doc__=_BaseFactorySettings.__doc__,
__base__=_BaseFactorySettings,
**cast(dict[str, Any], fields),
)
return new_mod


FactorySettings = _build_factory_settings_model()
28 changes: 18 additions & 10 deletions diracx-db/src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import json
import logging
import os
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator
from contextvars import ContextVar
Expand All @@ -14,6 +13,7 @@

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.settings import FactorySettings
from diracx.db.exceptions import DBUnavailableError

logger = logging.getLogger(__name__)
Expand All @@ -38,7 +38,8 @@ class BaseOSDB(metaclass=ABCMeta):
This method returns a dictionary of database names to connection parameters.
The available databases are determined by the `diracx.dbs.os` entrypoint in
the `pyproject.toml` file and the connection parameters are taken from the
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
`opensearch_dbs` field in FactorySettings, which reads from environment variables
prefixed with `DIRACX_OS_DB_{DB_NAME}`.

If extensions to DiracX are being used, there can be multiple implementations
of the same database. To list the available implementations use
Expand Down Expand Up @@ -104,19 +105,26 @@ def available_implementations(cls, db_name: str) -> list[type[BaseOSDB]]:
def available_urls(cls) -> dict[str, dict[str, Any]]:
"""Return a dict of available OpenSearch database urls.

The list of available URLs is determined by environment variables
The list of available URLs is determined by the opensearch_dbs field
in FactorySettings, which reads from environment variables
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
"""
factory_settings = FactorySettings()
opensearch_dbs = factory_settings.opensearch_dbs

conn_kwargs: dict[str, dict[str, Any]] = {}
for entry_point in select_from_extension(group=DiracEntryPoint.OS_DB):
db_name = entry_point.name
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
if var_name in os.environ:
try:
conn_kwargs[db_name] = json.loads(os.environ[var_name])
except Exception:
logger.error("Error loading connection parameters for %s", db_name)
raise
# Get the field value from the OpenSearchDBSettings model
if field_value := getattr(opensearch_dbs, db_name, None):
if field_value:
try:
conn_kwargs[db_name] = json.loads(field_value)
except Exception:
logger.error(
"Error loading connection parameters for %s", db_name
)
raise
return conn_kwargs

@classmethod
Expand Down
66 changes: 37 additions & 29 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import logging
import os
import re
from abc import ABCMeta
from collections.abc import AsyncIterator
Expand Down Expand Up @@ -53,8 +52,9 @@ class BaseSQLDB(metaclass=ABCMeta):
The available databases are discovered by calling `BaseSQLDB.available_urls`.
This method returns a mapping of database names to connection URLs. The
available databases are determined by the `diracx.dbs.sql` entrypoint in the
`pyproject.toml` file and the connection URLs are taken from the environment
variables of the form `DIRACX_DB_URL_<db-name>`.
`pyproject.toml` file and the connection URLs are taken from the
`sql_dbs` field in FactorySettings, which reads from environment variables
of the form `DIRACX_DB_URL_<db-name>`.

If extensions to DiracX are being used, there can be multiple implementations
of the same database. To list the available implementations use
Expand Down Expand Up @@ -125,37 +125,45 @@ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
def available_urls(cls) -> dict[str, str]:
"""Return a dict of available database urls.

The list of available URLs is determined by environment variables
The list of available URLs is determined by the sql_dbs field
in FactorySettings, which reads from environment variables
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
"""
from diracx.core.settings import FactorySettings

factory_settings = FactorySettings()
sql_dbs = factory_settings.sql_dbs

db_urls: dict[str, str] = {}
for entry_point in select_from_extension(group=DiracEntryPoint.SQL_DB):
db_name = entry_point.name
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
if var_name in os.environ:
try:
db_url = os.environ[var_name]
if db_url == "sqlite+aiosqlite:///:memory:":
db_urls[db_name] = db_url
# pydantic does not allow for underscore in scheme
# so we do a special case
elif "_" in db_url.split(":")[0]:
# Validate the URL with a fake schema, and then store
# the original one
scheme_id = db_url.find(":")
fake_url = (
db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
)
TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
db_urls[db_name] = db_url

else:
db_urls[db_name] = str(
TypeAdapter(SqlalchemyDsn).validate_python(db_url)
)
except Exception:
logger.error("Error loading URL for %s", db_name)
raise
# Get the field value from the SqlDBSettings model
if hasattr(sql_dbs, db_name):
db_url = getattr(sql_dbs, db_name)
if db_url:
try:
if db_url == "sqlite+aiosqlite:///:memory:":
db_urls[db_name] = db_url
# pydantic does not allow for underscore in scheme
# so we do a special case
elif "_" in db_url.split(":")[0]:
# Validate the URL with a fake schema, and then store
# the original one
scheme_id = db_url.find(":")
fake_url = (
db_url[:scheme_id].replace("_", "-")
+ db_url[scheme_id:]
)
TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
db_urls[db_name] = db_url

else:
db_urls[db_name] = str(
TypeAdapter(SqlalchemyDsn).validate_python(db_url)
)
except Exception:
logger.error("Error loading URL for %s", db_name)
raise
return db_urls

@classmethod
Expand Down
Loading
Loading