Skip to content

perf(rms_norm): use fused reduce_l2_norm path (~48× faster)#1

Closed
sbryngelson wants to merge 2 commits into
comp-physics:mainfrom
sbryngelson:perf/rms-norm-fast-reduce
Closed

perf(rms_norm): use fused reduce_l2_norm path (~48× faster)#1
sbryngelson wants to merge 2 commits into
comp-physics:mainfrom
sbryngelson:perf/rms-norm-fast-reduce

Conversation

@sbryngelson

Copy link
Copy Markdown
Member

Summary

rms_norm lowered to a reshape to [M,D,1,1] followed by reduce_sum over the
channel axis. That reduction falls off the ANE's fast reduction tile past
~256 rows, so rms_norm ran 16–27× slower than layer_norm — even though
RMS-norm is structurally cheaper (one reduction, no centering) — and scaled
super-linearly with row count.

Measured on M5 (H17s), gamma=1:

shape before after speedup
256×256 160 µs ~110 µs 1.5×
512×512 1791 µs ~110 µs 16×
1024×1024 6882 µs 168 µs 41×
2048×1024 13665 µs 257 µs 53×

Fix

Re-lower through the same fused reduce_l2_norm over the last axis that
l2_norm already uses, since the two are mathematically identical:

rms(x) = x / sqrt(mean(x²)+eps) · g  =  x · √D / sqrt(sum(x²)) · g     (reduce_l2_norm = sqrt(sum x²))

eps becomes a safe-divide floor on the norm (same pattern as l2_norm), and
the √D rescale is folded into the gamma weight, so the emitted op count drops
as well (no [M,D,1,1] reshape, no separate square/reduce_sum/scale/rsqrt chain).

Correctness / tests

  • Max error vs fp32 reference RMS ≤ 0.007 across shapes (existing
    test_nn_blocks::rms_norm_linear_silu tolerance is 0.03).
  • tests/test_builder_guards.py (2D guard) unchanged and passing.
  • Full suite: 527 passed on M5/H17s.

RMS-norm is ubiquitous in modern transformers (LLaMA/Qwen/etc.), so this is a
high-impact, low-risk lowering fix.

The rms_norm lowering reshaped to [M,D,1,1] and ran reduce_sum over the
channel axis, which falls off the ANE's fast reduction tile past ~256 rows:
it ran 16-27x slower than layer_norm (despite RMS being structurally cheaper)
and scaled super-linearly with row count (6882us at 1024x1024, 13.7ms at
2048x1024).

Re-lower through the same fused reduce_l2_norm over the last axis that l2_norm
already uses, since the two are mathematically identical:
    rms(x) = x / sqrt(mean(x^2)+eps) * g = x * sqrt(D) / sqrt(sum(x^2)) * g
eps becomes a safe-divide floor on the norm (as in l2_norm) and the sqrt(D)
rescale is folded into the gamma weight, so the op count drops too.

Measured on M5 (H17s), gamma=1:
    1024x1024: 6882us -> 168us  (41x), max err vs fp32 RMS 0.007
    2048x1024: 13665us -> 257us (53x)
Full pytest suite: 527 passed.
sbryngelson added a commit that referenced this pull request Jun 17, 2026
Drop ASCII-art dividers, colourise example output, ASCII-only source
@sbryngelson sbryngelson deleted the perf/rms-norm-fast-reduce branch June 19, 2026 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant