Skip to content

Commit db39b67

Browse files
committed
feat: Optimize insert filtering in upsert process using anti-join for matched keys
1 parent 4ab5c93 commit db39b67

1 file changed

Lines changed: 27 additions & 7 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,6 @@ def upsert(
804804
except ModuleNotFoundError as e:
805805
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e
806806

807-
from pyiceberg.io.pyarrow import expression_to_pyarrow
808807
from pyiceberg.table import upsert_util
809808

810809
if join_cols is None:
@@ -855,7 +854,7 @@ def upsert(
855854

856855
batches_to_overwrite = []
857856
overwrite_predicates = []
858-
rows_to_insert = df
857+
matched_target_keys: list[pa.Table] = [] # Accumulate matched keys for insert filtering
859858

860859
for batch in matched_iceberg_record_batches:
861860
rows = pa.Table.from_batches([batch])
@@ -873,13 +872,34 @@ def upsert(
873872
batches_to_overwrite.append(rows_to_update)
874873
overwrite_predicates.append(overwrite_mask_predicate)
875874

875+
# Collect matched keys for insert filtering (will use anti-join after loop)
876876
if when_not_matched_insert_all:
877-
expr_match = upsert_util.create_match_filter(rows, join_cols)
878-
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
879-
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
877+
matched_target_keys.append(rows.select(join_cols))
878+
879+
batch_loop_end = time.perf_counter()
880+
logger.info(
881+
f"Batch processing: {batch_loop_end - batch_loop_start:.3f}s "
882+
f"({batch_count} batches, get_rows_to_update total: {total_rows_to_update_time:.3f}s)"
883+
)
880884

881-
# Filter rows per batch.
882-
rows_to_insert = rows_to_insert.filter(~expr_match_arrow)
885+
# Use anti-join to find rows to insert (replaces per-batch expression filtering)
886+
rows_to_insert = df
887+
if when_not_matched_insert_all and matched_target_keys:
888+
filter_start = time.perf_counter()
889+
# Combine all matched keys and deduplicate
890+
combined_matched_keys = pa.concat_tables(matched_target_keys).group_by(join_cols).aggregate([])
891+
# Cast matched keys to source schema types for join compatibility
892+
source_key_schema = df.select(join_cols).schema
893+
combined_matched_keys = combined_matched_keys.cast(source_key_schema)
894+
# Use anti-join on key columns only (with row indices) to avoid issues with
895+
# struct/list types in non-key columns that PyArrow join doesn't support
896+
row_indices = pa.chunked_array([pa.array(range(len(df)), type=pa.int64())])
897+
source_keys_with_idx = df.select(join_cols).append_column("__row_idx__", row_indices)
898+
not_matched_keys = source_keys_with_idx.join(combined_matched_keys, keys=join_cols, join_type="left anti")
899+
indices_to_keep = not_matched_keys.column("__row_idx__").combine_chunks()
900+
rows_to_insert = df.take(indices_to_keep)
901+
filter_end = time.perf_counter()
902+
logger.info(f"Insert filtering (anti-join): {filter_end - filter_start:.3f}s ({len(combined_matched_keys)} matched keys)")
883903

884904
update_row_cnt = 0
885905
insert_row_cnt = 0

0 commit comments

Comments
 (0)