Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ark/include/ark/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ class Model : public ModelGraph {
// Sigmoid activation
Tensor sigmoid(Tensor input, Tensor output = NullTensor,
const std::string &name = "");
// Layer normalization
Tensor layernorm(Tensor input, Tensor gamma, Tensor beta,
Tensor output = NullTensor,
const std::string &name = "");
// Softmax
Tensor softmax(Tensor input, Tensor output = NullTensor,
const std::string &name = "");
// Performs rotary position embedding (RoPE) on the `input` tensor
Tensor rope(Tensor input, Tensor other, Tensor output = NullTensor,
const std::string &name = "");
Expand Down
1 change: 1 addition & 0 deletions ark/include/kernels/ark_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "noop.h"
#include "reduce.h"
#include "scalar.h"
#include "softmax.h"
#include "transpose.h"

#endif // ARK_KERNELS_H_
Expand Down
156 changes: 102 additions & 54 deletions ark/include/kernels/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,35 @@ struct LayerNormShapeChecker {
"Dimension C of input and output do not match");
static_assert(InShape::H == OutShape::H,
"Dimension H of input and output do not match");
static_assert(OutShape::W == OutShape::W,
static_assert(InShape::W == OutShape::W,
"Dimension W of input and output do not match");
};

// Perform layer normalization on input and write the result on output.
// When HasGammaBeta is true, applies affine transform: gamma * normalized + beta.
// gamma and beta are 1-D tensors of size W (the normalization dimension).
//
// Optimized: single global memory read (register cache), float accumulation,
// multi-element-per-thread unrolling for reduced loop overhead.
template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumWarps, int SmemBytes,
typename DataType, int NelemPerThread>
typename DataType, int NelemPerThread, bool HasGammaBeta>
struct LayerNorm {
using UnitOp = UnitOp<OutDims, OutShape, UnitOutDims, NumWarps, SmemBytes>;

static_assert(NelemPerThread > 0, "NelemPerThread must be positive");
static DEVICE void run(DataType *out, const DataType *in, int uop_idx,
int smem_per_warp) {
static DEVICE void run(DataType *out, const DataType *in,
const DataType *gamma, const DataType *beta,
int uop_idx, int smem_per_warp) {
using InOutChk = LayerNormShapeChecker<InShape, OutShape>;

constexpr int NonReduceDimLength = UnitOutDims::NCH;
// The reduction dimension of the final stage.
// Assume this division is always exact.
static_assert(
(UnitOp::NumThreads * NelemPerThread) % NonReduceDimLength == 0);
// If we reshape the input into a 2D matrix (NCH x W), NumThreads
// threads compute NCH rows, and each row's sum is computed by
// ThreadsPerRow threads. If ThreadsPerRow is larger than warp size, we
// need to use shared memory to reduce the result of each warp.
static_assert(UnitOp::NumThreads % NonReduceDimLength == 0,
"NumThreads must be evenly divisible by "
"NonReduceDimLength for correct physical "
"thread-to-row assignment");
constexpr int ThreadsPerRow =
(UnitOp::NumThreads * NelemPerThread) / NonReduceDimLength;

Expand All @@ -65,63 +69,107 @@ struct LayerNorm {

UnitOp::sync_threads();

DataType mean;
DataType cmp;
ReduceTypeMean::template identity<1>(&mean);
ReduceTypeMean::template identity<1>(&cmp);
// Compute warp_offset for multi-row shared memory partitioning
constexpr int PhysicalThreadsPerRow =
UnitOp::NumThreads / NonReduceDimLength;
static_assert(PhysicalThreadsPerRow > 0,
"Not enough threads for tile dimensions");
static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp ||
PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0,
"PhysicalThreadsPerRow must be <= warp size or a "
"multiple of warp size");
constexpr int WarpsPerRow = PhysicalThreadsPerRow / Arch::ThreadsPerWarp;
static_assert(WarpsPerRow * NonReduceDimLength <= Arch::ThreadsPerWarp,
"Too many warps for ReduceSharedStorage capacity");
int row_in_tile = tid / PhysicalThreadsPerRow;
int warp_offset = row_in_tile * WarpsPerRow;

constexpr int OuterIters =
(InShape::W + ThreadsPerRow - 1) / ThreadsPerRow;
constexpr int MaxElemsPerThread = OuterIters * NelemPerThread;

// --- Pass 1: Read input ONCE from global memory, cache in registers,
// accumulate sum for mean (float accumulation) ---
float cached[MaxElemsPerThread];
float sum = 0.0f;
int num_elems = 0;
#pragma unroll
for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_w;
DataType in_val = in[idx_in];
DataType val = type::Sub::compute(in_val, cmp);
DataType tmp = type::Add::compute(mean, val);
cmp = type::Sub::compute(type::Sub::compute(tmp, mean), val);
mean = tmp;
#pragma unroll
for (int j = 0; j < NelemPerThread; j++) {
if (idx_w + j < InShape::W) {
float val = type::Cast::compute<float>(in[idx_in_base + idx_w + j]);
cached[num_elems] = val;
sum += val;
num_elems++;
}
}
}
// final reduction on shared memory using warp shuffle.
mean = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
mean, tid, smem_per_warp);
// get the average result.
ReduceTypeMean::template postReduce<1>(&mean, &mean, UnitOutDims::W);
DataType variance;
ReduceTypeMean::template identity<1>(&variance);
ReduceTypeMean::template identity<1>(&cmp);
// get the variance
for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_w;
DataType in_val = in[idx_in];
DataType val = type::Sub::compute(
type::Mul::compute(type::Sub::compute(in_val, mean),
type::Sub::compute(in_val, mean)),
cmp);
DataType tmp = type::Add::compute(variance, val);
cmp = type::Sub::compute(type::Sub::compute(tmp, variance), val);
variance = tmp;

// Reduce sum across physical threads (each thread already accumulated
// NelemPerThread elements locally, so we reduce PhysicalThreadsPerRow threads).
sum = warpsReduce<ReduceTypeSum, UnitOp, PhysicalThreadsPerRow>(
sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset);
float fmean = sum / static_cast<float>(InShape::W);

// --- Pass 2: Compute variance from cached registers (no global read) ---
float var_sum = 0.0f;
#pragma unroll
for (int i = 0; i < MaxElemsPerThread; i++) {
if (i < num_elems) {
float diff = cached[i] - fmean;
var_sum += diff * diff;
}
}
variance = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
variance, tid, smem_per_warp);
ReduceTypeMean::template postReduce<1>(&variance, &variance,
UnitOutDims::W);
variance = type::Rsqrt::compute(
type::Add::compute(variance, type::Cast::compute<DataType>(1e-5f)));
// the output is input / sqrt(mean_square)

var_sum = warpsReduce<ReduceTypeSum, UnitOp, PhysicalThreadsPerRow>(
var_sum, tid % PhysicalThreadsPerRow, smem_per_warp, warp_offset);
float inv_std = rsqrtf(var_sum / static_cast<float>(InShape::W) + 1e-5f);

// --- Pass 3: Normalize and write output (from registers) ---
int wi = 0;
#pragma unroll
for (int idx_w = tid_w; idx_w < InShape::W; idx_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_w;
int idx_out = idx_out_base + idx_w;
out[idx_out] = type::Mul::compute(
type::Sub::compute(in[idx_in], mean), variance);
#pragma unroll
for (int j = 0; j < NelemPerThread; j++) {
if (idx_w + j < InShape::W) {
float normalized = (cached[wi] - fmean) * inv_std;
if constexpr (HasGammaBeta) {
normalized = normalized *
type::Cast::compute<float>(gamma[idx_w + j]) +
type::Cast::compute<float>(beta[idx_w + j]);
}
out[idx_out_base + idx_w + j] = type::Cast::compute<DataType>(normalized);
wi++;
}
}
}
}
};

