Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion pyiceberg/table/upsert_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import functools
import operator
from typing import Callable, List, TypeVar

import pyarrow as pa
from pyarrow import Table as pyarrow_table
Expand All @@ -28,6 +29,43 @@
In,
)

T = TypeVar("T")


def build_balanced_tree(operator_: Callable[[T, T], T], items: List[T]) -> T:
Comment thread
Fokko marked this conversation as resolved.
Outdated
"""
Recursively constructs a balanced binary tree of expressions using the provided binary operator.

This function is a safer and more scalable alternative to:
reduce(operator_, items)

Using reduce creates a deeply nested, unbalanced tree (e.g., operator_(a, operator_(b, operator_(c, ...)))),
which grows linearly with the number of items. This can lead to RecursionError exceptions in Python
when the number of expressions is large (e.g., >1000).

In contrast, this function builds a balanced binary tree with logarithmic depth (O(log n)),
helping avoid recursion issues and ensuring that expression trees remain stable, predictable,
and safe to traverse — especially in tools like PyIceberg that operate on large logical trees.

Parameters:
operator_ (Callable[[T, T], T]): A binary operator function (e.g., pyiceberg.expressions.Or, And).
items (List[T]): A list of expression objects to combine.

Returns:
T: An expression object representing the balanced combination of all input expressions.

Raises:
ValueError: If the input list is empty.
"""
if not items:
raise ValueError("No expressions to combine")
if len(items) == 1:
return items[0]
mid = len(items) // 2
left = build_balanced_tree(operator_, items[:mid])
right = build_balanced_tree(operator_, items[mid:])
return operator_(left, right)


def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpression:
unique_keys = df.select(join_cols).group_by(join_cols).aggregate([])
Expand All @@ -39,7 +77,7 @@ def create_match_filter(df: pyarrow_table, join_cols: list[str]) -> BooleanExpre
functools.reduce(operator.and_, [EqualTo(col, row[col]) for col in join_cols]) for row in unique_keys.to_pylist()
]

return AlwaysFalse() if len(filters) == 0 else functools.reduce(operator.or_, filters)
return AlwaysFalse() if len(filters) == 0 else build_balanced_tree(operator.or_, filters)
Comment thread
Fokko marked this conversation as resolved.
Outdated


def has_duplicate_rows(df: pyarrow_table, join_cols: list[str]) -> bool:
Expand Down
Loading