Skip to content

Commit 580c986

Browse files
author
redpheonixx
committed
handle decimal physicial type mapping
1 parent 1c0e2b0 commit 580c986

2 files changed

Lines changed: 41 additions & 11 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
176176
from pyiceberg.utils.singleton import Singleton
177177
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
178+
from decimal import Decimal, Context
178179

179180
if TYPE_CHECKING:
180181
from pyiceberg.table import FileScanTask, WriteTask
@@ -194,7 +195,7 @@
194195
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}
195196

196197
T = TypeVar("T")
197-
198+
DECIMAL_REGEX = re.compile(r"decimal\((\d+),\s*(\d+)\)")
198199

199200
@lru_cache
200201
def _cached_resolve_s3_region(bucket: str) -> Optional[str]:
@@ -1868,7 +1869,11 @@ def visit_fixed(self, fixed_type: FixedType) -> str:
18681869
return "FIXED_LEN_BYTE_ARRAY"
18691870

18701871
def visit_decimal(self, decimal_type: DecimalType) -> str:
1871-
return "FIXED_LEN_BYTE_ARRAY"
1872+
return (
1873+
"INT32" if decimal_type.precision <= 9
1874+
else "INT64" if decimal_type.precision <= 18
1875+
else "FIXED_LEN_BYTE_ARRAY"
1876+
)
18721877

18731878
def visit_boolean(self, boolean_type: BooleanType) -> str:
18741879
return "BOOLEAN"
@@ -2335,9 +2340,18 @@ def data_file_statistics_from_parquet_metadata(
23352340
col_aggs[field_id] = StatsAggregator(
23362341
stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length
23372342
)
2338-
2339-
col_aggs[field_id].update_min(statistics.min)
2340-
col_aggs[field_id].update_max(statistics.max)
2343+
matches=DECIMAL_REGEX.search(str(stats_col.iceberg_type))
2344+
if matches and statistics.physical_type != "FIXED_LEN_BYTE_ARRAY":
2345+
precision=int(matches.group(1))
2346+
scale=int(matches.group(2))
2347+
local_context = Context(prec=precision)
2348+
decoded_min = local_context.create_decimal(Decimal(statistics.min_raw)/ (10 ** scale))
2349+
decoded_max = local_context.create_decimal(Decimal(statistics.max_raw)/ (10 ** scale))
2350+
col_aggs[field_id].update_min(decoded_min)
2351+
col_aggs[field_id].update_max(decoded_max)
2352+
else:
2353+
col_aggs[field_id].update_min(statistics.min)
2354+
col_aggs[field_id].update_max(statistics.max)
23412355

23422356
except pyarrow.lib.ArrowNotImplementedError as e:
23432357
invalidate_col.add(field_id)

tests/io/test_pyarrow_stats.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
StringType,
7373
)
7474
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros
75-
75+
from decimal import Decimal
7676

7777
@dataclass(frozen=True)
7878
class TestStruct:
@@ -446,6 +446,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
446446
{"id": 10, "name": "strings", "required": False, "type": "string"},
447447
{"id": 11, "name": "uuids", "required": False, "type": "uuid"},
448448
{"id": 12, "name": "binaries", "required": False, "type": "binary"},
449+
{"id": 13, "name": "decimal8", "required": False, "type": "decimal(8, 2)"},
450+
{"id": 14, "name": "decimal16", "required": False, "type": "decimal(16, 6)"},
451+
{"id": 15, "name": "decimal32", "required": False, "type": "decimal(20, 6)"},
449452
],
450453
},
451454
],
@@ -470,6 +473,9 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
470473
strings = ["hello", "world"]
471474
uuids = [uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes, uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes]
472475
binaries = [b"hello", b"world"]
476+
decimal8 = [Decimal("123.45"), Decimal("678.91")]
477+
decimal16 = [Decimal("123456789.123456"), Decimal("678912345.678912")]
478+
decimal32 = [Decimal("12345678901234.123456"), Decimal("98765432109870.654321")]
473479

474480
table = pa.Table.from_pydict(
475481
{
@@ -485,14 +491,17 @@ def construct_test_table_primitive_types() -> Tuple[pq.FileMetaData, Union[Table
485491
"strings": strings,
486492
"uuids": uuids,
487493
"binaries": binaries,
494+
"decimal8": decimal8,
495+
"decimal16": decimal16,
496+
"decimal32": decimal32,
488497
},
489498
schema=arrow_schema,
490499
)
491500

492501
metadata_collector: List[Any] = []
493502

494503
with pa.BufferOutputStream() as f:
495-
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector) as writer:
504+
with pq.ParquetWriter(f, table.schema, metadata_collector=metadata_collector, store_decimal_as_integer=True) as writer:
496505
writer.write_table(table)
497506

498507
return metadata_collector[0], table_metadata
@@ -510,13 +519,13 @@ def test_metrics_primitive_types() -> None:
510519
)
511520
datafile = DataFile(**statistics.to_serialized_dict())
512521

513-
assert len(datafile.value_counts) == 12
514-
assert len(datafile.null_value_counts) == 12
522+
assert len(datafile.value_counts) == 15
523+
assert len(datafile.null_value_counts) == 15
515524
assert len(datafile.nan_value_counts) == 0
516525

517526
tz = timezone(timedelta(seconds=19800))
518527

519-
assert len(datafile.lower_bounds) == 12
528+
assert len(datafile.lower_bounds) == 15
520529
assert datafile.lower_bounds[1] == STRUCT_BOOL.pack(False)
521530
assert datafile.lower_bounds[2] == STRUCT_INT32.pack(23)
522531
assert datafile.lower_bounds[3] == STRUCT_INT64.pack(2)
@@ -529,8 +538,12 @@ def test_metrics_primitive_types() -> None:
529538
assert datafile.lower_bounds[10] == b"he"
530539
assert datafile.lower_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "foo").bytes
531540
assert datafile.lower_bounds[12] == b"he"
541+
assert int.from_bytes(datafile.lower_bounds[13], byteorder="big", signed=True) == int(12345)
542+
assert int.from_bytes(datafile.lower_bounds[14], byteorder="big", signed=True) == int(123456789123456)
543+
assert int.from_bytes(datafile.lower_bounds[15], byteorder="big", signed=True) == int(12345678901234123456)
544+
532545

533-
assert len(datafile.upper_bounds) == 12
546+
assert len(datafile.upper_bounds) == 15
534547
assert datafile.upper_bounds[1] == STRUCT_BOOL.pack(True)
535548
assert datafile.upper_bounds[2] == STRUCT_INT32.pack(89)
536549
assert datafile.upper_bounds[3] == STRUCT_INT64.pack(54)
@@ -543,6 +556,9 @@ def test_metrics_primitive_types() -> None:
543556
assert datafile.upper_bounds[10] == b"wp"
544557
assert datafile.upper_bounds[11] == uuid.uuid3(uuid.NAMESPACE_DNS, "bar").bytes
545558
assert datafile.upper_bounds[12] == b"wp"
559+
assert int.from_bytes(datafile.upper_bounds[13], byteorder="big", signed=True) == int(67891)
560+
assert int.from_bytes(datafile.upper_bounds[14], byteorder="big", signed=True) == int(678912345678912)
561+
assert int.from_bytes(datafile.upper_bounds[15], byteorder="big", signed=True) == int(98765432109870654321)
546562

547563

548564
def construct_test_table_invalid_upper_bound() -> Tuple[pq.FileMetaData, Union[TableMetadataV1, TableMetadataV2]]:

0 commit comments

Comments
 (0)