Skip to content

Commit a95c870

Browse files
committed
for UpdateSpec
1 parent 9d47ddf commit a95c870

2 files changed

Lines changed: 177 additions & 17 deletions

File tree

pyiceberg/table/update/spec.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,48 @@ class UpdateSpec(UpdateTableMetadata["UpdateSpec"]):
5858
_adds: list[PartitionField]
5959
_deletes: set[int]
6060
_last_assigned_partition_id: int
61+
# Store (source_column_name, transform, partition_field_name) for retry support
62+
_field_additions: list[tuple[str, Transform[Any, Any], str | None]]
6163

6264
def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None:
6365
super().__init__(transaction)
64-
self._name_to_field = {field.name: field for field in transaction.table_metadata.spec().fields}
65-
self._name_to_added_field = {}
66-
self._transform_to_field = {
67-
(field.source_id, repr(field.transform)): field for field in transaction.table_metadata.spec().fields
68-
}
69-
self._transform_to_added_field = {}
70-
self._adds = []
71-
self._deletes = set()
72-
self._last_assigned_partition_id = transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1
73-
self._renames = {}
7466
self._transaction = transaction
7567
self._case_sensitive = case_sensitive
68+
self._field_additions = []
69+
self._deletes = set()
70+
self._renames = {}
71+
# Initialize state from current metadata
72+
self._init_state_from_metadata()
73+
74+
def _init_state_from_metadata(self) -> None:
75+
"""Initialize or reinitialize state from current transaction metadata."""
76+
spec = self._transaction.table_metadata.spec()
77+
self._name_to_field = {field.name: field for field in spec.fields}
78+
self._transform_to_field = {(field.source_id, repr(field.transform)): field for field in spec.fields}
79+
self._last_assigned_partition_id = self._transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1
80+
# Clear intermediate state
81+
self._name_to_added_field = {}
82+
self._transform_to_added_field = {}
7683
self._added_time_fields = {}
84+
self._adds = []
7785