// Free function for layernorm without gamma/beta. Currently unused by the op
// layer (which always uses layernorm_affine), but retained for kernel-level API
// completeness and potential future use (e.g., non-affine LayerNorm op).
template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumWarps, int SmemBytes,
typename DataType>
int NelemPerThread = 1, typename DataType>
DEVICE void layernorm(DataType *out, const DataType *in, int uop_idx,
int smem_per_warp) {
constexpr int NelemPerThread = 1;
LayerNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumWarps,
SmemBytes, DataType, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
SmemBytes, DataType, NelemPerThread, false>::run(
out, in, nullptr, nullptr, uop_idx, smem_per_warp);
}

// Free function for layernorm with gamma/beta affine transform.
template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumWarps, int SmemBytes,
int NelemPerThread = 1, typename DataType>
DEVICE void layernorm_affine(DataType *out, const DataType *in,
const DataType *gamma, const DataType *beta,
int uop_idx, int smem_per_warp) {
LayerNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumWarps,
SmemBytes, DataType, NelemPerThread, true>::run(
out, in, gamma, beta, uop_idx, smem_per_warp);
}

} // namespace ark
Expand Down
88 changes: 63 additions & 25 deletions ark/include/kernels/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,28 @@ DEVICE bf16 warpReduce(bf16 val) {
}

