Skip to content

Commit 4651ea4

Browse files
author
Hanzhi Wang
committed
table.inspect.partitions(): add filter expression
Allow users to filter partitions using expressions when inspecting table partitions, similar to scan planning.
1 parent 6020f24 commit 4651ea4

3 files changed

Lines changed: 179 additions & 89 deletions

File tree

pyiceberg/table/__init__.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Callable,
3232
Dict,
3333
Iterable,
34+
Iterator,
3435
List,
3536
Optional,
3637
Set,
@@ -1942,11 +1943,11 @@ def _check_sequence_number(min_sequence_number: int, manifest: ManifestFile) ->
19421943
and (manifest.sequence_number or INITIAL_SEQUENCE_NUMBER) >= min_sequence_number
19431944
)
19441945

1945-
def plan_files(self) -> Iterable[FileScanTask]:
1946-
"""Plans the relevant files by filtering on the PartitionSpecs.
1946+
def scan_plan_helper(self) -> Iterator[ManifestEntry]:
1947+
"""Filter and return manifest entries based on partition and metrics evaluators.
19471948
19481949
Returns:
1949-
List of FileScanTasks that contain both data and delete files.
1950+
Iterator of ManifestEntry objects that match the scan's partition filter.
19501951
"""
19511952
snapshot = self.snapshot()
19521953
if not snapshot:
@@ -1957,8 +1958,6 @@ def plan_files(self) -> Iterable[FileScanTask]:
19571958

19581959
manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)
19591960

1960-
residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)
1961-
19621961
manifests = [
19631962
manifest_file
19641963
for manifest_file in snapshot.manifests(self.io)
@@ -1972,11 +1971,9 @@ def plan_files(self) -> Iterable[FileScanTask]:
19721971

19731972
min_sequence_number = _min_sequence_number(manifests)
19741973

1975-
data_entries: List[ManifestEntry] = []
1976-
positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER)
1977-
19781974
executor = ExecutorFactory.get_or_create()
1979-
for manifest_entry in chain(
1975+
1976+
return chain(
19801977
*executor.map(
19811978
lambda args: _open_manifest(*args),
19821979
[
@@ -1990,7 +1987,20 @@ def plan_files(self) -> Iterable[FileScanTask]:
19901987
if self._check_sequence_number(min_sequence_number, manifest)
19911988
],
19921989
)
1993-
):
1990+
)
1991+
1992+
def plan_files(self) -> Iterable[FileScanTask]:
1993+
"""Plans the relevant files by filtering on the PartitionSpecs.
1994+
1995+
Returns:
1996+
List of FileScanTasks that contain both data and delete files.
1997+
"""
1998+
data_entries: List[ManifestEntry] = []
1999+
positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER)
2000+
2001+
residual_evaluators: Dict[int, Callable[[DataFile], ResidualEvaluator]] = KeyDefaultDict(self._build_residual_evaluator)
2002+
2003+
for manifest_entry in self.scan_plan_helper():
19942004
data_file = manifest_entry.data_file
19952005
if data_file.content == DataFileContent.DATA:
19962006
data_entries.append(manifest_entry)

pyiceberg/table/inspect.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from __future__ import annotations
1818

1919
from datetime import datetime, timezone
20-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple
20+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple, Union
2121

2222
from pyiceberg.conversions import from_bytes
23-
from pyiceberg.manifest import DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
23+
from pyiceberg.expressions import AlwaysTrue, BooleanExpression
24+
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
2425
from pyiceberg.partitioning import PartitionSpec
2526
from pyiceberg.table.snapshots import Snapshot, ancestors_of
2627
from pyiceberg.types import PrimitiveType
@@ -32,6 +33,8 @@
3233

3334
from pyiceberg.table import Table
3435

36+
ALWAYS_TRUE = AlwaysTrue()
37+
3538

3639
class InspectTable:
3740
tbl: Table
@@ -255,10 +258,16 @@ def refs(self) -> "pa.Table":
255258

256259
return pa.Table.from_pylist(ref_results, schema=ref_schema)
257260

258-
def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
261+
def partitions(
262+
self,
263+
snapshot_id: Optional[int] = None,
264+
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
265+
case_sensitive: bool = True,
266+
) -> "pa.Table":
259267
import pyarrow as pa
260268

261269
from pyiceberg.io.pyarrow import schema_to_pyarrow
270+
from pyiceberg.table import DataScan
262271

263272
table_schema = pa.schema(
264273
[
@@ -289,85 +298,74 @@ def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table":
289298
table_schema = pa.unify_schemas([partitions_schema, table_schema])
290299

291300
snapshot = self._get_snapshot(snapshot_id)
292-
executor = ExecutorFactory.get_or_create()
293-
local_partitions_maps = executor.map(self._process_manifest, snapshot.manifests(self.tbl.io))
294-
295-
partitions_map: Dict[Tuple[str, Any], Any] = {}
296-
for local_map in local_partitions_maps:
297-
for partition_record_key, partition_row in local_map.items():
298-
if partition_record_key not in partitions_map:
299-
partitions_map[partition_record_key] = partition_row
300-
else:
301-
existing = partitions_map[partition_record_key]
302-
existing["record_count"] += partition_row["record_count"]
303-
existing["file_count"] += partition_row["file_count"]
304-
existing["total_data_file_size_in_bytes"] += partition_row["total_data_file_size_in_bytes"]
305-
existing["position_delete_record_count"] += partition_row["position_delete_record_count"]
306-
existing["position_delete_file_count"] += partition_row["position_delete_file_count"]
307-
existing["equality_delete_record_count"] += partition_row["equality_delete_record_count"]
308-
existing["equality_delete_file_count"] += partition_row["equality_delete_file_count"]
309-
310-
if partition_row["last_updated_at"] and (
311-
not existing["last_updated_at"] or partition_row["last_updated_at"] > existing["last_updated_at"]
312-
):
313-
existing["last_updated_at"] = partition_row["last_updated_at"]
314-
existing["last_updated_snapshot_id"] = partition_row["last_updated_snapshot_id"]
315301

316-
return pa.Table.from_pylist(
317-
partitions_map.values(),
318-
schema=table_schema,
302+
scan = DataScan(
303+
table_metadata=self.tbl.metadata,
304+
io=self.tbl.io,
305+
row_filter=row_filter,
306+
case_sensitive=case_sensitive,
307+
snapshot_id=snapshot.snapshot_id,
319308
)
320309

321-
def _process_manifest(self, manifest: ManifestFile) -> Dict[Tuple[str, Any], Any]:
322310
partitions_map: Dict[Tuple[str, Any], Any] = {}
323-
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
311+
312+
for entry in scan.scan_plan_helper():
324313
partition = entry.data_file.partition
325314
partition_record_dict = {
326-
field.name: partition[pos]
327-
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
315+
field.name: partition[pos] for pos, field in enumerate(self.tbl.metadata.specs()[entry.data_file.spec_id].fields)
328316
}
329317
entry_snapshot = self.tbl.snapshot_by_id(entry.snapshot_id) if entry.snapshot_id is not None else None
318+
self._update_partitions_map_from_manifest_entry(
319+
partitions_map, entry.data_file, partition_record_dict, entry_snapshot
320+
)
330321

331-
partition_record_key = _convert_to_hashable_type(partition_record_dict)
332-
if partition_record_key not in partitions_map:
333-
partitions_map[partition_record_key] = {
334-
"partition": partition_record_dict,
335-
"spec_id": entry.data_file.spec_id,
336-
"record_count": 0,
337-
"file_count": 0,
338-
"total_data_file_size_in_bytes": 0,
339-
"position_delete_record_count": 0,
340-
"position_delete_file_count": 0,
341-
"equality_delete_record_count": 0,
342-
"equality_delete_file_count": 0,
343-
"last_updated_at": entry_snapshot.timestamp_ms if entry_snapshot else None,
344-
"last_updated_snapshot_id": entry_snapshot.snapshot_id if entry_snapshot else None,
345-
}
322+
return pa.Table.from_pylist(
323+
partitions_map.values(),
324+
schema=table_schema,
325+
)
346326

347-
partition_row = partitions_map[partition_record_key]
348-
349-
if entry_snapshot is not None:
350-
if (
351-
partition_row["last_updated_at"] is None
352-
or partition_row["last_updated_snapshot_id"] < entry_snapshot.timestamp_ms
353-
):
354-
partition_row["last_updated_at"] = entry_snapshot.timestamp_ms
355-
partition_row["last_updated_snapshot_id"] = entry_snapshot.snapshot_id
356-
357-
if entry.data_file.content == DataFileContent.DATA:
358-
partition_row["record_count"] += entry.data_file.record_count
359-
partition_row["file_count"] += 1
360-
partition_row["total_data_file_size_in_bytes"] += entry.data_file.file_size_in_bytes
361-
elif entry.data_file.content == DataFileContent.POSITION_DELETES:
362-
partition_row["position_delete_record_count"] += entry.data_file.record_count
363-
partition_row["position_delete_file_count"] += 1
364-
elif entry.data_file.content == DataFileContent.EQUALITY_DELETES:
365-
partition_row["equality_delete_record_count"] += entry.data_file.record_count
366-
partition_row["equality_delete_file_count"] += 1
367-
else:
368-
raise ValueError(f"Unknown DataFileContent ({entry.data_file.content})")
327+
def _update_partitions_map_from_manifest_entry(
328+
self,
329+
partitions_map: Dict[Tuple[str, Any], Any],
330+
file: DataFile,
331+
partition_record_dict: Dict[str, Any],
332+
snapshot: Optional[Snapshot],
333+
) -> None:
334+
partition_record_key = _convert_to_hashable_type(partition_record_dict)
335+
if partition_record_key not in partitions_map:
336+
partitions_map[partition_record_key] = {
337+
"partition": partition_record_dict,
338+
"spec_id": file.spec_id,
339+
"record_count": 0,
340+
"file_count": 0,
341+
"total_data_file_size_in_bytes": 0,
342+
"position_delete_record_count": 0,
343+
"position_delete_file_count": 0,
344+
"equality_delete_record_count": 0,
345+
"equality_delete_file_count": 0,
346+
"last_updated_at": snapshot.timestamp_ms if snapshot else None,
347+
"last_updated_snapshot_id": snapshot.snapshot_id if snapshot else None,
348+
}
369349

370-
return partitions_map
350+
partition_row = partitions_map[partition_record_key]
351+
352+
if snapshot is not None:
353+
if partition_row["last_updated_at"] is None or partition_row["last_updated_snapshot_id"] < snapshot.timestamp_ms:
354+
partition_row["last_updated_at"] = snapshot.timestamp_ms
355+
partition_row["last_updated_snapshot_id"] = snapshot.snapshot_id
356+
357+
if file.content == DataFileContent.DATA:
358+
partition_row["record_count"] += file.record_count
359+
partition_row["file_count"] += 1
360+
partition_row["total_data_file_size_in_bytes"] += file.file_size_in_bytes
361+
elif file.content == DataFileContent.POSITION_DELETES:
362+
partition_row["position_delete_record_count"] += file.record_count
363+
partition_row["position_delete_file_count"] += 1
364+
elif file.content == DataFileContent.EQUALITY_DELETES:
365+
partition_row["equality_delete_record_count"] += file.record_count
366+
partition_row["equality_delete_file_count"] += 1
367+
else:
368+
raise ValueError(f"Unknown DataFileContent ({file.content})")
371369

372370
def _get_manifests_schema(self) -> "pa.Schema":
373371
import pyarrow as pa

tests/integration/test_inspect_table.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import math
2020
from datetime import date, datetime
21+
from typing import Union
2122

2223
import pyarrow as pa
2324
import pytest
@@ -26,6 +27,13 @@
2627

2728
from pyiceberg.catalog import Catalog
2829
from pyiceberg.exceptions import NoSuchTableError
30+
from pyiceberg.expressions import (
31+
And,
32+
BooleanExpression,
33+
EqualTo,
34+
GreaterThanOrEqual,
35+
LessThan,
36+
)
2937
from pyiceberg.schema import Schema
3038
from pyiceberg.table import Table
3139
from pyiceberg.typedef import Properties
@@ -198,6 +206,14 @@ def _inspect_files_asserts(df: pa.Table, spark_df: DataFrame) -> None:
198206
assert left == right, f"Difference in column {column}: {left} != {right}"
199207

200208

209+
def _check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
210+
lhs = df.to_pandas().sort_values("last_updated_at")
211+
rhs = spark_df.toPandas().sort_values("last_updated_at")
212+
for column in df.column_names:
213+
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
214+
assert left == right, f"Difference in column {column}: {left} != {right}"
215+
216+
201217
@pytest.mark.integration
202218
@pytest.mark.parametrize("format_version", [1, 2])
203219
def test_inspect_snapshots(
@@ -581,18 +597,84 @@ def test_inspect_partitions_partitioned(spark: SparkSession, session_catalog: Ca
581597
"""
582598
)
583599

