diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 8981d614843..e325e5346f1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -110,11 +110,19 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: and w_quantizer._amax.dim() >= 1 ): amax = w_quantizer._amax + # Per-block _amax (NVFP4 static) collapses the row axis we want + # to slice on; restore it so dim-0 slicing splits gate/up. + if amax.numel() != fused_total and amax.numel() % fused_total == 0: + amax = amax.contiguous().view(fused_total, amax.numel() // fused_total) amax_dim0 = amax.shape[0] if fused_total % amax_dim0 == 0: slice_start = fused_start * amax_dim0 // fused_total slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total - w_quantizer.amax = amax[slice_start:slice_end].contiguous() + sliced = amax[slice_start:slice_end].contiguous() + # The amax setter refuses shape changes; drop _amax first. + if hasattr(w_quantizer, "_amax"): + delattr(w_quantizer, "_amax") + w_quantizer.amax = sliced else: warnings.warn( f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a58aa4c9895..73ae63a5a56 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -1134,6 +1134,19 @@ def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: mod.revert_weight_conversion = original +def _sanitize_generation_config_for_save(model: torch.nn.Module) -> None: + """Force ``do_sample=True`` when generation_config has ``top_k``/``top_p`` set. + + Newer transformers reject ``do_sample=False`` mixed with sampling attrs in + ``save_pretrained``'s strict validate. + """ + gc = getattr(model, "generation_config", None) + if gc is None: + return + if getattr(gc, "top_k", None) is not None or getattr(gc, "top_p", None) is not None: + gc.do_sample = True + + def export_speculative_decoding( model: torch.nn.Module, dtype: torch.dtype | None = None, @@ -1228,6 +1241,8 @@ def export_hf_checkpoint( # modeling_utils does `from core_model_loading import revert_weight_conversion`. _patches = _patch_revert_weight_conversion() + _sanitize_generation_config_for_save(model) + try: model.save_pretrained( export_dir, diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index fe4c3f77ce6..bce49786077 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -52,7 +52,6 @@ promote_nvfp4_static_quantizers, quantizer_attr_names, reduce_amax, - weight_attr_names, ) from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper @@ -66,6 +65,107 @@ "svdquant", ] + +def _is_calibrated_nvfp4_static(q) -> bool: + """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" + return ( + isinstance(q, TensorQuantizer) + and not q._disabled + and q.is_nvfp4_static + and getattr(q, "_amax", None) is not None + ) + + +def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: + """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" + # Inline: layer_utils → quant_utils → model_calib cycle. + from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS + + # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant + # in export). Single source for the gate/up half avoids parallel lists. + patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) + groups: list[list[nn.Module]] = [] + wq_attr = quantizer_attr_names("weight").weight_quantizer + for parent in model.modules(): + for sibling_names in patterns: + members = [ + child + for child in (getattr(parent, n, None) for n in sibling_names) + if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) + ] + if len(members) >= 2: + groups.append(members) + return groups + + +@torch.no_grad() +def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: + """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. + + Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE + doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. + Activation quantizers on those modules remain uncalibrated; emits a warning. + """ + name_to_module = dict(model.named_modules()) + n = 0 + for module in name_to_module.values(): + if not isinstance(module, QuantModule): + continue + with enable_weight_access_and_writeback(module, model, name_to_module): + for weight, q in module.iter_weights_for_calibration(): + if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: + continue + if q._calibrator is None: + continue + if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): + continue + q.disable_quant() + q.enable_calib() + q(weight) + if q._calibrator.compute_amax() is not None: + q.load_calib_amax() + q.enable_quant() + q.disable_calib() + if hasattr(q._calibrator, "reset"): + q._calibrator.reset() + n += 1 + if n > 0: + warnings.warn( + f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " + f"their activation quantizers (if any) remain uncalibrated. " + f"Increase calib size/seq len to activate all experts.", + stacklevel=2, + ) + return n + + +@torch.no_grad() +def _sync_grouped_weight_global_amax(model: nn.Module) -> int: + """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. + + Run after ``max_calibrate``. Sibling discovery is name-based via + ``_collect_grouped_linears``; non-matching architectures (wqkv, fused + qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently + fall back to per-module global_amax. Fused-experts containers already + share a single quantizer across gate/up halves and need no sync. + """ + # quant_utils imports back from this module; top-level would cycle. + from modelopt.torch.export.quant_utils import preprocess_linear_fusion + + wq_attr = quantizer_attr_names("weight").weight_quantizer + n_groups = 0 + for group in _collect_grouped_linears(model): + for child in group: + wq = getattr(child, wq_attr) + if not isinstance(wq, NVFP4StaticQuantizer): + NVFP4StaticQuantizer.from_tensor_quantizer( + wq, global_amax=reduce_amax(wq._amax, axis=None) + ) + preprocess_linear_fusion(group) + n_groups += 1 + return n_groups + + CalibratorFactory: TypeAlias = Callable[ [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator ] @@ -346,32 +446,23 @@ def mse_calibrate( See :class:`MseCalibConfig ` for details on the remaining arguments. """ - # Step 1: First get initial amax using max calibration + # Step 1: max calibrate, bootstrap dead-expert weight quantizers, + # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. max_calibrate(model, forward_loop, distributed_sync) + _bootstrap_uncalibrated_weight_quantizers(model) + _sync_grouped_weight_global_amax(model) - # Step 2: Replace calibrators with MseCalibrator for enabled quantizers - # and identify weight quantizers - weight_quantizers = [] - seen_modules = set() - + # Step 2: replace calibrators with MseCalibrator for enabled quantizers. for name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - # Get the initial amax from max calibration initial_amax = module._amax.clone().detach() + is_nvfp4_static = module.is_nvfp4_static - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - - if is_nvfp4_static: - # Compute and set global_amax + # Promote standalone NVFP4-static quantizers; grouped siblings + # already promoted by _sync_grouped_weight_global_amax above. + if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) - - # Convert to NVFP4StaticQuantizer in-place NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) if fp8_scale_sweep: @@ -412,52 +503,48 @@ def mse_calibrate( quant_func=partial(_mse_quant_func, quantizer=module), ) - # Identify weight quantizers by checking if they have corresponding weight parameters + # Step 3: calibrate weight quantizers via iter_weights_for_calibration. name_to_module = dict(model.named_modules()) + seen_modules: set[int] = set() + pbar = tqdm(desc="MSE weight calibration") + n_calibrated = 0 for parent_module in name_to_module.values(): - if parent_module in seen_modules: + if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): continue - for weight_name in weight_attr_names(parent_module): - weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer - weight_quantizer = getattr(parent_module, weight_quantizer_name, None) - if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: - if getattr(weight_quantizer, "_calibrator", None) is not None: - weight_quantizers.append((parent_module, weight_name, weight_quantizer)) - seen_modules.add(parent_module) - - # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation - # This prevents massive memory accumulation seen in large models - for idx, (parent_module, weight_name, weight_quantizer) in enumerate( - tqdm(weight_quantizers, desc="MSE weight calibration") - ): - # Enable calibration mode for the weight quantizer - weight_quantizer.disable_quant() - weight_quantizer.enable_calib() + seen_modules.add(id(parent_module)) with enable_weight_access_and_writeback(parent_module, model, name_to_module): - weight = getattr(parent_module, weight_name) - weight_quantizer(weight) + for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): + if not ( + isinstance(weight_quantizer, TensorQuantizer) + and weight_quantizer.is_enabled + and getattr(weight_quantizer, "_calibrator", None) is not None + ): + continue + weight_quantizer.disable_quant() + weight_quantizer.enable_calib() + weight_quantizer(weight) - # IMMEDIATELY compute amax and reset calibrator to free memory - cal = getattr(weight_quantizer, "_calibrator", None) - if cal is not None and cal.compute_amax() is not None: - weight_quantizer.load_calib_amax() + cal = weight_quantizer._calibrator + if cal.compute_amax() is not None: + weight_quantizer.load_calib_amax() - weight_quantizer.enable_quant() - weight_quantizer.disable_calib() + weight_quantizer.enable_quant() + weight_quantizer.disable_calib() - # Synchronize ALL CUDA devices before resetting to ensure all async operations complete - # This is critical for multi-GPU setups where tensors may be on different devices - if torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + if torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - if cal is not None and hasattr(cal, "reset"): - cal.reset() + if hasattr(cal, "reset"): + cal.reset() - if (idx + 1) % 10 == 0 and torch.cuda.is_available(): - for dev_id in range(torch.cuda.device_count()): - torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) - torch.cuda.empty_cache() + pbar.update(1) + n_calibrated += 1 + if n_calibrated % 10 == 0 and torch.cuda.is_available(): + for dev_id in range(torch.cuda.device_count()): + torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) + torch.cuda.empty_cache() + pbar.close() if torch.cuda.is_available(): for dev_id in range(torch.cuda.device_count()): @@ -612,6 +699,8 @@ def forward(self, input, *args, **kwargs): print_rank_0("local_hessian: Running max calibration for all quantizers...") max_calibrate(model, forward_loop, distributed_sync) + _sync_grouped_weight_global_amax(model) + # Setup helpers for all quantized linear modules name_to_module = dict(model.named_modules()) weight_quantizers_info = [] @@ -666,14 +755,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): return xq - is_nvfp4_static = ( - weight_quantizer.is_static_block_quant - and weight_quantizer._num_bits == (2, 1) - and weight_quantizer._block_sizes is not None - and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) - ) + is_nvfp4_static = weight_quantizer.is_nvfp4_static - if is_nvfp4_static: + if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 3ff7401ec3e..fa540b8fdf5 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -514,6 +514,16 @@ def is_mx_format(self): and self.block_sizes.get("scale_bits", None) == (8, 0) ) + @property + def is_nvfp4_static(self): + """True for E2M1 weights + E4M3 per-block scales in static layout (format-only check).""" + return ( + self.is_static_block_quant + and self._num_bits == (2, 1) + and self._block_sizes is not None + and self._block_sizes.get("scale_bits") == (4, 3) + ) + def is_mxfp(self, bits): """Check if is MXFP4/MXFP6/MXFP8.""" if bits == 4: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 77f26b20602..1873ecda528 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -900,6 +900,24 @@ def forward(self, *args, **kwargs): self._down_proj_linear = False return super().forward(*args, **kwargs) + def iter_weights_for_calibration(self): + """Yield ``(weight_slice, quantizer)`` per-expert pairs. + + The base impl uses singular ``*_weight_quantizer`` and skips fused- + experts modules, so weight-only calibration never reaches per-expert + quantizers without this override. + """ + for weight_name, quantizers_name in ( + ("gate_up_proj", "gate_up_proj_weight_quantizers"), + ("down_proj", "down_proj_weight_quantizers"), + ): + weight = getattr(self, weight_name, None) + quantizers = getattr(self, quantizers_name, None) + if weight is None or quantizers is None: + continue + for idx, q in enumerate(quantizers): + yield weight[idx], q + def fold_weight(self, keep_attrs: bool = False): """Fold per-expert weight quantizers into the fused 3-D weights. diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 1a177e04dc8..cea3d4260e4 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -957,13 +957,7 @@ def promote_nvfp4_static_quantizers(model: nn.Module) -> int: for _name, module in list(model.named_modules()): if isinstance(module, TensorQuantizer) and not module._disabled: if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): - is_nvfp4_static = ( - module.is_static_block_quant - and module._num_bits == (2, 1) - and module._block_sizes is not None - and module._block_sizes.get("scale_bits") == (4, 3) - ) - if is_nvfp4_static: + if module.is_nvfp4_static: initial_amax = module._amax.clone().detach() global_amax = reduce_amax(initial_amax, axis=None) NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index e0ce2f0c66e..19e1ed49197 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -388,6 +388,110 @@ def _spy_export(wrapper, dtype): if QuantModuleRegistry.get(expert_type) is not None: QuantModuleRegistry.unregister(expert_type) + def test_per_block_amax_reshape_for_fused_export(self, monkeypatch): + """Per-block ``_amax`` (NVFP4 static, row axis collapsed) must be reshaped + before dim-0 slicing so gate's blocks and up's blocks are split correctly. + + Regression for the bug where a flat per-block ``_amax`` of shape + ``(fused_total * blocks_per_row,)`` was sliced naively, producing wrong + per-projection scales. The fix reshapes to ``(fused_total, blocks_per_row)`` + before slicing on dim-0 when ``amax.numel() % fused_total == 0``. + """ + from modelopt.torch.export.moe_utils import _export_fused_experts + + experts = _SyntheticFusedExperts() + expert_type = type(experts) + if QuantModuleRegistry.get(expert_type) is None: + QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})( + _QuantFusedExperts + ) + try: + converted = QuantModuleRegistry.convert(experts) + + # Per-block amax: 4 blocks per row. Distinct values per row so we can + # detect whether the reshape correctly preserves the row→block layout. + blocks_per_row = 4 + fused_total = 2 * INTERMEDIATE_DIM # gate_up rows + for idx in range(NUM_EXPERTS): + # Gate rows take values 1..INTERMEDIATE_DIM, up rows 101..101+INTERMEDIATE_DIM. + gate_amax = ( + torch.arange(1, INTERMEDIATE_DIM + 1).float().repeat_interleave(blocks_per_row) + ) + up_amax = ( + torch.arange(101, 101 + INTERMEDIATE_DIM) + .float() + .repeat_interleave(blocks_per_row) + ) + # Flat shape (fused_total * blocks_per_row,) — row axis collapsed. + flat = torch.cat([gate_amax, up_amax]) + assert flat.numel() == fused_total * blocks_per_row + + wq = converted.gate_up_proj_weight_quantizers[idx] + wq._disabled = False + wq.amax = flat + + # down_proj quantizers also need to look calibrated (otherwise + # the export-time fallback would compute amax from each weight + # slice and we'd skip the new reshape branch). Set a 1-D per-row + # amax that matches dim-0 of down_proj (so amax.numel() == fused_total + # for down). That intentionally does NOT exercise the new branch + # for down — we only want to exercise it for gate_up. + dwq = converted.down_proj_weight_quantizers[idx] + dwq._disabled = False + dwq.amax = torch.ones(HIDDEN_DIM) + + seen = {} + + def _spy_export(wrapper, dtype): + w = wrapper.weight.data + wq = wrapper.weight_quantizer + amax = wq._amax.detach().clone() if hasattr(wq, "_amax") else None + for idx in range(NUM_EXPERTS): + g_slice = converted.gate_up_proj.data[idx, :INTERMEDIATE_DIM, :] + u_slice = converted.gate_up_proj.data[idx, INTERMEDIATE_DIM:, :] + if w.shape == g_slice.shape and torch.equal(w, g_slice): + seen[(idx, "gate_proj")] = amax + return + if w.shape == u_slice.shape and torch.equal(w, u_slice): + seen[(idx, "up_proj")] = amax + return + + monkeypatch.setattr( + "modelopt.torch.export.unified_export_hf._export_quantized_weight", + _spy_export, + ) + + _export_fused_experts(converted, torch.float16) + + # gate's amax should contain values 1..INTERMEDIATE_DIM repeated + # blocks_per_row times, reshaped to (INTERMEDIATE_DIM, blocks_per_row); + # up's amax should contain 101..101+INTERMEDIATE_DIM same shape. + for idx in range(NUM_EXPERTS): + g_amax = seen.get((idx, "gate_proj")) + u_amax = seen.get((idx, "up_proj")) + assert g_amax is not None and u_amax is not None, ( + f"Expert {idx}: missing recorded amax" + ) + assert g_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} gate amax dim-0 should be {INTERMEDIATE_DIM} " + f"after reshape+slice, got {g_amax.shape}" + ) + assert u_amax.shape[0] == INTERMEDIATE_DIM, ( + f"Expert {idx} up amax dim-0 should be {INTERMEDIATE_DIM}, got {u_amax.shape}" + ) + # First block of first row carries the marker value. + assert g_amax.flatten()[0].item() == 1.0, ( + f"Expert {idx} gate amax[0,0] should be 1.0 (gate row 0 marker), " + f"got {g_amax.flatten()[0].item()} — reshape probably didn't restore row axis" + ) + assert u_amax.flatten()[0].item() == 101.0, ( + f"Expert {idx} up amax[0,0] should be 101.0 (up row 0 marker), " + f"got {u_amax.flatten()[0].item()} — slice probably didn't separate gate from up" + ) + finally: + if QuantModuleRegistry.get(expert_type) is not None: + QuantModuleRegistry.unregister(expert_type) + # --------------------------------------------------------------------------- # Tests for force_eager_experts_impl_on_the_fly @@ -529,6 +633,101 @@ def forward_loop(m): self._cleanup_registry(expert_type) + def test_bootstrap_populates_dead_expert_quantizers(self): + """`_bootstrap_uncalibrated_weight_quantizers` fills `_amax` on experts the + forward pass never routed to. + + Regression for the dead-expert MSE skip: with partial routing during max + calibration, never-routed experts' weight quantizers stay with + ``_amax=None``; bootstrap must run the calibrator on the per-expert weight + slice (via ``iter_weights_for_calibration``) to populate them so MSE + doesn't skip them. + """ + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.model_calib import ( + _bootstrap_uncalibrated_weight_quantizers, + ) + + model = _TinyMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + quant_cfg = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*gate_up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + # Forward loop that routes only to experts 0 and 1 (deterministic). + # Bypasses the router and calls experts directly with crafted indices. + live = {0, 1} + dead = {idx for idx in range(NUM_EXPERTS) if idx not in live} + assert dead, "Test requires at least one dead expert" + + def partial_forward(m): + torch.manual_seed(0) + seq_len = 8 + hidden = torch.randn(seq_len, HIDDEN_DIM) + top_k_index = torch.zeros(seq_len, TOP_K, dtype=torch.long) + top_k_index[:, 0] = 0 + top_k_index[:, 1] = 1 + top_k_weights = torch.ones(seq_len, TOP_K) / TOP_K + with torch.no_grad(): + m.moe.experts(hidden, top_k_index, top_k_weights) + + mtq.quantize(model, quant_cfg, forward_loop=partial_forward) + + experts = model.moe.experts + + # Pre-bootstrap: dead experts have no/zero _amax. + for idx in dead: + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert getattr(gu_q, "_amax", None) is None or torch.all(gu_q._amax == 0), ( + f"Dead expert {idx} gate_up_proj should be uncalibrated pre-bootstrap" + ) + assert getattr(d_q, "_amax", None) is None or torch.all(d_q._amax == 0), ( + f"Dead expert {idx} down_proj should be uncalibrated pre-bootstrap" + ) + + n_bootstrapped = _bootstrap_uncalibrated_weight_quantizers(model) + assert n_bootstrapped >= 2 * len(dead), ( + f"Expected ≥{2 * len(dead)} bootstrapped (gate_up + down per dead expert), " + f"got {n_bootstrapped}" + ) + + # Post-bootstrap: every expert has populated _amax matching max(|weight|). + for idx in range(NUM_EXPERTS): + gu_q = experts.gate_up_proj_weight_quantizers[idx] + d_q = experts.down_proj_weight_quantizers[idx] + assert gu_q._amax is not None and not torch.all(gu_q._amax == 0), ( + f"Expert {idx} gate_up_proj _amax not populated after bootstrap" + ) + assert d_q._amax is not None and not torch.all(d_q._amax == 0), ( + f"Expert {idx} down_proj _amax not populated after bootstrap" + ) + + # For dead experts, bootstrap reads max(|weight|). Sanity-check it matches + # the actual weight tensor's per-row max (axis=0 reduces over hidden_dim). + for idx in dead: + expected = experts.gate_up_proj.data[idx].abs().amax(dim=1) + got = experts.gate_up_proj_weight_quantizers[idx]._amax.flatten() + assert torch.allclose(got, expected, atol=1e-4), ( + f"Expert {idx} bootstrap amax should equal per-row max(|weight|); " + f"max diff {(got - expected).abs().max().item()}" + ) + + self._cleanup_registry(expert_type) + # --------------------------------------------------------------------------- # Tests for export enumeration — guards the bug where fused-experts were