Skip to content

Scale beam search: multi-LLM, samples-per-prompt, multi-GPU, PTX dedup, NCU caching#139

Open
jiannanWang wants to merge 1 commit into
mainfrom
jiannanWang/scale-beam-search
Open

Scale beam search: multi-LLM, samples-per-prompt, multi-GPU, PTX dedup, NCU caching#139
jiannanWang wants to merge 1 commit into
mainfrom
jiannanWang/scale-beam-search

Conversation

@jiannanWang
Copy link
Copy Markdown
Contributor

Summary

Beam search on main runs with a single LLM, a single GPU, and no candidate deduplication. That's enough for easy problems but plateaus quickly: a single model produces near-duplicate kernels each round, and on harder problems a single model often can't produce any working kernel at all (e.g., claude-opus-4.5 alone yields 0/32 successful kernels on the Gemma3 fused gate-up workload).

This PR scales beam search to run multiple LLMs in parallel, expand multiple samples per (parent, bottleneck, model) triple, distribute workers across multiple GPUs, and deduplicate beam entries by compiled-PTX fingerprint so the beam doesn't fill with byte-equivalent copies of the leader.

All new features are opt-in via strategy_config. Existing flows that don't set models or samples_per_prompt are unchanged.

Example commands

Three new presets ship with the PR:

# 90-worker concentrated: 2 parents × 3 bottlenecks × 3 models × 5 samples                                                                                                   
python examples/run_opt_manager.py \                                                                                                                                         
  --kernel-dir examples/optimize_01_matvec \                                                                                                                                 
  --strategy beam_search_diverse_concentrated \                                                                                                                              
  --max-rounds 8                                                                                                                                                             
                                                                                                                                                                             
# 90-worker spread: 5 parents × 3 bottlenecks × 3 models × 2 samples                                                                                                         
python examples/run_opt_manager.py \     
  --kernel-dir examples/optimize_01_matvec \
  --strategy beam_search_diverse \       
  --max-rounds 8       

# 9-worker smoke for pipeline validation 
python examples/run_opt_manager.py \                             
  --kernel-dir examples/optimize_01_matvec \
  --strategy beam_search_diverse_smoke \ 
  --max-rounds 1

Each preset configures models: [claude-opus-4.6, gpt-5-4, gemini-2-5-pro] out of the box, routed via the existing RelayProvider fallback.

Experiment results

Two problems, 8 rounds each:

┌───────────────┬───────────────────────────────────┬──────────────────────────────────────────────┐                                                                         
│    Problem    │ main (4 workers, claude-opus-4.5) │  This PR (90 workers, 3 LLMs concentrated)   │
├───────────────┼───────────────────────────────────┼──────────────────────────────────────────────┤                                                                         
│ matvec        │ 2.1494 ms — 0.95× PyTorch eager   │ 1.8318 ms — 1.12× PyTorch eager              │
├───────────────┼───────────────────────────────────┼──────────────────────────────────────────────┤
│ gemma3_swiglu │ failed: 0/32 working kernels      │ 0.2716 ms — 3.10× eager, 1.14× torch.compile │                                                                         
└───────────────┴───────────────────────────────────┴──────────────────────────────────────────────┘                                                                         

The matvec improvement comes from richer fanout exposing more refinement directions per round. The gemma3_swiglu improvement comes from multi-LLM breaking through the single-model correctness floor — gpt-5-4 and gemini-2-5-pro find working kernels for this problem where claude-opus-4.5 does not.

What is changed

  • models: list[str] strategy knob — per-candidate LLM override; each worker round-robins through the configured list.
  • samples_per_prompt: int strategy knob — N independent LLM draws per (parent, bottleneck, model) triple. Default = 1.
  • Per-GPU benchmark lock pool — dict[gpu_id, mp.Lock] replaces the single global lock; each worker acquires only its own GPU's lock.
  • Worker→GPU pinning — CUDA_VISIBLE_DEVICES=<gpu_id> is set in each worker process before any torch import.
  • nvidia-smi-based GPU detection — OptimizationManager._detect_gpus no longer touches torch.cuda (which would lock the device list for forked children).
  • PTX-fingerprint module — searching/ptx_fingerprint.py normalizes PTX (strip comments, canonicalize registers / labels) and SHA-256-hashes it; persisted on ProgramEntry.ptx_hash.
  • PTX capture in benchmark — sets TRITON_CACHE_DIR= and reads the resulting *.ptx files to compute the hash alongside time_ms.
  • PTX-hash dedup in beam selection — BeamSearchStrategy.update_with_results collapses pool entries with identical hash to the fastest representative before sort+truncate. New PTX dedup X→Y line in round logs.
  • Per-parent NCU caching — manager profiles each unique parent kernel once per round and shares the result with siblings instead of N re-profiles.
  • Three new strategy presets — beam_search_diverse.yaml, beam_search_diverse_concentrated.yaml, beam_search_diverse_smoke.yaml. Existing presets unchanged.

- BeamSearchStrategy: add models / samples_per_prompt / num_expanding_parents
  knobs; expansion now P × M × K × C with per-candidate openai_model and
  sample_idx threaded through the worker dispatch.

- PTX fingerprint dedup: new ptx_fingerprint.py captures compiled PTX
  from a per-call TRITON_CACHE_DIR during benchmarking, normalizes
  (strip comments/debug/headers, canonicalize register/label names),
  SHA-256 hashes. update_with_results dedups the combined pool by hash
  before sort+truncate; ProgramEntry / json_db carry ptx_hash.

- Multi-GPU: per-GPU mp.Lock pool (single lock covers both benchmark
  and NCU on a given GPU), round-robin worker -> GPU assignment,
  CUDA_VISIBLE_DEVICES=<gpu_id> set in the worker process before any
  torch import. Manager auto-detects via nvidia-smi (NOT torch.cuda) to
  avoid poisoning forked children with an inherited CUDA context.

- Per-parent baseline NCU cache: manager profiles each unique parent
  once per round and attaches baseline_metrics to each candidate dict;
  workers skip their own NCU when the cache is populated.

- Bottleneck plumbing fix: num_bottlenecks is now wired from
  strategy_config -> worker_kwargs[num_bottlenecks_to_request] ->
  BottleneckAnalyzer. Pre-fix the analyzer always asked for 1 ranked
  bottleneck so workers with bottleneck_id >= 2 silently fell back to
  rank 1.

- mp.Queue feeder-thread deadlock fix: NvidiaWorkerRunner.run_workers
  now drains the queue interleaved with join(timeout=0.5) polling
  instead of joining all workers serially before draining.

- best_runtime_ptx_hash propagation: orchestrator captures the hash
  after _update_kernels (was previously checked before, when the
  comparison was tautologically false), and parent-hash byte-identity
  fallback in update_with_results lets unchanged-parent results inherit
  the parent's hash so they collapse correctly in dedup.

- ncu_profiler.py: NaN-safe units-row detection (str.lower()
  propagated pd.NA back to float NaN, breaking the substring check).

- Configs: examples/configs/beam_search_diverse.yaml (spread, P=5/C=2),
  beam_search_diverse_concentrated.yaml (P=2/C=5),
  beam_search_diverse_smoke.yaml (smoke variant).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant