Skip to content

Commit 64ba8f0

Browse files
committed
feat: add methods to expire snapshots by IDs and older than a timestamp, with updated docstrings
1 parent 1c2f631 commit 64ba8f0

1 file changed

Lines changed: 46 additions & 12 deletions

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -919,9 +919,10 @@ class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
919919
_requirements: Tuple[TableRequirement, ...] = ()
920920

921921
def _commit(self) -> UpdatesAndRequirements:
922-
"""Commit the staged updates and requirements.
922+
"""
923+
Commit the staged updates and requirements.
923924
924-
This will remove the snapshots with the given IDs.
925+
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
925926
926927
Returns:
927928
Tuple of updates and requirements to be committed,
@@ -935,37 +936,70 @@ def _commit(self) -> UpdatesAndRequirements:
935936
return self._updates, self._requirements
936937

937938
def _get_protected_snapshot_ids(self) -> Set[int]:
938-
"""Get the IDs of protected snapshots.
939+
"""
940+
Get the IDs of protected snapshots.
939941
940-
These are the HEAD snapshots of all branches and all tagged snapshots.
941-
These ids are to be excluded from expiration.
942+
These are the HEAD snapshots of all branches and all tagged snapshots. These ids are to be excluded from expiration.
942943
943944
Returns:
944945
Set of protected snapshot IDs to exclude from expiration.
945946
"""
946-
protected_ids = set()
947+
protected_ids: Set[int] = set()
947948

948949
for ref in self._transaction.table_metadata.refs.values():
949950
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]:
950951
protected_ids.add(ref.snapshot_id)
951952

952953
return protected_ids
953954

954-
def by_id(self, snapshot_id: int) -> ExpireSnapshots:
955-
"""Expire a snapshot by its ID.
955+
def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
956+
"""
957+
Expire a snapshot by its ID.
958+
959+
This will mark the snapshot for expiration.
956960
957961
Args:
958962
snapshot_id (int): The ID of the snapshot to expire.
959-
960963
Returns:
961964
This for method chaining.
962965
"""
963966
if self._transaction.table_metadata.snapshot_by_id(snapshot_id) is None:
964967
raise ValueError(f"Snapshot with ID {snapshot_id} does not exist.")
965968

966-
protected_ids = self._get_protected_snapshot_ids()
967-
if snapshot_id in protected_ids:
968-
raise ValueError(f"Cannot expire snapshot {snapshot_id} as it is referenced by a branch or tag.")
969+
if snapshot_id in self._get_protected_snapshot_ids():
970+
raise ValueError(f"Snapshot with ID {snapshot_id} is protected and cannot be expired.")
969971

970972
self._snapshot_ids_to_expire.add(snapshot_id)
973+
971974
return self
975+
976+
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots":
977+
"""
978+
Expire multiple snapshots by their IDs.
979+
980+
This will mark the snapshots for expiration.
981+
982+
Args:
983+
snapshot_ids (List[int]): List of snapshot IDs to expire.
984+
Returns:
985+
This for method chaining.
986+
"""
987+
for snapshot_id in snapshot_ids:
988+
self.expire_snapshot_by_id(snapshot_id)
989+
return self
990+
991+
def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
992+
"""
993+
Expire all unprotected snapshots with a timestamp older than a given value.
994+
995+
Args:
996+
timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired.
997+
998+
Returns:
999+
This for method chaining.
1000+
"""
1001+
protected_ids = self._get_protected_snapshot_ids()
1002+
for snapshot in self._transaction.table_metadata.snapshots:
1003+
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
1004+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
1005+
return self

0 commit comments

Comments
 (0)