|
72 | 72 | expression_to_plain_format, |
73 | 73 | rewrite_not, |
74 | 74 | rewrite_to_dnf, |
| 75 | + translate_column_names, |
75 | 76 | visit, |
76 | 77 | visit_bound_predicate, |
77 | 78 | ) |
78 | 79 | from pyiceberg.manifest import ManifestFile, PartitionFieldSummary |
79 | 80 | from pyiceberg.schema import Accessor, Schema |
80 | 81 | from pyiceberg.typedef import Record |
81 | 82 | from pyiceberg.types import ( |
| 83 | + BooleanType, |
82 | 84 | DoubleType, |
83 | 85 | FloatType, |
84 | 86 | IcebergType, |
@@ -1623,3 +1625,177 @@ def test_expression_evaluator_null() -> None: |
1623 | 1625 | assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False |
1624 | 1626 | assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False |
1625 | 1627 | assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True |
| 1628 | + |
| 1629 | + |
| 1630 | +def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None: |
| 1631 | + """Test translate_column_names with matching column names.""" |
| 1632 | + # Create a bound expression using the original schema |
| 1633 | + unbound_expr = EqualTo("foo", "test_value") |
| 1634 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) |
| 1635 | + |
| 1636 | + # File schema has the same column names |
| 1637 | + file_schema = Schema( |
| 1638 | + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), |
| 1639 | + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), |
| 1640 | + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), |
| 1641 | + schema_id=1, |
| 1642 | + ) |
| 1643 | + |
| 1644 | + # Translate column names |
| 1645 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1646 | + |
| 1647 | + # Should return the same unbound expression since column names match |
| 1648 | + assert isinstance(translated_expr, EqualTo) |
| 1649 | + assert translated_expr.term == Reference("foo") |
| 1650 | + assert translated_expr.literal == literal("test_value") |
| 1651 | + |
| 1652 | + |
| 1653 | +def test_translate_column_names_different_column_names() -> None: |
| 1654 | + """Test translate_column_names with different column names in file schema.""" |
| 1655 | + # Original schema |
| 1656 | + original_schema = Schema( |
| 1657 | + NestedField(field_id=1, name="original_name", field_type=StringType(), required=False), |
| 1658 | + schema_id=1, |
| 1659 | + ) |
| 1660 | + |
| 1661 | + # Create bound expression |
| 1662 | + unbound_expr = EqualTo("original_name", "test_value") |
| 1663 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) |
| 1664 | + |
| 1665 | + # File schema has different column name but same field ID |
| 1666 | + file_schema = Schema( |
| 1667 | + NestedField(field_id=1, name="file_column_name", field_type=StringType(), required=False), |
| 1668 | + schema_id=1, |
| 1669 | + ) |
| 1670 | + |
| 1671 | + # Translate column names |
| 1672 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1673 | + |
| 1674 | + # Should use the file schema's column name |
| 1675 | + assert isinstance(translated_expr, EqualTo) |
| 1676 | + assert translated_expr.term == Reference("file_column_name") |
| 1677 | + assert translated_expr.literal == literal("test_value") |
| 1678 | + |
| 1679 | + |
| 1680 | +def test_translate_column_names_missing_column() -> None: |
| 1681 | + """Test translate_column_names when column is missing from file schema (schema evolution).""" |
| 1682 | + # Original schema |
| 1683 | + original_schema = Schema( |
| 1684 | + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False, initial_default="default"), |
| 1685 | + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42), |
| 1686 | + schema_id=1, |
| 1687 | + ) |
| 1688 | + |
| 1689 | + # Create bound expression for the missing column |
| 1690 | + unbound_expr = EqualTo("missing_col", 42) |
| 1691 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) |
| 1692 | + |
| 1693 | + # File schema only has the existing column (field_id=1), missing field_id=2 |
| 1694 | + file_schema = Schema( |
| 1695 | + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), |
| 1696 | + schema_id=1, |
| 1697 | + ) |
| 1698 | + |
| 1699 | + # Translate column names |
| 1700 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1701 | + |
| 1702 | + # Should evaluate to AlwaysTrue because the default value (42) matches the literal (42) |
| 1703 | + assert translated_expr == AlwaysTrue() |
| 1704 | + |
| 1705 | + |
| 1706 | +def test_translate_column_names_missing_column_false_evaluation() -> None: |
| 1707 | + """Test translate_column_names when missing column evaluates to false.""" |
| 1708 | + # Original schema |
| 1709 | + original_schema = Schema( |
| 1710 | + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10), |
| 1711 | + schema_id=1, |
| 1712 | + ) |
| 1713 | + |
| 1714 | + # Create bound expression that won't match the default value |
| 1715 | + unbound_expr = EqualTo("missing_col", 42) # default is 10, literal is 42 |
| 1716 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) |
| 1717 | + |
| 1718 | + # File schema doesn't have this column |
| 1719 | + file_schema = Schema( |
| 1720 | + NestedField(field_id=1, name="other_col", field_type=StringType(), required=False), |
| 1721 | + schema_id=1, |
| 1722 | + ) |
| 1723 | + |
| 1724 | + # Translate column names |
| 1725 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1726 | + |
| 1727 | + # Should evaluate to AlwaysFalse because default value (10) doesn't match literal (42) |
| 1728 | + assert translated_expr == AlwaysFalse() |
| 1729 | + |
| 1730 | + |
| 1731 | +def test_translate_column_names_complex_expression() -> None: |
| 1732 | + """Test translate_column_names with complex boolean expressions.""" |
| 1733 | + # Original schema |
| 1734 | + original_schema = Schema( |
| 1735 | + NestedField(field_id=1, name="col1", field_type=StringType(), required=False), |
| 1736 | + NestedField(field_id=2, name="col2", field_type=IntegerType(), required=True), |
| 1737 | + schema_id=1, |
| 1738 | + ) |
| 1739 | + |
| 1740 | + # Create complex bound expression |
| 1741 | + unbound_expr = And(EqualTo("col1", "test"), GreaterThan("col2", 10)) |
| 1742 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) |
| 1743 | + |
| 1744 | + # File schema has different column names |
| 1745 | + file_schema = Schema( |
| 1746 | + NestedField(field_id=1, name="file_col1", field_type=StringType(), required=False), |
| 1747 | + NestedField(field_id=2, name="file_col2", field_type=IntegerType(), required=True), |
| 1748 | + schema_id=1, |
| 1749 | + ) |
| 1750 | + |
| 1751 | + # Translate column names |
| 1752 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1753 | + |
| 1754 | + # Should be an And expression with translated column names |
| 1755 | + assert isinstance(translated_expr, And) |
| 1756 | + assert isinstance(translated_expr.left, EqualTo) |
| 1757 | + assert translated_expr.left.term == Reference("file_col1") |
| 1758 | + assert isinstance(translated_expr.right, GreaterThan) |
| 1759 | + assert translated_expr.right.term == Reference("file_col2") |
| 1760 | + |
| 1761 | + |
| 1762 | +def test_translate_column_names_case_sensitive() -> None: |
| 1763 | + """Test translate_column_names with case sensitivity.""" |
| 1764 | + # Original schema |
| 1765 | + original_schema = Schema( |
| 1766 | + NestedField(field_id=1, name="TestColumn", field_type=StringType(), required=False), |
| 1767 | + schema_id=1, |
| 1768 | + ) |
| 1769 | + |
| 1770 | + # Create bound expression |
| 1771 | + unbound_expr = EqualTo("TestColumn", "value") |
| 1772 | + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) |
| 1773 | + |
| 1774 | + # File schema has same field ID but different case |
| 1775 | + file_schema = Schema( |
| 1776 | + NestedField(field_id=1, name="testcolumn", field_type=StringType(), required=False), |
| 1777 | + schema_id=1, |
| 1778 | + ) |
| 1779 | + |
| 1780 | + # Translate with case sensitivity |
| 1781 | + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) |
| 1782 | + |
| 1783 | + # Should use the file schema's column name (different case) |
| 1784 | + assert isinstance(translated_expr, EqualTo) |
| 1785 | + assert translated_expr.term == Reference("testcolumn") |
| 1786 | + |
| 1787 | + |
| 1788 | +def test_translate_column_names_always_true_false() -> None: |
| 1789 | + """Test translate_column_names with AlwaysTrue and AlwaysFalse expressions.""" |
| 1790 | + file_schema = Schema( |
| 1791 | + NestedField(field_id=1, name="col", field_type=StringType(), required=False), |
| 1792 | + schema_id=1, |
| 1793 | + ) |
| 1794 | + |
| 1795 | + # Test AlwaysTrue |
| 1796 | + translated_true = translate_column_names(AlwaysTrue(), file_schema, case_sensitive=True) |
| 1797 | + assert translated_true == AlwaysTrue() |
| 1798 | + |
| 1799 | + # Test AlwaysFalse |
| 1800 | + translated_false = translate_column_names(AlwaysFalse(), file_schema, case_sensitive=True) |
| 1801 | + assert translated_false == AlwaysFalse() |
0 commit comments