Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion rl/buffer_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import threading
import time
import urllib.request
import uuid
import numpy as np
from datetime import datetime
Expand Down Expand Up @@ -70,6 +71,7 @@
# LLM Proxy URL (constructed from host and port)
_llm_proxy_host = get_env("LLM_PROXY_HOST")
_llm_proxy_port = get_env("LLM_PROXY_PORT")
llm_proxy_base_url: str = f"http://{_llm_proxy_host}:{_llm_proxy_port}"
llm_proxy_url: str = f"http://{_llm_proxy_host}:{_llm_proxy_port}/v1"

# Track last served step ID for cursor-based pagination
Expand All @@ -82,10 +84,27 @@
# step, but this server only reads rows marked is_trainable by the rollout side.
completed_sessions_by_instance: Dict[str, Set[str]] = {}

# Dropped incomplete groups are ignored if late rows arrive later.
dropped_instance_ids: Set[str] = set()

# Group size (set by /start_rollout)
group_size: int = 1


def _env_float(name: str, default: float) -> float:
raw = os.getenv(name)
if raw is None:
return default
try:
return float(raw)
except ValueError:
logger.warning("Invalid %s=%r, fallback to %s", name, raw, default)
return default


INCOMPLETE_GROUP_TTL_SECONDS = _env_float("AIEVOBOX_BUFFER_INCOMPLETE_GROUP_TTL_SECONDS", 1800.0)


