Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
and w_quantizer._amax.dim() >= 1
):
amax = w_quantizer._amax
# Per-block _amax (NVFP4 static) collapses the row axis we want
# to slice on; restore it so dim-0 slicing splits gate/up.
if amax.numel() != fused_total and amax.numel() % fused_total == 0:
amax = amax.contiguous().view(fused_total, amax.numel() // fused_total)
amax_dim0 = amax.shape[0]
if fused_total % amax_dim0 == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
sliced = amax[slice_start:slice_end].contiguous()
# The amax setter refuses shape changes; drop _amax first.
if hasattr(w_quantizer, "_amax"):
delattr(w_quantizer, "_amax")
w_quantizer.amax = sliced
else:
warnings.warn(
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand Down
15 changes: 15 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
mod.revert_weight_conversion = original


def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None:
"""Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set.

Newer transformers reject ``do_sample=False`` mixed with sampling attrs in
``save_pretrained``'s strict validate.
"""
gc = getattr(model, "generation_config", None)
if gc is None:
return
if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None:
gc.do_sample = True

Comment thread
coderabbitai[bot] marked this conversation as resolved.

def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1228,6 +1241,8 @@ def export_hf_checkpoint(
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

_sanitize_generation_config_for_save(model)

try:
model.save_pretrained(
export_dir,
Expand Down
210 changes: 147 additions & 63 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
promote_nvfp4_static_quantizers,
quantizer_attr_names,
reduce_amax,
weight_attr_names,
)
from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper

Expand All @@ -66,6 +65,107 @@
"svdquant",
]


def _is_calibrated_nvfp4_static(q) -> bool:
"""True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set."""
return (
isinstance(q, TensorQuantizer)
and not q._disabled
and q.is_nvfp4_static
and getattr(q, "_amax", None) is not None
)


def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]:
"""Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is only specific to HF it also belongs in plugins/huggingface.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reasoning as in the reply to your other comment about moving HF-specific functions: _is_calibrated_nvfp4_static is a pure TensorQuantizer property check with zero HF dependence, and _collect_grouped_linears uses pattern names that happen to be HF conventions but operates on generic nn.Module. Leaving these in model_calib.py keeps the sibling-sync helpers next to their only caller (_sync_grouped_weight_global_amax in mse_calibrate).

# Inline: layer_utils → quant_utils → model_calib cycle.
from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS

# Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant
# in export). Single source for the gate/up half avoids parallel lists.
patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS)
groups: list[list[nn.Module]] = []
wq_attr = quantizer_attr_names("weight").weight_quantizer
for parent in model.modules():
for sibling_names in patterns:
members = [
child
for child in (getattr(parent, n, None) for n in sibling_names)
if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None))
]
if len(members) >= 2:
groups.append(members)
return groups


@torch.no_grad()
def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this during calibration time and how is this different from what we do at export time _ensure_weight_quantizer_calibrated (mte.quant_utils.py)? can we either reuse the existing code inside this helper function or just rely on export to fill in the missing amax?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Two parts:

Why calibration-time, not export-time: _ensure_weight_quantizer_calibrated runs at export and does a max-style fill-in for missing _amax. If we deferred dead-expert handling to export only:

  • MSE's Step 3 walk would skip dead experts (they have _amax=None at MSE time)
  • _sync_grouped_weight_global_amax would skip them (the _is_calibrated_nvfp4_static check requires _amax populated)
  • promote_nvfp4_static_quantizers (in max_calibrate) wouldn't promote them either (it gates on _amax populated)
  • Dead experts would receive only export-time max calibration — no MSE refinement, no sibling-amax sync, no NVFP4-static promotion

That defeats the PR's main feature ("MSE-calibrate every per-expert weight in fused-experts MoE"); dead experts would just go back to max-only. The bootstrap is timed before MSE/sync/promotion to bring dead experts into all three.

Reuse of the existing helper: reasonable as a follow-up DRY win, but they differ in detail today:

  • _ensure_weight_quantizer_calibrated is per-quantizer; for NVFP4StaticQuantizer it uses reduce_block_amax + sets global_amax directly
  • Our bootstrap walks all QuantModules, uses each quantizer's existing calibrator (so any calibrator class works, not just max), runs inside enable_weight_access_and_writeback for FSDP/TP safety, and emits a single aggregate warning

Factoring out a shared inner step would save a handful of lines but pulls in another export → quantization inline import and switches the bootstrap from "calibrator-driven" to "tensor-driven". Happy to do as a small follow-up alongside moving preprocess_linear_fusion (also discussed elsewhere on this PR) — but not blocking the current fix.

"""Re-run weight calibration on the weight tensor for quantizers missing ``_amax``.

Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE
doesn't drop them and break the gate==up ``weight_scale_2`` export invariant.
Activation quantizers on those modules remain uncalibrated; emits a warning.
"""
name_to_module = dict(model.named_modules())
n = 0
for module in name_to_module.values():
if not isinstance(module, QuantModule):
continue
with enable_weight_access_and_writeback(module, model, name_to_module):
for weight, q in module.iter_weights_for_calibration():
if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic:
continue
if q._calibrator is None:
continue
if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0):
continue
q.disable_quant()
q.enable_calib()
q(weight)
if q._calibrator.compute_amax() is not None:
q.load_calib_amax()
q.enable_quant()
q.disable_calib()
if hasattr(q._calibrator, "reset"):
q._calibrator.reset()
n += 1
if n > 0:
warnings.warn(
f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; "
f"their activation quantizers (if any) remain uncalibrated. "
f"Increase calib size/seq len to activate all experts.",
stacklevel=2,
)
return n
Comment thread
coderabbitai[bot] marked this conversation as resolved.


@torch.no_grad()
def _sync_grouped_weight_global_amax(model: nn.Module) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is also useful for other algorithms like AWQ and GPTQ. We need to sync the amax for fused layers before some algorithm begins. cc @sychen52 on the design of fused modules.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move any HF specific functions to plugins/huggingface.py?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The helpers in this section operate on generic types (TensorQuantizer, QuantModule, NVFP4StaticQuantizer) and have no HF imports or HF-class isinstance checks:

  • _is_calibrated_nvfp4_static — pure check on TensorQuantizer.is_nvfp4_static property; could apply to a Megatron model with NVFP4-static quantizers.
  • _collect_grouped_linears — uses pattern names that happen to be common HF conventions (q_proj, gate_proj, w1, etc.) but operates on any nn.Module. The pattern data is now sourced from export.layer_utils._GATE_UP_PAIRS for the gate/up half.
  • _bootstrap_uncalibrated_weight_quantizers — walks QuantModule instances and uses each module's iter_weights_for_calibration override polymorphically. Works for any registered fused-experts container (Megatron's would be covered too if registered as QuantModule).
  • _sync_grouped_weight_global_amax — same; operates on the result of _collect_grouped_linears.

The HF plugin's role is registering specialized quant modules (like _QuantFusedExperts for HF *MoeExperts classes), not housing every helper that uses common naming conventions. Moving these would add a layering hop (model_calib → plugins/huggingface) for what's actually quantization-level logic invoked from mse_calibrate.

"""Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers.

Run after ``max_calibrate``. Sibling discovery is name-based via
``_collect_grouped_linears``; non-matching architectures (wqkv, fused
qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently
fall back to per-module global_amax. Fused-experts containers already
share a single quantizer across gate/up halves and need no sync.
"""
# quant_utils imports back from this module; top-level would cycle.
from modelopt.torch.export.quant_utils import preprocess_linear_fusion

wq_attr = quantizer_attr_names("weight").weight_quantizer
n_groups = 0
for group in _collect_grouped_linears(model):
for child in group:
wq = getattr(child, wq_attr)
if not isinstance(wq, NVFP4StaticQuantizer):
NVFP4StaticQuantizer.from_tensor_quantizer(
wq, global_amax=reduce_amax(wq._amax, axis=None)
)
preprocess_linear_fusion(group)
n_groups += 1
return n_groups


CalibratorFactory: TypeAlias = Callable[
[torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator
]
Expand Down Expand Up @@ -346,32 +446,23 @@ def mse_calibrate(
See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for
details on the remaining arguments.
"""
# Step 1: First get initial amax using max calibration
# Step 1: max calibrate, bootstrap dead-expert weight quantizers,
# unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid.
max_calibrate(model, forward_loop, distributed_sync)
_bootstrap_uncalibrated_weight_quantizers(model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why the dead-expert weight quantizers does not break the MAX calibration export path with NVFP4 dynamic quantizer. i.e., does max_calibrate need this fix as well?
Does it make more sense to run _bootstrap_uncalibrated_weight_quantizers in max_calibrate. (after weight_only_quantize/forward,
before promote_nvfp4_static_quantizers)? Therefore other recipes (AWQ, GPTQ, local-hessian) gets the fix as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. Max calibrate can be covered in the export time. @Edwardf0t1 added the logics that covers the unquantized experts.

_sync_grouped_weight_global_amax(model)

# Step 2: Replace calibrators with MseCalibrator for enabled quantizers
# and identify weight quantizers
weight_quantizers = []
seen_modules = set()

# Step 2: replace calibrators with MseCalibrator for enabled quantizers.
for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
# Get the initial amax from max calibration
initial_amax = module._amax.clone().detach()
is_nvfp4_static = module.is_nvfp4_static

is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)

if is_nvfp4_static:
# Compute and set global_amax
# Promote standalone NVFP4-static quantizers; grouped siblings
# already promoted by _sync_grouped_weight_global_amax above.
if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)

# Convert to NVFP4StaticQuantizer in-place
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)

