Skip to content

GoXzascc/AbsTopK-SAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

55 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features

ICLR 2026 arXiv License: MIT Python PyTorch

Links: Paper (arXiv) · OpenReview · ICLR 2026 Poster · Slides

Official PyTorch implementation of "AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features", accepted to ICLR 2026.

Sparse autoencoders (SAEs) decompose large language model hidden states into interpretable features. We derive common SAE variants (ReLU, JumpReLU, TopK) by unrolling a single step of the proximal gradient method for sparse coding, and reveal that their non-negativity constraint fragments bidirectional concepts (e.g. male ↔ female) into redundant features. AbsTopK SAE is derived from an ℓ₀ constraint and applies hard thresholding over the largest-magnitude activations, keeping both positive and negative activations so a single feature can encode contrasting concepts. Across four LLMs and seven probing/steering tasks, AbsTopK improves reconstruction fidelity, enhances interpretability, and matches or surpasses the supervised Difference-in-Means baseline.

SAE variants implemented

This codebase provides a unified framework covering every variant discussed in the paper:

Variant --sae_name Sparsity mechanism Bidirectional
ReLU SAE relu ReLU + L1 penalty
JumpReLU SAE jumprelu JumpReLU threshold
TopK SAE batchtopk Hard top-k selection (non-negative)
Gated SAE gated Gated encoder (Rajamanoharan et al. 2024)
AbsTopK SAE batchabsolutek Hard top-k by magnitude (signed)

Installation

This project uses uv for dependency management and targets Python ≥ 3.12.

git clone https://github.com/GoXzascc/AbsTopK-SAE.git
cd AbsTopK-SAE
uv sync

(Optional) Export your Hugging Face token to access gated models such as Llama-3.1-8B:

export HF_TOKEN=hf_xxx

Quick start

Train an AbsTopK SAE

The recommended entry point is src/train_abstopk.py, which trains a BatchAbsoluteKSAE and logs MSE, nMSE, and Loss Recovered to both MLflow and a per-run CSV file:

uv run src/train_abstopk.py \
    --model_name EleutherAI/pythia-70m \
    --layer 3 \
    --dataset monology/pile-uncopyrighted \
    --sae_name batchabsolutek \
    --batch_size 128 \
    --k 51 \
    --dictionary_factor 16 \
    --lr 3e-4 \
    --training_steps 30000 \
    --checkpoint_freq 5000 \
    --perf_log_freq 1000

Train other SAE variants

The general trainer src/trainer.py supports every variant via --sae_name:

uv run src/trainer.py --sae_name batchtopk        # TopK SAE
uv run src/trainer.py --sae_name jumprelu         # JumpReLU SAE
uv run src/trainer.py --sae_name relu             # ReLU SAE
uv run src/trainer.py --sae_name gated            # Gated SAE

Reproduce across models

scripts/abstopk_training.sh sweeps AbsTopK over pythia-70m, gemma-2-2b, Qwen3-4B and GPT-2 at the layers and k values used in the paper:

bash scripts/abstopk_training.sh

Configuration

Each experiment is described by a YAML file in configs/, named <model>_<dataset>_<sae_variant>SAE.yaml. Key fields:

Field Description
model_name Hugging Face model id
layer Residual stream hook layer
sae_name SAE variant (batchabsolutek, batchtopk, jumprelu, …)
k Number of active features (sparsity)
dictionary_factor Expansion ratio (dict_size = act_size * factor)
training_steps Number of optimisation steps
checkpoint_freq Checkpoint interval
perf_log_freq Loss-recovered evaluation interval

Pre-defined configs are provided for EleutherAI/pythia-70m, google/gemma-2-2b, google/gemma-3-12b, Qwen/Qwen3-4B, openai-community/gpt2 and meta-llama/Llama-3.1-8B.

Repository structure

AbsTopK-SAE/
├── src/
│   ├── sae.py                # SAE architectures (AbsTopK, TopK, JumpReLU, ReLU, Gated)
│   ├── trainer.py            # General SAE trainer with CSV/MLflow logging
│   ├── train_abstopk.py      # AbsTopK training entry point
│   ├── data.py               # Activation store (TransformerLens hooks)
│   ├── metrics.py            # nMSE, loss-recovered and sparsity metrics
│   ├── feature_analysis.py   # PCA, difference-in-means, cosine-sim analysis
│   ├── steering.py           # Activation steering utilities
│   └── utils.py              # Helpers (seeding, notifications, …)
├── configs/                  # YAML experiment configs
├── scripts/                  # Training launchers for each variant/model
├── tests/                    # Test suite for models and analysis modules
├── logs/                     # Training logs, checkpoints and CSV metrics
├── pyproject.toml            # Dependencies (uv)
├── uv.lock                   # Locked dependency versions
└── LICENSE

Evaluation & analysis

metrics.py, feature_analysis.py and steering.py implement the probing and steering tasks from the paper:

  • Reconstruction fidelity – nMSE and Loss Recovered (zero-ablation & mean-ablation baselines).
  • Interpretability – per-feature probing, PCA projections and difference-in-means comparison.
  • Steering – add or subtract a single bidirectional feature to flip a contrastive concept.

Citation

If you find this work useful, please cite:

@inproceedings{zhu2026abstopk,
  title     = {AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features},
  author    = {Zhu, Xudong and Khalili, Mohammad Mahdi and Zhu, Zhihui},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2026},
  url       = {https://openreview.net/forum?id=EEs6I4cO7S}
}

License

This project is released under the MIT License.

About

Official PyTorch implementation of "AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features" (ICLR 2026). Trains ReLU, JumpReLU, TopK, Gated and AbsTopK SAEs for LLM interpretability.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors