Skip to content

Commit 682afc5

Browse files
committed
fix
1 parent 7b2ecbb commit 682afc5

4 files changed

Lines changed: 273 additions & 11 deletions

File tree

pyiceberg/expressions/visitors.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
862862
Args:
863863
file_schema (Schema): The schema of the file.
864864
case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True.
865-
projected_fields (Dict[str, Any]): Partition field values for missing fields from projection.
865+
projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file.
866866
867867
Raises:
868868
TypeError: In the case of an UnboundPredicate.
@@ -871,12 +871,14 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]):
871871

872872
file_schema: Schema
873873
case_sensitive: bool
874-
projected_fields: Dict[str, Any]
874+
projected_field_values: Dict[str, Any]
875875

876-
def __init__(self, file_schema: Schema, case_sensitive: bool, projected_fields: Optional[Dict[str, Any]] = None) -> None:
876+
def __init__(
877+
self, file_schema: Schema, case_sensitive: bool, projected_field_values: Optional[Dict[str, Any]] = None
878+
) -> None:
877879
self.file_schema = file_schema
878880
self.case_sensitive = case_sensitive
879-
self.projected_fields = projected_fields or {}
881+
self.projected_field_values = projected_field_values or {}
880882

881883
def visit_true(self) -> BooleanExpression:
882884
return AlwaysTrue()
@@ -901,9 +903,8 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
901903
file_column_name = self.file_schema.find_column_name(field.field_id)
902904

903905
if file_column_name is None:
904-
# In the case of schema evolution, the column might not be present
905-
# we can use the default value as a constant and evaluate it against
906-
# the predicate
906+
# In the case of schema evolution or column projection, the column might not be present in the file schema.
907+
# we can use the projected value or the field's default value as a constant and evaluate it against the predicate
907908
pred: BooleanExpression
908909
if isinstance(predicate, BoundUnaryPredicate):
909910
pred = predicate.as_unbound(field.name)
@@ -917,7 +918,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
917918
return (
918919
AlwaysTrue()
919920
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
920-
Record(field.initial_default or self.projected_fields.get(field.name, None))
921+
Record(self.projected_field_values.get(field.name, None) or field.initial_default)
921922
)
922923
else AlwaysFalse()
923924
)
@@ -933,9 +934,9 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
933934

934935

935936
def translate_column_names(
936-
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_fields: Optional[Dict[str, Any]] = None
937+
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Optional[Dict[str, Any]] = None
937938
) -> BooleanExpression:
938-
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_fields))
939+
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))
939940

940941

941942
class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]):

