Skip to content

Commit 933716e

Browse files
committed
feat: Add support for writing bloom filters
1 parent 939a6e5 commit 933716e

5 files changed

Lines changed: 243 additions & 6 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,12 @@
180180
from pyiceberg.utils.config import Config
181181
from pyiceberg.utils.datetime import millis_to_datetime
182182
from pyiceberg.utils.decimal import unscaled_to_decimal
183-
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
183+
from pyiceberg.utils.properties import (
184+
get_first_property_value,
185+
property_as_bool,
186+
property_as_float,
187+
property_as_int,
188+
)
184189
from pyiceberg.utils.singleton import Singleton
185190
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
186191

@@ -2473,6 +2478,120 @@ def parquet_path_to_id_mapping(
24732478
return result
24742479

24752480

2481+
def id_to_parquet_path_mapping(schema: Schema) -> dict[int, str]:
2482+
"""
2483+
Compute the mapping of Iceberg column ID to parquet column path.
2484+
2485+
Args:
2486+
schema (pyiceberg.schema.Schema): The current table schema.
2487+
"""
2488+
result: dict[int, str] = {}
2489+
for pair in pre_order_visit(schema, ID2ParquetPathVisitor()):
2490+
result[pair.field_id] = pair.parquet_path
2491+
return result
2492+
2493+
2494+
@dataclass(frozen=True)
2495+
class BloomFilterOptions:
2496+
parquet_path: str
2497+
ndv: int | None
2498+
fpp: float | None
2499+
2500+
2501+
class BloomFilterOptionsCollector(PreOrderSchemaVisitor[list[BloomFilterOptions]]):
2502+
_field_id: int = 0
2503+
_schema: Schema
2504+
_properties: dict[str, str]
2505+
2506+
def __init__(self, schema: Schema, properties: dict[str, str], id_to_parquet_path_mapping: dict[int, str]):
2507+
self._schema = schema
2508+
self._properties = properties
2509+
self._id_to_parquet_path_mapping = id_to_parquet_path_mapping
2510+
2511+
def schema(
2512+
self, schema: Schema, struct_result: Callable[[], builtins.list[BloomFilterOptions]]
2513+
) -> builtins.list[BloomFilterOptions]:
2514+
return struct_result()
2515+
2516+
def struct(
2517+
self, struct: StructType, field_results: builtins.list[Callable[[], builtins.list[BloomFilterOptions]]]
2518+
) -> builtins.list[BloomFilterOptions]:
2519+
return list(itertools.chain(*[result() for result in field_results]))
2520+
2521+
def field(
2522+
self, field: NestedField, field_result: Callable[[], builtins.list[BloomFilterOptions]]
2523+
) -> builtins.list[BloomFilterOptions]:
2524+
self._field_id = field.field_id
2525+
return field_result()
2526+
2527+
def list(
2528+
self, list_type: ListType, element_result: Callable[[], builtins.list[BloomFilterOptions]]
2529+
) -> builtins.list[BloomFilterOptions]:
2530+
self._field_id = list_type.element_id
2531+
return element_result()
2532+
2533+
def map(
2534+
self,
2535+
map_type: MapType,
2536+
key_result: Callable[[], builtins.list[BloomFilterOptions]],
2537+
value_result: Callable[[], builtins.list[BloomFilterOptions]],
2538+
) -> builtins.list[BloomFilterOptions]:
2539+
self._field_id = map_type.key_id
2540+
k = key_result()
2541+
self._field_id = map_type.value_id
2542+
v = value_result()
2543+
return k + v
2544+
2545+
def primitive(self, primitive: PrimitiveType) -> builtins.list[BloomFilterOptions]:
2546+
from pyiceberg.table import TableProperties
2547+
2548+
column_name = self._schema.find_column_name(self._field_id)
2549+
if column_name is None:
2550+
return []
2551+
2552+
parquet_path = self._id_to_parquet_path_mapping.get(self._field_id)
2553+
if parquet_path is None:
2554+
return []
2555+
2556+
bloom_filter_enabled = property_as_bool(
2557+
self._properties, f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.{column_name}", False
2558+
)
2559+
if not bloom_filter_enabled:
2560+
return []
2561+
2562+
bloom_filter_fpp = property_as_float(
2563+
self._properties, f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_FPP_PREFIX}.{column_name}", None
2564+
)
2565+
bloom_filter_ndv = property_as_int(
2566+
self._properties, f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_NDV_PREFIX}.{column_name}", None
2567+
)
2568+
2569+
return [BloomFilterOptions(parquet_path=parquet_path, ndv=bloom_filter_ndv, fpp=bloom_filter_fpp)]
2570+
2571+
2572+
def get_bloom_filter_options(
2573+
schema: Schema,
2574+
table_properties: dict[str, str],
2575+
) -> dict[str, dict[str, Any]]:
2576+
"""
2577+
Get the bloom filter options from the table properties.
2578+
2579+
Args:
2580+
schema (pyiceberg.schema.Schema): The current table schema.
2581+
table_properties (dict[str, str]): The table properties.
2582+
"""
2583+
bloom_filter_options = pre_order_visit(
2584+
schema, BloomFilterOptionsCollector(schema, table_properties, id_to_parquet_path_mapping(schema))
2585+
)
2586+
result: dict[str, dict[str, Any]] = {}
2587+
for bf_opts in bloom_filter_options:
2588+
result[bf_opts.parquet_path] = {
2589+
**({"ndv": bf_opts.ndv} if bf_opts.ndv is not None else {}),
2590+
**({"fpp": bf_opts.fpp} if bf_opts.fpp is not None else {}),
2591+
}
2592+
return result
2593+
2594+
24762595
@dataclass(frozen=True)
24772596
class DataFileStatistics:
24782597
record_count: int
@@ -2668,7 +2787,6 @@ def data_file_statistics_from_parquet_metadata(
26682787
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
26692788
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties
26702789

2671-
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
26722790
row_group_size = property_as_int(
26732791
properties=table_metadata.properties,
26742792
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
@@ -2685,6 +2803,8 @@ def write_parquet(task: WriteTask) -> DataFile:
26852803
else:
26862804
file_schema = table_schema
26872805

2806+
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties, file_schema)
2807+
26882808
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
26892809
batches = [
26902810
_to_requested_schema(
@@ -2829,14 +2949,25 @@ def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_pa
28292949
PYARROW_UNCOMPRESSED_CODEC = "none"
28302950

28312951

2832-
def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]:
2952+
def _get_parquet_writer_kwargs(table_properties: Properties, file_schema: Schema) -> dict[str, Any]:
28332953
from pyiceberg.table import TableProperties
28342954

2835-
for key_pattern in [
2955+
unsupported_key_patterns = [
28362956
TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
28372957
TableProperties.PARQUET_BLOOM_FILTER_MAX_BYTES,
2838-
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*",
2839-
]:
2958+
]
2959+
2960+
from packaging import version
2961+
2962+
MIN_PYARROW_VERSION_SUPPORTING_BLOOM_FILTER_WRITES = "24.0.0"
2963+
if version.parse(pyarrow.__version__) < version.parse(MIN_PYARROW_VERSION_SUPPORTING_BLOOM_FILTER_WRITES):
2964+
unsupported_key_patterns += [
2965+
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*",
2966+
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_FPP_PREFIX}.*",
2967+
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_NDV_PREFIX}.*",
2968+
]
2969+
2970+
for key_pattern in unsupported_key_patterns:
28402971
if unsupported_keys := fnmatch.filter(table_properties, key_pattern):
28412972
warnings.warn(f"Parquet writer option(s) {unsupported_keys} not implemented", stacklevel=2)
28422973

