diff --git a/.gitignore b/.gitignore index 1daaa46d12..f6e9b94480 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,4 @@ uv.lock .cache/ # vim *.swp +results/ diff --git a/CMakeLists.txt b/CMakeLists.txt index a2395d02f6..6eb3d739ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,6 +244,35 @@ if(WIN32) endif() if(MLX_BUILD_CPU) + # ----------------------------- x86 SIMD -------------------------------- + if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64|i[3-9]86") + include(CheckCXXCompilerFlag) + check_cxx_compiler_flag("-mavx2" HAS_AVX2) + check_cxx_compiler_flag("-mfma" HAS_FMA) + check_cxx_compiler_flag("-mf16c" HAS_F16C) + + if(HAS_AVX2 + AND HAS_FMA + AND HAS_F16C) + message( + STATUS "Compiler supports AVX2/FMA/F16C - enabling AVX2 SIMD backend") + target_compile_options(mlx PRIVATE -mavx2 -mfma -mf16c) + target_compile_definitions(mlx PRIVATE MLX_USE_AVX2) + else() + message( + STATUS "Missing required x86 SIMD support - using base SIMD backend") + if(NOT HAS_AVX2) + message(STATUS " Missing: AVX2") + endif() + if(NOT HAS_FMA) + message(STATUS " Missing: FMA") + endif() + if(NOT HAS_F16C) + message(STATUS " Missing: F16C") + endif() + endif() + endif() + find_library(ACCELERATE_LIBRARY Accelerate) if(ACCELERATE_LIBRARY) message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}") diff --git a/benchmarks/python/blas/bench_gemm.py b/benchmarks/python/blas/bench_gemm.py index ee358a95d8..ee643f3be2 100644 --- a/benchmarks/python/blas/bench_gemm.py +++ b/benchmarks/python/blas/bench_gemm.py @@ -10,18 +10,44 @@ import numpy as np import torch -device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) -device_name = device_name.decode("utf-8").strip("\n") +try: + device_name = ( + subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], stderr=subprocess.DEVNULL + ) + .decode("utf-8") + .strip() + ) +except (subprocess.CalledProcessError, FileNotFoundError): + device_name = "unknown" + +if torch.backends.mps.is_available(): + torch_device = "mps" + torch_sync = torch.mps.synchronize +elif torch.cuda.is_available(): + torch_device = "cuda" + torch_sync = torch.cuda.synchronize +else: + torch_device = "cpu" + torch_sync = lambda: None -N_warmup = 8 -N_iter_bench = 80 -N_iter_func = 5 +FULL_WARMUP = 8 +FULL_ITER_BENCH = 80 +FULL_ITER_FUNC = 5 + +QUICK_WARMUP = 2 +QUICK_ITER_BENCH = 10 +QUICK_ITER_FUNC = 5 + +N_warmup = FULL_WARMUP +N_iter_bench = FULL_ITER_BENCH +N_iter_func = FULL_ITER_FUNC def bench(f, a, b): for i in range(N_warmup): f(a, b) - torch.mps.synchronize() + torch_sync() s = time.perf_counter_ns() for i in range(N_iter_bench): @@ -72,7 +98,7 @@ def gemm_nn_torch(a, b): for i in range(N_iter_func): y = a @ b ys.append(y) - torch.mps.synchronize() + torch_sync() return ys @@ -82,7 +108,7 @@ def gemm_nt_torch(a, b): for i in range(N_iter_func): y = a @ b.transpose(-1, -2) ys.append(y) - torch.mps.synchronize() + torch_sync() return ys @@ -92,7 +118,7 @@ def gemm_tn_torch(a, b): for i in range(N_iter_func): y = a.transpose(-1, -2) @ b ys.append(y) - torch.mps.synchronize() + torch_sync() return ys @@ -102,11 +128,11 @@ def gemm_tt_torch(a, b): for i in range(N_iter_func): y = a.transpose(-1, -2) @ b.transpose(-1, -2) ys.append(y) - torch.mps.synchronize() + torch_sync() return ys -def bench_shape(B, M, N, K, np_dtype, transpose="nn"): +def bench_shape(B, M, N, K, np_dtype, transpose="nn", max_torch_ops=None): shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M) shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K) @@ -116,10 +142,10 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"): a_mx = mx.array(a_np) b_mx = mx.array(b_np) - a_pt = torch.from_numpy(a_np).to("mps") - b_pt = torch.from_numpy(b_np).to("mps") + a_pt = torch.from_numpy(a_np).to(torch_device) + b_pt = torch.from_numpy(b_np).to(torch_device) - torch.mps.synchronize() + torch_sync() f_mx = { "nn": gemm_nn_mlx, @@ -135,7 +161,11 @@ def bench_shape(B, M, N, K, np_dtype, transpose="nn"): "tt": gemm_tt_torch, }[transpose] - time_torch = bench(f_pt, a_pt, b_pt) + gemm_ops = B * M * N * K + time_torch = None + if max_torch_ops is None or gemm_ops <= max_torch_ops: + time_torch = bench(f_pt, a_pt, b_pt) + time_mlx = bench(f_mx, a_mx, b_mx) t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1) @@ -158,34 +188,100 @@ def get_gflop_count(B, M, N, K): return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3) -if __name__ == "__main__": +def main(): + global N_warmup, N_iter_bench, N_iter_func + parser = argparse.ArgumentParser(description="Run gemm benchmarks") + parser.add_argument( + "--quick", + action="store_true", + help="Run fewer iterations and a reduced shape set.", + ) + parser.add_argument( + "--max-torch-ops", + type=int, + default=None, + help="Skip PyTorch timing for cases where B*M*N*K exceeds this value.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print per-shape timing results.", + ) + parser.add_argument( + "--single-threaded", + action="store_true", + help="Set OMP_NUM_THREADS=1 and OPENBLAS_NUM_THREADS=1 for single-threaded PyTorch/NumPy comparison.", + ) + args = parser.parse_args() + + if args.single_threaded: + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["OPENBLAS_NUM_THREADS"] = "1" + + if args.quick: + N_warmup = QUICK_WARMUP + N_iter_bench = QUICK_ITER_BENCH + N_iter_func = QUICK_ITER_FUNC + else: + N_warmup = FULL_WARMUP + N_iter_bench = FULL_ITER_BENCH + N_iter_func = FULL_ITER_FUNC dtypes = ("float32", "float16", "complex64") transposes = ("nn", "nt", "tn") - shapes = ( - (16, 234, 768, 3072), - (1, 64, 64, 25344), - (16, 1024, 1024, 1024), - (1, 1024, 1024, 2048), - (4, 1024, 1024, 4096), - (4, 1024, 4096, 1024), - (1, 4096, 4096, 4096), - ) + if args.quick: + shapes = ( + (16, 234, 768, 3072), + (1, 1024, 1024, 2048), + ) + else: + shapes = ( + (16, 234, 768, 3072), + (1, 64, 64, 25344), + (16, 1024, 1024, 1024), + (1, 1024, 1024, 2048), + (4, 1024, 1024, 4096), + (4, 1024, 4096, 1024), + (1, 4096, 4096, 4096), + ) + + if args.verbose: + print( + f"{'B':>3}, {'M':>4}, {'N':>4}, {'K':>4}, {'dtype':<9}, {'t':<2}, torch_gf, mlx_gf, diff" + ) + print("-" * 66) for dtype in dtypes: for transpose in transposes: for B, M, N, K in shapes: np_dtype = getattr(np, dtype) - time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose) + time_mlx, time_torch = bench_shape( + B, + M, + N, + K, + np_dtype, + transpose, + args.max_torch_ops, + ) gflop_count = get_gflop_count(B, M, N, K) gflops_mx = gflop_count / (time_mlx) - gflops_pt = gflop_count / (time_torch) - diff = gflops_mx / gflops_pt - 1.0 + if args.verbose: + if time_torch is None: + print( + f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, skipped, {gflops_mx:05.3f}, n/a" + ) + else: + gflops_pt = gflop_count / (time_torch) + diff = gflops_mx / gflops_pt - 1.0 + print( + f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" + ) + if gflops_pt >= 2.0 * gflops_mx: + print("ATTENTION ^^^^^^^") - print( - f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%" - ) - if gflops_pt >= 2.0 * gflops_mx: - print("ATTENTION ^^^^^^^") + +if __name__ == "__main__": + main() diff --git a/benchmarks/python/blas/bench_gemv.py b/benchmarks/python/blas/bench_gemv.py index 3cfc5eba41..e0c781562f 100644 --- a/benchmarks/python/blas/bench_gemv.py +++ b/benchmarks/python/blas/bench_gemv.py @@ -1,5 +1,6 @@ # Copyright © 2023 Apple Inc. +import argparse import os import subprocess import time @@ -14,29 +15,57 @@ if not os.path.isdir(results_dir): os.mkdir(results_dir) -device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) -device_name = device_name.decode("utf-8").strip("\n") - -N_warmup = 5 -N_iter_bench = 50 -N_iter_func = 20 - -out_vec_sizes = [128, 512, 2048, 4096] -in_vec_sizes = [128, 512, 2048, 4096] - -benchmark_vector_lens = [] -benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2] -benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2] -benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2] -benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000] - -benchmark_vector_lens.sort() +try: + device_name = ( + subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], stderr=subprocess.DEVNULL + ) + .decode("utf-8") + .strip() + ) +except (subprocess.CalledProcessError, FileNotFoundError): + device_name = "unknown" + +if torch.backends.mps.is_available(): + torch_device = "mps" + torch_sync = torch.mps.synchronize +elif torch.cuda.is_available(): + torch_device = "cuda" + torch_sync = torch.cuda.synchronize +else: + torch_device = "cpu" + torch_sync = lambda: None + +FULL_WARMUP = 5 +FULL_ITER_BENCH = 50 +FULL_ITER_FUNC = 20 + +QUICK_WARMUP = 2 +QUICK_ITER_BENCH = 10 +QUICK_ITER_FUNC = 5 + +FULL_OUT_VEC_SIZES = [128, 512, 2048, 4096] +FULL_IN_VEC_SIZES = [128, 512, 2048, 4096] +FULL_BENCHMARK_VECTOR_LENS = sorted( + [(i + 1) * 4096 for i in range(8)][::2] + + [(i + 1) * 4095 for i in range(8)][::2] + + [(i + 1) * 4097 for i in range(8)][::2] + + [64, 128, 512, 1024, 2048, 11008, 32000] +) + +QUICK_OUT_VEC_SIZES = [512, 2048] +QUICK_IN_VEC_SIZES = [512, 2048] +QUICK_BENCHMARK_VECTOR_LENS = sorted([128, 1024, 4096, 11008]) + +N_warmup = FULL_WARMUP +N_iter_bench = FULL_ITER_BENCH +N_iter_func = FULL_ITER_FUNC def bench(f, m, v): for i in range(N_warmup): f(m, v) - torch.mps.synchronize() + torch_sync() s = time.perf_counter_ns() for i in range(N_iter_bench): @@ -69,7 +98,7 @@ def gemv_torch(m, v): for i in range(N_iter_func): y = m @ v ys.append(y) - torch.mps.synchronize() + torch_sync() return ys @@ -79,11 +108,13 @@ def gemv_t_torch(m, v): for i in range(N_iter_func): y = v @ m ys.append(y) - torch.mps.synchronize() + torch_sync() return ys -def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False): +def bench_lens( + in_vec_len, out_vec_len, np_dtype, transpose=False, max_torch_elements=None +): shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len) shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1) @@ -91,16 +122,19 @@ def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False): vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype) mat_mlx = mx.array(mat_npy) vec_mlx = mx.array(vec_npy) - mat_trc = torch.from_numpy(mat_npy).to("mps") - vec_trc = torch.from_numpy(vec_npy).to("mps") - - torch.mps.synchronize() - - time_torch = ( - bench(gemv_t_torch, mat_trc, vec_trc) - if transpose - else bench(gemv_torch, mat_trc, vec_trc) - ) + mat_trc = torch.from_numpy(mat_npy).to(torch_device) + vec_trc = torch.from_numpy(vec_npy).to(torch_device) + + torch_sync() + + matrix_elements = in_vec_len * out_vec_len + time_torch = None + if max_torch_elements is None or matrix_elements <= max_torch_elements: + time_torch = ( + bench(gemv_t_torch, mat_trc, vec_trc) + if transpose + else bench(gemv_torch, mat_trc, vec_trc) + ) time_mlx = ( bench(gemv_t_mlx, mat_mlx, vec_mlx) if transpose @@ -128,28 +162,55 @@ def get_gflop_count(in_vec_len, out_vec_len): def get_gbyte_size(in_vec_len, out_vec_len, np_dtype): n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len - item_size = 4 if np_dtype == np.float32 else 2 + item_size = np.dtype(np_dtype).itemsize return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3) -def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose): +def bench_with_in_len( + ax, in_vec_len, out_vector_lens, dtype, transpose, max_torch_elements, verbose=False +): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] pyt_gb_s = [] pyt_gflops = [] + if verbose: + print(f" {'in':>5}, {'out':>5}, mlx_GB/s, trc_GB/s, diff") + for out_vec_len in out_vector_lens: gflop_count = get_gflop_count(in_vec_len, out_vec_len) gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) - time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) + time_mlx, time_torch = bench_lens( + in_vec_len, + out_vec_len, + np_dtype, + transpose, + max_torch_elements, + ) mlx_gb_s.append(gbyte_size / time_mlx) - pyt_gb_s.append(gbyte_size / time_torch) + pyt_gb_s.append(np.nan if time_torch is None else gbyte_size / time_torch) mlx_gflops.append(gflop_count / time_mlx) - pyt_gflops.append(gflop_count / time_torch) + pyt_gflops.append(np.nan if time_torch is None else gflop_count / time_torch) + + mlx_gb_s_value = gbyte_size / time_mlx + if verbose: + if time_torch is None: + print( + f" in={in_vec_len:5d}, out={out_vec_len:5d}, " + f"mlx={mlx_gb_s_value:7.2f} GB/s, torch=skipped" + ) + else: + pyt_gb_s_value = gbyte_size / time_torch + print( + f" in={in_vec_len:5d}, out={out_vec_len:5d}, " + f"mlx={mlx_gb_s_value:7.2f} GB/s, " + f"torch={pyt_gb_s_value:7.2f} GB/s, " + f"diff={mlx_gb_s_value/pyt_gb_s_value - 1:+.1%}" + ) if transpose: title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}" @@ -163,24 +224,51 @@ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose): ax.legend() -def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): +def bench_with_out_len( + ax, out_vec_len, in_vector_lens, dtype, transpose, max_torch_elements, verbose=False +): np_dtype = getattr(np, dtype) mlx_gb_s = [] mlx_gflops = [] pyt_gb_s = [] pyt_gflops = [] + if verbose: + print(f" {'in':>5}, {'out':>5}, mlx_GB/s, trc_GB/s, diff") + for in_vec_len in in_vector_lens: gflop_count = get_gflop_count(in_vec_len, out_vec_len) gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype) - time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose) + time_mlx, time_torch = bench_lens( + in_vec_len, + out_vec_len, + np_dtype, + transpose, + max_torch_elements, + ) mlx_gb_s.append(gbyte_size / time_mlx) - pyt_gb_s.append(gbyte_size / time_torch) + pyt_gb_s.append(np.nan if time_torch is None else gbyte_size / time_torch) mlx_gflops.append(gflop_count / time_mlx) - pyt_gflops.append(gflop_count / time_torch) + pyt_gflops.append(np.nan if time_torch is None else gflop_count / time_torch) + + mlx_gb_s_value = gbyte_size / time_mlx + if verbose: + if time_torch is None: + print( + f" in={in_vec_len:5d}, out={out_vec_len:5d}, " + f"mlx={mlx_gb_s_value:7.2f} GB/s, torch=skipped" + ) + else: + pyt_gb_s_value = gbyte_size / time_torch + print( + f" in={in_vec_len:5d}, out={out_vec_len:5d}, " + f"mlx={mlx_gb_s_value:7.2f} GB/s, " + f"torch={pyt_gb_s_value:7.2f} GB/s, " + f"diff={mlx_gb_s_value/pyt_gb_s_value - 1:+.1%}" + ) if transpose: title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])" @@ -194,27 +282,95 @@ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose): ax.legend() -for transpose in (False, True): - for dtype in ("float32", "float16", "complex64"): - fig, axs = plt.subplots( - len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" - ) - - for i, in_vec_len in enumerate(in_vec_sizes): - bench_with_in_len( - axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose +def main(): + parser = argparse.ArgumentParser(description="Run gemv benchmarks") + parser.add_argument( + "--quick", + action="store_true", + help="Run fewer iterations and a reduced vector-length set.", + ) + parser.add_argument( + "--max-torch-elements", + type=int, + default=None, + help="Skip PyTorch timing for cases where in_vec_len*out_vec_len exceeds this value.", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print per-shape timing results.", + ) + parser.add_argument( + "--single-threaded", + action="store_true", + help="Set OMP_NUM_THREADS=1 and OPENBLAS_NUM_THREADS=1 for single-threaded PyTorch/NumPy comparison.", + ) + args = parser.parse_args() + + if args.single_threaded: + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["OPENBLAS_NUM_THREADS"] = "1" + + global N_warmup, N_iter_bench, N_iter_func + if args.quick: + N_warmup = QUICK_WARMUP + N_iter_bench = QUICK_ITER_BENCH + N_iter_func = QUICK_ITER_FUNC + out_vec_sizes = QUICK_OUT_VEC_SIZES + in_vec_sizes = QUICK_IN_VEC_SIZES + benchmark_vector_lens = QUICK_BENCHMARK_VECTOR_LENS + else: + N_warmup = FULL_WARMUP + N_iter_bench = FULL_ITER_BENCH + N_iter_func = FULL_ITER_FUNC + out_vec_sizes = FULL_OUT_VEC_SIZES + in_vec_sizes = FULL_IN_VEC_SIZES + benchmark_vector_lens = FULL_BENCHMARK_VECTOR_LENS + + for transpose in (False, True): + for dtype in ("float32", "float16", "complex64"): + op_name = "gemv_t" if transpose else "gemv" + print(f"\n{'='*60}") + print(f"{op_name} | {dtype} | device: {torch_device}") + print(f"{'='*60}") + + fig, axs = plt.subplots( + len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained" ) - for i, out_vec_len in enumerate(out_vec_sizes): - bench_with_out_len( - axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose + print(f"--- sweep out_vec_len (fixed in_vec_len) ---") + for i, in_vec_len in enumerate(in_vec_sizes): + bench_with_in_len( + axs[i][0], + in_vec_len, + benchmark_vector_lens, + dtype, + transpose, + args.max_torch_elements, + args.verbose, + ) + + print(f"--- sweep in_vec_len (fixed out_vec_len) ---") + for i, out_vec_len in enumerate(out_vec_sizes): + bench_with_out_len( + axs[i][1], + out_vec_len, + benchmark_vector_lens, + dtype, + transpose, + args.max_torch_elements, + args.verbose, + ) + + fig.suptitle(f"{device_name}: {dtype} {op_name}") + fig.savefig( + os.path.join( + results_dir, + f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf", + ) ) + plt.close(fig) - op_name = "gemv_t" if transpose else "gemv" - fig.suptitle(f"{device_name}: {dtype} {op_name}") - fig.savefig( - os.path.join( - results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf" - ) - ) - plt.close(fig) + +if __name__ == "__main__": + main() diff --git a/mlx/backend/cpu/gemms/aligned_buffer.h b/mlx/backend/cpu/gemms/aligned_buffer.h new file mode 100644 index 0000000000..33966ae2ed --- /dev/null +++ b/mlx/backend/cpu/gemms/aligned_buffer.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include + +namespace mlx::core { + +// 32-byte aligned buffer with grow-only reallocation (for thread_local reuse). +template +class aligned_unique_ptr { + private: + T* ptr_; + size_t size_; + + public: + aligned_unique_ptr() : ptr_(nullptr), size_(0) {} + + explicit aligned_unique_ptr(size_t size) : size_(size) { + ptr_ = static_cast(aligned_alloc(32, size * sizeof(T))); + if (!ptr_) + throw std::bad_alloc(); + } + + ~aligned_unique_ptr() { + if (ptr_) + free(ptr_); + } + + aligned_unique_ptr(aligned_unique_ptr&& other) noexcept + : ptr_(other.ptr_), size_(other.size_) { + other.ptr_ = nullptr; + other.size_ = 0; + } + + aligned_unique_ptr& operator=(aligned_unique_ptr&& other) noexcept { + if (this != &other) { + if (ptr_) + free(ptr_); + ptr_ = other.ptr_; + size_ = other.size_; + other.ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + aligned_unique_ptr(const aligned_unique_ptr&) = delete; + aligned_unique_ptr& operator=(const aligned_unique_ptr&) = delete; + + T* get() const { + return ptr_; + } + + void reset(size_t new_size) { + if (new_size > size_) { + if (ptr_) + free(ptr_); + ptr_ = static_cast(aligned_alloc(32, new_size * sizeof(T))); + if (!ptr_) + throw std::bad_alloc(); + size_ = new_size; + } + } +}; + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/avx_gemm_simd.h b/mlx/backend/cpu/gemms/avx_gemm_simd.h new file mode 100644 index 0000000000..d63e3ff908 --- /dev/null +++ b/mlx/backend/cpu/gemms/avx_gemm_simd.h @@ -0,0 +1,433 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include +#include +#include + +#include "mlx/backend/cpu/simd/base_simd.h" + +// GEMM-private AVX2 SIMD helpers for fp16/bf16 matmul +namespace mlx::core::detail { + +// Forward declarations +template +struct Simd; +template +inline Simd load(const T* ptr); +template +inline void store(T* ptr, Simd x); +template +inline Simd broadcast(const T* ptr); +template +inline Simd fma(Simd a, Simd b, Simd c); + +// Simd — wraps __m256 for AVX operations. +using float8 = Simd; + +template <> +struct Simd { + static constexpr int size = 8; + __m256 value; + + Simd() : value(_mm256_setzero_ps()) {} + Simd(float v) : value(_mm256_set1_ps(v)) {} + explicit Simd(__m256 v) : value(v) {} + Simd(const Simd& other) = default; + Simd& operator=(const Simd& other) = default; + operator __m256() const { + return value; + } +}; + +// --- Load/Store (float) --- +template <> +inline float8 load(const float* x) { + return float8(_mm256_loadu_ps(x)); +} +template <> +inline void store(float* dst, float8 x) { + _mm256_storeu_ps(dst, x.value); +} +template <> +inline float8 broadcast(const float* x) { + return float8(_mm256_broadcast_ss(x)); +} + +// --- Arithmetic --- +inline float8 operator+(float8 a, float8 b) { + return float8(_mm256_add_ps(a, b)); +} +inline float8 operator-(float8 a, float8 b) { + return float8(_mm256_sub_ps(a, b)); +} +inline float8 operator*(float8 a, float8 b) { + return float8(_mm256_mul_ps(a, b)); +} +inline float8 operator/(float8 a, float8 b) { + return float8(_mm256_div_ps(a, b)); +} + +// --- FMA --- +template <> +inline float8 fma(float8 a, float8 b, float8 c) { +#ifdef __AVX2__ + return float8(_mm256_fmadd_ps(a, b, c)); +#else + return float8(_mm256_add_ps(_mm256_mul_ps(a, b), c)); +#endif +} + +// --- Horizontal Sum --- +inline float sum(float8 x) { + __m256 val = x.value; + __m128 vlow = _mm256_castps256_ps128(val); + __m128 vhigh = _mm256_extractf128_ps(val, 1); // high 128 + vlow = _mm_add_ps(vlow, vhigh); // add the low 128 + __m128 shuf = _mm_movehdup_ps(vlow); // broadcast elements 3,1 to 2,0 + __m128 sums = _mm_add_ps(vlow, shuf); + shuf = _mm_movehl_ps(shuf, sums); // high half -> low half + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); +} + +// 8x8 block transpose with fp16/bf16 → fp32 conversion. +// Loads 8 rows of 8 half-precision values, converts and transposes to fp32. +template +inline void +transpose_8x8_block(const T* src, float* dst, int src_stride, int dst_stride) { + static_assert( + std::is_same_v || std::is_same_v, + "transpose_8x8_block requires float16_t or bfloat16_t input"); + + if constexpr (std::is_same_v) { +#ifdef __F16C__ + // Load 8 rows of 8 float16 values, convert to fp32 + __m128i row0 = _mm_loadu_si128(reinterpret_cast(src)); + __m128i row1 = + _mm_loadu_si128(reinterpret_cast(src + src_stride)); + __m128i row2 = + _mm_loadu_si128(reinterpret_cast(src + 2 * src_stride)); + __m128i row3 = + _mm_loadu_si128(reinterpret_cast(src + 3 * src_stride)); + __m128i row4 = + _mm_loadu_si128(reinterpret_cast(src + 4 * src_stride)); + __m128i row5 = + _mm_loadu_si128(reinterpret_cast(src + 5 * src_stride)); + __m128i row6 = + _mm_loadu_si128(reinterpret_cast(src + 6 * src_stride)); + __m128i row7 = + _mm_loadu_si128(reinterpret_cast(src + 7 * src_stride)); + + // Convert to fp32 (vcvtph2ps: 1/cycle throughput, 3 cycle latency) + __m256 frow0 = _mm256_cvtph_ps(row0); + __m256 frow1 = _mm256_cvtph_ps(row1); + __m256 frow2 = _mm256_cvtph_ps(row2); + __m256 frow3 = _mm256_cvtph_ps(row3); + __m256 frow4 = _mm256_cvtph_ps(row4); + __m256 frow5 = _mm256_cvtph_ps(row5); + __m256 frow6 = _mm256_cvtph_ps(row6); + __m256 frow7 = _mm256_cvtph_ps(row7); + + // Transpose via unpack / shuffle / permute + __m256 t0 = _mm256_unpacklo_ps(frow0, frow1); + __m256 t1 = _mm256_unpackhi_ps(frow0, frow1); + __m256 t2 = _mm256_unpacklo_ps(frow2, frow3); + __m256 t3 = _mm256_unpackhi_ps(frow2, frow3); + __m256 t4 = _mm256_unpacklo_ps(frow4, frow5); + __m256 t5 = _mm256_unpackhi_ps(frow4, frow5); + __m256 t6 = _mm256_unpacklo_ps(frow6, frow7); + __m256 t7 = _mm256_unpackhi_ps(frow6, frow7); + + __m256 tt0 = _mm256_shuffle_ps(t0, t2, 0x44); + __m256 tt1 = _mm256_shuffle_ps(t0, t2, 0xEE); + __m256 tt2 = _mm256_shuffle_ps(t1, t3, 0x44); + __m256 tt3 = _mm256_shuffle_ps(t1, t3, 0xEE); + __m256 tt4 = _mm256_shuffle_ps(t4, t6, 0x44); + __m256 tt5 = _mm256_shuffle_ps(t4, t6, 0xEE); + __m256 tt6 = _mm256_shuffle_ps(t5, t7, 0x44); + __m256 tt7 = _mm256_shuffle_ps(t5, t7, 0xEE); + + __m256 r0 = _mm256_permute2f128_ps(tt0, tt4, 0x20); + __m256 r1 = _mm256_permute2f128_ps(tt1, tt5, 0x20); + __m256 r2 = _mm256_permute2f128_ps(tt2, tt6, 0x20); + __m256 r3 = _mm256_permute2f128_ps(tt3, tt7, 0x20); + __m256 r4 = _mm256_permute2f128_ps(tt0, tt4, 0x31); + __m256 r5 = _mm256_permute2f128_ps(tt1, tt5, 0x31); + __m256 r6 = _mm256_permute2f128_ps(tt2, tt6, 0x31); + __m256 r7 = _mm256_permute2f128_ps(tt3, tt7, 0x31); + + _mm256_storeu_ps(dst + 0 * dst_stride, r0); + _mm256_storeu_ps(dst + 1 * dst_stride, r1); + _mm256_storeu_ps(dst + 2 * dst_stride, r2); + _mm256_storeu_ps(dst + 3 * dst_stride, r3); + _mm256_storeu_ps(dst + 4 * dst_stride, r4); + _mm256_storeu_ps(dst + 5 * dst_stride, r5); + _mm256_storeu_ps(dst + 6 * dst_stride, r6); + _mm256_storeu_ps(dst + 7 * dst_stride, r7); +#else + // Fallback without F16C + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j++) { + dst[j * dst_stride + i] = static_cast(src[i * src_stride + j]); + } + } +#endif + } else { // bfloat16_t +#ifdef __AVX2__ + // bf16 → fp32: zero-extend to 32-bit, shift left 16 + __m256 rows[8]; + for (int i = 0; i < 8; i++) { + __m128i bf16_vals_u16 = _mm_loadu_si128( + reinterpret_cast(src + i * src_stride)); + __m256i bf16_vals_u32 = _mm256_cvtepu16_epi32(bf16_vals_u16); + __m256i fp32_bits = _mm256_slli_epi32(bf16_vals_u32, 16); + rows[i] = _mm256_castsi256_ps(fp32_bits); + } + + // Transpose the 8 rows using AVX shuffles + __m256 t0 = _mm256_unpacklo_ps(rows[0], rows[1]); + __m256 t1 = _mm256_unpackhi_ps(rows[0], rows[1]); + __m256 t2 = _mm256_unpacklo_ps(rows[2], rows[3]); + __m256 t3 = _mm256_unpackhi_ps(rows[2], rows[3]); + __m256 t4 = _mm256_unpacklo_ps(rows[4], rows[5]); + __m256 t5 = _mm256_unpackhi_ps(rows[4], rows[5]); + __m256 t6 = _mm256_unpacklo_ps(rows[6], rows[7]); + __m256 t7 = _mm256_unpackhi_ps(rows[6], rows[7]); + + __m256 tt0 = _mm256_shuffle_ps(t0, t2, 0x44); + __m256 tt1 = _mm256_shuffle_ps(t0, t2, 0xEE); + __m256 tt2 = _mm256_shuffle_ps(t1, t3, 0x44); + __m256 tt3 = _mm256_shuffle_ps(t1, t3, 0xEE); + __m256 tt4 = _mm256_shuffle_ps(t4, t6, 0x44); + __m256 tt5 = _mm256_shuffle_ps(t4, t6, 0xEE); + __m256 tt6 = _mm256_shuffle_ps(t5, t7, 0x44); + __m256 tt7 = _mm256_shuffle_ps(t5, t7, 0xEE); + + __m256 r0 = _mm256_permute2f128_ps(tt0, tt4, 0x20); + __m256 r1 = _mm256_permute2f128_ps(tt1, tt5, 0x20); + __m256 r2 = _mm256_permute2f128_ps(tt2, tt6, 0x20); + __m256 r3 = _mm256_permute2f128_ps(tt3, tt7, 0x20); + __m256 r4 = _mm256_permute2f128_ps(tt0, tt4, 0x31); + __m256 r5 = _mm256_permute2f128_ps(tt1, tt5, 0x31); + __m256 r6 = _mm256_permute2f128_ps(tt2, tt6, 0x31); + __m256 r7 = _mm256_permute2f128_ps(tt3, tt7, 0x31); + + _mm256_storeu_ps(dst + 0 * dst_stride, r0); + _mm256_storeu_ps(dst + 1 * dst_stride, r1); + _mm256_storeu_ps(dst + 2 * dst_stride, r2); + _mm256_storeu_ps(dst + 3 * dst_stride, r3); + _mm256_storeu_ps(dst + 4 * dst_stride, r4); + _mm256_storeu_ps(dst + 5 * dst_stride, r5); + _mm256_storeu_ps(dst + 6 * dst_stride, r6); + _mm256_storeu_ps(dst + 7 * dst_stride, r7); +#else + // Scalar fallback + for (int i = 0; i < 8; i++) { + for (int j = 0; j < 8; j++) { + dst[j * dst_stride + i] = static_cast(src[i * src_stride + j]); + } + } +#endif + } +} + +// ========================================================================== +// Conversion and Combined Operations (T -> float -> T) +// T = float16_t or bfloat16_t +// ========================================================================== + +// Load 8 half-precision values, convert to float8. +template +inline float8 load_convert_to_float(const T* src) { + static_assert( + std::is_same_v || std::is_same_v, + "load_convert_to_float requires float16_t or bfloat16_t input for this specialization."); + static_assert(sizeof(T) == 2, "Input type T must be 2 bytes."); + + if constexpr (std::is_same_v) { +#ifdef __F16C__ + __m128i f16_vals = _mm_loadu_si128(reinterpret_cast(src)); + return float8(_mm256_cvtph_ps(f16_vals)); +#else + float buffer[8]; + for (int i = 0; i < 8; ++i) + buffer[i] = static_cast(src[i]); + return load(buffer); +#endif + } else { // bfloat16_t +#ifdef __AVX2__ + // bf16 → fp32: zero-extend to 32-bit then shift left 16 + __m128i bf16_vals_u16 = + _mm_loadu_si128(reinterpret_cast(src)); + __m256i bf16_vals_u32 = _mm256_cvtepu16_epi32(bf16_vals_u16); + __m256i fp32_bits = _mm256_slli_epi32(bf16_vals_u32, 16); + return float8(_mm256_castsi256_ps(fp32_bits)); +#else + // Scalar fallback + float buffer[8]; + for (int i = 0; i < 8; ++i) { + uint32_t val_int = + static_cast(reinterpret_cast(src)[i]) + << 16; + std::memcpy(&buffer[i], &val_int, sizeof(float)); + } + return load(buffer); +#endif + } +} + +// fp32 → bf16 with round-to-nearest-even. +#ifdef __AVX2__ +inline __m128i convert_float_to_bfloat16_rne_avx2(__m256 src) { + __m256i val_int = _mm256_castps_si256(src); + __m256i bias = _mm256_set1_epi32(0x7FFF); + __m256i rounded_val = _mm256_add_epi32(val_int, bias); + __m256i bf16_bits_32 = _mm256_srli_epi32(rounded_val, 16); + __m128i bf16_bits_low = _mm256_castsi256_si128(bf16_bits_32); + __m128i bf16_bits_high = _mm256_extracti128_si256(bf16_bits_32, 1); + // Use signed pack to preserve negative values + return _mm_packs_epi32(bf16_bits_low, bf16_bits_high); +} +#endif + +// Store float8, converting back to 8 half-precision values. +template +inline void store_convert_from_float(T* dst, float8 src) { + static_assert( + std::is_same_v || std::is_same_v, + "store_convert_from_float requires float16_t or bfloat16_t output for this specialization."); + static_assert(sizeof(T) == 2, "Output type T must be 2 bytes."); + + if constexpr (std::is_same_v) { +#ifdef __F16C__ + __m128i f16_result = _mm256_cvtps_ph( + src.value, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), f16_result); +#else + float buffer[8]; + store(buffer, src); + for (int i = 0; i < 8; ++i) + dst[i] = static_cast(buffer[i]); +#endif + } else { // bfloat16_t +#ifdef __AVX2__ + __m128i bf16_result = convert_float_to_bfloat16_rne_avx2(src.value); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), bf16_result); +#else + // Scalar fallback with RNE + float buffer[8]; + store(buffer, src); + alignas(16) uint16_t bf16_bits_arr[8]; + for (int i = 0; i < 8; ++i) { + uint32_t val_int; + std::memcpy(&val_int, &buffer[i], sizeof(float)); + + // Handle NaN + if ((val_int & 0x7F800000) == 0x7F800000 && (val_int & 0x007FFFFF) != 0) { + bf16_bits_arr[i] = + 0x7FC0 | static_cast((val_int >> 16) & 0x003F); + } else { + uint32_t rounding_bias = ((val_int >> 16) & 1) + 0x7FFF; + val_int += rounding_bias; + bf16_bits_arr[i] = static_cast(val_int >> 16); + } + } + std::memcpy(dst, bf16_bits_arr, 8 * sizeof(uint16_t)); +#endif + } +} + +// 6×16 AVX2 microkernel: C[6][16] += A[6][kc] * B[kc][16] +// Uses 12 YMM accumulators + 2 B loads + 1 A broadcast = 15 registers. +template +inline void micro_kernel_6x16( + const float* __restrict A_panel, + const float* __restrict B_panel, + float* __restrict C_block, + int ldc, + int kc, + int a_stride, + int b_stride) { + static_assert(MR == 6, "This kernel requires MR=6"); + static_assert(NR == 16, "This kernel requires NR=16"); + + // 12 accumulators + 2 B loads + 1 A broadcast = 15 YMM registers + __m256 c00 = _mm256_loadu_ps(C_block + 0 * ldc); + __m256 c01 = _mm256_loadu_ps(C_block + 0 * ldc + 8); + __m256 c10 = _mm256_loadu_ps(C_block + 1 * ldc); + __m256 c11 = _mm256_loadu_ps(C_block + 1 * ldc + 8); + __m256 c20 = _mm256_loadu_ps(C_block + 2 * ldc); + __m256 c21 = _mm256_loadu_ps(C_block + 2 * ldc + 8); + __m256 c30 = _mm256_loadu_ps(C_block + 3 * ldc); + __m256 c31 = _mm256_loadu_ps(C_block + 3 * ldc + 8); + __m256 c40 = _mm256_loadu_ps(C_block + 4 * ldc); + __m256 c41 = _mm256_loadu_ps(C_block + 4 * ldc + 8); + __m256 c50 = _mm256_loadu_ps(C_block + 5 * ldc); + __m256 c51 = _mm256_loadu_ps(C_block + 5 * ldc + 8); + + // Prefetch B and A data 8 iterations ahead into L1 + constexpr int PF_DIST = 8; + + for (int k = 0; k < kc; ++k) { + const float* b_ptr = B_panel + k * b_stride; + const float* a_ptr = A_panel + k * a_stride; + + // Prefetch next B and A rows into L1 + if (k + PF_DIST < kc) { + _mm_prefetch( + reinterpret_cast(B_panel + (k + PF_DIST) * b_stride), + _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(B_panel + (k + PF_DIST) * b_stride + 8), + _MM_HINT_T0); + _mm_prefetch( + reinterpret_cast(A_panel + (k + PF_DIST) * a_stride), + _MM_HINT_T0); + } + + __m256 b0 = _mm256_loadu_ps(b_ptr); + __m256 b1 = _mm256_loadu_ps(b_ptr + 8); + + __m256 a; + a = _mm256_broadcast_ss(a_ptr + 0); + c00 = _mm256_fmadd_ps(a, b0, c00); + c01 = _mm256_fmadd_ps(a, b1, c01); + + a = _mm256_broadcast_ss(a_ptr + 1); + c10 = _mm256_fmadd_ps(a, b0, c10); + c11 = _mm256_fmadd_ps(a, b1, c11); + + a = _mm256_broadcast_ss(a_ptr + 2); + c20 = _mm256_fmadd_ps(a, b0, c20); + c21 = _mm256_fmadd_ps(a, b1, c21); + + a = _mm256_broadcast_ss(a_ptr + 3); + c30 = _mm256_fmadd_ps(a, b0, c30); + c31 = _mm256_fmadd_ps(a, b1, c31); + + a = _mm256_broadcast_ss(a_ptr + 4); + c40 = _mm256_fmadd_ps(a, b0, c40); + c41 = _mm256_fmadd_ps(a, b1, c41); + + a = _mm256_broadcast_ss(a_ptr + 5); + c50 = _mm256_fmadd_ps(a, b0, c50); + c51 = _mm256_fmadd_ps(a, b1, c51); + } + + _mm256_storeu_ps(C_block + 0 * ldc, c00); + _mm256_storeu_ps(C_block + 0 * ldc + 8, c01); + _mm256_storeu_ps(C_block + 1 * ldc, c10); + _mm256_storeu_ps(C_block + 1 * ldc + 8, c11); + _mm256_storeu_ps(C_block + 2 * ldc, c20); + _mm256_storeu_ps(C_block + 2 * ldc + 8, c21); + _mm256_storeu_ps(C_block + 3 * ldc, c30); + _mm256_storeu_ps(C_block + 3 * ldc + 8, c31); + _mm256_storeu_ps(C_block + 4 * ldc, c40); + _mm256_storeu_ps(C_block + 4 * ldc + 8, c41); + _mm256_storeu_ps(C_block + 5 * ldc, c50); + _mm256_storeu_ps(C_block + 5 * ldc + 8, c51); +} + +} // namespace mlx::core::detail diff --git a/mlx/backend/cpu/gemms/avx_simd_gemm.h b/mlx/backend/cpu/gemms/avx_simd_gemm.h new file mode 100644 index 0000000000..23c2a8f9cd --- /dev/null +++ b/mlx/backend/cpu/gemms/avx_simd_gemm.h @@ -0,0 +1,417 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/cpu/gemms/aligned_buffer.h" +#include "mlx/backend/cpu/gemms/avx_gemm_simd.h" +#include "mlx/backend/cpu/gemms/avx_simd_gemv.h" + +namespace mlx::core { + +template +inline void +pack_transpose_8x8(const T* src, float* dst, int src_stride, int dst_stride) { + detail::transpose_8x8_block(src, dst, src_stride, dst_stride); +} + +// Pack A block (m_block x k_block) into A_packed (MC x KC float, column-major). +template +static void pack_A_block( + const T* A, + float* A_packed, + int M, + int K, + int ldA, + int M_offset, + int K_offset, + int m_block, + int k_block, + bool a_trans) { + static_assert( + std::is_same_v || std::is_same_v, + "T must be float16 or bfloat16"); + constexpr int simd_width = 8; + + // Zero-fill only the portions we access (edge tiles) + if (m_block < MC || k_block < KC) { + for (int k = 0; k < k_block; ++k) { + std::fill(A_packed + k * MC, A_packed + k * MC + m_block, 0.0f); + } + } + + if (!a_trans) { + // A is row-major (M x K). Pack with 8x8 transpose blocks. + for (int k = 0; k < k_block; k += 8) { + int k_chunk = std::min(8, k_block - k); + + if (k_chunk == 8) { + for (int i = 0; i < m_block; i += 8) { + int m_chunk = std::min(8, m_block - i); + + if (m_chunk == 8) { + const T* a_block_start = A + (M_offset + i) * ldA + K_offset + k; + pack_transpose_8x8( + a_block_start, A_packed + k * MC + i, ldA, MC); + } else { + for (int ii = 0; ii < m_chunk; ++ii) { + const T* a_src_row_ptr = + A + (M_offset + i + ii) * ldA + K_offset + k; + for (int kk = 0; kk < k_chunk; ++kk) { + A_packed[(k + kk) * MC + (i + ii)] = + static_cast(a_src_row_ptr[kk]); + } + } + } + } + } else { + for (int i = 0; i < m_block; ++i) { + const T* a_src_row_ptr = A + (M_offset + i) * ldA + K_offset + k; + for (int kk = 0; kk < k_chunk; ++kk) { + A_packed[(k + kk) * MC + i] = static_cast(a_src_row_ptr[kk]); + } + } + } + } + } else { + // A is transposed (K x M row-major). Contiguous copy with SIMD convert. + for (int k = 0; k < k_block; ++k) { + const T* a_src_row_ptr = A + (K_offset + k) * ldA + M_offset; + float* a_dst_col_ptr = A_packed + k * MC; + int i = 0; + for (; i + simd_width <= m_block; i += simd_width) { + detail::float8 a_vec = detail::load_convert_to_float(a_src_row_ptr + i); + detail::store(a_dst_col_ptr + i, a_vec); + } + for (; i < m_block; ++i) { + a_dst_col_ptr[i] = static_cast(a_src_row_ptr[i]); + } + } + } +} + +// Pack B block (k_block x n_block) into B_packed (KC x NC float, row-major). +template +static void pack_B_block( + const T* B, + float* B_packed, + int K, + int N, + int ldB, + int K_offset, + int N_offset, + int k_block, + int n_block, + bool b_trans) { + static_assert( + std::is_same_v || std::is_same_v, + "T must be float16 or bfloat16"); + constexpr int simd_width = 8; + + if (k_block < KC || n_block < NC) { + for (int k = 0; k < k_block; ++k) { + std::fill(B_packed + k * NC, B_packed + k * NC + n_block, 0.0f); + } + } + + if (!b_trans) { + // B is row-major (K x N). Contiguous copy with SIMD convert. + for (int k = 0; k < k_block; ++k) { + const T* b_src_row_ptr = B + (K_offset + k) * ldB + N_offset; + float* b_dst_row_ptr = B_packed + k * NC; + int j = 0; + for (; j + simd_width <= n_block; j += simd_width) { + detail::float8 b_vec = detail::load_convert_to_float(b_src_row_ptr + j); + detail::store(b_dst_row_ptr + j, b_vec); + } + for (; j < n_block; ++j) { + b_dst_row_ptr[j] = static_cast(b_src_row_ptr[j]); + } + } + } else { + // B is transposed (N x K row-major). Pack with 8x8 transpose blocks. + for (int k = 0; k < k_block; k += 8) { + int k_chunk = std::min(8, k_block - k); + + if (k_chunk == 8) { + for (int j = 0; j < n_block; j += 8) { + int n_chunk = std::min(8, n_block - j); + + if (n_chunk == 8) { + const T* b_block_start = B + (N_offset + j) * ldB + K_offset + k; + float tmp_transpose[64]; + pack_transpose_8x8(b_block_start, tmp_transpose, ldB, 8); + for (int kk = 0; kk < 8; ++kk) { + for (int jj = 0; jj < 8; ++jj) { + B_packed[(k + kk) * NC + (j + jj)] = tmp_transpose[kk * 8 + jj]; + } + } + } else { + for (int kk = 0; kk < k_chunk; ++kk) { + float* b_dst_row_ptr = B_packed + (k + kk) * NC + j; + for (int jj = 0; jj < n_chunk; ++jj) { + b_dst_row_ptr[jj] = static_cast( + B[(N_offset + j + jj) * ldB + (K_offset + k + kk)]); + } + } + } + } + } else { + for (int kk = 0; kk < k_chunk; ++kk) { + float* b_dst_row_ptr = B_packed + (k + kk) * NC; + for (int j = 0; j < n_block; ++j) { + b_dst_row_ptr[j] = static_cast( + B[(N_offset + j) * ldB + (K_offset + k + kk)]); + } + } + } + } + } +} + +// Single-threaded fp16/bf16 GEMM with fp32 accumulation. Goto-style +// jc→pc→ic blocking; A and B are packed to fp32 once per panel. +template +void simd_gemm_optimized_higher_precision( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + int ldA, + int ldB, + int ldC, + float alpha, + float beta) { + static_assert( + std::is_same_v || std::is_same_v, + "GEMM kernel requires float16_t or bfloat16_t."); + + // Blocking parameters. + constexpr int MR = 6; + constexpr int NR = 16; + static_assert(NR % 8 == 0, "NR must be multiple of float SIMD width (8)"); + + constexpr int KC_BLOCK = 256; + constexpr int MC_BLOCK = 96; + constexpr int NC_BLOCK = 256; + + static_assert(MC_BLOCK % MR == 0, "MC_BLOCK must be a multiple of MR"); + static_assert(NC_BLOCK % NR == 0, "NC_BLOCK must be a multiple of NR"); + + // Thread-local buffers (grow-only, reused across calls) + thread_local aligned_unique_ptr A_packed_buf(MC_BLOCK * KC_BLOCK); + thread_local aligned_unique_ptr B_packed_buf(KC_BLOCK * NC_BLOCK); + thread_local aligned_unique_ptr C_acc_buf(1); + + A_packed_buf.reset(MC_BLOCK * KC_BLOCK); + B_packed_buf.reset(KC_BLOCK * NC_BLOCK); + C_acc_buf.reset(M * NC_BLOCK); + + float* A_packed = A_packed_buf.get(); + float* B_packed = B_packed_buf.get(); + float* C_acc = C_acc_buf.get(); + + // Scalar fallback for edge tiles (m_micro < MR or n_micro < NR) + auto compute_block_scalar_partial = []( + + const float* A_panel, + const float* B_panel, + float* C_sub, + int ldc_acc, + int m_micro, + int n_micro, + int k_block, + int a_stride, + int b_stride) { + for (int i = 0; i < m_micro; ++i) { + for (int j = 0; j < n_micro; ++j) { + float acc = C_sub[i * ldc_acc + j]; + for (int k = 0; k < k_block; ++k) { + acc += A_panel[i + k * a_stride] * B_panel[k * b_stride + j]; + } + C_sub[i * ldc_acc + j] = acc; + } + } + }; + + constexpr int sw = 8; + + for (int jc = 0; jc < N; jc += NC_BLOCK) { + int nc = std::min(NC_BLOCK, N - jc); + + for (int pc = 0; pc < K; pc += KC_BLOCK) { + int kc = std::min(KC_BLOCK, K - pc); + bool first_k = (pc == 0); + bool last_k = (pc + kc >= K); + + pack_B_block( + b, B_packed, K, N, ldB, pc, jc, kc, nc, b_trans); + + for (int ic = 0; ic < M; ic += MC_BLOCK) { + int mc = std::min(MC_BLOCK, M - ic); + + pack_A_block( + a, A_packed, M, K, ldA, ic, pc, mc, kc, a_trans); + + // Zero C_acc on first K-panel; alpha and beta*C are applied at + // writeback. + if (first_k) { + for (int i = 0; i < mc; ++i) { + std::memset(C_acc + (ic + i) * NC_BLOCK, 0, nc * sizeof(float)); + } + } + + // Microkernel loop + for (int ir = 0; ir < mc; ir += MR) { + int m_micro = std::min(MR, mc - ir); + + for (int jr = 0; jr < nc; jr += NR) { + int n_micro = std::min(NR, nc - jr); + + const float* a_ptr = A_packed + ir; + const float* b_ptr = B_packed + jr; + float* c_ptr = C_acc + (ic + ir) * NC_BLOCK + jr; + + // Prefetch next C_acc tile into L2 + if (jr + NR < nc) { + for (int pi = 0; pi < MR && ir + pi < mc; ++pi) + _mm_prefetch( + reinterpret_cast( + C_acc + (ic + ir + pi) * NC_BLOCK + jr + NR), + _MM_HINT_T1); + } else if (ir + MR < mc) { + for (int pi = 0; pi < MR && ir + MR + pi < mc; ++pi) + _mm_prefetch( + reinterpret_cast( + C_acc + (ic + ir + MR + pi) * NC_BLOCK), + _MM_HINT_T1); + } + + if (m_micro == MR && n_micro == NR) { + detail::micro_kernel_6x16( + a_ptr, b_ptr, c_ptr, NC_BLOCK, kc, MC_BLOCK, NC_BLOCK); + } else { + compute_block_scalar_partial( + a_ptr, + b_ptr, + c_ptr, + NC_BLOCK, + m_micro, + n_micro, + kc, + MC_BLOCK, + NC_BLOCK); + } + } + } + + // Writeback: C = alpha * acc + beta * C + if (last_k) { + bool apply_alpha = (alpha != 1.0f); + bool apply_beta = (beta != 0.0f); + detail::float8 alpha_vec(alpha); + detail::float8 beta_vec(beta); + + for (int i = 0; i < mc; ++i) { + T* c_row = c + (ic + i) * ldC + jc; + float* acc_row = C_acc + (ic + i) * NC_BLOCK; + int j = 0; + for (; j + sw <= nc; j += sw) { + detail::float8 acc = detail::load(acc_row + j); + if (apply_alpha) + acc = alpha_vec * acc; + if (apply_beta) { + detail::float8 cv = detail::load_convert_to_float(c_row + j); + acc = acc + beta_vec * cv; + } + detail::store_convert_from_float(c_row + j, acc); + } + for (; j < nc; ++j) { + float val = acc_row[j]; + if (apply_alpha) + val *= alpha; + if (apply_beta) + val += beta * static_cast(c_row[j]); + c_row[j] = static_cast(val); + } + } + } + } // ic + } // pc + } // jc +} + +// Public interface: validates dimensions and dispatches to the blocked kernel. +template +void simd_gemm( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + size_t M_s, + size_t N_s, + size_t K_s, + float alpha = 1.0f, + float beta = 0.0f) { + static_assert( + std::is_same_v || std::is_same_v, + "simd_gemm requires T = float16_t or bfloat16_t."); + static_assert( + std::is_same_v, "simd_gemm requires AccT = float."); + + if (M_s > static_cast(std::numeric_limits::max()) || + N_s > static_cast(std::numeric_limits::max()) || + K_s > static_cast(std::numeric_limits::max())) { + throw std::overflow_error("Matrix dimensions exceed int limits."); + } + int M = static_cast(M_s); + int N = static_cast(N_s); + int K = static_cast(K_s); + + if (M <= 0 || N <= 0) + return; + + int ldA = (!a_trans) ? K : M; + int ldB = (!b_trans) ? N : K; + int ldC = N; + + // K=0: C = beta * C + if (K <= 0) { + if (beta == 0.0f) { + for (int i = 0; i < M; ++i) { + T zero_val = static_cast(0.0f); + std::fill(c + i * ldC, c + i * ldC + N, zero_val); + } + } else if (beta != 1.0f) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + float c_old_f = static_cast(c[i * ldC + j]); + c[i * ldC + j] = static_cast(beta * c_old_f); + } + } + } + return; + } + + // Dispatch to GEMV for M=1 or N=1 (avoids blocked GEMM overhead) + if (M == 1 || N == 1) { + simd_gemv( + a, b, c, a_trans, b_trans, M, N, K, ldA, ldB, ldC, alpha, beta); + return; + } + + simd_gemm_optimized_higher_precision( + a, b, c, a_trans, b_trans, M, N, K, ldA, ldB, ldC, alpha, beta); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/cpu/gemms/avx_simd_gemv.h b/mlx/backend/cpu/gemms/avx_simd_gemv.h new file mode 100644 index 0000000000..dfda359aeb --- /dev/null +++ b/mlx/backend/cpu/gemms/avx_simd_gemv.h @@ -0,0 +1,193 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/backend/cpu/gemms/aligned_buffer.h" +#include "mlx/backend/cpu/gemms/avx_gemm_simd.h" + +namespace mlx::core { + +// Output-dim block: 4096 fp32 = 16KB, fits in L1 alongside B rows. +constexpr int GEMV_NC_BLOCK = 4096; + +// acc[0:width] += sum_k vec[k] * mat[k * mat_stride + 0:width] +template +static void gemv_outer_product( + const T* vec, + const T* mat, + float* acc, + int K, + int width, + int mat_stride) { + constexpr int sw = 8; + + for (int jc = 0; jc < width; jc += GEMV_NC_BLOCK) { + int nc = std::min(GEMV_NC_BLOCK, width - jc); + float* acc_block = acc + jc; + + for (int k = 0; k < K; k++) { + float v = static_cast(vec[k]); + detail::float8 v_bcast(v); + const T* mat_row = mat + k * mat_stride + jc; + + // Prefetch start of next row for this block + if (k + 1 < K) { + _mm_prefetch( + reinterpret_cast(mat + (k + 1) * mat_stride + jc), + _MM_HINT_T0); + } + + int j = 0; + for (; j + sw <= nc; j += sw) { + detail::float8 m = detail::load_convert_to_float(mat_row + j); + detail::float8 c = detail::load(acc_block + j); + detail::store( + acc_block + j, detail::fma(v_bcast, m, c)); + } + for (; j < nc; j++) { + acc_block[j] += v * static_cast(mat_row[j]); + } + } + } +} + +// acc[i] += dot(mat[i*mat_stride : +K], vec[0:K]); 4-row unroll to share vec +// loads. +template +static void gemv_dot_product( + const T* mat, + const T* vec, + float* acc, + int n_outputs, + int K, + int mat_stride) { + constexpr int sw = 8; + constexpr int UNROLL = 4; + + int i = 0; + for (; i + UNROLL <= n_outputs; i += UNROLL) { + detail::float8 s0, s1, s2, s3; + + const T* r0 = mat + (i + 0) * mat_stride; + const T* r1 = mat + (i + 1) * mat_stride; + const T* r2 = mat + (i + 2) * mat_stride; + const T* r3 = mat + (i + 3) * mat_stride; + + int k = 0; + for (; k + sw <= K; k += sw) { + detail::float8 v = detail::load_convert_to_float(vec + k); + s0 = detail::fma(detail::load_convert_to_float(r0 + k), v, s0); + s1 = detail::fma(detail::load_convert_to_float(r1 + k), v, s1); + s2 = detail::fma(detail::load_convert_to_float(r2 + k), v, s2); + s3 = detail::fma(detail::load_convert_to_float(r3 + k), v, s3); + } + + float d0 = detail::sum(s0); + float d1 = detail::sum(s1); + float d2 = detail::sum(s2); + float d3 = detail::sum(s3); + + for (; k < K; k++) { + float vk = static_cast(vec[k]); + d0 += vk * static_cast(r0[k]); + d1 += vk * static_cast(r1[k]); + d2 += vk * static_cast(r2[k]); + d3 += vk * static_cast(r3[k]); + } + + acc[i + 0] += d0; + acc[i + 1] += d1; + acc[i + 2] += d2; + acc[i + 3] += d3; + } + + for (; i < n_outputs; i++) { + detail::float8 s; + const T* row = mat + i * mat_stride; + + int k = 0; + for (; k + sw <= K; k += sw) { + detail::float8 v = detail::load_convert_to_float(vec + k); + s = detail::fma(detail::load_convert_to_float(row + k), v, s); + } + + float d = detail::sum(s); + for (; k < K; k++) { + d += static_cast(vec[k]) * static_cast(row[k]); + } + acc[i] += d; + } +} + +// C = alpha * op(A) * op(B) + beta * C, for M=1 or N=1. +// Dispatches to outer-product or dot-product core based on shape and transpose. +template +void simd_gemv( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + int ldA, + int ldB, + int ldC, + float alpha, + float beta) { + int out_len = (M == 1) ? N : M; + + // Thread-local fp32 accumulator (grow-only). + thread_local aligned_unique_ptr acc_buf(1); + acc_buf.reset(out_len); + float* acc = acc_buf.get(); + + constexpr int sw = 8; + std::memset(acc, 0, out_len * sizeof(float)); + + // acc += op(A) * op(B) + if (M == 1) { + if (!b_trans) { + gemv_outer_product(a, b, acc, K, N, ldB); + } else { + gemv_dot_product(b, a, acc, N, K, ldB); + } + } else { + if (!a_trans) { + gemv_dot_product(a, b, acc, M, K, ldA); + } else { + gemv_outer_product(b, a, acc, K, M, ldA); + } + } + + // Writeback: C = alpha * acc + beta * C (convert fp32 → T) + bool apply_alpha = (alpha != 1.0f); + bool apply_beta = (beta != 0.0f); + detail::float8 alpha_vec(alpha); + detail::float8 beta_vec(beta); + int j = 0; + for (; j + sw <= out_len; j += sw) { + detail::float8 val = detail::load(acc + j); + if (apply_alpha) + val = alpha_vec * val; + if (apply_beta) { + detail::float8 cv = detail::load_convert_to_float(c + j); + val = val + beta_vec * cv; + } + detail::store_convert_from_float(c + j, val); + } + for (; j < out_len; j++) { + float val = acc[j]; + if (apply_alpha) + val *= alpha; + if (apply_beta) + val += beta * static_cast(c[j]); + c[j] = static_cast(val); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cpu/gemms/simd_bf16.cpp b/mlx/backend/cpu/gemms/simd_bf16.cpp index 58f5964b6e..11ef34f46f 100644 --- a/mlx/backend/cpu/gemms/simd_bf16.cpp +++ b/mlx/backend/cpu/gemms/simd_bf16.cpp @@ -2,7 +2,12 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" + +#ifdef MLX_USE_AVX2 +#include "mlx/backend/cpu/gemms/avx_simd_gemm.h" +#else #include "mlx/backend/cpu/gemms/simd_gemm.h" +#endif namespace mlx::core { diff --git a/mlx/backend/cpu/gemms/simd_fp16.cpp b/mlx/backend/cpu/gemms/simd_fp16.cpp index 93467da868..826a2ca60a 100644 --- a/mlx/backend/cpu/gemms/simd_fp16.cpp +++ b/mlx/backend/cpu/gemms/simd_fp16.cpp @@ -2,7 +2,12 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/cpu/gemm.h" + +#ifdef MLX_USE_AVX2 +#include "mlx/backend/cpu/gemms/avx_simd_gemm.h" +#else #include "mlx/backend/cpu/gemms/simd_gemm.h" +#endif namespace mlx::core {