-
Notifications
You must be signed in to change notification settings - Fork 479
Expand file tree
/
Copy pathtest_expire_snapshots.py
More file actions
417 lines (327 loc) · 17.3 KB
/
test_expire_snapshots.py
File metadata and controls
417 lines (327 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import threading
from datetime import datetime, timedelta
from unittest.mock import MagicMock, Mock
from uuid import uuid4
import pytest
from pyiceberg.table import CommitTableResponse, Table
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
from pyiceberg.table.update import RemoveSnapshotRefUpdate, RemoveSnapshotsUpdate, update_table_metadata
from pyiceberg.table.update.snapshot import ExpireSnapshots
def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None:
"""Test that a HEAD (branch) snapshot cannot be expired."""
HEAD_SNAPSHOT = 3051729675574597004
KEEP_SNAPSHOT = 3055729675574597004
# Mock the catalog's commit_table method
table_v2.catalog = MagicMock()
# Simulate refs protecting HEAD_SNAPSHOT as a branch
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": MagicMock(snapshot_id=HEAD_SNAPSHOT, snapshot_ref_type="branch"),
"tag1": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="tag"),
}
}
)
# Assert fixture data
assert any(ref.snapshot_id == HEAD_SNAPSHOT for ref in table_v2.metadata.refs.values())
# Attempt to expire the HEAD snapshot and expect a ValueError
with pytest.raises(ValueError, match=f"Snapshot with ID {HEAD_SNAPSHOT} is protected and cannot be expired."):
table_v2.maintenance.expire_snapshots().by_id(HEAD_SNAPSHOT).commit()
table_v2.catalog.commit_table.assert_not_called()
def test_cannot_expire_tagged_snapshot(table_v2: Table) -> None:
"""Test that a tagged snapshot cannot be expired."""
TAGGED_SNAPSHOT = 3051729675574597004
KEEP_SNAPSHOT = 3055729675574597004
table_v2.catalog = MagicMock()
# Simulate refs protecting TAGGED_SNAPSHOT as a tag
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"tag1": MagicMock(snapshot_id=TAGGED_SNAPSHOT, snapshot_ref_type="tag"),
"main": MagicMock(snapshot_id=KEEP_SNAPSHOT, snapshot_ref_type="branch"),
}
}
)
assert any(ref.snapshot_id == TAGGED_SNAPSHOT for ref in table_v2.metadata.refs.values())
with pytest.raises(ValueError, match=f"Snapshot with ID {TAGGED_SNAPSHOT} is protected and cannot be expired."):
table_v2.maintenance.expire_snapshots().by_id(TAGGED_SNAPSHOT).commit()
table_v2.catalog.commit_table.assert_not_called()
def test_expire_unprotected_snapshot(table_v2: Table) -> None:
"""Test that an unprotected snapshot can be expired."""
EXPIRE_SNAPSHOT = 3051729675574597004
KEEP_SNAPSHOT = 3055729675574597004
mock_response = CommitTableResponse(
metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}),
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = mock_response
# Remove any refs that protect the snapshot to be expired
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
}
}
)
# Assert fixture data
assert all(ref.snapshot_id != EXPIRE_SNAPSHOT for ref in table_v2.metadata.refs.values())
# Expire the snapshot
table_v2.maintenance.expire_snapshots().by_id(EXPIRE_SNAPSHOT).commit()
table_v2.catalog.commit_table.assert_called_once()
remaining_snapshots = table_v2.metadata.snapshots
assert EXPIRE_SNAPSHOT not in remaining_snapshots
assert len(table_v2.metadata.snapshots) == 1
def test_expire_nonexistent_snapshot_raises(table_v2: Table) -> None:
"""Test that trying to expire a non-existent snapshot raises an error."""
NONEXISTENT_SNAPSHOT = 9999999999999999999
table_v2.catalog = MagicMock()
table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}})
with pytest.raises(ValueError, match=f"Snapshot with ID {NONEXISTENT_SNAPSHOT} does not exist."):
table_v2.maintenance.expire_snapshots().by_id(NONEXISTENT_SNAPSHOT).commit()
table_v2.catalog.commit_table.assert_not_called()
def test_expire_snapshots_by_timestamp_skips_protected(table_v2: Table) -> None:
# Setup: two snapshots; both are old, but one is head/tag protected
HEAD_SNAPSHOT = 3051729675574597004
TAGGED_SNAPSHOT = 3055729675574597004
# Add snapshots to metadata for timestamp/protected test
from types import SimpleNamespace
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": HEAD_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
"mytag": SnapshotRef(**{"snapshot-id": TAGGED_SNAPSHOT, "type": SnapshotRefType.TAG}),
},
"snapshots": [
SimpleNamespace(snapshot_id=HEAD_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=TAGGED_SNAPSHOT, timestamp_ms=1, parent_snapshot_id=None),
],
}
)
table_v2.catalog = MagicMock()
# Attempt to expire all snapshots before a future timestamp (so both are candidates)
future_datetime = datetime.now() + timedelta(days=1)
# Mock the catalog's commit_table to return the current metadata (simulate no change)
mock_response = CommitTableResponse(
metadata=table_v2.metadata, # protected snapshots remain
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog.commit_table.return_value = mock_response
table_v2.maintenance.expire_snapshots().older_than(future_datetime).commit()
# Update metadata to reflect the commit (as in other tests)
table_v2.metadata = mock_response.metadata
# Both protected snapshots should remain
remaining_ids = {s.snapshot_id for s in table_v2.metadata.snapshots}
assert HEAD_SNAPSHOT in remaining_ids
assert TAGGED_SNAPSHOT in remaining_ids
# No snapshots expired and no refs expired — commit_table should not be called at all
table_v2.catalog.commit_table.assert_not_called()
def test_expire_snapshots_by_ids(table_v2: Table) -> None:
"""Test that multiple unprotected snapshots can be expired by IDs."""
EXPIRE_SNAPSHOT_1 = 3051729675574597004
EXPIRE_SNAPSHOT_2 = 3051729675574597005
KEEP_SNAPSHOT = 3055729675574597004
mock_response = CommitTableResponse(
metadata=table_v2.metadata.model_copy(update={"snapshots": [KEEP_SNAPSHOT]}),
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = mock_response
# Add snapshots to metadata for multi-id test
from types import SimpleNamespace
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
"tag1": SnapshotRef(**{"snapshot-id": KEEP_SNAPSHOT, "type": SnapshotRefType.TAG}),
},
"snapshots": [
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_1, timestamp_ms=1, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=EXPIRE_SNAPSHOT_2, timestamp_ms=1, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=KEEP_SNAPSHOT, timestamp_ms=2, parent_snapshot_id=None),
],
}
)
# Assert fixture data
assert all(ref.snapshot_id not in (EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2) for ref in table_v2.metadata.refs.values())
# Expire the snapshots
table_v2.maintenance.expire_snapshots().by_ids([EXPIRE_SNAPSHOT_1, EXPIRE_SNAPSHOT_2]).commit()
table_v2.catalog.commit_table.assert_called_once()
remaining_snapshots = table_v2.metadata.snapshots
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots
assert len(table_v2.metadata.snapshots) == 1
def test_thread_safety_fix() -> None:
"""Test that ExpireSnapshots instances have isolated state."""
# Create two ExpireSnapshots instances
expire1 = ExpireSnapshots(Mock())
expire2 = ExpireSnapshots(Mock())
# Verify they have separate snapshot sets (this was the bug!)
# Before fix: both would have the same id (shared class attribute)
# After fix: they should have different ids (separate instance attributes)
assert id(expire1._snapshot_ids_to_expire) != id(expire2._snapshot_ids_to_expire), (
"ExpireSnapshots instances are sharing the same snapshot set - thread safety bug still exists"
)
# Test that modifications to one don't affect the other
expire1._snapshot_ids_to_expire.add(1001)
expire2._snapshot_ids_to_expire.add(2001)
# Verify no cross-contamination of snapshot IDs
assert 2001 not in expire1._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances"
assert 1001 not in expire2._snapshot_ids_to_expire, "Snapshot IDs are leaking between instances"
def test_concurrent_operations() -> None:
"""Test concurrent operations with separate ExpireSnapshots instances."""
results: dict[str, set[int]] = {"expire1_snapshots": set(), "expire2_snapshots": set()}
def worker1() -> None:
expire1 = ExpireSnapshots(Mock())
expire1._snapshot_ids_to_expire.update([1001, 1002, 1003])
results["expire1_snapshots"] = expire1._snapshot_ids_to_expire.copy()
def worker2() -> None:
expire2 = ExpireSnapshots(Mock())
expire2._snapshot_ids_to_expire.update([2001, 2002, 2003])
results["expire2_snapshots"] = expire2._snapshot_ids_to_expire.copy()
# Run both workers concurrently
thread1 = threading.Thread(target=worker1)
thread2 = threading.Thread(target=worker2)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
# Check for cross-contamination
expected_1 = {1001, 1002, 1003}
expected_2 = {2001, 2002, 2003}
assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated"
assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated"
def test_update_remove_snapshots_with_statistics(table_v2_with_statistics: Table) -> None:
"""
Test removing snapshots from a table that has statistics.
This test exercises the code path where RemoveStatisticsUpdate is instantiated
within the RemoveSnapshotsUpdate handler. Before the fix for #2558, this would
fail with: TypeError: BaseModel.__init__() takes 1 positional argument but 2 were given
"""
# The table has 2 snapshots with IDs: 3051729675574597004 and 3055729675574597004
# Both snapshots have statistics associated with them
REMOVE_SNAPSHOT = 3051729675574597004
KEEP_SNAPSHOT = 3055729675574597004
# Verify fixture assumptions
assert len(table_v2_with_statistics.metadata.snapshots) == 2
assert len(table_v2_with_statistics.metadata.statistics) == 2
assert any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in table_v2_with_statistics.metadata.statistics), (
"Snapshot to remove should have statistics"
)
# This should trigger RemoveStatisticsUpdate instantiation for the removed snapshot
update = RemoveSnapshotsUpdate(snapshot_ids=[REMOVE_SNAPSHOT])
new_metadata = update_table_metadata(table_v2_with_statistics.metadata, (update,))
# Verify the snapshot was removed
assert len(new_metadata.snapshots) == 1
assert new_metadata.snapshots[0].snapshot_id == KEEP_SNAPSHOT
# Verify the statistics for the removed snapshot were also removed
assert len(new_metadata.statistics) == 1
assert new_metadata.statistics[0].snapshot_id == KEEP_SNAPSHOT
assert not any(stat.snapshot_id == REMOVE_SNAPSHOT for stat in new_metadata.statistics), (
"Statistics for removed snapshot should be gone"
)
def _make_commit_response(table: Table) -> CommitTableResponse:
return CommitTableResponse(
metadata=table.metadata,
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
def test_ref_expiration_removes_old_tag_and_snapshot(table_v2: Table) -> None:
"""A tag whose snapshot age exceeds max_ref_age_ms is removed; its orphaned snapshot
is also expired when older_than() is combined."""
OLD_SNAPSHOT = 3051729675574597004
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
# "test" tag (fixture) points to OLD_SNAPSHOT with max-ref-age-ms=10000000 (~2.7 h).
# OLD_SNAPSHOT timestamp is from 2018 — definitely older than 2.7 h.
assert "test" in table_v2.metadata.refs
assert table_v2.metadata.refs["test"].snapshot_id == OLD_SNAPSHOT
future = datetime.now() + timedelta(days=1)
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).older_than(future).commit()
args, _ = table_v2.catalog.commit_table.call_args
updates = args[2]
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
snap_updates = [u for u in updates if isinstance(u, RemoveSnapshotsUpdate)]
assert any(u.ref_name == "test" for u in ref_updates), "Expected 'test' tag to be removed"
assert any(OLD_SNAPSHOT in u.snapshot_ids for u in snap_updates), (
"Expected OLD_SNAPSHOT to be removed since it is no longer referenced"
)
def test_ref_expiration_removes_old_branch(table_v2: Table) -> None:
"""A non-main branch whose snapshot age exceeds max_ref_age_ms is removed."""
OLD_SNAPSHOT = 3051729675574597004
CURRENT_SNAPSHOT = 3055729675574597004
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
"stale-branch": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}),
}
}
)
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
args, _ = table_v2.catalog.commit_table.call_args
updates = args[2]
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
assert any(u.ref_name == "stale-branch" for u in ref_updates)
assert not any(u.ref_name == "main" for u in ref_updates)
def test_main_branch_never_expires(table_v2: Table) -> None:
"""main branch is never removed regardless of age or max_ref_age_ms."""
CURRENT_SNAPSHOT = 3055729675574597004
table_v2.catalog = MagicMock()
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH, "max-ref-age-ms": 1}),
}
}
)
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
table_v2.catalog.commit_table.assert_not_called()
def test_young_ref_is_retained(table_v2: Table) -> None:
"""A ref whose snapshot is within max_ref_age_ms is not removed."""
OLD_SNAPSHOT = 3051729675574597004
CURRENT_SNAPSHOT = 3055729675574597004
table_v2.catalog = MagicMock()
table_v2.catalog.commit_table.return_value = _make_commit_response(table_v2)
# fresh-tag has a huge max_ref_age_ms — it should never expire
# stale-tag has max_ref_age_ms=1 — it will be expired (triggers a commit)
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": SnapshotRef(**{"snapshot-id": CURRENT_SNAPSHOT, "type": SnapshotRefType.BRANCH}),
"fresh-tag": SnapshotRef(
**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 9999999999999}
),
"stale-tag": SnapshotRef(**{"snapshot-id": OLD_SNAPSHOT, "type": SnapshotRefType.TAG, "max-ref-age-ms": 1}),
}
}
)
table_v2.maintenance.expire_snapshots().remove_expired_refs(default_max_ref_age_ms=1).commit()
table_v2.catalog.commit_table.assert_called_once()
args, _ = table_v2.catalog.commit_table.call_args
updates = args[2]
ref_updates = [u for u in updates if isinstance(u, RemoveSnapshotRefUpdate)]
assert any(u.ref_name == "stale-tag" for u in ref_updates), "stale-tag should be expired"
assert not any(u.ref_name == "fresh-tag" for u in ref_updates), "fresh-tag must not be expired"