Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions agent_memory_server/filters.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,58 @@
from datetime import datetime
from enum import Enum
from typing import Self
from functools import reduce
from typing import Any, Self

from pydantic import BaseModel
from pydantic.functional_validators import model_validator
from redisvl.query.filter import FilterExpression, Num, Tag
from redisvl.utils.token_escaper import TokenEscaper


def _normalize_tag_filter_aliases(data: Any, aliases: dict[str, str]) -> Any:
if not isinstance(data, dict):
return data

normalized = dict(data)
for alias, target in aliases.items():
if alias not in normalized:
continue

alias_value = normalized.pop(alias)
if alias_value is None:
continue
if normalized.get(target) is not None:
raise ValueError(f"{alias} and {target} cannot both be set")
normalized[target] = alias_value

return normalized


def _and_filter_expressions(expressions: list[FilterExpression]) -> FilterExpression:
return reduce(lambda left, right: left & right, expressions)


class TagFilter(BaseModel):
field: str
eq: str | None = None
ne: str | None = None
any: list[str] | None = None
all: list[str] | None = None
not_in: list[str] | None = None
startswith: str | None = None

@model_validator(mode="before")
@classmethod
def normalize_aliases(cls, data: Any) -> Any:
return _normalize_tag_filter_aliases(
data,
{
"in_": "any",
"not_eq": "ne",
"none": "not_in",
},
)

@model_validator(mode="after")
def validate_filters(self) -> Self:
if self.eq is not None and self.ne is not None:
Expand All @@ -26,6 +63,8 @@ def validate_filters(self) -> Self:
raise ValueError("all cannot be an empty list")
if self.any is not None and len(self.any) == 0:
raise ValueError("any cannot be an empty list")
if self.not_in is not None and len(self.not_in) == 0:
raise ValueError("not_in cannot be an empty list")
# Validate startswith doesn't combine with other filters
if self.startswith is not None:
if self.startswith == "":
Expand All @@ -38,6 +77,8 @@ def validate_filters(self) -> Self:
raise ValueError("startswith and any cannot both be set")
if self.all is not None:
raise ValueError("startswith and all cannot both be set")
if self.not_in is not None:
raise ValueError("startswith and not_in cannot both be set")
return self

def to_filter(self) -> FilterExpression:
Expand All @@ -54,7 +95,11 @@ def to_filter(self) -> FilterExpression:
if self.any is not None:
return Tag(self.field) == self.any
if self.all is not None:
return Tag(self.field) == self.all
return _and_filter_expressions(
[Tag(self.field) == value for value in self.all]
)
if self.not_in is not None:
return Tag(self.field) != self.not_in
raise ValueError("No filter provided")


Expand All @@ -67,6 +112,18 @@ class EnumFilter(BaseModel):
ne: str | None = None
any: list[str] | None = None
all: list[str] | None = None
not_in: list[str] | None = None

@model_validator(mode="before")
@classmethod
def normalize_aliases(cls, data: Any) -> Any:
return _normalize_tag_filter_aliases(
data,
{
"in_": "any",
"not_eq": "ne",
},
)

@model_validator(mode="after")
def validate_filters(self) -> Self:
Expand All @@ -78,6 +135,8 @@ def validate_filters(self) -> Self:
raise ValueError("all cannot be an empty list")
if self.any is not None and len(self.any) == 0:
raise ValueError("any cannot be an empty list")
if self.not_in is not None and len(self.not_in) == 0:
raise ValueError("not_in cannot be an empty list")

# Validate enum values
valid_values = [e.value for e in self.enum_class]
Expand All @@ -102,6 +161,12 @@ def validate_filters(self) -> Self:
raise ValueError(
f"all value '{val}' not in valid enum values: {valid_values}"
)
if self.not_in is not None:
for val in self.not_in:
if val not in valid_values:
raise ValueError(
f"not_in value '{val}' not in valid enum values: {valid_values}"
)

return self

Expand All @@ -113,7 +178,11 @@ def to_filter(self) -> FilterExpression:
if self.any is not None:
return Tag(self.field) == self.any
if self.all is not None:
return Tag(self.field) == self.all
return _and_filter_expressions(
[Tag(self.field) == value for value in self.all]
)
if self.not_in is not None:
return Tag(self.field) != self.not_in
raise ValueError("No filter provided")


Expand Down
81 changes: 81 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from agent_memory_server.models import (
MemoryRecordResult,
MemoryRecordResults,
SearchRequest,
)


