-
Notifications
You must be signed in to change notification settings - Fork 400
[Quantization] MSE-calibrate every per-expert weight in fused-experts MoE #1421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,7 +52,6 @@ | |
| promote_nvfp4_static_quantizers, | ||
| quantizer_attr_names, | ||
| reduce_amax, | ||
| weight_attr_names, | ||
| ) | ||
| from .utils.calib_utils import _GPTQ_HELPER_REGISTRY, GPTQHelper | ||
|
|
||
|
|
@@ -66,6 +65,107 @@ | |
| "svdquant", | ||
| ] | ||
|
|
||
|
|
||
| def _is_calibrated_nvfp4_static(q) -> bool: | ||
| """True iff ``q`` is an enabled NVFP4-static weight quantizer with ``_amax`` set.""" | ||
| return ( | ||
| isinstance(q, TensorQuantizer) | ||
| and not q._disabled | ||
| and q.is_nvfp4_static | ||
| and getattr(q, "_amax", None) is not None | ||
| ) | ||
|
|
||
|
|
||
| def _collect_grouped_linears(model: nn.Module) -> list[list[nn.Module]]: | ||
| """Collect sibling groups (Q/K/V, gate/up) with calibrated NVFP4-static weight quantizers.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this is only specific to HF it also belongs in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reasoning as in the reply to your other comment about moving HF-specific functions: |
||
| # Inline: layer_utils → quant_utils → model_calib cycle. | ||
| from modelopt.torch.export.layer_utils import _GATE_UP_PAIRS | ||
|
|
||
| # Reuses the existing gate/up pairs and adds Q/K/V (no equivalent constant | ||
| # in export). Single source for the gate/up half avoids parallel lists. | ||
| patterns: tuple[tuple[str, ...], ...] = (("q_proj", "k_proj", "v_proj"), *_GATE_UP_PAIRS) | ||
| groups: list[list[nn.Module]] = [] | ||
| wq_attr = quantizer_attr_names("weight").weight_quantizer | ||
| for parent in model.modules(): | ||
| for sibling_names in patterns: | ||
| members = [ | ||
| child | ||
| for child in (getattr(parent, n, None) for n in sibling_names) | ||
| if child is not None and _is_calibrated_nvfp4_static(getattr(child, wq_attr, None)) | ||
| ] | ||
| if len(members) >= 2: | ||
| groups.append(members) | ||
| return groups | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def _bootstrap_uncalibrated_weight_quantizers(model: nn.Module) -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this during calibration time and how is this different from what we do at export time
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. Two parts: Why calibration-time, not export-time:
That defeats the PR's main feature ("MSE-calibrate every per-expert weight in fused-experts MoE"); dead experts would just go back to max-only. The bootstrap is timed before MSE/sync/promotion to bring dead experts into all three. Reuse of the existing helper: reasonable as a follow-up DRY win, but they differ in detail today:
Factoring out a shared inner step would save a handful of lines but pulls in another |
||
| """Re-run weight calibration on the weight tensor for quantizers missing ``_amax``. | ||
|
|
||
| Covers MoE experts that ``max_calibrate`` skipped (no routed tokens) so MSE | ||
| doesn't drop them and break the gate==up ``weight_scale_2`` export invariant. | ||
| Activation quantizers on those modules remain uncalibrated; emits a warning. | ||
| """ | ||
| name_to_module = dict(model.named_modules()) | ||
| n = 0 | ||
| for module in name_to_module.values(): | ||
| if not isinstance(module, QuantModule): | ||
| continue | ||
| with enable_weight_access_and_writeback(module, model, name_to_module): | ||
| for weight, q in module.iter_weights_for_calibration(): | ||
| if not isinstance(q, TensorQuantizer) or q._disabled or q._dynamic: | ||
| continue | ||
| if q._calibrator is None: | ||
| continue | ||
| if getattr(q, "_amax", None) is not None and not torch.all(q._amax == 0): | ||
| continue | ||
| q.disable_quant() | ||
| q.enable_calib() | ||
| q(weight) | ||
| if q._calibrator.compute_amax() is not None: | ||
| q.load_calib_amax() | ||
| q.enable_quant() | ||
| q.disable_calib() | ||
| if hasattr(q._calibrator, "reset"): | ||
| q._calibrator.reset() | ||
| n += 1 | ||
| if n > 0: | ||
| warnings.warn( | ||
| f"Bootstrapped {n} weight quantizer(s) with no routed calibration tokens; " | ||
| f"their activation quantizers (if any) remain uncalibrated. " | ||
| f"Increase calib size/seq len to activate all experts.", | ||
| stacklevel=2, | ||
| ) | ||
| return n | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def _sync_grouped_weight_global_amax(model: nn.Module) -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is also useful for other algorithms like AWQ and GPTQ. We need to sync the amax for fused layers before some algorithm begins. cc @sychen52 on the design of fused modules.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you move any HF specific functions to plugins/huggingface.py?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The helpers in this section operate on generic types (
The HF plugin's role is registering specialized quant modules (like |
||
| """Unify NVFP4 ``global_amax`` across Q/K/V and gate/up sibling weight quantizers. | ||
|
|
||
| Run after ``max_calibrate``. Sibling discovery is name-based via | ||
| ``_collect_grouped_linears``; non-matching architectures (wqkv, fused | ||
| qkv_proj, DeepSeek variants, single-Linear fused gate_up_proj) silently | ||
| fall back to per-module global_amax. Fused-experts containers already | ||
| share a single quantizer across gate/up halves and need no sync. | ||
| """ | ||
| # quant_utils imports back from this module; top-level would cycle. | ||
| from modelopt.torch.export.quant_utils import preprocess_linear_fusion | ||
|
|
||
| wq_attr = quantizer_attr_names("weight").weight_quantizer | ||
| n_groups = 0 | ||
| for group in _collect_grouped_linears(model): | ||
| for child in group: | ||
| wq = getattr(child, wq_attr) | ||
| if not isinstance(wq, NVFP4StaticQuantizer): | ||
| NVFP4StaticQuantizer.from_tensor_quantizer( | ||
| wq, global_amax=reduce_amax(wq._amax, axis=None) | ||
| ) | ||
| preprocess_linear_fusion(group) | ||
| n_groups += 1 | ||
| return n_groups | ||
|
|
||
|
|
||
| CalibratorFactory: TypeAlias = Callable[ | ||
| [torch.Tensor, int | tuple | list | None, Callable[..., torch.Tensor]], _Calibrator | ||
| ] | ||
|
|
@@ -346,32 +446,23 @@ def mse_calibrate( | |
| See :class:`MseCalibConfig <modelopt.torch.quantization.config.MseCalibConfig>` for | ||
| details on the remaining arguments. | ||
| """ | ||
| # Step 1: First get initial amax using max calibration | ||
| # Step 1: max calibrate, bootstrap dead-expert weight quantizers, | ||
| # unify grouped NVFP4 global_amax so MSE sees a consistent FP8 grid. | ||
| max_calibrate(model, forward_loop, distributed_sync) | ||
| _bootstrap_uncalibrated_weight_quantizers(model) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why the dead-expert weight quantizers does not break the MAX calibration export path with NVFP4 dynamic quantizer. i.e., does
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no. Max calibrate can be covered in the export time. @Edwardf0t1 added the logics that covers the unquantized experts. |
||
| _sync_grouped_weight_global_amax(model) | ||
|
|
||
| # Step 2: Replace calibrators with MseCalibrator for enabled quantizers | ||
| # and identify weight quantizers | ||
| weight_quantizers = [] | ||
| seen_modules = set() | ||
|
|
||
| # Step 2: replace calibrators with MseCalibrator for enabled quantizers. | ||
| for name, module in list(model.named_modules()): | ||
| if isinstance(module, TensorQuantizer) and not module._disabled: | ||
| if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"): | ||
| # Get the initial amax from max calibration | ||
| initial_amax = module._amax.clone().detach() | ||
| is_nvfp4_static = module.is_nvfp4_static | ||
|
|
||
| is_nvfp4_static = ( | ||
| module.is_static_block_quant | ||
| and module._num_bits == (2, 1) | ||
| and module._block_sizes is not None | ||
| and module._block_sizes.get("scale_bits") == (4, 3) | ||
| ) | ||
|
|
||
| if is_nvfp4_static: | ||
| # Compute and set global_amax | ||
| # Promote standalone NVFP4-static quantizers; grouped siblings | ||
| # already promoted by _sync_grouped_weight_global_amax above. | ||
| if is_nvfp4_static and not isinstance(module, NVFP4StaticQuantizer): | ||
| global_amax = reduce_amax(initial_amax, axis=None) | ||
|
|
||
| # Convert to NVFP4StaticQuantizer in-place | ||
| NVFP4StaticQuantizer.from_tensor_quantizer(module, global_amax=global_amax) | ||
|
|
||
| if fp8_scale_sweep: | ||
|
|
@@ -412,52 +503,48 @@ def mse_calibrate( | |
| quant_func=partial(_mse_quant_func, quantizer=module), | ||
| ) | ||
|
|
||
| # Identify weight quantizers by checking if they have corresponding weight parameters | ||
| # Step 3: calibrate weight quantizers via iter_weights_for_calibration. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this comment is too long and specific to fused experts. remove mention of fused experts to make it more generic and in general reduce length of AI comments |
||
| name_to_module = dict(model.named_modules()) | ||
| seen_modules: set[int] = set() | ||
| pbar = tqdm(desc="MSE weight calibration") | ||
| n_calibrated = 0 | ||
| for parent_module in name_to_module.values(): | ||
| if parent_module in seen_modules: | ||
| if id(parent_module) in seen_modules or not isinstance(parent_module, QuantModule): | ||
| continue | ||
| for weight_name in weight_attr_names(parent_module): | ||
| weight_quantizer_name = quantizer_attr_names(weight_name).weight_quantizer | ||
| weight_quantizer = getattr(parent_module, weight_quantizer_name, None) | ||
| if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: | ||
| if getattr(weight_quantizer, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, weight_name, weight_quantizer)) | ||
| seen_modules.add(parent_module) | ||
|
|
||
| # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation | ||
| # This prevents massive memory accumulation seen in large models | ||
| for idx, (parent_module, weight_name, weight_quantizer) in enumerate( | ||
| tqdm(weight_quantizers, desc="MSE weight calibration") | ||
| ): | ||
| # Enable calibration mode for the weight quantizer | ||
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| seen_modules.add(id(parent_module)) | ||
| with enable_weight_access_and_writeback(parent_module, model, name_to_module): | ||
| weight = getattr(parent_module, weight_name) | ||
| weight_quantizer(weight) | ||
| for weight, weight_quantizer in parent_module.iter_weights_for_calibration(): | ||
| if not ( | ||
| isinstance(weight_quantizer, TensorQuantizer) | ||
| and weight_quantizer.is_enabled | ||
| and getattr(weight_quantizer, "_calibrator", None) is not None | ||
| ): | ||
| continue | ||
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| weight_quantizer(weight) | ||
|
|
||
| # IMMEDIATELY compute amax and reset calibrator to free memory | ||
| cal = getattr(weight_quantizer, "_calibrator", None) | ||
| if cal is not None and cal.compute_amax() is not None: | ||
| weight_quantizer.load_calib_amax() | ||
| cal = weight_quantizer._calibrator | ||
| if cal.compute_amax() is not None: | ||
| weight_quantizer.load_calib_amax() | ||
|
|
||
| weight_quantizer.enable_quant() | ||
| weight_quantizer.disable_calib() | ||
| weight_quantizer.enable_quant() | ||
| weight_quantizer.disable_calib() | ||
|
|
||
| # Synchronize ALL CUDA devices before resetting to ensure all async operations complete | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this comment removed?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old comment described the previous |
||
| # This is critical for multi-GPU setups where tensors may be on different devices | ||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
|
|
||
| if cal is not None and hasattr(cal, "reset"): | ||
| cal.reset() | ||
| if hasattr(cal, "reset"): | ||
| cal.reset() | ||
|
|
||
| if (idx + 1) % 10 == 0 and torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| torch.cuda.empty_cache() | ||
| pbar.update(1) | ||
| n_calibrated += 1 | ||
| if n_calibrated % 10 == 0 and torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
| torch.cuda.synchronize(torch.device(f"cuda:{dev_id}")) | ||
| torch.cuda.empty_cache() | ||
| pbar.close() | ||
|
|
||
| if torch.cuda.is_available(): | ||
| for dev_id in range(torch.cuda.device_count()): | ||
|
|
@@ -612,6 +699,8 @@ def forward(self, input, *args, **kwargs): | |
| print_rank_0("local_hessian: Running max calibration for all quantizers...") | ||
| max_calibrate(model, forward_loop, distributed_sync) | ||
|
|
||
| _sync_grouped_weight_global_amax(model) | ||
|
|
||
| # Setup helpers for all quantized linear modules | ||
| name_to_module = dict(model.named_modules()) | ||
| weight_quantizers_info = [] | ||
|
|
@@ -666,14 +755,9 @@ def quant_func(x, amax, quantizer=weight_quantizer): | |
|
|
||
| return xq | ||
|
|
||
| is_nvfp4_static = ( | ||
| weight_quantizer.is_static_block_quant | ||
| and weight_quantizer._num_bits == (2, 1) | ||
| and weight_quantizer._block_sizes is not None | ||
| and weight_quantizer._block_sizes.get("scale_bits") == (4, 3) | ||
| ) | ||
| is_nvfp4_static = weight_quantizer.is_nvfp4_static | ||
|
|
||
| if is_nvfp4_static: | ||
| if is_nvfp4_static and not isinstance(weight_quantizer, NVFP4StaticQuantizer): | ||
| global_amax = reduce_amax(initial_amax, axis=None) | ||
| NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -900,6 +900,24 @@ def forward(self, *args, **kwargs): | |||||
| self._down_proj_linear = False | ||||||
| return super().forward(*args, **kwargs) | ||||||
|
|
||||||
| def iter_weights_for_calibration(self): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this name is misleading, you are only iterating fused_expert weights, not all weights
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an override of the base |
||||||
| """Yield ``(weight_slice, quantizer)`` per-expert pairs. | ||||||
|
|
||||||
| The base impl uses singular ``*_weight_quantizer`` and skips fused- | ||||||
| experts modules, so weight-only calibration never reaches per-expert | ||||||
| quantizers without this override. | ||||||
| """ | ||||||
| for weight_name, quantizers_name in ( | ||||||
| ("gate_up_proj", "gate_up_proj_weight_quantizers"), | ||||||
| ("down_proj", "down_proj_weight_quantizers"), | ||||||
| ): | ||||||
| weight = getattr(self, weight_name, None) | ||||||
| quantizers = getattr(self, quantizers_name, None) | ||||||
| if weight is None or quantizers is None: | ||||||
| continue | ||||||
| for idx, q in enumerate(quantizers): | ||||||
| yield weight[idx], q | ||||||
|
|
||||||
| def fold_weight(self, keep_attrs: bool = False): | ||||||
| """Fold per-expert weight quantizers into the fused 3-D weights. | ||||||
|
|
||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.