Skip to content

Commit 49f75b4

Browse files
committed
Fixed bug for empty tables
1 parent e4463df commit 49f75b4

3 files changed

Lines changed: 73 additions & 40 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from pyiceberg.table.name_mapping import (
8888
NameMapping,
8989
)
90-
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
90+
from pyiceberg.table.refs import SnapshotRef
9191
from pyiceberg.table.snapshots import (
9292
Snapshot,
9393
SnapshotLogEntry,
@@ -398,7 +398,7 @@ def _build_partition_predicate(self, partition_records: Set[Record]) -> BooleanE
398398
expr = Or(expr, match_partition_expression)
399399
return expr
400400

401-
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: str) -> _FastAppendFiles:
401+
def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles:
402402
"""Determine the append type based on table properties.
403403
404404
Args:
@@ -431,7 +431,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
431431
name_mapping=self.table_metadata.name_mapping(),
432432
)
433433

434-
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> UpdateSnapshot:
434+
def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
435435
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
436436
437437
Returns:
@@ -448,7 +448,7 @@ def update_statistics(self) -> UpdateStatistics:
448448
"""
449449
return UpdateStatistics(transaction=self)
450450

451-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None:
451+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
452452
"""
453453
Shorthand API for appending a PyArrow table to a table transaction.
454454
@@ -490,7 +490,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
490490
append_files.append_data_file(data_file)
491491

492492
def dynamic_partition_overwrite(
493-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH
493+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
494494
) -> None:
495495
"""
496496
Shorthand for overwriting existing partitions with a PyArrow table.
@@ -554,7 +554,7 @@ def overwrite(
554554
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
555555
snapshot_properties: Dict[str, str] = EMPTY_DICT,
556556
case_sensitive: bool = True,
557-
branch: str = MAIN_BRANCH,
557+
branch: Optional[str] = None,
558558
) -> None:
559559
"""
560560
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
@@ -617,7 +617,7 @@ def delete(
617617
delete_filter: Union[str, BooleanExpression],
618618
snapshot_properties: Dict[str, str] = EMPTY_DICT,
619619
case_sensitive: bool = True,
620-
branch: str = MAIN_BRANCH,
620+
branch: Optional[str] = None,
621621
) -> None:
622622
"""
623623
Shorthand for deleting record from a table.
@@ -656,7 +656,10 @@ def delete(
656656
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive)
657657
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)
658658

659-
files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).use_ref(branch).plan_files()
659+
if branch is None:
660+
files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files()
661+
else:
662+
files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).use_ref(branch).plan_files()
660663

