Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
"jaxtyping>=0.2.11",
"numpy>=1.24; python_version >= '3.10' and python_version < '3.12'",
"numpy>=1.26; python_version >= '3.12' and python_version < '3.13'",
"packaging>=23.0",
"pandas>=1.1.5",
"protobuf>=3.20.0",
"rich>=12.6.0",
"sentencepiece",
"torch>=2.6",
"tqdm>=4.64.1",
"transformers-stream-generator>=0.0.5,<0.1",
"transformers>=4.56",
"transformers>=5.4.0",
"typeguard>=4.2,<5",
"typing-extensions",
"wandb>=0.13.5",
Expand All @@ -36,7 +37,6 @@
# whenever chardet>=6 is installed. Remove the pin when psf/requests bumps the cap.
evals=["lm-eval>=0.4", "chardet<6"]
lit=["lit-nlp>=1.3"]
qwen35=["packaging>=23.0", "transformers>=5.2.0"]

[project.scripts]
build-docs="docs.make_docs:build_docs"
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/components/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_attention_load_in_4bit():
assert attn.cfg.n_ctx == 1024
assert attn.cfg.d_head == 64
assert attn.cfg.n_heads == 8
assert attn.cfg.load_in_4bit == False
assert attn.cfg.load_in_4bit == True
assert attn.cfg.dtype == torch.float32
assert attn.cfg.act_fn == "relu"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Unit tests for Gemma1ArchitectureAdapter.

Mirrors test_gemma2_adapter.py: GemmaTextScaledWordEmbedding scales internally,
so the adapter must NOT override setup_hook_compatibility (an override that
installed a hook_conversion would double-scale embed.hook_out).
"""

import pytest

from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.model_bridge.supported_architectures.gemma1 import (
Gemma1ArchitectureAdapter,
)


def _make_cfg(d_model: int = 32) -> TransformerBridgeConfig:
return TransformerBridgeConfig(
d_model=d_model,
d_head=d_model // 4,
n_layers=2,
n_ctx=128,
n_heads=4,
d_vocab=256,
d_mlp=64,
architecture="GemmaForCausalLM",
)


@pytest.fixture(scope="module")
def adapter() -> Gemma1ArchitectureAdapter:
return Gemma1ArchitectureAdapter(_make_cfg())


class TestGemma1HookCompatibility:
def test_adapter_does_not_override_setup_hook_compatibility(
self, adapter: Gemma1ArchitectureAdapter
) -> None:
# bridge.py:763 uses hasattr() to decide whether to call the override.
assert "setup_hook_compatibility" not in vars(type(adapter))
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Unit tests for Gemma2ArchitectureAdapter.

Post-fix invariant: the adapter MUST NOT override setup_hook_compatibility.
Gemma2TextScaledWordEmbedding scales by sqrt(d_model) inside its own forward(),
so the bridge's wrapped layer already returns the scaled value. An override
that installed a hook_conversion would double-scale (cache R * d_model instead
of R * sqrt(d_model)) — the dev-4.x vLLM-comparison investigation surfaced this.
"""

import os

import pytest
import torch

from transformer_lens.config.transformer_bridge_config import TransformerBridgeConfig
from transformer_lens.model_bridge.supported_architectures.gemma2 import (
Gemma2ArchitectureAdapter,
)


def _make_cfg(d_model: int = 32) -> TransformerBridgeConfig:
return TransformerBridgeConfig(
d_model=d_model,
d_head=d_model // 4,
n_layers=2,
n_ctx=128,
n_heads=4,
d_vocab=256,
d_mlp=64,
architecture="Gemma2ForCausalLM",
)


@pytest.fixture(scope="module")
def adapter() -> Gemma2ArchitectureAdapter:
return Gemma2ArchitectureAdapter(_make_cfg())


class TestGemma2HookCompatibility:
def test_adapter_does_not_override_setup_hook_compatibility(
self, adapter: Gemma2ArchitectureAdapter
) -> None:
# bridge.py:763 uses hasattr() to decide whether to call the override.
# Absence is the contract — the base ArchitectureAdapter has no such
# method, so the bridge skips installation entirely.
assert "setup_hook_compatibility" not in vars(type(adapter))


@pytest.mark.skipif(bool(os.getenv("CI")), reason="Network/disk fetch of tiny Gemma2 — skip in CI")
def test_gemma2_embed_hook_out_magnitude_matches_sqrt_d_model_scaling():
"""End-to-end regression for the embed double-scale bug.

Before the fix: embed.hook_out held R * d_model (double-scaled).
After: embed.hook_out holds R * sqrt(d_model) — i.e. exactly what
Gemma2TextScaledWordEmbedding.forward returns and what flows
into block 0.
"""
import torch.nn.functional as F

from transformer_lens.model_bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers(
"hf-internal-testing/tiny-random-Gemma2ForCausalLM",
dtype=torch.float32,
device="cpu",
)
assert bridge.embed.hook_out.hook_conversion is None

