@@ -98,17 +98,23 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
9898 # When we are not able to compare (e.g. due to unsupported types),
9999 # fall back to selecting only rows in the source table that do NOT already exist in the target.
100100 # See: https://github.com/apache/arrow/issues/35785
101-
102101 MARKER_COLUMN_NAME = "__from_target"
102+ INDEX_COLUMN_NAME = "__source_index"
103103
104- if MARKER_COLUMN_NAME in join_cols_set :
104+ if MARKER_COLUMN_NAME in join_cols_set or INDEX_COLUMN_NAME in join_cols_set :
105105 raise ValueError (
106- f"{ MARKER_COLUMN_NAME } is used for joining " f"DataFrames, and cannot be used as column name"
106+ f"{ MARKER_COLUMN_NAME } and { INDEX_COLUMN_NAME } are reserved for joining "
107+ f"DataFrames, and cannot be used as column names"
107108 ) from None
108109
109- # Step 1: Prepare source index with join keys and a marker
110+ # Step 1: Prepare source index with join keys and a marker index
110111 # Cast to target table schema, so we can do the join
111- source_index = source_table .cast (target_table .schema ).select (join_cols_set )
112+ # See: https://github.com/apache/arrow/issues/37542
113+ source_index = (
114+ source_table .cast (target_table .schema )
115+ .select (join_cols_set )
116+ .append_column (INDEX_COLUMN_NAME , pa .array (range (len (source_table ))))
117+ )
112118
113119 # Step 2: Prepare target index with join keys and a marker
114120 target_index = target_table .select (join_cols_set ).append_column (
@@ -118,10 +124,14 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
118124 # Step 3: Perform a left outer join to find which rows from source exist in target
119125 joined = source_index .join (target_index , keys = list (join_cols_set ), join_type = "left outer" )
120126
121- # Step 4: Create a boolean mask for rows that do NOT exist in the target
122- # i.e., where market column is null after the join
127+ # Step 4: Restore original source order
128+ joined = joined .sort_by (INDEX_COLUMN_NAME )
129+
130+ # Step 5: Create a boolean mask for rows that do exist in the target
131+ # i.e., where marker column is true after the join
123132 to_update_mask = pc .invert (pc .is_null (joined [MARKER_COLUMN_NAME ]))
124133
125- # Step 5: Filter source table using the mask (keep only rows that should be updated),
126- # and cast to the target schema to ensure compatibility (e.g. large_string → string)
127- return source_table .filter (to_update_mask )
134+ # Step 6: Filter source table using the mask and cast to target schema
135+ filtered = source_table .filter (to_update_mask )
136+
137+ return filtered
0 commit comments