Skip to content

Commit a947f0a

Browse files
committed
feat: Optimize upsert process with coarse match filter and vectorized comparisons
1 parent ba65619 commit a947f0a

3 files changed

Lines changed: 311 additions & 26 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,8 @@ def upsert(
836836
)
837837

838838
# get list of rows that exist so we don't have to load the entire target table
839-
matched_predicate = upsert_util.create_match_filter(df, join_cols)
839+
# Use coarse filter for initial scan - exact matching happens in get_rows_to_update()
840+
matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols)
840841

841842
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
842843

pyiceberg/table/upsert_util.py

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import functools
1818
import operator
19+
from typing import Union
1920

2021
import pyarrow as pa
2122
from pyarrow import Table as pyarrow_table
@@ -31,8 +32,18 @@
3132

3233

3334
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
35+
"""
36+
Create an Iceberg BooleanExpression filter that exactly matches rows based on join columns.
37+
38+
For single-column keys, uses an efficient In() predicate.
39+
For composite keys, creates Or(And(...), And(...), ...) for exact row matching.
40+
This function should be used when exact matching is required (e.g., overwrite, insert filtering).
41+
"""
3442
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
3543

44+
if len(unique_keys) == 0:
45+
return AlwaysFalse()
46+
3647
if len(join_cols) == 1:
3748
return In(join_cols[0], unique_keys[0].to_pylist())
3849
else:
@@ -48,17 +59,97 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
4859
return Or(*filters)
4960

5061

62+
def create_coarse_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
63+
"""
64+
Create a coarse Iceberg BooleanExpression filter for initial row scanning.
65+
66+
For single-column keys, uses an efficient In() predicate (exact match).
67+
For composite keys, uses In() per column as a coarse filter (AND of In() predicates),
68+
which may return false positives but is much more efficient than exact matching.
69+
70+
This function should only be used for initial scans where exact matching happens
71+
downstream (e.g., in get_rows_to_update() via the join operation).
72+
"""
73+
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
74+
75+
if len(unique_keys) == 0:
76+
return AlwaysFalse()
77+
78+
if len(join_cols) == 1:
79+
return In(join_cols[0], unique_keys[0].to_pylist())
80+
else:
81+
# For composite keys: use In() per column as a coarse filter
82+
# This is more efficient than creating Or(And(...), And(...), ...) for each row
83+
# May include false positives, but fine-grained matching happens downstream
84+
column_filters = []
85+
for col in join_cols:
86+
unique_values = pc.unique(unique_keys[col]).to_pylist()
87+
column_filters.append(In(col, unique_values))
88+
return functools.reduce(operator.and_, column_filters)
89+
90+
5191
def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
5292
"""Check for duplicate rows in a PyArrow table based on the join columns."""
5393
return len(df.select(join_cols).group_by(join_cols).aggregate([([], "count_all")]).filter(pc.field("count_all") > 1)) > 0
5494

5595

