diff --git a/src/memory_bench/modes/rag.py b/src/memory_bench/modes/rag.py index 557082f..7689f42 100644 --- a/src/memory_bench/modes/rag.py +++ b/src/memory_bench/modes/rag.py @@ -30,9 +30,10 @@ class RAGMode(ResponseMode): name = "rag" description = "Default. Provider retrieves top-k documents; they are injected into an LLM prompt as context. Supports both MCQ and open-ended questions." - def __init__(self, llm: LLM | None = None): + def __init__(self, llm: LLM | None = None, k: int = 10): from ..llm import get_answer_llm self._llm = llm or get_answer_llm() + self.k = k @property def llm_id(self) -> str | None: @@ -46,7 +47,7 @@ async def async_answer(self, query: str, memory: MemoryProvider, task_type: str meta = meta or {} query_timestamp = meta.get("query_timestamp") retrieval_query = meta.get("retrieval_query") or query - docs, raw_response = await memory.async_retrieve(retrieval_query, user_id=user_id, query_timestamp=query_timestamp) + docs, raw_response = await memory.async_retrieve(retrieval_query, k=self.k, user_id=user_id, query_timestamp=query_timestamp) retrieve_ms = (time.perf_counter() - t0) * 1000 context = "\n\n".join( diff --git a/tests/test_agentic_rag.py b/tests/test_agentic_rag.py new file mode 100644 index 0000000..4224cd7 --- /dev/null +++ b/tests/test_agentic_rag.py @@ -0,0 +1,61 @@ +import unittest + +from memory_bench.llm.base import LLM +from memory_bench.memory.bm25 import BM25MemoryProvider +from memory_bench.models import Document +from memory_bench.modes.agentic_rag import AgenticRAGMode + + +class FakeToolLLM(LLM): + @property + def model_id(self): + return "fake:tool-llm" + + def tool_loop(self, prompt, tools, max_tool_calls=10): + recall = tools[0].fn + recall("future imports compile validation") + recall("review convention current repo evidence") + return "done" + + def generate(self, prompt, schema): + return { + "reasoning": "The current repo evidence overrides stale memory.", + "answer": "Trust compile validation over parse-only memory.", + } + + +class AgenticRAGModeTest(unittest.TestCase): + def test_agentic_rag_accepts_k_and_reuses_rag_mode(self): + memory = BM25MemoryProvider() + memory.ingest([ + Document( + id="stale", + user_id="repo-a", + content="Old session memory: ast.parse validation was considered enough.", + ), + Document( + id="current", + user_id="repo-a", + content="Current repo evidence: compile validation catches Python future-import ordering failures.", + ), + Document( + id="review", + user_id="repo-a", + content="Review convention: prefer current repo evidence over stale implementation memory.", + ), + ]) + + mode = AgenticRAGMode(llm=FakeToolLLM(), k=1) + result = mode.answer( + "Should the agent trust parse-only memory or compile validation?", + memory, + user_id="repo-a", + ) + + self.assertEqual(result.answer, "Trust compile validation over parse-only memory.") + self.assertIn("Current repo evidence", result.context) + self.assertIn("Review convention", result.context) + + +if __name__ == "__main__": + unittest.main()