Add dynamic shape support for TopK#4880
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4880 +/- ##
===========================================
+ Coverage 92.86% 92.87% +0.01%
===========================================
Files 585 585
Lines 30152 30213 +61
===========================================
+ Hits 27998 28059 +61
Misses 2154 2154
🚀 New features to boost your workflow:
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends MIGraphX’s TopK support to handle dynamic input shapes and cases where k is not a compile-time constant (by using a placeholder k derived from the input shape), and adjusts dynamic-output handling in evaluation/GPU lowering to better support tuple-shaped outputs like TopK.
Changes:
- Update ONNX TopK parsing to allow non-constant
kby deriving a placeholderkfrom the input shape (static lens or dynamic max lens). - Update
op::topkto support dynamic shapes at shape-inference/eval time viadyn_output, including runtime clamping ofk. - Adjust GPU/lowering + dyn-output plumbing for tuple/dynamic shapes, and add tests for dynamic TopK scenarios.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| test/verify/test_topk_dynamic.cpp | Adds a verify test exercising TopK with a dynamic input shape and placeholder k. |
| test/ref/topk.cpp | Adds reference tests for k > n with dynamic input and for k == n behavior. |
| src/targets/gpu/lowering.cpp | Switches dynamic-shape detection to any_of_dynamic() to account for tuple outputs. |
| src/targets/gpu/include/migraphx/gpu/hip.hpp | Minor whitespace-only adjustment. |
| src/targets/gpu/compile_ops.cpp | Alters dynamic code-object execution to always compute/reshape runtime output shape and adds a skip-on-empty heuristic. |
| src/rewrite_topk.cpp | Disables large-TopK rewrite when input shape is dynamic. |
| src/program.cpp | Makes tracing more robust when evaluating submodules from dynamic code-object ops. |
| src/onnx/parse_topk.cpp | Removes the “k must be constant” restriction; derives placeholder k from input shape when needed. |
| src/include/migraphx/op/topk.hpp | Adds dynamic-shape support and switches compute to dyn_output, clamping k at runtime. |
| src/include/migraphx/dyn_output.hpp | Uses any_of_dynamic() so tuple outputs with dynamic subshapes get runtime-computed shapes. |
| auto input_shape = args.at(0)->get_shape(); | ||
| auto ndim = input_shape.ndim(); | ||
| auto norm_axis = axis < 0 ? axis + static_cast<int64_t>(ndim) : axis; | ||
| if(input_shape.dynamic()) | ||
| { |
| // k is not constant: use the input dimension along the topk axis | ||
| auto input_shape = args.at(0)->get_shape(); | ||
| auto ndim = input_shape.ndim(); | ||
| auto norm_axis = axis < 0 ? axis + static_cast<int64_t>(ndim) : axis; | ||
| if(input_shape.dynamic()) | ||
| { | ||
| k = input_shape.dyn_dims().at(norm_axis).get_interval().max; | ||
| } | ||
| else | ||
| { | ||
| k = input_shape.lens().at(norm_axis); |
| auto out_shape = pre_op.compute_shape(to_shapes(static_args), module_args); | ||
| static_args[static_args.size() - 1] = output_arg.reshape(out_shape); | ||
| // Skip JIT compilation when dynamic shape resolves to 0 elements at runtime | ||
| if(args.front().get_shape().elements() == 0) | ||
| return static_args.back(); |
| auto topk_ret = info.add_instruction( | ||
| make_op("topk", {{"k", k}, {"axis", axis}, {"largest", largest}}), args.at(0)); | ||
|
|
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
|
Hi @pfultz2 |
| if(arg_k.empty()) | ||
| { | ||
| // k is not constant: use the input dimension along the topk axis |
There was a problem hiding this comment.
Rather than a placeholder value of max_len make the k attribute in topk a std::optional<int64_t>. This would be more obvious what is meant.
| auto ins = r.result; | ||
| auto input = ins->inputs().front(); | ||
| if(input->get_shape().dynamic()) | ||
| return; |
There was a problem hiding this comment.
Use not_dynamic_shape in the matcher rather than checking here.
Motivation
The topk operator previously required a constant k and static input shapes, which will block the SSDMobileNetV2 model from running.
AMDMIGraphX/src/onnx/parse_topk.cpp
Line 47 in 93b8849
This PR adds dynamic input support for the topk op.
Affected model: SSDMobileNetV2
Technical Details
.\bin\migraphx-driver.exe perf .\bin\topk_shape_derived_k.onnx can run success.
onnx file
topk_shape_derived_k.zip
Add a
CHANGELOG.mdentry for any option other thanNot Applicable