diff --git a/src/include/migraphx/dyn_output.hpp b/src/include/migraphx/dyn_output.hpp index ac3263cde3b..8e4b7f1d529 100644 --- a/src/include/migraphx/dyn_output.hpp +++ b/src/include/migraphx/dyn_output.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,7 +52,7 @@ struct compute_output_shape operator dyn_output() const { return ins_inputs([](const auto& x, shape ins_shape, const std::vector& inputs) { - if(ins_shape.dynamic()) + if(ins_shape.any_of_dynamic()) // some op returns a tuple shape e.g. TopK return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))}; return dyn_output{ins_shape, ins_shape}; }); diff --git a/src/include/migraphx/op/topk.hpp b/src/include/migraphx/op/topk.hpp index 5ff9393e24c..3d72a2e0c66 100644 --- a/src/include/migraphx/op/topk.hpp +++ b/src/include/migraphx/op/topk.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -60,16 +61,33 @@ struct topk shape normalize_compute_shape(std::vector inputs) const { - check_shapes{inputs, *this}.has(1, 2); - auto lens = inputs.at(0).lens(); + check_shapes{inputs, *this, true}.has(1, 2); auto type = inputs.at(0).type(); - lens[axis] = k; + if(inputs.at(0).dynamic()) + { + auto dyn_dims = inputs.at(0).dyn_dims(); + auto min_lens_vec = inputs.at(0).min_lens(); + auto max_lens_vec = inputs.at(0).max_lens(); + auto min_kk = std::min(static_cast(k), min_lens_vec[axis]); + auto max_kk = std::min(static_cast(k), max_lens_vec[axis]); + dyn_dims[axis] = {min_kk, max_kk}; - shape s_val{type, lens}; - shape s_ind{shape::int64_type, lens}; + shape s_val{type, dyn_dims}; + shape s_ind{shape::int64_type, dyn_dims}; + return shape({s_val, s_ind}); + } + else + { + auto lens = inputs.at(0).lens(); + // Clamp k to input size: k may be a placeholder (max dim) from parse time + auto kk = std::min(static_cast(k), lens[axis]); + lens[axis] = kk; - return shape({s_val, s_ind}); + shape s_val{type, lens}; + shape s_ind{shape::int64_type, lens}; + return shape({s_val, s_ind}); + } } template @@ -84,13 +102,15 @@ struct topk }; } - argument compute(const shape& output_shape, std::vector args) const + argument compute(const dyn_output& dyn_out, std::vector args) const { + const auto& output_shape = dyn_out.computed_shape; const auto& vec_ss = output_shape.sub_shapes(); argument res_val{vec_ss.front()}; argument res_ind{vec_ss.back()}; auto in_val = args.front(); auto relements = in_val.get_shape().lens()[axis]; + auto actual_k = std::min(static_cast(k), relements); auto make_indices = [&](const auto& m_idx) { return [&](int64_t i) { if(args.size() < 2) @@ -118,20 +138,20 @@ struct topk }); if(this->largest) std::partial_sort(data.begin(), - data.begin() + k, + data.begin() + actual_k, data.end(), compare_pair(std::greater<>{})); else std::partial_sort(data.begin(), - data.begin() + k, + data.begin() + actual_k, data.end(), compare_pair(std::less<>{})); std::transform(data.begin(), - data.begin() + this->k, + data.begin() + actual_k, y.begin(), [](const auto& p) { return p.first; }); std::transform(data.begin(), - data.begin() + this->k, + data.begin() + actual_k, y_ind.begin(), [](const auto& p) { return p.second; }); }); diff --git a/src/onnx/parse_topk.cpp b/src/onnx/parse_topk.cpp index 66ab9f7ad95..81853857136 100644 --- a/src/onnx/parse_topk.cpp +++ b/src/onnx/parse_topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,18 +40,6 @@ struct parse_topk : op_parser onnx_parser::node_info info, std::vector args) const { - int64_t k = 0; - if(args.size() == 2) - { - auto arg_k = args.at(1)->eval(); - check_arg_empty(arg_k, "PARSE_TopK: k input must be constant"); - k = arg_k.at(); - } - else if(contains(info.attributes, "k")) - { - k = info.attributes.at("k").i(); - } - bool largest = true; if(contains(info.attributes, "largest")) { @@ -64,6 +52,35 @@ struct parse_topk : op_parser axis = parser.parse_value(info.attributes.at("axis")).at(); } + int64_t k = 0; + if(args.size() == 2) + { + auto arg_k = args.at(1)->eval(); + if(arg_k.empty()) + { + // k is not constant: use the input dimension along the topk axis + auto input_shape = args.at(0)->get_shape(); + auto ndim = input_shape.ndim(); + auto norm_axis = axis < 0 ? axis + static_cast(ndim) : axis; + if(input_shape.dynamic()) + { + k = input_shape.dyn_dims().at(norm_axis).get_interval().max; + } + else + { + k = input_shape.lens().at(norm_axis); + } + } + else + { + k = arg_k.at(); + } + } + else if(contains(info.attributes, "k")) + { + k = info.attributes.at("k").i(); + } + auto topk_ret = info.add_instruction( make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); diff --git a/src/program.cpp b/src/program.cpp index 617a215f361..41889d7a1e0 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -671,7 +671,13 @@ std::vector program::eval(const parameter_map& params, if(trace_level > 0) { ctx.finish(); - std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; + // The ins_out map is populated from the main module's + // but when dynamic_code_object_op::compute recursively calls generic_eval + // on its runtime sub-module ins_out don't have it. + if(ins_out.find(ins) != ins_out.end()) + std::cout << "Run instruction: " << ins_out.at(ins) << std::endl; + else + std::cout << "Run instruction: " << ins->name() << " (submodule)" << std::endl; } timer t{}; auto result = f(); diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index 19411680db8..f4abd88972a 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -43,6 +43,8 @@ struct find_large_topk { auto ins = r.result; auto input = ins->inputs().front(); + if(input->get_shape().dynamic()) + return; auto op = ins->get_operator().to_value(); auto axis = op["axis"].to(); auto k = op["k"].to(); diff --git a/src/targets/gpu/compile_ops.cpp b/src/targets/gpu/compile_ops.cpp index 52272b1d7af..83f1f84acf1 100644 --- a/src/targets/gpu/compile_ops.cpp +++ b/src/targets/gpu/compile_ops.cpp @@ -159,11 +159,12 @@ struct dynamic_code_object_op return results.front(); } - if(output_arg.get_shape().dynamic()) - { - auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); - static_args[static_args.size() - 1] = output_arg.reshape(out_shape); - } + // static shape code can't be here, remove the check. + auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); + static_args[static_args.size() - 1] = output_arg.reshape(out_shape); + // Skip JIT compilation when dynamic shape resolves to 0 elements at runtime + if(args.front().get_shape().elements() == 0) + return static_args.back(); // Rewrite submodule without dynamic shapes to be used as the IR for compilation module static_submod; diff --git a/src/targets/gpu/include/migraphx/gpu/hip.hpp b/src/targets/gpu/include/migraphx/gpu/hip.hpp index d04e81e218c..c1134e13b5c 100644 --- a/src/targets/gpu/include/migraphx/gpu/hip.hpp +++ b/src/targets/gpu/include/migraphx/gpu/hip.hpp @@ -252,7 +252,6 @@ struct hip_allocate_memory { return get_preallocation(ctx, id); } - void finalize(context& ctx, const shape&, const std::vector&) const { argument a = allocate_gpu(s); diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 5eca58aaa13..f19f2716226 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -239,8 +239,8 @@ struct miopen_apply instruction_ref insert_dynamic_code_object_op(instruction_ref ins) const { assert(ins->get_operator().name() == "gpu::precompile_op"); - - if(not ins->get_shape().dynamic()) + // some op returns a tuple shape e.g. TopK + if(not ins->get_shape().any_of_dynamic()) return ins; return mod->replace_instruction( diff --git a/test/ref/topk.cpp b/test/ref/topk.cpp index 5e2ea0e0246..3f71f4f8879 100644 --- a/test/ref/topk.cpp +++ b/test/ref/topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -145,3 +145,92 @@ TEST_CASE(topk_smallest_custom_indices) std::vector gold_ind = {11, 13, 15, 14, 7, 9, 6, 10, 2, 5, 1, 3}; EXPECT(results.second == gold_ind); } + +// Test k > n with dynamic shapes: k=100 placeholder but runtime input has 5 elements +TEST_CASE(topk_k_greater_than_n_dynamic) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + // Dynamic shape: axis 0 ranges from 1 to 100 + std::vector dds = {{1, 100}}; + migraphx::shape s{migraphx::shape::float_type, dds}; + auto data = mm->add_parameter("data", s); + // k=100 is the max placeholder from parse time + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 0}, {"k", 100}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + p.compile(migraphx::make_target("ref")); + + // Runtime: only 5 elements + std::vector input_data = {3.0f, 1.0f, 4.0f, 1.5f, 2.0f}; + migraphx::shape input_fixed{migraphx::shape::float_type, {5}}; + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(input_fixed, input_data.data()); + auto rets = p.eval(pp); + + std::vector ret_val; + rets.front().visit([&](auto v) { ret_val.assign(v.begin(), v.end()); }); + std::vector ret_ind; + rets.back().visit([&](auto v) { ret_ind.assign(v.begin(), v.end()); }); + + // k=100 clamped to n=5, sorted descending + EXPECT(ret_val.size() == 5u); + std::vector gold_val = {4.0f, 3.0f, 2.0f, 1.5f, 1.0f}; + EXPECT(ret_val == gold_val); + std::vector gold_ind = {2, 0, 4, 3, 1}; + EXPECT(ret_ind == gold_ind); +} + +// Test k == n: k equals the axis dimension, should return all elements sorted +TEST_CASE(topk_k_equals_n) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + migraphx::shape s{migraphx::shape::float_type, {3, 5}}; + auto data = mm->add_parameter("data", s); + // k=5 equals axis=1 dimension of 5 + auto r = mm->add_instruction(migraphx::make_op("topk", {{"axis", 1}, {"k", 5}, {"largest", 0}}), + data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + + p.compile(migraphx::make_target("ref")); + + std::vector input_data = { + 2.1, + 2.3, + 2.0, + 2.5, + 1.9, + 3.3, + 0.2, + 4.5, + 0.1, + 0.8, + 1.0, + 4.5, + 2.1, + 0.8, + 1.5, + }; + migraphx::parameter_map pp; + pp["data"] = migraphx::argument(s, input_data.data()); + auto rets = p.eval(pp); + + std::vector ret_val; + rets.front().visit([&](auto v) { ret_val.assign(v.begin(), v.end()); }); + std::vector ret_ind; + rets.back().visit([&](auto v) { ret_ind.assign(v.begin(), v.end()); }); + + // All 5 elements returned per row, sorted ascending (smallest first) + EXPECT(ret_val.size() == 15u); + std::vector gold_val = { + 1.9, 2.0, 2.1, 2.3, 2.5, 0.1, 0.2, 0.8, 3.3, 4.5, 0.8, 1.0, 1.5, 2.1, 4.5}; + EXPECT(ret_val == gold_val); + std::vector gold_ind = {4, 2, 0, 1, 3, 3, 1, 4, 0, 2, 3, 0, 4, 2, 1}; + EXPECT(ret_ind == gold_ind); +} diff --git a/test/verify/test_topk_dynamic.cpp b/test/verify/test_topk_dynamic.cpp new file mode 100644 index 00000000000..89df61959fd --- /dev/null +++ b/test/verify/test_topk_dynamic.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +// Test k > n with dynamic shapes: k=100 placeholder but runtime input has fewer elements +template +struct test_topk_dynamic : verify_program> +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + std::vector dds = {{1, 100}}; + migraphx::shape s{migraphx::shape::float_type, dds}; + auto data = mm->add_parameter("data", s); + auto r = mm->add_instruction( + migraphx::make_op("topk", {{"axis", 0}, {"k", 100}, {"largest", 1}}), data); + auto r0 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), r); + auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), r); + mm->add_return({r0, r1}); + return p; + } + + std::unordered_map get_test_dims() const + { + return {{"data", migraphx::shape{migraphx::shape::float_type, {N}}}}; + } +}; + +template struct test_topk_dynamic<10>;