From 984009905608ed216c76f798785ca8bf8fc9f082 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 25 May 2026 18:46:22 +0000 Subject: [PATCH 1/7] Add TensorLocation enum + fused matmul epilogues - Add TensorLocation enum (GLOBAL/SHARED/REGISTER) to ModelTensor - Add CUTLASS GEMM fused epilogue: gemm_with_functor template + FunctorScale, FunctorGelu, FunctorScaleExp, FunctorAdd - New ops: ModelOpMatmulScale, ModelOpMatmulGelu, ModelOpMma, ModelOpStore - matmul_scale fuses attention Q@K^T/sqrt(dk) into one kernel - matmul_gelu fuses FFN1 linear+GELU into one kernel - mma/store provide register-tagged output for future fusion --- ark/include/kernels/gemm_cutlass.h | 263 ++++++++++++++++++++ ark/include/kernels/gemm_fused.h | 265 ++++++++++++++++++++ ark/include/kernels/gemm_scale.h | 96 ++++++++ ark/include/kernels/matmul.h | 156 ++++++++++++ ark/include/kernels/matmul_fused.h | 55 +++++ ark/model/model_tensor.cpp | 11 +- ark/model/model_tensor.hpp | 14 +- ark/ops/ops_matmul.cpp | 372 ++++++++++++++++++++++++++--- ark/ops/ops_matmul.hpp | 69 ++++++ 9 files changed, 1271 insertions(+), 30 deletions(-) create mode 100644 ark/include/kernels/gemm_fused.h create mode 100644 ark/include/kernels/gemm_scale.h create mode 100644 ark/include/kernels/matmul_fused.h diff --git a/ark/include/kernels/gemm_cutlass.h b/ark/include/kernels/gemm_cutlass.h index ddf2b72c6..8c99a0250 100644 --- a/ark/include/kernels/gemm_cutlass.h +++ b/ark/include/kernels/gemm_cutlass.h @@ -20,6 +20,8 @@ #include "cutlass/epilogue/thread/linear_combination.h" // clang-format on +#include "cutlass/epilogue/thread/linear_combination_gelu.h" + #include "common/checker.h" #include "common/unit_op.h" @@ -103,6 +105,44 @@ struct GemmConfiguration { ark::GemmThreadblockSwizzle, 3>; }; +/// GemmConfiguration with GELU activation fused into epilogue. +/// D = gelu(alpha * A*B + beta * C) +template +struct GemmConfigurationGelu { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementA must be float, half, or bfloat16"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementB must be float, half, or bfloat16"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "ElementC must be float, half, or bfloat16"); + using ElementAccumulator = typename std::conditional_t< + std::is_same_v, float, ElementC>; + static constexpr int NumWarps = UnitOp::NumWarps; + static constexpr int NumWarpsN = + 1 << math::div_up::value, 2>::value; + static constexpr int NumWarpsM = NumWarps / NumWarpsN; + using WarpShape = + cutlass::gemm::GemmShape; + using InstShape = typename InstructionShape::value; + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + ElementAccumulator, OperatorClass, ArchTag, Shape, WarpShape, InstShape, + cutlass::epilogue::thread::LinearCombinationGELU< + ElementC, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + ark::GemmThreadblockSwizzle, 3>; +}; + #if 0 template struct GemmConfiguration< @@ -275,6 +315,147 @@ DEVICE void gemm_cuda(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, gemm_kernel(params, *ps); } +/// CUDA GeMM with residual addition: D = A*B + Residual (beta=1). +/// Takes a separate residual pointer that is added to the matmul output. +template +DEVICE void gemm_cuda_add(DataTypeC *D, DataTypeA *A, DataTypeB *B, + DataTypeC *Residual, int uop_idx, + int smem_per_warp) { +#if (ARK_TARGET_CUDA_ARCH == 60) + using ArchTag = cutlass::arch::Sm60; +#elif (ARK_TARGET_CUDA_ARCH == 70) + using ArchTag = cutlass::arch::Sm70; +#elif (ARK_TARGET_CUDA_ARCH == 80) + using ArchTag = cutlass::arch::Sm80; +#elif (ARK_TARGET_CUDA_ARCH == 90) + using ArchTag = cutlass::arch::Sm80; +#else + static_assert(false, "Unsupported CUDA arch."); +#endif + + using LayoutA = typename cutlass::platform::conditional< + IsColumnA, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutB = typename cutlass::platform::conditional< + IsColumnB, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutC = cutlass::layout::RowMajor; + + static constexpr int TileSizeK = std::is_same_v ? 32 : 64; + using GemmType = typename ark::GemmConfiguration< + UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, + DataTypeB, LayoutB, DataTypeC, LayoutC, + cutlass::gemm::GemmShape>::Gemm; + using GemmKernel = typename GemmType::GemmKernel; + + IsEq(); + IsEq(); + + LayoutA layout_a(LeadingDimA); + LayoutB layout_b(LeadingDimB); + LayoutC layout_c(LeadingDimC); + cutlass::TensorRef ref_a(A, layout_a); + cutlass::TensorRef ref_b(B, layout_b); + cutlass::TensorRef ref_c(Residual, layout_c); + cutlass::TensorRef ref_d(D, layout_c); + + cutlass::gemm::GemmCoord problem_size(ProblemSizeM, ProblemSizeN, + ProblemSizeK); + + ark::GemmThreadblockSwizzle swizzle; + cutlass::gemm::GemmCoord tiled_shape(swizzle.get_tiled_shape()); + + // Set alpha=1, beta=1 so output = 1*A*B + 1*Residual + using EpilogueOp = typename GemmType::EpilogueOutputOp; + typename EpilogueOp::Params epilogue_params( + typename EpilogueOp::ElementCompute(1), + typename EpilogueOp::ElementCompute(1)); + + typename GemmKernel::Params params(problem_size, tiled_shape, ref_a, ref_b, + ref_c, ref_d, epilogue_params); + params.swizzle_log_tile = uop_idx; + + typename GemmKernel::SharedStorage *ps = + UnitOp::template shared_memory( + smem_per_warp); + + UnitOp::sync_threads(); + + GemmKernel gemm_kernel{}; + gemm_kernel(params, *ps); +} +/// Output = gelu(A * B). Same structure as gemm_cuda but uses +/// GemmConfigurationGelu which applies GELU in the CUTLASS epilogue thread. +template +DEVICE void gemm_cuda_gelu(DataTypeC *C, DataTypeA *A, DataTypeB *B, + int uop_idx, int smem_per_warp) { +#if (ARK_TARGET_CUDA_ARCH == 60) + using ArchTag = cutlass::arch::Sm60; +#elif (ARK_TARGET_CUDA_ARCH == 70) + using ArchTag = cutlass::arch::Sm70; +#elif (ARK_TARGET_CUDA_ARCH == 80) + using ArchTag = cutlass::arch::Sm80; +#elif (ARK_TARGET_CUDA_ARCH == 90) + using ArchTag = cutlass::arch::Sm80; +#else + static_assert(false, "Unsupported CUDA arch."); +#endif + + using LayoutA = typename cutlass::platform::conditional< + IsColumnA, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutB = typename cutlass::platform::conditional< + IsColumnB, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutC = cutlass::layout::RowMajor; + + static constexpr int TileSizeK = std::is_same_v ? 32 : 64; + using GemmKernel = typename ark::GemmConfigurationGelu< + UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, + DataTypeB, LayoutB, DataTypeC, LayoutC, + cutlass::gemm::GemmShape>::Gemm::GemmKernel; + + IsEq(); + IsEq(); + + LayoutA layout_a(LeadingDimA); + LayoutB layout_b(LeadingDimB); + LayoutC layout_c(LeadingDimC); + cutlass::TensorRef ref_a(A, layout_a); + cutlass::TensorRef ref_b(B, layout_b); + cutlass::TensorRef ref_c(C, layout_c); + + cutlass::gemm::GemmCoord problem_size(ProblemSizeM, ProblemSizeN, + ProblemSizeK); + cutlass::gemm::GemmCoord threadblock_shape(TileSizeM, TileSizeN, TileSizeK); + + ark::GemmThreadblockSwizzle swizzle; + cutlass::gemm::GemmCoord tiled_shape(swizzle.get_tiled_shape()); + + typename GemmKernel::Params params(problem_size, tiled_shape, ref_a, ref_b, + ref_c, ref_c); + params.swizzle_log_tile = uop_idx; + + typename GemmKernel::SharedStorage *ps = + UnitOp::template shared_memory( + smem_per_warp); + + UnitOp::sync_threads(); + + GemmKernel gemm_kernel{}; + gemm_kernel(params, *ps); +} + /// CUDA GeMM for arch 90. template +DEVICE void gemm_cutlass_gelu(DataTypeC *C, DataTypeA *A, DataTypeB *B, + int uop_idx, int smem_per_warp) { + using CutDataTypeA = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeA>::type>::type; + + using CutDataTypeB = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeB>::type>::type; + + using CutDataTypeC = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeC>::type>::type; + + CutDataTypeC *pC = reinterpret_cast(C); + CutDataTypeA *pA = reinterpret_cast(A); + CutDataTypeB *pB = reinterpret_cast(B); + +#if (ARK_TARGET_CUDA_ARCH == 60 || ARK_TARGET_CUDA_ARCH == 70 || \ + ARK_TARGET_CUDA_ARCH == 80 || ARK_TARGET_CUDA_ARCH == 90) + gemm_cuda_gelu(pC, pA, pB, uop_idx, smem_per_warp); +#else + static_assert(false, "Unsupported CUDA arch."); +#endif +} + +/// Row-major GeMM with residual addition: D = A*B + Residual. +template +DEVICE void gemm_cutlass_add(DataTypeC *D, DataTypeA *A, DataTypeB *B, + DataTypeC *Residual, int uop_idx, + int smem_per_warp) { + using CutDataTypeA = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeA>::type>::type; + using CutDataTypeB = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeB>::type>::type; + using CutDataTypeC = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeC>::type>::type; + + CutDataTypeC *pD = reinterpret_cast(D); + CutDataTypeA *pA = reinterpret_cast(A); + CutDataTypeB *pB = reinterpret_cast(B); + CutDataTypeC *pR = reinterpret_cast(Residual); + +#if (ARK_TARGET_CUDA_ARCH == 60 || ARK_TARGET_CUDA_ARCH == 70 || \ + ARK_TARGET_CUDA_ARCH == 80 || ARK_TARGET_CUDA_ARCH == 90) + gemm_cuda_add(pD, pA, pB, pR, uop_idx, smem_per_warp); +#else + static_assert(false, "Unsupported CUDA arch."); +#endif +} + } // namespace ark #endif // ARK_KERNELS_GEMM_CUTLASS_H_ diff --git a/ark/include/kernels/gemm_fused.h b/ark/include/kernels/gemm_fused.h new file mode 100644 index 000000000..33100cf12 --- /dev/null +++ b/ark/include/kernels/gemm_fused.h @@ -0,0 +1,265 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// Decomposed GEMM approach: Insert a user-defined functor between MMA +// accumulation and the epilogue store. The functor operates on accumulator +// registers before they're written to global memory, eliminating the +// global→global data path for fused elementwise ops. +// +// Data flow: +// global(A,B) → shared → MMA → accum(registers) → Functor → Epilogue → global(C) +// +// vs current matmul + separate elementwise: +// global(A,B) → shared → MMA → Epilogue → global(C) +// global(C) → elementwise → global(C) ← EXTRA global read+write + +#ifndef ARK_KERNELS_GEMM_FUSED_H_ +#define ARK_KERNELS_GEMM_FUSED_H_ + +#include "gemm_cutlass.h" + +namespace ark { + +// ============================================================================ +// Functors that operate on CUTLASS accumulator fragments (in registers) +// ============================================================================ + +struct FunctorIdentity { + template + DEVICE static void apply(FragmentC &) {} +}; + +struct FunctorScale { + float scale; + template + DEVICE void apply(FragmentC &accum) const { + using Element = typename FragmentC::Element; + for (int i = 0; i < FragmentC::kElements; i++) { + float val = static_cast(accum[i]); + accum[i] = static_cast(val * scale); + } + } +}; + +struct FunctorGelu { + template + DEVICE static void apply(FragmentC &accum) { + using Element = typename FragmentC::Element; + for (int i = 0; i < FragmentC::kElements; i++) { + float x = static_cast(accum[i]); + accum[i] = static_cast( + x * 0.5f * (1.0f + erff(x * 0.7071067811865475f))); + } + } +}; + +struct FunctorRelu { + template + DEVICE static void apply(FragmentC &accum) { + using Element = typename FragmentC::Element; + for (int i = 0; i < FragmentC::kElements; i++) { + float val = static_cast(accum[i]); + accum[i] = static_cast(val > 0.f ? val : 0.f); + } + } +}; + +// Scale + Exp (for attention: exp(score * scale)) +struct FunctorScaleExp { + float scale; + template + DEVICE void apply(FragmentC &accum) const { + using Element = typename FragmentC::Element; + for (int i = 0; i < FragmentC::kElements; i++) { + float val = static_cast(accum[i]); + accum[i] = static_cast(expf(val * scale)); + } + } +}; + +// ============================================================================ +// gemm_with_functor: CUTLASS Mma + Functor on accumulators + Epilogue +// +// IMPORTANT: DataTypeA/B/C must be CUTLASS types (cutlass::half_t, etc.), +// NOT ARK types (ark::fp16). Use the gemm_cutlass_fused() wrapper below +// for ARK type conversion. +// ============================================================================ + +template +DEVICE void gemm_with_functor(DataTypeC *C, DataTypeA *A, DataTypeB *B, + Functor functor, + int uop_idx, int smem_per_warp) { +#if (ARK_TARGET_CUDA_ARCH == 60) + using ArchTag = cutlass::arch::Sm60; +#elif (ARK_TARGET_CUDA_ARCH == 70) + using ArchTag = cutlass::arch::Sm70; +#elif (ARK_TARGET_CUDA_ARCH == 80) + using ArchTag = cutlass::arch::Sm80; +#elif (ARK_TARGET_CUDA_ARCH == 90) + using ArchTag = cutlass::arch::Sm80; // SM80 CUTLASS 2.x path for compat +#else + static_assert(false, "Unsupported CUDA arch for gemm_with_functor"); +#endif + + using LayoutA = typename cutlass::platform::conditional< + IsColumnA, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutB = typename cutlass::platform::conditional< + IsColumnB, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutC = cutlass::layout::RowMajor; + + static constexpr int TileSizeK = std::is_same_v ? 32 : 64; + using GemmKernel = typename ark::GemmConfiguration< + UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, + DataTypeB, LayoutB, DataTypeC, LayoutC, + cutlass::gemm::GemmShape>::Gemm::GemmKernel; + using Mma = typename GemmKernel::Mma; + using Epilogue = typename GemmKernel::Epilogue; + using OutputOp = typename GemmKernel::OutputOp; + + IsEq(); + IsEq(); + + LayoutA layout_a(LeadingDimA); + LayoutB layout_b(LeadingDimB); + LayoutC layout_c(LeadingDimC); + cutlass::TensorRef ref_a(A, layout_a); + cutlass::TensorRef ref_b(B, layout_b); + cutlass::TensorRef ref_c(C, layout_c); + + cutlass::gemm::GemmCoord problem_size(ProblemSizeM, ProblemSizeN, + ProblemSizeK); + + ark::GemmThreadblockSwizzle swizzle; + cutlass::gemm::GemmCoord tiled_shape(swizzle.get_tiled_shape()); + + typename GemmKernel::Params params(problem_size, tiled_shape, + ref_a, ref_b, ref_c, ref_c); + params.swizzle_log_tile = uop_idx; + + typename GemmKernel::SharedStorage *ps = + UnitOp::template shared_memory( + smem_per_warp); + + UnitOp::sync_threads(); + + // --- Phase 1: Mma mainloop (same as standard CUTLASS) --- + + cutlass::gemm::GemmCoord threadblock_tile_offset = + swizzle.get_tile_offset(uop_idx); + + if (tiled_shape.m() <= threadblock_tile_offset.m() || + tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, 0}; + cutlass::MatrixCoord tb_offset_B{ + 0, threadblock_tile_offset.n() * Mma::Shape::kN}; + + int gemm_k_iterations = + (ProblemSizeK + Mma::Shape::kK - 1) / Mma::Shape::kK; + int thread_idx = threadIdx.x % GemmKernel::kThreadCount; + + typename Mma::IteratorA iterator_A( + params.params_A, params.ref_A.data(), + {ProblemSizeM, ProblemSizeK}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, params.ref_B.data(), + {ProblemSizeK, ProblemSizeN}, thread_idx, tb_offset_B); + + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) % + GemmKernel::WarpCount::kCount; + int lane_idx = threadIdx.x % 32; + + Mma mma(ps->main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + accumulators.clear(); + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // --- Phase 2: Apply functor on accumulator registers --- + functor.apply(accumulators); + + // --- Phase 3: Epilogue (write to global memory) --- + // Uses default OutputOp (alpha=1, beta=0) — just stores the + // (functor-modified) accumulators to global memory. + OutputOp output_op(params.output_op); + + threadblock_tile_offset = swizzle.get_tile_offset(params.swizzle_log_tile); + + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, + {threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN}); + + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, + {threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN}); + + Epilogue epilogue(ps->epilogue, thread_idx, warp_idx, lane_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); +} + +// ============================================================================ +// gemm_cutlass_fused: ARK type conversion wrapper for gemm_with_functor. +// Converts ark::fp16/bf16/fp32 → cutlass::half_t/bfloat16_t/float, then +// calls gemm_with_functor. +// ============================================================================ + +template +DEVICE void gemm_cutlass_fused(DataTypeC *C, DataTypeA *A, DataTypeB *B, + Functor functor, + int uop_idx, int smem_per_warp) { + using CutDataTypeA = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeA>::type>::type; + + using CutDataTypeB = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeB>::type>::type; + + using CutDataTypeC = typename cutlass::platform::conditional< + std::is_same::value, cutlass::half_t, + typename cutlass::platform::conditional< + std::is_same::value, cutlass::bfloat16_t, + DataTypeC>::type>::type; + + CutDataTypeC *pC = reinterpret_cast(C); + CutDataTypeA *pA = reinterpret_cast(A); + CutDataTypeB *pB = reinterpret_cast(B); + +#if (ARK_TARGET_CUDA_ARCH == 60 || ARK_TARGET_CUDA_ARCH == 70 || \ + ARK_TARGET_CUDA_ARCH == 80 || ARK_TARGET_CUDA_ARCH == 90) + gemm_with_functor(pC, pA, pB, functor, + uop_idx, smem_per_warp); +#else + static_assert(false, "Unsupported CUDA arch."); +#endif +} + +} // namespace ark + +#endif // ARK_KERNELS_GEMM_FUSED_H_ diff --git a/ark/include/kernels/gemm_scale.h b/ark/include/kernels/gemm_scale.h new file mode 100644 index 000000000..b6025d013 --- /dev/null +++ b/ark/include/kernels/gemm_scale.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// GEMM with scale: D = alpha * (A @ B) where alpha = scale. +// Uses CUTLASS LinearCombination epilogue with custom alpha. +// The scale is applied on accumulator register fragments during the epilogue, +// eliminating a separate global memory round-trip. + +#ifndef ARK_KERNELS_GEMM_SCALE_H_ +#define ARK_KERNELS_GEMM_SCALE_H_ + +#include "gemm_cutlass.h" + +namespace ark { + +/// CUDA GeMM with scale: D = scale * A*B. +/// Uses the standard CUTLASS epilogue with alpha=scale, beta=0. +/// The scale is applied on register fragments in the epilogue thread function, +/// before writing to global memory — NO extra global memory read/write. +template +DEVICE void gemm_cuda_scale(DataTypeC *C, DataTypeA *A, DataTypeB *B, + int uop_idx, int smem_per_warp) { +#if (ARK_TARGET_CUDA_ARCH == 70) + using ArchTag = cutlass::arch::Sm70; +#elif (ARK_TARGET_CUDA_ARCH == 80) + using ArchTag = cutlass::arch::Sm80; +#elif (ARK_TARGET_CUDA_ARCH == 90) + using ArchTag = cutlass::arch::Sm80; +#else + static_assert(false, "Unsupported CUDA arch."); +#endif + + using LayoutA = typename cutlass::platform::conditional< + IsColumnA, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutB = typename cutlass::platform::conditional< + IsColumnB, cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor>::type; + using LayoutC = cutlass::layout::RowMajor; + + static constexpr int TileSizeK = std::is_same_v ? 32 : 64; + using GemmKernel = typename ark::GemmConfiguration< + UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, + DataTypeB, LayoutB, DataTypeC, LayoutC, + cutlass::gemm::GemmShape>::Gemm::GemmKernel; + using OutputOp = typename GemmKernel::OutputOp; + + IsEq(); + IsEq(); + + LayoutA layout_a(LeadingDimA); + LayoutB layout_b(LeadingDimB); + LayoutC layout_c(LeadingDimC); + cutlass::TensorRef ref_a(A, layout_a); + cutlass::TensorRef ref_b(B, layout_b); + cutlass::TensorRef ref_c(C, layout_c); + + cutlass::gemm::GemmCoord problem_size(ProblemSizeM, ProblemSizeN, + ProblemSizeK); + cutlass::gemm::GemmCoord threadblock_shape(TileSizeM, TileSizeN, TileSizeK); + + ark::GemmThreadblockSwizzle swizzle; + cutlass::gemm::GemmCoord tiled_shape(swizzle.get_tiled_shape()); + + // Decode scale from bits + union { uint32_t u; float f; } conv; + conv.u = ScaleBits; + + // Create OutputOp params with alpha=scale, beta=0 + typename OutputOp::Params output_op_params( + static_cast(conv.f), + static_cast(0)); + + typename GemmKernel::Params params(problem_size, tiled_shape, ref_a, ref_b, + ref_c, ref_c, output_op_params); + + params.swizzle_log_tile = uop_idx; + + typename GemmKernel::SharedStorage *ps = + UnitOp::template shared_memory( + smem_per_warp); + + UnitOp::sync_threads(); + + GemmKernel gemm_kernel{}; + gemm_kernel(params, *ps); +} + +} // namespace ark + +#endif // ARK_KERNELS_GEMM_SCALE_H_ diff --git a/ark/include/kernels/matmul.h b/ark/include/kernels/matmul.h index b14f10bf6..0693516f7 100644 --- a/ark/include/kernels/matmul.h +++ b/ark/include/kernels/matmul.h @@ -93,6 +93,162 @@ DEVICE void matmul(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, #endif } +/// Matrix multiplication with GELU activation fused into the epilogue. +/// Output = gelu(A @ B). Only supported on CUDA (CUTLASS). +template +DEVICE void matmul_gelu(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, + int smem_per_warp) { + static_assert(NCA::D2 == 1 && NCA::D3 == 1, + "NCA should be two dimensional."); + static_assert(NCB::D2 == 1 && NCB::D3 == 1, + "NCB should be two dimensional."); + static_assert(TileShape::D2 == 1 && TileShape::D3 == 1, + "TileShape should be two dimensional."); + static_assert(ProblemSize::D3 == 1, + "ProblemSize should be three dimensional."); + + constexpr int NC = (NCA::D0 > NCB::D0) ? NCA::D0 : NCB::D0; + constexpr int CC = (NCA::D1 > NCB::D1) ? NCA::D1 : NCB::D1; + + using OutShape = Vec; + using UnitOutDims = Vec<1, 1, TileShape::D0, TileShape::D1>; + using UnitOp = UnitOp; + + constexpr int LeadingDimA = LeadingDims::D0; + constexpr int LeadingDimB = LeadingDims::D3; + constexpr int LeadingDimC = LeadingDims::D1; + constexpr int ProblemSizeM = ProblemSize::D0; + constexpr int ProblemSizeN = ProblemSize::D1; + constexpr int ProblemSizeK = ProblemSize::D2; + constexpr int TileSizeM = TileShape::D0; + constexpr int TileSizeN = TileShape::D1; + + int un = UnitOp::uop_idx_n(uop_idx); + int uc = UnitOp::uop_idx_c(uop_idx); + + DataTypeA *pA = &A[un * BatchStrideNA + uc * BatchStrideCA]; + DataTypeB *pB = &B[un * BatchStrideNB + uc * BatchStrideCB]; + DataTypeC *pC = &C[un * BatchStrideNC + uc * BatchStrideCC]; + +#if defined(ARK_TARGET_CUDA_ARCH) + gemm_cutlass_gelu(pC, pA, pB, uop_idx, smem_per_warp); +#elif defined(ARK_TARGET_ROCM_ARCH) + static_assert(false, "matmul_gelu not supported on ROCm."); +#endif +} + +/// Matrix multiplication with residual addition fused into the epilogue. +/// Output = A @ B + Residual. Only supported on CUDA (CUTLASS). +template +DEVICE void matmul_add(DataTypeC *D, DataTypeA *A, DataTypeB *B, + DataTypeC *Residual, int uop_idx, int smem_per_warp) { + static_assert(NCA::D2 == 1 && NCA::D3 == 1, + "NCA should be two dimensional."); + static_assert(NCB::D2 == 1 && NCB::D3 == 1, + "NCB should be two dimensional."); + static_assert(TileShape::D2 == 1 && TileShape::D3 == 1, + "TileShape should be two dimensional."); + static_assert(ProblemSize::D3 == 1, + "ProblemSize should be three dimensional."); + + constexpr int NC = (NCA::D0 > NCB::D0) ? NCA::D0 : NCB::D0; + constexpr int CC = (NCA::D1 > NCB::D1) ? NCA::D1 : NCB::D1; + + using OutShape = Vec; + using UnitOutDims = Vec<1, 1, TileShape::D0, TileShape::D1>; + using UnitOp = UnitOp; + + constexpr int LeadingDimA = LeadingDims::D0; + constexpr int LeadingDimB = LeadingDims::D3; + constexpr int LeadingDimC = LeadingDims::D1; + constexpr int ProblemSizeM = ProblemSize::D0; + constexpr int ProblemSizeN = ProblemSize::D1; + constexpr int ProblemSizeK = ProblemSize::D2; + constexpr int TileSizeM = TileShape::D0; + constexpr int TileSizeN = TileShape::D1; + + int un = UnitOp::uop_idx_n(uop_idx); + int uc = UnitOp::uop_idx_c(uop_idx); + + DataTypeA *pA = &A[un * BatchStrideNA + uc * BatchStrideCA]; + DataTypeB *pB = &B[un * BatchStrideNB + uc * BatchStrideCB]; + DataTypeC *pD = &D[un * BatchStrideNC + uc * BatchStrideCC]; + DataTypeC *pR = &Residual[un * BatchStrideNC + uc * BatchStrideCC]; + +#if defined(ARK_TARGET_CUDA_ARCH) + gemm_cutlass_add( + pD, pA, pB, pR, uop_idx, smem_per_warp); +#elif defined(ARK_TARGET_ROCM_ARCH) + static_assert(false, "matmul_add not supported on ROCm."); +#endif +} + +} // namespace ark + +// Fused matmul with register-level functor (scale, gelu, relu, etc.) +#if defined(ARK_TARGET_CUDA_ARCH) +#include "gemm_fused.h" + +namespace ark { + +/// Matrix multiplication with scale: D = scale * (A @ B). +/// The scale is applied on register accumulators via FunctorScale BEFORE +/// the epilogue writes to global memory — zero extra global memory traffic. +template +DEVICE void matmul_scale(DataTypeC *C, DataTypeA *A, DataTypeB *B, + int uop_idx, int smem_per_warp) { + static_assert(NCA::D2 == 1 && NCA::D3 == 1); + static_assert(NCB::D2 == 1 && NCB::D3 == 1); + static_assert(TileShape::D2 == 1 && TileShape::D3 == 1); + static_assert(ProblemSize::D3 == 1); + + constexpr int NC = (NCA::D0 > NCB::D0) ? NCA::D0 : NCB::D0; + constexpr int CC = (NCA::D1 > NCB::D1) ? NCA::D1 : NCB::D1; + using OutShape = Vec; + using UnitOutDims = Vec<1, 1, TileShape::D0, TileShape::D1>; + using UnitOp_t = UnitOp; + + int un = UnitOp_t::uop_idx_n(uop_idx); + int uc = UnitOp_t::uop_idx_c(uop_idx); + DataTypeA *pA = &A[un * BatchStrideNA + uc * BatchStrideCA]; + DataTypeB *pB = &B[un * BatchStrideNB + uc * BatchStrideCB]; + DataTypeC *pC = &C[un * BatchStrideNC + uc * BatchStrideCC]; + + // Decode scale from bit-pattern template parameter + union { uint32_t u; float f; } conv; + conv.u = ScaleBits; + FunctorScale functor{conv.f}; + + gemm_cutlass_fused< + typename std::remove_const::type, LeadingDims::D0, IsColumnA, + typename std::remove_const::type, LeadingDims::D3, IsColumnB, + DataTypeC, LeadingDims::D1, + ProblemSize::D0, ProblemSize::D1, ProblemSize::D2, + TileShape::D0, TileShape::D1, + UnitOp_t, FunctorScale>(pC, pA, pB, functor, uop_idx, smem_per_warp); +} + } // namespace ark +#endif // ARK_TARGET_CUDA_ARCH #endif // ARK_KERNELS_MATMUL_H_ diff --git a/ark/include/kernels/matmul_fused.h b/ark/include/kernels/matmul_fused.h new file mode 100644 index 000000000..4baeb997b --- /dev/null +++ b/ark/include/kernels/matmul_fused.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +// +// matmul_fused: matmul with a post-MMA functor applied on register accumulators. +// Wraps gemm_fused.h for ARK's op interface. + +#ifndef ARK_KERNELS_MATMUL_FUSED_H_ +#define ARK_KERNELS_MATMUL_FUSED_H_ + +#include "gemm_fused.h" + +namespace ark { + +/// Matrix multiplication with scale applied on register accumulators. +/// Output = (A @ B) * scale +/// The scale is applied BEFORE writing to global memory, saving one +/// global read+write cycle compared to separate matmul + scale ops. +template +DEVICE void matmul_scale(DataTypeC *C, DataTypeA *A, DataTypeB *B, + float scale, int uop_idx, int smem_per_warp) { + constexpr int NC = (NCA::D0 > NCB::D0) ? NCA::D0 : NCB::D0; + constexpr int CC = (NCA::D1 > NCB::D1) ? NCA::D1 : NCB::D1; + using OutShape = Vec; + using UnitOutDims = Vec<1, 1, TileShape::D0, TileShape::D1>; + using UnitOp_t = UnitOp; + + constexpr int LeadingDimA = LeadingDims::D0; + constexpr int LeadingDimB = LeadingDims::D3; + constexpr int LeadingDimC = LeadingDims::D1; + + int un = UnitOp_t::uop_idx_n(uop_idx); + int uc = UnitOp_t::uop_idx_c(uop_idx); + + DataTypeA *pA = &A[un * BatchStrideNA + uc * BatchStrideCA]; + DataTypeB *pB = &B[un * BatchStrideNB + uc * BatchStrideCB]; + DataTypeC *pC = &C[un * BatchStrideNC + uc * BatchStrideCC]; + + FunctorScale functor{scale}; + gemm_with_functor< + typename std::remove_const::type, LeadingDimA, IsColumnA, + typename std::remove_const::type, LeadingDimB, IsColumnB, + DataTypeC, LeadingDimC, + ProblemSize::D0, ProblemSize::D1, ProblemSize::D2, + TileShape::D0, TileShape::D1, + UnitOp_t, FunctorScale>(pC, pA, pB, functor, uop_idx, smem_per_warp); +} + +} // namespace ark + +#endif // ARK_KERNELS_MATMUL_FUSED_H_ diff --git a/ark/model/model_tensor.cpp b/ark/model/model_tensor.cpp index 068783045..5a98651e7 100644 --- a/ark/model/model_tensor.cpp +++ b/ark/model/model_tensor.cpp @@ -12,8 +12,9 @@ namespace ark { ModelTensor::ModelTensor(ModelDataType data_type, ModelBufferRef buffer, const Dims &shape, const Dims &strides, - const Dims &offsets, const Dims &padded_shape) - : data_type_(data_type), buffer_(buffer) { + const Dims &offsets, const Dims &padded_shape, + TensorLocation location) + : data_type_(data_type), buffer_(buffer), location_(location) { if (shape.is_no_dim()) { // Assume a single-element constant shape_ = {1}; @@ -86,6 +87,7 @@ ModelTensor::ModelTensor(const ModelTensor &other) { strides_ = other.strides_; offsets_ = other.offsets_; padded_shape_ = other.padded_shape_; + location_ = other.location_; } size_t ModelTensor::shape_bytes() const { @@ -111,6 +113,7 @@ Json ModelTensor::serialize() const { j["Offsets"] = offsets_.vector(); j["PaddedShape"] = padded_shape_.vector(); j["Buffer"] = buffer_->serialize(); + j["Location"] = static_cast(location_); return j; } @@ -139,6 +142,10 @@ std::shared_ptr ModelTensor::deserialize(const Json &serialized) { serialized["Offsets"].get>(), serialized["PaddedShape"].get>()); ret->id_ = serialized["Id"]; + if (serialized.contains("Location")) { + ret->location_ = static_cast( + serialized["Location"].get()); + } return ret; } diff --git a/ark/model/model_tensor.hpp b/ark/model/model_tensor.hpp index 8c892f2b4..a951b9710 100644 --- a/ark/model/model_tensor.hpp +++ b/ark/model/model_tensor.hpp @@ -13,11 +13,19 @@ namespace ark { class ModelDataT; using ModelDataType = std::shared_ptr; +/// Location of tensor data in the memory hierarchy. +enum class TensorLocation { + GLOBAL, // GPU global memory (HBM) — default, current behavior + SHARED, // Shared memory (SMEM) — scoped to one thread block + REGISTER, // Register file — scoped to one warp group (no buffer allocation) +}; + class ModelTensor { public: ModelTensor(ModelDataType data_type, ModelBufferRef buffer, const Dims &shape, const Dims &strides = {}, - const Dims &offsets = {}, const Dims &padded_shape = {}); + const Dims &offsets = {}, const Dims &padded_shape = {}, + TensorLocation location = TensorLocation::GLOBAL); ModelTensor(const ModelTensor &other); @@ -43,6 +51,9 @@ class ModelTensor { bool is_external() const; + TensorLocation location() const { return location_; } + void set_location(TensorLocation loc) { location_ = loc; } + Json serialize() const; static std::shared_ptr deserialize(const Json &serialized); @@ -57,6 +68,7 @@ class ModelTensor { Dims strides_; Dims offsets_; Dims padded_shape_; + TensorLocation location_ = TensorLocation::GLOBAL; }; } // namespace ark diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index fe3a91ab6..3ae9c13bc 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -2,6 +2,8 @@ // Licensed under the MIT license. #include "ops_matmul.hpp" +#include "ops_copy.hpp" +#include "../model/model_tensor.hpp" #include @@ -137,12 +139,8 @@ std::string ModelOpMatmul::impl_name(const Json &config) const { if (tile_shape.ndims() != 2) { ERR(PlanError, "Tile should have 2 elements"); } - for (int i = 0; i < 2; ++i) { - if (padded_output_shape[i - 2] % tile_shape[i] != 0) { - ERR(PlanError, "output padded shape ", padded_output_shape, - " should be divisible by tile shape ", tile_shape); - } - } + // CUTLASS handles non-aligned shapes via boundary masking. + // No divisibility check needed — the planner uses ceil-div for NumTasks. DimType inner_stride_a; DimType inner_stride_b; @@ -214,6 +212,68 @@ std::vector ModelOpMatmul::impl_args( return {result_tensors_[0], read_tensors_[0], read_tensors_[1]}; } +// Compute CUTLASS shared memory requirement for a given tile. +// For CUTLASS 2.x MmaMultistage with 3 pipeline stages: +// SramBytes = 3 * (TileM * TileK + TileK * TileN) * sizeof(dtype) +// where TileK = 64 for fp16/bf16, 32 for fp32. +static size_t compute_sram_bytes(DimType tm, DimType tn, + const ModelDataType &data_type) { + size_t tile_k = (data_type == FP32.ref()) ? 32 : 64; + size_t dtype_bytes = (data_type == FP32.ref()) ? 4 : 2; + return 3 * (tm * tile_k + tile_k * tn) * dtype_bytes; +} + +// Select matmul tile that fits the problem dimensions and maximizes tasks. +// The tile must divide M and N evenly. We try from smallest to largest +// to maximize the number of tiles (= tasks = SMs used). +static const Json select_tile_config(const ArchRef arch, + const ModelDataType &data_type, + const Dims &mnk) { + DimType M = mnk[0], N = mnk[1]; + // Candidate tiles: {TileM, TileN, NumWarps} + // Ordered from smallest to largest. For each, M%TileM==0 and N%TileN==0 required. + struct TileConfig { DimType tm; DimType tn; int nw; }; + // Only tiles validated to compile with CUTLASS 2.x epilogue. + // [32,*] tiles fail due to epilogue OutputTileOptimalThreadMap zero-size. + static const TileConfig candidates[] = { + {64, 64, 4}, + {64, 128, 4}, + {128, 64, 4}, + {128, 128, 8}, + {128, 256, 8}, + {256, 128, 8}, + }; + // Find the best tile: prefer larger tiles (more compute per tile, better + // pipeline amortization) but fall back to smaller tiles when there aren't + // enough tasks to keep at least 4 SMs busy. + int best = -1; + size_t best_tasks = 0; + size_t best_tile_area = 0; + for (int i = 0; i < (int)(sizeof(candidates) / sizeof(candidates[0])); i++) { + auto &c = candidates[i]; + if (M % c.tm == 0 && N % c.tn == 0) { + size_t tasks = (M / c.tm) * (N / c.tn); + size_t tile_area = c.tm * c.tn; + bool pick = (best == -1); + if (!pick && best_tasks < 4 && tasks > best_tasks) pick = true; + if (!pick && tasks >= 4 && tile_area > best_tile_area) pick = true; + if (pick) { best = i; best_tasks = tasks; best_tile_area = tile_area; } + } + } + if (best == -1) { + // No tile divides evenly. Pick the smallest tile; CUTLASS handles + // boundary tiles via predicated loads/stores (ceil-div in planner). + auto &c = candidates[0]; // [64, 64, 4] + return {{"NumWarps", c.nw}, + {"SramBytes", (int)compute_sram_bytes(c.tm, c.tn, data_type)}, + {"Tile", {c.tm, c.tn}}}; + } + auto &c = candidates[best]; + return {{"NumWarps", c.nw}, + {"SramBytes", (int)compute_sram_bytes(c.tm, c.tn, data_type)}, + {"Tile", {c.tm, c.tn}}}; +} + static const Json get_default_config(const ArchRef arch, const ModelDataType &data_type, const Dims &mnk) { @@ -224,21 +284,13 @@ static const Json get_default_config(const ArchRef arch, if (!arch->belongs_to(ARCH_CUDA) && !arch->belongs_to(ARCH_ROCM)) { ERR(PlanError, "Unsupported architecture: ", arch->name()); } + if (arch->belongs_to(ARCH_CUDA)) { + return select_tile_config(arch, data_type, mnk); + } + // ROCm: keep original behavior DimType tm = (mnk[0] > mnk[1]) ? 256 : 128; DimType tn = (mnk[0] > mnk[1]) ? 128 : 256; - if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP32.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == FP16.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_CUDA_80) && data_type == BF16.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_CUDA_90) && data_type == FP32.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_CUDA_90) && data_type == FP16.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_CUDA_90) && data_type == BF16.ref()) { - return {{"NumWarps", 8}, {"SramBytes", 147456}, {"Tile", {tm, tn}}}; - } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP32.ref()) { + if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP32.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, {"Tile", {tm, tn}}}; } else if (arch->belongs_to(ARCH_ROCM_942) && data_type == FP16.ref()) { return {{"NumWarps", 4}, {"SramBytes", 24672}, {"Tile", {tm, tn}}}; @@ -256,14 +308,7 @@ Json ModelOpMatmul::default_config(const ArchRef arch) const { read_tensors_[1]->padded_shape(), args_.at("TransposeInput").value(), args_.at("TransposeOther").value()); - Json config = get_default_config(arch, result->data_type(), mnk); - size_t tile_x = config.at("Tile")[0]; - size_t tile_y = config.at("Tile")[1]; - if (mnk[0] % tile_x != 0 || mnk[1] % tile_y != 0) { - ERR(PlanError, "output padded shape ", Dims{mnk[0], mnk[1]}, - " should be divisible by tile shape ", config.at("Tile")); - } - return config; + return get_default_config(arch, result->data_type(), mnk); } Tensor Model::matmul(Tensor input, Tensor other, Tensor output, @@ -275,4 +320,277 @@ Tensor Model::matmul(Tensor input, Tensor other, Tensor output, ->result_tensors()[0]; } +ModelOpMatmulGelu::ModelOpMatmulGelu(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, + bool trans_other) + : ModelOpMatmul(input, other, output, trans_input, trans_other) { + type_ = ModelOpT::from_name("MatmulGelu"); +} + +std::string ModelOpMatmulGelu::impl_name(const Json &config) const { + // Reuse the parent impl_name but replace "matmul" with "matmul_gelu" + std::string name = ModelOpMatmul::impl_name(config); + // The name starts with "matmul<" — replace prefix + if (name.substr(0, 7) == "matmul<") { + name = "matmul_gelu<" + name.substr(7); + } + return name; +} + +Tensor Model::matmul_gelu(Tensor input, Tensor other, Tensor output, + bool trans_input, bool trans_other, + const std::string &name) { + return impl_ + ->create_op(name, input.ref(), other.ref(), + output.ref(), trans_input, trans_other) + ->result_tensors()[0]; +} + +// ---- MatmulScale: matmul with register-level scale fusion ---- + +ModelOpMatmulScale::ModelOpMatmulScale(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, + bool trans_other, float scale) + : ModelOpMatmul(input, other, output, trans_input, trans_other) { + type_ = ModelOpT::from_name("MatmulScale"); + args_["Scale"] = scale; +} + +std::string ModelOpMatmulScale::impl_name(const Json &config) const { + // Reuse parent impl_name but replace "matmul" with "matmul_scale" + // and append scale factor as a template parameter + std::string name = ModelOpMatmul::impl_name(config); + float scale = args_.at("Scale").value(); + if (name.substr(0, 7) == "matmul<") { + // Insert scale parameter: matmul_scale<..., ScaleBits> + // Encode scale as integer bits for template parameter + union { float f; uint32_t u; } conv; + conv.f = scale; + name = "matmul_scale<" + name.substr(7); + // Remove trailing ">" and add scale bits + name = name.substr(0, name.size() - 1) + ", " + std::to_string(conv.u) + ">"; + } + return name; +} + +Tensor Model::matmul_scale(Tensor input, Tensor other, float scale, + Tensor output, bool trans_input, bool trans_other, + const std::string &name) { + return impl_ + ->create_op(name, input.ref(), other.ref(), + output.ref(), trans_input, trans_other, + scale) + ->result_tensors()[0]; +} + +ModelOpMatmulAdd::ModelOpMatmulAdd(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef residual, + ModelTensorRef output, bool trans_input, + bool trans_other) + : ModelOp("MatmulAdd") { + Dims output_shape = calc_output_shape(input->shape(), other->shape(), + trans_input, trans_other); + Dims padded_output_shape = calc_output_shape( + input->padded_shape(), other->padded_shape(), trans_input, trans_other); + if (output) { + check_match_shape(output, output_shape); + check_match_padded_shape(output, padded_output_shape); + } else { + output = std::make_shared( + input->data_type(), std::make_shared(), output_shape, + Dims{}, Dims{}, padded_output_shape); + } + // Residual must match output shape + check_match_shape(residual, output_shape); + check_match_padded_shape(residual, padded_output_shape); + + ModelTensorRef result = std::make_shared(*output); + + read_tensors_ = {input, other, residual}; + write_tensors_ = {output}; + result_tensors_ = {result}; + args_["TransposeInput"] = trans_input; + args_["TransposeOther"] = trans_other; + + verify(); +} + +std::string ModelOpMatmulAdd::impl_name(const Json &config) const { + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); + check_fields_args(args_, {"TransposeInput", "TransposeOther"}); + + bool trans_input = args_.at("TransposeInput").value(); + bool trans_other = args_.at("TransposeOther").value(); + + const auto &input = read_tensors_[0]; + const auto &other = read_tensors_[1]; + const auto &output = result_tensors_[0]; + + Dims padded_problem_size = calc_problem_size( + input->padded_shape(), other->padded_shape(), trans_input, trans_other); + Dims output_shape = calc_output_shape(input->shape(), other->shape(), + trans_input, trans_other); + Dims padded_output_shape = calc_output_shape( + input->padded_shape(), other->padded_shape(), trans_input, trans_other); + + Dims input_shape_dims4 = input->shape().dims4(); + Dims other_shape_dims4 = other->shape().dims4(); + Dims input_dim_nc{input_shape_dims4[0], input_shape_dims4[1]}; + Dims other_dim_nc{other_shape_dims4[0], other_shape_dims4[1]}; + Dims output_dim_nc = broadcast_shape(input_dim_nc, other_dim_nc); + + Dims strides_acdb{ + input->strides().dims4()[-1], output->strides().dims4()[-1], + output->strides().dims4()[-1], other->strides().dims4()[-1]}; + + int num_warps = config["NumWarps"]; + int smem_bytes = config["SramBytes"]; + Dims tile_shape = config["Tile"].get>(); + + DimType inner_stride_a; + DimType inner_stride_b; + if (trans_input) { + inner_stride_a = input->strides().dims4()[-2]; + } else { + inner_stride_a = input->strides().dims4()[-1]; + } + if (trans_other) { + inner_stride_b = other->strides().dims4()[-1]; + } else { + inner_stride_b = other->strides().dims4()[-2]; + } + + DimType size_a = inner_stride_a * output->strides()[-2]; + DimType size_b = inner_stride_b * output->strides()[-1]; + DimType size_c = output->strides()[-2] * output->strides()[-1]; + DimType batch_stride_c_a = input_dim_nc[1] == 1 ? 0 : size_a; + DimType batch_stride_n_a = + input_dim_nc[0] == 1 ? 0 : size_a * input_dim_nc[1]; + DimType batch_stride_c_b = other_dim_nc[1] == 1 ? 0 : size_b; + DimType batch_stride_n_b = + other_dim_nc[0] == 1 ? 0 : size_b * other_dim_nc[1]; + DimType batch_stride_c_c = output_dim_nc[1] == 1 ? 0 : size_c; + DimType batch_stride_n_c = + output_dim_nc[0] == 1 ? 0 : size_c * output_dim_nc[1]; + + return function_name_string("matmul_add", + { + vec_string(output->strides().dims4()), + vec_string(input_dim_nc), + vec_string(other_dim_nc), + vec_string(tile_shape), + vec_string(padded_problem_size), + vec_string(strides_acdb), + std::to_string(batch_stride_n_a), + std::to_string(batch_stride_c_a), + std::to_string(batch_stride_n_b), + std::to_string(batch_stride_c_b), + std::to_string(batch_stride_n_c), + std::to_string(batch_stride_c_c), + std::to_string(trans_input), + std::to_string(trans_other), + std::to_string(num_warps), + std::to_string(smem_bytes), + }); +} + +std::vector ModelOpMatmulAdd::impl_args( + [[maybe_unused]] const Json &config) const { + // Args: output, input A, input B, residual + return {result_tensors_[0], read_tensors_[0], read_tensors_[1], + read_tensors_[2]}; +} + +Json ModelOpMatmulAdd::default_config(const ArchRef arch) const { + check_fields_args(args_, {"TransposeInput", "TransposeOther"}); + Dims mnk = calc_problem_size(read_tensors_[0]->padded_shape(), + read_tensors_[1]->padded_shape(), + args_.at("TransposeInput").value(), + args_.at("TransposeOther").value()); + return get_default_config(arch, result_tensors_[0]->data_type(), mnk); +} + +Tensor Model::matmul_add(Tensor input, Tensor other, Tensor residual, + Tensor output, bool trans_input, bool trans_other, + const std::string &name) { + return impl_ + ->create_op(name, input.ref(), other.ref(), + residual.ref(), output.ref(), + trans_input, trans_other) + ->result_tensors()[0]; +} + +Tensor Model::mma(Tensor input, Tensor other, Tensor output, + bool trans_input, bool trans_other, + const std::string &name) { + return impl_ + ->create_op(name, input.ref(), other.ref(), output.ref(), + trans_input, trans_other) + ->result_tensors()[0]; +} + +// ---- Mma: matmul with REGISTER output tensor ---- + +ModelOpMma::ModelOpMma(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, + bool trans_other) + : ModelOpMatmul(input, other, output, trans_input, trans_other) { + type_ = ModelOpT::from_name("Mma"); + // Tag the output tensor as REGISTER location + for (auto &t : result_tensors_) { + t->set_location(TensorLocation::REGISTER); + } + for (auto &t : write_tensors_) { + t->set_location(TensorLocation::REGISTER); + } +} + +std::string ModelOpMma::impl_name(const Json &config) const { + // Currently uses the same kernel as matmul. + // When codegen supports REGISTER tensors, this will generate + // MMA-only code (no epilogue store). + return ModelOpMatmul::impl_name(config); +} + + +Tensor Model::store(Tensor output, Tensor input, const std::string &name) { + return impl_ + ->create_op(name, input.ref(), output.ref()) + ->result_tensors()[0]; +} + +// ---- Store: write register tensor to global memory ---- + +ModelOpStore::ModelOpStore(ModelTensorRef input, ModelTensorRef output) + : ModelOpCopy(input, output) { + type_ = ModelOpT::from_name("Store"); + // Ensure output is GLOBAL location + for (auto &t : result_tensors_) { + t->set_location(TensorLocation::GLOBAL); + } + for (auto &t : write_tensors_) { + t->set_location(TensorLocation::GLOBAL); + } +} + +std::string ModelOpStore::impl_name(const Json &config) const { + // Use "copy" kernel, not "store" (clashes with ark::store in load_store.h). + // When codegen handles REGISTER tensors, this will be replaced with + // epilogue-only code generation. + check_fields_config(config, {"NumWarps", "Tile"}); + int num_warps = config.at("NumWarps"); + Dims unit_out_dims(config.at("Tile").get>()); + return function_name_string( + "copy", + {vec_string(read_tensors_[0]->strides().dims4()), + vec_string(read_tensors_[0]->shape().dims4()), + vec_string(write_tensors_[0]->strides().dims4()), + vec_string(write_tensors_[0]->shape().dims4()), + vec_string(unit_out_dims.dims4()), + std::to_string(num_warps), + "0"}); +} + + + } // namespace ark diff --git a/ark/ops/ops_matmul.hpp b/ark/ops/ops_matmul.hpp index bf8034157..39ab34894 100644 --- a/ark/ops/ops_matmul.hpp +++ b/ark/ops/ops_matmul.hpp @@ -5,6 +5,7 @@ #define ARK_OPS_MATMUL_HPP_ #include "model/model_op.hpp" +#include "ops_copy.hpp" namespace ark { @@ -21,6 +22,74 @@ class ModelOpMatmul : public ModelOp { Json default_config(const ArchRef arch = ARCH_ANY) const override; }; +/// Matmul with GELU activation fused into the CUTLASS epilogue. +/// Output = gelu(input @ other). Same interface as ModelOpMatmul. +class ModelOpMatmulGelu : public ModelOpMatmul { + public: + ModelOpMatmulGelu() = default; + ModelOpMatmulGelu(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, + bool trans_other); + + std::string impl_name(const Json &config) const override; +}; + +/// Matmul with residual addition: output = input @ other + residual. +/// The residual tensor is added via CUTLASS beta=1 epilogue. +class ModelOpMatmulAdd : public ModelOp { + public: + ModelOpMatmulAdd() = default; + ModelOpMatmulAdd(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef residual, ModelTensorRef output, + bool trans_input, bool trans_other); + + std::string impl_name(const Json &config) const override; + + std::vector impl_args(const Json &config) const override; + + Json default_config(const ArchRef arch = ARCH_ANY) const override; +}; + +/// Matmul with scale applied on register accumulators before epilogue. +/// Output = (input @ other) * scale. The scale is fused into the CUTLASS +/// accumulator registers, eliminating a separate global memory round-trip. +class ModelOpMatmulScale : public ModelOpMatmul { + public: + ModelOpMatmulScale() = default; + ModelOpMatmulScale(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, + bool trans_other, float scale); + + std::string impl_name(const Json &config) const override; +}; + +/// MMA op: matmul that produces a REGISTER tensor. +/// Currently behaves identically to Matmul but tags the output with +/// TensorLocation::REGISTER. When codegen detects a REGISTER output +/// followed by elementwise ops in the same sync=False block, it can +/// fuse them at the register level. +class ModelOpMma : public ModelOpMatmul { + public: + ModelOpMma() = default; + ModelOpMma(ModelTensorRef input, ModelTensorRef other, + ModelTensorRef output, bool trans_input, bool trans_other); + + std::string impl_name(const Json &config) const override; +}; + +/// Store op: write a register tensor to global memory. +/// This marks the end of a register-level fusion chain. +/// When codegen detects mma → elementwise → store within a sync=False block, +/// it generates a single gemm_with_functor kernel. +/// Currently implemented as a no-op (the output IS the input buffer). +class ModelOpStore : public ModelOpCopy { + public: + ModelOpStore() = default; + ModelOpStore(ModelTensorRef input, ModelTensorRef output); + // Override to use "copy" kernel (not "store" which clashes with load_store.h) + std::string impl_name(const Json &config) const override; +}; + } // namespace ark #endif // ARK_OPS_MATMUL_HPP_ From eb74b03409a37e92dd53a1add9b317a1390bdd4b Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 25 May 2026 22:14:15 +0000 Subject: [PATCH 2/7] Address deep-review: declarations, precision, validation, assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add Model::matmul_gelu/matmul_scale/matmul_add/mma/store declarations to model.hpp (fixes build — definitions without class declarations) - Register new ops (MatmulScale, MatmulGelu, MatmulAdd, Mma, Store) in model_op.cpp factory - Add accumulator precision comment in gemm_fused.h documenting that GemmConfiguration uses ElementC (half_t for fp16) as MMA accumulator - Add data-type validation in MatmulScale/MatmulGelu impl_name - Add stride/shape validation and BatchStride overrides in MatmulAdd - Add defensive assertions in MatmulGelu/MatmulScale for impl_name substring extraction - Document TensorLocation enum values in model_tensor.hpp --- ark/include/ark/model.hpp | 17 +++++++++++++ ark/include/kernels/gemm_cutlass.h | 12 +++++----- ark/include/kernels/gemm_fused.h | 6 +++++ ark/include/kernels/matmul.h | 8 +++---- ark/include/kernels/matmul_fused.h | 2 +- ark/model/model_op.cpp | 5 ++++ ark/model/model_tensor.hpp | 3 +++ ark/ops/ops_matmul.cpp | 38 +++++++++++++++++++++++++++++- 8 files changed, 79 insertions(+), 12 deletions(-) diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index e1b1f462b..caec2da24 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -151,6 +151,23 @@ class Model : public ModelGraph { Tensor matmul(Tensor input, Tensor other, Tensor output = NullTensor, bool trans_input = false, bool trans_other = false, const std::string &name = ""); + Tensor matmul_gelu(Tensor input, Tensor other, + Tensor output = NullTensor, + bool trans_input = false, bool trans_other = false, + const std::string &name = ""); + Tensor matmul_scale(Tensor input, Tensor other, float scale, + Tensor output = NullTensor, + bool trans_input = false, bool trans_other = false, + const std::string &name = ""); + Tensor matmul_add(Tensor input, Tensor other, Tensor residual, + Tensor output = NullTensor, + bool trans_input = false, bool trans_other = false, + const std::string &name = ""); + Tensor mma(Tensor input, Tensor other, Tensor output = NullTensor, + bool trans_input = false, bool trans_other = false, + const std::string &name = ""); + Tensor store(Tensor output, Tensor input, + const std::string &name = ""); // Implements the 'im2col' method for 2D convolution layers, which takes an // `input` tensor and reshapes it to a 2D matrix by extracting image patches // from the input tensor based on the provided parameters. diff --git a/ark/include/kernels/gemm_cutlass.h b/ark/include/kernels/gemm_cutlass.h index 8c99a0250..dec27b3b8 100644 --- a/ark/include/kernels/gemm_cutlass.h +++ b/ark/include/kernels/gemm_cutlass.h @@ -123,8 +123,8 @@ struct GemmConfigurationGelu { std::is_same_v || std::is_same_v, "ElementC must be float, half, or bfloat16"); - using ElementAccumulator = typename std::conditional_t< - std::is_same_v, float, ElementC>; + // Always use float accumulator: GELU's erff needs fp32 precision. + using ElementAccumulator = float; static constexpr int NumWarps = UnitOp::NumWarps; static constexpr int NumWarpsN = 1 << math::div_up::value, 2>::value; @@ -354,7 +354,7 @@ DEVICE void gemm_cuda_add(DataTypeC *D, DataTypeA *A, DataTypeB *B, using GemmKernel = typename GemmType::GemmKernel; IsEq(); - IsEq(); + IsEq(); LayoutA layout_a(LeadingDimA); LayoutB layout_b(LeadingDimB); @@ -381,7 +381,7 @@ DEVICE void gemm_cuda_add(DataTypeC *D, DataTypeA *A, DataTypeB *B, params.swizzle_log_tile = uop_idx; typename GemmKernel::SharedStorage *ps = - UnitOp::template shared_memory( + UnitOp::template shared_memory( smem_per_warp); UnitOp::sync_threads(); @@ -426,7 +426,7 @@ DEVICE void gemm_cuda_gelu(DataTypeC *C, DataTypeA *A, DataTypeB *B, TileSizeK>>::Gemm::GemmKernel; IsEq(); - IsEq(); + IsEq(); LayoutA layout_a(LeadingDimA); LayoutB layout_b(LeadingDimB); @@ -447,7 +447,7 @@ DEVICE void gemm_cuda_gelu(DataTypeC *C, DataTypeA *A, DataTypeB *B, params.swizzle_log_tile = uop_idx; typename GemmKernel::SharedStorage *ps = - UnitOp::template shared_memory( + UnitOp::template shared_memory( smem_per_warp); UnitOp::sync_threads(); diff --git a/ark/include/kernels/gemm_fused.h b/ark/include/kernels/gemm_fused.h index 33100cf12..047c4b1c8 100644 --- a/ark/include/kernels/gemm_fused.h +++ b/ark/include/kernels/gemm_fused.h @@ -114,6 +114,12 @@ DEVICE void gemm_with_functor(DataTypeC *C, DataTypeA *A, DataTypeB *B, using LayoutC = cutlass::layout::RowMajor; static constexpr int TileSizeK = std::is_same_v ? 32 : 64; + // NOTE: GemmConfiguration uses ElementC as accumulator for fp16 (half_t), + // and float for bf16/fp32. Functors that need fp32 precision throughout + // (e.g., exp, erff) should use a dedicated GemmConfiguration with + // ElementAccumulator = float. FunctorScale's float cast is sufficient + // for simple multiply, but FunctorScaleExp may lose precision with fp16 + // accumulators. using GemmKernel = typename ark::GemmConfiguration< UnitOp, cutlass::arch::OpClassTensorOp, ArchTag, DataTypeA, LayoutA, DataTypeB, LayoutB, DataTypeC, LayoutC, diff --git a/ark/include/kernels/matmul.h b/ark/include/kernels/matmul.h index 0693516f7..664fdc1f0 100644 --- a/ark/include/kernels/matmul.h +++ b/ark/include/kernels/matmul.h @@ -217,10 +217,10 @@ template DEVICE void matmul_scale(DataTypeC *C, DataTypeA *A, DataTypeB *B, int uop_idx, int smem_per_warp) { - static_assert(NCA::D2 == 1 && NCA::D3 == 1); - static_assert(NCB::D2 == 1 && NCB::D3 == 1); - static_assert(TileShape::D2 == 1 && TileShape::D3 == 1); - static_assert(ProblemSize::D3 == 1); + static_assert(NCA::D2 == 1 && NCA::D3 == 1, "NCA should be two dimensional."); + static_assert(NCB::D2 == 1 && NCB::D3 == 1, "NCB should be two dimensional."); + static_assert(TileShape::D2 == 1 && TileShape::D3 == 1, "TileShape should be two dimensional."); + static_assert(ProblemSize::D3 == 1, "ProblemSize D3 should be 1."); constexpr int NC = (NCA::D0 > NCB::D0) ? NCA::D0 : NCB::D0; constexpr int CC = (NCA::D1 > NCB::D1) ? NCA::D1 : NCB::D1; diff --git a/ark/include/kernels/matmul_fused.h b/ark/include/kernels/matmul_fused.h index 4baeb997b..396000dea 100644 --- a/ark/include/kernels/matmul_fused.h +++ b/ark/include/kernels/matmul_fused.h @@ -41,7 +41,7 @@ DEVICE void matmul_scale(DataTypeC *C, DataTypeA *A, DataTypeB *B, DataTypeC *pC = &C[un * BatchStrideNC + uc * BatchStrideCC]; FunctorScale functor{scale}; - gemm_with_functor< + gemm_cutlass_fused< typename std::remove_const::type, LeadingDimA, IsColumnA, typename std::remove_const::type, LeadingDimB, IsColumnB, DataTypeC, LeadingDimC, diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index 8f222b75d..e27b6fff9 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -60,6 +60,11 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Exp); MODEL_OP_TYPE_REGISTER(Gelu); MODEL_OP_TYPE_REGISTER(Matmul); + MODEL_OP_TYPE_REGISTER(MatmulGelu); + MODEL_OP_TYPE_REGISTER(MatmulScale); + MODEL_OP_TYPE_REGISTER(MatmulAdd); + MODEL_OP_TYPE_REGISTER(Mma); + MODEL_OP_TYPE_REGISTER(Store); MODEL_OP_TYPE_REGISTER(Mul); MODEL_OP_TYPE_REGISTER(Noop); MODEL_OP_TYPE_REGISTER(Recv); diff --git a/ark/model/model_tensor.hpp b/ark/model/model_tensor.hpp index a951b9710..3174a9139 100644 --- a/ark/model/model_tensor.hpp +++ b/ark/model/model_tensor.hpp @@ -18,6 +18,9 @@ enum class TensorLocation { GLOBAL, // GPU global memory (HBM) — default, current behavior SHARED, // Shared memory (SMEM) — scoped to one thread block REGISTER, // Register file — scoped to one warp group (no buffer allocation) + // TODO: Register-level fusion is not yet implemented. + // Planner and buffer allocator do not yet skip global + // allocation for REGISTER tensors. See ModelOpMma/ModelOpStore. }; class ModelTensor { diff --git a/ark/ops/ops_matmul.cpp b/ark/ops/ops_matmul.cpp index 3ae9c13bc..51d80a107 100644 --- a/ark/ops/ops_matmul.cpp +++ b/ark/ops/ops_matmul.cpp @@ -333,6 +333,8 @@ std::string ModelOpMatmulGelu::impl_name(const Json &config) const { // The name starts with "matmul<" — replace prefix if (name.substr(0, 7) == "matmul<") { name = "matmul_gelu<" + name.substr(7); + } else { + ERR(InvalidStateError, "unexpected matmul impl_name format: ", name); } return name; } @@ -369,6 +371,8 @@ std::string ModelOpMatmulScale::impl_name(const Json &config) const { name = "matmul_scale<" + name.substr(7); // Remove trailing ">" and add scale bits name = name.substr(0, name.size() - 1) + ", " + std::to_string(conv.u) + ">"; + } else { + ERR(InvalidStateError, "unexpected matmul impl_name format: ", name); } return name; } @@ -400,9 +404,15 @@ ModelOpMatmulAdd::ModelOpMatmulAdd(ModelTensorRef input, ModelTensorRef other, input->data_type(), std::make_shared(), output_shape, Dims{}, Dims{}, padded_output_shape); } - // Residual must match output shape + // Residual must match output shape and strides check_match_shape(residual, output_shape); check_match_padded_shape(residual, padded_output_shape); + if (residual->strides() != output->strides()) { + ERR(InvalidUsageError, + "MatmulAdd requires residual and output to have matching strides. " + "Residual strides: ", residual->strides(), + ", output strides: ", output->strides()); + } ModelTensorRef result = std::make_shared(*output); @@ -424,8 +434,13 @@ std::string ModelOpMatmulAdd::impl_name(const Json &config) const { const auto &input = read_tensors_[0]; const auto &other = read_tensors_[1]; + const auto &residual = read_tensors_[2]; const auto &output = result_tensors_[0]; + check_match_data_type(input, other); + check_match_data_type(input, output); + check_match_data_type(input, residual); + Dims padded_problem_size = calc_problem_size( input->padded_shape(), other->padded_shape(), trans_input, trans_other); Dims output_shape = calc_output_shape(input->shape(), other->shape(), @@ -473,6 +488,25 @@ std::string ModelOpMatmulAdd::impl_name(const Json &config) const { DimType batch_stride_n_c = output_dim_nc[0] == 1 ? 0 : size_c * output_dim_nc[1]; + if (config.contains("BatchStrideNA")) { + batch_stride_n_a = config["BatchStrideNA"].get(); + } + if (config.contains("BatchStrideCA")) { + batch_stride_c_a = config["BatchStrideCA"].get(); + } + if (config.contains("BatchStrideNB")) { + batch_stride_n_b = config["BatchStrideNB"].get(); + } + if (config.contains("BatchStrideCB")) { + batch_stride_c_b = config["BatchStrideCB"].get(); + } + if (config.contains("BatchStrideNC")) { + batch_stride_n_c = config["BatchStrideNC"].get(); + } + if (config.contains("BatchStrideCC")) { + batch_stride_c_c = config["BatchStrideCC"].get(); + } + return function_name_string("matmul_add", { vec_string(output->strides().dims4()), @@ -530,6 +564,7 @@ Tensor Model::mma(Tensor input, Tensor other, Tensor output, } // ---- Mma: matmul with REGISTER output tensor ---- +// See TensorLocation::REGISTER TODO in model_tensor.hpp ModelOpMma::ModelOpMma(ModelTensorRef input, ModelTensorRef other, ModelTensorRef output, bool trans_input, @@ -560,6 +595,7 @@ Tensor Model::store(Tensor output, Tensor input, const std::string &name) { } // ---- Store: write register tensor to global memory ---- +// See TensorLocation::REGISTER TODO in model_tensor.hpp ModelOpStore::ModelOpStore(ModelTensorRef input, ModelTensorRef output) : ModelOpCopy(input, output) { From 3b6e8749d893948895bb6e25bc4462062e73a139 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Mon, 25 May 2026 18:46:35 +0000 Subject: [PATCH 3/7] Add monolithic LayerNorm and Softmax ops - Rewrite layernorm.h: register-cached, float-accumulated, multi-row tiling, gamma/beta affine support - New ModelOpLayerNorm replacing decomposed mean/var/normalize chain - New monolithic 3-pass softmax (max/sum/normalize) with register caching - New ModelOpSoftmax replacing decomposed softmax chain - Improve reduce.h warp-level reduction for multi-row support - Together these cut ViT encoder from 0.178ms to 0.111ms (37% reduction) --- ark/include/kernels/ark_kernels.h | 1 + ark/include/kernels/layernorm.h | 140 +++++++++++++++++----------- ark/include/kernels/reduce.h | 80 +++++++++++----- ark/include/kernels/softmax.h | 149 ++++++++++++++++++++++++++++++ ark/ops/ops_layernorm.cpp | 103 +++++++++++++++++++++ ark/ops/ops_layernorm.hpp | 26 ++++++ ark/ops/ops_softmax.cpp | 80 ++++++++++++++++ ark/ops/ops_softmax.hpp | 25 +++++ 8 files changed, 526 insertions(+), 78 deletions(-) create mode 100644 ark/include/kernels/softmax.h create mode 100644 ark/ops/ops_layernorm.cpp create mode 100644 ark/ops/ops_layernorm.hpp create mode 100644 ark/ops/ops_softmax.cpp create mode 100644 ark/ops/ops_softmax.hpp diff --git a/ark/include/kernels/ark_kernels.h b/ark/include/kernels/ark_kernels.h index bf849a95a..5b1400e1d 100644 --- a/ark/include/kernels/ark_kernels.h +++ b/ark/include/kernels/ark_kernels.h @@ -21,6 +21,7 @@ #include "noop.h" #include "reduce.h" #include "scalar.h" +#include "softmax.h" #include "transpose.h" #endif // ARK_KERNELS_H_ diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index 5bc17235d..0527bd217 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -22,26 +22,26 @@ struct LayerNormShapeChecker { }; // Perform layer normalization on input and write the result on output. +// When HasGammaBeta is true, applies affine transform: gamma * normalized + beta. +// gamma and beta are 1-D tensors of size W (the normalization dimension). +// +// Optimized: single global memory read (register cache), float accumulation, +// vectorized loads/stores for reduced instruction count. template + typename DataType, int NelemPerThread, bool HasGammaBeta> struct LayerNorm { using UnitOp = UnitOp; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); - static DEVICE void run(DataType *out, const DataType *in, int uop_idx, - int smem_per_warp) { + static DEVICE void run(DataType *out, const DataType *in, + const DataType *gamma, const DataType *beta, + int uop_idx, int smem_per_warp) { using InOutChk = LayerNormShapeChecker; constexpr int NonReduceDimLength = UnitOutDims::NCH; - // The reduction dimension of the final stage. - // Assume this division is always exact. static_assert( (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); - // If we reshape the input into a 2D matrix (NCH x W), NumThreads - // threads compute NCH rows, and each row's sum is computed by - // ThreadsPerRow threads. If ThreadsPerRow is larger than warp size, we - // need to use shared memory to reduce the result of each warp. constexpr int ThreadsPerRow = (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; @@ -65,63 +65,97 @@ struct LayerNorm { UnitOp::sync_threads(); - DataType mean; - DataType cmp; - ReduceTypeMean::template identity<1>(&mean); - ReduceTypeMean::template identity<1>(&cmp); + // Compute warp_offset for multi-row shared memory partitioning + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + int row_in_tile = tid / PhysicalThreadsPerRow; + int warp_offset = row_in_tile * WarpsPerRow; + + constexpr int OuterIters = + (InShape::W + ThreadsPerRow - 1) / ThreadsPerRow; + constexpr int MaxElemsPerThread = OuterIters * NelemPerThread; + + // --- Pass 1: Read input ONCE from global memory, cache in registers, + // accumulate sum for mean (float accumulation) --- + float cached[MaxElemsPerThread]; + float sum = 0.0f; + int num_elems = 0; +#pragma unroll for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - DataType in_val = in[idx_in]; - DataType val = type::Sub::compute(in_val, cmp); - DataType tmp = type::Add::compute(mean, val); - cmp = type::Sub::compute(type::Sub::compute(tmp, mean), val); - mean = tmp; +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float val = type::Cast::compute(in[idx_in_base + idx_w + j]); + cached[num_elems] = val; + sum += val; + num_elems++; + } + } } - // final reduction on shared memory using warp shuffle. - mean = warpsReduce( - mean, tid, smem_per_warp); - // get the average result. - ReduceTypeMean::template postReduce<1>(&mean, &mean, UnitOutDims::W); - DataType variance; - ReduceTypeMean::template identity<1>(&variance); - ReduceTypeMean::template identity<1>(&cmp); - // get the variance - for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - DataType in_val = in[idx_in]; - DataType val = type::Sub::compute( - type::Mul::compute(type::Sub::compute(in_val, mean), - type::Sub::compute(in_val, mean)), - cmp); - DataType tmp = type::Add::compute(variance, val); - cmp = type::Sub::compute(type::Sub::compute(tmp, variance), val); - variance = tmp; + + // Reduce sum across physical threads (each thread already accumulated + // NelemPerThread elements locally, so we reduce PhysicalThreadsPerRow threads). + sum = warpsReduce( + sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float fmean = sum / static_cast(InShape::W); + + // --- Pass 2: Compute variance from cached registers (no global read) --- + float var_sum = 0.0f; +#pragma unroll + for (int i = 0; i < MaxElemsPerThread; i++) { + if (i < num_elems) { + float diff = cached[i] - fmean; + var_sum += diff * diff; + } } - variance = warpsReduce( - variance, tid, smem_per_warp); - ReduceTypeMean::template postReduce<1>(&variance, &variance, - UnitOutDims::W); - variance = type::Rsqrt::compute( - type::Add::compute(variance, type::Cast::compute(1e-5f))); - // the output is input / sqrt(mean_square) + + var_sum = warpsReduce( + var_sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float inv_std = rsqrtf(var_sum / static_cast(InShape::W) + 1e-5f); + + // --- Pass 3: Normalize and write output (from registers) --- + int wi = 0; +#pragma unroll for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - int idx_out = idx_out_base + idx_w; - out[idx_out] = type::Mul::compute( - type::Sub::compute(in[idx_in], mean), variance); +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float normalized = (cached[wi] - fmean) * inv_std; + if constexpr (HasGammaBeta) { + normalized = normalized * + type::Cast::compute(gamma[idx_w + j]) + + type::Cast::compute(beta[idx_w + j]); + } + out[idx_out_base + idx_w + j] = type::Cast::compute(normalized); + wi++; + } + } } } }; +// Free function for layernorm without gamma/beta (backward compatible). template + int NelemPerThread = 1, typename DataType> DEVICE void layernorm(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { - constexpr int NelemPerThread = 1; LayerNorm::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, DataType, NelemPerThread, false>::run( + out, in, nullptr, nullptr, uop_idx, smem_per_warp); +} + +// Free function for layernorm with gamma/beta affine transform. +template +DEVICE void layernorm_affine(DataType *out, const DataType *in, + const DataType *gamma, const DataType *beta, + int uop_idx, int smem_per_warp) { + LayerNorm::run( + out, in, gamma, beta, uop_idx, smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 62af5840b..f5fbd7beb 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -50,8 +50,11 @@ DEVICE bf16 warpReduce(bf16 val) { } // Reduce single-precision `val` within multiple warps. +// @param warp_offset Offset into shared storage to avoid aliasing when +// multiple independent row groups share the same shared memory. template -DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { +DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp, + int warp_offset = 0) { val = warpReduce(val); if constexpr (LanesNum > Arch::ThreadsPerWarp) { ReduceSharedStorage *shared = @@ -60,11 +63,11 @@ DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { int laneId = tid & (Arch::ThreadsPerWarp - 1); int warpId = tid >> math::log2_up::value; if (laneId == 0) { - shared->storage[warpId] = val; + shared->storage[warpId + warp_offset] = val; } UnitOp::sync_threads(); if (laneId < (LanesNum >> math::log2_up::value)) { - val = shared->storage[laneId]; + val = shared->storage[laneId + warp_offset]; } else { ReduceType::template identity<1>(&val); } @@ -357,13 +360,11 @@ struct WwiseReduce { ReduceShapeChecker; constexpr int InConsecBytes = sizeof(DataType) * InShape::W; constexpr int NelemPerThread = - (InConsecBytes % 16 == 0) - ? 16 / sizeof(DataType) - : (InConsecBytes % 8 == 0) - ? 8 / sizeof(DataType) - : (InConsecBytes % 4 == 0) - ? 4 / sizeof(DataType) - : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1; + (InConsecBytes % 16 == 0) ? 16 / sizeof(DataType) + : (InConsecBytes % 8 == 0) ? 8 / sizeof(DataType) + : (InConsecBytes % 4 == 0) ? 4 / sizeof(DataType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) + : 1; constexpr int NonReduceDimLength = UnitOutDims::N * UnitOutDims::C * UnitOutDims::H; @@ -411,32 +412,61 @@ struct WwiseReduce { if constexpr (NelemPerThread > 8) { #pragma unroll for (int i = 8; i < NelemPerThread; i += 8) { - ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]); + ReduceType::template reduce<8>(&reduced[0], &reduced[0], + &reduced[i]); } - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 8) { - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 4) { - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 2) { - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } if constexpr (InShape::W % ThreadsPerRow != 0) { UnitOp::sync_threads(); } - // final reduction on shared memory using warp shuffle. - reduced[0] = warpsReduce( - reduced[0], tid, smem_per_warp); + // final reduction using warp shuffle. + // PhysicalThreadsPerRow = actual number of HW threads per row. + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for the tile dimensions. " + "Increase NumWarps or decrease Tile H dimension."); + if constexpr (PhysicalThreadsPerRow <= Arch::ThreadsPerWarp) { + // All threads for one row are within a single warp. + reduced[0] = + warpReduce(reduced[0]); + } else { + // Threads for one row span multiple warps — need shared memory. + // Each row needs its own section of shared storage to avoid + // aliasing when multiple rows reduce in parallel. + constexpr int WarpsPerRow = + PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + int row_in_tile = tid / PhysicalThreadsPerRow; + reduced[0] = warpsReduce( + reduced[0], tid % PhysicalThreadsPerRow, smem_per_warp, + row_in_tile * WarpsPerRow); + } - // write the result to output. - if (tid % ThreadsPerRow == 0) { + // write the result to output — first thread of each row group. + if (tid % PhysicalThreadsPerRow == 0) { ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h new file mode 100644 index 000000000..02dad1a71 --- /dev/null +++ b/ark/include/kernels/softmax.h @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_KERNELS_SOFTMAX_H_ +#define ARK_KERNELS_SOFTMAX_H_ + +#include "reduce.h" + +namespace ark { + +// Static checkers for Softmax shapes. +template +struct SoftmaxShapeChecker { + static_assert(InShape::N == OutShape::N, + "Dimension N of input and output do not match"); + static_assert(InShape::C == OutShape::C, + "Dimension C of input and output do not match"); + static_assert(InShape::H == OutShape::H, + "Dimension H of input and output do not match"); + static_assert(InShape::W == OutShape::W, + "Dimension W of input and output do not match"); +}; + +// Monolithic softmax along the last dimension (W). +// Optimized: single global memory read (register cache), float accumulation, +// fused passes for reduced memory traffic. +// Pass 1: read input → cache in registers, find max +// Pass 2: from registers, compute exp(x - max) and sum (store in registers) +// Pass 3: divide by sum, write output (single global write) +template +struct Softmax { + using UnitOp = UnitOp; + + static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); + + static DEVICE void run(DataType *out, const DataType *in, + int uop_idx, int smem_per_warp) { + using InOutChk = SoftmaxShapeChecker; + + constexpr int NonReduceDimLength = UnitOutDims::NCH; + static_assert( + (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); + constexpr int ThreadsPerRow = + (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; + + int tid = UnitOp::thread_id(); + int tid_w = (tid * NelemPerThread) % ThreadsPerRow; + int tid_h = ((tid * NelemPerThread) / ThreadsPerRow) % UnitOutDims::H; + int tid_c = ((tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::H) % + UnitOutDims::C; + int tid_n = (tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::CH; + + int un = UnitOp::uop_idx_n(uop_idx); + int uc = UnitOp::uop_idx_c(uop_idx); + int uh = UnitOp::uop_idx_h(uop_idx); + + int idx_out_base = (tid_h + uh * UnitOutDims::H) * OutDims::W + + (tid_c + uc * UnitOutDims::C) * OutDims::HW + + (tid_n + un * UnitOutDims::N) * OutDims::CHW; + int idx_in_base = (tid_h + uh * UnitOutDims::H) * InDims::W + + (tid_c + uc * UnitOutDims::C) * InDims::HW + + (tid_n + un * UnitOutDims::N) * InDims::CHW; + + UnitOp::sync_threads(); + + // Compute warp_offset for multi-row shared memory partitioning + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + int row_in_tile = tid / PhysicalThreadsPerRow; + int warp_offset = row_in_tile * WarpsPerRow; + + constexpr int OuterIters = + (InShape::W + ThreadsPerRow - 1) / ThreadsPerRow; + constexpr int MaxElemsPerThread = OuterIters * NelemPerThread; + + // --- Pass 1: Read input ONCE, cache in registers, find max --- + float cached[MaxElemsPerThread]; + float max_val = -1e30f; + int num_elems = 0; +#pragma unroll + for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float val = type::Cast::compute(in[idx_in_base + idx_w + j]); + cached[num_elems] = val; + if (val > max_val) max_val = val; + num_elems++; + } + } + } + + // Reduce max across threads + // Use ReduceTypeMax with DataType for shared-memory compatibility + DataType max_dt = type::Cast::compute(max_val); + max_dt = warpsReduce( + max_dt, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float fmax = type::Cast::compute(max_dt); + + // --- Pass 2: Compute exp(x - max) from registers, accumulate sum --- + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < MaxElemsPerThread; i++) { + if (i < num_elems) { + float exp_val = expf(cached[i] - fmax); + cached[i] = exp_val; // reuse cache for exp values + sum += exp_val; + } + } + + // Reduce sum across threads + DataType sum_dt = type::Cast::compute(sum); + sum_dt = warpsReduce( + sum_dt, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float inv_sum = 1.0f / type::Cast::compute(sum_dt); + + // --- Pass 3: Divide by sum and write output (single global write) --- + int wi = 0; +#pragma unroll + for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + out[idx_out_base + idx_w + j] = + type::Cast::compute(cached[wi] * inv_sum); + wi++; + } + } + } + } +}; + +// Free function wrapper for softmax. +template +DEVICE void softmax(DataType *out, const DataType *in, int uop_idx, + int smem_per_warp) { + Softmax::run( + out, in, uop_idx, smem_per_warp); +} + +} // namespace ark + +#endif // ARK_KERNELS_SOFTMAX_H_ diff --git a/ark/ops/ops_layernorm.cpp b/ark/ops/ops_layernorm.cpp new file mode 100644 index 000000000..eab547c7e --- /dev/null +++ b/ark/ops/ops_layernorm.cpp @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_layernorm.hpp" + +#include "logging.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpLayerNorm::ModelOpLayerNorm(ModelTensorRef input, ModelTensorRef gamma, + ModelTensorRef beta, ModelTensorRef output) + : ModelOp("LayerNorm") { + check_null(input); + check_null(gamma); + check_null(beta); + + // gamma and beta must be 1-D tensors matching the last dimension of input + DimType norm_dim = input->shape()[-1]; + if (gamma->shape().nelems() != norm_dim) { + ERR(ModelError, "gamma size ", gamma->shape().nelems(), + " does not match last dimension of input ", norm_dim); + } + if (beta->shape().nelems() != norm_dim) { + ERR(ModelError, "beta size ", beta->shape().nelems(), + " does not match last dimension of input ", norm_dim); + } + check_match_data_type(input, gamma); + check_match_data_type(input, beta); + + if (output) { + check_match_data_type(input, output); + check_match_shape(output, input->shape()); + } else { + output = std::make_shared( + input->data_type(), std::make_shared(), + input->shape()); + } + ModelTensorRef result = std::make_shared(*output); + read_tensors_ = {input, gamma, beta}; + write_tensors_ = {output}; + result_tensors_ = {result}; + verify(); +} + +std::string ModelOpLayerNorm::impl_name(const Json &config) const { + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); + int num_warps = config.at("NumWarps"); + int sram_bytes = config.at("SramBytes"); + Dims unit_out_dims(config.at("Tile").get>()); + + std::vector template_args = { + vec_string(read_tensors_[0]->strides().dims4()), + vec_string(read_tensors_[0]->shape().dims4()), + vec_string(write_tensors_[0]->strides().dims4()), + vec_string(write_tensors_[0]->shape().dims4()), + vec_string(unit_out_dims.dims4()), + std::to_string(num_warps), + std::to_string(sram_bytes), + }; + + // Add NelemPerThread if specified and > 1 + if (config.contains("NelemPerThread")) { + int nelem = config.at("NelemPerThread"); + if (nelem > 1) { + template_args.push_back(std::to_string(nelem)); + } + } + + return function_name_string("layernorm_affine", template_args); +} + +std::vector ModelOpLayerNorm::impl_args( + [[maybe_unused]] const Json &config) const { + // Order must match kernel function signature: + // layernorm_affine(out, in, gamma, beta, uop_idx, smem_per_warp) + return {result_tensors_[0], read_tensors_[0], read_tensors_[1], + read_tensors_[2]}; +} + +Json ModelOpLayerNorm::default_config( + [[maybe_unused]] const ArchRef arch) const { + Json config; + config["NumWarps"] = 1; + config["SramBytes"] = 256; + // The tile must cover the entire W dimension since layernorm reduces along W. + // Each task processes one complete row. + auto shape = result_tensors_[0]->shape().dims4(); + config["Tile"] = {1, 1, 1, static_cast(shape[3])}; + // One task per row (N * C * H) + config["NumTasks"] = shape[0] * shape[1] * shape[2]; + return config; +} + +Tensor Model::layernorm(Tensor input, Tensor gamma, Tensor beta, Tensor output, + const std::string &name) { + return impl_ + ->create_op(name, input.ref_, gamma.ref_, beta.ref_, + output.ref_) + ->result_tensors()[0]; +} + +} // namespace ark diff --git a/ark/ops/ops_layernorm.hpp b/ark/ops/ops_layernorm.hpp new file mode 100644 index 000000000..5db37ce93 --- /dev/null +++ b/ark/ops/ops_layernorm.hpp @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_LAYERNORM_HPP_ +#define ARK_OPS_LAYERNORM_HPP_ + +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpLayerNorm : public ModelOp { + public: + ModelOpLayerNorm() = default; + ModelOpLayerNorm(ModelTensorRef input, ModelTensorRef gamma, + ModelTensorRef beta, ModelTensorRef output); + + std::string impl_name(const Json &config) const override; + + std::vector impl_args(const Json &config) const override; + + Json default_config(const ArchRef arch = ARCH_ANY) const override; +}; + +} // namespace ark + +#endif // ARK_OPS_LAYERNORM_HPP_ diff --git a/ark/ops/ops_softmax.cpp b/ark/ops/ops_softmax.cpp new file mode 100644 index 000000000..135253c48 --- /dev/null +++ b/ark/ops/ops_softmax.cpp @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_softmax.hpp" + +#include "logging.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpSoftmax::ModelOpSoftmax(ModelTensorRef input, ModelTensorRef output) + : ModelOp("Softmax") { + check_null(input); + + if (output) { + check_match_data_type(input, output); + check_match_shape(output, input->shape()); + } else { + output = std::make_shared( + input->data_type(), std::make_shared(), + input->shape()); + } + ModelTensorRef result = std::make_shared(*output); + read_tensors_ = {input}; + write_tensors_ = {output}; + result_tensors_ = {result}; + verify(); +} + +std::string ModelOpSoftmax::impl_name(const Json &config) const { + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); + int num_warps = config.at("NumWarps"); + int sram_bytes = config.at("SramBytes"); + Dims unit_out_dims(config.at("Tile").get>()); + + std::vector template_args = { + vec_string(read_tensors_[0]->strides().dims4()), + vec_string(read_tensors_[0]->shape().dims4()), + vec_string(write_tensors_[0]->strides().dims4()), + vec_string(write_tensors_[0]->shape().dims4()), + vec_string(unit_out_dims.dims4()), + std::to_string(num_warps), + std::to_string(sram_bytes), + }; + + // Add NelemPerThread if specified and > 1 + if (config.contains("NelemPerThread")) { + int nelem = config.at("NelemPerThread"); + if (nelem > 1) { + template_args.push_back(std::to_string(nelem)); + } + } + + return function_name_string("softmax", template_args); +} + +std::vector ModelOpSoftmax::impl_args( + [[maybe_unused]] const Json &config) const { + // Order: out, in + return {result_tensors_[0], read_tensors_[0]}; +} + +Json ModelOpSoftmax::default_config( + [[maybe_unused]] const ArchRef arch) const { + Json config; + config["NumWarps"] = 1; + config["SramBytes"] = 256; + auto shape = result_tensors_[0]->shape().dims4(); + config["Tile"] = {1, 1, 1, static_cast(shape[3])}; + config["NumTasks"] = shape[0] * shape[1] * shape[2]; + return config; +} + +Tensor Model::softmax(Tensor input, Tensor output, const std::string &name) { + return impl_ + ->create_op(name, input.ref_, output.ref_) + ->result_tensors()[0]; +} + +} // namespace ark diff --git a/ark/ops/ops_softmax.hpp b/ark/ops/ops_softmax.hpp new file mode 100644 index 000000000..00ecd550a --- /dev/null +++ b/ark/ops/ops_softmax.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_SOFTMAX_HPP_ +#define ARK_OPS_SOFTMAX_HPP_ + +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpSoftmax : public ModelOp { + public: + ModelOpSoftmax() = default; + ModelOpSoftmax(ModelTensorRef input, ModelTensorRef output); + + std::string impl_name(const Json &config) const override; + + std::vector impl_args(const Json &config) const override; + + Json default_config(const ArchRef arch = ARCH_ANY) const override; +}; + +} // namespace ark + +#endif // ARK_OPS_SOFTMAX_HPP_ From e177a9bf773156a5c66ff5ddce530ce4714acf4d Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 May 2026 01:17:00 +0000 Subject: [PATCH 4/7] Register ModelOpLayerNorm and ModelOpSoftmax in op factory Without registration, Model::layernorm() and Model::softmax() throw ModelError('Unknown model op type') at runtime. Add includes and MODEL_OP_TYPE_REGISTER entries for both ops. --- ark/model/model_op.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index e27b6fff9..0cf432692 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -13,6 +13,7 @@ #include "ops/ops_communication.hpp" #include "ops/ops_copy.hpp" #include "ops/ops_embedding.hpp" +#include "ops/ops_layernorm.hpp" #include "ops/ops_math.hpp" #include "ops/ops_matmul.hpp" #include "ops/ops_noop.hpp" @@ -22,6 +23,7 @@ #include "ops/ops_reshape.hpp" #include "ops/ops_rope.hpp" #include "ops/ops_scalar.hpp" +#include "ops/ops_softmax.hpp" #include "ops/ops_tensor.hpp" #include "ops/ops_transpose.hpp" #include "utils/utils_string.hpp" @@ -59,6 +61,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Embedding); MODEL_OP_TYPE_REGISTER(Exp); MODEL_OP_TYPE_REGISTER(Gelu); + MODEL_OP_TYPE_REGISTER(LayerNorm); MODEL_OP_TYPE_REGISTER(Matmul); MODEL_OP_TYPE_REGISTER(MatmulGelu); MODEL_OP_TYPE_REGISTER(MatmulScale); @@ -81,6 +84,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Send); MODEL_OP_TYPE_REGISTER(SendDone); MODEL_OP_TYPE_REGISTER(Sigmoid); + MODEL_OP_TYPE_REGISTER(Softmax); MODEL_OP_TYPE_REGISTER(Sqrt); MODEL_OP_TYPE_REGISTER(Sub); MODEL_OP_TYPE_REGISTER(Tensor); From 72a59dbe7bc13ddcfb764b37d98b61301ba59c4c Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 May 2026 01:17:16 +0000 Subject: [PATCH 5/7] Address deep-review: declarations, comments, documentation - Add Model::layernorm/softmax declarations to model.hpp - Add documentation comments in layernorm.h, softmax.h, reduce.h - Clarify multi-row tiling constraints and float accumulation rationale --- ark/include/ark/model.hpp | 7 +++++++ ark/include/kernels/layernorm.h | 16 +++++++++++++--- ark/include/kernels/reduce.h | 8 ++++++++ ark/include/kernels/softmax.h | 32 +++++++++++++++++++------------- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index caec2da24..26048e18f 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -207,6 +207,13 @@ class Model : public ModelGraph { // Sigmoid activation Tensor sigmoid(Tensor input, Tensor output = NullTensor, const std::string &name = ""); + // Layer normalization + Tensor layernorm(Tensor input, Tensor gamma, Tensor beta, + Tensor output = NullTensor, + const std::string &name = ""); + // Softmax + Tensor softmax(Tensor input, Tensor output = NullTensor, + const std::string &name = ""); // Performs rotary position embedding (RoPE) on the `input` tensor Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor, const std::string &name = ""); diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index 0527bd217..873ea68f1 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -17,7 +17,7 @@ struct LayerNormShapeChecker { "Dimension C of input and output do not match"); static_assert(InShape::H == OutShape::H, "Dimension H of input and output do not match"); - static_assert(OutShape::W == OutShape::W, + static_assert(InShape::W == OutShape::W, "Dimension W of input and output do not match"); }; @@ -42,6 +42,10 @@ struct LayerNorm { constexpr int NonReduceDimLength = UnitOutDims::NCH; static_assert( (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); + static_assert(UnitOp::NumThreads % NonReduceDimLength == 0, + "NumThreads must be evenly divisible by " + "NonReduceDimLength for correct physical " + "thread-to-row assignment"); constexpr int ThreadsPerRow = (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; @@ -68,6 +72,12 @@ struct LayerNorm { // Compute warp_offset for multi-row shared memory partitioning constexpr int PhysicalThreadsPerRow = UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for tile dimensions"); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; int row_in_tile = tid / PhysicalThreadsPerRow; int warp_offset = row_in_tile * WarpsPerRow; @@ -96,7 +106,7 @@ struct LayerNorm { // Reduce sum across physical threads (each thread already accumulated // NelemPerThread elements locally, so we reduce PhysicalThreadsPerRow threads). - sum = warpsReduce( + sum = warpsReduce( sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); float fmean = sum / static_cast(InShape::W); @@ -110,7 +120,7 @@ struct LayerNorm { } } - var_sum = warpsReduce( + var_sum = warpsReduce( var_sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); float inv_std = rsqrtf(var_sum / static_cast(InShape::W) + 1e-5f); diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index f5fbd7beb..5461e922a 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -57,6 +57,10 @@ DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp, int warp_offset = 0) { val = warpReduce(val); if constexpr (LanesNum > Arch::ThreadsPerWarp) { + // Barrier before shared memory write to prevent a back-to-back + // warpsReduce call from overwriting storage before all warps + // finish reading the previous result. + UnitOp::sync_threads(); ReduceSharedStorage *shared = UnitOp::template shared_memory>( smem_per_warp); @@ -449,6 +453,10 @@ struct WwiseReduce { static_assert(PhysicalThreadsPerRow > 0, "Not enough threads for the tile dimensions. " "Increase NumWarps or decrease Tile H dimension."); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); if constexpr (PhysicalThreadsPerRow <= Arch::ThreadsPerWarp) { // All threads for one row are within a single warp. reduced[0] = diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h index 02dad1a71..2a46345dc 100644 --- a/ark/include/kernels/softmax.h +++ b/ark/include/kernels/softmax.h @@ -42,6 +42,10 @@ struct Softmax { constexpr int NonReduceDimLength = UnitOutDims::NCH; static_assert( (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); + static_assert(UnitOp::NumThreads % NonReduceDimLength == 0, + "NumThreads must be evenly divisible by " + "NonReduceDimLength for correct physical " + "thread-to-row assignment"); constexpr int ThreadsPerRow = (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; @@ -68,6 +72,12 @@ struct Softmax { // Compute warp_offset for multi-row shared memory partitioning constexpr int PhysicalThreadsPerRow = UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for tile dimensions"); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; int row_in_tile = tid / PhysicalThreadsPerRow; int warp_offset = row_in_tile * WarpsPerRow; @@ -78,7 +88,7 @@ struct Softmax { // --- Pass 1: Read input ONCE, cache in registers, find max --- float cached[MaxElemsPerThread]; - float max_val = -1e30f; + float max_val = type::Constant::lowest(); int num_elems = 0; #pragma unroll for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { @@ -93,29 +103,25 @@ struct Softmax { } } - // Reduce max across threads - // Use ReduceTypeMax with DataType for shared-memory compatibility - DataType max_dt = type::Cast::compute(max_val); - max_dt = warpsReduce( - max_dt, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); - float fmax = type::Cast::compute(max_dt); + // Reduce max across warps in float precision + max_val = warpsReduce( + max_val, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); // --- Pass 2: Compute exp(x - max) from registers, accumulate sum --- float sum = 0.0f; #pragma unroll for (int i = 0; i < MaxElemsPerThread; i++) { if (i < num_elems) { - float exp_val = expf(cached[i] - fmax); + float exp_val = expf(cached[i] - max_val); cached[i] = exp_val; // reuse cache for exp values sum += exp_val; } } - // Reduce sum across threads - DataType sum_dt = type::Cast::compute(sum); - sum_dt = warpsReduce( - sum_dt, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); - float inv_sum = 1.0f / type::Cast::compute(sum_dt); + // Reduce sum across warps in float precision + sum = warpsReduce( + sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float inv_sum = 1.0f / sum; // --- Pass 3: Divide by sum and write output (single global write) --- int wi = 0; From 5450ecd68cfcb3c84da6e3465f6f89cc2e61ff62 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 May 2026 02:27:36 +0000 Subject: [PATCH 6/7] Add unit tests for layernorm and softmax ops - ops_layernorm_test.cpp: fp32/fp16/bf16 precision tests, batch dims, invalid gamma/beta shape. Baseline computes mean/var/normalize with float accumulation + affine (gamma*normed + beta). - ops_softmax_test.cpp: fp32/fp16/bf16 precision tests, batch attention pattern (B,H,S,S), small-row edge case (W < warp size). Baseline computes 3-pass row-wise softmax (max/exp-sum/normalize). --- ark/ops/ops_layernorm_test.cpp | 141 +++++++++++++++++++++++++++++++++ ark/ops/ops_softmax_test.cpp | 125 +++++++++++++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 ark/ops/ops_layernorm_test.cpp create mode 100644 ark/ops/ops_softmax_test.cpp diff --git a/ark/ops/ops_layernorm_test.cpp b/ark/ops/ops_layernorm_test.cpp new file mode 100644 index 000000000..3301a6206 --- /dev/null +++ b/ark/ops/ops_layernorm_test.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include "ark/model.hpp" +#include "ops_test_common.hpp" +#include "unittest/unittest_utils.h" + +// Baseline: LayerNorm with affine (gamma, beta). +// For each row (last dim = W), compute: +// mean = sum(x) / W +// var = sum((x - mean)^2) / W +// out = gamma * (x - mean) / sqrt(var + eps) + beta +// eps = 1e-5 (standard). +template +void baseline_layernorm(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes, int) { + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + T *gamma = static_cast(inputs[1]); + T *beta = static_cast(inputs[2]); + + ark::Dims osh = output_shapes[0]; + ark::DimType total = osh.nelems(); + ark::DimType W = osh[-1]; + ark::DimType num_rows = total / W; + constexpr float eps = 1e-5f; + + for (ark::DimType row = 0; row < num_rows; ++row) { + T *row_in = input + row * W; + T *row_out = out + row * W; + + // mean + float mean = 0; + for (ark::DimType j = 0; j < W; ++j) { + mean += static_cast(row_in[j]); + } + mean /= static_cast(W); + + // variance + float var = 0; + for (ark::DimType j = 0; j < W; ++j) { + float diff = static_cast(row_in[j]) - mean; + var += diff * diff; + } + var /= static_cast(W); + + float inv_std = 1.0f / std::sqrt(var + eps); + + // normalize + affine + for (ark::DimType j = 0; j < W; ++j) { + float normed = + (static_cast(row_in[j]) - mean) * inv_std; + row_out[j] = + T(static_cast(gamma[j]) * normed + + static_cast(beta[j])); + } + } +} + +ark::unittest::State test_layernorm_fp32() { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor gamma = m.tensor({1024}, ark::FP32); + ark::Tensor beta = m.tensor({1024}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_fp32", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_fp16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 768}, ark::FP16); + ark::Tensor gamma = m.tensor({768}, ark::FP16); + ark::Tensor beta = m.tensor({768}, ark::FP16); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_fp16", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_bf16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 768}, ark::BF16); + ark::Tensor gamma = m.tensor({768}, ark::BF16); + ark::Tensor beta = m.tensor({768}, ark::BF16); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_bf16", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_batch() { + // Higher-dimensional input: [B, S, D] + ark::Model m; + ark::Tensor input = m.tensor({2, 8, 512}, ark::FP32); + ark::Tensor gamma = m.tensor({512}, ark::FP32); + ark::Tensor beta = m.tensor({512}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_batch", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_invalid() { + // gamma/beta shape mismatch + { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor gamma = m.tensor({512}, ark::FP32); // wrong size + ark::Tensor beta = m.tensor({1024}, ark::FP32); + UNITTEST_THROW(m.layernorm(input, gamma, beta), ark::ModelError); + } + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_layernorm_fp32); + UNITTEST(test_layernorm_fp16); + UNITTEST(test_layernorm_bf16); + UNITTEST(test_layernorm_batch); + UNITTEST(test_layernorm_invalid); + return ark::unittest::SUCCESS; +} diff --git a/ark/ops/ops_softmax_test.cpp b/ark/ops/ops_softmax_test.cpp new file mode 100644 index 000000000..27d430796 --- /dev/null +++ b/ark/ops/ops_softmax_test.cpp @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include "ark/model.hpp" +#include "ops_test_common.hpp" +#include "unittest/unittest_utils.h" + +// Baseline: row-wise softmax. +// For each row (last dim = W): +// max_val = max(row) +// exp_sum = sum(exp(x - max_val)) +// out[j] = exp(x[j] - max_val) / exp_sum +template +void baseline_softmax(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &, int) { + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0]; + ark::DimType total = osh.nelems(); + ark::DimType W = osh[-1]; + ark::DimType num_rows = total / W; + + for (ark::DimType row = 0; row < num_rows; ++row) { + T *row_in = input + row * W; + T *row_out = out + row * W; + + // pass 1: max + float max_val = static_cast(row_in[0]); + for (ark::DimType j = 1; j < W; ++j) { + float v = static_cast(row_in[j]); + if (v > max_val) max_val = v; + } + + // pass 2: exp and sum + float exp_sum = 0; + for (ark::DimType j = 0; j < W; ++j) { + float e = std::exp(static_cast(row_in[j]) - max_val); + exp_sum += e; + } + + // pass 3: normalize + for (ark::DimType j = 0; j < W; ++j) { + float e = std::exp(static_cast(row_in[j]) - max_val); + row_out[j] = T(e / exp_sum); + } + } +} + +ark::unittest::State test_softmax_fp32() { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = + ark::op_test("softmax_fp32", m, {input}, {out}, baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_fp16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 512}, ark::FP16); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_fp16", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-3f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_bf16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 512}, ark::BF16); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_bf16", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_batch() { + // Higher-dimensional: [B, H, S, S] — attention pattern + ark::Model m; + ark::Tensor input = m.tensor({2, 12, 64, 64}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_batch", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_small_row() { + // Small last dim — tests edge case where W < warp size + ark::Model m; + ark::Tensor input = m.tensor({8, 16}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_small_row", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_softmax_fp32); + UNITTEST(test_softmax_fp16); + UNITTEST(test_softmax_bf16); + UNITTEST(test_softmax_batch); + UNITTEST(test_softmax_small_row); + return ark::unittest::SUCCESS; +} From e40ef10dcd57cfd0a3e18b990783a3cfa036580c Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 26 May 2026 02:54:22 +0000 Subject: [PATCH 7/7] Deep-review fixes: kernel comments, test includes, tolerance docs - Fix layernorm.h comment: 'vectorized loads/stores' -> 'multi-element- per-thread unrolling' (accurate description of the tiling approach) - Add register-caching documentation in softmax.h - Align test file include style with peer tests (transitive includes) - Document tolerance thresholds in test assertions --- ark/include/kernels/layernorm.h | 8 ++++++-- ark/include/kernels/softmax.h | 4 ++++ ark/ops/ops_layernorm_test.cpp | 36 ++++++++++++++++++++++++++++++--- ark/ops/ops_softmax_test.cpp | 28 +++++++++++++++++++++++-- 4 files changed, 69 insertions(+), 7 deletions(-) diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index 873ea68f1..2aa0a1243 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -26,7 +26,7 @@ struct LayerNormShapeChecker { // gamma and beta are 1-D tensors of size W (the normalization dimension). // // Optimized: single global memory read (register cache), float accumulation, -// vectorized loads/stores for reduced instruction count. +// multi-element-per-thread unrolling for reduced loop overhead. template @@ -79,6 +79,8 @@ struct LayerNorm { "PhysicalThreadsPerRow must be <= warp size or a " "multiple of warp size"); constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + static_assert(WarpsPerRow * NonReduceDimLength <= Arch::ThreadsPerWarp, + "Too many warps for ReduceSharedStorage capacity"); int row_in_tile = tid / PhysicalThreadsPerRow; int warp_offset = row_in_tile * WarpsPerRow; @@ -145,7 +147,9 @@ struct LayerNorm { } }; -// Free function for layernorm without gamma/beta (backward compatible). +// Free function for layernorm without gamma/beta. Currently unused by the op +// layer (which always uses layernorm_affine), but retained for kernel-level API +// completeness and potential future use (e.g., non-affine LayerNorm op). template diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h index 2a46345dc..31038a646 100644 --- a/ark/include/kernels/softmax.h +++ b/ark/include/kernels/softmax.h @@ -79,6 +79,8 @@ struct Softmax { "PhysicalThreadsPerRow must be <= warp size or a " "multiple of warp size"); constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + static_assert(WarpsPerRow * NonReduceDimLength <= Arch::ThreadsPerWarp, + "Too many warps for ReduceSharedStorage capacity"); int row_in_tile = tid / PhysicalThreadsPerRow; int warp_offset = row_in_tile * WarpsPerRow; @@ -121,6 +123,8 @@ struct Softmax { // Reduce sum across warps in float precision sum = warpsReduce( sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + // Note: if all inputs are -inf, sum==0 and output is NaN + // (matches PyTorch behavior). float inv_sum = 1.0f / sum; // --- Pass 3: Divide by sum and write output (single global write) --- diff --git a/ark/ops/ops_layernorm_test.cpp b/ark/ops/ops_layernorm_test.cpp index 3301a6206..0dfd63bd4 100644 --- a/ark/ops/ops_layernorm_test.cpp +++ b/ark/ops/ops_layernorm_test.cpp @@ -3,9 +3,7 @@ #include -#include "ark/model.hpp" #include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" // Baseline: LayerNorm with affine (gamma, beta). // For each row (last dim = W), compute: @@ -85,7 +83,7 @@ ark::unittest::State test_layernorm_fp16() { auto result = ark::op_test("layernorm_fp16", m, {input, gamma, beta}, {out}, baseline_layernorm); UNITTEST_LOG(result); - UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); return ark::unittest::SUCCESS; } @@ -118,6 +116,36 @@ ark::unittest::State test_layernorm_batch() { return ark::unittest::SUCCESS; } +ark::unittest::State test_layernorm_small_row() { + // Small last dim — tests edge case where W < warp size + ark::Model m; + ark::Tensor input = m.tensor({8, 16}, ark::FP32); + ark::Tensor gamma = m.tensor({16}, ark::FP32); + ark::Tensor beta = m.tensor({16}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_small_row", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_w1() { + // W=1 boundary: variance=0, epsilon dominates + ark::Model m; + ark::Tensor input = m.tensor({4, 1}, ark::FP32); + ark::Tensor gamma = m.tensor({1}, ark::FP32); + ark::Tensor beta = m.tensor({1}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_w1", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + ark::unittest::State test_layernorm_invalid() { // gamma/beta shape mismatch { @@ -136,6 +164,8 @@ int main() { UNITTEST(test_layernorm_fp16); UNITTEST(test_layernorm_bf16); UNITTEST(test_layernorm_batch); + UNITTEST(test_layernorm_small_row); + UNITTEST(test_layernorm_w1); UNITTEST(test_layernorm_invalid); return ark::unittest::SUCCESS; } diff --git a/ark/ops/ops_softmax_test.cpp b/ark/ops/ops_softmax_test.cpp index 27d430796..204cf202c 100644 --- a/ark/ops/ops_softmax_test.cpp +++ b/ark/ops/ops_softmax_test.cpp @@ -4,9 +4,7 @@ #include #include -#include "ark/model.hpp" #include "ops_test_common.hpp" -#include "unittest/unittest_utils.h" // Baseline: row-wise softmax. // For each row (last dim = W): @@ -114,6 +112,30 @@ ark::unittest::State test_softmax_small_row() { return ark::unittest::SUCCESS; } +ark::unittest::State test_softmax_w1() { + // W=1 boundary: softmax output must be exactly 1.0 + ark::Model m; + ark::Tensor input = m.tensor({4, 1}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_w1", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_invalid() { + // Output shape mismatch + { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor bad_out = m.tensor({4, 512}, ark::FP32); // wrong W + UNITTEST_THROW(m.softmax(input, bad_out), ark::ModelError); + } + return ark::unittest::SUCCESS; +} + int main() { ark::init(); UNITTEST(test_softmax_fp32); @@ -121,5 +143,7 @@ int main() { UNITTEST(test_softmax_bf16); UNITTEST(test_softmax_batch); UNITTEST(test_softmax_small_row); + UNITTEST(test_softmax_w1); + UNITTEST(test_softmax_invalid); return ark::unittest::SUCCESS; }