Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions onnxruntime/core/optimizer/group_query_attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,70 @@
return rotary_node_1 == nullptr || rotary_node_2 == nullptr || q_node == nullptr || k_node == nullptr || v_node == nullptr;
}

static bool NodeArgExists(const NodeArg* node_arg) {
return node_arg != nullptr && node_arg->Exists();
}

struct RotaryEmbeddingArgs {
NodeArg* cos_cache_arg = nullptr;
NodeArg* sin_cache_arg = nullptr;
NodeArg* position_ids_arg = nullptr;
int64_t interleaved = 0;
int64_t rotary_embedding_dim = 0;
};

static int64_t GetIntAttributeOrDefault(const Node& node, const std::string& attr_name, int64_t default_value) {
const auto* attr = graph_utils::GetNodeAttribute(node, attr_name);
return attr != nullptr ? attr->i() : default_value;
}

static bool TryGetRotaryEmbeddingArgs(Node& rotary_node, RotaryEmbeddingArgs& args) {
if (rotary_node.OpType() != "RotaryEmbedding") {
return false;
}

if (rotary_node.Domain() != kMSDomain && rotary_node.Domain() != kOnnxDomain) {
return false;
}

args.interleaved = GetIntAttributeOrDefault(rotary_node, "interleaved", 0);
args.rotary_embedding_dim = GetIntAttributeOrDefault(rotary_node, "rotary_embedding_dim", 0);
if ((args.interleaved != 0 && args.interleaved != 1) || args.rotary_embedding_dim < 0) {
return false;
}

auto& input_defs = rotary_node.MutableInputDefs();
if (rotary_node.Domain() == kMSDomain) {
// com.microsoft.RotaryEmbedding inputs:
// input, position_ids, cos_cache, sin_cache
if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) ||
!NodeArgExists(input_defs[3])) {
return false;
}
args.position_ids_arg = input_defs[1];
args.cos_cache_arg = input_defs[2];
args.sin_cache_arg = input_defs[3];
return true;
}

if (rotary_node.Domain() == kOnnxDomain) {
// ONNX RotaryEmbedding inputs:
// X, cos_cache, sin_cache, optional position_ids
// If position_ids is omitted, ONNX RotaryEmbedding uses 3D per-batch caches, which are
// incompatible with GroupQueryAttention's 2D rotary cache inputs.
if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) ||
!NodeArgExists(input_defs[3])) {
return false;
}
args.cos_cache_arg = input_defs[1];
args.sin_cache_arg = input_defs[2];
args.position_ids_arg = input_defs[3];
return true;
}

return false;
}