661664
commit_uuid = uuid.uuid4()
662665
counter = itertools.count(0)
@@ -717,6 +720,7 @@ def upsert(
717720
when_matched_update_all: bool = True,
718721
when_not_matched_insert_all: bool = True,
719722
case_sensitive: bool = True,
723+
branch: Optional[str] = None,
720724
) -> UpsertResult:
721725
"""Shorthand API for performing an upsert to an iceberg table.
722726
@@ -727,6 +731,7 @@ def upsert(
727731
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
728732
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
729733
case_sensitive: Bool indicating if the match should be case-sensitive
734+
branch: Branch Reference to run the upsert operation
730735
731736
To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
732737
@@ -789,12 +794,24 @@ def upsert(
789794
matched_predicate = upsert_util.create_match_filter(df, join_cols)
790795

791796
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
792-
matched_iceberg_table = DataScan(
793-
table_metadata=self.table_metadata,
794-
io=self._table.io,
795-
row_filter=matched_predicate,
796-
case_sensitive=case_sensitive,
797-
).to_arrow()
797+
if branch is None:
798+
matched_iceberg_table = DataScan(
799+
table_metadata=self.table_metadata,
800+
io=self._table.io,
801+
row_filter=matched_predicate,
802+
case_sensitive=case_sensitive,
803+
).to_arrow()
804+
else:
805+
matched_iceberg_table = (
806+
DataScan(
807+
table_metadata=self.table_metadata,
808+
io=self._table.io,
809+
row_filter=matched_predicate,
810+
case_sensitive=case_sensitive,
811+
)
812+
.use_ref(branch)
813+
.to_arrow()
814+
)
798815

799816
update_row_cnt = 0
800817
insert_row_cnt = 0
@@ -811,7 +828,7 @@ def upsert(
811828
# build the match predicate filter
812829
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
813830

814-
self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
831+
self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate, branch=branch)
815832

816833
if when_not_matched_insert_all:
817834
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
@@ -822,7 +839,7 @@ def upsert(
822839
insert_row_cnt = len(rows_to_insert)
823840

824841
if insert_row_cnt > 0:
825-
self.append(rows_to_insert)
842+
self.append(rows_to_insert, branch=branch)
826843

827844
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)
828845

@@ -1255,6 +1272,7 @@ def upsert(
12551272
when_matched_update_all: bool = True,
12561273
when_not_matched_insert_all: bool = True,
12571274
case_sensitive: bool = True,
1275+
branch: Optional[str] = None,
12581276
) -> UpsertResult:
12591277
"""Shorthand API for performing an upsert to an iceberg table.
12601278
@@ -1265,6 +1283,7 @@ def upsert(
12651283
when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing
12661284
when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table
12671285
case_sensitive: Bool indicating if the match should be case-sensitive
1286+
branch: Branch Reference to run the upsert operation
12681287
12691288
To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids
12701289
@@ -1297,9 +1316,10 @@ def upsert(
12971316
when_matched_update_all=when_matched_update_all,
12981317
when_not_matched_insert_all=when_not_matched_insert_all,
12991318
case_sensitive=case_sensitive,
1319+
branch=branch,
13001320
)
13011321

1302-
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH) -> None:
1322+
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None:
13031323
"""
13041324
Shorthand API for appending a PyArrow table to the table.
13051325
@@ -1312,7 +1332,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT,
13121332
tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch)
13131333

13141334
def dynamic_partition_overwrite(
1315-
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH
1335+
self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None
13161336
) -> None:
13171337
"""Shorthand for dynamic overwriting the table with a PyArrow table.
13181338
@@ -1331,7 +1351,7 @@ def overwrite(
13311351
overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
13321352
snapshot_properties: Dict[str, str] = EMPTY_DICT,
13331353
case_sensitive: bool = True,
1334-
branch: str = MAIN_BRANCH,
1354+
branch: Optional[str] = None,
13351355
) -> None:
13361356
"""
13371357
Shorthand for overwriting the table with a PyArrow table.
@@ -1364,7 +1384,7 @@ def delete(
13641384
delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE,
13651385
snapshot_properties: Dict[str, str] = EMPTY_DICT,
13661386
case_sensitive: bool = True,
1367-
branch: str = MAIN_BRANCH,
1387+
branch: Optional[str] = None,
13681388
) -> None:
13691389
"""
13701390
Shorthand for deleting rows from the table.

pyiceberg/table/update/snapshot.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,30 +105,39 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
105105
_added_data_files: List[DataFile]
106106
_manifest_num_counter: itertools.count[int]
107107
_deleted_data_files: Set[DataFile]
108-
_branch: str
109108

110109
def __init__(
111110
self,
112111
operation: Operation,
113112
transaction: Transaction,
114113
io: FileIO,
115-
branch: str,
116114
commit_uuid: Optional[uuid.UUID] = None,
117115
snapshot_properties: Dict[str, str] = EMPTY_DICT,
116+
branch: str = MAIN_BRANCH,
118117
) -> None:
119118
super().__init__(transaction)
120119
self.commit_uuid = commit_uuid or uuid.uuid4()
121120
self._io = io
122121
self._operation = operation
123122
self._snapshot_id = self._transaction.table_metadata.new_snapshot_id()
124-
self._branch = branch
125-
self._parent_snapshot_id = (
126-
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._branch)) else None
127-
)
128123
self._added_data_files = []
129124
self._deleted_data_files = set()
130125
self.snapshot_properties = snapshot_properties
131126
self._manifest_num_counter = itertools.count(0)
127+
self._set_target_branch(branch=branch)
128+
self._parent_snapshot_id = (
129+
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
130+
)
131+
132+
def _set_target_branch(self, branch: str) -> None:
133+
# Default is already set to MAIN_BRANCH. So branch name can't be None.
134+
assert branch is not None, ValueError("Invalid branch name: null")
135+
if branch in self._transaction.table_metadata.refs:
136+
ref = self._transaction.table_metadata.refs[branch]
137+
assert ref.snapshot_ref_type == SnapshotRefType.BRANCH, ValueError(
138+
f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots"
139+
)
140+
self._target_branch = branch
132141