tokens = torch.tensor([[1, 2, 3, 4]])
raw_R = F.embedding(tokens, bridge.embed.original_component.weight)
scaled_direct = bridge.embed.original_component(tokens)
_, cache = bridge.run_with_cache(tokens)

# embed.hook_out should equal the ScaledWordEmbedding's actual output
assert torch.allclose(cache["embed.hook_out"], scaled_direct, atol=1e-6)
# And the ratio to the raw lookup is sqrt(d_model), not d_model
ratio = cache["embed.hook_out"].abs().max().item() / raw_R.abs().max().item()
assert ratio == pytest.approx(bridge.cfg.d_model**0.5, rel=1e-4)
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
"""Unit tests for XGLMArchitectureAdapter: cfg, components, weight conversions, hook compat, factory."""

import math
from types import SimpleNamespace

import pytest
import torch

from transformer_lens.config import TransformerBridgeConfig
from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
Expand Down Expand Up @@ -191,44 +187,15 @@ def test_mlp_out_name(self, adapter: XGLMArchitectureAdapter) -> None:
assert mlp.submodules["out"].name == "fc2"


def _make_mock_bridge() -> SimpleNamespace:
"""Minimal mock bridge with embed.hook_out for hook-compat tests."""
hook_out = SimpleNamespace(hook_conversion=None)
embed = SimpleNamespace(hook_out=hook_out)
return SimpleNamespace(embed=embed)


class TestXGLMAdapterHookCompatibility:
"""setup_hook_compatibility attaches a scale conversion to hook_embed."""

def test_sets_hook_conversion_on_embed_hook_out(self, adapter: XGLMArchitectureAdapter) -> None:
bridge = _make_mock_bridge()
adapter.setup_hook_compatibility(bridge)
assert bridge.embed.hook_out.hook_conversion is not None

def test_scales_by_sqrt_d_model(self, adapter: XGLMArchitectureAdapter) -> None:
bridge = _make_mock_bridge()
adapter.setup_hook_compatibility(bridge)
conv = bridge.embed.hook_out.hook_conversion
x = torch.ones(2, 4, 64)
result = conv.handle_conversion(x)
expected_scale = math.sqrt(64)
assert torch.allclose(result, x * expected_scale, atol=1e-6)

def test_revert_inverts_scale(self, adapter: XGLMArchitectureAdapter) -> None:
bridge = _make_mock_bridge()
adapter.setup_hook_compatibility(bridge)
conv = bridge.embed.hook_out.hook_conversion
x = torch.randn(2, 4, 64)
assert torch.allclose(conv.revert(conv.handle_conversion(x)), x, atol=1e-6)

def test_no_error_when_embed_missing(self, adapter: XGLMArchitectureAdapter) -> None:
bridge = SimpleNamespace()
adapter.setup_hook_compatibility(bridge)

def test_no_error_when_hook_out_missing(self, adapter: XGLMArchitectureAdapter) -> None:
bridge = SimpleNamespace(embed=SimpleNamespace())
adapter.setup_hook_compatibility(bridge)
"""Adapter must not override setup_hook_compatibility — XGLMScaledWordEmbedding
scales internally, so any override would double-scale embed.hook_out."""

def test_adapter_does_not_override_setup_hook_compatibility(
self, adapter: XGLMArchitectureAdapter
) -> None:
# bridge.py:763 uses hasattr() to decide whether to call the override.
assert "setup_hook_compatibility" not in vars(type(adapter))


class TestXGLMFactoryRegistration:
Expand Down
45 changes: 6 additions & 39 deletions transformer_lens/model_bridge/supported_architectures/gemma1.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ def __init__(self, cfg: Any) -> None:
self.cfg.rmsnorm_uses_offset = True

self.weight_processing_conversions = {
# NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME in
# GemmaModel.forward(). We must NOT pre-scale embed weights here
# because that would cause double-scaling (pre-scale + runtime).
# The runtime hook_conversion in setup_hook_compatibility() handles
# scaling the hook output so it matches HookedTransformer's behavior.
# NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME inside
# GemmaTextScaledWordEmbedding.forward() (HF transformers >= 5.0).
# That layer is what bridge.embed wraps, so embed.hook_out already
# captures the scaled value — matching HookedTransformer's hook_embed
# (which uses pre-scaled W_E). We must NOT pre-scale weights here and
# we must NOT install a runtime hook_conversion that re-scales.
#
# Attention weight conversions
**self._qkvo_weight_conversions(),
Expand Down Expand Up @@ -118,40 +119,6 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head"),
}

def setup_hook_compatibility(self, bridge: Any) -> None:
"""Setup hook compatibility for Gemma1 models.

Gemma1 scales embeddings by sqrt(d_model) in its forward pass,
but the HuggingFace embed_tokens layer doesn't include this scaling.
We need to apply it to hook_embed to match HookedTransformer behavior.

Args:
bridge: The TransformerBridge instance
"""
from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
BaseTensorConversion,
)

