Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocm-cmake@dfaa4ddba4dbb2e1c6e9964ce610e2a12fd93f39 --build
ROCm/composable_kernel@ad0db05b040bacda751c65c705261b8a0a7ed25d --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On -DBUILD_TESTING=Off
https://gitlab.com/libeigen/eigen/-/archive/5.0.1/eigen-5.0.1.tar.gz -DBUILD_TESTING=Off -DEIGEN_BUILD_DOC=Off
ROCm/rocMLIR@1f6c4198f74ae62d2f0f72ec5b4f1cc0e3251774 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
ROCm/rocMLIR@c88de61ecf53278491df4f4ae815a212a8468dea -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ add_library(migraphx
target.cpp
tmp_dir.cpp
truncate_float.cpp
fast_mm.cpp
value.cpp
verify_args.cpp
)
Expand Down
2 changes: 2 additions & 0 deletions src/driver/passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fast_mm.hpp>
#include <migraphx/fuse_attention.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
Expand Down Expand Up @@ -74,6 +75,7 @@ static std::unordered_map<std::string, pass> create_passes_lookup()
eliminate_data_type{},
eliminate_identity{},
eliminate_pad{},
fast_mm{},
fuse_attention{},
fuse_pointwise{},
fuse_reduce{},
Expand Down
111 changes: 111 additions & 0 deletions src/fast_mm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fast_mm.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void fast_mm::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() != "convolution")
continue;

const auto out_type = ins->get_shape().type();
if(out_type != shape::float_type)
continue;

const auto& out_shape = ins->get_shape();
if(out_shape.dynamic())
continue;

auto inputs = ins->inputs();
auto x = inputs[0];
auto w = inputs[1];
if(not w->can_eval())
continue;

// The hi/lo split below assumes a single input-channel group.
auto op_val = ins->get_operator().to_value();
if(op_val.contains("group") and op_val.at("group").to<int>() != 1)
continue;

const auto& w_shape = w->get_shape();

// Skip when conv is too small to benefit from fp16. These also tend
// to be precision-sensitive (often follow upstream reductions whose
// small magnitudes mean fp16 input rounding dominates absolute error).
std::size_t reduction = std::accumulate(
w_shape.lens().begin() + 1, w_shape.lens().end(), 1, std::multiplies<>());
if(reduction < skip_small_k)
continue;

// W = W_hi + W_lo where W_hi = fp16-rounded W and W_lo = fp16-rounded
// residual. All folds at compile time since W is constant.
auto w_hi_h =
m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), w);
auto w_hi_f = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), w_hi_h);
auto w_lo_f = m.insert_instruction(ins, make_op("sub"), w, w_hi_f);
auto w_lo_h = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::half_type}}), w_lo_f);
auto w_concat = m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), w_hi_h, w_lo_h);

auto x_h =
m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), x);

// Duplicate X along the input-channel axis without copying: insert a
// size-1 axis, broadcast it to 2, then reshape to merge back into the
// channel dim. Same semantics as concat(X_h, X_h) along axis 1.
const auto& x_lens = x_h->get_shape().lens();
std::vector<std::size_t> bc_lens(x_lens.size() + 1);
bc_lens[0] = x_lens[0];
bc_lens[1] = 2;
std::copy(x_lens.begin() + 1, x_lens.end(), bc_lens.begin() + 2);
std::vector<std::int64_t> reshape_dims(x_lens.begin(), x_lens.end());
reshape_dims[1] *= 2;

auto x_unsq = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {1}}}), x_h);
auto x_bc =
m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", bc_lens}}), x_unsq);
auto x_doubled =
m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), x_bc);

auto half_conv = m.insert_instruction(ins, ins->get_operator(), x_doubled, w_concat);
auto converted =
m.insert_instruction(ins, make_op("convert", {{"target_type", out_type}}), half_conv);

m.replace_instruction(ins, converted);
}
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
45 changes: 45 additions & 0 deletions src/include/migraphx/fast_mm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FAST_MM_HPP
#define MIGRAPHX_GUARD_RTGLIB_FAST_MM_HPP

#include <string>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

