High-performance CUDA kernels for core ML primitives, written from scratch. Each kernel is benchmarked against cuBLAS/cuDNN baselines and profiled with Nsight Compute to understand the hardware bottlenecks at each optimization stage.
Hardware: NVIDIA GTX 1650 Super (Turing, SM 7.5, 4 GB GDDR6, 192 GB/s, 4.4 TFLOPS FP32)
CUDA: 12.x, C++ Standard: 17, Test framework: GoogleTest
| Primitive | Variants | Status |
|---|---|---|
| GEMM | Naive, Tiled (shared memory) | Done |
| Softmax | Naive 1D/2D, Online 1D/2D (single-pass), Fused (online + shared memory reduction) | Done |
| Attention | Fused MHA | In progress |
All benchmarks run at M=N=K (square matrices). Timing uses CUDA events averaged over 50 iterations after 10 warmup runs. The data transfers are excluded from the benchmark to focus only on the compute performance.
| Size | Naive | Tiled | cuBLAS | Tiled % of cuBLAS | Tiled % peak |
|---|---|---|---|---|---|
| 512 | 4.771 ms | 0.567 ms | 0.134 ms | 23.6% | 10.8% |
| 1024 | 35.204 ms | 4.173 ms | 0.780 ms | 18.7% | 11.7% |
| 2048 | 239.842 ms | 28.464 ms | 3.753 ms | 13.2% | 13.7% |
Tiling delivers an 8.4x speedup over naive at
Nsight Compute analysis (1024×1024, naive kernel):
- DRAM throughput: ~12%, severely memory-latency bound
- L1 hit rate: ~4%, no data reuse, every load misses cache
- Warp stall (long scoreboard): dominant, threads spend most cycles waiting for global memory After tiling, L1 hit rate rises to ~60% and warp stalls drop significantly, confirming that shared memory reuse is the primary lever for this kernel.
| Size | Naive | Online | Fused | Fused vs Naive | Online vs Naive |
|---|---|---|---|---|---|
| 512×512 | 0.409 ms | 0.250 ms | 0.027 ms | 15.1x | 1.6x |
| 1024×1024 | 0.819 ms | 0.506 ms | 0.074 ms | 11.1x | 1.6x |
| 2048×2048 | 1.638 ms | 1.013 ms | 0.248 ms | 6.6x | 1.6x |
The online kernel consistently achieves ~1.6x over naive by eliminating one full pass over the data. The fused kernel adds parallel shared memory reductions across 256 threads per row, yielding a further 4 to 10x on top, up to 15x total over naive at small sizes where the parallelism fits well within one block.
Before cloning the repository and building/executing the project, make sure to have the following required dependencies installed:
- CUDA Toolkit 12.x
- CMake
$\geq$ 3.18 - GCC
$\geq$ 9
mkdir build && cd build
cmake ..
make -j$(nproc)GoogleTest will automatically be fetched by CMake
# run all correctness tests
cd build && ctest --output-on-failure
# run a specific suite
./test_gemm
./test_softmaxAll kernels are validated against a CPU reference using relative homemade tolerance
function (atol + rtol * |ref|) to account for non-associative floating point
accumulation across K.
cd build
./bench_gemm --show-results
./bench_softmax --show-resultsThe
--show-resultsflag prints the benchmark results to the console. You can also choose to not use it for a silent run (for example when profiling the kernel).
Profiles are collected with Nsight Compute and saved as .ncu-rep files in profiles/.
The current profiles in profiles/ have been collected using the following command:
ncu --kernel-name {kernel_name} \
--set full \
-o profiles/{kernel_name} \
./build/{benchmark_file}If you want to go deeper into the kernel performances and bottlenecks, you can take a look at the saved profiles using:
ncu-ui profiles/{kernel_name}.ncu-repFor better understanding, we will note
$A \in \mathbb{R}^{M \times K}$ and$B \in \mathbb{R}^{K \times N}$
Naive: one thread computes one output element via a sequential dot product over K. Every thread independently loads a full row of A and a full column of B from global memory with no reuse.
Tiled: threads in a 16×16 block collaboratively load 16×16 tiles of A and B into shared memory, then compute partial dot products. Each element of A and B is read from global memory once per tile instead of once per output element, reducing memory traffic by a factor of the tile size.
The roofline chart from Nsight Compute shows the naive kernel sitting far below the memory bandwidth ceiling, This is a memory-latency bound, and not a bandwidth bound, because threads stall waiting for individual global loads that miss L1 and L2. Tiling improves L1 hit rate and moves the kernel closer to the bandwidth roof.
The challenge in a parallel softmax is that the standard formula requires two passes over the data (one for max, one for the weighted sum) and the normalization constant is a global reduction across all elements in the row.
Naive: one thread per row, two sequential loops. Correct but uses zero parallelism within a row. We also substract by the max value of the row to avoid overflow with the exponential function.
Online (single-pass using exponential property): uses the identity that allows the running sum to be rescaled when a new maximum is encountered:
This halves global memory reads, using one pass instead of two. The same trick is used in FlashAttention to avoid materializing the full attention score matrix.
Fused: online softmax split across 256 threads per row, with two tree reductions in shared memory:
- Each thread computes a local maximum and norm over its chunk of the row using the online update.
- A parallel reduction finds the global max of the row across all threads.
- Each thread's local norm is rescaled, then summed via a second reduction.
- Threads write the normalized output in-place
To ensure correctness across warps within the block, we make sure to synchronize
all threads using the __syncthreads() function.
- Pierre SCHWEITZER (schwp)