Skip to content

[minimax_m3] gate AR+RMSNorm fusion on ATOM_ENABLE_ALLREDUCE_RMSNORM_…#1344

Draft
zejunchen-zejun wants to merge 1 commit into
wuhuikx/atom-m3-bf16-to-mainfrom
zejun/disable_rmsnorm_ar_fusion
Draft

[minimax_m3] gate AR+RMSNorm fusion on ATOM_ENABLE_ALLREDUCE_RMSNORM_…#1344
zejunchen-zejun wants to merge 1 commit into
wuhuikx/atom-m3-bf16-to-mainfrom
zejun/disable_rmsnorm_ar_fusion

Conversation

@zejunchen-zejun

Copy link
Copy Markdown
Collaborator

…FUSION

Previously M3's target-model fusion was hardcoded on (fused_allreduce_gemma_rms_norm fused all-reduce + residual-add + Gemma RMSNorm whenever TP>1), ignoring the ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION env that the draft / deepseek_v2 / glm4_moe already honor — so the env couldn't disable M3's fusion for A/B testing.

Make both halves of the fusion move together under the flag (default on):

  • layernorm.py fused_allreduce_gemma_rms_norm: also require envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION; when off it returns a plain residual-add + Gemma RMSNorm (expects an already-all-reduced input).
  • minimax_m3.py: the RowParallel o_proj (dense + sparse attn), MoE experts, shared_experts and dense MLP down_proj now use reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, so when the env is off they do their own all-reduce and the norm runs unfused.

Default (=1) is byte-equivalent to before. Set ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=0 to disable M3 target fusion (linears all-reduce, norm unfused).

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

…FUSION

Previously M3's target-model fusion was hardcoded on (fused_allreduce_gemma_rms_norm
fused all-reduce + residual-add + Gemma RMSNorm whenever TP>1), ignoring the
ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION env that the draft / deepseek_v2 / glm4_moe
already honor — so the env couldn't disable M3's fusion for A/B testing.

Make both halves of the fusion move together under the flag (default on):
- layernorm.py fused_allreduce_gemma_rms_norm: also require
  envs.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION; when off it returns a plain
  residual-add + Gemma RMSNorm (expects an already-all-reduced input).
- minimax_m3.py: the RowParallel o_proj (dense + sparse attn), MoE experts,
  shared_experts and dense MLP down_proj now use
  reduce_results=not ENABLE_ALLREDUCE_RMSNORM_FUSION, so when the env is off they
  do their own all-reduce and the norm runs unfused.

Default (=1) is byte-equivalent to before. Set ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION=0
to disable M3 target fusion (linears all-reduce, norm unfused).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@zufayu zufayu requested review from JiaoliangYu and ZhangLirong-amd and removed request for ZhangLirong-amd June 26, 2026 06:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant