Skip to content

Commit 4ab5c93

Browse files
committed
feat: Enhance vectorized comparison to handle struct-level nulls and empty structs
1 parent a947f0a commit 4ab5c93

2 files changed

Lines changed: 60 additions & 8 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717
import functools
1818
import operator
19-
from typing import Union
2019

2120
import pyarrow as pa
2221
from pyarrow import Table as pyarrow_table
@@ -94,7 +93,7 @@ def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
9493

9594

9695
def _compare_columns_vectorized(
97-
source_col: Union[pa.Array, pa.ChunkedArray], target_col: Union[pa.Array, pa.ChunkedArray]
96+
source_col: pa.Array | pa.ChunkedArray, target_col: pa.Array | pa.ChunkedArray
9897
) -> pa.Array:
9998
"""
10099
Vectorized comparison of two columns, returning a boolean array where True means values differ.
@@ -105,6 +104,11 @@ def _compare_columns_vectorized(
105104
col_type = source_col.type
106105

107106
if pa.types.is_struct(col_type):
107+
# Handle struct-level nulls first
108+
source_null = pc.is_null(source_col)
109+
target_null = pc.is_null(target_col)
110+
struct_null_diff = pc.xor(source_null, target_null) # Different if exactly one is null
111+
108112
# PyArrow cannot directly compare struct columns, so we recursively compare each field
109113
diff_masks = []
110114
for i, field in enumerate(col_type):
@@ -114,12 +118,14 @@ def _compare_columns_vectorized(
114118
diff_masks.append(field_diff)
115119

116120
if not diff_masks:
117-
# Empty struct - no fields to compare, so no differences
118-
return pa.array([False] * len(source_col), type=pa.bool_())
121+
# Empty struct - only null differences matter
122+
return struct_null_diff
119123

120-
return functools.reduce(pc.or_, diff_masks)
124+
# Combine field differences with struct-level null differences
125+
field_diff = functools.reduce(pc.or_, diff_masks)
126+
return pc.or_(field_diff, struct_null_diff)
121127

122-
elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_map(col_type):
128+
elif pa.types.is_list(col_type) or pa.types.is_large_list(col_type) or pa.types.is_fixed_size_list(col_type) or pa.types.is_map(col_type):
123129
# For list/map types, fall back to Python comparison as PyArrow doesn't support vectorized comparison
124130
# This is still faster than the original row-by-row approach since we batch the conversion
125131
source_py = source_col.to_pylist()

tests/table/test_upsert.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,7 @@ def test_coarse_match_filter_composite_key() -> None:
892892
Test that create_coarse_match_filter produces efficient In() predicates for composite keys.
893893
"""
894894
from pyiceberg.table.upsert_util import create_coarse_match_filter, create_match_filter
895+
from pyiceberg.expressions import Or, And, In
895896

896897
# Create a table with composite key that has overlapping values
897898
# (1, 'x'), (2, 'y'), (1, 'z') - exact filter should have 3 conditions
@@ -908,10 +909,10 @@ def test_coarse_match_filter_composite_key() -> None:
908909
coarse_filter = create_coarse_match_filter(table, ["a", "b"])
909910

910911
# Exact filter is an Or of And conditions
911-
assert "Or" in str(exact_filter)
912+
assert isinstance(exact_filter, Or)
912913

913914
# Coarse filter is an And of In conditions
914-
assert "And" in str(coarse_filter)
915+
assert isinstance(coarse_filter, And)
915916
assert "In" in str(coarse_filter)
916917

917918

@@ -1071,3 +1072,48 @@ def test_upsert_with_list_field(catalog: Catalog) -> None:
10711072
res = tbl.upsert(update_data, join_cols=["id"])
10721073
assert res.rows_updated == 1
10731074
assert res.rows_inserted == 1
1075+
1076+
1077+
def test_vectorized_comparison_struct_level_nulls() -> None:
1078+
"""Test vectorized comparison handles struct-level nulls correctly (not just field-level nulls)."""
1079+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
1080+
1081+
struct_type = pa.struct([("x", pa.int32()), ("y", pa.string())])
1082+
1083+
# null struct vs non-null struct = different
1084+
source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type)
1085+
target = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type)
1086+
diff = _compare_columns_vectorized(source, target)
1087+
assert diff.to_pylist() == [False, True, False]
1088+
1089+
# non-null struct vs null struct = different
1090+
source = pa.array([{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}], type=struct_type)
1091+
target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type)
1092+
diff = _compare_columns_vectorized(source, target)
1093+
assert diff.to_pylist() == [False, True, False]
1094+
1095+
# null struct vs null struct = same (no update needed)
1096+
source = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type)
1097+
target = pa.array([{"x": 1, "y": "a"}, None, {"x": 3, "y": "c"}], type=struct_type)
1098+
diff = _compare_columns_vectorized(source, target)
1099+
assert diff.to_pylist() == [False, False, False]
1100+
1101+
1102+
def test_vectorized_comparison_empty_struct_with_nulls() -> None:
1103+
"""Test that empty structs with null values are compared correctly."""
1104+
from pyiceberg.table.upsert_util import _compare_columns_vectorized
1105+
1106+
# Empty struct type - edge case where only struct-level null handling matters
1107+
empty_struct_type = pa.struct([])
1108+
1109+
# null vs non-null empty struct = different
1110+
source = pa.array([{}, None, {}], type=empty_struct_type)
1111+
target = pa.array([{}, {}, {}], type=empty_struct_type)
1112+
diff = _compare_columns_vectorized(source, target)
1113+
assert diff.to_pylist() == [False, True, False]
1114+
1115+
# null vs null empty struct = same
1116+
source = pa.array([None, None], type=empty_struct_type)
1117+
target = pa.array([None, None], type=empty_struct_type)
1118+
diff = _compare_columns_vectorized(source, target)
1119+
assert diff.to_pylist() == [False, False]

0 commit comments

Comments
 (0)