|
15 | 15 | # specific language governing permissions and limitations |
16 | 16 | # under the License. |
17 | 17 | import threading |
18 | | -import uuid |
19 | 18 | from datetime import datetime, timedelta |
20 | | -from typing import Any, Dict, List |
| 19 | +from typing import Dict |
21 | 20 | from unittest.mock import MagicMock, Mock |
22 | 21 | from uuid import uuid4 |
23 | 22 |
|
@@ -283,324 +282,3 @@ def worker2() -> None: |
283 | 282 |
|
284 | 283 | assert results["expire1_snapshots"] == expected_1, "Worker 1 snapshots contaminated" |
285 | 284 | assert results["expire2_snapshots"] == expected_2, "Worker 2 snapshots contaminated" |
286 | | - |
287 | | - |
288 | | -def test_concurrent_different_tables_expiration() -> None: |
289 | | - """Test that concurrent snapshot expiration on DIFFERENT tables works correctly. |
290 | | -
|
291 | | - This test reproduces the issue described in: |
292 | | - https://github.com/apache/iceberg-python/issues/2409 |
293 | | -
|
294 | | - The issue occurs when expiring snapshots from different tables concurrently, |
295 | | - where snapshot IDs from one table get applied to another table. |
296 | | - """ |
297 | | - # Create two mock tables with different snapshot IDs |
298 | | - table1 = Mock() |
299 | | - table1.metadata = Mock() |
300 | | - table1.metadata.table_uuid = uuid4() |
301 | | - |
302 | | - table2 = Mock() |
303 | | - table2.metadata = Mock() |
304 | | - table2.metadata.table_uuid = uuid4() |
305 | | - |
306 | | - # Track calls to each table's expire_snapshots method |
307 | | - table1_expire_calls = [] |
308 | | - table2_expire_calls = [] |
309 | | - |
310 | | - def create_table1_expire_mock() -> Mock: |
311 | | - expire_mock = Mock() |
312 | | - |
313 | | - def side_effect(sid: int) -> Mock: |
314 | | - table1_expire_calls.append(sid) |
315 | | - return expire_mock |
316 | | - |
317 | | - expire_mock.by_id = Mock(side_effect=side_effect) |
318 | | - expire_mock.commit = Mock(return_value=None) |
319 | | - return expire_mock |
320 | | - |
321 | | - def create_table2_expire_mock() -> Mock: |
322 | | - expire_mock = Mock() |
323 | | - |
324 | | - def side_effect(sid: int) -> Mock: |
325 | | - table2_expire_calls.append(sid) |
326 | | - return expire_mock |
327 | | - |
328 | | - expire_mock.by_id = Mock(side_effect=side_effect) |
329 | | - expire_mock.commit = Mock(return_value=None) |
330 | | - return expire_mock |
331 | | - |
332 | | - table1.maintenance = Mock() |
333 | | - table1.maintenance.expire_snapshots = Mock(side_effect=create_table1_expire_mock) |
334 | | - |
335 | | - table2.maintenance = Mock() |
336 | | - table2.maintenance.expire_snapshots = Mock(side_effect=create_table2_expire_mock) |
337 | | - |
338 | | - # Define different snapshot IDs for each table |
339 | | - table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005] |
340 | | - table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005] |
341 | | - |
342 | | - def expire_table_snapshots(table_obj: Any, table_name: str, snapshots_to_expire: List[int], results: Dict[str, Any]) -> None: |
343 | | - """Expire specific snapshots from a table.""" |
344 | | - try: |
345 | | - # Expire the snapshots one by one (as in the user's example) |
346 | | - for snapshot_id in snapshots_to_expire: |
347 | | - table_obj.maintenance.expire_snapshots().by_id(snapshot_id).commit() |
348 | | - |
349 | | - results["success"] = True |
350 | | - results["expired_snapshots"] = snapshots_to_expire |
351 | | - |
352 | | - except Exception as e: |
353 | | - results["success"] = False |
354 | | - results["error"] = str(e) |
355 | | - |
356 | | - # Prepare snapshots to expire (first 2 from each table) |
357 | | - table1_to_expire = table1_snapshot_ids[:2] |
358 | | - table2_to_expire = table2_snapshot_ids[:2] |
359 | | - |
360 | | - results1: Dict[str, Any] = {} |
361 | | - results2: Dict[str, Any] = {} |
362 | | - |
363 | | - # Create threads to expire snapshots from different tables concurrently |
364 | | - thread1 = threading.Thread(target=expire_table_snapshots, args=(table1, "table1", table1_to_expire, results1)) |
365 | | - thread2 = threading.Thread(target=expire_table_snapshots, args=(table2, "table2", table2_to_expire, results2)) |
366 | | - |
367 | | - # Start threads concurrently |
368 | | - thread1.start() |
369 | | - thread2.start() |
370 | | - |
371 | | - # Wait for completion |
372 | | - thread1.join() |
373 | | - thread2.join() |
374 | | - |
375 | | - # Check results - both should succeed if thread safety is correct |
376 | | - # Assert both operations succeeded |
377 | | - assert results1.get("success", False), f"Table1 expiration failed: {results1.get('error', 'Unknown error')}" |
378 | | - assert results2.get("success", False), f"Table2 expiration failed: {results2.get('error', 'Unknown error')}" |
379 | | - |
380 | | - # CRITICAL: Verify that each table only received its own snapshot IDs |
381 | | - # This is the key test - if the bug exists, snapshot IDs will cross-contaminate |
382 | | - for sid in table1_expire_calls: |
383 | | - assert sid in table1_snapshot_ids, f"Table1 received unexpected snapshot ID {sid}" |
384 | | - |
385 | | - for sid in table2_expire_calls: |
386 | | - assert sid in table2_snapshot_ids, f"Table2 received unexpected snapshot ID {sid}" |
387 | | - |
388 | | - # Verify expected snapshots were expired |
389 | | - assert set(table1_expire_calls) == set(table1_to_expire), "Table1 didn't expire expected snapshots" |
390 | | - assert set(table2_expire_calls) == set(table2_to_expire), "Table2 didn't expire expected snapshots" |
391 | | - |
392 | | - |
393 | | -def test_concurrent_same_table_different_snapshots(table_v2_with_extensive_snapshots: Table) -> None: |
394 | | - """Test that concurrent snapshot expiration operations on the same table work correctly.""" |
395 | | - # Mock the catalog's commit_table method for both operations |
396 | | - table_v2_with_extensive_snapshots.catalog = MagicMock() |
397 | | - table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( |
398 | | - metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" |
399 | | - ) |
400 | | - |
401 | | - # Use existing snapshot IDs from fixture data, but filter out protected snapshots |
402 | | - all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) |
403 | | - snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] |
404 | | - |
405 | | - # Get protected snapshot IDs from refs |
406 | | - protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} |
407 | | - |
408 | | - # Find unprotected snapshots that we can expire |
409 | | - unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] |
410 | | - |
411 | | - # If we don't have enough unprotected snapshots, skip the test |
412 | | - if len(unprotected_snapshot_ids) < 2: |
413 | | - pytest.skip("Not enough unprotected snapshots available for testing") |
414 | | - |
415 | | - # We'll expire the first two unprotected snapshots concurrently |
416 | | - to_expire1 = [unprotected_snapshot_ids[0]] |
417 | | - to_expire2 = [unprotected_snapshot_ids[1]] |
418 | | - |
419 | | - def expire_snapshots_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None: |
420 | | - """Function to run in a thread that expires snapshots and captures results.""" |
421 | | - try: |
422 | | - # Expire snapshots |
423 | | - expire_op = table.maintenance.expire_snapshots() |
424 | | - for snapshot_id in snapshot_ids_to_expire: |
425 | | - expire_op = expire_op.by_id(snapshot_id) |
426 | | - expire_op.commit() |
427 | | - results["success"] = True |
428 | | - except Exception as e: |
429 | | - results["success"] = False |
430 | | - results["error"] = str(e) |
431 | | - |
432 | | - # Prepare result dictionaries to capture thread outcomes |
433 | | - results1: Dict[str, Any] = {} |
434 | | - results2: Dict[str, Any] = {} |
435 | | - |
436 | | - # Create threads to expire snapshots concurrently |
437 | | - thread1 = threading.Thread( |
438 | | - target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire1, results1) |
439 | | - ) |
440 | | - thread2 = threading.Thread( |
441 | | - target=expire_snapshots_thread_func, args=(table_v2_with_extensive_snapshots, to_expire2, results2) |
442 | | - ) |
443 | | - |
444 | | - # Start and join threads |
445 | | - thread1.start() |
446 | | - thread2.start() |
447 | | - thread1.join() |
448 | | - thread2.join() |
449 | | - |
450 | | - # Assert that both operations succeeded |
451 | | - assert results1.get("success", False), f"Thread 1 expiration failed: {results1.get('error', 'Unknown error')}" |
452 | | - assert results2.get("success", False), f"Thread 2 expiration failed: {results2.get('error', 'Unknown error')}" |
453 | | - |
454 | | - # Verify that both commit_table calls were made |
455 | | - assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 2 |
456 | | - |
457 | | - |
458 | | -def test_cross_table_snapshot_id_isolation() -> None: |
459 | | - """Test that verifies snapshot IDs don't get mixed up between different tables. |
460 | | -
|
461 | | - This test validates the fix for GitHub issue #2409 by ensuring that concurrent |
462 | | - operations on different table objects properly isolate their snapshot IDs. |
463 | | - """ |
464 | | - |
465 | | - # Create two mock table objects to simulate the user's scenario |
466 | | - # Mock table 1 with its own snapshot IDs |
467 | | - table1 = Mock() |
468 | | - table1.metadata = Mock() |
469 | | - table1.metadata.table_uuid = uuid.uuid4() |
470 | | - table1_snapshot_ids = [1001, 1002, 1003, 1004, 1005] |
471 | | - |
472 | | - # Mock table 2 with different snapshot IDs |
473 | | - table2 = Mock() |
474 | | - table2.metadata = Mock() |
475 | | - table2.metadata.table_uuid = uuid.uuid4() |
476 | | - table2_snapshot_ids = [2001, 2002, 2003, 2004, 2005] |
477 | | - |
478 | | - # Track which snapshot IDs each table's expire operation receives |
479 | | - table1_expire_calls = [] |
480 | | - table2_expire_calls = [] |
481 | | - |
482 | | - def mock_table1_expire() -> Mock: |
483 | | - expire_mock = Mock() |
484 | | - |
485 | | - def side_effect(sid: int) -> Mock: |
486 | | - table1_expire_calls.append(sid) |
487 | | - return expire_mock |
488 | | - |
489 | | - expire_mock.by_id = Mock(side_effect=side_effect) |
490 | | - expire_mock.commit = Mock(return_value=None) |
491 | | - return expire_mock |
492 | | - |
493 | | - def mock_table2_expire() -> Mock: |
494 | | - expire_mock = Mock() |
495 | | - |
496 | | - def side_effect(sid: int) -> Mock: |
497 | | - table2_expire_calls.append(sid) |
498 | | - return expire_mock |
499 | | - |
500 | | - expire_mock.by_id = Mock(side_effect=side_effect) |
501 | | - expire_mock.commit = Mock(return_value=None) |
502 | | - return expire_mock |
503 | | - |
504 | | - table1.maintenance = Mock() |
505 | | - table1.maintenance.expire_snapshots = Mock(side_effect=mock_table1_expire) |
506 | | - table2.maintenance = Mock() |
507 | | - table2.maintenance.expire_snapshots = Mock(side_effect=mock_table2_expire) |
508 | | - |
509 | | - def expire_from_table(table: Any, table_name: str, snapshot_ids: List[int], results: Dict[str, Any]) -> None: |
510 | | - """Expire snapshots from a specific table.""" |
511 | | - try: |
512 | | - for snapshot_id in snapshot_ids: |
513 | | - table.maintenance.expire_snapshots().by_id(snapshot_id).commit() |
514 | | - results["success"] = True |
515 | | - results["expired_ids"] = snapshot_ids |
516 | | - except Exception as e: |
517 | | - results["success"] = False |
518 | | - results["error"] = str(e) |
519 | | - |
520 | | - # Prepare snapshots to expire |
521 | | - table1_to_expire = table1_snapshot_ids[:2] # [1001, 1002] |
522 | | - table2_to_expire = table2_snapshot_ids[:2] # [2001, 2002] |
523 | | - |
524 | | - results1: Dict[str, Any] = {} |
525 | | - results2: Dict[str, Any] = {} |
526 | | - |
527 | | - # Run concurrent expiration operations |
528 | | - thread1 = threading.Thread(target=expire_from_table, args=(table1, "table1", table1_to_expire, results1)) |
529 | | - thread2 = threading.Thread(target=expire_from_table, args=(table2, "table2", table2_to_expire, results2)) |
530 | | - |
531 | | - thread1.start() |
532 | | - thread2.start() |
533 | | - thread1.join() |
534 | | - thread2.join() |
535 | | - |
536 | | - # CRITICAL ASSERTION: Each table should only receive its own snapshot IDs |
537 | | - # If this fails, it means the thread safety bug exists |
538 | | - |
539 | | - # Table1 should only see table1 snapshot IDs |
540 | | - assert all(sid in table1_snapshot_ids for sid in table1_expire_calls), ( |
541 | | - f"Table1 received unexpected snapshot IDs: {table1_expire_calls} (should only contain {table1_snapshot_ids})" |
542 | | - ) |
543 | | - |
544 | | - # Table2 should only see table2 snapshot IDs |
545 | | - assert all(sid in table2_snapshot_ids for sid in table2_expire_calls), ( |
546 | | - f"Table2 received unexpected snapshot IDs: {table2_expire_calls} (should only contain {table2_snapshot_ids})" |
547 | | - ) |
548 | | - |
549 | | - # Verify no cross-contamination |
550 | | - table1_received_table2_ids = [sid for sid in table1_expire_calls if sid in table2_snapshot_ids] |
551 | | - table2_received_table1_ids = [sid for sid in table2_expire_calls if sid in table1_snapshot_ids] |
552 | | - |
553 | | - assert len(table1_received_table2_ids) == 0, f"Table1 incorrectly received Table2 snapshot IDs: {table1_received_table2_ids}" |
554 | | - |
555 | | - assert len(table2_received_table1_ids) == 0, f"Table2 incorrectly received Table1 snapshot IDs: {table2_received_table1_ids}" |
556 | | - |
557 | | - |
558 | | -def test_batch_expire_snapshots(table_v2_with_extensive_snapshots: Table) -> None: |
559 | | - """Test that batch expiration of multiple snapshots works correctly.""" |
560 | | - # Mock the catalog's commit_table method |
561 | | - table_v2_with_extensive_snapshots.catalog = MagicMock() |
562 | | - table_v2_with_extensive_snapshots.catalog.commit_table.return_value = CommitTableResponse( |
563 | | - metadata=table_v2_with_extensive_snapshots.metadata, metadata_location="test://new_location" |
564 | | - ) |
565 | | - |
566 | | - # Use existing snapshot IDs from fixture data, but filter out protected snapshots |
567 | | - all_snapshots = list(table_v2_with_extensive_snapshots.snapshots()) |
568 | | - snapshot_ids = [snapshot.snapshot_id for snapshot in all_snapshots] |
569 | | - |
570 | | - # Get protected snapshot IDs from refs |
571 | | - protected_snapshot_ids = {ref.snapshot_id for ref in table_v2_with_extensive_snapshots.metadata.refs.values()} |
572 | | - |
573 | | - # Find unprotected snapshots that we can expire |
574 | | - unprotected_snapshot_ids = [sid for sid in snapshot_ids if sid not in protected_snapshot_ids] |
575 | | - |
576 | | - # If we don't have enough unprotected snapshots, skip the test |
577 | | - if len(unprotected_snapshot_ids) < 2: |
578 | | - pytest.skip("Not enough unprotected snapshots available for testing") |
579 | | - |
580 | | - # We'll expire the first two unprotected snapshots in a batch |
581 | | - to_expire = unprotected_snapshot_ids[:2] |
582 | | - |
583 | | - def batch_expire_thread_func(table: Any, snapshot_ids_to_expire: List[int], results: Dict[str, Any]) -> None: |
584 | | - try: |
585 | | - # Expire all snapshots in a single batch operation |
586 | | - table.maintenance.expire_snapshots().by_ids(snapshot_ids_to_expire).commit() |
587 | | - results["success"] = True |
588 | | - except Exception as e: |
589 | | - results["success"] = False |
590 | | - results["error"] = str(e) |
591 | | - |
592 | | - # Prepare result dictionary to capture thread outcome |
593 | | - results: Dict[str, Any] = {} |
594 | | - |
595 | | - # Create thread to expire snapshots |
596 | | - thread = threading.Thread(target=batch_expire_thread_func, args=(table_v2_with_extensive_snapshots, to_expire, results)) |
597 | | - |
598 | | - # Start and join thread |
599 | | - thread.start() |
600 | | - thread.join() |
601 | | - |
602 | | - # Assert that the operation succeeded |
603 | | - assert results.get("success", False), f"Batch expiration failed: {results.get('error', 'Unknown error')}" |
604 | | - |
605 | | - # Verify that commit_table was called once |
606 | | - assert table_v2_with_extensive_snapshots.catalog.commit_table.call_count == 1 |
0 commit comments