Skip to content

[Relax][ONNX] Fix LayerNormalization no-bias zero tensor shape and dtype#19772

Merged
tlopex merged 2 commits into
apache:mainfrom
javierdejesusda:fix/19691-onnx-layernorm-nobias-shape
Jun 15, 2026
Merged

[Relax][ONNX] Fix LayerNormalization no-bias zero tensor shape and dtype#19772
tlopex merged 2 commits into
apache:mainfrom
javierdejesusda:fix/19691-onnx-layernorm-nobias-shape

Conversation

@javierdejesusda

Copy link
Copy Markdown
Contributor

Root cause

In the ONNX LayerNormalization spec the bias B is optional; when omitted it should behave as
zeros shaped and typed like the scale W. In LayerNormalization._impl_v17, the synthesized zero
bias instead took its shape from data.struct_info.shape[1] (an unrelated data dim) and hardcoded
dtype="float32". For input [2, 3, 4, 8] with scale [8] and axis=-1 this builds a bias of
shape (3,) while gamma is (8,), so relax.op.nn.layer_norm raises a size-mismatch
InternalError. The float32 hardcode also breaks fp16/bf16 no-bias models, since gamma, beta, and
data must share a dtype. PyTorch's nn.LayerNorm(..., bias=False) exports exactly this no-bias form.

Fix

Derive both the shape and dtype of the synthesized zero bias from the scale, matching the ONNX
semantics for an omitted B and the existing torch frontend
(relax.const(np.zeros(shape), x.struct_info.dtype)):

if bias is None:
    bias = relax.const(_np.zeros(gamma_shape, dtype=scale.struct_info.dtype))

gamma_shape and the _np/get_const_tuple imports are already present. Deriving the dtype from
the scale (rather than the issue's float32-only suggestion) is what also fixes the fp16/bf16 case.

Test plan

Added non-square no-bias regression cases to test_frontend_onnx.py::test_layer_norm (the previous
no-bias case was square, which masked the bug): float32 [2,3,4,8]/scale [8] and float16 with
full check_correctness, plus a bf16 importer-only case (ORT's CPU provider has no bf16
LayerNormalization kernel).

Fixes #19691

When the optional bias input of LayerNormalization is omitted, the zero
bias was built from data.struct_info.shape[1] and hardcoded to float32
instead of following the scale (gamma) tensor. For a non-square input
such as [2, 3, 4, 8] with scale [8], this produced a bias of shape (3,)
while gamma is (8,), so relax.op.nn.layer_norm raised an InternalError
on the size mismatch. For a half-precision model with no bias, the
float32 bias was rejected because gamma, beta, and data must share one
dtype.

Synthesize the zero bias from gamma_shape and the scale dtype, matching
ONNX semantics where an omitted B is treated as zeros shaped and typed
like the scale. Add non-square no-bias regression cases: an fp16 case
checked end to end and a bf16 case checked through the importer, since
ONNX Runtime's CPU provider has no bf16 LayerNormalization kernel.

Fixes apache#19691

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue in the ONNX frontend's LayerNormalization importer when no bias is provided. It updates the fallback bias creation to use the shape and data type of the scale (gamma_shape and scale.struct_info.dtype) instead of hardcoding a float32 array based on the input's second dimension. It also adds corresponding unit tests for non-square inputs, float16, and bfloat16 data types. A review comment correctly points out that using bfloat16 directly in _np.zeros will raise a TypeError in standard NumPy environments, and suggests constructing the NumPy array as float32 and letting relax.const handle the conversion to the target dtype.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread python/tvm/relax/frontend/onnx/onnx_frontend.py Outdated
np.zeros rejects TVM dtype strings that NumPy lacks natively, so
np.zeros(gamma_shape, dtype="bfloat16") raises "data type 'bfloat16' not
understood". relax.const imports ml_dtypes and casts internally, but its
np.zeros argument is evaluated first, so that import is too late. Build
the zeros array with a native dtype and pass the target dtype to
relax.const, matching the existing torch frontend convention.
@tlopex tlopex merged commit 44b55a0 into apache:main Jun 15, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug][Relax][ONNX] LayerNormalization without bias synthesizes a wrong-shape zero tensor and fails

2 participants