diff --git a/bionemo-recipes/models/llama3/modeling_llama_te.py b/bionemo-recipes/models/llama3/modeling_llama_te.py
index 0b2236b5bf..d61976c424 100644
--- a/bionemo-recipes/models/llama3/modeling_llama_te.py
+++ b/bionemo-recipes/models/llama3/modeling_llama_te.py
@@ -52,6 +52,7 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ layer_precision: list[str | None] | None = None
def __init__(
self,
@@ -217,11 +218,54 @@ def _init_method(x):
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
+
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
+ def set_recipes(
+ self,
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ ) -> None:
+ """Attach quantization recipe objects for per-layer autocast.
+
+ Recipes are not serializable and must be set at runtime after model creation
+ and sharding (FSDP/DDP) but before training. The per-layer precision
+ assignments are read from ``self.config.layer_precision``.
+
+ Args:
+ fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
+ fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
+ """
+ self._fp8_recipe = fp8_recipe
+ self._fp4_recipe = fp4_recipe
+
+ def get_layer_autocast(self, layer_number: int):
+ """Return the appropriate TE autocast context manager for a given layer.
+
+ The context interacts with the outer FP8 autocast in the training script:
+ - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
+ - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
+ - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
+
+ Args:
+ layer_number: The 0-indexed layer number.
+
+ Returns:
+ A context manager for the layer's quantization mode.
+ """
+ precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+ if precision == "fp8":
+ return nullcontext()
+ elif precision == "fp4":
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
+ else:
+ return transformer_engine.pytorch.autocast(enabled=False)
+
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -298,12 +342,14 @@ def forward(
if te_rope_emb.dtype != torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
- with self.get_autocast_context(None, outer=True):
- for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
+ # by get_layer_autocast(), which nests inside this context.
+ with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
+ for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
- with self.get_autocast_context(layer_idx):
+ with self.get_layer_autocast(layer_number):
hidden_states = decoder_layer(
hidden_states,
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
@@ -363,8 +409,10 @@ def get_autocast_context(
if init and self.config.use_quantized_model_init:
if precision in ("fp8", "fp4"):
- return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
- return nullcontext()
+ return transformer_engine.pytorch.quantized_model_init(
+ recipe=recipe, preserve_high_precision_init_val=True
+ )
+ return transformer_engine.pytorch.quantized_model_init(enabled=False)
if precision == "fp8":
if recipe is None:
@@ -583,6 +631,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
return 0
return max(self.sequences.values())
+ @property
+ def is_compileable(self) -> bool:
+ """Required by HuggingFace transformers generate() auto-compile check."""
+ return False
+
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder the cache based on the beam indices."""
if isinstance(self.cache_manager, PagedKVCacheManager):
@@ -591,8 +644,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_key_cache = key_cache.index_select(0, beam_idx)
updated_value_cache = value_cache.index_select(0, beam_idx)
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
-
- @property
- def is_compileable(self) -> bool:
- """Return False as this cache is not compatible with torch.compile."""
- return False
diff --git a/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
new file mode 100644
index 0000000000..eb93415d50
--- /dev/null
+++ b/bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
@@ -0,0 +1,243 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pickle
+import subprocess
+
+import pytest
+import torch
+from transformer_engine.pytorch.fp8 import check_fp8_support
+
+
+def requires_fp8(func):
+ """Decorator to skip tests that require FP8 support."""
+ fp8_available, reason = check_fp8_support()
+ return pytest.mark.skipif(not fp8_available, reason=f"FP8 is not supported on this GPU: {reason}")(func)
+
+
+requires_multi_gpu = pytest.mark.skipif(
+ not torch.cuda.is_available() or torch.cuda.device_count() < 2,
+ reason="Test requires at least 2 GPUs",
+)
+
+
+@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
+@requires_fp8
+def test_single_process_attaches_correct_fp8_recipe(strategy, unused_tcp_port):
+ cmd = [
+ "torchrun",
+ "--nproc_per_node=1",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{unused_tcp_port}",
+ os.path.relpath(__file__),
+ "--strategy",
+ strategy,
+ ]
+
+ result = subprocess.run(
+ cmd,
+ check=False,
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=240,
+ )
+ if result.returncode != 0:
+ print(f"STDOUT:\n{result.stdout}")
+ print(f"STDERR:\n{result.stderr}")
+ pytest.fail(f"Command failed with exit code {result.returncode}")
+
+
+@pytest.mark.parametrize("strategy", ["ddp", "fsdp2"])
+@requires_fp8
+@requires_multi_gpu
+def test_multi_process_fp8_recipes_are_synced(strategy, unused_tcp_port):
+ cmd = [
+ "torchrun",
+ "--nproc_per_node=2",
+ "--rdzv-backend=c10d",
+ f"--rdzv-endpoint=localhost:{unused_tcp_port}",
+ os.path.relpath(__file__),
+ "--strategy",
+ strategy,
+ ]
+
+ result = subprocess.run(
+ cmd,
+ check=False,
+ text=True,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ timeout=240,
+ )
+ if result.returncode != 0:
+ print(f"STDOUT:\n{result.stdout}")
+ print(f"STDERR:\n{result.stderr}")
+ pytest.fail(f"Command failed with exit code {result.returncode}")
+
+
+if __name__ == "__main__":
+ import argparse
+ import enum
+ import os
+ import sys
+ from dataclasses import dataclass, field
+ from pathlib import Path
+
+ # Ensure the model directory is on sys.path for bare module imports.
+ sys.path.insert(0, Path(__file__).resolve().parent.parent.as_posix())
+
+ import torch.distributed as dist
+ from torch.distributed.device_mesh import init_device_mesh
+ from torch.distributed.fsdp import fully_shard
+ from torch.optim import AdamW
+ from transformer_engine.pytorch.fp8 import DelayedScaling, Format
+
+ from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
+
+ def recursive_assert(a, b, path=""):
+ if isinstance(a, dict) and isinstance(b, dict):
+ assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}"
+ for k in a:
+ recursive_assert(a[k], b[k], path=f"{path}.{k}")
+ elif isinstance(a, list) and isinstance(b, list):
+ assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}"
+ for i in range(len(a)):
+ recursive_assert(a[i], b[i], path=f"{path}.{i}")
+ elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
+ torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}")
+ else:
+ assert a == b, f"Value mismatch at {path}: {a} != {b}"
+
+ class Strategy(enum.StrEnum):
+ DDP = "ddp"
+ FSDP2 = "fsdp2"
+
+ @dataclass
+ class DistributedConfig:
+ """Class to track distributed ranks."""
+
+ rank: int = field(default_factory=dist.get_rank)
+ local_rank: int = field(default_factory=lambda: int(os.environ["LOCAL_RANK"]))
+ world_size: int = field(default_factory=dist.get_world_size)
+
+ def is_main_process(self) -> bool:
+ """This is the global rank 0 process, to be used for wandb logging, etc."""
+ return self.rank == 0
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--strategy", type=Strategy, default=Strategy.DDP, choices=[Strategy.FSDP2, Strategy.DDP])
+ args = parser.parse_args()
+
+ torch.distributed.init_process_group(backend="nccl")
+ dist_config = DistributedConfig()
+ torch.cuda.set_device(dist_config.local_rank)
+ device_mesh = init_device_mesh(
+ "cuda",
+ mesh_shape=(dist_config.world_size, 1),
+ mesh_dim_names=("dp", "tp"),
+ )
+ device = f"cuda:{dist_config.local_rank}"
+
+ fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
+
+ config = NVLlamaConfig(
+ hidden_size=256,
+ intermediate_size=512,
+ num_hidden_layers=6,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ vocab_size=100,
+ dtype=torch.bfloat16,
+ )
+ config.layer_precision = ["fp8"] * config.num_hidden_layers
+ model = NVLlamaForCausalLM(config)
+
+ if args.strategy is Strategy.FSDP2:
+ for layer in model.model.layers:
+ fully_shard(layer, mesh=device_mesh["dp"])
+ fully_shard(model, mesh=device_mesh["dp"])
+ model.to(device)
+
+ elif args.strategy is Strategy.DDP:
+ model.to(device)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[dist_config.local_rank],
+ output_device=dist_config.local_rank,
+ device_mesh=device_mesh["dp"],
+ )
+
+ optimizer = AdamW(model.parameters())
+
+ # Attach FP8 recipes to the model (layer precision is already on config).
+ llama_model = model.module.model if args.strategy is Strategy.DDP else model.model
+ llama_model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+
+ model.train()
+
+ generator = torch.Generator()
+ generator.manual_seed(torch.distributed.get_rank())
+
+ for _ in range(3):
+ input_data = {
+ "input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
+ "labels": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
+ "attention_mask": torch.ones(1, 32),
+ }
+ input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ outputs = model(**input_data)
+
+ outputs.loss.backward()
+
+ # Access FP8 extra states directly from modules instead of state_dict()
+ # since state_dict() now filters them out for HuggingFace compatibility
+ fp8_extra_states = {}
+ for name, module in model.named_modules():
+ if hasattr(module, "_extra_state") and callable(module._extra_state):
+ extra_state = module._extra_state()
+ if extra_state is not None and len(extra_state) > 0:
+ fp8_extra_states[f"{name}._extra_state"] = extra_state
+
+ # lm_head is BF16, not FP8, so exclude it from FP8 checks
+ fp8_extra_states = {key: val for key, val in fp8_extra_states.items() if "lm_head." not in key}
+
+ # 2 ranks, test to ensure that both ranks have the same FP8 extra states
+ if torch.distributed.get_world_size() == 2:
+ outputs_list = [None] * torch.distributed.get_world_size() if torch.distributed.get_rank() == 0 else None
+ torch.distributed.gather_object(fp8_extra_states, outputs_list, dst=0)
+ if torch.distributed.get_rank() == 0:
+ assert outputs_list is not None
+
+ for key in outputs_list[0]:
+ state_1 = outputs_list[0][key]
+ state_2 = outputs_list[1][key]
+ assert len(state_1) > 0, f"No FP8 extra states for {key}, rank 0"
+ assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1"
+ dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes())
+ dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
+ recursive_assert(dict_1, dict_2)
+
+ # One rank, test to ensure the correct FP8 extra states are saved
+ if torch.distributed.get_world_size() == 1:
+ for key, val in fp8_extra_states.items():
+ assert len(val) > 0, f"No FP8 extra states for {key}"
+ fp8_meta_dict = pickle.loads(val.detach().numpy(force=True).tobytes())
+ assert fp8_meta_dict["recipe"] == fp8_recipe, f"Recipe mismatch for {key}"
+
+ torch.distributed.destroy_process_group()
diff --git a/bionemo-recipes/models/llama3/tests/test_layer_quantization.py b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py
new file mode 100644
index 0000000000..a80ff80f2c
--- /dev/null
+++ b/bionemo-recipes/models/llama3/tests/test_layer_quantization.py
@@ -0,0 +1,180 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for NVLlamaModel.set_recipes and get_layer_autocast."""
+
+from contextlib import nullcontext
+from unittest.mock import patch
+
+import pytest
+import transformer_engine.common.recipe
+import transformer_engine.pytorch
+
+from modeling_llama_te import NVLlamaConfig, NVLlamaModel
+
+
+@pytest.fixture
+def model():
+ """Create a small NVLlamaModel for testing."""
+ config = NVLlamaConfig(
+ hidden_size=256,
+ intermediate_size=512,
+ num_hidden_layers=6,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ vocab_size=100,
+ )
+ return NVLlamaModel(config)
+
+
+# -- set_recipes --
+
+
+def test_all_fp8(model):
+ model.config.layer_precision = ["fp8"] * 6
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+ assert model._fp8_recipe is fp8_recipe
+ assert model._fp4_recipe is None
+ assert all(p == "fp8" for p in model.config.layer_precision)
+
+
+def test_all_fp4(model):
+ model.config.layer_precision = ["fp4"] * 6
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
+ assert model._fp8_recipe is None
+ assert model._fp4_recipe is fp4_recipe
+ assert all(p == "fp4" for p in model.config.layer_precision)
+
+
+def test_all_bf16(model):
+ model.config.layer_precision = [None] * 6
+ model.set_recipes(fp8_recipe=None, fp4_recipe=None)
+ assert all(p is None for p in model.config.layer_precision)
+
+
+def test_mixed_fp8_fp4(model):
+ model.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_mixed_fp8_bf16(model):
+ model.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
+ assert model.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None]
+
+
+def test_mixed_all_three(model):
+ model.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"]
+
+
+def test_covers_all_layers(model):
+ model.config.layer_precision = ["fp8"] + [None] * 5
+ model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
+ assert len(model.config.layer_precision) == 6
+
+
+def test_recipes_stored_as_attributes(model):
+ model.config.layer_precision = ["fp8", "fp4", None, None, None, None]
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ assert model._fp8_recipe is fp8_recipe
+ assert model._fp4_recipe is fp4_recipe
+ # The precision list only contains strings/None, not recipe objects.
+ for v in model.config.layer_precision:
+ assert v is None or isinstance(v, str)
+
+
+# -- get_layer_autocast --
+
+
+def test_fp8_layer_returns_nullcontext(model):
+ model.config.layer_precision = ["fp8"] + [None] * 5
+ model.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
+ ctx = model.get_layer_autocast(0)
+ assert isinstance(ctx, nullcontext)
+
+
+def test_fp4_layer_returns_te_autocast(model):
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.config.layer_precision = ["fp4"] + [None] * 5
+ model.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "fp4_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
+ assert ctx == "fp4_context"
+
+
+def test_bf16_layer_returns_te_autocast_disabled(model):
+ model.config.layer_precision = [None] * 6
+ model.set_recipes(fp8_recipe=None, fp4_recipe=None)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=False)
+ assert ctx == "bf16_context"
+
+
+def test_uninitialized_defaults_to_bf16(model):
+ """When layer_precision is None (default), all layers default to BF16."""
+ assert model.config.layer_precision is None
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ ctx = model.get_layer_autocast(0)
+ mock_autocast.assert_called_once_with(enabled=False)
+ assert ctx == "bf16_context"
+
+
+def test_mixed_layers_return_correct_contexts(model):
+ fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
+ fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
+ model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
+ model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
+ # FP8 layers -> nullcontext
+ assert isinstance(model.get_layer_autocast(0), nullcontext)
+ assert isinstance(model.get_layer_autocast(1), nullcontext)
+
+ # FP4 layers -> te.pytorch.autocast
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "fp4_context"
+ model.get_layer_autocast(2)
+ mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)
+
+ # BF16 layers -> te.pytorch.autocast(enabled=False)
+ with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
+ mock_autocast.return_value = "bf16_context"
+ model.get_layer_autocast(4)
+ mock_autocast.assert_called_with(enabled=False)
+
+
+def test_layer_precision_is_pickleable(model):
+ """The config.layer_precision list should be trivially pickleable."""
+ import pickle
+
+ model.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
+ roundtripped = pickle.loads(pickle.dumps(model.config.layer_precision))
+ assert roundtripped == model.config.layer_precision
diff --git a/bionemo-recipes/recipes/llama3_native_te/Dockerfile b/bionemo-recipes/recipes/llama3_native_te/Dockerfile
index b72c36b890..c3dd031de2 100644
--- a/bionemo-recipes/recipes/llama3_native_te/Dockerfile
+++ b/bionemo-recipes/recipes/llama3_native_te/Dockerfile
@@ -1,5 +1,14 @@
# syntax=docker/dockerfile:1.4
-FROM nvcr.io/nvidia/pytorch:26.04-py3
+FROM nvcr.io/nvidia/pytorch:26.03-py3
+
+# Rebuild TransformerEngine from main branch (includes PR #2753: MXFP8 FusedAdam support).
+# To pin a specific commit, replace 'main' with a commit hash.
+# Build: docker build -t llama3_native_te:te-main-26.03 .
+RUN pip uninstall -y transformer_engine transformer_engine_torch transformer_engine_cu12 && \
+ git clone --recursive https://github.com/NVIDIA/TransformerEngine.git /opt/te && \
+ cd /opt/te && git checkout main && \
+ NVTE_FRAMEWORK=pytorch MAX_JOBS=8 pip install --no-build-isolation . && \
+ rm -rf /opt/te
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/requirements.txt \
diff --git a/bionemo-recipes/recipes/llama3_native_te/README.md b/bionemo-recipes/recipes/llama3_native_te/README.md
index 2be3b0f11e..dffd92d775 100644
--- a/bionemo-recipes/recipes/llama3_native_te/README.md
+++ b/bionemo-recipes/recipes/llama3_native_te/README.md
@@ -1,8 +1,8 @@
# TransformerEngine-accelerated Llama 3 training with native PyTorch training loop
This folder demonstrates how to train TE-accelerated Llama 3 with a native PyTorch training loop, including sequence
-packing and FP8 precision, using fully sharded data parallel (FSDP) for distributed training. This recipe is configured
-for genomic sequences using a custom nucleotide tokenizer.
+packing, FP8/MXFP8/NVFP4 precision with layer-wise control, using fully sharded data parallel (FSDP) for distributed
+training. This recipe is configured for genomic sequences using a custom nucleotide tokenizer.
## How to use this recipe
@@ -16,9 +16,9 @@ bionemo-framework repository. You can download a zipped directory of this folder
## Supported Models and Training Features
-| Model | BF16 | FP8[1] | THD Input Format | FP8 with THD Input Format | MXFP8[2] | Context Parallelism | Tensor Parallelism |
-| ---------------------------------------- | ---- | ----------------- | ---------------- | ------------------------- | ------------------- | ------------------- | ------------------ |
-| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
+| Model | BF16 | FP8[1] | MXFP8[2] | NVFP4[3] | THD Input Format | Context Parallelism | Tensor Parallelism |
+| ---------------------------------------- | ---- | ----------------- | ------------------- | ------------------- | ---------------- | ------------------- | ------------------ |
+| [Llama 3](../../models/llama3/README.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 |
✅: Supported
🚧: Under development
@@ -26,6 +26,7 @@ bionemo-framework repository. You can download a zipped directory of this folder
\[1\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 9.0 and above (Hopper+)
\[2\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and 10.3 (Blackwell), 12.0 support pending
+\[3\]: Requires [compute capability](https://developer.nvidia.com/cuda-gpus) 10.0 and above (Blackwell+)
### Installing Dependencies
@@ -64,19 +65,87 @@ def compute_model_pflops(seq_len, global_batch_size, step_time_s):
return model_flops / 1e15
```
+### Low precision performance benchmarks
+
+
+In the above plot we can see the performance increases as we lower the precision of our transformer layers across the 1B and 8B variant of LLAMA3.
+
+#### MXFP8 vs BF16 throughput on Llama-3.1 (Lingua / DCLM)
+
+We benchmarked MXFP8 (`MXFP8BlockScaling`, E4M3) against the BF16 baseline using the Lingua / DCLM training setup on NVIDIA Blackwell GPUs. All runs use TransformerEngine, fused AdamW with FP32 master weights, and THD sequence packing.
+
+
+
+
+
+**Key finding:** plain MXFP8 over BF16 gives roughly the same ~30% throughput uplift on both 8B and 70B. Quantized model init (`qinit`) adds essentially nothing on 8B (+0.8 pp) but **adds ~10 percentage points on 70B (+9.7 pp)** — the per-layer quantize/dequantize work saved by qinit scales with depth (80 vs 32 transformer layers). On 70B, MXFP8 + qinit delivers a **+38.4% throughput gain over BF16** on a single B300 node.
+
+
+Single-node detail: per-model 3-way comparisons
+
+The per-model charts below show the 3-way (BF16 / MXFP8 no-qinit / MXFP8 + qinit) comparison underlying the headline figure above.
+
+**Llama-3.1-8B** (1 node / 8× B300 SXM6 AC, mbs=4, grad_acc=1, gbs = 32 seqs / 262k tokens, seq_len = 8192):
+
+
+
+
+
+On a single B300 node, **MXFP8 + qinit (+31.1%) and MXFP8 without qinit (+30.4%) deliver essentially the same throughput gain over BF16**. At this layer count the per-layer quantize/dequantize saving qinit provides is small; the speedup comes mainly from the FP8 GEMMs themselves. Averaged over global step ∈ [500, 990].
+
+**Llama-3.1-70B** (1 node / 8× B300 SXM6 AC, mbs=1, grad_acc=1, cp=2, dp=4, gbs = 4 seqs / ≈34k packed tokens, seq_len = 8192):
+
+
+
+
+
+On a single B300 node, **MXFP8 + qinit (+39.4%) pulls ahead of MXFP8 without qinit (+28.7%) — a ~10 percentage point gap that doesn't appear at 8B**. With 80 transformer layers, the per-step quantize/dequantize work avoided by qinit (the FP8 weight is already in compute format, so no on-the-fly cast every forward and backward) adds up to a meaningful throughput gain over the no-qinit path. Averaged over global step ∈ [500, 990]. We also separately measured `preserve_high_precision_init_val=True` (HPIV) and found it within 1% of the qinit-without-HPIV throughput, so HPIV's startup-time master-weight seeding is essentially free at steady state.
+
+
+
+
+Multi-node throughput (B200, production-scale runs)
+
+We also measured MXFP8 + qinit throughput at scale, on multi-node B200 with longer Lingua DCLM runs to confirm the single-node findings hold in production conditions.
+
+**Llama-3.1-8B** (8 nodes / 64× B200, mbs=2, grad_acc=2, global batch = 256 seqs, seq_len = 8192):
+
+
+
+
+
+MXFP8 + qinit reaches **22,517 unpadded tokens / s / GPU vs 17,644 for BF16 — a +27.6% throughput gain (×1.28 speedup, −21.7% step time)**. Averaged over global step ∈ [500, 1000].
+
+**Llama-3.1-70B** (4 nodes / 32× B200, cp=2, dp=16, mbs=1, grad_acc=1, gbs = 16 seqs, seq_len = 8192):
+
+
+
+
+
+MXFP8 + qinit reaches **2,725 unpadded tokens / s / GPU vs 1,972 for BF16 — a +38.2% throughput gain (×1.40 speedup, −27.6% step time)**. Averaged over global step ∈ [100, 490]. The larger relative gain on 70B vs 8B at scale matches the size-dependent pattern shown in the single-node headline above.
+
+
+
+Wandb runs:
+
+- Single-node 8B — [BF16](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_bf16_mbs4_1n_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_qinit_mbs4_1n_bia) / [MXFP8 (no qinit)](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_no_qinit_mbs4_1n_bia)
+- Single-node 70B (1k steps) — [BF16](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_bf16_mbs1_1n_1k_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_mbs1_1n_1k_bia) / [MXFP8 + qinit + HPIV](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_hpiv_mbs1_1n_1k_bia) / [MXFP8 (no qinit)](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_no_qinit_mbs1_1n_1k_bia)
+- Multi-node 8B — [BF16](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua-7b-bf16-baseline) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-7b/runs/lingua_7b_mxfp8_qinit_v6_te_main_8n_prenyx)
+- Multi-node 70B — [BF16](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_bf16_thd_fusedadam_4n_cp2_bia) / [MXFP8 + qinit](https://wandb.ai/clara-discovery/lingua-70b/runs/lingua_70b_mxfp8_qinit_thd_fusedadam_4n_cp2_bia)
+
### Convergence Benchmarks
-
-
+
+
-We compared the convergence of this Llama3 recipe (with FSDP2) against
-[NeMo 2.0](https://github.com/NVIDIA-NeMo/NeMo) and the [facebookresearch/lingua](https://github.com/facebookresearch/lingua)
+We compared the convergence of this Llama3 recipe (with FSDP2) against NeMo 2.0
+(https://github.com/NVIDIA-NeMo/NeMo) and the [facebookresearch/lingua](https://github.com/facebookresearch/lingua)
implementation on the DCLM Baseline 1.0 dataset. See [Training on Natural Language Data (Lingua
-Reproduction)](#training-on-natural-language-data-lingua-reproduction) for more details. The figure above shows similar loss convergence and step time to
+Reproduction)](#lingua-reproduction) for more details. The figure above shows similar loss convergence and step time to
the NeMo 2.0 training example, and the following table shows downstream performance on various tasks using the
-[lm-eval](https://github.com/eleutherai/lm-evaluation-harness) library. The variation in training step time every 10,000 steps
+[lm-eval](github.com/eleutherai/lm-evaluation-harness) library. The variation in training step time every 10,000 steps
are due checkpointing, further work will be done to improve training step time stability.
| name | arc_challenge | arc_easy | boolq | copa | hella_swag | piqa | winogrande |
@@ -88,6 +157,10 @@ are due checkpointing, further work will be done to improve training step time s
Models were trained on 64 NVIDIA H100 GPUs with a micro batch size of 4 and a context length of 4096 for 60,000 steps.
Training was performed with BF16 precision.
+### Low Precision convergence benchmarks
+
+For the multi-node 8B run on DCLM, the MXFP8 + quantized init training loss tracks the BF16 baseline to within ~0.1% over 60k steps, confirming the throughput gains above come with no measurable convergence regression. A small additional improvement is observed when keeping the first and last transformer layers in BF16 while running all other layers in MXFP8 (configurable via `fp8_layers`).
+
### Distributed Training
This recipe supports distributed training using DDP, FSDP2, and FSDP2 with Context Parallelism, shown in three separate training entrypoints:
@@ -127,10 +200,10 @@ batch size while running on a smaller number of GPUs.
python train_fsdp2.py --config-name L0_sanity grad_acc_steps=2
```
-### FP8 Training
+### Quantized Training (FP8 / MXFP8 / NVFP4)
To run training with FP8, enable it by overriding the `fp8_config.enabled=true` configuration parameter. Additional FP8
-configuration parameters, including switching to `MXFP8BlockScaling`, can be set via the hydra configuration.
+configuration parameters, including switching to `MXFP8BlockScaling`, can be set using the hydra configuration.
```bash
python train_fsdp2.py --config-name L0_sanity fp8_config.enabled=true
@@ -150,24 +223,60 @@ python train_fsdp2.py --config-name L0_sanity \
#### FP8 Debugging
-We also provide a mechanism to receive tensor data related to FP8 layers during training which may include activations, weights and gradients.
+```bash
+python train_fsdp2.py --config-name L0_sanity fp4_config.enabled=true
+```
+
+Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet
+supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today.
+
+Additional recipe parameters (e.g., switching to `MXFP8BlockScaling`) can be set via the hydra configuration.
+
+#### Layer-Wise Precision
+
+You can control which transformer layers use FP8 or FP4 by specifying 1-indexed layer numbers via `fp8_layers` and
+`fp4_layers`. Layers not assigned to either format will run in BF16.
+
+For example, to run layers 1-3 in FP8, layers 4-6 in FP4, and the rest in BF16 on a model with more than 6 layers:
+
+```bash
+python train_fsdp2.py --config-name L0_sanity \
+ fp8_config.enabled=true \
+ fp4_config.enabled=true \
+ 'fp8_layers=[1,2,3]' \
+ 'fp4_layers=[4,5,6]'
+```
+
+When both `fp8_config` and `fp4_config` are enabled but only one layer list is provided, the other format automatically
+claims the remaining layers. For example, if `fp8_layers=[1,2,3]` is set and `fp4_config.enabled=true` with no
+`fp4_layers`, then layers 4 through N will default to FP4.
+
+#### Quantization Stats Debugging
+
+We provide a mechanism to log tensor statistics (activations, weights, gradients) for quantized layers during training.
+When layer-wise precision is used, the stats config is automatically updated so that only the relevant layers are
+tracked.
-To enable this please select the following config options.
+To enable stats logging:
```bash
python train_fsdp2.py \
- fp8_stats_config.enabled=True \
- fp8_stats_config.fp8_log_dir=./logs/fp8_stats_logs_dummy \
- fp8_stats_config.fp8_stats_file=./fp8_debugging_stats.yaml \
- fp8_config.enabled=True
+ quant_stats_config.enabled=true \
+ quant_stats_config.quant_log_dir=./logs/quant_stats \
+ quant_stats_config.quant_stats_file=./fp8_debugging_stats.yaml \
+ fp8_config.enabled=true
```
-Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts.
+Note: This feature is available for the `train_ddp` and the `train_fsdp2` scripts. NVFP4 stats logging is not yet
+supported and will be enabled in a future TransformerEngine release; FP8/MXFP8 stats logging works today.
-The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the [NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html) in more detail. Below we will cover some very basic elements of the file structure.
+The config file structure [fp8_debugging_stats.yaml](fp8_debugging_stats.yaml) is explained in the
+[NVIDIA Transformer Engine config file documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/debug/2_config_file_structure.html)
+in more detail.
-This comes as a performance cost that is dependent on the `freq` parameter mentioned above. `freq=1` collects stats on every step which in our
-experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We recommend using `freq>=10` to reduce this performance hit.
+Stats collection has a performance cost dependent on the `freq` parameter in the config file. `freq=1` collects stats
+on every step which in our experiments caused a ~29% decrease in throughput (executed on a single RTX 5090). We
+recommend using `freq>=10` to reduce this performance hit.
### Sequence Packing (THD input format)
@@ -217,7 +326,7 @@ python train_fsdp2.py --config-name L0_sanity \
dataset.load_dataset_kwargs.path=/path/to/download/directory
```
-## Training on Natural Language Data (Lingua Reproduction)
+## Training on Natural Language Data (Lingua Reproduction) {#lingua-reproduction}
We provide a configuration to reproduce the Llama-3.2-1B training experiments from [Meta
Lingua](https://github.com/facebookresearch/lingua), using the [DCLM Baseline
diff --git a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
index 2dc5d10dcf..d0accd200d 100644
--- a/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
+++ b/bionemo-recipes/recipes/llama3_native_te/checkpoint.py
@@ -34,6 +34,7 @@
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
from torch.distributed.checkpoint.stateful import Stateful
+from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam as _FSDPParam
from torch.distributed.tensor import DTensor
from torchdata.stateful_dataloader import StatefulDataLoader
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
@@ -41,6 +42,81 @@
from distributed_config import DistributedConfig
+# ---------------------------------------------------------------------------
+# Monkey-patch FSDP2's FSDPParam.reset_sharded_param to handle QuantizedTensor.
+#
+# After checkpoint load, set_state_dict calls copy_() on FSDP-sharded params.
+# For QuantizedTensor (MXFP8Tensor), copy_() re-quantizes which can invalidate
+# the old untyped_storage, causing data_ptr() to crash. The original code
+# (PyTorch _fsdp_param.py) compares storage pointers without guarding against
+# QuantizedTensor. This patch wraps the comparison in a try/except so that
+# reset_sharded_param can proceed normally (re-recording _sharded_param_data).
+# ---------------------------------------------------------------------------
+
+
+def _patched_reset_sharded_param(self): # type: ignore[no-untyped-def]
+ """reset_sharded_param with QuantizedTensor safety."""
+ module_info = self._module_info
+ new_param = getattr(module_info.module, module_info.param_name)
+ if new_param is not self.sharded_param:
+ if torch.__future__.get_swap_module_params_on_conversion():
+ raise AssertionError(
+ f"Expects swap_tensors to preserve object but got {new_param} instead of {self.sharded_param}"
+ )
+ self.sharded_param = new_param
+
+ local_tensor = new_param._local_tensor
+ if local_tensor.is_meta:
+ return
+
+ updated_local_tensor = False
+ same_local_tensor = False
+
+ if type(self._sharded_param_data) is torch.Tensor:
+ try:
+ same_local_tensor = (
+ self._sharded_param_data.untyped_storage().data_ptr() > 0
+ and self._sharded_param_data.untyped_storage().data_ptr() == local_tensor.untyped_storage().data_ptr()
+ )
+ except RuntimeError:
+ # QuantizedTensor (e.g. MXFP8Tensor) can have invalid storage
+ # after copy_() re-quantization. Treat as not-same so that
+ # _sharded_param_data gets re-recorded below.
+ same_local_tensor = False
+
+ padded_sharded_size = self.padded_sharded_param_size
+ shard_dim = self.fsdp_placement.dim
+ length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
+
+ if local_tensor.size() != padded_sharded_size and not same_local_tensor:
+ if shard_dim != 0:
+ raise AssertionError(f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}")
+ padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
+ padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(local_tensor)
+ local_tensor = padded_local_tensor
+ updated_local_tensor = True
+
+ if self.pin_memory and not local_tensor.is_pinned():
+ local_tensor = local_tensor.cpu().pin_memory()
+ updated_local_tensor = True
+
+ if not same_local_tensor:
+ self._sharded_param_data = local_tensor.view(-1)
+
+ if not isinstance(self.sharded_param, DTensor):
+ raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
+
+ if updated_local_tensor:
+ self.sharded_param._local_tensor = local_tensor.narrow(dim=shard_dim, start=0, length=length)
+ if not self.sharded_param._local_tensor.is_contiguous():
+ raise AssertionError("Expected sharded_param._local_tensor to be contiguous")
+
+ self._sharding_spec = self.sharded_param._spec
+
+
+_FSDPParam.reset_sharded_param = _patched_reset_sharded_param
+
+
logger = logging.getLogger(__name__)
# Tracks in-flight async checkpoint futures keyed by strategy name (e.g. "fsdp2").
diff --git a/bionemo-recipes/recipes/llama3_native_te/dataset.py b/bionemo-recipes/recipes/llama3_native_te/dataset.py
index adca62a43a..797e47c729 100644
--- a/bionemo-recipes/recipes/llama3_native_te/dataset.py
+++ b/bionemo-recipes/recipes/llama3_native_te/dataset.py
@@ -192,7 +192,6 @@ def create_bshd_dataloader(
data_collator = base_collator
logger.info("Using standard DataCollatorForLanguageModeling")
- # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader
train_dataloader = dataloader_class(
tokenized_dataset,
@@ -200,7 +199,7 @@ def create_bshd_dataloader(
batch_size=micro_batch_size,
collate_fn=data_collator,
num_workers=num_workers,
- pin_memory=not use_stateful_dataloader,
+ pin_memory=True,
persistent_workers=num_workers > 0,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
)
@@ -288,7 +287,6 @@ def create_thd_dataloader(
f"Using GenomicDataCollator (uppercase={uppercase_labels}, mask_degenerate={mask_degenerate_bases})"
)
- # TODO(BIONEMO-3246) - remove the pin_memory=False once StatefulDataLoader supports pin_memory again.
dataloader_class = StatefulDataLoader if use_stateful_dataloader else DataLoader
train_dataloader = dataloader_class(
TokenPackingDataset(
@@ -299,7 +297,7 @@ def create_thd_dataloader(
batch_size=None, # The TokenPackingDataset will handle the batching.
collate_fn=data_collator,
num_workers=num_workers,
- pin_memory=not use_stateful_dataloader,
+ pin_memory=True,
persistent_workers=num_workers > 0,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
)
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
new file mode 100644
index 0000000000..9046d44caf
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
@@ -0,0 +1,33 @@
+example_fp4_tensor_stat_collection:
+ enabled: True
+ layers:
+ # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming)
+ # This matches: model.model.layers.[1-5].*.(layernorm_qkv|proj|fc1|fc2)
+ layer_name_regex_pattern: 'model\.model\.layers\.[1-5]\..*(layernorm_qkv|proj|fc1|fc2)'
+ transformer_engine:
+ LogNvfp4TensorStats:
+ enabled: True
+ tensors_struct:
+ - tensor: activation
+ stats: [underflows%, mse]
+ freq: 100
+ - tensor: gradient
+ stats: [underflows%, mse]
+ freq: 100
+
+example_fp8_tensor_stat_collection:
+ enabled: True
+ layers:
+ # Use regex to select layers (1-indexed as layers.1 through layers.N in the naming)
+ # This matches: model.model.layers.[6-10].*.(layernorm_qkv|proj|fc1|fc2)
+ layer_name_regex_pattern: 'model\.model\.layers\.([6-9]|10)\..*(layernorm_qkv|proj|fc1|fc2)'
+ transformer_engine:
+ LogFp8TensorStats:
+ enabled: True
+ tensors_struct:
+ - tensor: activation
+ stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
+ freq: 100
+ - tensor: gradient
+ stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
+ freq: 100
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
deleted file mode 100644
index d01024f04c..0000000000
--- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-import os
-from pathlib import Path
-
-import nvdlfw_inspect.api as debug_api
-import transformer_engine
-
-from distributed_config import DistributedConfig
-
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.INFO)
-
-
-def initialize_fp8_debugging(
- dist_config: DistributedConfig,
- enabled: bool,
- fp8_stats_file: str,
- fp8_log_dir: str | os.PathLike,
- fp8_enabled: bool,
-) -> None:
- """Initialize FP8 debugging.
-
- Args:
- dist_config: The distributed configuration.
- enabled: Whether to enable FP8 debugging.
- fp8_stats_file: The file containing the FP8 stats.
- fp8_log_dir: The directory to log the FP8 stats to.
- fp8_enabled: Whether FP8 autocast is enabled.
- """
- if not enabled:
- return
-
- if not fp8_enabled:
- raise ValueError(
- "fp8_stats_config.enabled is true but fp8_config.enabled is false, "
- "please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
- )
-
- fp8_log_dir = Path(fp8_log_dir) / f"rank_{dist_config.rank}"
- fp8_log_dir.mkdir(parents=True, exist_ok=True)
- logger.info(f"Logging FP8 stats to {fp8_log_dir}")
- te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
- debug_api.initialize(
- config_file=fp8_stats_file,
- feature_dirs=[te_features_dir],
- log_dir=fp8_log_dir.as_posix(),
- default_logging_enabled=True,
- )
diff --git a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
index 7544bbedcf..ba640a6cbb 100644
--- a/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
@@ -2,7 +2,7 @@ example_fp8_tensor_stat_collection:
enabled: True
layers:
# Match the actual linear layers within attention that support FP8 stats
- layer_types: [layernorm_qkv]
+ layer_types: [layernorm_qkv, proj, fc1, fc2]
transformer_engine:
LogFp8TensorStats:
enabled: True
@@ -16,3 +16,8 @@ example_fp8_tensor_stat_collection:
- tensor: weight
stats: [underflows%, scale_inv_min, scale_inv_max, mse]
freq: 10
+ LogTensorStats:
+ enabled: True
+ stats: [max, min, mean, std, l1_norm]
+ tensors: [dgrad, wgrad]
+ freq: 1
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml
new file mode 100644
index 0000000000..84196961cc
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b.yaml
@@ -0,0 +1,73 @@
+# Lingua 70B BF16 with Context Parallelism (CP=2).
+# Uses train_fsdp2_cp.py, not train_fsdp2.py.
+
+defaults:
+ - defaults
+ - _self_
+
+config_name_or_path: ./model_configs/meta-llama/Llama-3.1-70B
+
+config_kwargs:
+ attn_input_format: bshd
+ self_attn_mask_type: causal
+
+# CP=2 halves per-GPU sequence length, cutting attention activation memory ~4x.
+cp_size: 2
+
+# BSHD format required for CP (no sequence packing).
+use_sequence_packing: false
+
+# Meta device init is critical for 70B to avoid OOM during model construction.
+use_meta_device: true
+
+# FP32 master weights via TE FusedAdam
+use_fp32_master_weights: true
+
+wandb:
+ name: lingua-70b-bf16-cp2
+ project: lingua-70b
+ id: lingua-70b-bf16-cp2
+
+num_train_steps: 60_000
+
+dataset:
+ tokenizer_name_or_path: ./tokenizers/Meta-Llama-3-8B
+ micro_batch_size: 1
+ num_workers: 4
+ max_seq_length: 8192
+ stride: 512
+ buffer_size: 10_000
+ use_stateful_dataloader: false
+ mask_degenerate_bases: false
+ uppercase_labels: false
+ load_dataset_kwargs:
+ path: parquet
+ data_files: "/workspace/data/dclm-baseline/global-shard_01_of_10/**/*.parquet"
+ split: "train"
+ streaming: true
+
+# With CP=2, dp_size = 32/2 = 16. GBS = 1 * 16 * 16 = 256
+grad_acc_steps: 16
+
+adamw_kwargs:
+ lr: 3e-4
+ fused: true
+ betas: [0.9, 0.95]
+ eps: 1e-8
+ weight_decay: 0.01
+
+lr_scheduler_kwargs:
+ num_warmup_steps: 5_000
+ num_decay_steps: 55_000
+ min_lr_ratio: 0.000001
+
+# Checkpoint config
+checkpoint:
+ ckpt_dir: null
+ save_final_model: true
+ resume_from_checkpoint: true
+ save_every_n_steps: 1000
+ async_save: true
+
+profiler:
+ enabled: false
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml
new file mode 100644
index 0000000000..d226d32214
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_70b_mxfp8_qinit.yaml
@@ -0,0 +1,23 @@
+# Lingua 70B MXFP8 with quantized model init (all layers FP8) + CP=2.
+# Quantized model init stores only FP8 weights (no BF16 copies), saving memory.
+# FP32 master weights are maintained in FusedAdam optimizer.
+# BSHD format (no sequence packing) — required for CP.
+# Requires Blackwell GPUs (GB200/B300) for hardware MXFP8 support.
+
+defaults:
+ - L2_lingua_70b
+ - _self_
+
+# All layers in FP8 (no FL1 exclusion) — compatible with quantized_model_init.
+fp8_config:
+ enabled: true
+ fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling
+ fp8_format: E4M3
+ fp8_recipe_kwargs: {}
+ quantized_model_init_kwargs:
+ enabled: true
+ preserve_high_precision_init_val: true
+
+wandb:
+ name: lingua-70b-mxfp8-qinit-cp2
+ id: lingua-70b-mxfp8-qinit-cp2
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml
new file mode 100644
index 0000000000..65098a53fd
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b.yaml
@@ -0,0 +1,63 @@
+# Config to match the Llama-3.1-8B model pre-training experiments from https://github.com/facebookresearch/lingua.
+
+defaults:
+ - defaults
+ - _self_
+
+config_name_or_path: ./model_configs/meta-llama/Llama-3.1-8B
+
+config_kwargs:
+ attn_input_format: thd
+
+use_sequence_packing: true
+
+# FP32 master weights via TE FusedAdam (recommended over MixedPrecisionPolicy)
+use_fp32_master_weights: true
+
+wandb:
+ name: lingua-8b-bf16
+ project: lingua-8b
+ id: lingua-8b-bf16
+
+num_train_steps: 60_000
+
+dataset:
+ tokenizer_name_or_path: ./tokenizers/Meta-Llama-3-8B
+ micro_batch_size: 2
+ num_workers: 8
+ max_seq_length: 8192
+ stride: 512
+ buffer_size: 10_000
+ use_stateful_dataloader: false
+ mask_degenerate_bases: false
+ uppercase_labels: false
+ load_dataset_kwargs:
+ path: parquet
+ data_files: "/workspace/data/dclm-baseline/global-shard_01_of_10/**/*.parquet"
+ split: "train"
+ streaming: true
+
+grad_acc_steps: 4 # GBS = 2 * 4 * 32 GPUs = 256 (4 nodes)
+
+adamw_kwargs:
+ lr: 3e-4
+ fused: true
+ betas: [0.9, 0.95]
+ eps: 1e-8
+ weight_decay: 0.01
+
+lr_scheduler_kwargs:
+ num_warmup_steps: 5_000
+ num_decay_steps: 55_000 # total_steps - num_warmup_steps = 60_000 - 5_000
+ min_lr_ratio: 0.000001
+
+# Checkpoint config
+checkpoint:
+ ckpt_dir: null
+ save_final_model: true
+ resume_from_checkpoint: true
+ save_every_n_steps: 10_000
+ async_save: false
+
+profiler:
+ enabled: false
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml
new file mode 100644
index 0000000000..b9227cf101
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/L2_lingua_8b_mxfp8_qinit.yaml
@@ -0,0 +1,22 @@
+# Lingua 8B MXFP8 with quantized model init (all layers FP8).
+# Quantized model init stores only FP8 weights (no BF16 copies), saving memory.
+# FP32 master weights are maintained in FusedAdam optimizer.
+# Requires Blackwell GPUs (GB200) for hardware MXFP8 support.
+
+defaults:
+ - L2_lingua_8b
+ - _self_
+
+# All layers in FP8 (no FL1 exclusion) — compatible with outer quantized_model_init.
+fp8_config:
+ enabled: true
+ fp8_recipe: transformer_engine.common.recipe.MXFP8BlockScaling
+ fp8_format: E4M3
+ fp8_recipe_kwargs: {}
+ quantized_model_init_kwargs:
+ enabled: true
+ preserve_high_precision_init_val: true
+
+wandb:
+ name: lingua-8b-mxfp8-qinit
+ id: lingua-8b-mxfp8-qinit
diff --git a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
index 9302a0758d..82336651a7 100644
--- a/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
+++ b/bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
@@ -32,6 +32,8 @@ dataset:
wandb:
name: ???
project: null # Optional: set to your wandb project name
+ id: null # Set to a fixed ID to resume the same run across restarts
+ resume: allow # "allow" resumes if id exists, else creates new run
# TransformerEngine FP8 config. See
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html for more information on
@@ -41,6 +43,8 @@ fp8_config:
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
fp8_format: "HYBRID"
fp8_recipe_kwargs: {}
+ quantized_model_init_kwargs:
+ enabled: false # If this is set to true, fp8_config.enabled must also be set to true.
fp4_config:
enabled: false
@@ -74,10 +78,15 @@ checkpoint:
logger:
frequency: 100
-fp8_stats_config:
+quant_stats_config:
enabled: false
- fp8_stats_file: ./fp8_debugging_stats.yaml
- fp8_log_dir: ./log_fp8_stats
+ quant_stats_file: ./fp8_debugging_stats.yaml
+ quant_log_dir: ./log_quant_stats
+
+# Note: The layers are going to come in 1 indexed and we convert them to be 0 indexed at runtime.
+fp8_layers: null
+fp4_layers: null
+use_fp32_master_weights: null # Use TE FusedAdam for FP32 master weights
profiler:
enabled: false
diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json
new file mode 100644
index 0000000000..bd1408afc6
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-70B/config.json
@@ -0,0 +1,35 @@
+{
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": 128001,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 8192,
+ "initializer_range": 0.02,
+ "intermediate_size": 28672,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "model_type": "llama",
+ "num_attention_heads": 64,
+ "num_hidden_layers": 80,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.57.3",
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json
new file mode 100644
index 0000000000..460f2f1b71
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/model_configs/meta-llama/Llama-3.1-8B/config.json
@@ -0,0 +1,35 @@
+{
+ "architectures": [
+ "LlamaForCausalLM"
+ ],
+ "attention_bias": false,
+ "attention_dropout": 0.0,
+ "bos_token_id": 128000,
+ "eos_token_id": 128001,
+ "head_dim": 128,
+ "hidden_act": "silu",
+ "hidden_size": 4096,
+ "initializer_range": 0.02,
+ "intermediate_size": 14336,
+ "max_position_embeddings": 131072,
+ "mlp_bias": false,
+ "model_type": "llama",
+ "num_attention_heads": 32,
+ "num_hidden_layers": 32,
+ "num_key_value_heads": 8,
+ "pretraining_tp": 1,
+ "rms_norm_eps": 1e-05,
+ "rope_scaling": {
+ "factor": 8.0,
+ "high_freq_factor": 4.0,
+ "low_freq_factor": 1.0,
+ "original_max_position_embeddings": 8192,
+ "rope_type": "llama3"
+ },
+ "rope_theta": 500000.0,
+ "tie_word_embeddings": false,
+ "torch_dtype": "bfloat16",
+ "transformers_version": "4.57.3",
+ "use_cache": true,
+ "vocab_size": 128256
+}
diff --git a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
index 62171cd237..0bdb8a23e8 100644
--- a/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
@@ -58,6 +58,7 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ layer_precision: list[str | None] | None = None
def __init__(
self,
@@ -223,11 +224,54 @@ def _init_method(x):
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
+
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
+ def set_recipes(
+ self,
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ ) -> None:
+ """Attach quantization recipe objects for per-layer autocast.
+
+ Recipes are not serializable and must be set at runtime after model creation
+ and sharding (FSDP/DDP) but before training. The per-layer precision
+ assignments are read from ``self.config.layer_precision``.
+
+ Args:
+ fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
+ fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
+ """
+ self._fp8_recipe = fp8_recipe
+ self._fp4_recipe = fp4_recipe
+
+ def get_layer_autocast(self, layer_number: int):
+ """Return the appropriate TE autocast context manager for a given layer.
+
+ The context interacts with the outer FP8 autocast in the training script:
+ - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
+ - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
+ - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
+
+ Args:
+ layer_number: The 0-indexed layer number.
+
+ Returns:
+ A context manager for the layer's quantization mode.
+ """
+ precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+ if precision == "fp8":
+ return nullcontext()
+ elif precision == "fp4":
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
+ else:
+ return transformer_engine.pytorch.autocast(enabled=False)
+
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -304,12 +348,14 @@ def forward(
if te_rope_emb.dtype != torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
- with self.get_autocast_context(None, outer=True):
- for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
+ # by get_layer_autocast(), which nests inside this context.
+ with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
+ for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
- with self.get_autocast_context(layer_idx):
+ with self.get_layer_autocast(layer_number):
hidden_states = decoder_layer(
hidden_states,
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
@@ -369,8 +415,10 @@ def get_autocast_context(
if init and self.config.use_quantized_model_init:
if precision in ("fp8", "fp4"):
- return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
- return nullcontext()
+ return transformer_engine.pytorch.quantized_model_init(
+ recipe=recipe, preserve_high_precision_init_val=True
+ )
+ return transformer_engine.pytorch.quantized_model_init(enabled=False)
if precision == "fp8":
if recipe is None:
@@ -589,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
return 0
return max(self.sequences.values())
+ @property
+ def is_compileable(self) -> bool:
+ """Required by HuggingFace transformers generate() auto-compile check."""
+ return False
+
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder the cache based on the beam indices."""
if isinstance(self.cache_manager, PagedKVCacheManager):
@@ -597,8 +650,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_key_cache = key_cache.index_select(0, beam_idx)
updated_value_cache = value_cache.index_select(0, beam_idx)
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
-
- @property
- def is_compileable(self) -> bool:
- """Return False as this cache is not compatible with torch.compile."""
- return False
diff --git a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
index 726eb19e8e..4b1a8d4ec7 100644
--- a/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
+++ b/bionemo-recipes/recipes/llama3_native_te/perf_logger.py
@@ -91,7 +91,7 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig, start_step:
self.grad_acc_step_count = 0
# Whether to step debug_api.step() after each step
- self.fp8_stats_enabled = args.fp8_stats_config.enabled
+ self.quant_stats_config = args.quant_stats_config.enabled
@nvtx.annotate("PerfLogger.log_micro_step", color="pink")
def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: CausalLMOutputWithPast):
@@ -150,7 +150,7 @@ def log_step(
if self._profiler is not None:
self._profiler.step(step)
- if self.fp8_stats_enabled:
+ if self.quant_stats_config:
debug_api.step()
if step % self.logging_frequency == 0 and step > 0:
@@ -201,15 +201,15 @@ def log_step(
def finish(self):
"""Finish the logger and close the progress bar."""
+ if self.quant_stats_config:
+ debug_api.end_debug()
+
if not self._dist_config.is_main_process():
return
wandb.finish()
self._progress_bar.close()
- if self.fp8_stats_enabled:
- debug_api.end_debug()
-
class NsightProfiler:
"""Nsight Systems profiler wrapper for performance analysis.
diff --git a/bionemo-recipes/recipes/llama3_native_te/quantization.py b/bionemo-recipes/recipes/llama3_native_te/quantization.py
new file mode 100644
index 0000000000..e479b13c02
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/quantization.py
@@ -0,0 +1,223 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for layer-wise quantization configuration (FP8/FP4)."""
+
+import logging
+import tempfile
+from pathlib import Path
+
+import yaml
+
+
+logger = logging.getLogger(__name__)
+
+
+def generate_layer_regex(layer_numbers: list[int] | None) -> str:
+ """Generate a regex pattern to match specific layer numbers (1-indexed).
+
+ The debug API (nvdlfw_inspect) uses 1-indexed layer names after ``infer_and_assign_layer_names``.
+
+ Args:
+ layer_numbers: List of layer numbers (1-indexed, as shown in debug logs).
+ If empty or None, returns a pattern that matches nothing.
+
+ Returns:
+ Regex pattern string for matching those layers' linear sublayers.
+ """
+ if not layer_numbers:
+ return r"model\.model\.layers\.DISABLED_NO_LAYERS_SPECIFIED"
+ layer_pattern = "|".join(str(n) for n in sorted(layer_numbers))
+ return rf"model\.model\.layers\.({layer_pattern})\..*(layernorm_qkv|proj|fc1|fc2)"
+
+
+def update_quant_stats_config(
+ config_file: str,
+ fp4_layers: list[int] | None,
+ fp8_layers: list[int] | None,
+) -> str:
+ """Update the quant stats YAML config with layer-specific regex patterns.
+
+ Args:
+ config_file: Path to the original YAML config file.
+ fp4_layers: List of layer numbers for FP4 (1-indexed).
+ fp8_layers: List of layer numbers for FP8 (1-indexed).
+
+ Returns:
+ Path to the updated config file (a temp file).
+ """
+ with open(config_file, "r") as f:
+ config = yaml.safe_load(f)
+
+ if "example_fp4_tensor_stat_collection" in config:
+ # TODO: Remove this block and replace with FP8-style regex update once a TransformerEngine
+ # release with LogNvfp4TensorStats support is available. At that point, this becomes:
+ # fp4_regex = generate_layer_regex(fp4_layers)
+ # config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex
+ config["example_fp4_tensor_stat_collection"]["enabled"] = False
+ if fp4_layers:
+ logger.warning(
+ "NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). "
+ f"Disabling FP4 stats collection for layers {fp4_layers}. FP8 stats will still be collected."
+ )
+ else:
+ logger.info("FP4 stats section disabled (no FP4 layers and feature not yet supported)")
+
+ if "example_fp8_tensor_stat_collection" in config:
+ fp8_regex = generate_layer_regex(fp8_layers)
+ config["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp8_regex
+ if fp8_layers:
+ logger.info(f"Updated FP8 layer regex to match layers: {fp8_layers}")
+ else:
+ logger.info("FP8 layers empty - regex set to match nothing")
+
+ temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
+ yaml.dump(config, temp_file, default_flow_style=False)
+ temp_file.close()
+
+ config_str = yaml.dump(config, default_flow_style=False)
+ logger.info(f"Created updated quant stats config at: {temp_file.name}")
+ logger.info(f"Updated quant stats config contents:\n{config_str}")
+
+ return temp_file.name
+
+
+def initialize_quant_stats_logging(
+ quant_stats_file: str,
+ quant_log_dir: str,
+ rank: int,
+ layer_precision: list[str | None],
+) -> None:
+ """Set up quantization stats logging via nvdlfw_inspect.
+
+ Updates the quant stats YAML config with resolved layer regex patterns, creates
+ the per-rank log directory, and initializes the debug API.
+
+ Args:
+ quant_stats_file: Path to the base quant stats YAML config file.
+ quant_log_dir: Base directory for quant stats logs (a rank subdirectory will be created).
+ rank: The global rank of this process.
+ layer_precision: Per-layer precision list (0-indexed by position). Each element is
+ ``"fp8"``, ``"fp4"``, or ``None``.
+ """
+ import nvdlfw_inspect.api as debug_api
+ import transformer_engine
+
+ # Derive 1-indexed layer lists for the debug API, which uses 1-indexed layer names.
+ fp8_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp8"] or None
+ fp4_layers_1indexed = [i + 1 for i, p in enumerate(layer_precision) if p == "fp4"] or None
+ updated_config = update_quant_stats_config(
+ config_file=quant_stats_file,
+ fp4_layers=fp4_layers_1indexed,
+ fp8_layers=fp8_layers_1indexed,
+ )
+
+ rank_log_dir = Path(quant_log_dir) / f"rank_{rank}"
+ rank_log_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Logging quant stats to {rank_log_dir}")
+
+ te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
+ debug_api.initialize(
+ config_file=updated_config,
+ feature_dirs=[te_features_dir],
+ log_dir=rank_log_dir,
+ default_logging_enabled=True,
+ )
+
+
+def resolve_layer_precision(
+ num_layers: int,
+ fp8_enabled: bool,
+ fp4_enabled: bool,
+ fp8_layers: list[int] | None,
+ fp4_layers: list[int] | None,
+) -> list[str | None]:
+ """Resolve layer-wise quantization assignments from user config.
+
+ Takes 1-indexed layer lists (as specified by the user in YAML config) and returns a per-layer
+ precision list (0-indexed by position). When a quantization format is enabled but no layer list
+ is provided, all layers default to that format. When one format has explicit layers and the other
+ is enabled without a layer list, the unspecified format defaults to the remaining (unclaimed) layers.
+
+ Args:
+ num_layers: Total number of transformer layers in the model.
+ fp8_enabled: Whether FP8 quantization is enabled.
+ fp4_enabled: Whether FP4 quantization is enabled.
+ fp8_layers: 1-indexed list of layers for FP8, or None if not specified.
+ fp4_layers: 1-indexed list of layers for FP4, or None if not specified.
+
+ Returns:
+ A list of length ``num_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None``
+ (BF16 fallback), indexed by layer position (0-indexed).
+
+ Raises:
+ ValueError: If both formats are enabled with no layer lists, or if layer lists overlap.
+ """
+ all_layers = set(range(1, num_layers + 1))
+
+ if fp8_enabled and fp4_enabled and fp8_layers is None and fp4_layers is None:
+ raise ValueError(
+ "Both fp8_config and fp4_config are enabled but neither fp8_layers nor fp4_layers is specified. "
+ "When both are enabled, you must explicitly provide layer lists to indicate which layers use which format."
+ )
+
+ # When one format has explicit layers and the other defaults, fill in the remaining layers.
+ if fp8_enabled and fp8_layers is None:
+ claimed_by_fp4 = set(fp4_layers) if fp4_layers is not None else set()
+ fp8_layers = sorted(all_layers - claimed_by_fp4)
+ if claimed_by_fp4:
+ logger.warning(
+ f"fp8_config.enabled=True with no fp8_layers specified, but fp4_layers={sorted(claimed_by_fp4)} "
+ f"are already claimed by FP4. Defaulting FP8 to the remaining layers: {fp8_layers}"
+ )
+ else:
+ logger.info(
+ f"fp8_config.enabled=True with no fp8_layers specified, defaulting all {num_layers} layers to FP8"
+ )
+
+ if fp4_enabled and fp4_layers is None:
+ claimed_by_fp8 = set(fp8_layers) if fp8_layers is not None else set()
+ fp4_layers = sorted(all_layers - claimed_by_fp8)
+ if claimed_by_fp8:
+ logger.warning(
+ f"fp4_config.enabled=True with no fp4_layers specified, but fp8_layers={sorted(claimed_by_fp8)} "
+ f"are already claimed by FP8. Defaulting FP4 to the remaining layers: {fp4_layers}"
+ )
+ else:
+ logger.info(
+ f"fp4_config.enabled=True with no fp4_layers specified, defaulting all {num_layers} layers to FP4"
+ )
+
+ # Disable layer lists when corresponding config is not enabled.
+ if not fp8_enabled:
+ fp8_layers = None
+ if not fp4_enabled:
+ fp4_layers = None
+
+ # Validate no overlap between FP8 and FP4 layer assignments.
+ if fp8_layers is not None and fp4_layers is not None:
+ overlap = set(fp8_layers) & set(fp4_layers)
+ if overlap:
+ raise ValueError(
+ f"fp8_layers and fp4_layers cannot have overlapping layer numbers. Found overlap: {sorted(overlap)}"
+ )
+
+ # Build per-layer precision list (0-indexed by position, 1-indexed for lookup).
+ fp8_set = set(fp8_layers) if fp8_layers is not None else set()
+ fp4_set = set(fp4_layers) if fp4_layers is not None else set()
+ return [
+ "fp8" if layer_1indexed in fp8_set else "fp4" if layer_1indexed in fp4_set else None
+ for layer_1indexed in range(1, num_layers + 1)
+ ]
diff --git a/bionemo-recipes/recipes/llama3_native_te/requirements.txt b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
index 40f36f659d..a36f3df85f 100644
--- a/bionemo-recipes/recipes/llama3_native_te/requirements.txt
+++ b/bionemo-recipes/recipes/llama3_native_te/requirements.txt
@@ -5,7 +5,6 @@ torchao!=0.14.0
torchdata
torchmetrics
tqdm
-transformer_engine[pytorch]
transformers
wandb
zstandard
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
index 08330b12f7..bb7a2d8ed6 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
@@ -56,6 +56,8 @@ def pytest_collection_modifyitems(items):
stats_test_names = {
"test_sanity_ddp_fp8_stats_logging",
"test_sanity_fsdp2_fp8_stats_logging",
+ "test_sanity_ddp_fp8_partial_layers_stats_logging",
+ "test_sanity_fsdp2_fp8_partial_layers_stats_logging",
}
stats_tests = [item for item in items if item.name in stats_test_names]
other_tests = [item for item in items if item.name not in stats_test_names]
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py
new file mode 100644
index 0000000000..f4aaf03404
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_mxfp8_fsdp2_checkpoint_resume.py
@@ -0,0 +1,311 @@
+#!/usr/bin/env python
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Minimal reproduction of FSDP2 + MXFP8 checkpoint resume crash.
+
+Bug: After fully_shard() wraps a model with quantized_model_init (MXFP8) params,
+checkpoint resume via set_state_dict crashes with:
+ RuntimeError: Attempted to access the data pointer on an invalid python storage.
+
+Root cause: set_state_dict -> model.load_state_dict -> copy_() on MXFP8Tensor
+re-quantizes, allocating new internal storage. FSDP2's reset_sharded_param
+(post-load hook) then calls untyped_storage().data_ptr() on the invalidated
+storage. PyTorch has a "# TODO: need to support tensor subclass" comment at
+the crash site (_fsdp_param.py line 892).
+
+Fix: Wrap the data_ptr() comparison in try/except RuntimeError. When it fails,
+treat as same_local_tensor=False so _sharded_param_data gets re-recorded.
+
+Run with: torchrun --nproc_per_node=2 test_mxfp8_fsdp2_checkpoint_resume.py
+"""
+
+import argparse
+import shutil
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+import transformer_engine.pytorch as te
+from torch.distributed.device_mesh import DeviceMesh
+from torch.distributed.fsdp import fully_shard
+from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+from torch.distributed.tensor import DTensor
+from torch.nn import functional as f_nn
+from transformer_engine.common.recipe import MXFP8BlockScaling
+from transformer_engine.pytorch.optimizers import FusedAdam
+from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
+
+
+HIDDEN = 256
+FFN_HIDDEN = 1024
+NUM_HEADS = 8
+NUM_LAYERS = 2
+SEQ_LEN = 32
+BATCH = 2
+
+
+def apply_reset_sharded_param_fix():
+ """Monkey-patch FSDPParam.reset_sharded_param to handle QuantizedTensor.
+
+ After checkpoint load, copy_() on MXFP8Tensor re-quantizes which can
+ invalidate the old untyped_storage, causing data_ptr() to crash.
+ This wraps the comparison in try/except so reset_sharded_param can
+ proceed normally (re-recording _sharded_param_data).
+ """
+
+ def _patched_reset_sharded_param(self):
+ module_info = self._module_info
+ new_param = getattr(module_info.module, module_info.param_name)
+ if new_param is not self.sharded_param:
+ if torch.__future__.get_swap_module_params_on_conversion():
+ raise AssertionError(
+ f"Expects swap_tensors to preserve object but got {new_param} instead of {self.sharded_param}"
+ )
+ self.sharded_param = new_param
+
+ local_tensor = new_param._local_tensor
+ if local_tensor.is_meta:
+ return
+
+ updated_local_tensor = False
+ same_local_tensor = False
+
+ if type(self._sharded_param_data) is torch.Tensor:
+ try:
+ same_local_tensor = (
+ self._sharded_param_data.untyped_storage().data_ptr() > 0
+ and self._sharded_param_data.untyped_storage().data_ptr()
+ == local_tensor.untyped_storage().data_ptr()
+ )
+ except RuntimeError:
+ # QuantizedTensor (e.g. MXFP8Tensor) can have invalid storage
+ # after copy_() re-quantization.
+ same_local_tensor = False
+
+ padded_sharded_size = self.padded_sharded_param_size
+ shard_dim = self.fsdp_placement.dim
+ length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
+
+ if local_tensor.size() != padded_sharded_size and not same_local_tensor:
+ if shard_dim != 0:
+ raise AssertionError(f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}")
+ padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
+ padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(local_tensor)
+ local_tensor = padded_local_tensor
+ updated_local_tensor = True
+
+ if self.pin_memory and not local_tensor.is_pinned():
+ local_tensor = local_tensor.cpu().pin_memory()
+ updated_local_tensor = True
+
+ if not same_local_tensor:
+ self._sharded_param_data = local_tensor.view(-1)
+
+ if not isinstance(self.sharded_param, DTensor):
+ raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
+
+ if updated_local_tensor:
+ self.sharded_param._local_tensor = local_tensor.narrow(dim=shard_dim, start=0, length=length)
+ if not self.sharded_param._local_tensor.is_contiguous():
+ raise AssertionError("Expected sharded_param._local_tensor to be contiguous")
+
+ self._sharding_spec = self.sharded_param._spec
+
+ FSDPParam.reset_sharded_param = _patched_reset_sharded_param
+
+
+def _save_custom_attrs(model):
+ """Save custom attrs on QuantizedTensor params (lost during fully_shard + reset_parameters)."""
+ attrs = {}
+ for name, param in model.named_parameters():
+ local = param._local_tensor if isinstance(param, DTensor) else param
+ if isinstance(local, QuantizedTensor):
+ param_attrs = {}
+ for attr_name in dir(local):
+ if not attr_name.startswith("_") and not callable(getattr(local, attr_name, None)):
+ try:
+ param_attrs[attr_name] = getattr(local, attr_name)
+ except Exception:
+ pass
+ attrs[name] = param_attrs
+ return attrs
+
+
+def _restore_custom_attrs(model, attrs):
+ """Restore custom attrs on QuantizedTensor params."""
+ for name, param in model.named_parameters():
+ if name in attrs:
+ local = param._local_tensor if isinstance(param, DTensor) else param
+ if isinstance(local, QuantizedTensor):
+ for attr_name, attr_val in attrs[name].items():
+ try:
+ setattr(local, attr_name, attr_val)
+ except Exception:
+ pass
+
+
+def build_model(recipe):
+ """Build model with quantized_model_init on meta device."""
+ with te.quantized_model_init(
+ recipe=recipe,
+ enabled=True,
+ preserve_high_precision_init_val=True,
+ ):
+ model = torch.nn.Sequential(
+ *[
+ te.TransformerLayer(
+ HIDDEN,
+ FFN_HIDDEN,
+ NUM_HEADS,
+ fuse_qkv_params=True,
+ params_dtype=torch.bfloat16,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ device="meta",
+ )
+ for _ in range(NUM_LAYERS)
+ ]
+ )
+ return model
+
+
+def shard_model(model, mesh):
+ """Apply FSDP2 sharding, then materialize meta params via reset_parameters."""
+ has_meta = any(p.is_meta for p in model.parameters())
+ custom_attrs = _save_custom_attrs(model)
+ for child in model.children():
+ fully_shard(child, mesh=mesh)
+ fully_shard(model, mesh=mesh)
+ if has_meta:
+ for module in model.modules():
+ if hasattr(module, "reset_parameters"):
+ module.reset_parameters()
+ _restore_custom_attrs(model, custom_attrs)
+ return model
+
+
+def build_and_shard(recipe, mesh, device):
+ """Build model, shard, create optimizer, run one step to populate optimizer state."""
+ model = build_model(recipe)
+ model = shard_model(model, mesh)
+
+ optimizer = FusedAdam(
+ model.parameters(),
+ lr=1e-3,
+ master_weights=True,
+ master_weight_dtype=torch.float32,
+ )
+
+ # Run one training step to populate optimizer state
+ x = torch.randn(SEQ_LEN, BATCH, HIDDEN, dtype=torch.bfloat16, device=device)
+ target = torch.randn_like(x)
+ optimizer.zero_grad(set_to_none=True)
+ with te.autocast(enabled=True, recipe=recipe):
+ out = model(x)
+ loss = f_nn.mse_loss(out, target)
+ loss.backward()
+ optimizer.step()
+
+ return model, optimizer
+
+
+def run(apply_fix: bool):
+ """Run the reproduction: save checkpoint, load it, verify forward pass."""
+ dist.init_process_group(backend="cpu:gloo,cuda:nccl")
+ rank = dist.get_rank()
+ torch.cuda.set_device(rank)
+ device = torch.device(f"cuda:{rank}")
+ world_size = dist.get_world_size()
+ mesh = DeviceMesh("cuda", list(range(world_size)))
+
+ recipe = MXFP8BlockScaling()
+
+ if apply_fix:
+ apply_reset_sharded_param_fix()
+ if rank == 0:
+ print("Applied reset_sharded_param fix")
+
+ # Build model, train one step, save checkpoint
+ model, optimizer = build_and_shard(recipe, mesh, device)
+ if rank == 0:
+ print("Model built and trained for 1 step")
+
+ # Record reference output
+ x = torch.randn(SEQ_LEN, BATCH, HIDDEN, dtype=torch.bfloat16, device=device)
+ with torch.no_grad(), te.autocast(enabled=True, recipe=recipe):
+ ref_output = model(x).clone()
+ if rank == 0:
+ print(f"Reference output recorded, norm={ref_output.norm().item():.4f}")
+
+ checkpoint_dir = "/tmp/te_test_mxfp8_fsdp2_ckpt_resume"
+ if rank == 0:
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
+ dist.barrier()
+
+ try:
+ # Save checkpoint
+ model_state = {k: v for k, v in model.state_dict().items() if not k.endswith("_extra_state")}
+ dcp.save({"model": model_state, "optimizer": optimizer.state_dict()}, checkpoint_id=checkpoint_dir)
+ dist.barrier()
+ if rank == 0:
+ print(f"Checkpoint saved to {checkpoint_dir}")
+
+ # Build fresh model
+ model2, optimizer2 = build_and_shard(recipe, mesh, device)
+ if rank == 0:
+ print("Fresh model built, loading checkpoint...")
+
+ # Load checkpoint — THIS IS WHERE THE CRASH HAPPENS WITHOUT THE FIX
+ model2_state = {k: v for k, v in model2.state_dict().items() if not k.endswith("_extra_state")}
+ state_to_load = {"model": model2_state, "optimizer": optimizer2.state_dict()}
+ dcp.load(state_to_load, checkpoint_id=checkpoint_dir)
+ model2.load_state_dict(state_to_load["model"], strict=False)
+ optimizer2.load_state_dict(state_to_load["optimizer"])
+ dist.barrier()
+ if rank == 0:
+ print("Checkpoint loaded successfully!")
+
+ # Verify output matches
+ with torch.no_grad(), te.autocast(enabled=True, recipe=recipe):
+ loaded_output = model2(x)
+
+ torch.testing.assert_close(
+ loaded_output,
+ ref_output,
+ rtol=0,
+ atol=0,
+ msg=lambda m: f"Output mismatch after checkpoint load: {m}",
+ )
+ if rank == 0:
+ print("Output parity verified — bitwise identical!")
+
+ finally:
+ dist.barrier()
+ if rank == 0:
+ shutil.rmtree(checkpoint_dir, ignore_errors=True)
+
+ dist.destroy_process_group()
+ if rank == 0:
+ print("SUCCESS" if apply_fix else "SUCCESS (unexpected — bug may be fixed upstream)")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--fix", action="store_true", help="Apply the reset_sharded_param monkey-patch fix")
+ args = parser.parse_args()
+ torch.manual_seed(42)
+ torch.cuda.manual_seed(42)
+ run(apply_fix=args.fix)
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
index aebdfe17ef..d919278d4a 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
@@ -34,7 +34,7 @@ def _make_args(logging_frequency=1, num_train_steps=100):
"wandb": {"project": "test", "mode": "disabled"},
"num_train_steps": num_train_steps,
"profiler": {"enabled": False},
- "fp8_stats_config": {"enabled": False},
+ "quant_stats_config": {"enabled": False},
}
)
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
new file mode 100644
index 0000000000..2d6e02b050
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
@@ -0,0 +1,332 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import sys
+from pathlib import Path
+
+import pytest
+import yaml
+
+
+sys.path.append(Path(__file__).parent.parent.as_posix())
+
+from quantization import generate_layer_regex, resolve_layer_precision, update_quant_stats_config
+
+
+# -- resolve_layer_precision --
+
+
+def test_fp8_enabled_no_layers_defaults_all():
+ """When fp8 is enabled with no explicit layers, all layers should default to FP8."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp8", "fp8", "fp8"]
+
+
+def test_fp4_enabled_no_layers_defaults_all():
+ """When fp4 is enabled with no explicit layers, all layers should default to FP4."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp4", "fp4", "fp4", "fp4", "fp4", "fp4"]
+
+
+def test_fp8_explicit_layers():
+ """Explicit 1-indexed fp8_layers should produce fp8 at those positions."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[1, 3, 5], fp4_layers=None
+ )
+ assert result == ["fp8", None, "fp8", None, "fp8", None]
+
+
+def test_fp4_explicit_layers():
+ """Explicit 1-indexed fp4_layers should produce fp4 at those positions."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=True, fp8_layers=None, fp4_layers=[2, 4, 6]
+ )
+ assert result == [None, "fp4", None, "fp4", None, "fp4"]
+
+
+def test_mixed_fp8_fp4_explicit():
+ """Both enabled with explicit non-overlapping layers should work correctly."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 3, 4], fp4_layers=[2, 5]
+ )
+ assert result == ["fp8", "fp4", "fp8", "fp8", "fp4", None]
+
+
+def test_both_enabled_no_layers_raises():
+ """Both enabled with no layer lists should raise ValueError."""
+ with pytest.raises(ValueError, match="Both fp8_config and fp4_config are enabled"):
+ resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=None)
+
+
+def test_overlapping_layers_raises():
+ """Overlapping layer assignments should raise ValueError."""
+ with pytest.raises(ValueError, match="fp8_layers and fp4_layers cannot have overlapping"):
+ resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=[3, 4, 5]
+ )
+
+
+def test_disabled_ignores_layers():
+ """When a format is disabled, its layers should be ignored."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=[1, 2, 3], fp4_layers=[4, 5, 6]
+ )
+ assert result == [None, None, None, None, None, None]
+
+
+def test_both_disabled():
+ """Both disabled with no layers should return all None."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == [None, None, None, None, None, None]
+
+
+def test_large_model_defaults_all():
+ """Auto-population should work correctly for larger models (e.g. 36 layers)."""
+ result = resolve_layer_precision(
+ num_layers=36, fp8_enabled=True, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert result == ["fp8"] * 36
+
+
+def test_fp8_enabled_empty_list():
+ """An explicit empty list should remain empty (not default to all)."""
+ result = resolve_layer_precision(num_layers=6, fp8_enabled=True, fp4_enabled=False, fp8_layers=[], fp4_layers=None)
+ assert result == [None, None, None, None, None, None]
+
+
+def test_both_enabled_fp8_specified_fp4_defaults_to_remaining():
+ """When both enabled, FP8 has explicit layers, FP4 should default to the remaining layers."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=[1, 2, 3], fp4_layers=None
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_both_enabled_fp4_specified_fp8_defaults_to_remaining():
+ """When both enabled, FP4 has explicit layers, FP8 should default to the remaining layers."""
+ result = resolve_layer_precision(
+ num_layers=6, fp8_enabled=True, fp4_enabled=True, fp8_layers=None, fp4_layers=[4, 5, 6]
+ )
+ assert result == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
+
+
+def test_returns_correct_length():
+ """Result list length should always equal num_layers."""
+ for n in [1, 6, 48]:
+ result = resolve_layer_precision(
+ num_layers=n, fp8_enabled=False, fp4_enabled=False, fp8_layers=None, fp4_layers=None
+ )
+ assert len(result) == n
+
+
+# -- generate_layer_regex --
+
+
+def test_single_layer():
+ """Single layer should produce a simple regex."""
+ regex = generate_layer_regex([3])
+ assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv")
+ assert not re.search(regex, "model.model.layers.2.self_attention.layernorm_qkv")
+
+
+def test_multiple_layers():
+ """Multiple layers should match any of them."""
+ regex = generate_layer_regex([1, 2, 3])
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ assert re.search(regex, "model.model.layers.2.layernorm_mlp.fc1")
+ assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
+ assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+
+
+def test_matches_correct_sublayers():
+ """Regex should only match layernorm_qkv, proj, fc1, fc2."""
+ regex = generate_layer_regex([1])
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv_something")
+ assert re.search(regex, "model.model.layers.1.self_attention.proj_something")
+ assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc1_something")
+ assert re.search(regex, "model.model.layers.1.layernorm_mlp.fc2_something")
+ # Should not match unrelated sublayer names
+ assert not re.search(regex, "model.model.layers.1.self_attention.some_other_thing")
+
+
+def test_none_returns_disabled_pattern():
+ """None should return a pattern that matches nothing."""
+ regex = generate_layer_regex(None)
+ assert "DISABLED" in regex
+ assert not re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+
+
+def test_empty_list_returns_disabled_pattern():
+ """Empty list should return a pattern that matches nothing."""
+ regex = generate_layer_regex([])
+ assert "DISABLED" in regex
+
+
+def test_1indexed_layer_names():
+ """Regex should use 1-indexed layer numbers (matching debug API naming)."""
+ regex = generate_layer_regex([1])
+ # Should match layers.1 (1-indexed first layer)
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ # Should NOT match layers.0 (0-indexed first layer)
+ assert not re.search(regex, "model.model.layers.0.self_attention.layernorm_qkv")
+
+
+# -- update_quant_stats_config --
+
+
+@pytest.fixture
+def fp8_only_config(tmp_path):
+ """Create an FP8-only stats config file."""
+ config = {
+ "example_fp8_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogFp8TensorStats": {
+ "enabled": True,
+ "tensors_struct": [{"tensor": "activation", "stats": ["underflows%"], "freq": 10}],
+ }
+ },
+ }
+ }
+ config_path = tmp_path / "fp8_stats.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+ return str(config_path)
+
+
+@pytest.fixture
+def fp4_fp8_config(tmp_path):
+ """Create a combined FP4+FP8 stats config file."""
+ config = {
+ "example_fp4_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogNvfp4TensorStats": {"enabled": True},
+ },
+ },
+ "example_fp8_tensor_stat_collection": {
+ "enabled": True,
+ "layers": {
+ "layer_name_regex_pattern": "PLACEHOLDER",
+ },
+ "transformer_engine": {
+ "LogFp8TensorStats": {"enabled": True},
+ },
+ },
+ }
+ config_path = tmp_path / "fp4_fp8_stats.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+ return str(config_path)
+
+
+def test_fp8_layers_updates_regex(fp8_only_config):
+ """FP8 layer list should update the regex in the output config."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+ assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
+ assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+
+
+def test_none_layers_disables_matching(fp8_only_config):
+ """None layers should set regex to match nothing."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=None)
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert "DISABLED" in regex
+
+
+def test_fp4_section_disabled_fp8_still_updated(fp4_fp8_config):
+ """FP4 stats section should be disabled (not yet supported), FP8 should still be updated."""
+ output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+
+ # FP4 section should be disabled
+ assert result["example_fp4_tensor_stat_collection"]["enabled"] is False
+
+ # FP8 regex should still match layers 4-6
+ fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj")
+ assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj")
+
+
+def test_original_file_not_modified(fp8_only_config):
+ """update_quant_stats_config should write to a temp file, not modify the original."""
+ with open(fp8_only_config) as f:
+ original_content = f.read()
+
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2])
+
+ assert output_path != fp8_only_config
+ with open(fp8_only_config) as f:
+ assert f.read() == original_content
+
+
+def test_preserves_other_config_fields(fp8_only_config):
+ """Non-layer fields in the config should be preserved."""
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ # The transformer_engine section should still be there
+ assert result["example_fp8_tensor_stat_collection"]["transformer_engine"]["LogFp8TensorStats"]["enabled"] is True
+
+
+def test_missing_section_is_skipped(fp8_only_config):
+ """If fp4 section doesn't exist in config, it should be silently skipped."""
+ # fp8_only_config has no fp4 section -- passing fp4_layers should not error
+ output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=[1, 2], fp8_layers=[3, 4])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+ # Only FP8 section should exist and be updated
+ assert "example_fp4_tensor_stat_collection" not in result
+ regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(regex, "model.model.layers.3.self_attention.layernorm_qkv")
+
+
+def test_with_real_fp4_config():
+ """Test with the actual fp4_debugging_stats.yaml file."""
+ config_path = Path(__file__).parent.parent / "fp4_debugging_stats.yaml"
+ if not config_path.exists():
+ pytest.skip("fp4_debugging_stats.yaml not found")
+
+ output_path = update_quant_stats_config(config_file=str(config_path), fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6])
+ with open(output_path) as f:
+ result = yaml.safe_load(f)
+
+ # FP4 section should be disabled (not yet supported in current TE release)
+ assert result["example_fp4_tensor_stat_collection"]["enabled"] is False
+
+ # FP8 section should still be updated and working
+ fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+ assert re.search(fp8_regex, "model.model.layers.5.self_attention.proj")
+ assert not re.search(fp8_regex, "model.model.layers.2.self_attention.proj")
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py
new file mode 100644
index 0000000000..440e588bb4
--- /dev/null
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_quantized_model_init.py
@@ -0,0 +1,181 @@
+# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-Apache2
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for quantized_model_init with FL1 (first/last layer BF16) and all-FP8.
+
+Verifies that:
+1. All-FP8 + qinit: all decoder layer weights are QuantizedTensors with high-precision init vals
+2. FL1 + qinit: FP8 layers have QuantizedTensor weights, BF16 layers have regular BF16 weights
+3. BF16 layers don't lose precision from an outer quantized_model_init context
+
+Parametrized across all FP8 recipes with automatic xfail for unsupported hardware
+(same pattern as conftest.py and the model-level tests).
+"""
+
+import sys
+from pathlib import Path
+
+import pytest
+import torch
+import transformer_engine.pytorch as te
+from transformer_engine.common import recipe as recipe_module
+from transformer_engine.pytorch import fp8 as te_fp8
+from transformer_engine.pytorch.tensor import QuantizedTensor
+
+
+sys.path.append(Path(__file__).parent.parent.as_posix())
+
+from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
+
+
+# Small model config for fast testing
+_SMALL_CONFIG_KWARGS = {
+ "num_hidden_layers": 4,
+ "hidden_size": 256,
+ "intermediate_size": 512,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 4,
+ "vocab_size": 1024,
+ "max_position_embeddings": 128,
+}
+
+requires_gpu = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
+
+# FP8 recipes with hardware support checks — unsupported recipes auto-xfail.
+_FP8_RECIPES = [
+ ("DelayedScaling", recipe_module.DelayedScaling(), te_fp8.check_fp8_support),
+ ("Float8CurrentScaling", recipe_module.Float8CurrentScaling(), te_fp8.check_fp8_support),
+ ("Float8BlockScaling", recipe_module.Float8BlockScaling(), te_fp8.check_fp8_block_scaling_support),
+ ("MXFP8BlockScaling", recipe_module.MXFP8BlockScaling(), te_fp8.check_mxfp8_support),
+]
+
+
+def _parametrize_fp8_recipes():
+ params = []
+ for name, recipe, check_fn in _FP8_RECIPES:
+ supported, reason = check_fn()
+ params.append(pytest.param(recipe, id=name, marks=pytest.mark.xfail(condition=not supported, reason=reason)))
+ return params
+
+
+fp8_recipe_fixture = pytest.fixture(params=_parametrize_fp8_recipes())
+
+
+@fp8_recipe_fixture
+def qinit_recipe(request):
+ """FP8 recipe for quantized_model_init tests, with xfail for unsupported hardware."""
+ return request.param
+
+
+def _has_quantized_weights(layer) -> bool:
+ """Check if a TE TransformerLayer has any QuantizedTensor parameters."""
+ for param in layer.parameters():
+ if isinstance(param.data, QuantizedTensor):
+ return True
+ return False
+
+
+def _has_high_precision_init_val(layer) -> bool:
+ """Check if any parameter in the layer has a high-precision init val."""
+ for param in layer.parameters():
+ if hasattr(param, "get_high_precision_init_val") and param.get_high_precision_init_val() is not None:
+ return True
+ return False
+
+
+@requires_gpu
+def test_all_fp8_qinit(qinit_recipe):
+ """All layers FP8 with quantized_model_init: all weights should be QuantizedTensors."""
+ config = NVLlamaConfig(
+ **_SMALL_CONFIG_KWARGS,
+ attn_input_format="bshd",
+ dtype=torch.bfloat16,
+ )
+
+ with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True):
+ model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe)
+
+ for i, layer in enumerate(model.model.layers):
+ assert _has_quantized_weights(layer), f"Layer {i} should have QuantizedTensor weights"
+ assert _has_high_precision_init_val(layer), f"Layer {i} should have high-precision init vals"
+
+
+@requires_gpu
+def test_fl1_qinit_bf16_layers_not_quantized(qinit_recipe):
+ """FL1 + qinit: BF16 layers (first/last) should NOT have quantized weights."""
+ layer_precision = [None, "fp8", "fp8", None]
+ config = NVLlamaConfig(
+ **_SMALL_CONFIG_KWARGS,
+ attn_input_format="bshd",
+ dtype=torch.bfloat16,
+ layer_precision=layer_precision,
+ use_quantized_model_init=True,
+ )
+
+ with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True):
+ model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe)
+
+ # BF16 layers (0 and 3, 0-indexed) should NOT have quantized weights
+ assert not _has_quantized_weights(model.model.layers[0]), "First layer (BF16) should not have QuantizedTensors"
+ assert not _has_quantized_weights(model.model.layers[3]), "Last layer (BF16) should not have QuantizedTensors"
+
+ # FP8 layers (1 and 2, 0-indexed) should have quantized weights
+ assert _has_quantized_weights(model.model.layers[1]), "FP8 layer 1 should have QuantizedTensors"
+ assert _has_quantized_weights(model.model.layers[2]), "FP8 layer 2 should have QuantizedTensors"
+
+
+@requires_gpu
+def test_fl1_qinit_fp8_layers_preserve_high_precision(qinit_recipe):
+ """FL1 + qinit: FP8 layers should preserve high-precision init vals for master weights."""
+ layer_precision = [None, "fp8", "fp8", None]
+ config = NVLlamaConfig(
+ **_SMALL_CONFIG_KWARGS,
+ attn_input_format="bshd",
+ dtype=torch.bfloat16,
+ layer_precision=layer_precision,
+ use_quantized_model_init=True,
+ )
+
+ with te.quantized_model_init(recipe=qinit_recipe, enabled=True, preserve_high_precision_init_val=True):
+ model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe)
+
+ # FP8 layers should have high-precision init values
+ assert _has_high_precision_init_val(model.model.layers[1]), "FP8 layer should have high-precision init vals"
+ assert _has_high_precision_init_val(model.model.layers[2]), "FP8 layer should have high-precision init vals"
+
+ # BF16 layers should NOT have high-precision init values (they're already BF16)
+ assert not _has_high_precision_init_val(model.model.layers[0]), (
+ "BF16 layer should not have high-precision init vals"
+ )
+ assert not _has_high_precision_init_val(model.model.layers[3]), (
+ "BF16 layer should not have high-precision init vals"
+ )
+
+
+@requires_gpu
+def test_fl1_no_qinit_baseline(qinit_recipe):
+ """FL1 without qinit: all weights should be regular BF16 tensors (baseline)."""
+ layer_precision = [None, "fp8", "fp8", None]
+ config = NVLlamaConfig(
+ **_SMALL_CONFIG_KWARGS,
+ attn_input_format="bshd",
+ dtype=torch.bfloat16,
+ layer_precision=layer_precision,
+ )
+
+ model = NVLlamaForCausalLM(config, fp8_recipe=qinit_recipe)
+
+ for i, layer in enumerate(model.model.layers):
+ assert not _has_quantized_weights(layer), f"Layer {i} should not have QuantizedTensors without qinit"
diff --git a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
index 89e85068de..7640ee68b0 100644
--- a/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
+++ b/bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
@@ -439,6 +439,59 @@ def test_sanity_fsdp2_cp(tmp_path, recipe_path):
assert torch.isfinite(torch.tensor(final_loss)), f"Final loss {final_loss} is not finite"
+def test_sanity_convergence_fsdp2_te_fused_adam(tmp_path, recipe_path):
+ """Test FSDP2 training with TE FusedAdam for FP32 master weights.
+
+ This test validates:
+ - FusedAdam optimizer initializes correctly with FSDP2-wrapped model
+ - Training converges with FP32 master weights maintained by FusedAdam
+ - FusedAdam handles FP32 master weights at the optimizer level (no MixedPrecisionPolicy needed)
+ """
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.resume_from_checkpoint=false",
+ "use_fp32_master_weights=true",
+ ],
+ )
+
+ final_loss = main_fsdp2(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
+
+
+def test_sanity_convergence_fsdp2_te_fused_adam_fp8(tmp_path, recipe_path):
+ """Test FSDP2 + FusedAdam + FP8 training.
+
+ This test validates FusedAdam works correctly alongside FP8 quantization,
+ matching the approach used in the lingua 7B MXFP8 experiment config.
+ """
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "checkpoint.resume_from_checkpoint=false",
+ "use_fp32_master_weights=true",
+ "fp8_config.enabled=true",
+ "use_sequence_packing=true",
+ "config_kwargs.attn_input_format=thd",
+ ],
+ )
+
+ final_loss = main_fsdp2(sanity_config)
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ assert final_loss < 8.0, f"Final loss {final_loss} is too high, expected < 8.0"
+
+
@requires_fp8
def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
"""Test that FP8 stats logging creates the expected log files."""
@@ -452,8 +505,8 @@ def test_sanity_ddp_fp8_stats_logging(tmp_path, recipe_path):
f"checkpoint.ckpt_dir={tmp_path}",
"+dataset.pad_sequences_to_be_divisible_by=16",
"fp8_config.enabled=true",
- "fp8_stats_config.enabled=true",
- f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={fp8_log_dir}",
"num_train_steps=4",
],
)
@@ -493,8 +546,8 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
f"checkpoint.ckpt_dir={tmp_path}",
"fp8_config.enabled=true",
"+dataset.pad_sequences_to_be_divisible_by=16",
- "fp8_stats_config.enabled=true",
- f"fp8_stats_config.fp8_log_dir={fp8_log_dir}",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={fp8_log_dir}",
"num_train_steps=4",
],
)
@@ -507,6 +560,65 @@ def test_sanity_fsdp2_fp8_stats_logging(tmp_path, recipe_path):
assert (fp8_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+@requires_fp8
+def test_sanity_ddp_fp8_partial_layers_stats_logging(tmp_path, recipe_path):
+ """Test DDP training with layer-wise FP8 stats (layers 1-3 only)."""
+ quant_log_dir = tmp_path / "quant_stats_logs"
+
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb_init_args.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "+dataset.pad_sequences_to_be_divisible_by=16",
+ "fp8_config.enabled=true",
+ "fp8_layers=[1,2,3]",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={quant_log_dir}",
+ "num_train_steps=4",
+ ],
+ )
+
+ main_ddp(sanity_config)
+
+ # Verify the log directory structure was created
+ assert quant_log_dir.exists(), "Quant log directory was not created"
+ assert (quant_log_dir / "rank_0").exists(), "rank_0 directory was not created"
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs").exists(), "nvdlfw_inspect_logs directory was not created"
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), (
+ "nvdlfw_inspect_statistics_logs directory was not created"
+ )
+
+
+@requires_fp8
+def test_sanity_fsdp2_fp8_partial_layers_stats_logging(tmp_path, recipe_path):
+ """Test FSDP2 training with layer-wise FP8 stats (layers 1-3 only)."""
+ quant_log_dir = tmp_path / "quant_stats_logs"
+
+ with initialize_config_dir(config_dir=str(recipe_path / "hydra_config"), version_base="1.2"):
+ sanity_config = compose(
+ config_name="L0_sanity",
+ overrides=[
+ f"+wandb_init_args.dir={tmp_path}",
+ f"checkpoint.ckpt_dir={tmp_path}",
+ "+dataset.pad_sequences_to_be_divisible_by=16",
+ "fp8_config.enabled=true",
+ "fp8_layers=[1,2,3]",
+ "quant_stats_config.enabled=true",
+ f"quant_stats_config.quant_log_dir={quant_log_dir}",
+ "num_train_steps=4",
+ ],
+ )
+
+ main_fsdp2(sanity_config)
+
+ # Verify log structure
+ assert quant_log_dir.exists()
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+ assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+
+
def run_train_cmd(cmd, recipe_path):
"""Run a training command and check for errors.
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
index 413b9262c7..7d589045f1 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_ddp.py
@@ -42,9 +42,9 @@
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
from dataset import create_bshd_dataloader, create_thd_dataloader
from distributed_config import DistributedConfig
-from fp8_debugging import initialize_fp8_debugging
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -66,37 +66,78 @@ def main(args: DictConfig) -> float | None:
torch.distributed.init_process_group(backend="nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
- # TE Debug feature logging
- if args.fp8_stats_config.enabled:
- initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
-
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2.
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
+ if args.use_te:
+ config_class = NVLlamaConfig
+ model_class = NVLlamaForCausalLM
+ else:
+ config_class = LlamaConfig
+ model_class = LlamaForCausalLM
+
# --- Model Configuration ---
- # Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
+ config = config_class.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
fp8_recipe = None
+ fp4_recipe = None
if args.fp8_config.enabled:
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
-
- fp4_recipe = None
if args.fp4_config.enabled:
- fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
+
+ if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not (
+ args.fp8_config.enabled or args.fp4_config.enabled
+ ):
+ raise ValueError(
+ "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or "
+ "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization."
+ )
# --- Model Initialization ---
- if args.use_te:
- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
- model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
- else:
- config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
- model = LlamaForCausalLM(config)
+ # Optionally use transformer engine to initialize only fp8 versions of weights by setting
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
+ with transformer_engine.pytorch.quantized_model_init(
+ recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
+ ):
+ model = (
+ model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ if model_class is NVLlamaForCausalLM
+ else model_class(config)
+ )
logger.info("Initialized Model:\n%s", model)
+ # Attach quantization recipes to the model (layer precision is already on config).
+ if isinstance(model, NVLlamaForCausalLM):
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
# --- Distributed Wrapping (DDP) ---
- if args.fp8_stats_config.enabled:
+ if args.quant_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
model = model.to(device=device)
@@ -157,9 +198,8 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
# DDP requires no_sync to skip all-reduce until the last microbatch in the accumulation window.
with model.no_sync() if micro_step % args.grad_acc_steps != 0 else nullcontext():
- # Forward pass with mixed precision.
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
index da19daa2a7..d7462f7bd0 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
@@ -34,8 +34,10 @@
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
+from torch.distributed.tensor import DTensor
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
+from transformer_engine.pytorch.optimizers import FusedAdam
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
@@ -48,9 +50,9 @@
)
from dataset import create_bshd_dataloader, create_thd_dataloader
from distributed_config import DistributedConfig
-from fp8_debugging import initialize_fp8_debugging
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -58,6 +60,43 @@
logger.setLevel(logging.INFO)
+def _init_master_weights_from_high_precision(
+ optimizer: FusedAdam, model: torch.nn.Module, device: torch.device
+) -> None:
+ """Initialize optimizer master weights from high-precision init values.
+
+ When quantized_model_init is used with preserve_high_precision_init_val=True, each FP8 parameter
+ stores the original BF16 init values in CPU memory. This function initializes optimizer state
+ for all parameters, then overwrites master weights for quantized params with the preserved
+ high-precision values instead of dequantized FP8 values.
+
+ Follows the TE example:
+ https://github.com/NVIDIA/TransformerEngine/blob/main/examples/pytorch/quantized_model_init/fully_shard.py
+ """
+ count = 0
+ for name, param in model.named_parameters():
+ # Eagerly initialize optimizer state for all parameters.
+ # TE main's FusedAdam handles DTensor + QuantizedTensor natively.
+ optimizer.initialize_state(param, store_param_remainders=False)
+
+ # For quantized params, overwrite master weights with the preserved high-precision
+ # init values (instead of the dequantized FP8 values set by initialize_state).
+ local = param._local_tensor if isinstance(param, DTensor) else param
+ if hasattr(local, "get_high_precision_init_val"):
+ hp_val = local.get_high_precision_init_val()
+ if hp_val is not None:
+ optimizer.set_scaled_state(param, "master_param", hp_val.to(device=device, dtype=torch.float32))
+ local.clear_high_precision_init_val()
+ count += 1
+ logger.debug("Seeded master weight for %s from high-precision init val", name)
+ if count > 0:
+ logger.info("Initialized %d master weight(s) from high-precision init values", count)
+ else:
+ logger.info(
+ "No parameters with high-precision init values found (quantized_model_init may not have been used)"
+ )
+
+
@hydra.main(config_path="hydra_config", config_name="L0_sanity", version_base="1.2")
def main(args: DictConfig) -> float | None:
"""Train Llama3 with TE layers using FSDP2.
@@ -72,33 +111,75 @@ def main(args: DictConfig) -> float | None:
torch.distributed.init_process_group(backend="cpu:gloo,cuda:nccl", device_id=device)
torch.cuda.set_device(dist_config.local_rank)
- # TE Debug feature logging - MUST be done BEFORE FSDP wrapping
- if args.fp8_stats_config.enabled:
- initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
-
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("dp",))
+ if args.use_te:
+ config_class = NVLlamaConfig
+ model_class = NVLlamaForCausalLM
+ else:
+ config_class = LlamaConfig
+ model_class = LlamaForCausalLM
+
# --- Model Configuration ---
- # Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
+ config = config_class.from_pretrained(
+ args.config_name_or_path,
+ dtype=torch.bfloat16,
+ **args.config_kwargs,
+ )
+
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
fp8_recipe = None
+ fp4_recipe = None
if args.fp8_config.enabled:
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
-
- fp4_recipe = None
if args.fp4_config.enabled:
- fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
+
+ if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not (
+ args.fp8_config.enabled or args.fp4_config.enabled
+ ):
+ raise ValueError(
+ "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or "
+ "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization."
+ )
# --- Model Initialization ---
- if args.use_te:
- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
- with torch.device("meta") if args.use_meta_device else nullcontext():
- model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
- else:
- config = LlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
- with torch.device("meta") if args.use_meta_device else nullcontext():
- model = LlamaForCausalLM(config)
+ # Optionally use transformer engine to initialize only fp8 versions of weights by setting
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
+ with (
+ torch.device("meta") if args.use_meta_device else nullcontext(),
+ transformer_engine.pytorch.quantized_model_init(
+ recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
+ ),
+ ):
+ model = (
+ model_class(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+ if model_class is NVLlamaForCausalLM
+ else model_class(config)
+ )
logger.info("Initialized Model:\n%s", model)
@@ -108,6 +189,10 @@ def main(args: DictConfig) -> float | None:
fully_shard(layer, mesh=device_mesh["dp"])
fully_shard(model, mesh=device_mesh["dp"])
+ # Attach quantization recipes to the model (layer precision is already on config).
+ if isinstance(model, NVLlamaForCausalLM):
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
# If we're using meta device, we need to move sharded weights to the cuda device and initialize the parameters.
if args.use_meta_device:
if args.use_te:
@@ -118,12 +203,20 @@ def main(args: DictConfig) -> float | None:
model.apply(model._init_weights)
# Assign names to layers so debug API can identify them
- if args.fp8_stats_config.enabled:
+ if args.quant_stats_config.enabled:
debug_api.infer_and_assign_layer_names(model)
# --- Optimizer & Scheduler ---
# Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
- optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
+ adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True)
+ if args.use_fp32_master_weights:
+ # TE FusedAdam maintains FP32 master copies of BF16 params internally.
+ # 'fused' kwarg is not used by TE's FusedAdam (it's always fused).
+ adamw_kwargs.pop("fused", None)
+ optimizer = FusedAdam(model.parameters(), master_weights=True, **adamw_kwargs) # type: ignore
+ logger.info("Using TE FusedAdam with FP32 master weights")
+ else:
+ optimizer = AdamW(model.parameters(), **adamw_kwargs) # type: ignore
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
if args.use_torch_compile:
@@ -140,21 +233,32 @@ def main(args: DictConfig) -> float | None:
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
if args.checkpoint.resume_from_checkpoint and ckpt_path:
logger.info("Attempting to load checkpoint from %s", ckpt_path)
- model, optimizer, scheduler, train_dataloader, start_step, epoch = load_checkpoint_fsdp2(
+ model, optimizer, scheduler, _dl, start_step, epoch = load_checkpoint_fsdp2(
model=model,
optimizer=optimizer,
scheduler=scheduler,
ckpt_path=ckpt_path,
dist_config=dist_config,
- dataloader=train_dataloader,
+ dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
process_group=device_mesh.get_group("dp"),
)
+ if _dl is not None:
+ train_dataloader = _dl
logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
else:
logger.info("No checkpoint to load, starting from scratch")
start_step = 0
epoch = 0
+ # When starting from scratch with quantized_model_init + preserve_high_precision_init_val,
+ # seed FP32 master weights from the original high-precision init values (not dequantized FP8).
+ # Skip on resume — checkpoint already has correct master weights, and eager dequantize() can
+ # invalidate QuantizedTensor storage causing FSDP2 forward failures.
+ if args.use_fp32_master_weights and args.fp8_config.quantized_model_init_kwargs.get(
+ "preserve_high_precision_init_val", False
+ ):
+ _init_master_weights_from_high_precision(optimizer, model, device)
+
perf_logger = PerfLogger(dist_config, args, start_step=start_step)
gc.collect()
@@ -170,9 +274,8 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
- # Forward pass with mixed precision.
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
@@ -220,6 +323,7 @@ def main(args: DictConfig) -> float | None:
# Dataloader exhausted, incrementing epoch
epoch += 1
+ logger.warning("Dataloader exhausted at step %s, incrementing epoch to %s", step, epoch)
dataset_or_sampler.set_epoch(epoch)
# --- Cleanup ---
diff --git a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
index eaf1a1b39f..e9db3f2db6 100644
--- a/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
+++ b/bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
@@ -29,14 +29,17 @@
from pathlib import Path
import hydra
+import nvdlfw_inspect.api as debug_api
import nvtx
import torch
import transformer_engine.pytorch
from omegaconf import DictConfig, OmegaConf
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import fully_shard
+from torch.distributed.tensor import DTensor
from torch.optim import AdamW
from transformer_engine.common.recipe import Format
+from transformer_engine.pytorch.optimizers import FusedAdam
from checkpoint import (
_ckpt_futures,
@@ -50,6 +53,7 @@
from distributed_config import DistributedConfig
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
from perf_logger import PerfLogger
+from quantization import initialize_quant_stats_logging, resolve_layer_precision
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -57,6 +61,31 @@
logger.setLevel(logging.INFO)
+def _init_master_weights_from_high_precision(
+ optimizer: FusedAdam, model: torch.nn.Module, device: torch.device
+) -> None:
+ """Initialize optimizer master weights from high-precision init values.
+
+ When quantized_model_init is used with preserve_high_precision_init_val=True, each FP8 parameter
+ stores the original BF16 init values in CPU memory. This function initializes optimizer state
+ for all parameters, then overwrites master weights for quantized params with the preserved
+ high-precision values instead of dequantized FP8 values.
+ """
+ count = 0
+ for name, param in model.named_parameters():
+ optimizer.initialize_state(param, store_param_remainders=False)
+ local = param._local_tensor if isinstance(param, DTensor) else param
+ if hasattr(local, "get_high_precision_init_val"):
+ hp_val = local.get_high_precision_init_val()
+ if hp_val is not None:
+ optimizer.set_scaled_state(param, "master_param", hp_val.to(device=device, dtype=torch.float32))
+ local.clear_high_precision_init_val()
+ count += 1
+ logger.debug("Seeded master weight for %s from high-precision init val", name)
+ if count > 0:
+ logger.info("Initialized %d master weight(s) from high-precision init values", count)
+
+
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
def main(args: DictConfig) -> float | None:
"""Train Llama3 with TE layers using FSDP2 with Context Parallelism.
@@ -79,21 +108,60 @@ def main(args: DictConfig) -> float | None:
logger.info("Created device mesh: %s", device_mesh)
# --- Model Configuration ---
- # Create quantization recipes -- only used if FP8/FP4 is enabled in the config.
+ config = NVLlamaConfig.from_pretrained(
+ args.config_name_or_path,
+ dtype=torch.bfloat16,
+ **args.config_kwargs,
+ )
+
+ # Resolve layer-wise quantization assignments and store on config.
+ layer_precision = resolve_layer_precision(
+ num_layers=config.num_hidden_layers,
+ fp8_enabled=args.fp8_config.enabled,
+ fp4_enabled=args.fp4_config.enabled,
+ fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
+ fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
+ )
+ config.layer_precision = layer_precision
+
+ if args.quant_stats_config.enabled:
+ initialize_quant_stats_logging(
+ quant_stats_file=args.quant_stats_config.quant_stats_file,
+ quant_log_dir=args.quant_stats_config.quant_log_dir,
+ rank=dist_config.rank,
+ layer_precision=layer_precision,
+ )
+
+ # Create quantization recipes -- these are only used if FP8/FP4 is enabled in the config.
fp8_recipe = None
+ fp4_recipe = None
if args.fp8_config.enabled:
fp8_recipe = hydra.utils.get_class(args.fp8_config.fp8_recipe)(
fp8_format=Format[args.fp8_config.fp8_format], **args.fp8_config.fp8_recipe_kwargs
)
-
- fp4_recipe = None
if args.fp4_config.enabled:
- fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(**args.fp4_config.fp4_recipe_kwargs)
+ fp4_recipe = hydra.utils.get_class(args.fp4_config.fp4_recipe)(
+ fp4_format=Format[args.fp4_config.fp4_format], **args.fp4_config.fp4_recipe_kwargs
+ )
- # --- Model Initialization ---
- config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
+ if args.fp8_config.quantized_model_init_kwargs.get("enabled", False) and not (
+ args.fp8_config.enabled or args.fp4_config.enabled
+ ):
+ raise ValueError(
+ "fp8_config.quantized_model_init_kwargs.enabled=true requires fp8_config.enabled=true or "
+ "fp4_config.enabled=true. Enable at least one quantization format to use quantized model initialization."
+ )
- with torch.device("meta") if args.use_meta_device else nullcontext():
+ # --- Model Initialization ---
+ # Optionally use transformer engine to initialize only fp8 versions of weights by setting
+ # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
+ # and fp8 versions of weights are kept.
+ with (
+ torch.device("meta") if args.use_meta_device else nullcontext(),
+ transformer_engine.pytorch.quantized_model_init(
+ recipe=fp8_recipe, **args.fp8_config.quantized_model_init_kwargs
+ ),
+ ):
model = NVLlamaForCausalLM(config, fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
logger.info("Initialized Model:\n%s", model)
@@ -115,13 +183,28 @@ def main(args: DictConfig) -> float | None:
torch.cuda.Stream(),
)
+ # Attach quantization recipes to the model (layer precision is already on config).
+ model.model.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
+
if args.use_meta_device:
# TE layers require special handling to initialize the weights from the meta device.
model.init_empty_weights()
+ # Assign names to layers so debug API can identify them
+ if args.quant_stats_config.enabled:
+ debug_api.infer_and_assign_layer_names(model)
+
# --- Optimizer & Scheduler ---
# Convert OmegaConf to regular dict to avoid serialization issues (BIONEMO-2873).
- optimizer = AdamW(model.parameters(), **OmegaConf.to_container(args.adamw_kwargs, resolve=True)) # type: ignore
+ adamw_kwargs = OmegaConf.to_container(args.adamw_kwargs, resolve=True)
+ if args.use_fp32_master_weights:
+ # TE FusedAdam maintains FP32 master copies of BF16 params internally.
+ # 'fused' kwarg is not used by TE's FusedAdam (it's always fused).
+ adamw_kwargs.pop("fused", None)
+ optimizer = FusedAdam(model.parameters(), master_weights=True, **adamw_kwargs) # type: ignore
+ logger.info("Using TE FusedAdam with FP32 master weights")
+ else:
+ optimizer = AdamW(model.parameters(), **adamw_kwargs) # type: ignore
scheduler = get_cosine_annealing_schedule_with_warmup(optimizer, **args.lr_scheduler_kwargs)
if args.use_torch_compile:
@@ -177,6 +260,11 @@ def main(args: DictConfig) -> float | None:
start_step = 0
epoch = 0
+ if args.use_fp32_master_weights and args.fp8_config.quantized_model_init_kwargs.get(
+ "preserve_high_precision_init_val", False
+ ):
+ _init_master_weights_from_high_precision(optimizer, model, device)
+
perf_logger = PerfLogger(dist_config, args, start_step=start_step)
gc.collect()
@@ -192,10 +280,9 @@ def main(args: DictConfig) -> float | None:
micro_step += 1
- # Forward pass with mixed precision.
+ # Forward pass - quantization autocast is handled inside the model via set_recipes().
with nvtx.annotate("Forward pass", color="green"):
- with transformer_engine.pytorch.autocast(enabled=args.fp8_config.enabled, recipe=fp8_recipe):
- outputs = model(**batch)
+ outputs = model(**batch)
# Backward pass - scale loss by grad_acc_steps for proper gradient averaging
loss = outputs.loss / args.grad_acc_steps
diff --git a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
index 62171cd237..0bdb8a23e8 100644
--- a/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
+++ b/bionemo-recipes/recipes/opengenome2_llama_native_te/modeling_llama_te.py
@@ -58,6 +58,7 @@ class NVLlamaConfig(LlamaConfig):
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
attn_input_format: str = "thd"
self_attn_mask_type: str = "padding_causal"
+ layer_precision: list[str | None] | None = None
def __init__(
self,
@@ -223,11 +224,54 @@ def _init_method(x):
self.rotary_emb = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.rotary_emb.inv_freq = LlamaRotaryEmbedding(config=config).inv_freq
+ self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
+ self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
+
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
+ def set_recipes(
+ self,
+ fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
+ ) -> None:
+ """Attach quantization recipe objects for per-layer autocast.
+
+ Recipes are not serializable and must be set at runtime after model creation
+ and sharding (FSDP/DDP) but before training. The per-layer precision
+ assignments are read from ``self.config.layer_precision``.
+
+ Args:
+ fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
+ fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
+ """
+ self._fp8_recipe = fp8_recipe
+ self._fp4_recipe = fp4_recipe
+
+ def get_layer_autocast(self, layer_number: int):
+ """Return the appropriate TE autocast context manager for a given layer.
+
+ The context interacts with the outer FP8 autocast in the training script:
+ - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
+ - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
+ - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
+
+ Args:
+ layer_number: The 0-indexed layer number.
+
+ Returns:
+ A context manager for the layer's quantization mode.
+ """
+ precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+ if precision == "fp8":
+ return nullcontext()
+ elif precision == "fp4":
+ return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
+ else:
+ return transformer_engine.pytorch.autocast(enabled=False)
+
def forward(
self,
input_ids: torch.Tensor | None = None,
@@ -304,12 +348,14 @@ def forward(
if te_rope_emb.dtype != torch.float32:
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
- with self.get_autocast_context(None, outer=True):
- for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
+ # Outer FP8 autocast enables FP8 compute for the decoder stack. Per-layer overrides (FP4, BF16) are handled
+ # by get_layer_autocast(), which nests inside this context.
+ with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
+ for layer_number, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
- with self.get_autocast_context(layer_idx):
+ with self.get_layer_autocast(layer_number):
hidden_states = decoder_layer(
hidden_states,
attention_mask=None if self.config.attn_input_format == "thd" else attention_mask,
@@ -369,8 +415,10 @@ def get_autocast_context(
if init and self.config.use_quantized_model_init:
if precision in ("fp8", "fp4"):
- return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
- return nullcontext()
+ return transformer_engine.pytorch.quantized_model_init(
+ recipe=recipe, preserve_high_precision_init_val=True
+ )
+ return transformer_engine.pytorch.quantized_model_init(enabled=False)
if precision == "fp8":
if recipe is None:
@@ -589,6 +637,11 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
return 0
return max(self.sequences.values())
+ @property
+ def is_compileable(self) -> bool:
+ """Required by HuggingFace transformers generate() auto-compile check."""
+ return False
+
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorder the cache based on the beam indices."""
if isinstance(self.cache_manager, PagedKVCacheManager):
@@ -597,8 +650,3 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_key_cache = key_cache.index_select(0, beam_idx)
updated_value_cache = value_cache.index_select(0, beam_idx)
self.cache_manager.cache[layer_number] = (updated_key_cache, updated_value_cache)
-
- @property
- def is_compileable(self) -> bool:
- """Return False as this cache is not compatible with torch.compile."""
- return False
diff --git a/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png b/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png
new file mode 100644
index 0000000000..a782e166a1
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-1b-loss-curve.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-1b-step-time.png b/docs/docs/assets/images/llama3/lingua-1b-step-time.png
new file mode 100644
index 0000000000..2b641b539d
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-1b-step-time.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png
new file mode 100644
index 0000000000..ef5721e578
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-1node.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png
new file mode 100644
index 0000000000..e8214ccfea
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-70b-mxfp8-multinode.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png b/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png
new file mode 100644
index 0000000000..a2afdbb77b
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-7b-mxfp8-multinode.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png b/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png
new file mode 100644
index 0000000000..0e0bb9383c
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-8b-mxfp8-1node.png differ
diff --git a/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png b/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png
new file mode 100644
index 0000000000..8bdb78070e
Binary files /dev/null and b/docs/docs/assets/images/llama3/lingua-8b-vs-70b-mxfp8-uplift.png differ
diff --git a/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png
new file mode 100644
index 0000000000..a304fa29b7
Binary files /dev/null and b/docs/docs/assets/images/llama3/llama3_8gpu_tflops.png differ