if fp8_scale_sweep:
Expand Down Expand Up @@ -412,52 +503,48 @@ def mse_calibrate(
quant_func=partial(_mse_quant_func, quantizer=module),
)

# Identify weight quantizers by checking if they have corresponding weight parameters
# Step 3: calibrate weight quantizers via iter_weights_for_calibration.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is too long and specific to fused experts. remove mention of fused experts to make it more generic and in general reduce length of AI comments

name_to_module = dict(model.named_modules())
seen_modules: set[int] = set()
pbar = tqdm(desc="MSE weight calibration")
n_calibrated = 0
for parent_module in name_to_module.values():
if parent_module in seen_modules:
if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule):
continue
for weight_name in weight_attr_names(parent_module):
weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer
weight_quantizer = getattr(parent_module, weight_quantizer_name, None)
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
# This prevents massive memory accumulation seen in large models
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(
tqdm(weight_quantizers, desc="MSE weight calibration")
):
# Enable calibration mode for the weight quantizer
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
seen_modules.add(id(parent_module))
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)
for weight, weight_quantizer in parent_module.iter_weights_for_calibration():
if not (
isinstance(weight_quantizer, TensorQuantizer)
and weight_quantizer.is_enabled
and getattr(weight_quantizer, "_calibrator", None) is not None
):
continue
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
weight_quantizer(weight)

