Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions pyiceberg/table/update/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from copy import copy
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import literal # type: ignore
from pyiceberg.schema import (
PartnerAccessor,
Schema,
Expand Down Expand Up @@ -153,7 +154,12 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
return self

def add_column(
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
self,
path: Union[str, Tuple[str, ...]],
field_type: IcebergType,
doc: Optional[str] = None,
required: bool = False,
default_value: Optional[Any] = None,
) -> UpdateSchema:
"""Add a new column to a nested struct or Add a new top-level column.

Expand All @@ -168,6 +174,7 @@ def add_column(
field_type: Type for the new column.
doc: Documentation string for the new column.
required: Whether the new column is required.
default_value: Default value for the new column.

Returns:
This for method chaining.
Expand All @@ -177,10 +184,6 @@ def add_column(
raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead")
path = (path,)

if required and not self._allow_incompatible_changes:
# Table format version 1 and 2 cannot add required column because there is no initial value
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")

name = path[-1]
parent = path[:-1]

Expand Down Expand Up @@ -212,13 +215,34 @@ def add_column(

# assign new IDs in order
new_id = self.assign_new_column_id()
new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)

if default_value is not None:
try:
# To make sure that the value is valid for the type
initial_default = literal(default_value).to(new_type).value
except ValueError as e:
raise ValueError(f"Invalid default value: {e}") from e
else:
initial_default = default_value

if (required and initial_default is None) and not self._allow_incompatible_changes:
# Table format version 1 and 2 cannot add required column because there is no initial value
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")

# update tracking for moves
self._added_name_to_id[full_name] = new_id
self._id_to_parent[new_id] = parent_full_path

new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)
field = NestedField(field_id=new_id, name=name, field_type=new_type, required=required, doc=doc)
field = NestedField(
field_id=new_id,
name=name,
field_type=new_type,
required=required,
doc=doc,
initial_default=initial_default,
write_default=initial_default,
Comment thread
Fokko marked this conversation as resolved.
)
Comment thread
Fokko marked this conversation as resolved.

if parent_id in self._adds:
self._adds[parent_id].append(field)
Expand Down Expand Up @@ -330,6 +354,8 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
field_type=updated.field_type,
doc=updated.doc,
required=required,
initial_default=updated.initial_default,
Comment thread
Fokko marked this conversation as resolved.
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
Expand All @@ -338,6 +364,8 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
field_type=field.field_type,
doc=field.doc,
required=required,
initial_default=field.initial_default,
write_default=field.write_default,
)

def update_column(
Comment thread
Fokko marked this conversation as resolved.
Expand Down Expand Up @@ -387,6 +415,8 @@ def update_column(
field_type=field_type or updated.field_type,
doc=doc if doc is not None else updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
Expand All @@ -395,6 +425,8 @@ def update_column(
field_type=field_type or field.field_type,
doc=doc if doc is not None else field.doc,
required=field.required,
initial_default=field.initial_default,
write_default=field.write_default,
)

if required is not None:
Expand Down
36 changes: 30 additions & 6 deletions tests/integration/test_rest_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pyiceberg.table.sorting import SortField, SortOrder
from pyiceberg.table.update.schema import UpdateSchema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import EMPTY_DICT, Properties
from pyiceberg.types import (
BinaryType,
BooleanType,
Expand Down Expand Up @@ -69,7 +70,7 @@ def simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table:
return _create_table_with_schema(catalog, table_schema_simple)


def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table:
def _create_table_with_schema(catalog: Catalog, schema: Schema, properties: Properties = EMPTY_DICT) -> Table:
tbl_name = "default.test_schema_evolution"
try:
catalog.drop_table(tbl_name)
Expand All @@ -78,7 +79,7 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table:
return catalog.create_table(
identifier=tbl_name,
schema=schema,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json(), **properties},
)


Expand Down Expand Up @@ -1076,9 +1077,8 @@ def test_add_required_column(catalog: Catalog) -> None:
schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False))
table = _create_table_with_schema(catalog, schema_)
update = table.update_schema()
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match="Incompatible change: cannot add required column: data"):
update.add_column(path="data", field_type=IntegerType(), required=True)
assert "Incompatible change: cannot add required column: data" in str(exc_info.value)

new_schema = (
UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True)
Expand All @@ -1091,16 +1091,40 @@ def test_add_required_column(catalog: Catalog) -> None:
)


@pytest.mark.integration
def test_add_required_column_initial_default(catalog: Catalog) -> None:
schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False))
table = _create_table_with_schema(catalog, schema_)
new_schema = (
UpdateSchema(transaction=table.transaction())
.add_column(path="data", field_type=IntegerType(), required=True, default_value=22)
._apply()
)
assert new_schema == Schema(
NestedField(field_id=1, name="a", field_type=BooleanType(), required=False),
NestedField(field_id=2, name="data", field_type=IntegerType(), required=True, initial_default=22, write_default=22),
schema_id=1,
)


@pytest.mark.integration
def test_add_required_column_initial_default_invalid_value(catalog: Catalog) -> None:
schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False))
table = _create_table_with_schema(catalog, schema_)
update = table.update_schema()
with pytest.raises(ValueError, match="Invalid default value: Could not convert abc into a int"):
update.add_column(path="data", field_type=IntegerType(), required=True, default_value="abc")


@pytest.mark.integration
def test_add_required_column_case_insensitive(catalog: Catalog) -> None:
schema_ = Schema(NestedField(field_id=1, name="id", field_type=BooleanType(), required=False))
table = _create_table_with_schema(catalog, schema_)

with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match="already exists: ID"):
with table.transaction() as txn:
with txn.update_schema(allow_incompatible_changes=True) as update:
update.case_sensitive(False).add_column(path="ID", field_type=IntegerType(), required=True)
assert "already exists: ID" in str(exc_info.value)

new_schema = (
UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True)
Expand Down