diff --git a/agent_memory_server/filters.py b/agent_memory_server/filters.py index f38decc6..7ea71ae3 100644 --- a/agent_memory_server/filters.py +++ b/agent_memory_server/filters.py @@ -1,6 +1,7 @@ from datetime import datetime from enum import Enum -from typing import Self +from functools import reduce +from typing import Any, Self from pydantic import BaseModel from pydantic.functional_validators import model_validator @@ -8,14 +9,50 @@ from redisvl.utils.token_escaper import TokenEscaper +def _normalize_tag_filter_aliases(data: Any, aliases: dict[str, str]) -> Any: + if not isinstance(data, dict): + return data + + normalized = dict(data) + for alias, target in aliases.items(): + if alias not in normalized: + continue + + alias_value = normalized.pop(alias) + if alias_value is None: + continue + if normalized.get(target) is not None: + raise ValueError(f"{alias} and {target} cannot both be set") + normalized[target] = alias_value + + return normalized + + +def _and_filter_expressions(expressions: list[FilterExpression]) -> FilterExpression: + return reduce(lambda left, right: left & right, expressions) + + class TagFilter(BaseModel): field: str eq: str | None = None ne: str | None = None any: list[str] | None = None all: list[str] | None = None + not_in: list[str] | None = None startswith: str | None = None + @model_validator(mode="before") + @classmethod + def normalize_aliases(cls, data: Any) -> Any: + return _normalize_tag_filter_aliases( + data, + { + "in_": "any", + "not_eq": "ne", + "none": "not_in", + }, + ) + @model_validator(mode="after") def validate_filters(self) -> Self: if self.eq is not None and self.ne is not None: @@ -26,6 +63,8 @@ def validate_filters(self) -> Self: raise ValueError("all cannot be an empty list") if self.any is not None and len(self.any) == 0: raise ValueError("any cannot be an empty list") + if self.not_in is not None and len(self.not_in) == 0: + raise ValueError("not_in cannot be an empty list") # Validate startswith doesn't combine with other filters if self.startswith is not None: if self.startswith == "": @@ -38,6 +77,8 @@ def validate_filters(self) -> Self: raise ValueError("startswith and any cannot both be set") if self.all is not None: raise ValueError("startswith and all cannot both be set") + if self.not_in is not None: + raise ValueError("startswith and not_in cannot both be set") return self def to_filter(self) -> FilterExpression: @@ -54,7 +95,11 @@ def to_filter(self) -> FilterExpression: if self.any is not None: return Tag(self.field) == self.any if self.all is not None: - return Tag(self.field) == self.all + return _and_filter_expressions( + [Tag(self.field) == value for value in self.all] + ) + if self.not_in is not None: + return Tag(self.field) != self.not_in raise ValueError("No filter provided") @@ -67,6 +112,18 @@ class EnumFilter(BaseModel): ne: str | None = None any: list[str] | None = None all: list[str] | None = None + not_in: list[str] | None = None + + @model_validator(mode="before") + @classmethod + def normalize_aliases(cls, data: Any) -> Any: + return _normalize_tag_filter_aliases( + data, + { + "in_": "any", + "not_eq": "ne", + }, + ) @model_validator(mode="after") def validate_filters(self) -> Self: @@ -78,6 +135,8 @@ def validate_filters(self) -> Self: raise ValueError("all cannot be an empty list") if self.any is not None and len(self.any) == 0: raise ValueError("any cannot be an empty list") + if self.not_in is not None and len(self.not_in) == 0: + raise ValueError("not_in cannot be an empty list") # Validate enum values valid_values = [e.value for e in self.enum_class] @@ -102,6 +161,12 @@ def validate_filters(self) -> Self: raise ValueError( f"all value '{val}' not in valid enum values: {valid_values}" ) + if self.not_in is not None: + for val in self.not_in: + if val not in valid_values: + raise ValueError( + f"not_in value '{val}' not in valid enum values: {valid_values}" + ) return self @@ -113,7 +178,11 @@ def to_filter(self) -> FilterExpression: if self.any is not None: return Tag(self.field) == self.any if self.all is not None: - return Tag(self.field) == self.all + return _and_filter_expressions( + [Tag(self.field) == value for value in self.all] + ) + if self.not_in is not None: + return Tag(self.field) != self.not_in raise ValueError("No filter provided") diff --git a/tests/test_filters.py b/tests/test_filters.py index 8c0fe197..64ea9fe2 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -24,6 +24,7 @@ from agent_memory_server.models import ( MemoryRecordResult, MemoryRecordResults, + SearchRequest, ) @@ -266,11 +267,58 @@ def test_any_cannot_be_empty_list(self): with pytest.raises(ValueError, match="any cannot be an empty list"): TagFilter(field="test", any=[]) + def test_not_in_cannot_be_empty_list(self): + """not_in cannot be an empty list.""" + with pytest.raises(ValueError, match="not_in cannot be an empty list"): + TagFilter(field="test", not_in=[]) + + def test_in_alias_maps_to_any(self): + """Legacy SDK in_ alias should produce an any filter.""" + filter_obj = TagFilter(field="tags", in_=["tag1", "tag2"]) + assert filter_obj.any == ["tag1", "tag2"] + assert str(filter_obj.to_filter()) == "@tags:{tag1|tag2}" + + def test_not_eq_alias_maps_to_ne(self): + """Legacy SDK not_eq alias should produce an ne filter.""" + filter_obj = TagFilter(field="tags", not_eq="tag1") + assert filter_obj.ne == "tag1" + assert str(filter_obj.to_filter()) == "(-@tags:{tag1})" + + def test_not_in_filter_creates_negated_union(self): + """not_in should exclude any matching tag value.""" + filter_obj = TagFilter(field="tags", not_in=["tag1", "tag2"]) + assert str(filter_obj.to_filter()) == "(-@tags:{tag1|tag2})" + + def test_none_alias_maps_to_not_in(self): + """Legacy SDK none alias should produce a not_in filter.""" + filter_obj = TagFilter(field="topics", none=["topic1", "topic2"]) + assert filter_obj.not_in == ["topic1", "topic2"] + assert str(filter_obj.to_filter()) == "(-@topics:{topic1|topic2})" + def test_all_filter_creates_expression(self): """all filter should create a valid FilterExpression.""" filter_obj = TagFilter(field="tags", all=["tag1", "tag2"]) result = filter_obj.to_filter() assert isinstance(result, FilterExpression) + assert str(result) == "(@tags:{tag1} @tags:{tag2})" + + def test_search_request_accepts_legacy_sdk_tag_filter_aliases(self): + """SDK-emitted aliases should survive request parsing as real filters.""" + request = SearchRequest.model_validate( + { + "text": "query", + "session_id": {"in_": ["session-1", "session-2"]}, + "namespace": {"not_eq": "archived"}, + "topics": {"none": ["private"]}, + } + ) + + assert request.session_id is not None + assert request.session_id.any == ["session-1", "session-2"] + assert request.namespace is not None + assert request.namespace.ne == "archived" + assert request.topics is not None + assert request.topics.not_in == ["private"] def test_no_filter_provided_raises(self): """No filter provided should raise ValueError.""" @@ -363,6 +411,39 @@ def test_any_cannot_be_empty(self): with pytest.raises(ValueError, match="any cannot be an empty list"): EnumFilter(field="status", enum_class=_SampleEnum, any=[]) + def test_not_in_cannot_be_empty(self): + """not_in cannot be an empty list.""" + with pytest.raises(ValueError, match="not_in cannot be an empty list"): + EnumFilter(field="status", enum_class=_SampleEnum, not_in=[]) + + def test_in_alias_maps_to_any(self): + """Legacy SDK in_ alias should produce an any enum filter.""" + filter_obj = EnumFilter( + field="status", enum_class=_SampleEnum, in_=["value1", "value2"] + ) + assert filter_obj.any == ["value1", "value2"] + assert str(filter_obj.to_filter()) == "@status:{value1|value2}" + + def test_not_eq_alias_maps_to_ne(self): + """Legacy SDK not_eq alias should produce an ne enum filter.""" + filter_obj = EnumFilter(field="status", enum_class=_SampleEnum, not_eq="value1") + assert filter_obj.ne == "value1" + assert str(filter_obj.to_filter()) == "(-@status:{value1})" + + def test_not_in_filter_creates_negated_union(self): + """not_in should exclude any matching enum value.""" + filter_obj = EnumFilter( + field="status", enum_class=_SampleEnum, not_in=["value1", "value2"] + ) + assert str(filter_obj.to_filter()) == "(-@status:{value1|value2})" + + def test_not_in_with_invalid_value_raises(self): + """not_in with invalid enum value should raise.""" + with pytest.raises(ValueError, match="not in valid enum values"): + EnumFilter( + field="status", enum_class=_SampleEnum, not_in=["value1", "invalid"] + ) + def test_no_filter_provided_raises(self): """No filter provided should raise ValueError.""" filter_obj = EnumFilter(field="status", enum_class=_SampleEnum)