# IMMEDIATELY compute amax and reset calibrator to free memory
cal = getattr(weight_quantizer, "_calibrator", None)
if cal is not None and cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()
cal = weight_quantizer._calibrator
if cal.compute_amax() is not None:
weight_quantizer.load_calib_amax()

weight_quantizer.enable_quant()
weight_quantizer.disable_calib()
weight_quantizer.enable_quant()
weight_quantizer.disable_calib()

# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this comment removed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old comment described the previous weight_attr_names lookup; with the refactor to parent_module.iter_weights_for_calibration() the call site is self-describing. Trimmed the replacement block to a single line per your other comment.

# This is critical for multi-GPU setups where tensors may be on different devices
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))

if cal is not None and hasattr(cal, "reset"):
cal.reset()
if hasattr(cal, "reset"):
cal.reset()

if (idx + 1) % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.update(1)
n_calibrated += 1
if n_calibrated % 10 == 0 and torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
torch.cuda.empty_cache()
pbar.close()

if torch.cuda.is_available():
for dev_id in range(torch.cuda.device_count()):
Expand Down Expand Up @@ -612,6 +699,8 @@ def forward(self, input, *args, **kwargs):
print_rank_0("local_hessian: Running max calibration for all quantizers...")
max_calibrate(model, forward_loop, distributed_sync)