@@ -2849,6 +2980,8 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]:
28492980
if compression_codec == ICEBERG_UNCOMPRESSED_CODEC:
28502981
compression_codec = PYARROW_UNCOMPRESSED_CODEC
28512982

2983+
bloom_filter_options = get_bloom_filter_options(file_schema, table_properties)
2984+
28522985
return {
28532986
"compression": compression_codec,
28542987
"compression_level": compression_level,
@@ -2867,6 +3000,7 @@ def _get_parquet_writer_kwargs(table_properties: Properties) -> dict[str, Any]:
28673000
property_name=TableProperties.PARQUET_PAGE_ROW_LIMIT,
28683001
default=TableProperties.PARQUET_PAGE_ROW_LIMIT_DEFAULT,
28693002
),
3003+
**({"bloom_filter_options": bloom_filter_options} if bloom_filter_options else {}),
28703004
}
28713005

28723006

pyiceberg/table/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class TableProperties:
147147

148148
PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX = "write.parquet.bloom-filter-enabled.column"
149149

150+
PARQUET_BLOOM_FILTER_COLUMN_FPP_PREFIX = "write.parquet.bloom-filter-fpp.column"
151+
152+
PARQUET_BLOOM_FILTER_COLUMN_NDV_PREFIX = "write.parquet.bloom-filter-ndv.column"
153+
150154
WRITE_TARGET_FILE_SIZE_BYTES = "write.target-file-size-bytes"
151155
WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT = 512 * 1024 * 1024 # 512 MB
152156