// Reduce single-precision `val` within multiple warps.
// @param warp_offset Offset into shared storage to avoid aliasing when
// multiple independent row groups share the same shared memory.
template <typename ReduceType, typename UnitOp, int LanesNum, typename DataType>
DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp) {
DEVICE DataType warpsReduce(DataType val, int tid, int smem_per_warp,
int warp_offset = 0) {
val = warpReduce<ReduceType, LanesNum>(val);
if constexpr (LanesNum > Arch::ThreadsPerWarp) {
// Barrier before shared memory write to prevent a back-to-back
// warpsReduce call from overwriting storage before all warps
// finish reading the previous result.
UnitOp::sync_threads();
ReduceSharedStorage<DataType> *shared =
UnitOp::template shared_memory<ReduceSharedStorage<DataType>>(
smem_per_warp);
int laneId = tid & (Arch::ThreadsPerWarp - 1);
int warpId = tid >> math::log2_up<Arch::ThreadsPerWarp>::value;
if (laneId == 0) {
shared->storage[warpId] = val;
shared->storage[warpId + warp_offset] = val;
}
UnitOp::sync_threads();
if (laneId < (LanesNum >> math::log2_up<Arch::ThreadsPerWarp>::value)) {
val = shared->storage[laneId];
val = shared->storage[laneId + warp_offset];
} else {
ReduceType::template identity<1>(&val);
}
Expand Down Expand Up @@ -357,13 +364,11 @@ struct WwiseReduce {
ReduceShapeChecker<InShape, OutShape, UnitOutDims, Axis>;
constexpr int InConsecBytes = sizeof(DataType) * InShape::W;
constexpr int NelemPerThread =
(InConsecBytes % 16 == 0)
? 16 / sizeof(DataType)
: (InConsecBytes % 8 == 0)
? 8 / sizeof(DataType)
: (InConsecBytes % 4 == 0)
? 4 / sizeof(DataType)
: (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType) : 1;
(InConsecBytes % 16 == 0) ? 16 / sizeof(DataType)
: (InConsecBytes % 8 == 0) ? 8 / sizeof(DataType)
: (InConsecBytes % 4 == 0) ? 4 / sizeof(DataType)
: (InConsecBytes % 2 == 0) ? 2 / sizeof(DataType)
: 1;

constexpr int NonReduceDimLength =
UnitOutDims::N * UnitOutDims::C * UnitOutDims::H;
Expand Down Expand Up @@ -411,32 +416,65 @@ struct WwiseReduce {
if constexpr (NelemPerThread > 8) {
#pragma unroll
for (int i = 8; i < NelemPerThread; i += 8) {
ReduceType::template reduce<8>(&reduced[0], &reduced[0], &reduced[i]);
ReduceType::template reduce<8>(&reduced[0], &reduced[0],
&reduced[i]);
}
ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]);
ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]);
ReduceType::template reduce<4>(&reduced[0], &reduced[0],
&reduced[4]);
ReduceType::template reduce<2>(&reduced[0], &reduced[0],
&reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0],
&reduced[1]);
} else if constexpr (NelemPerThread == 8) {
ReduceType::template reduce<4>(&reduced[0], &reduced[0], &reduced[4]);
ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]);
ReduceType::template reduce<4>(&reduced[0], &reduced[0],
&reduced[4]);
ReduceType::template reduce<2>(&reduced[0], &reduced[0],
&reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0],
&reduced[1]);
} else if constexpr (NelemPerThread == 4) {
ReduceType::template reduce<2>(&reduced[0], &reduced[0], &reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]);
ReduceType::template reduce<2>(&reduced[0], &reduced[0],
&reduced[2]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0],
&reduced[1]);
} else if constexpr (NelemPerThread == 2) {
ReduceType::template reduce<1>(&reduced[0], &reduced[0], &reduced[1]);
ReduceType::template reduce<1>(&reduced[0], &reduced[0],
&reduced[1]);
}

if constexpr (InShape::W % ThreadsPerRow != 0) {
UnitOp::sync_threads();
}

// final reduction on shared memory using warp shuffle.
reduced[0] = warpsReduce<ReduceType, UnitOp, UnitOp::NumThreads>(
reduced[0], tid, smem_per_warp);
// final reduction using warp shuffle.
// PhysicalThreadsPerRow = actual number of HW threads per row.
constexpr int PhysicalThreadsPerRow =
UnitOp::NumThreads / NonReduceDimLength;
static_assert(PhysicalThreadsPerRow > 0,
"Not enough threads for the tile dimensions. "
"Increase NumWarps or decrease Tile H dimension.");
static_assert(PhysicalThreadsPerRow <= Arch::ThreadsPerWarp ||
PhysicalThreadsPerRow % Arch::ThreadsPerWarp == 0,
"PhysicalThreadsPerRow must be <= warp size or a "
"multiple of warp size");
if constexpr (PhysicalThreadsPerRow <= Arch::ThreadsPerWarp) {
// All threads for one row are within a single warp.
reduced[0] =
warpReduce<ReduceType, PhysicalThreadsPerRow>(reduced[0]);
} else {
// Threads for one row span multiple warps — need shared memory.
// Each row needs its own section of shared storage to avoid
// aliasing when multiple rows reduce in parallel.
constexpr int WarpsPerRow =
PhysicalThreadsPerRow / Arch::ThreadsPerWarp;
int row_in_tile = tid / PhysicalThreadsPerRow;
reduced[0] = warpsReduce<ReduceType, UnitOp, PhysicalThreadsPerRow>(
reduced[0], tid % PhysicalThreadsPerRow, smem_per_warp,
row_in_tile * WarpsPerRow);
}

// write the result to output.
if (tid % ThreadsPerRow == 0) {
// write the result to output — first thread of each row group.
if (tid % PhysicalThreadsPerRow == 0) {
ReduceType::template postReduce<1>(&out[idx_out], &reduced[0],
InShape::W);
}
Expand Down
Loading
Loading