Skip to content

Commit c911b9b

Browse files
committed
MOARRR CODE
1 parent f400313 commit c911b9b

9 files changed

Lines changed: 135 additions & 71 deletions

File tree

pyiceberg/avro/resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType],
290290
# There is a default value
291291
if file_field.write_default is not None:
292292
# The field is not in the record, but there is a write default value
293-
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore
293+
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default)))
294294
elif file_field.required:
295295
raise ValueError(f"Field is required, and there is no write default: {file_field}")
296296
else:

pyiceberg/conversions.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -503,27 +503,47 @@ def _(_: Union[IntegerType, LongType], val: int) -> int:
503503

504504

505505
@from_json.register(DateType)
506-
def _(_: DateType, val: str) -> date:
506+
def _(_: DateType, val: Union[str, int, date]) -> date:
507507
"""JSON date is string encoded."""
508-
return days_to_date(date_str_to_days(val))
508+
if isinstance(val, str):
509+
val = date_str_to_days(val)
510+
if isinstance(val, int):
511+
return days_to_date(val)
512+
else:
513+
return val
509514

510515

511516
@from_json.register(TimeType)
512-
def _(_: TimeType, val: str) -> time:
517+
def _(_: TimeType, val: Union[str, int, time]) -> time:
513518
"""JSON ISO8601 string into Python time."""
514-
return micros_to_time(time_str_to_micros(val))
519+
if isinstance(val, str):
520+
val = time_str_to_micros(val)
521+
if isinstance(val, int):
522+
return micros_to_time(val)
523+
else:
524+
return val
515525

516526

517527
@from_json.register(TimestampType)
518-
def _(_: PrimitiveType, val: str) -> datetime:
528+
def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime:
519529
"""JSON ISO8601 string into Python datetime."""
520-
return micros_to_timestamp(timestamp_to_micros(val))
530+
if isinstance(val, str):
531+
val = timestamp_to_micros(val)
532+
if isinstance(val, int):
533+
return micros_to_timestamp(val)
534+
else:
535+
return val
521536

522537

523538
@from_json.register(TimestamptzType)
524-
def _(_: TimestamptzType, val: str) -> datetime:
539+
def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime:
525540
"""JSON ISO8601 string into Python datetime."""
526-
return micros_to_timestamptz(timestamptz_to_micros(val))
541+
if isinstance(val, str):
542+
val = timestamptz_to_micros(val)
543+
if isinstance(val, int):
544+
return micros_to_timestamptz(val)
545+
else:
546+
return val
527547

528548

529549
@from_json.register(FloatType)
@@ -540,20 +560,26 @@ def _(_: StringType, val: str) -> str:
540560

541561

542562
@from_json.register(FixedType)
543-
def _(t: FixedType, val: str) -> bytes:
563+
def _(t: FixedType, val: Union[str, bytes]) -> bytes:
544564
"""JSON hexadecimal encoded string into bytes."""
545-
b = codecs.decode(val.encode(UTF8), "hex")
565+
if isinstance(val, str):
566+
b = codecs.decode(val.encode(UTF8), "hex")
546567

547-
if len(t) != len(b):
548-
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}")
568+
if len(t) != len(b):
569+
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}")
549570

550-
return b
571+
return b
572+
else:
573+
return val
551574

552575

553576
@from_json.register(BinaryType)
554-
def _(_: BinaryType, val: str) -> bytes:
577+
def _(_: BinaryType, val: Union[bytes, str]) -> bytes:
555578
"""JSON hexadecimal encoded string into bytes."""
556-
return codecs.decode(val.encode(UTF8), "hex")
579+
if isinstance(val, str):
580+
return codecs.decode(val.encode(UTF8), "hex")
581+
else:
582+
return val
557583

558584

559585
@from_json.register(DecimalType)
@@ -563,6 +589,11 @@ def _(_: DecimalType, val: str) -> Decimal:
563589

