Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,56 @@ These artifacts enable reproducibility and further analysis (error inspection, e

---

## Running Gemma4-E4B through local oMLX with GPT-5.5 Judge

For local Gemma4-E4B experiments, use the orchestration runner:

```bash
uv run python eval/run_gemma4_e4b_omlx_eval.py \
--data-path /path/to/HaluMem-medium.jsonl \
--version gemma4-e4b-medium-smoke \
--limit-users 1
```

The runner:

* starts an isolated local `omlx serve` instance for `gemma-4-E4B-it-MLX-8bit`;
* verifies `/health`, `/v1/models`, and a tiny `/v1/chat/completions` request;
* runs the local `gemma4-e4b` adapter with the local Gemma4-E4B endpoint;
* stops `omlx serve` and verifies the port is released;
* runs `evaluation.py` with `JUDGE_USE_CODEX_CONFIG=true`, so GPT-5.5 judge credentials and base URL are read from Codex config;
* writes a markdown report under `eval/reports/`.

Default local paths:

```text
oMLX binary: /Users/qiang/arika/.venv-omlx/bin/omlx
Gemma4-E4B model: /Users/qiang/.omlx/models/gemma-4-E4B-it-MLX-8bit
oMLX base path: /Users/qiang/.omlx-halumem-gemma4-e4b
Port: 50634
```

The runner intentionally separates model roles:

```text
System under test: memory backend + local Gemma4-E4B via oMLX
Judge: GPT-5.5 from Codex config, streaming Responses API
```

By default it requires at least `20 GiB` free on the oMLX base-path volume before starting. Use `--min-free-gb` to tune that guard, or `--force-low-disk` only for a deliberate low-disk smoke test.

Useful scope controls:

```bash
--max-sessions 0
--max-questions-per-session 0
--max-extracted-memories-per-session 6
```

`0` means no limit for sessions/questions. `--max-extracted-memories-per-session` caps the number of candidate memories scored by the official Memory Accuracy judge.

---

## Special Configurations for Some Memory Systems

While the experimental setup strives to maintain consistent configurations across all evaluated systems, certain memory systems exhibit unique API constraints that necessitate specific adjustments or workarounds.
Expand Down
273 changes: 273 additions & 0 deletions eval/eval_gemma4_e4b_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import argparse
import copy
import json
import os
import re
import time
from pathlib import Path
from typing import Iterable

from tqdm import tqdm

from llms import llm_request
from prompts import PROMPT_MEMOS


TEMPLATE_LOCAL = """Memories for user {user_id}:

{memories}
"""


def iter_jsonl(file_path: Path) -> Iterable[dict]:
with file_path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
yield json.loads(line)


def extract_user_name(persona_info: str) -> str:
match = re.search(r"Name:\s*(.*?); Gender:", persona_info)
if not match:
raise ValueError("No name found.")
return match.group(1).strip()


def dialogue_for_prompt(session: dict, max_chars: int = 24000) -> str:
lines: list[str] = []
for turn in session.get("dialogue", []):
content = " ".join(str(turn.get("content", "")).split())
if not content:
continue
lines.append(f"[{turn.get('timestamp', '')}] {turn.get('role', '')}: {content}")
dialogue = "\n".join(lines)
if len(dialogue) <= max_chars:
return dialogue
return dialogue[-max_chars:]


def parse_json_object(content: str) -> dict:
match = re.search(r"```json\s*(\{.*?\})\s*```", content, re.DOTALL)
if not match:
match = re.search(r"```\s*(\{.*?\})\s*```", content, re.DOTALL)
if match:
return json.loads(match.group(1).strip())

start = content.find("{")
end = content.rfind("}")
if start >= 0 and end > start:
return json.loads(content[start:end + 1])

raise ValueError(f"No JSON object found in model output: {content[:500]}")


def normalize_memory(memory: str) -> str:
return " ".join(str(memory).split()).strip(" -")


def fallback_user_memories(session: dict, max_memories: int) -> list[str]:
memories: list[str] = []
for turn in session.get("dialogue", []):
if turn.get("role") != "user":
continue
content = normalize_memory(turn.get("content", ""))
if not content:
continue
memories.append(f"{turn.get('timestamp', '')}: {content}")
if len(memories) >= max_memories:
break
return memories


def extract_memories_with_gemma(session: dict, max_memories: int) -> list[str]:
if max_memories <= 0:
return []

prompt = f"""You are the memory extraction module for a personal assistant.

Extract at most {max_memories} concise, self-contained memories from the user's messages only.

Rules:
- Keep only durable personal facts, preferences, plans, relationships, health/work/life details, and explicit updates.
- Include relevant names, dates, places, and quantities when present.
- Ignore assistant-only claims, chit-chat, and instructions that are not user facts.
- Do not invent facts.
- Each memory must be one short sentence.
- Return only JSON with this shape:
{{"memories": ["2023-05-04: The user ..."]}}

Conversation:
{dialogue_for_prompt(session)}
"""
try:
result = parse_json_object(llm_request(prompt))
memories = result.get("memories", [])
except Exception as exc:
print(f"Memory extraction failed, using fallback: {exc}")
return fallback_user_memories(session, max_memories)

normalized: list[str] = []
seen: set[str] = set()
for memory in memories:
item = normalize_memory(memory)
if not item or item in seen:
continue
seen.add(item)
normalized.append(item)
if len(normalized) >= max_memories:
break

if not normalized:
return fallback_user_memories(session, max_memories)
return normalized


def score_memory(query: str, memory: str) -> int:
query_terms = {t.lower() for t in re.findall(r"[A-Za-z0-9']+", query) if len(t) > 2}
memory_terms = {t.lower() for t in re.findall(r"[A-Za-z0-9']+", memory) if len(t) > 2}
return len(query_terms & memory_terms)


def search_memory(query: str, memory_bank: list[str], top_k: int) -> tuple[str, list[str], float]:
start = time.time()
scored = sorted(
((score_memory(query, memory), idx, memory) for idx, memory in enumerate(memory_bank)),
key=lambda item: (item[0], -item[1]),
reverse=True,
)
memories = [memory for score, _, memory in scored if score > 0][:top_k]
if not memories:
memories = memory_bank[-top_k:]
context = TEMPLATE_LOCAL.format(user_id="local", memories="\n".join(memories))
duration_ms = (time.time() - start) * 1000
return context, memories, duration_ms


def process_user(
user_data: dict,
*,
top_k: int,
max_sessions: int,
max_questions_per_session: int,
max_extracted_memories_per_session: int,
) -> dict:
user_name = extract_user_name(user_data["persona_info"])
sessions = user_data["sessions"]
if max_sessions > 0:
sessions = sessions[:max_sessions]

new_user_data = {
"uuid": user_data["uuid"],
"user_name": user_name,
"sessions": [],
}

memory_bank: list[str] = []
for session in tqdm(sessions, total=len(sessions), desc=f"Processing user {user_name}"):
start_add = time.time()
extracted_memories = extract_memories_with_gemma(
session,
max_memories=max_extracted_memories_per_session,
)
memory_bank.extend(extracted_memories)

new_session = {
"memory_points": copy.deepcopy(session["memory_points"]),
"dialogue": session["dialogue"],
"extracted_memories": extracted_memories,
"add_dialogue_duration_ms": (time.time() - start_add) * 1000,
}

for memory in new_session["memory_points"]:
if memory["is_update"] == "False" or not memory["original_memories"]:
continue
_, memories_from_system, duration_ms = search_memory(memory["memory_content"], memory_bank, top_k=10)
memory["memories_from_system"] = memories_from_system
memory["search_duration_ms"] = duration_ms

if "questions" in session:
questions = session["questions"]
if max_questions_per_session > 0:
questions = questions[:max_questions_per_session]
new_session["questions"] = []
for qa in questions:
context, _, duration_ms = search_memory(qa["question"], memory_bank, top_k=top_k)
prompt = PROMPT_MEMOS.format(context=context, question=qa["question"])
start_response = time.time()
response = llm_request(prompt)

new_qa = copy.deepcopy(qa)
new_qa["context"] = context
new_qa["search_duration_ms"] = duration_ms
new_qa["system_response"] = response
new_qa["response_duration_ms"] = (time.time() - start_response) * 1000
new_session["questions"].append(new_qa)

new_user_data["sessions"].append(new_session)

return new_user_data


def main(
*,
data_path: Path,
version: str,
top_k: int,
user_num: int,
max_sessions: int,
max_questions_per_session: int,
max_extracted_memories_per_session: int,
) -> Path:
frame = "gemma4-e4b"
save_path = Path("results") / f"{frame}-{version}"
save_path.mkdir(parents=True, exist_ok=True)
output_file = save_path / f"{frame}_eval_results.jsonl"
tmp_dir = save_path / "tmp"
tmp_dir.mkdir(parents=True, exist_ok=True)

if output_file.exists():
output_file.unlink()

processed = 0
for user_data in iter_jsonl(data_path):
processed += 1
tmp_file = tmp_dir / f"{user_data['uuid']}.json"
result = process_user(
user_data,
top_k=top_k,
max_sessions=max_sessions,
max_questions_per_session=max_questions_per_session,
max_extracted_memories_per_session=max_extracted_memories_per_session,
)
tmp_file.write_text(json.dumps(result, ensure_ascii=False), encoding="utf-8")
with output_file.open("a", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False) + "\n")
print(f"[{processed}/{user_num}] Saved {tmp_file}")
if processed >= user_num:
break

print(f"Final results saved to: {output_file}")
return output_file


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=Path, required=True)
parser.add_argument("--version", required=True)
parser.add_argument("--top-k", type=int, default=20)
parser.add_argument("--user-num", type=int, default=1)
parser.add_argument("--max-sessions", type=int, default=0)
parser.add_argument("--max-questions-per-session", type=int, default=0)
parser.add_argument("--max-extracted-memories-per-session", type=int, default=6)
args = parser.parse_args()

main(
data_path=args.data_path,
version=args.version,
top_k=args.top_k,
user_num=args.user_num,
max_sessions=args.max_sessions,
max_questions_per_session=args.max_questions_per_session,
max_extracted_memories_per_session=args.max_extracted_memories_per_session,
)
Loading