Skip to content

Commit 8e32e9c

Browse files
committed
Preserve order in get_rows_to_update for complex types
1 parent a699602 commit 8e32e9c

2 files changed

Lines changed: 21 additions & 11 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,23 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
9898
# When we are not able to compare (e.g. due to unsupported types),
9999
# fall back to selecting only rows in the source table that do NOT already exist in the target.
100100
# See: https://github.com/apache/arrow/issues/35785
101-
102101
MARKER_COLUMN_NAME = "__from_target"
102+
INDEX_COLUMN_NAME = "__source_index"
103103

104-
if MARKER_COLUMN_NAME in join_cols_set:
104+
if MARKER_COLUMN_NAME in join_cols_set or INDEX_COLUMN_NAME in join_cols_set:
105105
raise ValueError(
106-
f"{MARKER_COLUMN_NAME} is used for joining " f"DataFrames, and cannot be used as column name"
106+
f"{MARKER_COLUMN_NAME} and {INDEX_COLUMN_NAME} are reserved for joining "
107+
f"DataFrames, and cannot be used as column names"
107108
) from None
108109

109-
# Step 1: Prepare source index with join keys and a marker
110+
# Step 1: Prepare source index with join keys and a marker index
110111
# Cast to target table schema, so we can do the join
111-
source_index = source_table.cast(target_table.schema).select(join_cols_set)
112+
# See: https://github.com/apache/arrow/issues/37542
113+
source_index = (
114+
source_table.cast(target_table.schema)
115+
.select(join_cols_set)
116+
.append_column(INDEX_COLUMN_NAME, pa.array(range(len(source_table))))
117+
)
112118

113119
# Step 2: Prepare target index with join keys and a marker
114120
target_index = target_table.select(join_cols_set).append_column(
@@ -118,10 +124,14 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
118124
# Step 3: Perform a left outer join to find which rows from source exist in target
119125
joined = source_index.join(target_index, keys=list(join_cols_set), join_type="left outer")
120126

121-
# Step 4: Create a boolean mask for rows that do NOT exist in the target
122-
# i.e., where market column is null after the join
127+
# Step 4: Restore original source order
128+
joined = joined.sort_by(INDEX_COLUMN_NAME)
129+
130+
# Step 5: Create a boolean mask for rows that do exist in the target
131+
# i.e., where marker column is true after the join
123132
to_update_mask = pc.invert(pc.is_null(joined[MARKER_COLUMN_NAME]))
124133

125-
# Step 5: Filter source table using the mask (keep only rows that should be updated),
126-
# and cast to the target schema to ensure compatibility (e.g. large_string → string)
127-
return source_table.filter(to_update_mask)
134+
# Step 6: Filter source table using the mask and cast to target schema
135+
filtered = source_table.filter(to_update_mask)
136+
137+
return filtered

tests/table/test_upsert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None:
568568
{
569569
"id": 1,
570570
"nested_type": {"sub1": "bla1", "sub2": "bla"},
571-
}
571+
},
572572
],
573573
schema=arrow_schema,
574574
)

0 commit comments

Comments
 (0)