[AIMIGRAPHX-1017] Skip Q/DQ for Attention Ops#4900
Conversation
|
PR currently in draft as I am looking for some feedback on the premise of this change before I go and work on/clean up the actual implementation. |
There was a problem hiding this comment.
Pull request overview
This PR updates the FP8 quantization pipeline to avoid inserting Q/DQ pairs into dot -> softmax -> dot attention subgraphs so those regions can be fused later (improving performance and reducing scratch usage), and factors the attention-pattern matcher into a reusable header.
Changes:
- Extracts a reusable
match::dot_softmax_dotmatcher and uses it in GPU prefusion matching. - Adds a
skip_instructionsmechanism tocapture_arguments_passand wires it through FP8 quantization. - Detects attention regions in
quantize_fp8()and skips capture/QDQ insertion for those instructions.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/targets/gpu/prefuse_ops.cpp | Switches attention prefusion matching to the new shared dot_softmax_dot matcher. |
| src/quantize_8bits.cpp | Skips inserting capture ops (and therefore Q/DQ) for a provided set of instructions. |
| src/quantization.cpp | Detects attention regions and passes them into the capture pass to skip Q/DQ insertion. |
| src/include/migraphx/quantize_8bits.hpp | Extends capture_arguments_pass API to carry a skip set. |
| src/include/migraphx/match/dot_softmax_dot.hpp | Introduces a reusable matcher for undecomposed attention (dot -> softmax -> dot). |
| struct MIGRAPHX_EXPORT capture_arguments_pass | ||
| { | ||
| std::unordered_set<std::string> ins_names = {"dot", "convolution"}; | ||
| std::function<void(std::size_t, std::vector<argument>)> f{}; | ||
| std::size_t* param_index = nullptr; | ||
| std::unordered_set<instruction_ref> skip_instructions{}; | ||
| std::string name() const { return "capture_arguments"; } |
| /// Match the (undecomposed) `dot -> softmax -> dot` attention pattern, with | ||
| /// optional `mul` (scale), `add` (bias), or `where` (mask) ops between the | ||
| /// first dot and the softmax. This is the form before `rewrite_reduce` | ||
| /// decomposes softmax into its `div(exp(sub(x, max)), sum(exp(...)))` chain. | ||
| /// | ||
| /// `gemm_pred` is applied to both dot operations; pass `match::any()` to | ||
| /// match any dot. `bias_pred` is applied to the optional `add` (bias) op. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #4900 +/- ##
========================================
Coverage 92.88% 92.88%
========================================
Files 587 588 +1
Lines 30348 30365 +17
========================================
+ Hits 28187 28204 +17
Misses 2161 2161
🚀 New features to boost your workflow:
|
| std::unordered_set<instruction_ref> skip_instructions; | ||
| for(auto ins : iterator_for(*mm)) | ||
| { | ||
| auto r = match::match_instruction(*mm, ins, match::dot_softmax_dot()); |
There was a problem hiding this comment.
I'm assuming we just want to always assume this will be in the dot_softmax_dot form to match and use then?
There was a problem hiding this comment.
Yes that is the assumption, let me know if you think this may not always be true.
TedThemistokleous
left a comment
There was a problem hiding this comment.
Some questions more about the matcher predicate for the no input args case
The rest makes sense to me though - You're moving things to a separate file to reuse for the quant step. Just not sure if the match::any() is needed as its to broad instead of having something a bit more specific.
If we're assuming we have gemm->softmax->gemm already setup here then we can write something specific. let me know if this is incorrect though and you're using any() as a way to just grab everything
| return match::name("dot")(gemm_pred.bind("gemm2"))(match::arg(0)(softmax)); | ||
| } | ||
|
|
||
| inline auto dot_softmax_dot() { return dot_softmax_dot(match::any(), match::any()); } |
There was a problem hiding this comment.
Does this have to be match::any(), is the match::any? This seems too broad here.
There was a problem hiding this comment.
I'm a little confused about writing something more specific. What constraints can you put on gemm and bias pred here?
Motivation
Skips inserting Q/DQ pairs for attention patterns so they can be fused later on.
Technical Details
Before:
After:
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable