Skip to content

Commit 79f6181

Browse files
committed
Do row comparision in Python
1 parent 4e75ce1 commit 79f6181

2 files changed

Lines changed: 69 additions & 41 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
6060
The table is joined on the identifier columns, and then checked if there are any updated rows.
6161
Those are selected and everything is renamed correctly.
6262
"""
63+
all_columns = set(source_table.column_names)
6364
join_cols_set = set(join_cols)
6465

66+
non_key_cols = list(all_columns - join_cols_set)
67+
6568
if has_duplicate_rows(target_table, join_cols):
6669
raise ValueError("Target table has duplicate rows, aborting upsert")
6770

@@ -73,11 +76,12 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
7376
# fall back to selecting only rows in the source table that do NOT already exist in the target.
7477
# See: https://github.com/apache/arrow/issues/35785
7578
MARKER_COLUMN_NAME = "__from_target"
76-
INDEX_COLUMN_NAME = "__source_index"
79+
SOURCE_INDEX_COLUMN_NAME = "__source_index"
80+
TARGET_INDEX_COLUMN_NAME = "__target_index"
7781

78-
if MARKER_COLUMN_NAME in join_cols or INDEX_COLUMN_NAME in join_cols:
82+
if MARKER_COLUMN_NAME in join_cols or SOURCE_INDEX_COLUMN_NAME in join_cols or TARGET_INDEX_COLUMN_NAME in join_cols:
7983
raise ValueError(
80-
f"{MARKER_COLUMN_NAME} and {INDEX_COLUMN_NAME} are reserved for joining "
84+
f"{MARKER_COLUMN_NAME}, {SOURCE_INDEX_COLUMN_NAME} and {TARGET_INDEX_COLUMN_NAME} are reserved for joining "
8185
f"DataFrames, and cannot be used as column names"
8286
) from None
8387

@@ -87,17 +91,39 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
8791
source_index = (
8892
source_table.cast(target_table.schema)
8993
.select(join_cols_set)
90-
.append_column(INDEX_COLUMN_NAME, pa.array(range(len(source_table))))
94+
.append_column(SOURCE_INDEX_COLUMN_NAME, pa.array(range(len(source_table))))
9195
)
9296

9397
# Step 2: Prepare target index with join keys and a marker
94-
target_index = target_table.select(join_cols_set).append_column(MARKER_COLUMN_NAME, pa.repeat(True, len(target_table)))
98+
target_index = (
99+
target_table.select(join_cols_set)
100+
.append_column(TARGET_INDEX_COLUMN_NAME, pa.array(range(len(target_table))))
101+
.append_column(MARKER_COLUMN_NAME, pa.repeat(True, len(target_table)))
102+
)
95103

96104
# Step 3: Perform a left outer join to find which rows from source exist in target
97105
joined = source_index.join(target_index, keys=list(join_cols_set), join_type="left outer")
98106

99107
# Step 4: Create indices for rows that do exist in the target i.e., where marker column is true after the join
100-
to_update_indices = joined.filter(pc.field(MARKER_COLUMN_NAME))[INDEX_COLUMN_NAME]
101-
102-
# Step 5: Take rows from source table using the indices and cast to target schema
103-
return source_table.take(to_update_indices)
108+
matching_indices = joined.filter(pc.field(MARKER_COLUMN_NAME))
109+
110+
# Step 5: Compare all rows using Python
111+
to_update_indices = []
112+
for source_idx, target_idx in zip(
113+
matching_indices[SOURCE_INDEX_COLUMN_NAME].to_pylist(), matching_indices[TARGET_INDEX_COLUMN_NAME].to_pylist()
114+
):
115+
source_row = source_table.slice(source_idx, 1)
116+
target_row = target_table.slice(target_idx, 1)
117+
118+
for key in non_key_cols:
119+
source_val = source_row.column(key)[0].as_py()
120+
target_val = target_row.column(key)[0].as_py()
121+
if source_val != target_val:
122+
to_update_indices.append(source_idx)
123+
break
124+
125+
# Step 6: Take rows from source table using the indices and cast to target schema
126+
if to_update_indices:
127+
return source_table.take(to_update_indices)
128+
else:
129+
return source_table.schema.empty_table()

tests/table/test_upsert.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_merge_scenario_skip_upd_row(catalog: Catalog) -> None:
186186

187187
res = table.upsert(df=source_df, join_cols=["order_id"])
188188

189-
expected_updated = 2
189+
expected_updated = 1
190190
expected_inserted = 1
191191

192192
assert_upsert_result(res, expected_updated, expected_inserted)
@@ -222,7 +222,7 @@ def test_merge_scenario_date_as_key(catalog: Catalog) -> None:
222222

223223
res = table.upsert(df=source_df, join_cols=["order_date"])
224224

225-
expected_updated = 2
225+
expected_updated = 1
226226
expected_inserted = 1
227227

228228
assert_upsert_result(res, expected_updated, expected_inserted)
@@ -258,7 +258,7 @@ def test_merge_scenario_string_as_key(catalog: Catalog) -> None:
258258

259259
res = table.upsert(df=source_df, join_cols=["order_id"])
260260

261-
expected_updated = 2
261+
expected_updated = 1
262262
expected_inserted = 1
263263

264264
assert_upsert_result(res, expected_updated, expected_inserted)
@@ -371,25 +371,16 @@ def test_upsert_with_identifier_fields(catalog: Catalog) -> None:
371371

372372
expected_operations = [Operation.APPEND, Operation.OVERWRITE, Operation.APPEND, Operation.APPEND]
373373

374-
assert upd.rows_updated == 2
374+
assert upd.rows_updated == 1
375375
assert upd.rows_inserted == 1
376376

377377
assert [snap.summary.operation for snap in tbl.snapshots() if snap.summary is not None] == expected_operations
378378

379-
# This will update all 3 rows
379+
# This should be a no-op
380380
upd = tbl.upsert(df)
381381

382-
assert upd.rows_updated == 3
382+
assert upd.rows_updated == 0
383383
assert upd.rows_inserted == 0
384-
expected_operations = [
385-
Operation.APPEND,
386-
Operation.OVERWRITE,
387-
Operation.APPEND,
388-
Operation.APPEND,
389-
Operation.DELETE,
390-
Operation.OVERWRITE,
391-
Operation.APPEND,
392-
]
393384

394385
assert [snap.summary.operation for snap in tbl.snapshots() if snap.summary is not None] == expected_operations
395386

@@ -561,7 +552,7 @@ def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None:
561552
[
562553
{
563554
"id": 1,
564-
"nested_type": {"sub1": "1_sub1_init", "sub2": "1sub2_init"},
555+
"nested_type": {"sub1": "bla1", "sub2": "bla"},
565556
}
566557
],
567558
schema=arrow_schema,
@@ -572,32 +563,43 @@ def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None:
572563
[
573564
{
574565
"id": 2,
575-
"nested_type": {"sub1": "2_sub1_new", "sub2": "2_sub2_new"},
566+
"nested_type": {"sub1": "bla1", "sub2": "bla"},
576567
},
577568
{
578569
"id": 1,
579-
"nested_type": {"sub1": "1sub1_init", "sub2": "1sub2_new"},
570+
"nested_type": {"sub1": "bla1", "sub2": "bla2"},
580571
},
581-
# TODO: struct changes should cause _check_pyarrow_schema_compatible to fail. Introduce a new `sub3` attribute
582-
# {
583-
# "id": 1,
584-
# "nested_type": {"sub3": "1sub3_init", "sub2": "1sub2_new"},
585-
# },
586572
],
587573
schema=arrow_schema,
588574
)
589575

590-
upd = tbl.upsert(update_data, join_cols=["id"])
576+
res = tbl.upsert(update_data, join_cols=["id"])
591577

592-
# Row needs to be updated even tho it's not changed.
593-
# When pyarrow isn't able to compare rows, just update everything
594-
assert upd.rows_updated == 1
595-
assert upd.rows_inserted == 1
578+
expected_updated = 1
579+
expected_inserted = 1
596580

597-
assert tbl.scan().to_arrow().to_pylist() == [
598-
{"id": 2, "nested_type": {"sub1": "2_sub1_new", "sub2": "2_sub2_new"}},
599-
{"id": 1, "nested_type": {"sub1": "1sub1_init", "sub2": "1sub2_new"}},
600-
]
581+
assert_upsert_result(res, expected_updated, expected_inserted)
582+
583+
update_data = pa.Table.from_pylist(
584+
[
585+
{
586+
"id": 2,
587+
"nested_type": {"sub1": "bla1", "sub2": "bla"},
588+
},
589+
{
590+
"id": 1,
591+
"nested_type": {"sub1": "bla1", "sub2": "bla2"},
592+
},
593+
],
594+
schema=arrow_schema,
595+
)
596+
597+
res = tbl.upsert(update_data, join_cols=["id"])
598+
599+
expected_updated = 0
600+
expected_inserted = 0
601+
602+
assert_upsert_result(res, expected_updated, expected_inserted)
601603

602604

603605
def test_upsert_with_nulls(catalog: Catalog) -> None:

0 commit comments

Comments
 (0)