static void FusePreGQANodes(Graph& graph, Node* q_node, Node* k_node, Node* v_node, Node* rotary_node_1, Node* rotary_node_2, Node* new_node, NodeArg& new_node_output_arg) {
graph_utils::MoveAllNodeInputEdges(graph, *q_node, *new_node);

Expand Down Expand Up @@ -318,6 +382,11 @@

NodeArg* cos_cache_arg = nullptr;
NodeArg* sin_cache_arg = nullptr;
NodeArg* position_ids_arg = nullptr;
int64_t rotary_interleaved = 0;
int64_t rotary_embedding_dim = 0;
bool rotary_args_set = false;
bool rotary_args_mismatch = false;
NodeArg* past_key_values_key_arg = node.MutableInputDefs()[3];
NodeArg* past_key_values_value_arg = node.MutableInputDefs()[4];
NodeArg* seqlens_k = node.MutableInputDefs()[5];
Expand All @@ -334,7 +403,8 @@
for (auto pre_gqa_node = node.InputNodesBegin(); pre_gqa_node != node.InputNodesEnd(); ++pre_gqa_node) {
Node& rotary_or_v_node = *graph.GetNode(pre_gqa_node->Index());

if (rotary_or_v_node.OpType() == "RotaryEmbedding") {
RotaryEmbeddingArgs rotary_args;
if (TryGetRotaryEmbeddingArgs(rotary_or_v_node, rotary_args)) {
if (!rotary_node_1) {
rotary_node_1 = &rotary_or_v_node;
} else {
Expand All @@ -357,19 +427,28 @@
}
}

if (cos_cache_arg == nullptr) {
cos_cache_arg = rotary_or_v_node.MutableInputDefs()[2];
}

if (sin_cache_arg == nullptr) {
sin_cache_arg = rotary_or_v_node.MutableInputDefs()[3];
if (!rotary_args_set) {
cos_cache_arg = rotary_args.cos_cache_arg;
sin_cache_arg = rotary_args.sin_cache_arg;
position_ids_arg = rotary_args.position_ids_arg;
rotary_interleaved = rotary_args.interleaved;
rotary_embedding_dim = rotary_args.rotary_embedding_dim;
rotary_args_set = true;
} else if (cos_cache_arg != rotary_args.cos_cache_arg ||
sin_cache_arg != rotary_args.sin_cache_arg ||
position_ids_arg != rotary_args.position_ids_arg ||
rotary_interleaved != rotary_args.interleaved ||
rotary_embedding_dim != rotary_args.rotary_embedding_dim) {
rotary_args_mismatch = true;
}
} else if (rotary_or_v_node.OpType() == "MatMulNBits" || rotary_or_v_node.OpType() == "MatMul") {
v_node = &rotary_or_v_node;
}
}

if (CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node)) {
if (rotary_args_mismatch ||
CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node) ||
cos_cache_arg == nullptr || sin_cache_arg == nullptr) {
// Some of the required pre-GQA nodes required for fusion were not retrieved,
// this can be expected if the model has extra nodes in between MatMuls and rotary embeddings.
continue;
Expand Down Expand Up @@ -489,11 +568,13 @@
FusePreGQANodes(graph, q_node, k_node, v_node, rotary_node_1, rotary_node_2, mat_mul_or_n_bits_new_node, matmul_or_nbits_output);

node.GetMutableAttributes()["do_rotary"] = ONNX_NAMESPACE::MakeAttribute("do_rotary", static_cast<int64_t>(1));
node.GetMutableAttributes()["rotary_interleaved"] =
ONNX_NAMESPACE::MakeAttribute("rotary_interleaved", rotary_interleaved);

std::string empty_name;
auto& empty_node_arg = graph.GetOrCreateNodeArg(empty_name, nullptr);

const std::array gqa_input_defs{
std::vector<NodeArg*> gqa_input_defs{

Check warning on line 577 in onnxruntime/core/optimizer/group_query_attention_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/group_query_attention_fusion.cc:577: Add #include <vector> for vector<> [build/include_what_you_use] [4]
&matmul_or_nbits_output,
&empty_node_arg,
&empty_node_arg,
Expand All @@ -503,10 +584,16 @@
total_seq_len,
cos_cache_arg,
sin_cache_arg};
if (position_ids_arg != nullptr) {
gqa_input_defs.push_back(position_ids_arg);
}

auto& gqa_input_args = node.MutableInputArgsCount();
gqa_input_args[7] = 1;
gqa_input_args[8] = 1;
if (position_ids_arg != nullptr) {
gqa_input_args[9] = 1;
}

// Switch GQA input defs from unfused into the fused form.
auto& gqa_node_input_defs = node.MutableInputDefs();
Expand Down
196 changes: 196 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,143 @@ static void TestGQAFusion(const std::basic_string<ORTCHAR_T>& file_path, int mat
ASSERT_TRUE(op_to_count["com.microsoft.GroupQueryAttention"] == 1);
}

static void BuildOnnxRotaryEmbeddingGQAFusionGraph(ModelTestBuilder& builder,
bool include_position_ids,
int64_t q_interleaved = 0,
int64_t k_interleaved = 0,
int64_t rotary_embedding_dim = 0) {
constexpr int64_t batch_size = 1;
constexpr int64_t sequence_length = 2;
constexpr int64_t input_hidden_size = 8;
constexpr int64_t num_heads = 2;
constexpr int64_t kv_num_heads = 1;
const int64_t head_size = rotary_embedding_dim == 0 ? 16 : 32;
const int64_t q_hidden_size = num_heads * head_size;
const int64_t kv_hidden_size = kv_num_heads * head_size;
constexpr int64_t max_sequence_length = 8;
const int64_t half_rotary_dim = (rotary_embedding_dim == 0 ? head_size : rotary_embedding_dim) / 2;

auto make_weight = [&builder](int64_t rows, int64_t cols, float value) {
return builder.MakeInitializer<MLFloat16>(
{rows, cols}, std::vector<MLFloat16>(static_cast<size_t>(rows * cols), MLFloat16(value)));
};

NodeArg* input = builder.MakeInput<MLFloat16>({{batch_size, sequence_length, input_hidden_size}});
NodeArg* q_weight = make_weight(input_hidden_size, q_hidden_size, 0.5f);
NodeArg* k_weight = make_weight(input_hidden_size, kv_hidden_size, 0.25f);
NodeArg* v_weight = make_weight(input_hidden_size, kv_hidden_size, 0.125f);

NodeArg* q_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
NodeArg* v_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
builder.AddNode("MatMul", {input, q_weight}, {q_matmul_out});
builder.AddNode("MatMul", {input, k_weight}, {k_matmul_out});
builder.AddNode("MatMul", {input, v_weight}, {v_matmul_out});

const std::vector<int64_t> cache_shape = include_position_ids
? std::vector<int64_t>{max_sequence_length, half_rotary_dim}
: std::vector<int64_t>{batch_size, sequence_length, half_rotary_dim};
NodeArg* cos_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* sin_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* position_ids = include_position_ids
? builder.MakeInput<int64_t>({{batch_size, sequence_length}})
: nullptr;

NodeArg* q_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});

std::vector<NodeArg*> q_rotary_inputs{q_matmul_out, cos_cache, sin_cache};
std::vector<NodeArg*> k_rotary_inputs{k_matmul_out, cos_cache, sin_cache};
if (position_ids != nullptr) {
q_rotary_inputs.push_back(position_ids);
k_rotary_inputs.push_back(position_ids);
}

Node& q_rotary = builder.AddNode("RotaryEmbedding", q_rotary_inputs, {q_rotary_out}, kOnnxDomain);
q_rotary.AddAttribute("num_heads", num_heads);
q_rotary.AddAttribute("interleaved", q_interleaved);
q_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);
Node& k_rotary = builder.AddNode("RotaryEmbedding", k_rotary_inputs, {k_rotary_out}, kOnnxDomain);
k_rotary.AddAttribute("num_heads", kv_num_heads);
k_rotary.AddAttribute("interleaved", k_interleaved);
k_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);