564590

565591
@from_json.register(UUIDType)
566-
def _(_: UUIDType, val: str) -> uuid.UUID:
592+
def _(_: UUIDType, val: Union[str, bytes, uuid.UUID]) -> uuid.UUID:
567593
"""Convert JSON string into Python UUID."""
568-
return uuid.UUID(val)
594+
if isinstance(val, str):
595+
return uuid.UUID(val)
596+
elif isinstance(val, bytes):
597+
return uuid.UUID(bytes=val)
598+
else:
599+
return val

pyiceberg/expressions/literals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import struct
2525
from abc import ABC, abstractmethod
26-
from datetime import date, datetime
26+
from datetime import date, datetime, time
2727
from decimal import ROUND_HALF_UP, Decimal
2828
from functools import singledispatchmethod
2929
from math import isnan
@@ -54,6 +54,7 @@
5454
datetime_to_micros,
5555
micros_to_days,
5656
time_str_to_micros,
57+
time_to_micros,
5758
timestamp_to_micros,
5859
timestamptz_to_micros,
5960
)
@@ -152,6 +153,8 @@ def literal(value: L) -> Literal[L]:
152153
return TimestampLiteral(datetime_to_micros(value)) # type: ignore
153154
elif isinstance(value, date):
154155
return DateLiteral(date_to_days(value)) # type: ignore
156+
elif isinstance(value, time):
157+
return TimeLiteral(time_to_micros(value))
155158
else:
156159
raise TypeError(f"Invalid literal value: {repr(value)}")
157160

pyiceberg/table/update/schema.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,11 +398,11 @@ def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_v
398398
except ValueError as e:
399399
raise ValueError(f"Invalid default value: {e}") from e
400400

401-
if field.required and default_value != field.write_default:
401+
if field.required and default_value == field.write_default:
402402
# if the change is a noop, allow it even if allowIncompatibleChanges is false
403403
return
404404

405-
if not self._allow_incompatible_changes and field.required and default_value is not None:
405+
if not self._allow_incompatible_changes and field.required and default_value is None:
406406
raise ValueError("Cannot change change default-value of a required column to None")
407407

408408
if field.field_id in self._deletes:
@@ -729,19 +729,35 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]])
729729
name = field.name
730730
doc = field.doc
731731
required = field.required
732+
write_default = field.write_default
732733

733734
# There is an update
734735
if update := self._updates.get(field.field_id):
735736
name = update.name
736737
doc = update.doc
737738
required = update.required
738-
739-
if field.name == name and field.field_type == result_type and field.required == required and field.doc == doc:
739+
write_default = update.write_default
740+
741+
if (
742+
field.name == name
743+
and field.field_type == result_type
744+
and field.required == required
745+
and field.doc == doc
746+
and field.write_default == write_default
747+
):
740748
new_fields.append(field)
741749
else:
742750
has_changes = True
743751
new_fields.append(
744-
NestedField(field_id=field.field_id, name=name, field_type=result_type, required=required, doc=doc)
752+
NestedField(
753+
field_id=field.field_id,
754+
name=name,
755+
field_type=result_type,
756+
required=required,
757+
doc=doc,
758+
initial_default=field.initial_default,
759+
write_default=write_default,
760+
)
745761
)
746762

747763
if has_changes:

pyiceberg/types.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,24 @@
3535
import re
3636
from functools import cached_property
3737
from typing import (
38+
Annotated,
3839
Any,
3940
ClassVar,
4041
Dict,
4142
Literal,
4243
Optional,
43-
Tuple, Annotated,
44+
Tuple,
4445
)
4546

4647
from pydantic import (
48+
BeforeValidator,
4749
Field,
4850
PrivateAttr,
4951
SerializeAsAny,
5052
model_serializer,
51-
model_validator, PlainSerializer, BeforeValidator,
53+
model_validator,
5254
)
53-
from pydantic_core.core_schema import ValidatorFunctionWrapHandler, SerializationInfo, ValidationInfo
55+
from pydantic_core.core_schema import ValidationInfo, ValidatorFunctionWrapHandler
5456

5557
from pyiceberg.exceptions import ValidationError
5658
from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, TableVersion
@@ -289,24 +291,21 @@ def __eq__(self, other: Any) -> bool:
289291
return self.root == other.root if isinstance(other, DecimalType) else False
290292

291293

292-
# def _serialize_default_value(v: Any, context: SerializationInfo) -> Any:
293-
# from pyiceberg.conversions import to_json, from_json
294-
# return v
295294
def _deserialize_default_value(v: Any, context: ValidationInfo) -> Any:
296-
if context.mode != 'python':
297-
if v is not None:
298-
from pyiceberg.conversions import from_json
299-
return from_json(context.data.get("field_type"), v)
300-
else:
301-
return None
295+
if v is not None:
296+
from pyiceberg.conversions import from_json
297+
298+
return from_json(context.data.get("field_type"), v)
302299
else:
303-
return v
300+
return None
301+
304302

305-
# PlainSerializer(_serialize_default_value, return_type=Any),
306303
DefaultValue = Annotated[
307-
Any, BeforeValidator(_deserialize_default_value)
304+
L,
305+
BeforeValidator(_deserialize_default_value),
308306
]
309307

308+
310309
class NestedField(IcebergType):
311310
"""Represents a field of a struct, a map key, a map value, or a list element.
312311
@@ -335,8 +334,8 @@ class NestedField(IcebergType):
335334
field_type: SerializeAsAny[IcebergType] = Field(alias="type")
336335
required: bool = Field(default=False)
337336
doc: Optional[str] = Field(default=None, repr=False)
338-
initial_default: DefaultValue = Field(alias="initial-default", default=None, repr=False)
339-
write_default: DefaultValue = Field(alias="write-default", default=None, repr=False) # type: ignore
337+
initial_default: Optional[DefaultValue] = Field(alias="initial-default", default=None, repr=False) # type: ignore
338+
write_default: Optional[DefaultValue] = Field(alias="write-default", default=None, repr=False) # type: ignore
340339

341340
def __init__(
342341
self,
@@ -360,6 +359,26 @@ def __init__(
360359
data["write-default"] = data["write-default"] if "write-default" in data else write_default
361360
super().__init__(**data)
362361

362+
@model_serializer()
363+
def serialize_model(self) -> Dict[str, Any]:
364+
from pyiceberg.conversions import to_json
365+
366+
fields = {
367+
"id": self.field_id,
368+
"name": self.name,
369+
"type": self.field_type,
370+
"required": self.required,
371+
}
372+
373+
if self.doc is not None:
374+
fields["doc"] = self.doc
375+
if self.initial_default is not None:
376+
fields["initial-default"] = to_json(self.field_type, self.initial_default)
377+
if self.write_default is not None:
378+
fields["write-default"] = to_json(self.field_type, self.write_default)
379+
380+
return fields
381+
363382
def __str__(self) -> str:
364383
"""Return the string representation of the NestedField class."""
365384
doc = "" if not self.doc else f" ({self.doc})"

pyiceberg/utils/schema_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def field(self, field: NestedField, field_result: AvroType) -> AvroType:
530530
}
531531

532532
if field.write_default is not None:
533-
result["default"] = field.write_default # type: ignore
533+
result["default"] = field.write_default
534534
elif field.optional:
535535
result["default"] = None
536536

tests/integration/test_rest_schema.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
from pyiceberg.catalog import Catalog, load_catalog
2626
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError
27-
from pyiceberg.expressions import literal # type: ignore
2827
from pyiceberg.partitioning import PartitionField, PartitionSpec
2928
from pyiceberg.schema import Schema, prune_columns
3029
from pyiceberg.table import Table, TableProperties
@@ -1100,25 +1099,25 @@ def test_add_required_column(catalog: Catalog) -> None:
11001099
@pytest.mark.parametrize(
11011100
"iceberg_type, default_value, write_default",
11021101
[
1103-
# (BooleanType(), True, False),
1104-
# (IntegerType(), 123, 456),
1105-
# (LongType(), 123, 456),
1106-
# (FloatType(), 19.25, 22.27),
1107-
# (DoubleType(), 19.25, 22.27),
1102+
(BooleanType(), True, False),
1103+
(IntegerType(), 123, 456),
1104+
(LongType(), 123, 456),
1105+
(FloatType(), 19.25, 22.27),
1106+
(DoubleType(), 19.25, 22.27),
11081107
(DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")),
1109-
# (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")),
1110-
# (StringType(), "abc", "def"),
1111-
# (DateType(), date(1990, 3, 1), date(1991, 3, 1)),
1112-
# (TimeType(), time(19, 25, 22), time(22, 25, 22)),
1113-
# (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)),
1114-
# (
1115-
# TimestamptzType(),
1116-
# datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc),
1117-
# datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc),
1118-
# ),
1119-
# (BinaryType(), b"123", b"456"),
1120-
# (FixedType(4), b"1234", b"5678"),
1121-
# (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)),
1108+
(DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")),
1109+
(StringType(), "abc", "def"),
1110+
(DateType(), date(1990, 3, 1), date(1991, 3, 1)),
1111+
(TimeType(), time(19, 25, 22), time(22, 25, 22)),
1112+
(TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)),
1113+
(
1114+
TimestamptzType(),
1115+
datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc),
1116+
datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc),
1117+
),
1118+
(BinaryType(), b"123", b"456"),
1119+
(FixedType(4), b"1234", b"5678"),
1120+
(UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)),
11221121
],
11231122
)
11241123
def test_initial_default_all_columns(
@@ -1132,17 +1131,15 @@ def test_initial_default_all_columns(
11321131
tx.commit()
11331132

11341133
field = table.schema().find_field(1)
1135-
physical_type = literal(default_value).to(iceberg_type).value
1136-
assert field.initial_default == physical_type
1137-
assert field.write_default == physical_type
1134+
assert field.initial_default == default_value
1135+
assert field.write_default == default_value
11381136

11391137
with table.update_schema() as tx:
11401138
tx.set_default_value("data", write_default)
11411139

11421140
field = table.schema().find_field(1)
1143-
write_physical_type = literal(default_value).to(iceberg_type).value
1144-
assert field.initial_default == physical_type
1145-
assert field.write_default == write_physical_type
1141+
assert field.initial_default == default_value
1142+
assert field.write_default == write_default
11461143

11471144

11481145
@pytest.mark.integration

tests/test_types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,9 @@ def test_nested_field() -> None:
225225
assert str(field_var) == str(eval(repr(field_var)))
226226
assert field_var == pickle.loads(pickle.dumps(field_var))
227227

228-
with pytest.raises(pydantic_core.ValidationError) as exc_info:
228+
with pytest.raises(pydantic_core.ValidationError, match=".*validation errors for NestedField.*"):
229229
_ = (NestedField(1, "field", StringType(), required=True, write_default=(1, "a", True)),) # type: ignore
230230

231-
assert "validation errors for NestedField" in str(exc_info.value)
232-
233231

234232
@pytest.mark.parametrize("input_index,input_type", non_parameterized_types)
235233
@pytest.mark.parametrize("check_index,check_type", non_parameterized_types)

tests/utils/test_manifest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def test_write_manifest(
416416

417417
data_file = manifest_entry.data_file
418418

419-
assert data_file.content is DataFileContent.DATA
419+
assert data_file.content == DataFileContent.DATA
420420
assert (
421421
data_file.file_path
422422
== "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet"

0 commit comments

Comments
 (0)