Skip to content

Commit bc89cfa

Browse files
committed
add unit tests
1 parent 58e5ad6 commit bc89cfa

1 file changed

Lines changed: 176 additions & 0 deletions

File tree

tests/expressions/test_visitors.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,15 @@
7272
expression_to_plain_format,
7373
rewrite_not,
7474
rewrite_to_dnf,
75+
translate_column_names,
7576
visit,
7677
visit_bound_predicate,
7778
)
7879
from pyiceberg.manifest import ManifestFile, PartitionFieldSummary
7980
from pyiceberg.schema import Accessor, Schema
8081
from pyiceberg.typedef import Record
8182
from pyiceberg.types import (
83+
BooleanType,
8284
DoubleType,
8385
FloatType,
8486
IcebergType,
@@ -1623,3 +1625,177 @@ def test_expression_evaluator_null() -> None:
16231625
assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False
16241626
assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False
16251627
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True
1628+
1629+
1630+
def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None:
1631+
"""Test translate_column_names with matching column names."""
1632+
# Create a bound expression using the original schema
1633+
unbound_expr = EqualTo("foo", "test_value")
1634+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
1635+
1636+
# File schema has the same column names
1637+
file_schema = Schema(
1638+
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
1639+
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
1640+
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
1641+
schema_id=1,
1642+
)
1643+
1644+
# Translate column names
1645+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1646+
1647+
# Should return the same unbound expression since column names match
1648+
assert isinstance(translated_expr, EqualTo)
1649+
assert translated_expr.term == Reference("foo")
1650+
assert translated_expr.literal == literal("test_value")
1651+
1652+
1653+
def test_translate_column_names_different_column_names() -> None:
1654+
"""Test translate_column_names with different column names in file schema."""
1655+
# Original schema
1656+
original_schema = Schema(
1657+
NestedField(field_id=1, name="original_name", field_type=StringType(), required=False),
1658+
schema_id=1,
1659+
)
1660+
1661+
# Create bound expression
1662+
unbound_expr = EqualTo("original_name", "test_value")
1663+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1664+
1665+
# File schema has different column name but same field ID
1666+
file_schema = Schema(
1667+
NestedField(field_id=1, name="file_column_name", field_type=StringType(), required=False),
1668+
schema_id=1,
1669+
)
1670+
1671+
# Translate column names
1672+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1673+
1674+
# Should use the file schema's column name
1675+
assert isinstance(translated_expr, EqualTo)
1676+
assert translated_expr.term == Reference("file_column_name")
1677+
assert translated_expr.literal == literal("test_value")
1678+
1679+
1680+
def test_translate_column_names_missing_column() -> None:
1681+
"""Test translate_column_names when column is missing from file schema (schema evolution)."""
1682+
# Original schema
1683+
original_schema = Schema(
1684+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False, initial_default="default"),
1685+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42),
1686+
schema_id=1,
1687+
)
1688+
1689+
# Create bound expression for the missing column
1690+
unbound_expr = EqualTo("missing_col", 42)
1691+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1692+
1693+
# File schema only has the existing column (field_id=1), missing field_id=2
1694+
file_schema = Schema(
1695+
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
1696+
schema_id=1,
1697+
)
1698+
1699+
# Translate column names
1700+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1701+
1702+
# Should evaluate to AlwaysTrue because the default value (42) matches the literal (42)
1703+
assert translated_expr == AlwaysTrue()
1704+
1705+
1706+
def test_translate_column_names_missing_column_false_evaluation() -> None:
1707+
"""Test translate_column_names when missing column evaluates to false."""
1708+
# Original schema
1709+
original_schema = Schema(
1710+
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
1711+
schema_id=1,
1712+
)
1713+
1714+
# Create bound expression that won't match the default value
1715+
unbound_expr = EqualTo("missing_col", 42) # default is 10, literal is 42
1716+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1717+
1718+
# File schema doesn't have this column
1719+
file_schema = Schema(
1720+
NestedField(field_id=1, name="other_col", field_type=StringType(), required=False),
1721+
schema_id=1,
1722+
)
1723+
1724+
# Translate column names
1725+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1726+
1727+
# Should evaluate to AlwaysFalse because default value (10) doesn't match literal (42)
1728+
assert translated_expr == AlwaysFalse()
1729+
1730+
1731+
def test_translate_column_names_complex_expression() -> None:
1732+
"""Test translate_column_names with complex boolean expressions."""
1733+
# Original schema
1734+
original_schema = Schema(
1735+
NestedField(field_id=1, name="col1", field_type=StringType(), required=False),
1736+
NestedField(field_id=2, name="col2", field_type=IntegerType(), required=True),
1737+
schema_id=1,
1738+
)
1739+
1740+
# Create complex bound expression
1741+
unbound_expr = And(EqualTo("col1", "test"), GreaterThan("col2", 10))
1742+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1743+
1744+
# File schema has different column names
1745+
file_schema = Schema(
1746+
NestedField(field_id=1, name="file_col1", field_type=StringType(), required=False),
1747+
NestedField(field_id=2, name="file_col2", field_type=IntegerType(), required=True),
1748+
schema_id=1,
1749+
)
1750+
1751+
# Translate column names
1752+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1753+
1754+
# Should be an And expression with translated column names
1755+
assert isinstance(translated_expr, And)
1756+
assert isinstance(translated_expr.left, EqualTo)
1757+
assert translated_expr.left.term == Reference("file_col1")
1758+
assert isinstance(translated_expr.right, GreaterThan)
1759+
assert translated_expr.right.term == Reference("file_col2")
1760+
1761+
1762+
def test_translate_column_names_case_sensitive() -> None:
1763+
"""Test translate_column_names with case sensitivity."""
1764+
# Original schema
1765+
original_schema = Schema(
1766+
NestedField(field_id=1, name="TestColumn", field_type=StringType(), required=False),
1767+
schema_id=1,
1768+
)
1769+
1770+
# Create bound expression
1771+
unbound_expr = EqualTo("TestColumn", "value")
1772+
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))
1773+
1774+
# File schema has same field ID but different case
1775+
file_schema = Schema(
1776+
NestedField(field_id=1, name="testcolumn", field_type=StringType(), required=False),
1777+
schema_id=1,
1778+
)
1779+
1780+
# Translate with case sensitivity
1781+
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True)
1782+
1783+
# Should use the file schema's column name (different case)
1784+
assert isinstance(translated_expr, EqualTo)
1785+
assert translated_expr.term == Reference("testcolumn")
1786+
1787+
1788+
def test_translate_column_names_always_true_false() -> None:
1789+
"""Test translate_column_names with AlwaysTrue and AlwaysFalse expressions."""
1790+
file_schema = Schema(
1791+
NestedField(field_id=1, name="col", field_type=StringType(), required=False),
1792+
schema_id=1,
1793+
)
1794+
1795+
# Test AlwaysTrue
1796+
translated_true = translate_column_names(AlwaysTrue(), file_schema, case_sensitive=True)
1797+
assert translated_true == AlwaysTrue()
1798+
1799+
# Test AlwaysFalse
1800+
translated_false = translate_column_names(AlwaysFalse(), file_schema, case_sensitive=True)
1801+
assert translated_false == AlwaysFalse()

0 commit comments

Comments
 (0)