class EmbeddingScaleConversion(BaseTensorConversion):
"""Scale embeddings by sqrt(d_model) for Gemma models."""

def __init__(self, scale: float):
super().__init__()
self.scale = scale

def handle_conversion(self, input_value: Any, *full_context: Any) -> Any:
"""Scale the embedding output."""
return input_value * self.scale

def revert(self, input_value: Any, *full_context: Any) -> Any:
"""Unscale the embedding output (for user modifications)."""
return input_value / self.scale

# Apply scaling to embed.hook_out
if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
scale_factor = self.cfg.d_model**0.5
bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(scale_factor)

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for Gemma1 component testing.

Expand Down
34 changes: 7 additions & 27 deletions transformer_lens/model_bridge/supported_architectures/gemma2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
"""Gemma2 architecture adapter."""

from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
pass
from typing import Any

from transformer_lens.conversion_utils.conversion_steps import (
ArithmeticTensorConversion,
Expand Down Expand Up @@ -60,11 +57,12 @@ def __init__(self, cfg: Any) -> None:
# by map_default_transformer_lens_config() in sources/transformers.py

self.weight_processing_conversions = {
# NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME in
# Gemma2Model.forward(). We must NOT pre-scale embed weights here
# because that would cause double-scaling (pre-scale + runtime).
# The runtime hook_conversion in setup_hook_compatibility() handles
# scaling the hook output so it matches HookedTransformer's behavior.
# NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME inside
# Gemma2TextScaledWordEmbedding.forward() (HF transformers >= 5.0).
# That layer is what bridge.embed wraps, so embed.hook_out already
# captures the scaled value — matching HookedTransformer's hook_embed
# (which uses pre-scaled W_E). We must NOT pre-scale weights here and
# we must NOT install a runtime hook_conversion that re-scales.
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
),
Expand Down Expand Up @@ -162,24 +160,6 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

def setup_hook_compatibility(self, bridge: Any) -> None:
"""Setup hook compatibility for Gemma2 models.

Gemma2 scales embeddings by sqrt(d_model). The weights are pre-scaled via
preprocess_weights(), but we still need to apply the scaling conversion to
the hook output for proper hook functionality (so user modifications are
correctly scaled/unscaled).

Args:
bridge: The TransformerBridge instance
"""
# Apply embedding scaling conversion to hook output
if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
scale_factor = self.cfg.d_model**0.5
bridge.embed.hook_out.hook_conversion = ArithmeticTensorConversion(
OperationTypes.MULTIPLICATION, scale_factor
)

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references and attention implementation for Gemma-2 component testing.

Expand Down
26 changes: 4 additions & 22 deletions transformer_lens/model_bridge/supported_architectures/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def __init__(self, cfg: Any) -> None:
self.cfg.attn_implementation = "eager"

self.weight_processing_conversions = {
# Note: Gemma3 scales embeddings by sqrt(d_model) in the forward pass.
# This is handled in setup_hook_compatibility() which applies the scaling
# to hook_embed output at runtime, matching HuggingFace's behavior.
# We do NOT scale the stored weights here.
# Note: Gemma3TextScaledWordEmbedding scales by sqrt(d_model) inside
# its own forward(). Bridge.embed wraps that layer, so embed.hook_out
# already captures the scaled value — no weight pre-scaling and no
# hook_conversion needed (setup_hook_compatibility is a no-op).
#
# Q/K/V weight conversions
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
Expand Down Expand Up @@ -171,24 +171,6 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head"),
}

def setup_hook_compatibility(self, bridge: Any) -> None:
"""Setup hook compatibility for Gemma3 models.

Unlike Gemma1/Gemma2, Gemma3 uses Gemma3TextScaledWordEmbedding which
scales embeddings by sqrt(d_model) INSIDE the embedding layer's forward().
Therefore we do NOT need a hook_conversion — the embed.hook_out already
captures the scaled output. Adding a conversion would double-scale.

(Gemma1/Gemma2 scale in GemmaModel.forward() AFTER the embedding layer,
so their adapters correctly use EmbeddingScaleConversion to match HT.)

Args:
bridge: The TransformerBridge instance
"""
# No embed scaling conversion needed — Gemma3TextScaledWordEmbedding
# already applies sqrt(d_model) scaling in its forward() method.
pass

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references and native autograd for Gemma-3 component testing.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,6 @@ def __init__(self, cfg: Any) -> None:
"unembed": UnembeddingBridge(name="lm_head"),
}

def setup_hook_compatibility(self, bridge: Any) -> None:
"""Setup hook compatibility for Gemma3 multimodal models.

Like text-only Gemma 3, the multimodal model uses
Gemma3TextScaledWordEmbedding which scales embeddings by sqrt(d_model)
internally in its forward() method. No additional hook conversion is
needed — adding one would double-scale the embeddings.

Args:
bridge: The TransformerBridge instance
"""
pass

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for Gemma-3 multimodal component testing.

Expand Down
Loading
Loading