From 60deff46f045dd2916b4cff3e53565e0a95b7c3c Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 21 May 2026 00:41:33 +0000 Subject: [PATCH 1/3] evo2_megatron: load Savanna HF checkpoints with weights_only=False MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch 2.6 changed the default of `weights_only` to True. The Savanna checkpoint pickle includes numpy globals (`numpy.core.multiarray._reconstruct`), which the safer loader rejects. The converter then exits 0 with no output written and the error gets buried in stderr — silent failure. The Savanna repos under arcinstitute/* are trusted sources, so load with weights_only=False. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py index 811b07153e..156ce530b5 100644 --- a/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py +++ b/bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/utils/checkpoint/savanna_to_mbridge.py @@ -135,7 +135,7 @@ def load_savanna_state_dict(path: Path) -> dict[str, torch.Tensor]: Returns: Flat state dict with keys like 'sequential.{i}.xxx'. """ - raw = torch.load(str(path), map_location="cpu", weights_only=True, mmap=True) + raw = torch.load(str(path), map_location="cpu", weights_only=False, mmap=True) if "module" in raw: raw = raw["module"] From b640f66ab62600ccf5da05f8ced1bb7f2797534f Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 21 May 2026 00:43:15 +0000 Subject: [PATCH 2/3] interpretability/sae: add Evo2 1B SAE recipe Mirrors the existing esm2 / codonfm SAE recipes. Pipeline: chunk -> convert (Savanna->MBridge) -> predict_evo2 -> pt_to_parquet -> train Differences from esm2/codonfm are forced by Evo2 specifics: - Hyena/Megatron-Core model, no HF AutoModel path => reuses the existing `predict_evo2` CLI for inference instead of writing a custom extract.py - `pt_to_parquet.py` shim bridges predict_evo2's .pt output to the universal `sae.activation_store` parquet contract - `chunk_fasta.py` preprocessor keeps inputs within the model's trained context length (8192 bp for 1B); Hyena fftconv OOMs on long sequences even at micro-batch=1 - `train.py` is the same as codonfm's, copied verbatim per bionemo-recipes' KISS-over-DRY convention Validated end-to-end on 100 organelle sequences (Evo2 1B layer 12): loss 0.67 -> 0.045, FVU 0.90 -> 0.10, var_exp 0.10 -> 0.90, 2m14s wall. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2/README.md | 30 ++ .../recipes/evo2/pyproject.toml | 24 ++ .../recipes/evo2/scripts/1b.sh | 116 +++++++ .../recipes/evo2/scripts/chunk_fasta.py | 73 ++++ .../recipes/evo2/scripts/pt_to_parquet.py | 65 ++++ .../recipes/evo2/scripts/train.py | 321 ++++++++++++++++++ .../recipes/evo2/src/evo2_sae/__init__.py | 16 + 7 files changed, 645 insertions(+) create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml create mode 100755 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py create mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md new file mode 100644 index 0000000000..ad749dbedb --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/README.md @@ -0,0 +1,30 @@ +# Evo2 SAE Recipe + +Train a sparse autoencoder on Evo2 (DNA language model) residual-stream activations. + +Pipeline: + +``` +HF Savanna ckpt --convert--> MBridge ckpt + | + predict_evo2 --embedding-layer N (FASTA in, .pt out) + | + pt_to_parquet shim (.pt -> ActivationStore parquet shards) + | + train.py (TopK SAE) +``` + +The eval / dashboard stage from the esm2 recipe is intentionally not ported in v1. + +## Quick start (1B model, single GPU) + +```bash +bash scripts/1b.sh +``` + +This will: + +1. Convert `arcinstitute/savanna_evo2_1b_base` to MBridge format +2. Run `predict_evo2` on the OpenGenome2 organelle FASTA, extracting layer-12 embeddings +3. Convert the .pt outputs to parquet shards +4. Train a TopK SAE (expansion=8, k=32) diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml new file mode 100644 index 0000000000..26eff6b55c --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -0,0 +1,24 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "evo2-sae" +version = "0.1.0" +description = "Sparse Autoencoders for the Evo2 DNA language model" +readme = "README.md" +requires-python = ">=3.10" + +dependencies = [ + "sae", + "torch>=2.0", + "numpy>=1.20", + "tqdm>=4.60", + "pyarrow>=10.0", +] + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.uv.sources] +sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh new file mode 100755 index 0000000000..d499b4f365 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/1b.sh @@ -0,0 +1,116 @@ +#!/bin/bash +# Evo2 1B SAE pipeline: convert -> predict_evo2 -> pt_to_parquet -> train. +# +# Assumes: +# - bionemo-recipes/recipes/evo2_megatron has been built (.ci_build.sh) and +# its .venv is active, providing predict_evo2 + evo2_convert_savanna_to_mbridge. +# - The sae workspace package is importable in that same venv. +# - HF_TOKEN is set if Savanna checkpoint repo is gated. +# +# Override any of these by exporting before invocation. + +set -euo pipefail + +EVO2_MEGATRON_DIR="${EVO2_MEGATRON_DIR:-/workspace/bionemo-framework/bionemo-recipes/recipes/evo2_megatron}" +RECIPE_DIR="$(cd "$(dirname "$0")/.." && pwd)" + +MODEL="${MODEL:-arcinstitute/savanna_evo2_1b_base}" +MODEL_SIZE="${MODEL_SIZE:-evo2_1b_base}" +LAYER="${LAYER:-12}" +# Trained context length. 1B = 8192. Bump for 7B/40B (context-extended). +CHUNK_BP="${CHUNK_BP:-8192}" + +FASTA="${FASTA:-/data/interp/evo2/OpenGenome2/fasta/organelles/organelle_sequences.fasta.gz}" +WORK_ROOT="${WORK_ROOT:-/data/interp/evo2}" + +CKPT_DIR="${WORK_ROOT}/checkpoints/${MODEL_SIZE}_mbridge" +PREDICT_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_pt" +PARQUET_DIR="${WORK_ROOT}/activations/${MODEL_SIZE}_layer${LAYER}_parquet" +OUTPUT_DIR="${WORK_ROOT}/sae/${MODEL_SIZE}_layer${LAYER}" + +source "${EVO2_MEGATRON_DIR}/.venv/bin/activate" + +echo "============================================================" +echo "STEP 0: Chunk FASTA to <=${CHUNK_BP} bp (model trained context)" +echo "============================================================" +# chunk_fasta.py reads .gz directly and writes plain .fasta; no separate gunzip needed. +INPUT_STEM="$(basename "$FASTA")" +INPUT_STEM="${INPUT_STEM%.gz}" +INPUT_STEM="${INPUT_STEM%.fasta}" +CHUNKED_FASTA="${WORK_ROOT}/scratch/${INPUT_STEM}_chunked${CHUNK_BP}.fasta" +if [[ -f "$CHUNKED_FASTA" ]]; then + echo "Reusing existing chunked FASTA: $CHUNKED_FASTA" +else + python "${RECIPE_DIR}/scripts/chunk_fasta.py" \ + --input "$FASTA" \ + --output "$CHUNKED_FASTA" \ + --window "$CHUNK_BP" +fi +FASTA="$CHUNKED_FASTA" + +echo "============================================================" +echo "STEP 1: Convert Savanna -> MBridge" +echo "============================================================" +if [[ ! -f "${CKPT_DIR}/latest_checkpointed_iteration.txt" ]]; then + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path "$MODEL" \ + --mbridge-ckpt-dir "$CKPT_DIR" \ + --model-size "$MODEL_SIZE" \ + --tokenizer-path "${EVO2_MEGATRON_DIR}/tokenizers/nucleotide_fast_tokenizer_512" +else + echo "Reusing existing checkpoint at $CKPT_DIR" +fi + +echo "============================================================" +echo "STEP 2: Extract layer-${LAYER} embeddings (predict_evo2)" +echo "============================================================" +mkdir -p "$PREDICT_DIR" +if compgen -G "${PREDICT_DIR}/predictions__*.pt" > /dev/null; then + echo "Reusing existing .pt files in $PREDICT_DIR" +else + predict_evo2 \ + --fasta "$FASTA" \ + --ckpt-dir "$CKPT_DIR" \ + --output-dir "$PREDICT_DIR" \ + --embedding-layer "$LAYER" \ + --micro-batch-size 1 \ + --devices 1 \ + --write-interval batch +fi + +echo "============================================================" +echo "STEP 3: Convert .pt -> parquet ActivationStore" +echo "============================================================" +if [[ -f "${PARQUET_DIR}/metadata.json" ]]; then + echo "Reusing existing parquet shards at $PARQUET_DIR" +else + python "${RECIPE_DIR}/scripts/pt_to_parquet.py" \ + --predict-dir "$PREDICT_DIR" \ + --output "$PARQUET_DIR" \ + --model-name "$MODEL" \ + --layer "$LAYER" +fi + +echo "============================================================" +echo "STEP 4: Train TopK SAE" +echo "============================================================" +python "${RECIPE_DIR}/scripts/train.py" \ + --cache-dir "$PARQUET_DIR" \ + --model-path "$MODEL" \ + --layer "$LAYER" \ + --model-type topk \ + --expansion-factor 8 --top-k 32 \ + --auxk 64 --auxk-coef 0.03125 \ + --init-pre-bias \ + --n-epochs 3 \ + --batch-size 4096 \ + --lr 3e-4 \ + --log-interval 50 \ + --no-wandb \ + --output-dir "$OUTPUT_DIR" \ + --checkpoint-dir "${OUTPUT_DIR}/checkpoints" \ + --checkpoint-steps 999999 + +echo "============================================================" +echo "DONE: SAE checkpoint at ${OUTPUT_DIR}/checkpoints/checkpoint_final.pt" +echo "============================================================" diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py new file mode 100644 index 0000000000..55b26cad30 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/chunk_fasta.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk a FASTA into <=N-bp windows so predict_evo2 stays inside the model's trained context. + +Evo2 1B was trained with seq_length=8192; longer inputs OOM in the Hyena +fftconv path (intermediates scale super-linearly with L). For 7B/40B raise +--window to whatever those checkpoints were context-extended to. + +Non-overlapping windows by default. Each chunk gets a header of the form +">{orig_id}:{start}-{end}" so downstream parquet can be back-mapped. +""" + +import argparse +import gzip +from pathlib import Path + + +def parse_fasta(path: Path): + """Yield (seq_id, sequence) tuples from a FASTA file (transparently handles .gz).""" + opener = gzip.open if path.suffix == ".gz" else open + seq_id, parts = None, [] + with opener(path, "rt") as f: + for line in f: + line = line.rstrip() + if line.startswith(">"): + if seq_id is not None: + yield seq_id, "".join(parts) + seq_id = line[1:].split()[0] + parts = [] + else: + parts.append(line) + if seq_id is not None: + yield seq_id, "".join(parts) + + +def main(): + """Read input FASTA, write non-overlapping <=window-bp chunks to output FASTA.""" + p = argparse.ArgumentParser() + p.add_argument("--input", type=Path, required=True) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--window", type=int, default=8192) + args = p.parse_args() + + n_in = n_out = bp_out = 0 + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as out: + for seq_id, seq in parse_fasta(args.input): + n_in += 1 + for start in range(0, len(seq), args.window): + end = min(start + args.window, len(seq)) + chunk = seq[start:end] + out.write(f">{seq_id}:{start}-{end}\n{chunk}\n") + n_out += 1 + bp_out += len(chunk) + + print(f"Chunked {n_in} sequences -> {n_out} chunks ({bp_out:,} bp) at window={args.window}") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py new file mode 100644 index 0000000000..6a182b575d --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/pt_to_parquet.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert predict_evo2 .pt outputs to SAE ActivationStore parquet shards. + +predict_evo2 with --embedding-layer writes dicts of: + hidden_embeddings: [B, S, H] (bf16) + pad_mask: [B, S] (1 = valid token, 0 = padding) + seq_idx, tokens: metadata, ignored here + +We read each file, mask out padding, flatten to [N_tokens, H], and append +to an ActivationStore so train.py's load_activations() can consume it. +""" + +import argparse +import json +from pathlib import Path + +import torch +from sae.activation_store import ActivationStore, ActivationStoreConfig +from tqdm import tqdm + + +def main(): + """Walk predict_evo2 .pt files, mask padding, and write to an ActivationStore.""" + p = argparse.ArgumentParser() + p.add_argument("--predict-dir", type=Path, required=True, help="Dir containing predictions__*.pt") + p.add_argument("--output", type=Path, required=True, help="ActivationStore output dir") + p.add_argument("--model-name", type=str, required=True, help="Stamped into metadata.json") + p.add_argument("--layer", type=int, required=True, help="Stamped into metadata.json") + p.add_argument("--shard-size", type=int, default=100_000) + args = p.parse_args() + + pt_files = sorted(args.predict_dir.rglob("predictions__*.pt")) + if not pt_files: + raise FileNotFoundError(f"No predictions__*.pt under {args.predict_dir}") + + store = ActivationStore(args.output, ActivationStoreConfig(shard_size=args.shard_size)) + n_sequences = 0 + for pt in tqdm(pt_files, desc="pt->parquet"): + d = torch.load(pt, map_location="cpu", weights_only=False) + hidden = d["hidden_embeddings"] + mask = d["pad_mask"].bool() + flat = hidden[mask].float() + store.append(flat) + n_sequences += hidden.shape[0] + + store.finalize(metadata={"model_name": args.model_name, "layer": args.layer, "n_sequences": n_sequences}) + print(json.dumps(store.metadata, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py new file mode 100644 index 0000000000..19355822ae --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/scripts/train.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Step 2: Train SAE from cached CodonFM activations. + +Loads pre-extracted activations from an ActivationStore cache directory +and trains a Sparse Autoencoder. Requires extract.py to have been run first. + +Single-GPU: + python scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 + +Multi-GPU DDP: + torchrun --nproc_per_node=4 scripts/train.py \ + --cache-dir .cache/activations/encodon_1b_layer-2 \ + --model-path path/to/encodon_1b --layer -2 \ + --expansion-factor 8 --top-k 32 --batch-size 4096 --n-epochs 3 \ + --dp-size 4 +""" + +import argparse +import os +from pathlib import Path + +import numpy as np +import torch +from sae.activation_store import load_activations +from sae.architectures import ReLUSAE, TopKSAE +from sae.perf_logger import PerfLogger +from sae.training import ParallelConfig, Trainer, TrainingConfig, WandbConfig +from sae.utils import get_device, set_seed + + +def parse_args(): # noqa: D103 + p = argparse.ArgumentParser( + description="Train SAE from cached CodonFM activations", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required + p.add_argument("--cache-dir", type=str, required=True, help="Path to activation cache (from extract.py)") + p.add_argument("--model-path", type=str, required=True, help="Encodon model path (for cache validation)") + p.add_argument("--layer", type=int, required=True, help="Layer index (for cache validation)") + + # SAE architecture + sae_group = p.add_argument_group("SAE model") + sae_group.add_argument("--model-type", type=str, default="topk", choices=["topk", "relu"]) + sae_group.add_argument("--expansion-factor", type=int, default=8) + sae_group.add_argument("--top-k", type=int, default=32) + sae_group.add_argument("--normalize-input", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--auxk", type=int, default=None) + sae_group.add_argument("--auxk-coef", type=float, default=1 / 32) + sae_group.add_argument("--dead-tokens-threshold", type=int, default=10_000_000) + sae_group.add_argument("--init-pre-bias", action=argparse.BooleanOptionalAction, default=False) + sae_group.add_argument("--l1-coeff", type=float, default=1e-2, help="L1 coefficient (relu only)") + + # Training + train_group = p.add_argument_group("Training") + train_group.add_argument("--lr", type=float, default=3e-4) + train_group.add_argument("--n-epochs", type=int, default=3) + train_group.add_argument("--batch-size", type=int, default=4096) + train_group.add_argument("--log-interval", type=int, default=50) + train_group.add_argument("--shuffle", action=argparse.BooleanOptionalAction, default=True) + train_group.add_argument("--num-workers", type=int, default=0) + train_group.add_argument("--pin-memory", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--max-grad-norm", type=float, default=None) + train_group.add_argument("--lr-scale-with-latents", action=argparse.BooleanOptionalAction, default=False) + train_group.add_argument("--lr-reference-hidden-dim", type=int, default=2048) + train_group.add_argument("--warmup-steps", type=int, default=0, help="Linear LR warmup steps") + train_group.add_argument( + "--lr-schedule", + type=str, + default="constant", + choices=["constant", "cosine", "linear"], + help="LR schedule after warmup", + ) + train_group.add_argument("--lr-min", type=float, default=0.0, help="Minimum LR for decay schedules") + train_group.add_argument( + "--lr-decay-steps", + type=int, + default=None, + help="Total steps for LR decay (None = full training)", + ) + + # W&B + wb_group = p.add_argument_group("Weights & Biases") + wb_group.add_argument("--wandb", action=argparse.BooleanOptionalAction, default=False, dest="wandb_enabled") + wb_group.add_argument("--wandb-project", type=str, default="sae_codonfm_recipe") + wb_group.add_argument("--wandb-run-name", type=str, default=None) + wb_group.add_argument("--wandb-group", type=str, default=None) + wb_group.add_argument("--wandb-job-type", type=str, default=None) + + # Checkpointing + ckpt_group = p.add_argument_group("Checkpointing") + ckpt_group.add_argument("--checkpoint-dir", type=str, default=None) + ckpt_group.add_argument("--checkpoint-steps", type=int, default=None) + ckpt_group.add_argument("--resume-from", type=str, default=None) + + # Infrastructure + p.add_argument("--dp-size", type=int, default=1) + p.add_argument("--output-dir", type=str, default="./outputs") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--device", type=str, default=None) + p.add_argument( + "--num-sequences", + type=int, + default=None, + help="Subset cached activations to this many sequences' worth of shards", + ) + + return p.parse_args() + + +def build_sae(args, input_dim: int) -> torch.nn.Module: # noqa: D103 + hidden_dim = input_dim * args.expansion_factor + + if args.model_type == "topk": + return TopKSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + top_k=args.top_k, + normalize_input=args.normalize_input, + auxk=args.auxk, + auxk_coef=args.auxk_coef, + dead_tokens_threshold=args.dead_tokens_threshold, + ) + elif args.model_type == "relu": + return ReLUSAE( + input_dim=input_dim, + hidden_dim=hidden_dim, + l1_coeff=args.l1_coeff, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + +def build_training_config(args, device: str) -> TrainingConfig: # noqa: D103 + return TrainingConfig( + lr=args.lr, + n_epochs=args.n_epochs, + batch_size=args.batch_size, + device=device, + log_interval=args.log_interval, + shuffle=args.shuffle, + num_workers=args.num_workers, + pin_memory=args.pin_memory, + checkpoint_dir=args.checkpoint_dir, + checkpoint_steps=args.checkpoint_steps, + lr_scale_with_latents=args.lr_scale_with_latents, + lr_reference_hidden_dim=args.lr_reference_hidden_dim, + warmup_steps=args.warmup_steps, + max_grad_norm=args.max_grad_norm, + lr_schedule=args.lr_schedule, + lr_min=args.lr_min, + lr_decay_steps=args.lr_decay_steps, + ) + + +def build_wandb_config(args) -> WandbConfig: # noqa: D103 + return WandbConfig( + enabled=args.wandb_enabled, + project=args.wandb_project, + run_name=args.wandb_run_name, + group=args.wandb_group, + job_type=args.wandb_job_type, + config=vars(args), + ) + + +def build_parallel_config(args) -> ParallelConfig: # noqa: D103 + return ParallelConfig(dp_size=args.dp_size) + + +def main(): # noqa: D103 + args = parse_args() + + set_seed(args.seed) + device = args.device or get_device() + print(f"Using device: {device}") + print(f"Config: {vars(args)}") + + # Load cached activations + cache_path = Path(args.cache_dir) + if not (cache_path / "metadata.json").exists(): + raise FileNotFoundError(f"No cache found at {cache_path}. Run extract.py first.") + + store = load_activations(cache_path) + meta = store.metadata + + # Validate cache matches config + cached_model = meta.get("model_path", meta.get("model_name", "")) + if cached_model and cached_model != args.model_path: + print(f"WARNING: Cache model '{cached_model}' != '{args.model_path}'") + if meta.get("layer") != args.layer: + raise ValueError(f"Cache layer mismatch: {meta['layer']} vs {args.layer}") + + # Compute subsetting + cached_sequences = meta.get("n_sequences", None) + max_shards = None + if args.num_sequences and cached_sequences and args.num_sequences < cached_sequences: + keep_ratio = args.num_sequences / cached_sequences + max_shards = max(1, int(np.ceil(keep_ratio * meta["n_shards"]))) + print( + f"Subsetting: {args.num_sequences}/{cached_sequences} sequences " + f"-> using {max_shards}/{meta['n_shards']} shards (~{keep_ratio:.1%})" + ) + + # Estimate memory + n_shards_to_use = max_shards or meta["n_shards"] + shard_size = meta.get("shard_size", 100_000) + est_tokens = n_shards_to_use * shard_size + est_gb = est_tokens * meta["hidden_dim"] * 4 / (1024**3) + use_streaming = est_gb > 50 + + input_dim = meta["hidden_dim"] + sae = build_sae(args, input_dim) + print(f"SAE: {args.model_type}, input_dim={input_dim}, hidden_dim={sae.hidden_dim}") + + # Initialize pre_bias + if args.init_pre_bias and hasattr(sae, "init_pre_bias_from_data"): + print("Initializing pre_bias from geometric median of data...") + first_shard = torch.from_numpy(store._load_shard(0)).float() + sample_size = min(32768, len(first_shard)) + sae.init_pre_bias_from_data(first_shard[:sample_size]) + print(f" pre_bias initialized (mean={sae.pre_bias.mean().item():.4f})") + del first_shard + + # Build configs + training_config = build_training_config(args, device) + wandb_config = build_wandb_config(args) + parallel_config = build_parallel_config(args) + + perf_logger = PerfLogger( + log_interval=args.log_interval, + use_wandb=args.wandb_enabled, + print_logs=True, + device=device, + ) + + # Train + trainer = Trainer( + sae, + training_config, + wandb_config=wandb_config, + perf_logger=perf_logger, + parallel_config=parallel_config, + ) + + if use_streaming: + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + print( + f"Streaming from disk (~{est_gb:.0f}GB). " + f"Peak RAM: ~{shard_size * meta['hidden_dim'] * 4 / (1024**3):.1f}GB/process" + ) + + dataloader = store.get_streaming_dataloader( + batch_size=args.batch_size, + shuffle=args.shuffle, + seed=args.seed, + rank=rank, + world_size=world_size, + max_shards=max_shards, + ) + # Compute min batch count across all ranks to keep DDP in sync + # Read parquet footers for all ranks' shards (a few KB each, no data loading) + if world_size > 1: + import pyarrow.parquet as pq_meta + + dataset = dataloader.dataset + per_rank = len(dataset.shard_indices) + # Each rank got per_rank contiguous shards; compute batch count for each rank + min_batches = None + for r in range(world_size): + total_rows = sum( + pq_meta.read_metadata(store.path / f"shard_{idx:05d}.parquet").num_rows + for idx in range(r * per_rank, (r + 1) * per_rank) + ) + batches = total_rows // args.batch_size + if min_batches is None or batches < min_batches: + min_batches = batches + dataset.max_batches = min_batches + print(f"[rank {rank}] capped to {min_batches} batches/epoch for DDP sync") + trainer.fit( + dataloader, + resume_from=args.resume_from, + data_sharded=True, + ) + else: + shards = [] + for i, shard in enumerate(store.iter_shards(shuffle_shards=False)): + if max_shards is not None and i >= max_shards: + break + shards.append(torch.from_numpy(shard).float()) + activations_flat = torch.cat(shards) + print(f"Loaded {activations_flat.shape[0]:,} cached activations into memory") + + trainer.fit( + activations_flat, + resume_from=args.resume_from, + ) + + print("Training complete.") + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py new file mode 100644 index 0000000000..d8ac513dc8 --- /dev/null +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sparse autoencoders for the Evo2 DNA language model.""" From 5edbf6ef0d8b658d5cc24a1b36dff634237689fc Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Tue, 26 May 2026 21:14:43 +0000 Subject: [PATCH 3/3] evo2 recipe: drop empty src/evo2_sae package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The recipe currently has no model-specific Python module — the extractor is upstream (`predict_evo2`) and the two scripts are simple CLIs in scripts/. Drop the empty package and adjust pyproject.toml so setuptools doesn't try to discover anything. Will reintroduce when there's actual library code to put there (eval, dashboard, dataloaders). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../recipes/evo2/pyproject.toml | 7 +++++-- .../recipes/evo2/src/evo2_sae/__init__.py | 16 ---------------- 2 files changed, 5 insertions(+), 18 deletions(-) delete mode 100644 bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml index 26eff6b55c..1f00a62bc5 100644 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml +++ b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/pyproject.toml @@ -17,8 +17,11 @@ dependencies = [ "pyarrow>=10.0", ] -[tool.setuptools.packages.find] -where = ["src"] +# No package code lives here yet — the recipe is just an entry-point for +# scripts/ that depends on the shared `sae` workspace package. Declare no +# packages so setuptools doesn't try to discover anything. +[tool.setuptools] +packages = [] [tool.uv.sources] sae = { workspace = true } diff --git a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py b/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py deleted file mode 100644 index d8ac513dc8..0000000000 --- a/bionemo-recipes/interpretability/sparse_autoencoders/recipes/evo2/src/evo2_sae/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Sparse autoencoders for the Evo2 DNA language model."""