584-
def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
585-
lhs = df.to_pandas().sort_values("spec_id")
586-
rhs = spark_df.toPandas().sort_values("spec_id")
587-
for column in df.column_names:
588-
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
589-
assert left == right, f"Difference in column {column}: {left} != {right}"
590-
591600
tbl = session_catalog.load_table(identifier)
592601
for snapshot in tbl.metadata.snapshots:
593602
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id)
594603
spark_df = spark.sql(f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id}")
595-
check_pyiceberg_df_equals_spark_df(df, spark_df)
604+
_check_pyiceberg_df_equals_spark_df(df, spark_df)
605+
606+
607+
@pytest.mark.integration
608+
@pytest.mark.parametrize("format_version", [1, 2])
609+
def test_inspect_partitions_partitioned_with_filter(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
610+
identifier = "default.table_metadata_partitions_with_filter"
611+
try:
612+
session_catalog.drop_table(identifier=identifier)
613+
except NoSuchTableError:
614+
pass
615+
616+
spark.sql(
617+
f"""
618+
CREATE TABLE {identifier} (
619+
name string,
620+
dt date
621+
)
622+
PARTITIONED BY (dt)
623+
"""
624+
)
625+
626+
spark.sql(
627+
f"""
628+
INSERT INTO {identifier} VALUES ('John', CAST('2021-01-01' AS date))
629+
"""
630+
)
631+
632+
spark.sql(
633+
f"""
634+
INSERT INTO {identifier} VALUES ('Doe', CAST('2021-01-05' AS date))
635+
"""
636+
)
637+
638+
spark.sql(
639+
f"""
640+
INSERT INTO {identifier} VALUES ('Jenny', CAST('2021-02-01' AS date))
641+
"""
642+
)
643+
644+
tbl = session_catalog.load_table(identifier)
645+
for snapshot in tbl.metadata.snapshots:
646+
test_cases: list[tuple[Union[str, BooleanExpression], str]] = [
647+
("dt >= '2021-01-01'", "partition.dt >= '2021-01-01'"),
648+
(GreaterThanOrEqual("dt", "2021-01-01"), "partition.dt >= '2021-01-01'"),
649+
("dt >= '2021-01-01' and dt < '2021-03-01'", "partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'"),
650+
(
651+
And(GreaterThanOrEqual("dt", "2021-01-01"), LessThan("dt", "2021-03-01")),
652+
"partition.dt >= '2021-01-01' AND partition.dt < '2021-03-01'",
653+
),
654+
("dt == '2021-02-01'", "partition.dt = '2021-02-01'"),
655+
(EqualTo("dt", "2021-02-01"), "partition.dt = '2021-02-01'"),
656+
]
657+
for filter_predicate_lt, filter_predicate_rt in test_cases:
658+
df = tbl.inspect.partitions(snapshot_id=snapshot.snapshot_id, row_filter=filter_predicate_lt)
659+
spark_df = spark.sql(
660+
f"SELECT * FROM {identifier}.partitions VERSION AS OF {snapshot.snapshot_id} WHERE {filter_predicate_rt}"
661+
)
662+
_check_pyiceberg_df_equals_spark_df(df, spark_df)
663+
664+
665+
@pytest.mark.integration
666+
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog")])
667+
def test_inspect_partitions_partitioned_transform_with_filter(spark: SparkSession, catalog: Catalog) -> None:
668+
for table_name, predicate, partition_predicate in [
669+
("test_partitioned_by_identity", "ts >= '2023-03-05T00:00:00+00:00'", "ts >= '2023-03-05T00:00:00+00:00'"),
670+
("test_partitioned_by_years", "dt >= '2023-03-05'", "dt_year >= 53"),
671+
("test_partitioned_by_months", "dt >= '2023-03-05'", "dt_month >= 638"),
672+
("test_partitioned_by_days", "ts >= '2023-03-05T00:00:00+00:00'", "ts_day >= '2023-03-05'"),
673+
]:
674+
table = catalog.load_table(f"default.{table_name}")
675+
df = table.inspect.partitions(row_filter=predicate)
676+
expected_df = spark.sql(f"select * from default.{table_name}.partitions where partition.{partition_predicate}")
677+
assert len(df.to_pandas()) == len(expected_df.toPandas())
596678

597679

598680
@pytest.mark.integration

0 commit comments

Comments
 (0)