Skip to content

Commit e36d994

Browse files
committed
feat: Further optimize the filter for big datasets
Remove logging reset pyarrow squash
1 parent db39b67 commit e36d994

3 files changed

Lines changed: 614 additions & 36 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -834,8 +834,9 @@ def upsert(
834834
format_version=self.table_metadata.format_version,
835835
)
836836

837-
# get list of rows that exist so we don't have to load the entire target table
838-
# Use coarse filter for initial scan - exact matching happens in get_rows_to_update()
837+
# Create a coarse filter for the initial scan to reduce the number of rows read.
838+
# This filter is intentionally less precise but faster to evaluate than exact matching.
839+
# Exact key matching happens downstream in get_rows_to_update() via PyArrow joins.
839840
matched_predicate = upsert_util.create_coarse_match_filter(df, join_cols)
840841

841842
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
@@ -854,38 +855,32 @@ def upsert(
854855

855856
batches_to_overwrite = []
856857
overwrite_predicates = []
857-
matched_target_keys: list[pa.Table] = [] # Accumulate matched keys for insert filtering
858+
# Accumulate matched keys for anti-join insert filtering after the batch loop.
859+
# We only store key columns (not full rows) to minimize memory usage.
860+
matched_target_keys: list[pa.Table] = []
858861

859862
for batch in matched_iceberg_record_batches:
860863
rows = pa.Table.from_batches([batch])
861864

862865
if when_matched_update_all:
863-
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
864-
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
865-
# this extra step avoids unnecessary IO and writes
866+
# Check non-key columns to see if values have actually changed.
867+
# We don't want to do a blanket overwrite for matched rows if the
868+
# actual non-key column data hasn't changed - this avoids unnecessary IO and writes.
866869
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)
867870

868871
if len(rows_to_update) > 0:
869-
# build the match predicate filter
870872
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
871-
872873
batches_to_overwrite.append(rows_to_update)
873874
overwrite_predicates.append(overwrite_mask_predicate)
874875

875-
# Collect matched keys for insert filtering (will use anti-join after loop)
876876
if when_not_matched_insert_all:
877877
matched_target_keys.append(rows.select(join_cols))
878878

