diff --git a/graph_net/agent/graph_net_agent.py b/graph_net/agent/graph_net_agent.py index c25928e86..d62d6f296 100644 --- a/graph_net/agent/graph_net_agent.py +++ b/graph_net/agent/graph_net_agent.py @@ -48,6 +48,7 @@ def __init__( extract_timeout: Optional[int] = None, verify_timeout: Optional[int] = None, llm_timeout: int = 360, + max_model_size_b: float = 20.0, ): """ Initialize GraphNet Agent @@ -63,6 +64,9 @@ def __init__( verify_timeout: Timeout in seconds for forward verification subprocess (default None -> 300s). llm_timeout: Timeout in seconds for LLM script fix (default: 360). + max_model_size_b: Maximum model size in billions of parameters to attempt. + Models exceeding this limit are skipped (AnalysisError). + Default: 20.0B. """ if workspace is None: workspace = os.environ.get( @@ -70,6 +74,7 @@ def __init__( os.path.expanduser("~/graphnet_workspace"), ) self.workspace = WorkspaceManager(workspace) + self.max_model_size_b = max_model_size_b self.logger = setup_logger( "GraphNetAgent", log_file=self.workspace.get_log_path("agent"), @@ -125,7 +130,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus: model_dir = self._fetch_model(model_id) model_dir = self._resolve_model_dir(model_dir) - model_metadata = self._analyze_model(model_dir) + model_metadata = self._analyze_model(model_dir, self.max_model_size_b) script_path = self._generate_script(model_dir, model_metadata, model_id) # ── First attempt (template script) ────────────────────────── @@ -303,10 +308,12 @@ def _resolve_model_dir(self, model_dir: Path) -> Path: ) return model_dir - def _analyze_model(self, model_dir: Path): + def _analyze_model(self, model_dir: Path, max_param_b: float = 20.0): """Analyze model configuration to extract metadata""" self.logger.info("Analyzing model configuration") - model_metadata = self.metadata_analyzer.analyze(model_dir) + model_metadata = self.metadata_analyzer.analyze( + model_dir, max_param_b=max_param_b + ) self.logger.info( f"Metadata: model_type={model_metadata.model_type}, vocab_size={model_metadata.vocab_size}" ) diff --git a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py index 3e6213306..81b81326f 100644 --- a/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py +++ b/graph_net/agent/metadata_analyzer/config_metadata_analyzer.py @@ -37,20 +37,23 @@ class ConfigMetadataAnalyzer(BaseMetadataAnalyzer): when available to leverage rich config object properties for architecture detection. """ - def analyze(self, model_dir: Path) -> ModelMetadata: + def analyze(self, model_dir: Path, max_param_b: float = 20.0) -> ModelMetadata: """ Analyze model by parsing config.json (with transformers AutoConfig if available). Also handles diffusers-style configs that lack a 'model_type' key but have '_class_name' (e.g., UNet2DConditionModel). Args: - model_dir: Path to model directory + model_dir: Path to model directory + max_param_b: Maximum allowed estimated parameter count in billions. + Models exceeding this limit are rejected with AnalysisError. + Default 20B. Returns: ModelMetadata object Raises: - MetadataAnalysisError: If analysis fails + MetadataAnalysisError: If analysis fails or model is too large """ config_path = model_dir / "config.json" if not config_path.exists(): @@ -75,6 +78,16 @@ def analyze(self, model_dir: Path) -> ModelMetadata: with open(config_path, "r", encoding="utf-8") as f: cfg_dict = json.load(f) + # Reject models that are too large to load even with random weights + param_b = self._estimate_param_count_billion(cfg_dict) + if param_b > max_param_b: + raise MetadataAnalysisError( + f"Model too large to extract: estimated {param_b:.1f}B parameters " + f"(limit {max_param_b:.1f}B). " + f"Loading random fp32 weights would require ~{param_b * 4:.0f}GB RAM.", + error_category=GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED, + ) + arch_type = self._classify_architecture(cfg_obj, cfg_dict) input_shapes, input_dtypes = self._extract_input_info( cfg_obj, cfg_dict, arch_type @@ -489,6 +502,85 @@ def _infer_model_type_from_fields(cfg_dict: Dict) -> Optional[str]: return "bert" return None + @staticmethod + def _estimate_param_count_billion(config: Dict) -> float: + """Rough estimate of model parameter count in billions. + + Formula covers standard Transformer decoder/encoder: + - Per layer: attention (4 × hidden²) + FFN (3 × hidden × intermediate, SwiGLU style) + - Embedding: 2 × vocab_size × hidden_size (input + output, unshared) + MoE models: total params ≈ num_experts × expert_params, but only a few + experts are active per token. We use total params here because all expert + weights must be loaded into memory even when inactive. + """ + hidden_size = config.get("hidden_size", 768) or 768 + num_layers = config.get("num_hidden_layers", 12) or 12 + intermediate_size = ( + config.get("intermediate_size") or config.get("ffn_dim") or hidden_size * 4 + ) + vocab_size = config.get("vocab_size", 32000) or 32000 + + # MoE: total expert count (all experts loaded into RAM) + num_experts = ( + config.get("num_experts") + or config.get("num_local_experts") + or config.get("moe_num_experts") + or 1 + ) + + attn_params = 4 * hidden_size * hidden_size # Q, K, V, O + ffn_params = 3 * hidden_size * intermediate_size # gate, up, down (SwiGLU) + expert_ffn_params = ffn_params * int(num_experts) + embed_params = 2 * vocab_size * hidden_size # input + output embedding + + total = num_layers * (attn_params + expert_ffn_params) + embed_params + return total / 1e9 + + @staticmethod + def _estimate_oom_risk(config: Dict) -> str: + """Estimate GPU OOM risk from config fields. + + Returns 'low', 'medium', or 'high'. + 'high' means the model should fall back to CPU to avoid OOM. + """ + # Large models (>7B params) need >28GB GPU RAM in fp32 → CPU fallback + param_b = ConfigMetadataAnalyzer._estimate_param_count_billion(config) + if param_b > 7.0: + return "high" + if param_b > 3.0: + return "medium" + + vocab_size = config.get("vocab_size", 0) or 0 + hidden_size = config.get("hidden_size", 768) or 768 + num_layers = config.get("num_hidden_layers", 12) or 12 + num_experts = config.get("num_experts") or config.get("num_local_experts") or 1 + # Raw context window: models may allocate internal attention buffers + # based on max_position_embeddings regardless of actual input length + raw_ctx_len = config.get("max_position_embeddings", 512) or 512 + seq_len = min(raw_ctx_len, 2048) + + # Models with very large context may allocate causal mask buffers of + # size max_position_embeddings × max_position_embeddings internally + if raw_ctx_len > 65536: + return "high" + if raw_ctx_len > 16384: + return "medium" + + # lm_head output tensor (MB): batch=1 × seq_len × vocab_size × fp32 + lm_head_mb = seq_len * vocab_size * 4 / 1024**2 + # Activations estimate (MB): layers × seq_len × hidden_size × fp32 + activation_mb = num_layers * seq_len * hidden_size * 4 / 1024**2 + # MoE amplification (cap at 8 concurrent experts) + moe_factor = min(int(num_experts), 8) if int(num_experts) > 1 else 1 + + total_est_mb = (lm_head_mb + activation_mb) * moe_factor + + if total_est_mb > 8000: + return "high" + if total_est_mb > 3000: + return "medium" + return "low" + def _get_model_id(self, model_dir: Path, config: Dict) -> str: """Get model ID from directory or config.""" if "name_or_path" in config: diff --git a/graph_net/agent/parallel_extract.py b/graph_net/agent/parallel_extract.py index 42eea8bd0..ad1ab7384 100644 --- a/graph_net/agent/parallel_extract.py +++ b/graph_net/agent/parallel_extract.py @@ -136,6 +136,7 @@ def worker_fn( extract_timeout: int, verify_timeout: int, llm_retry: bool, + max_model_size_b: float = 20.0, ) -> None: """ Worker function, runs in a dedicated subprocess bound to a single GPU or CPU. @@ -203,6 +204,7 @@ def _orphan_watcher(): llm_retry=llm_retry, extract_timeout=extract_timeout, verify_timeout=verify_timeout, + max_model_size_b=max_model_size_b, ) except Exception as e: print(f"{prefix} Failed to initialize agent: {e}", flush=True) @@ -415,6 +417,14 @@ def _parse_args() -> argparse.Namespace: default=False, help="Enable LLM retry for failed extractions", ) + parser.add_argument( + "--max-model-size-b", + type=str, + default="auto", + help="Maximum model size in billions of parameters to attempt. " + "'auto' calculates from total RAM / workers (default); " + "or specify manually, e.g. '10' for 10B.", + ) return parser.parse_args() @@ -459,15 +469,44 @@ def _resolve_config(args: argparse.Namespace): ) verify_timeout = args.verify_timeout if args.verify_timeout is not None else 300 - return workspace, gpus, num_workers, extract_timeout, verify_timeout + # Resolve max_model_size_b + if args.max_model_size_b == "auto": + try: + import psutil + + total_ram_gb = psutil.virtual_memory().total / 1024**3 + except ImportError: + total_ram_gb = 256.0 # conservative fallback + max_model_size_b = total_ram_gb * 0.7 / num_workers / 4 + print( + f"[INFO] max_model_size_b=auto → {max_model_size_b:.1f}B " + f"(RAM={total_ram_gb:.0f}GB, workers={num_workers})" + ) + else: + max_model_size_b = float(args.max_model_size_b) + print(f"[INFO] max_model_size_b={max_model_size_b:.1f}B (manually set)") + + return ( + workspace, + gpus, + num_workers, + extract_timeout, + verify_timeout, + max_model_size_b, + ) def main() -> int: args = _parse_args() - workspace, gpus, num_workers, extract_timeout, verify_timeout = _resolve_config( - args - ) + ( + workspace, + gpus, + num_workers, + extract_timeout, + verify_timeout, + max_model_size_b, + ) = _resolve_config(args) llm_retry = args.use_llm model_ids = _load_model_ids(args) @@ -510,6 +549,7 @@ def main() -> int: extract_timeout, verify_timeout, llm_retry, + max_model_size_b, ), name=f"worker-{worker_id}" + (f"-gpu{gpu_id}" if gpu_id is not None else "-cpu"),