diff --git a/ark/include/ark/model.hpp b/ark/include/ark/model.hpp index caec2da2..26048e18 100644 --- a/ark/include/ark/model.hpp +++ b/ark/include/ark/model.hpp @@ -207,6 +207,13 @@ class Model : public ModelGraph { // Sigmoid activation Tensor sigmoid(Tensor input, Tensor output = NullTensor, const std::string &name = ""); + // Layer normalization + Tensor layernorm(Tensor input, Tensor gamma, Tensor beta, + Tensor output = NullTensor, + const std::string &name = ""); + // Softmax + Tensor softmax(Tensor input, Tensor output = NullTensor, + const std::string &name = ""); // Performs rotary position embedding (RoPE) on the `input` tensor Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor, const std::string &name = ""); diff --git a/ark/include/kernels/ark_kernels.h b/ark/include/kernels/ark_kernels.h index bf849a95..5b1400e1 100644 --- a/ark/include/kernels/ark_kernels.h +++ b/ark/include/kernels/ark_kernels.h @@ -21,6 +21,7 @@ #include "noop.h" #include "reduce.h" #include "scalar.h" +#include "softmax.h" #include "transpose.h" #endif // ARK_KERNELS_H_ diff --git a/ark/include/kernels/layernorm.h b/ark/include/kernels/layernorm.h index 5bc17235..2aa0a124 100644 --- a/ark/include/kernels/layernorm.h +++ b/ark/include/kernels/layernorm.h @@ -17,31 +17,35 @@ struct LayerNormShapeChecker { "Dimension C of input and output do not match"); static_assert(InShape::H == OutShape::H, "Dimension H of input and output do not match"); - static_assert(OutShape::W == OutShape::W, + static_assert(InShape::W == OutShape::W, "Dimension W of input and output do not match"); }; // Perform layer normalization on input and write the result on output. +// When HasGammaBeta is true, applies affine transform: gamma * normalized + beta. +// gamma and beta are 1-D tensors of size W (the normalization dimension). +// +// Optimized: single global memory read (register cache), float accumulation, +// multi-element-per-thread unrolling for reduced loop overhead. template + typename DataType, int NelemPerThread, bool HasGammaBeta> struct LayerNorm { using UnitOp = UnitOp; static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); - static DEVICE void run(DataType *out, const DataType *in, int uop_idx, - int smem_per_warp) { + static DEVICE void run(DataType *out, const DataType *in, + const DataType *gamma, const DataType *beta, + int uop_idx, int smem_per_warp) { using InOutChk = LayerNormShapeChecker; constexpr int NonReduceDimLength = UnitOutDims::NCH; - // The reduction dimension of the final stage. - // Assume this division is always exact. static_assert( (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); - // If we reshape the input into a 2D matrix (NCH x W), NumThreads - // threads compute NCH rows, and each row's sum is computed by - // ThreadsPerRow threads. If ThreadsPerRow is larger than warp size, we - // need to use shared memory to reduce the result of each warp. + static_assert(UnitOp::NumThreads % NonReduceDimLength == 0, + "NumThreads must be evenly divisible by " + "NonReduceDimLength for correct physical " + "thread-to-row assignment"); constexpr int ThreadsPerRow = (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; @@ -65,63 +69,107 @@ struct LayerNorm { UnitOp::sync_threads(); - DataType mean; - DataType cmp; - ReduceTypeMean::template identity<1>(&mean); - ReduceTypeMean::template identity<1>(&cmp); + // Compute warp_offset for multi-row shared memory partitioning + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for tile dimensions"); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); + constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + static_assert(WarpsPerRow * NonReduceDimLength <= Arch::ThreadsPerWarp, + "Too many warps for ReduceSharedStorage capacity"); + int row_in_tile = tid / PhysicalThreadsPerRow; + int warp_offset = row_in_tile * WarpsPerRow; + + constexpr int OuterIters = + (InShape::W + ThreadsPerRow - 1) / ThreadsPerRow; + constexpr int MaxElemsPerThread = OuterIters * NelemPerThread; + + // --- Pass 1: Read input ONCE from global memory, cache in registers, + // accumulate sum for mean (float accumulation) --- + float cached[MaxElemsPerThread]; + float sum = 0.0f; + int num_elems = 0; +#pragma unroll for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - DataType in_val = in[idx_in]; - DataType val = type::Sub::compute(in_val, cmp); - DataType tmp = type::Add::compute(mean, val); - cmp = type::Sub::compute(type::Sub::compute(tmp, mean), val); - mean = tmp; +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float val = type::Cast::compute(in[idx_in_base + idx_w + j]); + cached[num_elems] = val; + sum += val; + num_elems++; + } + } } - // final reduction on shared memory using warp shuffle. - mean = warpsReduce( - mean, tid, smem_per_warp); - // get the average result. - ReduceTypeMean::template postReduce<1>(&mean, &mean, UnitOutDims::W); - DataType variance; - ReduceTypeMean::template identity<1>(&variance); - ReduceTypeMean::template identity<1>(&cmp); - // get the variance - for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - DataType in_val = in[idx_in]; - DataType val = type::Sub::compute( - type::Mul::compute(type::Sub::compute(in_val, mean), - type::Sub::compute(in_val, mean)), - cmp); - DataType tmp = type::Add::compute(variance, val); - cmp = type::Sub::compute(type::Sub::compute(tmp, variance), val); - variance = tmp; + + // Reduce sum across physical threads (each thread already accumulated + // NelemPerThread elements locally, so we reduce PhysicalThreadsPerRow threads). + sum = warpsReduce( + sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float fmean = sum / static_cast(InShape::W); + + // --- Pass 2: Compute variance from cached registers (no global read) --- + float var_sum = 0.0f; +#pragma unroll + for (int i = 0; i < MaxElemsPerThread; i++) { + if (i < num_elems) { + float diff = cached[i] - fmean; + var_sum += diff * diff; + } } - variance = warpsReduce( - variance, tid, smem_per_warp); - ReduceTypeMean::template postReduce<1>(&variance, &variance, - UnitOutDims::W); - variance = type::Rsqrt::compute( - type::Add::compute(variance, type::Cast::compute(1e-5f))); - // the output is input / sqrt(mean_square) + + var_sum = warpsReduce( + var_sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + float inv_std = rsqrtf(var_sum / static_cast(InShape::W) + 1e-5f); + + // --- Pass 3: Normalize and write output (from registers) --- + int wi = 0; +#pragma unroll for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { - int idx_in = idx_in_base + idx_w; - int idx_out = idx_out_base + idx_w; - out[idx_out] = type::Mul::compute( - type::Sub::compute(in[idx_in], mean), variance); +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float normalized = (cached[wi] - fmean) * inv_std; + if constexpr (HasGammaBeta) { + normalized = normalized * + type::Cast::compute(gamma[idx_w + j]) + + type::Cast::compute(beta[idx_w + j]); + } + out[idx_out_base + idx_w + j] = type::Cast::compute(normalized); + wi++; + } + } } } }; +// Free function for layernorm without gamma/beta. Currently unused by the op +// layer (which always uses layernorm_affine), but retained for kernel-level API +// completeness and potential future use (e.g., non-affine LayerNorm op). template + int NelemPerThread = 1, typename DataType> DEVICE void layernorm(DataType *out, const DataType *in, int uop_idx, int smem_per_warp) { - constexpr int NelemPerThread = 1; LayerNorm::run(out, in, uop_idx, - smem_per_warp); + SmemBytes, DataType, NelemPerThread, false>::run( + out, in, nullptr, nullptr, uop_idx, smem_per_warp); +} + +// Free function for layernorm with gamma/beta affine transform. +template +DEVICE void layernorm_affine(DataType *out, const DataType *in, + const DataType *gamma, const DataType *beta, + int uop_idx, int smem_per_warp) { + LayerNorm::run( + out, in, gamma, beta, uop_idx, smem_per_warp); } } // namespace ark diff --git a/ark/include/kernels/reduce.h b/ark/include/kernels/reduce.h index 62af5840..5461e922 100644 --- a/ark/include/kernels/reduce.h +++ b/ark/include/kernels/reduce.h @@ -50,21 +50,28 @@ DEVICE bf16 warpReduce(bf16 val) { } // Reduce single-precision `val` within multiple warps. +// @param warp_offset Offset into shared storage to avoid aliasing when +// multiple independent row groups share the same shared memory. template -DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) { +DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp, + int warp_offset = 0) { val = warpReduce(val); if constexpr (LanesNum > Arch::ThreadsPerWarp) { + // Barrier before shared memory write to prevent a back-to-back + // warpsReduce call from overwriting storage before all warps + // finish reading the previous result. + UnitOp::sync_threads(); ReduceSharedStorage *shared = UnitOp::template shared_memory>( smem_per_warp); int laneId = tid & (Arch::ThreadsPerWarp - 1); int warpId = tid >> math::log2_up::value; if (laneId == 0) { - shared->storage[warpId] = val; + shared->storage[warpId + warp_offset] = val; } UnitOp::sync_threads(); if (laneId < (LanesNum >> math::log2_up::value)) { - val = shared->storage[laneId]; + val = shared->storage[laneId + warp_offset]; } else { ReduceType::template identity<1>(&val); } @@ -357,13 +364,11 @@ struct WwiseReduce { ReduceShapeChecker; constexpr int InConsecBytes = sizeof(DataType) * InShape::W; constexpr int NelemPerThread = - (InConsecBytes % 16 == 0) - ? 16 / sizeof(DataType) - : (InConsecBytes % 8 == 0) - ? 8 / sizeof(DataType) - : (InConsecBytes % 4 == 0) - ? 4 / sizeof(DataType) - : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1; + (InConsecBytes % 16 == 0) ? 16 / sizeof(DataType) + : (InConsecBytes % 8 == 0) ? 8 / sizeof(DataType) + : (InConsecBytes % 4 == 0) ? 4 / sizeof(DataType) + : (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) + : 1; constexpr int NonReduceDimLength = UnitOutDims::N * UnitOutDims::C * UnitOutDims::H; @@ -411,32 +416,65 @@ struct WwiseReduce { if constexpr (NelemPerThread > 8) { #pragma unroll for (int i = 8; i < NelemPerThread; i += 8) { - ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]); + ReduceType::template reduce<8>(&reduced[0], &reduced[0], + &reduced[i]); } - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 8) { - ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]); - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<4>(&reduced[0], &reduced[0], + &reduced[4]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 4) { - ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]); - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<2>(&reduced[0], &reduced[0], + &reduced[2]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } else if constexpr (NelemPerThread == 2) { - ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]); + ReduceType::template reduce<1>(&reduced[0], &reduced[0], + &reduced[1]); } if constexpr (InShape::W % ThreadsPerRow != 0) { UnitOp::sync_threads(); } - // final reduction on shared memory using warp shuffle. - reduced[0] = warpsReduce( - reduced[0], tid, smem_per_warp); + // final reduction using warp shuffle. + // PhysicalThreadsPerRow = actual number of HW threads per row. + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for the tile dimensions. " + "Increase NumWarps or decrease Tile H dimension."); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); + if constexpr (PhysicalThreadsPerRow <= Arch::ThreadsPerWarp) { + // All threads for one row are within a single warp. + reduced[0] = + warpReduce(reduced[0]); + } else { + // Threads for one row span multiple warps — need shared memory. + // Each row needs its own section of shared storage to avoid + // aliasing when multiple rows reduce in parallel. + constexpr int WarpsPerRow = + PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + int row_in_tile = tid / PhysicalThreadsPerRow; + reduced[0] = warpsReduce( + reduced[0], tid % PhysicalThreadsPerRow, smem_per_warp, + row_in_tile * WarpsPerRow); + } - // write the result to output. - if (tid % ThreadsPerRow == 0) { + // write the result to output — first thread of each row group. + if (tid % PhysicalThreadsPerRow == 0) { ReduceType::template postReduce<1>(&out[idx_out], &reduced[0], InShape::W); } diff --git a/ark/include/kernels/softmax.h b/ark/include/kernels/softmax.h new file mode 100644 index 00000000..31038a64 --- /dev/null +++ b/ark/include/kernels/softmax.h @@ -0,0 +1,159 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_KERNELS_SOFTMAX_H_ +#define ARK_KERNELS_SOFTMAX_H_ + +#include "reduce.h" + +namespace ark { + +// Static checkers for Softmax shapes. +template +struct SoftmaxShapeChecker { + static_assert(InShape::N == OutShape::N, + "Dimension N of input and output do not match"); + static_assert(InShape::C == OutShape::C, + "Dimension C of input and output do not match"); + static_assert(InShape::H == OutShape::H, + "Dimension H of input and output do not match"); + static_assert(InShape::W == OutShape::W, + "Dimension W of input and output do not match"); +}; + +// Monolithic softmax along the last dimension (W). +// Optimized: single global memory read (register cache), float accumulation, +// fused passes for reduced memory traffic. +// Pass 1: read input → cache in registers, find max +// Pass 2: from registers, compute exp(x - max) and sum (store in registers) +// Pass 3: divide by sum, write output (single global write) +template +struct Softmax { + using UnitOp = UnitOp; + + static_assert(NelemPerThread > 0, "NelemPerThread must be positive"); + + static DEVICE void run(DataType *out, const DataType *in, + int uop_idx, int smem_per_warp) { + using InOutChk = SoftmaxShapeChecker; + + constexpr int NonReduceDimLength = UnitOutDims::NCH; + static_assert( + (UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0); + static_assert(UnitOp::NumThreads % NonReduceDimLength == 0, + "NumThreads must be evenly divisible by " + "NonReduceDimLength for correct physical " + "thread-to-row assignment"); + constexpr int ThreadsPerRow = + (UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength; + + int tid = UnitOp::thread_id(); + int tid_w = (tid * NelemPerThread) % ThreadsPerRow; + int tid_h = ((tid * NelemPerThread) / ThreadsPerRow) % UnitOutDims::H; + int tid_c = ((tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::H) % + UnitOutDims::C; + int tid_n = (tid * NelemPerThread) / ThreadsPerRow / UnitOutDims::CH; + + int un = UnitOp::uop_idx_n(uop_idx); + int uc = UnitOp::uop_idx_c(uop_idx); + int uh = UnitOp::uop_idx_h(uop_idx); + + int idx_out_base = (tid_h + uh * UnitOutDims::H) * OutDims::W + + (tid_c + uc * UnitOutDims::C) * OutDims::HW + + (tid_n + un * UnitOutDims::N) * OutDims::CHW; + int idx_in_base = (tid_h + uh * UnitOutDims::H) * InDims::W + + (tid_c + uc * UnitOutDims::C) * InDims::HW + + (tid_n + un * UnitOutDims::N) * InDims::CHW; + + UnitOp::sync_threads(); + + // Compute warp_offset for multi-row shared memory partitioning + constexpr int PhysicalThreadsPerRow = + UnitOp::NumThreads / NonReduceDimLength; + static_assert(PhysicalThreadsPerRow > 0, + "Not enough threads for tile dimensions"); + static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp || + PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0, + "PhysicalThreadsPerRow must be <= warp size or a " + "multiple of warp size"); + constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp; + static_assert(WarpsPerRow * NonReduceDimLength <= Arch::ThreadsPerWarp, + "Too many warps for ReduceSharedStorage capacity"); + int row_in_tile = tid / PhysicalThreadsPerRow; + int warp_offset = row_in_tile * WarpsPerRow; + + constexpr int OuterIters = + (InShape::W + ThreadsPerRow - 1) / ThreadsPerRow; + constexpr int MaxElemsPerThread = OuterIters * NelemPerThread; + + // --- Pass 1: Read input ONCE, cache in registers, find max --- + float cached[MaxElemsPerThread]; + float max_val = type::Constant::lowest(); + int num_elems = 0; +#pragma unroll + for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + float val = type::Cast::compute(in[idx_in_base + idx_w + j]); + cached[num_elems] = val; + if (val > max_val) max_val = val; + num_elems++; + } + } + } + + // Reduce max across warps in float precision + max_val = warpsReduce( + max_val, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + + // --- Pass 2: Compute exp(x - max) from registers, accumulate sum --- + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < MaxElemsPerThread; i++) { + if (i < num_elems) { + float exp_val = expf(cached[i] - max_val); + cached[i] = exp_val; // reuse cache for exp values + sum += exp_val; + } + } + + // Reduce sum across warps in float precision + sum = warpsReduce( + sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset); + // Note: if all inputs are -inf, sum==0 and output is NaN + // (matches PyTorch behavior). + float inv_sum = 1.0f / sum; + + // --- Pass 3: Divide by sum and write output (single global write) --- + int wi = 0; +#pragma unroll + for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) { +#pragma unroll + for (int j = 0; j < NelemPerThread; j++) { + if (idx_w + j < InShape::W) { + out[idx_out_base + idx_w + j] = + type::Cast::compute(cached[wi] * inv_sum); + wi++; + } + } + } + } +}; + +// Free function wrapper for softmax. +template +DEVICE void softmax(DataType *out, const DataType *in, int uop_idx, + int smem_per_warp) { + Softmax::run( + out, in, uop_idx, smem_per_warp); +} + +} // namespace ark + +#endif // ARK_KERNELS_SOFTMAX_H_ diff --git a/ark/model/model_op.cpp b/ark/model/model_op.cpp index e27b6fff..0cf43269 100644 --- a/ark/model/model_op.cpp +++ b/ark/model/model_op.cpp @@ -13,6 +13,7 @@ #include "ops/ops_communication.hpp" #include "ops/ops_copy.hpp" #include "ops/ops_embedding.hpp" +#include "ops/ops_layernorm.hpp" #include "ops/ops_math.hpp" #include "ops/ops_matmul.hpp" #include "ops/ops_noop.hpp" @@ -22,6 +23,7 @@ #include "ops/ops_reshape.hpp" #include "ops/ops_rope.hpp" #include "ops/ops_scalar.hpp" +#include "ops/ops_softmax.hpp" #include "ops/ops_tensor.hpp" #include "ops/ops_transpose.hpp" #include "utils/utils_string.hpp" @@ -59,6 +61,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Embedding); MODEL_OP_TYPE_REGISTER(Exp); MODEL_OP_TYPE_REGISTER(Gelu); + MODEL_OP_TYPE_REGISTER(LayerNorm); MODEL_OP_TYPE_REGISTER(Matmul); MODEL_OP_TYPE_REGISTER(MatmulGelu); MODEL_OP_TYPE_REGISTER(MatmulScale); @@ -81,6 +84,7 @@ const ModelOpType ModelOpT::from_name(const std::string &type_name) { MODEL_OP_TYPE_REGISTER(Send); MODEL_OP_TYPE_REGISTER(SendDone); MODEL_OP_TYPE_REGISTER(Sigmoid); + MODEL_OP_TYPE_REGISTER(Softmax); MODEL_OP_TYPE_REGISTER(Sqrt); MODEL_OP_TYPE_REGISTER(Sub); MODEL_OP_TYPE_REGISTER(Tensor); diff --git a/ark/ops/ops_layernorm.cpp b/ark/ops/ops_layernorm.cpp new file mode 100644 index 00000000..eab547c7 --- /dev/null +++ b/ark/ops/ops_layernorm.cpp @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_layernorm.hpp" + +#include "logging.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpLayerNorm::ModelOpLayerNorm(ModelTensorRef input, ModelTensorRef gamma, + ModelTensorRef beta, ModelTensorRef output) + : ModelOp("LayerNorm") { + check_null(input); + check_null(gamma); + check_null(beta); + + // gamma and beta must be 1-D tensors matching the last dimension of input + DimType norm_dim = input->shape()[-1]; + if (gamma->shape().nelems() != norm_dim) { + ERR(ModelError, "gamma size ", gamma->shape().nelems(), + " does not match last dimension of input ", norm_dim); + } + if (beta->shape().nelems() != norm_dim) { + ERR(ModelError, "beta size ", beta->shape().nelems(), + " does not match last dimension of input ", norm_dim); + } + check_match_data_type(input, gamma); + check_match_data_type(input, beta); + + if (output) { + check_match_data_type(input, output); + check_match_shape(output, input->shape()); + } else { + output = std::make_shared( + input->data_type(), std::make_shared(), + input->shape()); + } + ModelTensorRef result = std::make_shared(*output); + read_tensors_ = {input, gamma, beta}; + write_tensors_ = {output}; + result_tensors_ = {result}; + verify(); +} + +std::string ModelOpLayerNorm::impl_name(const Json &config) const { + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); + int num_warps = config.at("NumWarps"); + int sram_bytes = config.at("SramBytes"); + Dims unit_out_dims(config.at("Tile").get>()); + + std::vector template_args = { + vec_string(read_tensors_[0]->strides().dims4()), + vec_string(read_tensors_[0]->shape().dims4()), + vec_string(write_tensors_[0]->strides().dims4()), + vec_string(write_tensors_[0]->shape().dims4()), + vec_string(unit_out_dims.dims4()), + std::to_string(num_warps), + std::to_string(sram_bytes), + }; + + // Add NelemPerThread if specified and > 1 + if (config.contains("NelemPerThread")) { + int nelem = config.at("NelemPerThread"); + if (nelem > 1) { + template_args.push_back(std::to_string(nelem)); + } + } + + return function_name_string("layernorm_affine", template_args); +} + +std::vector ModelOpLayerNorm::impl_args( + [[maybe_unused]] const Json &config) const { + // Order must match kernel function signature: + // layernorm_affine(out, in, gamma, beta, uop_idx, smem_per_warp) + return {result_tensors_[0], read_tensors_[0], read_tensors_[1], + read_tensors_[2]}; +} + +Json ModelOpLayerNorm::default_config( + [[maybe_unused]] const ArchRef arch) const { + Json config; + config["NumWarps"] = 1; + config["SramBytes"] = 256; + // The tile must cover the entire W dimension since layernorm reduces along W. + // Each task processes one complete row. + auto shape = result_tensors_[0]->shape().dims4(); + config["Tile"] = {1, 1, 1, static_cast(shape[3])}; + // One task per row (N * C * H) + config["NumTasks"] = shape[0] * shape[1] * shape[2]; + return config; +} + +Tensor Model::layernorm(Tensor input, Tensor gamma, Tensor beta, Tensor output, + const std::string &name) { + return impl_ + ->create_op(name, input.ref_, gamma.ref_, beta.ref_, + output.ref_) + ->result_tensors()[0]; +} + +} // namespace ark diff --git a/ark/ops/ops_layernorm.hpp b/ark/ops/ops_layernorm.hpp new file mode 100644 index 00000000..5db37ce9 --- /dev/null +++ b/ark/ops/ops_layernorm.hpp @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_LAYERNORM_HPP_ +#define ARK_OPS_LAYERNORM_HPP_ + +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpLayerNorm : public ModelOp { + public: + ModelOpLayerNorm() = default; + ModelOpLayerNorm(ModelTensorRef input, ModelTensorRef gamma, + ModelTensorRef beta, ModelTensorRef output); + + std::string impl_name(const Json &config) const override; + + std::vector impl_args(const Json &config) const override; + + Json default_config(const ArchRef arch = ARCH_ANY) const override; +}; + +} // namespace ark + +#endif // ARK_OPS_LAYERNORM_HPP_ diff --git a/ark/ops/ops_layernorm_test.cpp b/ark/ops/ops_layernorm_test.cpp new file mode 100644 index 00000000..0dfd63bd --- /dev/null +++ b/ark/ops/ops_layernorm_test.cpp @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include "ops_test_common.hpp" + +// Baseline: LayerNorm with affine (gamma, beta). +// For each row (last dim = W), compute: +// mean = sum(x) / W +// var = sum((x - mean)^2) / W +// out = gamma * (x - mean) / sqrt(var + eps) + beta +// eps = 1e-5 (standard). +template +void baseline_layernorm(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &input_shapes, int) { + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + T *gamma = static_cast(inputs[1]); + T *beta = static_cast(inputs[2]); + + ark::Dims osh = output_shapes[0]; + ark::DimType total = osh.nelems(); + ark::DimType W = osh[-1]; + ark::DimType num_rows = total / W; + constexpr float eps = 1e-5f; + + for (ark::DimType row = 0; row < num_rows; ++row) { + T *row_in = input + row * W; + T *row_out = out + row * W; + + // mean + float mean = 0; + for (ark::DimType j = 0; j < W; ++j) { + mean += static_cast(row_in[j]); + } + mean /= static_cast(W); + + // variance + float var = 0; + for (ark::DimType j = 0; j < W; ++j) { + float diff = static_cast(row_in[j]) - mean; + var += diff * diff; + } + var /= static_cast(W); + + float inv_std = 1.0f / std::sqrt(var + eps); + + // normalize + affine + for (ark::DimType j = 0; j < W; ++j) { + float normed = + (static_cast(row_in[j]) - mean) * inv_std; + row_out[j] = + T(static_cast(gamma[j]) * normed + + static_cast(beta[j])); + } + } +} + +ark::unittest::State test_layernorm_fp32() { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor gamma = m.tensor({1024}, ark::FP32); + ark::Tensor beta = m.tensor({1024}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_fp32", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_fp16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 768}, ark::FP16); + ark::Tensor gamma = m.tensor({768}, ark::FP16); + ark::Tensor beta = m.tensor({768}, ark::FP16); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_fp16", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_bf16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 768}, ark::BF16); + ark::Tensor gamma = m.tensor({768}, ark::BF16); + ark::Tensor beta = m.tensor({768}, ark::BF16); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_bf16", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_batch() { + // Higher-dimensional input: [B, S, D] + ark::Model m; + ark::Tensor input = m.tensor({2, 8, 512}, ark::FP32); + ark::Tensor gamma = m.tensor({512}, ark::FP32); + ark::Tensor beta = m.tensor({512}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_batch", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_small_row() { + // Small last dim — tests edge case where W < warp size + ark::Model m; + ark::Tensor input = m.tensor({8, 16}, ark::FP32); + ark::Tensor gamma = m.tensor({16}, ark::FP32); + ark::Tensor beta = m.tensor({16}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_small_row", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_w1() { + // W=1 boundary: variance=0, epsilon dominates + ark::Model m; + ark::Tensor input = m.tensor({4, 1}, ark::FP32); + ark::Tensor gamma = m.tensor({1}, ark::FP32); + ark::Tensor beta = m.tensor({1}, ark::FP32); + ark::Tensor out = m.layernorm(input, gamma, beta); + + auto result = ark::op_test("layernorm_w1", m, {input, gamma, beta}, + {out}, baseline_layernorm); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-4f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_layernorm_invalid() { + // gamma/beta shape mismatch + { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor gamma = m.tensor({512}, ark::FP32); // wrong size + ark::Tensor beta = m.tensor({1024}, ark::FP32); + UNITTEST_THROW(m.layernorm(input, gamma, beta), ark::ModelError); + } + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_layernorm_fp32); + UNITTEST(test_layernorm_fp16); + UNITTEST(test_layernorm_bf16); + UNITTEST(test_layernorm_batch); + UNITTEST(test_layernorm_small_row); + UNITTEST(test_layernorm_w1); + UNITTEST(test_layernorm_invalid); + return ark::unittest::SUCCESS; +} diff --git a/ark/ops/ops_softmax.cpp b/ark/ops/ops_softmax.cpp new file mode 100644 index 00000000..135253c4 --- /dev/null +++ b/ark/ops/ops_softmax.cpp @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include "ops_softmax.hpp" + +#include "logging.hpp" +#include "ops_common.hpp" + +namespace ark { + +ModelOpSoftmax::ModelOpSoftmax(ModelTensorRef input, ModelTensorRef output) + : ModelOp("Softmax") { + check_null(input); + + if (output) { + check_match_data_type(input, output); + check_match_shape(output, input->shape()); + } else { + output = std::make_shared( + input->data_type(), std::make_shared(), + input->shape()); + } + ModelTensorRef result = std::make_shared(*output); + read_tensors_ = {input}; + write_tensors_ = {output}; + result_tensors_ = {result}; + verify(); +} + +std::string ModelOpSoftmax::impl_name(const Json &config) const { + check_fields_config(config, {"NumWarps", "SramBytes", "Tile"}); + int num_warps = config.at("NumWarps"); + int sram_bytes = config.at("SramBytes"); + Dims unit_out_dims(config.at("Tile").get>()); + + std::vector template_args = { + vec_string(read_tensors_[0]->strides().dims4()), + vec_string(read_tensors_[0]->shape().dims4()), + vec_string(write_tensors_[0]->strides().dims4()), + vec_string(write_tensors_[0]->shape().dims4()), + vec_string(unit_out_dims.dims4()), + std::to_string(num_warps), + std::to_string(sram_bytes), + }; + + // Add NelemPerThread if specified and > 1 + if (config.contains("NelemPerThread")) { + int nelem = config.at("NelemPerThread"); + if (nelem > 1) { + template_args.push_back(std::to_string(nelem)); + } + } + + return function_name_string("softmax", template_args); +} + +std::vector ModelOpSoftmax::impl_args( + [[maybe_unused]] const Json &config) const { + // Order: out, in + return {result_tensors_[0], read_tensors_[0]}; +} + +Json ModelOpSoftmax::default_config( + [[maybe_unused]] const ArchRef arch) const { + Json config; + config["NumWarps"] = 1; + config["SramBytes"] = 256; + auto shape = result_tensors_[0]->shape().dims4(); + config["Tile"] = {1, 1, 1, static_cast(shape[3])}; + config["NumTasks"] = shape[0] * shape[1] * shape[2]; + return config; +} + +Tensor Model::softmax(Tensor input, Tensor output, const std::string &name) { + return impl_ + ->create_op(name, input.ref_, output.ref_) + ->result_tensors()[0]; +} + +} // namespace ark diff --git a/ark/ops/ops_softmax.hpp b/ark/ops/ops_softmax.hpp new file mode 100644 index 00000000..00ecd550 --- /dev/null +++ b/ark/ops/ops_softmax.hpp @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef ARK_OPS_SOFTMAX_HPP_ +#define ARK_OPS_SOFTMAX_HPP_ + +#include "model/model_op.hpp" + +namespace ark { + +class ModelOpSoftmax : public ModelOp { + public: + ModelOpSoftmax() = default; + ModelOpSoftmax(ModelTensorRef input, ModelTensorRef output); + + std::string impl_name(const Json &config) const override; + + std::vector impl_args(const Json &config) const override; + + Json default_config(const ArchRef arch = ARCH_ANY) const override; +}; + +} // namespace ark + +#endif // ARK_OPS_SOFTMAX_HPP_ diff --git a/ark/ops/ops_softmax_test.cpp b/ark/ops/ops_softmax_test.cpp new file mode 100644 index 00000000..204cf202 --- /dev/null +++ b/ark/ops/ops_softmax_test.cpp @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include + +#include "ops_test_common.hpp" + +// Baseline: row-wise softmax. +// For each row (last dim = W): +// max_val = max(row) +// exp_sum = sum(exp(x - max_val)) +// out[j] = exp(x[j] - max_val) / exp_sum +template +void baseline_softmax(std::vector &outputs, + const std::vector &output_shapes, + const std::vector &inputs, + const std::vector &, int) { + T *out = static_cast(outputs[0]); + T *input = static_cast(inputs[0]); + + ark::Dims osh = output_shapes[0]; + ark::DimType total = osh.nelems(); + ark::DimType W = osh[-1]; + ark::DimType num_rows = total / W; + + for (ark::DimType row = 0; row < num_rows; ++row) { + T *row_in = input + row * W; + T *row_out = out + row * W; + + // pass 1: max + float max_val = static_cast(row_in[0]); + for (ark::DimType j = 1; j < W; ++j) { + float v = static_cast(row_in[j]); + if (v > max_val) max_val = v; + } + + // pass 2: exp and sum + float exp_sum = 0; + for (ark::DimType j = 0; j < W; ++j) { + float e = std::exp(static_cast(row_in[j]) - max_val); + exp_sum += e; + } + + // pass 3: normalize + for (ark::DimType j = 0; j < W; ++j) { + float e = std::exp(static_cast(row_in[j]) - max_val); + row_out[j] = T(e / exp_sum); + } + } +} + +ark::unittest::State test_softmax_fp32() { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = + ark::op_test("softmax_fp32", m, {input}, {out}, baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_fp16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 512}, ark::FP16); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_fp16", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-3f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_bf16() { + ark::Model m; + ark::Tensor input = m.tensor({2, 512}, ark::BF16); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_bf16", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 5e-2f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_batch() { + // Higher-dimensional: [B, H, S, S] — attention pattern + ark::Model m; + ark::Tensor input = m.tensor({2, 12, 64, 64}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_batch", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_small_row() { + // Small last dim — tests edge case where W < warp size + ark::Model m; + ark::Tensor input = m.tensor({8, 16}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_small_row", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_w1() { + // W=1 boundary: softmax output must be exactly 1.0 + ark::Model m; + ark::Tensor input = m.tensor({4, 1}, ark::FP32); + ark::Tensor out = m.softmax(input); + + auto result = ark::op_test("softmax_w1", m, {input}, {out}, + baseline_softmax); + UNITTEST_LOG(result); + UNITTEST_TRUE(result.max_diff[0] < 1e-5f); + return ark::unittest::SUCCESS; +} + +ark::unittest::State test_softmax_invalid() { + // Output shape mismatch + { + ark::Model m; + ark::Tensor input = m.tensor({4, 1024}, ark::FP32); + ark::Tensor bad_out = m.tensor({4, 512}, ark::FP32); // wrong W + UNITTEST_THROW(m.softmax(input, bad_out), ark::ModelError); + } + return ark::unittest::SUCCESS; +} + +int main() { + ark::init(); + UNITTEST(test_softmax_fp32); + UNITTEST(test_softmax_fp16); + UNITTEST(test_softmax_bf16); + UNITTEST(test_softmax_batch); + UNITTEST(test_softmax_small_row); + UNITTEST(test_softmax_w1); + UNITTEST(test_softmax_invalid); + return ark::unittest::SUCCESS; +}