Skip to content

Commit 971258c

Browse files
committed
Make SetPredicate and subclasses JSON serializable with Pydantic
1 parent 40521c8 commit 971258c

2 files changed

Lines changed: 25 additions & 5 deletions

File tree

pyiceberg/expressions/__init__.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from pyiceberg.typedef import L, StructProtocol
4343
from pyiceberg.types import DoubleType, FloatType, NestedField
4444
from pyiceberg.utils.singleton import Singleton
45+
from pyiceberg.utils.pydantic import IcebergBaseModel
46+
from pydantic import Field
4547

4648

4749
def _to_unbound_term(term: Union[str, UnboundTerm[Any]]) -> UnboundTerm[Any]:
@@ -559,12 +561,19 @@ def as_bound(self) -> Type[BoundNotNaN[L]]:
559561
return BoundNotNaN[L]
560562

561563

562-
class SetPredicate(UnboundPredicate[L], ABC):
563-
literals: Set[Literal[L]]
564+
class SetPredicate(UnboundPredicate[L], IcebergBaseModel, ABC):
565+
type: str = Field(default="in", alias="type")
566+
term: str
567+
value: list[Any]
564568

565569
def __init__(self, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]):
566-
super().__init__(term)
567-
self.literals = _to_literal_set(literals)
570+
# Convert term to string for serialization
571+
term_str = term.name if isinstance(term, Reference) else str(term)
572+
literals_set = _to_literal_set(literals)
573+
value_list = [lit.value for lit in literals_set]
574+
super().__init__(term=term_str, value=value_list)
575+
self.literals = literals_set
576+
self.term_obj = term
568577

569578
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSetPredicate[L]:
570579
bound_term = self.term.bind(schema, case_sensitive)
@@ -676,6 +685,8 @@ def as_unbound(self) -> Type[NotIn[L]]:
676685

677686

678687
class In(SetPredicate[L]):
688+
type: str = Field(default="in", alias="type")
689+
679690
def __new__( # type: ignore # pylint: disable=W0221
680691
cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]
681692
) -> BooleanExpression:
@@ -698,6 +709,8 @@ def as_bound(self) -> Type[BoundIn[L]]:
698709

699710

700711
class NotIn(SetPredicate[L], ABC):
712+
type: str = Field(default="not-in", alias="type")
713+
701714
def __new__( # type: ignore # pylint: disable=W0221
702715
cls, term: Union[str, UnboundTerm[Any]], literals: Union[Iterable[L], Iterable[Literal[L]]]
703716
) -> BooleanExpression:
@@ -712,7 +725,7 @@ def __new__( # type: ignore # pylint: disable=W0221
712725

713726
def __invert__(self) -> In[L]:
714727
"""Transform the Expression into its negated version."""
715-
return In[L](self.term, self.literals)
728+
return In[L](self.term, self._literals)
716729

717730
@property
718731
def as_bound(self) -> Type[BoundNotIn[L]]:

tests/expressions/test_expressions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,13 @@ def test_not_in() -> None:
868868
assert not_in == eval(repr(not_in))
869869
assert not_in == pickle.loads(pickle.dumps(not_in))
870870

871+
def test_serialize_in():
872+
pred = In(term="foo", literals=[1, 2, 3])
873+
assert pred.model_dump_json() == '{"type":"in","term":"foo","value":[1,2,3]}'
874+
875+
def test_serialize_not_in():
876+
pred = NotIn(term="foo", literals=[1, 2, 3])
877+
assert pred.model_dump_json() == '{"type":"not-in","term":"foo","value":[1,2,3]}'
871878

872879
def test_bound_equal_to(term: BoundReference[Any]) -> None:
873880
bound_equal_to = BoundEqualTo(term, literal("a"))

0 commit comments

Comments
 (0)