pyiceberg/io/pyarrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,7 @@ def _task_to_record_batches(
14691469
pyarrow_filter = None
14701470
if bound_row_filter is not AlwaysTrue():
14711471
translated_row_filter = translate_column_names(
1472-
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_fields=projected_missing_fields
1472+
bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields
14731473
)
14741474
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
14751475
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

tests/expressions/test_visitors.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@
7272
expression_to_plain_format,
7373
rewrite_not,
7474
rewrite_to_dnf,
75+
translate_column_names,
7576
visit,
7677
visit_bound_predicate,
7778
)
7879
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
7980
from pyiceberg.schema import Accessor, Schema
8081
from pyiceberg.typedef import Record
8182
from pyiceberg.types import (
83+
BooleanType,
8284
DoubleType,
8385
FloatType,
8486
IcebergType,
@@ -1623,3 +1625,252 @@ def test_expression_evaluator_null() -> None:
16231625
assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False
16241626
assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False
16251627
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True
1628+
1629+
1630+
def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None:
1631+
"""Test translate_column_names with matching column names."""
1632+
# Create a bound expression using the original schema
1633+
unbound_expr = EqualTo("foo", "test_value")
1634+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
1635+
1636+
# File schema has the same column names
1637+
file_schema = Schema(
1638+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
1639+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
1640+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
1641+
schema_id=1,
1642+
)
1643+
1644+
# Translate column names
1645+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1646+
1647+
# Should return an unbound expression with the same column name since they match
1648+
assert isinstance(translated_expr, EqualTo)
1649+
assert translated_expr.term == Reference("foo")
1650+
assert translated_expr.literal == literal("test_value")
1651+
1652+
1653+
def test_translate_column_names_different_column_names() -> None:
1654+
"""Test translate_column_names with different column names in file schema."""
1655+
# Original schema
1656+
original_schema = Schema(
1657+
NestedField(field_id=1, name="original_name", field_type=StringType(), required=False),
1658+
schema_id=1,
1659+
)
1660+
1661+
# Create bound expression
1662+
unbound_expr = EqualTo("original_name", "test_value")
1663+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1664+
1665+
# File schema has different column name but same field ID
1666+
file_schema = Schema(
1667+
NestedField(field_id=1, name="file_column_name", field_type=StringType(), required=False),
1668+
schema_id=1,
1669+
)
1670+
1671+
# Translate column names
1672+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1673+
1674+
# Should use the file schema's column name
1675+
assert isinstance(translated_expr, EqualTo)
1676+
assert translated_expr.term == Reference("file_column_name")
1677+
assert translated_expr.literal == literal("test_value")
1678+
1679+
1680+
def test_translate_column_names_missing_column() -> None:
1681+
"""Test translate_column_names when column is missing from file schema (such as in schema evolution)."""
1682+
# Original schema
1683+
original_schema = Schema(
1684+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1685+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1686+
schema_id=1,
1687+
)
1688+
1689+
# Create bound expression for the missing column
1690+
unbound_expr = EqualTo("missing_col", 42)
1691+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1692+
1693+
# File schema only has the existing column (field_id=1), missing field_id=2
1694+
file_schema = Schema(
1695+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1696+
schema_id=1,
1697+
)
1698+
1699+
# Translate column names
1700+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1701+
1702+
# missing_col's initial_default (None) does not match the expression literal (42)
1703+
assert translated_expr == AlwaysFalse()
1704+
1705+
1706+
def test_translate_column_names_missing_column_is_null() -> None:
1707+
"""Test translate_column_names when missing column is checked for null."""
1708+
# Original schema
1709+
original_schema = Schema(
1710+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1711+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1712+
schema_id=1,
1713+
)
1714+
1715+
# Create bound expression for the missing column
1716+
unbound_expr = IsNull("missing_col")
1717+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1718+
1719+
# File schema only has the existing column (field_id=1), missing field_id=2
1720+
file_schema = Schema(
1721+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1722+
schema_id=1,
1723+
)
1724+
1725+
# Translate column names
1726+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1727+
1728+
# Should evaluate to AlwaysTrue because the missing column is treated as null
1729+
# missing_col's initial_default (None) satisfies the IsNull predicate
1730+
assert translated_expr == AlwaysTrue()
1731+
1732+
1733+
def test_translate_column_names_missing_column_with_initial_default() -> None:
1734+
"""Test translate_column_names when missing column has initial_default that matches expression."""
1735+
# Original schema
1736+
original_schema = Schema(
1737+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1738+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42),
1739+
schema_id=1,
1740+
)
1741+
1742+
# Create bound expression for the missing column
1743+
unbound_expr = EqualTo("missing_col", 42)
1744+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1745+
1746+
# File schema only has the existing column (field_id=1), missing field_id=2
1747+
file_schema = Schema(
1748+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1749+
schema_id=1,
1750+
)
1751+
1752+
# Translate column names
1753+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1754+
1755+
# Should evaluate to AlwaysTrue because the initial_default value (42) matches the literal (42)
1756+
assert translated_expr == AlwaysTrue()
1757+
1758+
1759+
def test_translate_column_names_missing_column_initial_default_mismatch() -> None:
1760+
"""Test translate_column_names when missing column's initial_default doesn't match expression."""
1761+
# Original schema
1762+
original_schema = Schema(
1763+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
1764+
schema_id=1,
1765+
)
1766+
1767+
# Create bound expression that won't match the default value
1768+
unbound_expr = EqualTo("missing_col", 42)
1769+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1770+
1771+
# File schema doesn't have this column
1772+
file_schema = Schema(
1773+
NestedField(field_id=1, name="other_col", field_type=StringType(), required=False),
1774+
schema_id=1,
1775+
)
1776+
1777+
# Translate column names
1778+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1779+
1780+
# Should evaluate to AlwaysFalse because initial_default value (10) doesn't match literal (42)
1781+
assert translated_expr == AlwaysFalse()
1782+
1783+
1784+
def test_translate_column_names_projected_field_matches() -> None:
1785+
"""Test translate_column_names with projected field value that matches expression."""
1786+
# Original schema with a field that has no initial_default (defaults to None)
1787+
original_schema = Schema(
1788+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1789+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1790+
schema_id=1,
1791+
)
1792+
1793+
# Create bound expression for the missing column
1794+
unbound_expr = EqualTo("missing_col", 42)
1795+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1796+
1797+
# File schema only has the existing column (field_id=1), missing field_id=2
1798+
file_schema = Schema(
1799+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1800+
schema_id=1,
1801+
)
1802+
1803+
# Projected column that is missing in the file schema
1804+
projected_field_values = {"missing_col": 42}
1805+
1806+
# Translate column names
1807+
translated_expr = translate_column_names(
1808+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1809+
)
1810+
1811+
# Should evaluate to AlwaysTrue since projected field value matches the expression literal
1812+
# even though the field is missing in the file schema
1813+
assert translated_expr == AlwaysTrue()
1814+
1815+
1816+
def test_translate_column_names_projected_field_mismatch() -> None:
1817+
"""Test translate_column_names with projected field value that doesn't match expression."""
1818+
# Original schema with a field that has no initial_default (defaults to None)
1819+
original_schema = Schema(
1820+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1821+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False),
1822+
schema_id=1,
1823+
)
1824+
1825+
# Create bound expression for the missing column
1826+
unbound_expr = EqualTo("missing_col", 42)
1827+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1828+
1829+
# File schema only has the existing column (field_id=1), missing field_id=2
1830+
file_schema = Schema(
1831+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1832+
schema_id=1,
1833+
)
1834+
1835+
# Projected column that is missing in the file schema
1836+
projected_field_values = {"missing_col": 1}
1837+
1838+
# Translate column names
1839+
translated_expr = translate_column_names(
1840+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1841+
)
1842+
1843+
# Should evaluate to AlwaysFalse since projected field value does not match the expression literal
1844+
assert translated_expr == AlwaysFalse()
1845+
1846+
1847+
def test_translate_column_names_projected_field_and_initial_default() -> None:
1848+
"""Test translate_column_names with both projected field values and initial_default values."""
1849+
# Original schema with mixed field configurations
1850+
original_schema = Schema(
1851+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1852+
NestedField(field_id=2, name="missing_col_1", field_type=IntegerType(), required=False),
1853+
NestedField(field_id=3, name="missing_col_2", field_type=StringType(), required=False, initial_default="test"),
1854+
schema_id=1,
1855+
)
1856+
1857+
# Create bound expression for both missing columns
1858+
unbound_expr = And(EqualTo("missing_col_1", 42), EqualTo("missing_col_2", "test"))
1859+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1860+
1861+
# File schema only has the existing column (field_id=1), missing field_id=2 and field_id=3
1862+
file_schema = Schema(
1863+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1864+
schema_id=1,
1865+
)
1866+
1867+
# Projected value for one missing column
1868+
projected_field_values = {"missing_col_1": 42}
1869+
1870+
# Translate column names
1871+
translated_expr = translate_column_names(
1872+
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1873+
)
1874+
1875+
# Should evaluate to AlwaysTrue since both missing_col_1's projected value and missing_col_2's initial_default match their respective expression literals
1876+
assert translated_expr == AlwaysTrue()

tests/io/test_pyarrow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,16 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa
11971197
},
11981198
schema=schema,
11991199
)
1200+
# Test that row filter works with partition value projection
1201+
assert table.scan(row_filter="partition_id = 1").to_arrow() == pa.table(
1202+
{
1203+
"other_field": ["foo", "bar", "baz"],
1204+
"partition_id": [1, 1, 1],
1205+
},
1206+
schema=schema,
1207+
)
1208+
# Test that row filter does not return any rows for a non-existing partition value
1209+
assert len(table.scan(row_filter="partition_id = -1").to_arrow()) == 0
12001210

12011211

12021212
def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:

0 commit comments

Comments
 (0)