Expand Down Expand Up @@ -266,11 +267,58 @@ def test_any_cannot_be_empty_list(self):
with pytest.raises(ValueError, match="any cannot be an empty list"):
TagFilter(field="test", any=[])

def test_not_in_cannot_be_empty_list(self):
"""not_in cannot be an empty list."""
with pytest.raises(ValueError, match="not_in cannot be an empty list"):
TagFilter(field="test", not_in=[])

def test_in_alias_maps_to_any(self):
"""Legacy SDK in_ alias should produce an any filter."""
filter_obj = TagFilter(field="tags", in_=["tag1", "tag2"])
assert filter_obj.any == ["tag1", "tag2"]
assert str(filter_obj.to_filter()) == "@tags:{tag1|tag2}"

def test_not_eq_alias_maps_to_ne(self):
"""Legacy SDK not_eq alias should produce an ne filter."""
filter_obj = TagFilter(field="tags", not_eq="tag1")
assert filter_obj.ne == "tag1"
assert str(filter_obj.to_filter()) == "(-@tags:{tag1})"

def test_not_in_filter_creates_negated_union(self):
"""not_in should exclude any matching tag value."""
filter_obj = TagFilter(field="tags", not_in=["tag1", "tag2"])
assert str(filter_obj.to_filter()) == "(-@tags:{tag1|tag2})"

def test_none_alias_maps_to_not_in(self):
"""Legacy SDK none alias should produce a not_in filter."""
filter_obj = TagFilter(field="topics", none=["topic1", "topic2"])
assert filter_obj.not_in == ["topic1", "topic2"]
assert str(filter_obj.to_filter()) == "(-@topics:{topic1|topic2})"

def test_all_filter_creates_expression(self):
"""all filter should create a valid FilterExpression."""
filter_obj = TagFilter(field="tags", all=["tag1", "tag2"])
result = filter_obj.to_filter()
assert isinstance(result, FilterExpression)
assert str(result) == "(@tags:{tag1} @tags:{tag2})"

def test_search_request_accepts_legacy_sdk_tag_filter_aliases(self):
"""SDK-emitted aliases should survive request parsing as real filters."""
request = SearchRequest.model_validate(
{
"text": "query",
"session_id": {"in_": ["session-1", "session-2"]},
"namespace": {"not_eq": "archived"},
"topics": {"none": ["private"]},
}
)

assert request.session_id is not None
assert request.session_id.any == ["session-1", "session-2"]
assert request.namespace is not None
assert request.namespace.ne == "archived"
assert request.topics is not None
assert request.topics.not_in == ["private"]

def test_no_filter_provided_raises(self):
"""No filter provided should raise ValueError."""
Expand Down Expand Up @@ -363,6 +411,39 @@ def test_any_cannot_be_empty(self):
with pytest.raises(ValueError, match="any cannot be an empty list"):
EnumFilter(field="status", enum_class=_SampleEnum, any=[])

def test_not_in_cannot_be_empty(self):
"""not_in cannot be an empty list."""
with pytest.raises(ValueError, match="not_in cannot be an empty list"):
EnumFilter(field="status", enum_class=_SampleEnum, not_in=[])

def test_in_alias_maps_to_any(self):
"""Legacy SDK in_ alias should produce an any enum filter."""
filter_obj = EnumFilter(
field="status", enum_class=_SampleEnum, in_=["value1", "value2"]
)
assert filter_obj.any == ["value1", "value2"]
assert str(filter_obj.to_filter()) == "@status:{value1|value2}"

def test_not_eq_alias_maps_to_ne(self):
"""Legacy SDK not_eq alias should produce an ne enum filter."""
filter_obj = EnumFilter(field="status", enum_class=_SampleEnum, not_eq="value1")
assert filter_obj.ne == "value1"
assert str(filter_obj.to_filter()) == "(-@status:{value1})"

def test_not_in_filter_creates_negated_union(self):
"""not_in should exclude any matching enum value."""
filter_obj = EnumFilter(
field="status", enum_class=_SampleEnum, not_in=["value1", "value2"]
)
assert str(filter_obj.to_filter()) == "(-@status:{value1|value2})"

def test_not_in_with_invalid_value_raises(self):
"""not_in with invalid enum value should raise."""
with pytest.raises(ValueError, match="not in valid enum values"):
EnumFilter(
field="status", enum_class=_SampleEnum, not_in=["value1", "invalid"]
)

def test_no_filter_provided_raises(self):
"""No filter provided should raise ValueError."""
filter_obj = EnumFilter(field="status", enum_class=_SampleEnum)
Expand Down
Loading