Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions diracx-core/src/diracx/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ class AuthSettings(ServiceSettingsBase):
@model_validator(mode="after")
def check_retention_greater_than_expiration(self) -> Self:
"""Ensure retention times are bigger than expiration times to avoid deleting valid flows."""
if self.completed_flow_retention_minutes <= (
self.device_flow_expiration_seconds / 60
) or self.completed_flow_retention_minutes <= (
self.authorization_flow_expiration_seconds / 60
retention_seconds = self.expired_flow_retention_days * 86400
if (
retention_seconds <= self.device_flow_expiration_seconds
or retention_seconds <= self.authorization_flow_expiration_seconds
):
raise ValueError(
f"completed_flow_retention_minutes ({self.completed_flow_retention_minutes} minutes) must be bigger"
f" than device_flow_expiration_seconds ({self.device_flow_expiration_seconds / 60} minutes) and"
f" authorization_flow_expiration_seconds: ({self.authorization_flow_expiration_seconds / 60} minutes)"
f"expired_flow_retention_days ({self.expired_flow_retention_days} days) must be bigger than"
f" device_flow_expiration_seconds ({self.device_flow_expiration_seconds} s) and"
f" authorization_flow_expiration_seconds ({self.authorization_flow_expiration_seconds} s)"
)
return self

Expand Down Expand Up @@ -202,11 +202,12 @@ def check_retention_greater_than_expiration(self) -> Self:
before it must be exchanged for tokens. Default: 5 minutes.
"""

completed_flow_retention_minutes: int = 60
"""Retention time in minutes for completed flow.
expired_flow_retention_days: int = 7
"""Retention time in days for device and authorization flows.

The maximum retention time of flow after being completed
and before they are deleted. Default: 60 minutes.
Expired flow rows are deleted (in batches) once they are older than this
many days. Must be larger than the device and authorization flow expiration
times so that still-valid flows are never removed. Default: 7 days.
"""

state_key: FernetKey
Expand Down Expand Up @@ -251,11 +252,12 @@ def check_retention_greater_than_expiration(self) -> Self:
through a new authentication flow. Default: 60 minutes.
"""

revoked_refresh_token_retention_minutes: int = 43200
"""Retention time in minutes for revoked refresh tokens.
refresh_token_retention_months: int = 6
"""Retention time in months for refresh tokens.

The maximum retention time of refresh tokens after being
revoked and before they are deleted. Default: 43200 minutes (30 days).
Refresh tokens live in monthly partitions that are dropped once the whole
month is older than this many months. It is therefore the longest a refresh
token (revoked or not) is kept before removal. Default: 6 months.
"""

available_properties: set[SecurityProperty] = Field(
Expand Down
223 changes: 180 additions & 43 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
import re
import secrets
from datetime import UTC, datetime
from itertools import pairwise

from dateutil.relativedelta import relativedelta
from dateutil.rrule import MONTHLY, rrule
from sqlalchemy import delete, insert, select, text, update
from sqlalchemy.exc import IntegrityError, NoResultFound
Expand Down Expand Up @@ -32,6 +34,62 @@

logger = logging.getLogger(__name__)

# Always keep at least this many months of future RefreshTokens partitions ahead
# of "now" so the ``p_future`` catch-all partition never accumulates rows.
PARTITION_MONTHS_AHEAD = 12

# Maximum number of flow rows deleted per transaction when cleaning expired
# device/authorization flows, to bound lock usage on large tables.
FLOW_DELETE_BATCH_SIZE = 50_000


def _month_start(dt: datetime) -> datetime:
"""Truncate ``dt`` to the first instant of its month."""
return dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)


def _partition_name(month_start: datetime) -> str:
"""Name of the partition holding the tokens created during ``month_start``."""
return f"p_{month_start.year}_{month_start.month}"


def _partition_boundary(dt: datetime) -> str:
"""``RANGE COLUMNS(JTI)`` upper bound (exclusive) for tokens created before ``dt``."""
return str(uuid7_from_datetime(dt, randomize=False)).replace("-", "")


def plan_partition_maintenance(
existing_months: list[datetime],
now: datetime,
retention_months: int,
months_ahead: int,
) -> tuple[list[datetime], list[datetime]]:
"""Decide which monthly ``RefreshTokens`` partitions to drop and to add.

``existing_months`` are the month-start datetimes of the existing
``p_<year>_<month>`` partitions (excluding ``p_future``). Returns
``(months_to_drop, months_to_add)`` as month-start datetimes.
"""
existing = sorted(existing_months)

# A partition for month ``m`` holds tokens created before ``m + 1 month``, so
# the whole partition is expired once that upper bound is older than the
# retention horizon. Keeping ``retention_months`` worth of partitions never
# drops a token younger than that many calendar months.
horizon = now - relativedelta(months=retention_months)
months_to_drop = [m for m in existing if m + relativedelta(months=1) <= horizon]

# Ensure a partition exists for every month up to ``now + months_ahead`` by
# appending months above the highest existing partition.
target_last = _month_start(now) + relativedelta(months=months_ahead)
cursor = max(existing) if existing else _month_start(now) - relativedelta(months=1)
months_to_add: list[datetime] = []
while cursor < target_last:
cursor += relativedelta(months=1)
months_to_add.append(cursor)

return months_to_drop, months_to_add


class AuthDB(BaseSQLDB):
metadata = AuthDBBase.metadata
Expand Down Expand Up @@ -67,8 +125,8 @@ async def post_create(cls, conn: AsyncConnection) -> None:
partition_list = []
for name, limit in pairwise(dates):
partition_list.append(
f"PARTITION p_{name.year}_{name.month} "
f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
f"PARTITION {_partition_name(name)} "
f"VALUES LESS THAN ('{_partition_boundary(limit)}')"
)
partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")

Expand Down Expand Up @@ -340,60 +398,139 @@ async def revoke_user_refresh_tokens(self, subject):
.values(status=RefreshTokenStatus.REVOKED)
)

