evo2 SAE recipe: streaming extract.py + env-overridable sweep runner#1583
Draft
polinabinder1 wants to merge 8 commits into
Draft
evo2 SAE recipe: streaming extract.py + env-overridable sweep runner#1583polinabinder1 wants to merge 8 commits into
polinabinder1 wants to merge 8 commits into
Conversation
torch 2.6 changed the default of `weights_only` to True. The Savanna checkpoint pickle includes numpy globals (`numpy.core.multiarray._reconstruct`), which the safer loader rejects. The converter then exits 0 with no output written and the error gets buried in stderr — silent failure. The Savanna repos under arcinstitute/* are trusted sources, so load with weights_only=False. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mirrors the existing esm2 / codonfm SAE recipes. Pipeline:
chunk -> convert (Savanna->MBridge) -> predict_evo2 -> pt_to_parquet -> train
Differences from esm2/codonfm are forced by Evo2 specifics:
- Hyena/Megatron-Core model, no HF AutoModel path => reuses the
existing `predict_evo2` CLI for inference instead of writing
a custom extract.py
- `pt_to_parquet.py` shim bridges predict_evo2's .pt output to
the universal `sae.activation_store` parquet contract
- `chunk_fasta.py` preprocessor keeps inputs within the model's
trained context length (8192 bp for 1B); Hyena fftconv OOMs
on long sequences even at micro-batch=1
- `train.py` is the same as codonfm's, copied verbatim per
bionemo-recipes' KISS-over-DRY convention
Validated end-to-end on 100 organelle sequences (Evo2 1B layer 12):
loss 0.67 -> 0.045, FVU 0.90 -> 0.10, var_exp 0.10 -> 0.90, 2m14s wall.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The recipe currently has no model-specific Python module — the extractor is upstream (`predict_evo2`) and the two scripts are simple CLIs in scripts/. Drop the empty package and adjust pyproject.toml so setuptools doesn't try to discover anything. Will reintroduce when there's actual library code to put there (eval, dashboard, dataloaders). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
torchrun --nproc_per_node N can hand a rank an empty batch when the last micro-batch falls past the shard boundary. _padding_collate_fn then crashed in max() with "iterable argument is empty". Return None from the collate when batch is empty and skip the loop iteration in predict(). Required for predict_evo2 to run reliably under DP > 1. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
scripts/extract.py
Codonfm-style streaming activation extractor. Reuses predict_evo2's
Megatron model/DP/dataloader machinery by monkey-patching its
_write_predictions_batch, then streams pad-stripped layer-N activations
directly into an ActivationStore (parquet shards) inside the inference
loop — skipping the .pt intermediate that pt_to_parquet had to walk.
--max-tokens caps each rank's budget. File-based rank wait + merge
(not dist.barrier — predict.main tears down the process group before
the writer hook returns, so the barrier silently no-ops and rank 0
races ahead; observed orphaned 18M tokens before this was fixed).
Saves ~30 min and ~7 TB scratch per 25M-token run vs the old pipeline.
scripts/compose_prokeuk_fasta.py
Builds a balanced prokaryotic + eukaryotic mixed FASTA from
OpenGenome2 subsets (metagenomes + eukaryotic_genic_windows). Truncates
metagenome contigs to --metagenome-window bp each (default 50k) — they
average ~1.1 Mbp, so a handful of full contigs would dominate the mix.
Emits unique seq_{i} headers so predict_evo2's dup-id check passes.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
extract.py replaces the predict_evo2 -> .pt -> pt_to_parquet path with a single streaming step that writes ActivationStore parquet shards directly during inference. Delete the now-unused shim, rewrite 1b.sh as a 3-step pipeline (convert -> extract -> train), and update the README accordingly. 1b.sh: - collapse predict_evo2 + pt_to_parquet into a single 'STEP 2: extract' that calls torchrun extract.py - expose RUN_TAG, PARQUET_DIR/OUTPUT_DIR, MAX_TOKENS, MICRO_BATCH, DEVICES, and SAE training hyperparams (EXPANSION_FACTOR, TOP_K, AUXK, AUXK_COEF, DEAD_TOKENS_THRESHOLD, N_EPOCHS, LR) as env overrides so the same script drives a multi-config sweep - TRAIN_ONLY=1 skips chunk/convert/extract against a cached parquet - WANDB_API_KEY gates wandb logging; WANDB_PROJECT/WANDB_RUN_NAME override README: pipeline diagram + quick-start examples for the new env-overridable flow; remove all references to .pt intermediates and pt_to_parquet. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
topk.py: aux-loss target was `x - recon + pre_bias`, which simplifies to `x - decoder(codes)` -- norm dominated by ||pre_bias|| (~449 on evo2 L22) rather than the actual reconstruction error (~8). The denominator (`target_var = residual.pow(2).mean(-1)`) was inflated by the same factor, so the aux gradient was scaled by roughly (||pre_bias|| / ||error||)^2 ~ 3000x below the canonical formulation. Fix to `residual = x - recon`, matching the OpenAI/Gao TopK formulation. Numerically verified on the 500M L22 checkpoint: residual (a) ||x - recon|| = 8.0 vs buggy (b) ||x - recon + pre_bias|| = 449.7. 1b.sh: default DEAD_TOKENS_THRESHOLD to 10_000_000, matching the train.py default and codonfm convention (Gao et al.). Prior 500_000 default flagged ~70% of latents as 'dead' even when they were firing once per ~half-million tokens, vs codonfm's 0.003% under the canonical threshold. Still overridable via env. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Read this before reviewing
This PR is stacked on top of #1579 (
evo2-sae-recipe). Both branches live on the same fork, so GitHub can't show a base-on-fork comparison — the diff below includes #1579's commits at the bottom. The new work here is the top three commits only:a5b1db7evo2 sae recipe: streaming-extract pipeline, drop pt_to_parquet shim9b7856aevo2 sae: streaming extractor + prok+euk FASTA composerextract.py(+206) andcompose_prokeuk_fasta.py(+142), both new89ff40cevo2_megatron predict: skip empty batches on DP shard boundarypredict.pyNet diff for the new work: 6 files, +439 / −107. Once #1579 merges into
main, this PR's diff auto-cleans to just these three commits.What changes
The recipe's extraction pipeline collapses from two steps to one. Before:
After:
extract.pymonkey-patchespredict_evo2's writer so activations stream straight into the SAEActivationStoreparquet format during inference. No.ptintermediate, no shim. This is the pipeline used for the current 100M / 500M layer-22 prok+euk training runs — the old.ptpath was never used by any real run, sopt_to_parquet.pyis deleted.Files in the new work
scripts/extract.py(new, +206) — streaming activation extractor. Re-usesbionemo.evo2.run.predictfor all the heavy machinery (Megatron model load, DP/CP/TP/PP, FASTA dataloader, inference loop) but swaps the per-batch.ptwriter for an in-processActivationStore. Handles multi-rank merge with a file-based wait (dist.barrier()no-ops afterpredict.main()tears down the process group).scripts/compose_prokeuk_fasta.py(new, +142) — composes a prokaryotic + eukaryotic FASTA mix from OpenGenome2 shards with a hard token budget. Excludes metagenome sources; matches the training mix used in 100M / 500M runs.scripts/pt_to_parquet.py(deleted) — obsoleted byextract.py.scripts/1b.sh(rewritten) — 3-step pipeline (convert -> extract -> train) instead of 4. Env-overridable hyperparams (RUN_TAG,LAYER,MAX_TOKENS,MICRO_BATCH,DEVICES,EXPANSION_FACTOR,TOP_K,AUXK,AUXK_COEF,DEAD_TOKENS_THRESHOLD,N_EPOCHS,LR) so the same script drives a multi-config sweep.TRAIN_ONLY=1skips extraction against a cached parquet.WANDB_API_KEYgates wandb.evo2_megatron/.../predict.py(+8) — skip empty batches on DP shard boundaries. Fixes a hang where the last DP rank could receive 0 sequences and stallpredict_evo2.README.md— updated pipeline diagram and quick-start examples for the new flow.How to use
Compose a custom mix first if needed:
Test plan
bash scripts/1b.shend-to-end on a small FASTA: convert step skips if MBridge ckpt exists, extract writes parquet shards + metadata.json, train completes and writescheckpoint_final.pt.compose_prokeuk_fasta.py --target-tokens 1_000_000writes a FASTA close to the target (±genome-record granularity) and excludes metagenome sources.TRAIN_ONLY=1against an existing parquet skips chunk/convert/extract and runs train directly.predict_evo2on a sequence count not divisible by DP world size completes instead of hanging.LAYER,RUN_TAG,AUXK,N_EPOCHS, etc.) surface in the printed step banners and the underlying CLI flags.🤖 Generated with Claude Code