Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions graph_net/agent/graph_net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
extract_timeout: Optional[int] = None,
verify_timeout: Optional[int] = None,
llm_timeout: int = 360,
max_model_size_b: float = 20.0,
):
"""
Initialize GraphNet Agent
Expand All @@ -63,13 +64,17 @@ def __init__(
verify_timeout: Timeout in seconds for forward verification subprocess
(default None -> 300s).
llm_timeout: Timeout in seconds for LLM script fix (default: 360).
max_model_size_b: Maximum model size in billions of parameters to attempt.
Models exceeding this limit are skipped (AnalysisError).
Default: 20.0B.
"""
if workspace is None:
workspace = os.environ.get(
"GRAPH_NET_EXTRACT_WORKSPACE",
os.path.expanduser("~/graphnet_workspace"),
)
self.workspace = WorkspaceManager(workspace)
self.max_model_size_b = max_model_size_b
self.logger = setup_logger(
"GraphNetAgent",
log_file=self.workspace.get_log_path("agent"),
Expand Down Expand Up @@ -125,7 +130,7 @@ def extract_sample(self, model_id: str) -> ExtractionStatus:

model_dir = self._fetch_model(model_id)
model_dir = self._resolve_model_dir(model_dir)
model_metadata = self._analyze_model(model_dir)
model_metadata = self._analyze_model(model_dir, self.max_model_size_b)
script_path = self._generate_script(model_dir, model_metadata, model_id)

# ── First attempt (template script) ──────────────────────────
Expand Down Expand Up @@ -303,10 +308,12 @@ def _resolve_model_dir(self, model_dir: Path) -> Path:
)
return model_dir

def _analyze_model(self, model_dir: Path):
def _analyze_model(self, model_dir: Path, max_param_b: float = 20.0):
"""Analyze model configuration to extract metadata"""
self.logger.info("Analyzing model configuration")
model_metadata = self.metadata_analyzer.analyze(model_dir)
model_metadata = self.metadata_analyzer.analyze(
model_dir, max_param_b=max_param_b
)
self.logger.info(
f"Metadata: model_type={model_metadata.model_type}, vocab_size={model_metadata.vocab_size}"
)
Expand Down
98 changes: 95 additions & 3 deletions graph_net/agent/metadata_analyzer/config_metadata_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,23 @@ class ConfigMetadataAnalyzer(BaseMetadataAnalyzer):
when available to leverage rich config object properties for architecture detection.
"""

def analyze(self, model_dir: Path) -> ModelMetadata:
def analyze(self, model_dir: Path, max_param_b: float = 20.0) -> ModelMetadata:
"""
Analyze model by parsing config.json (with transformers AutoConfig if available).
Also handles diffusers-style configs that lack a 'model_type' key but have
'_class_name' (e.g., UNet2DConditionModel).

Args:
model_dir: Path to model directory
model_dir: Path to model directory
max_param_b: Maximum allowed estimated parameter count in billions.
Models exceeding this limit are rejected with AnalysisError.
Default 20B.

Returns:
ModelMetadata object

Raises:
MetadataAnalysisError: If analysis fails
MetadataAnalysisError: If analysis fails or model is too large
"""
config_path = model_dir / "config.json"
if not config_path.exists():
Expand All @@ -75,6 +78,16 @@ def analyze(self, model_dir: Path) -> ModelMetadata:
with open(config_path, "r", encoding="utf-8") as f:
cfg_dict = json.load(f)

# Reject models that are too large to load even with random weights
param_b = self._estimate_param_count_billion(cfg_dict)
if param_b > max_param_b:
raise MetadataAnalysisError(
f"Model too large to extract: estimated {param_b:.1f}B parameters "
f"(limit {max_param_b:.1f}B). "
f"Loading random fp32 weights would require ~{param_b * 4:.0f}GB RAM.",
error_category=GraphExtractionErrorCategory.METADATA_ANALYSIS_FAILED,
)

arch_type = self._classify_architecture(cfg_obj, cfg_dict)
input_shapes, input_dtypes = self._extract_input_info(
cfg_obj, cfg_dict, arch_type
Expand Down Expand Up @@ -489,6 +502,85 @@ def _infer_model_type_from_fields(cfg_dict: Dict) -> Optional[str]:
return "bert"
return None

@staticmethod
def _estimate_param_count_billion(config: Dict) -> float:
"""Rough estimate of model parameter count in billions.

Formula covers standard Transformer decoder/encoder:
- Per layer: attention (4 × hidden²) + FFN (3 × hidden × intermediate, SwiGLU style)
- Embedding: 2 × vocab_size × hidden_size (input + output, unshared)
MoE models: total params ≈ num_experts × expert_params, but only a few
experts are active per token. We use total params here because all expert
weights must be loaded into memory even when inactive.
"""
hidden_size = config.get("hidden_size", 768) or 768
num_layers = config.get("num_hidden_layers", 12) or 12
intermediate_size = (
config.get("intermediate_size") or config.get("ffn_dim") or hidden_size * 4
)
vocab_size = config.get("vocab_size", 32000) or 32000

# MoE: total expert count (all experts loaded into RAM)
num_experts = (
config.get("num_experts")
or config.get("num_local_experts")
or config.get("moe_num_experts")
or 1
)

attn_params = 4 * hidden_size * hidden_size # Q, K, V, O
ffn_params = 3 * hidden_size * intermediate_size # gate, up, down (SwiGLU)
expert_ffn_params = ffn_params * int(num_experts)
embed_params = 2 * vocab_size * hidden_size # input + output embedding

total = num_layers * (attn_params + expert_ffn_params) + embed_params
return total / 1e9

@staticmethod
def _estimate_oom_risk(config: Dict) -> str:
"""Estimate GPU OOM risk from config fields.