96+
def _compare_columns_vectorized(
97+
source_col: Union[pa.Array, pa.ChunkedArray], target_col: Union[pa.Array, pa.ChunkedArray]
98+
) -> pa.Array:
99+
"""
100+
Vectorized comparison of two columns, returning a boolean array where True means values differ.
101+
102+
Handles struct types recursively by comparing each nested field.
103+
Handles null values correctly: null != non-null is True, null == null is True (no update needed).
104+
"""
105+
col_type = source_col.type
106+
107+
if pa.types.is_struct(col_type):
108+
# PyArrow cannot directly compare struct columns, so we recursively compare each field
109+
diff_masks = []
110+
for i, field in enumerate(col_type):
111+
src_field = pc.struct_field(source_col, [i])
112+
tgt_field = pc.struct_field(target_col, [i])
113+
field_diff = _compare_columns_vectorized(src_field, tgt_field)
114+
diff_masks.append(field_diff)
115+
116+
if not diff_masks:
117+
# Empty struct - no fields to compare, so no differences
118+
return pa.array([False] * len(source_col), type=pa.bool_())
119+
120+
return functools.reduce(pc.or_, diff_masks)
121+
122+
elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_map(col_type):
123+
# For list/map types, fall back to Python comparison as PyArrow doesn't support vectorized comparison
124+
# This is still faster than the original row-by-row approach since we batch the conversion
125+
source_py = source_col.to_pylist()
126+
target_py = target_col.to_pylist()
127+
return pa.array([s != t for s, t in zip(source_py, target_py, strict=True)], type=pa.bool_())
128+
129+
else:
130+
# For primitive types, use vectorized not_equal
131+
# Handle nulls: not_equal returns null when comparing with null
132+
# We need: null vs non-null = different (True), null vs null = same (False)
133+
diff = pc.not_equal(source_col, target_col)
134+
source_null = pc.is_null(source_col)
135+
target_null = pc.is_null(target_col)
136+
137+
# XOR of null masks: True if exactly one is null (meaning they differ)
138+
null_diff = pc.xor(source_null, target_null)
139+
140+
# Combine: different if values differ OR exactly one is null
141+
# Fill null comparison results with False (both non-null but comparison returned null shouldn't happen,
142+
# but if it does, treat as no difference)
143+
diff_filled = pc.fill_null(diff, False)
144+
return pc.or_(diff_filled, null_diff)
145+
146+
56147
def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols: list[str]) -> pa.Table:
57148
"""
58149
Return a table with rows that need to be updated in the target table based on the join columns.
59150
151+
Uses vectorized PyArrow operations for efficient comparison, avoiding row-by-row Python loops.
60152
The table is joined on the identifier columns, and then checked if there are any updated rows.
61-
Those are selected and everything is renamed correctly.
62153
"""
63154
all_columns = set(source_table.column_names)
64155
join_cols_set = set(join_cols)
@@ -69,13 +160,13 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
69160
raise ValueError("Target table has duplicate rows, aborting upsert")
70161

71162
if len(target_table) == 0:
72-
# When the target table is empty, there is nothing to update :)
163+
# When the target table is empty, there is nothing to update
164+
return source_table.schema.empty_table()
165+
166+
if len(non_key_cols) == 0:
167+
# No non-key columns to compare, all matched rows are "updates" but with no changes
73168
return source_table.schema.empty_table()
74169

75-
# We need to compare non_key_cols in Python as PyArrow
76-
# 1. Cannot do a join when non-join columns have complex types
77-
# 2. Cannot compare columns with complex types
78-
# See: https://github.com/apache/arrow/issues/35785
79170
SOURCE_INDEX_COLUMN_NAME = "__source_index"
80171
TARGET_INDEX_COLUMN_NAME = "__target_index"
81172

@@ -100,25 +191,32 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
100191
# Step 3: Perform an inner join to find which rows from source exist in target
101192
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
102193

103-
# Step 4: Compare all rows using Python
104-
to_update_indices = []
105-
for source_idx, target_idx in zip(
106-
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(),
107-
matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist(),
108-
strict=True,
109-
):
110-
source_row = source_table.slice(source_idx, 1)
111-
target_row = target_table.slice(target_idx, 1)
112-
113-
for key in non_key_cols:
114-
source_val = source_row.column(key)[0].as_py()
115-
target_val = target_row.column(key)[0].as_py()
116-
if source_val != target_val:
117-
to_update_indices.append(source_idx)
118-
break
119-
120-
# Step 5: Take rows from source table using the indices and cast to target schema
121-
if to_update_indices:
194+
if len(matching_indices) == 0:
195+
# No matching rows found
196+
return source_table.schema.empty_table()
197+
198+
# Step 4: Take matched rows in batch (vectorized - single operation)
199+
source_indices = matching_indices[SOURCE_INDEX_COLUMN_NAME]
200+
target_indices = matching_indices[TARGET_INDEX_COLUMN_NAME]
201+
202+
matched_source = source_table.take(source_indices)
203+
matched_target = target_table.take(target_indices)
204+
205+
# Step 5: Vectorized comparison per column
206+
diff_masks = []
207+
for col in non_key_cols:
208+
source_col = matched_source.column(col)
209+
target_col = matched_target.column(col)
210+
col_diff = _compare_columns_vectorized(source_col, target_col)
211+
diff_masks.append(col_diff)
212+
213+
# Step 6: Combine masks with OR (any column different = needs update)
214+
combined_mask = functools.reduce(pc.or_, diff_masks)
215+
216+
# Step 7: Filter to get indices of rows that need updating
217+
to_update_indices = pc.filter(source_indices, combined_mask)
218+
219+
if len(to_update_indices) > 0:
122220
return source_table.take(to_update_indices)
123221
else:
124222
return source_table.schema.empty_table()

