@@ -76,13 +76,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
7676 # 1. Cannot do a join when non-join columns have complex types
7777 # 2. Cannot compare columns with complex types
7878 # See: https://github.com/apache/arrow/issues/35785
79- MARKER_COLUMN_NAME = "__from_target"
8079 SOURCE_INDEX_COLUMN_NAME = "__source_index"
8180 TARGET_INDEX_COLUMN_NAME = "__target_index"
8281
83- if MARKER_COLUMN_NAME in join_cols or SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols :
82+ if SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols :
8483 raise ValueError (
85- f"{ MARKER_COLUMN_NAME } , { SOURCE_INDEX_COLUMN_NAME } and { TARGET_INDEX_COLUMN_NAME } are reserved for joining "
84+ f"{ SOURCE_INDEX_COLUMN_NAME } and { TARGET_INDEX_COLUMN_NAME } are reserved for joining "
8685 f"DataFrames, and cannot be used as column names"
8786 ) from None
8887
@@ -96,19 +95,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
9695 )
9796
9897 # Step 2: Prepare target index with join keys and a marker
99- target_index = (
100- target_table .select (join_cols_set )
101- .append_column (TARGET_INDEX_COLUMN_NAME , pa .array (range (len (target_table ))))
102- .append_column (MARKER_COLUMN_NAME , pa .repeat (True , len (target_table )))
103- )
104-
105- # Step 3: Perform a left outer join to find which rows from source exist in target
106- joined = source_index .join (target_index , keys = list (join_cols_set ), join_type = "left outer" )
98+ target_index = target_table .select (join_cols_set ).append_column (TARGET_INDEX_COLUMN_NAME , pa .array (range (len (target_table ))))
10799
108- # Step 4: Create indices for rows that do exist in the target i.e., where marker column is true after the join
109- matching_indices = joined . filter ( pc . field ( MARKER_COLUMN_NAME ) )
100+ # Step 3: Perform an inner join to find which rows from source exist in target
101+ matching_indices = source_index . join ( target_index , keys = list ( join_cols_set ), join_type = "inner" )
110102
111- # Step 5 : Compare all rows using Python
103+ # Step 4 : Compare all rows using Python
112104 to_update_indices = []
113105 for source_idx , target_idx in zip (
114106 matching_indices [SOURCE_INDEX_COLUMN_NAME ].to_pylist (), matching_indices [TARGET_INDEX_COLUMN_NAME ].to_pylist ()
@@ -123,7 +115,7 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
123115 to_update_indices .append (source_idx )
124116 break
125117
126- # Step 6 : Take rows from source table using the indices and cast to target schema
118+ # Step 5 : Take rows from source table using the indices and cast to target schema
127119 if to_update_indices :
128120 return source_table .take (to_update_indices )
129121 else :
0 commit comments