Skip to content

Commit 958aac4

Browse files
committed
changed design context for branch writes
1 parent 4ed0607 commit 958aac4

2 files changed

Lines changed: 14 additions & 10 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from pyiceberg.table.name_mapping import (
8888
NameMapping,
8989
)
90-
from pyiceberg.table.refs import SnapshotRef
90+
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
9191
from pyiceberg.table.snapshots import (
9292
Snapshot,
9393
SnapshotLogEntry,
@@ -437,6 +437,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran
437437
Returns:
438438
A new UpdateSnapshot
439439
"""
440+
if branch is None:
441+
branch = MAIN_BRANCH
442+
440443
return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
441444

442445
def update_statistics(self) -> UpdateStatistics:

pyiceberg/table/update/snapshot.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
105105
_added_data_files: List[DataFile]
106106
_manifest_num_counter: itertools.count[int]
107107
_deleted_data_files: Set[DataFile]
108+
_target_branch = MAIN_BRANCH
108109

109110
def __init__(
110111
self,
@@ -124,20 +125,20 @@ def __init__(
124125
self._deleted_data_files = set()
125126
self.snapshot_properties = snapshot_properties
126127
self._manifest_num_counter = itertools.count(0)
127-
self._set_target_branch(branch=branch)
128+
self._target_branch = self._validate_target_branch(branch=branch)
128129
self._parent_snapshot_id = (
129130
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
130131
)
131132

132-
def _set_target_branch(self, branch: str) -> None:
133+
def _validate_target_branch(self, branch: str) -> str:
133134
# Default is already set to MAIN_BRANCH. So branch name can't be None.
134-
assert branch is not None, "Invalid branch name: null"
135+
if branch is None:
136+
raise ValueError("Invalid branch name: null")
135137
if branch in self._transaction.table_metadata.refs:
136138
ref = self._transaction.table_metadata.refs[branch]
137-
assert (
138-
ref.snapshot_ref_type == SnapshotRefType.BRANCH
139-
), f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots"
140-
self._target_branch = branch
139+
if ref.snapshot_ref_type != SnapshotRefType.BRANCH:
140+
raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots")
141+
return branch
141142

142143
def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]:
143144
self._added_data_files.append(data_file)
@@ -639,13 +640,13 @@ def __init__(
639640
self,
640641
transaction: Transaction,
641642
io: FileIO,
643+
branch: str,
642644
snapshot_properties: Dict[str, str] = EMPTY_DICT,
643-
branch: Optional[str] = MAIN_BRANCH,
644645
) -> None:
645646
self._transaction = transaction
646647
self._io = io
647648
self._snapshot_properties = snapshot_properties
648-
self._branch = branch if branch is not None else MAIN_BRANCH
649+
self._branch = branch
649650

650651
def fast_append(self) -> _FastAppendFiles:
651652
return _FastAppendFiles(

0 commit comments

Comments
 (0)