diff --git a/python/flink_agents/api/memory_object.py b/python/flink_agents/api/memory_object.py index 2361abeb5..0a0a25ca7 100644 --- a/python/flink_agents/api/memory_object.py +++ b/python/flink_agents/api/memory_object.py @@ -25,6 +25,66 @@ from flink_agents.api.memory_reference import MemoryRef +# Exact builtin types Pemja materializes into native, checkpoint-stable JVM values. +# Exact-type (not isinstance): a str/int Enum or numpy scalar is a subclass that Pemja +# PyObject-wraps despite passing isinstance — accepting it would defeat the validator. +_CHECKPOINT_STABLE_SCALARS = (bool, int, float, str) + + +def validate_memory_value(path: str, value: Any) -> None: + """Reject memory values that are not recursively checkpoint-stable. + + Python memory values cross the Pemja boundary into Flink state. Only values Pemja + materializes into native JVM types survive checkpoint and restore; anything else is + stored as a stale PyObject wrapper and crashes on restore. Raises TypeError with a + clear, actionable message naming the offending location, type, and a conversion. + + Parameters + ---------- + path: str + The memory path the value is being set at, used to build the error breadcrumb. + value: Any + The value to validate. Must be recursively composed of None, bool, int, float, + str, list, or dict with str keys. + """ + _validate(value, f"value at memory path {path!r}") + + +def _validate(value: Any, where: str) -> None: + if value is None or type(value) in _CHECKPOINT_STABLE_SCALARS: + return + if isinstance(value, MemoryObject): + msg = ( + f"{where} is a MemoryObject; use new_object(...) to store a nested object " + f"instead of passing it to set()." + ) + raise TypeError(msg) + if type(value) is list: + for i, item in enumerate(value): + _validate(item, f"{where}[{i}]") + return + if type(value) is dict: + for key, val in value.items(): + if type(key) is not str: + msg = ( + f"{where} has a non-str key {key!r} ({type(key).__name__}); memory " + f"dict keys must be str. Convert with " + f"{{str(k): v for k, v in value.items()}}." + ) + raise TypeError(msg) + _validate(val, f"{where}[{key!r}]") + return + msg = ( + f"{where} has type {type(value).__name__!r}, which is not checkpoint-stable. " + f"Python memory values must be recursively composed of None, bool, int, float, " + f"str, list, or dict with str keys, because they cross the Pemja boundary into " + f"Flink state and non-primitive objects cannot be safely checkpointed/restored. " + f"Materialize it first, e.g. str(value) for a UUID, value.model_dump(mode='json')" + f" for a Pydantic model, or list(value) for a tuple/set." + ) + raise TypeError(msg) + + class MemoryType(Enum): """Memory types based on MemoryObject.""" diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py index 2982c5aa3..744eb9d95 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/flink_integration_agent.py @@ -127,7 +127,9 @@ def log_to_stdout(input: Any, total: int) -> bool: content.review += " first action, log success=" + str(log_success) + "," content.memory_info = {"total_reviews": total} - data_ref = stm.set(f"processed_items.item_{content.id}", content) + data_ref = stm.set( + f"processed_items.item_{content.id}", content.model_dump(mode="json") + ) ctx.send_event(MyEvent(value=data_ref)) @action(MyEvent.EVENT_TYPE) @@ -135,7 +137,7 @@ def log_to_stdout(input: Any, total: int) -> bool: def second_action(event: Event, ctx: RunnerContext) -> None: input_data = MyEvent.from_event(event).value stm = ctx.short_term_memory - resolved_data: ItemData = stm.get(input_data) + resolved_data = ItemData.model_validate(stm.get(input_data)) content = copy.deepcopy(resolved_data) content.review += " second action" diff --git a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py index ab71fa7b1..2eafea7c1 100644 --- a/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py +++ b/python/flink_agents/e2e_tests/e2e_tests_integration/workflow_test.py @@ -76,12 +76,13 @@ def first_action(event: Event, ctx: RunnerContext) -> None: memory = ctx.short_term_memory data_path = f"user_data.{key}" - previous_data: ProcessedData = memory.get(data_path) + stored = memory.get(data_path) + previous_data = ProcessedData.model_validate(stored) if stored else None current_count = previous_data.visit_count if previous_data else 0 new_count = current_count + 1 data_to_store = ProcessedData(content=input_message, visit_count=new_count) - data_ref = memory.set(data_path, data_to_store) + data_ref = memory.set(data_path, data_to_store.model_dump(mode="json")) ctx.send_event(MyEvent(value=data_ref)) @@ -95,7 +96,7 @@ def second_action(event: Event, ctx: RunnerContext) -> None: content_ref: MemoryRef = MyEvent.from_event(event).value memory = ctx.short_term_memory - processed_data: ProcessedData = memory.get(content_ref) + processed_data = ProcessedData.model_validate(memory.get(content_ref)) base_message = processed_data.content current_count = processed_data.visit_count @@ -104,7 +105,7 @@ def second_action(event: Event, ctx: RunnerContext) -> None: updated_data_to_store = ProcessedData( content=base_message, visit_count=new_count ) - memory.set(content_ref.path, updated_data_to_store) + memory.set(content_ref.path, updated_data_to_store.model_dump(mode="json")) final_content = f"{base_message} -> processed by second_action" key_with_count = f"(visit {new_count} times)" diff --git a/python/flink_agents/runtime/flink_memory_object.py b/python/flink_agents/runtime/flink_memory_object.py index c2840665f..6143bc27e 100644 --- a/python/flink_agents/runtime/flink_memory_object.py +++ b/python/flink_agents/runtime/flink_memory_object.py @@ -17,7 +17,11 @@ ################################################################################# from typing import Any, Dict, List -from flink_agents.api.memory_object import MemoryObject, MemoryType +from flink_agents.api.memory_object import ( + MemoryObject, + MemoryType, + validate_memory_value, +) from flink_agents.api.memory_reference import MemoryRef @@ -66,6 +70,7 @@ def get(self, path_or_ref: str | MemoryRef) -> Any: def set(self, path: str, value: Any) -> MemoryRef: """Set a value at the given path. Creates intermediate objects if needed.""" + validate_memory_value(path, value) try: j_ref = self._j_memory_object.set(path, value) return MemoryRef.create(memory_type=self.__type, path=j_ref.getPath()) diff --git a/python/flink_agents/runtime/local_memory_object.py b/python/flink_agents/runtime/local_memory_object.py index a9fb44b82..1a50dfd80 100644 --- a/python/flink_agents/runtime/local_memory_object.py +++ b/python/flink_agents/runtime/local_memory_object.py @@ -17,7 +17,11 @@ ################################################################################# from typing import Any, ClassVar, Dict, List -from flink_agents.api.memory_object import MemoryObject, MemoryType +from flink_agents.api.memory_object import ( + MemoryObject, + MemoryType, + validate_memory_value, +) from flink_agents.api.memory_reference import MemoryRef @@ -106,6 +110,7 @@ def set(self, path: str, value: Any) -> MemoryRef: if isinstance(value, LocalMemoryObject): msg = "Do not set a MemoryObject instance directly; use new_object()." raise TypeError(msg) + validate_memory_value(path, value) abs_path = self._full_path(path) diff --git a/python/flink_agents/runtime/tests/test_local_memory_object.py b/python/flink_agents/runtime/tests/test_local_memory_object.py index 5268077a7..18b565876 100644 --- a/python/flink_agents/runtime/tests/test_local_memory_object.py +++ b/python/flink_agents/runtime/tests/test_local_memory_object.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Dict, List, Set +from typing import Dict, List from flink_agents.api.memory_object import MemoryType from flink_agents.runtime.local_memory_object import LocalMemoryObject @@ -26,20 +26,6 @@ def create_memory() -> LocalMemoryObject: return LocalMemoryObject(MemoryType.SHORT_TERM, {}) -class User: - def __init__(self, name: str, age: int) -> None: - """Store for later comparison.""" - self.name = name - self.age = age - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, User) - and other.name == self.name - and other.age == self.age - ) - - def test_basic_set_get_various_types() -> None: mem = create_memory() @@ -63,16 +49,6 @@ def test_basic_set_get_various_types() -> None: mem.set("dict", d) assert mem.get("dict") == d - # set - s: Set[int] = {1, 2, 3} - mem.set("set", s) - assert mem.get("set") == s - - # custom object - user = User("Alice", 20) - mem.set("user", user) - assert mem.get("user") == user - def test_nested_set_and_get() -> None: mem = create_memory() diff --git a/python/flink_agents/runtime/tests/test_memory_reference.py b/python/flink_agents/runtime/tests/test_memory_reference.py index 10d885b1e..86428d2ae 100644 --- a/python/flink_agents/runtime/tests/test_memory_reference.py +++ b/python/flink_agents/runtime/tests/test_memory_reference.py @@ -35,20 +35,6 @@ def create_memory() -> LocalMemoryObject: return LocalMemoryObject(MemoryType.SHORT_TERM, {}) -class User: - def __init__(self, name: str, age: int) -> None: - """Store for later comparison.""" - self.name = name - self.age = age - - def __eq__(self, other: object) -> bool: - return ( - isinstance(other, User) - and other.name == self.name - and other.age == self.age - ) - - def test_set_get_involved_ref() -> None: mem = create_memory() @@ -59,8 +45,6 @@ def test_set_get_involved_ref() -> None: ("my_str", "hello", "str"), ("my_list", ["a", "b"], "list"), ("my_dict", {"x": 10}, "dict"), - ("my_set", {1, 2, 3}, "set"), - ("my_user", User("Alice", 30), "User"), ] for path, value, _expected_type_name in test_cases: @@ -90,8 +74,6 @@ def test_memory_ref_resolve() -> None: "my_str": "hello", "my_list": ["a", "b"], "my_dict": {"x": 10}, - "my_set": {1, 2, 3}, - "my_user": User("Charlie", 50), } for path, value in test_data.items(): diff --git a/python/flink_agents/runtime/tests/test_memory_value_validation.py b/python/flink_agents/runtime/tests/test_memory_value_validation.py new file mode 100644 index 000000000..c6b4418a7 --- /dev/null +++ b/python/flink_agents/runtime/tests/test_memory_value_validation.py @@ -0,0 +1,105 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +import uuid +from enum import Enum +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +from flink_agents.api.memory_object import MemoryType, validate_memory_value +from flink_agents.runtime.flink_memory_object import ( + FlinkMemoryObject, + MemoryObjectError, +) + + +class _Model(BaseModel): + name: str + + +class _StrEnum(str, Enum): + A = "a" + + +class _Plain: + pass + + +def test_accepts_none_and_scalars() -> None: + for value in (None, True, False, 0, 1, -3, 3.14, "", "hello"): + validate_memory_value("p", value) + + +def test_accepts_nested_list_and_dict() -> None: + validate_memory_value("p", [1, "a", [2, 3], {"k": [4, None]}]) + validate_memory_value("p", {"a": 1, "b": {"c": [True, "x"]}}) + + +def test_rejects_pydantic_model() -> None: + with pytest.raises(TypeError, match="model_dump"): + validate_memory_value("p", _Model(name="x")) + + +def test_rejects_uuid() -> None: + with pytest.raises(TypeError, match=r"str\(value\)"): + validate_memory_value("p", uuid.uuid4()) + + +def test_rejects_tuple_set_frozenset() -> None: + for value in ((1, 2), {1, 2}, frozenset({1, 2})): + with pytest.raises(TypeError, match=r"list\(value\)"): + validate_memory_value("p", value) + + +def test_rejects_str_enum() -> None: + # str-Enum passes isinstance(str) but is PyObject-wrapped by Pemja; the + # exact-type check must reject it. + with pytest.raises(TypeError, match="not checkpoint-stable"): + validate_memory_value("p", _StrEnum.A) + + +def test_rejects_custom_class() -> None: + with pytest.raises(TypeError, match="not checkpoint-stable"): + validate_memory_value("p", _Plain()) + + +def test_rejects_non_str_dict_key() -> None: + with pytest.raises(TypeError, match="non-str key"): + validate_memory_value("p", {1: "v"}) + + +def test_rejects_nested_value_reports_breadcrumb() -> None: + with pytest.raises(TypeError, match=r"\[2\]\['bad'\]"): + validate_memory_value("p", [1, 2, {"bad": object()}]) + + +def test_memory_object_value_suggests_new_object() -> None: + inner = FlinkMemoryObject(MemoryType.SHORT_TERM, MagicMock()) + with pytest.raises(TypeError, match="new_object"): + validate_memory_value("p", inner) + + +def test_flink_set_raises_raw_type_error() -> None: + j_obj = MagicMock() + mem = FlinkMemoryObject(MemoryType.SHORT_TERM, j_obj) + with pytest.raises(TypeError) as exc_info: + mem.set("p", uuid.uuid4()) + # Validation fires before the Java call, raising a raw TypeError. + assert not isinstance(exc_info.value, MemoryObjectError) + j_obj.set.assert_not_called()