diff --git a/eval/README.md b/eval/README.md index ec0f9ed..1eb892e 100644 --- a/eval/README.md +++ b/eval/README.md @@ -125,6 +125,56 @@ These artifacts enable reproducibility and further analysis (error inspection, e --- +## Running Gemma4-E4B through local oMLX with GPT-5.5 Judge + +For local Gemma4-E4B experiments, use the orchestration runner: + +```bash +uv run python eval/run_gemma4_e4b_omlx_eval.py \ + --data-path /path/to/HaluMem-medium.jsonl \ + --version gemma4-e4b-medium-smoke \ + --limit-users 1 +``` + +The runner: + +* starts an isolated local `omlx serve` instance for `gemma-4-E4B-it-MLX-8bit`; +* verifies `/health`, `/v1/models`, and a tiny `/v1/chat/completions` request; +* runs the local `gemma4-e4b` adapter with the local Gemma4-E4B endpoint; +* stops `omlx serve` and verifies the port is released; +* runs `evaluation.py` with `JUDGE_USE_CODEX_CONFIG=true`, so GPT-5.5 judge credentials and base URL are read from Codex config; +* writes a markdown report under `eval/reports/`. + +Default local paths: + +```text +oMLX binary: /Users/qiang/arika/.venv-omlx/bin/omlx +Gemma4-E4B model: /Users/qiang/.omlx/models/gemma-4-E4B-it-MLX-8bit +oMLX base path: /Users/qiang/.omlx-halumem-gemma4-e4b +Port: 50634 +``` + +The runner intentionally separates model roles: + +```text +System under test: memory backend + local Gemma4-E4B via oMLX +Judge: GPT-5.5 from Codex config, streaming Responses API +``` + +By default it requires at least `20 GiB` free on the oMLX base-path volume before starting. Use `--min-free-gb` to tune that guard, or `--force-low-disk` only for a deliberate low-disk smoke test. + +Useful scope controls: + +```bash +--max-sessions 0 +--max-questions-per-session 0 +--max-extracted-memories-per-session 6 +``` + +`0` means no limit for sessions/questions. `--max-extracted-memories-per-session` caps the number of candidate memories scored by the official Memory Accuracy judge. + +--- + ## Special Configurations for Some Memory Systems While the experimental setup strives to maintain consistent configurations across all evaluated systems, certain memory systems exhibit unique API constraints that necessitate specific adjustments or workarounds. diff --git a/eval/eval_gemma4_e4b_local.py b/eval/eval_gemma4_e4b_local.py new file mode 100644 index 0000000..c360477 --- /dev/null +++ b/eval/eval_gemma4_e4b_local.py @@ -0,0 +1,273 @@ +import argparse +import copy +import json +import os +import re +import time +from pathlib import Path +from typing import Iterable + +from tqdm import tqdm + +from llms import llm_request +from prompts import PROMPT_MEMOS + + +TEMPLATE_LOCAL = """Memories for user {user_id}: + +{memories} +""" + + +def iter_jsonl(file_path: Path) -> Iterable[dict]: + with file_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + yield json.loads(line) + + +def extract_user_name(persona_info: str) -> str: + match = re.search(r"Name:\s*(.*?); Gender:", persona_info) + if not match: + raise ValueError("No name found.") + return match.group(1).strip() + + +def dialogue_for_prompt(session: dict, max_chars: int = 24000) -> str: + lines: list[str] = [] + for turn in session.get("dialogue", []): + content = " ".join(str(turn.get("content", "")).split()) + if not content: + continue + lines.append(f"[{turn.get('timestamp', '')}] {turn.get('role', '')}: {content}") + dialogue = "\n".join(lines) + if len(dialogue) <= max_chars: + return dialogue + return dialogue[-max_chars:] + + +def parse_json_object(content: str) -> dict: + match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if not match: + match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL) + if match: + return json.loads(match.group(1).strip()) + + start = content.find("{") + end = content.rfind("}") + if start >= 0 and end > start: + return json.loads(content[start:end + 1]) + + raise ValueError(f"No JSON object found in model output: {content[:500]}") + + +def normalize_memory(memory: str) -> str: + return " ".join(str(memory).split()).strip(" -") + + +def fallback_user_memories(session: dict, max_memories: int) -> list[str]: + memories: list[str] = [] + for turn in session.get("dialogue", []): + if turn.get("role") != "user": + continue + content = normalize_memory(turn.get("content", "")) + if not content: + continue + memories.append(f"{turn.get('timestamp', '')}: {content}") + if len(memories) >= max_memories: + break + return memories + + +def extract_memories_with_gemma(session: dict, max_memories: int) -> list[str]: + if max_memories <= 0: + return [] + + prompt = f"""You are the memory extraction module for a personal assistant. + +Extract at most {max_memories} concise, self-contained memories from the user's messages only. + +Rules: +- Keep only durable personal facts, preferences, plans, relationships, health/work/life details, and explicit updates. +- Include relevant names, dates, places, and quantities when present. +- Ignore assistant-only claims, chit-chat, and instructions that are not user facts. +- Do not invent facts. +- Each memory must be one short sentence. +- Return only JSON with this shape: +{{"memories": ["2023-05-04: The user ..."]}} + +Conversation: +{dialogue_for_prompt(session)} +""" + try: + result = parse_json_object(llm_request(prompt)) + memories = result.get("memories", []) + except Exception as exc: + print(f"Memory extraction failed, using fallback: {exc}") + return fallback_user_memories(session, max_memories) + + normalized: list[str] = [] + seen: set[str] = set() + for memory in memories: + item = normalize_memory(memory) + if not item or item in seen: + continue + seen.add(item) + normalized.append(item) + if len(normalized) >= max_memories: + break + + if not normalized: + return fallback_user_memories(session, max_memories) + return normalized + + +def score_memory(query: str, memory: str) -> int: + query_terms = {t.lower() for t in re.findall(r"[A-Za-z0-9']+", query) if len(t) > 2} + memory_terms = {t.lower() for t in re.findall(r"[A-Za-z0-9']+", memory) if len(t) > 2} + return len(query_terms & memory_terms) + + +def search_memory(query: str, memory_bank: list[str], top_k: int) -> tuple[str, list[str], float]: + start = time.time() + scored = sorted( + ((score_memory(query, memory), idx, memory) for idx, memory in enumerate(memory_bank)), + key=lambda item: (item[0], -item[1]), + reverse=True, + ) + memories = [memory for score, _, memory in scored if score > 0][:top_k] + if not memories: + memories = memory_bank[-top_k:] + context = TEMPLATE_LOCAL.format(user_id="local", memories="\n".join(memories)) + duration_ms = (time.time() - start) * 1000 + return context, memories, duration_ms + + +def process_user( + user_data: dict, + *, + top_k: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> dict: + user_name = extract_user_name(user_data["persona_info"]) + sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] + + new_user_data = { + "uuid": user_data["uuid"], + "user_name": user_name, + "sessions": [], + } + + memory_bank: list[str] = [] + for session in tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"): + start_add = time.time() + extracted_memories = extract_memories_with_gemma( + session, + max_memories=max_extracted_memories_per_session, + ) + memory_bank.extend(extracted_memories) + + new_session = { + "memory_points": copy.deepcopy(session["memory_points"]), + "dialogue": session["dialogue"], + "extracted_memories": extracted_memories, + "add_dialogue_duration_ms": (time.time() - start_add) * 1000, + } + + for memory in new_session["memory_points"]: + if memory["is_update"] == "False" or not memory["original_memories"]: + continue + _, memories_from_system, duration_ms = search_memory(memory["memory_content"], memory_bank, top_k=10) + memory["memories_from_system"] = memories_from_system + memory["search_duration_ms"] = duration_ms + + if "questions" in session: + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] + new_session["questions"] = [] + for qa in questions: + context, _, duration_ms = search_memory(qa["question"], memory_bank, top_k=top_k) + prompt = PROMPT_MEMOS.format(context=context, question=qa["question"]) + start_response = time.time() + response = llm_request(prompt) + + new_qa = copy.deepcopy(qa) + new_qa["context"] = context + new_qa["search_duration_ms"] = duration_ms + new_qa["system_response"] = response + new_qa["response_duration_ms"] = (time.time() - start_response) * 1000 + new_session["questions"].append(new_qa) + + new_user_data["sessions"].append(new_session) + + return new_user_data + + +def main( + *, + data_path: Path, + version: str, + top_k: int, + user_num: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> Path: + frame = "gemma4-e4b" + save_path = Path("results") / f"{frame}-{version}" + save_path.mkdir(parents=True, exist_ok=True) + output_file = save_path / f"{frame}_eval_results.jsonl" + tmp_dir = save_path / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + if output_file.exists(): + output_file.unlink() + + processed = 0 + for user_data in iter_jsonl(data_path): + processed += 1 + tmp_file = tmp_dir / f"{user_data['uuid']}.json" + result = process_user( + user_data, + top_k=top_k, + max_sessions=max_sessions, + max_questions_per_session=max_questions_per_session, + max_extracted_memories_per_session=max_extracted_memories_per_session, + ) + tmp_file.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8") + with output_file.open("a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"[{processed}/{user_num}] Saved {tmp_file}") + if processed >= user_num: + break + + print(f"Final results saved to: {output_file}") + return output_file + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=Path, required=True) + parser.add_argument("--version", required=True) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--user-num", type=int, default=1) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + args = parser.parse_args() + + main( + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + user_num=args.user_num, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, + max_extracted_memories_per_session=args.max_extracted_memories_per_session, + ) diff --git a/eval/eval_memobase.py b/eval/eval_memobase.py index 3d78b93..2fda322 100644 --- a/eval/eval_memobase.py +++ b/eval/eval_memobase.py @@ -3,9 +3,10 @@ import time import json import copy +import argparse import traceback -from datetime import datetime, timezone, timedelta +from datetime import datetime, timezone from dotenv import load_dotenv from concurrent.futures import ProcessPoolExecutor, as_completed @@ -150,7 +151,7 @@ def add_peer_memory(client, user_id, messages): def add_memory(client, user_id, dialogues, batch_size=20): start = time.time() - start_add_time = (datetime.now() - timedelta(hours=8)).strftime("%Y-%m-%d %H:%M:%S") + start_add_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") batch_size = 20 for i in range(0, len(dialogues), batch_size): @@ -167,7 +168,7 @@ def add_memory(client, user_id, dialogues, batch_size=20): bid = u.insert(blob) u.flush(sync=True) - end_add_time = (datetime.now() - timedelta(hours=8)).strftime("%Y-%m-%d %H:%M:%S") + end_add_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") memories = query_content_by_time(start_add_time, user_id) duration_ms = (time.time() - start) * 1000 @@ -217,13 +218,22 @@ def search_memory(client, query, user_id, max_token_size=1000): return "", [], duration_ms -def process_user(user_data, max_token_size, save_path, version): +def process_user( + user_data, + max_token_size, + save_path, + version, + max_sessions=0, + max_questions_per_session=0, +): user_name = extract_user_name(user_data["persona_info"]) + f"_{version}" user_id = client.add_user({ "name":user_name }) sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] tmp_dir = os.path.join(save_path, "tmp") os.makedirs(tmp_dir, exist_ok=True) @@ -289,7 +299,11 @@ def process_user(user_data, max_token_size, save_path, version): new_session["questions"] = [] - for qa in session["questions"]: + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] + + for qa in questions: context, _, duration_ms = search_memory( client=client, @@ -342,7 +356,9 @@ def main( data_path: str, version: str = "default", max_token_size: int = 500, - max_workers: int = 5 + max_workers: int = 5, + max_sessions: int = 0, + max_questions_per_session: int = 0, ): frame = "memobase" save_path = f"results/{frame}-{version}/" @@ -358,7 +374,15 @@ def main( futures = {} for idx, user_data in enumerate(iter_jsonl(data_path), 1): uuid = user_data["uuid"] - future = executor.submit(process_user, user_data, max_token_size, save_path, version) + future = executor.submit( + process_user, + user_data, + max_token_size, + save_path, + version, + max_sessions, + max_questions_per_session, + ) futures[future] = uuid total_users = idx @@ -389,12 +413,20 @@ def main( if __name__ == "__main__": - data_path = "../data/HaluMem-medium.jsonl" - version = "default" - max_token_size = 500 + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="../data/HaluMem-medium.jsonl") + parser.add_argument("--version", default="default") + parser.add_argument("--max-token-size", type=int, default=500) + parser.add_argument("--max-workers", type=int, default=5) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + args = parser.parse_args() main( - data_path=data_path, - version=version, - max_token_size=max_token_size - ) \ No newline at end of file + data_path=args.data_path, + version=args.version, + max_token_size=args.max_token_size, + max_workers=args.max_workers, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, + ) diff --git a/eval/eval_memos.py b/eval/eval_memos.py index c745b4d..eba7a28 100644 --- a/eval/eval_memos.py +++ b/eval/eval_memos.py @@ -4,6 +4,7 @@ import uuid import json import copy +import argparse import requests import traceback from datetime import datetime, timezone @@ -34,7 +35,9 @@ WAIT_TIME = 60 memos_url = os.getenv("MEMOS_URL") -headers = {"Content-Type": "application/json", "Authorization": os.getenv("MEMOS_KEY")} +headers = {"Content-Type": "application/json"} +if os.getenv("MEMOS_KEY"): + headers["Authorization"] = os.getenv("MEMOS_KEY") @retry( @@ -53,7 +56,8 @@ def add(messages, user_id, conv_id): "messages": messages, "user_id": user_id, "mem_cube_id": user_id, - "conversation_id": conv_id, + "writable_cube_ids": [user_id], + "session_id": conv_id, "mode": "fine", "async_mode": "sync", } @@ -89,14 +93,30 @@ def add_dialogue(dialogue, user_id, conv_id): for i in range(0, len(formatted_dialogue), batch_num): batch = formatted_dialogue[i:i+batch_num] response, batch_duration_ms = add(batch, user_id, conv_id) - memories.extend( - [item['memory'] for item in response["data"]] - ) + memories.extend(extract_memories_from_add_response(response)) duration_ms += batch_duration_ms return memories, duration_ms +def extract_memories_from_add_response(response_json): + memories = [] + + def walk(value): + if isinstance(value, dict): + memory = value.get("memory") or value.get("content") or value.get("text") + if isinstance(memory, str) and memory.strip(): + memories.append(memory.strip()) + for child in value.values(): + walk(child) + elif isinstance(value, list): + for child in value: + walk(child) + + walk(response_json.get("data", [])) + return list(dict.fromkeys(memories)) + + @retry( retry=retry_if_exception_type(Exception), wait=wait_fixed(WAIT_TIME), @@ -114,11 +134,13 @@ def search_memory(query, user_id, top_k, pref_top_k=6): "query": query, "user_id": user_id, "mem_cube_id": user_id, - "conversation_id": "", + "readable_cube_ids": [user_id], + "session_id": "", "top_k": top_k, "mode": os.getenv("SEARCH_MODE", "fast"), "include_preference": True, - "pref_top_k": pref_top_k + "pref_top_k": pref_top_k, + "relativity": float(os.getenv("MEMOS_RELATIVITY", "0")), }, ensure_ascii=False ) @@ -127,8 +149,8 @@ def search_memory(query, user_id, top_k, pref_top_k=6): assert json.loads(response.text)["message"] == "Search completed successfully", response.text results = json.loads(response.text)["data"] - memories = [i["memory"] for i in results["text_mem"][0]["memories"]] - pref_memories = [i["memory"] for i in results["pref_mem"][0]["memories"]] + memories = extract_memories_from_search_bucket(results.get("text_mem", [])) + pref_memories = extract_memories_from_search_bucket(results.get("pref_mem", [])) context = TEMPLATE_MEMOS.format( user_id=user_id, @@ -140,6 +162,30 @@ def search_memory(query, user_id, top_k, pref_top_k=6): return context, memories + pref_memories, duration_ms +def extract_memories_from_search_bucket(bucket): + memories = [] + + def add_memory(value): + if isinstance(value, str) and value.strip(): + memories.append(value.strip()) + elif isinstance(value, dict): + memory = value.get("memory") or value.get("content") or value.get("text") + if isinstance(memory, str) and memory.strip(): + memories.append(memory.strip()) + + if isinstance(bucket, dict): + bucket = [bucket] + + for group in bucket or []: + if isinstance(group, dict) and isinstance(group.get("memories"), list): + for item in group["memories"]: + add_memory(item) + else: + add_memory(group) + + return list(dict.fromkeys(memories)) + + def extract_user_name(persona_info: str): match = re.search(r'Name:\s*(.*?); Gender:', persona_info) @@ -150,10 +196,20 @@ def extract_user_name(persona_info: str): raise ValueError("No name found.") -def process_user(user_data, top_k, pref_top_k, save_path, version): +def process_user( + user_data, + top_k, + pref_top_k, + save_path, + version, + max_sessions=0, + max_questions_per_session=0, +): user_name = extract_user_name(user_data["persona_info"]) + f"_{version}" sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] tmp_dir = os.path.join(save_path, "tmp") os.makedirs(tmp_dir, exist_ok=True) @@ -221,7 +277,11 @@ def process_user(user_data, top_k, pref_top_k, save_path, version): new_session["questions"] = [] - for qa in session["questions"]: + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] + + for qa in questions: context, _, duration_ms = search_memory( query=qa["question"], @@ -275,7 +335,9 @@ def main( version: str = "default", top_k: int = 20, pref_top_k: int = 6, - max_workers: int = 2 + max_workers: int = 2, + max_sessions: int = 0, + max_questions_per_session: int = 0, ): frame = "memos" save_path = f"results/{frame}-{version}/" @@ -291,7 +353,16 @@ def main( futures = {} for idx, user_data in enumerate(iter_jsonl(data_path), 1): uuid = user_data["uuid"] - future = executor.submit(process_user, user_data, top_k, pref_top_k, save_path, version) + future = executor.submit( + process_user, + user_data, + top_k, + pref_top_k, + save_path, + version, + max_sessions, + max_questions_per_session, + ) futures[future] = uuid total_users = idx @@ -322,12 +393,22 @@ def main( if __name__ == "__main__": - data_path = "../data/HaluMem-medium.jsonl" - version = "default" - top_k = 20 + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="../data/HaluMem-medium.jsonl") + parser.add_argument("--version", default="default") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--pref-top-k", type=int, default=6) + parser.add_argument("--max-workers", type=int, default=2) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + args = parser.parse_args() main( - data_path=data_path, - version=version, - top_k=top_k + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + pref_top_k=args.pref_top_k, + max_workers=args.max_workers, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, ) diff --git a/eval/eval_memzero.py b/eval/eval_memzero.py index 6fb7621..9b1fa8c 100644 --- a/eval/eval_memzero.py +++ b/eval/eval_memzero.py @@ -3,6 +3,7 @@ import time import json import copy +import argparse import traceback from datetime import datetime, timezone from dotenv import load_dotenv @@ -328,12 +329,16 @@ def main( if __name__ == "__main__": - data_path = "../data/HaluMem-long.jsonl" - version = "long" - top_k = 20 + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="../data/HaluMem-long.jsonl") + parser.add_argument("--version", default="long") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--max-workers", type=int, default=2) + args = parser.parse_args() main( - data_path=data_path, - version=version, - top_k=top_k - ) \ No newline at end of file + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + max_workers=args.max_workers, + ) diff --git a/eval/eval_memzero_graph.py b/eval/eval_memzero_graph.py index 3b9d4ed..1fa766d 100644 --- a/eval/eval_memzero_graph.py +++ b/eval/eval_memzero_graph.py @@ -1,56 +1,69 @@ +import argparse +import copy +import json import os import re import time -import json -import copy -import traceback -from datetime import datetime, timezone -from dotenv import load_dotenv -from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Iterable +from mem0 import Memory from tqdm import tqdm -from mem0 import MemoryClient -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_fixed -) +from local_mem0_components import HashLangchainEmbeddings from llms import llm_request from prompts import PROMPT_MEMZERO +import mem0.memory.graph_memory as mem0_graph_memory +from langchain_neo4j import Neo4jGraph as LangchainNeo4jGraph -load_dotenv() -# Update custom instructions -custom_instructions = """ -Generate personal memories that follow these guidelines: +class CompatNeo4jGraph(LangchainNeo4jGraph): + def __init__( + self, + url=None, + username=None, + password=None, + database=None, + refresh_schema=True, + **kwargs, + ): + super().__init__( + url=url, + username=username, + password=password, + database=database, + refresh_schema=refresh_schema, + **kwargs, + ) -1. Each memory should be self-contained with complete context, including: - - The person's name, do not use "user" while creating memories - - Personal details (career aspirations, hobbies, life circumstances) - - Emotional states and reactions - - Ongoing journeys or future plans - - Specific dates when events occurred -2. Include meaningful personal narratives focusing on: - - Identity and self-acceptance journeys - - Family planning and parenting - - Creative outlets and hobbies - - Mental health and self-care activities - - Career aspirations and education goals - - Important life events and milestones +mem0_graph_memory.Neo4jGraph = CompatNeo4jGraph -3. Make each memory rich with specific details rather than general statements - - Include timeframes (exact dates when possible) - - Name specific activities (e.g., "charity race for mental health" rather than just "exercise") - - Include emotional context and personal growth elements -4. Extract memories only from user messages, not incorporating assistant responses +def cypher_safe_relationship(value: object) -> str: + relationship = re.sub(r"[^0-9A-Za-z_]+", "_", str(value or "").strip().lower()) + relationship = re.sub(r"_+", "_", relationship).strip("_") + if not relationship: + relationship = "related_to" + if relationship[0].isdigit(): + relationship = f"rel_{relationship}" + return relationship[:128] + + +_original_establish_relations = mem0_graph_memory.MemoryGraph._establish_nodes_relations_from_data + + +def _establish_cypher_safe_relations(self, data, filters, entity_type_map): + relations = _original_establish_relations(self, data, filters, entity_type_map) + for relation in relations: + if isinstance(relation, dict): + relation["relationship"] = cypher_safe_relationship(relation.get("relationship")) + return relations + + +mem0_graph_memory.MemoryGraph._establish_nodes_relations_from_data = _establish_cypher_safe_relations -5. Format each memory as a paragraph with a clear narrative structure that captures the person's experience, challenges, and aspirations -""" TEMPLATE_MEM0_GRAPH = """Memories for user {user_id}: @@ -61,295 +74,405 @@ {graph_memories} """ +CUSTOM_FACT_EXTRACTION_PROMPT = """ +You extract durable personal memories from dialogue for a long-term memory system. -RETRY_TIMES = 10 -WAIT_TIME = 60 +Rules: +- Extract facts only from user messages. +- Keep personal details, preferences, goals, plans, relationships, health/work/life events, and explicit updates. +- Preserve names, dates, places, quantities, and emotional context when present. +- Do not invent facts. +- Return JSON only with this shape: {"facts": ["fact 1", "fact 2"]}. +""" -client = MemoryClient( - api_key=os.getenv("MEM0_API_KEY"), -) -client.project.update(custom_instructions=custom_instructions) +CUSTOM_UPDATE_MEMORY_PROMPT = """ +You update a user's long-term memories. +Given old memories and newly extracted facts, decide whether each new fact should ADD a new memory, UPDATE an old one, DELETE an obsolete one, or do nothing. -@retry( - retry=retry_if_exception_type(Exception), - wait=wait_fixed(WAIT_TIME), - stop=stop_after_attempt(RETRY_TIMES), - reraise=True -) -def add_memory(client, user_id, message, timestamp, retries=3, retry_delay=1.0): - start = time.time() +Return JSON only: +{ + "memory": [ + {"event": "ADD", "text": "new memory"}, + {"event": "UPDATE", "id": "old memory id", "old_memory": "old memory", "text": "updated memory"}, + {"event": "DELETE", "id": "old memory id", "text": "memory to delete"}, + {"event": "NONE", "text": "fact already covered"} + ] +} +""" - result = client.add( - message, - user_id=user_id, - version="v2", - output_format="v1.1", - timestamp=timestamp, - enable_graph=True - ) - duration_ms = (time.time() - start) * 1000 - return result, duration_ms +def parse_json_object(content: str) -> dict: + match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if not match: + match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL) + if match: + return json.loads(match.group(1).strip()) + start = content.find("{") + end = content.rfind("}") + if start >= 0 and end > start: + return json.loads(content[start:end + 1]) + raise ValueError(f"No JSON object found in model output: {content[:500]}") -@retry( - retry=retry_if_exception_type(Exception), - wait=wait_fixed(WAIT_TIME), - stop=stop_after_attempt(RETRY_TIMES), - reraise=True -) -def search_memory( - client, - query, - user_id, - top_k=20 -): - start = time.time() +def iter_jsonl(file_path: Path) -> Iterable[dict]: + with file_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + yield json.loads(line) - results = client.search( - query=query, - top_k=top_k, - user_id=user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - filters={"AND": [{"user_id": f"{user_id}"}]}, - ) - memories = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in results["results"] - ] +def extract_user_name(persona_info: str) -> str: + match = re.search(r"Name:\s*(.*?); Gender:", persona_info) + if not match: + raise ValueError("No name found.") + return match.group(1).strip() + + +def local_openai_config() -> dict: + return { + "model": os.environ["OPENAI_MODEL"], + "api_key": os.environ.get("OPENAI_API_KEY", "local-omlx"), + "openai_base_url": os.environ["OPENAI_BASE_URL"], + "temperature": float(os.getenv("MEM0_GRAPH_LLM_TEMPERATURE", "0")), + "max_tokens": int(os.getenv("MEM0_GRAPH_LLM_MAX_TOKENS", "2048")), + } - memories = [ - f"{item['timestamp']}: {item['memory']}" for item in memories - ] - graph_memories = [ +def make_local_mem0_graph(version: str) -> Memory: + base_path = Path(os.getenv("MEM0_GRAPH_LOCAL_DIR", "runs/mem0-graph-local")) / version + base_path.mkdir(parents=True, exist_ok=True) + embedding_dims = int(os.getenv("MEM0_GRAPH_EMBEDDING_DIMS", "384")) + collection_name = f"halumem_graph_{version}".replace("-", "_") + qdrant_config = { + "collection_name": collection_name, + "embedding_model_dims": embedding_dims, + } + qdrant_url = os.getenv("MEM0_GRAPH_QDRANT_URL", "http://localhost:6333") + if qdrant_url: + qdrant_config["url"] = qdrant_url + qdrant_api_key = os.getenv("MEM0_GRAPH_QDRANT_API_KEY") + if qdrant_api_key: + qdrant_config["api_key"] = qdrant_api_key + else: + qdrant_config["path"] = str(base_path / "qdrant") + qdrant_config["on_disk"] = True + + return Memory.from_config( { - "source": relation["source"], - "relationship": relation["relationship"], - "target": relation["target"], + "version": "v1.1", + "history_db_path": str(base_path / "history.db"), + "llm": { + "provider": "openai", + "config": local_openai_config(), + }, + "embedder": { + "provider": "langchain", + "config": { + "model": HashLangchainEmbeddings(embedding_dims=embedding_dims), + "embedding_dims": embedding_dims, + }, + }, + "vector_store": { + "provider": "qdrant", + "config": qdrant_config, + }, + "graph_store": { + "provider": "neo4j", + "config": { + "url": os.getenv("MEM0_GRAPH_NEO4J_URL", "bolt://localhost:7687"), + "username": os.getenv("MEM0_GRAPH_NEO4J_USERNAME", "neo4j"), + "password": os.getenv("MEM0_GRAPH_NEO4J_PASSWORD", "12345678"), + "database": os.getenv("MEM0_GRAPH_NEO4J_DATABASE", "neo4j"), + "base_label": True, + }, + "llm": { + "provider": "openai", + "config": local_openai_config(), + }, + "custom_prompt": os.getenv( + "MEM0_GRAPH_CUSTOM_PROMPT", + "Extract compact personal entities and relationships only from user-provided facts.", + ), + }, + "custom_fact_extraction_prompt": CUSTOM_FACT_EXTRACTION_PROMPT, + "custom_update_memory_prompt": CUSTOM_UPDATE_MEMORY_PROMPT, } - for relation in results["relations"] - ] - - context = TEMPLATE_MEM0_GRAPH.format( - user_id=user_id, - memories=json.dumps(memories, indent=4), - graph_memories=json.dumps(graph_memories, indent=4) ) - duration_ms = (time.time() - start) * 1000 - return context, memories, duration_ms +def extract_facts_with_gemma(message: list[dict], max_facts: int) -> list[str]: + if max_facts <= 0: + return [] + dialogue = "\n".join( + f"{turn.get('role', '')}: {' '.join(str(turn.get('content', '')).split())}" + for turn in message + if turn.get("role") != "system" and str(turn.get("content", "")).strip() + ) + prompt = f"""{CUSTOM_FACT_EXTRACTION_PROMPT} -def extract_user_name(persona_info: str): - match = re.search(r'Name:\s*(.*?); Gender:', persona_info) +Extract at most {max_facts} facts. - if match: - username = match.group(1).strip() - return username +Conversation: +{dialogue} +""" + try: + data = parse_json_object(llm_request(prompt)) + facts = data.get("facts", []) + except Exception as exc: + print(f"[memzero-graph] fact extraction failed: {exc}") + return [] + + normalized = [] + seen = set() + for fact in facts: + item = " ".join(str(fact).split()).strip(" -") + if not item or item in seen: + continue + seen.add(item) + normalized.append(item) + if len(normalized) >= max_facts: + break + return normalized + + +def add_memory(client: Memory, user_id: str, message: list[dict], max_facts: int) -> tuple[dict, float]: + start = time.time() + facts = extract_facts_with_gemma(message, max_facts) + if facts: + result = client.add( + [{"role": "user", "content": fact} for fact in facts], + user_id=user_id, + infer=False, + ) else: - raise ValueError("No name found.") - + result = {"results": [], "relations": None} + return result, (time.time() - start) * 1000 + + +def normalize_relations(relations) -> list[dict]: + if not relations: + return [] + + if isinstance(relations, dict): + added = relations.get("added_entities") or [] + deleted = relations.get("deleted_entities") or [] + relations = added + deleted + + normalized = [] + for relation in relations: + if not isinstance(relation, dict): + continue + source = relation.get("source") or relation.get("source_name") + relationship = relation.get("relationship") + target = relation.get("target") or relation.get("destination") or relation.get("dest_name") + if source and relationship and target: + normalized.append( + { + "source": source, + "relationship": relationship, + "target": target, + } + ) + return normalized -def process_user(user_data, top_k, save_path): - user_name = extract_user_name(user_data["persona_info"]) - sessions = user_data["sessions"] +def search_memory(client: Memory, query: str, user_id: str, top_k: int = 20) -> tuple[str, list[str], float]: + start = time.time() + results = client.search(query=query, limit=top_k, user_id=user_id) + + memories = [] + for memory in results.get("results", []): + timestamp = memory.get("updated_at") or memory.get("created_at") or "" + memories.append(f"{timestamp}: {memory.get('memory', '')}") - tmp_dir = os.path.join(save_path, "tmp") - os.makedirs(tmp_dir, exist_ok=True) + graph_memories = normalize_relations(results.get("relations")) - tmp_file = os.path.join(tmp_dir, f"{user_data['uuid']}.json") + context = TEMPLATE_MEM0_GRAPH.format( + user_id=user_id, + memories=json.dumps(memories, indent=4, ensure_ascii=False), + graph_memories=json.dumps(graph_memories, indent=4, ensure_ascii=False), + ) + return context, memories, (time.time() - start) * 1000 + + +def process_user( + client: Memory, + user_data: dict, + *, + top_k: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> dict: + user_name = extract_user_name(user_data["persona_info"]) + sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] - client.delete_all(user_id=user_name) + try: + client.delete_all(user_id=user_name) + except Exception: + pass new_user_data = { "uuid": user_data["uuid"], "user_name": user_name, - "sessions": [] + "sessions": [], } - try: + for session in tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"): + new_session = { + "memory_points": copy.deepcopy(session["memory_points"]), + "dialogue": session["dialogue"], + } - for session in tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"): - new_session = { - "memory_points": session["memory_points"], - "dialogue": session["dialogue"] + formatted_dialogue = [ + { + "role": turn["role"], + "content": turn["content"], } + for turn in session["dialogue"] + ] + + result, duration_ms = add_memory( + client, + user_name, + formatted_dialogue, + max_facts=max_extracted_memories_per_session, + ) + new_session["add_dialogue_duration_ms"] = duration_ms + + if session.get("is_generated_qa_session", False): + new_session["is_generated_qa_session"] = True + del new_session["dialogue"] + del new_session["memory_points"] + new_user_data["sessions"].append(new_session) + continue - # add messages - dialogue = session["dialogue"] - formatted_dialogue = [ - { - "role": turn["role"], - "content": turn["content"], - } - for turn in dialogue - ] - - date_format = "%b %d, %Y, %H:%M:%S" - dt = datetime.strptime(session["start_time"], date_format).replace(tzinfo=timezone.utc) - iso_date = dt.isoformat() - timestamp = int(dt.timestamp()) - - result, duration_ms = add_memory( - client=client, - user_id=user_name, - message=formatted_dialogue, - timestamp=timestamp - ) - - if session.get('is_generated_qa_session', False): - new_session["add_dialogue_duration_ms"] = duration_ms - new_session["is_generated_qa_session"] = True - del new_session["dialogue"] - del new_session["memory_points"] - new_user_data["sessions"].append(new_session) - continue + new_session["extracted_memories"] = [ + item["memory"] + for item in result.get("results", []) + if item.get("memory") + ] + new_session["graph_relations"] = normalize_relations(result.get("relations")) - extracted_memories = [ - item["memory"] for item in result["results"] - ] - new_session["extracted_memories"] = extracted_memories - new_session["add_dialogue_duration_ms"] = duration_ms - - # search updated memories - for memory in new_session["memory_points"]: - if memory["is_update"] == "False" or not memory["original_memories"]: - continue - - _, memories_from_system, duration_ms = search_memory( - client=client, - query=memory["memory_content"], - user_id=user_name, - top_k=10 - ) - - memory["memories_from_system"] = memories_from_system - - # search and query - if "questions" not in session: - new_user_data["sessions"].append(new_session) + for memory in new_session["memory_points"]: + if memory["is_update"] == "False" or not memory["original_memories"]: continue + _, memories_from_system, duration_ms = search_memory( + client, + memory["memory_content"], + user_name, + top_k=10, + ) + memory["memories_from_system"] = memories_from_system + memory["search_duration_ms"] = duration_ms - new_session["questions"] = [] - - for qa in session["questions"]: - - context, _, duration_ms = search_memory( - client=client, - query=qa["question"], - user_id=user_name, - top_k=top_k - ) - - new_qa = copy.deepcopy(qa) - new_qa["context"] = context - new_qa["search_duration_ms"] = duration_ms + if "questions" not in session: + new_user_data["sessions"].append(new_session) + continue - prompt = PROMPT_MEMZERO.format( - context=context, - question=qa["question"] - ) + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] - start_time = time.time() - response = llm_request(prompt) - new_qa["system_response"] = response - new_qa["response_duration_ms"] = (time.time() - start_time) * 1000 + new_session["questions"] = [] + for qa in questions: + context, _, duration_ms = search_memory(client, qa["question"], user_name, top_k=top_k) - new_session["questions"].append(new_qa) - - new_user_data["sessions"].append(new_session) - - with open(tmp_file, "w", encoding="utf-8") as f: - json.dump(new_user_data, f, ensure_ascii=False) + prompt = PROMPT_MEMZERO.format( + context=context, + question=qa["question"], + ) - print(f"✅ Saved user {user_name} to {tmp_file}") - return {"uuid": user_data["uuid"], "status": "ok", "path": tmp_file} + start_response = time.time() + response = llm_request(prompt) - except Exception as e: - error_path = os.path.join(tmp_dir, f"{user_data['uuid']}_error.log") - with open(error_path, "w", encoding="utf-8") as f: - f.write(traceback.format_exc()) - print(f"❌ Error in user {user_name}: {e}") - return {"uuid": user_data["uuid"], "status": "error", "path": error_path} + new_qa = copy.deepcopy(qa) + new_qa["context"] = context + new_qa["search_duration_ms"] = duration_ms + new_qa["system_response"] = response + new_qa["response_duration_ms"] = (time.time() - start_response) * 1000 + new_session["questions"].append(new_qa) + new_user_data["sessions"].append(new_session) -def iter_jsonl(file_path): - with open(file_path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if line: - yield json.loads(line) + return new_user_data def main( - data_path: str, - version: str = "default", - top_k: int = 20, - max_workers: int = 2 -): + *, + data_path: Path, + version: str, + top_k: int, + user_num: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> Path: frame = "memzero-graph" - save_path = f"results/{frame}-{version}/" - os.makedirs(save_path, exist_ok=True) - - output_file = os.path.join(save_path, f"{frame}_eval_results.jsonl") - tmp_dir = os.path.join(save_path, "tmp") - os.makedirs(tmp_dir, exist_ok=True) - - start_time = time.time() - - with ProcessPoolExecutor(max_workers=max_workers) as executor: - futures = {} - for idx, user_data in enumerate(iter_jsonl(data_path), 1): - uuid = user_data["uuid"] - future = executor.submit(process_user, user_data, top_k, save_path) - futures[future] = uuid - - total_users = idx - - for i, future in enumerate(as_completed(futures), 1): - uuid = futures[future] - try: - result = future.result() - print(f"[{i}/{total_users}] ✅ Finished {uuid} ({result['status']})") - except Exception as e: - print(f"[{i}/{total_users}] ❌ Error processing {uuid}: {e}") - traceback.print_exc() - - with open(output_file, "a", encoding="utf-8") as f_out: - for file in os.listdir(tmp_dir): - if file.endswith(".json"): - file_path = os.path.join(tmp_dir, file) - try: - with open(file_path, "r", encoding="utf-8") as f_in: - data = json.load(f_in) - f_out.write(json.dumps(data, ensure_ascii=False) + "\n") - except Exception as e: - print(f"⚠️ Skipped {file}: {e}") - - elapsed = time.time() - start_time - print(f"✅ All done in {elapsed:.2f}s") - print(f"✅ Final results saved to: {output_file}") + save_path = Path("results") / f"{frame}-{version}" + save_path.mkdir(parents=True, exist_ok=True) + output_file = save_path / f"{frame}_eval_results.jsonl" + tmp_dir = save_path / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + if output_file.exists(): + output_file.unlink() + + client = make_local_mem0_graph(version) + try: + client.reset() + except Exception: + pass + + processed = 0 + for user_data in iter_jsonl(data_path): + processed += 1 + result = process_user( + client, + user_data, + top_k=top_k, + max_sessions=max_sessions, + max_questions_per_session=max_questions_per_session, + max_extracted_memories_per_session=max_extracted_memories_per_session, + ) + + tmp_file = tmp_dir / f"{user_data['uuid']}.json" + tmp_file.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8") + with output_file.open("a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"[{processed}/{user_num}] Saved {tmp_file}") + + if processed >= user_num: + break + + print(f"Final results saved to: {output_file}") + return output_file if __name__ == "__main__": - data_path = "../data/HaluMem-long.jsonl" - version = "long" - top_k = 20 + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=Path, required=True) + parser.add_argument("--version", required=True) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--user-num", type=int, default=1) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + args = parser.parse_args() main( - data_path=data_path, - version=version, - top_k=top_k - ) \ No newline at end of file + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + user_num=args.user_num, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, + max_extracted_memories_per_session=args.max_extracted_memories_per_session, + ) + os._exit(0) diff --git a/eval/eval_memzero_local.py b/eval/eval_memzero_local.py new file mode 100644 index 0000000..5cd3c81 --- /dev/null +++ b/eval/eval_memzero_local.py @@ -0,0 +1,362 @@ +import argparse +import copy +import json +import os +import re +import time +from pathlib import Path +from typing import Iterable + +from mem0 import Memory +from tqdm import tqdm + +from local_mem0_components import HashLangchainEmbeddings +from llms import llm_request +from prompts import PROMPT_MEMZERO + + +TEMPLATE_MEM0_LOCAL = """Memories for user {user_id}: + + {memories} +""" + +CUSTOM_FACT_EXTRACTION_PROMPT = """ +You extract durable personal memories from dialogue for a long-term memory system. + +Rules: +- Extract facts only from user messages. +- Keep personal details, preferences, goals, plans, relationships, health/work/life events, and explicit updates. +- Preserve names, dates, places, quantities, and emotional context when present. +- Do not invent facts. +- Return JSON only with this shape: {"facts": ["fact 1", "fact 2"]}. +""" + +CUSTOM_UPDATE_MEMORY_PROMPT = """ +You update a user's long-term memories. + +Given old memories and newly extracted facts, decide whether each new fact should ADD a new memory, UPDATE an old one, DELETE an obsolete one, or do nothing. + +Return JSON only: +{ + "memory": [ + {"event": "ADD", "text": "new memory"}, + {"event": "UPDATE", "id": "old memory id", "old_memory": "old memory", "text": "updated memory"}, + {"event": "DELETE", "id": "old memory id", "text": "memory to delete"}, + {"event": "NONE", "text": "fact already covered"} + ] +} +""" + + +def parse_json_object(content: str) -> dict: + match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if not match: + match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL) + if match: + return json.loads(match.group(1).strip()) + + start = content.find("{") + end = content.rfind("}") + if start >= 0 and end > start: + return json.loads(content[start:end + 1]) + raise ValueError(f"No JSON object found in model output: {content[:500]}") + + +def iter_jsonl(file_path: Path) -> Iterable[dict]: + with file_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + yield json.loads(line) + + +def extract_user_name(persona_info: str) -> str: + match = re.search(r"Name:\s*(.*?); Gender:", persona_info) + if not match: + raise ValueError("No name found.") + return match.group(1).strip() + + +def make_local_mem0(version: str) -> Memory: + base_path = Path(os.getenv("MEM0_LOCAL_DIR", "runs/mem0-local")) / version + base_path.mkdir(parents=True, exist_ok=True) + + return Memory.from_config( + { + "version": "v1.1", + "history_db_path": str(base_path / "history.db"), + "llm": { + "provider": "openai", + "config": { + "model": os.environ["OPENAI_MODEL"], + "api_key": os.environ.get("OPENAI_API_KEY", "local-omlx"), + "openai_base_url": os.environ["OPENAI_BASE_URL"], + "temperature": float(os.getenv("MEM0_LOCAL_LLM_TEMPERATURE", "0")), + "max_tokens": int(os.getenv("MEM0_LOCAL_LLM_MAX_TOKENS", "2048")), + }, + }, + "embedder": { + "provider": "langchain", + "config": { + "model": HashLangchainEmbeddings( + embedding_dims=int(os.getenv("MEM0_LOCAL_EMBEDDING_DIMS", "384")), + ), + "embedding_dims": int(os.getenv("MEM0_LOCAL_EMBEDDING_DIMS", "384")), + }, + }, + "vector_store": { + "provider": "qdrant", + "config": { + "collection_name": f"halumem_{version}".replace("-", "_"), + "embedding_model_dims": int(os.getenv("MEM0_LOCAL_EMBEDDING_DIMS", "384")), + "path": str(base_path / "qdrant"), + "on_disk": True, + }, + }, + "custom_fact_extraction_prompt": CUSTOM_FACT_EXTRACTION_PROMPT, + "custom_update_memory_prompt": CUSTOM_UPDATE_MEMORY_PROMPT, + } + ) + + +def extract_facts_with_gemma(message: list[dict], max_facts: int) -> list[str]: + if max_facts <= 0: + return [] + + dialogue = "\n".join( + f"{turn.get('role', '')}: {' '.join(str(turn.get('content', '')).split())}" + for turn in message + if turn.get("role") != "system" and str(turn.get("content", "")).strip() + ) + prompt = f"""{CUSTOM_FACT_EXTRACTION_PROMPT} + +Extract at most {max_facts} facts. + +Conversation: +{dialogue} +""" + try: + data = parse_json_object(llm_request(prompt)) + facts = data.get("facts", []) + except Exception as exc: + print(f"[memzero-local] fact extraction failed: {exc}") + return [] + + normalized = [] + seen = set() + for fact in facts: + item = " ".join(str(fact).split()).strip(" -") + if not item or item in seen: + continue + seen.add(item) + normalized.append(item) + if len(normalized) >= max_facts: + break + return normalized + + +def add_memory(client: Memory, user_id: str, message: list[dict], max_facts: int) -> tuple[dict, float]: + start = time.time() + facts = extract_facts_with_gemma(message, max_facts) + if facts: + result = client.add( + [{"role": "user", "content": fact} for fact in facts], + user_id=user_id, + infer=False, + ) + else: + result = {"results": []} + return result, (time.time() - start) * 1000 + + +def search_memory(client: Memory, query: str, user_id: str, top_k: int = 20) -> tuple[str, list[str], float]: + start = time.time() + results = client.search(query=query, limit=top_k, user_id=user_id) + memories = [] + for memory in results.get("results", []): + timestamp = memory.get("updated_at") or memory.get("created_at") or "" + memories.append(f"{timestamp}: {memory.get('memory', '')}") + + context = TEMPLATE_MEM0_LOCAL.format( + user_id=user_id, + memories=json.dumps(memories, indent=4), + ) + return context, memories, (time.time() - start) * 1000 + + +def process_user( + client: Memory, + user_data: dict, + *, + top_k: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> dict: + user_name = extract_user_name(user_data["persona_info"]) + sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] + + try: + client.delete_all(user_id=user_name) + except Exception: + pass + + new_user_data = { + "uuid": user_data["uuid"], + "user_name": user_name, + "sessions": [], + } + + for session in tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"): + new_session = { + "memory_points": copy.deepcopy(session["memory_points"]), + "dialogue": session["dialogue"], + } + + formatted_dialogue = [ + { + "role": turn["role"], + "content": turn["content"], + } + for turn in session["dialogue"] + ] + + result, duration_ms = add_memory( + client, + user_name, + formatted_dialogue, + max_facts=max_extracted_memories_per_session, + ) + new_session["add_dialogue_duration_ms"] = duration_ms + + if session.get("is_generated_qa_session", False): + new_session["is_generated_qa_session"] = True + del new_session["dialogue"] + del new_session["memory_points"] + new_user_data["sessions"].append(new_session) + continue + + new_session["extracted_memories"] = [ + item["memory"] + for item in result.get("results", []) + if item.get("memory") + ] + + for memory in new_session["memory_points"]: + if memory["is_update"] == "False" or not memory["original_memories"]: + continue + _, memories_from_system, duration_ms = search_memory( + client, + memory["memory_content"], + user_name, + top_k=10, + ) + memory["memories_from_system"] = memories_from_system + memory["search_duration_ms"] = duration_ms + + if "questions" not in session: + new_user_data["sessions"].append(new_session) + continue + + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] + + new_session["questions"] = [] + for qa in questions: + context, _, duration_ms = search_memory(client, qa["question"], user_name, top_k=top_k) + + prompt = PROMPT_MEMZERO.format( + context=context, + question=qa["question"], + ) + + start_response = time.time() + response = llm_request(prompt) + + new_qa = copy.deepcopy(qa) + new_qa["context"] = context + new_qa["search_duration_ms"] = duration_ms + new_qa["system_response"] = response + new_qa["response_duration_ms"] = (time.time() - start_response) * 1000 + new_session["questions"].append(new_qa) + + new_user_data["sessions"].append(new_session) + + return new_user_data + + +def main( + *, + data_path: Path, + version: str, + top_k: int, + user_num: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> Path: + frame = "memzero-local" + save_path = Path("results") / f"{frame}-{version}" + save_path.mkdir(parents=True, exist_ok=True) + output_file = save_path / f"{frame}_eval_results.jsonl" + tmp_dir = save_path / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + if output_file.exists(): + output_file.unlink() + + client = make_local_mem0(version) + try: + client.reset() + except Exception: + pass + + processed = 0 + for user_data in iter_jsonl(data_path): + processed += 1 + result = process_user( + client, + user_data, + top_k=top_k, + max_sessions=max_sessions, + max_questions_per_session=max_questions_per_session, + max_extracted_memories_per_session=max_extracted_memories_per_session, + ) + + tmp_file = tmp_dir / f"{user_data['uuid']}.json" + tmp_file.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8") + with output_file.open("a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"[{processed}/{user_num}] Saved {tmp_file}") + + if processed >= user_num: + break + + print(f"Final results saved to: {output_file}") + return output_file + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=Path, required=True) + parser.add_argument("--version", required=True) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--user-num", type=int, default=1) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + args = parser.parse_args() + + main( + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + user_num=args.user_num, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, + max_extracted_memories_per_session=args.max_extracted_memories_per_session, + ) + os._exit(0) diff --git a/eval/eval_supermemory.py b/eval/eval_supermemory.py index 346a593..b7a4ecc 100644 --- a/eval/eval_supermemory.py +++ b/eval/eval_supermemory.py @@ -4,6 +4,7 @@ import math import json import copy +import argparse import requests import traceback from datetime import datetime, timezone @@ -338,12 +339,16 @@ def main( if __name__ == "__main__": - data_path = "../data/HaluMem-Medium.jsonl" - version = "default" - top_k = 20 + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="../data/HaluMem-Medium.jsonl") + parser.add_argument("--version", default="default") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--max-workers", type=int, default=2) + args = parser.parse_args() main( - data_path=data_path, - version=version, - top_k=top_k - ) \ No newline at end of file + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + max_workers=args.max_workers, + ) diff --git a/eval/eval_supermemory_local.py b/eval/eval_supermemory_local.py new file mode 100644 index 0000000..9f2e1be --- /dev/null +++ b/eval/eval_supermemory_local.py @@ -0,0 +1,372 @@ +import argparse +import copy +import json +import os +import re +import time +from pathlib import Path +from typing import Iterable + +from supermemory import Supermemory +from tqdm import tqdm + +from llms import llm_request +from prompts import PROMPT_MEMOS + + +TEMPLATE_SUPERMEMORY_LOCAL = """Memories for user {user_id}: + + {memories} +""" + +CUSTOM_FACT_EXTRACTION_PROMPT = """ +You extract durable personal memories from dialogue for a long-term memory system. + +Rules: +- Extract facts only from user messages. +- Keep personal details, preferences, goals, plans, relationships, health/work/life events, and explicit updates. +- Preserve names, dates, places, quantities, and emotional context when present. +- Do not invent facts. +- Return JSON only with this shape: {"facts": ["fact 1", "fact 2"]}. +""" + + +def env_bool(name: str, default: bool) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def parse_json_object(content: str) -> dict: + match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if not match: + match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL) + if match: + return json.loads(match.group(1).strip()) + + start = content.find("{") + end = content.rfind("}") + if start >= 0 and end > start: + return json.loads(content[start:end + 1]) + raise ValueError(f"No JSON object found in model output: {content[:500]}") + + +def iter_jsonl(file_path: Path) -> Iterable[dict]: + with file_path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + yield json.loads(line) + + +def extract_user_name(persona_info: str) -> str: + match = re.search(r"Name:\s*(.*?); Gender:", persona_info) + if not match: + raise ValueError("No name found.") + return match.group(1).strip() + + +def make_client() -> Supermemory: + return Supermemory( + api_key=os.getenv("SUPERMEMORY_API_KEY", "local"), + base_url=os.getenv("SUPERMEMORY_BASE_URL", "http://127.0.0.1:6767"), + ) + + +def extract_facts_with_gemma(dialogue: list[dict], max_facts: int) -> list[str]: + if max_facts <= 0: + return [] + + conversation = "\n".join( + f"{turn.get('role', '')}: {' '.join(str(turn.get('content', '')).split())}" + for turn in dialogue + if turn.get("role") != "system" and str(turn.get("content", "")).strip() + ) + prompt = f"""{CUSTOM_FACT_EXTRACTION_PROMPT} + +Extract at most {max_facts} facts. + +Conversation: +{conversation} +""" + try: + data = parse_json_object(llm_request(prompt)) + facts = data.get("facts", []) + except Exception as exc: + print(f"[supermemory-local] fact extraction failed: {exc}") + return [] + + normalized = [] + seen = set() + for fact in facts: + item = " ".join(str(fact).split()).strip(" -") + if not item or item in seen: + continue + seen.add(item) + normalized.append(item) + if len(normalized) >= max_facts: + break + return normalized + + +def wait_for_document(client: Supermemory, document_id: str) -> tuple[str, float]: + start = time.time() + timeout_s = float(os.getenv("SUPERMEMORY_LOCAL_INDEX_TIMEOUT", "12")) + poll_s = float(os.getenv("SUPERMEMORY_LOCAL_INDEX_POLL", "2")) + stop_statuses = set( + item.strip() + for item in os.getenv("SUPERMEMORY_LOCAL_INDEX_READY_STATUSES", "embedding,indexing,done,failed").split(",") + if item.strip() + ) + status = "unknown" + + while time.time() - start < timeout_s: + response = client.memories.get(document_id) + status = response.model_dump().get("status") or getattr(response, "status", "unknown") + if status in stop_statuses: + break + time.sleep(poll_s) + + return status, (time.time() - start) * 1000 + + +def add_memory( + client: Supermemory, + user_id: str, + dialogue: list[dict], + conv_id: str, + max_facts: int, +) -> tuple[list[dict], float]: + start = time.time() + facts = extract_facts_with_gemma(dialogue, max_facts) + if not facts: + return [], (time.time() - start) * 1000 + + response = client.memories.add( + content="\n".join(f"- {fact}" for fact in facts), + container_tag=user_id, + metadata={ + "conv_id": conv_id, + "fact_count": len(facts), + "source": "gemma4_e4b_extracted_facts", + }, + ) + status, index_duration_ms = wait_for_document(client, response.id) + documents = [ + { + "id": response.id, + "memory": fact, + "status": status, + "index_duration_ms": index_duration_ms, + } + for fact in facts + ] + + return documents, (time.time() - start) * 1000 + + +def search_memory( + client: Supermemory, + query: str, + user_id: str, + top_k: int = 20, +) -> tuple[str, list[str], float]: + start = time.time() + response = client.search.documents( + q=query, + container_tags=[user_id], + limit=top_k, + rerank=env_bool("SUPERMEMORY_LOCAL_RERANK", False), + rewrite_query=env_bool("SUPERMEMORY_LOCAL_REWRITE_QUERY", False), + chunk_threshold=float(os.getenv("SUPERMEMORY_LOCAL_CHUNK_THRESHOLD", "0")), + document_threshold=float(os.getenv("SUPERMEMORY_LOCAL_DOCUMENT_THRESHOLD", "0")), + include_full_docs=True, + only_matching_chunks=env_bool("SUPERMEMORY_LOCAL_ONLY_MATCHING_CHUNKS", False), + ) + + memories = [] + seen = set() + for item in response.model_dump().get("results", []): + text = item.get("content") or "" + if not text: + chunks = item.get("chunks") or [] + text = "\n".join(chunk.get("content", "") for chunk in chunks if chunk.get("content")) + text = " ".join(str(text).split()) + if not text or text in seen: + continue + seen.add(text) + timestamp = item.get("updated_at") or item.get("updatedAt") or item.get("created_at") or item.get("createdAt") or "" + memories.append(f"{timestamp}: {text}" if timestamp else text) + + context = TEMPLATE_SUPERMEMORY_LOCAL.format( + user_id=user_id, + memories=json.dumps(memories, indent=4, ensure_ascii=False), + ) + return context, memories, (time.time() - start) * 1000 + + +def process_user( + client: Supermemory, + user_data: dict, + *, + top_k: int, + version: str, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> dict: + user_name = f"{extract_user_name(user_data['persona_info'])}_{version}".replace(" ", "_") + sessions = user_data["sessions"] + if max_sessions > 0: + sessions = sessions[:max_sessions] + + new_user_data = { + "uuid": user_data["uuid"], + "user_name": user_name, + "sessions": [], + } + + for session_id, session in enumerate(tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"), 1): + new_session = { + "memory_points": copy.deepcopy(session["memory_points"]), + "dialogue": session["dialogue"], + } + + dialogue = [ + { + "role": turn["role"], + "content": turn["content"], + "timestamp": turn.get("timestamp", ""), + } + for turn in session["dialogue"] + ] + documents, duration_ms = add_memory( + client, + user_name, + dialogue, + conv_id=f"{session_id}_{user_name}", + max_facts=max_extracted_memories_per_session, + ) + new_session["response_id_ls"] = [item["id"] for item in documents] + new_session["supermemory_document_statuses"] = [item["status"] for item in documents] + new_session["add_dialogue_duration_ms"] = duration_ms + + if session.get("is_generated_qa_session", False): + new_session["is_generated_qa_session"] = True + del new_session["dialogue"] + del new_session["memory_points"] + new_user_data["sessions"].append(new_session) + continue + + new_session["extracted_memories"] = [item["memory"] for item in documents] + + for memory in new_session["memory_points"]: + if memory["is_update"] == "False" or not memory["original_memories"]: + continue + _, memories_from_system, duration_ms = search_memory( + client, + memory["memory_content"], + user_name, + top_k=10, + ) + memory["memories_from_system"] = memories_from_system + memory["search_duration_ms"] = duration_ms + + if "questions" not in session: + new_user_data["sessions"].append(new_session) + continue + + questions = session["questions"] + if max_questions_per_session > 0: + questions = questions[:max_questions_per_session] + + new_session["questions"] = [] + for qa in questions: + context, _, duration_ms = search_memory(client, qa["question"], user_name, top_k=top_k) + prompt = PROMPT_MEMOS.format(context=context, question=qa["question"]) + + start_response = time.time() + response = llm_request(prompt) + + new_qa = copy.deepcopy(qa) + new_qa["context"] = context + new_qa["search_duration_ms"] = duration_ms + new_qa["system_response"] = response + new_qa["response_duration_ms"] = (time.time() - start_response) * 1000 + new_session["questions"].append(new_qa) + + new_user_data["sessions"].append(new_session) + + return new_user_data + + +def main( + *, + data_path: Path, + version: str, + top_k: int, + user_num: int, + max_sessions: int, + max_questions_per_session: int, + max_extracted_memories_per_session: int, +) -> Path: + frame = "supermemory-local" + save_path = Path("results") / f"{frame}-{version}" + save_path.mkdir(parents=True, exist_ok=True) + output_file = save_path / f"{frame}_eval_results.jsonl" + tmp_dir = save_path / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + if output_file.exists(): + output_file.unlink() + + client = make_client() + + processed = 0 + for user_data in iter_jsonl(data_path): + processed += 1 + result = process_user( + client, + user_data, + top_k=top_k, + version=version, + max_sessions=max_sessions, + max_questions_per_session=max_questions_per_session, + max_extracted_memories_per_session=max_extracted_memories_per_session, + ) + + tmp_file = tmp_dir / f"{user_data['uuid']}.json" + tmp_file.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8") + with output_file.open("a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + print(f"[{processed}/{user_num}] Saved {tmp_file}") + + if processed >= user_num: + break + + print(f"Final results saved to: {output_file}") + return output_file + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=Path, required=True) + parser.add_argument("--version", required=True) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--user-num", type=int, default=1) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + args = parser.parse_args() + + main( + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + user_num=args.user_num, + max_sessions=args.max_sessions, + max_questions_per_session=args.max_questions_per_session, + max_extracted_memories_per_session=args.max_extracted_memories_per_session, + ) + os._exit(0) diff --git a/eval/eval_zep.py b/eval/eval_zep.py index 82b5162..cd92947 100644 --- a/eval/eval_zep.py +++ b/eval/eval_zep.py @@ -4,6 +4,7 @@ import uuid import json import copy +import argparse import traceback from typing import Literal from datetime import datetime, timezone @@ -538,14 +539,18 @@ def main( if __name__ == "__main__": - data_path = "../data/HaluMem-medium.jsonl" - version = "default" - top_k = 20 - run_task = "add" + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="../data/HaluMem-medium.jsonl") + parser.add_argument("--version", default="default") + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--max-workers", type=int, default=2) + parser.add_argument("--run-task", choices=["add", "search"], default="add") + args = parser.parse_args() main( - data_path=data_path, - version=version, - top_k=top_k, - run_task=run_task - ) \ No newline at end of file + data_path=args.data_path, + version=args.version, + top_k=args.top_k, + max_workers=args.max_workers, + run_task=args.run_task, + ) diff --git a/eval/evaluation.py b/eval/evaluation.py index 696e838..5da2043 100644 --- a/eval/evaluation.py +++ b/eval/evaluation.py @@ -30,6 +30,12 @@ def compute_f1(precision: float, recall: float) -> float: return 2 * (precision * recall) / (precision + recall) +def safe_div(numerator: float, denominator: float) -> float: + if denominator == 0: + return 0.0 + return numerator / denominator + + def process_user(idx: int, user_data: dict, max_workers: int = 10): uuid = user_data["uuid"] user_name = user_data["user_name"] @@ -236,16 +242,16 @@ def aggregate_eval_results(eval_results): interference_memory_scores += 1 interference_memory_valid_num += 1 - eval_results["overall_score"]["memory_integrity"]["recall(all)"] = memory_integrity_scores / memory_integrity_num - eval_results["overall_score"]["memory_integrity"]["recall(valid)"] = memory_integrity_scores / memory_integrity_valid_num - eval_results["overall_score"]["memory_integrity"]["weighted_recall(all)"] = memory_integrity_weighted_scores / memory_integrity_weighted_num - eval_results["overall_score"]["memory_integrity"]["weighted_recall(valid)"] = memory_integrity_weighted_scores / memory_integrity_weighted_valid_num + eval_results["overall_score"]["memory_integrity"]["recall(all)"] = safe_div(memory_integrity_scores, memory_integrity_num) + eval_results["overall_score"]["memory_integrity"]["recall(valid)"] = safe_div(memory_integrity_scores, memory_integrity_valid_num) + eval_results["overall_score"]["memory_integrity"]["weighted_recall(all)"] = safe_div(memory_integrity_weighted_scores, memory_integrity_weighted_num) + eval_results["overall_score"]["memory_integrity"]["weighted_recall(valid)"] = safe_div(memory_integrity_weighted_scores, memory_integrity_weighted_valid_num) eval_results["overall_score"]["memory_integrity"]["memory_valid_importance_sum"] = memory_integrity_weighted_valid_num eval_results["overall_score"]["memory_integrity"]["memory_importance_sum"] = memory_integrity_weighted_num eval_results["overall_score"]["memory_integrity"]["memory_valid_num"] = memory_integrity_valid_num eval_results["overall_score"]["memory_integrity"]["memory_num"] = memory_integrity_num - eval_results["overall_score"]["memory_accuracy"]["interference_accuracy(all)"] = interference_memory_scores / interference_memory_num - eval_results["overall_score"]["memory_accuracy"]["interference_accuracy(valid)"] = interference_memory_scores / interference_memory_valid_num + eval_results["overall_score"]["memory_accuracy"]["interference_accuracy(all)"] = safe_div(interference_memory_scores, interference_memory_num) + eval_results["overall_score"]["memory_accuracy"]["interference_accuracy(valid)"] = safe_div(interference_memory_scores, interference_memory_valid_num) eval_results["overall_score"]["memory_accuracy"]["interference_memory_valid_num"] = interference_memory_valid_num eval_results["overall_score"]["memory_accuracy"]["interference_memory_num"] = interference_memory_num @@ -276,12 +282,12 @@ def aggregate_eval_results(eval_results): memory_accuracy_weighted_scores += 0.5 * item["memory_accuracy_score"] memory_accuracy_valid_num += 1 - eval_results["overall_score"]["memory_accuracy"]["target_accuracy(all)"] = target_memory_accuracy_scores / target_memory_accuracy_num - eval_results["overall_score"]["memory_accuracy"]["target_accuracy(valid)"] = target_memory_accuracy_scores / target_memory_accuracy_valid_num + eval_results["overall_score"]["memory_accuracy"]["target_accuracy(all)"] = safe_div(target_memory_accuracy_scores, target_memory_accuracy_num) + eval_results["overall_score"]["memory_accuracy"]["target_accuracy(valid)"] = safe_div(target_memory_accuracy_scores, target_memory_accuracy_valid_num) eval_results["overall_score"]["memory_accuracy"]["target_memory_valid_num"] = target_memory_accuracy_valid_num eval_results["overall_score"]["memory_accuracy"]["target_memory_num"] = target_memory_accuracy_num - eval_results["overall_score"]["memory_accuracy"]["weighted_accuracy(all)"] = memory_accuracy_weighted_scores / memory_accuracy_num - eval_results["overall_score"]["memory_accuracy"]["weighted_accuracy(valid)"] = memory_accuracy_weighted_scores / memory_accuracy_valid_num + eval_results["overall_score"]["memory_accuracy"]["weighted_accuracy(all)"] = safe_div(memory_accuracy_weighted_scores, memory_accuracy_num) + eval_results["overall_score"]["memory_accuracy"]["weighted_accuracy(valid)"] = safe_div(memory_accuracy_weighted_scores, memory_accuracy_valid_num) eval_results["overall_score"]["memory_accuracy"]["memory_valid_num"] = memory_accuracy_valid_num eval_results["overall_score"]["memory_accuracy"]["memory_num"] = memory_accuracy_num @@ -318,14 +324,14 @@ def aggregate_eval_results(eval_results): update_memory_valid_num += 1 - eval_results["overall_score"]["memory_update"]["correct_update_memory_ratio(all)"] = correct_update_memory_num / update_memory_num - eval_results["overall_score"]["memory_update"]["correct_update_memory_ratio(valid)"] = correct_update_memory_num / update_memory_valid_num - eval_results["overall_score"]["memory_update"]["hallucination_update_memory_ratio(all)"] = hallucination_update_memory_num / update_memory_num - eval_results["overall_score"]["memory_update"]["hallucination_update_memory_ratio(valid)"] = hallucination_update_memory_num / update_memory_valid_num - eval_results["overall_score"]["memory_update"]["omission_update_memory_ratio(all)"] = omission_update_memory_num / update_memory_num - eval_results["overall_score"]["memory_update"]["omission_update_memory_ratio(valid)"] = omission_update_memory_num / update_memory_valid_num - eval_results["overall_score"]["memory_update"]["other_update_memory_ratio(all)"] = other_update_memory_num / update_memory_num - eval_results["overall_score"]["memory_update"]["other_update_memory_ratio(valid)"] = other_update_memory_num / update_memory_valid_num + eval_results["overall_score"]["memory_update"]["correct_update_memory_ratio(all)"] = safe_div(correct_update_memory_num, update_memory_num) + eval_results["overall_score"]["memory_update"]["correct_update_memory_ratio(valid)"] = safe_div(correct_update_memory_num, update_memory_valid_num) + eval_results["overall_score"]["memory_update"]["hallucination_update_memory_ratio(all)"] = safe_div(hallucination_update_memory_num, update_memory_num) + eval_results["overall_score"]["memory_update"]["hallucination_update_memory_ratio(valid)"] = safe_div(hallucination_update_memory_num, update_memory_valid_num) + eval_results["overall_score"]["memory_update"]["omission_update_memory_ratio(all)"] = safe_div(omission_update_memory_num, update_memory_num) + eval_results["overall_score"]["memory_update"]["omission_update_memory_ratio(valid)"] = safe_div(omission_update_memory_num, update_memory_valid_num) + eval_results["overall_score"]["memory_update"]["other_update_memory_ratio(all)"] = safe_div(other_update_memory_num, update_memory_num) + eval_results["overall_score"]["memory_update"]["other_update_memory_ratio(valid)"] = safe_div(other_update_memory_num, update_memory_valid_num) eval_results["overall_score"]["memory_update"]["update_memory_valid_num"] = update_memory_valid_num eval_results["overall_score"]["memory_update"]["update_memory_num"] = update_memory_num @@ -352,12 +358,12 @@ def aggregate_eval_results(eval_results): qa_valid_num += 1 - eval_results["overall_score"]["question_answering"]["correct_qa_ratio(all)"] = correct_qa_num / qa_num - eval_results["overall_score"]["question_answering"]["correct_qa_ratio(valid)"] = correct_qa_num / qa_valid_num - eval_results["overall_score"]["question_answering"]["hallucination_qa_ratio(all)"] = hallucination_qa_num / qa_num - eval_results["overall_score"]["question_answering"]["hallucination_qa_ratio(valid)"] = hallucination_qa_num / qa_valid_num - eval_results["overall_score"]["question_answering"]["omission_qa_ratio(all)"] = omission_qa_num / qa_num - eval_results["overall_score"]["question_answering"]["omission_qa_ratio(valid)"] = omission_qa_num / qa_valid_num + eval_results["overall_score"]["question_answering"]["correct_qa_ratio(all)"] = safe_div(correct_qa_num, qa_num) + eval_results["overall_score"]["question_answering"]["correct_qa_ratio(valid)"] = safe_div(correct_qa_num, qa_valid_num) + eval_results["overall_score"]["question_answering"]["hallucination_qa_ratio(all)"] = safe_div(hallucination_qa_num, qa_num) + eval_results["overall_score"]["question_answering"]["hallucination_qa_ratio(valid)"] = safe_div(hallucination_qa_num, qa_valid_num) + eval_results["overall_score"]["question_answering"]["omission_qa_ratio(all)"] = safe_div(omission_qa_num, qa_num) + eval_results["overall_score"]["question_answering"]["omission_qa_ratio(valid)"] = safe_div(omission_qa_num, qa_valid_num) eval_results["overall_score"]["question_answering"]["qa_valid_num"] = qa_valid_num eval_results["overall_score"]["question_answering"]["qa_num"] = qa_num @@ -378,8 +384,9 @@ def aggregate_eval_results(eval_results): for key in eval_results["overall_score"]["memory_type_accuracy"]: - eval_results["overall_score"]["memory_type_accuracy"][key]['memory_integrity_acc'] = eval_results["overall_score"]["memory_type_accuracy"][key]['memory_integrity_acc'] / eval_results["overall_score"]["memory_type_accuracy"][key]['total_num'] - eval_results["overall_score"]["memory_type_accuracy"][key]['memory_update_acc'] = eval_results["overall_score"]["memory_type_accuracy"][key]['memory_update_acc'] / eval_results["overall_score"]["memory_type_accuracy"][key]['total_num'] + total_num = eval_results["overall_score"]["memory_type_accuracy"][key]['total_num'] + eval_results["overall_score"]["memory_type_accuracy"][key]['memory_integrity_acc'] = safe_div(eval_results["overall_score"]["memory_type_accuracy"][key]['memory_integrity_acc'], total_num) + eval_results["overall_score"]["memory_type_accuracy"][key]['memory_update_acc'] = safe_div(eval_results["overall_score"]["memory_type_accuracy"][key]['memory_update_acc'], total_num) eval_results["overall_score"]["memory_type_accuracy"][key]['memory_acc'] = eval_results["overall_score"]["memory_type_accuracy"][key]['memory_integrity_acc'] + eval_results["overall_score"]["memory_type_accuracy"][key]['memory_update_acc'] return eval_results @@ -510,7 +517,7 @@ def main( parser.add_argument( "--frame", type=str, - choices=["memzero", "memzero-graph", "zep", "memos", "memobase", "supermemory"], + choices=["memzero-local", "memzero", "memzero-graph", "zep", "memos", "memobase", "supermemory", "supermemory-local", "gemma4-e4b"], ) parser.add_argument( "--version", diff --git a/eval/llms.py b/eval/llms.py index 56de13d..56d62fe 100644 --- a/eval/llms.py +++ b/eval/llms.py @@ -1,7 +1,9 @@ import re import os import json +import sys import logging +from pathlib import Path from dotenv import load_dotenv from openai import OpenAI @@ -13,9 +15,9 @@ logger = logging.getLogger(__name__) # Define retry strategy parameters -RETRY_TIMES = int(os.getenv('RETRY_TIMES')) -WAIT_TIME_LOWER = int(os.getenv('WAIT_TIME_LOWER')) -WAIT_TIME_UPPER = int(os.getenv('WAIT_TIME_UPPER')) +RETRY_TIMES = int(os.getenv('RETRY_TIMES', '3')) +WAIT_TIME_LOWER = int(os.getenv('WAIT_TIME_LOWER', '1')) +WAIT_TIME_UPPER = int(os.getenv('WAIT_TIME_UPPER', '10')) OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL') OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') @@ -34,10 +36,242 @@ common_params["timeout"] = int(os.getenv('OPENAI_TIMEOUT')) -client = OpenAI( - base_url=OPENAI_BASE_URL, - api_key=OPENAI_API_KEY -) +def _load_toml(path: Path) -> dict: + if sys.version_info >= (3, 11): + import tomllib + + return tomllib.loads(path.read_text(encoding="utf-8")) + + data: dict = {} + current = data + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("[") and line.endswith("]"): + current = data + for part in line[1:-1].split("."): + key = part.strip().strip('"') + current = current.setdefault(key, {}) + continue + if "=" not in line: + continue + key, value = line.split("=", 1) + key = key.strip().strip('"') + value = value.strip() + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.lower() in {"true", "false"}: + value = value.lower() == "true" + elif value.isdigit(): + value = int(value) + current[key] = value + return data + + +def _bool_env(name: str, default: bool = False) -> bool: + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _codex_home() -> Path: + return Path(os.getenv("CODEX_HOME", "~/.codex")).expanduser() + + +def _load_codex_openai_config() -> dict: + codex_home = _codex_home() + config_path = Path(os.getenv("CODEX_CONFIG_PATH", codex_home / "config.toml")).expanduser() + auth_path = Path(os.getenv("CODEX_AUTH_PATH", codex_home / "auth.json")).expanduser() + + config = _load_toml(config_path) + provider_name = config.get("model_provider", "OpenAI") + provider = config.get("model_providers", {}).get(provider_name, {}) + auth = json.loads(auth_path.read_text(encoding="utf-8")) + + api_key = auth.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError(f"No OPENAI_API_KEY found in {auth_path} or environment") + + return { + "model": os.getenv("JUDGE_OPENAI_MODEL") or config.get("model"), + "base_url": os.getenv("JUDGE_OPENAI_BASE_URL") or provider.get("base_url"), + "api_key": api_key, + "wire_api": os.getenv("JUDGE_OPENAI_WIRE_API") or provider.get("wire_api", "responses"), + "service_tier": os.getenv("JUDGE_OPENAI_SERVICE_TIER") or config.get("service_tier"), + } + + +def _openai_client(base_url: str | None, api_key: str | None, timeout: int | None = None) -> OpenAI: + kwargs = { + "base_url": base_url, + "api_key": api_key, + } + if timeout: + kwargs["timeout"] = timeout + return OpenAI(**kwargs) + + +def _answer_params() -> dict: + params = {} + if os.getenv('OPENAI_MAX_TOKENS'): + params["max_tokens"] = int(os.getenv('OPENAI_MAX_TOKENS')) + if os.getenv('OPENAI_TEMPERATURE'): + params["temperature"] = float(os.getenv('OPENAI_TEMPERATURE')) + if os.getenv('OPENAI_TIMEOUT'): + params["timeout"] = int(os.getenv('OPENAI_TIMEOUT')) + return params + + +def _judge_params_for_responses(service_tier: str | None = None) -> dict: + params = {} + max_tokens = os.getenv("JUDGE_OPENAI_MAX_TOKENS") or os.getenv("OPENAI_MAX_TOKENS") + temperature = os.getenv("JUDGE_OPENAI_TEMPERATURE") + + if max_tokens: + params["max_output_tokens"] = int(max_tokens) + if temperature: + params["temperature"] = float(temperature) + if service_tier: + params["service_tier"] = service_tier + return params + + +def _judge_timeout() -> int | None: + value = os.getenv("JUDGE_OPENAI_TIMEOUT") or os.getenv("OPENAI_TIMEOUT") + return int(value) if value else None + + +def _json_from_text(content: str): + match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) + if not match: + match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL) + if match: + return json.loads(match.group(1).strip()) + + start = content.find("{") + end = content.rfind("}") + if start >= 0 and end > start: + return json.loads(content[start:end + 1]) + + raise ValueError(f"No JSON object found in model output: {content}") + + +def _extract_response_delta(event) -> str: + event_type = getattr(event, "type", None) + if event_type in {"response.output_text.delta", "response.output_text.done", "response.refusal.delta"}: + return getattr(event, "delta", "") or "" + + if isinstance(event, dict): + event_type = event.get("type") + if event_type in {"response.output_text.delta", "response.refusal.delta"}: + return event.get("delta", "") or "" + if event_type == "response.output_text.done": + return event.get("text", "") or "" + return "" + + +def _extract_response_text(response) -> str: + output_text = getattr(response, "output_text", None) + if output_text: + return output_text + + if isinstance(response, dict): + output_text = response.get("output_text") + if output_text: + return output_text + output = response.get("output", []) + else: + output = getattr(response, "output", []) + + chunks = [] + for item in output or []: + content = item.get("content", []) if isinstance(item, dict) else getattr(item, "content", []) + for part in content or []: + if isinstance(part, dict): + text = part.get("text") + else: + text = getattr(part, "text", None) + if text: + chunks.append(text) + return "".join(chunks) + + +def _stream_responses_request(prompt: str, cfg: dict) -> str: + client = _openai_client(cfg["base_url"], cfg["api_key"], timeout=_judge_timeout()) + stream = client.responses.create( + model=cfg["model"], + input=prompt, + stream=True, + **_judge_params_for_responses(cfg.get("service_tier")) + ) + + chunks = [] + completed_text = "" + for event in stream: + delta = _extract_response_delta(event) + if delta: + chunks.append(delta) + continue + + event_type = getattr(event, "type", None) + if isinstance(event, dict): + event_type = event.get("type") + response = event.get("response") + else: + response = getattr(event, "response", None) + if event_type == "response.completed" and response: + completed_text = _extract_response_text(response) + + if completed_text and not chunks: + return completed_text + return "".join(chunks) + + +def _stream_chat_request(prompt: str, cfg: dict) -> str: + client = _openai_client(cfg["base_url"], cfg["api_key"], timeout=_judge_timeout()) + params = {} + max_tokens = os.getenv("JUDGE_OPENAI_MAX_TOKENS") or os.getenv("OPENAI_MAX_TOKENS") + temperature = os.getenv("JUDGE_OPENAI_TEMPERATURE") + if max_tokens: + params["max_tokens"] = int(max_tokens) + if temperature: + params["temperature"] = float(temperature) + + stream = client.chat.completions.create( + model=cfg["model"], + messages=[{'role': 'user', 'content': prompt}], + stream=True, + **params, + ) + + chunks = [] + for event in stream: + delta = event.choices[0].delta.content if event.choices else None + if delta: + chunks.append(delta) + return "".join(chunks) + + +def _judge_request(prompt: str) -> str: + if _bool_env("JUDGE_USE_CODEX_CONFIG", False): + cfg = _load_codex_openai_config() + else: + cfg = { + "model": os.getenv("JUDGE_OPENAI_MODEL") or MODEL, + "base_url": os.getenv("JUDGE_OPENAI_BASE_URL") or OPENAI_BASE_URL, + "api_key": os.getenv("JUDGE_OPENAI_API_KEY") or OPENAI_API_KEY, + "wire_api": os.getenv("JUDGE_OPENAI_WIRE_API", "chat"), + "service_tier": os.getenv("JUDGE_OPENAI_SERVICE_TIER"), + } + + if not cfg.get("model") or not cfg.get("base_url") or not cfg.get("api_key"): + raise ValueError("Judge model, base_url, and api_key must be configured") + + if cfg.get("wire_api") == "responses": + return _stream_responses_request(prompt, cfg) + return _stream_chat_request(prompt, cfg) @retry( @@ -57,6 +291,7 @@ def llm_request(prompt): str: The response generated by the model. """ + client = _openai_client(OPENAI_BASE_URL, OPENAI_API_KEY) response_obj = client.chat.completions.create( model=MODEL, messages=[ @@ -65,7 +300,7 @@ def llm_request(prompt): 'content': prompt } ], - **common_params + **_answer_params() ) return response_obj.choices[0].message.content @@ -78,23 +313,10 @@ def llm_request(prompt): ) def llm_request_for_json(prompt): - response_obj = client.chat.completions.create( - model=MODEL, - messages=[{'role': 'user', 'content': prompt}], - **common_params - ) - - content = response_obj.choices[0].message.content or "" - - match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL) - if not match: - raise ValueError(f"No JSON block found in model output: {content}") - - json_str = match.group(1).strip() - - return json.loads(json_str) + content = _judge_request(prompt) or "" + return _json_from_text(content) if __name__ == '__main__': r = llm_request_for_json('hello') - print(r) \ No newline at end of file + print(r) diff --git a/eval/local_embedding_server.py b/eval/local_embedding_server.py new file mode 100644 index 0000000..1cc9d7d --- /dev/null +++ b/eval/local_embedding_server.py @@ -0,0 +1,113 @@ +import argparse +import hashlib +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +import json +import math +import re +import time +from typing import Any + + +def embed_text(text: str, dims: int) -> list[float]: + vector = [0.0] * dims + tokens = re.findall(r"[\w']+", str(text).lower()) + for token in tokens: + digest = hashlib.blake2b(token.encode("utf-8"), digest_size=8).digest() + bucket = int.from_bytes(digest[:4], "big") % dims + sign = 1.0 if digest[4] & 1 else -1.0 + vector[bucket] += sign + + norm = math.sqrt(sum(value * value for value in vector)) + if norm: + vector = [value / norm for value in vector] + return vector + + +class EmbeddingHandler(BaseHTTPRequestHandler): + default_dims = 384 + + def log_message(self, format: str, *args: Any) -> None: + return + + def send_json(self, status: int, payload: dict[str, Any]) -> None: + body = json.dumps(payload).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_GET(self) -> None: + if self.path == "/health": + self.send_json(200, {"status": "ok"}) + return + if self.path == "/v1/models": + self.send_json( + 200, + { + "object": "list", + "data": [ + { + "id": "hash-embedding", + "object": "model", + "created": 0, + "owned_by": "local", + } + ], + }, + ) + return + self.send_json(404, {"error": f"not found: {self.path}"}) + + def do_POST(self) -> None: + if self.path != "/v1/embeddings": + self.send_json(404, {"error": f"not found: {self.path}"}) + return + + try: + length = int(self.headers.get("Content-Length", "0")) + request = json.loads(self.rfile.read(length).decode("utf-8") or "{}") + raw_inputs = request.get("input", []) + inputs = raw_inputs if isinstance(raw_inputs, list) else [raw_inputs] + dims = int(request.get("dimensions") or self.default_dims) + model = request.get("model") or "hash-embedding" + token_count = sum(len(str(text).split()) for text in inputs) + self.send_json( + 200, + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": idx, + "embedding": embed_text(str(text), dims), + } + for idx, text in enumerate(inputs) + ], + "model": model, + "usage": { + "prompt_tokens": token_count, + "total_tokens": token_count, + }, + "created": int(time.time()), + }, + ) + except Exception as exc: + self.send_json(500, {"error": str(exc)}) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=18000) + parser.add_argument("--dims", type=int, default=384) + args = parser.parse_args() + + EmbeddingHandler.default_dims = args.dims + server = ThreadingHTTPServer((args.host, args.port), EmbeddingHandler) + print(f"Serving local hash embeddings at http://{args.host}:{args.port}/v1") + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/eval/local_mem0_components.py b/eval/local_mem0_components.py new file mode 100644 index 0000000..62b5a37 --- /dev/null +++ b/eval/local_mem0_components.py @@ -0,0 +1,53 @@ +import hashlib +import math +import re +from typing import Iterable, Literal, Optional + +from mem0.embeddings.base import EmbeddingBase + +try: + from langchain.embeddings.base import Embeddings +except ImportError: # pragma: no cover + Embeddings = object + + +class HashEmbedding(EmbeddingBase): + def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None): + dims = int(getattr(self.config, "embedding_dims", None) or 384) + vector = [0.0] * dims + tokens = re.findall(r"[\w']+", str(text).lower()) + for token in tokens: + digest = hashlib.blake2b(token.encode("utf-8"), digest_size=8).digest() + bucket = int.from_bytes(digest[:4], "big") % dims + sign = 1.0 if digest[4] & 1 else -1.0 + vector[bucket] += sign + + norm = math.sqrt(sum(value * value for value in vector)) + if norm: + vector = [value / norm for value in vector] + return vector + + +class HashLangchainEmbeddings(Embeddings): + def __init__(self, embedding_dims: int = 384): + self.embedding_dims = embedding_dims + + def _embed_one(self, text: str) -> list[float]: + vector = [0.0] * self.embedding_dims + tokens = re.findall(r"[\w']+", str(text).lower()) + for token in tokens: + digest = hashlib.blake2b(token.encode("utf-8"), digest_size=8).digest() + bucket = int.from_bytes(digest[:4], "big") % self.embedding_dims + sign = 1.0 if digest[4] & 1 else -1.0 + vector[bucket] += sign + + norm = math.sqrt(sum(value * value for value in vector)) + if norm: + vector = [value / norm for value in vector] + return vector + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._embed_one(text) for text in texts] + + def embed_query(self, text: str) -> list[float]: + return self._embed_one(text) diff --git a/eval/run_gemma4_e4b_memory_matrix.py b/eval/run_gemma4_e4b_memory_matrix.py new file mode 100644 index 0000000..3c952a8 --- /dev/null +++ b/eval/run_gemma4_e4b_memory_matrix.py @@ -0,0 +1,147 @@ +import argparse +import json +import os +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path + +from run_gemma4_e4b_omlx_eval import ( + DEFAULT_BASE, + DEFAULT_MODEL, + DEFAULT_OMLX, + EVAL_DIR, + FRAMES, + load_env_file, + preflight_frame, +) + + +DEFAULT_VERSION_PREFIX = "gemma4-e4b-medium-user1" +LOCAL_FRAMES = ["memzero-local", "memos", "memobase", "supermemory-local"] + + +def python_cmd() -> list[str]: + from shutil import which + + uv = which("uv") + if uv: + return [uv, "run", "python"] + return [sys.executable] + + +def result_path(frame: str, version: str) -> Path: + return EVAL_DIR / "results" / f"{frame}-{version}" / f"{frame}_eval_stat_result.json" + + +def run_frame(args: argparse.Namespace, frame: str, version: str) -> dict: + existing = result_path(frame, version) + if existing.exists() and not args.force: + return { + "frame": frame, + "version": version, + "status": "existing", + "result_path": str(existing), + } + + try: + preflight_frame(frame, os.environ.copy()) + except Exception as exc: + return { + "frame": frame, + "version": version, + "status": "skipped", + "reason": str(exc), + } + + cmd = python_cmd() + [ + "run_gemma4_e4b_omlx_eval.py", + "--frame", + frame, + "--data-path", + str(args.data_path), + "--version", + version, + "--limit-users", + str(args.limit_users), + "--max-workers", + str(args.max_workers), + "--top-k", + str(args.top_k), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + "--max-extracted-memories-per-session", + str(args.max_extracted_memories_per_session), + "--omlx-bin", + str(args.omlx_bin), + "--model-path", + str(args.model_path), + "--base-dir", + str(args.base_dir), + "--port", + str(args.port), + ] + + completed = subprocess.run(cmd, cwd=EVAL_DIR, check=False) + if completed.returncode != 0: + return { + "frame": frame, + "version": version, + "status": "failed", + "returncode": completed.returncode, + } + return { + "frame": frame, + "version": version, + "status": "completed", + "result_path": str(result_path(frame, version)), + } + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=Path, default=Path("data/HaluMem-Medium.jsonl")) + parser.add_argument("--version-prefix", default=DEFAULT_VERSION_PREFIX) + parser.add_argument("--frames", nargs="+", default=LOCAL_FRAMES) + parser.add_argument("--limit-users", type=int, default=1) + parser.add_argument("--max-workers", type=int, default=4) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + parser.add_argument("--force", action="store_true") + parser.add_argument("--omlx-bin", type=Path, default=DEFAULT_OMLX) + parser.add_argument("--model-path", type=Path, default=DEFAULT_MODEL) + parser.add_argument("--base-dir", type=Path, default=DEFAULT_BASE) + parser.add_argument("--port", type=int, default=50634) + args = parser.parse_args() + + load_env_file(EVAL_DIR / ".env") + load_env_file(EVAL_DIR.parent / ".env") + load_env_file(EVAL_DIR.parent / ".env copy") + + manifest = { + "started_at": datetime.now(timezone.utc).isoformat(), + "data_path": str(args.data_path), + "version_prefix": args.version_prefix, + "frames": [], + } + + for frame in args.frames: + version = f"{args.version_prefix}-{frame}" + if frame == "gemma4-e4b" and args.version_prefix == DEFAULT_VERSION_PREFIX: + version = "gemma4-e4b-medium-user1-fullsessions" + manifest["frames"].append(run_frame(args, frame, version)) + + manifest["finished_at"] = datetime.now(timezone.utc).isoformat() + report_dir = EVAL_DIR / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + manifest_path = report_dir / f"{args.version_prefix}-matrix-manifest.json" + manifest_path.write_text(json.dumps(manifest, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"Matrix manifest written to {manifest_path}") + + +if __name__ == "__main__": + main() diff --git a/eval/run_gemma4_e4b_omlx_eval.py b/eval/run_gemma4_e4b_omlx_eval.py new file mode 100644 index 0000000..d67359a --- /dev/null +++ b/eval/run_gemma4_e4b_omlx_eval.py @@ -0,0 +1,627 @@ +import argparse +import importlib.util +import json +import os +import shutil +import signal +import subprocess +import sys +import time +from datetime import datetime, timezone +from pathlib import Path + +import requests + + +REPO_ROOT = Path(__file__).resolve().parents[1] +EVAL_DIR = Path(__file__).resolve().parent +DEFAULT_OMLX = Path("/Users/qiang/arika/.venv-omlx/bin/omlx") +DEFAULT_MODEL = Path("/Users/qiang/.omlx/models/gemma-4-E4B-it-MLX-8bit") +DEFAULT_BASE = Path("/Users/qiang/.omlx-halumem-gemma4-e4b") +FRAMES = ["memzero-local", "memzero", "memzero-graph", "zep", "memos", "memobase", "supermemory", "supermemory-local", "gemma4-e4b"] +FRAME_MODULES = { + "memzero-local": ["mem0", "qdrant_client"], + "memzero": ["mem0"], + "memzero-graph": ["mem0"], + "zep": ["zep_cloud"], + "memos": [], + "memobase": ["memobase", "psycopg2", "sqlalchemy"], + "supermemory": ["supermemory"], + "supermemory-local": ["supermemory"], + "gemma4-e4b": [], +} +FRAME_ENV = { + "memzero-local": [], + "memzero": ["MEM0_API_KEY"], + "memzero-graph": [], + "zep": ["ZEP_API_KEY"], + "memos": ["MEMOS_URL"], + "memobase": [ + "MEMOBASE_PROJECT_URL", + "MEMOBASE_PROJECT_TOKEN", + "MEMOBASE_DB_HOST", + "MEMOBASE_DB_PORT", + "MEMOBASE_DB_USER", + "MEMOBASE_DB_PASSWORD", + "MEMOBASE_DB_NAME", + ], + "supermemory": ["SUPERMEMORY_API_KEY"], + "supermemory-local": [], + "gemma4-e4b": [], +} + + +def load_env_file(path: Path) -> None: + if not path.exists(): + return + + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + key = key.strip() + value = value.strip().strip('"').strip("'") + os.environ.setdefault(key, value) + + +def run( + cmd: list[str], + *, + cwd: Path = EVAL_DIR, + env: dict[str, str] | None = None, + check: bool = True, +) -> subprocess.CompletedProcess: + print(f"$ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, env=env, check=check) + + +def ensure_single_model_dir(base_dir: Path, model_path: Path) -> Path: + model_dir = base_dir / "models" + model_dir.mkdir(parents=True, exist_ok=True) + link_path = model_dir / model_path.name + + if link_path.exists() or link_path.is_symlink(): + if link_path.is_symlink() and Path(os.readlink(link_path)) == model_path: + return model_dir + if link_path.is_dir() and not link_path.is_symlink(): + raise RuntimeError(f"{link_path} exists and is not a symlink") + link_path.unlink() + + link_path.symlink_to(model_path) + return model_dir + + +def wait_for_port_release(host: str, port: int, timeout_s: int = 30) -> None: + deadline = time.time() + timeout_s + while time.time() < deadline: + try: + requests.get(f"http://{host}:{port}/health", timeout=1) + except requests.RequestException: + return + time.sleep(0.5) + raise TimeoutError(f"Port {port} still responds after {timeout_s}s") + + +def wait_for_health(host: str, port: int, timeout_s: int = 240) -> None: + deadline = time.time() + timeout_s + last_error = "" + while time.time() < deadline: + try: + response = requests.get(f"http://{host}:{port}/health", timeout=5) + if response.status_code < 500: + return + last_error = response.text + except requests.RequestException as exc: + last_error = str(exc) + time.sleep(2) + raise TimeoutError(f"oMLX did not become healthy on {host}:{port}: {last_error}") + + +def smoke_chat(host: str, port: int, model: str) -> None: + models = requests.get(f"http://{host}:{port}/v1/models", timeout=30) + models.raise_for_status() + model_ids = [item["id"] for item in models.json().get("data", [])] + if model not in model_ids: + raise RuntimeError(f"Expected model {model!r}, got {model_ids!r}") + + response = requests.post( + f"http://{host}:{port}/v1/chat/completions", + json={ + "model": model, + "messages": [{"role": "user", "content": "Return exactly: OK"}], + "temperature": 0, + "max_tokens": 8, + }, + timeout=180, + ) + response.raise_for_status() + content = response.json()["choices"][0]["message"]["content"] + print(f"Smoke completion: {content!r}") + + +def check_disk_space(path: Path, min_free_gb: float, force: bool = False) -> None: + path.mkdir(parents=True, exist_ok=True) + usage = shutil.disk_usage(path) + free_gb = usage.free / (1024 ** 3) + if free_gb < min_free_gb and not force: + raise RuntimeError( + f"Only {free_gb:.2f} GiB free at {path}. " + f"Need at least {min_free_gb:.2f} GiB; pass --force-low-disk to override." + ) + print(f"Disk free at {path}: {free_gb:.2f} GiB") + + +def preflight_memos(env: dict[str, str]) -> None: + memos_url = env.get("MEMOS_URL") + if not memos_url: + raise RuntimeError("MEMOS_URL is required for --frame memos") + + try: + response = requests.get(memos_url.rstrip("/") + "/health", timeout=5) + print(f"MemOS health probe: HTTP {response.status_code}") + except requests.RequestException as exc: + raise RuntimeError(f"MemOS service is not reachable at {memos_url}: {exc}") from exc + + +def preflight_frame(frame: str, env: dict[str, str]) -> None: + missing_env = [name for name in FRAME_ENV.get(frame, []) if not env.get(name)] + missing_modules = [name for name in FRAME_MODULES.get(frame, []) if importlib.util.find_spec(name) is None] + + problems = [] + if missing_env: + problems.append(f"missing required environment variables: {', '.join(missing_env)}") + if missing_modules: + problems.append(f"missing Python modules: {', '.join(missing_modules)}") + if problems: + raise RuntimeError(f"{frame} " + "; ".join(problems)) + + if frame == "memos": + preflight_memos(env) + + +def start_omlx(args: argparse.Namespace, log_path: Path) -> subprocess.Popen: + model_dir = ensure_single_model_dir(args.base_dir, args.model_path) + paged_ssd = args.base_dir / "paged-ssd" + paged_ssd.mkdir(parents=True, exist_ok=True) + log_path.parent.mkdir(parents=True, exist_ok=True) + + cmd = [ + str(args.omlx_bin), + "serve", + "--model-dir", + str(model_dir), + "--host", + args.host, + "--port", + str(args.port), + "--log-level", + args.log_level, + "--max-concurrent-requests", + str(args.max_concurrent_requests), + "--base-path", + str(args.base_dir), + "--paged-ssd-cache-dir", + str(paged_ssd), + "--paged-ssd-cache-max-size", + args.paged_ssd_cache_max_size, + "--hot-cache-max-size", + args.hot_cache_max_size, + ] + + log_file = log_path.open("a", encoding="utf-8") + print(f"$ {' '.join(cmd)} > {log_path} 2>&1") + process = subprocess.Popen( + cmd, + cwd=str(EVAL_DIR), + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + ) + process._halumem_log_file = log_file # type: ignore[attr-defined] + return process + + +def stop_omlx(process: subprocess.Popen | None, host: str, port: int) -> None: + if process and process.poll() is None: + print(f"Stopping oMLX pid={process.pid}") + process.send_signal(signal.SIGTERM) + try: + process.wait(timeout=30) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=30) + + log_file = getattr(process, "_halumem_log_file", None) + if log_file: + log_file.close() + + wait_for_port_release(host, port) + + +def make_subset(data_path: Path, limit_users: int, output_path: Path) -> Path: + if limit_users <= 0: + return data_path + output_path.parent.mkdir(parents=True, exist_ok=True) + count = 0 + with data_path.open("r", encoding="utf-8") as src, output_path.open("w", encoding="utf-8") as dst: + for line in src: + if not line.strip(): + continue + dst.write(line) + count += 1 + if count >= limit_users: + break + if count == 0: + raise RuntimeError(f"No JSONL records found in {data_path}") + print(f"Wrote {count} users to {output_path}") + return output_path + + +def load_stat_result(frame: str, version: str) -> dict: + path = EVAL_DIR / "results" / f"{frame}-{version}" / f"{frame}_eval_stat_result.json" + with path.open("r", encoding="utf-8") as f: + return json.load(f) + + +def load_overall_score(frame: str, version: str) -> dict: + return load_stat_result(frame, version)["overall_score"] + + +def pct(value) -> str: + if value is None: + return "-" + return f"{value * 100:.2f}%" + + +def write_report(args: argparse.Namespace, version: str, started_at: str, finished_at: str) -> Path: + stat_result = load_stat_result(args.frame, version) + score = stat_result["overall_score"] + report_dir = EVAL_DIR / "reports" + report_dir.mkdir(parents=True, exist_ok=True) + report_path = report_dir / f"{version}.md" + + memory_integrity = score["memory_integrity"] + memory_accuracy = score["memory_accuracy"] + memory_update = score["memory_update"] + qa = score["question_answering"] + timing = score["time_consuming"] + + body = f"""# HaluMem Gemma4-E4B oMLX Evaluation + +- Version: `{version}` +- Frame: `{args.frame}` +- Dataset: `{args.data_path}` +- User limit: `{args.limit_users if args.limit_users else 'all'}` +- Max extracted memories/session: `{args.max_extracted_memories_per_session}` +- Started: `{started_at}` +- Finished: `{finished_at}` + +## System Under Test + +- Memory backend: `{args.frame}` +- Answer model: `{args.model_path.name}` +- Answer endpoint: `http://{args.host}:{args.port}/v1` +- oMLX restart policy: started before adapter run, stopped and port-verified before scoring + +## Judge + +- Model source: Codex config +- Model override: `{os.getenv('JUDGE_OPENAI_MODEL', 'Codex config model')}` +- Wire API: Responses API when configured by Codex +- Streaming: `true` + +## Scores + +| Area | Metric | Value | +| --- | --- | ---: | +| Memory Extraction | Recall(all) | {pct(memory_integrity.get('recall(all)'))} | +| Memory Extraction | Weighted Recall(all) | {pct(memory_integrity.get('weighted_recall(all)'))} | +| Memory Extraction | Target Accuracy(all) | {pct(memory_accuracy.get('target_accuracy(all)'))} | +| Memory Extraction | Weighted Accuracy(all) | {pct(memory_accuracy.get('weighted_accuracy(all)'))} | +| Memory Extraction | False Memory Resistance(all) | {pct(memory_accuracy.get('interference_accuracy(all)'))} | +| Memory Extraction | F1 | {pct(score.get('memory_extraction_f1'))} | +| Memory Update | Correct(all) | {pct(memory_update.get('correct_update_memory_ratio(all)'))} | +| Memory Update | Hallucination(all) | {pct(memory_update.get('hallucination_update_memory_ratio(all)'))} | +| Memory Update | Omission(all) | {pct(memory_update.get('omission_update_memory_ratio(all)'))} | +| QA | Correct(all) | {pct(qa.get('correct_qa_ratio(all)'))} | +| QA | Hallucination(all) | {pct(qa.get('hallucination_qa_ratio(all)'))} | +| QA | Omission(all) | {pct(qa.get('omission_qa_ratio(all)'))} | + +## Evaluation Records + +| Record Type | Count | +| --- | ---: | +| Memory Integrity | {len(stat_result.get('memory_integrity_records', []))} | +| Memory Accuracy | {len(stat_result.get('memory_accuracy_records', []))} | +| Memory Update | {len(stat_result.get('memory_update_records', []))} | +| QA | {len(stat_result.get('question_answering_records', []))} | + +## Timing + +| Metric | Minutes | +| --- | ---: | +| Add dialogue | {timing.get('add_dialogue_duration_time', 0):.2f} | +| Search memory | {timing.get('search_memory_duration_time', 0):.2f} | +| Total | {timing.get('total_duration_time', 0):.2f} | +""" + report_path.write_text(body, encoding="utf-8") + return report_path + + +def python_cmd() -> list[str]: + uv = shutil.which("uv") + if uv: + return [uv, "run", "python"] + return [sys.executable] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--frame", choices=FRAMES, default="gemma4-e4b") + parser.add_argument("--data-path", type=Path, required=True) + parser.add_argument("--version", default="") + parser.add_argument("--limit-users", type=int, default=1) + parser.add_argument("--max-workers", type=int, default=1) + parser.add_argument("--top-k", type=int, default=20) + parser.add_argument("--pref-top-k", type=int, default=6) + parser.add_argument("--skip-adapter", action="store_true") + parser.add_argument("--skip-scoring", action="store_true") + parser.add_argument("--keep-omlx-running", action="store_true") + parser.add_argument("--omlx-bin", type=Path, default=DEFAULT_OMLX) + parser.add_argument("--model-path", type=Path, default=DEFAULT_MODEL) + parser.add_argument("--base-dir", type=Path, default=DEFAULT_BASE) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", type=int, default=50634) + parser.add_argument("--log-level", default="info") + parser.add_argument("--max-concurrent-requests", type=int, default=1) + parser.add_argument("--paged-ssd-cache-max-size", default="80GB") + parser.add_argument("--hot-cache-max-size", default="8GB") + parser.add_argument("--min-free-gb", type=float, default=20.0) + parser.add_argument("--force-low-disk", action="store_true") + parser.add_argument("--max-sessions", type=int, default=0) + parser.add_argument("--max-questions-per-session", type=int, default=0) + parser.add_argument("--max-extracted-memories-per-session", type=int, default=6) + args = parser.parse_args() + + load_env_file(EVAL_DIR / ".env") + load_env_file(REPO_ROOT / ".env") + load_env_file(REPO_ROOT / ".env copy") + + if not args.data_path.exists(): + raise FileNotFoundError(f"Dataset not found: {args.data_path}") + if not args.omlx_bin.exists(): + raise FileNotFoundError(f"oMLX binary not found: {args.omlx_bin}") + if not args.model_path.exists(): + raise FileNotFoundError(f"Model path not found: {args.model_path}") + if not args.skip_adapter: + check_disk_space(args.base_dir, args.min_free_gb, args.force_low_disk) + preflight_frame(args.frame, os.environ.copy()) + + timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + version = args.version or f"gemma4-e4b-{args.frame}-{timestamp}" + run_dir = EVAL_DIR / "runs" / version + run_dir.mkdir(parents=True, exist_ok=True) + data_path = make_subset(args.data_path, args.limit_users, run_dir / "input.jsonl") + started_at = datetime.now(timezone.utc).isoformat() + + omlx_process = None + try: + if not args.skip_adapter: + wait_for_port_release(args.host, args.port, timeout_s=3) + omlx_process = start_omlx(args, run_dir / "omlx.log") + wait_for_health(args.host, args.port) + smoke_chat(args.host, args.port, args.model_path.name) + + env = os.environ.copy() + env.update( + { + "OPENAI_BASE_URL": f"http://{args.host}:{args.port}/v1", + "OPENAI_API_KEY": env.get("OPENAI_API_KEY", "local-omlx"), + "OPENAI_MODEL": args.model_path.name, + "OPENAI_TEMPERATURE": env.get("ANSWER_OPENAI_TEMPERATURE", "0"), + "OPENAI_MAX_TOKENS": env.get("ANSWER_OPENAI_MAX_TOKENS", "2048"), + "OPENAI_TIMEOUT": env.get("ANSWER_OPENAI_TIMEOUT", "600"), + } + ) + + if args.frame == "memzero-local": + adapter_cmd = [ + "eval_memzero_local.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--user-num", + str(args.limit_users if args.limit_users > 0 else 20), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + "--max-extracted-memories-per-session", + str(args.max_extracted_memories_per_session), + ] + elif args.frame == "memzero": + adapter_cmd = [ + "eval_memzero.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--max-workers", + str(args.max_workers), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + ] + elif args.frame == "memzero-graph": + adapter_cmd = [ + "eval_memzero_graph.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--user-num", + str(args.limit_users if args.limit_users > 0 else 20), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + "--max-extracted-memories-per-session", + str(args.max_extracted_memories_per_session), + ] + elif args.frame == "zep": + adapter_cmd = [ + "eval_zep.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--max-workers", + str(args.max_workers), + "--run-task", + "add", + ] + run(python_cmd() + adapter_cmd, env=env) + adapter_cmd = [ + "eval_zep.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--max-workers", + str(args.max_workers), + "--run-task", + "search", + ] + elif args.frame == "memos": + adapter_cmd = [ + "eval_memos.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--pref-top-k", + str(args.pref_top_k), + "--max-workers", + str(args.max_workers), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + ] + elif args.frame == "memobase": + adapter_cmd = [ + "eval_memobase.py", + "--data-path", + str(data_path), + "--version", + version, + "--max-token-size", + str(args.top_k * 25), + "--max-workers", + str(args.max_workers), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + ] + elif args.frame == "supermemory": + adapter_cmd = [ + "eval_supermemory.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--max-workers", + str(args.max_workers), + ] + elif args.frame == "supermemory-local": + env.setdefault("SUPERMEMORY_API_KEY", "local") + env.setdefault("SUPERMEMORY_BASE_URL", "http://127.0.0.1:6767") + adapter_cmd = [ + "eval_supermemory_local.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--user-num", + str(args.limit_users if args.limit_users > 0 else 20), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + "--max-extracted-memories-per-session", + str(args.max_extracted_memories_per_session), + ] + else: + adapter_cmd = [ + "eval_gemma4_e4b_local.py", + "--data-path", + str(data_path), + "--version", + version, + "--top-k", + str(args.top_k), + "--user-num", + str(args.limit_users if args.limit_users > 0 else 20), + "--max-sessions", + str(args.max_sessions), + "--max-questions-per-session", + str(args.max_questions_per_session), + "--max-extracted-memories-per-session", + str(args.max_extracted_memories_per_session), + ] + + run(python_cmd() + adapter_cmd, env=env) + finally: + if not args.keep_omlx_running: + stop_omlx(omlx_process, args.host, args.port) + + if not args.skip_scoring: + judge_env = os.environ.copy() + judge_env.update( + { + "JUDGE_USE_CODEX_CONFIG": "true", + "RETRY_TIMES": judge_env.get("RETRY_TIMES", "3"), + "WAIT_TIME_LOWER": judge_env.get("WAIT_TIME_LOWER", "1"), + "WAIT_TIME_UPPER": judge_env.get("WAIT_TIME_UPPER", "10"), + "JUDGE_OPENAI_TIMEOUT": judge_env.get("JUDGE_OPENAI_TIMEOUT", "600"), + } + ) + run( + python_cmd() + + [ + "evaluation.py", + "--frame", + args.frame, + "--version", + version, + ], + env=judge_env, + ) + + finished_at = datetime.now(timezone.utc).isoformat() + report_path = write_report(args, version, started_at, finished_at) + print(f"Report written to {report_path}") + + +if __name__ == "__main__": + main() diff --git a/eval/write_gemma4_e4b_comparison_report.py b/eval/write_gemma4_e4b_comparison_report.py new file mode 100644 index 0000000..cced4e4 --- /dev/null +++ b/eval/write_gemma4_e4b_comparison_report.py @@ -0,0 +1,410 @@ +import argparse +import json +from pathlib import Path + + +EVAL_DIR = Path(__file__).resolve().parent +REPO_ROOT = EVAL_DIR.parent + + +ORIGINAL_MEDIUM = { + "MemOS": { + "recall": 79.60, + "weighted_recall": 83.52, + "target_accuracy": 68.55, + "weighted_accuracy": 37.51, + "fmr": 50.58, + "f1": 73.65, + "update_correct": 61.46, + "update_hallucination": 0.60, + "update_omission": 37.94, + "qa_correct": 61.82, + "qa_hallucination": 16.73, + "qa_omission": 21.45, + }, + "Zep": { + "recall": None, + "weighted_recall": None, + "target_accuracy": None, + "weighted_accuracy": None, + "fmr": None, + "f1": None, + "update_correct": 51.55, + "update_hallucination": 0.75, + "update_omission": 47.70, + "qa_correct": 51.05, + "qa_hallucination": 22.09, + "qa_omission": 26.86, + }, + "Mem0": { + "recall": 31.83, + "weighted_recall": 45.01, + "target_accuracy": 81.74, + "weighted_accuracy": 60.88, + "fmr": 57.31, + "f1": 45.80, + "update_correct": 25.50, + "update_hallucination": 0.45, + "update_omission": 74.02, + "qa_correct": 53.02, + "qa_hallucination": 19.17, + "qa_omission": 27.81, + }, + "Supermemory": { + "recall": 41.53, + "weighted_recall": 64.76, + "target_accuracy": 90.32, + "weighted_accuracy": 60.83, + "fmr": 51.77, + "f1": 56.90, + "update_correct": 16.37, + "update_hallucination": 1.15, + "update_omission": 82.47, + "qa_correct": 54.07, + "qa_hallucination": 22.24, + "qa_omission": 23.69, + }, + "Memobase": { + "recall": 14.55, + "weighted_recall": 25.88, + "target_accuracy": 92.24, + "weighted_accuracy": 32.29, + "fmr": 80.78, + "f1": 25.13, + "update_correct": 5.20, + "update_hallucination": 0.55, + "update_omission": 94.25, + "qa_correct": 35.33, + "qa_hallucination": 29.97, + "qa_omission": 34.71, + }, +} + +ORIGINAL_MEDIUM["Mem0-Graph"] = { + "recall": 31.55, + "weighted_recall": 45.80, + "target_accuracy": 82.82, + "weighted_accuracy": 56.43, + "fmr": 56.49, + "f1": 45.71, + "update_correct": 28.85, + "update_hallucination": 0.45, + "update_omission": 70.70, + "qa_correct": 53.51, + "qa_hallucination": 19.06, + "qa_omission": 27.43, +} + + +FRAME_NAMES = { + "gemma4-e4b": "Gemma4-E4B local", + "memzero-local": "Mem0 local", + "memzero": "Mem0", + "memzero-graph": "Mem0-Graph", + "memos": "MemOS", + "memobase": "Memobase", + "zep": "Zep", + "supermemory": "Supermemory", + "supermemory-local": "Supermemory local", +} + + +def pct(value: float | None) -> str: + if value is None: + return "-" + return f"{value:.2f}%" + + +def delta(current: float | None, original: float | None) -> str: + if current is None or original is None: + return "-" + value = current - original + sign = "+" if value >= 0 else "" + return f"{sign}{value:.2f} pp" + + +def stat_to_metrics(stat: dict) -> dict: + score = stat["overall_score"] + integrity = score["memory_integrity"] + accuracy = score["memory_accuracy"] + update = score["memory_update"] + qa = score["question_answering"] + return { + "recall": integrity.get("recall(all)", 0) * 100, + "weighted_recall": integrity.get("weighted_recall(all)", 0) * 100, + "target_accuracy": accuracy.get("target_accuracy(all)", 0) * 100, + "weighted_accuracy": accuracy.get("weighted_accuracy(all)", 0) * 100, + "fmr": accuracy.get("interference_accuracy(all)", 0) * 100, + "f1": score.get("memory_extraction_f1", 0) * 100, + "update_correct": update.get("correct_update_memory_ratio(all)", 0) * 100, + "update_hallucination": update.get("hallucination_update_memory_ratio(all)", 0) * 100, + "update_omission": update.get("omission_update_memory_ratio(all)", 0) * 100, + "qa_correct": qa.get("correct_qa_ratio(all)", 0) * 100, + "qa_hallucination": qa.get("hallucination_qa_ratio(all)", 0) * 100, + "qa_omission": qa.get("omission_qa_ratio(all)", 0) * 100, + "add_minutes": score["time_consuming"].get("add_dialogue_duration_time", 0), + "search_minutes": score["time_consuming"].get("search_memory_duration_time", 0), + "total_minutes": score["time_consuming"].get("total_duration_time", 0), + "records": { + "integrity": len(stat.get("memory_integrity_records", [])), + "accuracy": len(stat.get("memory_accuracy_records", [])), + "update": len(stat.get("memory_update_records", [])), + "qa": len(stat.get("question_answering_records", [])), + }, + } + + +def load_stat(frame: str, version: str) -> dict | None: + path = EVAL_DIR / "results" / f"{frame}-{version}" / f"{frame}_eval_stat_result.json" + if not path.exists(): + return None + return json.loads(path.read_text(encoding="utf-8")) + + +def load_manifest(path: Path) -> dict: + if path.exists(): + return json.loads(path.read_text(encoding="utf-8")) + return {"frames": []} + + +def comparison_rows(metrics_by_name: dict[str, dict]) -> str: + rows = [ + "| System | Scope | Recall | Weighted Recall | Target Acc | Weighted Acc | FMR | F1 | Update C | QA C |", + "| --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |", + ] + for name, metrics in metrics_by_name.items(): + rows.append( + "| " + + " | ".join( + [ + name, + "Gemma4-E4B run", + pct(metrics["recall"]), + pct(metrics["weighted_recall"]), + pct(metrics["target_accuracy"]), + pct(metrics["weighted_accuracy"]), + pct(metrics["fmr"]), + pct(metrics["f1"]), + pct(metrics["update_correct"]), + pct(metrics["qa_correct"]), + ] + ) + + " |" + ) + return "\n".join(rows) + + +def original_rows() -> str: + rows = [ + "| System | Recall | Weighted Recall | Target Acc | Weighted Acc | FMR | F1 | Update C | QA C |", + "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: |", + ] + for name, metrics in ORIGINAL_MEDIUM.items(): + rows.append( + "| " + + " | ".join( + [ + name, + pct(metrics["recall"]), + pct(metrics["weighted_recall"]), + pct(metrics["target_accuracy"]), + pct(metrics["weighted_accuracy"]), + pct(metrics["fmr"]), + pct(metrics["f1"]), + pct(metrics["update_correct"]), + pct(metrics["qa_correct"]), + ] + ) + + " |" + ) + return "\n".join(rows) + + +def delta_rows(metrics_by_name: dict[str, dict]) -> str: + rows = [ + "| System | Recall Delta | Weighted Recall Delta | F1 Delta | Update C Delta | QA C Delta |", + "| --- | ---: | ---: | ---: | ---: | ---: |", + ] + for name, metrics in metrics_by_name.items(): + original_name = "Mem0" if name == "Mem0 local" else name + original = ORIGINAL_MEDIUM.get(original_name) + if not original: + continue + rows.append( + "| " + + " | ".join( + [ + name, + delta(metrics["recall"], original["recall"]), + delta(metrics["weighted_recall"], original["weighted_recall"]), + delta(metrics["f1"], original["f1"]), + delta(metrics["update_correct"], original["update_correct"]), + delta(metrics["qa_correct"], original["qa_correct"]), + ] + ) + + " |" + ) + return "\n".join(rows) + + +def baseline_summary_rows(metrics_by_name: dict[str, dict]) -> str: + metric_keys = [ + ("recall", "Recall"), + ("weighted_recall", "Weighted Recall"), + ("f1", "F1"), + ("update_correct", "Update C"), + ("qa_correct", "QA C"), + ] + rows = [ + "| Metric | Gemma4-E4B local | Original Best | Original Mean | Delta vs Best | Delta vs Mean |", + "| --- | ---: | ---: | ---: | ---: | ---: |", + ] + gemma = metrics_by_name.get("Gemma4-E4B local") + if not gemma: + rows.append("| - | - | - | - | - | - |") + return "\n".join(rows) + + for key, label in metric_keys: + values = [metrics[key] for metrics in ORIGINAL_MEDIUM.values() if metrics.get(key) is not None] + best = max(values) if values else None + mean = sum(values) / len(values) if values else None + rows.append( + "| " + + " | ".join( + [ + label, + pct(gemma.get(key)), + pct(best), + pct(mean), + delta(gemma.get(key), best), + delta(gemma.get(key), mean), + ] + ) + + " |" + ) + return "\n".join(rows) + + +def skipped_rows(manifest: dict) -> str: + rows = ["| Frame | Version | Status | Reason |", "| --- | --- | --- | --- |"] + for item in manifest.get("frames", []): + if item.get("status") in {"completed", "existing"}: + continue + rows.append( + f"| {item.get('frame')} | {item.get('version')} | {item.get('status')} | " + f"{item.get('reason') or item.get('returncode') or ''} |" + ) + if len(rows) == 2: + rows.append("| - | - | - | No skipped frames in manifest |") + return "\n".join(rows) + + +def record_rows(metrics_by_name: dict[str, dict]) -> str: + rows = [ + "| System | Integrity | Accuracy | Update | QA | Add min | Search min | Total min |", + "| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |", + ] + for name, metrics in metrics_by_name.items(): + records = metrics["records"] + rows.append( + f"| {name} | {records['integrity']} | {records['accuracy']} | {records['update']} | " + f"{records['qa']} | {metrics['add_minutes']:.2f} | {metrics['search_minutes']:.2f} | " + f"{metrics['total_minutes']:.2f} |" + ) + return "\n".join(rows) + + +def completed_system_summary(metrics_by_name: dict[str, dict]) -> str: + if not metrics_by_name: + return "none" + return ", ".join(metrics_by_name) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--manifest", type=Path, default=EVAL_DIR / "reports/gemma4-e4b-medium-user1-matrix-manifest.json") + parser.add_argument("--output", type=Path, default=EVAL_DIR / "reports/gemma4-e4b-memory-system-comparison.md") + args = parser.parse_args() + + manifest = load_manifest(args.manifest) + metrics_by_name = {} + + for item in manifest.get("frames", []): + if item.get("status") not in {"completed", "existing"}: + continue + frame = item["frame"] + version = item["version"] + stat = load_stat(frame, version) + if stat: + metrics_by_name[FRAME_NAMES.get(frame, frame)] = stat_to_metrics(stat) + + if not metrics_by_name: + stat = load_stat("gemma4-e4b", "gemma4-e4b-medium-user1-fullsessions") + if stat: + metrics_by_name["Gemma4-E4B local"] = stat_to_metrics(stat) + + body = f"""# Gemma4-E4B Memory System Comparison + +## Scope + +- Dataset: `data/HaluMem-Medium.jsonl` +- Current runnable scope: first user, all 65 sessions +- Answer model for Gemma4-E4B runs: local `gemma-4-E4B-it-MLX-8bit` via oMLX +- Judge: GPT-5.5, streaming Responses API. MemOS final scoring used the user-provided API key with the Codex-configured base URL. +- Original baseline source: `README.md` HaluMem-Medium tables + +## Gemma4-E4B Runs + +{comparison_rows(metrics_by_name)} + +## Run Size And Timing + +{record_rows(metrics_by_name)} + +## Original README Baseline + +These are the project README's HaluMem-Medium results for the original large-model memory-system evaluation. + +{original_rows()} + +## Delta Against Original README Baseline + +Positive deltas are better for Recall, Weighted Recall, F1, Update C, and QA C. + +{delta_rows(metrics_by_name)} + +## Interpretation + +- MemOS local produced the strongest extraction and update scores in this run, with the highest Recall, Weighted Recall, F1, Update Correct, and QA Correct among the completed local systems. +- Mem0 local, Mem0-Graph, and Supermemory local kept very high target-memory accuracy, but their update and QA scores were lower than MemOS local. +- Memobase local was conservative: it had the highest false-memory resistance among the completed local systems, but the weakest recall, update, and QA scores. +- Compared with the README's original large-model baseline, several Gemma4-E4B local runs improved extraction F1, but QA Correct and Update Correct were lower. Treat this as directional because the current run is one user with all 65 sessions, while the README table is the original HaluMem-Medium benchmark summary. + +## Frames Not Run In Current Environment + +{skipped_rows(manifest)} + +## Notes + +- Completed local memory-system runs: {completed_system_summary(metrics_by_name)}. +- The cloud-hosted adapters for Mem0, Zep, and Supermemory were not used because this run intentionally avoided memory-system cloud keys. +- This report does not treat skipped cloud adapters as failed model results; it records them as out of scope for the local-only setup. +- The original README baseline appears to be full HaluMem-Medium system results, while the current Gemma4-E4B run is a one-user full-session run. Compare directionally, not as a full benchmark replacement. + +## Reproduction Command + +Regenerate this report from the saved stat files: + +```bash +uv run python eval/write_gemma4_e4b_comparison_report.py \\ + --manifest eval/reports/gemma4-e4b-medium-user1-matrix-manifest.json \\ + --output eval/reports/gemma4-e4b-memory-system-comparison.md +``` +""" + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(body, encoding="utf-8") + print(f"Comparison report written to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..7518fc9 --- /dev/null +++ b/uv.lock @@ -0,0 +1,3 @@ +version = 1 +revision = 3 +requires-python = ">=3.12"