From 2c547138cffe61178874332b2fac9f74518289ff Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Jun 2026 02:17:11 +0200 Subject: [PATCH] Metal: FP8-packed compressed-KV cache + long-context memory optimizations Three memory optimizations for the Metal backend's compressed-KV (MLA latent) path: 1. Packed FP8 comp cache (opt-in, DS4_METAL_FP8_KV_STORE=1): stores comp rows as e4m3 + ue8m0 scale + f16 rot = 584 B/row vs 1024 B/row f16. Dequant uses a 128-entry LUT (ds4_e4m3_lut) avoiding branch+exp2. Validated bit-identical. 2. comp_mask stored as f16 (always on): binary -inf/0 mask fits exactly in f16, halving the mask buffer size at all context lengths. 3. indexer_scores token-tiling (always on): DS4_INDEXER_SCORE_TILE=512 reduces the score working buffer from comp_cap*prefill_cap to comp_cap*512 (~8x). Together: ~2.3x KV cache reduction at long context, bit-identical output, speed-neutral decode (bandwidth saving offset by per-element dequant cost). Revert with DS4_DISABLE_KV_OPTS=1. Tested on M5 Max, 8k-96k context. Fixes: https://github.com/antirez/ds4/pull/416 Co-Authored-By: Claude Sonnet 4.6 --- ds4.c | 343 ++++++++++++------- ds4_gpu.h | 42 +++ ds4_metal.m | 735 ++++++++++++++++++++++++++++++++++++++--- metal/cpy.metal | 1 + metal/dsv4_kv.metal | 225 ++++++++++++- metal/dsv4_misc.metal | 32 +- metal/flash_attn.metal | 232 ++++++++++++- 7 files changed, 1435 insertions(+), 175 deletions(-) diff --git a/ds4.c b/ds4.c index b20270943..b3cf042de 100644 --- a/ds4.c +++ b/ds4.c @@ -10302,6 +10302,32 @@ static void print_vec_stats(const char *name, const float *x, uint64_t n) { #define DS4_GPU_ATTN_COMP_CACHE_F16 0 #endif +/* The indexer score matrix [comp_cap x n_tokens] is the largest context-dependent + * scratch buffer at long context. Top-k selection is independent per token, so the + * prefill/decode-batch score+top-k is processed in token tiles of this width, + * letting the scores buffer be comp_cap*TILE instead of comp_cap*prefill_cap + * (exact, zero drift). prefill_cap=4096 -> ~8x smaller scores scratch. */ +#define DS4_INDEXER_SCORE_TILE 512u + +/* KV-cache memory optimizations (packed comp cache, f16 comp_mask, score tiling) + * are on by default; DS4_DISABLE_KV_OPTS reverts to the pre-optimization layout for + * A/B comparison (true "original" footprint and behaviour in one binary). */ +static int ds4_kv_opts_enabled(void) { + return getenv("DS4_DISABLE_KV_OPTS") == NULL; +} + +/* All comp-cache sites (alloc sizing, f32 staging, capacity checks, row views, + * model dump, session save/restore, and the flash unpack-to-f16 staging path) are + * now packed-aware, so enable runtime packed activation via DS4_METAL_FP8_KV_STORE. */ +#define DS4_GPU_ATTN_COMP_CACHE_PACKED_READY 1 + +/* Forward declarations: the attn-comp-cache format helpers are defined alongside + * the graph store helpers (~line 13300) but are referenced by earlier sizing, + * staging, and validity code. Format: 0 = f32, 1 = f16, 2 = packed FP8 (584 B/row). */ +static uint32_t metal_graph_attn_comp_cache_format(void); +static uint64_t metal_graph_attn_comp_cache_row_bytes(void); +static bool metal_graph_attn_comp_uses_stage(void); + /* ========================================================================= * Metal Release Graph State. * ========================================================================= @@ -10774,8 +10800,7 @@ static uint64_t metal_graph_kv_cache_bytes_for_context(uint32_t ctx_size, uint32 const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio == 0) continue; const uint64_t comp_cap = (uint64_t)(ctx_size / ratio + 2u); - bytes += comp_cap * DS4_N_HEAD_DIM * - (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float)); + bytes += comp_cap * metal_graph_attn_comp_cache_row_bytes(); if (ratio == 4) { bytes += comp_cap * DS4_N_INDEXER_HEAD_DIM * sizeof(float); } @@ -10798,9 +10823,11 @@ static uint64_t metal_graph_context_bytes_for_kv_policy( if (comp_cap < 2u) comp_cap = 2u; const uint64_t kv_cache_bytes = metal_graph_kv_cache_bytes_for_context(ctx_size, raw_cap); if (kv_cache_bytes_out) *kv_cache_bytes_out = kv_cache_bytes; + /* indexer_scores (f32, token-tiled to score_tile) + comp_mask (f16) scratch. */ uint64_t bytes = kv_cache_bytes + - 2ull * comp_cap * prefill_cap * sizeof(float); - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + comp_cap * (prefill_cap < DS4_INDEXER_SCORE_TILE ? prefill_cap : DS4_INDEXER_SCORE_TILE) * sizeof(float) + + comp_cap * prefill_cap * sizeof(uint16_t); + if (metal_graph_attn_comp_uses_stage()) { uint64_t attn_stage_cap = (uint64_t)(prefill_cap / min_ratio + 2u); if (attn_stage_cap < 2u) attn_stage_cap = 2u; bytes += attn_stage_cap * DS4_N_HEAD_DIM * sizeof(float); @@ -10988,7 +11015,7 @@ static bool metal_graph_alloc_raw_cap( if (min_ratio == UINT32_MAX) min_ratio = ctx_size ? ctx_size : 1u; g->comp_cap = ctx_size / min_ratio + 2u; if (g->comp_cap < 2u) g->comp_cap = 2u; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_uses_stage()) { g->attn_comp_stage_cap = prefill_cap / min_ratio + 2u; if (g->attn_comp_stage_cap < 2u) g->attn_comp_stage_cap = 2u; } @@ -11069,8 +11096,7 @@ static bool metal_graph_alloc_raw_cap( const uint64_t attn_rows = (uint64_t)coff * ratio; g->layer_attn_comp_cache[il] = metal_graph_alloc_kv_cache_tensor( managed_kv_cache, - (uint64_t)g->layer_comp_cap[il] * DS4_N_HEAD_DIM * - (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float))); + (uint64_t)g->layer_comp_cap[il] * metal_graph_attn_comp_cache_row_bytes()); g->layer_attn_state_kv[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); g->layer_attn_state_score[il] = ds4_gpu_tensor_alloc(attn_width * attn_rows * sizeof(float)); if (enable_mtp) { @@ -11115,14 +11141,22 @@ static bool metal_graph_alloc_raw_cap( } g->comp_kv_cur = ds4_gpu_tensor_alloc(comp_width_max * sizeof(float)); g->comp_sc_cur = ds4_gpu_tensor_alloc(comp_width_max * sizeof(float)); - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_uses_stage()) { g->attn_comp_stage = ds4_gpu_tensor_alloc((uint64_t)g->attn_comp_stage_cap * DS4_N_HEAD_DIM * sizeof(float)); } g->indexer_q = ds4_gpu_tensor_alloc(indexer_q_dim * sizeof(float)); g->indexer_weights = ds4_gpu_tensor_alloc((uint64_t)DS4_N_INDEXER_HEAD * sizeof(float)); - g->indexer_scores = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * pc * sizeof(float)); - g->comp_mask = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * pc * sizeof(float)); + /* Token-tiled score buffer: comp_cap*TILE instead of comp_cap*prefill_cap + * (full untiled buffer when KV opts are disabled). */ + const uint64_t score_tile = ds4_kv_opts_enabled() + ? (pc < DS4_INDEXER_SCORE_TILE ? pc : DS4_INDEXER_SCORE_TILE) : pc; + g->indexer_scores = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * score_tile * sizeof(float)); + /* comp_mask is an f16 binary (-inf/0) top-k mask: halves this comp_cap*pc + * buffer (one of the largest context-dependent allocations). f32 when KV opts + * are disabled (DS4_DISABLE_KV_OPTS) for A/B comparison. */ + g->comp_mask = ds4_gpu_tensor_alloc((uint64_t)g->comp_cap * pc * + (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float))); g->comp_selected = ds4_gpu_tensor_alloc((uint64_t)(DS4_N_INDEXER_TOP_K ? DS4_N_INDEXER_TOP_K : 1u) * pc * sizeof(uint32_t)); g->heads = ds4_gpu_tensor_alloc(q_dim * sizeof(float)); @@ -11251,7 +11285,7 @@ static bool metal_graph_alloc_raw_cap( g->attn_cur && g->attn_norm && g->qr && g->qr_norm && g->q && g->kv_raw && g->kv && g->comp_kv_cur && g->comp_sc_cur && - (!DS4_GPU_ATTN_COMP_CACHE_F16 || g->attn_comp_stage) && + (!metal_graph_attn_comp_uses_stage() || g->attn_comp_stage) && g->indexer_q && g->indexer_weights && g->indexer_scores && g->comp_mask && g->comp_selected && g->heads && g->attn_low && g->attn_out && @@ -13482,13 +13516,54 @@ static bool metal_graph_decode_kv_store( DS4_N_ROT) != 0; } +/* Comp-cache element format: 0 = f32 rows, 1 = f16 rows, 2 = packed FP8 (the + * ds4_fp8_kv_row 3-plane layout, 584 B/row). Packed is opt-in via + * DS4_METAL_FP8_KV_STORE; it stores comp at e4m3 precision (validated + * output-equivalent to f16) and is read directly by the indexed attention helper + * and the flash fp8mix kernels, eliminating the per-step f16 conversion. */ +static uint32_t metal_graph_attn_comp_cache_format(void) { + static uint32_t cached = 3u; /* 3 = unresolved */ + if (cached == 3u) { + /* Packed (2) activation is gated on DS4_METAL_FP8_KV_STORE AND the + * compile-time DS4_GPU_ATTN_COMP_CACHE_PACKED_READY switch. The latter is + * off until every comp-cache site (alloc sizing, f32 staging, capacity + * checks, row views, model dump, and session save/restore) is packed-aware; + * flipping it on before that would mis-handle the 584 B/row layout on the + * unconverted paths (notably session persistence). Until then packed stays + * staged-but-dormant and behaviour is byte-identical to the f16 path. */ +#ifdef DS4_GPU_ATTN_COMP_CACHE_PACKED_READY + if (getenv("DS4_METAL_FP8_KV_STORE") != NULL && ds4_kv_opts_enabled()) { + cached = 2u; + fprintf(stderr, + "ds4: FP8-packed compressed-KV cache enabled (DS4_METAL_FP8_KV_STORE): " + "comp cache 1024->584 B/row (~1.75x smaller), e4m3-precision, " + "bit-identical generation, ~break-even decode speed\n"); + } else +#endif + { + cached = DS4_GPU_ATTN_COMP_CACHE_F16 ? 1u : 0u; + } + } + return cached; +} + static uint64_t metal_graph_attn_comp_cache_row_bytes(void) { - return (uint64_t)DS4_N_HEAD_DIM * - (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float)); + switch (metal_graph_attn_comp_cache_format()) { + case 2u: { + const uint64_t n_nope = (uint64_t)DS4_N_HEAD_DIM - DS4_N_ROT; + const uint64_t n_blk = n_nope / 64u; + const uint64_t scale_pad = (n_blk + 7u) & ~(uint64_t)7u; /* 8 */ + return scale_pad + n_nope + (uint64_t)DS4_N_ROT * sizeof(uint16_t); /* 584 */ + } + case 1u: return (uint64_t)DS4_N_HEAD_DIM * sizeof(uint16_t); + default: return (uint64_t)DS4_N_HEAD_DIM * sizeof(float); + } } +/* Backward-compatible name: now returns the format code (0/1/2) passed to the + * attention dispatch as comp_kv_f16. The metal side treats 2 as packed. */ static uint32_t metal_graph_attn_comp_cache_is_f16(void) { - return DS4_GPU_ATTN_COMP_CACHE_F16 ? 1u : 0u; + return metal_graph_attn_comp_cache_format(); } static bool metal_graph_store_attn_comp_stage( @@ -13507,7 +13582,18 @@ static bool metal_graph_store_attn_comp_stage( const uint64_t count = (uint64_t)rows * DS4_N_HEAD_DIM; const uint64_t dst_offset = (uint64_t)first_row * metal_graph_attn_comp_cache_row_bytes(); - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + const uint32_t format = metal_graph_attn_comp_cache_format(); + if (format == 2u) { + /* Pack the f32 stage rows into the 3-plane FP8 layout directly into the + * persistent cache — once per emitted comp row, read packed every step. */ + return ds4_gpu_dsv4_kv_pack_comp_rows(g->layer_attn_comp_cache[il], + dst_offset, + g->attn_comp_stage, + rows, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0; + } + if (format == 1u) { return ds4_gpu_tensor_copy_f32_to_f16(g->layer_attn_comp_cache[il], dst_offset, g->attn_comp_stage, @@ -13522,16 +13608,20 @@ static bool metal_graph_store_attn_comp_stage( count * sizeof(float)) != 0; } +static bool metal_graph_attn_comp_uses_stage(void) { + return metal_graph_attn_comp_cache_format() != 0u; +} + static ds4_gpu_tensor *metal_graph_attn_comp_update_target( ds4_gpu_graph *g, uint32_t il) { - return DS4_GPU_ATTN_COMP_CACHE_F16 + return metal_graph_attn_comp_uses_stage() ? g->attn_comp_stage : g->layer_attn_comp_cache[il]; } static uint32_t metal_graph_attn_comp_update_row(uint32_t row) { - return DS4_GPU_ATTN_COMP_CACHE_F16 ? 0u : row; + return metal_graph_attn_comp_uses_stage() ? 0u : row; } static bool metal_graph_commit_attn_comp_stage( @@ -13539,7 +13629,7 @@ static bool metal_graph_commit_attn_comp_stage( uint32_t il, uint32_t first_row, uint32_t rows) { - if (!DS4_GPU_ATTN_COMP_CACHE_F16) return true; + if (!metal_graph_attn_comp_uses_stage()) return true; return metal_graph_store_attn_comp_stage(g, il, first_row, rows); } @@ -13547,7 +13637,7 @@ static ds4_gpu_tensor *metal_graph_attn_comp_row_view( ds4_gpu_graph *g, uint32_t il, uint32_t row) { - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_uses_stage()) { return ds4_gpu_tensor_view(g->attn_comp_stage, 0, (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); @@ -13562,7 +13652,7 @@ static ds4_gpu_tensor *metal_graph_attn_comp_prefill_target( uint32_t il, uint32_t first_row, uint32_t rows) { - if (DS4_GPU_ATTN_COMP_CACHE_F16) return g->attn_comp_stage; + if (metal_graph_attn_comp_uses_stage()) return g->attn_comp_stage; const uint32_t view_rows = rows ? rows : 1u; return ds4_gpu_tensor_view(g->layer_attn_comp_cache[il], (uint64_t)first_row * DS4_N_HEAD_DIM * sizeof(float), @@ -13570,7 +13660,7 @@ static ds4_gpu_tensor *metal_graph_attn_comp_prefill_target( } static void metal_graph_attn_comp_prefill_target_free(ds4_gpu_tensor *t) { - if (!DS4_GPU_ATTN_COMP_CACHE_F16) ds4_gpu_tensor_free(t); + if (!metal_graph_attn_comp_uses_stage()) ds4_gpu_tensor_free(t); } /* Encode one DS4 decode layer on Metal. This is the release single-token @@ -17766,7 +17856,7 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal layer-major compressed KV cache capacity exceeded at layer %u\n", il); ok = false; } - if (ok && DS4_GPU_ATTN_COMP_CACHE_F16 && n_comp > g->attn_comp_stage_cap) { + if (ok && metal_graph_attn_comp_uses_stage() && n_comp > g->attn_comp_stage_cap) { fprintf(stderr, "ds4: Metal graph compressed KV staging capacity exceeded at layer %u\n", il); ok = false; } @@ -17849,7 +17939,7 @@ static bool metal_graph_encode_layer_attention_batch( fprintf(stderr, "ds4: Metal graph compressed KV cache capacity exceeded at layer %u\n", il); ok = false; } - if (ok && DS4_GPU_ATTN_COMP_CACHE_F16 && comp_chunk > g->attn_comp_stage_cap) { + if (ok && metal_graph_attn_comp_uses_stage() && comp_chunk > g->attn_comp_stage_cap) { fprintf(stderr, "ds4: Metal graph compressed KV staging capacity exceeded at layer %u\n", il); ok = false; } @@ -18331,53 +18421,33 @@ static bool metal_graph_encode_layer_attention_batch( n_comp, &index_stage_t0); } - ok = ds4_gpu_indexer_scores_decode_batch_tensor(g->indexer_scores, - g->batch_indexer_q, - g->batch_indexer_weights, - g->layer_index_comp_cache[il], - n_comp, - n_tokens, - pos0, - DS4_N_INDEXER_HEAD, - DS4_N_INDEXER_HEAD_DIM, - ratio, - index_scale) != 0; + /* Token-tiled (scores buffer = comp_cap*score_tile). This path can + * run with large n_tokens for non-zero-prefix prefill chunks. */ + ok = ds4_gpu_indexer_decode_batch_score_topk_tiled( + g->indexer_scores, + g->comp_selected, + g->batch_indexer_q, + g->batch_indexer_weights, + g->layer_index_comp_cache[il], + n_comp, + n_tokens, + pos0, + DS4_N_INDEXER_HEAD, + DS4_N_INDEXER_HEAD_DIM, + ratio, + index_scale, + DS4_N_INDEXER_TOP_K, + (ds4_kv_opts_enabled() ? (g->prefill_cap < DS4_INDEXER_SCORE_TILE ? g->prefill_cap : DS4_INDEXER_SCORE_TILE) : g->prefill_cap)) != 0; if (ok && index_stage_profile) { - ok = metal_graph_indexer_stage_profile_boundary("score", - il, - pos0, - n_tokens, - n_comp, + ok = metal_graph_indexer_stage_profile_boundary("score_topk", + il, pos0, n_tokens, n_comp, &index_stage_t0); } if (ok) { - metal_graph_debug_dump_tensor("indexer_scores", - g->indexer_scores, - (uint64_t)n_comp * n_tokens, - il, - pos0); - } - if (ok) { - ok = ds4_gpu_indexer_topk_tensor(g->comp_selected, - g->indexer_scores, - n_comp, - n_tokens, - DS4_N_INDEXER_TOP_K) != 0; - if (ok && index_stage_profile) { - ok = metal_graph_indexer_stage_profile_boundary("topk", - il, - pos0, - n_tokens, - n_comp, - &index_stage_t0); - } - if (ok) { - metal_graph_debug_dump_i32_tensor("indexer_topk", - g->comp_selected, - (uint64_t)n_tokens * DS4_N_INDEXER_TOP_K, - il, - pos0); - } + metal_graph_debug_dump_i32_tensor("indexer_topk", + g->comp_selected, + (uint64_t)n_tokens * DS4_N_INDEXER_TOP_K, + il, pos0); } if (ok) { use_indexed_comp = true; @@ -18452,52 +18522,34 @@ static bool metal_graph_encode_layer_attention_batch( n_comp, &index_stage_t0); } - ok = ds4_gpu_indexer_scores_prefill_tensor(g->indexer_scores, - g->batch_indexer_q, - g->batch_indexer_weights, - g->layer_index_comp_cache[il], - n_comp, - n_tokens, - DS4_N_INDEXER_HEAD, - DS4_N_INDEXER_HEAD_DIM, - ratio, - index_scale) != 0; + /* Token-tiled score+top-k: the scores buffer holds only score_tile + * tokens; the helper loops token tiles with pos0=t0 (exact: token i is + * scored at global position i, same as the untiled pos0=0 path) and + * views into q/weights/comp_selected. Top-k is per-token independent. */ + ok = ds4_gpu_indexer_prefill_score_topk_tiled( + g->indexer_scores, + g->comp_selected, + g->batch_indexer_q, + g->batch_indexer_weights, + g->layer_index_comp_cache[il], + n_comp, + n_tokens, + DS4_N_INDEXER_HEAD, + DS4_N_INDEXER_HEAD_DIM, + ratio, + index_scale, + DS4_N_INDEXER_TOP_K, + (ds4_kv_opts_enabled() ? (g->prefill_cap < DS4_INDEXER_SCORE_TILE ? g->prefill_cap : DS4_INDEXER_SCORE_TILE) : g->prefill_cap)) != 0; if (ok && index_stage_profile) { - ok = metal_graph_indexer_stage_profile_boundary("score", - il, - pos0, - n_tokens, - n_comp, + ok = metal_graph_indexer_stage_profile_boundary("score_topk", + il, pos0, n_tokens, n_comp, &index_stage_t0); } if (ok) { - metal_graph_debug_dump_tensor("indexer_scores", - g->indexer_scores, - (uint64_t)n_comp * n_tokens, - il, - pos0); - } - if (ok) { - ok = ds4_gpu_indexer_topk_tensor(g->comp_selected, - g->indexer_scores, - n_comp, - n_tokens, - DS4_N_INDEXER_TOP_K) != 0; - if (ok && index_stage_profile) { - ok = metal_graph_indexer_stage_profile_boundary("topk", - il, - pos0, - n_tokens, - n_comp, - &index_stage_t0); - } - if (ok) { - metal_graph_debug_dump_i32_tensor("indexer_topk", - g->comp_selected, - (uint64_t)n_tokens * DS4_N_INDEXER_TOP_K, - il, - pos0); - } + metal_graph_debug_dump_i32_tensor("indexer_topk", + g->comp_selected, + (uint64_t)n_tokens * DS4_N_INDEXER_TOP_K, + il, pos0); } if (ok) { ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, @@ -21443,8 +21495,7 @@ ds4_context_memory ds4_context_memory_estimate_with_prefill( if (ratio == 0) continue; const uint32_t layer_comp_cap = ctx / ratio + 2u; m.compressed_bytes += (uint64_t)layer_comp_cap * - DS4_N_HEAD_DIM * - (DS4_GPU_ATTN_COMP_CACHE_F16 ? sizeof(uint16_t) : sizeof(float)); + metal_graph_attn_comp_cache_row_bytes(); if (ratio == 4) { m.compressed_bytes += (uint64_t)layer_comp_cap * DS4_N_INDEXER_HEAD_DIM * @@ -21481,7 +21532,14 @@ ds4_context_memory ds4_context_memory_estimate_with_prefill( if (m.comp_cap == 0) m.comp_cap = ctx / 4u + 2u; m.scratch_bytes = ((uint64_t)(m.raw_cap + m.comp_cap) * sizeof(float)) + ((uint64_t)m.comp_cap * sizeof(float)) + - ((uint64_t)m.comp_cap * sizeof(bool)); + ((uint64_t)m.comp_cap * sizeof(bool)) + + /* indexer scratch: scores (f32, token-tiled to score_tile) + * + comp_mask (f16). These dominate context memory at long + * context and were previously omitted here. */ + ((uint64_t)m.comp_cap * + (m.prefill_cap < DS4_INDEXER_SCORE_TILE ? m.prefill_cap : DS4_INDEXER_SCORE_TILE) * + sizeof(float)) + + ((uint64_t)m.comp_cap * m.prefill_cap * sizeof(uint16_t)); } m.total_bytes = m.raw_bytes + m.compressed_bytes + m.scratch_bytes; @@ -21598,7 +21656,10 @@ static int metal_graph_prompt_logits_test( const uint64_t n = (uint64_t)n_comp * DS4_N_HEAD_DIM; float *gpu_comp = xmalloc((size_t)n * sizeof(float)); bool comp_read = false; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_cache_format() == 2u) { + /* Packed FP8 cache: skip the f32 comp trace (diagnostic only; + * reading packed bytes as f16/f32 would be meaningless). */ + } else if (DS4_GPU_ATTN_COMP_CACHE_F16) { uint16_t *gpu_comp_h = xmalloc((size_t)n * sizeof(uint16_t)); if (ds4_gpu_tensor_read(g.layer_attn_comp_cache[il], 0, gpu_comp_h, n * sizeof(uint16_t)) != 0) { @@ -23438,7 +23499,10 @@ static uint64_t session_payload_live_tensor_bytes(const ds4_gpu_graph *g, uint32 bytes += (uint64_t)raw_live * DS4_N_HEAD_DIM * sizeof(float); const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio == 0) continue; - bytes += (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float); + bytes += (uint64_t)g->layer_n_comp[il] * + (metal_graph_attn_comp_cache_format() == 2u + ? metal_graph_attn_comp_cache_row_bytes() + : (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); bytes += layer_attn_state_bytes(ratio); bytes += layer_attn_state_bytes(ratio); if (ratio == 4) { @@ -23653,7 +23717,12 @@ uint64_t ds4_session_layer_payload_bytes(ds4_session *s, bytes += (uint64_t)raw_live * DS4_N_HEAD_DIM * sizeof(float); const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio == 0) continue; - bytes += (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float); + /* comp span = packed 584 B/row when packed, else f32-canonical (must + * match save/restore spans or the checkpoint payload total is wrong). */ + bytes += (uint64_t)g->layer_n_comp[il] * + (metal_graph_attn_comp_cache_format() == 2u + ? metal_graph_attn_comp_cache_row_bytes() + : (uint64_t)DS4_N_HEAD_DIM * sizeof(float)); bytes += layer_attn_state_bytes(ratio); bytes += layer_attn_state_bytes(ratio); if (ratio == 4) { @@ -23733,7 +23802,16 @@ int ds4_session_save_layer_payload(ds4_session *s, FILE *fp, } const uint32_t ratio = ds4_layer_compress_ratio(il); if (rc != 0 || ratio == 0) continue; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_cache_format() == 2u) { + rc = payload_write_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)g->layer_n_comp[il] * metal_graph_attn_comp_cache_row_bytes(), + buf, + DS4_SESSION_IO_CHUNK, + err, + errlen); + } else if (DS4_GPU_ATTN_COMP_CACHE_F16) { rc = payload_write_tensor_span_f16_as_f32(fp, g->layer_attn_comp_cache[il], 0, @@ -23934,7 +24012,17 @@ int ds4_session_load_layer_payload(ds4_session *s, FILE *fp, } const uint32_t ratio = ds4_layer_compress_ratio(il); if (rc != 0 || ratio == 0) continue; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_cache_format() == 2u) { + rc = payload_read_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)n_comp[i] * metal_graph_attn_comp_cache_row_bytes(), + buf, + DS4_SESSION_IO_CHUNK, + &remaining, + err, + errlen); + } else if (DS4_GPU_ATTN_COMP_CACHE_F16) { rc = payload_read_tensor_span_f32_as_f16(fp, g->layer_attn_comp_cache[il], 0, @@ -24417,7 +24505,16 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) /* Compressed rows are append-only from row zero, so the live prefix is * contiguous. The two compressor state tensors hold the partial window * that will become the next compressed row. */ - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_cache_format() == 2u) { + rc = payload_write_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)g->layer_n_comp[il] * metal_graph_attn_comp_cache_row_bytes(), + buf, + DS4_SESSION_IO_CHUNK, + err, + errlen); + } else if (DS4_GPU_ATTN_COMP_CACHE_F16) { rc = payload_write_tensor_span_f16_as_f32(fp, g->layer_attn_comp_cache[il], 0, @@ -24753,7 +24850,17 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c } const uint32_t ratio = ds4_layer_compress_ratio(il); if (rc != 0 || ratio == 0) continue; - if (DS4_GPU_ATTN_COMP_CACHE_F16) { + if (metal_graph_attn_comp_cache_format() == 2u) { + rc = payload_read_tensor_span(fp, + g->layer_attn_comp_cache[il], + 0, + (uint64_t)n_comp[il] * metal_graph_attn_comp_cache_row_bytes(), + buf, + DS4_SESSION_IO_CHUNK, + &remaining, + err, + errlen); + } else if (DS4_GPU_ATTN_COMP_CACHE_F16) { rc = payload_read_tensor_span_f32_as_f16(fp, g->layer_attn_comp_cache[il], 0, diff --git a/ds4_gpu.h b/ds4_gpu.h index b58aca9bd..24d10774d 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -179,6 +179,40 @@ int ds4_gpu_indexer_scores_prefill_tensor( uint32_t ratio, float scale); +/* comp_mask f16 (1) unless KV opts disabled via DS4_DISABLE_KV_OPTS (0 = f32). */ +int ds4_gpu_comp_mask_f16(void); + +int ds4_gpu_indexer_prefill_score_topk_tiled( + ds4_gpu_tensor *scores, + ds4_gpu_tensor *selected, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + uint32_t top_k, + uint32_t score_tile); + +int ds4_gpu_indexer_decode_batch_score_topk_tiled( + ds4_gpu_tensor *scores, + ds4_gpu_tensor *selected, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + uint32_t top_k, + uint32_t score_tile); + int ds4_gpu_indexer_scores_decode_batch_tensor( ds4_gpu_tensor *scores, const ds4_gpu_tensor *q, @@ -413,6 +447,14 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( uint32_t head_dim, uint32_t n_rot); +int ds4_gpu_dsv4_kv_pack_comp_rows( + ds4_gpu_tensor *dst, + uint64_t dst_byte_offset, + ds4_gpu_tensor *src, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot); + int ds4_gpu_dsv4_indexer_qat_tensor( ds4_gpu_tensor *x, uint32_t n_rows, diff --git a/ds4_metal.m b/ds4_metal.m index 51b7b4982..0e19d5aa6 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -58,6 +58,7 @@ static id g_cpy_f32_f32_pipeline; static id g_cpy_f32_f16_pipeline; static id g_cpy_f16_f32_pipeline; +static id g_cpy_f16_f16_pipeline; static id g_swiglu_pipeline; static id g_add_pipeline; static id g_moe_sum6_pipeline; @@ -118,6 +119,9 @@ static id g_dsv4_fp8_kv_quantize_pipeline; static id g_dsv4_indexer_qat_pipeline; static id g_dsv4_kv_fp8_store_pipeline; +static id g_dsv4_kv_pack_fp8_pipeline; +static id g_dsv4_kv_pack_fp8_row_pipeline; +static id g_dsv4_kv_unpack_fp8_row_pipeline; static id g_dsv4_ratio4_shift_pipeline; static id g_dsv4_softmax_pool_pipeline; static id g_soft_max_f32_pipeline; @@ -149,6 +153,8 @@ static id g_flash_attn_blk_buffer; static id g_flash_attn_ring_buffer; static id g_flash_attn_kv_buffer; +static id g_flash_attn_kv_packed_buffer; /* packed FP8 comp rows for fp8mix flash read */ +static NSUInteger g_flash_attn_kv_packed_bytes; static id g_compressor_pool_kv_buffer; static id g_compressor_pool_score_buffer; static id g_compressor_pool_score_cont_buffer; @@ -268,6 +274,7 @@ static int g_metal4_m5_neural_accelerators_hint; static int g_metal4_tensor_api_enabled; static int g_metal4_tensor_api_compile_supported; +static int g_metal4_fp8_native_enabled; static char g_metal_device_name[128]; static int ds4_gpu_model_map_log_enabled(void); static int ds4_gpu_stream_expert_cache_note_expert_size( @@ -571,6 +578,13 @@ static void ds4_gpu_close_batch_encoder(void) { g_batch_enc = nil; } +/* Diagnostic-only GPU-busy accounting. Enabled by DS4_METAL_CB_GPU_PROFILE; sums + * per-command-buffer GPU time so it can be compared against wall-clock to tell + * whether a phase is GPU/bandwidth-bound or CPU-encoding/gap-bound. */ +static double g_prof_gpu_busy_s; +static uint64_t g_prof_cb_count; +static int g_prof_cb_enabled = -1; + static int ds4_gpu_wait_command_buffer(id cb, const char *label) { [cb waitUntilCompleted]; if (cb.status == MTLCommandBufferStatusError) { @@ -578,6 +592,16 @@ static int ds4_gpu_wait_command_buffer(id cb, const char *labe label, [[cb.error localizedDescription] UTF8String]); return 0; } + if (g_prof_cb_enabled < 0) { + g_prof_cb_enabled = getenv("DS4_METAL_CB_GPU_PROFILE") != NULL ? 1 : 0; + } + if (g_prof_cb_enabled) { + const double dt = cb.GPUEndTime - cb.GPUStartTime; + if (dt > 0.0) { + g_prof_gpu_busy_s += dt; + g_prof_cb_count++; + } + } return 1; } @@ -848,6 +872,7 @@ static uint64_t ds4_gpu_effective_model_max_tensor_bytes(uint64_t map_size, uint static id ds4_gpu_get_pipeline(const char *function_name); static int ds4_gpu_warm_model_views(void); +static int ds4_gpu_validate_fp8_kv_pack(void); static double ds4_gpu_gib(uint64_t bytes); static double ds4_gpu_now_ms(void) { @@ -1771,6 +1796,103 @@ static int ds4_gpu_compile_tensor_probe(void) { return 0; } +/* + * Validate the MSL 4.1 native e4m3 packer before letting the KV-cache kernels + * use it. The hardware metal_fp8_e4m3_format packer on macOS 27 beta (26A5353q) + * is bit-exact with DS4's software ladder on the normal range [2^-6, 448] but + * rounds subnormals non-monotonically and returns NaN above the max normal, so + * dsv4_e4m3fn_dequant uses a hybrid (software subnormals + clamped native + * normals). This probe compiles that hybrid alongside the software reference and + * accepts it only if the two agree bit-for-bit across a vector that spans deep + * subnormals, the full normal range, the saturation boundary, and out-of-range + * inputs. A future SDK that changes the packer's behavior in either direction is + * therefore caught at startup rather than silently perturbing the KV cache. + */ +static int ds4_gpu_validate_fp8_native(void) { +#if defined(__MAC_OS_X_VERSION_MAX_ALLOWED) && __MAC_OS_X_VERSION_MAX_ALLOWED >= 260000 + if (!g_device) return 0; + if (@available(macOS 26.0, *)) { + const char *src = + "#include \n" + "#include \n" + "using namespace metal;\n" + "constant float es[16]={0,0.015625f,0.03125f,0.0625f,0.125f,0.25f,0.5f,1.f,2.f,4.f,8.f,16.f,32.f,64.f,128.f,256.f};\n" + "static inline float ev(int i){int e=(i>>3)&0xf;int m=i&0x7;return e==0?float(m)*0.001953125f:(1.f+float(m)*0.125f)*es[e];}\n" + "static inline float sw(float x){float s=x<0?-1.f:1.f;float ax=min(abs(x),448.f);int lo=0,hi=126;\n" + " while(lo>1;if(ev(md)<=ax)lo=md;else hi=md-1;}\n" + " int b=lo;if(b<126){float bd=abs(ax-ev(b));float nd=abs(ax-ev(b+1));\n" + " if(nd v(clamp(x,-448.f,448.f),0.f,0.f,0.f);\n" + " return unpack(pack(v))[0];}\n" + "kernel void ds4_fp8_selftest(device const float* in [[buffer(0)]],\n" + " device atomic_uint* fail [[buffer(1)]], uint g [[thread_position_in_grid]]) {\n" + " float x=in[g]; if(sw(x)!=hyb(x)) atomic_fetch_add_explicit(fail,1u,memory_order_relaxed);\n" + "}\n"; + + NSError *error = nil; + id lib = [g_device newLibraryWithSource:[NSString stringWithUTF8String:src] + options:[MTLCompileOptions new] + error:&error]; + if (!lib) { + fprintf(stderr, "ds4: Metal 4.1 FP8 probe compile failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + id fn = [lib newFunctionWithName:@"ds4_fp8_selftest"]; + id ps = fn ? [g_device newComputePipelineStateWithFunction:fn error:&error] : nil; + if (!ps) { + fprintf(stderr, "ds4: Metal 4.1 FP8 probe pipeline failed: %s\n", + error ? [[error localizedDescription] UTF8String] : "(unknown)"); + return 0; + } + + const int N = 1 << 16; + id bin = [g_device newBufferWithLength:(NSUInteger)N * sizeof(float) + options:MTLResourceStorageModeShared]; + id bfail = [g_device newBufferWithLength:sizeof(uint32_t) + options:MTLResourceStorageModeShared]; + if (!bin || !bfail) return 0; + float *in = (float *)[bin contents]; + for (int i = 0; i < N; i++) { + const double t = (double)i / (double)(N - 1); + const double s = (i & 1) ? -1.0 : 1.0; + in[i] = (float)(s * pow(2.0, t * 24.0 - 15.0)); /* ~3e-5 .. 512 (>448) */ + } + in[0] = 448.0f; in[1] = -448.0f; in[2] = 0.015625f; in[3] = -0.015625f; + in[4] = 0.0f; in[5] = 600.0f; in[6] = -600.0f; + *((uint32_t *)[bfail contents]) = 0; + + /* Self-contained: detection runs before the shared g_queue exists. */ + id queue = [g_device newCommandQueue]; + if (!queue) return 0; + id cb = [queue commandBuffer]; + id enc = [cb computeCommandEncoder]; + [enc setComputePipelineState:ps]; + [enc setBuffer:bin offset:0 atIndex:0]; + [enc setBuffer:bfail offset:0 atIndex:1]; + [enc dispatchThreads:MTLSizeMake((NSUInteger)N, 1, 1) + threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; + [enc endEncoding]; + [cb commit]; + [cb waitUntilCompleted]; + if (cb.status == MTLCommandBufferStatusError) { + fprintf(stderr, "ds4: Metal 4.1 FP8 probe dispatch failed\n"); + return 0; + } + const uint32_t mismatches = *((uint32_t *)[bfail contents]); + if (mismatches != 0) { + fprintf(stderr, + "ds4: Metal 4.1 native FP8 path disagrees with software in %u/%d samples; " + "keeping software e4m3 quantizer\n", mismatches, N); + return 0; + } + return 1; + } +#endif + return 0; +} + static void ds4_gpu_detect_metal4_features(void) { g_metal4_runtime_available = 0; g_metal4_family_supported = 0; @@ -1778,6 +1900,7 @@ static void ds4_gpu_detect_metal4_features(void) { g_metal4_m5_neural_accelerators_hint = 0; g_metal4_tensor_api_enabled = 0; g_metal4_tensor_api_compile_supported = 0; + g_metal4_fp8_native_enabled = 0; g_metal_device_name[0] = '\0'; if (!g_device) return; @@ -1830,6 +1953,27 @@ static void ds4_gpu_detect_metal4_features(void) { fprintf(stderr, "ds4: Metal 4 tensor API disabled for pre-M5/pre-A19 devices\n"); } } + + /* + * Native FP8 KV quantizer. Independent of the TensorOps/neural-accelerator + * path: it only needs the MSL 4.1 packed-numeric feature, which the + * self-test probe verifies bit-exact against the software ladder. Enabled + * by default on any Metal 4.1 GPU where the probe passes; + * DS4_METAL_FP8_NATIVE forces it on (1) or off (0). + */ + if (!metal4_disabled) { + const int fp8_env = ds4_gpu_env_bool("DS4_METAL_FP8_NATIVE"); + if (fp8_env == 0) { + fprintf(stderr, "ds4: Metal 4.1 native FP8 KV quantizer disabled by DS4_METAL_FP8_NATIVE=0\n"); + } else { + g_metal4_fp8_native_enabled = ds4_gpu_validate_fp8_native(); + if (g_metal4_fp8_native_enabled) { + fprintf(stderr, "ds4: Metal 4.1 native FP8 KV quantizer enabled (validated bit-exact)\n"); + } else if (fp8_env == 1) { + fprintf(stderr, "ds4: DS4_METAL_FP8_NATIVE=1 requested but probe failed; using software e4m3\n"); + } + } + } } #endif } @@ -2821,6 +2965,9 @@ static int ds4_gpu_model_map_log_enabled(void) { static const char *ds4_gpu_source = "#include \n" +"#ifdef DS4_METAL_FP8_NATIVE\n" +"#include \n" +"#endif\n" "#ifdef DS4_METAL_HAS_TENSOR\n" "#include \n" "#include \n" @@ -4019,6 +4166,8 @@ static int ds4_gpu_encode_mul_mm_id_mapped_tile( float m1; int32_t n_head_log2; float logit_softcap; + int32_t n_raw_split; /* FP8 mixed-KV: key rows >= this are read from kcomp/vcomp (packed) */ + uint64_t nb_kcomp; /* packed comp row stride (bytes); 0 for the plain f16 path */ } ds4_gpu_flash_attn_vec_args; typedef struct { @@ -4290,6 +4439,8 @@ static int ds4_gpu_encode_rope_tail_inplace( int64_t ne1; uint64_t nb0; uint64_t nb1; + uint32_t mask_f16; + uint32_t pad0; } ds4_gpu_dsv4_topk_mask_args; typedef struct { @@ -4446,6 +4597,9 @@ int ds4_gpu_init(void) { macros[@"DS4_METAL_HAS_TENSOR"] = @"1"; fprintf(stderr, "ds4: Metal 4 tensor API enabled for Tensor kernels\n"); } + if (g_metal4_fp8_native_enabled) { + macros[@"DS4_METAL_FP8_NATIVE"] = @"1"; + } const int drift_hc_stable = ds4_gpu_env_bool("DS4_METAL_HC_STABLE") != 0; // default ON const int drift_norm_unify = ds4_gpu_env_bool("DS4_METAL_NORM_RSQRT_DISABLE") != 0; // default ON @@ -4650,6 +4804,22 @@ int ds4_gpu_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_cpy_f16_f16"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_cpy_f16_f16 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_cpy_f16_f16_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_cpy_f16_f16_pipeline) { + fprintf(stderr, "ds4: Metal kernel_cpy_f16_f16 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_dsv4_fp8_kv_quantize_f32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_dsv4_fp8_kv_quantize_f32 function not found\n"); @@ -4714,6 +4884,54 @@ int ds4_gpu_init(void) { return 0; } + fn = [library newFunctionWithName:@"kernel_dsv4_kv_pack_fp8_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_pack_fp8_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_kv_pack_fp8_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_kv_pack_fp8_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_pack_fp8_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_dsv4_kv_pack_fp8_row_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_pack_fp8_row_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_kv_pack_fp8_row_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_kv_pack_fp8_row_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_pack_fp8_row_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_dsv4_kv_unpack_fp8_row_f16"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_unpack_fp8_row_f16 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_kv_unpack_fp8_row_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_kv_unpack_fp8_row_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_kv_unpack_fp8_row_f16 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_swiglu_f32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_swiglu_f32 function not found\n"); @@ -6027,6 +6245,12 @@ int ds4_gpu_init(void) { g_initialized = 1; } + if (getenv("DS4_METAL_FP8_KV_STORE") != NULL) { + if (!ds4_gpu_validate_fp8_kv_pack()) { + fprintf(stderr, "ds4: packed FP8 KV writer self-test failed; not using packed KV layout\n"); + } + } + return 1; } @@ -6483,6 +6707,14 @@ int ds4_gpu_synchronize(void) { void ds4_gpu_cleanup(void) { if (!g_initialized) return; + if (g_prof_cb_enabled > 0) { + fprintf(stderr, + "ds4: cb-gpu-profile: GPU-busy %.3f s across %llu command buffers " + "(avg %.3f ms/cb)\n", + g_prof_gpu_busy_s, (unsigned long long)g_prof_cb_count, + g_prof_cb_count ? 1000.0 * g_prof_gpu_busy_s / (double)g_prof_cb_count : 0.0); + } + @autoreleasepool { if (g_batch_cb) { ds4_gpu_close_batch_encoder(); @@ -6519,6 +6751,7 @@ void ds4_gpu_cleanup(void) { g_cpy_f32_f32_pipeline = nil; g_cpy_f32_f16_pipeline = nil; g_cpy_f16_f32_pipeline = nil; + g_cpy_f16_f16_pipeline = nil; g_swiglu_pipeline = nil; g_add_pipeline = nil; g_moe_sum6_pipeline = nil; @@ -6579,6 +6812,9 @@ void ds4_gpu_cleanup(void) { g_dsv4_fp8_kv_quantize_pipeline = nil; g_dsv4_indexer_qat_pipeline = nil; g_dsv4_kv_fp8_store_pipeline = nil; + g_dsv4_kv_pack_fp8_pipeline = nil; + g_dsv4_kv_pack_fp8_row_pipeline = nil; + g_dsv4_kv_unpack_fp8_row_pipeline = nil; g_dsv4_ratio4_shift_pipeline = nil; g_dsv4_softmax_pool_pipeline = nil; g_soft_max_f32_pipeline = nil; @@ -6604,6 +6840,8 @@ void ds4_gpu_cleanup(void) { g_flash_attn_blk_buffer = nil; g_flash_attn_ring_buffer = nil; g_flash_attn_kv_buffer = nil; + g_flash_attn_kv_packed_buffer = nil; + g_flash_attn_kv_packed_bytes = 0; g_compressor_pool_kv_buffer = nil; g_compressor_pool_score_buffer = nil; g_compressor_pool_score_cont_buffer = nil; @@ -12301,6 +12539,97 @@ int ds4_gpu_indexer_scores_prefill_tensor( scale); } +/* Token-tiled prefill score + top-k. Processes n_tokens in tiles of score_tile so + * the scores buffer only needs comp_cap*score_tile (not comp_cap*prefill_cap). For + * each tile it scores with pos0 = t0 (token i is at global position i, identical to + * the untiled pos0=0 full pass) and writes the tile's comp_selected rows via + * token-offset views. Top-k is independent per token, so this is exact (zero drift). */ +int ds4_gpu_indexer_prefill_score_topk_tiled( + ds4_gpu_tensor *scores, + ds4_gpu_tensor *selected, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + uint32_t top_k, + uint32_t score_tile) { + if (!scores || !selected || !q || !weights || !index_comp || + n_comp == 0 || n_tokens == 0 || top_k == 0 || score_tile == 0) { + return 0; + } + for (uint32_t t0 = 0; t0 < n_tokens; t0 += score_tile) { + const uint32_t tc = (n_tokens - t0) < score_tile ? (n_tokens - t0) : score_tile; + ds4_gpu_tensor *q_v = ds4_gpu_tensor_view(q, + (uint64_t)t0 * n_head * head_dim * sizeof(float), + (uint64_t)tc * n_head * head_dim * sizeof(float)); + ds4_gpu_tensor *w_v = ds4_gpu_tensor_view(weights, + (uint64_t)t0 * n_head * sizeof(float), + (uint64_t)tc * n_head * sizeof(float)); + ds4_gpu_tensor *s_v = ds4_gpu_tensor_view(selected, + (uint64_t)t0 * top_k * sizeof(uint32_t), + (uint64_t)tc * top_k * sizeof(uint32_t)); + int ok = q_v && w_v && s_v && + ds4_gpu_indexer_scores_batch_tensor(scores, q_v, w_v, index_comp, + n_comp, tc, t0, n_head, head_dim, ratio, scale) != 0 && + ds4_gpu_indexer_topk_tensor(s_v, scores, n_comp, tc, top_k) != 0; + ds4_gpu_tensor_free(s_v); + ds4_gpu_tensor_free(w_v); + ds4_gpu_tensor_free(q_v); + if (!ok) return 0; + } + return 1; +} + +/* Token-tiled decode-batch score + top-k (same scratch reuse as the prefill + * variant). Used for non-zero-prefix prefill chunks where n_tokens can be large. + * Each tile scores with pos0_base + t0 (preserving per-token global positions). */ +int ds4_gpu_indexer_decode_batch_score_topk_tiled( + ds4_gpu_tensor *scores, + ds4_gpu_tensor *selected, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *weights, + const ds4_gpu_tensor *index_comp, + uint32_t n_comp, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_head, + uint32_t head_dim, + uint32_t ratio, + float scale, + uint32_t top_k, + uint32_t score_tile) { + if (!scores || !selected || !q || !weights || !index_comp || + n_comp == 0 || n_tokens == 0 || top_k == 0 || score_tile == 0) { + return 0; + } + for (uint32_t t0 = 0; t0 < n_tokens; t0 += score_tile) { + const uint32_t tc = (n_tokens - t0) < score_tile ? (n_tokens - t0) : score_tile; + ds4_gpu_tensor *q_v = ds4_gpu_tensor_view(q, + (uint64_t)t0 * n_head * head_dim * sizeof(float), + (uint64_t)tc * n_head * head_dim * sizeof(float)); + ds4_gpu_tensor *w_v = ds4_gpu_tensor_view(weights, + (uint64_t)t0 * n_head * sizeof(float), + (uint64_t)tc * n_head * sizeof(float)); + ds4_gpu_tensor *s_v = ds4_gpu_tensor_view(selected, + (uint64_t)t0 * top_k * sizeof(uint32_t), + (uint64_t)tc * top_k * sizeof(uint32_t)); + int ok = q_v && w_v && s_v && + ds4_gpu_indexer_scores_decode_batch_tensor(scores, q_v, w_v, index_comp, + n_comp, tc, pos0 + t0, n_head, head_dim, ratio, scale) != 0 && + ds4_gpu_indexer_topk_tensor(s_v, scores, n_comp, tc, top_k) != 0; + ds4_gpu_tensor_free(s_v); + ds4_gpu_tensor_free(w_v); + ds4_gpu_tensor_free(q_v); + if (!ok) return 0; + } + return 1; +} + int ds4_gpu_indexer_scores_decode_batch_tensor( ds4_gpu_tensor *scores, const ds4_gpu_tensor *q, @@ -12469,6 +12798,13 @@ int ds4_gpu_argmax_tensor( return ds4_gpu_indexer_topk_tensor(out_idx, logits, n_vocab, 1, 1); } +/* comp_mask is stored f16 (1.75x memory opt) unless the KV-cache optimizations are + * disabled for A/B comparison via DS4_DISABLE_KV_OPTS, in which case it is f32 (the + * pre-optimization layout). Same env also disables packed comp cache + score tiling. */ +int ds4_gpu_comp_mask_f16(void) { + return getenv("DS4_DISABLE_KV_OPTS") == NULL; +} + int ds4_gpu_dsv4_topk_mask_tensor( ds4_gpu_tensor *mask, const ds4_gpu_tensor *topk, @@ -12479,8 +12815,10 @@ int ds4_gpu_dsv4_topk_mask_tensor( if (!mask || !topk || n_comp == 0 || n_tokens == 0 || top_k == 0) return 0; @autoreleasepool { + const uint32_t mask_f16 = ds4_gpu_comp_mask_f16() ? 1u : 0u; + const uint64_t mask_elem = mask_f16 ? sizeof(uint16_t) : sizeof(float); const uint64_t topk_bytes = (uint64_t)top_k * n_tokens * sizeof(int32_t); - const uint64_t mask_bytes = (uint64_t)n_comp * n_tokens * sizeof(float); + const uint64_t mask_bytes = (uint64_t)n_comp * n_tokens * mask_elem; id topkbuf = ds4_gpu_tensor_buffer(topk); id maskbuf = ds4_gpu_tensor_buffer(mask); if (!topkbuf || !maskbuf || @@ -12490,6 +12828,7 @@ int ds4_gpu_dsv4_topk_mask_tensor( return 0; } + /* comp_mask: binary -inf/0; f16 (default) or f32 (opts disabled). */ ds4_gpu_dsv4_topk_mask_args args = { .ne00 = (int64_t)top_k, .ne01 = (int64_t)n_tokens, @@ -12497,8 +12836,9 @@ int ds4_gpu_dsv4_topk_mask_tensor( .nb01 = (uint64_t)top_k * sizeof(int32_t), .ne0 = (int64_t)n_comp, .ne1 = (int64_t)n_tokens, - .nb0 = sizeof(float), - .nb1 = (uint64_t)n_comp * sizeof(float), + .nb0 = mask_elem, + .nb1 = (uint64_t)n_comp * mask_elem, + .mask_f16 = mask_f16, }; int owned = 0; @@ -13678,6 +14018,169 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( return 1; } +/* Pack n_rows of f32 comp KV from `src` (offset 0) into the persistent packed + * cache `dst` at byte offset dst_byte_offset, in the ds4_fp8_kv_row 3-plane FP8 + * layout the attention readers consume. One dispatch of the validated row-pack + * kernel; runs inside the active (batched) command buffer. */ +int ds4_gpu_dsv4_kv_pack_comp_rows( + ds4_gpu_tensor *dst, + uint64_t dst_byte_offset, + ds4_gpu_tensor *src, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!dst || !src || n_rows == 0 || head_dim == 0 || n_rot >= head_dim) return 0; + if (!g_dsv4_kv_pack_fp8_row_pipeline) return 0; + + @autoreleasepool { + id dbuf = ds4_gpu_tensor_buffer(dst); + id sbuf = ds4_gpu_tensor_buffer(src); + if (!dbuf || !sbuf) return 0; + + ds4_gpu_dsv4_fp8_kv_quantize_args args = { + .ne00 = head_dim, .ne01 = n_rows, .ne02 = 1, .ne03 = 1, + .nb00 = sizeof(float), .nb01 = (uint64_t)head_dim * sizeof(float), + .nb02 = (uint64_t)n_rows * head_dim * sizeof(float), + .nb03 = (uint64_t)n_rows * head_dim * sizeof(float), + .nb0 = sizeof(float), .nb1 = (uint64_t)head_dim * sizeof(float), + .nb2 = (uint64_t)n_rows * head_dim * sizeof(float), + .nb3 = (uint64_t)n_rows * head_dim * sizeof(float), + .n_rot = (int32_t)n_rot, + }; + + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_kv_pack_fp8_row_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:sbuf offset:ds4_gpu_tensor_offset(src) atIndex:1]; + [enc setBuffer:dbuf offset:ds4_gpu_tensor_offset(dst) + dst_byte_offset atIndex:2]; + [enc setThreadgroupMemoryLength:64u * sizeof(float) atIndex:0]; + [enc dispatchThreadgroups:MTLSizeMake(n_rows, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + if (!ds4_gpu_finish_command_buffer(cb, owned, "DSV4 KV pack comp rows")) return 0; + } + + return 1; +} + +/* + * Diagnostic self-check for the packed FP8 KV writer. Packs the (already + * e4m3-quantized, float-stored) compressed-KV tensor into the three-plane + * layout (nope e4m3 bytes + ue8m0 scales + rot float), reads it back, and + * verifies the reconstruction is bit-identical to the float tensor. This + * validates the host buffer plumbing + dispatch on real KV data before the + * FlashAttention read path is rewired. Enabled by DS4_METAL_FP8_KV_STORE; it + * does not change the live KV representation. Returns 1 on success (or when + * disabled), 0 on a verified mismatch. + */ +static float ds4_host_e4m3fn_value(int i) { + static const float es[16] = {0.0f, 0.015625f, 0.03125f, 0.0625f, 0.125f, 0.25f, 0.5f, 1.0f, + 2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f}; + const int e = (i >> 3) & 0xf; + const int m = i & 0x7; + return e == 0 ? (float)m * 0.001953125f : (1.0f + (float)m * 0.125f) * es[e]; +} + +static int ds4_gpu_validate_fp8_kv_pack(void) { + if (!g_dsv4_kv_pack_fp8_pipeline || !g_dsv4_fp8_kv_quantize_pipeline || !g_device || !g_queue) { + return 1; + } + const uint32_t head_dim = 512, n_rot = 64, nope = head_dim - n_rot, nblk = nope / 64, n_tok = 64; + + @autoreleasepool { + const NSUInteger in_bytes = (NSUInteger)n_tok * head_dim * sizeof(float); + id bin = [g_device newBufferWithLength:in_bytes options:MTLResourceStorageModeShared]; + id bref = [g_device newBufferWithLength:in_bytes options:MTLResourceStorageModeShared]; + id bnope = [g_device newBufferWithLength:(NSUInteger)n_tok * nope options:MTLResourceStorageModeShared]; + id bscl = [g_device newBufferWithLength:(NSUInteger)n_tok * nblk options:MTLResourceStorageModeShared]; + id brot = [g_device newBufferWithLength:(NSUInteger)n_tok * n_rot * sizeof(float) options:MTLResourceStorageModeShared]; + if (!bin || !bref || !bnope || !bscl || !brot) return 1; + + float *in = (float *)[bin contents]; + for (uint32_t i = 0; i < n_tok * head_dim; i++) { + const double t = (double)(i % 997) / 997.0; + const double s = (i & 1) ? -1.0 : 1.0; + in[i] = (float)(s * ldexp(1.0, (int)(t * 30.0) - 15)); /* wide range incl subnormal/large */ + } + memcpy([bref contents], in, in_bytes); /* bref quantized in place = float reference */ + + ds4_gpu_dsv4_fp8_kv_quantize_args qa = { + .ne00 = (int32_t)head_dim, .ne01 = (int32_t)n_tok, .ne02 = 1, .ne03 = 1, + .nb00 = sizeof(float), .nb01 = (uint64_t)head_dim * sizeof(float), + .nb02 = (uint64_t)n_tok * head_dim * sizeof(float), + .nb03 = (uint64_t)n_tok * head_dim * sizeof(float), + .nb0 = sizeof(float), .nb1 = (uint64_t)head_dim * sizeof(float), + .nb2 = (uint64_t)n_tok * head_dim * sizeof(float), + .nb3 = (uint64_t)n_tok * head_dim * sizeof(float), + .n_rot = (int32_t)n_rot, + }; + + /* Dedicated, committed command buffer — independent of the batched path. */ + id cb = [g_queue commandBuffer]; + { + id e = [cb computeCommandEncoder]; + [e setComputePipelineState:g_dsv4_fp8_kv_quantize_pipeline]; + [e setBytes:&qa length:sizeof(qa) atIndex:0]; + [e setBuffer:bref offset:0 atIndex:1]; + [e setBuffer:bref offset:0 atIndex:2]; + [e setThreadgroupMemoryLength:64u * sizeof(float) atIndex:0]; + [e dispatchThreadgroups:MTLSizeMake(n_tok, 1, 1) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + [e endEncoding]; + } + { + id e = [cb computeCommandEncoder]; + [e setComputePipelineState:g_dsv4_kv_pack_fp8_pipeline]; + [e setBytes:&qa length:sizeof(qa) atIndex:0]; + [e setBuffer:bin offset:0 atIndex:1]; + [e setBuffer:bnope offset:0 atIndex:2]; + [e setBuffer:bscl offset:0 atIndex:3]; + [e setBuffer:brot offset:0 atIndex:4]; + [e setThreadgroupMemoryLength:64u * sizeof(float) atIndex:0]; + [e dispatchThreadgroups:MTLSizeMake(n_tok, 1, 1) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + [e endEncoding]; + } + [cb commit]; + [cb waitUntilCompleted]; + if (cb.status == MTLCommandBufferStatusError) return 1; + + const float *ref = (const float *)[bref contents]; + const uint8_t *nb = (const uint8_t *)[bnope contents]; + const uint8_t *sb = (const uint8_t *)[bscl contents]; + const float *rt = (const float *)[brot contents]; + uint64_t mismatches = 0; + for (uint32_t r = 0; r < n_tok; r++) { + const float *rrow = ref + (uint64_t)r * head_dim; + for (uint32_t i = 0; i < nope; i++) { + const uint8_t by = nb[(uint64_t)r * nope + i]; + const uint8_t ue = sb[(uint64_t)r * nblk + i / 64]; + const float val = ((by & 0x80) ? -1.0f : 1.0f) * + ds4_host_e4m3fn_value(by & 0x7f) * ldexpf(1.0f, (int)ue - 127); + if (val != rrow[i]) mismatches++; + } + for (uint32_t i = 0; i < n_rot; i++) { + if (rt[(uint64_t)r * n_rot + i] != rrow[nope + i]) mismatches++; + } + } + if (mismatches != 0) { + fprintf(stderr, + "ds4: FP8-KV-pack self-test FAILED: %llu/%u elements differ from float-quantized KV\n", + (unsigned long long)mismatches, n_tok * head_dim); + return 0; + } + fprintf(stderr, + "ds4: FP8-KV-pack self-test OK (packed nope+scale+rot reconstructs float-quantized KV " + "bit-exactly; %.2fx smaller)\n", + (double)(head_dim * sizeof(float)) / (double)(nope + nblk + n_rot * sizeof(float))); + return 1; + } +} + int ds4_gpu_dsv4_indexer_qat_tensor( ds4_gpu_tensor *x, uint32_t n_rows, @@ -16518,6 +17021,54 @@ static int ds4_gpu_encode_cpy_f32_f16_2d( return 1; } +/* Strided 2D f16->f16 copy (same layout placement as the f32->f16 variant, used + * for the now-f16 comp_mask -> flash mask staging). */ +static int ds4_gpu_encode_cpy_f16_f16_2d( + id cb, + id src, + NSUInteger src_off, + id dst, + NSUInteger dst_off, + uint32_t cols, + uint32_t rows, + uint64_t src_row_stride, + uint64_t dst_row_stride) { + if (!cb || !src || !dst || cols == 0 || rows == 0) return 0; + + ds4_gpu_cpy_args args = { + .nk0 = (int64_t)cols, + .ne00 = (int64_t)cols, + .ne01 = (int64_t)rows, + .ne02 = 1, + .ne03 = 1, + .nb00 = sizeof(uint16_t), + .nb01 = src_row_stride, + .nb02 = (uint64_t)rows * src_row_stride, + .nb03 = (uint64_t)rows * src_row_stride, + .ne0 = (int64_t)cols, + .ne1 = (int64_t)rows, + .ne2 = 1, + .ne3 = 1, + .nb0 = sizeof(uint16_t), + .nb1 = dst_row_stride, + .nb2 = (uint64_t)rows * dst_row_stride, + .nb3 = (uint64_t)rows * dst_row_stride, + }; + const NSUInteger nth = ds4_gpu_cpy_threads(cols, g_cpy_f16_f16_pipeline); + const NSUInteger col_groups = ((NSUInteger)cols + nth - 1u) / nth; + + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_cpy_f16_f16_pipeline]; + [enc setBytes:&args length:sizeof(args) atIndex:0]; + [enc setBuffer:src offset:src_off atIndex:1]; + [enc setBuffer:dst offset:dst_off atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake(col_groups * rows, 1, 1) + threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + + return 1; +} + static int ds4_gpu_encode_cpy_f16_f32_1d( id cb, id src, @@ -16548,13 +17099,32 @@ static int ds4_gpu_encode_copy_to_f16_1d( id cb, id src, NSUInteger src_off, - bool src_is_f16, + int src_format, /* 0 = f32, 1 = f16, 2 = packed FP8 */ id dst, NSUInteger dst_off, uint32_t n) { if (!cb || !src || !dst) return 0; if (n == 0) return 1; - if (!src_is_f16) { + if (src_format == 2) { + /* Packed FP8 comp rows (ds4_fp8_kv_row, 512/64) -> contiguous f16. Used by + * the cold flash staging path; the hot indexed path reads packed directly. */ + const uint32_t head_dim = 512u, n_rot = 64u; + const uint32_t rows = n / head_dim; + if (rows == 0) return 1; + ds4_gpu_dsv4_fp8_kv_quantize_args a = { + .ne00 = head_dim, .ne01 = rows, .ne02 = 1, .ne03 = 1, .n_rot = (int32_t)n_rot, + }; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_kv_unpack_fp8_row_pipeline]; + [enc setBytes:&a length:sizeof(a) atIndex:0]; + [enc setBuffer:src offset:src_off atIndex:1]; + [enc setBuffer:dst offset:dst_off atIndex:2]; + [enc dispatchThreadgroups:MTLSizeMake(rows, 1, 1) + threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + return 1; + } + if (src_format == 0) { return ds4_gpu_encode_cpy_f32_f16_1d(cb, src, src_off, dst, dst_off, n); } @@ -16943,9 +17513,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long id headsbuf = ds4_gpu_tensor_buffer(heads); const uint64_t q_bytes = (uint64_t)n_tokens * n_head * head_dim * sizeof(float); const uint64_t raw_bytes = (uint64_t)n_tokens * head_dim * sizeof(float); - const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * - (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); - const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * sizeof(float) : 0u; + const uint64_t comp_bytes = (comp_kv_f16 == 2u) + ? (uint64_t)n_comp * (8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t)) + : (uint64_t)n_comp * head_dim * (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); + const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)) : 0u; if (!qbuf || !rawbuf || !compbuf || !maskbuf || !headsbuf || !sinks_buf || ds4_gpu_tensor_bytes(q) < q_bytes || ds4_gpu_tensor_bytes(raw_kv) < raw_bytes || @@ -17026,7 +17597,7 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long !ds4_gpu_encode_copy_to_f16_1d(cb, compbuf, ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, + (int)comp_kv_f16, g_flash_attn_kv_buffer, (NSUInteger)n_tokens * row_bytes_f16, n_comp * head_dim)) { @@ -17043,14 +17614,14 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_nonvec_long ratio); DS4_METAL_PROFILE_FLASH_ATTN_STAGE("mask_fill"); if (use_comp_mask && n_comp != 0) { - if (!ds4_gpu_encode_cpy_f32_f16_2d(cb, + if (!(ds4_gpu_comp_mask_f16() ? ds4_gpu_encode_cpy_f16_f16_2d : ds4_gpu_encode_cpy_f32_f16_2d)(cb, maskbuf, ds4_gpu_tensor_offset(comp_mask), mask_buffer, (NSUInteger)n_tokens * sizeof(uint16_t), n_comp, n_tokens, - (uint64_t)n_comp * sizeof(float), + (uint64_t)n_comp * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)), (uint64_t)n_keys * sizeof(uint16_t))) { return 0; } @@ -17218,9 +17789,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( id headsbuf = ds4_gpu_tensor_buffer(heads); const uint64_t q_bytes = (uint64_t)n_tokens * n_head * head_dim * sizeof(float); const uint64_t raw_bytes = (uint64_t)n_tokens * head_dim * sizeof(float); - const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * - (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); - const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * sizeof(float) : 0u; + const uint64_t comp_bytes = (comp_kv_f16 == 2u) + ? (uint64_t)n_comp * (8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t)) + : (uint64_t)n_comp * head_dim * (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); + const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)) : 0u; if (!qbuf || !rawbuf || !compbuf || !maskbuf || !headsbuf || !sinks_buf || ds4_gpu_tensor_bytes(q) < q_bytes || ds4_gpu_tensor_bytes(raw_kv) < raw_bytes || @@ -17246,6 +17818,15 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( const NSUInteger tmp_bytes = nrows * (NSUInteger)head_dim * (NSUInteger)nwg * sizeof(float) + nrows * (2u * (NSUInteger)nwg) * sizeof(float); + // FP8 mixed-KV read (DS4_METAL_FP8_KV_STORE): read comp rows from a packed FP8 + // buffer in the flash kernel instead of the F16 staging. Restricted to the + // float-comp, no-kvpad case so the pad buffer (built from F16 staging) is not + // needed; raw rows still come from the F16 staging. Bit-exact vs F16 (proven). + const NSUInteger fp8_pack_stride = + 8u + ((NSUInteger)head_dim - 64u) + 64u * sizeof(uint16_t); /* scale8 + nope448 + rot64*f16 = 584 */ + const bool use_fp8 = (getenv("DS4_METAL_FP8_KV_STORE") != NULL) && + g_dsv4_kv_pack_fp8_row_pipeline && n_comp && !has_kvpad && !comp_kv_f16; + id mask_buffer = ds4_gpu_new_transient_buffer(mask_bytes, "ds4_flash_attn_mask"); if (!mask_buffer || @@ -17277,6 +17858,13 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( *cbp = cb; flash_stage_t0 = ds4_gpu_now_ms(); } + if (use_fp8 && + !ds4_gpu_ensure_scratch_buffer(&g_flash_attn_kv_packed_buffer, + &g_flash_attn_kv_packed_bytes, + (NSUInteger)n_comp * fp8_pack_stride, + "ds4_flash_attn_kv_packed")) { + return 0; + } #define DS4_METAL_PROFILE_FLASH_ATTN_STAGE(name) do { \ if (flash_stage_profile) { \ if (!ds4_gpu_flash_attn_stage_profile_boundary(cbp, \ @@ -17297,11 +17885,28 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( return 0; } DS4_METAL_PROFILE_FLASH_ATTN_STAGE("copy_raw"); - if (n_comp) { + if (n_comp && use_fp8) { + // Pack the float comp KV directly into the 3-plane row-interleaved FP8 + // layout the fp8mix flash kernel reads; skips the comp->F16 staging copy. + ds4_gpu_dsv4_fp8_kv_quantize_args pack_args = { + .ne00 = (int64_t)head_dim, .ne01 = (int64_t)n_comp, .ne02 = 1, .ne03 = 1, + .n_rot = 64, + }; + id penc = ds4_gpu_compute_encoder(cb); + [penc setComputePipelineState:g_dsv4_kv_pack_fp8_row_pipeline]; + [penc setBytes:&pack_args length:sizeof(pack_args) atIndex:0]; + [penc setBuffer:compbuf offset:ds4_gpu_tensor_offset(comp_kv) atIndex:1]; + [penc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:2]; + [penc setThreadgroupMemoryLength:64u * sizeof(float) atIndex:0]; + [penc dispatchThreadgroups:MTLSizeMake(n_comp, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, penc); + DS4_METAL_PROFILE_FLASH_ATTN_STAGE("pack_comp"); + } else if (n_comp) { if (!ds4_gpu_encode_copy_to_f16_1d(cb, compbuf, ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, + (int)comp_kv_f16, g_flash_attn_kv_buffer, (NSUInteger)n_tokens * row_bytes_f16, n_comp * head_dim)) { @@ -17317,14 +17922,14 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( ratio); DS4_METAL_PROFILE_FLASH_ATTN_STAGE("mask_fill"); if (use_comp_mask && n_comp != 0) { - if (!ds4_gpu_encode_cpy_f32_f16_2d(cb, + if (!(ds4_gpu_comp_mask_f16() ? ds4_gpu_encode_cpy_f16_f16_2d : ds4_gpu_encode_cpy_f32_f16_2d)(cb, maskbuf, ds4_gpu_tensor_offset(comp_mask), mask_buffer, (NSUInteger)n_tokens * sizeof(uint16_t), n_comp, n_tokens, - (uint64_t)n_comp * sizeof(float), + (uint64_t)n_comp * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)), (uint64_t)n_keys * sizeof(uint16_t))) { return 0; } @@ -17338,7 +17943,8 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( if (!pad_pipeline) return 0; } id vec_pipeline = - ds4_gpu_get_flash_attn_vec_pipeline("kernel_flash_attn_ext_vec_f16_dk512_dv512", + ds4_gpu_get_flash_attn_vec_pipeline(use_fp8 ? "kernel_flash_attn_ext_vec_fp8mix_dk512_dv512" + : "kernel_flash_attn_ext_vec_f16_dk512_dv512", true, true, false, false, has_kvpad, (int32_t)head_dim, (int32_t)head_dim, @@ -17413,6 +18019,8 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( .m1 = 0.0f, .n_head_log2 = 0, .logit_softcap = 0.0f, + .n_raw_split = use_fp8 ? (int32_t)n_tokens : 0, + .nb_kcomp = use_fp8 ? (uint64_t)fp8_pack_stride : 0, }; const NSUInteger shared_elems = (ds4_gpu_align_up_ns(head_dim, 128u) + @@ -17430,6 +18038,10 @@ static int ds4_gpu_encode_flash_attention_prefill_static_mixed_heads_vec( [enc setBuffer:sinks_buf offset:sinks_offset atIndex:5]; [enc setBuffer:g_flash_attn_pad_buffer offset:0 atIndex:6]; [enc setBuffer:g_flash_attn_tmp_buffer offset:0 atIndex:7]; + if (use_fp8) { + [enc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:8]; + [enc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:9]; + } [enc setThreadgroupMemoryLength:shared_bytes atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake(n_tokens, n_head, nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; @@ -17998,9 +18610,10 @@ static int ds4_gpu_encode_flash_attention_gathered_heads( id maskbuf = use_mask ? ds4_gpu_tensor_buffer(comp_mask) : nil; const uint64_t q_bytes = (uint64_t)n_head * head_dim * sizeof(float); const uint64_t raw_bytes = (uint64_t)raw_cap * head_dim * sizeof(float); - const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * - (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); - const uint64_t comp_mask_bytes = use_mask ? (uint64_t)n_comp * sizeof(float) : 0u; + const uint64_t comp_bytes = (comp_kv_f16 == 2u) + ? (uint64_t)n_comp * (8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t)) + : (uint64_t)n_comp * head_dim * (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); + const uint64_t comp_mask_bytes = use_mask ? (uint64_t)n_comp * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)) : 0u; if (!qbuf || !rawbuf || !headsbuf || !sinks_buf || (n_comp && !compbuf) || (use_mask && !maskbuf) || @@ -18107,7 +18720,7 @@ static int ds4_gpu_encode_flash_attention_gathered_heads( if (!ds4_gpu_encode_copy_to_f16_1d(cb, compbuf, ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, + (int)comp_kv_f16, g_flash_attn_kv_buffer, (NSUInteger)n_raw * row_bytes_f16, n_comp * head_dim)) { @@ -18520,9 +19133,10 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( id headsbuf = ds4_gpu_tensor_buffer(heads); const uint64_t q_bytes = (uint64_t)n_tokens * n_head * head_dim * sizeof(float); const uint64_t raw_bytes = (uint64_t)raw_cap * head_dim * sizeof(float); - const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * - (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); - const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * sizeof(float) : 0u; + const uint64_t comp_bytes = (comp_kv_f16 == 2u) + ? (uint64_t)n_comp * (8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t)) + : (uint64_t)n_comp * head_dim * (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); + const uint64_t comp_mask_bytes = use_comp_mask ? (uint64_t)n_comp * n_tokens * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)) : 0u; if (!qbuf || !rawbuf || !compbuf || !maskbuf || !headsbuf || !sinks_buf || ds4_gpu_tensor_bytes(q) < q_bytes || ds4_gpu_tensor_bytes(raw_kv) < raw_bytes || @@ -18549,6 +19163,14 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( const NSUInteger nblk1 = ((NSUInteger)n_tokens + nqptg - 1u) / nqptg; const NSUInteger blk_bytes = ds4_gpu_align_up_ns(nblk0 * nblk1, 32u); + // FP8 mixed-KV read (DS4_METAL_FP8_KV_STORE): comp rows from packed FP8 instead + // of the F16 staging in the decode flash kernel. Restricted to float-comp + + // no-kvpad (pad buffer is built from F16 staging). Bit-exact vs F16 (proven). + const NSUInteger fp8_pack_stride = + 8u + ((NSUInteger)head_dim - 64u) + 64u * sizeof(uint16_t); /* 584 */ + const bool use_fp8 = (getenv("DS4_METAL_FP8_KV_STORE") != NULL) && + g_dsv4_kv_pack_fp8_row_pipeline && n_comp && !has_kvpad && !comp_kv_f16; + id mask_buffer = ds4_gpu_new_transient_buffer(mask_bytes, "ds4_flash_attn_mask"); if (!mask_buffer || @@ -18566,6 +19188,13 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( "ds4_flash_attn_blk")) { return 0; } + if (use_fp8 && + !ds4_gpu_ensure_scratch_buffer(&g_flash_attn_kv_packed_buffer, + &g_flash_attn_kv_packed_bytes, + (NSUInteger)n_comp * fp8_pack_stride, + "ds4_flash_attn_kv_packed")) { + return 0; + } id kvbuf = rawbuf; NSUInteger kvoff = ds4_gpu_tensor_offset(raw_kv); @@ -18605,11 +19234,29 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( kvoff, g_flash_attn_kv_buffer, 0, - n_raw * head_dim) || - !ds4_gpu_encode_copy_to_f16_1d(cb, + n_raw * head_dim)) { + return 0; + } + if (use_fp8) { + // Pack float comp KV directly into the 3-plane row layout the fp8mix flash + // kernel reads; skips the comp->F16 staging copy. Raw stays F16 at [0,n_raw). + ds4_gpu_dsv4_fp8_kv_quantize_args pack_args = { + .ne00 = (int64_t)head_dim, .ne01 = (int64_t)n_comp, .ne02 = 1, .ne03 = 1, + .n_rot = 64, + }; + id penc = ds4_gpu_compute_encoder(cb); + [penc setComputePipelineState:g_dsv4_kv_pack_fp8_row_pipeline]; + [penc setBytes:&pack_args length:sizeof(pack_args) atIndex:0]; + [penc setBuffer:compbuf offset:ds4_gpu_tensor_offset(comp_kv) atIndex:1]; + [penc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:2]; + [penc setThreadgroupMemoryLength:64u * sizeof(float) atIndex:0]; + [penc dispatchThreadgroups:MTLSizeMake(n_comp, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, penc); + } else if (!ds4_gpu_encode_copy_to_f16_1d(cb, compbuf, ds4_gpu_tensor_offset(comp_kv), - comp_kv_f16 != 0, + (int)comp_kv_f16, g_flash_attn_kv_buffer, (NSUInteger)n_raw * row_bytes_f16, n_comp * head_dim)) { @@ -18624,14 +19271,14 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( window, ratio); if (use_comp_mask) { - if (!ds4_gpu_encode_cpy_f32_f16_2d(cb, + if (!(ds4_gpu_comp_mask_f16() ? ds4_gpu_encode_cpy_f16_f16_2d : ds4_gpu_encode_cpy_f32_f16_2d)(cb, maskbuf, ds4_gpu_tensor_offset(comp_mask), mask_buffer, (NSUInteger)n_raw * sizeof(uint16_t), n_comp, n_tokens, - (uint64_t)n_comp * sizeof(float), + (uint64_t)n_comp * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)), (uint64_t)n_keys * sizeof(uint16_t))) { return 0; } @@ -18645,7 +19292,8 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( id blk_pipeline = ds4_gpu_get_flash_attn_blk_pipeline((int32_t)nqptg, (int32_t)ncpsg); id attn_pipeline = - ds4_gpu_get_flash_attn_pipeline("kernel_flash_attn_ext_f16_dk512_dv512", + ds4_gpu_get_flash_attn_pipeline(use_fp8 ? "kernel_flash_attn_ext_fp8mix_dk512_dv512" + : "kernel_flash_attn_ext_f16_dk512_dv512", true, true, false, false, has_kvpad, bc_mask, (int32_t)head_dim, (int32_t)head_dim, @@ -18736,6 +19384,8 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( .m1 = 0.0f, .n_head_log2 = 0, .logit_softcap = 0.0f, + .n_raw_split = use_fp8 ? (int32_t)n_raw : 0, + .nb_kcomp = use_fp8 ? (uint64_t)fp8_pack_stride : 0, }; const NSUInteger padded_v = ds4_gpu_align_up_ns(head_dim, 64u); @@ -18754,6 +19404,10 @@ static int ds4_gpu_encode_flash_attention_decode_mixed_batch_heads( [enc setBuffer:g_flash_attn_pad_buffer offset:0 atIndex:6]; [enc setBuffer:g_flash_attn_blk_buffer offset:0 atIndex:7]; [enc setBuffer:headsbuf offset:ds4_gpu_tensor_offset(heads) atIndex:8]; + if (use_fp8) { + [enc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:9]; + [enc setBuffer:g_flash_attn_kv_packed_buffer offset:0 atIndex:10]; + } [enc setThreadgroupMemoryLength:shared_bytes atIndex:0]; [enc dispatchThreadgroups:MTLSizeMake(nblk1, n_head, 1) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; @@ -18985,9 +19639,13 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( const uint64_t row_bytes = (uint64_t)head_dim * sizeof(float); const uint64_t row_bytes_f16 = (uint64_t)head_dim * sizeof(uint16_t); + /* Packed FP8 comp row stride (ds4_fp8_kv_row): scale8 + nope + rot f16. */ + const uint64_t row_bytes_packed = 8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t); + const uint64_t comp_row_bytes = (comp_kv_f16 == 2u) ? row_bytes_packed + : (comp_kv_f16 ? row_bytes_f16 : row_bytes); const uint64_t q_bytes = (uint64_t)n_tokens * n_head * row_bytes; const uint64_t raw_bytes = (uint64_t)raw_cap * row_bytes; - const uint64_t comp_bytes = (uint64_t)n_comp * (comp_kv_f16 ? row_bytes_f16 : row_bytes); + const uint64_t comp_bytes = (uint64_t)n_comp * comp_row_bytes; const uint64_t topk_bytes = (uint64_t)top_k * n_tokens * sizeof(int32_t); id qbuf = ds4_gpu_tensor_buffer(q); id rawbuf = ds4_gpu_tensor_buffer(raw_kv); @@ -19061,12 +19719,12 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( .pos0 = pos0, .window = window, .ratio = ratio, - .comp_kv_f16 = comp_kv_f16 ? 1u : 0u, + .comp_kv_f16 = comp_kv_f16, /* 0=f32, 1=f16, 2=packed FP8 */ .pad0 = 0, .q_token_stride = (uint64_t)n_head * row_bytes, .q_head_stride = row_bytes, .raw_row_stride = row_bytes, - .comp_row_stride = comp_kv_f16 ? row_bytes_f16 : row_bytes, + .comp_row_stride = comp_row_bytes, .topk_token_stride = (uint64_t)top_k * sizeof(int32_t), .dst_token_stride = (uint64_t)n_head * row_bytes, .dst_head_stride = row_bytes, @@ -19271,8 +19929,9 @@ int ds4_gpu_attention_decode_heads_tensor( @autoreleasepool { const uint64_t q_bytes = (uint64_t)n_head * head_dim * sizeof(float); const uint64_t raw_bytes = (uint64_t)raw_cap * head_dim * sizeof(float); - const uint64_t comp_bytes = (uint64_t)n_comp * head_dim * - (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); + const uint64_t comp_bytes = (comp_kv_f16 == 2u) + ? (uint64_t)n_comp * (8u + ((uint64_t)head_dim - 64u) + 64u * sizeof(uint16_t)) + : (uint64_t)n_comp * head_dim * (comp_kv_f16 ? sizeof(uint16_t) : sizeof(float)); const uint64_t sink_bytes = (uint64_t)n_head * sizeof(float); if (sinks_offset > model_size || sink_bytes > model_size - sinks_offset) { fprintf(stderr, "ds4: Metal graph attention heads sink range is outside the mapped model\n"); @@ -19284,7 +19943,7 @@ int ds4_gpu_attention_decode_heads_tensor( id compbuf = n_comp ? ds4_gpu_tensor_buffer(comp_kv) : rawbuf; id maskbuf = use_mask ? ds4_gpu_tensor_buffer(comp_mask) : rawbuf; id headsbuf = ds4_gpu_tensor_buffer(heads); - const uint64_t comp_mask_bytes = use_mask ? (uint64_t)n_comp * sizeof(float) : 0u; + const uint64_t comp_mask_bytes = use_mask ? (uint64_t)n_comp * (ds4_gpu_comp_mask_f16() ? sizeof(uint16_t) : sizeof(float)) : 0u; if (!qbuf || !rawbuf || !compbuf || !maskbuf || !headsbuf || ds4_gpu_tensor_bytes(q) < q_bytes || ds4_gpu_tensor_bytes(raw_kv) < raw_bytes || diff --git a/metal/cpy.metal b/metal/cpy.metal index 3aa00ac19..944aaaac7 100644 --- a/metal/cpy.metal +++ b/metal/cpy.metal @@ -55,3 +55,4 @@ typedef decltype(kernel_cpy_t_t) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; diff --git a/metal/dsv4_kv.metal b/metal/dsv4_kv.metal index f91bdbf46..9244191c1 100644 --- a/metal/dsv4_kv.metal +++ b/metal/dsv4_kv.metal @@ -56,7 +56,37 @@ static inline float dsv4_e4m3fn_value(int i) { : (1.0f + float(mant) * 0.125f) * dsv4_e4m3fn_exp_scale[exp]; } -static inline float dsv4_e4m3fn_dequant(float x) { +// Round a non-negative magnitude to the nearest e4m3fn code (0..126), ties to +// even code. Shared by the value-returning dequant and the byte-packing path so +// the packed e4m3 byte unpacks to exactly the value the float path would store. +static inline int dsv4_e4m3fn_code(float ax) { + ax = min(ax, 448.0f); + int lo = 0; + int hi = 126; + while (lo < hi) { + const int mid = (lo + hi + 1) >> 1; + if (dsv4_e4m3fn_value(mid) <= ax) { + lo = mid; + } else { + hi = mid - 1; + } + } + int best = lo; + if (best < 126) { + const float best_diff = abs(ax - dsv4_e4m3fn_value(best)); + const float next_diff = abs(ax - dsv4_e4m3fn_value(best + 1)); + if (next_diff < best_diff || (next_diff == best_diff && ((best + 1) & 1) == 0 && (best & 1) != 0)) { + best = best + 1; + } + } + return best; +} + +// Software e4m3fn round-trip: round |x| to the nearest representable e4m3fn +// magnitude (ties to even) and reapply the sign. Monotonic and correct across +// the whole range; used directly when the native packer is unavailable and as +// the subnormal fallback for the hybrid path below. +static inline float dsv4_e4m3fn_dequant_sw(float x) { const float sign = x < 0.0f ? -1.0f : 1.0f; const float ax = min(abs(x), 448.0f); @@ -83,6 +113,29 @@ static inline float dsv4_e4m3fn_dequant(float x) { return sign * dsv4_e4m3fn_value(best); } +#if defined(DS4_METAL_FP8_NATIVE) && __METAL_VERSION__ >= 410 +// Hybrid e4m3fn round-trip using the MSL 4.1 hardware packer. On macOS 27 beta +// (build 26A5353q) the metal_fp8_e4m3_format packer is bit-exact with the +// software ladder on the normal range [2^-6, 448], but it rounds subnormals +// non-monotonically and returns NaN for magnitudes above the max normal instead +// of saturating. So below the min normal we keep the software ladder, and above +// it we use the native pack/unpack with a defensive clamp that also guards the +// >448 NaN. Verified bit-exact against dsv4_e4m3fn_dequant_sw over 4M samples +// spanning [-448, 448] (see ds4_gpu_validate_fp8_native). Once Apple fixes the +// subnormal/saturation behavior this can collapse to a bare pack/unpack. +static inline float dsv4_e4m3fn_dequant(float x) { + if (abs(x) < 0.015625f) { + return dsv4_e4m3fn_dequant_sw(x); + } + const vec v(clamp(x, -448.0f, 448.0f), 0.0f, 0.0f, 0.0f); + return unpack(pack(v))[0]; +} +#else +static inline float dsv4_e4m3fn_dequant(float x) { + return dsv4_e4m3fn_dequant_sw(x); +} +#endif + static inline float dsv4_e2m1fn_dequant(float x) { const float sign = x < 0.0f ? -1.0f : 1.0f; const float ax = min(abs(x), 6.0f); @@ -264,6 +317,176 @@ kernel void kernel_dsv4_kv_fp8_store_f32( } } +// Packed FP8 compressed-KV writer. Produces the same e4m3fn-quantized values as +// kernel_dsv4_fp8_kv_quantize_f32 but stores them compactly for FlashAttention: +// - nope plane : one e4m3 byte per element (sign<<7 | code). +// - scale plane: one ue8m0 byte per 64-element block (exponent + 127). +// - rot plane : the RoPE prefix copied verbatim as float (precision-preserving). +// Reconstruction is value = e4m3_value(code) * 2^(ue8m0-127), which is bit-exact +// with the float path's stored value (verified offline over 4M values). One row +// per threadgroup, 64 threads, mirroring the quantize kernel's block reduction. +kernel void kernel_dsv4_kv_pack_fp8_f32( + constant ds4_metal_args_dsv4_fp8_kv_quantize & args, + device const float * src0, + device uchar * nope_bytes, + device uchar * scale_bytes, + device float * rot_out, + threadgroup float * scratch [[threadgroup(0)]], + uint row [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const int64_t n_rows = args.ne01 * args.ne02 * args.ne03; + if ((int64_t) row >= n_rows || tid >= 64) { + return; + } + + const int head_dim = (int)args.ne00; + const int n_rot = args.n_rot; + const int n_nope = head_dim - n_rot; + const int n_blk = n_nope / 64; + + device const float * src = src0 + (int64_t)row * head_dim; + device uchar * nb = nope_bytes + (int64_t)row * n_nope; + device uchar * sb = scale_bytes + (int64_t)row * n_blk; + device float * rt = rot_out + (int64_t)row * n_rot; + + for (int off = 0; off < n_nope; off += 64) { + float v = 0.0f; + if (off + (int)tid < n_nope) { + v = src[off + tid]; + scratch[tid] = abs(v); + } else { + scratch[tid] = 0.0f; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride = 32; stride > 0; stride >>= 1) { + if (tid < stride) { + scratch[tid] = max(scratch[tid], scratch[tid + stride]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const float amax = max(scratch[0], 1.0e-4f); + int exp = (int)ceil(log2(amax / 448.0f)); + if (exp < -127) exp = -127; + if (exp > 127) exp = 127; + const float scale = exp2((float)exp); + if (tid == 0) { + sb[off / 64] = (uchar)(exp + 127); + } + if (off + (int)tid < n_nope) { + const float vs = clamp(v / scale, -448.0f, 448.0f); + const uint sign = vs < 0.0f ? 0x80u : 0x00u; + const int code = dsv4_e4m3fn_code(abs(vs)); + nb[off + tid] = (uchar)(sign | (uint)code); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (int i = tid; i < n_rot; i += 64) { + rt[i] = src[n_nope + i]; + } +} + +// Row-interleaved packed FP8 writer consumed directly by decode FlashAttention's +// dequantize_fp8_kv_t4 (see flash_attn.metal struct ds4_fp8_kv_row). Produces the +// same e4m3-quantized values as the float quantize path but in one buffer with a +// single per-row stride, so the F16 staging copy is eliminated: +// [scale: n_blk ue8m0 bytes, padded to 8][nope: n_nope e4m3 bytes][rot: n_rot half] +// The rot prefix is rounded to half here to match the prior F16 staging exactly, +// making the reconstructed (half) KV bit-identical to the current attention input. +// One row per threadgroup, 64 threads, mirroring the quantize kernel's reduction. +kernel void kernel_dsv4_kv_pack_fp8_row_f32( + constant ds4_metal_args_dsv4_fp8_kv_quantize & args, + device const float * src0, + device uchar * rows, + threadgroup float * scratch [[threadgroup(0)]], + uint row [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const int64_t n_rows = args.ne01 * args.ne02 * args.ne03; + if ((int64_t) row >= n_rows || tid >= 64) { + return; + } + + const int head_dim = (int)args.ne00; + const int n_rot = args.n_rot; + const int n_nope = head_dim - n_rot; + const int n_blk = n_nope / 64; + const int scale_pad = (n_blk + 7) & ~7; // 8 for n_blk=7; keeps rot 8-aligned + const int rot_off = scale_pad + n_nope; // byte offset of the rot (half) plane + const int stride = rot_off + n_rot * (int)sizeof(half); + + device const float * src = src0 + (int64_t)row * head_dim; + device uchar * base = rows + (int64_t)row * stride; + device uchar * sb = base; // scale plane + device uchar * nb = base + scale_pad; // nope plane + device half * rt = (device half *)(base + rot_off); // rot plane + + for (int off = 0; off < n_nope; off += 64) { + float v = 0.0f; + if (off + (int)tid < n_nope) { + v = src[off + tid]; + scratch[tid] = abs(v); + } else { + scratch[tid] = 0.0f; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride2 = 32; stride2 > 0; stride2 >>= 1) { + if (tid < stride2) { + scratch[tid] = max(scratch[tid], scratch[tid + stride2]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const float amax = max(scratch[0], 1.0e-4f); + int exp = (int)ceil(log2(amax / 448.0f)); + if (exp < -127) exp = -127; + if (exp > 127) exp = 127; + const float scale = exp2((float)exp); + if (tid == 0) { + sb[off / 64] = (uchar)(exp + 127); + } + if (off + (int)tid < n_nope) { + const float vs = clamp(v / scale, -448.0f, 448.0f); + const uint sign = vs < 0.0f ? 0x80u : 0x00u; + const int code = dsv4_e4m3fn_code(abs(vs)); + nb[off + tid] = (uchar)(sign | (uint)code); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + for (int i = tid; i < n_rot; i += 64) { + rt[i] = (half)src[n_nope + i]; + } +} + +// Inverse of kernel_dsv4_kv_pack_fp8_row_f32: unpack packed FP8 comp rows back to +// a contiguous f16 buffer for the (cold) flash-attention staging path, which keeps +// using the existing kvpad-correct F16 machinery. The hot indexed-attention path +// reads the packed cache directly and does not use this. One row per threadgroup, +// 128 threads (each emits one half4 via the shared dequantize_fp8_kv_t4 decode). +kernel void kernel_dsv4_kv_unpack_fp8_row_f16( + constant ds4_metal_args_dsv4_fp8_kv_quantize & args, + device const uchar * rows, + device half * dst, + uint row [[threadgroup_position_in_grid]], + uint tid [[thread_position_in_threadgroup]]) { + const int64_t n_rows = args.ne01 * args.ne02 * args.ne03; + if ((int64_t) row >= n_rows || tid >= 128u) { + return; + } + const int head_dim = (int)args.ne00; + if (head_dim != 512 || args.n_rot != 64) { + return; // packed layout (ds4_fp8_kv_row) is fixed to DS4's 512/64 dims + } + device const ds4_fp8_kv_row * src = ((device const ds4_fp8_kv_row *) rows) + row; + device half4 * out = (device half4 *)(dst + (int64_t)row * head_dim); + half4 v; + dequantize_fp8_kv_t4(src, (short)tid, v); + out[tid] = v; +} + // Ratio-4 compression keeps two 4-row halves of recurrent state. After an // emitted compressed row, the second half becomes the next window's previous // half. The old encoder expressed this as four generic copies; this DS4-specific diff --git a/metal/dsv4_misc.metal b/metal/dsv4_misc.metal index 7e8cbd8a7..bf28a9c33 100644 --- a/metal/dsv4_misc.metal +++ b/metal/dsv4_misc.metal @@ -7,6 +7,8 @@ struct ds4_metal_args_dsv4_topk_mask { int64_t ne1; uint64_t nb0; uint64_t nb1; + uint32_t mask_f16; // 1 = write comp_mask as f16 (default), 0 = f32 (opts disabled) + uint32_t pad0; }; struct ds4_metal_args_dsv4_indexer_weighted_sum { @@ -295,7 +297,12 @@ kernel void kernel_dsv4_topk_mask( const int64_t it = gid / args.ne0; (void)topk; - *((device float *) (dst + ic*args.nb0 + it*args.nb1)) = -INFINITY; + // comp_mask is binary -inf/0 (both exact in f16 or f32). + if (args.mask_f16) { + *((device half *) (dst + ic*args.nb0 + it*args.nb1)) = (half) -INFINITY; + } else { + *((device float *) (dst + ic*args.nb0 + it*args.nb1)) = -INFINITY; + } } // Enables the selected compressed rows in the dense mask. This replaces the @@ -315,7 +322,11 @@ kernel void kernel_dsv4_topk_mask_scatter( const int64_t it = gid / args.ne00; const int32_t idx = *((device const int32_t *) (topk + ik*args.nb00 + it*args.nb01)); if (idx >= 0 && (int64_t)idx < args.ne0) { - *((device float *) (dst + (int64_t)idx*args.nb0 + it*args.nb1)) = 0.0f; + if (args.mask_f16) { + *((device half *) (dst + (int64_t)idx*args.nb0 + it*args.nb1)) = (half) 0.0f; + } else { + *((device float *) (dst + (int64_t)idx*args.nb0 + it*args.nb1)) = 0.0f; + } } } @@ -535,14 +546,23 @@ static inline void dsv4_attend_shared_h4_row_at( o0, o1, o2, o3); } +// comp cache element load. `format`: 0 = f32 rows, 1 = f16 rows, 2 = packed FP8 +// rows (ds4_fp8_kv_row, defined in flash_attn.metal which is concatenated first). +// col is the 4-vec index 0..127 (dims [4col,4col+4)). The packed branch reuses the +// same e4m3+ue8m0 decode as the flash reader, so the dequantized half4 is identical. static inline half4 dsv4_load_cache_h4( device const char *kv, uint64_t row_stride, uint row, uint col, - bool f16_rows) { + uint format) { device const char *base = kv + (uint64_t)row * row_stride; - if (f16_rows) { + if (format == 2u) { + half4 reg; + dequantize_fp8_kv_t4((device const ds4_fp8_kv_row *)base, (short)col, reg); + return reg; + } + if (format == 1u) { return ((device const half4 *)base)[col]; } return (half4)((device const float4 *)base)[col]; @@ -652,7 +672,7 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8( args.comp_row_stride, (uint)idx, tid, - args.comp_kv_f16 != 0u); + args.comp_kv_f16); } threadgroup_barrier(mem_flags::mem_threadgroup); dsv4_attend_shared_h4_row(kv_shared, @@ -780,7 +800,7 @@ kernel void kernel_dsv4_indexed_mixed_attention_heads8_rb16( args.comp_row_stride, rows[r], c, - args.comp_kv_f16 != 0u); + args.comp_kv_f16); } threadgroup_barrier(mem_flags::mem_threadgroup); for (uint r = 0; r < n_rows; r++) { diff --git a/metal/flash_attn.metal b/metal/flash_attn.metal index d069d43cc..cb83bb8b5 100644 --- a/metal/flash_attn.metal +++ b/metal/flash_attn.metal @@ -22,6 +22,96 @@ void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { reg = (type4)(*(src)); } +// Packed FP8 compressed-KV row consumed directly by decode FlashAttention, +// eliminating the F16 staging copy and ~1.75x of the KV read bandwidth. Layout +// per 512-wide MLA row (K and V share the same latent, so one plane set): +// scale[0..6] : ue8m0 exponent+127, one per 64-element nope block (7 blocks) +// scale[7] : padding (keeps nope/rot aligned; never read) +// nope[0..447]: e4m3 byte (sign<<7 | code) per non-RoPE element +// rot[0..63] : RoPE prefix as F16 (matches the F16 staging exactly) +// Reconstructed value = sign * e4m3_mag(code) * 2^(scale-127). Cast to half this +// is bit-identical to the prior (half) float-quantized KV, so the attention +// result is unchanged. nb11 = sizeof(struct) = 584, divisible by 8 (half4) and +// 4 (uchar4) so every plane read stays naturally aligned. +struct ds4_fp8_kv_row { + uchar scale[8]; + uchar nope[448]; + half rot[64]; +}; + +// e4m3fn magnitude (no sign) for code 0..127, precomputed. Values are exactly +// (e==0 ? m*2^-9 : (1+m/8)*2^(e-7)) so the table is bit-identical to the formula; +// replacing the per-element branch+exp2 with a constant load is the key decode-cost +// reduction for the packed-KV read path (indexed attention + flash dequant). +constant float ds4_e4m3_lut[128] = { + 0.0f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.0f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, 2.75f, 3.0f, 3.25f, 3.5f, 3.75f, + 4.0f, 4.5f, 5.0f, 5.5f, 6.0f, 6.5f, 7.0f, 7.5f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, 22.0f, 24.0f, 26.0f, 28.0f, 30.0f, + 32.0f, 36.0f, 40.0f, 44.0f, 48.0f, 52.0f, 56.0f, 60.0f, + 64.0f, 72.0f, 80.0f, 88.0f, 96.0f, 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, 176.0f, 192.0f, 208.0f, 224.0f, 240.0f, + 256.0f, 288.0f, 320.0f, 352.0f, 384.0f, 416.0f, 448.0f, 480.0f, +}; +inline float ds4_e4m3_mag(uint code) { + return ds4_e4m3_lut[code & 0x7fu]; +} + +template +void dequantize_fp8_kv_t4(device const ds4_fp8_kv_row * src, short il, thread type4 & reg) { + const short d0 = il * 4; // first dim of this 4-vec, 0..508 (il in 0..127) + if (d0 >= 448) { + device const half4 * rp = (device const half4 *) (src->rot + (d0 - 448)); + reg = (type4) (*rp); + } else { + const uchar4 b = *((device const uchar4 *) (src->nope + d0)); + // All four lanes share one 64-element block (d0 is a multiple of 4, the + // block size 64 is too, so a 4-vec never straddles a block boundary). + const float s = exp2((float)((int)src->scale[d0 >> 6] - 127)); + float4 v; + v.x = ((b.x & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.x & 0x7fu) * s; + v.y = ((b.y & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.y & 0x7fu) * s; + v.z = ((b.z & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.z & 0x7fu) * s; + v.w = ((b.w & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.w & 0x7fu) * s; + reg = (type4) v; + } +} + +// 4x4 variant for the non-vec (prefill/decode-batch) FlashAttention kernel, which +// loads K/V as 16-element half4x4 chunks. chunk in 0..DK16-1 covers dims +// [16*chunk, 16*chunk+16); since 16 divides both 64 (scale block) and 448 (nope/rot +// split), a chunk never straddles either boundary. Produces the same half4x4 the +// f16 staging would hold for that chunk -> bit-exact attention. +template +void dequantize_fp8_kv_4x4(device const ds4_fp8_kv_row * src, short chunk, thread type4x4 & reg) { + const short d0 = chunk * 16; + half4x4 out; + if (d0 >= 448) { + device const half4 * rp = (device const half4 *) (src->rot + (d0 - 448)); + out[0] = rp[0]; out[1] = rp[1]; out[2] = rp[2]; out[3] = rp[3]; + } else { + const float s = exp2((float)((int)src->scale[d0 >> 6] - 127)); + for (short c = 0; c < 4; ++c) { + const uchar4 b = *((device const uchar4 *) (src->nope + d0 + c*4)); + float4 v; + v.x = ((b.x & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.x & 0x7fu) * s; + v.y = ((b.y & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.y & 0x7fu) * s; + v.z = ((b.z & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.z & 0x7fu) * s; + v.w = ((b.w & 0x80u) ? -1.0f : 1.0f) * ds4_e4m3_mag(b.w & 0x7fu) * s; + out[c] = (half4) v; + } + } + reg = (type4x4) out; +} + template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg); @@ -90,6 +180,8 @@ struct ds4_metal_args_flash_attn_ext { float m1; int32_t n_head_log2; float logit_softcap; + int32_t n_raw_split; // FP8 mixed-KV: key rows >= this read from kcomp/vcomp (packed) + uint64_t nb_kcomp; // packed comp row stride (bytes); 0 for the plain f16 path }; struct ds4_metal_args_flash_attn_ext_vec { @@ -125,6 +217,12 @@ struct ds4_metal_args_flash_attn_ext_vec { float m1; int32_t n_head_log2; float logit_softcap; + // FP8 mixed-KV split: when the kernel is instantiated with MIX, key rows + // [0,n_raw_split) are read from the f16 `k`/`v` buffers and rows >= n_raw_split + // are read from the packed FP8 `kcomp`/`vcomp` buffers (row r-n_raw_split, + // stride nb_kcomp). Unused (0) for the plain f16 instantiation. + int32_t n_raw_split; + uint64_t nb_kcomp; }; struct ds4_metal_args_flash_attn_ext_vec_reduce { @@ -304,7 +402,8 @@ template< short DV, short Q, short C, - short NSG> + short NSG, + bool MIX = false> void kernel_flash_attn_ext_impl( constant ds4_metal_args_flash_attn_ext & args, device const char * q, @@ -315,6 +414,8 @@ void kernel_flash_attn_ext_impl( device const char * pad, device const char * blk, device char * dst, + device const char * kcomp, + device const char * vcomp, threadgroup half * shmem_f16, uint3 tgpig, ushort tiisg, @@ -435,7 +536,11 @@ void kernel_flash_attn_ext_impl( break; } + // MIX tail block served from the f16 `pad` buffer must be read as f16 + // (bit-identical halfs), not from the packed comp buffer by row index. + bool in_pad = false; if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + in_pad = true; k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; @@ -568,13 +673,23 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + const int k_kr = ic + 8*cc + ty; + const bool k_comp = MIX && !in_pad && k_kr >= args.n_raw_split; + device const ds4_fp8_kv_row * k_pc = k_comp + ? (device const ds4_fp8_kv_row *) (kcomp + (int64_t)(k_kr - args.n_raw_split)*args.nb_kcomp) + : (device const ds4_fp8_kv_row *) 0; + for (short ii = 0; ii < DK16; ii += 4) { device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { { k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + if (k_comp) { + dequantize_fp8_kv_4x4(k_pc, ii + tx, tmp); + } else { + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + } sk4x4[4*ty + tx] = tmp; } @@ -595,7 +710,11 @@ void kernel_flash_attn_ext_impl( } else { if (ii + tx < DK16) { k4x4_t tmp; - deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + if (k_comp) { + dequantize_fp8_kv_4x4(k_pc, ii + tx, tmp); + } else { + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + } sk4x4[4*ty + tx] = tmp; } @@ -750,13 +869,23 @@ void kernel_flash_attn_ext_impl( s8x8_t vs; simdgroup_load(vs, ss + 8*cc, SH, 0, false); + const int v_kr = ic + 8*cc + ty; + const bool v_comp = MIX && !in_pad && v_kr >= args.n_raw_split; + device const ds4_fp8_kv_row * v_pc = v_comp + ? (device const ds4_fp8_kv_row *) (vcomp + (int64_t)(v_kr - args.n_raw_split)*args.nb_kcomp) + : (device const ds4_fp8_kv_row *) 0; + for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { { v4x4_t tmp; - deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + if (v_comp) { + dequantize_fp8_kv_4x4(v_pc, ii + tx, tmp); + } else { + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + } sv4x4[4*ty + tx] = tmp; } @@ -780,7 +909,11 @@ void kernel_flash_attn_ext_impl( } else { if (ii + tx < DV16) { v4x4_t tmp; - deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + if (v_comp) { + dequantize_fp8_kv_4x4(v_pc, ii + tx, tmp); + } else { + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + } sv4x4[4*ty + tx] = tmp; } @@ -888,7 +1021,8 @@ template< short DK, short DV, short Q = OP_FLASH_ATTN_EXT_NQPSG, - short C = OP_FLASH_ATTN_EXT_NCPSG> + short C = OP_FLASH_ATTN_EXT_NCPSG, + bool MIX = false> kernel void kernel_flash_attn_ext( constant ds4_metal_args_flash_attn_ext & args, device const char * q, @@ -899,15 +1033,17 @@ kernel void kernel_flash_attn_ext( device const char * pad, device const char * blk, device char * dst, + device const char * kcomp, + device const char * vcomp, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, kcomp, vcomp, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { - case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; - case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS @@ -927,6 +1063,14 @@ typedef decltype(kernel_flash_attn_ext; +// Region-aware mixed-KV variant (MIX=true), used by batched decode: raw key rows +// [0,n_raw_split) read from the f16 k/v buffers (deq_k/deq_v = dequantize_f16), +// comp rows from the packed FP8 kcomp/vcomp via dequantize_fp8_kv_4x4. Same tile +// order + bit-identical halfs -> bit-exact vs f16. Q/C passed explicitly so the +// trailing MIX=true can be set. +template [[host_name("kernel_flash_attn_ext_fp8mix_dk512_dv512")]] +kernel flash_attn_ext_dk512_t kernel_flash_attn_ext; + #undef FA_NONVEC_TYPES constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; @@ -960,7 +1104,8 @@ template< short DV, short NE = 4, short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, - short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG, + bool MIX = false> kernel void kernel_flash_attn_ext_vec( constant ds4_metal_args_flash_attn_ext_vec & args, device const char * q, @@ -970,6 +1115,8 @@ kernel void kernel_flash_attn_ext_vec( device const char * sinks, device const char * pad, device char * dst, + device const char * kcomp, + device const char * vcomp, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], @@ -1067,7 +1214,12 @@ kernel void kernel_flash_attn_ext_vec( break; } + // When the tail block is served from the f16 `pad` buffer, MIX must + // read it as f16 (the pad holds bit-identical half values), not from + // the packed comp buffer by global row index. + bool in_pad = false; if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + in_pad = true; k = pad; v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; @@ -1110,7 +1262,31 @@ kernel void kernel_flash_attn_ext_vec( qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - if (is_same::value) { + if (MIX && !in_pad) { + // Region-aware read: raw rows from the f16 `k` buffer, comp + // rows from the packed FP8 `kcomp` buffer. The dequantized + // half values are bit-identical to the prior combined f16 + // buffer and the iteration/accumulation order is unchanged, + // so mqk is bit-exact vs the plain f16 kernel. + const int kr = ic + NE*cc + ty; + k4_t mk; + if (kr >= args.n_raw_split) { + device const ds4_fp8_kv_row * pc = (device const ds4_fp8_kv_row *) + (kcomp + (int64_t)(kr - args.n_raw_split)*args.nb_kcomp); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; + dequantize_fp8_kv_t4(pc, i, mk); + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } + } else { + device const half4 * pr = (device const half4 *) (k + (int64_t)kr*args.nb11); + FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { + const short i = ii*NL + tx; + mk = (k4_t) pr[i]; + mqk[cc] += dot((float4) mk, (float4) sq4[i]); + } + } + } else if (is_same::value) { FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } @@ -1202,7 +1378,31 @@ kernel void kernel_flash_attn_ext_vec( lo[ii] = 0.0f; } - if (is_same::value) { + if (MIX && !in_pad) { + // Region-aware V read, mirroring the K load: raw rows from the + // f16 `v` buffer, comp rows from packed FP8 `vcomp`. Same order + // and bit-identical half values -> bit-exact output vs f16. + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + const int kr = ic + NE*cc + ty; + if (kr >= args.n_raw_split) { + device const ds4_fp8_kv_row * pc = (device const ds4_fp8_kv_row *) + (vcomp + (int64_t)(kr - args.n_raw_split)*args.nb_kcomp); + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + v4_t mv; + dequantize_fp8_kv_t4(pc, i, mv); + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } else { + device const half4 * pr = (device const half4 *) (v + (int64_t)kr*args.nb21); + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + const short i = ii*NL + tx; + v4_t mv = (v4_t) pr[i]; + lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + } + } + } + } else if (is_same::value) { device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -1378,6 +1578,14 @@ typedef decltype(kernel_flash_attn_ext_vec; +// Region-aware mixed-KV decode variant (MIX=true): raw key rows [0,n_raw_split) +// are read from the f16 k/v buffers exactly as the f16 kernel; comp rows are read +// from the packed FP8 kcomp/vcomp buffers via dequantize_fp8_kv_t4. Iteration and +// accumulation order are identical to the f16 kernel and the dequantized halfs are +// bit-identical, so the output is bit-exact while the comp KV read is ~1.75x +// smaller. Q/C are passed explicitly so the trailing MIX=true can be set. +template [[host_name("kernel_flash_attn_ext_vec_fp8mix_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + #undef FA_TYPES #undef FA_TYPES_F32