Skip to content

Commit 11d54a9

Browse files
author
Roman Shanin
committed
Evaluate projected field predicate
1 parent 2d34662 commit 11d54a9

2 files changed

Lines changed: 37 additions & 9 deletions

File tree

pyiceberg/expressions/visitors.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -870,8 +870,11 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
870870
file_schema: Schema
871871
case_sensitive: bool
872872

873-
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any]) -> None:
873+
def __init__(
874+
self, file_schema: Schema, projected_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any]
875+
) -> None:
874876
self.file_schema = file_schema
877+
self.projected_schema = projected_schema
875878
self.case_sensitive = case_sensitive
876879
self.projected_missing_fields = projected_missing_fields
877880

@@ -895,17 +898,30 @@ def visit_unbound_predicate(self, predicate: UnboundPredicate[L]) -> BooleanExpr
895898

896899
def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpression:
897900
file_column_name = self.file_schema.find_column_name(predicate.term.ref().field.field_id)
898-
field_name = predicate.term.ref().field.name
899901

900902
if file_column_name is None:
901903
# In the case of schema evolution, the column might not be present
902904
# in the file schema when reading older data
903905
if isinstance(predicate, BoundIsNull):
904906
return AlwaysTrue()
905-
# Projected fields are only available for identity partition fields
906-
# Which mean that partition pruning excluded partition field which evaluates to false
907-
elif field_name in self.projected_missing_fields:
908-
return AlwaysTrue()
907+
# Evaluate projected field by value extracted from partition
908+
elif (field_name := predicate.term.ref().field.name) in self.projected_missing_fields:
909+
unbound_predicate: BooleanExpression
910+
if isinstance(predicate, BoundUnaryPredicate):
911+
unbound_predicate = predicate.as_unbound(field_name)
912+
elif isinstance(predicate, BoundLiteralPredicate):
913+
unbound_predicate = predicate.as_unbound(field_name, predicate.literal)
914+
elif isinstance(predicate, BoundSetPredicate):
915+
unbound_predicate = predicate.as_unbound(field_name, predicate.literals)
916+
else:
917+
raise ValueError(f"Unsupported predicate: {predicate}")
918+
field = self.projected_schema.find_field(field_name)
919+
schema = Schema(field)
920+
evaluator = expression_evaluator(schema, unbound_predicate, self.case_sensitive)
921+
if evaluator(Record(self.projected_missing_fields[field_name])):
922+
return AlwaysTrue()
923+
else:
924+
return AlwaysFalse()
909925
else:
910926
return AlwaysFalse()
911927

@@ -919,8 +935,14 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
919935
raise ValueError(f"Unsupported predicate: {predicate}")
920936

921937

922-
def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_missing_fields: dict[str, Any]) -> BooleanExpression:
923-
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_missing_fields))
938+
def translate_column_names(
939+
expr: BooleanExpression,
940+
file_schema: Schema,
941+
projected_schema: Schema,
942+
case_sensitive: bool,
943+
projected_missing_fields: dict[str, Any],
944+
) -> BooleanExpression:
945+
return visit(expr, _ColumnNameTranslator(file_schema, projected_schema, case_sensitive, projected_missing_fields))
924946

925947

926948
class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]):

pyiceberg/io/pyarrow.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,13 @@ def _task_to_record_batches(
14121412

14131413
pyarrow_filter = None
14141414
if bound_row_filter is not AlwaysTrue():
1415-
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_missing_fields=projected_missing_fields)
1415+
translated_row_filter = translate_column_names(
1416+
bound_row_filter,
1417+
file_schema,
1418+
projected_schema,
1419+
case_sensitive=case_sensitive,
1420+
projected_missing_fields=projected_missing_fields,
1421+
)
14161422
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
14171423
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
14181424

0 commit comments

Comments
 (0)