Skip to content

Add rotary embedding onnx domain support#29261

Open
xiaoyu-work wants to merge 3 commits into
mainfrom
xiaoyu/re
Open

Add rotary embedding onnx domain support#29261
xiaoyu-work wants to merge 3 commits into
mainfrom
xiaoyu/re

Conversation

@xiaoyu-work

Copy link
Copy Markdown
Contributor

Description

Mobius exports standard ONNX rotary embedding op. Adding support for this.

Motivation and Context

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the GroupQueryAttention fusion optimizer pass to recognize standard (ONNX-domain) RotaryEmbedding nodes when extracting the cos_cache/sin_cache inputs needed to fuse rotary embedding into the com.microsoft.GroupQueryAttention node.

Changes:

  • Adds a helper to retrieve cos_cache/sin_cache NodeArgs for both com.microsoft.RotaryEmbedding and ONNX-domain RotaryEmbedding (different input ordering).
  • Updates the fusion pattern-matching logic to use that helper and to require that rotary cache inputs were successfully identified before fusing.

Comment thread onnxruntime/core/optimizer/group_query_attention_fusion.cc
Comment thread onnxruntime/core/optimizer/group_query_attention_fusion.cc Outdated
@titaiwangms

Copy link
Copy Markdown
Contributor

cc @tianleiwu

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

The change extends GroupQueryAttentionFusion to also match the standard ONNX-domain RotaryEmbedding (X, cos_cache, sin_cache, position_ids) in addition to com.microsoft.RotaryEmbedding, plumbing position_ids through to GQA input 9 and setting the input-arg-count. The approach is sound:

  • Requiring position_ids to be present for the ONNX-domain path (and bailing otherwise) correctly avoids the 3D per-batch cos/sin cache form that GQA's 2D rotary cache validation cannot consume.
  • The position_ids_arg_mismatch guard and the added cos_cache_arg == nullptr || sin_cache_arg == nullptr checks make the fusion safely skip ambiguous/mixed cases.
  • MutableInputArgsCount()[9] is in-bounds because the GQA schema declares formal inputs up to index 11, so UpdateInputArgCount() sizes the vector accordingly.
  • Both the fused (with position_ids) and non-fused (omitted position_ids) cases are covered by new tests.

Main concern

Rotary interleaved / rotary_embedding_dim attributes are not validated or propagated. GQA's do_rotary path runs non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and the fusion never sets it). A standard ONNX RotaryEmbedding with interleaved=1, or a partial-rotary node with rotary_embedding_dim > 0 (and a correspondingly narrower cos/sin cache), is silently fused into a GQA that applies a different rotation — producing incorrect results with no error. Since this PR specifically targets a standard-ONNX export path where interleaved RoPE is common, the fusion should either verify interleaved == 0 and rotary_embedding_dim == 0 (full rotary) before matching, or propagate interleaved to GQA's rotary_interleaved. Inline comment below. (Note: the pre-existing com.microsoft.RotaryEmbedding path has the same latent gap.)

Minor

  • Only position_ids is checked for consistency between the two rotary nodes; cos_cache/sin_cache are taken from whichever rotary node is visited first without verifying the second uses the same caches. In practice Q/K share caches, but an explicit equality guard (mirroring the position_ids_arg_mismatch check) would make the fusion robust to malformed graphs.

return false;
}
cos_cache_arg = input_defs[1];
sin_cache_arg = input_defs[2];

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TryGetRotaryEmbeddingArgs matches an ONNX RotaryEmbedding purely by op type, domain, and input presence, but never inspects the interleaved or rotary_embedding_dim attributes. GQA's do_rotary applies non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and is not set anywhere in this pass). So a node with interleaved=1, or partial rotary (rotary_embedding_dim > 0 with a narrower cos/sin cache), would be silently fused into a GQA that computes a different rotation — a silent numerical mismatch rather than a hard failure.

Consider rejecting the match when interleaved != 0 or rotary_embedding_dim != 0 (full rotary only), or propagate interleaved onto the fused GQA's rotary_interleaved attribute. The same gap exists for the com.microsoft.RotaryEmbedding branch above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by propagating interleaved  to GQA’s rotary_interleaved and rejecting Q/K RotaryEmbedding mismatches.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants