Skip to content

Commit 2f0a5cb

Browse files
committed
Fix UUID support
1 parent bedf777 commit 2f0a5cb

4 files changed

Lines changed: 63 additions & 6 deletions

File tree

pyiceberg/avro/writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
List,
3333
Optional,
3434
Tuple,
35+
Union,
3536
)
3637
from uuid import UUID
3738

@@ -121,8 +122,11 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None:
121122

122123
@dataclass(frozen=True)
123124
class UUIDWriter(Writer):
124-
def write(self, encoder: BinaryEncoder, val: UUID) -> None:
125-
encoder.write(val.bytes)
125+
def write(self, encoder: BinaryEncoder, val: Union[UUID, bytes]) -> None:
126+
if isinstance(val, UUID):
127+
encoder.write(val.bytes)
128+
else:
129+
encoder.write(val)
126130

127131

128132
@dataclass(frozen=True)

pyiceberg/io/pyarrow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def visit_string(self, _: StringType) -> pa.DataType:
684684
return pa.large_string()
685685

686686
def visit_uuid(self, _: UUIDType) -> pa.DataType:
687-
return pa.binary(16)
687+
return pa.uuid()
688688

689689
def visit_unknown(self, _: UnknownType) -> pa.DataType:
690690
return pa.null()
@@ -1252,6 +1252,8 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
12521252
return FixedType(primitive.byte_width)
12531253
elif pa.types.is_null(primitive):
12541254
return UnknownType()
1255+
elif isinstance(primitive, pa.UuidType):
1256+
return UUIDType()
12551257

12561258
raise TypeError(f"Unsupported type: {primitive}")
12571259

pyiceberg/partitioning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,13 @@ def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
467467

468468

469469
@_to_partition_representation.register(UUIDType)
470-
def _(type: IcebergType, value: Optional[Union[uuid.UUID, int]]) -> Optional[Union[str, int]]:
470+
def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optional[Union[bytes, int]]:
471471
if value is None:
472472
return None
473+
elif isinstance(value, bytes):
474+
return value # IdentityTransform
473475
elif isinstance(value, uuid.UUID):
474-
return str(value) # IdentityTransform
476+
return value.bytes # IdentityTransform
475477
elif isinstance(value, int):
476478
return value # BucketTransform
477479
else:

tests/integration/test_writes/test_writes.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import random
2121
import time
22+
import uuid
2223
from datetime import date, datetime, timedelta
2324
from decimal import Decimal
2425
from pathlib import Path
@@ -48,7 +49,7 @@
4849
from pyiceberg.schema import Schema
4950
from pyiceberg.table import TableProperties
5051
from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
51-
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
52+
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, BucketTransform, Transform
5253
from pyiceberg.types import (
5354
DateType,
5455
DecimalType,
@@ -58,6 +59,7 @@
5859
LongType,
5960
NestedField,
6061
StringType,
62+
UUIDType,
6163
)
6264
from utils import _create_table
6365

@@ -1841,3 +1843,50 @@ def test_read_write_decimals(session_catalog: Catalog) -> None:
18411843
tbl.append(arrow_table)
18421844

18431845
assert tbl.scan().to_arrow() == arrow_table
1846+
1847+
1848+
@pytest.mark.integration
1849+
@pytest.mark.parametrize("transform", [IdentityTransform(), BucketTransform(32)])
1850+
def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transform: Transform) -> None:
1851+
identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}"
1852+
1853+
schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True))
1854+
1855+
try:
1856+
session_catalog.drop_table(identifier=identifier)
1857+
except NoSuchTableError:
1858+
pass
1859+
1860+
partition_spec = PartitionSpec(
1861+
PartitionField(source_id=1, field_id=1000, transform=transform, name="uuid_identity")
1862+
)
1863+
1864+
import pyarrow as pa
1865+
1866+
arr_table = pa.Table.from_pydict(
1867+
{
1868+
"uuid": [
1869+
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
1870+
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
1871+
],
1872+
},
1873+
schema=pa.schema(
1874+
[
1875+
# Uuid not yet supported, so we have to stick with `binary(16)`
1876+
# https://github.com/apache/arrow/issues/46468
1877+
pa.field("uuid", pa.binary(16), nullable=False),
1878+
]
1879+
),
1880+
)
1881+
1882+
tbl = session_catalog.create_table(
1883+
identifier=identifier,
1884+
schema=schema,
1885+
partition_spec=partition_spec,
1886+
)
1887+
1888+
tbl.append(arr_table)
1889+
1890+
lhs = [r[0] for r in spark.table(identifier).collect()]
1891+
rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]
1892+
assert lhs == rhs

0 commit comments

Comments
 (0)