From 6d5fb11be68258cfe931e0dc95648efe7facd44a Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Wed, 22 Apr 2026 18:20:14 +0200 Subject: [PATCH 1/7] fix: adjust rss logic and db to our needs --- diracx-db/src/diracx/db/sql/rss/db.py | 60 +++++++++++++++----- diracx-db/tests/rss/test_rss_db.py | 25 +++++--- diracx-logic/src/diracx/logic/rss/query.py | 66 +++++++++++++--------- 3 files changed, 102 insertions(+), 49 deletions(-) diff --git a/diracx-db/src/diracx/db/sql/rss/db.py b/diracx-db/src/diracx/db/sql/rss/db.py index a891995f6..1e91d4205 100644 --- a/diracx-db/src/diracx/db/sql/rss/db.py +++ b/diracx-db/src/diracx/db/sql/rss/db.py @@ -1,5 +1,7 @@ from __future__ import annotations +from datetime import datetime + from sqlalchemy import select from sqlalchemy.engine import Row @@ -18,31 +20,31 @@ class ResourceStatusDB(BaseSQLDB): metadata = RSSBase.metadata - async def get_site_status(self, name: str, vo: str = "all") -> tuple[str, str]: - stmt = select(SiteStatus.status, SiteStatus.reason).where( - SiteStatus.name == name, + async def get_site_statuses(self, vo: str = "all") -> list[tuple[str, str, str]]: + stmt = select(SiteStatus.name, SiteStatus.status, SiteStatus.reason).where( SiteStatus.status_type == "all", SiteStatus.vo == vo, ) result = await self.conn.execute(stmt) - row = result.one_or_none() - if not row: - raise ResourceNotFoundError(name) + rows = result.all() + if not rows: + raise ResourceNotFoundError(f"Site statuses for VO {vo}") - return row.Status, row.Reason + return [(row.Name, row.Status, row.Reason) for row in rows] - async def get_resource_status( + async def get_resource_statuses( self, - name: str, status_types: list[str] | None = None, vo: str = "all", - ) -> dict[str, Row]: + ) -> dict[str, dict[str, Row]]: if not status_types: status_types = ["all"] stmt = select( - ResourceStatus.status, ResourceStatus.reason, ResourceStatus.status_type + ResourceStatus.name, + ResourceStatus.status, + ResourceStatus.reason, + ResourceStatus.status_type, ).where( - ResourceStatus.name == name, ResourceStatus.status_type.in_(status_types), ResourceStatus.vo == vo, ) @@ -50,5 +52,35 @@ async def get_resource_status( rows = result.all() if not rows: - raise ResourceNotFoundError(name) - return {row.StatusType: row for row in rows} + raise ResourceNotFoundError(f"Resource statuses for VO {vo}") + statuses: dict[str, dict[str, Row]] = {} + for row in rows: + if row.Name not in statuses: + statuses[row.Name] = {} + statuses[row.Name][row.StatusType] = row + return statuses + + async def get_status_date( + self, + status_types: list[str] | None = None, + vo: str = "all", + ) -> Row[tuple[datetime, datetime]]: + if not status_types: + status_types = ["all"] + stmt = ( + select( + ResourceStatus.date_effective, + ResourceStatus.last_check_time, + ) + .where( + ResourceStatus.status_type.in_(status_types), + ResourceStatus.vo == vo, + ) + .order_by(ResourceStatus.date_effective.desc()) # the most recent date + .limit(1) + ) + result = await self.conn.execute(stmt) + row = result.first() + if not row: + raise ResourceNotFoundError(f"Resource statuses for VO {vo}") + return row diff --git a/diracx-db/tests/rss/test_rss_db.py b/diracx-db/tests/rss/test_rss_db.py index 2956801ee..84332b73f 100644 --- a/diracx-db/tests/rss/test_rss_db.py +++ b/diracx-db/tests/rss/test_rss_db.py @@ -41,14 +41,17 @@ async def test_site_status(rss_db: ResourceStatusDB): # Test with the test Site (should be found) async with rss_db as db: - status, reason = await db.get_site_status("TestSite") + rows = await db.get_site_statuses() + assert rows + name, status, reason = rows[0] + assert name == "TestSite" assert status == "Active" assert reason == "All good" # Test with an unknow Site (should not be found) with pytest.raises(ResourceNotFoundError): async with rss_db as db: - await db.get_site_status("Unknown") + await db.get_site_statuses("Unknown") async def test_resource_status(rss_db: ResourceStatusDB): @@ -102,34 +105,38 @@ async def test_resource_status(rss_db: ResourceStatusDB): # Test with the test Compute Element (should be found) async with rss_db as db: - result = await db.get_resource_status("TestCompute") + result = await db.get_resource_statuses() + assert "TestCompute" in result + result = result["TestCompute"] assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" # Test with the test FTS (should be found) async with rss_db as db: - result = await db.get_resource_status("TestFTS") + result = await db.get_resource_statuses() + assert "TestFTS" in result + result = result["TestFTS"] assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" # Test with the test Storage Element (should be found) async with rss_db as db: - result = await db.get_resource_status( - "TestStorage", ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + result = await db.get_resource_statuses( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] ) - assert set(result.keys()) == { + assert set(result["TestStorage"].keys()) == { "ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess", } - for row in result.values(): + for row in result["TestStorage"].values(): assert row.Status == "Active" assert row.Reason == "All good" # Test with an unknow Resource (should not be found) with pytest.raises(ResourceNotFoundError): async with rss_db as db: - await db.get_resource_status("Unknown") + await db.get_resource_statuses(vo="Unknown") diff --git a/diracx-logic/src/diracx/logic/rss/query.py b/diracx-logic/src/diracx/logic/rss/query.py index 7ce83cbce..0c32034a0 100644 --- a/diracx-logic/src/diracx/logic/rss/query.py +++ b/diracx-logic/src/diracx/logic/rss/query.py @@ -35,36 +35,50 @@ def map_status(db_status: str, reason: str | None = None) -> ResourceStatus: ) -async def get_site_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> SiteStatusModel: - status, reason = await resource_status_db.get_site_status(name, vo) - return SiteStatusModel(all=map_status(status, reason)) +async def get_site_statuses( + resource_status_db: ResourceStatusDB, vo: str +) -> dict[str, SiteStatusModel]: + rows = await resource_status_db.get_site_statuses(vo) + return { + name: SiteStatusModel(all=map_status(status, reason)) + for name, status, reason in rows + } -async def get_compute_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> ComputeElementStatus: - rows = await resource_status_db.get_resource_status(name, ["all"], vo) - return ComputeElementStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) +async def get_compute_statuses( + resource_status_db: ResourceStatusDB, vo: str +) -> dict[str, ComputeElementStatus]: + all_rows = await resource_status_db.get_resource_statuses(["all"], vo) + return { + name: ComputeElementStatus( + all=map_status(rows["all"].Status, rows["all"].Reason) + ) + for name, rows in all_rows.items() + } -async def get_fts_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> FTSStatus: - rows = await resource_status_db.get_resource_status(name, ["all"], vo) - return FTSStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) +async def get_fts_statuses( + resource_status_db: ResourceStatusDB, vo: str +) -> dict[str, FTSStatus]: + all_rows = await resource_status_db.get_resource_statuses(["all"], vo) + return { + name: FTSStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) + for name, rows in all_rows.items() + } -async def get_storage_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> StorageElementStatus: - rows = await resource_status_db.get_resource_status( - name, ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], vo - ) - return StorageElementStatus( - read=map_status(rows["ReadAccess"].Status, rows["ReadAccess"].Reason), - write=map_status(rows["WriteAccess"].Status, rows["WriteAccess"].Reason), - check=map_status(rows["CheckAccess"].Status, rows["CheckAccess"].Reason), - remove=map_status(rows["RemoveAccess"].Status, rows["RemoveAccess"].Reason), +async def get_storage_statuses( + resource_status_db: ResourceStatusDB, vo: str +) -> dict[str, StorageElementStatus]: + all_rows = await resource_status_db.get_resource_statuses( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], vo ) + return { + name: StorageElementStatus( + read=map_status(rows["ReadAccess"].Status, rows["ReadAccess"].Reason), + write=map_status(rows["WriteAccess"].Status, rows["WriteAccess"].Reason), + check=map_status(rows["CheckAccess"].Status, rows["CheckAccess"].Reason), + remove=map_status(rows["RemoveAccess"].Status, rows["RemoveAccess"].Reason), + ) + for name, rows in all_rows.items() + } From 0c30e905df21902462af7f5ad3d808a0f00f837a Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Tue, 5 May 2026 14:56:27 +0200 Subject: [PATCH 2/7] feat(draft): add ResourceStatusSource and rss router --- .../src/diracx/core/config/__init__.py | 2 + diracx-core/src/diracx/core/config/sources.py | 90 +++++++- diracx-core/src/diracx/core/models/rss.py | 20 +- diracx-core/tests/test_status_source.py | 189 ++++++++++++++++ diracx-routers/pyproject.toml | 2 +- diracx-routers/src/diracx/routers/rss.py | 207 ++++++++++++++++++ diracx-routers/tests/test_rss.py | 123 +++++++++++ 7 files changed, 625 insertions(+), 8 deletions(-) create mode 100644 diracx-core/tests/test_status_source.py create mode 100644 diracx-routers/src/diracx/routers/rss.py create mode 100644 diracx-routers/tests/test_rss.py diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 35d5fa4e9..5dadca9c2 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -13,6 +13,7 @@ "OperationsConfig", "RegistryConfig", "RemoteGitConfigSource", + "ResourceStatusSource", "SerializableSet", "SupportInfo", "UserConfig", @@ -35,5 +36,6 @@ ConfigSourceUrl, LocalGitConfigSource, RemoteGitConfigSource, + ResourceStatusSource, is_running_in_async_context, ) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py index f16fffa82..879869e6f 100644 --- a/diracx-core/src/diracx/core/config/sources.py +++ b/diracx-core/src/diracx/core/config/sources.py @@ -12,7 +12,7 @@ from datetime import datetime, timezone from pathlib import Path from tempfile import TemporaryDirectory -from typing import Annotated, Generic, TypeVar +from typing import Annotated, Generic, Literal, TypeVar, Union from urllib.parse import urlparse, urlunparse import sh @@ -20,9 +20,22 @@ from cachetools import Cache, LRUCache from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints -from diracx.core.exceptions import BadConfigurationVersionError +from diracx.core.exceptions import BadConfigurationVersionError, ResourceNotFoundError from diracx.core.extensions import DiracEntryPoint, select_from_extension +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + StorageElementStatus, +) from diracx.core.utils import TwoLevelCache +from diracx.db.sql.rss.db import ResourceStatusDB +from diracx.logic.rss.query import ( + get_compute_statuses, + get_fts_statuses, + get_site_statuses, + get_storage_statuses, +) from .schema import Config @@ -305,3 +318,76 @@ def latest_revision(self) -> tuple[str, datetime]: logger.exception(err) return super().latest_revision() + + +class ResourceStatusSource( + CacheableSource[ + dict[str, StorageElementStatus] + | dict[str, ComputeElementStatus] + | dict[str, SiteStatus] + | dict[str, FTSStatus] + ] +): + """A source that provides the status of a resource.""" + + def __init__( + self, + *, + resource_type: Literal["ComputeElement", "StorageElement", "Site", "FTS"], + vo: str = "all", + resource_status_db: ResourceStatusDB, + ) -> None: + self.resource_type = resource_type + self.vo = vo + self.resource_status_db = resource_status_db + super().__init__() + + def latest_revision(self) -> tuple[str, datetime]: + """Return the latest revision of the resource status. + + This could be a hash of the current status snapshot or the max DateEffective/LastCheckTime. + """ + # Fetch the resource status from the database + status_date = asyncio.run(self.resource_status_db.get_status_date(vo=self.vo)) + # Generate a unique hash for the current status snapshot + status_hash = hash(frozenset(status_date)) + latest_revision = f"rev_{status_hash}" + + modified = status_date.DateEffective + + return latest_revision, modified + + def read_raw( + self, hexsha: str, modified: datetime + ) -> Union[ + dict[str, StorageElementStatus], + dict[str, ComputeElementStatus], + dict[str, SiteStatus], + dict[str, FTSStatus], + ]: + """Read the raw resource status from the database.""" + # Fetch the resource status from the database + status_data = asyncio.run(self.get_status_data()) + for status in status_data.values(): + status._hexsha = hexsha + status._modified = modified + return status_data + + async def get_status_data(self): + status_data = None + if self.resource_type == "Site": + status_data = await get_site_statuses(self.resource_status_db, vo=self.vo) + elif self.resource_type == "ComputeElement": + status_data = await get_compute_statuses( + self.resource_status_db, vo=self.vo + ) + elif self.resource_type == "StorageElement": + status_data = await get_storage_statuses( + self.resource_status_db, vo=self.vo + ) + elif self.resource_type == "FTS": + status_data = await get_fts_statuses(self.resource_status_db, vo=self.vo) + + if not status_data: + raise ResourceNotFoundError(f"Resource type {self.resource_type} not found") + return status_data diff --git a/diracx-core/src/diracx/core/models/rss.py b/diracx-core/src/diracx/core/models/rss.py index 6f022b38f..913d43cd0 100644 --- a/diracx-core/src/diracx/core/models/rss.py +++ b/diracx-core/src/diracx/core/models/rss.py @@ -1,9 +1,19 @@ from __future__ import annotations +from datetime import datetime from enum import StrEnum from typing import Annotated, Literal, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr + + +class CachedModel(BaseModel): + """Base class for models that are cached.""" + + # hash for a unique representation of the status version + _hexsha: str = PrivateAttr() + # modification date + _modified: datetime = PrivateAttr() class AllowedStatus(BaseModel): @@ -34,22 +44,22 @@ class ResourceType(StrEnum): FTS = "FTS" -class StorageElementStatus(BaseModel): +class StorageElementStatus(CachedModel): read: ResourceStatus write: ResourceStatus check: ResourceStatus remove: ResourceStatus -class ComputeElementStatus(BaseModel): +class ComputeElementStatus(CachedModel): all: ResourceStatus -class FTSStatus(BaseModel): +class FTSStatus(CachedModel): all: ResourceStatus -class SiteStatus(BaseModel): +class SiteStatus(CachedModel): all: ResourceStatus diff --git a/diracx-core/tests/test_status_source.py b/diracx-core/tests/test_status_source.py new file mode 100644 index 000000000..99312c012 --- /dev/null +++ b/diracx-core/tests/test_status_source.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from collections import namedtuple +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from diracx.core.config.sources import ResourceStatusSource +from diracx.core.exceptions import ResourceNotFoundError +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + StorageElementStatus, +) +from diracx.db.sql.rss.db import ResourceStatusDB + + +@pytest.fixture +def mock_resource_status_db(): + """Fixture to mock the ResourceStatusDB.""" + db = MagicMock(spec=ResourceStatusDB) + DateRow = namedtuple("DateRow", ["DateEffective", "DateChecked"]) + db.get_status_date = AsyncMock( + return_value=DateRow( + DateEffective=datetime.fromisoformat("2023-01-01T00:00:00+00:00"), + DateChecked=datetime.now(timezone.utc), + ) + ) + return db + + +def test_latest_revision(mock_resource_status_db): + """Test the latest_revision method of ResourceStatusSource.""" + source = ResourceStatusSource( + resource_type="Site", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the method + revision, modified = source.latest_revision() + + # Verify the revision is generated correctly + assert revision.startswith("rev_") + assert isinstance(modified, datetime) + + # Verify the database call + mock_resource_status_db.get_status_date.assert_called_once_with(vo="test_vo") + + +def test_read_raw_site(mock_resource_status_db): + """Test the read_raw method for Site resource type.""" + # Mock the database data + mock_db_data = [("testSite", "Active", "")] + + # Patch the get_site_statuses method of the database to return the mock data + mock_resource_status_db.get_site_statuses = AsyncMock(return_value=mock_db_data) + + # Initialize the ResourceStatusSource with the mocked database + source = ResourceStatusSource( + resource_type="Site", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the read_raw method, which internally calls get_site_statuses from query.py + result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result matches the expected output + expected_result = {"testSite": SiteStatus(all={"allowed": True, "warnings": None})} + for key, value in expected_result.items(): + assert key in result + assert value.model_dump() == result[key].model_dump() + # Verify that the database method was called correctly + mock_resource_status_db.get_site_statuses.assert_awaited_once_with("test_vo") + + +def test_read_raw_compute(mock_resource_status_db): + """Test the read_raw method for ComputeElement resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + + mock_db_data = { + "TestCE": {"all": ResourceStatus(Name="TestCE", Status="Active", Reason="")} + } + mock_resource_status_db.get_resource_statuses = AsyncMock(return_value=mock_db_data) + + source = ResourceStatusSource( + resource_type="ComputeElement", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the method + result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = { + "TestCE": ComputeElementStatus(all={"allowed": True, "warnings": None}) + } + for key, value in expected_result.items(): + assert key in result + assert value.model_dump() == result[key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( + ["all"], "test_vo" + ) + + +def test_read_raw_storage(mock_resource_status_db): + """Test the read_raw method for StorageElement resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + + mock_db_data = { + "TestSE": { + "ReadAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), + "WriteAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), + "CheckAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), + "RemoveAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), + } + } + mock_resource_status_db.get_resource_statuses.return_value = mock_db_data + source = ResourceStatusSource( + resource_type="StorageElement", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the method + result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = { + "TestSE": StorageElementStatus( + read={"allowed": True, "warnings": None}, + write={"allowed": True, "warnings": None}, + check={"allowed": True, "warnings": None}, + remove={"allowed": True, "warnings": None}, + ) + } + for key, value in expected_result.items(): + assert key in result + assert value.model_dump() == result[key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], "test_vo" + ) + + +def test_read_raw_fts(mock_resource_status_db): + """Test the read_raw method for FTS resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + + mock_db_data = { + "FTS": { + "all": ResourceStatus(Name="FTS", Status="Active", Reason=None), + } + } + mock_resource_status_db.get_resource_statuses.return_value = mock_db_data + + source = ResourceStatusSource( + resource_type="FTS", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the method + result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = {"FTS": FTSStatus(all={"allowed": True, "warnings": None})} + for key, value in expected_result.items(): + assert key in result + assert value.model_dump() == result[key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( + ["all"], "test_vo" + ) + + +def test_read_raw_invalid_resource_type(mock_resource_status_db): + """Test the read_raw method for an invalid resource type.""" + source = ResourceStatusSource( + resource_type="InvalidType", + vo="test_vo", + resource_status_db=mock_resource_status_db, + ) + + # Call the method and verify it raises ResourceNotFoundError + with pytest.raises(ResourceNotFoundError): + source.read_raw("test_revision", datetime.now(tz=timezone.utc)) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index d798b1696..7e721eac2 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -47,6 +47,7 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +rss = "diracx.routers.rss:router" [project.entry-points."diracx.access_policies"] wms = "diracx.routers.jobs.access_policies:WMSAccessPolicy" @@ -87,5 +88,4 @@ markers = [ "enabled_dependencies: List of dependencies which should be available to the FastAPI test client", ] - asyncio_default_fixture_loop_scope = "function" diff --git a/diracx-routers/src/diracx/routers/rss.py b/diracx-routers/src/diracx/routers/rss.py new file mode 100644 index 000000000..68e525a4c --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Annotated, Literal, Union, cast + +from fastapi import ( + Depends, + Header, + HTTPException, + Query, + Response, + status, +) + +from diracx.core.config.sources import ResourceStatusSource +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + StorageElementStatus, +) +from diracx.db.sql.rss.db import ResourceStatusDB +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from .fastapi_classes import DiracxRouter + +logger = logging.getLogger(__name__) + +LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" + +router = DiracxRouter() + +# Keep track of the ResourceStatusSource instances +resource_status_sources: dict[tuple[str, str], ResourceStatusSource] = {} + + +# Override the ResourceStatusSource dependency to use +async def get_resource_status_source( + resource_status_db: ResourceStatusDB, + resource_type: Literal[ + "ComputeElement", "StorageElement", "Site", "FTS" + ] = "StorageElement", + vo: str = "all", +) -> ResourceStatusSource: + key = (resource_type, vo) + if key not in resource_status_sources: + logger.debug(f"Creating new ResourceStatusSource for {key}") + resource_status_sources[key] = ResourceStatusSource( + resource_type=resource_type, + vo=vo, + resource_status_db=resource_status_db, + ) + # populate the cache + resource_status_sources[key].read() + else: + logger.debug(f"Reusing existing ResourceStatusSource for {key}") + return resource_status_sources[key] + + +async def get_resource_status( + response: Response, + resource_type: str, + vo: str, + resource_status_db: ResourceStatusDB, + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> Union[ + dict[str, StorageElementStatus], + dict[str, ComputeElementStatus], + dict[str, SiteStatus], + dict[str, FTSStatus], +]: + """Get the latest status of resources. + + If If-None-Match header is given and matches the latest ETag, return 304 + + If If-Modified-Since is given and is newer than latest, + return 304: this is to avoid flip/flopping + """ + resource_status_source = await get_resource_status_source( + resource_type=resource_type, + vo=vo, + resource_status_db=resource_status_db, + ) + status_data = await resource_status_source.read_non_blocking() + + last_modified = max(val._modified for val in status_data.values()) + + headers = { + "ETag": list(status_data.values())[0]._hexsha, + "Last-Modified": last_modified.strftime(LAST_MODIFIED_FORMAT), + } + + if if_none_match == list(status_data.values())[0]._hexsha: + raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) + + # This is to prevent flip/flopping in case + # a server gets out of sync with disk + if if_modified_since: + try: + not_before = datetime.strptime( + if_modified_since, LAST_MODIFIED_FORMAT + ).astimezone(timezone.utc) + except ValueError: + logger.debug( + "Failed to parse If-Modified-Since header: %s", if_modified_since + ) + else: + if not_before > last_modified: + raise HTTPException( + status_code=status.HTTP_304_NOT_MODIFIED, headers=headers + ) + + response.headers.update(headers) + + return status_data + + +@router.get("/storage") +async def get_storage_status( + response: Response, + resource_status_db: ResourceStatusDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +): + """Get the latest status of storage elements.""" + return cast( + dict[str, StorageElementStatus], + await get_resource_status( + response=response, + resource_type="StorageElement", + resource_status_db=resource_status_db, + vo=user_info.vo, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ), + ) + + +@router.get("/compute") +async def get_compute_status( + response: Response, + resource_status_db: ResourceStatusDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +): + """Get the latest status of compute elements.""" + return cast( + dict[str, ComputeElementStatus], + await get_resource_status( + response=response, + resource_type="ComputeElement", + resource_status_db=resource_status_db, + vo=user_info.vo, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ), + ) + + +@router.get("/site") +async def get_site_status( + response: Response, + resource_status_db: ResourceStatusDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + vo: Annotated[str | None, Query()] = None, + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +): + """Get the latest status of sites.""" + return cast( + dict[str, SiteStatus], + await get_resource_status( + response=response, + resource_type="Site", + resource_status_db=resource_status_db, + vo=user_info.vo, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ), + ) + + +@router.get("/fts") +async def get_fts_status( + response: Response, + resource_status_db: ResourceStatusDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + vo: Annotated[str | None, Query()] = None, + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +): + """Get the latest status of FTS servers.""" + return cast( + dict[str, FTSStatus], + await get_resource_status( + response=response, + resource_type="FTS", + resource_status_db=resource_status_db, + vo=user_info.vo, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ), + ) diff --git a/diracx-routers/tests/test_rss.py b/diracx-routers/tests/test_rss.py new file mode 100644 index 000000000..1779e1dfa --- /dev/null +++ b/diracx-routers/tests/test_rss.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import pytest +from fastapi import status + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ResourceStatusDB", + ] +) + + +@pytest.fixture +def normal_user_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +def test_unauthenticated(client_factory): + with client_factory.unauthenticated() as client: + response = client.get("/api/rss/storage") + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +@pytest.mark.parametrize( + "endpoint", + ["/api/rss/storage", "/api/rss/compute", "/api/rss/site", "/api/rss/fts"], +) +def test_get_resource_status(normal_user_client, endpoint): + r = normal_user_client.get(endpoint) + assert r.status_code == status.HTTP_200_OK, r.json() + assert r.json(), r.text + + last_modified = r.headers["Last-Modified"] + etag = r.headers["ETag"] + + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": last_modified, + }, + ) + + assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert not r.text + + # If only an invalid ETAG is passed, we expect a response + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "wrongEtag", + }, + ) + assert r.status_code == status.HTTP_200_OK, r.json() + assert r.json(), r.text + + # If an past ETAG and an past timestamp as give, we expect an response + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "pastEtag", + "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", + }, + ) + assert r.status_code == status.HTTP_200_OK, r.json() + assert r.json(), r.text + + # If an future ETAG and an new timestamp as give, we expect 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "futureEtag", + "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", + }, + ) + assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert not r.text + + # If an invalid ETAG and an invalid modified time, we expect a response + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "futureEtag", + "If-Modified-Since": "wrong format", + }, + ) + assert r.status_code == status.HTTP_200_OK, r.json() + assert r.json(), r.text + + # If the correct ETAG and a past timestamp as give, we expect 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", + }, + ) + assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert not r.text + + # If the correct ETAG and a new timestamp as give, we expect 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", + }, + ) + assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert not r.text + + # If the correct ETAG and an invalid modified time, we expect 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "wrong format", + }, + ) + assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert not r.text From c5b09e88a18bf69278f99b3f0340d96f223687d7 Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Tue, 12 May 2026 15:44:40 +0200 Subject: [PATCH 3/7] refactor(draft): make CacheableSource async native --- .../src/diracx/core/config/__init__.py | 4 - diracx-core/src/diracx/core/config/sources.py | 260 +++++++----------- diracx-core/tests/test_status_source.py | 2 +- diracx-logic/src/diracx/logic/rss/source.py | 118 ++++++++ diracx-routers/src/diracx/routers/factory.py | 47 ++++ diracx-routers/src/diracx/routers/rss.py | 250 +++++++++-------- diracx-routers/tests/test_rss.py | 5 +- diracx-testing/src/diracx/testing/utils.py | 3 + 8 files changed, 411 insertions(+), 278 deletions(-) create mode 100644 diracx-logic/src/diracx/logic/rss/source.py diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 5dadca9c2..eb941ce15 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -13,11 +13,9 @@ "OperationsConfig", "RegistryConfig", "RemoteGitConfigSource", - "ResourceStatusSource", "SerializableSet", "SupportInfo", "UserConfig", - "is_running_in_async_context", ] from .schema import ( @@ -36,6 +34,4 @@ ConfigSourceUrl, LocalGitConfigSource, RemoteGitConfigSource, - ResourceStatusSource, - is_running_in_async_context, ) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py index 879869e6f..5cf3816ac 100644 --- a/diracx-core/src/diracx/core/config/sources.py +++ b/diracx-core/src/diracx/core/config/sources.py @@ -9,10 +9,11 @@ import logging import os from abc import ABCMeta, abstractmethod +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from tempfile import TemporaryDirectory -from typing import Annotated, Generic, Literal, TypeVar, Union +from typing import Annotated, Generic, TypeVar from urllib.parse import urlparse, urlunparse import sh @@ -20,22 +21,9 @@ from cachetools import Cache, LRUCache from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints -from diracx.core.exceptions import BadConfigurationVersionError, ResourceNotFoundError +from diracx.core.exceptions import BadConfigurationVersionError from diracx.core.extensions import DiracEntryPoint, select_from_extension -from diracx.core.models.rss import ( - ComputeElementStatus, - FTSStatus, - SiteStatus, - StorageElementStatus, -) from diracx.core.utils import TwoLevelCache -from diracx.db.sql.rss.db import ResourceStatusDB -from diracx.logic.rss.query import ( - get_compute_statuses, - get_fts_statuses, - get_site_statuses, - get_storage_statuses, -) from .schema import Config @@ -49,14 +37,6 @@ logger = logging.getLogger(__name__) -def is_running_in_async_context(): - try: - asyncio.get_running_loop() - return True - except RuntimeError: - return False - - def _apply_default_scheme(value: str) -> str: """Apply the default git+file:// scheme if not present.""" if "://" not in value: @@ -73,77 +53,115 @@ class AnyUrlWithoutHost(AnyUrl): T = TypeVar("T") +@dataclass(frozen=True) +class Snapshot(Generic[T]): + """Wraps a cached data payload with its cache metadata. + + Decouples cache plumbing (hexsha, modified) from the data models themselves, + replacing the old CachedModel._hexsha / _modified private-attribute pattern. + """ + + data: T + hexsha: str + modified: datetime + + class CacheableSource(Generic[T], metaclass=ABCMeta): - """Abstract base class for sources that can be cached. + """Abstract base class for async sources that can be cached. Handles the caching of the latest revision and its content using a two-level cache. + Subclasses implement async `latest_revision` and `read_raw`; the base class + provides the single-flight refresh logic via asyncio.Event and asyncio.Task. """ def __init__(self): - # Revision cache is used to store the latest revision and its - # modification date. This cache has two TTLs, one which triggers the - # background refresh and the other which is results in a hard failure. - # This allows us to avoid blocking while the refresh is done, while - # maintaining strong guarantees on the data freshness. + # Revision cache stores (hexsha → content) with two TTLs. + # soft_ttl: triggers a background refresh while serving the stale value. + # hard_ttl: absolute deadline; missing it causes a hard miss (await refresh). self._revision_cache = TwoLevelCache( soft_ttl=DEFAULT_CS_REV_CACHE_SOFT_TTL, hard_ttl=DEFAULT_CS_REV_CACHE_HARD_TTL, max_workers=1, max_items=1, ) - # The content of a given revision can be stored in a simple LRU cache - # We keep the last two versions in memory to avoid any potential to flip - # flop between two versions when it changes. + # Keep the last two content versions so there is no flip-flop during a transition. self._content_cache: Cache = LRUCache(maxsize=2) - @abstractmethod - def latest_revision(self) -> tuple[str, datetime]: - """Abstract method. + # Single-flight refresh state: at most one Task is in flight at a time. + self._refresh_task: asyncio.Task | None = None + self._refresh_lock = asyncio.Lock() - Must return: - * a unique hash as a string, representing the last version - * a datetime object corresponding to when the version dates. - """ + @abstractmethod + async def latest_revision(self) -> tuple[str, datetime]: + """Return (hexsha, modified) for the current revision.""" @abstractmethod - def read_raw(self, hexsha: str, modified: datetime) -> T: - """Abstract method. + async def read_raw(self, hexsha: str, modified: datetime) -> T: + """Fetch and return the data for *hexsha*.""" - Return the Source object that corresponds to the specific hash - The `modified` parameter is just added as a attribute to the source. + async def _refresh(self) -> str: + """Fetch the latest revision and populate the content cache. + + Returns the hexsha so callers can look up self._content_cache[hexsha]. """ + hexsha, modified = await self.latest_revision() + if hexsha not in self._content_cache: + self._content_cache[hexsha] = await self.read_raw(hexsha, modified) + return hexsha - def read(self) -> T: - """Load the source from the backend with appropriate caching. + async def _ensure_refresh_task(self) -> asyncio.Task: + """Start a background refresh task if one is not already running.""" + async with self._refresh_lock: + if self._refresh_task is None or self._refresh_task.done(): + self._refresh_task = asyncio.create_task(self._refresh()) + return self._refresh_task + + async def read(self) -> T: + """Load the source with caching; awaits a refresh on a hard cache miss. - :raises: diracx.core.exceptions.NotReadyError if the source is being loaded still :raises: git.exc.BadName if version does not exist """ hexsha = self._revision_cache.get( - "latest_revision", self._read_work, blocking=True + "latest_revision", self._sync_refresh_shim, blocking=True ) return self._content_cache[hexsha] async def read_non_blocking(self) -> T: - """Load the source from the backend with appropriate caching. + """Load the source with caching; raises NotReadyError while a refresh is in flight. - :raises: diracx.core.exceptions.NotReadyError if the source is being loaded still - :raises: git.exc.BadName if version does not exist + Triggers a background refresh when the soft TTL has expired so that + subsequent requests benefit from fresh data without paying the latency now. + + :raises: diracx.core.exceptions.NotReadyError if the cache is cold """ - hexsha = self._revision_cache.get( - "latest_revision", self._read_work, blocking=False - ) - return self._content_cache[hexsha] + # Try the revision cache first (non-blocking). On a soft-miss or hard-miss + # we kick off an async background refresh and either serve stale or raise. + try: + hexsha = self._revision_cache.get( + "latest_revision", self._sync_refresh_shim, blocking=False + ) + return self._content_cache[hexsha] + except KeyError: + # The revision cache returned a hexsha not yet in the content cache — + # shouldn't happen in normal operation; treat as not-ready. + pass + + # Hard miss: nothing in either cache yet. Start (or reuse) a background + # refresh task and raise NotReadyError so the router can serve a 503. + asyncio.create_task(self._ensure_refresh_task()) + from diracx.core.exceptions import NotReadyError - def _read_work(self) -> str: - """Work function for the thread pool of `self._revision_cache`. + raise NotReadyError("Cache is not yet populated; a refresh is in progress.") - This function ensures that the latest revision is loaded into the - content cache before it is admitted into the revision cache. + def _sync_refresh_shim(self) -> str: + """Synchronous shim used by TwoLevelCache's thread-pool worker. + + Runs the async _refresh coroutine on the running event loop via + run_coroutine_threadsafe so the engine's loop is respected. """ - hexsha, modified = self.latest_revision() - if hexsha not in self._content_cache: - self._content_cache[hexsha] = self.read_raw(hexsha, modified) + loop = asyncio.get_event_loop() + future = asyncio.run_coroutine_threadsafe(self._refresh(), loop) + hexsha = future.result() return hexsha def clear_caches(self): @@ -199,22 +217,27 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: self.remote_url = self.extract_remote_url(backend_url) self.git_revision = self.get_git_revision_from_url(backend_url) - def latest_revision(self) -> tuple[str, datetime]: + async def latest_revision(self) -> tuple[str, datetime]: + """Return the latest git revision hash and its commit timestamp.""" try: - rev = sh.git( - "rev-parse", - self.git_revision, - _cwd=self.repo_location, - _tty_out=False, - _async=is_running_in_async_context(), + rev = ( + await asyncio.to_thread( + sh.git, + "rev-parse", + self.git_revision, + _cwd=self.repo_location, + _tty_out=False, + ) ).strip() - commit_info = sh.git.show( - "-s", - "--format=%ct", - rev, - _cwd=self.repo_location, - _tty_out=False, - _async=is_running_in_async_context(), + commit_info = ( + await asyncio.to_thread( + sh.git.show, + "-s", + "--format=%ct", + rev, + _cwd=self.repo_location, + _tty_out=False, + ) ).strip() modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) except sh.ErrorReturnCode as e: @@ -224,15 +247,15 @@ def latest_revision(self) -> tuple[str, datetime]: logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) return rev, modified - def read_raw(self, hexsha: str, modified: datetime) -> Config: + async def read_raw(self, hexsha: str, modified: datetime) -> Config: """:param: hexsha commit hash""" logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified) try: - blob = sh.git.show( + blob = await asyncio.to_thread( + sh.git.show, f"{hexsha}:{DEFAULT_CONFIG_FILE}", _cwd=self.repo_location, _tty_out=False, - _async=False, ) raw_obj = yaml.safe_load(blob) except sh.ErrorReturnCode as e: @@ -244,8 +267,6 @@ def read_raw(self, hexsha: str, modified: datetime) -> Config: group=DiracEntryPoint.CORE, name="config" )[0].load() config = config_class.model_validate(raw_obj) - config._hexsha = hexsha - config._modified = modified return config def extract_remote_url(self, backend_url: ConfigSourceUrl) -> str: @@ -310,84 +331,11 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: def __hash__(self): return hash(self.repo_location) - def latest_revision(self) -> tuple[str, datetime]: + async def latest_revision(self) -> tuple[str, datetime]: logger.debug("Pulling latest version from %s", self) try: - sh.git.pull(_cwd=self.repo_location, _async=False) + await asyncio.to_thread(sh.git.pull, _cwd=self.repo_location) except sh.ErrorReturnCode as err: logger.exception(err) - return super().latest_revision() - - -class ResourceStatusSource( - CacheableSource[ - dict[str, StorageElementStatus] - | dict[str, ComputeElementStatus] - | dict[str, SiteStatus] - | dict[str, FTSStatus] - ] -): - """A source that provides the status of a resource.""" - - def __init__( - self, - *, - resource_type: Literal["ComputeElement", "StorageElement", "Site", "FTS"], - vo: str = "all", - resource_status_db: ResourceStatusDB, - ) -> None: - self.resource_type = resource_type - self.vo = vo - self.resource_status_db = resource_status_db - super().__init__() - - def latest_revision(self) -> tuple[str, datetime]: - """Return the latest revision of the resource status. - - This could be a hash of the current status snapshot or the max DateEffective/LastCheckTime. - """ - # Fetch the resource status from the database - status_date = asyncio.run(self.resource_status_db.get_status_date(vo=self.vo)) - # Generate a unique hash for the current status snapshot - status_hash = hash(frozenset(status_date)) - latest_revision = f"rev_{status_hash}" - - modified = status_date.DateEffective - - return latest_revision, modified - - def read_raw( - self, hexsha: str, modified: datetime - ) -> Union[ - dict[str, StorageElementStatus], - dict[str, ComputeElementStatus], - dict[str, SiteStatus], - dict[str, FTSStatus], - ]: - """Read the raw resource status from the database.""" - # Fetch the resource status from the database - status_data = asyncio.run(self.get_status_data()) - for status in status_data.values(): - status._hexsha = hexsha - status._modified = modified - return status_data - - async def get_status_data(self): - status_data = None - if self.resource_type == "Site": - status_data = await get_site_statuses(self.resource_status_db, vo=self.vo) - elif self.resource_type == "ComputeElement": - status_data = await get_compute_statuses( - self.resource_status_db, vo=self.vo - ) - elif self.resource_type == "StorageElement": - status_data = await get_storage_statuses( - self.resource_status_db, vo=self.vo - ) - elif self.resource_type == "FTS": - status_data = await get_fts_statuses(self.resource_status_db, vo=self.vo) - - if not status_data: - raise ResourceNotFoundError(f"Resource type {self.resource_type} not found") - return status_data + return await super().latest_revision() diff --git a/diracx-core/tests/test_status_source.py b/diracx-core/tests/test_status_source.py index 99312c012..33f0b786d 100644 --- a/diracx-core/tests/test_status_source.py +++ b/diracx-core/tests/test_status_source.py @@ -6,7 +6,6 @@ import pytest -from diracx.core.config.sources import ResourceStatusSource from diracx.core.exceptions import ResourceNotFoundError from diracx.core.models.rss import ( ComputeElementStatus, @@ -15,6 +14,7 @@ StorageElementStatus, ) from diracx.db.sql.rss.db import ResourceStatusDB +from diracx.logic.rss.source import ResourceStatusSource @pytest.fixture diff --git a/diracx-logic/src/diracx/logic/rss/source.py b/diracx-logic/src/diracx/logic/rss/source.py new file mode 100644 index 000000000..c4f15b963 --- /dev/null +++ b/diracx-logic/src/diracx/logic/rss/source.py @@ -0,0 +1,118 @@ +"""Resource Status System source classes. + +These classes sit in the logic layer (diracx-logic) so they can import from +diracx-db without violating the project's dependency flow: + + routers → logic → db → core + +`CacheableSource` (the abstraction) and `Snapshot` live in diracx-core; +the concrete implementations live here because they need diracx-db. +""" + +from __future__ import annotations + +import logging +from datetime import datetime + +from diracx.core.config.sources import CacheableSource, Snapshot +from diracx.db.sql.rss.db import ResourceStatusDB + +from .query import ( + get_compute_statuses, + get_fts_statuses, + get_site_statuses, + get_storage_statuses, +) + +logger = logging.getLogger(__name__) + + +class ResourceStatusSource(CacheableSource[Snapshot]): + """Caching source for Compute, Storage, and FTS resource statuses. + + Holds a long-lived reference to the app-level ``ResourceStatusDB`` instance + (created in ``factory.py`` alongside the engine). Each refresh uses + ``async with self._db`` so that ``__aenter__`` runs — which sets the + ``ContextVar`` connection — on the *same* event loop the engine is bound to. + + One source covers *all* VOs for a given resource type. VO-level filtering + is done in the route after the snapshot is fetched, keeping the cache simple + and avoiding N redundant poll schedules for the same underlying table. + """ + + def __init__(self, *, db: ResourceStatusDB, resource_type: str) -> None: + """ + Args: + db: Long-lived ``ResourceStatusDB`` instance from factory.py. + Must already have ``engine_context`` registered (i.e. its engine + is open for the application lifetime). + resource_type: One of ``"ComputeElement"``, ``"StorageElement"``, + ``"FTS"``. + """ + super().__init__() + self._db = db + self._resource_type = resource_type + + async def latest_revision(self) -> tuple[str, datetime]: + """Query the max DateEffective for this resource type across all VOs. + + Uses ``modified.isoformat()`` as the ETag so the value is deterministic + across replicas (unlike ``hash()``, which is randomised by PYTHONHASHSEED). + """ + async with self._db as db: + status_date = await db.get_resource_status_date( + resource_type=self._resource_type + ) + modified: datetime = status_date.DateEffective + return modified.isoformat(), modified + + async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: + """Fetch all statuses for this resource type and wrap them in a Snapshot.""" + async with self._db as db: + data = await self._fetch(db) + return Snapshot(data=data, hexsha=hexsha, modified=modified) + + async def _fetch(self, db: ResourceStatusDB): + """Dispatch to the appropriate query helper based on resource type.""" + if self._resource_type == "ComputeElement": + return await get_compute_statuses(db, vo="all") + if self._resource_type == "StorageElement": + return await get_storage_statuses(db, vo="all") + if self._resource_type == "FTS": + return await get_fts_statuses(db, vo="all") + raise ValueError(f"Unsupported resource_type: {self._resource_type!r}") + + +class SiteStatusSource(CacheableSource[Snapshot]): + """Caching source for Site statuses. + + Sites have a first-class status row in their own table (``SiteStatus``), + independent of the per-resource ``ResourceStatus`` table that + ``ResourceStatusSource`` queries. They are also always stored with + ``vo="all"``, so no per-VO filtering is needed on the way out. + + Keeping this as a separate class (rather than a subtype of + ``ResourceStatusSource``) avoids conflating two different DB tables and + makes the ``latest_revision`` query correct for each. + """ + + def __init__(self, *, db: ResourceStatusDB) -> None: + """ + Args: + db: Long-lived ``ResourceStatusDB`` instance from factory.py. + """ + super().__init__() + self._db = db + + async def latest_revision(self) -> tuple[str, datetime]: + """Query the max DateEffective from the Site status table.""" + async with self._db as db: + status_date = await db.get_site_status_date() + modified: datetime = status_date.DateEffective + return modified.isoformat(), modified + + async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: + """Fetch all site statuses and wrap them in a Snapshot.""" + async with self._db as db: + data = await get_site_statuses(db, vo="all") + return Snapshot(data=data, hexsha=hexsha, modified=modified) diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py index 4c26f4c58..f121d601f 100644 --- a/diracx-routers/src/diracx/routers/factory.py +++ b/diracx-routers/src/diracx/routers/factory.py @@ -35,8 +35,11 @@ from diracx.core.utils import dotenv_files_from_environment from diracx.db.exceptions import DBUnavailableError from diracx.db.os.utils import BaseOSDB +from diracx.db.sql.rss.db import ResourceStatusDB from diracx.db.sql.utils import BaseSQLDB +from diracx.logic.rss.source import ResourceStatusSource, SiteStatusSource from diracx.routers.access_policies import BaseAccessPolicy, check_permissions +from diracx.routers.rss import RSSSnapshotSentinels from .fastapi_classes import DiracFastAPI, DiracxRouter from .otel import instrument_otel @@ -186,6 +189,9 @@ def create_app_inner( # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() + # Track the app-lifetime ResourceStatusDB instance so we can build sources below. + rss_db_instance: ResourceStatusDB | None = None + for db_name, db_url in database_urls.items(): try: sql_db_classes = BaseSQLDB.available_implementations(db_name) @@ -207,6 +213,12 @@ def create_app_inner( db_no_transaction, sql_db ) + # Capture the long-lived ResourceStatusDB instance for the RSS sources. + # We reuse this instance (not a per-request DI copy) so that the engine's + # event loop and the async context manager are managed correctly. + if isinstance(sql_db, ResourceStatusDB) and rss_db_instance is None: + rss_db_instance = sql_db + # At least one DB works, so we do not fail the startup fail_startup = False except Exception: @@ -215,6 +227,41 @@ def create_app_inner( if fail_startup: raise Exception("No SQL database could be initialized, aborting") + # --------------------------------------------------------------------------- + # Wire RSS sources via dependency_overrides — same pattern as ConfigSource.create. + # + # Each source holds a reference to the *app-lifetime* rss_db_instance; every + # refresh calls `async with self._db` so __aenter__ runs on the FastAPI event + # loop (the same loop the engine is bound to). One source per resource type + # covers all VOs; per-VO filtering is done in the route. + # --------------------------------------------------------------------------- + if rss_db_instance is not None: + compute_source = ResourceStatusSource( + db=rss_db_instance, resource_type="ComputeElement" + ) + storage_source = ResourceStatusSource( + db=rss_db_instance, resource_type="StorageElement" + ) + fts_source = ResourceStatusSource(db=rss_db_instance, resource_type="FTS") + site_source = SiteStatusSource(db=rss_db_instance) + + app.dependency_overrides[RSSSnapshotSentinels.get_compute_snapshot] = ( + compute_source.read_non_blocking + ) + app.dependency_overrides[RSSSnapshotSentinels.get_storage_snapshot] = ( + storage_source.read_non_blocking + ) + app.dependency_overrides[RSSSnapshotSentinels.get_fts_snapshot] = ( + fts_source.read_non_blocking + ) + app.dependency_overrides[RSSSnapshotSentinels.get_site_snapshot] = ( + site_source.read_non_blocking + ) + else: + logger.warning( + "ResourceStatusDB not found; RSS endpoints will not be available." + ) + # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() for db_name, connection_kwargs in os_database_conn_kwargs.items(): diff --git a/diracx-routers/src/diracx/routers/rss.py b/diracx-routers/src/diracx/routers/rss.py index 68e525a4c..833c19e83 100644 --- a/diracx-routers/src/diracx/routers/rss.py +++ b/diracx-routers/src/diracx/routers/rss.py @@ -1,26 +1,27 @@ from __future__ import annotations import logging +from collections.abc import Callable from datetime import datetime, timezone -from typing import Annotated, Literal, Union, cast +from enum import StrEnum, auto +from typing import Annotated, cast from fastapi import ( Depends, Header, HTTPException, - Query, Response, status, ) -from diracx.core.config.sources import ResourceStatusSource +from diracx.core.config.sources import Snapshot from diracx.core.models.rss import ( ComputeElementStatus, FTSStatus, SiteStatus, StorageElementStatus, ) -from diracx.db.sql.rss.db import ResourceStatusDB +from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token from .fastapi_classes import DiracxRouter @@ -31,72 +32,97 @@ router = DiracxRouter() -# Keep track of the ResourceStatusSource instances -resource_status_sources: dict[tuple[str, str], ResourceStatusSource] = {} - - -# Override the ResourceStatusSource dependency to use -async def get_resource_status_source( - resource_status_db: ResourceStatusDB, - resource_type: Literal[ - "ComputeElement", "StorageElement", "Site", "FTS" - ] = "StorageElement", - vo: str = "all", -) -> ResourceStatusSource: - key = (resource_type, vo) - if key not in resource_status_sources: - logger.debug(f"Creating new ResourceStatusSource for {key}") - resource_status_sources[key] = ResourceStatusSource( - resource_type=resource_type, - vo=vo, - resource_status_db=resource_status_db, - ) - # populate the cache - resource_status_sources[key].read() - else: - logger.debug(f"Reusing existing ResourceStatusSource for {key}") - return resource_status_sources[key] - - -async def get_resource_status( - response: Response, - resource_type: str, - vo: str, - resource_status_db: ResourceStatusDB, - if_none_match: Annotated[str | None, Header()] = None, - if_modified_since: Annotated[str | None, Header()] = None, -) -> Union[ - dict[str, StorageElementStatus], - dict[str, ComputeElementStatus], - dict[str, SiteStatus], - dict[str, FTSStatus], -]: - """Get the latest status of resources. - - If If-None-Match header is given and matches the latest ETag, return 304 - - If If-Modified-Since is given and is newer than latest, - return 304: this is to avoid flip/flopping + +# --------------------------------------------------------------------------- +# Access policy +# --------------------------------------------------------------------------- + + +class ActionType(StrEnum): + # Create a job or a sandbox + CREATE = auto() + # Check job status, download a sandbox + READ = auto() + # Delete, kill, remove, set status, etc of a job + # Delete or assign a sandbox + MANAGE = auto() + # Search + QUERY = auto() + # Actions from a pilot (e.g. heartbeat) + PILOT = auto() + + +class ResourceStatusAccessPolicy(BaseAccessPolicy): + """Policy: any authenticated user may READ resource statuses. + + Write/admin actions are rejected here; VO scoping is the route's responsibility. + Registered under ``[project.entry-points."diracx.access_policies"]`` in + ``diracx-routers/pyproject.toml`` so the framework can discover it. """ - resource_status_source = await get_resource_status_source( - resource_type=resource_type, - vo=vo, - resource_status_db=resource_status_db, - ) - status_data = await resource_status_source.read_non_blocking() - last_modified = max(val._modified for val in status_data.values()) + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + ): + if action != ActionType.READ: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Resource Status System is read-only.", + ) + # Any authenticated user may read; VO scoping happens in the route. + + +ResourceStatusAccessPolicyCallable = Annotated[ + Callable, Depends(ResourceStatusAccessPolicy.check) +] + +class RSSSnapshotSentinels: + @classmethod + def get_storage_snapshot(cls) -> Snapshot: + raise NotImplementedError + + @classmethod + def get_compute_snapshot(cls) -> Snapshot: + raise NotImplementedError + + @classmethod + def get_site_snapshot(cls) -> Snapshot: + raise NotImplementedError + + @classmethod + def get_fts_snapshot(cls) -> Snapshot: + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Shared ETag / 304 helper +# --------------------------------------------------------------------------- + + +def _apply_cache_headers( + response: Response, + snapshot: Snapshot, + if_none_match: str | None, + if_modified_since: str | None, +) -> None: + """Set ETag / Last-Modified headers and raise 304 when appropriate. + + Raises: + HTTPException(304): when the client's cached copy is still current. + """ headers = { - "ETag": list(status_data.values())[0]._hexsha, - "Last-Modified": last_modified.strftime(LAST_MODIFIED_FORMAT), + "ETag": snapshot.hexsha, + "Last-Modified": snapshot.modified.strftime(LAST_MODIFIED_FORMAT), } - if if_none_match == list(status_data.values())[0]._hexsha: + if if_none_match == snapshot.hexsha: raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) - # This is to prevent flip/flopping in case - # a server gets out of sync with disk if if_modified_since: try: not_before = datetime.strptime( @@ -107,101 +133,99 @@ async def get_resource_status( "Failed to parse If-Modified-Since header: %s", if_modified_since ) else: - if not_before > last_modified: + # Guard against flip-flop when a replica is momentarily behind. + if not_before > snapshot.modified: raise HTTPException( status_code=status.HTTP_304_NOT_MODIFIED, headers=headers ) response.headers.update(headers) - return status_data + +# --------------------------------------------------------------------------- +# Routes +# --------------------------------------------------------------------------- @router.get("/storage") async def get_storage_status( response: Response, - resource_status_db: ResourceStatusDB, + snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_storage_snapshot)], user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: ResourceStatusAccessPolicyCallable, if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, -): - """Get the latest status of storage elements.""" +) -> dict[str, StorageElementStatus]: + """Get the latest status of storage elements, scoped to the caller's VO.""" + await check_permissions(action=ActionType.READ) + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, StorageElementStatus], - await get_resource_status( - response=response, - resource_type="StorageElement", - resource_status_db=resource_status_db, - vo=user_info.vo, - if_none_match=if_none_match, - if_modified_since=if_modified_since, - ), + { + name: se + for name, se in snapshot.data.items() + if getattr(se, "vo", "all") in (user_info.vo, "all") + }, ) @router.get("/compute") async def get_compute_status( response: Response, - resource_status_db: ResourceStatusDB, + snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_compute_snapshot)], user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: ResourceStatusAccessPolicyCallable, if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, -): - """Get the latest status of compute elements.""" +) -> dict[str, ComputeElementStatus]: + """Get the latest status of compute elements, scoped to the caller's VO.""" + await check_permissions(action=ActionType.READ) + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, ComputeElementStatus], - await get_resource_status( - response=response, - resource_type="ComputeElement", - resource_status_db=resource_status_db, - vo=user_info.vo, - if_none_match=if_none_match, - if_modified_since=if_modified_since, - ), + { + name: ce + for name, ce in snapshot.data.items() + if getattr(ce, "vo", "all") in (user_info.vo, "all") + }, ) @router.get("/site") async def get_site_status( response: Response, - resource_status_db: ResourceStatusDB, + snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_site_snapshot)], user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - vo: Annotated[str | None, Query()] = None, + check_permissions: ResourceStatusAccessPolicyCallable, if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, -): - """Get the latest status of sites.""" - return cast( - dict[str, SiteStatus], - await get_resource_status( - response=response, - resource_type="Site", - resource_status_db=resource_status_db, - vo=user_info.vo, - if_none_match=if_none_match, - if_modified_since=if_modified_since, - ), - ) +) -> dict[str, SiteStatus]: + """Get the latest status of sites. + + Sites are always stored with vo="all" so no VO filtering is applied. + """ + await check_permissions(action=ActionType.READ) + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) + return cast(dict[str, SiteStatus], snapshot.data) @router.get("/fts") async def get_fts_status( response: Response, - resource_status_db: ResourceStatusDB, + snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_fts_snapshot)], user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - vo: Annotated[str | None, Query()] = None, + check_permissions: ResourceStatusAccessPolicyCallable, if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, -): - """Get the latest status of FTS servers.""" +) -> dict[str, FTSStatus]: + """Get the latest status of FTS servers, scoped to the caller's VO.""" + await check_permissions(action=ActionType.READ) + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, FTSStatus], - await get_resource_status( - response=response, - resource_type="FTS", - resource_status_db=resource_status_db, - vo=user_info.vo, - if_none_match=if_none_match, - if_modified_since=if_modified_since, - ), + { + name: fts + for name, fts in snapshot.data.items() + if getattr(fts, "vo", "all") in (user_info.vo, "all") + }, ) diff --git a/diracx-routers/tests/test_rss.py b/diracx-routers/tests/test_rss.py index 1779e1dfa..e3108aeed 100644 --- a/diracx-routers/tests/test_rss.py +++ b/diracx-routers/tests/test_rss.py @@ -4,10 +4,7 @@ from fastapi import status pytestmark = pytest.mark.enabled_dependencies( - [ - "AuthSettings", - "ResourceStatusDB", - ] + ["AuthSettings", "ResourceStatusDB", "RSSSnapshotSentinels"] ) diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 570241e69..bb8d99195 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -242,6 +242,8 @@ def enrich_tokens( all_access_policies=all_access_policies, ) + from diracx.routers.rss import RSSSnapshotSentinels + self.all_dependency_overrides = self.app.dependency_overrides.copy() self.app.dependency_overrides = {} for obj in self.all_dependency_overrides: @@ -253,6 +255,7 @@ def enrich_tokens( BaseOSDB, ConfigSource, BaseAccessPolicy, + RSSSnapshotSentinels, ), ), obj From 884619c8728352bdc4e8865498b7c05ff31dc211 Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Thu, 21 May 2026 17:54:04 +0200 Subject: [PATCH 4/7] feat: add AsyncCacheableSource and RSSAccessPolicy --- .../src/diracx/core/config/__init__.py | 6 + diracx-core/src/diracx/core/config/sources.py | 221 +++++++++-------- diracx-core/src/diracx/core/models/rss.py | 27 +- diracx-core/src/diracx/core/utils.py | 150 +++++++++++- diracx-db/src/diracx/db/sql/rss/db.py | 150 ++++++++++-- diracx-db/tests/rss/test_rss_db.py | 16 +- diracx-logic/src/diracx/logic/rss/query.py | 87 ++++--- diracx-logic/src/diracx/logic/rss/source.py | 131 +++++----- .../rss/{test_rss.py => test_rss_query.py} | 0 .../tests/rss/test_rss_source.py | 141 +++++------ diracx-routers/pyproject.toml | 1 + diracx-routers/src/diracx/routers/factory.py | 49 ++-- diracx-routers/src/diracx/routers/rss.py | 231 ------------------ .../src/diracx/routers/rss/__init__.py | 7 + .../src/diracx/routers/rss/access_policies.py | 27 ++ diracx-routers/src/diracx/routers/rss/rss.py | 143 +++++++++++ diracx-routers/tests/test_rss.py | 106 ++++++-- diracx-testing/src/diracx/testing/utils.py | 6 +- 18 files changed, 900 insertions(+), 599 deletions(-) rename diracx-logic/tests/rss/{test_rss.py => test_rss_query.py} (100%) rename diracx-core/tests/test_status_source.py => diracx-logic/tests/rss/test_rss_source.py (52%) delete mode 100644 diracx-routers/src/diracx/routers/rss.py create mode 100644 diracx-routers/src/diracx/routers/rss/__init__.py create mode 100644 diracx-routers/src/diracx/routers/rss/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/rss/rss.py diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index eb941ce15..15c0c4970 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations __all__ = [ + "AsyncCacheableSource", + "CacheableSource", "Config", "ConfigSource", "ConfigSourceUrl", @@ -16,6 +18,7 @@ "SerializableSet", "SupportInfo", "UserConfig", + "is_running_in_async_context", ] from .schema import ( @@ -30,8 +33,11 @@ UserConfig, ) from .sources import ( + AsyncCacheableSource, + CacheableSource, ConfigSource, ConfigSourceUrl, LocalGitConfigSource, RemoteGitConfigSource, + is_running_in_async_context, ) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py index 5cf3816ac..8072cb1ef 100644 --- a/diracx-core/src/diracx/core/config/sources.py +++ b/diracx-core/src/diracx/core/config/sources.py @@ -9,7 +9,6 @@ import logging import os from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from tempfile import TemporaryDirectory @@ -23,7 +22,7 @@ from diracx.core.exceptions import BadConfigurationVersionError from diracx.core.extensions import DiracEntryPoint, select_from_extension -from diracx.core.utils import TwoLevelCache +from diracx.core.utils import AsyncTwoLevelCache, TwoLevelCache from .schema import Config @@ -37,6 +36,14 @@ logger = logging.getLogger(__name__) +def is_running_in_async_context(): + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + def _apply_default_scheme(value: str) -> str: """Apply the default git+file:// scheme if not present.""" if "://" not in value: @@ -53,115 +60,77 @@ class AnyUrlWithoutHost(AnyUrl): T = TypeVar("T") -@dataclass(frozen=True) -class Snapshot(Generic[T]): - """Wraps a cached data payload with its cache metadata. - - Decouples cache plumbing (hexsha, modified) from the data models themselves, - replacing the old CachedModel._hexsha / _modified private-attribute pattern. - """ - - data: T - hexsha: str - modified: datetime - - class CacheableSource(Generic[T], metaclass=ABCMeta): - """Abstract base class for async sources that can be cached. + """Abstract base class for sources that can be cached. Handles the caching of the latest revision and its content using a two-level cache. - Subclasses implement async `latest_revision` and `read_raw`; the base class - provides the single-flight refresh logic via asyncio.Event and asyncio.Task. """ def __init__(self): - # Revision cache stores (hexsha → content) with two TTLs. - # soft_ttl: triggers a background refresh while serving the stale value. - # hard_ttl: absolute deadline; missing it causes a hard miss (await refresh). + # Revision cache is used to store the latest revision and its + # modification date. This cache has two TTLs, one which triggers the + # background refresh and the other which is results in a hard failure. + # This allows us to avoid blocking while the refresh is done, while + # maintaining strong guarantees on the data freshness. self._revision_cache = TwoLevelCache( soft_ttl=DEFAULT_CS_REV_CACHE_SOFT_TTL, hard_ttl=DEFAULT_CS_REV_CACHE_HARD_TTL, max_workers=1, max_items=1, ) - # Keep the last two content versions so there is no flip-flop during a transition. + # The content of a given revision can be stored in a simple LRU cache + # We keep the last two versions in memory to avoid any potential to flip + # flop between two versions when it changes. self._content_cache: Cache = LRUCache(maxsize=2) - # Single-flight refresh state: at most one Task is in flight at a time. - self._refresh_task: asyncio.Task | None = None - self._refresh_lock = asyncio.Lock() - @abstractmethod - async def latest_revision(self) -> tuple[str, datetime]: - """Return (hexsha, modified) for the current revision.""" + def latest_revision(self) -> tuple[str, datetime]: + """Abstract method. - @abstractmethod - async def read_raw(self, hexsha: str, modified: datetime) -> T: - """Fetch and return the data for *hexsha*.""" + Must return: + * a unique hash as a string, representing the last version + * a datetime object corresponding to when the version dates. + """ - async def _refresh(self) -> str: - """Fetch the latest revision and populate the content cache. + @abstractmethod + def read_raw(self, hexsha: str, modified: datetime) -> T: + """Abstract method. - Returns the hexsha so callers can look up self._content_cache[hexsha]. + Return the Source object that corresponds to the specific hash + The `modified` parameter is just added as a attribute to the source. """ - hexsha, modified = await self.latest_revision() - if hexsha not in self._content_cache: - self._content_cache[hexsha] = await self.read_raw(hexsha, modified) - return hexsha - async def _ensure_refresh_task(self) -> asyncio.Task: - """Start a background refresh task if one is not already running.""" - async with self._refresh_lock: - if self._refresh_task is None or self._refresh_task.done(): - self._refresh_task = asyncio.create_task(self._refresh()) - return self._refresh_task - - async def read(self) -> T: - """Load the source with caching; awaits a refresh on a hard cache miss. + def read(self) -> T: + """Load the source from the backend with appropriate caching. + :raises: diracx.core.exceptions.NotReadyError if the source is being loaded still :raises: git.exc.BadName if version does not exist """ hexsha = self._revision_cache.get( - "latest_revision", self._sync_refresh_shim, blocking=True + "latest_revision", self._read_work, blocking=True ) return self._content_cache[hexsha] async def read_non_blocking(self) -> T: - """Load the source with caching; raises NotReadyError while a refresh is in flight. - - Triggers a background refresh when the soft TTL has expired so that - subsequent requests benefit from fresh data without paying the latency now. + """Load the source from the backend with appropriate caching. - :raises: diracx.core.exceptions.NotReadyError if the cache is cold + :raises: diracx.core.exceptions.NotReadyError if the source is being loaded still + :raises: git.exc.BadName if version does not exist """ - # Try the revision cache first (non-blocking). On a soft-miss or hard-miss - # we kick off an async background refresh and either serve stale or raise. - try: - hexsha = self._revision_cache.get( - "latest_revision", self._sync_refresh_shim, blocking=False - ) - return self._content_cache[hexsha] - except KeyError: - # The revision cache returned a hexsha not yet in the content cache — - # shouldn't happen in normal operation; treat as not-ready. - pass - - # Hard miss: nothing in either cache yet. Start (or reuse) a background - # refresh task and raise NotReadyError so the router can serve a 503. - asyncio.create_task(self._ensure_refresh_task()) - from diracx.core.exceptions import NotReadyError - - raise NotReadyError("Cache is not yet populated; a refresh is in progress.") + hexsha = self._revision_cache.get( + "latest_revision", self._read_work, blocking=False + ) + return self._content_cache[hexsha] - def _sync_refresh_shim(self) -> str: - """Synchronous shim used by TwoLevelCache's thread-pool worker. + def _read_work(self) -> str: + """Work function for the thread pool of `self._revision_cache`. - Runs the async _refresh coroutine on the running event loop via - run_coroutine_threadsafe so the engine's loop is respected. + This function ensures that the latest revision is loaded into the + content cache before it is admitted into the revision cache. """ - loop = asyncio.get_event_loop() - future = asyncio.run_coroutine_threadsafe(self._refresh(), loop) - hexsha = future.result() + hexsha, modified = self.latest_revision() + if hexsha not in self._content_cache: + self._content_cache[hexsha] = self.read_raw(hexsha, modified) return hexsha def clear_caches(self): @@ -170,6 +139,55 @@ def clear_caches(self): self._content_cache.clear() +class AsyncCacheableSource(Generic[T], metaclass=ABCMeta): + """Abstract base class for async sources that can be cached. + + Async equivalent of CacheableSource. Uses AsyncTwoLevelCache so populate + functions are native coroutines. + """ + + def __init__(self): + self._revision_cache = AsyncTwoLevelCache( + soft_ttl=DEFAULT_CS_REV_CACHE_SOFT_TTL, + hard_ttl=DEFAULT_CS_REV_CACHE_HARD_TTL, + max_items=1, + ) + self._content_cache: Cache = LRUCache(maxsize=2) + + @abstractmethod + async def latest_revision(self) -> tuple[str, datetime]: + """Return (revision_str, modified) identifying the current revision.""" + + @abstractmethod + async def read_raw(self, hexsha: str, modified: datetime) -> T: + """Fetch and return the data for the given revision.""" + + async def _read_work(self) -> str: + hexsha, modified = await self.latest_revision() + if hexsha not in self._content_cache: + self._content_cache[hexsha] = await self.read_raw(hexsha, modified) + return hexsha + + async def read(self) -> T: + """Blocking read — awaits refresh on a hard cache miss.""" + hexsha = await self._revision_cache.get( + "latest_revision", self._read_work, blocking=True + ) + return self._content_cache[hexsha] + + async def read_non_blocking(self) -> T: + """Non-blocking read — raises NotReadyError on a hard cache miss.""" + hexsha = await self._revision_cache.get( + "latest_revision", self._read_work, blocking=False + ) + return self._content_cache[hexsha] + + async def clear_caches(self): + """Clear the caches.""" + await self._revision_cache.clear() + self._content_cache.clear() + + class ConfigSource(CacheableSource[Config]): """Abstract class for the configuration source. @@ -217,27 +235,22 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: self.remote_url = self.extract_remote_url(backend_url) self.git_revision = self.get_git_revision_from_url(backend_url) - async def latest_revision(self) -> tuple[str, datetime]: - """Return the latest git revision hash and its commit timestamp.""" + def latest_revision(self) -> tuple[str, datetime]: try: - rev = ( - await asyncio.to_thread( - sh.git, - "rev-parse", - self.git_revision, - _cwd=self.repo_location, - _tty_out=False, - ) + rev = sh.git( + "rev-parse", + self.git_revision, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), ).strip() - commit_info = ( - await asyncio.to_thread( - sh.git.show, - "-s", - "--format=%ct", - rev, - _cwd=self.repo_location, - _tty_out=False, - ) + commit_info = sh.git.show( + "-s", + "--format=%ct", + rev, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), ).strip() modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) except sh.ErrorReturnCode as e: @@ -247,15 +260,15 @@ async def latest_revision(self) -> tuple[str, datetime]: logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) return rev, modified - async def read_raw(self, hexsha: str, modified: datetime) -> Config: + def read_raw(self, hexsha: str, modified: datetime) -> Config: """:param: hexsha commit hash""" logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified) try: - blob = await asyncio.to_thread( - sh.git.show, + blob = sh.git.show( f"{hexsha}:{DEFAULT_CONFIG_FILE}", _cwd=self.repo_location, _tty_out=False, + _async=False, ) raw_obj = yaml.safe_load(blob) except sh.ErrorReturnCode as e: @@ -267,6 +280,8 @@ async def read_raw(self, hexsha: str, modified: datetime) -> Config: group=DiracEntryPoint.CORE, name="config" )[0].load() config = config_class.model_validate(raw_obj) + config._hexsha = hexsha + config._modified = modified return config def extract_remote_url(self, backend_url: ConfigSourceUrl) -> str: @@ -331,11 +346,11 @@ def __init__(self, *, backend_url: ConfigSourceUrl) -> None: def __hash__(self): return hash(self.repo_location) - async def latest_revision(self) -> tuple[str, datetime]: + def latest_revision(self) -> tuple[str, datetime]: logger.debug("Pulling latest version from %s", self) try: - await asyncio.to_thread(sh.git.pull, _cwd=self.repo_location) + sh.git.pull(_cwd=self.repo_location, _async=False) except sh.ErrorReturnCode as err: logger.exception(err) - return await super().latest_revision() + return super().latest_revision() diff --git a/diracx-core/src/diracx/core/models/rss.py b/diracx-core/src/diracx/core/models/rss.py index 913d43cd0..10e1444a9 100644 --- a/diracx-core/src/diracx/core/models/rss.py +++ b/diracx-core/src/diracx/core/models/rss.py @@ -1,19 +1,22 @@ from __future__ import annotations +from dataclasses import dataclass from datetime import datetime from enum import StrEnum -from typing import Annotated, Literal, Union +from typing import Annotated, Generic, Literal, TypeVar, Union -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, Field +T = TypeVar("T") -class CachedModel(BaseModel): - """Base class for models that are cached.""" - # hash for a unique representation of the status version - _hexsha: str = PrivateAttr() - # modification date - _modified: datetime = PrivateAttr() +@dataclass(frozen=True) +class Snapshot(Generic[T]): + """Wraps a cached data payload with its cache metadata.""" + + data: T + hexsha: str + modified: datetime class AllowedStatus(BaseModel): @@ -44,22 +47,22 @@ class ResourceType(StrEnum): FTS = "FTS" -class StorageElementStatus(CachedModel): +class StorageElementStatus(BaseModel): read: ResourceStatus write: ResourceStatus check: ResourceStatus remove: ResourceStatus -class ComputeElementStatus(CachedModel): +class ComputeElementStatus(BaseModel): all: ResourceStatus -class FTSStatus(CachedModel): +class FTSStatus(BaseModel): all: ResourceStatus -class SiteStatus(CachedModel): +class SiteStatus(BaseModel): all: ResourceStatus diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index 8949d07f3..f1ddf0ee5 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -3,6 +3,7 @@ __all__ = [ "EXPIRES_GRACE_SECONDS", "TwoLevelCache", + "AsyncTwoLevelCache", "batched_async", "dotenv_files_from_environment", "read_credentials", @@ -11,6 +12,7 @@ "write_credentials", ] +import asyncio import fcntl import json import logging @@ -19,7 +21,7 @@ import stat import threading from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Coroutine from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime, timedelta, timezone from pathlib import Path @@ -293,6 +295,152 @@ def clear(self): self.locks.clear() +class AsyncTwoLevelCache: + """Async equivalent of TwoLevelCache, for use with async populate functions. + + Mirrors the two-TTL semantics of TwoLevelCache exactly: a soft TTL that + triggers a background refresh while still serving the stale value, and a + hard TTL beyond which a miss either awaits a fresh value (blocking=True) or + raises NotReadyError (blocking=False). + + The key difference from TwoLevelCache is that all coordination uses asyncio + primitives (asyncio.Lock, asyncio.Task) rather than a ThreadPoolExecutor, + so populate_func can be a native coroutine. + + Attributes: + soft_cache (TTLCache): A cache with a shorter TTL for quick access. + hard_cache (TTLCache): A cache with a longer TTL as a fallback. + tasks (dict): In-flight refresh Tasks keyed by cache key. + _lock (asyncio.Lock): Guards task creation to ensure single-flight behaviour. + + Args: + soft_ttl (int): Time-to-live in seconds for the soft cache. + hard_ttl (int): Time-to-live in seconds for the hard cache. + max_items (int): Maximum number of items in each cache tier. + + Example: + >>> cache = AsyncTwoLevelCache(soft_ttl=5, hard_ttl=3600) + >>> async def populate(): + ... return await some_db_query() + >>> value = await cache.get("key", populate) + + """ + + def __init__( + self, + soft_ttl: int, + hard_ttl: int, + *, + max_items: int = 1_000_000, + ): + """Initialize the AsyncTwoLevelCache with specified TTLs.""" + self.soft_cache: Cache = TTLCache(max_items, soft_ttl) + self.hard_cache: Cache = TTLCache(max_items, hard_ttl) + # One Task per key for single-flight refresh deduplication. + self.tasks: dict[str, asyncio.Task] = {} + # A single lock guards task creation across all keys. + # Per-key locks would be cleaner but require careful cleanup; + # contention here is minimal since task creation is very fast. + self._lock = asyncio.Lock() + + async def get( + self, + key: str, + populate_func: Callable[[], Coroutine[Any, Any, T]], + blocking: bool = True, + ) -> T: + """Retrieve a value from the cache, populating it if necessary. + + Checks the soft cache first. On a soft miss, kicks off a background + refresh and returns the stale hard-cache value if one exists. On a hard + miss, either awaits the refresh (blocking=True) or raises NotReadyError + (blocking=False). + + Args: + key (str): The cache key to retrieve or populate. + populate_func: An async callable (coroutine function) that returns + the value to cache. + blocking (bool): If True, wait for the populate_func to complete on + a hard miss. If False, raise NotReadyError instead. + + Returns: + The cached value associated with the key. + + """ + # Fast path: soft cache hit, no locking needed. + if key in self.soft_cache: + return self.soft_cache[key] + + async with self._lock: + # Re-check inside the lock in case another coroutine just populated it. + if key in self.soft_cache: + return self.soft_cache[key] + + # Ensure at most one refresh Task is in flight for this key. + if key not in self.tasks or self.tasks[key].done(): + self.tasks[key] = asyncio.create_task(self._work(key, populate_func)) + task = self.tasks[key] + + if key in self.hard_cache: + # Soft miss but hard hit: serve stale while the refresh runs. + # Pre-fill soft cache so the next request skips the lock entirely. + result = self.hard_cache[key] + self.soft_cache[key] = result + return result + + # Hard miss: no value in either cache yet. + if blocking: + # Await outside the lock so _work can acquire it to write results. + await task + return self.hard_cache[key] + + logger.debug( + "Cache key %r not ready yet, background population in progress", key + ) + raise NotReadyError(f"Cache key {key} is not ready yet.") + + async def _work( + self, key: str, populate_func: Callable[[], Coroutine[Any, Any, T]] + ) -> None: + """Await populate_func and write results into both cache tiers. + + Always removes the task entry so the next soft miss can schedule a fresh + refresh, regardless of whether this attempt succeeded or failed. + + Args: + key (str): The cache key to populate. + populate_func: Async callable that produces the value. + + """ + success = False + result = None + try: + result = await populate_func() + success = True + except Exception: + logger.error( + "Failed to populate cache key %r, will retry on next request", + key, + exc_info=True, + ) + raise + finally: + async with self._lock: + self.tasks.pop(key, None) + if success: + self.hard_cache[key] = result + self.soft_cache[key] = result + + async def clear(self): + """Cancel any in-flight refresh tasks and clear both cache tiers.""" + async with self._lock: + for task in self.tasks.values(): + task.cancel() + self.tasks.clear() + self.soft_cache.clear() + self.hard_cache.clear() + + async def batched_async( iterable: AsyncIterable[T], n: int, *, strict: bool = False ) -> AsyncIterable[tuple[T, ...]]: diff --git a/diracx-db/src/diracx/db/sql/rss/db.py b/diracx-db/src/diracx/db/sql/rss/db.py index 1e91d4205..7fe095da8 100644 --- a/diracx-db/src/diracx/db/sql/rss/db.py +++ b/diracx-db/src/diracx/db/sql/rss/db.py @@ -1,8 +1,8 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone -from sqlalchemy import select +from sqlalchemy import insert, select from sqlalchemy.engine import Row from diracx.core.exceptions import ResourceNotFoundError @@ -20,23 +20,40 @@ class ResourceStatusDB(BaseSQLDB): metadata = RSSBase.metadata - async def get_site_statuses(self, vo: str = "all") -> list[tuple[str, str, str]]: - stmt = select(SiteStatus.name, SiteStatus.status, SiteStatus.reason).where( - SiteStatus.status_type == "all", - SiteStatus.vo == vo, - ) + async def get_site_statuses(self) -> list[tuple[str, str, str, str]]: + """Return all site statuses across all VOs. + + Returns: + List of (name, status, reason, vo) tuples. + + """ + stmt = select( + SiteStatus.name, + SiteStatus.status, + SiteStatus.reason, + SiteStatus.vo, + ).where(SiteStatus.status_type == "all") result = await self.conn.execute(stmt) rows = result.all() if not rows: - raise ResourceNotFoundError(f"Site statuses for VO {vo}") + raise ResourceNotFoundError("Site statuses") - return [(row.Name, row.Status, row.Reason) for row in rows] + return [(row.Name, row.Status, row.Reason, row.VO) for row in rows] async def get_resource_statuses( self, status_types: list[str] | None = None, - vo: str = "all", ) -> dict[str, dict[str, Row]]: + """Return resource statuses for the given status types across all VOs. + + Args: + status_types: Status type filter (e.g. ["ReadAccess", "WriteAccess"]). + Defaults to ["all"]. + + Returns: + Nested dict keyed by resource name then status type. + + """ if not status_types: status_types = ["all"] stmt = select( @@ -44,15 +61,16 @@ async def get_resource_statuses( ResourceStatus.status, ResourceStatus.reason, ResourceStatus.status_type, + ResourceStatus.vo, ).where( ResourceStatus.status_type.in_(status_types), - ResourceStatus.vo == vo, ) result = await self.conn.execute(stmt) rows = result.all() if not rows: - raise ResourceNotFoundError(f"Resource statuses for VO {vo}") + raise ResourceNotFoundError("Resource statuses") + statuses: dict[str, dict[str, Row]] = {} for row in rows: if row.Name not in statuses: @@ -60,11 +78,19 @@ async def get_resource_statuses( statuses[row.Name][row.StatusType] = row return statuses - async def get_status_date( + async def get_resource_status_date( self, status_types: list[str] | None = None, - vo: str = "all", ) -> Row[tuple[datetime, datetime]]: + """Return the most recent DateEffective across all VOs for the given status types. + + Args: + status_types: Status type filter. Defaults to ["all"]. + + Returns: + Row with (date_effective, last_check_time) for the most recent entry. + + """ if not status_types: status_types = ["all"] stmt = ( @@ -72,15 +98,101 @@ async def get_status_date( ResourceStatus.date_effective, ResourceStatus.last_check_time, ) - .where( - ResourceStatus.status_type.in_(status_types), - ResourceStatus.vo == vo, + .where(ResourceStatus.status_type.in_(status_types)) + .order_by(ResourceStatus.date_effective.desc()) + .limit(1) + ) + result = await self.conn.execute(stmt) + row = result.first() + if not row: + raise ResourceNotFoundError("Resource statuses") + return row + + async def get_site_status_date(self) -> Row[tuple[datetime, datetime]]: + """Return the most recent DateEffective from the SiteStatus table across all VOs. + + Returns: + Row with (date_effective, last_check_time) for the most recent entry. + + """ + stmt = ( + select( + SiteStatus.date_effective, + SiteStatus.last_check_time, ) - .order_by(ResourceStatus.date_effective.desc()) # the most recent date + .where(SiteStatus.status_type == "all") + .order_by(SiteStatus.date_effective.desc()) .limit(1) ) result = await self.conn.execute(stmt) row = result.first() if not row: - raise ResourceNotFoundError(f"Resource statuses for VO {vo}") + raise ResourceNotFoundError("Site statuses") return row + + async def insert_resource_status( + self, + name: str, + status: str, + status_type: str, + vo: str, + reason: str = "", + date_effective: datetime | None = None, + last_check_time: datetime | None = None, + ) -> None: + """Insert a single ResourceStatus row. + + Args: + name: Resource name. + status: Status value. + status_type: One of "all", "ReadAccess", "WriteAccess", etc. + vo: Virtual organisation (e.g. "lhcb", "all"). + reason: Human-readable reason string. + date_effective: Timestamp when the status became effective. + Defaults to now. + last_check_time: Timestamp of last check. Defaults to now. + + """ + now = datetime.now(timezone.utc) + stmt = insert(ResourceStatus).values( + Name=name, + Status=status, + StatusType=status_type, + VO=vo, + Reason=reason, + DateEffective=date_effective or now, + LastCheckTime=last_check_time or now, + ) + await self.conn.execute(stmt) + + async def insert_site_status( + self, + name: str, + status: str, + vo: str, + reason: str = "", + date_effective: datetime | None = None, + last_check_time: datetime | None = None, + ) -> None: + """Insert a single SiteStatus row. + + Args: + name: Site name (e.g. "LCG.CERN.cern"). + status: Status value (e.g. "Active", "Banned"). + vo: Virtual organisation. + reason: Human-readable reason string. + date_effective: Defaults to now. + last_check_time: Defaults to now. + + """ + now = datetime.now(timezone.utc) + stmt = insert(SiteStatus).values( + Name=name, + Status=status, + StatusType="all", + VO=vo, + Reason=reason, + DateEffective=date_effective or now, + LastCheckTime=last_check_time or now, + ) + await self.conn.execute(stmt) diff --git a/diracx-db/tests/rss/test_rss_db.py b/diracx-db/tests/rss/test_rss_db.py index 84332b73f..6e199619f 100644 --- a/diracx-db/tests/rss/test_rss_db.py +++ b/diracx-db/tests/rss/test_rss_db.py @@ -5,7 +5,6 @@ import pytest from sqlalchemy import insert -from diracx.core.exceptions import ResourceNotFoundError from diracx.db.sql.rss.db import ResourceStatusDB _NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) @@ -43,15 +42,11 @@ async def test_site_status(rss_db: ResourceStatusDB): async with rss_db as db: rows = await db.get_site_statuses() assert rows - name, status, reason = rows[0] + name, status, reason, vo = rows[0] assert name == "TestSite" assert status == "Active" assert reason == "All good" - - # Test with an unknow Site (should not be found) - with pytest.raises(ResourceNotFoundError): - async with rss_db as db: - await db.get_site_statuses("Unknown") + assert vo == "all" async def test_resource_status(rss_db: ResourceStatusDB): @@ -111,6 +106,7 @@ async def test_resource_status(rss_db: ResourceStatusDB): assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" + assert result["all"].VO == "all" # Test with the test FTS (should be found) async with rss_db as db: @@ -120,6 +116,7 @@ async def test_resource_status(rss_db: ResourceStatusDB): assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" + assert result["all"].VO == "all" # Test with the test Storage Element (should be found) async with rss_db as db: @@ -135,8 +132,3 @@ async def test_resource_status(rss_db: ResourceStatusDB): for row in result["TestStorage"].values(): assert row.Status == "Active" assert row.Reason == "All good" - - # Test with an unknow Resource (should not be found) - with pytest.raises(ResourceNotFoundError): - async with rss_db as db: - await db.get_resource_statuses(vo="Unknown") diff --git a/diracx-logic/src/diracx/logic/rss/query.py b/diracx-logic/src/diracx/logic/rss/query.py index 0c32034a0..cd6a87d46 100644 --- a/diracx-logic/src/diracx/logic/rss/query.py +++ b/diracx-logic/src/diracx/logic/rss/query.py @@ -36,49 +36,80 @@ def map_status(db_status: str, reason: str | None = None) -> ResourceStatus: async def get_site_statuses( - resource_status_db: ResourceStatusDB, vo: str -) -> dict[str, SiteStatusModel]: - rows = await resource_status_db.get_site_statuses(vo) - return { - name: SiteStatusModel(all=map_status(status, reason)) - for name, status, reason in rows - } + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, SiteStatusModel]]: + """Fetch all site statuses across all VOs. + + The returned models carry the vo field so the router can filter to the + caller's VO from the cached all-VO snapshot. + """ + rows = await resource_status_db.get_site_statuses() + + result: dict[str, dict[str, SiteStatusModel]] = {} + + for name, status, reason, vo in rows: + vo = vo or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = SiteStatusModel(all=map_status(status, reason)) + + return result async def get_compute_statuses( - resource_status_db: ResourceStatusDB, vo: str -) -> dict[str, ComputeElementStatus]: - all_rows = await resource_status_db.get_resource_statuses(["all"], vo) - return { - name: ComputeElementStatus( + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, ComputeElementStatus]]: + """Fetch all compute element statuses across all VOs.""" + all_rows = await resource_status_db.get_resource_statuses(["all"]) + + result: dict[str, dict[str, ComputeElementStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["all"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = ComputeElementStatus( all=map_status(rows["all"].Status, rows["all"].Reason) ) - for name, rows in all_rows.items() - } + + return result async def get_fts_statuses( - resource_status_db: ResourceStatusDB, vo: str -) -> dict[str, FTSStatus]: - all_rows = await resource_status_db.get_resource_statuses(["all"], vo) - return { - name: FTSStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) - for name, rows in all_rows.items() - } + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, FTSStatus]]: + """Fetch all FTS server statuses across all VOs.""" + all_rows = await resource_status_db.get_resource_statuses(["all"]) + + result: dict[str, dict[str, FTSStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["all"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = FTSStatus( + all=map_status(rows["all"].Status, rows["all"].Reason) + ) + + return result async def get_storage_statuses( - resource_status_db: ResourceStatusDB, vo: str -) -> dict[str, StorageElementStatus]: + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, StorageElementStatus]]: + """Fetch all storage element statuses across all VOs.""" all_rows = await resource_status_db.get_resource_statuses( - ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], vo + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] ) - return { - name: StorageElementStatus( + + result: dict[str, dict[str, StorageElementStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["ReadAccess"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = StorageElementStatus( read=map_status(rows["ReadAccess"].Status, rows["ReadAccess"].Reason), write=map_status(rows["WriteAccess"].Status, rows["WriteAccess"].Reason), check=map_status(rows["CheckAccess"].Status, rows["CheckAccess"].Reason), remove=map_status(rows["RemoveAccess"].Status, rows["RemoveAccess"].Reason), ) - for name, rows in all_rows.items() - } + + return result diff --git a/diracx-logic/src/diracx/logic/rss/source.py b/diracx-logic/src/diracx/logic/rss/source.py index c4f15b963..e8e94f2db 100644 --- a/diracx-logic/src/diracx/logic/rss/source.py +++ b/diracx-logic/src/diracx/logic/rss/source.py @@ -1,20 +1,32 @@ """Resource Status System source classes. -These classes sit in the logic layer (diracx-logic) so they can import from -diracx-db without violating the project's dependency flow: +These classes live in diracx-logic so they can import from diracx-db without +violating the project's dependency flow: routers → logic → db → core -`CacheableSource` (the abstraction) and `Snapshot` live in diracx-core; -the concrete implementations live here because they need diracx-db. +Dependency injection pattern +---------------------------- +Each source subclass exposes a `create` classmethod used as the FastAPI +dependency key in the router — the same pattern as ConfigSource.create. +factory.py overrides each `create` with the corresponding source instance's +`read_non_blocking`, so routes receive the cached Snapshot directly. + +Note on DB usage +---------------- +Every DB call opens its own connection via `async with self._db as db` because +these calls happen outside FastAPI's DI pipeline (db_transaction never runs). +The engine is already open for the app lifetime via engine_context. """ from __future__ import annotations import logging from datetime import datetime +from typing import ClassVar -from diracx.core.config.sources import CacheableSource, Snapshot +from diracx.core.config.sources import AsyncCacheableSource +from diracx.core.models.rss import Snapshot from diracx.db.sql.rss.db import ResourceStatusDB from .query import ( @@ -27,92 +39,89 @@ logger = logging.getLogger(__name__) -class ResourceStatusSource(CacheableSource[Snapshot]): - """Caching source for Compute, Storage, and FTS resource statuses. +class ResourceStatusSource(AsyncCacheableSource): + """Base caching source for Compute, Storage and FTS resource types. - Holds a long-lived reference to the app-level ``ResourceStatusDB`` instance - (created in ``factory.py`` alongside the engine). Each refresh uses - ``async with self._db`` so that ``__aenter__`` runs — which sets the - ``ContextVar`` connection — on the *same* event loop the engine is bound to. + Subclasses declare `resource_type` as a class attribute — latest_revision + and _fetch dispatch on it automatically. - One source covers *all* VOs for a given resource type. VO-level filtering - is done in the route after the snapshot is fetched, keeping the cache simple - and avoiding N redundant poll schedules for the same underlying table. + One source instance per resource type covers all VOs. VO filtering is done + in the route after the snapshot is fetched from the cache. """ - def __init__(self, *, db: ResourceStatusDB, resource_type: str) -> None: - """ - Args: - db: Long-lived ``ResourceStatusDB`` instance from factory.py. - Must already have ``engine_context`` registered (i.e. its engine - is open for the application lifetime). - resource_type: One of ``"ComputeElement"``, ``"StorageElement"``, - ``"FTS"``. - """ + resource_type: ClassVar[str] + + def __init__(self, *, db: ResourceStatusDB) -> None: super().__init__() self._db = db - self._resource_type = resource_type async def latest_revision(self) -> tuple[str, datetime]: - """Query the max DateEffective for this resource type across all VOs. - - Uses ``modified.isoformat()`` as the ETag so the value is deterministic - across replicas (unlike ``hash()``, which is randomised by PYTHONHASHSEED). - """ async with self._db as db: - status_date = await db.get_resource_status_date( - resource_type=self._resource_type - ) - modified: datetime = status_date.DateEffective + row = await db.get_resource_status_date() + modified: datetime = row.DateEffective return modified.isoformat(), modified async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: - """Fetch all statuses for this resource type and wrap them in a Snapshot.""" async with self._db as db: data = await self._fetch(db) return Snapshot(data=data, hexsha=hexsha, modified=modified) - async def _fetch(self, db: ResourceStatusDB): - """Dispatch to the appropriate query helper based on resource type.""" - if self._resource_type == "ComputeElement": - return await get_compute_statuses(db, vo="all") - if self._resource_type == "StorageElement": - return await get_storage_statuses(db, vo="all") - if self._resource_type == "FTS": - return await get_fts_statuses(db, vo="all") - raise ValueError(f"Unsupported resource_type: {self._resource_type!r}") + async def _fetch(self, db: ResourceStatusDB) -> dict: + if self.resource_type == "StorageElement": + return await get_storage_statuses(db) + if self.resource_type == "ComputeElement": + return await get_compute_statuses(db) + if self.resource_type == "FTS": + return await get_fts_statuses(db) + raise ValueError(f"Unsupported resource_type: {self.resource_type!r}") -class SiteStatusSource(CacheableSource[Snapshot]): - """Caching source for Site statuses. +class StorageElementStatusSource(ResourceStatusSource): + resource_type = "StorageElement" + + @classmethod + async def create(cls) -> Snapshot: + raise NotImplementedError("This method should not be called") + + +class ComputeElementStatusSource(ResourceStatusSource): + resource_type = "ComputeElement" - Sites have a first-class status row in their own table (``SiteStatus``), - independent of the per-resource ``ResourceStatus`` table that - ``ResourceStatusSource`` queries. They are also always stored with - ``vo="all"``, so no per-VO filtering is needed on the way out. + @classmethod + async def create(cls) -> Snapshot: + raise NotImplementedError("This method should not be called") - Keeping this as a separate class (rather than a subtype of - ``ResourceStatusSource``) avoids conflating two different DB tables and - makes the ``latest_revision`` query correct for each. + +class FTSStatusSource(ResourceStatusSource): + resource_type = "FTS" + + @classmethod + async def create(cls) -> Snapshot: + raise NotImplementedError("This method should not be called") + + +class SiteStatusSource(AsyncCacheableSource): + """Caching source for Site statuses. + + Uses its own DB table (SiteStatus) and a dedicated date query, so it is a + direct subclass of AsyncCacheableSource rather than ResourceStatusSource. """ def __init__(self, *, db: ResourceStatusDB) -> None: - """ - Args: - db: Long-lived ``ResourceStatusDB`` instance from factory.py. - """ super().__init__() self._db = db async def latest_revision(self) -> tuple[str, datetime]: - """Query the max DateEffective from the Site status table.""" async with self._db as db: - status_date = await db.get_site_status_date() - modified: datetime = status_date.DateEffective + row = await db.get_site_status_date() + modified: datetime = row.DateEffective return modified.isoformat(), modified async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: - """Fetch all site statuses and wrap them in a Snapshot.""" async with self._db as db: - data = await get_site_statuses(db, vo="all") + data = await get_site_statuses(db) return Snapshot(data=data, hexsha=hexsha, modified=modified) + + @classmethod + async def create(cls) -> Snapshot: + raise NotImplementedError("This method should not be called") diff --git a/diracx-logic/tests/rss/test_rss.py b/diracx-logic/tests/rss/test_rss_query.py similarity index 100% rename from diracx-logic/tests/rss/test_rss.py rename to diracx-logic/tests/rss/test_rss_query.py diff --git a/diracx-core/tests/test_status_source.py b/diracx-logic/tests/rss/test_rss_source.py similarity index 52% rename from diracx-core/tests/test_status_source.py rename to diracx-logic/tests/rss/test_rss_source.py index 33f0b786d..4ef55df14 100644 --- a/diracx-core/tests/test_status_source.py +++ b/diracx-logic/tests/rss/test_rss_source.py @@ -6,7 +6,6 @@ import pytest -from diracx.core.exceptions import ResourceNotFoundError from diracx.core.models.rss import ( ComputeElementStatus, FTSStatus, @@ -14,7 +13,12 @@ StorageElementStatus, ) from diracx.db.sql.rss.db import ResourceStatusDB -from diracx.logic.rss.source import ResourceStatusSource +from diracx.logic.rss.source import ( + ComputeElementStatusSource, + FTSStatusSource, + SiteStatusSource, + StorageElementStatusSource, +) @pytest.fixture @@ -22,7 +26,9 @@ def mock_resource_status_db(): """Fixture to mock the ResourceStatusDB.""" db = MagicMock(spec=ResourceStatusDB) DateRow = namedtuple("DateRow", ["DateEffective", "DateChecked"]) - db.get_status_date = AsyncMock( + db.__aenter__ = AsyncMock(return_value=db) + db.__aexit__ = AsyncMock(return_value=None) + db.get_resource_status_date = AsyncMock( return_value=DateRow( DateEffective=datetime.fromisoformat("2023-01-01T00:00:00+00:00"), DateChecked=datetime.now(timezone.utc), @@ -31,103 +37,97 @@ def mock_resource_status_db(): return db -def test_latest_revision(mock_resource_status_db): +async def test_latest_revision(mock_resource_status_db): """Test the latest_revision method of ResourceStatusSource.""" - source = ResourceStatusSource( - resource_type="Site", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) + source = ComputeElementStatusSource(db=mock_resource_status_db) # Call the method - revision, modified = source.latest_revision() + revision, modified = await source.latest_revision() # Verify the revision is generated correctly - assert revision.startswith("rev_") + assert revision assert isinstance(modified, datetime) # Verify the database call - mock_resource_status_db.get_status_date.assert_called_once_with(vo="test_vo") + mock_resource_status_db.get_resource_status_date.assert_called_once() -def test_read_raw_site(mock_resource_status_db): +async def test_read_raw_site(mock_resource_status_db): """Test the read_raw method for Site resource type.""" # Mock the database data - mock_db_data = [("testSite", "Active", "")] + mock_db_data = [("testSite", "Active", "", "test_vo")] # Patch the get_site_statuses method of the database to return the mock data mock_resource_status_db.get_site_statuses = AsyncMock(return_value=mock_db_data) # Initialize the ResourceStatusSource with the mocked database - source = ResourceStatusSource( - resource_type="Site", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) + source = SiteStatusSource(db=mock_resource_status_db) # Call the read_raw method, which internally calls get_site_statuses from query.py - result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) # Verify the result matches the expected output expected_result = {"testSite": SiteStatus(all={"allowed": True, "warnings": None})} for key, value in expected_result.items(): - assert key in result - assert value.model_dump() == result[key].model_dump() + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() # Verify that the database method was called correctly - mock_resource_status_db.get_site_statuses.assert_awaited_once_with("test_vo") + mock_resource_status_db.get_site_statuses.assert_awaited_once() -def test_read_raw_compute(mock_resource_status_db): +async def test_read_raw_compute(mock_resource_status_db): """Test the read_raw method for ComputeElement resource type.""" - ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) mock_db_data = { - "TestCE": {"all": ResourceStatus(Name="TestCE", Status="Active", Reason="")} + "TestCE": { + "all": ResourceStatus( + Name="TestCE", Status="Active", Reason="", VO="test_vo" + ) + } } mock_resource_status_db.get_resource_statuses = AsyncMock(return_value=mock_db_data) - source = ResourceStatusSource( - resource_type="ComputeElement", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) + source = ComputeElementStatusSource(db=mock_resource_status_db) # Call the method - result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) # Verify the result expected_result = { "TestCE": ComputeElementStatus(all={"allowed": True, "warnings": None}) } for key, value in expected_result.items(): - assert key in result - assert value.model_dump() == result[key].model_dump() - mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( - ["all"], "test_vo" - ) + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with(["all"]) -def test_read_raw_storage(mock_resource_status_db): +async def test_read_raw_storage(mock_resource_status_db): """Test the read_raw method for StorageElement resource type.""" - ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) mock_db_data = { "TestSE": { - "ReadAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), - "WriteAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), - "CheckAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), - "RemoveAccess": ResourceStatus(Name="TestSE", Status="Active", Reason=None), + "ReadAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "WriteAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "CheckAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "RemoveAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), } } mock_resource_status_db.get_resource_statuses.return_value = mock_db_data - source = ResourceStatusSource( - resource_type="StorageElement", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) + source = StorageElementStatusSource(db=mock_resource_status_db) # Call the method - result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) # Verify the result expected_result = { @@ -139,51 +139,34 @@ def test_read_raw_storage(mock_resource_status_db): ) } for key, value in expected_result.items(): - assert key in result - assert value.model_dump() == result[key].model_dump() + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( - ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], "test_vo" + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] ) -def test_read_raw_fts(mock_resource_status_db): +async def test_read_raw_fts(mock_resource_status_db): """Test the read_raw method for FTS resource type.""" - ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason"]) + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) mock_db_data = { "FTS": { - "all": ResourceStatus(Name="FTS", Status="Active", Reason=None), + "all": ResourceStatus( + Name="FTS", Status="Active", Reason=None, VO="test_vo" + ), } } mock_resource_status_db.get_resource_statuses.return_value = mock_db_data - source = ResourceStatusSource( - resource_type="FTS", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) + source = FTSStatusSource(db=mock_resource_status_db) # Call the method - result = source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) # Verify the result expected_result = {"FTS": FTSStatus(all={"allowed": True, "warnings": None})} for key, value in expected_result.items(): - assert key in result - assert value.model_dump() == result[key].model_dump() - mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( - ["all"], "test_vo" - ) - - -def test_read_raw_invalid_resource_type(mock_resource_status_db): - """Test the read_raw method for an invalid resource type.""" - source = ResourceStatusSource( - resource_type="InvalidType", - vo="test_vo", - resource_status_db=mock_resource_status_db, - ) - - # Call the method and verify it raises ResourceNotFoundError - with pytest.raises(ResourceNotFoundError): - source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with(["all"]) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 7e721eac2..26e49317f 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -52,6 +52,7 @@ rss = "diracx.routers.rss:router" [project.entry-points."diracx.access_policies"] wms = "diracx.routers.jobs.access_policies:WMSAccessPolicy" sandbox = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +rss = "diracx.routers.rss.access_policies:RSSAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py index f121d601f..a2ba1cbcd 100644 --- a/diracx-routers/src/diracx/routers/factory.py +++ b/diracx-routers/src/diracx/routers/factory.py @@ -37,9 +37,13 @@ from diracx.db.os.utils import BaseOSDB from diracx.db.sql.rss.db import ResourceStatusDB from diracx.db.sql.utils import BaseSQLDB -from diracx.logic.rss.source import ResourceStatusSource, SiteStatusSource +from diracx.logic.rss.source import ( + ComputeElementStatusSource, + FTSStatusSource, + SiteStatusSource, + StorageElementStatusSource, +) from diracx.routers.access_policies import BaseAccessPolicy, check_permissions -from diracx.routers.rss import RSSSnapshotSentinels from .fastapi_classes import DiracFastAPI, DiracxRouter from .otel import instrument_otel @@ -189,7 +193,6 @@ def create_app_inner( # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() - # Track the app-lifetime ResourceStatusDB instance so we can build sources below. rss_db_instance: ResourceStatusDB | None = None for db_name, db_url in database_urls.items(): @@ -213,11 +216,8 @@ def create_app_inner( db_no_transaction, sql_db ) - # Capture the long-lived ResourceStatusDB instance for the RSS sources. - # We reuse this instance (not a per-request DI copy) so that the engine's - # event loop and the async context manager are managed correctly. - if isinstance(sql_db, ResourceStatusDB) and rss_db_instance is None: - rss_db_instance = sql_db + if isinstance(sql_db, ResourceStatusDB) and rss_db_instance is None: + rss_db_instance = sql_db # At least one DB works, so we do not fail the startup fail_startup = False @@ -227,39 +227,26 @@ def create_app_inner( if fail_startup: raise Exception("No SQL database could be initialized, aborting") - # --------------------------------------------------------------------------- - # Wire RSS sources via dependency_overrides — same pattern as ConfigSource.create. - # - # Each source holds a reference to the *app-lifetime* rss_db_instance; every - # refresh calls `async with self._db` so __aenter__ runs on the FastAPI event - # loop (the same loop the engine is bound to). One source per resource type - # covers all VOs; per-VO filtering is done in the route. - # --------------------------------------------------------------------------- if rss_db_instance is not None: - compute_source = ResourceStatusSource( - db=rss_db_instance, resource_type="ComputeElement" - ) - storage_source = ResourceStatusSource( - db=rss_db_instance, resource_type="StorageElement" - ) - fts_source = ResourceStatusSource(db=rss_db_instance, resource_type="FTS") + compute_source = ComputeElementStatusSource(db=rss_db_instance) + storage_source = StorageElementStatusSource(db=rss_db_instance) + fts_source = FTSStatusSource(db=rss_db_instance) site_source = SiteStatusSource(db=rss_db_instance) - app.dependency_overrides[RSSSnapshotSentinels.get_compute_snapshot] = ( - compute_source.read_non_blocking - ) - app.dependency_overrides[RSSSnapshotSentinels.get_storage_snapshot] = ( + app.dependency_overrides[StorageElementStatusSource.create] = ( storage_source.read_non_blocking ) - app.dependency_overrides[RSSSnapshotSentinels.get_fts_snapshot] = ( - fts_source.read_non_blocking + app.dependency_overrides[ComputeElementStatusSource.create] = ( + compute_source.read_non_blocking ) - app.dependency_overrides[RSSSnapshotSentinels.get_site_snapshot] = ( + app.dependency_overrides[SiteStatusSource.create] = ( site_source.read_non_blocking ) + app.dependency_overrides[FTSStatusSource.create] = fts_source.read_non_blocking else: logger.warning( - "ResourceStatusDB not found; RSS endpoints will not be available." + "ResourceStatusDB not found in database_urls; " + "RSS endpoints will return 503 until it becomes available." ) # Add the OpenSearch DBs to the application diff --git a/diracx-routers/src/diracx/routers/rss.py b/diracx-routers/src/diracx/routers/rss.py deleted file mode 100644 index 833c19e83..000000000 --- a/diracx-routers/src/diracx/routers/rss.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -import logging -from collections.abc import Callable -from datetime import datetime, timezone -from enum import StrEnum, auto -from typing import Annotated, cast - -from fastapi import ( - Depends, - Header, - HTTPException, - Response, - status, -) - -from diracx.core.config.sources import Snapshot -from diracx.core.models.rss import ( - ComputeElementStatus, - FTSStatus, - SiteStatus, - StorageElementStatus, -) -from diracx.routers.access_policies import BaseAccessPolicy -from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token - -from .fastapi_classes import DiracxRouter - -logger = logging.getLogger(__name__) - -LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" - -router = DiracxRouter() - - -# --------------------------------------------------------------------------- -# Access policy -# --------------------------------------------------------------------------- - - -class ActionType(StrEnum): - # Create a job or a sandbox - CREATE = auto() - # Check job status, download a sandbox - READ = auto() - # Delete, kill, remove, set status, etc of a job - # Delete or assign a sandbox - MANAGE = auto() - # Search - QUERY = auto() - # Actions from a pilot (e.g. heartbeat) - PILOT = auto() - - -class ResourceStatusAccessPolicy(BaseAccessPolicy): - """Policy: any authenticated user may READ resource statuses. - - Write/admin actions are rejected here; VO scoping is the route's responsibility. - Registered under ``[project.entry-points."diracx.access_policies"]`` in - ``diracx-routers/pyproject.toml`` so the framework can discover it. - """ - - @staticmethod - async def policy( - policy_name: str, - user_info: AuthorizedUserInfo, - /, - *, - action: ActionType | None = None, - ): - if action != ActionType.READ: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Resource Status System is read-only.", - ) - # Any authenticated user may read; VO scoping happens in the route. - - -ResourceStatusAccessPolicyCallable = Annotated[ - Callable, Depends(ResourceStatusAccessPolicy.check) -] - - -class RSSSnapshotSentinels: - @classmethod - def get_storage_snapshot(cls) -> Snapshot: - raise NotImplementedError - - @classmethod - def get_compute_snapshot(cls) -> Snapshot: - raise NotImplementedError - - @classmethod - def get_site_snapshot(cls) -> Snapshot: - raise NotImplementedError - - @classmethod - def get_fts_snapshot(cls) -> Snapshot: - raise NotImplementedError - - -# --------------------------------------------------------------------------- -# Shared ETag / 304 helper -# --------------------------------------------------------------------------- - - -def _apply_cache_headers( - response: Response, - snapshot: Snapshot, - if_none_match: str | None, - if_modified_since: str | None, -) -> None: - """Set ETag / Last-Modified headers and raise 304 when appropriate. - - Raises: - HTTPException(304): when the client's cached copy is still current. - """ - headers = { - "ETag": snapshot.hexsha, - "Last-Modified": snapshot.modified.strftime(LAST_MODIFIED_FORMAT), - } - - if if_none_match == snapshot.hexsha: - raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) - - if if_modified_since: - try: - not_before = datetime.strptime( - if_modified_since, LAST_MODIFIED_FORMAT - ).astimezone(timezone.utc) - except ValueError: - logger.debug( - "Failed to parse If-Modified-Since header: %s", if_modified_since - ) - else: - # Guard against flip-flop when a replica is momentarily behind. - if not_before > snapshot.modified: - raise HTTPException( - status_code=status.HTTP_304_NOT_MODIFIED, headers=headers - ) - - response.headers.update(headers) - - -# --------------------------------------------------------------------------- -# Routes -# --------------------------------------------------------------------------- - - -@router.get("/storage") -async def get_storage_status( - response: Response, - snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_storage_snapshot)], - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: ResourceStatusAccessPolicyCallable, - if_none_match: Annotated[str | None, Header()] = None, - if_modified_since: Annotated[str | None, Header()] = None, -) -> dict[str, StorageElementStatus]: - """Get the latest status of storage elements, scoped to the caller's VO.""" - await check_permissions(action=ActionType.READ) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, StorageElementStatus], - { - name: se - for name, se in snapshot.data.items() - if getattr(se, "vo", "all") in (user_info.vo, "all") - }, - ) - - -@router.get("/compute") -async def get_compute_status( - response: Response, - snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_compute_snapshot)], - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: ResourceStatusAccessPolicyCallable, - if_none_match: Annotated[str | None, Header()] = None, - if_modified_since: Annotated[str | None, Header()] = None, -) -> dict[str, ComputeElementStatus]: - """Get the latest status of compute elements, scoped to the caller's VO.""" - await check_permissions(action=ActionType.READ) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, ComputeElementStatus], - { - name: ce - for name, ce in snapshot.data.items() - if getattr(ce, "vo", "all") in (user_info.vo, "all") - }, - ) - - -@router.get("/site") -async def get_site_status( - response: Response, - snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_site_snapshot)], - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: ResourceStatusAccessPolicyCallable, - if_none_match: Annotated[str | None, Header()] = None, - if_modified_since: Annotated[str | None, Header()] = None, -) -> dict[str, SiteStatus]: - """Get the latest status of sites. - - Sites are always stored with vo="all" so no VO filtering is applied. - """ - await check_permissions(action=ActionType.READ) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast(dict[str, SiteStatus], snapshot.data) - - -@router.get("/fts") -async def get_fts_status( - response: Response, - snapshot: Annotated[Snapshot, Depends(RSSSnapshotSentinels.get_fts_snapshot)], - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - check_permissions: ResourceStatusAccessPolicyCallable, - if_none_match: Annotated[str | None, Header()] = None, - if_modified_since: Annotated[str | None, Header()] = None, -) -> dict[str, FTSStatus]: - """Get the latest status of FTS servers, scoped to the caller's VO.""" - await check_permissions(action=ActionType.READ) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, FTSStatus], - { - name: fts - for name, fts in snapshot.data.items() - if getattr(fts, "vo", "all") in (user_info.vo, "all") - }, - ) diff --git a/diracx-routers/src/diracx/routers/rss/__init__.py b/diracx-routers/src/diracx/routers/rss/__init__.py new file mode 100644 index 000000000..bebfcccc7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/__init__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from ..fastapi_classes import DiracxRouter +from .rss import router as rss_router + +router = DiracxRouter() +router.include_router(rss_router) diff --git a/diracx-routers/src/diracx/routers/rss/access_policies.py b/diracx-routers/src/diracx/routers/rss/access_policies.py new file mode 100644 index 000000000..b0b439976 --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/access_policies.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class RSSAccessPolicy(BaseAccessPolicy): + """Any authenticated user can access.""" + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + ): + if user_info.preferred_username: + return + + raise HTTPException(status.HTTP_403_FORBIDDEN) + + +CheckRSSPolicyCallable = Annotated[Callable, Depends(RSSAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/rss/rss.py b/diracx-routers/src/diracx/routers/rss/rss.py new file mode 100644 index 000000000..31f7788b2 --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/rss.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Annotated, cast + +from fastapi import ( + Depends, + Header, + HTTPException, + Response, + status, +) + +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + Snapshot, + StorageElementStatus, +) +from diracx.logic.rss.source import ( + ComputeElementStatusSource, + FTSStatusSource, + SiteStatusSource, + StorageElementStatusSource, +) +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..fastapi_classes import DiracxRouter +from .access_policies import CheckRSSPolicyCallable + +logger = logging.getLogger(__name__) + +LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" + +router = DiracxRouter() + + +def _apply_cache_headers( + response: Response, + snapshot: Snapshot, + if_none_match: str | None, + if_modified_since: str | None, +) -> None: + """Set ETag / Last-Modified headers and raise 304 when appropriate. + + Raises: + HTTPException(304): when the client's cached copy is still current. + + """ + headers = { + "ETag": snapshot.hexsha, + "Last-Modified": snapshot.modified.strftime(LAST_MODIFIED_FORMAT), + } + + if if_none_match == snapshot.hexsha: + raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) + + if if_modified_since: + try: + not_before = datetime.strptime( + if_modified_since, LAST_MODIFIED_FORMAT + ).astimezone(timezone.utc) + except ValueError: + logger.debug( + "Failed to parse If-Modified-Since header: %s", if_modified_since + ) + else: + if not_before > snapshot.modified: + raise HTTPException( + status_code=status.HTTP_304_NOT_MODIFIED, headers=headers + ) + + response.headers.update(headers) + + +@router.get("/storage") +async def get_storage_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(StorageElementStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, StorageElementStatus]: + """Get the latest status of storage elements, scoped to the caller's VO.""" + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) + return cast( + dict[str, StorageElementStatus], + {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, + ) + + +@router.get("/compute") +async def get_compute_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(ComputeElementStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, ComputeElementStatus]: + """Get the latest status of compute elements, scoped to the caller's VO.""" + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) + return cast( + dict[str, ComputeElementStatus], + {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, + ) + + +@router.get("/site") +async def get_site_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(SiteStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, SiteStatus]: + """Get the latest status of sites, scoped to the caller's VO.""" + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) + return cast( + dict[str, SiteStatus], + {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, + ) + + +@router.get("/fts") +async def get_fts_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(FTSStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, FTSStatus]: + """Get the latest status of FTS servers, scoped to the caller's VO.""" + _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) + return cast( + dict[str, FTSStatus], + {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, + ) diff --git a/diracx-routers/tests/test_rss.py b/diracx-routers/tests/test_rss.py index e3108aeed..92056429c 100644 --- a/diracx-routers/tests/test_rss.py +++ b/diracx-routers/tests/test_rss.py @@ -1,23 +1,96 @@ from __future__ import annotations +import asyncio +from datetime import datetime, timezone + import pytest from fastapi import status pytestmark = pytest.mark.enabled_dependencies( - ["AuthSettings", "ResourceStatusDB", "RSSSnapshotSentinels"] + [ + "AuthSettings", + "ResourceStatusDB", + "SiteStatusSource", + "FTSStatusSource", + "ComputeElementStatusSource", + "StorageElementStatusSource", + "RSSAccessPolicy", + "DevelopmentSettings", + ] ) +async def _prepare_rss(client): + """Seed the DB and warm every source cache inside a single connection.""" + from diracx.core.config.sources import AsyncCacheableSource + from diracx.db.sql.rss.db import ResourceStatusDB + + db_override = client.app.dependency_overrides.get(ResourceStatusDB.no_transaction) + if db_override is None: + return + # factory.py stores partial(db_transaction, db_instance); args[0] is the instance. + db = db_override.args[0] + now = datetime.now(tz=timezone.utc) + + # Seed — open one connection, insert all rows, then close it cleanly. + async with db as conn: + for status_type in ("ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"): + await conn.insert_resource_status( + name="SE-CERN", + status="Active", + status_type=status_type, + vo="lhcb", + reason="All good", + date_effective=now, + ) + await conn.insert_resource_status( + name="CE-CERN", + status="Active", + status_type="all", + vo="lhcb", + reason="All good", + date_effective=now, + ) + await conn.insert_resource_status( + name="FTS-CERN", + status="Active", + status_type="all", + vo="lhcb", + reason="All good", + date_effective=now, + ) + await conn.insert_site_status( + name="LCG.CERN.cern", + status="Active", + vo="lhcb", + reason="All good", + date_effective=now, + ) + # Connection is now fully closed and _conn ContextVar is reset. + + # Warm each source — each source.read() opens its own fresh connection. + for override in client.app.dependency_overrides.values(): + source = getattr(override, "__self__", None) + if isinstance(source, AsyncCacheableSource): + await source.read() + + @pytest.fixture def normal_user_client(client_factory): with client_factory.normal_user() as client: + asyncio.get_event_loop().run_until_complete(_prepare_rss(client)) yield client -def test_unauthenticated(client_factory): +@pytest.fixture +def unauthenticated_client(client_factory): with client_factory.unauthenticated() as client: - response = client.get("/api/rss/storage") - assert response.status_code == status.HTTP_401_UNAUTHORIZED + yield client + + +def test_unauthenticated(unauthenticated_client): + response = unauthenticated_client.get("/api/rss/storage") + assert response.status_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.parametrize( @@ -32,28 +105,23 @@ def test_get_resource_status(normal_user_client, endpoint): last_modified = r.headers["Last-Modified"] etag = r.headers["ETag"] + # Matching ETag + matching Last-Modified → 304 r = normal_user_client.get( endpoint, - headers={ - "If-None-Match": etag, - "If-Modified-Since": last_modified, - }, + headers={"If-None-Match": etag, "If-Modified-Since": last_modified}, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text assert not r.text - # If only an invalid ETAG is passed, we expect a response + # Wrong ETag only → 200 r = normal_user_client.get( endpoint, - headers={ - "If-None-Match": "wrongEtag", - }, + headers={"If-None-Match": "wrongEtag"}, ) assert r.status_code == status.HTTP_200_OK, r.json() assert r.json(), r.text - # If an past ETAG and an past timestamp as give, we expect an response + # Past ETag + past timestamp → 200 r = normal_user_client.get( endpoint, headers={ @@ -64,7 +132,7 @@ def test_get_resource_status(normal_user_client, endpoint): assert r.status_code == status.HTTP_200_OK, r.json() assert r.json(), r.text - # If an future ETAG and an new timestamp as give, we expect 304 + # Wrong ETag + future timestamp → 304 (If-Modified-Since takes effect) r = normal_user_client.get( endpoint, headers={ @@ -75,7 +143,7 @@ def test_get_resource_status(normal_user_client, endpoint): assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text assert not r.text - # If an invalid ETAG and an invalid modified time, we expect a response + # Wrong ETag + invalid timestamp → 200 r = normal_user_client.get( endpoint, headers={ @@ -86,7 +154,7 @@ def test_get_resource_status(normal_user_client, endpoint): assert r.status_code == status.HTTP_200_OK, r.json() assert r.json(), r.text - # If the correct ETAG and a past timestamp as give, we expect 304 + # Correct ETag + past timestamp → 304 (ETag match takes priority) r = normal_user_client.get( endpoint, headers={ @@ -97,7 +165,7 @@ def test_get_resource_status(normal_user_client, endpoint): assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text assert not r.text - # If the correct ETAG and a new timestamp as give, we expect 304 + # Correct ETag + future timestamp → 304 r = normal_user_client.get( endpoint, headers={ @@ -108,7 +176,7 @@ def test_get_resource_status(normal_user_client, endpoint): assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text assert not r.text - # If the correct ETAG and an invalid modified time, we expect 304 + # Correct ETag + invalid timestamp → 304 r = normal_user_client.get( endpoint, headers={ diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index bb8d99195..ceb212943 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -46,6 +46,7 @@ from diracx.core.extensions import DiracEntryPoint from diracx.core.models import AccessTokenPayload, RefreshTokenPayload +from diracx.logic.rss.source import ResourceStatusSource, SiteStatusSource if TYPE_CHECKING: from diracx.core.settings import ( @@ -242,8 +243,6 @@ def enrich_tokens( all_access_policies=all_access_policies, ) - from diracx.routers.rss import RSSSnapshotSentinels - self.all_dependency_overrides = self.app.dependency_overrides.copy() self.app.dependency_overrides = {} for obj in self.all_dependency_overrides: @@ -255,7 +254,8 @@ def enrich_tokens( BaseOSDB, ConfigSource, BaseAccessPolicy, - RSSSnapshotSentinels, + ResourceStatusSource, + SiteStatusSource, ), ), obj From 9c08935b531e1b172853c4c5edb2f64826d2f63d Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Tue, 26 May 2026 13:56:38 +0200 Subject: [PATCH 5/7] fix: give up on source dependency injection --- diracx-logic/src/diracx/logic/rss/source.py | 29 ------------- diracx-routers/src/diracx/routers/factory.py | 34 ---------------- diracx-routers/src/diracx/routers/rss/rss.py | 43 ++++++++++++++++++-- diracx-testing/src/diracx/testing/utils.py | 3 -- 4 files changed, 39 insertions(+), 70 deletions(-) diff --git a/diracx-logic/src/diracx/logic/rss/source.py b/diracx-logic/src/diracx/logic/rss/source.py index e8e94f2db..a00e08d48 100644 --- a/diracx-logic/src/diracx/logic/rss/source.py +++ b/diracx-logic/src/diracx/logic/rss/source.py @@ -4,19 +4,6 @@ violating the project's dependency flow: routers → logic → db → core - -Dependency injection pattern ----------------------------- -Each source subclass exposes a `create` classmethod used as the FastAPI -dependency key in the router — the same pattern as ConfigSource.create. -factory.py overrides each `create` with the corresponding source instance's -`read_non_blocking`, so routes receive the cached Snapshot directly. - -Note on DB usage ----------------- -Every DB call opens its own connection via `async with self._db as db` because -these calls happen outside FastAPI's DI pipeline (db_transaction never runs). -The engine is already open for the app lifetime via engine_context. """ from __future__ import annotations @@ -79,26 +66,14 @@ async def _fetch(self, db: ResourceStatusDB) -> dict: class StorageElementStatusSource(ResourceStatusSource): resource_type = "StorageElement" - @classmethod - async def create(cls) -> Snapshot: - raise NotImplementedError("This method should not be called") - class ComputeElementStatusSource(ResourceStatusSource): resource_type = "ComputeElement" - @classmethod - async def create(cls) -> Snapshot: - raise NotImplementedError("This method should not be called") - class FTSStatusSource(ResourceStatusSource): resource_type = "FTS" - @classmethod - async def create(cls) -> Snapshot: - raise NotImplementedError("This method should not be called") - class SiteStatusSource(AsyncCacheableSource): """Caching source for Site statuses. @@ -121,7 +96,3 @@ async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: async with self._db as db: data = await get_site_statuses(db) return Snapshot(data=data, hexsha=hexsha, modified=modified) - - @classmethod - async def create(cls) -> Snapshot: - raise NotImplementedError("This method should not be called") diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py index a2ba1cbcd..4c26f4c58 100644 --- a/diracx-routers/src/diracx/routers/factory.py +++ b/diracx-routers/src/diracx/routers/factory.py @@ -35,14 +35,7 @@ from diracx.core.utils import dotenv_files_from_environment from diracx.db.exceptions import DBUnavailableError from diracx.db.os.utils import BaseOSDB -from diracx.db.sql.rss.db import ResourceStatusDB from diracx.db.sql.utils import BaseSQLDB -from diracx.logic.rss.source import ( - ComputeElementStatusSource, - FTSStatusSource, - SiteStatusSource, - StorageElementStatusSource, -) from diracx.routers.access_policies import BaseAccessPolicy, check_permissions from .fastapi_classes import DiracFastAPI, DiracxRouter @@ -193,8 +186,6 @@ def create_app_inner( # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() - rss_db_instance: ResourceStatusDB | None = None - for db_name, db_url in database_urls.items(): try: sql_db_classes = BaseSQLDB.available_implementations(db_name) @@ -216,9 +207,6 @@ def create_app_inner( db_no_transaction, sql_db ) - if isinstance(sql_db, ResourceStatusDB) and rss_db_instance is None: - rss_db_instance = sql_db - # At least one DB works, so we do not fail the startup fail_startup = False except Exception: @@ -227,28 +215,6 @@ def create_app_inner( if fail_startup: raise Exception("No SQL database could be initialized, aborting") - if rss_db_instance is not None: - compute_source = ComputeElementStatusSource(db=rss_db_instance) - storage_source = StorageElementStatusSource(db=rss_db_instance) - fts_source = FTSStatusSource(db=rss_db_instance) - site_source = SiteStatusSource(db=rss_db_instance) - - app.dependency_overrides[StorageElementStatusSource.create] = ( - storage_source.read_non_blocking - ) - app.dependency_overrides[ComputeElementStatusSource.create] = ( - compute_source.read_non_blocking - ) - app.dependency_overrides[SiteStatusSource.create] = ( - site_source.read_non_blocking - ) - app.dependency_overrides[FTSStatusSource.create] = fts_source.read_non_blocking - else: - logger.warning( - "ResourceStatusDB not found in database_urls; " - "RSS endpoints will return 503 until it becomes available." - ) - # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() for db_name, connection_kwargs in os_database_conn_kwargs.items(): diff --git a/diracx-routers/src/diracx/routers/rss/rss.py b/diracx-routers/src/diracx/routers/rss/rss.py index 31f7788b2..e0fce2735 100644 --- a/diracx-routers/src/diracx/routers/rss/rss.py +++ b/diracx-routers/src/diracx/routers/rss/rss.py @@ -19,13 +19,16 @@ Snapshot, StorageElementStatus, ) +from diracx.db.sql.rss.db import ResourceStatusDB from diracx.logic.rss.source import ( ComputeElementStatusSource, FTSStatusSource, + ResourceStatusSource, SiteStatusSource, StorageElementStatusSource, ) from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token +from diracx.tasks.plumbing.depends import NoTransaction from ..fastapi_classes import DiracxRouter from .access_policies import CheckRSSPolicyCallable @@ -37,6 +40,34 @@ router = DiracxRouter() +async def get_snapshot(rss_type, db) -> Snapshot: + """Get the status snapshot from the unique ResourceStatusSource instance or create it if it does not exist. + + :param rss_type: The type of the resource status source. + :param db: The database instance. + :returns: The status snapshot from the unique ResourceStatusSource instance. + """ + sources: dict[str, ResourceStatusSource | SiteStatusSource] = {} + if rss_type == "ComputeElementStatus": + if rss_type not in sources: + sources[rss_type] = ComputeElementStatusSource(db=db) + return await sources[rss_type].read() + elif rss_type == "StorageElementStatus": + if rss_type not in sources: + sources[rss_type] = StorageElementStatusSource(db=db) + return await sources[rss_type].read() + elif rss_type == "FTSStatus": + if rss_type not in sources: + sources[rss_type] = FTSStatusSource(db=db) + return await sources[rss_type].read() + elif rss_type == "SiteStatus": + if rss_type not in sources: + sources[rss_type] = SiteStatusSource(db=db) + return await sources[rss_type].read() + else: + raise ValueError(f"Unknown resource status source type: {rss_type}") + + def _apply_cache_headers( response: Response, snapshot: Snapshot, @@ -80,11 +111,12 @@ async def get_storage_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - snapshot: Annotated[Snapshot, Depends(StorageElementStatusSource.create)], + db: Annotated[ResourceStatusDB, NoTransaction()], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, StorageElementStatus]: """Get the latest status of storage elements, scoped to the caller's VO.""" + snapshot = await get_snapshot("StorageElementStatus", db) _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, StorageElementStatus], @@ -97,11 +129,12 @@ async def get_compute_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - snapshot: Annotated[Snapshot, Depends(ComputeElementStatusSource.create)], + db: Annotated[ResourceStatusDB, NoTransaction()], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, ComputeElementStatus]: """Get the latest status of compute elements, scoped to the caller's VO.""" + snapshot = await get_snapshot("ComputeElementStatus", db) _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, ComputeElementStatus], @@ -114,11 +147,12 @@ async def get_site_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - snapshot: Annotated[Snapshot, Depends(SiteStatusSource.create)], + db: Annotated[ResourceStatusDB, NoTransaction()], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, SiteStatus]: """Get the latest status of sites, scoped to the caller's VO.""" + snapshot = await get_snapshot("SiteStatus", db) _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, SiteStatus], @@ -131,11 +165,12 @@ async def get_fts_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - snapshot: Annotated[Snapshot, Depends(FTSStatusSource.create)], + db: Annotated[ResourceStatusDB, NoTransaction()], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, FTSStatus]: """Get the latest status of FTS servers, scoped to the caller's VO.""" + snapshot = await get_snapshot("FTSStatus", db) _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) return cast( dict[str, FTSStatus], diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index ceb212943..570241e69 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -46,7 +46,6 @@ from diracx.core.extensions import DiracEntryPoint from diracx.core.models import AccessTokenPayload, RefreshTokenPayload -from diracx.logic.rss.source import ResourceStatusSource, SiteStatusSource if TYPE_CHECKING: from diracx.core.settings import ( @@ -254,8 +253,6 @@ def enrich_tokens( BaseOSDB, ConfigSource, BaseAccessPolicy, - ResourceStatusSource, - SiteStatusSource, ), ), obj From 422564f26ccb439942965e7e9cef8b4491623c80 Mon Sep 17 00:00:00 2001 From: Loxeris <30194187+Loxeris@users.noreply.github.com> Date: Tue, 26 May 2026 15:30:21 +0200 Subject: [PATCH 6/7] chore: regenerate client --- .../src/diracx/client/_generated/_client.py | 5 +- .../diracx/client/_generated/aio/_client.py | 5 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 304 +++++++++++++ .../client/_generated/models/__init__.py | 26 ++ .../client/_generated/models/_models.py | 225 ++++++++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 412 ++++++++++++++++++ diracx-core/src/diracx/core/utils.py | 2 +- .../src/diracx/routers/rss/__init__.py | 3 + .../src/gubbins/client/_generated/_client.py | 6 +- .../gubbins/client/_generated/aio/_client.py | 6 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 304 +++++++++++++ .../client/_generated/models/__init__.py | 26 ++ .../client/_generated/models/_models.py | 225 ++++++++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 412 ++++++++++++++++++ 18 files changed, 1964 insertions(+), 5 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index caf48034f..c5e641f24 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, RssOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 79ab383a9..f2d2b46bf 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, RssOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.aio.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 6be34fb8a..77674abec 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 8aee57b46..4efe2dd6f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -51,6 +51,10 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_rss_get_compute_status_request, + build_rss_get_fts_status_request, + build_rss_get_site_status_request, + build_rss_get_storage_status_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2319,3 +2323,303 @@ async def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 14b5195d4..b4c06cc69 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -12,10 +12,16 @@ from ._models import ( # type: ignore + AllowedStatus, + BannedStatus, BodyAuthGetOidcToken, BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + ComputeElementStatus, + ComputeElementStatusAll, + FTSStatus, + FTSStatusAll, GroupInfo, HTTPValidationError, HeartbeatData, @@ -34,7 +40,14 @@ SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, + SiteStatus, + SiteStatusAll, SortSpec, + StorageElementStatus, + StorageElementStatusCheck, + StorageElementStatusRead, + StorageElementStatusRemove, + StorageElementStatusWrite, SummaryParams, SummaryParamsSearchItem, SupportInfo, @@ -59,10 +72,16 @@ from ._patch import patch_sdk as _patch_sdk __all__ = [ + "AllowedStatus", + "BannedStatus", "BodyAuthGetOidcToken", "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "ComputeElementStatus", + "ComputeElementStatusAll", + "FTSStatus", + "FTSStatusAll", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -81,7 +100,14 @@ "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", + "SiteStatus", + "SiteStatusAll", "SortSpec", + "StorageElementStatus", + "StorageElementStatusCheck", + "StorageElementStatusRead", + "StorageElementStatusRemove", + "StorageElementStatusWrite", "SummaryParams", "SummaryParamsSearchItem", "SupportInfo", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 888ec3b8a..730f15d6d 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -16,6 +16,70 @@ JSON = MutableMapping[str, Any] +class AllowedStatus(_serialization.Model): + """AllowedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar warnings: Warnings. + :vartype warnings: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "warnings": {"key": "warnings", "type": "str"}, + } + + def __init__(self, *, allowed: bool, warnings: Optional[str] = None, **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword warnings: Warnings. + :paramtype warnings: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.warnings = warnings + + +class BannedStatus(_serialization.Model): + """BannedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar reason: Reason. + :vartype reason: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "reason": {"key": "reason", "type": "str"}, + } + + def __init__(self, *, allowed: bool, reason: str = "Unknown", **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword reason: Reason. + :paramtype reason: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.reason = reason + + class BodyAuthGetOidcToken(_serialization.Model): """Body_auth_get_oidc_token. @@ -184,6 +248,66 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class ComputeElementStatus(_serialization.Model): + """ComputeElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.ComputeElementStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "ComputeElementStatusAll"}, + } + + def __init__(self, *, all: "_models.ComputeElementStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.ComputeElementStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class ComputeElementStatusAll(_serialization.Model): + """All.""" + + +class FTSStatus(_serialization.Model): + """FTSStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.FTSStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "FTSStatusAll"}, + } + + def __init__(self, *, all: "_models.FTSStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.FTSStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class FTSStatusAll(_serialization.Model): + """All.""" + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -1261,6 +1385,36 @@ def __init__( self.last_update_time = last_update_time +class SiteStatus(_serialization.Model): + """SiteStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.SiteStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "SiteStatusAll"}, + } + + def __init__(self, *, all: "_models.SiteStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.SiteStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class SiteStatusAll(_serialization.Model): + """All.""" + + class SortSpec(_serialization.Model): """SortSpec. @@ -1294,6 +1448,77 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class StorageElementStatus(_serialization.Model): + """StorageElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar read: Read. Required. + :vartype read: ~_generated.models.StorageElementStatusRead + :ivar write: Write. Required. + :vartype write: ~_generated.models.StorageElementStatusWrite + :ivar check: Check. Required. + :vartype check: ~_generated.models.StorageElementStatusCheck + :ivar remove: Remove. Required. + :vartype remove: ~_generated.models.StorageElementStatusRemove + """ + + _validation = { + "read": {"required": True}, + "write": {"required": True}, + "check": {"required": True}, + "remove": {"required": True}, + } + + _attribute_map = { + "read": {"key": "read", "type": "StorageElementStatusRead"}, + "write": {"key": "write", "type": "StorageElementStatusWrite"}, + "check": {"key": "check", "type": "StorageElementStatusCheck"}, + "remove": {"key": "remove", "type": "StorageElementStatusRemove"}, + } + + def __init__( + self, + *, + read: "_models.StorageElementStatusRead", + write: "_models.StorageElementStatusWrite", + check: "_models.StorageElementStatusCheck", + remove: "_models.StorageElementStatusRemove", + **kwargs: Any + ) -> None: + """ + :keyword read: Read. Required. + :paramtype read: ~_generated.models.StorageElementStatusRead + :keyword write: Write. Required. + :paramtype write: ~_generated.models.StorageElementStatusWrite + :keyword check: Check. Required. + :paramtype check: ~_generated.models.StorageElementStatusCheck + :keyword remove: Remove. Required. + :paramtype remove: ~_generated.models.StorageElementStatusRemove + """ + super().__init__(**kwargs) + self.read = read + self.write = write + self.check = check + self.remove = remove + + +class StorageElementStatusCheck(_serialization.Model): + """Check.""" + + +class StorageElementStatusRead(_serialization.Model): + """Read.""" + + +class StorageElementStatusRemove(_serialization.Model): + """Remove.""" + + +class StorageElementStatusWrite(_serialization.Model): + """Write.""" + + class SummaryParams(_serialization.Model): """SummaryParams. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 6be34fb8a..77674abec 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 11ffdcff7..69089682a 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -565,6 +565,118 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_rss_get_storage_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/storage" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_compute_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/compute" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_site_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/site" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_fts_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/fts" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2818,3 +2930,303 @@ def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any) -> l return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index f1ddf0ee5..c1ac4c7a3 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -2,8 +2,8 @@ __all__ = [ "EXPIRES_GRACE_SECONDS", - "TwoLevelCache", "AsyncTwoLevelCache", + "TwoLevelCache", "batched_async", "dotenv_files_from_environment", "read_credentials", diff --git a/diracx-routers/src/diracx/routers/rss/__init__.py b/diracx-routers/src/diracx/routers/rss/__init__.py index bebfcccc7..1f8b8659e 100644 --- a/diracx-routers/src/diracx/routers/rss/__init__.py +++ b/diracx-routers/src/diracx/routers/rss/__init__.py @@ -1,6 +1,9 @@ from __future__ import annotations +__all__ = ["RSSAccessPolicy", "router"] + from ..fastapi_classes import DiracxRouter +from .access_policies import RSSAccessPolicy from .rss import router as rss_router router = DiracxRouter() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index b19d47c68..04db73e5f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + RssOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.operations.MyOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index 32b9dad3a..1adec809d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + RssOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.aio.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.aio.operations.MyOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.aio.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 5cfdf7253..d7d250107 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index a2e0565c5..fb42de5d0 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -56,6 +56,10 @@ build_lollygag_insert_owner_object_request, build_my_pilots_get_pilot_summary_request, build_my_pilots_submit_pilot_request, + build_rss_get_compute_status_request, + build_rss_get_fts_status_request, + build_rss_get_site_status_request, + build_rss_get_storage_status_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2605,3 +2609,303 @@ async def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index b97d2e439..6684af567 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -12,11 +12,17 @@ from ._models import ( # type: ignore + AllowedStatus, + BannedStatus, BodyAuthGetOidcToken, BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + ComputeElementStatus, + ComputeElementStatusAll, ExtendedMetadata, + FTSStatus, + FTSStatusAll, GroupInfo, HTTPValidationError, HeartbeatData, @@ -34,7 +40,14 @@ SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, + SiteStatus, + SiteStatusAll, SortSpec, + StorageElementStatus, + StorageElementStatusCheck, + StorageElementStatusRead, + StorageElementStatusRemove, + StorageElementStatusWrite, SummaryParams, SummaryParamsSearchItem, SupportInfo, @@ -59,11 +72,17 @@ from diracx.client._generated.models._patch import patch_sdk as _patch_sdk __all__ = [ + "AllowedStatus", + "BannedStatus", "BodyAuthGetOidcToken", "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "ComputeElementStatus", + "ComputeElementStatusAll", "ExtendedMetadata", + "FTSStatus", + "FTSStatusAll", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -81,7 +100,14 @@ "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", + "SiteStatus", + "SiteStatusAll", "SortSpec", + "StorageElementStatus", + "StorageElementStatusCheck", + "StorageElementStatusRead", + "StorageElementStatusRemove", + "StorageElementStatusWrite", "SummaryParams", "SummaryParamsSearchItem", "SupportInfo", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 69b8ffcf1..6953a050f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -16,6 +16,70 @@ JSON = MutableMapping[str, Any] +class AllowedStatus(_serialization.Model): + """AllowedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar warnings: Warnings. + :vartype warnings: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "warnings": {"key": "warnings", "type": "str"}, + } + + def __init__(self, *, allowed: bool, warnings: Optional[str] = None, **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword warnings: Warnings. + :paramtype warnings: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.warnings = warnings + + +class BannedStatus(_serialization.Model): + """BannedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar reason: Reason. + :vartype reason: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "reason": {"key": "reason", "type": "str"}, + } + + def __init__(self, *, allowed: bool, reason: str = "Unknown", **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword reason: Reason. + :paramtype reason: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.reason = reason + + class BodyAuthGetOidcToken(_serialization.Model): """Body_auth_get_oidc_token. @@ -184,6 +248,36 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class ComputeElementStatus(_serialization.Model): + """ComputeElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.ComputeElementStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "ComputeElementStatusAll"}, + } + + def __init__(self, *, all: "_models.ComputeElementStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.ComputeElementStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class ComputeElementStatusAll(_serialization.Model): + """All.""" + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -231,6 +325,36 @@ def __init__( self.gubbins_user_info = gubbins_user_info +class FTSStatus(_serialization.Model): + """FTSStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.FTSStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "FTSStatusAll"}, + } + + def __init__(self, *, all: "_models.FTSStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.FTSStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class FTSStatusAll(_serialization.Model): + """All.""" + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -1282,6 +1406,36 @@ def __init__( self.last_update_time = last_update_time +class SiteStatus(_serialization.Model): + """SiteStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.SiteStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "SiteStatusAll"}, + } + + def __init__(self, *, all: "_models.SiteStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.SiteStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class SiteStatusAll(_serialization.Model): + """All.""" + + class SortSpec(_serialization.Model): """SortSpec. @@ -1315,6 +1469,77 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class StorageElementStatus(_serialization.Model): + """StorageElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar read: Read. Required. + :vartype read: ~_generated.models.StorageElementStatusRead + :ivar write: Write. Required. + :vartype write: ~_generated.models.StorageElementStatusWrite + :ivar check: Check. Required. + :vartype check: ~_generated.models.StorageElementStatusCheck + :ivar remove: Remove. Required. + :vartype remove: ~_generated.models.StorageElementStatusRemove + """ + + _validation = { + "read": {"required": True}, + "write": {"required": True}, + "check": {"required": True}, + "remove": {"required": True}, + } + + _attribute_map = { + "read": {"key": "read", "type": "StorageElementStatusRead"}, + "write": {"key": "write", "type": "StorageElementStatusWrite"}, + "check": {"key": "check", "type": "StorageElementStatusCheck"}, + "remove": {"key": "remove", "type": "StorageElementStatusRemove"}, + } + + def __init__( + self, + *, + read: "_models.StorageElementStatusRead", + write: "_models.StorageElementStatusWrite", + check: "_models.StorageElementStatusCheck", + remove: "_models.StorageElementStatusRemove", + **kwargs: Any + ) -> None: + """ + :keyword read: Read. Required. + :paramtype read: ~_generated.models.StorageElementStatusRead + :keyword write: Write. Required. + :paramtype write: ~_generated.models.StorageElementStatusWrite + :keyword check: Check. Required. + :paramtype check: ~_generated.models.StorageElementStatusCheck + :keyword remove: Remove. Required. + :paramtype remove: ~_generated.models.StorageElementStatusRemove + """ + super().__init__(**kwargs) + self.read = read + self.write = write + self.check = check + self.remove = remove + + +class StorageElementStatusCheck(_serialization.Model): + """Check.""" + + +class StorageElementStatusRead(_serialization.Model): + """Read.""" + + +class StorageElementStatusRemove(_serialization.Model): + """Remove.""" + + +class StorageElementStatusWrite(_serialization.Model): + """Write.""" + + class SummaryParams(_serialization.Model): """SummaryParams. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 5cfdf7253..d7d250107 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 7dcaa92ee..7866a140a 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -647,6 +647,118 @@ def build_my_pilots_get_pilot_summary_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_rss_get_storage_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/storage" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_compute_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/compute" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_site_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/site" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_fts_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/fts" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3181,3 +3293,303 @@ def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore From 8500412141431fa82b946fe887fe614f8c392b8a Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Thu, 4 Jun 2026 15:08:51 +0200 Subject: [PATCH 7/7] fix: wire RSS sources as singletons and fix caching bugs Instantiate the resource status sources once in the application factory via a new diracx.cacheable_sources entry-point group, mirroring the ConfigSource pattern, so their caches persist across requests. Also: - track storage revisions using the access status types instead of "all" - include the row count in revisions so deletions change the ETag - suffix ETags with the caller's VO and set Vary: Authorization - parse If-Modified-Since as GMT via a shared cache-header helper, also used by the configuration router - return empty results instead of 500 when the tables are empty - simplify RSSAccessPolicy to allow any authenticated user --- .../src/diracx/core/config/__init__.py | 2 + diracx-core/src/diracx/core/config/sources.py | 27 ++- diracx-core/src/diracx/core/extensions.py | 1 + diracx-core/src/diracx/core/models/rss.py | 15 +- diracx-core/tests/test_utils.py | 147 ++++++++++++++++ diracx-db/src/diracx/db/sql/rss/db.py | 68 +++----- diracx-db/tests/rss/test_rss_db.py | 36 ++++ diracx-logic/pyproject.toml | 6 + diracx-logic/src/diracx/logic/rss/source.py | 70 +++++--- diracx-logic/tests/rss/test_rss_source.py | 52 ++++-- .../src/diracx/routers/configuration.py | 39 +---- diracx-routers/src/diracx/routers/factory.py | 24 +++ .../src/diracx/routers/rss/access_policies.py | 10 +- diracx-routers/src/diracx/routers/rss/rss.py | 134 ++++----------- .../src/diracx/routers/utils/__init__.py | 8 +- .../src/diracx/routers/utils/http_cache.py | 72 ++++++++ diracx-routers/tests/test_rss.py | 160 ++++++++++++++---- diracx-testing/src/diracx/testing/utils.py | 3 +- 18 files changed, 605 insertions(+), 269 deletions(-) create mode 100644 diracx-routers/src/diracx/routers/utils/http_cache.py diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 15c0c4970..77f62e081 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -16,6 +16,7 @@ "RegistryConfig", "RemoteGitConfigSource", "SerializableSet", + "Snapshot", "SupportInfo", "UserConfig", "is_running_in_async_context", @@ -39,5 +40,6 @@ ConfigSourceUrl, LocalGitConfigSource, RemoteGitConfigSource, + Snapshot, is_running_in_async_context, ) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py index 8072cb1ef..9853b79cb 100644 --- a/diracx-core/src/diracx/core/config/sources.py +++ b/diracx-core/src/diracx/core/config/sources.py @@ -9,10 +9,11 @@ import logging import os from abc import ABCMeta, abstractmethod +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from tempfile import TemporaryDirectory -from typing import Annotated, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar from urllib.parse import urlparse, urlunparse import sh @@ -139,6 +140,15 @@ def clear_caches(self): self._content_cache.clear() +@dataclass(frozen=True) +class Snapshot(Generic[T]): + """Wraps a cached data payload with its cache metadata.""" + + data: T + hexsha: str + modified: datetime + + class AsyncCacheableSource(Generic[T], metaclass=ABCMeta): """Abstract base class for async sources that can be cached. @@ -146,6 +156,10 @@ class AsyncCacheableSource(Generic[T], metaclass=ABCMeta): functions are native coroutines. """ + #: The database class this source reads from. Used by the application + #: factory to instantiate the source with the matching database instance. + db_class: ClassVar[type] + def __init__(self): self._revision_cache = AsyncTwoLevelCache( soft_ttl=DEFAULT_CS_REV_CACHE_SOFT_TTL, @@ -187,6 +201,17 @@ async def clear_caches(self): await self._revision_cache.clear() self._content_cache.clear() + @classmethod + async def create(cls) -> T: + """Dependency injection stub. + + The application factory instantiates each concrete source and + overrides ``cls.create`` with the instance's ``read`` method, so this + should never actually be called. Each subclass's bound ``create`` + classmethod is a distinct dependency key. + """ + raise NotImplementedError(f"{cls.__name__} was not wired by the factory") + class ConfigSource(CacheableSource[Config]): """Abstract class for the configuration source. diff --git a/diracx-core/src/diracx/core/extensions.py b/diracx-core/src/diracx/core/extensions.py index 4aeac7c15..7a261208f 100644 --- a/diracx-core/src/diracx/core/extensions.py +++ b/diracx-core/src/diracx/core/extensions.py @@ -23,6 +23,7 @@ class DiracEntryPoint(StrEnum): CORE = "diracx" ACCESS_POLICY = "diracx.access_policies" + CACHEABLE_SOURCES = "diracx.cacheable_sources" CLI = "diracx.cli" HIDDEN_CLI = "diracx.cli.hidden" OS_DB = "diracx.dbs.os" diff --git a/diracx-core/src/diracx/core/models/rss.py b/diracx-core/src/diracx/core/models/rss.py index 10e1444a9..6f022b38f 100644 --- a/diracx-core/src/diracx/core/models/rss.py +++ b/diracx-core/src/diracx/core/models/rss.py @@ -1,23 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass -from datetime import datetime from enum import StrEnum -from typing import Annotated, Generic, Literal, TypeVar, Union +from typing import Annotated, Literal, Union from pydantic import BaseModel, Field -T = TypeVar("T") - - -@dataclass(frozen=True) -class Snapshot(Generic[T]): - """Wraps a cached data payload with its cache metadata.""" - - data: T - hexsha: str - modified: datetime - class AllowedStatus(BaseModel): allowed: Literal[True] diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index b88483e39..42edc5346 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -11,6 +11,7 @@ from diracx.core.exceptions import NotReadyError from diracx.core.models import TokenResponse from diracx.core.utils import ( + AsyncTwoLevelCache, TwoLevelCache, dotenv_files_from_environment, read_credentials, @@ -299,3 +300,149 @@ def start_slow(): # Ensure background thread completes thread.join() + + +class TestAsyncTwoLevelCache: + """Tests for AsyncTwoLevelCache, mirroring TestTwoLevelCache.""" + + async def test_successful_population(self): + """Test that cache is populated successfully.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + return "test_value" + + result = await cache.get("key", populate) + assert result == "test_value" + assert call_count == 1 + + # Second call should use cached value + result = await cache.get("key", populate) + assert result == "test_value" + assert call_count == 1 + + async def test_failed_population_logs_and_allows_retry(self, caplog): + """Test that failed population logs error and allows retry on next request.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def failing_populate(): + nonlocal call_count + call_count += 1 + raise ValueError("Test error") + + # First call should fail and log the error + with pytest.raises(ValueError, match="Test error"): + await cache.get("key", failing_populate, blocking=True) + + assert call_count == 1 + assert "Failed to populate cache key 'key'" in caplog.text + assert "Test error" in caplog.text + + # Task should be removed, so next call should retry + with pytest.raises(ValueError, match="Test error"): + await cache.get("key", failing_populate, blocking=True) + + assert call_count == 2 # Should have retried + + async def test_failed_population_then_success(self): + """Test that after a failure, subsequent successful call works.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + should_fail = True + + async def populate(): + if should_fail: + raise ValueError("Test error") + return "success_value" + + # First call fails + with pytest.raises(ValueError): + await cache.get("key", populate, blocking=True) + + # Second call succeeds + should_fail = False + result = await cache.get("key", populate, blocking=True) + assert result == "success_value" + + async def test_none_return_value_is_cached(self): + """Test that None return values are properly cached.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate_none(): + nonlocal call_count + call_count += 1 + return None + + result = await cache.get("key", populate_none) + assert result is None + assert call_count == 1 + + # Second call should use cached None value + result = await cache.get("key", populate_none) + assert result is None + assert call_count == 1 # Should not have called populate again + + async def test_non_blocking_raises_not_ready(self): + """Test that non-blocking mode raises NotReadyError when cache is empty.""" + import asyncio + + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + started = asyncio.Event() + release = asyncio.Event() + + async def slow_populate(): + started.set() + await release.wait() + return "value" + + # Start population in the background + task = asyncio.create_task(cache.get("key", slow_populate, blocking=True)) + + # Wait until slow_populate has started to avoid race conditions + await asyncio.wait_for(started.wait(), timeout=1.0) + + # Non-blocking call should raise NotReadyError since cache isn't populated yet + with pytest.raises(NotReadyError, match="not ready yet"): + await cache.get("key", slow_populate, blocking=False) + + # Let the in-flight population finish + release.set() + assert await task == "value" + + async def test_single_flight_deduplication(self): + """Test that concurrent gets for the same key only populate once.""" + import asyncio + + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return "value" + + results = await asyncio.gather( + *(cache.get("key", populate, blocking=True) for _ in range(5)) + ) + assert results == ["value"] * 5 + assert call_count == 1 + + async def test_clear(self): + """Test that clear empties both cache tiers.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + return "value" + + assert await cache.get("key", populate) == "value" + await cache.clear() + assert await cache.get("key", populate) == "value" + assert call_count == 2 diff --git a/diracx-db/src/diracx/db/sql/rss/db.py b/diracx-db/src/diracx/db/sql/rss/db.py index 7fe095da8..fc01ab96b 100644 --- a/diracx-db/src/diracx/db/sql/rss/db.py +++ b/diracx-db/src/diracx/db/sql/rss/db.py @@ -2,11 +2,9 @@ from datetime import datetime, timezone -from sqlalchemy import insert, select +from sqlalchemy import func, insert, select from sqlalchemy.engine import Row -from diracx.core.exceptions import ResourceNotFoundError - from ..utils import BaseSQLDB from .schema import ( ResourceStatus, @@ -34,11 +32,7 @@ async def get_site_statuses(self) -> list[tuple[str, str, str, str]]: SiteStatus.vo, ).where(SiteStatus.status_type == "all") result = await self.conn.execute(stmt) - rows = result.all() - if not rows: - raise ResourceNotFoundError("Site statuses") - - return [(row.Name, row.Status, row.Reason, row.VO) for row in rows] + return [(row.Name, row.Status, row.Reason, row.VO) for row in result.all()] async def get_resource_statuses( self, @@ -66,13 +60,9 @@ async def get_resource_statuses( ResourceStatus.status_type.in_(status_types), ) result = await self.conn.execute(stmt) - rows = result.all() - - if not rows: - raise ResourceNotFoundError("Resource statuses") statuses: dict[str, dict[str, Row]] = {} - for row in rows: + for row in result.all(): if row.Name not in statuses: statuses[row.Name] = {} statuses[row.Name][row.StatusType] = row @@ -81,54 +71,42 @@ async def get_resource_statuses( async def get_resource_status_date( self, status_types: list[str] | None = None, - ) -> Row[tuple[datetime, datetime]]: - """Return the most recent DateEffective across all VOs for the given status types. + ) -> tuple[datetime | None, int]: + """Return the most recent DateEffective and row count for the given status types. Args: status_types: Status type filter. Defaults to ["all"]. Returns: - Row with (date_effective, last_check_time) for the most recent entry. + (max_date_effective, row_count) across all VOs. The date is None + when the table contains no matching rows. """ if not status_types: status_types = ["all"] - stmt = ( - select( - ResourceStatus.date_effective, - ResourceStatus.last_check_time, - ) - .where(ResourceStatus.status_type.in_(status_types)) - .order_by(ResourceStatus.date_effective.desc()) - .limit(1) - ) + stmt = select( + func.max(ResourceStatus.date_effective), + func.count(), + ).where(ResourceStatus.status_type.in_(status_types)) result = await self.conn.execute(stmt) - row = result.first() - if not row: - raise ResourceNotFoundError("Resource statuses") - return row + max_date, count = result.one() + return max_date, count - async def get_site_status_date(self) -> Row[tuple[datetime, datetime]]: - """Return the most recent DateEffective from the SiteStatus table across all VOs. + async def get_site_status_date(self) -> tuple[datetime | None, int]: + """Return the most recent DateEffective and row count from the SiteStatus table. Returns: - Row with (date_effective, last_check_time) for the most recent entry. + (max_date_effective, row_count) across all VOs. The date is None + when the table contains no matching rows. """ - stmt = ( - select( - SiteStatus.date_effective, - SiteStatus.last_check_time, - ) - .where(SiteStatus.status_type == "all") - .order_by(SiteStatus.date_effective.desc()) - .limit(1) - ) + stmt = select( + func.max(SiteStatus.date_effective), + func.count(), + ).where(SiteStatus.status_type == "all") result = await self.conn.execute(stmt) - row = result.first() - if not row: - raise ResourceNotFoundError("Site statuses") - return row + max_date, count = result.one() + return max_date, count async def insert_resource_status( self, diff --git a/diracx-db/tests/rss/test_rss_db.py b/diracx-db/tests/rss/test_rss_db.py index 6e199619f..3457163c5 100644 --- a/diracx-db/tests/rss/test_rss_db.py +++ b/diracx-db/tests/rss/test_rss_db.py @@ -132,3 +132,39 @@ async def test_resource_status(rss_db: ResourceStatusDB): for row in result["TestStorage"].values(): assert row.Status == "Active" assert row.Reason == "All good" + + # The date queries should return the latest date and the row count + async with rss_db as db: + max_date, count = await db.get_resource_status_date() + assert max_date == _NOW + assert count == 2 # TestCompute + TestFTS "all" rows + + async with rss_db as db: + max_date, count = await db.get_resource_status_date( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + ) + assert max_date == _NOW + assert count == 4 # TestStorage access rows + + +async def test_empty_tables(rss_db: ResourceStatusDB): + """Empty tables yield empty results rather than errors.""" + async with rss_db as db: + assert await db.get_site_statuses() == [] + assert await db.get_resource_statuses() == {} + assert await db.get_resource_status_date() == (None, 0) + assert await db.get_site_status_date() == (None, 0) + + +async def test_site_status_date(rss_db: ResourceStatusDB): + async with rss_db as db: + await db.insert_site_status( + name="LCG.CERN.cern", + status="Active", + vo="lhcb", + reason="All good", + date_effective=_NOW, + ) + max_date, count = await db.get_site_status_date() + assert max_date == _NOW + assert count == 1 diff --git a/diracx-logic/pyproject.toml b/diracx-logic/pyproject.toml index 889fe9a55..0ddfb5949 100644 --- a/diracx-logic/pyproject.toml +++ b/diracx-logic/pyproject.toml @@ -23,6 +23,12 @@ dependencies = [ ] dynamic = ["version"] +[project.entry-points."diracx.cacheable_sources"] +storage_status = "diracx.logic.rss.source:StorageElementStatusSource" +compute_status = "diracx.logic.rss.source:ComputeElementStatusSource" +fts_status = "diracx.logic.rss.source:FTSStatusSource" +site_status = "diracx.logic.rss.source:SiteStatusSource" + [project.optional-dependencies] testing = ["diracx-testing", "freezegun"] types = [ diff --git a/diracx-logic/src/diracx/logic/rss/source.py b/diracx-logic/src/diracx/logic/rss/source.py index a00e08d48..61091eab2 100644 --- a/diracx-logic/src/diracx/logic/rss/source.py +++ b/diracx-logic/src/diracx/logic/rss/source.py @@ -9,11 +9,11 @@ from __future__ import annotations import logging -from datetime import datetime +from abc import abstractmethod +from datetime import datetime, timezone from typing import ClassVar -from diracx.core.config.sources import AsyncCacheableSource -from diracx.core.models.rss import Snapshot +from diracx.core.config.sources import AsyncCacheableSource, Snapshot from diracx.db.sql.rss.db import ResourceStatusDB from .query import ( @@ -25,18 +25,36 @@ logger = logging.getLogger(__name__) +#: Revision returned when the underlying table contains no rows. +EMPTY_REVISION = ("empty-0", datetime(1970, 1, 1, tzinfo=timezone.utc)) -class ResourceStatusSource(AsyncCacheableSource): + +def _make_revision(max_date: datetime | None, count: int) -> tuple[str, datetime]: + """Build a (revision, modified) pair from the latest date and row count. + + Including the row count in the revision means insertions and deletions + change the ETag even when they do not advance the latest DateEffective. + """ + if max_date is None: + return EMPTY_REVISION + return f"{max_date.isoformat()}-{count}", max_date + + +class ResourceStatusSource(AsyncCacheableSource[Snapshot]): """Base caching source for Compute, Storage and FTS resource types. - Subclasses declare `resource_type` as a class attribute — latest_revision - and _fetch dispatch on it automatically. + Subclasses declare the status types their data lives in and how to fetch + it from the database. One source instance per resource type covers all VOs. VO filtering is done in the route after the snapshot is fetched from the cache. """ - resource_type: ClassVar[str] + db_class = ResourceStatusDB + + #: Status types holding this resource type's data, used both for the + #: revision query and the data fetch. + status_types: ClassVar[list[str]] def __init__(self, *, db: ResourceStatusDB) -> None: super().__init__() @@ -44,53 +62,57 @@ def __init__(self, *, db: ResourceStatusDB) -> None: async def latest_revision(self) -> tuple[str, datetime]: async with self._db as db: - row = await db.get_resource_status_date() - modified: datetime = row.DateEffective - return modified.isoformat(), modified + max_date, count = await db.get_resource_status_date(self.status_types) + return _make_revision(max_date, count) async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: async with self._db as db: data = await self._fetch(db) return Snapshot(data=data, hexsha=hexsha, modified=modified) + @abstractmethod async def _fetch(self, db: ResourceStatusDB) -> dict: - if self.resource_type == "StorageElement": - return await get_storage_statuses(db) - if self.resource_type == "ComputeElement": - return await get_compute_statuses(db) - if self.resource_type == "FTS": - return await get_fts_statuses(db) - raise ValueError(f"Unsupported resource_type: {self.resource_type!r}") + """Fetch this resource type's statuses, keyed by VO then name.""" class StorageElementStatusSource(ResourceStatusSource): - resource_type = "StorageElement" + status_types = ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_storage_statuses(db) class ComputeElementStatusSource(ResourceStatusSource): - resource_type = "ComputeElement" + status_types = ["all"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_compute_statuses(db) class FTSStatusSource(ResourceStatusSource): - resource_type = "FTS" + status_types = ["all"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_fts_statuses(db) -class SiteStatusSource(AsyncCacheableSource): +class SiteStatusSource(AsyncCacheableSource[Snapshot]): """Caching source for Site statuses. Uses its own DB table (SiteStatus) and a dedicated date query, so it is a direct subclass of AsyncCacheableSource rather than ResourceStatusSource. """ + db_class = ResourceStatusDB + def __init__(self, *, db: ResourceStatusDB) -> None: super().__init__() self._db = db async def latest_revision(self) -> tuple[str, datetime]: async with self._db as db: - row = await db.get_site_status_date() - modified: datetime = row.DateEffective - return modified.isoformat(), modified + max_date, count = await db.get_site_status_date() + return _make_revision(max_date, count) async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: async with self._db as db: diff --git a/diracx-logic/tests/rss/test_rss_source.py b/diracx-logic/tests/rss/test_rss_source.py index 4ef55df14..503c364c2 100644 --- a/diracx-logic/tests/rss/test_rss_source.py +++ b/diracx-logic/tests/rss/test_rss_source.py @@ -20,20 +20,17 @@ StorageElementStatusSource, ) +_MAX_DATE = datetime.fromisoformat("2023-01-01T00:00:00+00:00") + @pytest.fixture def mock_resource_status_db(): """Fixture to mock the ResourceStatusDB.""" db = MagicMock(spec=ResourceStatusDB) - DateRow = namedtuple("DateRow", ["DateEffective", "DateChecked"]) db.__aenter__ = AsyncMock(return_value=db) db.__aexit__ = AsyncMock(return_value=None) - db.get_resource_status_date = AsyncMock( - return_value=DateRow( - DateEffective=datetime.fromisoformat("2023-01-01T00:00:00+00:00"), - DateChecked=datetime.now(timezone.utc), - ) - ) + db.get_resource_status_date = AsyncMock(return_value=(_MAX_DATE, 4)) + db.get_site_status_date = AsyncMock(return_value=(_MAX_DATE, 2)) return db @@ -45,11 +42,44 @@ async def test_latest_revision(mock_resource_status_db): revision, modified = await source.latest_revision() # Verify the revision is generated correctly - assert revision - assert isinstance(modified, datetime) + assert revision == f"{_MAX_DATE.isoformat()}-4" + assert modified == _MAX_DATE + + # Verify the database call queries this source's status types + mock_resource_status_db.get_resource_status_date.assert_awaited_once_with(["all"]) + + +async def test_latest_revision_storage_status_types(mock_resource_status_db): + """Storage revisions must track the access status types, not "all".""" + source = StorageElementStatusSource(db=mock_resource_status_db) + + await source.latest_revision() + + mock_resource_status_db.get_resource_status_date.assert_awaited_once_with( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + ) + + +async def test_latest_revision_empty(mock_resource_status_db): + """An empty table yields a stable sentinel revision instead of failing.""" + mock_resource_status_db.get_resource_status_date = AsyncMock(return_value=(None, 0)) + source = ComputeElementStatusSource(db=mock_resource_status_db) + + revision, modified = await source.latest_revision() + + assert revision == "empty-0" + assert modified == datetime(1970, 1, 1, tzinfo=timezone.utc) + + +async def test_latest_revision_site(mock_resource_status_db): + """Test the latest_revision method of SiteStatusSource.""" + source = SiteStatusSource(db=mock_resource_status_db) + + revision, modified = await source.latest_revision() - # Verify the database call - mock_resource_status_db.get_resource_status_date.assert_called_once() + assert revision == f"{_MAX_DATE.isoformat()}-2" + assert modified == _MAX_DATE + mock_resource_status_db.get_site_status_date.assert_awaited_once_with() async def test_read_raw_site(mock_resource_status_db): diff --git a/diracx-routers/src/diracx/routers/configuration.py b/diracx-routers/src/diracx/routers/configuration.py index 63d4ea199..16f5a6536 100644 --- a/diracx-routers/src/diracx/routers/configuration.py +++ b/diracx-routers/src/diracx/routers/configuration.py @@ -3,13 +3,10 @@ __all__ = ["router"] import logging -from datetime import datetime, timezone -from http import HTTPStatus from typing import Annotated from fastapi import ( Header, - HTTPException, Response, ) @@ -17,11 +14,10 @@ from .access_policies import open_access from .fastapi_classes import DiracxRouter +from .utils.http_cache import apply_cache_headers logger = logging.getLogger(__name__) -LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" - router = DiracxRouter() @@ -42,31 +38,12 @@ async def serve_config( return 304: this is to avoid flip/flopping """ # await check_permissions() - headers = { - "ETag": config._hexsha, - "Last-Modified": config._modified.strftime(LAST_MODIFIED_FORMAT), - } - - if if_none_match == config._hexsha: - raise HTTPException(status_code=HTTPStatus.NOT_MODIFIED, headers=headers) - - # This is to prevent flip/flopping in case - # a server gets out of sync with disk - if if_modified_since: - try: - not_before = datetime.strptime( - if_modified_since, LAST_MODIFIED_FORMAT - ).astimezone(timezone.utc) - except ValueError: - logger.debug( - "Failed to parse If-Modified-Since header: %s", if_modified_since - ) - else: - if not_before > config._modified: - raise HTTPException( - status_code=HTTPStatus.NOT_MODIFIED, headers=headers - ) - - response.headers.update(headers) + apply_cache_headers( + response, + etag=config._hexsha, + modified=config._modified, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ) return config diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py index 4c26f4c58..b33490be3 100644 --- a/diracx-routers/src/diracx/routers/factory.py +++ b/diracx-routers/src/diracx/routers/factory.py @@ -185,6 +185,7 @@ def create_app_inner( fail_startup = True # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() + sql_db_instances: dict[type[BaseSQLDB], BaseSQLDB] = {} for db_name, db_url in database_urls.items(): try: @@ -199,6 +200,7 @@ def create_app_inner( for sql_db_class in sql_db_classes: assert sql_db_class.transaction not in app.dependency_overrides available_sql_db_classes.add(sql_db_class) + sql_db_instances[sql_db_class] = sql_db app.dependency_overrides[sql_db_class.transaction] = partial( db_transaction, sql_db @@ -215,6 +217,28 @@ def create_app_inner( if fail_startup: raise Exception("No SQL database could be initialized, aborting") + # Instantiate the cacheable sources and override their create methods, + # mirroring the ConfigSource wiring above. A single instance is used for + # each source so that its caches persist across requests. + wired_source_names = set() + for entry_point in select_from_extension(group=DiracEntryPoint.CACHEABLE_SOURCES): + # The first entry point for a given name is the highest priority one + if entry_point.name in wired_source_names: + continue + wired_source_names.add(entry_point.name) + source_cls = entry_point.load() + source_db = sql_db_instances.get(source_cls.db_class) + if source_db is None: + logger.warning( + "Cannot wire cacheable source %s: %s is not available", + entry_point.name, + source_cls.db_class.__name__, + ) + continue + source = source_cls(db=source_db) + assert source_cls.create not in app.dependency_overrides + app.dependency_overrides[source_cls.create] = source.read + # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() for db_name, connection_kwargs in os_database_conn_kwargs.items(): diff --git a/diracx-routers/src/diracx/routers/rss/access_policies.py b/diracx-routers/src/diracx/routers/rss/access_policies.py index b0b439976..e30e52b32 100644 --- a/diracx-routers/src/diracx/routers/rss/access_policies.py +++ b/diracx-routers/src/diracx/routers/rss/access_policies.py @@ -3,7 +3,7 @@ from collections.abc import Callable from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import Depends from diracx.routers.access_policies import BaseAccessPolicy from diracx.routers.utils.users import AuthorizedUserInfo @@ -18,10 +18,10 @@ async def policy( user_info: AuthorizedUserInfo, /, ): - if user_info.preferred_username: - return - - raise HTTPException(status.HTTP_403_FORBIDDEN) + # Authentication is already guaranteed by verify_dirac_access_token; + # any authenticated user may read resource statuses. VO scoping is + # applied in the routes themselves. + return CheckRSSPolicyCallable = Annotated[Callable, Depends(RSSAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/rss/rss.py b/diracx-routers/src/diracx/routers/rss/rss.py index e0fce2735..ee7adfff5 100644 --- a/diracx-routers/src/diracx/routers/rss/rss.py +++ b/diracx-routers/src/diracx/routers/rss/rss.py @@ -1,109 +1,57 @@ from __future__ import annotations import logging -from datetime import datetime, timezone -from typing import Annotated, cast - -from fastapi import ( - Depends, - Header, - HTTPException, - Response, - status, -) +from typing import Annotated, Any + +from fastapi import Depends, Header, Response +from diracx.core.config.sources import Snapshot from diracx.core.models.rss import ( ComputeElementStatus, FTSStatus, SiteStatus, - Snapshot, StorageElementStatus, ) -from diracx.db.sql.rss.db import ResourceStatusDB from diracx.logic.rss.source import ( ComputeElementStatusSource, FTSStatusSource, - ResourceStatusSource, SiteStatusSource, StorageElementStatusSource, ) from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token -from diracx.tasks.plumbing.depends import NoTransaction from ..fastapi_classes import DiracxRouter +from ..utils.http_cache import apply_cache_headers from .access_policies import CheckRSSPolicyCallable logger = logging.getLogger(__name__) -LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" - router = DiracxRouter() -async def get_snapshot(rss_type, db) -> Snapshot: - """Get the status snapshot from the unique ResourceStatusSource instance or create it if it does not exist. - - :param rss_type: The type of the resource status source. - :param db: The database instance. - :returns: The status snapshot from the unique ResourceStatusSource instance. - """ - sources: dict[str, ResourceStatusSource | SiteStatusSource] = {} - if rss_type == "ComputeElementStatus": - if rss_type not in sources: - sources[rss_type] = ComputeElementStatusSource(db=db) - return await sources[rss_type].read() - elif rss_type == "StorageElementStatus": - if rss_type not in sources: - sources[rss_type] = StorageElementStatusSource(db=db) - return await sources[rss_type].read() - elif rss_type == "FTSStatus": - if rss_type not in sources: - sources[rss_type] = FTSStatusSource(db=db) - return await sources[rss_type].read() - elif rss_type == "SiteStatus": - if rss_type not in sources: - sources[rss_type] = SiteStatusSource(db=db) - return await sources[rss_type].read() - else: - raise ValueError(f"Unknown resource status source type: {rss_type}") - - -def _apply_cache_headers( - response: Response, +def _vo_view( snapshot: Snapshot, + vo: str, + response: Response, if_none_match: str | None, if_modified_since: str | None, -) -> None: - """Set ETag / Last-Modified headers and raise 304 when appropriate. - - Raises: - HTTPException(304): when the client's cached copy is still current. +) -> dict[str, Any]: + """Apply cache headers and return the caller's VO view of a snapshot. + The snapshot covers all VOs so it can be cached once; the response is the + "all" entries overlaid with the caller's VO-specific entries. The ETag is + suffixed with the VO (and Vary: Authorization set) since the same URL + serves different content per VO. """ - headers = { - "ETag": snapshot.hexsha, - "Last-Modified": snapshot.modified.strftime(LAST_MODIFIED_FORMAT), - } - - if if_none_match == snapshot.hexsha: - raise HTTPException(status_code=status.HTTP_304_NOT_MODIFIED, headers=headers) - - if if_modified_since: - try: - not_before = datetime.strptime( - if_modified_since, LAST_MODIFIED_FORMAT - ).astimezone(timezone.utc) - except ValueError: - logger.debug( - "Failed to parse If-Modified-Since header: %s", if_modified_since - ) - else: - if not_before > snapshot.modified: - raise HTTPException( - status_code=status.HTTP_304_NOT_MODIFIED, headers=headers - ) - - response.headers.update(headers) + apply_cache_headers( + response, + etag=f"{snapshot.hexsha}-{vo}", + modified=snapshot.modified, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + vary="Authorization", + ) + return {**snapshot.data.get("all", {}), **snapshot.data.get(vo, {})} @router.get("/storage") @@ -111,17 +59,12 @@ async def get_storage_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - db: Annotated[ResourceStatusDB, NoTransaction()], + snapshot: Annotated[Snapshot, Depends(StorageElementStatusSource.create)], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, StorageElementStatus]: """Get the latest status of storage elements, scoped to the caller's VO.""" - snapshot = await get_snapshot("StorageElementStatus", db) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, StorageElementStatus], - {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, - ) + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) @router.get("/compute") @@ -129,17 +72,12 @@ async def get_compute_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - db: Annotated[ResourceStatusDB, NoTransaction()], + snapshot: Annotated[Snapshot, Depends(ComputeElementStatusSource.create)], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, ComputeElementStatus]: """Get the latest status of compute elements, scoped to the caller's VO.""" - snapshot = await get_snapshot("ComputeElementStatus", db) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, ComputeElementStatus], - {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, - ) + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) @router.get("/site") @@ -147,17 +85,12 @@ async def get_site_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - db: Annotated[ResourceStatusDB, NoTransaction()], + snapshot: Annotated[Snapshot, Depends(SiteStatusSource.create)], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, SiteStatus]: """Get the latest status of sites, scoped to the caller's VO.""" - snapshot = await get_snapshot("SiteStatus", db) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, SiteStatus], - {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, - ) + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) @router.get("/fts") @@ -165,14 +98,9 @@ async def get_fts_status( response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckRSSPolicyCallable, - db: Annotated[ResourceStatusDB, NoTransaction()], + snapshot: Annotated[Snapshot, Depends(FTSStatusSource.create)], if_none_match: Annotated[str | None, Header()] = None, if_modified_since: Annotated[str | None, Header()] = None, ) -> dict[str, FTSStatus]: """Get the latest status of FTS servers, scoped to the caller's VO.""" - snapshot = await get_snapshot("FTSStatus", db) - _apply_cache_headers(response, snapshot, if_none_match, if_modified_since) - return cast( - dict[str, FTSStatus], - {**snapshot.data.get("all", {}), **snapshot.data.get(user_info.vo, {})}, - ) + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) diff --git a/diracx-routers/src/diracx/routers/utils/__init__.py b/diracx-routers/src/diracx/routers/utils/__init__.py index a8c5919dc..260a03638 100644 --- a/diracx-routers/src/diracx/routers/utils/__init__.py +++ b/diracx-routers/src/diracx/routers/utils/__init__.py @@ -1,5 +1,11 @@ from __future__ import annotations -__all__ = ["AuthorizedUserInfo", "verify_dirac_access_token"] +__all__ = [ + "LAST_MODIFIED_FORMAT", + "AuthorizedUserInfo", + "apply_cache_headers", + "verify_dirac_access_token", +] +from .http_cache import LAST_MODIFIED_FORMAT, apply_cache_headers from .users import AuthorizedUserInfo, verify_dirac_access_token diff --git a/diracx-routers/src/diracx/routers/utils/http_cache.py b/diracx-routers/src/diracx/routers/utils/http_cache.py new file mode 100644 index 000000000..7966d0cba --- /dev/null +++ b/diracx-routers/src/diracx/routers/utils/http_cache.py @@ -0,0 +1,72 @@ +"""Helpers for HTTP conditional-request caching (ETag / Last-Modified / 304).""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from http import HTTPStatus + +from fastapi import HTTPException, Response + +logger = logging.getLogger(__name__) + +LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" + + +def apply_cache_headers( + response: Response, + *, + etag: str, + modified: datetime, + if_none_match: str | None, + if_modified_since: str | None, + vary: str | None = None, +) -> None: + """Set ETag / Last-Modified headers and raise 304 when appropriate. + + If If-None-Match matches the current ETag, return 304. + + If If-Modified-Since is given and is newer than the current Last-Modified, + return 304: this is to avoid flip/flopping in case a server gets out of + sync with the source of truth. + + Args: + response: The response whose headers should be updated. + etag: The current entity tag. + modified: The current modification time (timezone-aware). + if_none_match: Value of the If-None-Match request header, if any. + if_modified_since: Value of the If-Modified-Since request header, if any. + vary: Optional value for the Vary header, for responses whose content + depends on more than the URL (e.g. the caller's identity). + + Raises: + HTTPException(304): when the client's cached copy is still current. + + """ + headers = { + "ETag": etag, + "Last-Modified": modified.strftime(LAST_MODIFIED_FORMAT), + } + if vary is not None: + headers["Vary"] = vary + + if if_none_match == etag: + raise HTTPException(status_code=HTTPStatus.NOT_MODIFIED, headers=headers) + + if if_modified_since: + try: + # The If-Modified-Since header is always GMT (RFC 9110) + not_before = datetime.strptime( + if_modified_since, LAST_MODIFIED_FORMAT + ).replace(tzinfo=timezone.utc) + except ValueError: + logger.debug( + "Failed to parse If-Modified-Since header: %s", if_modified_since + ) + else: + if not_before > modified: + raise HTTPException( + status_code=HTTPStatus.NOT_MODIFIED, headers=headers + ) + + response.headers.update(headers) diff --git a/diracx-routers/tests/test_rss.py b/diracx-routers/tests/test_rss.py index 92056429c..64df9c5fa 100644 --- a/diracx-routers/tests/test_rss.py +++ b/diracx-routers/tests/test_rss.py @@ -1,10 +1,9 @@ from __future__ import annotations -import asyncio from datetime import datetime, timezone +from http import HTTPStatus import pytest -from fastapi import status pytestmark = pytest.mark.enabled_dependencies( [ @@ -19,20 +18,43 @@ ] ) +ALL_ENDPOINTS = [ + "/api/rss/storage", + "/api/rss/compute", + "/api/rss/site", + "/api/rss/fts", +] -async def _prepare_rss(client): - """Seed the DB and warm every source cache inside a single connection.""" - from diracx.core.config.sources import AsyncCacheableSource + +def _get_rss_db(client): from diracx.db.sql.rss.db import ResourceStatusDB - db_override = client.app.dependency_overrides.get(ResourceStatusDB.no_transaction) - if db_override is None: - return - # factory.py stores partial(db_transaction, db_instance); args[0] is the instance. - db = db_override.args[0] + db_override = client.app.dependency_overrides[ResourceStatusDB.no_transaction] + # factory.py stores partial(db_no_transaction, db_instance); args[0] is the instance. + return db_override.args[0] + + +async def _clear_source_caches(client): + """Clear the singleton sources' caches. + + The sources live for the whole test session while each test gets a fresh + database, so any snapshot cached by a previous test must be dropped. + """ + from diracx.core.config.sources import AsyncCacheableSource + + for override in client.app.dependency_overrides.values(): + source = getattr(override, "__self__", None) + if isinstance(source, AsyncCacheableSource): + await source.clear_caches() + + +async def _prepare_rss(client): + """Reset the source caches and seed the database.""" + await _clear_source_caches(client) + + db = _get_rss_db(client) now = datetime.now(tz=timezone.utc) - # Seed — open one connection, insert all rows, then close it cleanly. async with db as conn: for status_type in ("ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"): await conn.insert_resource_status( @@ -43,6 +65,16 @@ async def _prepare_rss(client): reason="All good", date_effective=now, ) + # A storage element belonging to another VO, which the test user + # (vo=lhcb) must not see. + await conn.insert_resource_status( + name="SE-OTHER", + status="Active", + status_type=status_type, + vo="other_vo", + reason="All good", + date_effective=now, + ) await conn.insert_resource_status( name="CE-CERN", status="Active", @@ -66,19 +98,29 @@ async def _prepare_rss(client): reason="All good", date_effective=now, ) - # Connection is now fully closed and _conn ContextVar is reset. - - # Warm each source — each source.read() opens its own fresh connection. - for override in client.app.dependency_overrides.values(): - source = getattr(override, "__self__", None) - if isinstance(source, AsyncCacheableSource): - await source.read() + # A site visible to every VO. + await conn.insert_site_status( + name="LCG.Shared.ch", + status="Active", + vo="all", + reason="All good", + date_effective=now, + ) @pytest.fixture def normal_user_client(client_factory): with client_factory.normal_user() as client: - asyncio.get_event_loop().run_until_complete(_prepare_rss(client)) + # Run on the TestClient's portal so async primitives are bound to the + # same event loop that serves the requests. + client.portal.call(_prepare_rss, client) + yield client + + +@pytest.fixture +def empty_db_client(client_factory): + with client_factory.normal_user() as client: + client.portal.call(_clear_source_caches, client) yield client @@ -90,27 +132,28 @@ def unauthenticated_client(client_factory): def test_unauthenticated(unauthenticated_client): response = unauthenticated_client.get("/api/rss/storage") - assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert response.status_code == HTTPStatus.UNAUTHORIZED -@pytest.mark.parametrize( - "endpoint", - ["/api/rss/storage", "/api/rss/compute", "/api/rss/site", "/api/rss/fts"], -) +@pytest.mark.parametrize("endpoint", ALL_ENDPOINTS) def test_get_resource_status(normal_user_client, endpoint): r = normal_user_client.get(endpoint) - assert r.status_code == status.HTTP_200_OK, r.json() + assert r.status_code == HTTPStatus.OK, r.json() assert r.json(), r.text last_modified = r.headers["Last-Modified"] etag = r.headers["ETag"] + # The same URL serves different content per VO, so the ETag must identify + # the VO and caches must be told the response varies with the caller. + assert etag.endswith("-lhcb") + assert "Authorization" in r.headers["Vary"] # Matching ETag + matching Last-Modified → 304 r = normal_user_client.get( endpoint, headers={"If-None-Match": etag, "If-Modified-Since": last_modified}, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text assert not r.text # Wrong ETag only → 200 @@ -118,7 +161,7 @@ def test_get_resource_status(normal_user_client, endpoint): endpoint, headers={"If-None-Match": "wrongEtag"}, ) - assert r.status_code == status.HTTP_200_OK, r.json() + assert r.status_code == HTTPStatus.OK, r.json() assert r.json(), r.text # Past ETag + past timestamp → 200 @@ -129,7 +172,7 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", }, ) - assert r.status_code == status.HTTP_200_OK, r.json() + assert r.status_code == HTTPStatus.OK, r.json() assert r.json(), r.text # Wrong ETag + future timestamp → 304 (If-Modified-Since takes effect) @@ -140,7 +183,7 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", }, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text assert not r.text # Wrong ETag + invalid timestamp → 200 @@ -151,7 +194,7 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "wrong format", }, ) - assert r.status_code == status.HTTP_200_OK, r.json() + assert r.status_code == HTTPStatus.OK, r.json() assert r.json(), r.text # Correct ETag + past timestamp → 304 (ETag match takes priority) @@ -162,7 +205,7 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", }, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text assert not r.text # Correct ETag + future timestamp → 304 @@ -173,7 +216,7 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", }, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text assert not r.text # Correct ETag + invalid timestamp → 304 @@ -184,5 +227,56 @@ def test_get_resource_status(normal_user_client, endpoint): "If-Modified-Since": "wrong format", }, ) - assert r.status_code == status.HTTP_304_NOT_MODIFIED, r.text + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text assert not r.text + + +def test_vo_filtering(normal_user_client): + """Users only see "all" entries plus those of their own VO.""" + r = normal_user_client.get("/api/rss/storage") + assert r.status_code == HTTPStatus.OK, r.json() + assert set(r.json()) == {"SE-CERN"} # not SE-OTHER (vo=other_vo) + + r = normal_user_client.get("/api/rss/site") + assert r.status_code == HTTPStatus.OK, r.json() + assert set(r.json()) == {"LCG.CERN.cern", "LCG.Shared.ch"} + + +@pytest.mark.parametrize("endpoint", ALL_ENDPOINTS) +def test_empty_db(empty_db_client, endpoint): + """An empty database yields an empty result with valid cache headers.""" + r = empty_db_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + assert r.json() == {} + assert r.headers["ETag"] == "empty-0-lhcb" + + # A conditional request against the sentinel revision still works + r = empty_db_client.get(endpoint, headers={"If-None-Match": "empty-0-lhcb"}) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + + +def test_served_from_cache(normal_user_client, monkeypatch): + """Once populated, requests are served from the cache without DB access.""" + from diracx.db.sql.rss.db import ResourceStatusDB + + # Populate the cache for every endpoint + for endpoint in ALL_ENDPOINTS: + r = normal_user_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + + # Break every read path of the DB to prove the cache is used + async def _fail(*args, **kwargs): + raise AssertionError("The database should not be accessed") + + for method in ( + "get_site_statuses", + "get_resource_statuses", + "get_resource_status_date", + "get_site_status_date", + ): + monkeypatch.setattr(ResourceStatusDB, method, _fail) + + for endpoint in ALL_ENDPOINTS: + r = normal_user_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + assert r.json(), r.text diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 570241e69..e700d5fa3 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -168,7 +168,7 @@ def __init__( test_sandbox_settings, test_dev_settings, ): - from diracx.core.config import ConfigSource + from diracx.core.config import AsyncCacheableSource, ConfigSource from diracx.core.extensions import select_from_extension from diracx.core.settings import ServiceSettingsBase from diracx.db.os.utils import BaseOSDB @@ -252,6 +252,7 @@ def enrich_tokens( BaseSQLDB, BaseOSDB, ConfigSource, + AsyncCacheableSource, BaseAccessPolicy, ), ), obj