From ea9065e3b8e66f914631e51559bfecf7d9076bf0 Mon Sep 17 00:00:00 2001 From: Ayush Patel Date: Wed, 29 Apr 2026 12:34:06 -0500 Subject: [PATCH] feat: Add merge() to Table and Transaction Atomic delete-insert merge by join columns using per-column In filters for file pruning and in-memory anti-join for row-level correctness, committed as a single OVERWRITE snapshot. Unlike upsert(), does not enforce uniqueness on source or target. --- pyiceberg/table/__init__.py | 162 +++++ tests/benchmark/test_merge_filter.py | 157 +++++ tests/table/test_merge.py | 858 +++++++++++++++++++++++++++ 3 files changed, 1177 insertions(+) create mode 100644 tests/benchmark/test_merge_filter.py create mode 100644 tests/table/test_merge.py diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9071f99e10..cd0bf5ad92 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -178,6 +178,14 @@ class UpsertResult: rows_inserted: int = 0 +@dataclass() +class MergeResult: + """Summary of the merge operation.""" + + rows_deleted: int = 0 + rows_inserted: int = 0 + + class TableProperties: PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes" PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB @@ -885,6 +893,126 @@ def upsert( return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def merge( + self, + df: pa.Table, + join_cols: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: Optional[str] = MAIN_BRANCH, + check_duplicate_keys: bool = False, + ) -> MergeResult: + """Atomic delete-insert merge by join columns. + + Deletes all target rows matching the source data's join column + values and inserts the source rows, all in a single OVERWRITE + snapshot. + + Uses per-column ``In`` filters for file pruning (O(sum of + cardinalities) instead of O(product)), then an in-memory + anti-join for row-level correctness. + + Unlike ``upsert()``, does not enforce uniqueness on source or + target by default. + + Args: + df: The Arrow dataframe containing replacement rows. + join_cols: Columns used to match source rows against target rows. + snapshot_properties: Custom properties to be added to the snapshot summary. + branch: Branch reference to run the operation. + check_duplicate_keys: If True, raise ValueError when the source + data contains duplicate key tuples based on the join columns. + This is a data quality guard, not a correctness requirement - + merge() produces correct results with duplicate keys. + + Returns: + A MergeResult with row counts (deleted from target, inserted from source). + """ + try: + import pyarrow as pa + import pyarrow.compute as pc + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + import functools + + from pyiceberg.expressions import In + from pyiceberg.io.pyarrow import ArrowScan, _check_pyarrow_schema_compatible, _dataframe_to_data_files + from pyiceberg.table import upsert_util + + if not isinstance(df, pa.Table): + raise ValueError(f"Expected PyArrow table, got: {df}") + + if not join_cols: + raise ValueError("join_cols must be a non-empty list of column names.") + + missing = set(join_cols) - set(df.column_names) + if missing: + raise ValueError(f"join_cols not found in source data: {missing}") + + if df.num_rows == 0: + return MergeResult() + + if check_duplicate_keys and upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source data based on join columns.") + + downcast_ns = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + self.table_metadata.schema(), + provided_schema=df.schema, + downcast_ns_timestamp_to_us=downcast_ns, + format_version=self.table_metadata.format_version, + ) + + # Step 1: Build per-column In filters for file pruning. + # O(sum of cardinalities) instead of O(product). + # Over-approximates the match set, which is fine - row-level + # correctness is enforced by the anti-join in step 3. + in_filters: list[BooleanExpression] = [In(col, pc.unique(df[col]).to_pylist()) for col in join_cols] + candidate_filter: BooleanExpression = functools.reduce(And, in_filters) + + # Step 2: Find candidate files via manifest pruning. + scan = self._scan(row_filter=candidate_filter, case_sensitive=True) + if branch is not None and branch in self.table_metadata.refs: + scan = scan.use_ref(branch) + tasks = list(scan.plan_files()) + + if not tasks: + # No files overlap - just append. + self.append(df, snapshot_properties=snapshot_properties, branch=branch) + return MergeResult(rows_inserted=df.num_rows) + + # Step 3: Read ALL rows from candidate files, anti-join to keep + # non-matching rows. The candidate_filter was only for file + # pruning - row-level correctness comes from the anti-join. + arrow_scan = ArrowScan( + self.table_metadata, + self._table.io, + projected_schema=self.table_metadata.schema(), + row_filter=ALWAYS_TRUE, + case_sensitive=True, + ) + target_data = arrow_scan.to_table(tasks) + source_keys = df.select(join_cols) + + kept_rows = target_data.join(source_keys, keys=join_cols, join_type="left anti") + rows_deleted = target_data.num_rows - kept_rows.num_rows + new_content = pa.concat_tables([kept_rows, df], promote_options="default") + + # Step 4: Atomic single-snapshot commit. + # Delete old files, append rewritten content. + with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).overwrite() as overwrite_op: + for task in tasks: + overwrite_op.delete_data_file(task.file) + for data_file in _dataframe_to_data_files( + table_metadata=self.table_metadata, + df=new_content, + io=self._table.io, + write_uuid=overwrite_op.commit_uuid, + ): + overwrite_op.append_data_file(data_file) + + return MergeResult(rows_deleted=rows_deleted, rows_inserted=df.num_rows) + def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True ) -> None: @@ -1415,6 +1543,40 @@ def upsert( branch=branch, ) + def merge( + self, + df: pa.Table, + join_cols: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: Optional[str] = MAIN_BRANCH, + check_duplicate_keys: bool = False, + ) -> MergeResult: + """Atomic delete-insert merge by join columns. + + Unlike ``upsert()``, does not enforce uniqueness on source or + target by default. + + Args: + df: The Arrow dataframe containing replacement rows. + join_cols: Columns used to match source rows against target rows. + snapshot_properties: Custom properties to be added to the snapshot summary. + branch: Branch reference to run the operation. + check_duplicate_keys: If True, raise ValueError when the source + data contains duplicate key tuples based on the join columns. + This is a data quality guard, not a correctness requirement. + + Returns: + A MergeResult with row counts (deleted from target, inserted from source). + """ + with self.transaction() as tx: + return tx.merge( + df=df, + join_cols=join_cols, + snapshot_properties=snapshot_properties, + branch=branch, + check_duplicate_keys=check_duplicate_keys, + ) + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None: """ Shorthand API for appending a PyArrow table to the table. diff --git a/tests/benchmark/test_merge_filter.py b/tests/benchmark/test_merge_filter.py new file mode 100644 index 0000000000..46549247d5 --- /dev/null +++ b/tests/benchmark/test_merge_filter.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""End-to-end benchmark: merge() vs create_match_filter + overwrite(). + +Compares the full write path (filter construction + file I/O + snapshot commit) +between the new merge() implementation and the previous approach. + +Usage: + poetry run pytest tests/benchmark/test_merge_filter.py -v -s -m benchmark +""" + +import gc +import itertools +import timeit +import tracemalloc +from pathlib import PosixPath +from typing import Any, Callable + +import pyarrow as pa +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.table.upsert_util import create_match_filter +from pyiceberg.types import IntegerType, NestedField, StringType +from tests.catalog.test_base import InMemoryCatalog + + +def _make_schema(col_cardinalities: dict[str, int]) -> Schema: + fields = [] + for i, col in enumerate(col_cardinalities): + field_type = IntegerType() if col == "date_id" else StringType() + fields.append(NestedField(i + 1, col, field_type, required=True)) + fields.append(NestedField(len(col_cardinalities) + 1, "v", IntegerType(), required=True)) + return Schema(*fields) + + +def _build_table(col_cardinalities: dict[str, int]) -> tuple[pa.Table, list[str], Schema]: + from pyiceberg.io.pyarrow import schema_to_pyarrow + + schema = _make_schema(col_cardinalities) + arrow_schema = schema_to_pyarrow(schema) + + vals: list[list[Any]] = [] + for col, card in col_cardinalities.items(): + if col == "date_id": + vals.append(list(range(20260101, 20260101 + card))) + else: + vals.append([f"{col}_{i}" for i in range(card)]) + combos = list(itertools.product(*vals)) + data = {col: [c[i] for c in combos] for i, col in enumerate(col_cardinalities)} + data["v"] = list(range(len(combos))) + return pa.table(data, schema=arrow_schema), list(col_cardinalities.keys()), schema + + +def _fresh_table(catalog: Catalog, name: str, schema: Schema, data: pa.Table) -> Table: + ident = f"default.{name}" + try: + catalog.drop_table(ident) + except NoSuchTableError: + pass + tbl = catalog.create_table(ident, schema=schema) + tbl.append(data) + return tbl + + +def _measure(fn: Callable[[], Any], runs: int = 3) -> tuple[float, int]: + """Returns (avg_seconds, peak_memory_bytes).""" + times = [] + peak = 0 + for _ in range(runs): + gc.collect() + tracemalloc.start() + t0 = timeit.default_timer() + fn() + times.append(timeit.default_timer() - t0) + _, p = tracemalloc.get_traced_memory() + tracemalloc.stop() + peak = max(peak, p) + return sum(times) / len(times), peak + + +def _fmt(secs: float, mem_bytes: int) -> str: + mem = f"{mem_bytes / 1024:.0f} KB" if mem_bytes < 1048576 else f"{mem_bytes / 1048576:.1f} MB" + return f"{secs * 1000:.0f} ms, peak {mem}" + + +COLS = {"date_id": 252, "account": 100} # 25,200 target rows + + +@pytest.mark.benchmark +@pytest.mark.parametrize("n_source", [100, 5000]) +def test_e2e_merge(n_source: int, tmp_path: PosixPath) -> None: + """End-to-end merge(): per-column In + anti-join + single OVERWRITE snapshot.""" + target_data, join_cols, schema = _build_table(COLS) + + catalog = InMemoryCatalog("bench", warehouse=str(tmp_path)) + catalog.create_namespace("default") + tbl = _fresh_table(catalog, f"merge_{n_source}", schema, target_data) + + source_dict = {col: target_data.column(col).to_pylist()[:n_source] for col in target_data.column_names} + source_dict["v"] = [x + 9000 for x in source_dict["v"]] + source = pa.table(source_dict, schema=target_data.schema) + + avg, peak = _measure(lambda: tbl.merge(source, join_cols=join_cols)) + print(f"\n merge(): {target_data.num_rows:,} target, {n_source:,} source -> {_fmt(avg, peak)}") + + +@pytest.mark.benchmark +def test_e2e_overwrite_100src(tmp_path: PosixPath) -> None: + """End-to-end overwrite() with 100 source rows.""" + target_data, join_cols, schema = _build_table(COLS) + + catalog = InMemoryCatalog("bench", warehouse=str(tmp_path)) + catalog.create_namespace("default") + tbl = _fresh_table(catalog, "overwrite_100", schema, target_data) + + source_dict = {col: target_data.column(col).to_pylist()[:100] for col in target_data.column_names} + source_dict["v"] = [x + 9000 for x in source_dict["v"]] + source = pa.table(source_dict, schema=target_data.schema) + + avg, peak = _measure(lambda: (tbl.overwrite(source, overwrite_filter=create_match_filter(source, join_cols)))) + print(f"\n overwrite(): {target_data.num_rows:,} target, 100 source -> {_fmt(avg, peak)}") + + +@pytest.mark.benchmark +def test_e2e_overwrite_5ksrc_filter_only(tmp_path: PosixPath) -> None: + """At 5,000 source rows, just constructing the filter takes seconds. + + We only measure filter construction here because the full overwrite() + with a 20,000-node expression tree causes process termination during + manifest evaluation. + """ + target_data, join_cols, schema = _build_table(COLS) + + source_dict = {col: target_data.column(col).to_pylist()[:5000] for col in target_data.column_names} + source_dict["v"] = [x + 9000 for x in source_dict["v"]] + source = pa.table(source_dict, schema=target_data.schema) + + avg, peak = _measure(lambda: create_match_filter(source, join_cols), runs=1) + print(f"\n create_match_filter only (no overwrite): 5,000 source -> {_fmt(avg, peak)}") diff --git a/tests/table/test_merge.py b/tests/table/test_merge.py new file mode 100644 index 0000000000..c168c22cc1 --- /dev/null +++ b/tests/table/test_merge.py @@ -0,0 +1,858 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import warnings +from pathlib import PosixPath + +import pyarrow as pa +import pytest + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.schema import Schema +from pyiceberg.table import MergeResult +from pyiceberg.table.snapshots import Operation +from pyiceberg.types import IntegerType, NestedField, StringType +from tests.catalog.test_base import InMemoryCatalog + + +@pytest.fixture +def catalog(tmp_path: PosixPath) -> InMemoryCatalog: + catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix()) + catalog.create_namespace("default") + return catalog + + +def _drop(catalog: Catalog, ident: str) -> None: + try: + catalog.drop_table(ident) + except NoSuchTableError: + pass + + +SCHEMA = Schema( + NestedField(1, "user_id", IntegerType(), required=True), + NestedField(2, "name", StringType(), required=True), + NestedField(3, "score", IntegerType(), required=True), +) + +ARROW = pa.schema( + [ + pa.field("user_id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + pa.field("score", pa.int32(), nullable=False), + ] +) + + +# ==================== BASIC MERGE BEHAVIOR ==================== + + +def test_merge_replaces_matching_rows(catalog: Catalog) -> None: + ident = "default.merge_replace" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + ], + schema=ARROW, + ) + ) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + {"user_id": 2, "name": "Bob", "score": 250}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 3 + assert result.column("score").to_pylist() == [150, 250, 300] + + +def test_merge_inserts_new_rows(catalog: Catalog) -> None: + ident = "default.merge_insert" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [150, 200] + + +def test_merge_all_new_rows(catalog: Catalog) -> None: + """Source has no overlap with target - all rows inserted.""" + ident = "default.merge_all_new" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 99, "name": "New", "score": 999}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("user_id").to_pylist() == [1, 99] + + +# ==================== UNIQUENESS: THE KEY DIFFERENCE FROM UPSERT ==================== + + +def test_merge_allows_duplicate_source_keys(catalog: Catalog) -> None: + """upsert() would reject this - merge() allows it.""" + ident = "default.merge_dup_src" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + # Two rows with same key - upsert raises ValueError + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice V1", "score": 150}, + {"user_id": 1, "name": "Alice V2", "score": 200}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow() + assert result.num_rows == 2 + assert set(result.column("score").to_pylist()) == {150, 200} + + +def test_merge_allows_duplicate_target_keys(catalog: Catalog) -> None: + """upsert() would reject this - merge() allows it.""" + ident = "default.merge_dup_tgt" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + # Target has duplicates + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 1, "name": "Alice V2", "score": 110}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ) + ) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 999}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [999, 200] + + +# ==================== COMPOSITE KEYS ==================== + + +def test_merge_composite_join_cols(catalog: Catalog) -> None: + ident = "default.merge_composite" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 1, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ) + ) + + # Only replace (user_id=1, name="Alice") + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 999}, + ], + schema=ARROW, + ), + join_cols=["user_id", "name"], + ) + + result = tbl.scan().to_arrow().sort_by("name") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [999, 200] + + +def test_merge_three_join_cols(catalog: Catalog) -> None: + """Three-column composite key (realistic ETL scenario).""" + schema = Schema( + NestedField(1, "date_id", IntegerType(), required=True), + NestedField(2, "account", StringType(), required=True), + NestedField(3, "security", StringType(), required=True), + NestedField(4, "value", IntegerType(), required=True), + ) + arrow = pa.schema( + [ + pa.field("date_id", pa.int32(), nullable=False), + pa.field("account", pa.string(), nullable=False), + pa.field("security", pa.string(), nullable=False), + pa.field("value", pa.int32(), nullable=False), + ] + ) + ident = "default.merge_3col" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=schema) + + tbl.append( + pa.Table.from_pylist( + [ + {"date_id": 20260101, "account": "A", "security": "S1", "value": 100}, + {"date_id": 20260101, "account": "A", "security": "S2", "value": 200}, + {"date_id": 20260101, "account": "B", "security": "S1", "value": 300}, + ], + schema=arrow, + ) + ) + + # Replace only (20260101, A, S1) + tbl.merge( + pa.Table.from_pylist( + [ + {"date_id": 20260101, "account": "A", "security": "S1", "value": 999}, + ], + schema=arrow, + ), + join_cols=["date_id", "account", "security"], + ) + + result = tbl.scan().to_arrow().sort_by([("account", "ascending"), ("security", "ascending")]) + assert result.num_rows == 3 + assert result.column("value").to_pylist() == [999, 200, 300] + + +# ==================== EDGE CASES ==================== + + +def test_merge_into_empty_table(catalog: Catalog) -> None: + ident = "default.merge_empty_tgt" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [100, 200] + + +def test_merge_empty_source_is_noop(catalog: Catalog) -> None: + ident = "default.merge_empty_src" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + tbl.merge(pa.Table.from_pylist([], schema=ARROW), join_cols=["user_id"]) + + result = tbl.scan().to_arrow() + assert result.num_rows == 1 + assert result.column("score").to_pylist() == [100] + + +def test_merge_idempotent(catalog: Catalog) -> None: + """Running merge twice with the same data produces the same result.""" + ident = "default.merge_idempotent" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ) + ) + + source = pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + ], + schema=ARROW, + ) + + tbl.merge(source, join_cols=["user_id"]) + tbl.merge(source, join_cols=["user_id"]) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [150, 200] + + +# ==================== SNAPSHOT BEHAVIOR ==================== + + +def test_merge_produces_overwrite_snapshot(catalog: Catalog) -> None: + """merge() should produce a single OVERWRITE snapshot, not DELETE + APPEND.""" + ident = "default.merge_snapshot" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + snapshot_count_before = len(tbl.snapshots()) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + snapshots_after = tbl.snapshots() + # Should add exactly ONE snapshot (OVERWRITE), not two (DELETE + APPEND) + assert len(snapshots_after) == snapshot_count_before + 1 + assert snapshots_after[-1].summary is not None + assert snapshots_after[-1].summary.operation == Operation.OVERWRITE + + +# ==================== VALIDATION ==================== + + +def test_merge_rejects_empty_join_cols(catalog: Catalog) -> None: + ident = "default.merge_no_cols" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + with pytest.raises(ValueError, match="non-empty"): + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ), + join_cols=[], + ) + + +def test_merge_rejects_missing_join_cols(catalog: Catalog) -> None: + ident = "default.merge_bad_col" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + with pytest.raises(ValueError, match="not found in source"): + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ), + join_cols=["nonexistent"], + ) + + +def test_merge_rejects_non_arrow_input(catalog: Catalog) -> None: + ident = "default.merge_bad_type" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + with pytest.raises(ValueError, match="Expected PyArrow"): + tbl.merge("not a table", join_cols=["user_id"]) + + +# ==================== check_duplicate_keys ==================== + + +def test_merge_check_duplicate_keys_raises(catalog: Catalog) -> None: + """check_duplicate_keys=True raises on duplicate source keys.""" + ident = "default.merge_dup_check" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append(pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice", "score": 100}, + ], schema=ARROW)) + + with pytest.raises(ValueError, match="Duplicate rows"): + tbl.merge( + pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice V1", "score": 150}, + {"user_id": 1, "name": "Alice V2", "score": 200}, + ], schema=ARROW), + join_cols=["user_id"], + check_duplicate_keys=True, + ) + + +def test_merge_check_duplicate_keys_allows_unique(catalog: Catalog) -> None: + """check_duplicate_keys=True passes when source keys are unique.""" + ident = "default.merge_dup_check_ok" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append(pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice", "score": 100}, + ], schema=ARROW)) + + tbl.merge( + pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice", "score": 150}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], schema=ARROW), + join_cols=["user_id"], + check_duplicate_keys=True, + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [150, 200] + + +def test_merge_default_allows_duplicate_keys(catalog: Catalog) -> None: + """Default (check_duplicate_keys=False) allows duplicates.""" + ident = "default.merge_default_dup" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append(pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice", "score": 100}, + ], schema=ARROW)) + + tbl.merge( + pa.Table.from_pylist([ + {"user_id": 1, "name": "Alice V1", "score": 150}, + {"user_id": 1, "name": "Alice V2", "score": 200}, + ], schema=ARROW), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow() + assert result.num_rows == 2 + + +# ==================== CORRECTNESS: OVER-APPROXIMATION + ANTI-JOIN ==================== + + +def test_merge_preserves_non_matching_rows_in_same_file(catalog: Catalog) -> None: + """The per-column In filter over-approximates: In(a,[1,2]) AND In(b,[x,y]) + matches (1,y) and (2,x) even if they're not in the source. + + The anti-join must correctly keep those rows. + """ + schema = Schema( + NestedField(1, "a", IntegerType(), required=True), + NestedField(2, "b", StringType(), required=True), + NestedField(3, "val", IntegerType(), required=True), + ) + arrow = pa.schema( + [ + pa.field("a", pa.int32(), nullable=False), + pa.field("b", pa.string(), nullable=False), + pa.field("val", pa.int32(), nullable=False), + ] + ) + ident = "default.merge_overapprox" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=schema) + + tbl.append( + pa.Table.from_pylist( + [ + {"a": 1, "b": "x", "val": 100}, + {"a": 1, "b": "y", "val": 200}, # over-approx match, NOT in source + {"a": 2, "b": "x", "val": 300}, # over-approx match, NOT in source + {"a": 2, "b": "y", "val": 400}, + {"a": 3, "b": "z", "val": 500}, # completely unrelated + ], + schema=arrow, + ) + ) + + # Source only has (1,x) and (2,y) - not (1,y) or (2,x) + tbl.merge( + pa.Table.from_pylist( + [ + {"a": 1, "b": "x", "val": 999}, + {"a": 2, "b": "y", "val": 888}, + ], + schema=arrow, + ), + join_cols=["a", "b"], + ) + + result = tbl.scan().to_arrow().sort_by([("a", "ascending"), ("b", "ascending")]) + assert result.num_rows == 5 + assert result.column("val").to_pylist() == [999, 200, 300, 888, 500] + + +def test_merge_preserves_unrelated_rows_in_same_file(catalog: Catalog) -> None: + """Rows in the same data file that don't match any join key must survive.""" + ident = "default.merge_unrelated" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + # All in one file (single append) + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + {"user_id": 4, "name": "Diana", "score": 400}, + {"user_id": 5, "name": "Eve", "score": 500}, + ], + schema=ARROW, + ) + ) + + # Only replace user_id=2 + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 2, "name": "Bob", "score": 999}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 5 + assert result.column("score").to_pylist() == [100, 999, 300, 400, 500] + + +def test_merge_schema_preserved_after_anti_join(catalog: Catalog) -> None: + """Anti-join can strip nullability and field metadata. + The output schema must match the Iceberg table schema exactly.""" + ident = "default.merge_schema" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + ], + schema=ARROW, + ) + ) + + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + # Read back and verify data is correct + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [150, 200] + + # Verify we can merge again (schema must be valid for next write) + tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 2, "name": "Bob", "score": 250}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + result = tbl.scan().to_arrow().sort_by("user_id") + assert result.num_rows == 2 + assert result.column("score").to_pylist() == [150, 250] + + +# ==================== MERGE RESULT ==================== + + +def test_merge_result_empty_df(catalog: Catalog) -> None: + """Empty source returns zero counts and is a no-op.""" + ident = "default.merge_result_empty" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge(pa.Table.from_pylist([], schema=ARROW), join_cols=["user_id"]) + + assert result == MergeResult(rows_deleted=0, rows_inserted=0) + assert tbl.scan().to_arrow().num_rows == 3 + + +def test_merge_result_no_overlap(catalog: Catalog) -> None: + """Source keys don't match any target rows - takes append fast-path.""" + ident = "default.merge_result_no_overlap" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 4, "name": "Dave", "score": 400}, + {"user_id": 5, "name": "Eve", "score": 500}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + assert result == MergeResult(rows_deleted=0, rows_inserted=2) + assert tbl.scan().to_arrow().num_rows == 5 + + +def test_merge_result_full_overlap(catalog: Catalog) -> None: + """Every source key matches a target row - all replaced.""" + ident = "default.merge_result_full_overlap" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + {"user_id": 2, "name": "Bob", "score": 250}, + {"user_id": 3, "name": "Charlie", "score": 350}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + assert result == MergeResult(rows_deleted=3, rows_inserted=3) + assert tbl.scan().to_arrow().num_rows == 3 + + +def test_merge_result_partial_overlap(catalog: Catalog) -> None: + """Some source keys match, some don't.""" + ident = "default.merge_result_partial" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 2, "name": "Bob", "score": 200}, + {"user_id": 3, "name": "Charlie", "score": 300}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 2, "name": "Bob", "score": 250}, + {"user_id": 3, "name": "Charlie", "score": 350}, + {"user_id": 4, "name": "Dave", "score": 400}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + assert result == MergeResult(rows_deleted=2, rows_inserted=3) + assert tbl.scan().to_arrow().num_rows == 4 + + +def test_merge_result_target_duplicates(catalog: Catalog) -> None: + """Target has multiple rows per matching key - delete-insert is N:M, not 1:1.""" + ident = "default.merge_result_target_dups" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + {"user_id": 1, "name": "Alice", "score": 110}, + {"user_id": 1, "name": "Alice", "score": 120}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 999}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + assert result == MergeResult(rows_deleted=3, rows_inserted=1) + assert tbl.scan().to_arrow().num_rows == 1 + + +def test_merge_result_source_duplicates(catalog: Catalog) -> None: + """Source has duplicate keys - rows_inserted counts source verbatim.""" + ident = "default.merge_result_source_dups" + _drop(catalog, ident) + tbl = catalog.create_table(ident, schema=SCHEMA) + + tbl.append( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 100}, + ], + schema=ARROW, + ) + ) + + result = tbl.merge( + pa.Table.from_pylist( + [ + {"user_id": 1, "name": "Alice", "score": 150}, + {"user_id": 1, "name": "Alice", "score": 160}, + {"user_id": 1, "name": "Alice", "score": 170}, + ], + schema=ARROW, + ), + join_cols=["user_id"], + ) + + assert result == MergeResult(rows_deleted=1, rows_inserted=3) + assert tbl.scan().to_arrow().num_rows == 3