pyiceberg/utils/properties.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,41 @@ def property_as_bool(
6666
return default
6767

6868

69+
def properties_as_int_dict(
70+
properties: dict[str, str],
71+
property_prefix: str,
72+
) -> dict[str, int]:
73+
return {
74+
key.removeprefix(property_prefix + "."): value
75+
for key in properties.keys()
76+
if key.startswith(property_prefix)
77+
if (value := property_as_int(properties, key, None)) is not None
78+
}
79+
80+
81+
def properties_as_float_dict(
82+
properties: dict[str, str],
83+
property_prefix: str,
84+
) -> dict[str, float]:
85+
return {
86+
key.removeprefix(property_prefix + "."): value
87+
for key in properties.keys()
88+
if key.startswith(property_prefix)
89+
if (value := property_as_float(properties, key, None)) is not None
90+
}
91+
92+
93+
def properties_as_bool_dict(
94+
properties: dict[str, str],
95+
property_prefix: str,
96+
) -> dict[str, bool]:
97+
return {
98+
key.removeprefix(property_prefix + "."): property_as_bool(properties, key, False)
99+
for key in properties.keys()
100+
if key.startswith(property_prefix)
101+
}
102+
103+
69104
def get_first_property_value(
70105
properties: Properties,
71106
*property_names: str,

tests/integration/test_writes/test_writes.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@
3030
import fastavro
3131
import pandas as pd
3232
import pandas.testing
33+
import pyarrow
3334
import pyarrow as pa
3435
import pyarrow.compute as pc
3536
import pyarrow.parquet as pq
3637
import pytest
3738
import pytz
39+
from packaging import version
3840
from pyarrow.fs import S3FileSystem
3941
from pydantic_core import ValidationError
4042
from pyspark.sql import SparkSession
@@ -68,6 +70,11 @@
6870
from pyiceberg.view.metadata import SQLViewRepresentation, ViewVersion
6971
from utils import TABLE_SCHEMA, _create_table
7072

73+
skip_if_bloom_filter_not_supported = pytest.mark.skipif(
74+
version.parse(pyarrow.__version__) < version.parse("24.0.0"),
75+
reason="Requires pyarrow version >= 24.0.0",
76+
)
77+
7178

7279
@pytest.fixture(scope="session", autouse=True)
7380
def table_v1_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
@@ -712,6 +719,27 @@ def test_write_parquet_unsupported_properties(
712719
tbl.append(arrow_table_with_null)
713720

714721

722+
@pytest.mark.integration
723+
@skip_if_bloom_filter_not_supported
724+
@pytest.mark.parametrize("format_version", [1, 2])
725+
def test_write_parquet_bloom_filter_properties(
726+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
727+
) -> None:
728+
identifier = "default.write_parquet_bloom_filter_properties"
729+
730+
_create_table(
731+
session_catalog,
732+
identifier,
733+
{
734+
"format-version": format_version,
735+
"write.parquet.bloom-filter-enabled.column.string": "true",
736+
"write.parquet.bloom-filter-fpp.column.string": "0.1",
737+
"write.parquet.bloom-filter-ndv.column.string": "100",
738+
},
739+
[arrow_table_with_null],
740+
)
741+
742+
715743
@pytest.mark.integration
716744
@pytest.mark.parametrize("format_version", [1, 2])
717745
def test_spark_writes_orc_pyiceberg_reads(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:

tests/table/test_metadata.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pytest
2727

2828
from pyiceberg.exceptions import ValidationError
29+
from pyiceberg.io.pyarrow import get_bloom_filter_options
2930
from pyiceberg.partitioning import PartitionField, PartitionSpec
3031
from pyiceberg.schema import Schema
3132
from pyiceberg.serializers import FromByteStream
@@ -876,3 +877,38 @@ def test_new_table_metadata_format_v2_with_v3_schema_fails(field_type: Primitive
876877
location="s3://some_v1_location/",
877878
properties={"format-version": "2"},
878879
)
880+
881+
882+
def test_get_bloom_filter_options() -> None:
883+
schema = Schema(
884+
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
885+
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=False),
886+
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
887+
NestedField(
888+
field_id=34,
889+
name="qux",
890+
field_type=StructType(
891+
NestedField(field_id=35, name="quux", field_type=StringType(), required=False),
892+
NestedField(field_id=36, name="quuux", field_type=IntegerType(), required=False),
893+
),
894+
required=False,
895+
),
896+
)
897+
898+
table_properties = {
899+
"write.parquet.bloom-filter-enabled.column.foo": "true",
900+
"write.parquet.bloom-filter-fpp.column.foo": "0.01",
901+
"write.parquet.bloom-filter-ndv.column.foo": "1000",
902+
"write.parquet.bloom-filter-enabled.column.bar": "false",
903+
"write.parquet.bloom-filter-fpp.column.bar": "0.02",
904+
"write.parquet.bloom-filter-ndv.column.bar": "2000",
905+
"write.parquet.bloom-filter-enabled.column.qux.quux": "true",
906+
"write.parquet.bloom-filter-fpp.column.qux.quux": "0.03",
907+
"write.parquet.bloom-filter-ndv.column.qux.quux": "3000",
908+
}
909+
910+
bloom_filter_options = get_bloom_filter_options(schema, table_properties)
911+
assert bloom_filter_options == {
912+
"foo": {"fpp": 0.01, "ndv": 1000},
913+
"qux.quux": {"fpp": 0.03, "ndv": 3000},
914+
}

0 commit comments

Comments
 (0)