Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions python/flink_agents/api/memory_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,66 @@
from flink_agents.api.memory_reference import MemoryRef


# Exact builtin types Pemja materializes into native, checkpoint-stable JVM values.
# Exact-type (not isinstance): a str/int Enum or numpy scalar is a subclass that Pemja
# PyObject-wraps despite passing isinstance — accepting it would defeat the validator.
_CHECKPOINT_STABLE_SCALARS = (bool, int, float, str)


def validate_memory_value(path: str, value: Any) -> None:
"""Reject memory values that are not recursively checkpoint-stable.

Python memory values cross the Pemja boundary into Flink state. Only values Pemja
materializes into native JVM types survive checkpoint and restore; anything else is
stored as a stale PyObject wrapper and crashes on restore. Raises TypeError with a
clear, actionable message naming the offending location, type, and a conversion.

Parameters
----------
path: str
The memory path the value is being set at, used to build the error breadcrumb.
value: Any
The value to validate. Must be recursively composed of None, bool, int, float,
str, list, or dict with str keys.
"""
_validate(value, f"value at memory path {path!r}")


def _validate(value: Any, where: str) -> None:
if value is None or type(value) in _CHECKPOINT_STABLE_SCALARS:
return
if isinstance(value, MemoryObject):
msg = (
f"{where} is a MemoryObject; use new_object(...) to store a nested object "
f"instead of passing it to set()."
)
raise TypeError(msg)
if type(value) is list:
for i, item in enumerate(value):
_validate(item, f"{where}[{i}]")
return
if type(value) is dict:
for key, val in value.items():
if type(key) is not str:
msg = (
f"{where} has a non-str key {key!r} ({type(key).__name__}); memory "
f"dict keys must be str. Convert with "
f"{{str(k): v for k, v in value.items()}}."
)
raise TypeError(msg)
_validate(val, f"{where}[{key!r}]")
return
msg = (
f"{where} has type {type(value).__name__!r}, which is not checkpoint-stable. "
f"Python memory values must be recursively composed of None, bool, int, float, "
f"str, list, or dict with str keys, because they cross the Pemja boundary into "
f"Flink state and non-primitive objects cannot be safely checkpointed/restored. "
f"Materialize it first, e.g. str(value) for a UUID, value.model_dump(mode='json')"
f" for a Pydantic model, or list(value) for a tuple/set."
)
raise TypeError(msg)