_sync_grouped_weight_global_amax(model)

# Setup helpers for all quantized linear modules
name_to_module = dict(model.named_modules())
weight_quantizers_info = []
Expand Down Expand Up @@ -666,14 +755,9 @@ def quant_func(x, amax, quantizer=weight_quantizer):

return xq

is_nvfp4_static = (
weight_quantizer.is_static_block_quant
and weight_quantizer._num_bits == (2, 1)
and weight_quantizer._block_sizes is not None
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
)
is_nvfp4_static = weight_quantizer.is_nvfp4_static

if is_nvfp4_static:
if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer):
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)

Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,16 @@ def is_mx_format(self):
and self.block_sizes.get("scale_bits", None) == (8, 0)
)

@property
def is_nvfp4_static(self):
"""True for E2M1 weights + E4M3 per-block scales in static layout (format-only check)."""
return (
self.is_static_block_quant
and self._num_bits == (2, 1)
and self._block_sizes is not None
and self._block_sizes.get("scale_bits") == (4, 3)
)

def is_mxfp(self, bits):
"""Check if is MXFP4/MXFP6/MXFP8."""
if bits == 4:
Expand Down
18 changes: 18 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs):
self._down_proj_linear = False
return super().forward(*args, **kwargs)

def iter_weights_for_calibration(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this name is misleading, you are only iterating fused_expert weights, not all weights

Suggested change
def iter_weights_for_calibration(self):
def iter_fused_expert_weights_for_calibration(self):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an override of the base QuantModule.iter_weights_for_calibration (defined in modelopt/torch/quantization/nn/modules/quant_module.py:122). Callers like mse_calibrate, weight_only_quantize, and _bootstrap_uncalibrated_weight_quantizers invoke module.iter_weights_for_calibration() polymorphically — they hit the overridden method by class dispatch, not by checking for fused-experts specifically. Renaming this override would break that dispatch: the base method would still be called by name, but _QuantFusedExperts instances would fall back to the base impl (which uses singular *_weight_quantizer and silently skips fused-experts modules — the exact bug this PR fixes).

"""Yield ``(weight_slice, quantizer)`` per-expert pairs.

The base impl uses singular ``*_weight_quantizer`` and skips fused-
experts modules, so weight-only calibration never reaches per-expert
quantizers without this override.
"""
for weight_name, quantizers_name in (
("gate_up_proj", "gate_up_proj_weight_quantizers"),
("down_proj", "down_proj_weight_quantizers"),
):
weight = getattr(self, weight_name, None)
quantizers = getattr(self, quantizers_name, None)
if weight is None or quantizers is None:
continue
for idx, q in enumerate(quantizers):
yield weight[idx], q

def fold_weight(self, keep_attrs: bool = False):
"""Fold per-expert weight quantizers into the fused 3-D weights.

Expand Down
8 changes: 1 addition & 7 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int:
for _name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
is_nvfp4_static = (
module.is_static_block_quant
and module._num_bits == (2, 1)
and module._block_sizes is not None
and module._block_sizes.get("scale_bits") == (4, 3)
)
if is_nvfp4_static:
if module.is_nvfp4_static:
initial_amax = module._amax.clone().detach()
global_amax = reduce_amax(initial_amax, axis=None)
NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax)
Expand Down
Loading
Loading