diff --git a/onnxruntime/core/optimizer/group_query_attention_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_fusion.cc index f6bfd29315c58..134b63ea99c44 100644 --- a/onnxruntime/core/optimizer/group_query_attention_fusion.cc +++ b/onnxruntime/core/optimizer/group_query_attention_fusion.cc @@ -248,6 +248,70 @@ static bool CheckIfAnyOfRequiredGQANodesDoesNotExist(Node* rotary_node_1, Node* return rotary_node_1 == nullptr || rotary_node_2 == nullptr || q_node == nullptr || k_node == nullptr || v_node == nullptr; } +static bool NodeArgExists(const NodeArg* node_arg) { + return node_arg != nullptr && node_arg->Exists(); +} + +struct RotaryEmbeddingArgs { + NodeArg* cos_cache_arg = nullptr; + NodeArg* sin_cache_arg = nullptr; + NodeArg* position_ids_arg = nullptr; + int64_t interleaved = 0; + int64_t rotary_embedding_dim = 0; +}; + +static int64_t GetIntAttributeOrDefault(const Node& node, const std::string& attr_name, int64_t default_value) { + const auto* attr = graph_utils::GetNodeAttribute(node, attr_name); + return attr != nullptr ? attr->i() : default_value; +} + +static bool TryGetRotaryEmbeddingArgs(Node& rotary_node, RotaryEmbeddingArgs& args) { + if (rotary_node.OpType() != "RotaryEmbedding") { + return false; + } + + if (rotary_node.Domain() != kMSDomain && rotary_node.Domain() != kOnnxDomain) { + return false; + } + + args.interleaved = GetIntAttributeOrDefault(rotary_node, "interleaved", 0); + args.rotary_embedding_dim = GetIntAttributeOrDefault(rotary_node, "rotary_embedding_dim", 0); + if ((args.interleaved != 0 && args.interleaved != 1) || args.rotary_embedding_dim < 0) { + return false; + } + + auto& input_defs = rotary_node.MutableInputDefs(); + if (rotary_node.Domain() == kMSDomain) { + // com.microsoft.RotaryEmbedding inputs: + // input, position_ids, cos_cache, sin_cache + if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) || + !NodeArgExists(input_defs[3])) { + return false; + } + args.position_ids_arg = input_defs[1]; + args.cos_cache_arg = input_defs[2]; + args.sin_cache_arg = input_defs[3]; + return true; + } + + if (rotary_node.Domain() == kOnnxDomain) { + // ONNX RotaryEmbedding inputs: + // X, cos_cache, sin_cache, optional position_ids + // If position_ids is omitted, ONNX RotaryEmbedding uses 3D per-batch caches, which are + // incompatible with GroupQueryAttention's 2D rotary cache inputs. + if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) || + !NodeArgExists(input_defs[3])) { + return false; + } + args.cos_cache_arg = input_defs[1]; + args.sin_cache_arg = input_defs[2]; + args.position_ids_arg = input_defs[3]; + return true; + } + + return false; +} + static void FusePreGQANodes(Graph& graph, Node* q_node, Node* k_node, Node* v_node, Node* rotary_node_1, Node* rotary_node_2, Node* new_node, NodeArg& new_node_output_arg) { graph_utils::MoveAllNodeInputEdges(graph, *q_node, *new_node); @@ -318,6 +382,11 @@ Status GroupQueryAttentionFusion::ApplyImpl( NodeArg* cos_cache_arg = nullptr; NodeArg* sin_cache_arg = nullptr; + NodeArg* position_ids_arg = nullptr; + int64_t rotary_interleaved = 0; + int64_t rotary_embedding_dim = 0; + bool rotary_args_set = false; + bool rotary_args_mismatch = false; NodeArg* past_key_values_key_arg = node.MutableInputDefs()[3]; NodeArg* past_key_values_value_arg = node.MutableInputDefs()[4]; NodeArg* seqlens_k = node.MutableInputDefs()[5]; @@ -334,7 +403,8 @@ Status GroupQueryAttentionFusion::ApplyImpl( for (auto pre_gqa_node = node.InputNodesBegin(); pre_gqa_node != node.InputNodesEnd(); ++pre_gqa_node) { Node& rotary_or_v_node = *graph.GetNode(pre_gqa_node->Index()); - if (rotary_or_v_node.OpType() == "RotaryEmbedding") { + RotaryEmbeddingArgs rotary_args; + if (TryGetRotaryEmbeddingArgs(rotary_or_v_node, rotary_args)) { if (!rotary_node_1) { rotary_node_1 = &rotary_or_v_node; } else { @@ -357,19 +427,28 @@ Status GroupQueryAttentionFusion::ApplyImpl( } } - if (cos_cache_arg == nullptr) { - cos_cache_arg = rotary_or_v_node.MutableInputDefs()[2]; - } - - if (sin_cache_arg == nullptr) { - sin_cache_arg = rotary_or_v_node.MutableInputDefs()[3]; + if (!rotary_args_set) { + cos_cache_arg = rotary_args.cos_cache_arg; + sin_cache_arg = rotary_args.sin_cache_arg; + position_ids_arg = rotary_args.position_ids_arg; + rotary_interleaved = rotary_args.interleaved; + rotary_embedding_dim = rotary_args.rotary_embedding_dim; + rotary_args_set = true; + } else if (cos_cache_arg != rotary_args.cos_cache_arg || + sin_cache_arg != rotary_args.sin_cache_arg || + position_ids_arg != rotary_args.position_ids_arg || + rotary_interleaved != rotary_args.interleaved || + rotary_embedding_dim != rotary_args.rotary_embedding_dim) { + rotary_args_mismatch = true; } } else if (rotary_or_v_node.OpType() == "MatMulNBits" || rotary_or_v_node.OpType() == "MatMul") { v_node = &rotary_or_v_node; } } - if (CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node)) { + if (rotary_args_mismatch || + CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node) || + cos_cache_arg == nullptr || sin_cache_arg == nullptr) { // Some of the required pre-GQA nodes required for fusion were not retrieved, // this can be expected if the model has extra nodes in between MatMuls and rotary embeddings. continue; @@ -489,11 +568,13 @@ Status GroupQueryAttentionFusion::ApplyImpl( FusePreGQANodes(graph, q_node, k_node, v_node, rotary_node_1, rotary_node_2, mat_mul_or_n_bits_new_node, matmul_or_nbits_output); node.GetMutableAttributes()["do_rotary"] = ONNX_NAMESPACE::MakeAttribute("do_rotary", static_cast(1)); + node.GetMutableAttributes()["rotary_interleaved"] = + ONNX_NAMESPACE::MakeAttribute("rotary_interleaved", rotary_interleaved); std::string empty_name; auto& empty_node_arg = graph.GetOrCreateNodeArg(empty_name, nullptr); - const std::array gqa_input_defs{ + std::vector gqa_input_defs{ &matmul_or_nbits_output, &empty_node_arg, &empty_node_arg, @@ -503,10 +584,16 @@ Status GroupQueryAttentionFusion::ApplyImpl( total_seq_len, cos_cache_arg, sin_cache_arg}; + if (position_ids_arg != nullptr) { + gqa_input_defs.push_back(position_ids_arg); + } auto& gqa_input_args = node.MutableInputArgsCount(); gqa_input_args[7] = 1; gqa_input_args[8] = 1; + if (position_ids_arg != nullptr) { + gqa_input_args[9] = 1; + } // Switch GQA input defs from unfused into the fused form. auto& gqa_node_input_defs = node.MutableInputDefs(); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index d5aaf0bb2d2ee..ce9f65eaca39f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -607,6 +607,143 @@ static void TestGQAFusion(const std::basic_string& file_path, int mat ASSERT_TRUE(op_to_count["com.microsoft.GroupQueryAttention"] == 1); } +static void BuildOnnxRotaryEmbeddingGQAFusionGraph(ModelTestBuilder& builder, + bool include_position_ids, + int64_t q_interleaved = 0, + int64_t k_interleaved = 0, + int64_t rotary_embedding_dim = 0) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t input_hidden_size = 8; + constexpr int64_t num_heads = 2; + constexpr int64_t kv_num_heads = 1; + const int64_t head_size = rotary_embedding_dim == 0 ? 16 : 32; + const int64_t q_hidden_size = num_heads * head_size; + const int64_t kv_hidden_size = kv_num_heads * head_size; + constexpr int64_t max_sequence_length = 8; + const int64_t half_rotary_dim = (rotary_embedding_dim == 0 ? head_size : rotary_embedding_dim) / 2; + + auto make_weight = [&builder](int64_t rows, int64_t cols, float value) { + return builder.MakeInitializer( + {rows, cols}, std::vector(static_cast(rows * cols), MLFloat16(value))); + }; + + NodeArg* input = builder.MakeInput({{batch_size, sequence_length, input_hidden_size}}); + NodeArg* q_weight = make_weight(input_hidden_size, q_hidden_size, 0.5f); + NodeArg* k_weight = make_weight(input_hidden_size, kv_hidden_size, 0.25f); + NodeArg* v_weight = make_weight(input_hidden_size, kv_hidden_size, 0.125f); + + NodeArg* q_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, q_hidden_size}); + NodeArg* k_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + NodeArg* v_matmul_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + builder.AddNode("MatMul", {input, q_weight}, {q_matmul_out}); + builder.AddNode("MatMul", {input, k_weight}, {k_matmul_out}); + builder.AddNode("MatMul", {input, v_weight}, {v_matmul_out}); + + const std::vector cache_shape = include_position_ids + ? std::vector{max_sequence_length, half_rotary_dim} + : std::vector{batch_size, sequence_length, half_rotary_dim}; + NodeArg* cos_cache = builder.MakeInput(cache_shape); + NodeArg* sin_cache = builder.MakeInput(cache_shape); + NodeArg* position_ids = include_position_ids + ? builder.MakeInput({{batch_size, sequence_length}}) + : nullptr; + + NodeArg* q_rotary_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, q_hidden_size}); + NodeArg* k_rotary_out = + builder.MakeIntermediate(std::vector{batch_size, sequence_length, kv_hidden_size}); + + std::vector q_rotary_inputs{q_matmul_out, cos_cache, sin_cache}; + std::vector k_rotary_inputs{k_matmul_out, cos_cache, sin_cache}; + if (position_ids != nullptr) { + q_rotary_inputs.push_back(position_ids); + k_rotary_inputs.push_back(position_ids); + } + + Node& q_rotary = builder.AddNode("RotaryEmbedding", q_rotary_inputs, {q_rotary_out}, kOnnxDomain); + q_rotary.AddAttribute("num_heads", num_heads); + q_rotary.AddAttribute("interleaved", q_interleaved); + q_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); + Node& k_rotary = builder.AddNode("RotaryEmbedding", k_rotary_inputs, {k_rotary_out}, kOnnxDomain); + k_rotary.AddAttribute("num_heads", kv_num_heads); + k_rotary.AddAttribute("interleaved", k_interleaved); + k_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); + + NodeArg* past_key = + builder.MakeInput({{batch_size, kv_num_heads, max_sequence_length, head_size}}); + NodeArg* past_value = + builder.MakeInput({{batch_size, kv_num_heads, max_sequence_length, head_size}}); + NodeArg* seqlens_k = builder.MakeInput({{batch_size}}); + NodeArg* total_sequence_length = builder.MakeInput({{1}}); + NodeArg* gqa_output = + builder.MakeOutput(std::vector{batch_size, sequence_length, q_hidden_size}); + + Node& gqa = builder.AddNode("GroupQueryAttention", + {q_rotary_out, k_rotary_out, v_matmul_out, past_key, past_value, + seqlens_k, total_sequence_length}, + {gqa_output}, + kMSDomain); + gqa.AddAttribute("num_heads", num_heads); + gqa.AddAttribute("kv_num_heads", kv_num_heads); +} + +static Status CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(Graph& graph, int64_t expected_rotary_interleaved) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 0); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 1); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") { + continue; + } + + TEST_RETURN_IF_NOT(node.InputDefs().size() == 10); + TEST_RETURN_IF_NOT(node.InputDefs()[7] != nullptr && node.InputDefs()[7]->Exists()); + TEST_RETURN_IF_NOT(node.InputDefs()[8] != nullptr && node.InputDefs()[8]->Exists()); + TEST_RETURN_IF_NOT(node.InputDefs()[9] != nullptr && node.InputDefs()[9]->Exists()); + + const auto& attrs = node.GetAttributes(); + auto do_rotary_attr = attrs.find("do_rotary"); + TEST_RETURN_IF_NOT(do_rotary_attr != attrs.end()); + TEST_RETURN_IF_NOT(do_rotary_attr->second.i() == 1); + + auto rotary_interleaved_attr = attrs.find("rotary_interleaved"); + TEST_RETURN_IF_NOT(rotary_interleaved_attr != attrs.end()); + TEST_RETURN_IF_NOT(rotary_interleaved_attr->second.i() == expected_rotary_interleaved); + } + + return Status::OK(); +} + +static Status CheckOnnxRotaryEmbeddingGQAFused(Graph& graph) { + return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 0); +} + +static Status CheckOnnxRotaryEmbeddingGQANotFused(Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 2); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 3); + TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") { + continue; + } + + TEST_RETURN_IF_NOT(node.InputDefs().size() == 7); + const auto& attrs = node.GetAttributes(); + auto do_rotary_attr = attrs.find("do_rotary"); + TEST_RETURN_IF_NOT(do_rotary_attr == attrs.end() || do_rotary_attr->second.i() == 0); + } + + return Status::OK(); +} + static void TestSkipLayerNormFusion(const std::basic_string& file_path, int add_count, int ln_count, int skip_ln_count, int cast_count, logging::Logger* logger) { std::shared_ptr p_model; @@ -796,6 +933,65 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) { TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_different_head_sizes.onnx", 1, 0, logger_.get()); } +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQAFused)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 1, 1); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 1); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + post_graph_checker)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingPartialRotaryTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 0, 16); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQAFused)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedMismatchTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 1); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQANotFused)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingNoPositionIdsTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, false); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_, + std::make_unique(), + TransformerLevel::Level2, 3, nullptr, + CheckOnnxRotaryEmbeddingGQANotFused)); +} + TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get()); TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());