Skip to content

Commit e7d25f8

Browse files
committed
Remove redundant concurrent snapshot expiration tests and improve thread safety assertions
1 parent b12706e commit e7d25f8

1 file changed

Lines changed: 1 addition & 323 deletions

File tree

tests/table/test_expire_snapshots.py

Lines changed: 1 addition & 323 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import threading
18-
import uuid
1918
from datetime import datetime, timedelta
20-
from typing import Any, Dict, List
19+
from typing import Dict
2120
from unittest.mock import MagicMock, Mock
2221
from uuid import uuid4
2322

@@ -283,324 +282,3 @@ def worker2() -> None:
283282

284283
assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated"
285284
assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated"
286-
287-
288-
def test_concurrent_different_tables_expiration() -> None:
289-
"""Test that concurrent snapshot expiration on DIFFERENT tables works correctly.
290-
291-
This test reproduces the issue described in:
292-
https://github.com/apache/iceberg-python/issues/2409
293-
294-
The issue occurs when expiring snapshots from different tables concurrently,
295-
where snapshot IDs from one table get applied to another table.
296-
"""
297-
# Create two mock tables with different snapshot IDs
298-
table1 = Mock()
299-
table1.metadata = Mock()
300-
table1.metadata.table_uuid = uuid4()
301-
302-
table2 = Mock()
303-
table2.metadata = Mock()
304-
table2.metadata.table_uuid = uuid4()
305-
306-
# Track calls to each table's expire_snapshots method
307-
table1_expire_calls = []
308-
table2_expire_calls = []
309-
310-
def create_table1_expire_mock() -> Mock:
311-
expire_mock = Mock()
312-
313-
def side_effect(sid: int) -> Mock:
314-
table1_expire_calls.append(sid)
315-
return expire_mock
316-
317-
expire_mock.by_id = Mock(side_effect=side_effect)
318-
expire_mock.commit = Mock(return_value=None)
319-
return expire_mock
320-
321-
def create_table2_expire_mock() -> Mock:
322-
expire_mock = Mock()
323-
324-
def side_effect(sid: int) -> Mock:
325-
table2_expire_calls.append(sid)
326-
return expire_mock
327-
328-
expire_mock.by_id = Mock(side_effect=side_effect)
329-
expire_mock.commit = Mock(return_value=None)
330-
return expire_mock
331-
332-
table1.maintenance = Mock()
333-
table1.maintenance.expire_snapshots = Mock(side_effect=create_table1_expire_mock)
334-
335-
table2.maintenance = Mock()
336-
table2.maintenance.expire_snapshots = Mock(side_effect=create_table2_expire_mock)
337-
338-
# Define different snapshot IDs for each table
339-
table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005]
340-
table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005]
341-
342-
def expire_table_snapshots(table_obj: Any, table_name: str, snapshots_to_expire: List[int], results: Dict[str, Any]) -> None:
343-
"""Expire specific snapshots from a table."""
344-
try:
345-
# Expire the snapshots one by one (as in the user's example)
346-
for snapshot_id in snapshots_to_expire:
347-
table_obj.maintenance.expire_snapshots().by_id(snapshot_id).commit()
348-
349-
results["success"] = True
350-
results["expired_snapshots"] = snapshots_to_expire
351-
352-
except Exception as e:
353-
results["success"] = False
354-
results["error"] = str(e)
355-
356-
# Prepare snapshots to expire (first 2 from each table)
357-
table1_to_expire = table1_snapshot_ids[:2]
358-
table2_to_expire = table2_snapshot_ids[:2]
359-
360-
results1: Dict[str, Any] = {}
361-
results2: Dict[str, Any] = {}
362-
363-
# Create threads to expire snapshots from different tables concurrently
364-
thread1 = threading.Thread(target=expire_table_snapshots, args=(table1, "table1", table1_to_expire, results1))
365-
thread2 = threading.Thread(target=expire_table_snapshots, args=(table2, "table2", table2_to_expire, results2))
366-
367-
# Start threads concurrently
368-
thread1.start()
369-
thread2.start()
370-
371-
# Wait for completion
372-
thread1.join()
373-
thread2.join()
374-
375-
# Check results - both should succeed if thread safety is correct
376-
# Assert both operations succeeded
377-
assert results1.get("success", False), f"Table1 expiration failed: {results1.get('error', 'Unknown error')}"
378-
assert results2.get("success", False), f"Table2 expiration failed: {results2.get('error', 'Unknown error')}"
379-
380-
# CRITICAL: Verify that each table only received its own snapshot IDs
381-
# This is the key test - if the bug exists, snapshot IDs will cross-contaminate
382-
for sid in table1_expire_calls:
383-
assert sid in table1_snapshot_ids, f"Table1 received unexpected snapshot ID {sid}"
384-
385-
for sid in table2_expire_calls:
386-
assert sid in table2_snapshot_ids, f"Table2 received unexpected snapshot ID {sid}"
387-
388-
# Verify expected snapshots were expired
389-
assert set(table1_expire_calls) == set(table1_to_expire), "Table1 didn't expire expected snapshots"
390-
assert set(table2_expire_calls) == set(table2_to_expire), "Table2 didn't expire expected snapshots"
391-
392-
393-
def test_concurrent_same_table_different_snapshots(table_v2_with_extensive_snapshots: Table) -> None:
394-
"""Test that concurrent snapshot expiration operations on the same table work correctly."""
395-
# Mock the catalog's commit_table method for both operations
396-
table_v2_with_extensive_snapshots.catalog = MagicMock()
397-
table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse(
398-
metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location"
399-
)
400-
401-
# Use existing snapshot IDs from fixture data, but filter out protected snapshots
402-
all_snapshots = list(table_v2_with_extensive_snapshots.snapshots())
403-
snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots]
404-
405-
# Get protected snapshot IDs from refs
406-
protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()}
407-
408-
# Find unprotected snapshots that we can expire
409-
unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids]
410-
411-
# If we don't have enough unprotected snapshots, skip the test
412-
if len(unprotected_snapshot_ids) < 2:
413-
pytest.skip("Not enough unprotected snapshots available for testing")
414-
415-
# We'll expire the first two unprotected snapshots concurrently
416-
to_expire1 = [unprotected_snapshot_ids[0]]
417-
to_expire2 = [unprotected_snapshot_ids[1]]
418-
419-
def expire_snapshots_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None:
420-
"""Function to run in a thread that expires snapshots and captures results."""
421-
try:
422-
# Expire snapshots
423-
expire_op = table.maintenance.expire_snapshots()
424-
for snapshot_id in snapshot_ids_to_expire:
425-
expire_op = expire_op.by_id(snapshot_id)
426-
expire_op.commit()
427-
results["success"] = True
428-
except Exception as e:
429-
results["success"] = False
430-
results["error"] = str(e)
431-
432-
# Prepare result dictionaries to capture thread outcomes
433-
results1: Dict[str, Any] = {}
434-
results2: Dict[str, Any] = {}
435-
436-
# Create threads to expire snapshots concurrently
437-
thread1 = threading.Thread(
438-
target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire1, results1)
439-
)
440-
thread2 = threading.Thread(
441-
target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire2, results2)
442-
)
443-
444-
# Start and join threads
445-
thread1.start()
446-
thread2.start()
447-
thread1.join()
448-
thread2.join()
449-
450-
# Assert that both operations succeeded
451-
assert results1.get("success", False), f"Thread 1 expiration failed: {results1.get('error', 'Unknown error')}"
452-
assert results2.get("success", False), f"Thread 2 expiration failed: {results2.get('error', 'Unknown error')}"
453-
454-
# Verify that both commit_table calls were made
455-
assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 2
456-
457-
458-
def test_cross_table_snapshot_id_isolation() -> None:
459-
"""Test that verifies snapshot IDs don't get mixed up between different tables.
460-
461-
This test validates the fix for GitHub issue #2409 by ensuring that concurrent
462-
operations on different table objects properly isolate their snapshot IDs.
463-
"""
464-
465-
# Create two mock table objects to simulate the user's scenario
466-
# Mock table 1 with its own snapshot IDs
467-
table1 = Mock()
468-
table1.metadata = Mock()
469-
table1.metadata.table_uuid = uuid.uuid4()
470-
table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005]
471-
472-
# Mock table 2 with different snapshot IDs
473-
table2 = Mock()
474-
table2.metadata = Mock()
475-
table2.metadata.table_uuid = uuid.uuid4()
476-
table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005]
477-
478-
# Track which snapshot IDs each table's expire operation receives
479-
table1_expire_calls = []
480-
table2_expire_calls = []
481-
482-
def mock_table1_expire() -> Mock:
483-
expire_mock = Mock()
484-
485-
def side_effect(sid: int) -> Mock:
486-
table1_expire_calls.append(sid)
487-
return expire_mock
488-
489-
expire_mock.by_id = Mock(side_effect=side_effect)
490-
expire_mock.commit = Mock(return_value=None)
491-
return expire_mock
492-
493-
def mock_table2_expire() -> Mock:
494-
expire_mock = Mock()
495-
496-
def side_effect(sid: int) -> Mock:
497-
table2_expire_calls.append(sid)
498-
return expire_mock
499-
500-
expire_mock.by_id = Mock(side_effect=side_effect)
501-
expire_mock.commit = Mock(return_value=None)
502-
return expire_mock
503-
504-
table1.maintenance = Mock()
505-
table1.maintenance.expire_snapshots = Mock(side_effect=mock_table1_expire)
506-
table2.maintenance = Mock()
507-
table2.maintenance.expire_snapshots = Mock(side_effect=mock_table2_expire)
508-
509-
def expire_from_table(table: Any, table_name: str, snapshot_ids: List[int], results: Dict[str, Any]) -> None:
510-
"""Expire snapshots from a specific table."""
511-
try:
512-
for snapshot_id in snapshot_ids:
513-
table.maintenance.expire_snapshots().by_id(snapshot_id).commit()
514-
results["success"] = True
515-
results["expired_ids"] = snapshot_ids
516-
except Exception as e:
517-
results["success"] = False
518-
results["error"] = str(e)
519-
520-
# Prepare snapshots to expire
521-
table1_to_expire = table1_snapshot_ids[:2] # [1001, 1002]
522-
table2_to_expire = table2_snapshot_ids[:2] # [2001, 2002]
523-
524-
results1: Dict[str, Any] = {}
525-
results2: Dict[str, Any] = {}
526-
527-
# Run concurrent expiration operations
528-
thread1 = threading.Thread(target=expire_from_table, args=(table1, "table1", table1_to_expire, results1))
529-
thread2 = threading.Thread(target=expire_from_table, args=(table2, "table2", table2_to_expire, results2))
530-
531-
thread1.start()
532-
thread2.start()
533-
thread1.join()
534-
thread2.join()
535-
536-
# CRITICAL ASSERTION: Each table should only receive its own snapshot IDs
537-
# If this fails, it means the thread safety bug exists
538-
539-
# Table1 should only see table1 snapshot IDs
540-
assert all(sid in table1_snapshot_ids for sid in table1_expire_calls), (
541-
f"Table1 received unexpected snapshot IDs: {table1_expire_calls} (should only contain {table1_snapshot_ids})"
542-
)
543-
544-
# Table2 should only see table2 snapshot IDs
545-
assert all(sid in table2_snapshot_ids for sid in table2_expire_calls), (
546-
f"Table2 received unexpected snapshot IDs: {table2_expire_calls} (should only contain {table2_snapshot_ids})"
547-
)
548-
549-
# Verify no cross-contamination
550-
table1_received_table2_ids = [sid for sid in table1_expire_calls if sid in table2_snapshot_ids]
551-
table2_received_table1_ids = [sid for sid in table2_expire_calls if sid in table1_snapshot_ids]
552-
553-
assert len(table1_received_table2_ids) == 0, f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}"
554-
555-
assert len(table2_received_table1_ids) == 0, f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}"
556-
557-
558-
def test_batch_expire_snapshots(table_v2_with_extensive_snapshots: Table) -> None:
559-
"""Test that batch expiration of multiple snapshots works correctly."""
560-
# Mock the catalog's commit_table method
561-
table_v2_with_extensive_snapshots.catalog = MagicMock()
562-
table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse(
563-
metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location"
564-
)
565-
566-
# Use existing snapshot IDs from fixture data, but filter out protected snapshots
567-
all_snapshots = list(table_v2_with_extensive_snapshots.snapshots())
568-
snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots]
569-
570-
# Get protected snapshot IDs from refs
571-
protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()}
572-
573-
# Find unprotected snapshots that we can expire
574-
unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids]
575-
576-
# If we don't have enough unprotected snapshots, skip the test
577-
if len(unprotected_snapshot_ids) < 2:
578-
pytest.skip("Not enough unprotected snapshots available for testing")
579-
580-
# We'll expire the first two unprotected snapshots in a batch
581-
to_expire = unprotected_snapshot_ids[:2]
582-
583-
def batch_expire_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None:
584-
try:
585-
# Expire all snapshots in a single batch operation
586-
table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).commit()
587-
results["success"] = True
588-
except Exception as e:
589-
results["success"] = False
590-
results["error"] = str(e)
591-
592-
# Prepare result dictionary to capture thread outcome
593-
results: Dict[str, Any] = {}
594-
595-
# Create thread to expire snapshots
596-
thread = threading.Thread(target=batch_expire_thread_func, args=(table_v2_with_extensive_snapshots, to_expire, results))
597-
598-
# Start and join thread
599-
thread.start()
600-
thread.join()
601-
602-
# Assert that the operation succeeded
603-
assert results.get("success", False), f"Batch expiration failed: {results.get('error', 'Unknown error')}"
604-
605-
# Verify that commit_table was called once
606-
assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 1

0 commit comments

Comments
 (0)