From eaa953b5b695a00bbf5000667e350c8022a0e522 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 7 May 2026 03:54:57 +0000 Subject: [PATCH 1/2] Fix bugs for MSE Signed-off-by: Chenjie Luo --- modelopt/torch/export/moe_utils.py | 16 ++- modelopt/torch/export/unified_export_hf.py | 19 +++ modelopt/torch/quantization/model_calib.py | 130 +++++++++++++++--- .../nn/modules/tensor_quantizer.py | 16 +++ .../torch/quantization/utils/core_utils.py | 8 +- 5 files changed, 163 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 8981d614843..1472274a106 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -110,11 +110,25 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax + # Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep) + # produces a per-block _amax with shape (num_blocks_total, ...) + # where num_blocks_total = fused_total * blocks_per_row. That + # shape collapses the row axis we want to slice on. Restore the + # row dimension so the dim-0 slicing below splits gate / up + # correctly. No-op when _amax is already aligned with fused_total. + if amax.numel() != fused_total and amax.numel() % fused_total == 0: + amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] if fused_total % amax_dim0 == 0: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer.amax = amax[slice_start:slice_end].contiguous() + sliced = amax[slice_start:slice_end].contiguous() + # The amax setter refuses shape changes once `_amax` exists, + # so drop the existing buffer before re-registering with the + # sliced shape. + if hasattr(w_quantizer, "_amax"): + delattr(w_quantizer, "_amax") + w_quantizer.amax = sliced else: warnings.warn( f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a58aa4c9895..0b2acb99d99 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: mod.revert_weight_conversion = original +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: + """Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set. + + Newer transformers reject ``do_sample=False`` mixed with sampling attrs in + ``save_pretrained``'s strict validate. + """ + gc = getattr(model, "generation_config", None) + if gc is None: + return + if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None: + gc.do_sample = True + + def export_speculative_decoding( model: torch.nn.Module, dtype: torch.dtype | None = None, @@ -1228,6 +1241,12 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() + # Some upstream HF checkpoints ship a generation_config.json that fails + # transformers' strict validation on save (e.g. ``top_p`` set without + # ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to + # the sampling-attrs intent so save_pretrained can write the file. + _sanitize_generation_config_for_save(model) + try: model.save_pretrained( export_dir, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe4c3f77ce6..fc64a5c780f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -64,8 +64,98 @@ "max_calibrate", "smoothquant", "svdquant", + "sync_grouped_weight_global_amax", ] + +# Sibling weight-quantizer name groups whose ``global_amax`` should share an +# FP8 scale-of-scales. All members of a group sit under the same parent module +# (e.g. one self-attention or one MLP block) and either consume the same input +# tensor or get fused at deployment, so a divergent global_amax across siblings +# would split their FP8 grids and skew the round. +_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( + # Standard self-attention (skipped for fused qkv_proj — single weight). + ("q_proj", "k_proj", "v_proj"), + # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.). + ("gate_proj", "up_proj"), + # Gated MLP, older Mixtral-style naming. + ("w1", "w3"), +) + + +def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool: + """True for an NVFP4-static weight quantizer that ``max_calibrate`` already + populated with a per-block ``_amax`` and that is currently enabled. + """ + return ( + isinstance(q, TensorQuantizer) + and not q._disabled + and q.is_nvfp4_static + and hasattr(q, "_amax") + and q._amax is not None + ) + + +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: + """Find groups of Linear-like submodules whose NVFP4-static weight quantizers + should share ``global_amax`` (Q/K/V under one attention parent; gate/up under + one MLP parent). + """ + groups: list[list[nn.Module]] = [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent in model.modules(): + for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: + members: list[nn.Module] = [] + for n in sibling_names: + child = getattr(parent, n, None) + if child is None: + continue + wq = getattr(child, wq_attr, None) + if _is_calibrated_nvfp4_static_weight_quantizer(wq): + members.append(child) + if len(members) >= 2: + groups.append(members) + return groups + + +@torch.no_grad() +def sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Sync ``global_amax`` across sibling NVFP4-static weight quantizers. + + For each group of siblings (Q/K/V projections under one attention parent; + gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the + NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a + consistent FP8 grid across the group during MSE / local-Hessian search. + + Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion` + (whose ``NVFP4StaticQuantizer`` branch performs the same + ``max(stack(global_amax))`` unification at export time). To call it before + MSE, this helper first promotes each grouped weight quantizer to + :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (= + ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in + place. + + Must be called after ``max_calibrate`` has populated each weight + quantizer's ``_amax``. Idempotent. Returns the number of groups synced. + """ + from modelopt.torch.export.quant_utils import preprocess_linear_fusion + + n_groups = 0 + for group in _collect_grouped_linears(model): + # Promote each member's weight quantizer so `preprocess_linear_fusion` + # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads + # `global_amax`, which only exists post-promotion). + wq_attr = quantizer_attr_names("weight").weight_quantizer + for child in group: + wq = getattr(child, wq_attr) + if not isinstance(wq, NVFP4StaticQuantizer): + local_global = reduce_amax(wq._amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global) + preprocess_linear_fusion(group) + n_groups += 1 + return n_groups + + CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -349,6 +439,13 @@ def mse_calibrate( # Step 1: First get initial amax using max calibration max_calibrate(model, forward_loop, distributed_sync) + # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers + # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one + # MLP block) so their FP8 scale-of-scales matches and the per-block FP8 + # round uses a consistent grid. No-op when there are no sibling groups + # (e.g. fused QKV / fused gate_up_proj). + sync_grouped_weight_global_amax(model) + # Step 2: Replace calibrators with MseCalibrator for enabled quantizers # and identify weight quantizers weight_quantizers = [] @@ -360,19 +457,16 @@ def mse_calibrate( # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = module.is_nvfp4_static if is_nvfp4_static: - # Compute and set global_amax - global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + # If sync_grouped_weight_global_amax already promoted this + # quantizer (it's a sibling in a Q/K/V or gate/up group), + # its global_amax has been unified across the group; just + # leave it. Otherwise convert + set local global_amax. + if not isinstance(module, NVFP4StaticQuantizer): + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: # Check if backend has a registered custom calibrator factory. @@ -612,6 +706,11 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) + # Sync global_amax across sibling NVFP4-static weight quantizers + # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales + # is consistent across the group. Idempotent; no-op when fused. + sync_grouped_weight_global_amax(model) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -666,14 +765,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_static = ( - weight_quantizer.is_static_block_quant - and weight_quantizer._num_bits == (2, 1) - and weight_quantizer._block_sizes is not None - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = weight_quantizer.is_nvfp4_static - if is_nvfp4_static: + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..12649691453 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -514,6 +514,22 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_nvfp4_static(self): + """Check if this quantizer is configured for NVFP4 static block quantization. + + Format-only check (does not consider whether ``_amax`` has been + populated by calibration). True when the quantizer holds E2M1 weights + with E4M3 per-block scales in a static layout — i.e. the two-level + scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`. + """ + return ( + self.is_static_block_quant + and self._num_bits == (2, 1) + and self._block_sizes is not None + and self._block_sizes.get("scale_bits") == (4, 3) + ) + def is_mxfp(self, bits): """Check if is MXFP4/MXFP6/MXFP8.""" if bits == 4: diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..cea3d4260e4 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: + if module.is_nvfp4_static: initial_amax = module._amax.clone().detach() global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) From ff014782e158a97ea208a6a45246af68b823b6e5 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 7 May 2026 17:10:02 +0000 Subject: [PATCH 2/2] [Quantization] MSE-calibrate every per-expert weight in fused-experts MoE Two-part fix for transformers 5.x fused-experts containers (Qwen3-MoE / Qwen3.5-MoE / Mixtral / DeepSeek / Kimi-K2.x ...) where weight quantizers live in `nn.ModuleList`s (`gate_up_proj_weight_quantizers`, `down_proj_weight_quantizers`): 1. Add `_QuantFusedExperts.iter_weights_for_calibration` that yields per-expert (weight_slice, quantizer) pairs for both projections. The base impl uses singular `*_weight_quantizer` and silently skips fused-experts modules, so weight-only calibration paths never reach per-expert quantizers. 2. Refactor `mse_calibrate`: - Add `_bootstrap_uncalibrated_weight_quantizers` after `max_calibrate` to populate `_amax` on quantizers the forward pass didn't reach (dead MoE experts that received no calibration tokens). Runs the existing calibrator on the weight slice surfaced by `iter_weights_for_calibration`. - Replace the singular-only `weight_attr_names` discovery + `getattr`-by- name walk with an `iter_weights_for_calibration` walk done inside each parent module's `enable_weight_access_and_writeback` context, so MSE processes every per-expert quantizer (active and dead) and remains FSDP-safe. Without this, the export-time fallback in `_export_fused_experts` derived separate gate/up amaxes from each half of the fused weight, breaking the gate==up `weight_scale_2` invariant on dead experts. End-to-end check on Qwen3.5-122B-A10B with `nvfp4_experts_only_mse-fp8_cast_kv`: - Before: 1/12288 (layer 38 expert 69) gate \!= up; 0 weights MSE-calibrated - After: 0/12288 mismatches; 24576 weights MSE-calibrated; ~4.2 min Signed-off-by: Chenjie Luo --- modelopt/torch/export/moe_utils.py | 12 +- modelopt/torch/export/unified_export_hf.py | 4 - modelopt/torch/quantization/model_calib.py | 238 +++++++++--------- .../nn/modules/tensor_quantizer.py | 8 +- .../torch/quantization/plugins/huggingface.py | 18 ++ .../plugins/test_fused_experts.py | 199 +++++++++++++++ 6 files changed, 335 insertions(+), 144 deletions(-) diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 1472274a106..e325e5346f1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -110,12 +110,8 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax - # Static block-quant calibration (e.g. NVFP4 MSE FP8 sweep) - # produces a per-block _amax with shape (num_blocks_total, ...) - # where num_blocks_total = fused_total * blocks_per_row. That - # shape collapses the row axis we want to slice on. Restore the - # row dimension so the dim-0 slicing below splits gate / up - # correctly. No-op when _amax is already aligned with fused_total. + # Per-block _amax (NVFP4 static) collapses the row axis we want + # to slice on; restore it so dim-0 slicing splits gate/up. if amax.numel() != fused_total and amax.numel() % fused_total == 0: amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] @@ -123,9 +119,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total sliced = amax[slice_start:slice_end].contiguous() - # The amax setter refuses shape changes once `_amax` exists, - # so drop the existing buffer before re-registering with the - # sliced shape. + # The amax setter refuses shape changes; drop _amax first. if hasattr(w_quantizer, "_amax"): delattr(w_quantizer, "_amax") w_quantizer.amax = sliced diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 0b2acb99d99..73ae63a5a56 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1241,10 +1241,6 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() - # Some upstream HF checkpoints ship a generation_config.json that fails - # transformers' strict validation on save (e.g. ``top_p`` set without - # ``do_sample=True`` — newer transformers raises). Flip ``do_sample`` to - # the sampling-attrs intent so save_pretrained can write the file. _sanitize_generation_config_for_save(model) try: diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fc64a5c780f..bce49786077 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -52,7 +52,6 @@ promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, - weight_attr_names, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -64,93 +63,104 @@ "max_calibrate", "smoothquant", "svdquant", - "sync_grouped_weight_global_amax", ] -# Sibling weight-quantizer name groups whose ``global_amax`` should share an -# FP8 scale-of-scales. All members of a group sit under the same parent module -# (e.g. one self-attention or one MLP block) and either consume the same input -# tensor or get fused at deployment, so a divergent global_amax across siblings -# would split their FP8 grids and skew the round. -_GROUPED_WEIGHT_QUANTIZER_PATTERNS: tuple[tuple[str, ...], ...] = ( - # Standard self-attention (skipped for fused qkv_proj — single weight). - ("q_proj", "k_proj", "v_proj"), - # Gated MLP, modern naming (Llama / Qwen / Mistral / etc.). - ("gate_proj", "up_proj"), - # Gated MLP, older Mixtral-style naming. - ("w1", "w3"), -) - - -def _is_calibrated_nvfp4_static_weight_quantizer(q) -> bool: - """True for an NVFP4-static weight quantizer that ``max_calibrate`` already - populated with a per-block ``_amax`` and that is currently enabled. - """ +def _is_calibrated_nvfp4_static(q) -> bool: + """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" return ( isinstance(q, TensorQuantizer) and not q._disabled and q.is_nvfp4_static - and hasattr(q, "_amax") - and q._amax is not None + and getattr(q, "_amax", None) is not None ) def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: - """Find groups of Linear-like submodules whose NVFP4-static weight quantizers - should share ``global_amax`` (Q/K/V under one attention parent; gate/up under - one MLP parent). - """ + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" + # Inline: layer_utils → quant_utils → model_calib cycle. + from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS + + # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant + # in export). Single source for the gate/up half avoids parallel lists. + patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) groups: list[list[nn.Module]] = [] wq_attr = quantizer_attr_names("weight").weight_quantizer for parent in model.modules(): - for sibling_names in _GROUPED_WEIGHT_QUANTIZER_PATTERNS: - members: list[nn.Module] = [] - for n in sibling_names: - child = getattr(parent, n, None) - if child is None: - continue - wq = getattr(child, wq_attr, None) - if _is_calibrated_nvfp4_static_weight_quantizer(wq): - members.append(child) + for sibling_names in patterns: + members = [ + child + for child in (getattr(parent, n, None) for n in sibling_names) + if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) + ] if len(members) >= 2: groups.append(members) return groups @torch.no_grad() -def sync_grouped_weight_global_amax(model: nn.Module) -> int: - """Sync ``global_amax`` across sibling NVFP4-static weight quantizers. - - For each group of siblings (Q/K/V projections under one attention parent; - gate/up — a.k.a. ``w1``/``w3`` — under one MLP parent) unifies the - NVFP4 ``global_amax`` so the per-block FP8 round picks scales against a - consistent FP8 grid across the group during MSE / local-Hessian search. - - Reuses :func:`modelopt.torch.export.quant_utils.preprocess_linear_fusion` - (whose ``NVFP4StaticQuantizer`` branch performs the same - ``max(stack(global_amax))`` unification at export time). To call it before - MSE, this helper first promotes each grouped weight quantizer to - :class:`NVFP4StaticQuantizer` with its local ``global_amax`` (= - ``reduce_amax(_amax)``); ``preprocess_linear_fusion`` then unifies in - place. - - Must be called after ``max_calibrate`` has populated each weight - quantizer's ``_amax``. Idempotent. Returns the number of groups synced. +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: + """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. + + Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE + doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. + Activation quantizers on those modules remain uncalibrated; emits a warning. + """ + name_to_module = dict(model.named_modules()) + n = 0 + for module in name_to_module.values(): + if not isinstance(module, QuantModule): + continue + with enable_weight_access_and_writeback(module, model, name_to_module): + for weight, q in module.iter_weights_for_calibration(): + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: + continue + if q._calibrator is None: + continue + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): + continue + q.disable_quant() + q.enable_calib() + q(weight) + if q._calibrator.compute_amax() is not None: + q.load_calib_amax() + q.enable_quant() + q.disable_calib() + if hasattr(q._calibrator, "reset"): + q._calibrator.reset() + n += 1 + if n > 0: + warnings.warn( + f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " + f"their activation quantizers (if any) remain uncalibrated. " + f"Increase calib size/seq len to activate all experts.", + stacklevel=2, + ) + return n + + +@torch.no_grad() +def _sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. + + Run after ``max_calibrate``. Sibling discovery is name-based via + ``_collect_grouped_linears``; non-matching architectures (wqkv, fused + qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently + fall back to per-module global_amax. Fused-experts containers already + share a single quantizer across gate/up halves and need no sync. """ + # quant_utils imports back from this module; top-level would cycle. from modelopt.torch.export.quant_utils import preprocess_linear_fusion + wq_attr = quantizer_attr_names("weight").weight_quantizer n_groups = 0 for group in _collect_grouped_linears(model): - # Promote each member's weight quantizer so `preprocess_linear_fusion` - # sees post-conversion NVFP4StaticQuantizers (its NVFP4 branch reads - # `global_amax`, which only exists post-promotion). - wq_attr = quantizer_attr_names("weight").weight_quantizer for child in group: wq = getattr(child, wq_attr) if not isinstance(wq, NVFP4StaticQuantizer): - local_global = reduce_amax(wq._amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(wq, global_amax=local_global) + NVFP4StaticQuantizer.from_tensor_quantizer( + wq, global_amax=reduce_amax(wq._amax, axis=None) + ) preprocess_linear_fusion(group) n_groups += 1 return n_groups @@ -436,37 +446,24 @@ def mse_calibrate( See :class:`MseCalibConfig ` for details on the remaining arguments. """ - # Step 1: First get initial amax using max calibration + # Step 1: max calibrate, bootstrap dead-expert weight quantizers, + # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) + _bootstrap_uncalibrated_weight_quantizers(model) + _sync_grouped_weight_global_amax(model) - # Step 1b: Sync global_amax across sibling NVFP4-static weight quantizers - # (q/k/v_proj under one attention block; gate/up — a.k.a. w1/w3 — under one - # MLP block) so their FP8 scale-of-scales matches and the per-block FP8 - # round uses a consistent grid. No-op when there are no sibling groups - # (e.g. fused QKV / fused gate_up_proj). - sync_grouped_weight_global_amax(model) - - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - # and identify weight quantizers - weight_quantizers = [] - seen_modules = set() - + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() - is_nvfp4_static = module.is_nvfp4_static - if is_nvfp4_static: - # If sync_grouped_weight_global_amax already promoted this - # quantizer (it's a sibling in a Q/K/V or gate/up group), - # its global_amax has been unified across the group; just - # leave it. Otherwise convert + set local global_amax. - if not isinstance(module, NVFP4StaticQuantizer): - global_amax = reduce_amax(initial_amax, axis=None) - NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) + # Promote standalone NVFP4-static quantizers; grouped siblings + # already promoted by _sync_grouped_weight_global_amax above. + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): + global_amax = reduce_amax(initial_amax, axis=None) + NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: # Check if backend has a registered custom calibrator factory. @@ -506,52 +503,48 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) - # Identify weight quantizers by checking if they have corresponding weight parameters + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. name_to_module = dict(model.named_modules()) + seen_modules: set[int] = set() + pbar = tqdm(desc="MSE weight calibration") + n_calibrated = 0 for parent_module in name_to_module.values(): - if parent_module in seen_modules: + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): continue - for weight_name in weight_attr_names(parent_module): - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: - if getattr(weight_quantizer, "_calibrator", None) is not None: - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) - seen_modules.add(parent_module) - - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation - # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") - ): - # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() + seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): - weight = getattr(parent_module, weight_name) - weight_quantizer(weight) + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + if not ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and getattr(weight_quantizer, "_calibrator", None) is not None + ): + continue + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + weight_quantizer(weight) - # IMMEDIATELY compute amax and reset calibrator to free memory - cal = getattr(weight_quantizer, "_calibrator", None) - if cal is not None and cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete - # This is critical for multi-GPU setups where tensors may be on different devices - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - if cal is not None and hasattr(cal, "reset"): - cal.reset() + if hasattr(cal, "reset"): + cal.reset() - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + pbar.update(1) + n_calibrated += 1 + if n_calibrated % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + pbar.close() if torch.cuda.is_available(): for dev_id in range(torch.cuda.device_count()): @@ -706,10 +699,7 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) - # Sync global_amax across sibling NVFP4-static weight quantizers - # (q/k/v_proj, gate/up_proj a.k.a. w1/w3) so the FP8 scale-of-scales - # is consistent across the group. Idempotent; no-op when fused. - sync_grouped_weight_global_amax(model) + _sync_grouped_weight_global_amax(model) # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 12649691453..fa540b8fdf5 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -516,13 +516,7 @@ def is_mx_format(self): @property def is_nvfp4_static(self): - """Check if this quantizer is configured for NVFP4 static block quantization. - - Format-only check (does not consider whether ``_amax`` has been - populated by calibration). True when the quantizer holds E2M1 weights - with E4M3 per-block scales in a static layout — i.e. the two-level - scaling NVFP4 path consumed by :class:`NVFP4StaticQuantizer`. - """ + """True for E2M1 weights + E4M3 per-block scales in static layout (format-only check).""" return ( self.is_static_block_quant and self._num_bits == (2, 1) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 77f26b20602..1873ecda528 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield ``(weight_slice, quantizer)`` per-expert pairs. + + The base impl uses singular ``*_weight_quantizer`` and skips fused- + experts modules, so weight-only calibration never reaches per-expert + quantizers without this override. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + yield weight[idx], q + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index e0ce2f0c66e..19e1ed49197 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -388,6 +388,110 @@ def _spy_export(wrapper, dtype): if QuantModuleRegistry.get(expert_type) is not None: QuantModuleRegistry.unregister(expert_type) + def test_per_block_amax_reshape_for_fused_export(self, monkeypatch): + """Per-block ``_amax`` (NVFP4 static, row axis collapsed) must be reshaped + before dim-0 slicing so gate's blocks and up's blocks are split correctly. + + Regression for the bug where a flat per-block ``_amax`` of shape + ``(fused_total * blocks_per_row,)`` was sliced naively, producing wrong + per-projection scales. The fix reshapes to ``(fused_total, blocks_per_row)`` + before slicing on dim-0 when ``amax.numel() % fused_total == 0``. + """ + from modelopt.torch.export.moe_utils import _export_fused_experts + + experts = _SyntheticFusedExperts() + expert_type = type(experts) + if QuantModuleRegistry.get(expert_type) is None: + QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})( + _QuantFusedExperts + ) + try: + converted = QuantModuleRegistry.convert(experts) + + # Per-block amax: 4 blocks per row. Distinct values per row so we can + # detect whether the reshape correctly preserves the row→block layout. + blocks_per_row = 4 + fused_total = 2 * INTERMEDIATE_DIM # gate_up rows + for idx in range(NUM_EXPERTS): + # Gate rows take values 1..INTERMEDIATE_DIM, up rows 101..101+INTERMEDIATE_DIM. + gate_amax = ( + torch.arange(1, INTERMEDIATE_DIM + 1).float().repeat_interleave(blocks_per_row) + ) + up_amax = ( + torch.arange(101, 101 + INTERMEDIATE_DIM) + .float() + .repeat_interleave(blocks_per_row) + ) + # Flat shape (fused_total * blocks_per_row,) — row axis collapsed. + flat = torch.cat([gate_amax, up_amax]) + assert flat.numel() == fused_total * blocks_per_row + + wq = converted.gate_up_proj_weight_quantizers[idx] + wq._disabled = False + wq.amax = flat + + # down_proj quantizers also need to look calibrated (otherwise + # the export-time fallback would compute amax from each weight + # slice and we'd skip the new reshape branch). Set a 1-D per-row + # amax that matches dim-0 of down_proj (so amax.numel() == fused_total + # for down). That intentionally does NOT exercise the new branch + # for down — we only want to exercise it for gate_up. + dwq = converted.down_proj_weight_quantizers[idx] + dwq._disabled = False + dwq.amax = torch.ones(HIDDEN_DIM) + + seen = {} + + def _spy_export(wrapper, dtype): + w = wrapper.weight.data + wq = wrapper.weight_quantizer + amax = wq._amax.detach().clone() if hasattr(wq, "_amax") else None + for idx in range(NUM_EXPERTS): + g_slice = converted.gate_up_proj.data[idx, :INTERMEDIATE_DIM, :] + u_slice = converted.gate_up_proj.data[idx, INTERMEDIATE_DIM:, :] + if w.shape == g_slice.shape and torch.equal(w, g_slice): + seen[(idx, "gate_proj")] = amax + return + if w.shape == u_slice.shape and torch.equal(w, u_slice): + seen[(idx, "up_proj")] = amax + return + + monkeypatch.setattr( + "modelopt.torch.export.unified_export_hf._export_quantized_weight", + _spy_export, + ) + + _export_fused_experts(converted, torch.float16) + + # gate's amax should contain values 1..INTERMEDIATE_DIM repeated + # blocks_per_row times, reshaped to (INTERMEDIATE_DIM, blocks_per_row); + # up's amax should contain 101..101+INTERMEDIATE_DIM same shape. + for idx in range(NUM_EXPERTS): + g_amax = seen.get((idx, "gate_proj")) + u_amax = seen.get((idx, "up_proj")) + assert g_amax is not None and u_amax is not None, ( + f"Expert {idx}: missing recorded amax" + ) + assert g_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} gate amax dim-0 should be {INTERMEDIATE_DIM} " + f"after reshape+slice, got {g_amax.shape}" + ) + assert u_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} up amax dim-0 should be {INTERMEDIATE_DIM}, got {u_amax.shape}" + ) + # First block of first row carries the marker value. + assert g_amax.flatten()[0].item() == 1.0, ( + f"Expert {idx} gate amax[0,0] should be 1.0 (gate row 0 marker), " + f"got {g_amax.flatten()[0].item()} — reshape probably didn't restore row axis" + ) + assert u_amax.flatten()[0].item() == 101.0, ( + f"Expert {idx} up amax[0,0] should be 101.0 (up row 0 marker), " + f"got {u_amax.flatten()[0].item()} — slice probably didn't separate gate from up" + ) + finally: + if QuantModuleRegistry.get(expert_type) is not None: + QuantModuleRegistry.unregister(expert_type) + # --------------------------------------------------------------------------- # Tests for force_eager_experts_impl_on_the_fly @@ -529,6 +633,101 @@ def forward_loop(m): self._cleanup_registry(expert_type) + def test_bootstrap_populates_dead_expert_quantizers(self): + """`_bootstrap_uncalibrated_weight_quantizers` fills `_amax` on experts the + forward pass never routed to. + + Regression for the dead-expert MSE skip: with partial routing during max + calibration, never-routed experts' weight quantizers stay with + ``_amax=None``; bootstrap must run the calibrator on the per-expert weight + slice (via ``iter_weights_for_calibration``) to populate them so MSE + doesn't skip them. + """ + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import ( + _bootstrap_uncalibrated_weight_quantizers, + ) + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + # Forward loop that routes only to experts 0 and 1 (deterministic). + # Bypasses the router and calls experts directly with crafted indices. + live = {0, 1} + dead = {idx for idx in range(NUM_EXPERTS) if idx not in live} + assert dead, "Test requires at least one dead expert" + + def partial_forward(m): + torch.manual_seed(0) + seq_len = 8 + hidden = torch.randn(seq_len, HIDDEN_DIM) + top_k_index = torch.zeros(seq_len, TOP_K, dtype=torch.long) + top_k_index[:, 0] = 0 + top_k_index[:, 1] = 1 + top_k_weights = torch.ones(seq_len, TOP_K) / TOP_K + with torch.no_grad(): + m.moe.experts(hidden, top_k_index, top_k_weights) + + mtq.quantize(model, quant_cfg, forward_loop=partial_forward) + + experts = model.moe.experts + + # Pre-bootstrap: dead experts have no/zero _amax. + for idx in dead: + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert getattr(gu_q, "_amax", None) is None or torch.all(gu_q._amax == 0), ( + f"Dead expert {idx} gate_up_proj should be uncalibrated pre-bootstrap" + ) + assert getattr(d_q, "_amax", None) is None or torch.all(d_q._amax == 0), ( + f"Dead expert {idx} down_proj should be uncalibrated pre-bootstrap" + ) + + n_bootstrapped = _bootstrap_uncalibrated_weight_quantizers(model) + assert n_bootstrapped >= 2 * len(dead), ( + f"Expected ≥{2 * len(dead)} bootstrapped (gate_up + down per dead expert), " + f"got {n_bootstrapped}" + ) + + # Post-bootstrap: every expert has populated _amax matching max(|weight|). + for idx in range(NUM_EXPERTS): + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert gu_q._amax is not None and not torch.all(gu_q._amax == 0), ( + f"Expert {idx} gate_up_proj _amax not populated after bootstrap" + ) + assert d_q._amax is not None and not torch.all(d_q._amax == 0), ( + f"Expert {idx} down_proj _amax not populated after bootstrap" + ) + + # For dead experts, bootstrap reads max(|weight|). Sanity-check it matches + # the actual weight tensor's per-row max (axis=0 reduces over hidden_dim). + for idx in dead: + expected = experts.gate_up_proj.data[idx].abs().amax(dim=1) + got = experts.gate_up_proj_weight_quantizers[idx]._amax.flatten() + assert torch.allclose(got, expected, atol=1e-4), ( + f"Expert {idx} bootstrap amax should equal per-row max(|weight|); " + f"max diff {(got - expected).abs().max().item()}" + ) + + self._cleanup_registry(expert_type) + # --------------------------------------------------------------------------- # Tests for export enumeration — guards the bug where fused-experts were