diff --git a/requirements.txt b/requirements.txt index d9603f74f8e..a56063feb0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocm-cmake@dfaa4ddba4dbb2e1c6e9964ce610e2a12fd93f39 --build ROCm/composable_kernel@ad0db05b040bacda751c65c705261b8a0a7ed25d --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On -DBUILD_TESTING=Off https://gitlab.com/libeigen/eigen/-/archive/5.0.1/eigen-5.0.1.tar.gz -DBUILD_TESTING=Off -DEIGEN_BUILD_DOC=Off -ROCm/rocMLIR@1f6c4198f74ae62d2f0f72ec5b4f1cc0e3251774 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@c88de61ecf53278491df4f4ae815a212a8468dea -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7a1d00a4faa..1c3faa0761b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -145,6 +145,7 @@ add_library(migraphx target.cpp tmp_dir.cpp truncate_float.cpp + fast_mm.cpp value.cpp verify_args.cpp ) diff --git a/src/driver/passes.cpp b/src/driver/passes.cpp index 0c910a00376..b28fb632bc5 100644 --- a/src/driver/passes.cpp +++ b/src/driver/passes.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -74,6 +75,7 @@ static std::unordered_map create_passes_lookup() eliminate_data_type{}, eliminate_identity{}, eliminate_pad{}, + fast_mm{}, fuse_attention{}, fuse_pointwise{}, fuse_reduce{}, diff --git a/src/fast_mm.cpp b/src/fast_mm.cpp new file mode 100644 index 00000000000..a3fc5e917fd --- /dev/null +++ b/src/fast_mm.cpp @@ -0,0 +1,111 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +void fast_mm::apply(module& m) const +{ + for(auto ins : iterator_for(m)) + { + if(ins->name() != "convolution") + continue; + + const auto out_type = ins->get_shape().type(); + if(out_type != shape::float_type) + continue; + + const auto& out_shape = ins->get_shape(); + if(out_shape.dynamic()) + continue; + + auto inputs = ins->inputs(); + auto x = inputs[0]; + auto w = inputs[1]; + if(not w->can_eval()) + continue; + + // The hi/lo split below assumes a single input-channel group. + auto op_val = ins->get_operator().to_value(); + if(op_val.contains("group") and op_val.at("group").to() != 1) + continue; + + const auto& w_shape = w->get_shape(); + + // Skip when conv is too small to benefit from fp16. These also tend + // to be precision-sensitive (often follow upstream reductions whose + // small magnitudes mean fp16 input rounding dominates absolute error). + std::size_t reduction = std::accumulate( + w_shape.lens().begin() + 1, w_shape.lens().end(), 1, std::multiplies<>()); + if(reduction < skip_small_k) + continue; + + // W = W_hi + W_lo where W_hi = fp16-rounded W and W_lo = fp16-rounded + // residual. All folds at compile time since W is constant. + auto w_hi_h = + m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), w); + auto w_hi_f = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::float_type}}), w_hi_h); + auto w_lo_f = m.insert_instruction(ins, make_op("sub"), w, w_hi_f); + auto w_lo_h = m.insert_instruction( + ins, make_op("convert", {{"target_type", shape::half_type}}), w_lo_f); + auto w_concat = m.insert_instruction(ins, make_op("concat", {{"axis", 1}}), w_hi_h, w_lo_h); + + auto x_h = + m.insert_instruction(ins, make_op("convert", {{"target_type", shape::half_type}}), x); + + // Duplicate X along the input-channel axis without copying: insert a + // size-1 axis, broadcast it to 2, then reshape to merge back into the + // channel dim. Same semantics as concat(X_h, X_h) along axis 1. + const auto& x_lens = x_h->get_shape().lens(); + std::vector bc_lens(x_lens.size() + 1); + bc_lens[0] = x_lens[0]; + bc_lens[1] = 2; + std::copy(x_lens.begin() + 1, x_lens.end(), bc_lens.begin() + 2); + std::vector reshape_dims(x_lens.begin(), x_lens.end()); + reshape_dims[1] *= 2; + + auto x_unsq = m.insert_instruction(ins, make_op("unsqueeze", {{"axes", {1}}}), x_h); + auto x_bc = + m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", bc_lens}}), x_unsq); + auto x_doubled = + m.insert_instruction(ins, make_op("reshape", {{"dims", reshape_dims}}), x_bc); + + auto half_conv = m.insert_instruction(ins, ins->get_operator(), x_doubled, w_concat); + auto converted = + m.insert_instruction(ins, make_op("convert", {{"target_type", out_type}}), half_conv); + + m.replace_instruction(ins, converted); + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/fast_mm.hpp b/src/include/migraphx/fast_mm.hpp new file mode 100644 index 00000000000..04266c298ce --- /dev/null +++ b/src/include/migraphx/fast_mm.hpp @@ -0,0 +1,45 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_RTGLIB_FAST_MM_HPP +#define MIGRAPHX_GUARD_RTGLIB_FAST_MM_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct MIGRAPHX_EXPORT fast_mm +{ + std::size_t skip_small_k = 64; + std::string name() const { return "fast_mm"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3ed3e72033d..701df0a0e7e 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -98,6 +99,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti auto& ctx = any_cast(gctx); ctx.set_exhaustive_tune_flag(options.exhaustive_tune); ctx.load_problem_cache(); + auto gfx_name = ctx.get_current_device().get_gfx_name(); + const bool missing_fp32_mma = starts_with(gfx_name, "gfx11") or starts_with(gfx_name, "gfx12"); // clang-format off return @@ -133,6 +136,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti optimize_module{}, layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})}, dead_code_elimination{}, + enable_pass(missing_fp32_mma and options.fast_math, fast_mm{}), + dead_code_elimination{}, fuse_horizontal{}, dead_code_elimination{}, prefuse_ops{&ctx}, diff --git a/test/fast_mm_test.cpp b/test/fast_mm_test.cpp new file mode 100644 index 00000000000..03aea5ae932 --- /dev/null +++ b/test/fast_mm_test.cpp @@ -0,0 +1,154 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include + +#include + +static void run_pass(migraphx::module& m, migraphx::fast_mm fmm = {}) +{ migraphx::run_passes(m, {fmm, migraphx::dead_code_elimination{}}); } + +TEST_CASE(fp32_convolution_const_weights_rewritten) +{ + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}}; + std::vector w_data(ws.elements(), 0.5f); + + migraphx::module m1; + { + auto x = m1.add_parameter("x", xs); + auto w = m1.add_literal(migraphx::literal{ws, w_data}); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + run_pass(m1, {.skip_small_k = 0}); + + migraphx::module m2; + { + auto x = m2.add_parameter("x", xs); + auto w = m2.add_literal(migraphx::literal{ws, w_data}); + + auto w_hi_h = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w); + auto w_hi_f = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), w_hi_h); + auto w_lo_f = m2.add_instruction(migraphx::make_op("sub"), w, w_hi_f); + auto w_lo_h = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), w_lo_f); + auto w_concat = + m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w_hi_h, w_lo_h); + + auto x_h = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), x); + auto x_unsq = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1}}}), x_h); + auto x_bc = m2.add_instruction( + migraphx::make_op("multibroadcast", + {{"out_lens", std::vector{1, 2, 3, 8, 8}}}), + x_unsq); + auto x_doubled = m2.add_instruction( + migraphx::make_op("reshape", {{"dims", std::vector{1, 6, 8, 8}}}), x_bc); + + auto conv = m2.add_instruction(migraphx::make_op("convolution"), x_doubled, w_concat); + auto out = m2.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), conv); + m2.add_return({out}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(fp32_convolution_tiny_unchanged) +{ + // 11 outputs * 8 reduction = 88 ops — too small to benefit from fp16 + // acceleration, and tiny conv outputs are precision-sensitive. + migraphx::shape xs{migraphx::shape::float_type, {1, 8, 1, 1}}; + migraphx::shape ws{migraphx::shape::float_type, {11, 8, 1, 1}}; + std::vector w_data(ws.elements(), 0.5f); + + migraphx::module m1; + { + auto x = m1.add_parameter("x", xs); + auto w = m1.add_literal(migraphx::literal{ws, w_data}); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp32_convolution_param_weights_unchanged) +{ + migraphx::shape xs{migraphx::shape::float_type, {1, 3, 8, 8}}; + migraphx::shape ws{migraphx::shape::float_type, {4, 3, 3, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", xs); + auto w = m1.add_parameter("w", ws); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp16_convolution_unchanged) +{ + migraphx::shape xs{migraphx::shape::half_type, {1, 3, 8, 8}}; + migraphx::shape ws{migraphx::shape::half_type, {4, 3, 3, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", xs); + auto w = m1.add_parameter("w", ws); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(non_convolution_unchanged) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 3}}; + + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto add = m1.add_instruction(migraphx::make_op("add"), x, y); + m1.add_return({add}); + } + auto m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }