Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions graph_net/agent/parallel_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,13 @@ def _parse_args() -> argparse.Namespace:
"--gpus",
type=str,
default=None,
help="Comma-separated GPU indices to use (GPU mode; if set, ignores --num-workers)",
help="Comma-separated GPU indices to use (default: auto-detect all available GPUs)",
)
parser.add_argument(
"--num-workers",
"--cpu-workers",
type=int,
default=None,
help="Number of worker processes in CPU mode (default: CPU count)",
help="Number of worker processes in CPU-only mode (default: half of CPU cores)",
)
parser.add_argument(
"--output",
Expand Down Expand Up @@ -435,22 +435,29 @@ def _resolve_config(args: argparse.Namespace):
)
print(f"[INFO] Workspace: {workspace}")

if get_device_type() == "cuda":
# Decide GPU vs CPU mode: if --cpu-workers is set, force CPU-only mode.
# If no CUDA available, also fall back to CPU mode.
if (args.cpu_workers and args.cpu_workers > 0) or get_device_type() == "cpu":
# CPU-only mode. Default to half of CPU cores to avoid overloading the
# system, since each worker is a heavy process (model loading + graph
# extraction).
gpus = []
num_workers = (
args.cpu_workers if args.cpu_workers else max(1, (os.cpu_count() or 2) // 2)
)
print(f"[INFO] CPU-only mode: {num_workers} workers")
extract_timeout = (
args.extract_timeout if args.extract_timeout is not None else 2000
)
verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600
else:
gpus = get_gpu_ids(args)
num_workers = len(gpus)
print(f"[INFO] GPU mode (torch fallback): {gpus}")
print(f"[INFO] GPU mode: {num_workers} workers on GPUs {gpus}")
extract_timeout = (
args.extract_timeout if args.extract_timeout is not None else 1000
)
verify_timeout = args.verify_timeout if args.verify_timeout is not None else 300
else:
gpus = []
num_workers = args.num_workers if args.num_workers else 1
print(f"[INFO] CPU mode: {num_workers} workers")
extract_timeout = (
args.extract_timeout if args.extract_timeout is not None else 2000
)
verify_timeout = args.verify_timeout if args.verify_timeout is not None else 600

return workspace, gpus, num_workers, extract_timeout, verify_timeout

Expand All @@ -468,7 +475,10 @@ def main() -> int:
print("[ERROR] Empty model list, nothing to do")
return 1

print(f"[INFO] Total models: {len(model_ids)}, workers: {num_workers}")
if gpus:
print(f"[INFO] Total models: {len(model_ids)}, GPU workers: {num_workers}")
else:
print(f"[INFO] Total models: {len(model_ids)}, CPU workers: {num_workers}")

# --- Populate shared task queue ---
task_queue: multiprocessing.Queue = multiprocessing.Queue()
Expand All @@ -479,8 +489,9 @@ def main() -> int:
result_queue: multiprocessing.Queue = multiprocessing.Queue()

start_time = datetime.now()
worker_type = "GPU" if gpus else "CPU"
print(
f"\n[START] {start_time.strftime('%Y-%m-%d %H:%M:%S')} — launching {num_workers} workers\n"
f"\n[START] {start_time.strftime('%Y-%m-%d %H:%M:%S')} — launching {num_workers} {worker_type} workers\n"
)

processes = []
Expand Down
Loading