Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ class UpsertResult:
rows_inserted: int = 0


@dataclass()
class MergeResult:
"""Summary of the merge operation."""

rows_deleted: int = 0
rows_inserted: int = 0


class TableProperties:
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
Expand Down Expand Up @@ -885,6 +893,126 @@ def upsert(

return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)

def merge(
self,
df: pa.Table,
join_cols: List[str],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: Optional[str] = MAIN_BRANCH,
check_duplicate_keys: bool = False,
) -> MergeResult:
"""Atomic delete-insert merge by join columns.

Deletes all target rows matching the source data's join column
values and inserts the source rows, all in a single OVERWRITE
snapshot.

Uses per-column ``In`` filters for file pruning (O(sum of
cardinalities) instead of O(product)), then an in-memory
anti-join for row-level correctness.

Unlike ``upsert()``, does not enforce uniqueness on source or
target by default.

Args:
df: The Arrow dataframe containing replacement rows.
join_cols: Columns used to match source rows against target rows.
snapshot_properties: Custom properties to be added to the snapshot summary.
branch: Branch reference to run the operation.
check_duplicate_keys: If True, raise ValueError when the source
data contains duplicate key tuples based on the join columns.
This is a data quality guard, not a correctness requirement -
merge() produces correct results with duplicate keys.

Returns:
A MergeResult with row counts (deleted from target, inserted from source).
"""
try:
import pyarrow as pa
import pyarrow.compute as pc
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

import functools

from pyiceberg.expressions import In
from pyiceberg.io.pyarrow import ArrowScan, _check_pyarrow_schema_compatible, _dataframe_to_data_files
from pyiceberg.table import upsert_util

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if not join_cols:
raise ValueError("join_cols must be a non-empty list of column names.")

missing = set(join_cols) - set(df.column_names)
if missing:
raise ValueError(f"join_cols not found in source data: {missing}")

if df.num_rows == 0:
return MergeResult()

if check_duplicate_keys and upsert_util.has_duplicate_rows(df, join_cols):
raise ValueError("Duplicate rows found in source data based on join columns.")

downcast_ns = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
provided_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns,
format_version=self.table_metadata.format_version,
)

# Step 1: Build per-column In filters for file pruning.
# O(sum of cardinalities) instead of O(product).
# Over-approximates the match set, which is fine - row-level
# correctness is enforced by the anti-join in step 3.
in_filters: list[BooleanExpression] = [In(col, pc.unique(df[col]).to_pylist()) for col in join_cols]
candidate_filter: BooleanExpression = functools.reduce(And, in_filters)

# Step 2: Find candidate files via manifest pruning.
scan = self._scan(row_filter=candidate_filter, case_sensitive=True)
if branch is not None and branch in self.table_metadata.refs:
scan = scan.use_ref(branch)
tasks = list(scan.plan_files())

if not tasks:
# No files overlap - just append.
self.append(df, snapshot_properties=snapshot_properties, branch=branch)
return MergeResult(rows_inserted=df.num_rows)

# Step 3: Read ALL rows from candidate files, anti-join to keep
# non-matching rows. The candidate_filter was only for file
# pruning - row-level correctness comes from the anti-join.
arrow_scan = ArrowScan(
self.table_metadata,
self._table.io,
projected_schema=self.table_metadata.schema(),
row_filter=ALWAYS_TRUE,
case_sensitive=True,
)
target_data = arrow_scan.to_table(tasks)
source_keys = df.select(join_cols)

kept_rows = target_data.join(source_keys, keys=join_cols, join_type="left anti")
rows_deleted = target_data.num_rows - kept_rows.num_rows
new_content = pa.concat_tables([kept_rows, df], promote_options="default")

# Step 4: Atomic single-snapshot commit.
# Delete old files, append rewritten content.
with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).overwrite() as overwrite_op:
for task in tasks:
overwrite_op.delete_data_file(task.file)
for data_file in _dataframe_to_data_files(
table_metadata=self.table_metadata,
df=new_content,
io=self._table.io,
write_uuid=overwrite_op.commit_uuid,
):
overwrite_op.append_data_file(data_file)

return MergeResult(rows_deleted=rows_deleted, rows_inserted=df.num_rows)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
) -> None:
Expand Down Expand Up @@ -1415,6 +1543,40 @@ def upsert(
branch=branch,
)

def merge(
self,
df: pa.Table,
join_cols: List[str],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: Optional[str] = MAIN_BRANCH,
check_duplicate_keys: bool = False,
) -> MergeResult:
"""Atomic delete-insert merge by join columns.

Unlike ``upsert()``, does not enforce uniqueness on source or
target by default.

Args:
df: The Arrow dataframe containing replacement rows.
join_cols: Columns used to match source rows against target rows.
snapshot_properties: Custom properties to be added to the snapshot summary.
branch: Branch reference to run the operation.
check_duplicate_keys: If True, raise ValueError when the source
data contains duplicate key tuples based on the join columns.
This is a data quality guard, not a correctness requirement.

Returns:
A MergeResult with row counts (deleted from target, inserted from source).
"""
with self.transaction() as tx:
return tx.merge(
df=df,
join_cols=join_cols,
snapshot_properties=snapshot_properties,
branch=branch,
check_duplicate_keys=check_duplicate_keys,
)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> None:
"""
Shorthand API for appending a PyArrow table to the table.
Expand Down
157 changes: 157 additions & 0 deletions tests/benchmark/test_merge_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""End-to-end benchmark: merge() vs create_match_filter + overwrite().

Compares the full write path (filter construction + file I/O + snapshot commit)
between the new merge() implementation and the previous approach.

Usage:
poetry run pytest tests/benchmark/test_merge_filter.py -v -s -m benchmark
"""