async def clean_expired_refresh_tokens(self, max_validity: int) -> int:
"""Delete expired refresh tokens.
async def maintain_refresh_token_partitions(
self,
retention_months: int,
months_ahead: int = PARTITION_MONTHS_AHEAD,
) -> None:
"""Maintain the monthly partitions of the RefreshTokens table.

max_validity: Maximum validity time in minutes for refresh tokens.
Drops partitions whose entire month is older than ``retention_months``
and adds partitions ahead of time so the ``p_future`` catch-all never
fills. Cleanup of expired refresh tokens is achieved by dropping whole
partitions rather than deleting rows.

Only implemented for MySQL; raises ``NotImplementedError`` for any other
dialect (the table is only partitioned on MySQL).
"""
expired_date = str(
uuid7_from_datetime(substract_date(minutes=max_validity), randomize=False)
)
stmt_expired = delete(RefreshTokens).where(
RefreshTokens.status == RefreshTokenStatus.CREATED,
RefreshTokens.jti < expired_date,
dialect = self.conn.dialect.name
if dialect != "mysql":
raise NotImplementedError(
"Refresh token partition maintenance is only implemented for "
f"MySQL, not {dialect!r}"
)

check_partition_query = text(
"SELECT PARTITION_NAME FROM information_schema.partitions "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = 'RefreshTokens' "
"AND PARTITION_NAME IS NOT NULL"
)
res_expired = await self.conn.execute(stmt_expired)
partition_names = (await self.conn.execute(check_partition_query)).all()

return res_expired.rowcount
existing_months = []
for (name,) in partition_names:
if match := re.fullmatch(r"p_(\d+)_(\d+)", name):
existing_months.append(
datetime(int(match.group(1)), int(match.group(2)), 1, tzinfo=UTC)
)

async def clean_revoked_refresh_tokens(self, max_retention: int) -> int:
"""Delete old revoked refresh tokens.
if not existing_months:
logger.warning(
"RefreshTokens is not partitioned; skipping partition maintenance. "
"Partition the table manually (see AuthDB.post_create)."
)
return

