diff --git a/rl/buffer_server.py b/rl/buffer_server.py index 91108d7..e9496c5 100644 --- a/rl/buffer_server.py +++ b/rl/buffer_server.py @@ -7,6 +7,7 @@ import sys import threading import time +import urllib.request import uuid import numpy as np from datetime import datetime @@ -70,6 +71,7 @@ # LLM Proxy URL (constructed from host and port) _llm_proxy_host = get_env("LLM_PROXY_HOST") _llm_proxy_port = get_env("LLM_PROXY_PORT") +llm_proxy_base_url: str = f"http://{_llm_proxy_host}:{_llm_proxy_port}" llm_proxy_url: str = f"http://{_llm_proxy_host}:{_llm_proxy_port}/v1" # Track last served step ID for cursor-based pagination @@ -82,10 +84,27 @@ # step, but this server only reads rows marked is_trainable by the rollout side. completed_sessions_by_instance: Dict[str, Set[str]] = {} +# Dropped incomplete groups are ignored if late rows arrive later. +dropped_instance_ids: Set[str] = set() + # Group size (set by /start_rollout) group_size: int = 1 +def _env_float(name: str, default: float) -> float: + raw = os.getenv(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + logger.warning("Invalid %s=%r, fallback to %s", name, raw, default) + return default + + +INCOMPLETE_GROUP_TTL_SECONDS = _env_float("AIEVOBOX_BUFFER_INCOMPLETE_GROUP_TTL_SECONDS", 1800.0) + + @app.middleware("http") async def set_body_size(request: Request, call_next): request._body_size_limit = 1_073_741_824 # 1GB @@ -175,6 +194,145 @@ def _propagate_terminal_rewards(group: List[Dict[str, Any]]) -> List[Dict[str, A return group +def _group_session_ids(bucket: List[Dict[str, Any]]) -> Set[str]: + session_ids = set() + for item in bucket: + extra = item.get("extra_info") or {} + session_id = str(extra.get("session_id") or "") + if session_id: + session_ids.add(session_id) + return session_ids + + +def _group_latest_timestamp(bucket: List[Dict[str, Any]]) -> float: + timestamps = [] + for item in bucket: + extra = item.get("extra_info") or {} + try: + timestamps.append(float(extra.get("timestamp") or 0.0)) + except (TypeError, ValueError): + pass + return max(timestamps) if timestamps else 0.0 + + +def _notify_llm_proxy_clear_sessions(session_ids: Set[str], reason: str, group_ids: List[str]) -> None: + if not session_ids: + return + + # Best-effort cleanup: rollout serving should continue even if the proxy is unavailable. + payload = json.dumps( + { + "session_ids": sorted(session_ids), + "reason": reason, + "group_ids": group_ids, + } + ).encode("utf-8") + request = urllib.request.Request( + f"{llm_proxy_base_url}/admin/clear_sessions", + data=payload, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=5) as response: + body = response.read().decode("utf-8", errors="replace") + logger.info("llm_proxy clear_sessions response: %s", body) + except Exception as exc: + logger.warning( + "Failed to notify llm_proxy to clear %d sessions for %s: %s", + len(session_ids), + reason, + exc, + ) + + +def _drop_pending_groups(group_ids: List[str], reason: str) -> None: + global pending_items_by_instance, completed_sessions_by_instance, dropped_instance_ids + + now = time.time() + session_ids: Set[str] = set() + dropped_summary = {} + for group_id in group_ids: + bucket = pending_items_by_instance.pop(group_id, []) + completed_sessions = completed_sessions_by_instance.pop(group_id, set()) + dropped_instance_ids.add(group_id) + group_session_ids = _group_session_ids(bucket) + latest_ts = _group_latest_timestamp(bucket) + session_ids.update(group_session_ids) + dropped_summary[group_id] = { + "items": len(bucket), + "sessions": len(group_session_ids), + "completed_sessions": len(completed_sessions), + "age_seconds": round(now - latest_ts, 1) if latest_ts else None, + } + + if not dropped_summary: + return + + logger.warning( + "Dropped incomplete rollout groups: reason=%s groups=%s", + reason, + dropped_summary, + ) + _notify_llm_proxy_clear_sessions(session_ids, reason, group_ids) + + +def _filter_dropped_group_items(new_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not dropped_instance_ids: + return new_items + + # Late DB rows from dropped groups must not recreate stale pending state. + kept = [] + skipped_session_ids: Set[str] = set() + skipped_group_ids: Set[str] = set() + for item in new_items: + group_id = str(item.get("instance_id", "")) + if group_id not in dropped_instance_ids: + kept.append(item) + continue + + skipped_group_ids.add(group_id) + extra = item.get("extra_info") or {} + session_id = str(extra.get("session_id") or "") + if session_id: + skipped_session_ids.add(session_id) + + if skipped_group_ids: + logger.warning( + "Skipped late rows from dropped rollout groups: groups=%s sessions=%d", + sorted(skipped_group_ids), + len(skipped_session_ids), + ) + _notify_llm_proxy_clear_sessions( + skipped_session_ids, + "late_rows_from_dropped_groups", + sorted(skipped_group_ids), + ) + + return kept + + +def cleanup_incomplete_pending_groups() -> None: + if not pending_items_by_instance: + return + + # Drop groups that never reached the configured repeat count within the TTL. + now = time.time() + ttl_drop_group_ids = [] + if INCOMPLETE_GROUP_TTL_SECONDS <= 0: + return + + for group_id, bucket in pending_items_by_instance.items(): + if len(completed_sessions_by_instance.get(group_id, set())) >= group_size: + continue + latest_ts = _group_latest_timestamp(bucket) + if latest_ts and now - latest_ts >= INCOMPLETE_GROUP_TTL_SECONDS: + ttl_drop_group_ids.append(group_id) + + if ttl_drop_group_ids: + _drop_pending_groups(ttl_drop_group_ids, "incomplete_group_ttl") + + async def fetch_new_items_from_db(limit: Optional[int] = None) -> List[Dict[str, Any]]: """Fetch new trainable rows from the database using cursor-based pagination.""" global data_manager, last_served_id @@ -280,7 +438,9 @@ async def get_rollout_data(request: Request): # Fetch new items from database and accumulate groups new_items = await fetch_new_items_from_db(limit=None) + new_items = _filter_dropped_group_items(new_items) ready_groups, finished_ids = accumulate_and_pop_ready_groups(new_items, max_groups=max_groups) + cleanup_incomplete_pending_groups() # Log pending status pending_counts = { @@ -360,7 +520,7 @@ def start_aievobox_process(data: dict): NOTE: LLM Proxy is now hosted in-process by slime_generator. It must already be running before this function is called. """ - global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, data_manager + global aievobox_process, group_size, last_served_id, pending_items_by_instance, completed_sessions_by_instance, dropped_instance_ids, data_manager # Set group size (num_repeat_per_sample) group_size = max(1, int(data.get("num_repeat_per_sample", 16))) @@ -370,6 +530,7 @@ def start_aievobox_process(data: dict): if restart_training: pending_items_by_instance.clear() completed_sessions_by_instance.clear() + dropped_instance_ids.clear() logger.info("restart_training=True, cleared pending items") # Keep a single job_session for both reader and writer process. diff --git a/rl/llm_proxy.py b/rl/llm_proxy.py index 5ba1d36..fe7afed 100644 --- a/rl/llm_proxy.py +++ b/rl/llm_proxy.py @@ -18,7 +18,7 @@ import sys import time from logging.handlers import RotatingFileHandler -from typing import Any, Optional +from typing import Any, List, Optional # Add rl directory to path for utils import _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -323,6 +323,33 @@ async def health_check(): } +@app.post("/admin/clear_sessions") +async def admin_clear_sessions(request: Request): + # Called by buffer_server to release mask-builder state for dropped rollout groups. + builder = STATE.trajectory_mask_builder + if builder is None: + return {"cleared": 0, "requested": 0} + + try: + payload = await request.json() + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON body: {e}") + + raw_session_ids = payload.get("session_ids", []) if isinstance(payload, dict) else [] + if not isinstance(raw_session_ids, list): + raise HTTPException(status_code=400, detail="session_ids must be a list") + + session_ids: List[str] = [str(session_id) for session_id in raw_session_ids if session_id] + cleared = builder.clear_sessions(session_ids) + logger.info( + "Cleared trajectory sessions: requested=%d cleared=%d reason=%s", + len(session_ids), + cleared, + payload.get("reason") if isinstance(payload, dict) else None, + ) + return {"cleared": cleared, "requested": len(session_ids)} + + @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown.""" diff --git a/rl/mask/trajectory_mask_builder.py b/rl/mask/trajectory_mask_builder.py index bde0aa5..9d16334 100644 --- a/rl/mask/trajectory_mask_builder.py +++ b/rl/mask/trajectory_mask_builder.py @@ -516,3 +516,15 @@ def get_training_info( def clear_session(self, session_id: str) -> None: self.session_roots.pop(session_id, None) + + def clear_sessions(self, session_ids: List[str]) -> int: + # Batch cleanup keeps the admin proxy endpoint cheap and deterministic. + cleared = 0 + for session_id in session_ids: + if session_id in self.session_roots: + self.session_roots.pop(session_id, None) + cleared += 1 + return cleared + + def session_count(self) -> int: + return len(self.session_roots)