import gc
import itertools
import timeit
import tracemalloc
from pathlib import PosixPath
from typing import Any, Callable

import pyarrow as pa
import pytest

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.types import IntegerType, NestedField, StringType
from tests.catalog.test_base import InMemoryCatalog


def _make_schema(col_cardinalities: dict[str, int]) -> Schema:
fields = []
for i, col in enumerate(col_cardinalities):
field_type = IntegerType() if col == "date_id" else StringType()
fields.append(NestedField(i + 1, col, field_type, required=True))
fields.append(NestedField(len(col_cardinalities) + 1, "v", IntegerType(), required=True))
return Schema(*fields)


def _build_table(col_cardinalities: dict[str, int]) -> tuple[pa.Table, list[str], Schema]:
from pyiceberg.io.pyarrow import schema_to_pyarrow

schema = _make_schema(col_cardinalities)
arrow_schema = schema_to_pyarrow(schema)

vals: list[list[Any]] = []
for col, card in col_cardinalities.items():
if col == "date_id":
vals.append(list(range(20260101, 20260101 + card)))
else:
vals.append([f"{col}_{i}" for i in range(card)])
combos = list(itertools.product(*vals))
data = {col: [c[i] for c in combos] for i, col in enumerate(col_cardinalities)}
data["v"] = list(range(len(combos)))
return pa.table(data, schema=arrow_schema), list(col_cardinalities.keys()), schema


def _fresh_table(catalog: Catalog, name: str, schema: Schema, data: pa.Table) -> Table:
ident = f"default.{name}"
try:
catalog.drop_table(ident)
except NoSuchTableError:
pass
tbl = catalog.create_table(ident, schema=schema)
tbl.append(data)
return tbl


def _measure(fn: Callable[[], Any], runs: int = 3) -> tuple[float, int]:
"""Returns (avg_seconds, peak_memory_bytes)."""
times = []
peak = 0
for _ in range(runs):
gc.collect()
tracemalloc.start()
t0 = timeit.default_timer()
fn()
times.append(timeit.default_timer() - t0)
_, p = tracemalloc.get_traced_memory()
tracemalloc.stop()
peak = max(peak, p)
return sum(times) / len(times), peak


def _fmt(secs: float, mem_bytes: int) -> str:
mem = f"{mem_bytes / 1024:.0f} KB" if mem_bytes < 1048576 else f"{mem_bytes / 1048576:.1f} MB"
return f"{secs * 1000:.0f} ms, peak {mem}"


COLS = {"date_id": 252, "account": 100} # 25,200 target rows


@pytest.mark.benchmark
@pytest.mark.parametrize("n_source", [100, 5000])
def test_e2e_merge(n_source: int, tmp_path: PosixPath) -> None:
"""End-to-end merge(): per-column In + anti-join + single OVERWRITE snapshot."""
target_data, join_cols, schema = _build_table(COLS)

catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
catalog.create_namespace("default")
tbl = _fresh_table(catalog, f"merge_{n_source}", schema, target_data)

source_dict = {col: target_data.column(col).to_pylist()[:n_source] for col in target_data.column_names}
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
source = pa.table(source_dict, schema=target_data.schema)

avg, peak = _measure(lambda: tbl.merge(source, join_cols=join_cols))
print(f"\n merge(): {target_data.num_rows:,} target, {n_source:,} source -> {_fmt(avg, peak)}")


@pytest.mark.benchmark
def test_e2e_overwrite_100src(tmp_path: PosixPath) -> None:
"""End-to-end overwrite() with 100 source rows."""
target_data, join_cols, schema = _build_table(COLS)

catalog = InMemoryCatalog("bench", warehouse=str(tmp_path))
catalog.create_namespace("default")
tbl = _fresh_table(catalog, "overwrite_100", schema, target_data)

source_dict = {col: target_data.column(col).to_pylist()[:100] for col in target_data.column_names}
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
source = pa.table(source_dict, schema=target_data.schema)

avg, peak = _measure(lambda: (tbl.overwrite(source, overwrite_filter=create_match_filter(source, join_cols))))
print(f"\n overwrite(): {target_data.num_rows:,} target, 100 source -> {_fmt(avg, peak)}")


@pytest.mark.benchmark
def test_e2e_overwrite_5ksrc_filter_only(tmp_path: PosixPath) -> None:
"""At 5,000 source rows, just constructing the filter takes seconds.

We only measure filter construction here because the full overwrite()
with a 20,000-node expression tree causes process termination during
manifest evaluation.
"""
target_data, join_cols, schema = _build_table(COLS)

source_dict = {col: target_data.column(col).to_pylist()[:5000] for col in target_data.column_names}
source_dict["v"] = [x + 9000 for x in source_dict["v"]]
source = pa.table(source_dict, schema=target_data.schema)

avg, peak = _measure(lambda: create_match_filter(source, join_cols), runs=1)
print(f"\n create_match_filter only (no overwrite): 5,000 source -> {_fmt(avg, peak)}")
Loading
Loading