From 4edbe5c9b0192525c4eed37835b3717c162c9acb Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 2 Jun 2026 13:30:23 +0200 Subject: [PATCH 1/2] feat: clean up AuthDB refresh tokens by dropping partitions Replace the row-level DELETE sweeps for the RefreshTokens table with maintenance of its monthly JTI range-partitions: drop partitions whose whole month is older than the retention horizon, and add partitions ahead of time so the p_future catch-all never fills. Dropping a partition is an O(1) metadata operation and avoids the row-lock and lock-memory cost of large DELETEs. Retention is now expressed in calendar months via the new DIRACX_SERVICE_AUTH_REFRESH_TOKEN_RETENTION_MONTHS setting (default 6), replacing revoked_refresh_token_retention_minutes. Partition maintenance is implemented for MySQL only and raises NotImplementedError for other dialects. The unpartitioned flow tables keep their existing DELETE-based cleanup. --- diracx-core/src/diracx/core/settings.py | 9 +- diracx-db/src/diracx/db/sql/auth/db.py | 151 +++++++++++--- diracx-db/tests/auth/test_partitions.py | 185 ++++++++++++++++++ diracx-db/tests/auth/test_refresh_token.py | 43 ---- diracx-logic/src/diracx/logic/__main__.py | 6 +- .../src/diracx/logic/auth/management.py | 15 +- docs/admin/reference/env-variables.md | 11 +- 7 files changed, 332 insertions(+), 88 deletions(-) create mode 100644 diracx-db/tests/auth/test_partitions.py diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index dd55a88a8..820de242c 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -251,11 +251,12 @@ def check_retention_greater_than_expiration(self) -> Self: through a new authentication flow. Default: 60 minutes. """ - revoked_refresh_token_retention_minutes: int = 43200 - """Retention time in minutes for revoked refresh tokens. + refresh_token_retention_months: int = 6 + """Retention time in months for refresh tokens. - The maximum retention time of refresh tokens after being - revoked and before they are deleted. Default: 43200 minutes (30 days). + Refresh tokens live in monthly partitions that are dropped once the whole + month is older than this many months. It is therefore the longest a refresh + token (revoked or not) is kept before removal. Default: 6 months. """ available_properties: set[SecurityProperty] = Field( diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index caf1b0226..a35a0ad19 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -1,10 +1,12 @@ from __future__ import annotations import logging +import re import secrets from datetime import UTC, datetime from itertools import pairwise +from dateutil.relativedelta import relativedelta from dateutil.rrule import MONTHLY, rrule from sqlalchemy import delete, insert, select, text, update from sqlalchemy.exc import IntegrityError, NoResultFound @@ -32,6 +34,58 @@ logger = logging.getLogger(__name__) +# Always keep at least this many months of future RefreshTokens partitions ahead +# of "now" so the ``p_future`` catch-all partition never accumulates rows. +PARTITION_MONTHS_AHEAD = 12 + + +def _month_start(dt: datetime) -> datetime: + """Truncate ``dt`` to the first instant of its month.""" + return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + + +def _partition_name(month_start: datetime) -> str: + """Name of the partition holding the tokens created during ``month_start``.""" + return f"p_{month_start.year}_{month_start.month}" + + +def _partition_boundary(dt: datetime) -> str: + """``RANGE COLUMNS(JTI)`` upper bound (exclusive) for tokens created before ``dt``.""" + return str(uuid7_from_datetime(dt, randomize=False)).replace("-", "") + + +def plan_partition_maintenance( + existing_months: list[datetime], + now: datetime, + retention_months: int, + months_ahead: int, +) -> tuple[list[datetime], list[datetime]]: + """Decide which monthly ``RefreshTokens`` partitions to drop and to add. + + ``existing_months`` are the month-start datetimes of the existing + ``p__`` partitions (excluding ``p_future``). Returns + ``(months_to_drop, months_to_add)`` as month-start datetimes. + """ + existing = sorted(existing_months) + + # A partition for month ``m`` holds tokens created before ``m + 1 month``, so + # the whole partition is expired once that upper bound is older than the + # retention horizon. Keeping ``retention_months`` worth of partitions never + # drops a token younger than that many calendar months. + horizon = now - relativedelta(months=retention_months) + months_to_drop = [m for m in existing if m + relativedelta(months=1) <= horizon] + + # Ensure a partition exists for every month up to ``now + months_ahead`` by + # appending months above the highest existing partition. + target_last = _month_start(now) + relativedelta(months=months_ahead) + cursor = max(existing) if existing else _month_start(now) - relativedelta(months=1) + months_to_add: list[datetime] = [] + while cursor < target_last: + cursor += relativedelta(months=1) + months_to_add.append(cursor) + + return months_to_drop, months_to_add + class AuthDB(BaseSQLDB): metadata = AuthDBBase.metadata @@ -67,8 +121,8 @@ async def post_create(cls, conn: AsyncConnection) -> None: partition_list = [] for name, limit in pairwise(dates): partition_list.append( - f"PARTITION p_{name.year}_{name.month} " - f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')" + f"PARTITION {_partition_name(name)} " + f"VALUES LESS THAN ('{_partition_boundary(limit)}')" ) partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)") @@ -340,37 +394,84 @@ async def revoke_user_refresh_tokens(self, subject): .values(status=RefreshTokenStatus.REVOKED) ) - async def clean_expired_refresh_tokens(self, max_validity: int) -> int: - """Delete expired refresh tokens. + async def maintain_refresh_token_partitions( + self, + retention_months: int, + months_ahead: int = PARTITION_MONTHS_AHEAD, + ) -> None: + """Maintain the monthly partitions of the RefreshTokens table. - max_validity: Maximum validity time in minutes for refresh tokens. + Drops partitions whose entire month is older than ``retention_months`` + and adds partitions ahead of time so the ``p_future`` catch-all never + fills. Cleanup of expired refresh tokens is achieved by dropping whole + partitions rather than deleting rows. + + Only implemented for MySQL; raises ``NotImplementedError`` for any other + dialect (the table is only partitioned on MySQL). """ - expired_date = str( - uuid7_from_datetime(substract_date(minutes=max_validity), randomize=False) - ) - stmt_expired = delete(RefreshTokens).where( - RefreshTokens.status == RefreshTokenStatus.CREATED, - RefreshTokens.jti < expired_date, + dialect = self.conn.dialect.name + if dialect != "mysql": + raise NotImplementedError( + "Refresh token partition maintenance is only implemented for " + f"MySQL, not {dialect!r}" + ) + + check_partition_query = text( + "SELECT PARTITION_NAME FROM information_schema.partitions " + "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'RefreshTokens' " + "AND PARTITION_NAME IS NOT NULL" ) - res_expired = await self.conn.execute(stmt_expired) + partition_names = (await self.conn.execute(check_partition_query)).all() - return res_expired.rowcount + existing_months = [] + for (name,) in partition_names: + if match := re.fullmatch(r"p_(\d+)_(\d+)", name): + existing_months.append( + datetime(int(match.group(1)), int(match.group(2)), 1, tzinfo=UTC) + ) - async def clean_revoked_refresh_tokens(self, max_retention: int) -> int: - """Delete old revoked refresh tokens. + if not existing_months: + logger.warning( + "RefreshTokens is not partitioned; skipping partition maintenance. " + "Partition the table manually (see AuthDB.post_create)." + ) + return - max_retention: Maximum retention time in minutes for revoked refresh tokens. - """ - revoked_date = str( - uuid7_from_datetime(substract_date(minutes=max_retention), randomize=False) + months_to_drop, months_to_add = plan_partition_maintenance( + existing_months, + now=datetime.now(tz=UTC), + retention_months=retention_months, + months_ahead=months_ahead, ) - stmt_revoked = delete(RefreshTokens).where( - RefreshTokens.status == RefreshTokenStatus.REVOKED, - RefreshTokens.jti < revoked_date, - ) - res_revoked = await self.conn.execute(stmt_revoked) - return res_revoked.rowcount + # Add new partitions first, by splitting the p_future catch-all. + if months_to_add: + new_partitions = [ + f"PARTITION {_partition_name(m)} " + f"VALUES LESS THAN ('{_partition_boundary(m + relativedelta(months=1))}')" + for m in months_to_add + ] + new_partitions.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)") + await self.conn.execute( + text( + "ALTER TABLE RefreshTokens REORGANIZE PARTITION p_future INTO (" + + ", ".join(new_partitions) + + ")" + ) + ) + + # Then drop the partitions whose whole month is past the retention horizon. + if months_to_drop: + drop_names = ", ".join(_partition_name(m) for m in months_to_drop) + await self.conn.execute( + text(f"ALTER TABLE RefreshTokens DROP PARTITION {drop_names}") + ) + + logger.info( + "Refresh token partition maintenance: added %d, dropped %d", + len(months_to_add), + len(months_to_drop), + ) async def clean_expired_authorization_flows(self, max_retention: int) -> int: """Delete old authorization flows. diff --git a/diracx-db/tests/auth/test_partitions.py b/diracx-db/tests/auth/test_partitions.py new file mode 100644 index 000000000..979981a8b --- /dev/null +++ b/diracx-db/tests/auth/test_partitions.py @@ -0,0 +1,185 @@ +"""Tests for the RefreshTokens partition-maintenance logic. + +The pure planner (``plan_partition_maintenance``) and the name/boundary helpers +are dialect-independent and exercised directly here. The MySQL-only executor +(``maintain_refresh_token_partitions``) cannot be run against the in-memory +SQLite test database, so we only assert that it refuses to run on SQLite. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +import pytest +from dateutil.relativedelta import relativedelta + +from diracx.db.sql.auth.db import ( + AuthDB, + _partition_boundary, + _partition_name, + plan_partition_maintenance, +) +from diracx.db.sql.utils import uuid7_from_datetime + + +def m(year: int, month: int) -> datetime: + """Month-start datetime helper.""" + return datetime(year, month, 1, tzinfo=UTC) + + +@pytest.fixture +async def auth_db(tmp_path): + auth_db = AuthDB("sqlite+aiosqlite:///:memory:") + async with auth_db.engine_context(): + async with auth_db.engine.begin() as conn: + await conn.run_sync(auth_db.metadata.create_all) + yield auth_db + + +# --- helpers --------------------------------------------------------------- + + +def test_partition_name(): + assert _partition_name(m(2026, 3)) == "p_2026_3" + assert _partition_name(m(2026, 12)) == "p_2026_12" + + +def test_partition_boundary_matches_uuid7(): + dt = m(2026, 4) + boundary = _partition_boundary(dt) + # The boundary is the dash-stripped lowest UUIDv7 for the timestamp. + assert boundary == str(uuid7_from_datetime(dt, randomize=False)).replace("-", "") + assert len(boundary) == 32 # 32 hex chars, no dashes + + +def test_partition_boundary_is_monotonic(): + # The executor relies on lexical ordering of the JTI string boundaries. + assert _partition_boundary(m(2026, 1)) < _partition_boundary(m(2026, 2)) + assert _partition_boundary(m(2026, 12)) < _partition_boundary(m(2027, 1)) + + +# --- planner: drop --------------------------------------------------------- + + +def test_plan_drops_only_fully_expired_partitions(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 9)] # Jan..Aug 2026 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=0 + ) + + # A partition for month X has upper bound X+1mo; drop when that is older than + # now - 1 month (2026-05-15). Jan..Apr have bounds Feb1..May1 (all <= May15). + assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)] + + +def test_plan_drop_boundary_is_inclusive(): + # Upper bound exactly equal to the horizon must be dropped (<=). + now = m(2026, 6) + # now - 1 month == 2026-05-01 + existing = [m(2026, 4), m(2026, 5)] # bounds: May1, Jun1 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=0 + ) + assert to_drop == [m(2026, 4)] # May1 <= May1 drops April; June kept + + +def test_plan_keeps_last_six_months_by_default(): + # The deployment policy: keep the last 6 months worth of refresh tokens. + now = datetime(2026, 7, 15, tzinfo=UTC) + existing = [m(2025, month) for month in range(6, 13)] + [ + m(2026, month) for month in range(1, 8) + ] # 2025-06 .. 2026-07 + + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=0 + ) + + # Horizon is 2026-01-15: nothing from the last 6 months is dropped. + assert all(d < m(2026, 1) for d in to_drop) + assert m(2025, 12) in to_drop + assert m(2026, 1) not in to_drop + assert max(to_drop) == m(2025, 12) + + +def test_plan_keeps_everything_when_retention_is_huge(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 9)] + to_drop, _ = plan_partition_maintenance( + existing, now=now, retention_months=120, months_ahead=0 + ) + assert to_drop == [] + + +# --- planner: add ---------------------------------------------------------- + + +def test_plan_adds_months_up_to_horizon(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, 7)] # highest existing partition is July + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + # target_last = month_start(now) + 3 = 2026-09; append above July. + assert to_add == [m(2026, 8), m(2026, 9)] + + +def test_plan_adds_nothing_when_buffer_already_covered(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(6, 10)] # Jun..Sep + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + assert to_add == [] # highest existing (Sep) already == now+3mo + + +def test_plan_crosses_year_boundary(): + now = datetime(2026, 11, 15, tzinfo=UTC) + existing = [m(2026, 11)] + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=3 + ) + assert to_add == [m(2026, 12), m(2027, 1), m(2027, 2)] + + +def test_plan_empty_existing_seeds_from_current_month(): + now = datetime(2026, 6, 15, tzinfo=UTC) + _, to_add = plan_partition_maintenance( + [], now=now, retention_months=6, months_ahead=2 + ) + # No partitions yet: seed current month + buffer. + assert to_add == [m(2026, 6), m(2026, 7), m(2026, 8)] + + +def test_plan_combined_drop_and_add(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, month) for month in range(1, 8)] # Jan..Jul + to_drop, to_add = plan_partition_maintenance( + existing, now=now, retention_months=1, months_ahead=2 + ) + assert to_drop == [m(2026, 1), m(2026, 2), m(2026, 3), m(2026, 4)] + assert to_add == [m(2026, 8)] # target_last = 2026-08, append above July + + +def test_plan_added_months_are_contiguous_and_increasing(): + now = datetime(2026, 6, 15, tzinfo=UTC) + existing = [m(2026, 6)] + _, to_add = plan_partition_maintenance( + existing, now=now, retention_months=6, months_ahead=12 + ) + # Each added month is exactly one month after the previous. + for previous, current in zip(to_add, to_add[1:]): + assert current == previous + relativedelta(months=1) + assert to_add[0] == m(2026, 7) + assert to_add[-1] == m(2027, 6) # now + 12 months + + +# --- executor: dialect guard ---------------------------------------------- + + +async def test_maintain_partitions_requires_mysql(auth_db: AuthDB): + async with auth_db as auth_db: + with pytest.raises(NotImplementedError, match="MySQL"): + await auth_db.maintain_refresh_token_partitions(retention_months=6) diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 39fc69b37..28d6dfe9d 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -257,46 +257,3 @@ async def test_get_refresh_tokens(auth_db: AuthDB): # Check the number of retrieved refresh tokens (should be 3 refresh tokens) assert len(refresh_tokens) == 2 - - -async def test_clean_refresh_tokens(auth_db: AuthDB): - # Insert two refresh tokens - jtis = [] - async with auth_db as auth_db: - for _ in range(2): - jti = uuid7() - await auth_db.insert_refresh_token( - jti, - "subject", - "scope", - ) - jtis.append(jti) - - # Revoke one of the refresh token - async with auth_db as auth_db: - await auth_db.revoke_refresh_token(jtis[0]) - - # Check the number of deleted refresh tokens (should be 0) - async with auth_db as auth_db: - deleted_expired = await auth_db.clean_expired_refresh_tokens(max_validity=10) - assert deleted_expired == 0 - - async with auth_db as auth_db: - deleted_revoked = await auth_db.clean_revoked_refresh_tokens(max_retention=30) - assert deleted_revoked == 0 - - # Check the number of deleted refresh tokens (should be 1 of each) - async with auth_db as auth_db: - deleted_expired = await auth_db.clean_expired_refresh_tokens(max_validity=0) - assert deleted_expired == 1 - - async with auth_db as auth_db: - deleted_revoked = await auth_db.clean_revoked_refresh_tokens(max_retention=0) - assert deleted_revoked == 1 - - # Get all refresh tokens (Admin) - async with auth_db as auth_db: - refresh_tokens = await auth_db.get_user_refresh_tokens() - - # Check the number of retrieved refresh tokens (should be 0) - assert len(refresh_tokens) == 0 diff --git a/diracx-logic/src/diracx/logic/__main__.py b/diracx-logic/src/diracx/logic/__main__.py index d7b7674a5..e54837092 100644 --- a/diracx-logic/src/diracx/logic/__main__.py +++ b/diracx-logic/src/diracx/logic/__main__.py @@ -91,8 +91,8 @@ async def delete_jwk(args): async def cleanup_authdb(args): - """Delete expired tokens and flows from the AuthDB.""" - logger.info("Deleting expired tokens and flows") + """Maintain AuthDB partitions and remove expired flows.""" + logger.info("Maintaining AuthDB partitions and removing expired flows") import os from diracx.core.settings import AuthSettings @@ -138,7 +138,7 @@ def parse_args(): delete_jwk_parser.set_defaults(func=delete_jwk) cleanup_authdb_parser = subparsers.add_parser( - "cleanup-authdb", help="Delete expired tokens and flows from the AuthDB" + "cleanup-authdb", help="Maintain AuthDB partitions and remove expired flows" ) cleanup_authdb_parser.set_defaults(func=cleanup_authdb) diff --git a/diracx-logic/src/diracx/logic/auth/management.py b/diracx-logic/src/diracx/logic/auth/management.py index 4dfab8ce6..a70ed96a3 100644 --- a/diracx-logic/src/diracx/logic/auth/management.py +++ b/diracx-logic/src/diracx/logic/auth/management.py @@ -65,16 +65,15 @@ async def revoke_refresh_token_by_refresh_token( async def cleanup_expired_data(auth_db: AuthDB, settings: AuthSettings) -> None: - """Remove expired data from the auth database.""" - expired_tokens = await auth_db.clean_expired_refresh_tokens( - max_validity=settings.refresh_token_expire_minutes, - ) - logger.info("Deleted %d expired refresh tokens", expired_tokens) + """Remove expired data from the auth database. - revoked_tokens = await auth_db.clean_revoked_refresh_tokens( - max_retention=settings.revoked_refresh_token_retention_minutes, + Expired refresh tokens are removed by dropping whole monthly partitions of + the RefreshTokens table (see ``AuthDB.maintain_refresh_token_partitions``). + The flow tables are not partitioned, so their expired rows are deleted. + """ + await auth_db.maintain_refresh_token_partitions( + retention_months=settings.refresh_token_retention_months, ) - logger.info("Deleted %d revoked refresh tokens", revoked_tokens) auth = await auth_db.clean_expired_authorization_flows( max_retention=settings.completed_flow_retention_minutes, diff --git a/docs/admin/reference/env-variables.md b/docs/admin/reference/env-variables.md index ea9a04a49..d81d1e85e 100644 --- a/docs/admin/reference/env-variables.md +++ b/docs/admin/reference/env-variables.md @@ -120,14 +120,15 @@ Expiration time in minutes for refresh tokens. The maximum lifetime of refresh tokens before they must be re-issued through a new authentication flow. Default: 60 minutes. -### `DIRACX_SERVICE_AUTH_REVOKED_REFRESH_TOKEN_RETENTION_MINUTES` +### `DIRACX_SERVICE_AUTH_REFRESH_TOKEN_RETENTION_MONTHS` -*Optional*, default value: `43200` +*Optional*, default value: `6` -Retention time in minutes for revoked refresh tokens. +Retention time in months for refresh tokens. -The maximum retention time of refresh tokens after being -revoked and before they are deleted. Default: 43200 minutes (30 days). +Refresh tokens live in monthly partitions that are dropped once the whole +month is older than this many months. It is therefore the longest a refresh +token (revoked or not) is kept before removal. Default: 6 months. ### `DIRACX_SERVICE_AUTH_AVAILABLE_PROPERTIES` From 08879a0b381577eead5fcc32b1a85f3811021195 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Thu, 4 Jun 2026 13:20:32 +0200 Subject: [PATCH 2/2] feat: batch-delete expired auth flows and switch retention to days Replace the single bulk DELETE of expired authorization/device flows with a batched loop that deletes up to FLOW_DELETE_BATCH_SIZE rows per transaction, bounding lock usage on large tables. Add creation_time indexes on both flow tables to keep the cutoff scans cheap. Rename the completed_flow_retention_minutes setting to expired_flow_retention_days (default 7 days) and update the retention vs. expiration validation accordingly. --- diracx-core/src/diracx/core/settings.py | 23 +++--- diracx-db/src/diracx/db/sql/auth/db.py | 72 ++++++++++++++----- diracx-db/src/diracx/db/sql/auth/schema.py | 4 ++ .../tests/auth/test_authorization_flow.py | 36 ++++++++-- diracx-db/tests/auth/test_device_flow.py | 32 +++++++-- .../src/diracx/logic/auth/management.py | 4 +- docs/admin/reference/env-variables.md | 11 +-- 7 files changed, 134 insertions(+), 48 deletions(-) diff --git a/diracx-core/src/diracx/core/settings.py b/diracx-core/src/diracx/core/settings.py index 820de242c..78d11ae14 100644 --- a/diracx-core/src/diracx/core/settings.py +++ b/diracx-core/src/diracx/core/settings.py @@ -158,15 +158,15 @@ class AuthSettings(ServiceSettingsBase): @model_validator(mode="after") def check_retention_greater_than_expiration(self) -> Self: """Ensure retention times are bigger than expiration times to avoid deleting valid flows.""" - if self.completed_flow_retention_minutes <= ( - self.device_flow_expiration_seconds / 60 - ) or self.completed_flow_retention_minutes <= ( - self.authorization_flow_expiration_seconds / 60 + retention_seconds = self.expired_flow_retention_days * 86400 + if ( + retention_seconds <= self.device_flow_expiration_seconds + or retention_seconds <= self.authorization_flow_expiration_seconds ): raise ValueError( - f"completed_flow_retention_minutes ({self.completed_flow_retention_minutes} minutes) must be bigger" - f" than device_flow_expiration_seconds ({self.device_flow_expiration_seconds / 60} minutes) and" - f" authorization_flow_expiration_seconds: ({self.authorization_flow_expiration_seconds / 60} minutes)" + f"expired_flow_retention_days ({self.expired_flow_retention_days} days) must be bigger than" + f" device_flow_expiration_seconds ({self.device_flow_expiration_seconds} s) and" + f" authorization_flow_expiration_seconds ({self.authorization_flow_expiration_seconds} s)" ) return self @@ -202,11 +202,12 @@ def check_retention_greater_than_expiration(self) -> Self: before it must be exchanged for tokens. Default: 5 minutes. """ - completed_flow_retention_minutes: int = 60 - """Retention time in minutes for completed flow. + expired_flow_retention_days: int = 7 + """Retention time in days for device and authorization flows. - The maximum retention time of flow after being completed - and before they are deleted. Default: 60 minutes. + Expired flow rows are deleted (in batches) once they are older than this + many days. Must be larger than the device and authorization flow expiration + times so that still-valid flows are never removed. Default: 7 days. """ state_key: FernetKey diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index a35a0ad19..c7963918d 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -38,6 +38,10 @@ # of "now" so the ``p_future`` catch-all partition never accumulates rows. PARTITION_MONTHS_AHEAD = 12 +# Maximum number of flow rows deleted per transaction when cleaning expired +# device/authorization flows, to bound lock usage on large tables. +FLOW_DELETE_BATCH_SIZE = 50_000 + def _month_start(dt: datetime) -> datetime: """Truncate ``dt`` to the first instant of its month.""" @@ -473,28 +477,60 @@ async def maintain_refresh_token_partitions( len(months_to_drop), ) - async def clean_expired_authorization_flows(self, max_retention: int) -> int: - """Delete old authorization flows. + async def _delete_flows_in_batches( + self, table, pk_column, cutoff: datetime, batch_size: int + ) -> int: + """Delete rows of ``table`` created before ``cutoff`` in batches. - max_retention: Maximum retention time in minutes for expired authorization flows. - Must be bigger than authorization_flow_expiration_seconds. + Each batch deletes up to ``batch_size`` rows in its own transaction to + bound lock usage on large tables. Returns the number of rows deleted. """ - stmt_auth = delete(AuthorizationFlows).where( - AuthorizationFlows.creation_time < substract_date(minutes=max_retention), + total = 0 + while True: + pks = ( + ( + await self.conn.execute( + select(pk_column) + .where(table.creation_time < cutoff) + .limit(batch_size) + ) + ) + .scalars() + .all() + ) + if not pks: + break + await self.conn.execute(delete(table).where(pk_column.in_(pks))) + await self.conn.commit() + total += len(pks) + if len(pks) < batch_size: + break + return total + + async def clean_expired_authorization_flows( + self, retention_days: int, batch_size: int = FLOW_DELETE_BATCH_SIZE + ) -> int: + """Delete authorization flows older than ``retention_days`` days. + + ``retention_days`` must be bigger than authorization_flow_expiration_seconds. + """ + return await self._delete_flows_in_batches( + AuthorizationFlows, + AuthorizationFlows.uuid, + substract_date(days=retention_days), + batch_size, ) - res_auth = await self.conn.execute(stmt_auth) - - return res_auth.rowcount - async def clean_expired_device_flows(self, max_retention: int) -> int: - """Delete old device flows. + async def clean_expired_device_flows( + self, retention_days: int, batch_size: int = FLOW_DELETE_BATCH_SIZE + ) -> int: + """Delete device flows older than ``retention_days`` days. - max_retention: Maximum retention time in minutes for expired device flows. - Must be bigger than device_flow_expiration_seconds. + ``retention_days`` must be bigger than device_flow_expiration_seconds. """ - stmt_device = delete(DeviceFlows).where( - DeviceFlows.creation_time < substract_date(minutes=max_retention), + return await self._delete_flows_in_batches( + DeviceFlows, + DeviceFlows.user_code, + substract_date(days=retention_days), + batch_size, ) - res_device = await self.conn.execute(stmt_device) - - return res_device.rowcount diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 0c7554318..d6053153b 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -71,6 +71,8 @@ class DeviceFlows(Base): ) # Should be a hash id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken") + __table_args__ = (Index("index_device_flows_creation_time", creation_time),) + class AuthorizationFlows(Base): __tablename__ = "AuthorizationFlows" @@ -87,6 +89,8 @@ class AuthorizationFlows(Base): code: Mapped[Optional[str255]] = mapped_column("Code") # Should be a hash id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken") + __table_args__ = (Index("index_authorization_flows_creation_time", creation_time),) + class RefreshTokenStatus(Enum): """CREATED -> REVOKED. diff --git a/diracx-db/tests/auth/test_authorization_flow.py b/diracx-db/tests/auth/test_authorization_flow.py index 1fb41731e..1159f8d92 100644 --- a/diracx-db/tests/auth/test_authorization_flow.py +++ b/diracx-db/tests/auth/test_authorization_flow.py @@ -104,17 +104,41 @@ async def test_clean_authorization_flows(auth_db: AuthDB): await auth_db.update_authorization_flow_status(code3, FlowStatus.DONE) await auth_db.update_authorization_flow_status(code4, FlowStatus.ERROR) - # Check the number of deleted authorization flow (should be 0) + # Nothing is older than the retention window: nothing deleted. async with auth_db as auth_db: - deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=30) + deleted_auth = await auth_db.clean_expired_authorization_flows( + retention_days=30 + ) assert deleted_auth == 0 - # Check the number of deleted authorization flow (should be 4: 1 PENDING, 1 READY, 1 DONE, 1 ERROR) + # retention_days=0 deletes everything (1 PENDING, 1 READY, 1 DONE, 1 ERROR). async with auth_db as auth_db: - deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=0) + deleted_auth = await auth_db.clean_expired_authorization_flows(retention_days=0) assert deleted_auth == 4 - # Check the number of deleted authorization flow (should be 0 because there is nothing left to delete) + # Nothing left to delete. async with auth_db as auth_db: - deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=0) + deleted_auth = await auth_db.clean_expired_authorization_flows(retention_days=0) assert deleted_auth == 0 + + +async def test_clean_authorization_flows_batched(auth_db: AuthDB): + # Insert five flows, then delete them all with a small batch size to + # exercise the multi-batch loop (each batch commits its own transaction). + async with auth_db as auth_db: + for i in range(5): + await auth_db.insert_authorization_flow( + f"client_id{i}", "scope", "code_challenge", "S256", "redirect_uri" + ) + + async with auth_db as auth_db: + deleted = await auth_db.clean_expired_authorization_flows( + retention_days=0, batch_size=2 + ) + assert deleted == 5 + + async with auth_db as auth_db: + deleted = await auth_db.clean_expired_authorization_flows( + retention_days=0, batch_size=2 + ) + assert deleted == 0 diff --git a/diracx-db/tests/auth/test_device_flow.py b/diracx-db/tests/auth/test_device_flow.py index 168888f7e..a7ce10b4a 100644 --- a/diracx-db/tests/auth/test_device_flow.py +++ b/diracx-db/tests/auth/test_device_flow.py @@ -157,17 +157,37 @@ async def test_clean_device_flows(auth_db: AuthDB): await auth_db.update_device_flow_status(device_code3, FlowStatus.DONE) await auth_db.update_device_flow_status(device_code4, FlowStatus.ERROR) - # Check the number of deleted device flows (should be 0) + # Nothing is older than the retention window: nothing deleted. async with auth_db as auth_db: - deleted_device = await auth_db.clean_expired_device_flows(max_retention=30) + deleted_device = await auth_db.clean_expired_device_flows(retention_days=30) assert deleted_device == 0 - # Check the number of deleted device flows (should be 4: 1 PENDING, 1 READY, 1 DONE, 1 ERROR) + # retention_days=0 deletes everything (1 PENDING, 1 READY, 1 DONE, 1 ERROR). async with auth_db as auth_db: - deleted_device = await auth_db.clean_expired_device_flows(max_retention=0) + deleted_device = await auth_db.clean_expired_device_flows(retention_days=0) assert deleted_device == 4 - # Check the number of deleted device flow (should be 0 because there is nothing left to delete) + # Nothing left to delete. async with auth_db as auth_db: - deleted_device = await auth_db.clean_expired_device_flows(max_retention=0) + deleted_device = await auth_db.clean_expired_device_flows(retention_days=0) assert deleted_device == 0 + + +async def test_clean_device_flows_batched(auth_db: AuthDB): + # Insert five device flows, then delete them all with a small batch size to + # exercise the multi-batch loop (each batch commits its own transaction). + async with auth_db as auth_db: + for i in range(5): + await auth_db.insert_device_flow(f"client_id{i}", "scope") + + async with auth_db as auth_db: + deleted = await auth_db.clean_expired_device_flows( + retention_days=0, batch_size=2 + ) + assert deleted == 5 + + async with auth_db as auth_db: + deleted = await auth_db.clean_expired_device_flows( + retention_days=0, batch_size=2 + ) + assert deleted == 0 diff --git a/diracx-logic/src/diracx/logic/auth/management.py b/diracx-logic/src/diracx/logic/auth/management.py index a70ed96a3..dc5f4f88d 100644 --- a/diracx-logic/src/diracx/logic/auth/management.py +++ b/diracx-logic/src/diracx/logic/auth/management.py @@ -76,11 +76,11 @@ async def cleanup_expired_data(auth_db: AuthDB, settings: AuthSettings) -> None: ) auth = await auth_db.clean_expired_authorization_flows( - max_retention=settings.completed_flow_retention_minutes, + retention_days=settings.expired_flow_retention_days, ) logger.info("Deleted %d expired authorization flows", auth) device = await auth_db.clean_expired_device_flows( - max_retention=settings.completed_flow_retention_minutes, + retention_days=settings.expired_flow_retention_days, ) logger.info("Deleted %d expired device flows", device) diff --git a/docs/admin/reference/env-variables.md b/docs/admin/reference/env-variables.md index d81d1e85e..1c1405ccd 100644 --- a/docs/admin/reference/env-variables.md +++ b/docs/admin/reference/env-variables.md @@ -57,14 +57,15 @@ Expiration time in seconds for authorization code flow. The time window during which the authorization code remains valid before it must be exchanged for tokens. Default: 5 minutes. -### `DIRACX_SERVICE_AUTH_COMPLETED_FLOW_RETENTION_MINUTES` +### `DIRACX_SERVICE_AUTH_EXPIRED_FLOW_RETENTION_DAYS` -*Optional*, default value: `60` +*Optional*, default value: `7` -Retention time in minutes for completed flow. +Retention time in days for device and authorization flows. -The maximum retention time of flow after being completed -and before they are deleted. Default: 60 minutes. +Expired flow rows are deleted (in batches) once they are older than this +many days. Must be larger than the device and authorization flow expiration +times so that still-valid flows are never removed. Default: 7 days. ### `DIRACX_SERVICE_AUTH_STATE_KEY`