Returns 'low', 'medium', or 'high'.
'high' means the model should fall back to CPU to avoid OOM.
"""
# Large models (>7B params) need >28GB GPU RAM in fp32 → CPU fallback
param_b = ConfigMetadataAnalyzer._estimate_param_count_billion(config)
if param_b > 7.0:
return "high"
if param_b > 3.0:
return "medium"

vocab_size = config.get("vocab_size", 0) or 0
hidden_size = config.get("hidden_size", 768) or 768
num_layers = config.get("num_hidden_layers", 12) or 12
num_experts = config.get("num_experts") or config.get("num_local_experts") or 1
# Raw context window: models may allocate internal attention buffers
# based on max_position_embeddings regardless of actual input length
raw_ctx_len = config.get("max_position_embeddings", 512) or 512
seq_len = min(raw_ctx_len, 2048)

# Models with very large context may allocate causal mask buffers of
# size max_position_embeddings × max_position_embeddings internally
if raw_ctx_len > 65536:
return "high"
if raw_ctx_len > 16384:
return "medium"

# lm_head output tensor (MB): batch=1 × seq_len × vocab_size × fp32
lm_head_mb = seq_len * vocab_size * 4 / 1024**2
# Activations estimate (MB): layers × seq_len × hidden_size × fp32
activation_mb = num_layers * seq_len * hidden_size * 4 / 1024**2
# MoE amplification (cap at 8 concurrent experts)
moe_factor = min(int(num_experts), 8) if int(num_experts) > 1 else 1

total_est_mb = (lm_head_mb + activation_mb) * moe_factor

if total_est_mb > 8000:
return "high"
if total_est_mb > 3000:
return "medium"
return "low"

def _get_model_id(self, model_dir: Path, config: Dict) -> str:
"""Get model ID from directory or config."""
if "name_or_path" in config:
Expand Down
48 changes: 44 additions & 4 deletions graph_net/agent/parallel_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def worker_fn(
extract_timeout: int,
verify_timeout: int,
llm_retry: bool,
max_model_size_b: float = 20.0,
) -> None:
"""
Worker function, runs in a dedicated subprocess bound to a single GPU or CPU.
Expand Down Expand Up @@ -203,6 +204,7 @@ def _orphan_watcher():
llm_retry=llm_retry,
extract_timeout=extract_timeout,
verify_timeout=verify_timeout,
max_model_size_b=max_model_size_b,
)
except Exception as e:
print(f"{prefix} Failed to initialize agent: {e}", flush=True)
Expand Down Expand Up @@ -415,6 +417,14 @@ def _parse_args() -> argparse.Namespace:
default=False,
help="Enable LLM retry for failed extractions",
)
parser.add_argument(
"--max-model-size-b",
type=str,
default="auto",
help="Maximum model size in billions of parameters to attempt. "
"'auto' calculates from total RAM / workers (default); "
"or specify manually, e.g. '10' for 10B.",
)
return parser.parse_args()


Expand Down Expand Up @@ -459,15 +469,44 @@ def _resolve_config(args: argparse.Namespace):
)
verify_timeout = args.verify_timeout if args.verify_timeout is not None else 300

return workspace, gpus, num_workers, extract_timeout, verify_timeout
# Resolve max_model_size_b
if args.max_model_size_b == "auto":
try:
import psutil

total_ram_gb = psutil.virtual_memory().total / 1024**3
except ImportError:
total_ram_gb = 256.0 # conservative fallback
max_model_size_b = total_ram_gb * 0.7 / num_workers / 4
print(
f"[INFO] max_model_size_b=auto → {max_model_size_b:.1f}B "
f"(RAM={total_ram_gb:.0f}GB, workers={num_workers})"
)
else:
max_model_size_b = float(args.max_model_size_b)
print(f"[INFO] max_model_size_b={max_model_size_b:.1f}B (manually set)")

return (
workspace,
gpus,
num_workers,
extract_timeout,
verify_timeout,
max_model_size_b,
)


def main() -> int:
args = _parse_args()

workspace, gpus, num_workers, extract_timeout, verify_timeout = _resolve_config(
args
)
(
workspace,
gpus,
num_workers,
extract_timeout,
verify_timeout,
max_model_size_b,
) = _resolve_config(args)
llm_retry = args.use_llm

model_ids = _load_model_ids(args)
Expand Down Expand Up @@ -510,6 +549,7 @@ def main() -> int:
extract_timeout,
verify_timeout,
llm_retry,
max_model_size_b,
),
name=f"worker-{worker_id}"
+ (f"-gpu{gpu_id}" if gpu_id is not None else "-cpu"),
Expand Down
Loading