struct MIGRAPHX_EXPORT fast_mm
{
std::size_t skip_small_k = 64;
std::string name() const { return "fast_mm"; }
void apply(module& m) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
5 changes: 5 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fp8_ocp_to_fnuz.hpp>
#include <migraphx/fast_mm.hpp>
#include <migraphx/fuse_attention.hpp>
#include <migraphx/fuse_concat.hpp>
#include <migraphx/fuse_horizontal.hpp>
Expand Down Expand Up @@ -98,6 +99,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
auto& ctx = any_cast<context>(gctx);
ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
ctx.load_problem_cache();
auto gfx_name = ctx.get_current_device().get_gfx_name();
const bool missing_fp32_mma = starts_with(gfx_name, "gfx11") or starts_with(gfx_name, "gfx12");

// clang-format off
return
Expand Down Expand Up @@ -133,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})},
dead_code_elimination{},
enable_pass(missing_fp32_mma and options.fast_math, fast_mm{}),
dead_code_elimination{},
fuse_horizontal{},
dead_code_elimination{},
prefuse_ops{&ctx},
Expand Down
154 changes: 154 additions & 0 deletions test/fast_mm_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fast_mm.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/module.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>

#include <test.hpp>

static void run_pass(migraphx::module& m, migraphx::fast_mm fmm = {})
{ migraphx::run_passes(m, {fmm, migraphx::dead_code_elimination{}}); }

TEST_CASE(fp32_convolution_const_weights_rewritten)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};
std::vector<float> w_data(ws.elements(), 0.5f);

migraphx::module m1;
{
auto x = m1.add_parameter("x", xs);
auto w = m1.add_literal(migraphx::literal{ws, w_data});
auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w);
m1.add_return({conv});
}
run_pass(m1, {.skip_small_k = 0});

migraphx::module m2;
{
auto x = m2.add_parameter("x", xs);
auto w = m2.add_literal(migraphx::literal{ws, w_data});

auto w_hi_h = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w);
auto w_hi_f = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), w_hi_h);
auto w_lo_f = m2.add_instruction(migraphx::make_op("sub"), w, w_hi_f);
auto w_lo_h = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w_lo_f);
auto w_concat =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w_hi_h, w_lo_h);

auto x_h = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x);
auto x_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), x_h);
auto x_bc = m2.add_instruction(
migraphx::make_op("multibroadcast",
{{"out_lens", std::vector<std::size_t>{1, 2, 3, 8, 8}}}),
x_unsq);
auto x_doubled = m2.add_instruction(
migraphx::make_op("reshape", {{"dims", std::vector<std::int64_t>{1, 6, 8, 8}}}), x_bc);

auto conv = m2.add_instruction(migraphx::make_op("convolution"), x_doubled, w_concat);
auto out = m2.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), conv);
m2.add_return({out});
}
EXPECT(m1 == m2);
}

TEST_CASE(fp32_convolution_tiny_unchanged)
{
// 11 outputs * 8 reduction = 88 ops — too small to benefit from fp16
// acceleration, and tiny conv outputs are precision-sensitive.
migraphx::shape xs{migraphx::shape::float_type, {1, 8, 1, 1}};
migraphx::shape ws{migraphx::shape::float_type, {11, 8, 1, 1}};
std::vector<float> w_data(ws.elements(), 0.5f);

migraphx::module m1;
{
auto x = m1.add_parameter("x", xs);
auto w = m1.add_literal(migraphx::literal{ws, w_data});
auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w);
m1.add_return({conv});
}
auto m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(fp32_convolution_param_weights_unchanged)
{
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", xs);
auto w = m1.add_parameter("w", ws);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w);
m1.add_return({conv});
}
auto m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(fp16_convolution_unchanged)
{
migraphx::shape xs{migraphx::shape::half_type, {1, 3, 8, 8}};
migraphx::shape ws{migraphx::shape::half_type, {4, 3, 3, 3}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", xs);
auto w = m1.add_parameter("w", ws);
auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w);
m1.add_return({conv});
}
auto m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

TEST_CASE(non_convolution_unchanged)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};

migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto y = m1.add_parameter("y", s);
auto add = m1.add_instruction(migraphx::make_op("add"), x, y);
m1.add_return({add});
}
auto m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }