|
16 | 16 | # under the License. |
17 | 17 | from __future__ import annotations |
18 | 18 |
|
| 19 | +import functools |
19 | 20 | import itertools |
| 21 | +import operator |
20 | 22 | import os |
21 | 23 | import uuid |
22 | 24 | import warnings |
@@ -774,39 +776,59 @@ def upsert( |
774 | 776 | matched_predicate = upsert_util.create_match_filter(df, join_cols) |
775 | 777 |
|
776 | 778 | # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. |
777 | | - matched_iceberg_table = DataScan( |
| 779 | + matched_iceberg_record_batches = DataScan( |
778 | 780 | table_metadata=self.table_metadata, |
779 | 781 | io=self._table.io, |
780 | 782 | row_filter=matched_predicate, |
781 | 783 | case_sensitive=case_sensitive, |
782 | | - ).to_arrow() |
| 784 | + ).to_arrow_batch_reader() |
783 | 785 |
|
784 | | - update_row_cnt = 0 |
785 | | - insert_row_cnt = 0 |
| 786 | + batches_to_overwrite = [] |
| 787 | + overwrite_predicates = [] |
| 788 | + insert_filters = [] |
786 | 789 |
|
787 | | - if when_matched_update_all: |
788 | | - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed |
789 | | - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed |
790 | | - # this extra step avoids unnecessary IO and writes |
791 | | - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) |
| 790 | + for batch in matched_iceberg_record_batches: |
| 791 | + rows = pa.Table.from_batches([batch]) |
792 | 792 |
|
793 | | - update_row_cnt = len(rows_to_update) |
| 793 | + if when_matched_update_all: |
| 794 | + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed |
| 795 | + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed |
| 796 | + # this extra step avoids unnecessary IO and writes |
| 797 | + rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols) |
| 798 | + |
| 799 | + if len(rows_to_update) > 0: |
| 800 | + # build the match predicate filter |
| 801 | + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) |
794 | 802 |
|
795 | | - if len(rows_to_update) > 0: |
796 | | - # build the match predicate filter |
797 | | - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) |
| 803 | + batches_to_overwrite.append(rows_to_update) |
| 804 | + overwrite_predicates.append(overwrite_mask_predicate) |
798 | 805 |
|
799 | | - self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) |
| 806 | + if when_not_matched_insert_all: |
| 807 | + expr_match = upsert_util.create_match_filter(rows, join_cols) |
| 808 | + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) |
| 809 | + expr_match_arrow = expression_to_pyarrow(expr_match_bound) |
| 810 | + |
| 811 | + insert_filters.append(~expr_match_arrow) |
| 812 | + |
| 813 | + update_row_cnt = 0 |
| 814 | + insert_row_cnt = 0 |
| 815 | + |
| 816 | + if batches_to_overwrite: |
| 817 | + rows_to_update = pa.concat_tables(batches_to_overwrite) |
| 818 | + update_row_cnt = len(rows_to_update) |
| 819 | + self.overwrite( |
| 820 | + rows_to_update, |
| 821 | + overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], |
| 822 | + ) |
800 | 823 |
|
801 | 824 | if when_not_matched_insert_all: |
802 | | - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) |
803 | | - expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) |
804 | | - expr_match_arrow = expression_to_pyarrow(expr_match_bound) |
805 | | - rows_to_insert = df.filter(~expr_match_arrow) |
| 825 | + if insert_filters: |
| 826 | + rows_to_insert = df.filter(functools.reduce(operator.and_, insert_filters)) |
| 827 | + else: |
| 828 | + rows_to_insert = df |
806 | 829 |
|
807 | 830 | insert_row_cnt = len(rows_to_insert) |
808 | | - |
809 | | - if insert_row_cnt > 0: |
| 831 | + if rows_to_insert: |
810 | 832 | self.append(rows_to_insert) |
811 | 833 |
|
812 | 834 | return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) |
|
0 commit comments