Skip to content

Commit 85424d8

Browse files
committed
Replace left outer join + filter with inner join
1 parent e9e9485 commit 85424d8

1 file changed

Lines changed: 7 additions & 15 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
7676
# 1. Cannot do a join when non-join columns have complex types
7777
# 2. Cannot compare columns with complex types
7878
# See: https://github.com/apache/arrow/issues/35785
79-
MARKER_COLUMN_NAME = "__from_target"
8079
SOURCE_INDEX_COLUMN_NAME = "__source_index"
8180
TARGET_INDEX_COLUMN_NAME = "__target_index"
8281

83-
if MARKER_COLUMN_NAME in join_cols or SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols:
82+
if SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols:
8483
raise ValueError(
85-
f"{MARKER_COLUMN_NAME}, {SOURCE_INDEX_COLUMN_NAME} and {TARGET_INDEX_COLUMN_NAME} are reserved for joining "
84+
f"{SOURCE_INDEX_COLUMN_NAME} and {TARGET_INDEX_COLUMN_NAME} are reserved for joining "
8685
f"DataFrames, and cannot be used as column names"
8786
) from None
8887

@@ -96,19 +95,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
9695
)
9796

9897
# Step 2: Prepare target index with join keys and a marker
99-
target_index = (
100-
target_table.select(join_cols_set)
101-
.append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
102-
.append_column(MARKER_COLUMN_NAME, pa.repeat(True, len(target_table)))
103-
)
104-
105-
# Step 3: Perform a left outer join to find which rows from source exist in target
106-
joined = source_index.join(target_index, keys=list(join_cols_set), join_type="left outer")
98+
target_index = target_table.select(join_cols_set).append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
10799

108-
# Step 4: Create indices for rows that do exist in the target i.e., where marker column is true after the join
109-
matching_indices = joined.filter(pc.field(MARKER_COLUMN_NAME))
100+
# Step 3: Perform an inner join to find which rows from source exist in target
101+
matching_indices = source_index.join(target_index, keys=list(join_cols_set), join_type="inner")
110102

111-
# Step 5: Compare all rows using Python
103+
# Step 4: Compare all rows using Python
112104
to_update_indices = []
113105
for source_idx, target_idx in zip(
114106
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
@@ -123,7 +115,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
123115
to_update_indices.append(source_idx)
124116
break
125117

126-
# Step 6: Take rows from source table using the indices and cast to target schema
118+
# Step 5: Take rows from source table using the indices and cast to target schema
127119
if to_update_indices:
128120
return source_table.take(to_update_indices)
129121
else:

0 commit comments

Comments
 (0)