Skip to content

Commit 511ea91

Browse files
committed
Implement snapshot expiration with protection for branch/tag heads
Added ExpireSnapshots.expire_snapshots_older_than and expire_snapshots_by_ids methods to support expiring snapshots by timestamp and by multiple IDs. Ensured that protected snapshots (branch/tag heads) cannot be expired, both at the API and commit stages. Updated the expiration logic to always skip protected snapshots, even if they are accidentally included. Added and fixed tests to verify that protected snapshots are never expired and that expiration works as expected for unprotected snapshots. Improved test setup to accurately reflect post-commit metadata and to assert correct expiration behavior.
1 parent 1cac992 commit 511ea91

2 files changed

Lines changed: 140 additions & 1 deletion

File tree

pyiceberg/table/update/snapshot.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,12 +864,15 @@ def _commit(self) -> UpdatesAndRequirements:
864864
"""
865865
Commit the staged updates and requirements.
866866
867-
This will remove the snapshots with the given IDs.
867+
This will remove the snapshots with the given IDs, but will always skip protected snapshots (branch/tag heads).
868868
869869
Returns:
870870
Tuple of updates and requirements to be committed,
871871
as required by the calling parent apply functions.
872872
"""
873+
# Remove any protected snapshot IDs from the set to expire, just in case
874+
protected_ids = self._get_protected_snapshot_ids()
875+
self._snapshot_ids_to_expire -= protected_ids
873876
update = RemoveSnapshotsUpdate(snapshot_ids=self._snapshot_ids_to_expire)
874877
self._updates += (update,)
875878
return self._updates, self._requirements
@@ -911,3 +914,34 @@ def expire_snapshot_by_id(self, snapshot_id: int) -> ExpireSnapshots:
911914
self._snapshot_ids_to_expire.add(snapshot_id)
912915

913916
return self
917+
918+
def expire_snapshots_by_ids(self, snapshot_ids: List[int]) -> "ExpireSnapshots":
919+
"""
920+
Expire multiple snapshots by their IDs.
921+
922+
This will mark the snapshots for expiration.
923+
924+
Args:
925+
snapshot_ids (List[int]): List of snapshot IDs to expire.
926+
Returns:
927+
This for method chaining.
928+
"""
929+
for snapshot_id in snapshot_ids:
930+
self.expire_snapshot_by_id(snapshot_id)
931+
return self
932+
933+
def expire_snapshots_older_than(self, timestamp_ms: int) -> "ExpireSnapshots":
934+
"""
935+
Expire all unprotected snapshots with a timestamp older than a given value.
936+
937+
Args:
938+
timestamp_ms (int): Only snapshots with timestamp_ms < this value will be expired.
939+
940+
Returns:
941+
This for method chaining.
942+
"""
943+
protected_ids = self._get_protected_snapshot_ids()
944+
for snapshot in self._transaction.table_metadata.snapshots:
945+
if snapshot.timestamp_ms < timestamp_ms and snapshot.snapshot_id not in protected_ids:
946+
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
947+
return self

tests/table/test_expire_snapshots.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,108 @@ def test_expire_nonexistent_snapshot_raises(table_v2: Table) -> None:
117117
table_v2.expire_snapshots().expire_snapshot_by_id(NONEXISTENT_SNAPSHOT).commit()
118118

119119
table_v2.catalog.commit_table.assert_not_called()
120+
121+
122+
def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
123+
# Setup: two snapshots; both are old, but one is head/tag protected
124+
HEAD_SNAPSHOT = 3051729675574597004
125+
TAGGED_SNAPSHOT = 3055729675574597004
126+
127+
# Add snapshots to metadata for timestamp/protected test
128+
from types import SimpleNamespace
129+
130+
table_v2.metadata = table_v2.metadata.model_copy(
131+
update={
132+
"refs": {
133+
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
134+
"mytag": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
135+
},
136+
"snapshots": [
137+
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
138+
SimpleNamespace(snapshot_id=TAGGED_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
139+
],
140+
}
141+
)
142+
table_v2.catalog = MagicMock()
143+
144+
# Attempt to expire all snapshots before a future timestamp (so both are candidates)
145+
future_timestamp = 9999999999999 # Far in the future, after any real snapshot
146+
147+
# Mock the catalog's commit_table to return the current metadata (simulate no change)
148+
mock_response = CommitTableResponse(
149+
metadata=table_v2.metadata, # protected snapshots remain
150+
metadata_location="mock://metadata/location",
151+
uuid=uuid4(),
152+
)
153+
table_v2.catalog.commit_table.return_value = mock_response
154+
155+
table_v2.expire_snapshots().expire_snapshots_older_than(future_timestamp).commit()
156+
# Update metadata to reflect the commit (as in other tests)
157+
table_v2.metadata = mock_response.metadata
158+
159+
# Both protected snapshots should remain
160+
remaining_ids = {s.snapshot_id for s in table_v2.metadata.snapshots}
161+
assert HEAD_SNAPSHOT in remaining_ids
162+
assert TAGGED_SNAPSHOT in remaining_ids
163+
164+
# No snapshots should have been expired (commit_table called, but with empty snapshot_ids)
165+
args, kwargs = table_v2.catalog.commit_table.call_args
166+
updates = args[2] if len(args) > 2 else ()
167+
# Find RemoveSnapshotsUpdate in updates
168+
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
169+
assert remove_update is not None
170+
assert remove_update.snapshot_ids == []
171+
172+
173+
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
174+
"""Test that multiple unprotected snapshots can be expired by IDs."""
175+
EXPIRE_SNAPSHOT_1 = 3051729675574597004
176+
EXPIRE_SNAPSHOT_2 = 3051729675574597005
177+
KEEP_SNAPSHOT = 3055729675574597004
178+
179+
mock_response = CommitTableResponse(
180+
metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}),
181+
metadata_location="mock://metadata/location",
182+
uuid=uuid4(),
183+
)
184+
table_v2.catalog = MagicMock()
185+
table_v2.catalog.commit_table.return_value = mock_response
186+
187+
# Remove any refs that protect the snapshots to be expired
188+
table_v2.metadata = table_v2.metadata.model_copy(
189+
update={
190+
"refs": {
191+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
192+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
193+
}
194+
}
195+
)
196+
197+
# Add snapshots to metadata for multi-id test
198+
from types import SimpleNamespace
199+
200+
table_v2.metadata = table_v2.metadata.model_copy(
201+
update={
202+
"refs": {
203+
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
204+
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
205+
},
206+
"snapshots": [
207+
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
208+
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_2, timestamp_ms=1, parent_snapshot_id=None),
209+
SimpleNamespace(snapshot_id=KEEP_SNAPSHOT, timestamp_ms=2, parent_snapshot_id=None),
210+
],
211+
}
212+
)
213+
214+
# Assert fixture data
215+
assert all(ref.snapshot_id not in (EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2) for ref in table_v2.metadata.refs.values())
216+
217+
# Expire the snapshots
218+
table_v2.expire_snapshots().expire_snapshots_by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2]).commit()
219+
220+
table_v2.catalog.commit_table.assert_called_once()
221+
remaining_snapshots = table_v2.metadata.snapshots
222+
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
223+
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots
224+
assert len(table_v2.metadata.snapshots) == 1

0 commit comments

Comments
 (0)