Skip to content

Commit f16f8b3

Browse files
committed
Fallback for upsert when arrow cannot compare source rows with target rows
1 parent 7a6a7c8 commit f16f8b3

2 files changed

Lines changed: 82 additions & 12 deletions

File tree

pyiceberg/table/upsert_util.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,18 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
6767

6868
diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols])
6969

70-
return (
71-
source_table
72-
# We already know that the schema is compatible, this is to fix large_ types
73-
.cast(target_table.schema)
74-
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs")
75-
.filter(diff_expr)
76-
.drop_columns([f"{col}-rhs" for col in non_key_cols])
77-
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names})
78-
# Finally cast to the original schema since it doesn't carry nullability:
79-
# https://github.com/apache/arrow/issues/45557
80-
).cast(target_table.schema)
70+
try:
71+
return (
72+
source_table
73+
# We already know that the schema is compatible, this is to fix large_ types
74+
.cast(target_table.schema)
75+
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs")
76+
.filter(diff_expr)
77+
.drop_columns([f"{col}-rhs" for col in non_key_cols])
78+
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names})
79+
# Finally cast to the original schema since it doesn't carry nullability:
80+
# https://github.com/apache/arrow/issues/45557
81+
).cast(target_table.schema)
82+
except pa.ArrowInvalid:
83+
# When we are not able to compare, just update all rows from source table
84+
return source_table.cast(target_table.schema)

tests/table/test_upsert.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pyiceberg.table import UpsertResult
3131
from pyiceberg.table.snapshots import Operation
3232
from pyiceberg.table.upsert_util import create_match_filter
33-
from pyiceberg.types import IntegerType, NestedField, StringType
33+
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
3434
from tests.catalog.test_base import InMemoryCatalog, Table
3535

3636

@@ -509,3 +509,69 @@ def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
509509
ValueError, match="Join columns could not be found, please set identifier-field-ids or pass in explicitly."
510510
):
511511
tbl.upsert(df)
512+
513+
514+
def test_upsert_struct_field_fails_in_join(catalog: Catalog) -> None:
515+
identifier = "default.test_upsert_struct_field_fails"
516+
_drop_table(catalog, identifier)
517+
518+
schema = Schema(
519+
NestedField(1, "id", IntegerType(), required=True),
520+
NestedField(
521+
2,
522+
"nested_type",
523+
# Struct<type: string, coordinates: list<double>>
524+
StructType(
525+
NestedField(3, "sub1", StringType(), required=True),
526+
NestedField(4, "sub2", StringType(), required=True),
527+
),
528+
required=False,
529+
),
530+
identifier_field_ids=[1],
531+
)
532+
533+
tbl = catalog.create_table(identifier, schema=schema)
534+
535+
arrow_schema = pa.schema(
536+
[
537+
pa.field("id", pa.int32(), nullable=False),
538+
pa.field(
539+
"nested_type",
540+
pa.struct(
541+
[
542+
pa.field("sub1", pa.large_string(), nullable=False),
543+
pa.field("sub2", pa.large_string(), nullable=False),
544+
]
545+
),
546+
nullable=True,
547+
),
548+
]
549+
)
550+
551+
initial_data = pa.Table.from_pylist(
552+
[
553+
{
554+
"id": 1,
555+
"nested_type": {"sub1": "bla1", "sub2": "bla"},
556+
}
557+
],
558+
schema=arrow_schema,
559+
)
560+
tbl.append(initial_data)
561+
562+
update_data = pa.Table.from_pylist(
563+
[
564+
{
565+
"id": 1,
566+
"nested_type": {"sub1": "bla1", "sub2": "bla"},
567+
}
568+
],
569+
schema=arrow_schema,
570+
)
571+
572+
upd = tbl.upsert(update_data, join_cols=["id"])
573+
574+
# Row needs to be updated even tho it's not changed.
575+
# When pyarrow isn't able to compare rows, just update everything
576+
assert upd.rows_updated == 1
577+
assert upd.rows_inserted == 0

0 commit comments

Comments
 (0)