max_retention: Maximum retention time in minutes for revoked refresh tokens.
"""
revoked_date = str(
uuid7_from_datetime(substract_date(minutes=max_retention), randomize=False)
months_to_drop, months_to_add = plan_partition_maintenance(
existing_months,
now=datetime.now(tz=UTC),
retention_months=retention_months,
months_ahead=months_ahead,
)
stmt_revoked = delete(RefreshTokens).where(
RefreshTokens.status == RefreshTokenStatus.REVOKED,
RefreshTokens.jti < revoked_date,
)
res_revoked = await self.conn.execute(stmt_revoked)

return res_revoked.rowcount
# Add new partitions first, by splitting the p_future catch-all.
if months_to_add:
new_partitions = [
f"PARTITION {_partition_name(m)} "
f"VALUES LESS THAN ('{_partition_boundary(m + relativedelta(months=1))}')"
for m in months_to_add
]
new_partitions.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")
await self.conn.execute(
text(
"ALTER TABLE RefreshTokens REORGANIZE PARTITION p_future INTO ("
+ ", ".join(new_partitions)
+ ")"
)
)

# Then drop the partitions whose whole month is past the retention horizon.
if months_to_drop:
drop_names = ", ".join(_partition_name(m) for m in months_to_drop)
await self.conn.execute(
text(f"ALTER TABLE RefreshTokens DROP PARTITION {drop_names}")
)

logger.info(
"Refresh token partition maintenance: added %d, dropped %d",
len(months_to_add),
len(months_to_drop),
)

async def clean_expired_authorization_flows(self, max_retention: int) -> int:
"""Delete old authorization flows.
async def _delete_flows_in_batches(
self, table, pk_column, cutoff: datetime, batch_size: int
) -> int:
"""Delete rows of ``table`` created before ``cutoff`` in batches.

max_retention: Maximum retention time in minutes for expired authorization flows.
Must be bigger than authorization_flow_expiration_seconds.
Each batch deletes up to ``batch_size`` rows in its own transaction to
bound lock usage on large tables. Returns the number of rows deleted.
"""
stmt_auth = delete(AuthorizationFlows).where(
AuthorizationFlows.creation_time < substract_date(minutes=max_retention),
total = 0
while True:
pks = (
(
await self.conn.execute(
select(pk_column)
.where(table.creation_time < cutoff)
.limit(batch_size)
)
)
.scalars()
.all()
)
if not pks:
break
await self.conn.execute(delete(table).where(pk_column.in_(pks)))
await self.conn.commit()
total += len(pks)
if len(pks) < batch_size:
break
return total

async def clean_expired_authorization_flows(
self, retention_days: int, batch_size: int = FLOW_DELETE_BATCH_SIZE
) -> int:
"""Delete authorization flows older than ``retention_days`` days.

``retention_days`` must be bigger than authorization_flow_expiration_seconds.
"""
return await self._delete_flows_in_batches(
AuthorizationFlows,
AuthorizationFlows.uuid,
substract_date(days=retention_days),
batch_size,
)
res_auth = await self.conn.execute(stmt_auth)

return res_auth.rowcount
async def clean_expired_device_flows(
self, retention_days: int, batch_size: int = FLOW_DELETE_BATCH_SIZE
) -> int:
"""Delete device flows older than ``retention_days`` days.

async def clean_expired_device_flows(self, max_retention: int) -> int:
"""Delete old device flows.

max_retention: Maximum retention time in minutes for expired device flows.
Must be bigger than device_flow_expiration_seconds.
``retention_days`` must be bigger than device_flow_expiration_seconds.
"""
stmt_device = delete(DeviceFlows).where(
DeviceFlows.creation_time < substract_date(minutes=max_retention),
return await self._delete_flows_in_batches(
DeviceFlows,
DeviceFlows.user_code,
substract_date(days=retention_days),
batch_size,
)
res_device = await self.conn.execute(stmt_device)

return res_device.rowcount
4 changes: 4 additions & 0 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DeviceFlows(Base):
) # Should be a hash
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")

__table_args__ = (Index("index_device_flows_creation_time", creation_time),)


class AuthorizationFlows(Base):
__tablename__ = "AuthorizationFlows"
Expand All @@ -87,6 +89,8 @@ class AuthorizationFlows(Base):
code: Mapped[Optional[str255]] = mapped_column("Code") # Should be a hash
id_token: Mapped[Optional[dict[str, Any]]] = mapped_column("IDToken")

__table_args__ = (Index("index_authorization_flows_creation_time", creation_time),)


class RefreshTokenStatus(Enum):
"""CREATED -> REVOKED.
Expand Down
36 changes: 30 additions & 6 deletions diracx-db/tests/auth/test_authorization_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,41 @@ async def test_clean_authorization_flows(auth_db: AuthDB):
await auth_db.update_authorization_flow_status(code3, FlowStatus.DONE)
await auth_db.update_authorization_flow_status(code4, FlowStatus.ERROR)

# Check the number of deleted authorization flow (should be 0)
# Nothing is older than the retention window: nothing deleted.
async with auth_db as auth_db:
deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=30)
deleted_auth = await auth_db.clean_expired_authorization_flows(
retention_days=30
)
assert deleted_auth == 0

# Check the number of deleted authorization flow (should be 4: 1 PENDING, 1 READY, 1 DONE, 1 ERROR)
# retention_days=0 deletes everything (1 PENDING, 1 READY, 1 DONE, 1 ERROR).
async with auth_db as auth_db:
deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=0)
deleted_auth = await auth_db.clean_expired_authorization_flows(retention_days=0)
assert deleted_auth == 4

# Check the number of deleted authorization flow (should be 0 because there is nothing left to delete)
# Nothing left to delete.
async with auth_db as auth_db:
deleted_auth = await auth_db.clean_expired_authorization_flows(max_retention=0)
deleted_auth = await auth_db.clean_expired_authorization_flows(retention_days=0)
assert deleted_auth == 0


async def test_clean_authorization_flows_batched(auth_db: AuthDB):
# Insert five flows, then delete them all with a small batch size to
# exercise the multi-batch loop (each batch commits its own transaction).
async with auth_db as auth_db:
for i in range(5):
await auth_db.insert_authorization_flow(
f"client_id{i}", "scope", "code_challenge", "S256", "redirect_uri"
)

async with auth_db as auth_db:
deleted = await auth_db.clean_expired_authorization_flows(
retention_days=0, batch_size=2
)
assert deleted == 5

async with auth_db as auth_db:
deleted = await auth_db.clean_expired_authorization_flows(
retention_days=0, batch_size=2
)
assert deleted == 0
Loading
Loading