class MemoryType(Enum):
"""Memory types based on MemoryObject."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,17 @@ def log_to_stdout(input: Any, total: int) -> bool:
content.review += " first action, log success=" + str(log_success) + ","
content.memory_info = {"total_reviews": total}

data_ref = stm.set(f"processed_items.item_{content.id}", content)
data_ref = stm.set(
f"processed_items.item_{content.id}", content.model_dump(mode="json")
)
ctx.send_event(MyEvent(value=data_ref))

@action(MyEvent.EVENT_TYPE)
@staticmethod
def second_action(event: Event, ctx: RunnerContext) -> None:
input_data = MyEvent.from_event(event).value
stm = ctx.short_term_memory
resolved_data: ItemData = stm.get(input_data)
resolved_data = ItemData.model_validate(stm.get(input_data))

content = copy.deepcopy(resolved_data)
content.review += " second action"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def first_action(event: Event, ctx: RunnerContext) -> None:
memory = ctx.short_term_memory

data_path = f"user_data.{key}"
previous_data: ProcessedData = memory.get(data_path)
stored = memory.get(data_path)
previous_data = ProcessedData.model_validate(stored) if stored else None
current_count = previous_data.visit_count if previous_data else 0
new_count = current_count + 1

data_to_store = ProcessedData(content=input_message, visit_count=new_count)
data_ref = memory.set(data_path, data_to_store)
data_ref = memory.set(data_path, data_to_store.model_dump(mode="json"))

ctx.send_event(MyEvent(value=data_ref))

Expand All @@ -95,7 +96,7 @@ def second_action(event: Event, ctx: RunnerContext) -> None:
content_ref: MemoryRef = MyEvent.from_event(event).value
memory = ctx.short_term_memory

processed_data: ProcessedData = memory.get(content_ref)
processed_data = ProcessedData.model_validate(memory.get(content_ref))

base_message = processed_data.content
current_count = processed_data.visit_count
Expand All @@ -104,7 +105,7 @@ def second_action(event: Event, ctx: RunnerContext) -> None:
updated_data_to_store = ProcessedData(
content=base_message, visit_count=new_count
)
memory.set(content_ref.path, updated_data_to_store)
memory.set(content_ref.path, updated_data_to_store.model_dump(mode="json"))

final_content = f"{base_message} -> processed by second_action"
key_with_count = f"(visit {new_count} times)"
Expand Down
7 changes: 6 additions & 1 deletion python/flink_agents/runtime/flink_memory_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
#################################################################################
from typing import Any, Dict, List

from flink_agents.api.memory_object import MemoryObject, MemoryType
from flink_agents.api.memory_object import (
MemoryObject,
MemoryType,
validate_memory_value,
)
from flink_agents.api.memory_reference import MemoryRef


Expand Down Expand Up @@ -66,6 +70,7 @@ def get(self, path_or_ref: str | MemoryRef) -> Any:

def set(self, path: str, value: Any) -> MemoryRef:
"""Set a value at the given path. Creates intermediate objects if needed."""
validate_memory_value(path, value)
try:
j_ref = self._j_memory_object.set(path, value)
return MemoryRef.create(memory_type=self.__type, path=j_ref.getPath())
Expand Down
7 changes: 6 additions & 1 deletion python/flink_agents/runtime/local_memory_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
#################################################################################
from typing import Any, ClassVar, Dict, List

from flink_agents.api.memory_object import MemoryObject, MemoryType
from flink_agents.api.memory_object import (
MemoryObject,
MemoryType,
validate_memory_value,
)
from flink_agents.api.memory_reference import MemoryRef


Expand Down Expand Up @@ -106,6 +110,7 @@ def set(self, path: str, value: Any) -> MemoryRef:
if isinstance(value, LocalMemoryObject):
msg = "Do not set a MemoryObject instance directly; use new_object()."
raise TypeError(msg)
validate_memory_value(path, value)

abs_path = self._full_path(path)

Expand Down
26 changes: 1 addition & 25 deletions python/flink_agents/runtime/tests/test_local_memory_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
from typing import Dict, List, Set
from typing import Dict, List

from flink_agents.api.memory_object import MemoryType
from flink_agents.runtime.local_memory_object import LocalMemoryObject
Expand All @@ -26,20 +26,6 @@ def create_memory() -> LocalMemoryObject:
return LocalMemoryObject(MemoryType.SHORT_TERM, {})


class User:
def __init__(self, name: str, age: int) -> None:
"""Store for later comparison."""
self.name = name
self.age = age

def __eq__(self, other: object) -> bool:
return (
isinstance(other, User)
and other.name == self.name
and other.age == self.age
)


def test_basic_set_get_various_types() -> None:
mem = create_memory()

Expand All @@ -63,16 +49,6 @@ def test_basic_set_get_various_types() -> None:
mem.set("dict", d)
assert mem.get("dict") == d

# set
s: Set[int] = {1, 2, 3}
mem.set("set", s)
assert mem.get("set") == s

# custom object
user = User("Alice", 20)
mem.set("user", user)
assert mem.get("user") == user


def test_nested_set_and_get() -> None:
mem = create_memory()
Expand Down
18 changes: 0 additions & 18 deletions python/flink_agents/runtime/tests/test_memory_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ def create_memory() -> LocalMemoryObject:
return LocalMemoryObject(MemoryType.SHORT_TERM, {})


class User:
def __init__(self, name: str, age: int) -> None:
"""Store for later comparison."""
self.name = name
self.age = age

def __eq__(self, other: object) -> bool:
return (
isinstance(other, User)
and other.name == self.name
and other.age == self.age
)


def test_set_get_involved_ref() -> None:
mem = create_memory()

Expand All @@ -59,8 +45,6 @@ def test_set_get_involved_ref() -> None:
("my_str", "hello", "str"),
("my_list", ["a", "b"], "list"),
("my_dict", {"x": 10}, "dict"),
("my_set", {1, 2, 3}, "set"),
("my_user", User("Alice", 30), "User"),
]

for path, value, _expected_type_name in test_cases:
Expand Down Expand Up @@ -90,8 +74,6 @@ def test_memory_ref_resolve() -> None:
"my_str": "hello",
"my_list": ["a", "b"],
"my_dict": {"x": 10},
"my_set": {1, 2, 3},
"my_user": User("Charlie", 50),
}

for path, value in test_data.items():
Expand Down
105 changes: 105 additions & 0 deletions python/flink_agents/runtime/tests/test_memory_value_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import uuid
from enum import Enum
from unittest.mock import MagicMock

import pytest
from pydantic import BaseModel

from flink_agents.api.memory_object import MemoryType, validate_memory_value
from flink_agents.runtime.flink_memory_object import (
FlinkMemoryObject,
MemoryObjectError,
)


class _Model(BaseModel):
name: str


class _StrEnum(str, Enum):
A = "a"


class _Plain:
pass


def test_accepts_none_and_scalars() -> None:
for value in (None, True, False, 0, 1, -3, 3.14, "", "hello"):
validate_memory_value("p", value)


def test_accepts_nested_list_and_dict() -> None:
validate_memory_value("p", [1, "a", [2, 3], {"k": [4, None]}])
validate_memory_value("p", {"a": 1, "b": {"c": [True, "x"]}})


def test_rejects_pydantic_model() -> None:
with pytest.raises(TypeError, match="model_dump"):
validate_memory_value("p", _Model(name="x"))


def test_rejects_uuid() -> None:
with pytest.raises(TypeError, match=r"str\(value\)"):
validate_memory_value("p", uuid.uuid4())


def test_rejects_tuple_set_frozenset() -> None:
for value in ((1, 2), {1, 2}, frozenset({1, 2})):
with pytest.raises(TypeError, match=r"list\(value\)"):
validate_memory_value("p", value)


def test_rejects_str_enum() -> None:
# str-Enum passes isinstance(str) but is PyObject-wrapped by Pemja; the
# exact-type check must reject it.
with pytest.raises(TypeError, match="not checkpoint-stable"):
validate_memory_value("p", _StrEnum.A)


def test_rejects_custom_class() -> None:
with pytest.raises(TypeError, match="not checkpoint-stable"):
validate_memory_value("p", _Plain())


def test_rejects_non_str_dict_key() -> None:
with pytest.raises(TypeError, match="non-str key"):
validate_memory_value("p", {1: "v"})


def test_rejects_nested_value_reports_breadcrumb() -> None:
with pytest.raises(TypeError, match=r"\[2\]\['bad'\]"):
validate_memory_value("p", [1, 2, {"bad": object()}])


def test_memory_object_value_suggests_new_object() -> None:
inner = FlinkMemoryObject(MemoryType.SHORT_TERM, MagicMock())
with pytest.raises(TypeError, match="new_object"):
validate_memory_value("p", inner)


def test_flink_set_raises_raw_type_error() -> None:
j_obj = MagicMock()
mem = FlinkMemoryObject(MemoryType.SHORT_TERM, j_obj)
with pytest.raises(TypeError) as exc_info:
mem.set("p", uuid.uuid4())
# Validation fires before the Java call, raising a raw TypeError.
assert not isinstance(exc_info.value, MemoryObjectError)
j_obj.set.assert_not_called()
Loading