78-
def add_field(
86+
def _reset_state(self) -> None:
87+
"""Reset state for retry, rebuilding from refreshed metadata.
88+
89+
This is called on transaction retry to reapply the spec changes on top of the refreshed table metadata."""
90+
self._init_state_from_metadata()
91+
for source_column_name, transform, partition_field_name in self._field_additions:
92+
self._do_add_field(source_column_name, transform, partition_field_name)
93+
94+
def _do_add_field(
7995
self,
8096
source_column_name: str,
81-
transform: str | Transform[Any, Any],
82-
partition_field_name: str | None = None,
83-
) -> UpdateSpec:
97+
transform: Transform[Any, Any],
98+
partition_field_name: str | None,
99+
) -> None:
84100
ref = Reference(source_column_name)
85101
bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive)
86-
if isinstance(transform, str):
87-
transform = parse_transform(transform)
88-
# verify transform can actually bind it
102+
89103
output_type = bound_ref.field.field_type
90104
if not transform.can_transform(output_type):
91105
raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}")
@@ -121,6 +135,16 @@ def add_field(
121135

122136
self._name_to_added_field[new_field.name] = new_field
123137
self._adds.append(new_field)
138+
139+
def add_field(
140+
self,
141+
source_column_name: str,
142+
transform: str | Transform[Any, Any],
143+
partition_field_name: str | None = None,
144+
) -> UpdateSpec:
145+
transform = parse_transform(transform)
146+
self._field_additions.append((source_column_name, transform, partition_field_name))
147+
self._do_add_field(source_column_name, transform, partition_field_name)
124148
return self
125149

126150
def add_identity(self, source_column_name: str) -> UpdateSpec:
@@ -178,6 +202,10 @@ def _commit(self) -> UpdatesAndRequirements:
178202

179203
return updates, requirements
180204

205+
def commit(self) -> None:
206+
updates, requirements = self._commit()
207+
self._transaction._apply(updates, requirements, pending_update=self)
208+
181209
def _apply(self) -> PartitionSpec:
182210
def _check_and_add_partition_name(
183211
schema: Schema, name: str, source_id: int, transform: Transform[Any, Any], partition_names: set[str]

tests/table/test_commit_retry.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,3 +997,135 @@ def mock_commit(
997997
result = table.scan().to_arrow()
998998
assert len(result) == 3
999999
assert result["id"].to_pylist() == [4, 5, 6]
1000+
1001+
1002+
class TestUpdateSpecRetry:
1003+
def test_update_spec_retried_on_conflict(self, catalog: SqlCatalog, schema: Schema) -> None:
1004+
"""Test that UpdateSpec operations are retried on CommitFailedException."""
1005+
from pyiceberg.transforms import BucketTransform
1006+
1007+
table = catalog.create_table(
1008+
"default.test_spec_retry",
1009+
schema=schema,
1010+
properties={
1011+
TableProperties.COMMIT_NUM_RETRIES: "3",
1012+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1013+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1014+
},
1015+
)
1016+
1017+
original_commit = catalog.commit_table
1018+
commit_count = 0
1019+
1020+
def mock_commit(
1021+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1022+
) -> CommitTableResponse:
1023+
nonlocal commit_count
1024+
commit_count += 1
1025+
if commit_count == 1:
1026+
raise CommitFailedException("Simulated spec conflict")
1027+
return original_commit(tbl, requirements, updates)
1028+
1029+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1030+
with table.update_spec() as update_spec:
1031+
update_spec.add_field(
1032+
source_column_name="id", transform=BucketTransform(16), partition_field_name="id_bucket"
1033+
)
1034+
1035+
assert commit_count == 2
1036+
1037+
def test_update_spec_resolves_conflict_on_retry(self, catalog: SqlCatalog, schema: Schema) -> None:
1038+
"""Test that spec update can resolve conflicts via retry"""
1039+
from pyiceberg.transforms import BucketTransform
1040+
1041+
table = catalog.create_table(
1042+
"default.test_spec_conflict_resolved",
1043+
schema=schema,
1044+
properties={
1045+
TableProperties.COMMIT_NUM_RETRIES: "5",
1046+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1047+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1048+
},
1049+
)
1050+
1051+
with table.update_spec() as update_spec:
1052+
update_spec.add_field(source_column_name="id", transform=BucketTransform(16), partition_field_name="id_bucket")
1053+
1054+
table2 = catalog.load_table("default.test_spec_conflict_resolved")
1055+
with table2.update_spec() as update_spec2:
1056+
update_spec2.add_identity("id")
1057+
1058+
assert table.spec().spec_id == 1
1059+
assert table2.spec().spec_id == 2
1060+
1061+
original_commit = catalog.commit_table
1062+
commit_count = 0
1063+
1064+
def mock_commit(
1065+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1066+
) -> CommitTableResponse:
1067+
nonlocal commit_count
1068+
commit_count += 1
1069+
return original_commit(tbl, requirements, updates)
1070+
1071+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1072+
# Retry resolves conflicts caused by mismatch spec_id
1073+
with table.update_spec() as update_spec:
1074+
update_spec.add_field(source_column_name="id", transform=BucketTransform(8), partition_field_name="id_bucket_new")
1075+
1076+
assert commit_count == 2
1077+
1078+
def test_transaction_with_spec_change_and_append_retries(
1079+
self, catalog: SqlCatalog, schema: Schema, arrow_table: pa.Table
1080+
) -> None:
1081+
"""Test that a transaction with spec change and append handles retry correctly."""
1082+
table = catalog.create_table(
1083+
"default.test_transaction_spec_and_append",
1084+
schema=schema,
1085+
properties={
1086+
TableProperties.COMMIT_NUM_RETRIES: "3",
1087+
TableProperties.COMMIT_MIN_RETRY_WAIT_MS: "1",
1088+
TableProperties.COMMIT_MAX_RETRY_WAIT_MS: "10",
1089+
},
1090+
)
1091+
1092+
original_commit = catalog.commit_table
1093+
commit_count = 0
1094+
captured_updates: list[tuple[TableUpdate, ...]] = []
1095+
1096+
def mock_commit(
1097+
tbl: Table, requirements: tuple[TableRequirement, ...], updates: tuple[TableUpdate, ...]
1098+
) -> CommitTableResponse:
1099+
nonlocal commit_count
1100+
commit_count += 1
1101+
captured_updates.append(updates)
1102+
if commit_count == 1:
1103+
raise CommitFailedException("Simulated conflict")
1104+
return original_commit(tbl, requirements, updates)
1105+
1106+
with patch.object(catalog, "commit_table", side_effect=mock_commit):
1107+
with table.transaction() as txn:
1108+
with txn.update_spec() as update_spec:
1109+
update_spec.add_identity("id")
1110+
txn.append(arrow_table)
1111+
1112+
assert commit_count == 2
1113+
1114+
first_attempt_update_types = [type(u).__name__ for u in captured_updates[0]]
1115+
assert "AddPartitionSpecUpdate" in first_attempt_update_types
1116+
assert "AddSnapshotUpdate" in first_attempt_update_types
1117+
1118+
retry_attempt_update_types = [type(u).__name__ for u in captured_updates[1]]
1119+
assert "AddPartitionSpecUpdate" in retry_attempt_update_types
1120+
assert "AddSnapshotUpdate" in retry_attempt_update_types
1121+
1122+
assert len(table.scan().to_arrow()) == 3
1123+
1124+
from pyiceberg.transforms import IdentityTransform
1125+
1126+
assert table.spec().spec_id == 1
1127+
assert len(table.spec().fields) == 1
1128+
partition_field = table.spec().fields[0]
1129+
assert partition_field.name == "id"
1130+
assert partition_field.source_id == 1 # "id" column's field_id
1131+
assert isinstance(partition_field.transform, IdentityTransform)

0 commit comments

Comments
 (0)