8787from pyiceberg .table .name_mapping import (
8888 NameMapping ,
8989)
90- from pyiceberg .table .refs import MAIN_BRANCH , SnapshotRef
90+ from pyiceberg .table .refs import SnapshotRef
9191from pyiceberg .table .snapshots import (
9292 Snapshot ,
9393 SnapshotLogEntry ,
@@ -398,7 +398,7 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
398398 expr = Or (expr , match_partition_expression )
399399 return expr
400400
401- def _append_snapshot_producer (self , snapshot_properties : Dict [str , str ], branch : str ) -> _FastAppendFiles :
401+ def _append_snapshot_producer (self , snapshot_properties : Dict [str , str ], branch : Optional [ str ] ) -> _FastAppendFiles :
402402 """Determine the append type based on table properties.
403403
404404 Args:
@@ -431,7 +431,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
431431 name_mapping = self .table_metadata .name_mapping (),
432432 )
433433
434- def update_snapshot (self , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : str = MAIN_BRANCH ) -> UpdateSnapshot :
434+ def update_snapshot (self , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : Optional [ str ] = None ) -> UpdateSnapshot :
435435 """Create a new UpdateSnapshot to produce a new snapshot for the table.
436436
437437 Returns:
@@ -448,7 +448,7 @@ def update_statistics(self) -> UpdateStatistics:
448448 """
449449 return UpdateStatistics (transaction = self )
450450
451- def append (self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : str = MAIN_BRANCH ) -> None :
451+ def append (self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : Optional [ str ] = None ) -> None :
452452 """
453453 Shorthand API for appending a PyArrow table to a table transaction.
454454
@@ -490,7 +490,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
490490 append_files .append_data_file (data_file )
491491
492492 def dynamic_partition_overwrite (
493- self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : str = MAIN_BRANCH
493+ self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : Optional [ str ] = None
494494 ) -> None :
495495 """
496496 Shorthand for overwriting existing partitions with a PyArrow table.
@@ -554,7 +554,7 @@ def overwrite(
554554 overwrite_filter : Union [BooleanExpression , str ] = ALWAYS_TRUE ,
555555 snapshot_properties : Dict [str , str ] = EMPTY_DICT ,
556556 case_sensitive : bool = True ,
557- branch : str = MAIN_BRANCH ,
557+ branch : Optional [ str ] = None ,
558558 ) -> None :
559559 """
560560 Shorthand for adding a table overwrite with a PyArrow table to the transaction.
@@ -617,7 +617,7 @@ def delete(
617617 delete_filter : Union [str , BooleanExpression ],
618618 snapshot_properties : Dict [str , str ] = EMPTY_DICT ,
619619 case_sensitive : bool = True ,
620- branch : str = MAIN_BRANCH ,
620+ branch : Optional [ str ] = None ,
621621 ) -> None :
622622 """
623623 Shorthand for deleting record from a table.
@@ -656,7 +656,10 @@ def delete(
656656 bound_delete_filter = bind (self .table_metadata .schema (), delete_filter , case_sensitive )
657657 preserve_row_filter = _expression_to_complementary_pyarrow (bound_delete_filter )
658658
659- files = self ._scan (row_filter = delete_filter , case_sensitive = case_sensitive ).use_ref (branch ).plan_files ()
659+ if branch is None :
660+ files = self ._scan (row_filter = delete_filter , case_sensitive = case_sensitive ).plan_files ()
661+ else :
662+ files = self ._scan (row_filter = delete_filter , case_sensitive = case_sensitive ).use_ref (branch ).plan_files ()
660663
661664 commit_uuid = uuid .uuid4 ()
662665 counter = itertools .count (0 )
@@ -717,6 +720,7 @@ def upsert(
717720 when_matched_update_all : bool = True ,
718721 when_not_matched_insert_all : bool = True ,
719722 case_sensitive : bool = True ,
723+ branch : Optional [str ] = None ,
720724 ) -> UpsertResult :
721725 """Shorthand API for performing an upsert to an iceberg table.
722726
@@ -727,6 +731,7 @@ def upsert(
727731 when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
728732 when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
729733 case_sensitive: Bool indicating if the match should be case-sensitive
734+ branch: Branch Reference to run the upsert operation
730735
731736 To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
732737
@@ -789,12 +794,24 @@ def upsert(
789794 matched_predicate = upsert_util .create_match_filter (df , join_cols )
790795
791796 # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
792- matched_iceberg_table = DataScan (
793- table_metadata = self .table_metadata ,
794- io = self ._table .io ,
795- row_filter = matched_predicate ,
796- case_sensitive = case_sensitive ,
797- ).to_arrow ()
797+ if branch is None :
798+ matched_iceberg_table = DataScan (
799+ table_metadata = self .table_metadata ,
800+ io = self ._table .io ,
801+ row_filter = matched_predicate ,
802+ case_sensitive = case_sensitive ,
803+ ).to_arrow ()
804+ else :
805+ matched_iceberg_table = (
806+ DataScan (
807+ table_metadata = self .table_metadata ,
808+ io = self ._table .io ,
809+ row_filter = matched_predicate ,
810+ case_sensitive = case_sensitive ,
811+ )
812+ .use_ref (branch )
813+ .to_arrow ()
814+ )
798815
799816 update_row_cnt = 0
800817 insert_row_cnt = 0
@@ -811,7 +828,7 @@ def upsert(
811828 # build the match predicate filter
812829 overwrite_mask_predicate = upsert_util .create_match_filter (rows_to_update , join_cols )
813830
814- self .overwrite (rows_to_update , overwrite_filter = overwrite_mask_predicate )
831+ self .overwrite (rows_to_update , overwrite_filter = overwrite_mask_predicate , branch = branch )
815832
816833 if when_not_matched_insert_all :
817834 expr_match = upsert_util .create_match_filter (matched_iceberg_table , join_cols )
@@ -822,7 +839,7 @@ def upsert(
822839 insert_row_cnt = len (rows_to_insert )
823840
824841 if insert_row_cnt > 0 :
825- self .append (rows_to_insert )
842+ self .append (rows_to_insert , branch = branch )
826843
827844 return UpsertResult (rows_updated = update_row_cnt , rows_inserted = insert_row_cnt )
828845
@@ -1255,6 +1272,7 @@ def upsert(
12551272 when_matched_update_all : bool = True ,
12561273 when_not_matched_insert_all : bool = True ,
12571274 case_sensitive : bool = True ,
1275+ branch : Optional [str ] = None ,
12581276 ) -> UpsertResult :
12591277 """Shorthand API for performing an upsert to an iceberg table.
12601278
@@ -1265,6 +1283,7 @@ def upsert(
12651283 when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
12661284 when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
12671285 case_sensitive: Bool indicating if the match should be case-sensitive
1286+ branch: Branch Reference to run the upsert operation
12681287
12691288 To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
12701289
@@ -1297,9 +1316,10 @@ def upsert(
12971316 when_matched_update_all = when_matched_update_all ,
12981317 when_not_matched_insert_all = when_not_matched_insert_all ,
12991318 case_sensitive = case_sensitive ,
1319+ branch = branch ,
13001320 )
13011321
1302- def append (self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : str = MAIN_BRANCH ) -> None :
1322+ def append (self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : Optional [ str ] = None ) -> None :
13031323 """
13041324 Shorthand API for appending a PyArrow table to the table.
13051325
@@ -1312,7 +1332,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
13121332 tx .append (df = df , snapshot_properties = snapshot_properties , branch = branch )
13131333
13141334 def dynamic_partition_overwrite (
1315- self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : str = MAIN_BRANCH
1335+ self , df : pa .Table , snapshot_properties : Dict [str , str ] = EMPTY_DICT , branch : Optional [ str ] = None
13161336 ) -> None :
13171337 """Shorthand for dynamic overwriting the table with a PyArrow table.
13181338
@@ -1331,7 +1351,7 @@ def overwrite(
13311351 overwrite_filter : Union [BooleanExpression , str ] = ALWAYS_TRUE ,
13321352 snapshot_properties : Dict [str , str ] = EMPTY_DICT ,
13331353 case_sensitive : bool = True ,
1334- branch : str = MAIN_BRANCH ,
1354+ branch : Optional [ str ] = None ,
13351355 ) -> None :
13361356 """
13371357 Shorthand for overwriting the table with a PyArrow table.
@@ -1364,7 +1384,7 @@ def delete(
13641384 delete_filter : Union [BooleanExpression , str ] = ALWAYS_TRUE ,
13651385 snapshot_properties : Dict [str , str ] = EMPTY_DICT ,
13661386 case_sensitive : bool = True ,
1367- branch : str = MAIN_BRANCH ,
1387+ branch : Optional [ str ] = None ,
13681388 ) -> None :
13691389 """
13701390 Shorthand for deleting rows from the table.
0 commit comments