Links: Paper (arXiv) · OpenReview · ICLR 2026 Poster · Slides
Official PyTorch implementation of "AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features", accepted to ICLR 2026.
Sparse autoencoders (SAEs) decompose large language model hidden states into interpretable features. We derive common SAE variants (ReLU, JumpReLU, TopK) by unrolling a single step of the proximal gradient method for sparse coding, and reveal that their non-negativity constraint fragments bidirectional concepts (e.g. male ↔ female) into redundant features. AbsTopK SAE is derived from an ℓ₀ constraint and applies hard thresholding over the largest-magnitude activations, keeping both positive and negative activations so a single feature can encode contrasting concepts. Across four LLMs and seven probing/steering tasks, AbsTopK improves reconstruction fidelity, enhances interpretability, and matches or surpasses the supervised Difference-in-Means baseline.
This codebase provides a unified framework covering every variant discussed in the paper:
| Variant | --sae_name |
Sparsity mechanism | Bidirectional |
|---|---|---|---|
| ReLU SAE | relu |
ReLU + L1 penalty | |
| JumpReLU SAE | jumprelu |
JumpReLU threshold | |
| TopK SAE | batchtopk |
Hard top-k selection (non-negative) | |
| Gated SAE | gated |
Gated encoder (Rajamanoharan et al. 2024) | |
| AbsTopK SAE | batchabsolutek |
Hard top-k by magnitude (signed) | ✅ |
This project uses uv for dependency management
and targets Python ≥ 3.12.
git clone https://github.com/GoXzascc/AbsTopK-SAE.git
cd AbsTopK-SAE
uv sync(Optional) Export your Hugging Face token to access gated models such as Llama-3.1-8B:
export HF_TOKEN=hf_xxxThe recommended entry point is src/train_abstopk.py, which trains a
BatchAbsoluteKSAE and logs MSE, nMSE, and Loss Recovered to both MLflow and a
per-run CSV file:
uv run src/train_abstopk.py \
--model_name EleutherAI/pythia-70m \
--layer 3 \
--dataset monology/pile-uncopyrighted \
--sae_name batchabsolutek \
--batch_size 128 \
--k 51 \
--dictionary_factor 16 \
--lr 3e-4 \
--training_steps 30000 \
--checkpoint_freq 5000 \
--perf_log_freq 1000The general trainer src/trainer.py supports every variant via --sae_name:
uv run src/trainer.py --sae_name batchtopk # TopK SAE
uv run src/trainer.py --sae_name jumprelu # JumpReLU SAE
uv run src/trainer.py --sae_name relu # ReLU SAE
uv run src/trainer.py --sae_name gated # Gated SAEscripts/abstopk_training.sh sweeps AbsTopK over pythia-70m, gemma-2-2b,
Qwen3-4B and GPT-2 at the layers and k values used in the paper:
bash scripts/abstopk_training.shEach experiment is described by a YAML file in configs/, named
<model>_<dataset>_<sae_variant>SAE.yaml. Key fields:
| Field | Description |
|---|---|
model_name |
Hugging Face model id |
layer |
Residual stream hook layer |
sae_name |
SAE variant (batchabsolutek, batchtopk, jumprelu, …) |
k |
Number of active features (sparsity) |
dictionary_factor |
Expansion ratio (dict_size = act_size * factor) |
training_steps |
Number of optimisation steps |
checkpoint_freq |
Checkpoint interval |
perf_log_freq |
Loss-recovered evaluation interval |
Pre-defined configs are provided for EleutherAI/pythia-70m,
google/gemma-2-2b, google/gemma-3-12b, Qwen/Qwen3-4B,
openai-community/gpt2 and meta-llama/Llama-3.1-8B.
AbsTopK-SAE/
├── src/
│ ├── sae.py # SAE architectures (AbsTopK, TopK, JumpReLU, ReLU, Gated)
│ ├── trainer.py # General SAE trainer with CSV/MLflow logging
│ ├── train_abstopk.py # AbsTopK training entry point
│ ├── data.py # Activation store (TransformerLens hooks)
│ ├── metrics.py # nMSE, loss-recovered and sparsity metrics
│ ├── feature_analysis.py # PCA, difference-in-means, cosine-sim analysis
│ ├── steering.py # Activation steering utilities
│ └── utils.py # Helpers (seeding, notifications, …)
├── configs/ # YAML experiment configs
├── scripts/ # Training launchers for each variant/model
├── tests/ # Test suite for models and analysis modules
├── logs/ # Training logs, checkpoints and CSV metrics
├── pyproject.toml # Dependencies (uv)
├── uv.lock # Locked dependency versions
└── LICENSE
metrics.py, feature_analysis.py and steering.py implement the probing and
steering tasks from the paper:
- Reconstruction fidelity – nMSE and Loss Recovered (zero-ablation & mean-ablation baselines).
- Interpretability – per-feature probing, PCA projections and difference-in-means comparison.
- Steering – add or subtract a single bidirectional feature to flip a contrastive concept.
If you find this work useful, please cite:
@inproceedings{zhu2026abstopk,
title = {AbsTopK: Rethinking Sparse Autoencoders For Bidirectional Features},
author = {Zhu, Xudong and Khalili, Mohammad Mahdi and Zhu, Zhihui},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2026},
url = {https://openreview.net/forum?id=EEs6I4cO7S}
}This project is released under the MIT License.