Skip to content

Commit 9257a6d

Browse files
committed
feat: make LiteralPredicate serializable via internal IcebergBaseModel
1 parent 7c6792e commit 9257a6d

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

pyiceberg/expressions/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
from pydantic import Field
3737

38+
from pydantic import Field
39+
3840
from pyiceberg.expressions.literals import (
3941
AboveMax,
4042
BelowMin,
@@ -750,6 +752,39 @@ def __init__(self, term: Union[str, UnboundTerm[Any]], literal: Union[L, Literal
750752
super().__init__(term)
751753
self.literal = _to_literal(literal) # pylint: disable=W0621
752754

755+
# ---- JSON (Pydantic) serialization helpers ----
756+
757+
class _LiteralPredicateModel(IcebergBaseModel):
758+
type: str = Field(alias="type")
759+
term: str
760+
value: Any
761+
762+
def _json_op(self) -> str:
763+
mapping = {
764+
EqualTo: "eq",
765+
NotEqualTo: "not-eq",
766+
LessThan: "lt",
767+
LessThanOrEqual: "lt-eq",
768+
GreaterThan: "gt",
769+
GreaterThanOrEqual: "gt-eq",
770+
StartsWith: "starts-with",
771+
NotStartsWith: "not-starts-with",
772+
}
773+
for cls, op in mapping.items():
774+
if isinstance(self, cls):
775+
return op
776+
raise ValueError(f"Unknown LiteralPredicate: {type(self).__name__}")
777+
778+
def model_dump(self, **kwargs: Any) -> dict:
779+
term_name = getattr(self.term, "name", str(self.term))
780+
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump(**kwargs)
781+
782+
def model_dump_json(self, **kwargs: Any) -> str:
783+
term_name = getattr(self.term, "name", str(self.term))
784+
return self._LiteralPredicateModel(type=self._json_op(), term=term_name, value=self.literal.value).model_dump_json(
785+
**kwargs
786+
)
787+
753788
def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredicate[L]:
754789
bound_term = self.term.bind(schema, case_sensitive)
755790
lit = self.literal.to(bound_term.ref().field.field_type)

tests/expressions/test_expressions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@
5555
NotIn,
5656
NotNaN,
5757
NotNull,
58+
NotStartsWith,
5859
Or,
5960
Reference,
61+
StartsWith,
6062
UnboundPredicate,
6163
)
6264
from pyiceberg.expressions.literals import Literal, literal
@@ -933,6 +935,7 @@ def test_bound_less_than_or_equal(term: BoundReference[Any]) -> None:
933935

934936
def test_equal_to() -> None:
935937
equal_to = EqualTo(Reference("a"), literal("a"))
938+
assert equal_to.model_dump_json() == '{"type":"eq","term":"a","value":"a"}'
936939
assert str(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
937940
assert repr(equal_to) == "EqualTo(term=Reference(name='a'), literal=literal('a'))"
938941
assert equal_to == eval(repr(equal_to))
@@ -941,6 +944,7 @@ def test_equal_to() -> None:
941944

942945
def test_not_equal_to() -> None:
943946
not_equal_to = NotEqualTo(Reference("a"), literal("a"))
947+
assert not_equal_to.model_dump_json() == '{"type":"not-eq","term":"a","value":"a"}'
944948
assert str(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
945949
assert repr(not_equal_to) == "NotEqualTo(term=Reference(name='a'), literal=literal('a'))"
946950
assert not_equal_to == eval(repr(not_equal_to))
@@ -949,6 +953,7 @@ def test_not_equal_to() -> None:
949953

950954
def test_greater_than_or_equal_to() -> None:
951955
greater_than_or_equal_to = GreaterThanOrEqual(Reference("a"), literal("a"))
956+
assert greater_than_or_equal_to.model_dump_json() == '{"type":"gt-eq","term":"a","value":"a"}'
952957
assert str(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
953958
assert repr(greater_than_or_equal_to) == "GreaterThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
954959
assert greater_than_or_equal_to == eval(repr(greater_than_or_equal_to))
@@ -957,6 +962,7 @@ def test_greater_than_or_equal_to() -> None:
957962

958963
def test_greater_than() -> None:
959964
greater_than = GreaterThan(Reference("a"), literal("a"))
965+
assert greater_than.model_dump_json() == '{"type":"gt","term":"a","value":"a"}'
960966
assert str(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
961967
assert repr(greater_than) == "GreaterThan(term=Reference(name='a'), literal=literal('a'))"
962968
assert greater_than == eval(repr(greater_than))
@@ -965,14 +971,26 @@ def test_greater_than() -> None:
965971

966972
def test_less_than() -> None:
967973
less_than = LessThan(Reference("a"), literal("a"))
974+
assert less_than.model_dump_json() == '{"type":"lt","term":"a","value":"a"}'
968975
assert str(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
969976
assert repr(less_than) == "LessThan(term=Reference(name='a'), literal=literal('a'))"
970977
assert less_than == eval(repr(less_than))
971978
assert less_than == pickle.loads(pickle.dumps(less_than))
972979

973980

981+
def test_starts_with() -> None:
982+
starts_with = StartsWith(Reference("a"), literal("a"))
983+
assert starts_with.model_dump_json() == '{"type":"starts-with","term":"a","value":"a"}'
984+
985+
986+
def test_not_starts_with() -> None:
987+
not_starts_with = NotStartsWith(Reference("a"), literal("a"))
988+
assert not_starts_with.model_dump_json() == '{"type":"not-starts-with","term":"a","value":"a"}'
989+
990+
974991
def test_less_than_or_equal() -> None:
975992
less_than_or_equal = LessThanOrEqual(Reference("a"), literal("a"))
993+
assert less_than_or_equal.model_dump_json() == '{"type":"lt-eq","term":"a","value":"a"}'
976994
assert str(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
977995
assert repr(less_than_or_equal) == "LessThanOrEqual(term=Reference(name='a'), literal=literal('a'))"
978996
assert less_than_or_equal == eval(repr(less_than_or_equal))

0 commit comments

Comments
 (0)