From 561d510af00340384c9827359bc52e4a1a33ee54 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 17 Jun 2026 18:51:47 +0000 Subject: [PATCH 1/6] Add nanoGPT model script --- examples/xegpu/nanoGPT.py | 1361 +++++++++++++++++++++++++++++++++++++ 1 file changed, 1361 insertions(+) create mode 100644 examples/xegpu/nanoGPT.py diff --git a/examples/xegpu/nanoGPT.py b/examples/xegpu/nanoGPT.py new file mode 100644 index 00000000..b435f4a4 --- /dev/null +++ b/examples/xegpu/nanoGPT.py @@ -0,0 +1,1361 @@ +"""nano-GPT / GPT-2-style forward pass on the Intel GPU (XeGPU), with +FLASH multi-head attention. + +This is a nanoGPT block stack: each transformer block is + a = x + attn_proj( MultiHeadAttention( ln1(x) ) ) # attention sublayer + y = a + ffn( ln2(a) ) # MLP sublayer + ffn(z) = Linear(C, 4C) -> ReLU -> Linear(4C, C) +and the full model is + x = token_emb + pos_emb # embeddings (done host-side) + for _ in range(n_layer): x = Block(x) + x = ln_f(x); logits = x @ lm_head +TRUE multi-head: H heads of head_size = C/H = 64, computed by ONE +fused flash-attention kernel per block. + +The attention is the FLASH/FUSED kernel : standard +attention is built on 4D tensors (Z, H, n_ctx, head_size) at the linalg level, +then a transform-dialect schedule rewrites the whole Q@K^T -> softmax -> @V region +into ONE kernel that tiles the K/V reduction dim and carries a running max/sum (the +flash-attention online-softmax), so the full T x T scores matrix is never +materialized. Everything else (layernorm, the q/k/v/proj/ffn/lm_head matmuls, the +casts/bias/residual elementwise ops) is lowered as its own XeGPU kernel; the whole +model is ONE MLIR module with on-device buffers handing off between kernels. + +Config: n_layer=6, C=256, H=4 (head_size=64), hidden=1024, vocab=256, T=256. + +Builds the FULL model (n_layer blocks -> ln_f -> lm_head), with FUSED multi-head +NON-CAUSAL attention per block. + +Bridging the model's 2D (T,C) activations to the fused kernel's multi-head +(H,T,hs) layout uses NO on-device transpose kernel: each q/k/v projection buffer +is presented as a (H,T,hs) STRIDED memref VIEW (memref.expand_shape + +memref.transpose -- pure layout, zero compute), and the fused schedule's +(1,wg_rows,0,0) tiling peels the head dim into the work-group GRID so each wg reads +2D strided slices -> 2D load_nd . + +Run: + .venv/bin/python examples/xegpu/gpt2.py [--gpt-layers N] [--check] + .venv/bin/python examples/xegpu/gpt2.py [--dump STAGE] +""" + +import sys +import numpy as np +from mlir import ir +from mlir.dialects import linalg, bufferization, tensor, arith, math, gpu, memref +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, +) +from mlir.dialects.bufferization import LayoutMapOption + +from lighthouse import dialects as lh_dialects +import lighthouse.transform as lh_transform +from lighthouse.dialects.transform import transform_ext +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, + match_and_split, + PipelineInterrupt, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.ingress.mlir_gen.utils import ( + emit_buf_to_tensor, + affine_map, + parallel, + reduction, +) +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.named import times_weights +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution.runner import Runner +from lighthouse.execution import GPUMemoryManager +from lighthouse.schedule.xegpu import xegpu_to_binary, XeGPUParameterSelector +from lighthouse.schedule.xegpu.mlp_schedule import xegpu_wg_annotation_for_mlp_layer + + +# ============================================================================= +# HOW THIS FILE IS ORGANIZED +# ============================================================================= +# Compiling a model to the GPU here happens in THREE stages. The code is grouped +# to match them: +# +# 1. PAYLOAD -- "WHAT to compute". Build an MLIR module describing the math +# as high-level `linalg` ops (matmul, layernorm, softmax, elementwise). This +# is hardware-agnostic; it does not say how to run on the GPU yet. +# -> class `Builder` (emits one op at a time) and the `build_*_payload` +# functions (assemble ops into ffn / attn / block / full-gpt). +# +# 2. SCHEDULE -- "HOW to lower it to the GPU". Build a SECOND MLIR module (a +# "transform dialect" script) that rewrites the payload: tile each op into +# GPU work-groups, vectorize, bufferize, outline each op into its own GPU +# kernel, and attach XeGPU layout/target attributes. +# -> `build_combined_schedule` / `_bundle` (the orchestrator) plus the +# `_tile_one_matmul` / `_tile_one_layernorm` / `_tile_one_fused_attention_region` +# helpers. +# *** THIS is the part that "schedules the passes". *** +# +# 3. DRIVER -- "run it". `main()` applies the schedule to the payload +# (TransformDriver), JIT-compiles + runs it on the GPU (Runner), and checks +# the result against a plain-numpy reference (`numpy_ref_*`). +# +# KEY IDEA -- one module, many separate kernels: the whole model is ONE MLIR +# module, but each op becomes its OWN GPU kernel (no cross-op fusion). Data passes +# between kernels through device buffers (`gpu.alloc`) that stay on the GPU -- no +# round-trip to the host between ops. +# ============================================================================= + + +def F32(): # 32-bit float (used for accumulation / norms) + return ir.F32Type.get() + + +def F16(): # 16-bit float (required by the GPU matmul units) + return ir.F16Type.get() + + +# ============================================================================= +# STAGE 1 -- PAYLOAD: describe WHAT to compute (hardware-agnostic linalg ops) +# ============================================================================= +# Each Builder method emits ONE high-level op that writes its result into a fresh +# on-device buffer (`gpu.alloc`), and returns a tensor "view" of that buffer for +# the next op to read. Because each op writes a distinct device buffer, each will +# become its OWN GPU kernel later; the buffers are the on-device handoff between +# kernels (kernel N writes buffer B, kernel N+1 reads B -- no host round-trip). +# +# dtype convention: the GPU matmul (DPAS) hardware needs f16 inputs and produces +# an f32 result. LayerNorm/softmax run in f32. So between a norm/softmax and a +# matmul we insert an explicit f32->f16 `cast` op (its own kernel). +# ============================================================================= +class Builder: + """Emits the model's ops and remembers the order/kind of each one. + + `kinds` is the crucial bookkeeping: an ordered list, one entry per op emitted, + recording its "class" so the SCHEDULE (stage 2) can later tile and annotate + each kernel correctly. Classes: + 'mm' = matmul (linalg.matmul) -> DPAS systolic-array kernel + 'ln' = layernorm (3 generics + 2 fills) -> reduction kernel (uses shared mem) + 'fa' = flash multi-head attention -> ONE kernel (QK^T->softmax->@V, + online-softmax over K/V tiles); see attention_4d + the fused-attention + schedule helpers. (Softmax lives INSIDE this kernel, not as its own.) + 'ew' = elementwise (cast / bias / relu / residual) -> simple row-parallel kernel + The op build order in the payload == the order of `kinds` == the order the + kernels appear in the final module, which is how the schedule matches them up. + """ + + def __init__(self, T): + self.T = T + self.f32, self.f16 = F32(), F16() + self.kinds = [] # ordered kernel classes (see docstring) + self.to_dealloc = [] # device buffers to gpu.dealloc at the end + + def _buf(self, shape, dtype): + # Allocate a DEVICE buffer (lives in GPU memory). Returns the memref. + b = gpu.alloc(ir.MemRefType.get(shape, dtype), None, [], [], []) + self.to_dealloc.append(b) + return b + + def _par(self, rank=2): + # Identity affine map (d0,d1,...) -> (d0,d1,...): a plain elementwise + # access pattern where output[i,j] depends on input[i,j]. + return affine_map(rank, [ir.AffineDimExpr.get(i) for i in range(rank)]) + + # ---- matmul: a(M,K) f16 @ b(K,N) f16 -> (M,N) f32 buffer ---- + def matmul(self, a, b, M, N, out_buf=None): + # Standard C = A @ B. `times_weights` emits linalg.matmul; we first fill the + # accumulator with 0. f16 inputs, f32 output -- matches the DPAS hardware. + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + acc = linalg.fill(arith.constant(self.f32, 0.0), outs=[out_t]) + res = times_weights(a, b, acc) + bufferization.materialize_in_destination( + None, res, buf, restrict=True, writable=True + ) + self.kinds.append("mm") + if out_buf is not None: # caller gave the final output buffer + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- layernorm(x (M,N) f32, gamma,beta (N,)) -> (M,N) f32 buffer ---- + def layernorm(self, x, gamma, beta, M, N, eps=1e-5): + # LayerNorm normalizes each ROW (length N) to mean 0 / variance 1, then + # scales by gamma and shifts by beta. Built from 3 linalg.generic ops: + # (1) mean_sum[i] = sum_j x[i,j] (row reduction) + # (2) var_sum[i] = sum_j (x[i,j]-mean_i)^2 (row reduction) + # (3) out[i,j] = (x[i,j]-mean_i)*rsqrt(var_i+eps)*gamma[j] + beta[j] + # Affine maps describe each operand's access pattern: + # par2 (d0,d1)->(d0,d1) : full 2-D elementwise + # red2 (d0,d1)->(d0) : reduce over j -> one value per row + # bias2 (d0,d1)->(d1) : gamma/beta indexed by column only + f32 = self.f32 + par2, red2 = self._par(), affine_map(2, [ir.AffineDimExpr.get(0)]) + bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) + inv_n = arith.constant(f32, 1.0 / float(N)) + eps_c = arith.constant(f32, eps) + zero = arith.constant(f32, 0.0) + # (1) row sums -> mean_sum (linalg.fill zeroes the accumulator first) + mean_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) + + @linalg.generic([x], [mean_acc], [par2, red2], [parallel, reduction]) + def mean_sum(v, acc): + return arith.AddFOp(v, acc) + + # (2) sum of squared deviations -> var_sum (mean_i = mean_sum_i / N) + var_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) + + @linalg.generic( + [x, mean_sum], [var_acc], [par2, red2, red2], [parallel, reduction] + ) + def var_sum(v, m_sum, acc): + mean = arith.MulFOp(m_sum, inv_n).result + c = arith.SubFOp(v, mean).result + return arith.AddFOp(arith.MulFOp(c, c).result, acc) + + # (3) normalize + scale + shift -> output + buf = self._buf((M, N), f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic( + [x, mean_sum, var_sum, gamma, beta], + [out_t], + [par2, red2, red2, bias2, bias2, par2], + [parallel, parallel], + ) + def normed(v, m_sum, v_sum, g, b, _o): + mean = arith.MulFOp(m_sum, inv_n).result + var = arith.MulFOp(v_sum, inv_n).result + inv_std = math.rsqrt(arith.AddFOp(var, eps_c).result) + c = arith.SubFOp(v, mean).result + return arith.AddFOp( + arith.MulFOp(arith.MulFOp(c, inv_std).result, g).result, b + ) + + bufferization.materialize_in_destination( + None, normed, buf, restrict=True, writable=True + ) + self.kinds.append("ln") + return emit_buf_to_tensor(buf, restrict=True) + + # ---- softmax(x (M,N) f32) over last dim -> (M,N) f32 buffer ---- + # ---- elementwise cast f32 -> f16 ---- + def cast_f16(self, x, M, N): + par2 = self._par() + buf = self._buf((M, N), self.f16) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) + def c(s, _o): + return arith.TruncFOp(self.f16, s) + + bufferization.materialize_in_destination( + None, c, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + return emit_buf_to_tensor(buf, restrict=True) + + # ---- bias add (+ optional relu): out = max(x + bias, 0)? x (M,N) f32, bias (N,) ---- + def bias(self, x, bias_vec, M, N, relu=False, out_buf=None): + par2 = self._par() + bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) + zero = arith.constant(self.f32, 0.0) + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic( + [x, bias_vec], [out_t], [par2, bias2, par2], [parallel, parallel] + ) + def b(v, bb, _o): + s = arith.AddFOp(v, bb).result + if relu: + return arith.MaximumFOp(s, zero) + return arith.AddFOp(s, zero) # identity wrap so the op has a body + + bufferization.materialize_in_destination( + None, b, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + if out_buf is not None: + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- residual add: out = a + b (both (M,N) f32) ---- + def add(self, a, b, M, N, out_buf=None): + par2 = self._par() + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([a, b], [out_t], [par2, par2, par2], [parallel, parallel]) + def r(x, y, _o): + return arith.AddFOp(x, y) + + bufferization.materialize_in_destination( + None, r, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + if out_buf is not None: + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- cast f32 (T,C) -> f16 (T,C), returning the MEMREF buffer (for views) ---- + def cast_f16_buf(self, x, T, C): + par2 = self._par() + buf = self._buf((T, C), self.f16) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) + def c(s, _o): + return arith.TruncFOp(self.f16, s) + + bufferization.materialize_in_destination( + None, c, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + return buf + + # ---- view a (T, H*hs) MEMREF as (H, T, hs) -- NO kernel, NO data move ---- + def _heads_view_of(self, buf2d, T, H, hs): + # We present the 2D + # (T, H*hs) projection buffer as a (H,T,hs) STRIDED memref VIEW: + # (T,H*hs) --memref.expand_shape--> (T,H,hs) [strides C,hs,1] + # --memref.transpose [1,0,2]--> (H,T,hs) [strides hs,C,1] + # Both are pure layout ops (no compute, no kinds entry). When the fused + # schedule tiles (1,wg_rows,0,0), the grid peels head h -> a 2D + # memref> -> 2D load_nd (XeGPU supports + # such strided block loads). + C = H * hs + et = buf2d.type.element_type + exp_t = ir.MemRefType.get((T, H, hs), et) + e = memref.expand_shape( + exp_t, buf2d, [[0], [1, 2]], [], static_output_shape=[T, H, hs] + ) + d0, d1, d2 = (ir.AffineDimExpr.get(i) for i in range(3)) + perm = ir.AffineMap.get(3, 0, [d1, d0, d2]) # (H,T,hs) <- (T,H,hs) + layout = ir.StridedLayoutAttr.get(0, [hs, C, 1]) + res_t = ir.MemRefType.get((H, T, hs), et, layout=layout) + return memref.transpose(res_t, e, perm) + + def heads_view(self, buf2d, T, H, hs): + return emit_buf_to_tensor(self._heads_view_of(buf2d, T, H, hs), restrict=True) + + # ---- fused multi-head attention core on 3D (H,T,hs) f16 -> (H,T,hs) f16 ---- + # (named attention_4d because it is the canonical 4D (Z,H,T,hs) attention + # algorithm with the batch dim Z=1 FOLDED OUT: one sequence, so (1,H,T,hs) + # collapses to (H,T,hs) and linalg.batch_matmul treats H as the batch axis.) + def attention_4d(self, Qh, Kh, Vh, H, T, hs, out_view, out_view_memref): + # Emits the SAME linalg op sequence as generate_gpu_attention_payload + # (batch_matmul QK^T -> scale-mul -> softmax -> batch_matmul @V), so the + # fused-attention schedule's matchers/rewrite apply verbatim. After the + # per-region fused tiling, ALL these ops fuse into ONE scf.forall -> ONE + # GPU kernel (the flash/online-softmax kernel). Counts as one 'fa'. + # Inputs Qh/Kh/Vh are (H,T,hs) f16 strided VIEWS (heads_view); the @V result + # is materialized into `out_view`, a (H,T,hs) strided view of a (T,C) buffer, + # so the merge back to 2D is also a free view (no from_heads kernel). + f16 = self.f16 + scale = 1.0 / (hs**0.5) + zero = arith.constant(f16, 0.0) + # K^T: (H,T,hs) -> (H,hs,T). Lowers to a 2D vector.transpose per head (the + # grid peels H), exactly like the standalone -- f16 is fine here. + Kt = linalg.transpose( + Kh, outs=[tensor.empty((H, hs, T), f16)], permutation=[0, 2, 1] + ) + qkt_init = linalg.fill(zero, outs=[tensor.empty((H, T, T), f16)]) + qkt = linalg.batch_matmul(Qh, Kt, outs=[qkt_init]) + sc = arith.constant(f16, scale) + scale_t = linalg.fill(sc, outs=[tensor.empty((H, T, T), f16)]) + scaled = linalg.mul(qkt, scale_t, outs=[tensor.empty((H, T, T), f16)]) + aw = linalg.softmax( + result=[ir.RankedTensorType.get((H, T, T), f16)], + input=scaled, + output=tensor.empty((H, T, T), f16), + dimension=2, + ) + # @V: (H,T,T) @ (H,T,hs) -> (H,T,hs) f16, materialized into the (T,C) view. + out_filled = linalg.fill(zero, outs=[out_view]) + out = linalg.batch_matmul(aw, Vh, outs=[out_filled]) + bufferization.materialize_in_destination( + None, out, out_view_memref, restrict=True, writable=True + ) + self.kinds.append("fa") + + # ---- fused multi-head attention(ln_f32 (T,C) f32) -> (T,C) f16, NON-CAUSAL ---- + def fused_attention(self, x, wq, wk, wv, T, C, H): + # True multi-head attention via the flash kernel, with NO on-device + # head-transpose kernel. Flow: + # x(f32) -cast-> f16 -q/k/v proj-> (T,C) f16 buffers -heads_view (free)-> + # (H,T,hs) strided views -> attention_4d (fused flash kernel) -> @V written + # into a (T,C) f16 buffer via its (H,T,hs) view -> return that (T,C) f16. + hs = C // H + x16 = self.cast_f16(x, T, C) # ew + qbuf = self.cast_f16_buf( + self.matmul(x16, wq, T, C), T, C + ) # mm, ew -> (T,C) f16 memref + kbuf = self.cast_f16_buf(self.matmul(x16, wk, T, C), T, C) # mm, ew + vbuf = self.cast_f16_buf(self.matmul(x16, wv, T, C), T, C) # mm, ew + Qh = self.heads_view(qbuf, T, H, hs) # (H,T,hs) strided view (free) + Kh = self.heads_view(kbuf, T, H, hs) + Vh = self.heads_view(vbuf, T, H, hs) + # Output (T,C) f16 buffer, viewed as (H,T,hs) for the @V store. + out_buf = self._buf((T, C), self.f16) + out_view_memref = self._heads_view_of(out_buf, T, H, hs) + out_view = emit_buf_to_tensor(out_view_memref, restrict=True, writable=True) + self.attention_4d( + Qh, Kh, Vh, H, T, hs, out_view, out_view_memref + ) # fa, writes out_buf + return emit_buf_to_tensor(out_buf, restrict=True) # (T,C) f16 + + +# --------------------------------------------------------------------------- +# PAYLOAD ASSEMBLY -- wire the Builder ops into a complete MLIR function. +# Each build_*_payload creates one `func.func` (the "payload") whose arguments +# are the input + all weights (as device memrefs) and whose body is the op graph. +# `emit_buf_to_tensor` turns a memref argument into a tensor the ops can read; +# `func_cif` makes the function callable from C/the Runner. Returns (module, +# kinds) where `kinds` drives the schedule. +# --------------------------------------------------------------------------- + + +def _mha(q, k, v, H, causal=False): + """Multi-head attention over (T,C) q/k/v (already projected), per-head, with an + optional causal mask. Returns (T,C). Mirrors the fused kernel's math, which is + non-causal, so `causal` defaults to False.""" + T, C = q.shape + hs = C // H + scale = 1.0 / (hs**0.5) + mask = np.triu(np.full((T, T), -np.inf, np.float32), k=1) if causal else 0.0 + attn = np.zeros((T, C), np.float32) + for h in range(H): + sl = slice(h * hs, (h + 1) * hs) + scores = (q[:, sl] @ k[:, sl].T) * scale + mask + scores = scores - scores.max(-1, keepdims=True) + e = np.exp(scores) + w = e / e.sum(-1, keepdims=True) + attn[:, sl] = _f16(w) @ v[:, sl] + return attn + + +def _emit_block_fused(bld, x, w, T, C, hidden, H, eps, out_buf=None): + """Like _emit_block but the attention sublayer is the FUSED multi-head flash + kernel (NON-CAUSAL, no mask). `w` weight keys: g1,b1n, wq,wk,wv, wp,bp, + g2,b2n, w1,bb1,w2,bb2. wq/wk/wv/wp are full (C,C).""" + # ---- attention sublayer: a = x + proj(fused_attn(ln1(x))) ---- + ln1 = bld.layernorm(x, w["g1"], w["b1n"], T, C, eps) + attn16 = bld.fused_attention(ln1, w["wq"], w["wk"], w["wv"], T, C, H) # f16 (T,C) + proj = bld.matmul(attn16, w["wp"], T, C) + proj = bld.bias(proj, w["bp"], T, C, relu=False) + a = bld.add(x, proj, T, C) + # ---- FFN sublayer: y = a + ffn(ln2(a)) ---- + ln2 = bld.layernorm(a, w["g2"], w["b2n"], T, C, eps) + ln2_16 = bld.cast_f16(ln2, T, C) + h = bld.matmul(ln2_16, w["w1"], T, hidden) + h = bld.bias(h, w["bb1"], T, hidden, relu=True) + h16 = bld.cast_f16(h, T, hidden) + o = bld.matmul(h16, w["w2"], T, C) + o = bld.bias(o, w["bb2"], T, C, relu=False) + return bld.add(a, o, T, C, out_buf=out_buf) + + +def numpy_ref_block_fused(x, w, H, eps=1e-5, causal=False): + """Multi-head block reference (matches _emit_block_fused).""" + ln1 = _f16(_ln(x, w["g1"], w["b1n"], eps)) + q = _f16(ln1 @ w["wq"].astype(np.float32)) + k = _f16(ln1 @ w["wk"].astype(np.float32)) + v = _f16(ln1 @ w["wv"].astype(np.float32)) + attn = _mha(q, k, v, H, causal) + proj = _f16(attn) @ w["wp"].astype(np.float32) + w["bp"] + a = x + proj + ln2 = _f16(_ln(a, w["g2"], w["b2n"], eps)) + hh = np.maximum(_f16(ln2) @ w["w1"].astype(np.float32) + w["bb1"], 0.0) + o = _f16(hh) @ w["w2"].astype(np.float32) + w["bb2"] + return a + o + + +def build_gpt_fused_payload(func_name, T, C, hidden, vocab, n_layer, H, eps=1e-5): + """Full nanoGPT.py forward as ONE module, with FUSED multi-head attention per block + (R3/R4). Multi-head (H heads, flash attention), + NON-CAUSAL (no mask), wq/wk/wv/wp are (C,C). Embeddings done host-side.""" + f32, f16 = F32(), F16() + mod = ir.Module.create() + x_t = ir.MemRefType.get((T, C), f32) # input activations (256,256) f32 + g_t = ir.MemRefType.get((C,), f32) # layernorm gamma/beta vectors (256,) f32 + wqkv_t = ir.MemRefType.get((C, C), f16) # q/k/v projection weights (256,256) f16 + wproj_t = ir.MemRefType.get( + (C, C), f16 + ) # attention output proj weight (256,256) f16 + bvec_t = ir.MemRefType.get((C,), f32) # bias vectors (256,) f32 + w1_t = ir.MemRefType.get((C, hidden), f16) # FFN up-projection (256,1024) f16 + b1_t = ir.MemRefType.get((hidden,), f32) # FFN hidden bias (1024,) f32 + w2_t = ir.MemRefType.get((hidden, C), f16) # FFN down-projection (1024,256) f16 + lmw_t = ir.MemRefType.get((C, vocab), f16) # lm_head weight (256,256) f16 + lmb_t = ir.MemRefType.get((vocab,), f32) # lm_head bias (256,) f32 + out_t = ir.MemRefType.get((T, vocab), f32) # output logits (256,256) f32 + # per-layer arg types: g1,b1n, wq,wk,wv, wp,bp, g2,b2n, w1,bb1,w2,bb2 (13) -- NO mask. + per_layer = [ + g_t, + g_t, + wqkv_t, + wqkv_t, + wqkv_t, + wproj_t, + bvec_t, + g_t, + g_t, + w1_t, + b1_t, + w2_t, + bvec_t, + ] + from lighthouse.utils.mlir import func_cif + + fargs = [out_t, x_t] + for _ in range(n_layer): + fargs += per_layer + fargs += [g_t, g_t, lmw_t, lmb_t] # ln_f gamma/beta, lm_head W,b (no mask) + bld = Builder(T) + with ir.InsertionPoint(mod.body): + + @func_cif(*fargs, name=func_name) + def payload(*args): + output = args[0] + emit_buf_to_tensor(output, restrict=True, writable=True) + x = emit_buf_to_tensor(args[1], restrict=True) + idx = 2 + layer_w = [] + keys = [ + "g1", + "b1n", + "wq", + "wk", + "wv", + "wp", + "bp", + "g2", + "b2n", + "w1", + "bb1", + "w2", + "bb2", + ] + for _ in range(n_layer): + w = { + k: emit_buf_to_tensor(args[idx + i], restrict=True) + for i, k in enumerate(keys) + } + idx += len(keys) + layer_w.append(w) + gf_g = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + gf_b = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + lmw = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + lmb = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + + h = x + for w in layer_w: + h = _emit_block_fused(bld, h, w, T, C, hidden, H, eps) + hf = bld.layernorm(h, gf_g, gf_b, T, C, eps) + hf16 = bld.cast_f16(hf, T, C) + logits = bld.matmul(hf16, lmw, T, vocab) + bld.bias(logits, lmb, T, vocab, relu=False, out_buf=output) + for b in bld.to_dealloc: + gpu.dealloc(None, [], b) + + emit_gpu_util_funcs(f32, rank=2) + emit_gpu_util_funcs(f32, rank=1) + emit_gpu_util_funcs(f16, rank=2) + return mod, bld.kinds + + +def numpy_ref_gpt_fused(x, layer_w, gf_g, gf_b, lmw, lmb, H, eps=1e-5): + """Non-causal multi-head full-gpt reference.""" + h = x + for w in layer_w: + h = numpy_ref_block_fused(h, w, H, eps) + hf = _ln(h, gf_g, gf_b, eps) + return _f16(hf) @ lmw.astype(np.float32) + lmb + + +# ============================================================================= +# STAGE 2 -- SCHEDULE: describe HOW to lower the payload to the GPU +# ============================================================================= +# A "schedule" here is itself an MLIR module written in the TRANSFORM dialect: a +# little program of rewrite ops that the transform interpreter runs over the +# payload module. It does NOT compute anything; it REWRITES the payload from +# high-level linalg ops down to GPU (XeGPU) kernels. +# +# We can't reuse the repo's per-op schedules (layer_norm_schedule, mlp_schedule, +# softmax_schedule) directly, because each assumes the module contains ONLY its +# op. Our module is mixed (matmul + layernorm + softmax + elementwise), so we +# build ONE COMBINED schedule that handles all op classes. The strategy: +# +# (a) TILE each op into its own parallel loop nest (`scf.forall` = the GPU +# work-group grid). Different op classes tile differently: +# - matmul -> `_tile_one_matmul` (work-group tile + k-loop tile; the +# DPAS tile sizes come from `mm_params`) +# - layernorm-> `_tile_one_layernorm` (tile rows, fuse the 2 reductions + +# 2 zero-fills into the loop) +# - fused attn-> `_tile_one_fused_attention_region` (tile @V batch_matmul into +# a forall, fuse QK^T/scale/softmax/@V in; flash rewrite later) +# - elementwise -> a single `structured_tile_using_forall` over rows +# (b) SHARED TAIL (same for every kernel): vectorize -> bufferize (tensors -> +# memrefs) -> convert the forall grids to `gpu.launch` -> OUTLINE each into +# its own `gpu.module`/`gpu.func` kernel -> attach the XeVM target. +# (c) ANNOTATE each kernel with XeGPU layout attributes (how data maps to +# sub-groups / DPAS tiles). +# +# `kinds` (from the Builder) tells the schedule the class and order of every +# kernel, so steps (a) and (c) can treat each one correctly. See memory.md +# parts 6-10 for the subtle correctness rules baked into the helpers below. +# ============================================================================= +def _tile_one_matmul(matmul_op, anytype, mm_params): + """Tile ONE matmul for DPAS: a work-group `forall` tile (wg_m x wg_n) with any + elementwise consumer fused in, then an inner reduction (k) loop. Tile sizes + come from `mm_params` (chosen by xegpu_parameter_selector for the GPU).""" + wg_tile = [mm_params["wg_m"], mm_params["wg_n"]] + consumers = transform_ext.get_tileable_consumers(matmul_op) + leaf = transform_ext.extract_handle(consumers, -1) + _, [wg_loop], _ = lh_transform.tile( + leaf, + tile_sizes=wg_tile, + fuse_producers=True, + use_forall=True, + apply_cleanup=False, + ) + wg_matmul = match(wg_loop, ops={"linalg.matmul"}) + lh_transform.tile(wg_matmul, tile_sizes=[0, 0, mm_params["k_tile"]]) + + +def _tile_one_fused_attention_region(anytype, qkt_bmm, pv_bmm, softmax_op, fa_params): + """Tile + fuse ONE attention region (QK^T -> scale -> softmax -> @V) into a + SINGLE scf.forall, so it vectorizes/bufferizes into one kernel body that + `replace_with_fused_attention` later rewrites into the flash loop. + + Operates on PRE-SPLIT, per-region + handles (qkt_bmm, pv_bmm, softmax_op) so it is region-local and works at any + multiplicity. All further producers are pulled in via get_producer_of_operand + (SSA-walk = inherently scoped to this region).""" + prod = transform.get_producer_of_operand + + def fuse(p, c): + return structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=p, containing_op=c + )[1] + + wg_rows = fa_params["wg_rows"] + # 1. Tile the @V batch_matmul in (batch=1, M=wg_rows) -> forall grid. + tiled_pv, forall = structured.structured_tile_using_forall( + anytype, + anytype, + pv_bmm, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(1, wg_rows, 0, 0), + ) + func = transform.get_parent_op( + anytype, forall, op_name="func.func", deduplicate=True + ) + # 2. Fuse the @V output init fill (producer of forall operand 0). + forall = fuse(prod(anytype, forall, operand_number=0), forall) + transform.apply_cse(func) + canonicalize(func) + # 3. Decompose this region's softmax. linalg.softmax -> 4 generics + 2 fills: + # max = reduce_max(scaled) [+ -inf fill] + # num = exp(scaled - max) + # den = reduce_sum(num) [+ 0 fill] + # div = num / den (feeds @V) + structured.structured_decompose_interface(anytype, softmax_op) + transform.apply_cse(func) + canonicalize(func) + # Grab the whole producer chain UP FRONT via SSA walk (region-local; no count + # matching). Fusing op X invalidates only X's handle, so collect all, then fuse + # each once in consumer->producer topological order. + # tiled_pv operand 0 is the aw extract_slice inside the forall; hop through it + # to the func-scope softmax `div` that it slices. + aw_slice = prod(anytype, tiled_pv, operand_number=0) + div = prod(anytype, aw_slice, operand_number=0) # num / den (softmax out) + num = prod(anytype, div, operand_number=0) # exp generic + den = prod(anytype, div, operand_number=1) # sum-reduce generic + den_fill = prod(anytype, den, operand_number=1) # 0 fill (sum acc) + mx = prod(anytype, num, operand_number=1) # max-reduce generic + mx_fill = prod(anytype, mx, operand_number=1) # -inf fill (max acc) + scaled = prod(anytype, num, operand_number=0) # linalg.mul (qkt*scale) + scale_fill = prod(anytype, scaled, operand_number=1) # scale-constant fill + qkt = prod(anytype, scaled, operand_number=0) # QK^T batch_matmul + kt = prod(anytype, qkt, operand_number=1) # K^T transpose + qkt_fill = prod(anytype, qkt, operand_number=2) # 0 fill (qkt acc) + for p in ( + div, + den, + num, + mx, + scaled, + qkt, + den_fill, + mx_fill, + scale_fill, + qkt_fill, + kt, + ): + forall = fuse(p, forall) + transform.apply_cse(func) + canonicalize(func) + return func, forall + + +def _fuse_attention_in_region(anytype, forall, fa_params): + """After the shared bufferize+vectorize, rewrite ONE attention region's + vector.contract pair (QK^T, @V) into the flash loop via the transform + op. Scoped to `forall` so counts are exact at any multiplicity.""" + contract_ops = match_and_split(forall, ops={"vector.contract"}, nhandles=2) + first_contract, second_contract = contract_ops[0], contract_ops[1] + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) + mulf_op = match_and_split(forall, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + # NB: the merged fused-attention op is non-causal only -- there is + # no `causal` parameter yet, so the model runs as non-causal attention. + transform_ext.replace_with_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=fa_params["inner_loop_tile_size"], + ) + + +def xegpu_fa_annotation(gf, anytype, fa_params): + """Attach XeGPU layouts to ONE fused-attention gpu.func.""" + num_subgroups = fa_params["wg_rows"] // fa_params["sg_rows"] + n_head = fa_params["n_head"] + q_sg_layout = [num_subgroups, 1] + q_sg_data = [16, n_head] + q_inst_data = [8, 16] + k_sg_layout = [num_subgroups, 1] + k_sg_data = [16, n_head] + k_inst_data = [16, 16] + v_sg_layout, v_sg_data, v_inst_data = k_sg_layout, k_sg_data, k_inst_data + kt_sg_layout = [1, num_subgroups] + kt_sg_data = [n_head, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + out_sg_layout, out_sg_data, out_inst_data = q_sg_layout, q_sg_data, q_inst_data + l128_sg_layout = [num_subgroups, 1] + l128_sg_data = [16, 16] + l128_inst_data = [8, 16] + qk_sg_layout, qk_sg_data, qk_inst_data = ( + l128_sg_layout, + l128_sg_data, + l128_inst_data, + ) + + store_nd_op = match_and_split(gf, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) + load_nd_ops = match_and_split(gf, ops={"xegpu.load_nd"}, nhandles=9) + xegpu.set_anchor_layout( + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data + ) + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data, + ) + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + ) + dpas_ops = match_and_split(gf, ops={"xegpu.dpas"}, nhandles=8) + for i in range(4): + d = dpas_ops[i] + xegpu.set_anchor_layout( + d, sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data, index=0 + ) + xegpu.set_anchor_layout( + d, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1, + ) + xegpu.set_anchor_layout( + d, + sg_layout=l128_sg_layout, + sg_data=l128_sg_data, + inst_data=l128_inst_data, + index=2, + ) + for i in range(4, 8): + d = dpas_ops[i] + xegpu.set_anchor_layout( + d, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0, + ) + xegpu.set_anchor_layout( + d, sg_layout=v_sg_layout, sg_data=v_sg_data, inst_data=v_inst_data, index=1 + ) + xegpu.set_anchor_layout( + d, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2, + ) + + +def build_combined_schedule( + mm_params, ln_params, kinds, stop_at_stage="", fa_params=None +): + """Build the transform-dialect schedule module for a payload with op classes + `kinds`. Counts how many of each class there are, then delegates to `_bundle` + (wrapped in transform boilerplate). `stop_at_stage` lets callers halt early + for debugging (--dump ).""" + n_mm = kinds.count("mm") + n_ln = kinds.count("ln") + n_sm = kinds.count("sm") + n_ew = kinds.count("ew") + with schedule_boilerplate() as (schedule, named_seq): + anytype = transform.AnyOpType.get() + func0 = match(named_seq.bodyTarget, ops={"func.func"}) + mod = transform.get_parent_op( + anytype, func0, op_name="builtin.module", deduplicate=True + ) + try: + _bundle( + mod, + mm_params, + ln_params, + kinds, + n_mm, + n_ln, + n_sm, + n_ew, + stop_at_stage, + fa_params=fa_params, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + return schedule + + +def _bundle( + mod, + mm_params, + ln_params, + kinds, + n_mm, + n_ln, + n_sm, + n_ew, + stop_at_stage="", + fa_params=None, +): + """THE PASS ORCHESTRATOR -- emits the actual sequence of transform ops. + + Runs in 3 phases over the whole payload module: + TILE -- tile every op into a GPU work-group `forall` (per op class) + SHARED TAIL -- vectorize, bufferize, forall->gpu.launch, outline kernels, + attach the XeVM target, lower vector ops to XeGPU + ANNOTATE -- attach XeGPU sub-group/DPAS layout to each kernel + `stop_at_stage` raises PipelineInterrupt to halt after a phase (for --dump). + Reading the inline comments here is the best way to understand "which part of + the code schedules the passes" -- it is this function, top to bottom.""" + anytype = transform.AnyOpType.get() + rss = ln_params["reduction_step_size"] + wg_rows = ln_params["wg_rows"] + nkernels = len(kinds) + n_fa = kinds.count("fa") + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + # ===== TILE each op-class into its own forall ===== + # KEY PROBLEM: match(linalg.generic) is NOT scoped -- once an op is tiled into + # a forall, its generic is STILL matched (it's just nested), so we can't + # re-match "the remaining bare generics" by count. SOLUTION: split ALL generic + # handles ONCE up front (their build order is deterministic), then tile each + # using its preserved handle. A handle to op X stays valid across tiling of + # OTHER ops. We tile the simple EW generics first (no fusion/cleanup, so ln + # handles survive), then the layernorms (which fuse + cleanup). + # + # Generic build order: each layernorm contributes [mean, var, normalize] (3), + # in block build order; each elementwise contributes 1. We reconstruct the + # per-op handle slices from `kinds`. + # 'fa' softmax generics do NOT exist yet (fa is tiled last, softmax still + # un-decomposed), so they are not in this pool. The fa core's linalg.transpose + # /linalg.mul/batch_matmul are not linalg.generic, so also excluded. (The head + # reshape is a pure memref VIEW -- no generic, no kernel; see Builder.heads_view.) + ngen_total = 3 * n_ln + n_ew + gen_handles = transform.split_handle( + (anytype,) * ngen_total, match(mod, ops={"linalg.generic"}) + ) + # Walk kinds to assign generic handles to ops. + ln_slices, ew_handles = [], [] + gi = 0 + for k in kinds: + if k == "ln": + ln_slices.append( + (gen_handles[gi], gen_handles[gi + 1], gen_handles[gi + 2]) + ) + gi += 3 + elif k == "ew": + ew_handles.append(gen_handles[gi]) + gi += 1 + # mm / sm / fa contribute no bare linalg.generic here + + # 1) Tile layernorms FIRST, using preserved (mean,var,normalize) handles. + # Doing this BEFORE EW/matmul tiling keeps the bare linalg.fill pool exactly + # predictable: 2*(untiled lns) + n_mm (matmul accumulator fills). EW tiling + # can introduce its own init fills, so we must finish ln fill-fusion first. + for i, (mean_red, var_red, normalize) in enumerate(ln_slices): + ln_untiled = n_ln - i + _tile_one_layernorm( + mod, + anytype, + wg_rows, + rss, + mean_red, + var_red, + normalize, + ln_untiled, + n_mm, + ln_params["T"], + ) + + # 2) Tile EW generics into own foralls (handles preserved across ln tiling). + for eg in ew_handles: + structured.structured_tile_using_forall( + anytype, + anytype, + eg, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(wg_rows,), + ) + + # 4) Matmuls (their EW producers already wrapped in foralls) + mms = match_and_split(mod, ops={"linalg.matmul"}, nhandles=n_mm) + for mm in mms: + _tile_one_matmul(mm, anytype, mm_params) + + # 5) Fused-attention regions. Done LAST so the generic pre-split above ran while + # each fa softmax was still ONE linalg.softmax (its decomposition generics + # don't exist yet, so they can't inflate ngen_total). Pre-split the 2*n_fa + # batch_matmuls (build order [QK^T, @V] per region) + n_fa softmaxes by count, + # then tile+fuse each region into ONE forall (decompose happens in-region). + if n_fa: + fa_bmms = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2 * n_fa) + fa_softmaxes = match_and_split(mod, ops={"linalg.softmax"}, nhandles=n_fa) + for r in range(n_fa): + _tile_one_fused_attention_region( + anytype, fa_bmms[2 * r], fa_bmms[2 * r + 1], fa_softmaxes[r], fa_params + ) + + func = match(mod, ops={"func.func"}) + lh_transform.cleanup(func) + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # ===== SHARED TAIL ===== + func = structured.structured_vectorize_children_and_apply_patterns( + anytype, func, fold_type_extensions_into_contract=True + ) + lh_transform.cleanup(func) + # Fused-attention regions carry a batch-of-1 dim from the (1,wg_rows,0,0) tiling; + # drop leading unit dims so the QK^T/@V vector.contracts become 2D, as the flash + # rewrite expects. + if n_fa: + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() + transform.apply_cse(func) + canonicalize(func) + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap, + ).result + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # ===== FUSED-ATTENTION REWRITE (after bufferize+vectorize, before gpu.launch) ===== + # Re-find each attention forall by kinds index (forall IR order == kinds order, + # the invariant the launch/gpu_mods loops below also rely on) and rewrite its + # QK^T/@V vector.contract pair into the flash online-softmax loop. Must run + # BEFORE forall->gpu.launch so the producer-walks for q/k/v loads stay in-region. + if n_fa: + all_foralls = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) + for idx, kind in enumerate(kinds): + if kind == "fa": + _fuse_attention_in_region(anytype, all_foralls[idx], fa_params) + func = match(mod, ops={"func.func"}) + transform.apply_cse(func) + canonicalize(func) + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + wg_loops = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) + for wg_loop in wg_loops: + loop.loop_forall_to_parallel([anytype], wg_loop) + func = match(mod, ops={"func.func"}) + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # launch threads per kernel, in IR (build) order = `kinds`. + launches = match_and_split(mod, ops={"gpu.launch"}, nhandles=nkernels) + mm_threads = ( + (mm_params["wg_m"] // mm_params["sg_m"]) + * (mm_params["wg_n"] // mm_params["sg_n"]) + * 16 + ) + sm_threads = (ln_params["wg_rows"] // ln_params["sg_rows"]) * ln_params[ + "subgroup_size" + ] + fa_threads = ( + (fa_params["wg_rows"] // fa_params["sg_rows"]) * fa_params["subgroup_size"] + if fa_params + else 0 + ) + for launch, kind in zip(launches, kinds): + nt = {"mm": mm_threads, "fa": fa_threads}.get(kind, sm_threads) + xegpu.set_gpu_launch_threads(launch, threads=[nt, 1, 1]) + + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + mod = apply_registered_pass( + mod, "xevm-attach-target", options={"O": "3", "chip": "pvc"} + ) + + # per-gpu.module convert-vector-to-xegpu. ONLY ln/sm need SLM allocas (their + # cross-lane reductions go through shared local memory -> store_matrix). The + # ew kernels (cast/bias/residual) are pure row-parallel: forcing their allocas + # to SLM creates store_matrix paths that fail to lower. So SLM-ify ln/sm only; + # leave ew (and mm) as store_nd. + gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) + sg_layout = [ln_params["sg_rows"], 1] + sg_data = [ln_params["sg_rows"], rss] + for gm, kind in zip(gpu_mods, kinds): + gf = match(gm, ops={"gpu.func"}) + if kind in ("ln", "sm"): + allocas = match(gf, ops={"memref.alloca"}) + transform_ext.update_address_space(allocas, address_space=3) + gf = apply_registered_pass(gf, "convert-vector-to-xegpu") + transform.apply_cse(gf) + if kind == "fa": + # flash kernel carries state in iter_args (no SLM); hoist invariants. + gf = apply_registered_pass(gf, "loop-invariant-code-motion") + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # ===== PER-KERNEL ANNOTATION ===== + # mm -> full mlp wg annotation + # ln -> store_nd (1) + store_matrix (the SLM reduction stores) + # sm -> store_nd (1) + store_matrix (4) + # ew -> store_nd (1) only (pure row-parallel, no SLM) + gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) + for gm, kind in zip(gpu_mods, kinds): + gf = match(gm, ops={"gpu.func"}) + if kind == "mm": + xegpu_wg_annotation_for_mlp_layer(gf, **mm_params) + elif kind == "fa": + xegpu_fa_annotation(gf, anytype, fa_params) + else: + # ln/sm/ew: anchor-layout their store_nd, and (ln/sm) their SLM + # store_matrix. Pass the whole match handle to set_anchor_layout (it + # accepts a multi-handle) -- avoids guessing exact store counts. + xegpu.set_anchor_layout( + match(gf, ops={"xegpu.store_nd"}), sg_layout=sg_layout, sg_data=sg_data + ) + if kind in ("ln", "sm"): + xegpu.set_anchor_layout( + match(gf, ops={"xegpu.store_matrix"}), + sg_layout=sg_layout, + sg_data=sg_data, + ) + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + return mod + + +def _tile_one_layernorm( + mod, anytype, wg_rows, rss, mean_red, var_red, normalize, ln_untiled, n_mm, T_ROWS +): + """Tile ONE layernorm into its own forall, using PRESERVED handles to its 3 + generics (mean_red, var_red, normalize). Handles to other ops stay valid. + + The 2 accumulator fills are matched by their producer relationship: we match + all fills and fuse the ones that feed this ln. To avoid touching matmul fills, + we rely on fuse_into_containing pulling only genuine producers of the forall. + """ + _, ln_forall = structured.structured_tile_using_forall( + anytype, + anytype, + normalize, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(wg_rows,), + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=var_red, containing_op=ln_forall + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=mean_red, containing_op=ln_forall + ) + # Fuse this ln's 2 accumulator fills into the forall. Robustly select ONLY the + # layernorm accumulator fills (NOT matmul fills) by filtering on result type: + # ln accumulators are rank-1 tensor; matmul accumulators are rank-2. + # This avoids fragile positional counting across the whole block. There are + # 2*ln_untiled such rank-1 fills (this ln + other untiled lns); this ln's are + # the FIRST 2 in IR order. + ln_func = transform.get_parent_op( + anytype, ln_forall, op_name="func.func", deduplicate=True + ) + reduce_t = ir.RankedTensorType.get((T_ROWS,), F32()) # ln accumulator type (T,) + fill_match = structured.MatchOp( + anytype, ln_func, ops=["linalg.fill"], filter_result_type=reduce_t + ) + n_ln_fills = 2 * ln_untiled + fills = transform.split_handle((anytype,) * n_ln_fills, fill_match.results[0]) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=fills[1], containing_op=ln_forall + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=fills[0], containing_op=ln_forall + ) + # Fusion leaves the full-size original fills DEAD at func scope (fusion only + # slices a copy inside the forall). They must be removed or the next ln finds + # too many. Use canonicalize (which does DCE of the dead originals) at FUNC + # scope, but NEVER apply_cse at func scope -- CSE would merge the identical + # live zero-fills ACROSS layernorms. CSE the duplicate GENERICS inside the + # forall only (scoped), so the re-match below finds exactly 3. + transform.apply_cse(ln_forall) + canonicalize(ln_func) + # tile this ln's reductions+normalize (now inside the forall). Re-match the + # 3 generics INSIDE the forall (scoped to ln_forall, so unambiguous: exactly 3). + g2 = match_and_split(ln_forall, ops={"linalg.generic"}, nhandles=3) + structured.TileUsingForOp(g2[2], sizes=[0, rss]) + structured.structured_tile_reduction_using_for( + [anytype], anytype, anytype, anytype, target=g2[1], tile_sizes=[0, rss] + ) + structured.structured_tile_reduction_using_for( + [anytype], anytype, anytype, anytype, target=g2[0], tile_sizes=[0, rss] + ) + transform.apply_cse(ln_forall) + canonicalize(ln_forall) + + +# ============================================================================= +# NUMPY REFERENCE -- the same math in plain numpy, to CHECK the GPU result. +# These mirror what the model computes. `_f16` rounds through float16 to model +# the GPU's f16 matmul precision, so the comparison tolerance can be tight. +# ============================================================================= +def _ln(x, gamma, beta, eps=1e-5): + mu = x.mean(-1, keepdims=True) + var = x.var(-1, keepdims=True) + return (x - mu) / np.sqrt(var + eps) * gamma + beta + + +def _f16(a): + # round f32 -> f16 -> f32: models the precision loss of the GPU's f16 matmul. + return a.astype(np.float16).astype(np.float32) + + +def main(): + """Entry point. Builds the FULL gpt model (n_layer blocks -> ln_f -> lm_head), + flash multi-head NON-CAUSAL attention per block. Flags: + --gpt-layers N : number of transformer layers (default 6) + --check : run on the GPU and compare to the numpy reference + --dump STAGE : print IR at a stage and exit, one of + initial | schedule | tiled | vectorized | + bufferized | inner-tiled | gpu-outlining | + xegpu-initial | xegpu-wg | final + + Flow: build payload module -> build combined schedule (which folds in the fused + attention rewrite) -> TransformDriver lowers it to XeGPU + xegpu_to_binary makes + the GPU binary -> Runner JIT-runs it -> compare to the numpy reference.""" + dump = None + check = "--check" in sys.argv + if "--dump" in sys.argv: + dump = sys.argv[sys.argv.index("--dump") + 1] + + # Kernel-friendly shapes: T=C=256 (q/k/v/proj matmuls), + # hidden=1024, vocab=256, n_layer=6. True multi-head: H heads of + # head_size=C/H=64 -- the fused flash kernel handles head_size=64 fine. + T, C, hidden = 256, 256, 1024 + vocab, n_layer = 256, 6 + H = 4 # attention heads (hs = C/H = 64) + if "--gpt-layers" in sys.argv: + n_layer = int(sys.argv[sys.argv.index("--gpt-layers") + 1]) + # mm/sm params drive the non-attention kernels (matmul, layernorm); fa_params + # drives the fused attention kernel. + param_selector = XeGPUParameterSelector() + mm_params = dict(param_selector.get_parameters((T, C, C))) + mm_params["gpu_specs"] = param_selector.gpu_specs + ln_params = { + "wg_rows": 64, + "sg_rows": 8, + "subgroup_size": 16, + "reduction_step_size": 16, + "T": T, + } + fa_params = { + "batch_size": 1, + "num_heads": H, + "n_ctx": T, + "n_head": C // H, + "wg_rows": 128, + "sg_rows": 16, + "subgroup_size": 16, + "inner_loop_tile_size": 64, + } + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + mod, kinds = build_gpt_fused_payload("payload", T, C, hidden, vocab, n_layer, H) + if dump == "initial": + print(mod) + print("KINDS:", kinds) + return + + sched = build_combined_schedule( + dict(mm_params), + dict(ln_params), + kinds, + stop_at_stage=(dump or ""), + fa_params=dict(fa_params), + ) + if dump == "schedule": + print(sched) + return + schedules = [sched] + if not dump or dump == "final": + schedules.append(xegpu_to_binary()) + payload = TransformDriver(schedules).apply(mod) + if dump: + print(payload) + return + print(f"LOWERED OK: 'gpt-fused' to {len(kinds)} kernels in one module") + + if not check: + return + runner = Runner( + payload, + mem_manager_cls=GPUMemoryManager, + shared_libs=["libmlir_levelzero_runtime.so"], + ) + np.random.seed(0) + out = np.zeros((T, vocab), np.float32) + cb = Runner.get_gpu_argument_access_callback(out, arg_index=0) + sc = 0.05 # small weight scale -> O(1) activations so f16 stays accurate + + # full model, fused multi-head attn per block. + # host "embeddings": simulate token+pos embedding sum as the input x. + x = (np.random.randn(T, C) * 0.5).astype(np.float32) + layers = [] + host = [out, x] + for _ in range(n_layer): + lw = dict( + g1=np.ones(C, np.float32), + b1n=np.zeros(C, np.float32), + wq=(np.random.randn(C, C) * sc).astype(np.float16), + wk=(np.random.randn(C, C) * sc).astype(np.float16), + wv=(np.random.randn(C, C) * sc).astype(np.float16), + wp=(np.random.randn(C, C) * sc).astype(np.float16), + bp=np.zeros(C, np.float32), + g2=np.ones(C, np.float32), + b2n=np.zeros(C, np.float32), + w1=(np.random.randn(C, hidden) * sc).astype(np.float16), + bb1=np.zeros(hidden, np.float32), + w2=(np.random.randn(hidden, C) * sc).astype(np.float16), + bb2=np.zeros(C, np.float32), + ) + layers.append(lw) + host += [ + lw["g1"], + lw["b1n"], + lw["wq"], + lw["wk"], + lw["wv"], + lw["wp"], + lw["bp"], + lw["g2"], + lw["b2n"], + lw["w1"], + lw["bb1"], + lw["w2"], + lw["bb2"], + ] + gf_g = np.ones(C, np.float32) + gf_b = np.zeros(C, np.float32) + lmw = (np.random.randn(C, vocab) * sc).astype(np.float16) + lmb = np.zeros(vocab, np.float32) + host += [gf_g, gf_b, lmw, lmb] + runner.execute( + host_input_buffers=host, + payload_function_name="payload", + argument_access_callback=cb, + ) + ref = numpy_ref_gpt_fused(x, layers, gf_g, gf_b, lmw, lmb, H) + + rel = np.abs(out - ref).max() / (np.abs(ref).max() + 1e-6) + print(f"max abs diff={np.abs(out - ref).max():.4f} rel={rel:.6f}") + print("PASSED" if rel < 5e-2 else "FAILED") + + +if __name__ == "__main__": + main() From 8387cd7ae53f219388f2fdde2b7db5b253d220ac Mon Sep 17 00:00:00 2001 From: nbpatel Date: Wed, 17 Jun 2026 21:36:08 +0000 Subject: [PATCH 2/6] Add lit RUN/CHECK header to nanoGPT example The example was discovered by lit but had no RUN line, causing the pre-commit suite to report UNRESOLVED ("Test has no 'RUN:' line"). Add a RUN line that dumps the xegpu-wg IR (1 layer, host-side, no GPU) and FileChecks the module header, matching the sibling xegpu examples. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/xegpu/nanoGPT.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/xegpu/nanoGPT.py b/examples/xegpu/nanoGPT.py index b435f4a4..962ded56 100644 --- a/examples/xegpu/nanoGPT.py +++ b/examples/xegpu/nanoGPT.py @@ -1,3 +1,6 @@ +# RUN: %PYTHON %s --dump xegpu-wg --gpt-layers 1 | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + """nano-GPT / GPT-2-style forward pass on the Intel GPU (XeGPU), with FLASH multi-head attention. From 771b1ff2a4882a576a47e181086b2b095b6e2d3a Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 18 Jun 2026 15:36:03 +0000 Subject: [PATCH 3/6] Split nanoGPT example into payload, schedule, and driver modules --- examples/xegpu/nanoGPT.py | 1157 +---------------- .../ingress/mlir_gen/gpu_nanoGPT_payload.py | 459 +++++++ lighthouse/schedule/xegpu/nanoGPT_schedule.py | 646 +++++++++ 3 files changed, 1140 insertions(+), 1122 deletions(-) create mode 100644 lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py create mode 100644 lighthouse/schedule/xegpu/nanoGPT_schedule.py diff --git a/examples/xegpu/nanoGPT.py b/examples/xegpu/nanoGPT.py index 962ded56..e814e87b 100644 --- a/examples/xegpu/nanoGPT.py +++ b/examples/xegpu/nanoGPT.py @@ -2,7 +2,7 @@ # CHECK: module attributes {gpu.container_module} { """nano-GPT / GPT-2-style forward pass on the Intel GPU (XeGPU), with -FLASH multi-head attention. +FLASH multi-head attention -- the DRIVER ("run it") entry point. This is a nanoGPT block stack: each transformer block is a = x + attn_proj( MultiHeadAttention( ln1(x) ) ) # attention sublayer @@ -36,388 +36,56 @@ (1,wg_rows,0,0) tiling peels the head dim into the work-group GRID so each wg reads 2D strided slices -> 2D load_nd . +HOW THIS EXAMPLE IS ORGANIZED -- compiling a model to the GPU here happens in +THREE stages: + + 1. PAYLOAD ("WHAT to compute") -> lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py + (the `Builder` class + `build_gpt_fused_payload`). Hardware-agnostic linalg. + 2. SCHEDULE ("HOW to lower it") -> lighthouse/schedule/xegpu/nanoGPT_schedule.py + (`build_combined_schedule`). A transform-dialect module that tiles each op + into GPU work-groups, vectorizes, bufferizes, outlines each op into its own + GPU kernel, and attaches XeGPU layout/target attributes. + 3. DRIVER ("run it") -> this file. `main()` applies the schedule to the payload + (TransformDriver), JIT-compiles + runs it on the GPU (Runner), and checks the + result against the plain-numpy reference below. + +KEY IDEA -- one module, many separate kernels: the whole model is ONE MLIR module, +but each op becomes its OWN GPU kernel (no cross-op fusion). Data passes between +kernels through device buffers (`gpu.alloc`) that stay on the GPU -- no round-trip +to the host between ops. + Run: - .venv/bin/python examples/xegpu/gpt2.py [--gpt-layers N] [--check] - .venv/bin/python examples/xegpu/gpt2.py [--dump STAGE] + .venv/bin/python examples/xegpu/nanoGPT.py [--gpt-layers N] [--check] + .venv/bin/python examples/xegpu/nanoGPT.py [--dump STAGE] """ import sys import numpy as np from mlir import ir -from mlir.dialects import linalg, bufferization, tensor, arith, math, gpu, memref -from mlir.dialects import transform -from mlir.dialects.transform import structured, loop, xegpu -from mlir.dialects.transform import bufferization as transform_bufferization -from mlir.dialects.transform.vector import ( - apply_patterns_vector_cast_away_vector_leading_one_dim, - apply_patterns_vector_drop_unit_dims_with_shape_cast, -) -from mlir.dialects.bufferization import LayoutMapOption from lighthouse import dialects as lh_dialects -import lighthouse.transform as lh_transform -from lighthouse.dialects.transform import transform_ext -from lighthouse.pipeline.helper import ( - apply_registered_pass, - canonicalize, - match, - match_and_split, - PipelineInterrupt, -) -from lighthouse.schedule import schedule_boilerplate -from lighthouse.ingress.mlir_gen.utils import ( - emit_buf_to_tensor, - affine_map, - parallel, - reduction, -) -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs -from lighthouse.ingress.mlir_gen.named import times_weights from lighthouse.pipeline.driver import TransformDriver from lighthouse.execution.runner import Runner from lighthouse.execution import GPUMemoryManager from lighthouse.schedule.xegpu import xegpu_to_binary, XeGPUParameterSelector -from lighthouse.schedule.xegpu.mlp_schedule import xegpu_wg_annotation_for_mlp_layer +from lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload import build_gpt_fused_payload +from lighthouse.schedule.xegpu.nanoGPT_schedule import build_combined_schedule # ============================================================================= -# HOW THIS FILE IS ORGANIZED -# ============================================================================= -# Compiling a model to the GPU here happens in THREE stages. The code is grouped -# to match them: -# -# 1. PAYLOAD -- "WHAT to compute". Build an MLIR module describing the math -# as high-level `linalg` ops (matmul, layernorm, softmax, elementwise). This -# is hardware-agnostic; it does not say how to run on the GPU yet. -# -> class `Builder` (emits one op at a time) and the `build_*_payload` -# functions (assemble ops into ffn / attn / block / full-gpt). -# -# 2. SCHEDULE -- "HOW to lower it to the GPU". Build a SECOND MLIR module (a -# "transform dialect" script) that rewrites the payload: tile each op into -# GPU work-groups, vectorize, bufferize, outline each op into its own GPU -# kernel, and attach XeGPU layout/target attributes. -# -> `build_combined_schedule` / `_bundle` (the orchestrator) plus the -# `_tile_one_matmul` / `_tile_one_layernorm` / `_tile_one_fused_attention_region` -# helpers. -# *** THIS is the part that "schedules the passes". *** -# -# 3. DRIVER -- "run it". `main()` applies the schedule to the payload -# (TransformDriver), JIT-compiles + runs it on the GPU (Runner), and checks -# the result against a plain-numpy reference (`numpy_ref_*`). -# -# KEY IDEA -- one module, many separate kernels: the whole model is ONE MLIR -# module, but each op becomes its OWN GPU kernel (no cross-op fusion). Data passes -# between kernels through device buffers (`gpu.alloc`) that stay on the GPU -- no -# round-trip to the host between ops. -# ============================================================================= - - -def F32(): # 32-bit float (used for accumulation / norms) - return ir.F32Type.get() - - -def F16(): # 16-bit float (required by the GPU matmul units) - return ir.F16Type.get() - - -# ============================================================================= -# STAGE 1 -- PAYLOAD: describe WHAT to compute (hardware-agnostic linalg ops) -# ============================================================================= -# Each Builder method emits ONE high-level op that writes its result into a fresh -# on-device buffer (`gpu.alloc`), and returns a tensor "view" of that buffer for -# the next op to read. Because each op writes a distinct device buffer, each will -# become its OWN GPU kernel later; the buffers are the on-device handoff between -# kernels (kernel N writes buffer B, kernel N+1 reads B -- no host round-trip). -# -# dtype convention: the GPU matmul (DPAS) hardware needs f16 inputs and produces -# an f32 result. LayerNorm/softmax run in f32. So between a norm/softmax and a -# matmul we insert an explicit f32->f16 `cast` op (its own kernel). +# NUMPY REFERENCE -- the same math in plain numpy, to CHECK the GPU result. +# These mirror what the model computes. `_f16` rounds through float16 to model +# the GPU's f16 matmul precision, so the comparison tolerance can be tight. # ============================================================================= -class Builder: - """Emits the model's ops and remembers the order/kind of each one. - - `kinds` is the crucial bookkeeping: an ordered list, one entry per op emitted, - recording its "class" so the SCHEDULE (stage 2) can later tile and annotate - each kernel correctly. Classes: - 'mm' = matmul (linalg.matmul) -> DPAS systolic-array kernel - 'ln' = layernorm (3 generics + 2 fills) -> reduction kernel (uses shared mem) - 'fa' = flash multi-head attention -> ONE kernel (QK^T->softmax->@V, - online-softmax over K/V tiles); see attention_4d + the fused-attention - schedule helpers. (Softmax lives INSIDE this kernel, not as its own.) - 'ew' = elementwise (cast / bias / relu / residual) -> simple row-parallel kernel - The op build order in the payload == the order of `kinds` == the order the - kernels appear in the final module, which is how the schedule matches them up. - """ - - def __init__(self, T): - self.T = T - self.f32, self.f16 = F32(), F16() - self.kinds = [] # ordered kernel classes (see docstring) - self.to_dealloc = [] # device buffers to gpu.dealloc at the end - - def _buf(self, shape, dtype): - # Allocate a DEVICE buffer (lives in GPU memory). Returns the memref. - b = gpu.alloc(ir.MemRefType.get(shape, dtype), None, [], [], []) - self.to_dealloc.append(b) - return b - - def _par(self, rank=2): - # Identity affine map (d0,d1,...) -> (d0,d1,...): a plain elementwise - # access pattern where output[i,j] depends on input[i,j]. - return affine_map(rank, [ir.AffineDimExpr.get(i) for i in range(rank)]) - - # ---- matmul: a(M,K) f16 @ b(K,N) f16 -> (M,N) f32 buffer ---- - def matmul(self, a, b, M, N, out_buf=None): - # Standard C = A @ B. `times_weights` emits linalg.matmul; we first fill the - # accumulator with 0. f16 inputs, f32 output -- matches the DPAS hardware. - buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - acc = linalg.fill(arith.constant(self.f32, 0.0), outs=[out_t]) - res = times_weights(a, b, acc) - bufferization.materialize_in_destination( - None, res, buf, restrict=True, writable=True - ) - self.kinds.append("mm") - if out_buf is not None: # caller gave the final output buffer - return None - return emit_buf_to_tensor(buf, restrict=True) - - # ---- layernorm(x (M,N) f32, gamma,beta (N,)) -> (M,N) f32 buffer ---- - def layernorm(self, x, gamma, beta, M, N, eps=1e-5): - # LayerNorm normalizes each ROW (length N) to mean 0 / variance 1, then - # scales by gamma and shifts by beta. Built from 3 linalg.generic ops: - # (1) mean_sum[i] = sum_j x[i,j] (row reduction) - # (2) var_sum[i] = sum_j (x[i,j]-mean_i)^2 (row reduction) - # (3) out[i,j] = (x[i,j]-mean_i)*rsqrt(var_i+eps)*gamma[j] + beta[j] - # Affine maps describe each operand's access pattern: - # par2 (d0,d1)->(d0,d1) : full 2-D elementwise - # red2 (d0,d1)->(d0) : reduce over j -> one value per row - # bias2 (d0,d1)->(d1) : gamma/beta indexed by column only - f32 = self.f32 - par2, red2 = self._par(), affine_map(2, [ir.AffineDimExpr.get(0)]) - bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) - inv_n = arith.constant(f32, 1.0 / float(N)) - eps_c = arith.constant(f32, eps) - zero = arith.constant(f32, 0.0) - # (1) row sums -> mean_sum (linalg.fill zeroes the accumulator first) - mean_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) - - @linalg.generic([x], [mean_acc], [par2, red2], [parallel, reduction]) - def mean_sum(v, acc): - return arith.AddFOp(v, acc) - - # (2) sum of squared deviations -> var_sum (mean_i = mean_sum_i / N) - var_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) - - @linalg.generic( - [x, mean_sum], [var_acc], [par2, red2, red2], [parallel, reduction] - ) - def var_sum(v, m_sum, acc): - mean = arith.MulFOp(m_sum, inv_n).result - c = arith.SubFOp(v, mean).result - return arith.AddFOp(arith.MulFOp(c, c).result, acc) - - # (3) normalize + scale + shift -> output - buf = self._buf((M, N), f32) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - - @linalg.generic( - [x, mean_sum, var_sum, gamma, beta], - [out_t], - [par2, red2, red2, bias2, bias2, par2], - [parallel, parallel], - ) - def normed(v, m_sum, v_sum, g, b, _o): - mean = arith.MulFOp(m_sum, inv_n).result - var = arith.MulFOp(v_sum, inv_n).result - inv_std = math.rsqrt(arith.AddFOp(var, eps_c).result) - c = arith.SubFOp(v, mean).result - return arith.AddFOp( - arith.MulFOp(arith.MulFOp(c, inv_std).result, g).result, b - ) - - bufferization.materialize_in_destination( - None, normed, buf, restrict=True, writable=True - ) - self.kinds.append("ln") - return emit_buf_to_tensor(buf, restrict=True) - - # ---- softmax(x (M,N) f32) over last dim -> (M,N) f32 buffer ---- - # ---- elementwise cast f32 -> f16 ---- - def cast_f16(self, x, M, N): - par2 = self._par() - buf = self._buf((M, N), self.f16) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - - @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) - def c(s, _o): - return arith.TruncFOp(self.f16, s) - - bufferization.materialize_in_destination( - None, c, buf, restrict=True, writable=True - ) - self.kinds.append("ew") - return emit_buf_to_tensor(buf, restrict=True) - - # ---- bias add (+ optional relu): out = max(x + bias, 0)? x (M,N) f32, bias (N,) ---- - def bias(self, x, bias_vec, M, N, relu=False, out_buf=None): - par2 = self._par() - bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) - zero = arith.constant(self.f32, 0.0) - buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - - @linalg.generic( - [x, bias_vec], [out_t], [par2, bias2, par2], [parallel, parallel] - ) - def b(v, bb, _o): - s = arith.AddFOp(v, bb).result - if relu: - return arith.MaximumFOp(s, zero) - return arith.AddFOp(s, zero) # identity wrap so the op has a body - - bufferization.materialize_in_destination( - None, b, buf, restrict=True, writable=True - ) - self.kinds.append("ew") - if out_buf is not None: - return None - return emit_buf_to_tensor(buf, restrict=True) - - # ---- residual add: out = a + b (both (M,N) f32) ---- - def add(self, a, b, M, N, out_buf=None): - par2 = self._par() - buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - - @linalg.generic([a, b], [out_t], [par2, par2, par2], [parallel, parallel]) - def r(x, y, _o): - return arith.AddFOp(x, y) - - bufferization.materialize_in_destination( - None, r, buf, restrict=True, writable=True - ) - self.kinds.append("ew") - if out_buf is not None: - return None - return emit_buf_to_tensor(buf, restrict=True) - - # ---- cast f32 (T,C) -> f16 (T,C), returning the MEMREF buffer (for views) ---- - def cast_f16_buf(self, x, T, C): - par2 = self._par() - buf = self._buf((T, C), self.f16) - out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) - - @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) - def c(s, _o): - return arith.TruncFOp(self.f16, s) - - bufferization.materialize_in_destination( - None, c, buf, restrict=True, writable=True - ) - self.kinds.append("ew") - return buf - - # ---- view a (T, H*hs) MEMREF as (H, T, hs) -- NO kernel, NO data move ---- - def _heads_view_of(self, buf2d, T, H, hs): - # We present the 2D - # (T, H*hs) projection buffer as a (H,T,hs) STRIDED memref VIEW: - # (T,H*hs) --memref.expand_shape--> (T,H,hs) [strides C,hs,1] - # --memref.transpose [1,0,2]--> (H,T,hs) [strides hs,C,1] - # Both are pure layout ops (no compute, no kinds entry). When the fused - # schedule tiles (1,wg_rows,0,0), the grid peels head h -> a 2D - # memref> -> 2D load_nd (XeGPU supports - # such strided block loads). - C = H * hs - et = buf2d.type.element_type - exp_t = ir.MemRefType.get((T, H, hs), et) - e = memref.expand_shape( - exp_t, buf2d, [[0], [1, 2]], [], static_output_shape=[T, H, hs] - ) - d0, d1, d2 = (ir.AffineDimExpr.get(i) for i in range(3)) - perm = ir.AffineMap.get(3, 0, [d1, d0, d2]) # (H,T,hs) <- (T,H,hs) - layout = ir.StridedLayoutAttr.get(0, [hs, C, 1]) - res_t = ir.MemRefType.get((H, T, hs), et, layout=layout) - return memref.transpose(res_t, e, perm) - - def heads_view(self, buf2d, T, H, hs): - return emit_buf_to_tensor(self._heads_view_of(buf2d, T, H, hs), restrict=True) - - # ---- fused multi-head attention core on 3D (H,T,hs) f16 -> (H,T,hs) f16 ---- - # (named attention_4d because it is the canonical 4D (Z,H,T,hs) attention - # algorithm with the batch dim Z=1 FOLDED OUT: one sequence, so (1,H,T,hs) - # collapses to (H,T,hs) and linalg.batch_matmul treats H as the batch axis.) - def attention_4d(self, Qh, Kh, Vh, H, T, hs, out_view, out_view_memref): - # Emits the SAME linalg op sequence as generate_gpu_attention_payload - # (batch_matmul QK^T -> scale-mul -> softmax -> batch_matmul @V), so the - # fused-attention schedule's matchers/rewrite apply verbatim. After the - # per-region fused tiling, ALL these ops fuse into ONE scf.forall -> ONE - # GPU kernel (the flash/online-softmax kernel). Counts as one 'fa'. - # Inputs Qh/Kh/Vh are (H,T,hs) f16 strided VIEWS (heads_view); the @V result - # is materialized into `out_view`, a (H,T,hs) strided view of a (T,C) buffer, - # so the merge back to 2D is also a free view (no from_heads kernel). - f16 = self.f16 - scale = 1.0 / (hs**0.5) - zero = arith.constant(f16, 0.0) - # K^T: (H,T,hs) -> (H,hs,T). Lowers to a 2D vector.transpose per head (the - # grid peels H), exactly like the standalone -- f16 is fine here. - Kt = linalg.transpose( - Kh, outs=[tensor.empty((H, hs, T), f16)], permutation=[0, 2, 1] - ) - qkt_init = linalg.fill(zero, outs=[tensor.empty((H, T, T), f16)]) - qkt = linalg.batch_matmul(Qh, Kt, outs=[qkt_init]) - sc = arith.constant(f16, scale) - scale_t = linalg.fill(sc, outs=[tensor.empty((H, T, T), f16)]) - scaled = linalg.mul(qkt, scale_t, outs=[tensor.empty((H, T, T), f16)]) - aw = linalg.softmax( - result=[ir.RankedTensorType.get((H, T, T), f16)], - input=scaled, - output=tensor.empty((H, T, T), f16), - dimension=2, - ) - # @V: (H,T,T) @ (H,T,hs) -> (H,T,hs) f16, materialized into the (T,C) view. - out_filled = linalg.fill(zero, outs=[out_view]) - out = linalg.batch_matmul(aw, Vh, outs=[out_filled]) - bufferization.materialize_in_destination( - None, out, out_view_memref, restrict=True, writable=True - ) - self.kinds.append("fa") - - # ---- fused multi-head attention(ln_f32 (T,C) f32) -> (T,C) f16, NON-CAUSAL ---- - def fused_attention(self, x, wq, wk, wv, T, C, H): - # True multi-head attention via the flash kernel, with NO on-device - # head-transpose kernel. Flow: - # x(f32) -cast-> f16 -q/k/v proj-> (T,C) f16 buffers -heads_view (free)-> - # (H,T,hs) strided views -> attention_4d (fused flash kernel) -> @V written - # into a (T,C) f16 buffer via its (H,T,hs) view -> return that (T,C) f16. - hs = C // H - x16 = self.cast_f16(x, T, C) # ew - qbuf = self.cast_f16_buf( - self.matmul(x16, wq, T, C), T, C - ) # mm, ew -> (T,C) f16 memref - kbuf = self.cast_f16_buf(self.matmul(x16, wk, T, C), T, C) # mm, ew - vbuf = self.cast_f16_buf(self.matmul(x16, wv, T, C), T, C) # mm, ew - Qh = self.heads_view(qbuf, T, H, hs) # (H,T,hs) strided view (free) - Kh = self.heads_view(kbuf, T, H, hs) - Vh = self.heads_view(vbuf, T, H, hs) - # Output (T,C) f16 buffer, viewed as (H,T,hs) for the @V store. - out_buf = self._buf((T, C), self.f16) - out_view_memref = self._heads_view_of(out_buf, T, H, hs) - out_view = emit_buf_to_tensor(out_view_memref, restrict=True, writable=True) - self.attention_4d( - Qh, Kh, Vh, H, T, hs, out_view, out_view_memref - ) # fa, writes out_buf - return emit_buf_to_tensor(out_buf, restrict=True) # (T,C) f16 +def _ln(x, gamma, beta, eps=1e-5): + mu = x.mean(-1, keepdims=True) + var = x.var(-1, keepdims=True) + return (x - mu) / np.sqrt(var + eps) * gamma + beta -# --------------------------------------------------------------------------- -# PAYLOAD ASSEMBLY -- wire the Builder ops into a complete MLIR function. -# Each build_*_payload creates one `func.func` (the "payload") whose arguments -# are the input + all weights (as device memrefs) and whose body is the op graph. -# `emit_buf_to_tensor` turns a memref argument into a tensor the ops can read; -# `func_cif` makes the function callable from C/the Runner. Returns (module, -# kinds) where `kinds` drives the schedule. -# --------------------------------------------------------------------------- +def _f16(a): + # round f32 -> f16 -> f32: models the precision loss of the GPU's f16 matmul. + return a.astype(np.float16).astype(np.float32) def _mha(q, k, v, H, causal=False): @@ -439,29 +107,8 @@ def _mha(q, k, v, H, causal=False): return attn -def _emit_block_fused(bld, x, w, T, C, hidden, H, eps, out_buf=None): - """Like _emit_block but the attention sublayer is the FUSED multi-head flash - kernel (NON-CAUSAL, no mask). `w` weight keys: g1,b1n, wq,wk,wv, wp,bp, - g2,b2n, w1,bb1,w2,bb2. wq/wk/wv/wp are full (C,C).""" - # ---- attention sublayer: a = x + proj(fused_attn(ln1(x))) ---- - ln1 = bld.layernorm(x, w["g1"], w["b1n"], T, C, eps) - attn16 = bld.fused_attention(ln1, w["wq"], w["wk"], w["wv"], T, C, H) # f16 (T,C) - proj = bld.matmul(attn16, w["wp"], T, C) - proj = bld.bias(proj, w["bp"], T, C, relu=False) - a = bld.add(x, proj, T, C) - # ---- FFN sublayer: y = a + ffn(ln2(a)) ---- - ln2 = bld.layernorm(a, w["g2"], w["b2n"], T, C, eps) - ln2_16 = bld.cast_f16(ln2, T, C) - h = bld.matmul(ln2_16, w["w1"], T, hidden) - h = bld.bias(h, w["bb1"], T, hidden, relu=True) - h16 = bld.cast_f16(h, T, hidden) - o = bld.matmul(h16, w["w2"], T, C) - o = bld.bias(o, w["bb2"], T, C, relu=False) - return bld.add(a, o, T, C, out_buf=out_buf) - - def numpy_ref_block_fused(x, w, H, eps=1e-5, causal=False): - """Multi-head block reference (matches _emit_block_fused).""" + """Multi-head block reference (matches _emit_block_fused in the payload).""" ln1 = _f16(_ln(x, w["g1"], w["b1n"], eps)) q = _f16(ln1 @ w["wq"].astype(np.float32)) k = _f16(ln1 @ w["wk"].astype(np.float32)) @@ -475,106 +122,8 @@ def numpy_ref_block_fused(x, w, H, eps=1e-5, causal=False): return a + o -def build_gpt_fused_payload(func_name, T, C, hidden, vocab, n_layer, H, eps=1e-5): - """Full nanoGPT.py forward as ONE module, with FUSED multi-head attention per block - (R3/R4). Multi-head (H heads, flash attention), - NON-CAUSAL (no mask), wq/wk/wv/wp are (C,C). Embeddings done host-side.""" - f32, f16 = F32(), F16() - mod = ir.Module.create() - x_t = ir.MemRefType.get((T, C), f32) # input activations (256,256) f32 - g_t = ir.MemRefType.get((C,), f32) # layernorm gamma/beta vectors (256,) f32 - wqkv_t = ir.MemRefType.get((C, C), f16) # q/k/v projection weights (256,256) f16 - wproj_t = ir.MemRefType.get( - (C, C), f16 - ) # attention output proj weight (256,256) f16 - bvec_t = ir.MemRefType.get((C,), f32) # bias vectors (256,) f32 - w1_t = ir.MemRefType.get((C, hidden), f16) # FFN up-projection (256,1024) f16 - b1_t = ir.MemRefType.get((hidden,), f32) # FFN hidden bias (1024,) f32 - w2_t = ir.MemRefType.get((hidden, C), f16) # FFN down-projection (1024,256) f16 - lmw_t = ir.MemRefType.get((C, vocab), f16) # lm_head weight (256,256) f16 - lmb_t = ir.MemRefType.get((vocab,), f32) # lm_head bias (256,) f32 - out_t = ir.MemRefType.get((T, vocab), f32) # output logits (256,256) f32 - # per-layer arg types: g1,b1n, wq,wk,wv, wp,bp, g2,b2n, w1,bb1,w2,bb2 (13) -- NO mask. - per_layer = [ - g_t, - g_t, - wqkv_t, - wqkv_t, - wqkv_t, - wproj_t, - bvec_t, - g_t, - g_t, - w1_t, - b1_t, - w2_t, - bvec_t, - ] - from lighthouse.utils.mlir import func_cif - - fargs = [out_t, x_t] - for _ in range(n_layer): - fargs += per_layer - fargs += [g_t, g_t, lmw_t, lmb_t] # ln_f gamma/beta, lm_head W,b (no mask) - bld = Builder(T) - with ir.InsertionPoint(mod.body): - - @func_cif(*fargs, name=func_name) - def payload(*args): - output = args[0] - emit_buf_to_tensor(output, restrict=True, writable=True) - x = emit_buf_to_tensor(args[1], restrict=True) - idx = 2 - layer_w = [] - keys = [ - "g1", - "b1n", - "wq", - "wk", - "wv", - "wp", - "bp", - "g2", - "b2n", - "w1", - "bb1", - "w2", - "bb2", - ] - for _ in range(n_layer): - w = { - k: emit_buf_to_tensor(args[idx + i], restrict=True) - for i, k in enumerate(keys) - } - idx += len(keys) - layer_w.append(w) - gf_g = emit_buf_to_tensor(args[idx], restrict=True) - idx += 1 - gf_b = emit_buf_to_tensor(args[idx], restrict=True) - idx += 1 - lmw = emit_buf_to_tensor(args[idx], restrict=True) - idx += 1 - lmb = emit_buf_to_tensor(args[idx], restrict=True) - idx += 1 - - h = x - for w in layer_w: - h = _emit_block_fused(bld, h, w, T, C, hidden, H, eps) - hf = bld.layernorm(h, gf_g, gf_b, T, C, eps) - hf16 = bld.cast_f16(hf, T, C) - logits = bld.matmul(hf16, lmw, T, vocab) - bld.bias(logits, lmb, T, vocab, relu=False, out_buf=output) - for b in bld.to_dealloc: - gpu.dealloc(None, [], b) - - emit_gpu_util_funcs(f32, rank=2) - emit_gpu_util_funcs(f32, rank=1) - emit_gpu_util_funcs(f16, rank=2) - return mod, bld.kinds - - def numpy_ref_gpt_fused(x, layer_w, gf_g, gf_b, lmw, lmb, H, eps=1e-5): - """Non-causal multi-head full-gpt reference.""" + """Non-causal multi-head full-gpt reference (matches build_gpt_fused_payload).""" h = x for w in layer_w: h = numpy_ref_block_fused(h, w, H, eps) @@ -582,642 +131,6 @@ def numpy_ref_gpt_fused(x, layer_w, gf_g, gf_b, lmw, lmb, H, eps=1e-5): return _f16(hf) @ lmw.astype(np.float32) + lmb -# ============================================================================= -# STAGE 2 -- SCHEDULE: describe HOW to lower the payload to the GPU -# ============================================================================= -# A "schedule" here is itself an MLIR module written in the TRANSFORM dialect: a -# little program of rewrite ops that the transform interpreter runs over the -# payload module. It does NOT compute anything; it REWRITES the payload from -# high-level linalg ops down to GPU (XeGPU) kernels. -# -# We can't reuse the repo's per-op schedules (layer_norm_schedule, mlp_schedule, -# softmax_schedule) directly, because each assumes the module contains ONLY its -# op. Our module is mixed (matmul + layernorm + softmax + elementwise), so we -# build ONE COMBINED schedule that handles all op classes. The strategy: -# -# (a) TILE each op into its own parallel loop nest (`scf.forall` = the GPU -# work-group grid). Different op classes tile differently: -# - matmul -> `_tile_one_matmul` (work-group tile + k-loop tile; the -# DPAS tile sizes come from `mm_params`) -# - layernorm-> `_tile_one_layernorm` (tile rows, fuse the 2 reductions + -# 2 zero-fills into the loop) -# - fused attn-> `_tile_one_fused_attention_region` (tile @V batch_matmul into -# a forall, fuse QK^T/scale/softmax/@V in; flash rewrite later) -# - elementwise -> a single `structured_tile_using_forall` over rows -# (b) SHARED TAIL (same for every kernel): vectorize -> bufferize (tensors -> -# memrefs) -> convert the forall grids to `gpu.launch` -> OUTLINE each into -# its own `gpu.module`/`gpu.func` kernel -> attach the XeVM target. -# (c) ANNOTATE each kernel with XeGPU layout attributes (how data maps to -# sub-groups / DPAS tiles). -# -# `kinds` (from the Builder) tells the schedule the class and order of every -# kernel, so steps (a) and (c) can treat each one correctly. See memory.md -# parts 6-10 for the subtle correctness rules baked into the helpers below. -# ============================================================================= -def _tile_one_matmul(matmul_op, anytype, mm_params): - """Tile ONE matmul for DPAS: a work-group `forall` tile (wg_m x wg_n) with any - elementwise consumer fused in, then an inner reduction (k) loop. Tile sizes - come from `mm_params` (chosen by xegpu_parameter_selector for the GPU).""" - wg_tile = [mm_params["wg_m"], mm_params["wg_n"]] - consumers = transform_ext.get_tileable_consumers(matmul_op) - leaf = transform_ext.extract_handle(consumers, -1) - _, [wg_loop], _ = lh_transform.tile( - leaf, - tile_sizes=wg_tile, - fuse_producers=True, - use_forall=True, - apply_cleanup=False, - ) - wg_matmul = match(wg_loop, ops={"linalg.matmul"}) - lh_transform.tile(wg_matmul, tile_sizes=[0, 0, mm_params["k_tile"]]) - - -def _tile_one_fused_attention_region(anytype, qkt_bmm, pv_bmm, softmax_op, fa_params): - """Tile + fuse ONE attention region (QK^T -> scale -> softmax -> @V) into a - SINGLE scf.forall, so it vectorizes/bufferizes into one kernel body that - `replace_with_fused_attention` later rewrites into the flash loop. - - Operates on PRE-SPLIT, per-region - handles (qkt_bmm, pv_bmm, softmax_op) so it is region-local and works at any - multiplicity. All further producers are pulled in via get_producer_of_operand - (SSA-walk = inherently scoped to this region).""" - prod = transform.get_producer_of_operand - - def fuse(p, c): - return structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op=p, containing_op=c - )[1] - - wg_rows = fa_params["wg_rows"] - # 1. Tile the @V batch_matmul in (batch=1, M=wg_rows) -> forall grid. - tiled_pv, forall = structured.structured_tile_using_forall( - anytype, - anytype, - pv_bmm, - num_threads=[], - tile_sizes=[], - static_tile_sizes=(1, wg_rows, 0, 0), - ) - func = transform.get_parent_op( - anytype, forall, op_name="func.func", deduplicate=True - ) - # 2. Fuse the @V output init fill (producer of forall operand 0). - forall = fuse(prod(anytype, forall, operand_number=0), forall) - transform.apply_cse(func) - canonicalize(func) - # 3. Decompose this region's softmax. linalg.softmax -> 4 generics + 2 fills: - # max = reduce_max(scaled) [+ -inf fill] - # num = exp(scaled - max) - # den = reduce_sum(num) [+ 0 fill] - # div = num / den (feeds @V) - structured.structured_decompose_interface(anytype, softmax_op) - transform.apply_cse(func) - canonicalize(func) - # Grab the whole producer chain UP FRONT via SSA walk (region-local; no count - # matching). Fusing op X invalidates only X's handle, so collect all, then fuse - # each once in consumer->producer topological order. - # tiled_pv operand 0 is the aw extract_slice inside the forall; hop through it - # to the func-scope softmax `div` that it slices. - aw_slice = prod(anytype, tiled_pv, operand_number=0) - div = prod(anytype, aw_slice, operand_number=0) # num / den (softmax out) - num = prod(anytype, div, operand_number=0) # exp generic - den = prod(anytype, div, operand_number=1) # sum-reduce generic - den_fill = prod(anytype, den, operand_number=1) # 0 fill (sum acc) - mx = prod(anytype, num, operand_number=1) # max-reduce generic - mx_fill = prod(anytype, mx, operand_number=1) # -inf fill (max acc) - scaled = prod(anytype, num, operand_number=0) # linalg.mul (qkt*scale) - scale_fill = prod(anytype, scaled, operand_number=1) # scale-constant fill - qkt = prod(anytype, scaled, operand_number=0) # QK^T batch_matmul - kt = prod(anytype, qkt, operand_number=1) # K^T transpose - qkt_fill = prod(anytype, qkt, operand_number=2) # 0 fill (qkt acc) - for p in ( - div, - den, - num, - mx, - scaled, - qkt, - den_fill, - mx_fill, - scale_fill, - qkt_fill, - kt, - ): - forall = fuse(p, forall) - transform.apply_cse(func) - canonicalize(func) - return func, forall - - -def _fuse_attention_in_region(anytype, forall, fa_params): - """After the shared bufferize+vectorize, rewrite ONE attention region's - vector.contract pair (QK^T, @V) into the flash loop via the transform - op. Scoped to `forall` so counts are exact at any multiplicity.""" - contract_ops = match_and_split(forall, ops={"vector.contract"}, nhandles=2) - first_contract, second_contract = contract_ops[0], contract_ops[1] - q_load = transform.get_producer_of_operand( - anytype, first_contract, operand_number=0 - ) - k_load = transform.get_producer_of_operand( - anytype, first_contract, operand_number=1 - ) - v_load = transform.get_producer_of_operand( - anytype, second_contract, operand_number=1 - ) - mulf_op = match_and_split(forall, ops={"arith.mulf"}, nhandles=1)[0] - scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) - # NB: the merged fused-attention op is non-causal only -- there is - # no `causal` parameter yet, so the model runs as non-causal attention. - transform_ext.replace_with_fused_attention( - q_load=q_load, - k_load=k_load, - v_load=v_load, - scale=scale, - output=second_contract, - tile_size=fa_params["inner_loop_tile_size"], - ) - - -def xegpu_fa_annotation(gf, anytype, fa_params): - """Attach XeGPU layouts to ONE fused-attention gpu.func.""" - num_subgroups = fa_params["wg_rows"] // fa_params["sg_rows"] - n_head = fa_params["n_head"] - q_sg_layout = [num_subgroups, 1] - q_sg_data = [16, n_head] - q_inst_data = [8, 16] - k_sg_layout = [num_subgroups, 1] - k_sg_data = [16, n_head] - k_inst_data = [16, 16] - v_sg_layout, v_sg_data, v_inst_data = k_sg_layout, k_sg_data, k_inst_data - kt_sg_layout = [1, num_subgroups] - kt_sg_data = [n_head, 16] - kt_inst_data = [16, 16] - kt_order = [0, 1] - out_sg_layout, out_sg_data, out_inst_data = q_sg_layout, q_sg_data, q_inst_data - l128_sg_layout = [num_subgroups, 1] - l128_sg_data = [16, 16] - l128_inst_data = [8, 16] - qk_sg_layout, qk_sg_data, qk_inst_data = ( - l128_sg_layout, - l128_sg_data, - l128_inst_data, - ) - - store_nd_op = match_and_split(gf, ops={"xegpu.store_nd"}, nhandles=1)[0] - xegpu.set_anchor_layout( - store_nd_op, - sg_layout=out_sg_layout, - sg_data=out_sg_data, - inst_data=out_inst_data, - ) - load_nd_ops = match_and_split(gf, ops={"xegpu.load_nd"}, nhandles=9) - xegpu.set_anchor_layout( - load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data - ) - for i in range(1, 5): - xegpu.set_anchor_layout( - load_nd_ops[i], - sg_layout=k_sg_layout, - sg_data=k_sg_data, - inst_data=k_inst_data, - ) - for i in range(5, 9): - xegpu.set_anchor_layout( - load_nd_ops[i], - sg_layout=v_sg_layout, - sg_data=v_sg_data, - inst_data=v_inst_data, - ) - dpas_ops = match_and_split(gf, ops={"xegpu.dpas"}, nhandles=8) - for i in range(4): - d = dpas_ops[i] - xegpu.set_anchor_layout( - d, sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data, index=0 - ) - xegpu.set_anchor_layout( - d, - sg_layout=kt_sg_layout, - sg_data=kt_sg_data, - inst_data=kt_inst_data, - order=kt_order, - index=1, - ) - xegpu.set_anchor_layout( - d, - sg_layout=l128_sg_layout, - sg_data=l128_sg_data, - inst_data=l128_inst_data, - index=2, - ) - for i in range(4, 8): - d = dpas_ops[i] - xegpu.set_anchor_layout( - d, - sg_layout=qk_sg_layout, - sg_data=qk_sg_data, - inst_data=qk_inst_data, - index=0, - ) - xegpu.set_anchor_layout( - d, sg_layout=v_sg_layout, sg_data=v_sg_data, inst_data=v_inst_data, index=1 - ) - xegpu.set_anchor_layout( - d, - sg_layout=out_sg_layout, - sg_data=out_sg_data, - inst_data=out_inst_data, - index=2, - ) - - -def build_combined_schedule( - mm_params, ln_params, kinds, stop_at_stage="", fa_params=None -): - """Build the transform-dialect schedule module for a payload with op classes - `kinds`. Counts how many of each class there are, then delegates to `_bundle` - (wrapped in transform boilerplate). `stop_at_stage` lets callers halt early - for debugging (--dump ).""" - n_mm = kinds.count("mm") - n_ln = kinds.count("ln") - n_sm = kinds.count("sm") - n_ew = kinds.count("ew") - with schedule_boilerplate() as (schedule, named_seq): - anytype = transform.AnyOpType.get() - func0 = match(named_seq.bodyTarget, ops={"func.func"}) - mod = transform.get_parent_op( - anytype, func0, op_name="builtin.module", deduplicate=True - ) - try: - _bundle( - mod, - mm_params, - ln_params, - kinds, - n_mm, - n_ln, - n_sm, - n_ew, - stop_at_stage, - fa_params=fa_params, - ) - except PipelineInterrupt: - pass - finally: - transform.yield_() - return schedule - - -def _bundle( - mod, - mm_params, - ln_params, - kinds, - n_mm, - n_ln, - n_sm, - n_ew, - stop_at_stage="", - fa_params=None, -): - """THE PASS ORCHESTRATOR -- emits the actual sequence of transform ops. - - Runs in 3 phases over the whole payload module: - TILE -- tile every op into a GPU work-group `forall` (per op class) - SHARED TAIL -- vectorize, bufferize, forall->gpu.launch, outline kernels, - attach the XeVM target, lower vector ops to XeGPU - ANNOTATE -- attach XeGPU sub-group/DPAS layout to each kernel - `stop_at_stage` raises PipelineInterrupt to halt after a phase (for --dump). - Reading the inline comments here is the best way to understand "which part of - the code schedules the passes" -- it is this function, top to bottom.""" - anytype = transform.AnyOpType.get() - rss = ln_params["reduction_step_size"] - wg_rows = ln_params["wg_rows"] - nkernels = len(kinds) - n_fa = kinds.count("fa") - - if stop_at_stage == "initial": - raise PipelineInterrupt() - - # ===== TILE each op-class into its own forall ===== - # KEY PROBLEM: match(linalg.generic) is NOT scoped -- once an op is tiled into - # a forall, its generic is STILL matched (it's just nested), so we can't - # re-match "the remaining bare generics" by count. SOLUTION: split ALL generic - # handles ONCE up front (their build order is deterministic), then tile each - # using its preserved handle. A handle to op X stays valid across tiling of - # OTHER ops. We tile the simple EW generics first (no fusion/cleanup, so ln - # handles survive), then the layernorms (which fuse + cleanup). - # - # Generic build order: each layernorm contributes [mean, var, normalize] (3), - # in block build order; each elementwise contributes 1. We reconstruct the - # per-op handle slices from `kinds`. - # 'fa' softmax generics do NOT exist yet (fa is tiled last, softmax still - # un-decomposed), so they are not in this pool. The fa core's linalg.transpose - # /linalg.mul/batch_matmul are not linalg.generic, so also excluded. (The head - # reshape is a pure memref VIEW -- no generic, no kernel; see Builder.heads_view.) - ngen_total = 3 * n_ln + n_ew - gen_handles = transform.split_handle( - (anytype,) * ngen_total, match(mod, ops={"linalg.generic"}) - ) - # Walk kinds to assign generic handles to ops. - ln_slices, ew_handles = [], [] - gi = 0 - for k in kinds: - if k == "ln": - ln_slices.append( - (gen_handles[gi], gen_handles[gi + 1], gen_handles[gi + 2]) - ) - gi += 3 - elif k == "ew": - ew_handles.append(gen_handles[gi]) - gi += 1 - # mm / sm / fa contribute no bare linalg.generic here - - # 1) Tile layernorms FIRST, using preserved (mean,var,normalize) handles. - # Doing this BEFORE EW/matmul tiling keeps the bare linalg.fill pool exactly - # predictable: 2*(untiled lns) + n_mm (matmul accumulator fills). EW tiling - # can introduce its own init fills, so we must finish ln fill-fusion first. - for i, (mean_red, var_red, normalize) in enumerate(ln_slices): - ln_untiled = n_ln - i - _tile_one_layernorm( - mod, - anytype, - wg_rows, - rss, - mean_red, - var_red, - normalize, - ln_untiled, - n_mm, - ln_params["T"], - ) - - # 2) Tile EW generics into own foralls (handles preserved across ln tiling). - for eg in ew_handles: - structured.structured_tile_using_forall( - anytype, - anytype, - eg, - num_threads=[], - tile_sizes=[], - static_tile_sizes=(wg_rows,), - ) - - # 4) Matmuls (their EW producers already wrapped in foralls) - mms = match_and_split(mod, ops={"linalg.matmul"}, nhandles=n_mm) - for mm in mms: - _tile_one_matmul(mm, anytype, mm_params) - - # 5) Fused-attention regions. Done LAST so the generic pre-split above ran while - # each fa softmax was still ONE linalg.softmax (its decomposition generics - # don't exist yet, so they can't inflate ngen_total). Pre-split the 2*n_fa - # batch_matmuls (build order [QK^T, @V] per region) + n_fa softmaxes by count, - # then tile+fuse each region into ONE forall (decompose happens in-region). - if n_fa: - fa_bmms = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2 * n_fa) - fa_softmaxes = match_and_split(mod, ops={"linalg.softmax"}, nhandles=n_fa) - for r in range(n_fa): - _tile_one_fused_attention_region( - anytype, fa_bmms[2 * r], fa_bmms[2 * r + 1], fa_softmaxes[r], fa_params - ) - - func = match(mod, ops={"func.func"}) - lh_transform.cleanup(func) - if stop_at_stage == "tiled": - raise PipelineInterrupt() - - # ===== SHARED TAIL ===== - func = structured.structured_vectorize_children_and_apply_patterns( - anytype, func, fold_type_extensions_into_contract=True - ) - lh_transform.cleanup(func) - # Fused-attention regions carry a batch-of-1 dim from the (1,wg_rows,0,0) tiling; - # drop leading unit dims so the QK^T/@V vector.contracts become 2D, as the flash - # rewrite expects. - if n_fa: - with ir.InsertionPoint(transform.apply_patterns(func).patterns): - apply_patterns_vector_cast_away_vector_leading_one_dim() - apply_patterns_vector_drop_unit_dims_with_shape_cast() - transform.apply_cse(func) - canonicalize(func) - if stop_at_stage == "vectorized": - raise PipelineInterrupt() - - mod = apply_registered_pass(mod, "eliminate-empty-tensors") - mod = transform_bufferization.OneShotBufferizeOp( - mod, - allow_return_allocs_from_loops=True, - bufferize_function_boundaries=True, - function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap, - ).result - mod = apply_registered_pass(mod, "fold-memref-alias-ops") - transform.apply_cse(mod) - canonicalize(mod) - - func = match(mod, ops={"func.func"}) - func = apply_registered_pass( - func, - "promote-buffers-to-stack", - options={ - "max-alloc-size-in-bytes": "8192", - "max-rank-of-allocated-memref": "2", - }, - ) - if stop_at_stage == "bufferized": - raise PipelineInterrupt() - - # ===== FUSED-ATTENTION REWRITE (after bufferize+vectorize, before gpu.launch) ===== - # Re-find each attention forall by kinds index (forall IR order == kinds order, - # the invariant the launch/gpu_mods loops below also rely on) and rewrite its - # QK^T/@V vector.contract pair into the flash online-softmax loop. Must run - # BEFORE forall->gpu.launch so the producer-walks for q/k/v loads stay in-region. - if n_fa: - all_foralls = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) - for idx, kind in enumerate(kinds): - if kind == "fa": - _fuse_attention_in_region(anytype, all_foralls[idx], fa_params) - func = match(mod, ops={"func.func"}) - transform.apply_cse(func) - canonicalize(func) - if stop_at_stage == "inner-tiled": - raise PipelineInterrupt() - - wg_loops = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) - for wg_loop in wg_loops: - loop.loop_forall_to_parallel([anytype], wg_loop) - func = match(mod, ops={"func.func"}) - func = apply_registered_pass(func, "gpu-map-parallel-loops") - func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") - func = apply_registered_pass(func, "lower-affine") - transform.apply_cse(func) - canonicalize(func) - - # launch threads per kernel, in IR (build) order = `kinds`. - launches = match_and_split(mod, ops={"gpu.launch"}, nhandles=nkernels) - mm_threads = ( - (mm_params["wg_m"] // mm_params["sg_m"]) - * (mm_params["wg_n"] // mm_params["sg_n"]) - * 16 - ) - sm_threads = (ln_params["wg_rows"] // ln_params["sg_rows"]) * ln_params[ - "subgroup_size" - ] - fa_threads = ( - (fa_params["wg_rows"] // fa_params["sg_rows"]) * fa_params["subgroup_size"] - if fa_params - else 0 - ) - for launch, kind in zip(launches, kinds): - nt = {"mm": mm_threads, "fa": fa_threads}.get(kind, sm_threads) - xegpu.set_gpu_launch_threads(launch, threads=[nt, 1, 1]) - - func = apply_registered_pass(func, "lower-affine") - canonicalize(func) - func = apply_registered_pass(func, "gpu-launch-sink-index-computations") - mod = apply_registered_pass(mod, "gpu-kernel-outlining") - transform.apply_cse(mod) - if stop_at_stage == "gpu-outlining": - raise PipelineInterrupt() - - mod = apply_registered_pass( - mod, "xevm-attach-target", options={"O": "3", "chip": "pvc"} - ) - - # per-gpu.module convert-vector-to-xegpu. ONLY ln/sm need SLM allocas (their - # cross-lane reductions go through shared local memory -> store_matrix). The - # ew kernels (cast/bias/residual) are pure row-parallel: forcing their allocas - # to SLM creates store_matrix paths that fail to lower. So SLM-ify ln/sm only; - # leave ew (and mm) as store_nd. - gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) - sg_layout = [ln_params["sg_rows"], 1] - sg_data = [ln_params["sg_rows"], rss] - for gm, kind in zip(gpu_mods, kinds): - gf = match(gm, ops={"gpu.func"}) - if kind in ("ln", "sm"): - allocas = match(gf, ops={"memref.alloca"}) - transform_ext.update_address_space(allocas, address_space=3) - gf = apply_registered_pass(gf, "convert-vector-to-xegpu") - transform.apply_cse(gf) - if kind == "fa": - # flash kernel carries state in iter_args (no SLM); hoist invariants. - gf = apply_registered_pass(gf, "loop-invariant-code-motion") - transform.apply_cse(mod) - canonicalize(mod) - if stop_at_stage == "xegpu-initial": - raise PipelineInterrupt() - - # ===== PER-KERNEL ANNOTATION ===== - # mm -> full mlp wg annotation - # ln -> store_nd (1) + store_matrix (the SLM reduction stores) - # sm -> store_nd (1) + store_matrix (4) - # ew -> store_nd (1) only (pure row-parallel, no SLM) - gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) - for gm, kind in zip(gpu_mods, kinds): - gf = match(gm, ops={"gpu.func"}) - if kind == "mm": - xegpu_wg_annotation_for_mlp_layer(gf, **mm_params) - elif kind == "fa": - xegpu_fa_annotation(gf, anytype, fa_params) - else: - # ln/sm/ew: anchor-layout their store_nd, and (ln/sm) their SLM - # store_matrix. Pass the whole match handle to set_anchor_layout (it - # accepts a multi-handle) -- avoids guessing exact store counts. - xegpu.set_anchor_layout( - match(gf, ops={"xegpu.store_nd"}), sg_layout=sg_layout, sg_data=sg_data - ) - if kind in ("ln", "sm"): - xegpu.set_anchor_layout( - match(gf, ops={"xegpu.store_matrix"}), - sg_layout=sg_layout, - sg_data=sg_data, - ) - if stop_at_stage == "xegpu-wg": - raise PipelineInterrupt() - return mod - - -def _tile_one_layernorm( - mod, anytype, wg_rows, rss, mean_red, var_red, normalize, ln_untiled, n_mm, T_ROWS -): - """Tile ONE layernorm into its own forall, using PRESERVED handles to its 3 - generics (mean_red, var_red, normalize). Handles to other ops stay valid. - - The 2 accumulator fills are matched by their producer relationship: we match - all fills and fuse the ones that feed this ln. To avoid touching matmul fills, - we rely on fuse_into_containing pulling only genuine producers of the forall. - """ - _, ln_forall = structured.structured_tile_using_forall( - anytype, - anytype, - normalize, - num_threads=[], - tile_sizes=[], - static_tile_sizes=(wg_rows,), - ) - _, ln_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op=var_red, containing_op=ln_forall - ) - _, ln_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op=mean_red, containing_op=ln_forall - ) - # Fuse this ln's 2 accumulator fills into the forall. Robustly select ONLY the - # layernorm accumulator fills (NOT matmul fills) by filtering on result type: - # ln accumulators are rank-1 tensor; matmul accumulators are rank-2. - # This avoids fragile positional counting across the whole block. There are - # 2*ln_untiled such rank-1 fills (this ln + other untiled lns); this ln's are - # the FIRST 2 in IR order. - ln_func = transform.get_parent_op( - anytype, ln_forall, op_name="func.func", deduplicate=True - ) - reduce_t = ir.RankedTensorType.get((T_ROWS,), F32()) # ln accumulator type (T,) - fill_match = structured.MatchOp( - anytype, ln_func, ops=["linalg.fill"], filter_result_type=reduce_t - ) - n_ln_fills = 2 * ln_untiled - fills = transform.split_handle((anytype,) * n_ln_fills, fill_match.results[0]) - _, ln_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op=fills[1], containing_op=ln_forall - ) - _, ln_forall = structured.structured_fuse_into_containing_op( - anytype, anytype, producer_op=fills[0], containing_op=ln_forall - ) - # Fusion leaves the full-size original fills DEAD at func scope (fusion only - # slices a copy inside the forall). They must be removed or the next ln finds - # too many. Use canonicalize (which does DCE of the dead originals) at FUNC - # scope, but NEVER apply_cse at func scope -- CSE would merge the identical - # live zero-fills ACROSS layernorms. CSE the duplicate GENERICS inside the - # forall only (scoped), so the re-match below finds exactly 3. - transform.apply_cse(ln_forall) - canonicalize(ln_func) - # tile this ln's reductions+normalize (now inside the forall). Re-match the - # 3 generics INSIDE the forall (scoped to ln_forall, so unambiguous: exactly 3). - g2 = match_and_split(ln_forall, ops={"linalg.generic"}, nhandles=3) - structured.TileUsingForOp(g2[2], sizes=[0, rss]) - structured.structured_tile_reduction_using_for( - [anytype], anytype, anytype, anytype, target=g2[1], tile_sizes=[0, rss] - ) - structured.structured_tile_reduction_using_for( - [anytype], anytype, anytype, anytype, target=g2[0], tile_sizes=[0, rss] - ) - transform.apply_cse(ln_forall) - canonicalize(ln_forall) - - -# ============================================================================= -# NUMPY REFERENCE -- the same math in plain numpy, to CHECK the GPU result. -# These mirror what the model computes. `_f16` rounds through float16 to model -# the GPU's f16 matmul precision, so the comparison tolerance can be tight. -# ============================================================================= -def _ln(x, gamma, beta, eps=1e-5): - mu = x.mean(-1, keepdims=True) - var = x.var(-1, keepdims=True) - return (x - mu) / np.sqrt(var + eps) * gamma + beta - - -def _f16(a): - # round f32 -> f16 -> f32: models the precision loss of the GPU's f16 matmul. - return a.astype(np.float16).astype(np.float32) - - def main(): """Entry point. Builds the FULL gpt model (n_layer blocks -> ln_f -> lm_head), flash multi-head NON-CAUSAL attention per block. Flags: diff --git a/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py b/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py new file mode 100644 index 00000000..3e358bd5 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py @@ -0,0 +1,459 @@ +"""Generate MLIR payload for a GPU nano-GPT / GPT-2-style forward pass. + +This is STAGE 1 -- the PAYLOAD ("WHAT to compute"). It builds an MLIR module +describing the model math as high-level `linalg` ops (matmul, layernorm, softmax, +elementwise). It is hardware-agnostic; HOW it lowers to the GPU lives in +`lighthouse.schedule.xegpu.nanoGPT_schedule`. + + -> class `Builder` (emits one op at a time) and `build_gpt_fused_payload` + (assembles ops into ffn / attn / block / full-gpt). + +Architecture: each transformer block is + a = x + attn_proj( MultiHeadAttention( ln1(x) ) ) # attention sublayer + y = a + ffn( ln2(a) ) # MLP sublayer + ffn(z) = Linear(C, 4C) -> ReLU -> Linear(4C, C) +and the full model is + x = token_emb + pos_emb # embeddings (done host-side) + for _ in range(n_layer): x = Block(x) + x = ln_f(x); logits = x @ lm_head +TRUE multi-head: H heads of head_size = C/H = 64, computed by ONE fused +flash-attention kernel per block (NON-CAUSAL for now). +""" + +from mlir import ir +from mlir.dialects import linalg, bufferization, tensor, arith, math, gpu, memref + +from lighthouse.ingress.mlir_gen.utils import ( + emit_buf_to_tensor, + affine_map, + parallel, + reduction, +) +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.named import times_weights + + +def F32(): # 32-bit float (used for accumulation / norms) + return ir.F32Type.get() + + +def F16(): # 16-bit float (required by the GPU matmul units) + return ir.F16Type.get() + + +# ============================================================================= +# PAYLOAD: describe WHAT to compute (hardware-agnostic linalg ops) +# ============================================================================= +# Each Builder method emits ONE high-level op that writes its result into a fresh +# on-device buffer (`gpu.alloc`), and returns a tensor "view" of that buffer for +# the next op to read. Because each op writes a distinct device buffer, each will +# become its OWN GPU kernel later; the buffers are the on-device handoff between +# kernels (kernel N writes buffer B, kernel N+1 reads B -- no host round-trip). +# +# dtype convention: the GPU matmul (DPAS) hardware needs f16 inputs and produces +# an f32 result. LayerNorm/softmax run in f32. So between a norm/softmax and a +# matmul we insert an explicit f32->f16 `cast` op (its own kernel). +# ============================================================================= +class Builder: + """Emits the model's ops and remembers the order/kind of each one. + + `kinds` is the crucial bookkeeping: an ordered list, one entry per op emitted, + recording its "class" so the SCHEDULE (stage 2) can later tile and annotate + each kernel correctly. Classes: + 'mm' = matmul (linalg.matmul) -> DPAS systolic-array kernel + 'ln' = layernorm (3 generics + 2 fills) -> reduction kernel (uses shared mem) + 'fa' = flash multi-head attention -> ONE kernel (QK^T->softmax->@V, + online-softmax over K/V tiles); see attention_4d + the fused-attention + schedule helpers. (Softmax lives INSIDE this kernel, not as its own.) + 'ew' = elementwise (cast / bias / relu / residual) -> simple row-parallel kernel + The op build order in the payload == the order of `kinds` == the order the + kernels appear in the final module, which is how the schedule matches them up. + """ + + def __init__(self, T): + self.T = T + self.f32, self.f16 = F32(), F16() + self.kinds = [] # ordered kernel classes (see docstring) + self.to_dealloc = [] # device buffers to gpu.dealloc at the end + + def _buf(self, shape, dtype): + # Allocate a DEVICE buffer (lives in GPU memory). Returns the memref. + b = gpu.alloc(ir.MemRefType.get(shape, dtype), None, [], [], []) + self.to_dealloc.append(b) + return b + + def _par(self, rank=2): + # Identity affine map (d0,d1,...) -> (d0,d1,...): a plain elementwise + # access pattern where output[i,j] depends on input[i,j]. + return affine_map(rank, [ir.AffineDimExpr.get(i) for i in range(rank)]) + + # ---- matmul: a(M,K) f16 @ b(K,N) f16 -> (M,N) f32 buffer ---- + def matmul(self, a, b, M, N, out_buf=None): + # Standard C = A @ B. `times_weights` emits linalg.matmul; we first fill the + # accumulator with 0. f16 inputs, f32 output -- matches the DPAS hardware. + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + acc = linalg.fill(arith.constant(self.f32, 0.0), outs=[out_t]) + res = times_weights(a, b, acc) + bufferization.materialize_in_destination( + None, res, buf, restrict=True, writable=True + ) + self.kinds.append("mm") + if out_buf is not None: # caller gave the final output buffer + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- layernorm(x (M,N) f32, gamma,beta (N,)) -> (M,N) f32 buffer ---- + def layernorm(self, x, gamma, beta, M, N, eps=1e-5): + # LayerNorm normalizes each ROW (length N) to mean 0 / variance 1, then + # scales by gamma and shifts by beta. Built from 3 linalg.generic ops: + # (1) mean_sum[i] = sum_j x[i,j] (row reduction) + # (2) var_sum[i] = sum_j (x[i,j]-mean_i)^2 (row reduction) + # (3) out[i,j] = (x[i,j]-mean_i)*rsqrt(var_i+eps)*gamma[j] + beta[j] + # Affine maps describe each operand's access pattern: + # par2 (d0,d1)->(d0,d1) : full 2-D elementwise + # red2 (d0,d1)->(d0) : reduce over j -> one value per row + # bias2 (d0,d1)->(d1) : gamma/beta indexed by column only + f32 = self.f32 + par2, red2 = self._par(), affine_map(2, [ir.AffineDimExpr.get(0)]) + bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) + inv_n = arith.constant(f32, 1.0 / float(N)) + eps_c = arith.constant(f32, eps) + zero = arith.constant(f32, 0.0) + # (1) row sums -> mean_sum (linalg.fill zeroes the accumulator first) + mean_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) + + @linalg.generic([x], [mean_acc], [par2, red2], [parallel, reduction]) + def mean_sum(v, acc): + return arith.AddFOp(v, acc) + + # (2) sum of squared deviations -> var_sum (mean_i = mean_sum_i / N) + var_acc = linalg.fill(zero, outs=[tensor.empty((M,), f32)]) + + @linalg.generic( + [x, mean_sum], [var_acc], [par2, red2, red2], [parallel, reduction] + ) + def var_sum(v, m_sum, acc): + mean = arith.MulFOp(m_sum, inv_n).result + c = arith.SubFOp(v, mean).result + return arith.AddFOp(arith.MulFOp(c, c).result, acc) + + # (3) normalize + scale + shift -> output + buf = self._buf((M, N), f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic( + [x, mean_sum, var_sum, gamma, beta], + [out_t], + [par2, red2, red2, bias2, bias2, par2], + [parallel, parallel], + ) + def normed(v, m_sum, v_sum, g, b, _o): + mean = arith.MulFOp(m_sum, inv_n).result + var = arith.MulFOp(v_sum, inv_n).result + inv_std = math.rsqrt(arith.AddFOp(var, eps_c).result) + c = arith.SubFOp(v, mean).result + return arith.AddFOp( + arith.MulFOp(arith.MulFOp(c, inv_std).result, g).result, b + ) + + bufferization.materialize_in_destination( + None, normed, buf, restrict=True, writable=True + ) + self.kinds.append("ln") + return emit_buf_to_tensor(buf, restrict=True) + + # ---- elementwise cast f32 -> f16 ---- + def cast_f16(self, x, M, N): + par2 = self._par() + buf = self._buf((M, N), self.f16) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) + def c(s, _o): + return arith.TruncFOp(self.f16, s) + + bufferization.materialize_in_destination( + None, c, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + return emit_buf_to_tensor(buf, restrict=True) + + # ---- bias add (+ optional relu): out = max(x + bias, 0)? x (M,N) f32, bias (N,) ---- + def bias(self, x, bias_vec, M, N, relu=False, out_buf=None): + par2 = self._par() + bias2 = affine_map(2, [ir.AffineDimExpr.get(1)]) + zero = arith.constant(self.f32, 0.0) + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic( + [x, bias_vec], [out_t], [par2, bias2, par2], [parallel, parallel] + ) + def b(v, bb, _o): + s = arith.AddFOp(v, bb).result + if relu: + return arith.MaximumFOp(s, zero) + return arith.AddFOp(s, zero) # identity wrap so the op has a body + + bufferization.materialize_in_destination( + None, b, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + if out_buf is not None: + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- residual add: out = a + b (both (M,N) f32) ---- + def add(self, a, b, M, N, out_buf=None): + par2 = self._par() + buf = out_buf if out_buf is not None else self._buf((M, N), self.f32) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([a, b], [out_t], [par2, par2, par2], [parallel, parallel]) + def r(x, y, _o): + return arith.AddFOp(x, y) + + bufferization.materialize_in_destination( + None, r, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + if out_buf is not None: + return None + return emit_buf_to_tensor(buf, restrict=True) + + # ---- cast f32 (T,C) -> f16 (T,C), returning the MEMREF buffer (for views) ---- + def cast_f16_buf(self, x, T, C): + par2 = self._par() + buf = self._buf((T, C), self.f16) + out_t = emit_buf_to_tensor(buf, restrict=True, writable=True) + + @linalg.generic([x], [out_t], [par2, par2], [parallel, parallel]) + def c(s, _o): + return arith.TruncFOp(self.f16, s) + + bufferization.materialize_in_destination( + None, c, buf, restrict=True, writable=True + ) + self.kinds.append("ew") + return buf + + # ---- view a (T, H*hs) MEMREF as (H, T, hs) -- NO kernel, NO data move ---- + def _heads_view_of(self, buf2d, T, H, hs): + # We present the 2D + # (T, H*hs) projection buffer as a (H,T,hs) STRIDED memref VIEW: + # (T,H*hs) --memref.expand_shape--> (T,H,hs) [strides C,hs,1] + # --memref.transpose [1,0,2]--> (H,T,hs) [strides hs,C,1] + # Both are pure layout ops (no compute, no kinds entry). When the fused + # schedule tiles (1,wg_rows,0,0), the grid peels head h -> a 2D + # memref> -> 2D load_nd (XeGPU supports + # such strided block loads). + C = H * hs + et = buf2d.type.element_type + exp_t = ir.MemRefType.get((T, H, hs), et) + e = memref.expand_shape( + exp_t, buf2d, [[0], [1, 2]], [], static_output_shape=[T, H, hs] + ) + d0, d1, d2 = (ir.AffineDimExpr.get(i) for i in range(3)) + perm = ir.AffineMap.get(3, 0, [d1, d0, d2]) # (H,T,hs) <- (T,H,hs) + layout = ir.StridedLayoutAttr.get(0, [hs, C, 1]) + res_t = ir.MemRefType.get((H, T, hs), et, layout=layout) + return memref.transpose(res_t, e, perm) + + def heads_view(self, buf2d, T, H, hs): + return emit_buf_to_tensor(self._heads_view_of(buf2d, T, H, hs), restrict=True) + + # ---- fused multi-head attention core on 3D (H,T,hs) f16 -> (H,T,hs) f16 ---- + # (named attention_4d because it is the canonical 4D (Z,H,T,hs) attention + # algorithm with the batch dim Z=1 FOLDED OUT: one sequence, so (1,H,T,hs) + # collapses to (H,T,hs) and linalg.batch_matmul treats H as the batch axis.) + def attention_4d(self, Qh, Kh, Vh, H, T, hs, out_view, out_view_memref): + # Emits the SAME linalg op sequence as generate_gpu_attention_payload + # (batch_matmul QK^T -> scale-mul -> softmax -> batch_matmul @V), so the + # fused-attention schedule's matchers/rewrite apply verbatim. After the + # per-region fused tiling, ALL these ops fuse into ONE scf.forall -> ONE + # GPU kernel (the flash/online-softmax kernel). Counts as one 'fa'. + # Inputs Qh/Kh/Vh are (H,T,hs) f16 strided VIEWS (heads_view); the @V result + # is materialized into `out_view`, a (H,T,hs) strided view of a (T,C) buffer, + # so the merge back to 2D is also a free view (no from_heads kernel). + f16 = self.f16 + scale = 1.0 / (hs**0.5) + zero = arith.constant(f16, 0.0) + # K^T: (H,T,hs) -> (H,hs,T). Lowers to a 2D vector.transpose per head (the + # grid peels H), exactly like the standalone -- f16 is fine here. + Kt = linalg.transpose( + Kh, outs=[tensor.empty((H, hs, T), f16)], permutation=[0, 2, 1] + ) + qkt_init = linalg.fill(zero, outs=[tensor.empty((H, T, T), f16)]) + qkt = linalg.batch_matmul(Qh, Kt, outs=[qkt_init]) + sc = arith.constant(f16, scale) + scale_t = linalg.fill(sc, outs=[tensor.empty((H, T, T), f16)]) + scaled = linalg.mul(qkt, scale_t, outs=[tensor.empty((H, T, T), f16)]) + aw = linalg.softmax( + result=[ir.RankedTensorType.get((H, T, T), f16)], + input=scaled, + output=tensor.empty((H, T, T), f16), + dimension=2, + ) + # @V: (H,T,T) @ (H,T,hs) -> (H,T,hs) f16, materialized into the (T,C) view. + out_filled = linalg.fill(zero, outs=[out_view]) + out = linalg.batch_matmul(aw, Vh, outs=[out_filled]) + bufferization.materialize_in_destination( + None, out, out_view_memref, restrict=True, writable=True + ) + self.kinds.append("fa") + + # ---- fused multi-head attention(ln_f32 (T,C) f32) -> (T,C) f16, NON-CAUSAL ---- + def fused_attention(self, x, wq, wk, wv, T, C, H): + # True multi-head attention via the flash kernel, with NO on-device + # head-transpose kernel. Flow: + # x(f32) -cast-> f16 -q/k/v proj-> (T,C) f16 buffers -heads_view (free)-> + # (H,T,hs) strided views -> attention_4d (fused flash kernel) -> @V written + # into a (T,C) f16 buffer via its (H,T,hs) view -> return that (T,C) f16. + hs = C // H + x16 = self.cast_f16(x, T, C) # ew + qbuf = self.cast_f16_buf( + self.matmul(x16, wq, T, C), T, C + ) # mm, ew -> (T,C) f16 memref + kbuf = self.cast_f16_buf(self.matmul(x16, wk, T, C), T, C) # mm, ew + vbuf = self.cast_f16_buf(self.matmul(x16, wv, T, C), T, C) # mm, ew + Qh = self.heads_view(qbuf, T, H, hs) # (H,T,hs) strided view (free) + Kh = self.heads_view(kbuf, T, H, hs) + Vh = self.heads_view(vbuf, T, H, hs) + # Output (T,C) f16 buffer, viewed as (H,T,hs) for the @V store. + out_buf = self._buf((T, C), self.f16) + out_view_memref = self._heads_view_of(out_buf, T, H, hs) + out_view = emit_buf_to_tensor(out_view_memref, restrict=True, writable=True) + self.attention_4d( + Qh, Kh, Vh, H, T, hs, out_view, out_view_memref + ) # fa, writes out_buf + return emit_buf_to_tensor(out_buf, restrict=True) # (T,C) f16 + + +# --------------------------------------------------------------------------- +# PAYLOAD ASSEMBLY -- wire the Builder ops into a complete MLIR function. +# `build_gpt_fused_payload` creates one `func.func` (the "payload") whose arguments +# are the input + all weights (as device memrefs) and whose body is the op graph. +# `emit_buf_to_tensor` turns a memref argument into a tensor the ops can read; +# `func_cif` makes the function callable from C/the Runner. Returns (module, +# kinds) where `kinds` drives the schedule. +# --------------------------------------------------------------------------- + + +def _emit_block_fused(bld, x, w, T, C, hidden, H, eps, out_buf=None): + """Emit ONE transformer block whose attention sublayer is the FUSED multi-head + flash kernel (NON-CAUSAL, no mask). `w` weight keys: g1,b1n, wq,wk,wv, wp,bp, + g2,b2n, w1,bb1,w2,bb2. wq/wk/wv/wp are full (C,C).""" + # ---- attention sublayer: a = x + proj(fused_attn(ln1(x))) ---- + ln1 = bld.layernorm(x, w["g1"], w["b1n"], T, C, eps) + attn16 = bld.fused_attention(ln1, w["wq"], w["wk"], w["wv"], T, C, H) # f16 (T,C) + proj = bld.matmul(attn16, w["wp"], T, C) + proj = bld.bias(proj, w["bp"], T, C, relu=False) + a = bld.add(x, proj, T, C) + # ---- FFN sublayer: y = a + ffn(ln2(a)) ---- + ln2 = bld.layernorm(a, w["g2"], w["b2n"], T, C, eps) + ln2_16 = bld.cast_f16(ln2, T, C) + h = bld.matmul(ln2_16, w["w1"], T, hidden) + h = bld.bias(h, w["bb1"], T, hidden, relu=True) + h16 = bld.cast_f16(h, T, hidden) + o = bld.matmul(h16, w["w2"], T, C) + o = bld.bias(o, w["bb2"], T, C, relu=False) + return bld.add(a, o, T, C, out_buf=out_buf) + + +def build_gpt_fused_payload(func_name, T, C, hidden, vocab, n_layer, H, eps=1e-5): + """Full nanoGPT forward as ONE module, with FUSED multi-head attention per block. + Multi-head (H heads, flash attention), NON-CAUSAL (no mask), wq/wk/wv/wp are + (C,C). Embeddings done host-side. Returns (module, kinds).""" + f32, f16 = F32(), F16() + mod = ir.Module.create() + x_t = ir.MemRefType.get((T, C), f32) # input activations (256,256) f32 + g_t = ir.MemRefType.get((C,), f32) # layernorm gamma/beta vectors (256,) f32 + wqkv_t = ir.MemRefType.get((C, C), f16) # q/k/v projection weights (256,256) f16 + wproj_t = ir.MemRefType.get( + (C, C), f16 + ) # attention output proj weight (256,256) f16 + bvec_t = ir.MemRefType.get((C,), f32) # bias vectors (256,) f32 + w1_t = ir.MemRefType.get((C, hidden), f16) # FFN up-projection (256,1024) f16 + b1_t = ir.MemRefType.get((hidden,), f32) # FFN hidden bias (1024,) f32 + w2_t = ir.MemRefType.get((hidden, C), f16) # FFN down-projection (1024,256) f16 + lmw_t = ir.MemRefType.get((C, vocab), f16) # lm_head weight (256,256) f16 + lmb_t = ir.MemRefType.get((vocab,), f32) # lm_head bias (256,) f32 + out_t = ir.MemRefType.get((T, vocab), f32) # output logits (256,256) f32 + # per-layer arg types: g1,b1n, wq,wk,wv, wp,bp, g2,b2n, w1,bb1,w2,bb2 (13) -- NO mask. + per_layer = [ + g_t, + g_t, + wqkv_t, + wqkv_t, + wqkv_t, + wproj_t, + bvec_t, + g_t, + g_t, + w1_t, + b1_t, + w2_t, + bvec_t, + ] + from lighthouse.utils.mlir import func_cif + + fargs = [out_t, x_t] + for _ in range(n_layer): + fargs += per_layer + fargs += [g_t, g_t, lmw_t, lmb_t] # ln_f gamma/beta, lm_head W,b (no mask) + bld = Builder(T) + with ir.InsertionPoint(mod.body): + + @func_cif(*fargs, name=func_name) + def payload(*args): + output = args[0] + emit_buf_to_tensor(output, restrict=True, writable=True) + x = emit_buf_to_tensor(args[1], restrict=True) + idx = 2 + layer_w = [] + keys = [ + "g1", + "b1n", + "wq", + "wk", + "wv", + "wp", + "bp", + "g2", + "b2n", + "w1", + "bb1", + "w2", + "bb2", + ] + for _ in range(n_layer): + w = { + k: emit_buf_to_tensor(args[idx + i], restrict=True) + for i, k in enumerate(keys) + } + idx += len(keys) + layer_w.append(w) + gf_g = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + gf_b = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + lmw = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + lmb = emit_buf_to_tensor(args[idx], restrict=True) + idx += 1 + + h = x + for w in layer_w: + h = _emit_block_fused(bld, h, w, T, C, hidden, H, eps) + hf = bld.layernorm(h, gf_g, gf_b, T, C, eps) + hf16 = bld.cast_f16(hf, T, C) + logits = bld.matmul(hf16, lmw, T, vocab) + bld.bias(logits, lmb, T, vocab, relu=False, out_buf=output) + for b in bld.to_dealloc: + gpu.dealloc(None, [], b) + + emit_gpu_util_funcs(f32, rank=2) + emit_gpu_util_funcs(f32, rank=1) + emit_gpu_util_funcs(f16, rank=2) + return mod, bld.kinds diff --git a/lighthouse/schedule/xegpu/nanoGPT_schedule.py b/lighthouse/schedule/xegpu/nanoGPT_schedule.py new file mode 100644 index 00000000..62c99576 --- /dev/null +++ b/lighthouse/schedule/xegpu/nanoGPT_schedule.py @@ -0,0 +1,646 @@ +"""Transform-dialect SCHEDULE for the GPU nano-GPT forward pass. + +This is STAGE 2 -- the SCHEDULE ("HOW to lower it"). A "schedule" here is itself an +MLIR module written in the TRANSFORM dialect: a little program of rewrite ops that +the transform interpreter runs over the payload module (built in +`lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload`). It does NOT compute anything; +it REWRITES the payload from high-level linalg ops down to GPU (XeGPU) kernels. + + -> `build_combined_schedule` / `_bundle` (the orchestrator) plus the + `_tile_one_matmul` / `_tile_one_layernorm` / `_tile_one_fused_attention_region` + helpers. + +We can't reuse the repo's per-op schedules (layer_norm_schedule, mlp_schedule, +softmax_schedule) directly, because each assumes the module contains ONLY its op. +The nanoGPT module is mixed (matmul + layernorm + softmax + elementwise), so we +build ONE COMBINED schedule that handles all op classes. The strategy: + + (a) TILE each op into its own parallel loop nest (`scf.forall` = the GPU + work-group grid). Different op classes tile differently: + - matmul -> `_tile_one_matmul` (work-group tile + k-loop tile; the + DPAS tile sizes come from `mm_params`) + - layernorm-> `_tile_one_layernorm` (tile rows, fuse the 2 reductions + + 2 zero-fills into the loop) + - fused attn-> `_tile_one_fused_attention_region` (tile @V batch_matmul into + a forall, fuse QK^T/scale/softmax/@V in; flash rewrite later) + - elementwise -> a single `structured_tile_using_forall` over rows + (b) SHARED TAIL (same for every kernel): vectorize -> bufferize (tensors -> + memrefs) -> convert the forall grids to `gpu.launch` -> OUTLINE each into + its own `gpu.module`/`gpu.func` kernel -> attach the XeVM target. + (c) ANNOTATE each kernel with XeGPU layout attributes (how data maps to + sub-groups / DPAS tiles). + +`kinds` (from the Builder) tells the schedule the class and order of every kernel, +so steps (a) and (c) can treat each one correctly. +""" + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.transform.vector import ( + apply_patterns_vector_cast_away_vector_leading_one_dim, + apply_patterns_vector_drop_unit_dims_with_shape_cast, +) +from mlir.dialects.bufferization import LayoutMapOption + +import lighthouse.transform as lh_transform +from lighthouse.dialects.transform import transform_ext +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, + match_and_split, + PipelineInterrupt, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.schedule.xegpu.mlp_schedule import xegpu_wg_annotation_for_mlp_layer +from lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload import F32 + + +def _tile_one_matmul(matmul_op, anytype, mm_params): + """Tile ONE matmul for DPAS: a work-group `forall` tile (wg_m x wg_n) with any + elementwise consumer fused in, then an inner reduction (k) loop. Tile sizes + come from `mm_params` (chosen by xegpu_parameter_selector for the GPU).""" + wg_tile = [mm_params["wg_m"], mm_params["wg_n"]] + consumers = transform_ext.get_tileable_consumers(matmul_op) + leaf = transform_ext.extract_handle(consumers, -1) + _, [wg_loop], _ = lh_transform.tile( + leaf, + tile_sizes=wg_tile, + fuse_producers=True, + use_forall=True, + apply_cleanup=False, + ) + wg_matmul = match(wg_loop, ops={"linalg.matmul"}) + lh_transform.tile(wg_matmul, tile_sizes=[0, 0, mm_params["k_tile"]]) + + +def _tile_one_fused_attention_region(anytype, qkt_bmm, pv_bmm, softmax_op, fa_params): + """Tile + fuse ONE attention region (QK^T -> scale -> softmax -> @V) into a + SINGLE scf.forall, so it vectorizes/bufferizes into one kernel body that + `replace_with_fused_attention` later rewrites into the flash loop. + + Operates on PRE-SPLIT, per-region + handles (qkt_bmm, pv_bmm, softmax_op) so it is region-local and works at any + multiplicity. All further producers are pulled in via get_producer_of_operand + (SSA-walk = inherently scoped to this region).""" + prod = transform.get_producer_of_operand + + def fuse(p, c): + return structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=p, containing_op=c + )[1] + + wg_rows = fa_params["wg_rows"] + # 1. Tile the @V batch_matmul in (batch=1, M=wg_rows) -> forall grid. + tiled_pv, forall = structured.structured_tile_using_forall( + anytype, + anytype, + pv_bmm, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(1, wg_rows, 0, 0), + ) + func = transform.get_parent_op( + anytype, forall, op_name="func.func", deduplicate=True + ) + # 2. Fuse the @V output init fill (producer of forall operand 0). + forall = fuse(prod(anytype, forall, operand_number=0), forall) + transform.apply_cse(func) + canonicalize(func) + # 3. Decompose this region's softmax. linalg.softmax -> 4 generics + 2 fills: + # max = reduce_max(scaled) [+ -inf fill] + # num = exp(scaled - max) + # den = reduce_sum(num) [+ 0 fill] + # div = num / den (feeds @V) + structured.structured_decompose_interface(anytype, softmax_op) + transform.apply_cse(func) + canonicalize(func) + # Grab the whole producer chain UP FRONT via SSA walk (region-local; no count + # matching). Fusing op X invalidates only X's handle, so collect all, then fuse + # each once in consumer->producer topological order. + # tiled_pv operand 0 is the aw extract_slice inside the forall; hop through it + # to the func-scope softmax `div` that it slices. + aw_slice = prod(anytype, tiled_pv, operand_number=0) + div = prod(anytype, aw_slice, operand_number=0) # num / den (softmax out) + num = prod(anytype, div, operand_number=0) # exp generic + den = prod(anytype, div, operand_number=1) # sum-reduce generic + den_fill = prod(anytype, den, operand_number=1) # 0 fill (sum acc) + mx = prod(anytype, num, operand_number=1) # max-reduce generic + mx_fill = prod(anytype, mx, operand_number=1) # -inf fill (max acc) + scaled = prod(anytype, num, operand_number=0) # linalg.mul (qkt*scale) + scale_fill = prod(anytype, scaled, operand_number=1) # scale-constant fill + qkt = prod(anytype, scaled, operand_number=0) # QK^T batch_matmul + kt = prod(anytype, qkt, operand_number=1) # K^T transpose + qkt_fill = prod(anytype, qkt, operand_number=2) # 0 fill (qkt acc) + for p in ( + div, + den, + num, + mx, + scaled, + qkt, + den_fill, + mx_fill, + scale_fill, + qkt_fill, + kt, + ): + forall = fuse(p, forall) + transform.apply_cse(func) + canonicalize(func) + return func, forall + + +def _fuse_attention_in_region(anytype, forall, fa_params): + """After the shared bufferize+vectorize, rewrite ONE attention region's + vector.contract pair (QK^T, @V) into the flash loop via the transform + op. Scoped to `forall` so counts are exact at any multiplicity.""" + contract_ops = match_and_split(forall, ops={"vector.contract"}, nhandles=2) + first_contract, second_contract = contract_ops[0], contract_ops[1] + q_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=0 + ) + k_load = transform.get_producer_of_operand( + anytype, first_contract, operand_number=1 + ) + v_load = transform.get_producer_of_operand( + anytype, second_contract, operand_number=1 + ) + mulf_op = match_and_split(forall, ops={"arith.mulf"}, nhandles=1)[0] + scale = transform.get_producer_of_operand(anytype, mulf_op, operand_number=1) + # NB: the merged fused-attention op is non-causal only -- there is + # no `causal` parameter yet, so the model runs as non-causal attention. + transform_ext.replace_with_fused_attention( + q_load=q_load, + k_load=k_load, + v_load=v_load, + scale=scale, + output=second_contract, + tile_size=fa_params["inner_loop_tile_size"], + ) + + +def xegpu_fa_annotation(gf, anytype, fa_params): + """Attach XeGPU layouts to ONE fused-attention gpu.func.""" + num_subgroups = fa_params["wg_rows"] // fa_params["sg_rows"] + n_head = fa_params["n_head"] + q_sg_layout = [num_subgroups, 1] + q_sg_data = [16, n_head] + q_inst_data = [8, 16] + k_sg_layout = [num_subgroups, 1] + k_sg_data = [16, n_head] + k_inst_data = [16, 16] + v_sg_layout, v_sg_data, v_inst_data = k_sg_layout, k_sg_data, k_inst_data + kt_sg_layout = [1, num_subgroups] + kt_sg_data = [n_head, 16] + kt_inst_data = [16, 16] + kt_order = [0, 1] + out_sg_layout, out_sg_data, out_inst_data = q_sg_layout, q_sg_data, q_inst_data + l128_sg_layout = [num_subgroups, 1] + l128_sg_data = [16, 16] + l128_inst_data = [8, 16] + qk_sg_layout, qk_sg_data, qk_inst_data = ( + l128_sg_layout, + l128_sg_data, + l128_inst_data, + ) + + store_nd_op = match_and_split(gf, ops={"xegpu.store_nd"}, nhandles=1)[0] + xegpu.set_anchor_layout( + store_nd_op, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + ) + load_nd_ops = match_and_split(gf, ops={"xegpu.load_nd"}, nhandles=9) + xegpu.set_anchor_layout( + load_nd_ops[0], sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data + ) + for i in range(1, 5): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=k_sg_layout, + sg_data=k_sg_data, + inst_data=k_inst_data, + ) + for i in range(5, 9): + xegpu.set_anchor_layout( + load_nd_ops[i], + sg_layout=v_sg_layout, + sg_data=v_sg_data, + inst_data=v_inst_data, + ) + dpas_ops = match_and_split(gf, ops={"xegpu.dpas"}, nhandles=8) + for i in range(4): + d = dpas_ops[i] + xegpu.set_anchor_layout( + d, sg_layout=q_sg_layout, sg_data=q_sg_data, inst_data=q_inst_data, index=0 + ) + xegpu.set_anchor_layout( + d, + sg_layout=kt_sg_layout, + sg_data=kt_sg_data, + inst_data=kt_inst_data, + order=kt_order, + index=1, + ) + xegpu.set_anchor_layout( + d, + sg_layout=l128_sg_layout, + sg_data=l128_sg_data, + inst_data=l128_inst_data, + index=2, + ) + for i in range(4, 8): + d = dpas_ops[i] + xegpu.set_anchor_layout( + d, + sg_layout=qk_sg_layout, + sg_data=qk_sg_data, + inst_data=qk_inst_data, + index=0, + ) + xegpu.set_anchor_layout( + d, sg_layout=v_sg_layout, sg_data=v_sg_data, inst_data=v_inst_data, index=1 + ) + xegpu.set_anchor_layout( + d, + sg_layout=out_sg_layout, + sg_data=out_sg_data, + inst_data=out_inst_data, + index=2, + ) + + +def build_combined_schedule( + mm_params, ln_params, kinds, stop_at_stage="", fa_params=None +): + """Build the transform-dialect schedule module for a payload with op classes + `kinds`. Counts how many of each class there are, then delegates to `_bundle` + (wrapped in transform boilerplate). `stop_at_stage` lets callers halt early + for debugging (--dump ).""" + n_mm = kinds.count("mm") + n_ln = kinds.count("ln") + n_sm = kinds.count("sm") + n_ew = kinds.count("ew") + with schedule_boilerplate() as (schedule, named_seq): + anytype = transform.AnyOpType.get() + func0 = match(named_seq.bodyTarget, ops={"func.func"}) + mod = transform.get_parent_op( + anytype, func0, op_name="builtin.module", deduplicate=True + ) + try: + _bundle( + mod, + mm_params, + ln_params, + kinds, + n_mm, + n_ln, + n_sm, + n_ew, + stop_at_stage, + fa_params=fa_params, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + return schedule + + +def _bundle( + mod, + mm_params, + ln_params, + kinds, + n_mm, + n_ln, + n_sm, + n_ew, + stop_at_stage="", + fa_params=None, +): + """THE PASS ORCHESTRATOR -- emits the actual sequence of transform ops. + + Runs in 3 phases over the whole payload module: + TILE -- tile every op into a GPU work-group `forall` (per op class) + SHARED TAIL -- vectorize, bufferize, forall->gpu.launch, outline kernels, + attach the XeVM target, lower vector ops to XeGPU + ANNOTATE -- attach XeGPU sub-group/DPAS layout to each kernel + `stop_at_stage` raises PipelineInterrupt to halt after a phase (for --dump). + Reading the inline comments here is the best way to understand "which part of + the code schedules the passes" -- it is this function, top to bottom.""" + anytype = transform.AnyOpType.get() + rss = ln_params["reduction_step_size"] + wg_rows = ln_params["wg_rows"] + nkernels = len(kinds) + n_fa = kinds.count("fa") + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + # ===== TILE each op-class into its own forall ===== + # KEY PROBLEM: match(linalg.generic) is NOT scoped -- once an op is tiled into + # a forall, its generic is STILL matched (it's just nested), so we can't + # re-match "the remaining bare generics" by count. SOLUTION: split ALL generic + # handles ONCE up front (their build order is deterministic), then tile each + # using its preserved handle. A handle to op X stays valid across tiling of + # OTHER ops. We tile the simple EW generics first (no fusion/cleanup, so ln + # handles survive), then the layernorms (which fuse + cleanup). + # + # Generic build order: each layernorm contributes [mean, var, normalize] (3), + # in block build order; each elementwise contributes 1. We reconstruct the + # per-op handle slices from `kinds`. + # 'fa' softmax generics do NOT exist yet (fa is tiled last, softmax still + # un-decomposed), so they are not in this pool. The fa core's linalg.transpose + # /linalg.mul/batch_matmul are not linalg.generic, so also excluded. (The head + # reshape is a pure memref VIEW -- no generic, no kernel; see Builder.heads_view.) + ngen_total = 3 * n_ln + n_ew + gen_handles = transform.split_handle( + (anytype,) * ngen_total, match(mod, ops={"linalg.generic"}) + ) + # Walk kinds to assign generic handles to ops. + ln_slices, ew_handles = [], [] + gi = 0 + for k in kinds: + if k == "ln": + ln_slices.append( + (gen_handles[gi], gen_handles[gi + 1], gen_handles[gi + 2]) + ) + gi += 3 + elif k == "ew": + ew_handles.append(gen_handles[gi]) + gi += 1 + # mm / sm / fa contribute no bare linalg.generic here + + # 1) Tile layernorms FIRST, using preserved (mean,var,normalize) handles. + # Doing this BEFORE EW/matmul tiling keeps the bare linalg.fill pool exactly + # predictable: 2*(untiled lns) + n_mm (matmul accumulator fills). EW tiling + # can introduce its own init fills, so we must finish ln fill-fusion first. + for i, (mean_red, var_red, normalize) in enumerate(ln_slices): + ln_untiled = n_ln - i + _tile_one_layernorm( + mod, + anytype, + wg_rows, + rss, + mean_red, + var_red, + normalize, + ln_untiled, + n_mm, + ln_params["T"], + ) + + # 2) Tile EW generics into own foralls (handles preserved across ln tiling). + for eg in ew_handles: + structured.structured_tile_using_forall( + anytype, + anytype, + eg, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(wg_rows,), + ) + + # 4) Matmuls (their EW producers already wrapped in foralls) + mms = match_and_split(mod, ops={"linalg.matmul"}, nhandles=n_mm) + for mm in mms: + _tile_one_matmul(mm, anytype, mm_params) + + # 5) Fused-attention regions. Done LAST so the generic pre-split above ran while + # each fa softmax was still ONE linalg.softmax (its decomposition generics + # don't exist yet, so they can't inflate ngen_total). Pre-split the 2*n_fa + # batch_matmuls (build order [QK^T, @V] per region) + n_fa softmaxes by count, + # then tile+fuse each region into ONE forall (decompose happens in-region). + if n_fa: + fa_bmms = match_and_split(mod, ops={"linalg.batch_matmul"}, nhandles=2 * n_fa) + fa_softmaxes = match_and_split(mod, ops={"linalg.softmax"}, nhandles=n_fa) + for r in range(n_fa): + _tile_one_fused_attention_region( + anytype, fa_bmms[2 * r], fa_bmms[2 * r + 1], fa_softmaxes[r], fa_params + ) + + func = match(mod, ops={"func.func"}) + lh_transform.cleanup(func) + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # ===== SHARED TAIL ===== + func = structured.structured_vectorize_children_and_apply_patterns( + anytype, func, fold_type_extensions_into_contract=True + ) + lh_transform.cleanup(func) + # Fused-attention regions carry a batch-of-1 dim from the (1,wg_rows,0,0) tiling; + # drop leading unit dims so the QK^T/@V vector.contracts become 2D, as the flash + # rewrite expects. + if n_fa: + with ir.InsertionPoint(transform.apply_patterns(func).patterns): + apply_patterns_vector_cast_away_vector_leading_one_dim() + apply_patterns_vector_drop_unit_dims_with_shape_cast() + transform.apply_cse(func) + canonicalize(func) + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap, + ).result + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # ===== FUSED-ATTENTION REWRITE (after bufferize+vectorize, before gpu.launch) ===== + # Re-find each attention forall by kinds index (forall IR order == kinds order, + # the invariant the launch/gpu_mods loops below also rely on) and rewrite its + # QK^T/@V vector.contract pair into the flash online-softmax loop. Must run + # BEFORE forall->gpu.launch so the producer-walks for q/k/v loads stay in-region. + if n_fa: + all_foralls = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) + for idx, kind in enumerate(kinds): + if kind == "fa": + _fuse_attention_in_region(anytype, all_foralls[idx], fa_params) + func = match(mod, ops={"func.func"}) + transform.apply_cse(func) + canonicalize(func) + if stop_at_stage == "inner-tiled": + raise PipelineInterrupt() + + wg_loops = match_and_split(mod, ops={"scf.forall"}, nhandles=nkernels) + for wg_loop in wg_loops: + loop.loop_forall_to_parallel([anytype], wg_loop) + func = match(mod, ops={"func.func"}) + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # launch threads per kernel, in IR (build) order = `kinds`. + launches = match_and_split(mod, ops={"gpu.launch"}, nhandles=nkernels) + mm_threads = ( + (mm_params["wg_m"] // mm_params["sg_m"]) + * (mm_params["wg_n"] // mm_params["sg_n"]) + * 16 + ) + sm_threads = (ln_params["wg_rows"] // ln_params["sg_rows"]) * ln_params[ + "subgroup_size" + ] + fa_threads = ( + (fa_params["wg_rows"] // fa_params["sg_rows"]) * fa_params["subgroup_size"] + if fa_params + else 0 + ) + for launch, kind in zip(launches, kinds): + nt = {"mm": mm_threads, "fa": fa_threads}.get(kind, sm_threads) + xegpu.set_gpu_launch_threads(launch, threads=[nt, 1, 1]) + + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + mod = apply_registered_pass( + mod, "xevm-attach-target", options={"O": "3", "chip": "pvc"} + ) + + # per-gpu.module convert-vector-to-xegpu. ONLY ln/sm need SLM allocas (their + # cross-lane reductions go through shared local memory -> store_matrix). The + # ew kernels (cast/bias/residual) are pure row-parallel: forcing their allocas + # to SLM creates store_matrix paths that fail to lower. So SLM-ify ln/sm only; + # leave ew (and mm) as store_nd. + gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) + sg_layout = [ln_params["sg_rows"], 1] + sg_data = [ln_params["sg_rows"], rss] + for gm, kind in zip(gpu_mods, kinds): + gf = match(gm, ops={"gpu.func"}) + if kind in ("ln", "sm"): + allocas = match(gf, ops={"memref.alloca"}) + transform_ext.update_address_space(allocas, address_space=3) + gf = apply_registered_pass(gf, "convert-vector-to-xegpu") + transform.apply_cse(gf) + if kind == "fa": + # flash kernel carries state in iter_args (no SLM); hoist invariants. + gf = apply_registered_pass(gf, "loop-invariant-code-motion") + transform.apply_cse(mod) + canonicalize(mod) + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # ===== PER-KERNEL ANNOTATION ===== + # mm -> full mlp wg annotation + # ln -> store_nd (1) + store_matrix (the SLM reduction stores) + # sm -> store_nd (1) + store_matrix (4) + # ew -> store_nd (1) only (pure row-parallel, no SLM) + gpu_mods = match_and_split(mod, ops={"gpu.module"}, nhandles=nkernels) + for gm, kind in zip(gpu_mods, kinds): + gf = match(gm, ops={"gpu.func"}) + if kind == "mm": + xegpu_wg_annotation_for_mlp_layer(gf, **mm_params) + elif kind == "fa": + xegpu_fa_annotation(gf, anytype, fa_params) + else: + # ln/sm/ew: anchor-layout their store_nd, and (ln/sm) their SLM + # store_matrix. Pass the whole match handle to set_anchor_layout (it + # accepts a multi-handle) -- avoids guessing exact store counts. + xegpu.set_anchor_layout( + match(gf, ops={"xegpu.store_nd"}), sg_layout=sg_layout, sg_data=sg_data + ) + if kind in ("ln", "sm"): + xegpu.set_anchor_layout( + match(gf, ops={"xegpu.store_matrix"}), + sg_layout=sg_layout, + sg_data=sg_data, + ) + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + return mod + + +def _tile_one_layernorm( + mod, anytype, wg_rows, rss, mean_red, var_red, normalize, ln_untiled, n_mm, T_ROWS +): + """Tile ONE layernorm into its own forall, using PRESERVED handles to its 3 + generics (mean_red, var_red, normalize). Handles to other ops stay valid. + + The 2 accumulator fills are matched by their producer relationship: we match + all fills and fuse the ones that feed this ln. To avoid touching matmul fills, + we rely on fuse_into_containing pulling only genuine producers of the forall. + """ + _, ln_forall = structured.structured_tile_using_forall( + anytype, + anytype, + normalize, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(wg_rows,), + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=var_red, containing_op=ln_forall + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=mean_red, containing_op=ln_forall + ) + # Fuse this ln's 2 accumulator fills into the forall. Robustly select ONLY the + # layernorm accumulator fills (NOT matmul fills) by filtering on result type: + # ln accumulators are rank-1 tensor; matmul accumulators are rank-2. + # This avoids fragile positional counting across the whole block. There are + # 2*ln_untiled such rank-1 fills (this ln + other untiled lns); this ln's are + # the FIRST 2 in IR order. + ln_func = transform.get_parent_op( + anytype, ln_forall, op_name="func.func", deduplicate=True + ) + reduce_t = ir.RankedTensorType.get((T_ROWS,), F32()) # ln accumulator type (T,) + fill_match = structured.MatchOp( + anytype, ln_func, ops=["linalg.fill"], filter_result_type=reduce_t + ) + n_ln_fills = 2 * ln_untiled + fills = transform.split_handle((anytype,) * n_ln_fills, fill_match.results[0]) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=fills[1], containing_op=ln_forall + ) + _, ln_forall = structured.structured_fuse_into_containing_op( + anytype, anytype, producer_op=fills[0], containing_op=ln_forall + ) + # Fusion leaves the full-size original fills DEAD at func scope (fusion only + # slices a copy inside the forall). They must be removed or the next ln finds + # too many. Use canonicalize (which does DCE of the dead originals) at FUNC + # scope, but NEVER apply_cse at func scope -- CSE would merge the identical + # live zero-fills ACROSS layernorms. CSE the duplicate GENERICS inside the + # forall only (scoped), so the re-match below finds exactly 3. + transform.apply_cse(ln_forall) + canonicalize(ln_func) + # tile this ln's reductions+normalize (now inside the forall). Re-match the + # 3 generics INSIDE the forall (scoped to ln_forall, so unambiguous: exactly 3). + g2 = match_and_split(ln_forall, ops={"linalg.generic"}, nhandles=3) + structured.TileUsingForOp(g2[2], sizes=[0, rss]) + structured.structured_tile_reduction_using_for( + [anytype], anytype, anytype, anytype, target=g2[1], tile_sizes=[0, rss] + ) + structured.structured_tile_reduction_using_for( + [anytype], anytype, anytype, anytype, target=g2[0], tile_sizes=[0, rss] + ) + transform.apply_cse(ln_forall) + canonicalize(ln_forall) From 70a5f2381875ab2abb3cded634116029bb417ce1 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 18 Jun 2026 21:34:49 +0000 Subject: [PATCH 4/6] Clean up --- examples/xegpu/nanoGPT.py | 6 ++++++ lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py | 4 ---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/xegpu/nanoGPT.py b/examples/xegpu/nanoGPT.py index e814e87b..83683997 100644 --- a/examples/xegpu/nanoGPT.py +++ b/examples/xegpu/nanoGPT.py @@ -69,6 +69,7 @@ from lighthouse.execution import GPUMemoryManager from lighthouse.schedule.xegpu import xegpu_to_binary, XeGPUParameterSelector from lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload import build_gpt_fused_payload +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs from lighthouse.schedule.xegpu.nanoGPT_schedule import build_combined_schedule @@ -183,6 +184,11 @@ def main(): with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() mod, kinds = build_gpt_fused_payload("payload", T, C, hidden, vocab, n_layer, H) + with ir.InsertionPoint(mod.body): + f32, f16 = ir.F32Type.get(), ir.F16Type.get() + emit_gpu_util_funcs(f32, rank=2) + emit_gpu_util_funcs(f32, rank=1) + emit_gpu_util_funcs(f16, rank=2) if dump == "initial": print(mod) print("KINDS:", kinds) diff --git a/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py b/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py index 3e358bd5..15d72a24 100644 --- a/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py @@ -29,7 +29,6 @@ parallel, reduction, ) -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs from lighthouse.ingress.mlir_gen.named import times_weights @@ -453,7 +452,4 @@ def payload(*args): for b in bld.to_dealloc: gpu.dealloc(None, [], b) - emit_gpu_util_funcs(f32, rank=2) - emit_gpu_util_funcs(f32, rank=1) - emit_gpu_util_funcs(f16, rank=2) return mod, bld.kinds From cc88e9ee12db5fdeeedb3cb614f6a71053179dac Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 18 Jun 2026 21:48:48 +0000 Subject: [PATCH 5/6] move files --- examples/xegpu/nanoGPT.py | 14 ++++---------- .../xegpu/nanoGPT_payload.py | 6 +++++- .../xegpu/nanoGPT_schedule.py | 4 ++-- 3 files changed, 11 insertions(+), 13 deletions(-) rename lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py => examples/xegpu/nanoGPT_payload.py (98%) rename {lighthouse/schedule => examples}/xegpu/nanoGPT_schedule.py (99%) diff --git a/examples/xegpu/nanoGPT.py b/examples/xegpu/nanoGPT.py index 83683997..610c959c 100644 --- a/examples/xegpu/nanoGPT.py +++ b/examples/xegpu/nanoGPT.py @@ -39,9 +39,9 @@ HOW THIS EXAMPLE IS ORGANIZED -- compiling a model to the GPU here happens in THREE stages: - 1. PAYLOAD ("WHAT to compute") -> lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py + 1. PAYLOAD ("WHAT to compute") -> examples/xegpu/nanoGPT_payload.py (the `Builder` class + `build_gpt_fused_payload`). Hardware-agnostic linalg. - 2. SCHEDULE ("HOW to lower it") -> lighthouse/schedule/xegpu/nanoGPT_schedule.py + 2. SCHEDULE ("HOW to lower it") -> examples/xegpu/nanoGPT_schedule.py (`build_combined_schedule`). A transform-dialect module that tiles each op into GPU work-groups, vectorizes, bufferizes, outlines each op into its own GPU kernel, and attaches XeGPU layout/target attributes. @@ -68,9 +68,8 @@ from lighthouse.execution.runner import Runner from lighthouse.execution import GPUMemoryManager from lighthouse.schedule.xegpu import xegpu_to_binary, XeGPUParameterSelector -from lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload import build_gpt_fused_payload -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs -from lighthouse.schedule.xegpu.nanoGPT_schedule import build_combined_schedule +from nanoGPT_payload import build_gpt_fused_payload +from nanoGPT_schedule import build_combined_schedule # ============================================================================= @@ -184,11 +183,6 @@ def main(): with ir.Context(), ir.Location.unknown(): lh_dialects.register_and_load() mod, kinds = build_gpt_fused_payload("payload", T, C, hidden, vocab, n_layer, H) - with ir.InsertionPoint(mod.body): - f32, f16 = ir.F32Type.get(), ir.F16Type.get() - emit_gpu_util_funcs(f32, rank=2) - emit_gpu_util_funcs(f32, rank=1) - emit_gpu_util_funcs(f16, rank=2) if dump == "initial": print(mod) print("KINDS:", kinds) diff --git a/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py b/examples/xegpu/nanoGPT_payload.py similarity index 98% rename from lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py rename to examples/xegpu/nanoGPT_payload.py index 15d72a24..f61e89c4 100644 --- a/lighthouse/ingress/mlir_gen/gpu_nanoGPT_payload.py +++ b/examples/xegpu/nanoGPT_payload.py @@ -3,7 +3,7 @@ This is STAGE 1 -- the PAYLOAD ("WHAT to compute"). It builds an MLIR module describing the model math as high-level `linalg` ops (matmul, layernorm, softmax, elementwise). It is hardware-agnostic; HOW it lowers to the GPU lives in -`lighthouse.schedule.xegpu.nanoGPT_schedule`. +`nanoGPT_schedule`. -> class `Builder` (emits one op at a time) and `build_gpt_fused_payload` (assembles ops into ffn / attn / block / full-gpt). @@ -29,6 +29,7 @@ parallel, reduction, ) +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs from lighthouse.ingress.mlir_gen.named import times_weights @@ -452,4 +453,7 @@ def payload(*args): for b in bld.to_dealloc: gpu.dealloc(None, [], b) + emit_gpu_util_funcs(f32, rank=2) + emit_gpu_util_funcs(f32, rank=1) + emit_gpu_util_funcs(f16, rank=2) return mod, bld.kinds diff --git a/lighthouse/schedule/xegpu/nanoGPT_schedule.py b/examples/xegpu/nanoGPT_schedule.py similarity index 99% rename from lighthouse/schedule/xegpu/nanoGPT_schedule.py rename to examples/xegpu/nanoGPT_schedule.py index 62c99576..7c835552 100644 --- a/lighthouse/schedule/xegpu/nanoGPT_schedule.py +++ b/examples/xegpu/nanoGPT_schedule.py @@ -3,7 +3,7 @@ This is STAGE 2 -- the SCHEDULE ("HOW to lower it"). A "schedule" here is itself an MLIR module written in the TRANSFORM dialect: a little program of rewrite ops that the transform interpreter runs over the payload module (built in -`lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload`). It does NOT compute anything; +`nanoGPT_payload`). It does NOT compute anything; it REWRITES the payload from high-level linalg ops down to GPU (XeGPU) kernels. -> `build_combined_schedule` / `_bundle` (the orchestrator) plus the @@ -55,7 +55,7 @@ ) from lighthouse.schedule import schedule_boilerplate from lighthouse.schedule.xegpu.mlp_schedule import xegpu_wg_annotation_for_mlp_layer -from lighthouse.ingress.mlir_gen.gpu_nanoGPT_payload import F32 +from nanoGPT_payload import F32 def _tile_one_matmul(matmul_op, anytype, mm_params): From bfa4b3089d40cc71ebf615c4289799b2eaeab211 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 18 Jun 2026 21:55:17 +0000 Subject: [PATCH 6/6] Fix CI --- examples/xegpu/lit.local.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/xegpu/lit.local.cfg b/examples/xegpu/lit.local.cfg index 62c63b32..3305a483 100644 --- a/examples/xegpu/lit.local.cfg +++ b/examples/xegpu/lit.local.cfg @@ -1,6 +1,8 @@ config.excludes = [ "csv_logger.py", "genetic_algorithm.py", + "nanoGPT_payload.py", + "nanoGPT_schedule.py", "tune_matmul_ga.py", "tune_utils.py", ]