879-
batch_loop_end = time.perf_counter()
880-
logger.info(
881-
f"Batch processing: {batch_loop_end - batch_loop_start:.3f}s "
882-
f"({batch_count} batches, get_rows_to_update total: {total_rows_to_update_time:.3f}s)"
883-
)
884-
885-
# Use anti-join to find rows to insert (replaces per-batch expression filtering)
879+
# Use anti-join to find rows to insert. This is more efficient than per-batch
880+
# expression filtering because: (1) we build expressions once, not per batch,
881+
# and (2) PyArrow joins are faster than evaluating large Or(...) expressions.
886882
rows_to_insert = df
887883
if when_not_matched_insert_all and matched_target_keys:
888-
filter_start = time.perf_counter()
889884
# Combine all matched keys and deduplicate
890885
combined_matched_keys = pa.concat_tables(matched_target_keys).group_by(join_cols).aggregate([])
891886
# Cast matched keys to source schema types for join compatibility
@@ -898,8 +893,6 @@ def upsert(
898893
not_matched_keys = source_keys_with_idx.join(combined_matched_keys, keys=join_cols, join_type="left anti")
899894
indices_to_keep = not_matched_keys.column("__row_idx__").combine_chunks()
900895
rows_to_insert = df.take(indices_to_keep)
901-
filter_end = time.perf_counter()
902-
logger.info(f"Insert filtering (anti-join): {filter_end - filter_start:.3f}s ({len(combined_matched_keys)} matched keys)")
903896

904897
update_row_cnt = 0
905898
insert_row_cnt = 0

pyiceberg/table/upsert_util.py

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,24 @@
2323

2424
from pyiceberg.expressions import (
2525
AlwaysFalse,
26+
AlwaysTrue,
27+
And,
2628
BooleanExpression,
2729
EqualTo,
30+
GreaterThanOrEqual,
2831
In,
32+
LessThanOrEqual,
2933
Or,
3034
)
3135

36+
# Threshold for switching from In() predicate to range-based or no filter.
37+
# When unique keys exceed this, the In() predicate becomes too expensive to process.
38+
LARGE_FILTER_THRESHOLD = 10_000
39+
40+
# Minimum density (ratio of unique values to range size) for range filter to be effective.
41+
# Below this threshold, range filters read too much irrelevant data.
42+
DENSITY_THRESHOLD = 0.1
43+
3244

3345
def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
3446
"""
@@ -58,32 +70,119 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
5870
return Or(*filters)
5971

6072

73+
def _is_numeric_type(arrow_type: pa.DataType) -> bool:
74+
"""Check if a PyArrow type is numeric (suitable for range filtering)."""
75+
return pa.types.is_integer(arrow_type) or pa.types.is_floating(arrow_type)
76+
77+
78+
def _create_range_filter(col_name: str, values: pa.Array) -> BooleanExpression:
79+
"""Create a min/max range filter for a numeric column."""
80+
min_val = pc.min(values).as_py()
81+
max_val = pc.max(values).as_py()
82+
return And(GreaterThanOrEqual(col_name, min_val), LessThanOrEqual(col_name, max_val))
83+
84+
6185
def create_coarse_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
6286
"""
6387
Create a coarse Iceberg BooleanExpression filter for initial row scanning.
6488
65-
For single-column keys, uses an efficient In() predicate (exact match).
66-
For composite keys, uses In() per column as a coarse filter (AND of In() predicates),
67-
which may return false positives but is much more efficient than exact matching.
89+
This is an optimization for reducing the scan size before exact matching happens
90+
downstream (e.g., in get_rows_to_update() via the join operation). It trades filter
91+
precision for filter evaluation speed.
92+
93+
IMPORTANT: This is not a silver bullet optimization. It only helps specific use cases:
94+
- Datasets with < 10,000 unique keys benefit from In() predicates
95+
- Large datasets with dense numeric keys (>10% density) benefit from range filters
96+
- Large datasets with sparse keys or non-numeric columns fall back to full scan
97+
98+
For small datasets (< LARGE_FILTER_THRESHOLD unique keys, currently 10,000):
99+
- Single-column keys: uses In() predicate
100+
- Composite keys: uses AND of In() predicates per column
101+
102+
For large datasets (>= LARGE_FILTER_THRESHOLD unique keys):
103+
- Single numeric column with dense IDs (>10% coverage): uses min/max range filter
104+
- Otherwise: returns AlwaysTrue() to skip filtering (full scan)
68105
69-
This function should only be used for initial scans where exact matching happens
70-
downstream (e.g., in get_rows_to_update() via the join operation).
106+
The density threshold (DENSITY_THRESHOLD = 0.1 or 10%) determines whether a range
107+
filter is efficient. Below this threshold, the range would include too many
108+
non-matching rows, making a full scan more practical.
109+
110+
Args:
111+
df: PyArrow table containing the source data with join columns
112+
join_cols: List of column names to use for matching
113+
114+
Returns:
115+
BooleanExpression filter for Iceberg table scan
71116
"""
72117
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
118+
num_unique_keys = len(unique_keys)
73119

74-
if len(unique_keys) == 0:
120+
if num_unique_keys == 0:
75121
return AlwaysFalse()
76122

123+
# For small datasets, use the standard In() approach
124+
if num_unique_keys < LARGE_FILTER_THRESHOLD:
125+
if len(join_cols) == 1:
126+
return In(join_cols[0], unique_keys[0].to_pylist())
127+
else:
128+
column_filters = []
129+
for col in join_cols:
130+
unique_values = pc.unique(unique_keys[col]).to_pylist()
131+
column_filters.append(In(col, unique_values))
132+
if len(column_filters) == 0:
133+
return AlwaysFalse()
134+
if len(column_filters) == 1:
135+
return column_filters[0]
136+
return functools.reduce(operator.and_, column_filters)
137+
138+
# For large datasets, use optimized strategies
77139
if len(join_cols) == 1:
78-
return In(join_cols[0], unique_keys[0].to_pylist())
140+
col_name = join_cols[0]
141+
col_data = unique_keys[col_name]
142+
col_type = col_data.type
143+
144+
# For numeric columns, check if range filter is efficient (dense IDs)
145+
if _is_numeric_type(col_type):
146+
min_val = pc.min(col_data).as_py()
147+
max_val = pc.max(col_data).as_py()
148+
value_range = max_val - min_val + 1
149+
density = num_unique_keys / value_range if value_range > 0 else 0
150+
151+
# If IDs are dense (>10% coverage of the range), use range filter
152+
# Otherwise, range filter would read too much irrelevant data
153+
if density > DENSITY_THRESHOLD:
154+
return _create_range_filter(col_name, col_data)
155+
else:
156+
return AlwaysTrue()
157+
else:
158+
# Non-numeric single column with many values - skip filter
159+
return AlwaysTrue()
79160
else:
80-
# For composite keys: use In() per column as a coarse filter
81-
# This is more efficient than creating Or(And(...), And(...), ...) for each row
82-
# May include false positives, but fine-grained matching happens downstream
161+
# Composite keys with many values - use range filters for numeric columns where possible
83162
column_filters = []
84163
for col in join_cols:
85-
unique_values = pc.unique(unique_keys[col]).to_pylist()
86-
column_filters.append(In(col, unique_values))
164+
col_data = unique_keys[col]
165+
col_type = col_data.type
166+
unique_values = pc.unique(col_data)
167+
168+
if _is_numeric_type(col_type) and len(unique_values) >= LARGE_FILTER_THRESHOLD:
169+
# Use range filter for large numeric columns
170+
min_val = pc.min(unique_values).as_py()
171+
max_val = pc.max(unique_values).as_py()
172+
value_range = max_val - min_val + 1
173+
density = len(unique_values) / value_range if value_range > 0 else 0
174+
175+
if density > DENSITY_THRESHOLD:
176+
column_filters.append(_create_range_filter(col, unique_values))
177+
else:
178+
# Sparse numeric column - still use In() as it's part of composite key
179+
column_filters.append(In(col, unique_values.to_pylist()))
180+
else:
181+
# Small column or non-numeric - use In()
182+
column_filters.append(In(col, unique_values.to_pylist()))
183+
184+
if len(column_filters) == 0:
185+
return AlwaysTrue()
87186
return functools.reduce(operator.and_, column_filters)
88187

89188

@@ -98,8 +197,21 @@ def _compare_columns_vectorized(
98197
"""
99198
Vectorized comparison of two columns, returning a boolean array where True means values differ.
100199
101-
Handles struct types recursively by comparing each nested field.
102-
Handles null values correctly: null != non-null is True, null == null is True (no update needed).
200+
Handles different PyArrow types:
201+
- Primitive types: Uses pc.not_equal() with proper null handling
202+
- Struct types: Recursively compares each nested field
203+
- List/Map types: Falls back to Python comparison (still batched, not row-by-row)
204+
205+
Null handling semantics:
206+
- null != non-null -> True (values differ, needs update)
207+
- null == null -> False (values same, no update needed)
208+
209+
Args:
210+
source_col: Column from the source table
211+
target_col: Column from the target table (must have same length)
212+
213+
Returns:
214+
Boolean PyArrow array where True indicates the values at that index differ
103215
"""
104216
col_type = source_col.type
105217

@@ -155,7 +267,32 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
155267
Return a table with rows that need to be updated in the target table based on the join columns.
156268
157269
Uses vectorized PyArrow operations for efficient comparison, avoiding row-by-row Python loops.
158-
The table is joined on the identifier columns, and then checked if there are any updated rows.
270+
The function performs an inner join on the identifier columns, then compares non-key columns
271+
to find rows where values have actually changed.
272+
273+
Algorithm:
274+
1. Prepare source and target index tables with row indices
275+
2. Inner join on join columns to find matching rows
276+
3. Use take() to extract matched rows in batch
277+
4. Compare non-key columns using vectorized operations
278+
5. Filter to rows where at least one non-key column differs
279+
280+
Note: The column names '__source_index' and '__target_index' are reserved for internal use
281+
and cannot be used as join column names.
282+
283+
Args:
284+
source_table: PyArrow table with new/updated data
285+
target_table: PyArrow table with existing data
286+
join_cols: List of column names that form the unique key
287+
288+
Returns:
289+
PyArrow table containing only the rows from source_table that exist in target_table
290+
and have at least one non-key column with a different value. Returns an empty table
291+
if no updates are needed.
292+
293+
Raises:
294+
ValueError: If target_table has duplicate rows based on join_cols
295+
ValueError: If join_cols contains reserved column names
159296
"""
160297
all_columns = set(source_table.column_names)
161298
join_cols_set = set(join_cols)
@@ -183,8 +320,8 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
183320
) from None
184321

185322
# Step 1: Prepare source index with join keys and a marker index
186-
# Cast to target table schema, so we can do the join
187-
# See: https://github.com/apache/arrow/issues/37542
323+
# Cast source to target schema to ensure type compatibility for the join
324+
# (e.g., source int32 vs target int64 would cause join issues)
188325
source_index = (
189326
source_table.cast(target_table.schema)
190327
.select(join_cols_set)

0 commit comments

Comments
 (0)