133142
def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
134143
self._added_data_files.append(data_file)
@@ -276,16 +285,16 @@ def _commit(self) -> UpdatesAndRequirements:
276285
SetSnapshotRefUpdate(
277286
snapshot_id=self._snapshot_id,
278287
parent_snapshot_id=self._parent_snapshot_id,
279-
ref_name=self._branch,
288+
ref_name=self._target_branch,
280289
type=SnapshotRefType.BRANCH,
281290
),
282291
),
283292
(
284293
AssertRefSnapshotId(
285-
snapshot_id=self._transaction.table_metadata.refs[self._branch].snapshot_id
286-
if self._branch in self._transaction.table_metadata.refs
294+
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
295+
if self._target_branch in self._transaction.table_metadata.refs
287296
else self._transaction.table_metadata.current_snapshot_id,
288-
ref=self._branch,
297+
ref=self._target_branch,
289298
),
290299
),
291300
)
@@ -338,7 +347,7 @@ def __init__(
338347
commit_uuid: Optional[uuid.UUID] = None,
339348
snapshot_properties: Dict[str, str] = EMPTY_DICT,
340349
):
341-
super().__init__(operation, transaction, io, branch, commit_uuid, snapshot_properties)
350+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
342351
self._predicate = AlwaysFalse()
343352
self._case_sensitive = True
344353

@@ -503,7 +512,7 @@ def __init__(
503512
) -> None:
504513
from pyiceberg.table import TableProperties
505514

506-
super().__init__(operation, transaction, io, branch, commit_uuid, snapshot_properties)
515+
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
507516
self._target_size_bytes = property_as_int(
508517
self._transaction.table_metadata.properties,
509518
TableProperties.MANIFEST_TARGET_SIZE_BYTES,
@@ -549,7 +558,7 @@ def _existing_manifests(self) -> List[ManifestFile]:
549558
"""Determine if there are any existing manifest files."""
550559
existing_files = []
551560

552-
if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._branch):
561+
if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch):
553562
for manifest_file in snapshot.manifests(io=self._io):
554563
entries = manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True)
555564
found_deleted_data_files = [entry.data_file for entry in entries if entry.data_file in self._deleted_data_files]
@@ -623,12 +632,16 @@ class UpdateSnapshot:
623632
_snapshot_properties: Dict[str, str]
624633

625634
def __init__(
626-
self, transaction: Transaction, io: FileIO, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: str = MAIN_BRANCH
635+
self,
636+
transaction: Transaction,
637+
io: FileIO,
638+
snapshot_properties: Dict[str, str] = EMPTY_DICT,
639+
branch: Optional[str] = MAIN_BRANCH,
627640
) -> None:
628641
self._transaction = transaction
629642
self._io = io
630643
self._snapshot_properties = snapshot_properties
631-
self._branch = branch
644+
self._branch = branch if branch is not None else MAIN_BRANCH
632645

633646
def fast_append(self) -> _FastAppendFiles:
634647
return _FastAppendFiles(

pyiceberg/utils/concurrent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
class ExecutorFactory:
2626
_instance: Optional[Executor] = None
2727

28+
@staticmethod
29+
def max_workers() -> Optional[int]:
30+
"""Return the max number of workers configured."""
31+
return Config().get_int("max-workers")
32+
2833
@staticmethod
2934
def get_or_create() -> Executor:
3035
"""Return the same executor in each call."""
@@ -33,8 +38,3 @@ def get_or_create() -> Executor:
3338
ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers)
3439

3540
return ExecutorFactory._instance
36-
37-
@staticmethod
38-
def max_workers() -> Optional[int]:
39-
"""Return the max number of workers configured."""
40-
return Config().get_int("max-workers")

0 commit comments

Comments
 (0)