NodeArg* past_key =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* past_value =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* seqlens_k = builder.MakeInput<int32_t>({{batch_size}});
NodeArg* total_sequence_length = builder.MakeInput<int32_t>({{1}});
NodeArg* gqa_output =
builder.MakeOutput<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});

Node& gqa = builder.AddNode("GroupQueryAttention",
{q_rotary_out, k_rotary_out, v_matmul_out, past_key, past_value,
seqlens_k, total_sequence_length},
{gqa_output},
kMSDomain);
gqa.AddAttribute("num_heads", num_heads);
gqa.AddAttribute("kv_num_heads", kv_num_heads);
}

static Status CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(Graph& graph, int64_t expected_rotary_interleaved) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 0);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 1);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 10);
TEST_RETURN_IF_NOT(node.InputDefs()[7] != nullptr && node.InputDefs()[7]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[8] != nullptr && node.InputDefs()[8]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[9] != nullptr && node.InputDefs()[9]->Exists());

const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr != attrs.end());
TEST_RETURN_IF_NOT(do_rotary_attr->second.i() == 1);

auto rotary_interleaved_attr = attrs.find("rotary_interleaved");
TEST_RETURN_IF_NOT(rotary_interleaved_attr != attrs.end());
TEST_RETURN_IF_NOT(rotary_interleaved_attr->second.i() == expected_rotary_interleaved);
}

return Status::OK();
}

static Status CheckOnnxRotaryEmbeddingGQAFused(Graph& graph) {
return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 0);
}

static Status CheckOnnxRotaryEmbeddingGQANotFused(Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 2);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 3);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 7);
const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr == attrs.end() || do_rotary_attr->second.i() == 0);
}

return Status::OK();
}

static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count,
int skip_ln_count, int cast_count, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
Expand Down Expand Up @@ -796,6 +933,65 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) {
TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_different_head_sizes.onnx", 1, 0, logger_.get());
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQAFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 1, 1);
};

auto post_graph_checker = [](Graph& graph) {
return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 1);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
post_graph_checker));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingPartialRotaryTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 0, 16);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQAFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedMismatchTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 1);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQANotFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingNoPositionIdsTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, false);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQANotFused));
}

TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());
Expand Down
Loading