@app.middleware("http")
async def set_body_size(request: Request, call_next):
request._body_size_limit = 1_073_741_824 # 1GB
Expand Down Expand Up @@ -175,6 +194,145 @@ def _propagate_terminal_rewards(group: List[Dict[str, Any]]) -> List[Dict[str, A
return group


def _group_session_ids(bucket: List[Dict[str, Any]]) -> Set[str]:
session_ids = set()
for item in bucket:
extra = item.get("extra_info") or {}
session_id = str(extra.get("session_id") or "")
if session_id:
session_ids.add(session_id)
return session_ids


def _group_latest_timestamp(bucket: List[Dict[str, Any]]) -> float:
timestamps = []
for item in bucket:
extra = item.get("extra_info") or {}
try:
timestamps.append(float(extra.get("timestamp") or 0.0))
except (TypeError, ValueError):
pass
return max(timestamps) if timestamps else 0.0


def _notify_llm_proxy_clear_sessions(session_ids: Set[str], reason: str, group_ids: List[str]) -> None:
if not session_ids:
return

# Best-effort cleanup: rollout serving should continue even if the proxy is unavailable.
payload = json.dumps(
{
"session_ids": sorted(session_ids),
"reason": reason,
"group_ids": group_ids,
}
).encode("utf-8")
request = urllib.request.Request(
f"{llm_proxy_base_url}/admin/clear_sessions",
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=5) as response:
body = response.read().decode("utf-8", errors="replace")
logger.info("llm_proxy clear_sessions response: %s", body)
except Exception as exc:
logger.warning(
"Failed to notify llm_proxy to clear %d sessions for %s: %s",
len(session_ids),
reason,
exc,
)


def _drop_pending_groups(group_ids: List[str], reason: str) -> None:
global pending_items_by_instance, completed_sessions_by_instance, dropped_instance_ids

now = time.time()
session_ids: Set[str] = set()
dropped_summary = {}
for group_id in group_ids:
bucket = pending_items_by_instance.pop(group_id, [])
completed_sessions = completed_sessions_by_instance.pop(group_id, set())
dropped_instance_ids.add(group_id)
group_session_ids = _group_session_ids(bucket)
latest_ts = _group_latest_timestamp(bucket)
session_ids.update(group_session_ids)
dropped_summary[group_id] = {
"items": len(bucket),
"sessions": len(group_session_ids),
"completed_sessions": len(completed_sessions),
"age_seconds": round(now - latest_ts, 1) if latest_ts else None,
}

if not dropped_summary:
return

logger.warning(
"Dropped incomplete rollout groups: reason=%s groups=%s",
reason,
dropped_summary,
)
_notify_llm_proxy_clear_sessions(session_ids, reason, group_ids)


def _filter_dropped_group_items(new_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not dropped_instance_ids:
return new_items

# Late DB rows from dropped groups must not recreate stale pending state.
kept = []
skipped_session_ids: Set[str] = set()
skipped_group_ids: Set[str] = set()
for item in new_items:
group_id = str(item.get("instance_id", ""))
if group_id not in dropped_instance_ids:
kept.append(item)
continue

skipped_group_ids.add(group_id)
extra = item.get("extra_info") or {}
session_id = str(extra.get("session_id") or "")
if session_id:
skipped_session_ids.add(session_id)

if skipped_group_ids:
logger.warning(
"Skipped late rows from dropped rollout groups: groups=%s sessions=%d",
sorted(skipped_group_ids),
len(skipped_session_ids),
)
_notify_llm_proxy_clear_sessions(
skipped_session_ids,
"late_rows_from_dropped_groups",
sorted(skipped_group_ids),
)

return kept


def cleanup_incomplete_pending_groups() -> None:
if not pending_items_by_instance:
return

# Drop groups that never reached the configured repeat count within the TTL.
now = time.time()
ttl_drop_group_ids = []
if INCOMPLETE_GROUP_TTL_SECONDS <= 0:
return

for group_id, bucket in pending_items_by_instance.items():
if len(completed_sessions_by_instance.get(group_id, set())) >= group_size:
continue
latest_ts = _group_latest_timestamp(bucket)
if latest_ts and now - latest_ts >= INCOMPLETE_GROUP_TTL_SECONDS:
ttl_drop_group_ids.append(group_id)

if ttl_drop_group_ids:
_drop_pending_groups(ttl_drop_group_ids, "incomplete_group_ttl")


async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""Fetch new trainable rows from the database using cursor-based pagination."""
global data_manager, last_served_id
Expand Down Expand Up @@ -280,7 +438,9 @@ async def get_rollout_data(request: Request):

# Fetch new items from database and accumulate groups
new_items = await fetch_new_items_from_db(limit=None)
new_items = _filter_dropped_group_items(new_items)
ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items, max_groups=max_groups)
cleanup_incomplete_pending_groups()

# Log pending status
pending_counts = {
Expand Down Expand Up @@ -360,7 +520,7 @@ def start_aievobox_process(data: dict):
NOTE: LLM Proxy is now hosted in-process by slime_generator.
It must already be running before this function is called.
"""
global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, data_manager
global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, dropped_instance_ids, data_manager

# Set group size (num_repeat_per_sample)
group_size = max(1, int(data.get("num_repeat_per_sample", 16)))
Expand All @@ -370,6 +530,7 @@ def start_aievobox_process(data: dict):
if restart_training:
pending_items_by_instance.clear()
completed_sessions_by_instance.clear()
dropped_instance_ids.clear()
logger.info("restart_training=True, cleared pending items")

# Keep a single job_session for both reader and writer process.
Expand Down
29 changes: 28 additions & 1 deletion rl/llm_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
import time
from logging.handlers import RotatingFileHandler
from typing import Any, Optional
from typing import Any, List, Optional

# Add rl directory to path for utils import
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -323,6 +323,33 @@ async def health_check():
}


@app.post("/admin/clear_sessions")
async def admin_clear_sessions(request: Request):
# Called by buffer_server to release mask-builder state for dropped rollout groups.
builder = STATE.trajectory_mask_builder
if builder is None:
return {"cleared": 0, "requested": 0}

try:
payload = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {e}")

raw_session_ids = payload.get("session_ids", []) if isinstance(payload, dict) else []
if not isinstance(raw_session_ids, list):
raise HTTPException(status_code=400, detail="session_ids must be a list")

session_ids: List[str] = [str(session_id) for session_id in raw_session_ids if session_id]
cleared = builder.clear_sessions(session_ids)
logger.info(
"Cleared trajectory sessions: requested=%d cleared=%d reason=%s",
len(session_ids),
cleared,
payload.get("reason") if isinstance(payload, dict) else None,
)
return {"cleared": cleared, "requested": len(session_ids)}


@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown."""
Expand Down
12 changes: 12 additions & 0 deletions rl/mask/trajectory_mask_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,15 @@ def get_training_info(

def clear_session(self, session_id: str) -> None:
self.session_roots.pop(session_id, None)

def clear_sessions(self, session_ids: List[str]) -> int:
# Batch cleanup keeps the admin proxy endpoint cheap and deterministic.
cleared = 0
for session_id in session_ids:
if session_id in self.session_roots:
self.session_roots.pop(session_id, None)
cleared += 1
return cleared

def session_count(self) -> int:
return len(self.session_roots)