tests/table/test_upsert.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,189 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None:
885885
for snapshot in snapshots[initial_snapshot_count:]:
886886
assert snapshot.summary is not None
887887
assert snapshot.summary.additional_properties.get("test_prop") == "test_value"
888+
889+
890+
def test_coarse_match_filter_composite_key() -> None:
891+
"""
892+
Test that create_coarse_match_filter produces efficient In() predicates for composite keys.
893+
"""
894+
from pyiceberg.table.upsert_util import create_coarse_match_filter, create_match_filter
895+
896+
# Create a table with composite key that has overlapping values
897+
# (1, 'x'), (2, 'y'), (1, 'z') - exact filter should have 3 conditions
898+
# coarse filter should have In(a, [1,2]) AND In(b, ['x','y','z'])
899+
data = [
900+
{"a": 1, "b": "x", "val": 1},
901+
{"a": 2, "b": "y", "val": 2},
902+
{"a": 1, "b": "z", "val": 3},
903+
]
904+
schema = pa.schema([pa.field("a", pa.int32()), pa.field("b", pa.string()), pa.field("val", pa.int32())])
905+
table = pa.Table.from_pylist(data, schema=schema)
906+
907+
exact_filter = create_match_filter(table, ["a", "b"])
908+
coarse_filter = create_coarse_match_filter(table, ["a", "b"])
909+
910+
# Exact filter is an Or of And conditions
911+
assert "Or" in str(exact_filter)
912+
913+
# Coarse filter is an And of In conditions
914+
assert "And" in str(coarse_filter)
915+
assert "In" in str(coarse_filter)
916+
917+
918+
def test_vectorized_comparison_primitives() -> None:
919+
"""Test vectorized comparison with primitive types."""
920+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
921+
922+
# Test integers
923+
source = pa.array([1, 2, 3, 4])
924+
target = pa.array([1, 2, 5, 4])
925+
diff = _compare_columns_vectorized(source, target)
926+
assert diff.to_pylist() == [False, False, True, False]
927+
928+
# Test strings
929+
source = pa.array(["a", "b", "c"])
930+
target = pa.array(["a", "x", "c"])
931+
diff = _compare_columns_vectorized(source, target)
932+
assert diff.to_pylist() == [False, True, False]
933+
934+
# Test floats
935+
source = pa.array([1.0, 2.5, 3.0])
936+
target = pa.array([1.0, 2.5, 3.1])
937+
diff = _compare_columns_vectorized(source, target)
938+
assert diff.to_pylist() == [False, False, True]
939+
940+
941+
def test_vectorized_comparison_nulls() -> None:
942+
"""Test vectorized comparison handles nulls correctly."""
943+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
944+
945+
# null vs non-null = different
946+
source = pa.array([1, None, 3])
947+
target = pa.array([1, 2, 3])
948+
diff = _compare_columns_vectorized(source, target)
949+
assert diff.to_pylist() == [False, True, False]
950+
951+
# non-null vs null = different
952+
source = pa.array([1, 2, 3])
953+
target = pa.array([1, None, 3])
954+
diff = _compare_columns_vectorized(source, target)
955+
assert diff.to_pylist() == [False, True, False]
956+
957+
# null vs null = same (no update needed)
958+
source = pa.array([1, None, 3])
959+
target = pa.array([1, None, 3])
960+
diff = _compare_columns_vectorized(source, target)
961+
assert diff.to_pylist() == [False, False, False]
962+
963+
964+
def test_vectorized_comparison_structs() -> None:
965+
"""Test vectorized comparison with nested struct types."""
966+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
967+
968+
struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())])
969+
970+
# Same structs
971+
source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type)
972+
target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type)
973+
diff = _compare_columns_vectorized(source, target)
974+
assert diff.to_pylist() == [False, False]
975+
976+
# Different struct values
977+
source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}], type=struct_type)
978+
target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "c"}], type=struct_type)
979+
diff = _compare_columns_vectorized(source, target)
980+
assert diff.to_pylist() == [False, True]
981+
982+
983+
def test_vectorized_comparison_nested_structs() -> None:
984+
"""Test vectorized comparison with deeply nested struct types."""
985+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
986+
987+
inner_struct = pa.struct([("val", pa.int32())])
988+
outer_struct = pa.struct([("inner", inner_struct), ("name", pa.string())])
989+
990+
source = pa.array(
991+
[{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 2}, "name": "b"}],
992+
type=outer_struct,
993+
)
994+
target = pa.array(
995+
[{"inner": {"val": 1}, "name": "a"}, {"inner": {"val": 3}, "name": "b"}],
996+
type=outer_struct,
997+
)
998+
diff = _compare_columns_vectorized(source, target)
999+
assert diff.to_pylist() == [False, True]
1000+
1001+
1002+
def test_vectorized_comparison_lists() -> None:
1003+
"""Test vectorized comparison with list types (falls back to Python comparison)."""
1004+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
1005+
1006+
list_type = pa.list_(pa.int32())
1007+
1008+
source = pa.array([[1, 2], [3, 4]], type=list_type)
1009+
target = pa.array([[1, 2], [3, 5]], type=list_type)
1010+
diff = _compare_columns_vectorized(source, target)
1011+
assert diff.to_pylist() == [False, True]
1012+
1013+
1014+
def test_get_rows_to_update_no_non_key_cols() -> None:
1015+
"""Test get_rows_to_update when all columns are key columns."""
1016+
from pyiceberg.table.upsert_util import get_rows_to_update
1017+
1018+
# All columns are key columns, so no non-key columns to compare
1019+
source = pa.Table.from_pydict({"id": [1, 2, 3]})
1020+
target = pa.Table.from_pydict({"id": [1, 2, 3]})
1021+
rows = get_rows_to_update(source, target, ["id"])
1022+
assert len(rows) == 0
1023+
1024+
1025+
def test_upsert_with_list_field(catalog: Catalog) -> None:
1026+
"""Test upsert with list type as non-key column."""
1027+
from pyiceberg.types import ListType
1028+
1029+
identifier = "default.test_upsert_with_list_field"
1030+
_drop_table(catalog, identifier)
1031+
1032+
schema = Schema(
1033+
NestedField(1, "id", IntegerType(), required=True),
1034+
NestedField(
1035+
2,
1036+
"tags",
1037+
ListType(element_id=3, element_type=StringType(), element_required=False),
1038+
required=False,
1039+
),
1040+
identifier_field_ids=[1],
1041+
)
1042+
1043+
tbl = catalog.create_table(identifier, schema=schema)
1044+
1045+
arrow_schema = pa.schema(
1046+
[
1047+
pa.field("id", pa.int32(), nullable=False),
1048+
pa.field("tags", pa.list_(pa.large_string()), nullable=True),
1049+
]
1050+
)
1051+
1052+
initial_data = pa.Table.from_pylist(
1053+
[
1054+
{"id": 1, "tags": ["a", "b"]},
1055+
{"id": 2, "tags": ["c"]},
1056+
],
1057+
schema=arrow_schema,
1058+
)
1059+
tbl.append(initial_data)
1060+
1061+
# Update with changed list
1062+
update_data = pa.Table.from_pylist(
1063+
[
1064+
{"id": 1, "tags": ["a", "b"]}, # Same - no update
1065+
{"id": 2, "tags": ["c", "d"]}, # Changed - should update
1066+
{"id": 3, "tags": ["e"]}, # New - should insert
1067+
],
1068+
schema=arrow_schema,
1069+
)
1070+
1071+
res = tbl.upsert(update_data, join_cols=["id"])
1072+
assert res.rows_updated == 1
1073+
assert res.rows_inserted == 1

0 commit comments

Comments
 (0)