From 5155868b30f217dce9d2efa229c902dc96a3fff5 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 2 Jun 2026 21:24:49 +0000 Subject: [PATCH 1/3] Add Layer norm linalg/xegpu example --- examples/xegpu/layer_norm.py | 331 ++++++++++++++++++ .../mlir_gen/gpu_layer_norm_payload.py | 127 +++++++ lighthouse/schedule/xegpu/__init__.py | 2 + .../schedule/xegpu/layer_norm_schedule.py | 293 ++++++++++++++++ 4 files changed, 753 insertions(+) create mode 100644 examples/xegpu/layer_norm.py create mode 100644 lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py create mode 100644 lighthouse/schedule/xegpu/layer_norm_schedule.py diff --git a/examples/xegpu/layer_norm.py b/examples/xegpu/layer_norm.py new file mode 100644 index 00000000..aac57fd8 --- /dev/null +++ b/examples/xegpu/layer_norm.py @@ -0,0 +1,331 @@ +# RUN: %PYTHON %s --dump-kernel=xegpu-wg | FileCheck %s +# CHECK: module attributes {gpu.container_module} { + +""" +XeGPU layer_norm benchmark. +""" + +import argparse +from typing import Optional +from functools import cached_property + +import numpy as np +from mlir import ir + +from lighthouse import dialects as lh_dialects +from lighthouse.execution.runner import Runner +from lighthouse.pipeline.driver import TransformDriver +from lighthouse.execution import GPUMemoryManager +from lighthouse.utils.numpy import mlir_to_numpy_dtype +from lighthouse.ingress.mlir_gen import get_mlir_elem_type +from lighthouse.ingress.mlir_gen.gpu_layer_norm_payload import ( + generate_gpu_layer_norm_payload, +) +from lighthouse.schedule.xegpu import layer_norm_schedule, xegpu_to_binary + + +def layer_norm_complexity(M: int, N: int, nbytes: int): + """ + Complexity of layer_norm operation. + + Per row of length N: + - N adds for mean reduction + - N subs + N muls + N adds for variance reduction + - N subs + N muls (inv_std) + N muls (gamma) + N adds (beta) + 1 rsqrt + Total ~ 8 FLOPs per element. + """ + flop_count = M * N * 8 + memory_reads = M * N * nbytes + 2 * N * nbytes # input + gamma + beta + memory_writes = M * N * nbytes + return flop_count, memory_reads, memory_writes + + +def check_correctness( + input_arr: np.ndarray, + gamma_arr: np.ndarray, + beta_arr: np.ndarray, + output_arr: np.ndarray, + eps: float, + verbose: int = 0, +) -> bool: + x = input_arr.astype(np.float32) + mean = np.mean(x, axis=1, keepdims=True) + var = np.mean((x - mean) ** 2, axis=1, keepdims=True) + inv_std = 1.0 / np.sqrt(var + eps) + output_ref = (x - mean) * inv_std * gamma_arr.astype(np.float32) + beta_arr.astype( + np.float32 + ) + + output = output_arr.astype(np.float32) + + if verbose > 1: + print("Reference solution (first 5 rows):") + print(output_ref[:5]) + print("Computed solution (first 5 rows):") + print(output[:5]) + + values_ok = np.allclose(output, output_ref, rtol=1e-3, atol=1e-4) + + if verbose: + if values_ok: + print("PASSED") + else: + max_diff = np.abs(output - output_ref).max() + print(f"FAILED! Max abs diff: {max_diff:.6e}") + return values_ok + + +class XeGPULayerNorm: + """ + Layer norm workload on XeGPU. + + Computes layer normalization along the last dimension (rows): + mean_i = (1/N) * sum_j x[i, j] + var_i = (1/N) * sum_j (x[i, j] - mean_i)^2 + out[i, j] = (x[i, j] - mean_i) / sqrt(var_i + eps) * gamma[j] + beta[j] + """ + + def __init__( + self, + M: int, + N: int, + dtype: str = "f32", + eps: float = 1e-5, + ): + self.M = M + self.N = N + self.eps = eps + self.shape = (M, N) + self.bias_shape = (N,) + assert dtype == "f32", "Only f32 type is supported for layer_norm" + self.elem_type = get_mlir_elem_type(dtype) + self.dtype = mlir_to_numpy_dtype(self.elem_type) + self.memory_manager_class = GPUMemoryManager + self.payload_function_name = "payload" + + @cached_property + def _initial_host_arrays(self) -> tuple[np.ndarray]: + """Generate initial values on host with numpy.""" + np.random.seed(42) + input_arr = np.random.uniform(-1.0, 1.0, self.shape).astype(self.dtype) + gamma_arr = np.random.uniform(0.5, 1.5, self.bias_shape).astype(self.dtype) + beta_arr = np.random.uniform(-0.1, 0.1, self.bias_shape).astype(self.dtype) + output_arr = np.zeros(self.shape, dtype=self.dtype) + return (output_arr, input_arr, gamma_arr, beta_arr) + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + return layer_norm_complexity(self.M, self.N, nbytes) + + def payload_module(self) -> ir.Module: + """Generate MLIR module for layer_norm payload.""" + return generate_gpu_layer_norm_payload( + func_name=self.payload_function_name, + M=self.M, + N=self.N, + dtype=self.elem_type, + eps=self.eps, + ) + + def schedule_modules( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> list[ir.Module]: + """Generate transform schedule for layer_norm.""" + schedules = [] + schedules.append(Runner.get_bench_wrapper_schedule(self.payload_function_name)) + + schedules.append( + layer_norm_schedule( + stop_at_stage=stop_at_stage, + parameters=parameters, + ) + ) + + if stop_at_stage and stop_at_stage != "final": + return schedules + + schedules.append(xegpu_to_binary()) + + return schedules + + def shared_libs(self) -> list[str]: + return ["libmlir_levelzero_runtime.so"] + + +def parse_cli(): + parser = argparse.ArgumentParser( + description="LayerNorm using MLIR XeGPU", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--sizes", + type=int, + nargs=2, + default=[1024, 512], + help="M,N matrix sizes (MxN)", + ) + parser.add_argument( + "--wg-rows", + type=int, + default=64, + help="Number of rows per workgroup.", + ) + parser.add_argument( + "--sg-rows", + type=int, + default=8, + help="Number of rows per subgroup.", + ) + parser.add_argument( + "--subgroup-size", + type=int, + default=16, + help="Subgroup size.", + ) + parser.add_argument( + "--reduction-step-size", + type=int, + default=16, + help="Step size for reduction loop tiling.", + ) + parser.add_argument( + "--eps", + type=float, + default=1e-5, + help="Epsilon added to variance for numerical stability.", + ) + parser.add_argument( + "--nruns", + type=int, + default=1000, + help="Number of runs to average the execution time.", + ) + parser.add_argument( + "--nwarmup", + type=int, + default=20, + help="Number of warm-up iterations before benchmarking.", + ) + parser.add_argument( + "--check-result", + action="store_true", + help="Check the result of the layer_norm computation.", + ) + parser.add_argument( + "--dump-kernel", + type=str, + choices=[ + "initial", + "tiled", + "vectorized", + "bufferized", + "gpu-outlining", + "xegpu-initial", + "xegpu-wg", + "final", + ], + help="Dump kernel IR at different stages of lowering and exit without " + "executing the kernel.", + ) + parser.add_argument( + "--dump-schedule", + action="store_true", + help="Dump transform schedule.", + ) + parser.add_argument( + "--verbose", + "-v", + action="count", + default=0, + help="Increase output verbosity.", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_cli() + + params = { + "sizes": args.sizes, + "wg_rows": args.wg_rows, + "sg_rows": args.sg_rows, + "subgroup_size": args.subgroup_size, + "reduction_step_size": args.reduction_step_size, + } + + M, N = args.sizes + dtype = "f32" + + with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPULayerNorm(M=M, N=N, dtype=dtype, eps=args.eps) + + if args.dump_kernel or args.dump_schedule: + pipeline = TransformDriver( + wload.schedule_modules( + stop_at_stage=args.dump_kernel, parameters=params + ) + ) + payload = pipeline.apply(wload.payload_module()) + if args.dump_kernel: + print(payload) + if args.dump_schedule: + for schedule_module in wload.schedule_modules(parameters=params): + print(schedule_module) + else: + pipeline = TransformDriver(wload.schedule_modules(parameters=params)) + payload = pipeline.apply(wload.payload_module()) + runner = Runner( + payload, + mem_manager_cls=wload.memory_manager_class, + shared_libs=wload.shared_libs(), + ) + if args.check_result: + result_host_copy = np.zeros(wload.shape, dtype=wload.dtype) + argument_access_callback = Runner.get_gpu_argument_access_callback( + result_host_copy, arg_index=0 + ) + + runner.execute( + host_input_buffers=wload._initial_host_arrays, + payload_function_name=wload.payload_function_name, + argument_access_callback=argument_access_callback, + ) + + _, input_arr, gamma_arr, beta_arr = wload._initial_host_arrays + success = check_correctness( + input_arr, + gamma_arr, + beta_arr, + result_host_copy, + eps=wload.eps, + verbose=args.verbose, + ) + if not success: + raise ValueError("Result mismatch!") + else: + print("Result is correct. Proceeding to benchmark...") + + times = runner.benchmark( + host_input_buffers=wload._initial_host_arrays, + nruns=args.nruns, + nwarmup=args.nwarmup, + ) + times *= 1e6 # convert to microseconds + elapsed = np.mean(times) + flop_count = wload.get_complexity()[0] + gflops = flop_count / (elapsed * 1e-6) / 1e9 + + def list2str(a): + return ",".join(map(str, a)) + + print( + f"sizes={list2str(args.sizes)} " + f"dt={dtype} " + f"wg-rows={args.wg_rows} " + f"sg-rows={args.sg_rows} " + f"subgroup-size={args.subgroup_size} " + f"time(us): {elapsed:.2f} " + f"GFLOPS: {gflops:.2f} " + ) diff --git a/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py new file mode 100644 index 00000000..c72d1772 --- /dev/null +++ b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py @@ -0,0 +1,127 @@ +"""Generate MLIR payload for GPU layer_norm operation.""" + +from mlir import ir +from mlir.dialects import linalg, bufferization, tensor, arith, math + +from lighthouse.utils.mlir import func_cif +from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs +from lighthouse.ingress.mlir_gen.utils import ( + emit_buf_to_tensor, + affine_map, + parallel, + reduction, +) + + +def generate_gpu_layer_norm_payload( + func_name: str, + M: int, + N: int, + dtype: ir.Type, + eps: float = 1e-5, +) -> ir.Module: + """ + Generate MLIR module for layer_norm payload. + + Computes layer normalization along the last dimension (rows): + mean_i = (1/N) * sum_j x[i, j] + var_i = (1/N) * sum_j (x[i, j] - mean_i)^2 + out[i, j] = (x[i, j] - mean_i) / sqrt(var_i + eps) * gamma[j] + beta[j] + + Args: + func_name: Name of the payload function + M: Number of rows + N: Number of columns (normalization dimension) + dtype: MLIR element type (e.g., F32Type) + eps: Small constant added to variance for numerical stability + + Returns: + MLIR module containing the layer_norm payload function + """ + mod = ir.Module.create() + shape = (M, N) + reduce_shape = (M,) + bias_shape = (N,) + memref_t = ir.MemRefType.get(shape, dtype) + bias_memref_t = ir.MemRefType.get(bias_shape, dtype) + + # Affine maps used by the linalg.generic ops below. + # Iteration space is (i, j); reductions reduce over j. + par_map_2d = affine_map(2, [ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)]) + red_map_2d = affine_map(2, [ir.AffineDimExpr.get(0)]) + bias_map_2d = affine_map(2, [ir.AffineDimExpr.get(1)]) + + inv_N = 1.0 / float(N) + + with ir.InsertionPoint(mod.body): + # Function signature: payload(output, input, gamma, beta) + @func_cif(memref_t, memref_t, bias_memref_t, bias_memref_t, name=func_name) + def payload(output, input_arg, gamma_arg, beta_arg): + emit_buf_to_tensor(output, restrict=True, writable=True) + input_tensor = emit_buf_to_tensor(input_arg, restrict=True) + gamma_tensor = emit_buf_to_tensor(gamma_arg, restrict=True) + beta_tensor = emit_buf_to_tensor(beta_arg, restrict=True) + + zero = arith.constant(dtype, 0.0) + inv_n_const = arith.constant(dtype, inv_N) + eps_const = arith.constant(dtype, eps) + + # 1) Mean reduction: mean_sum[i, 0] = sum_j x[i, j] + mean_init = tensor.empty(reduce_shape, dtype) + mean_acc = linalg.fill(zero, outs=[mean_init]) + + @linalg.generic( + [input_tensor], + [mean_acc], + [par_map_2d, red_map_2d], + [parallel, reduction], + ) + def mean_sum(x, acc): + return arith.AddFOp(x, acc) + + # 2) Variance reduction: var_sum[i, 0] = sum_j (x[i, j] - mean_i)^2 + # where mean_i = mean_sum[i, 0] * (1/N) + var_init = tensor.empty(reduce_shape, dtype) + var_acc = linalg.fill(zero, outs=[var_init]) + + @linalg.generic( + [input_tensor, mean_sum], + [var_acc], + [par_map_2d, red_map_2d, red_map_2d], + [parallel, reduction], + ) + def var_sum(x, m_sum, acc): + mean = arith.MulFOp(m_sum, inv_n_const).result + centered = arith.SubFOp(x, mean).result + sq = arith.MulFOp(centered, centered).result + return arith.AddFOp(sq, acc) + + # 3) Final elementwise: + # out[i, j] = (x[i, j] - mean_i) * rsqrt(var_i + eps) * gamma[j] + beta[j] + out_init = tensor.empty(shape, dtype) + + @linalg.generic( + [input_tensor, mean_sum, var_sum, gamma_tensor, beta_tensor], + [out_init], + [par_map_2d, red_map_2d, red_map_2d, bias_map_2d, bias_map_2d, par_map_2d], + [parallel, parallel], + ) + def normalized(x, m_sum, v_sum, g, b, _out): + mean = arith.MulFOp(m_sum, inv_n_const).result + var = arith.MulFOp(v_sum, inv_n_const).result + var_eps = arith.AddFOp(var, eps_const).result + inv_std = math.rsqrt(var_eps) + centered = arith.SubFOp(x, mean).result + scaled = arith.MulFOp(centered, inv_std).result + weighted = arith.MulFOp(scaled, g).result + return arith.AddFOp(weighted, b) + + bufferization.materialize_in_destination( + None, normalized, output, restrict=True, writable=True + ) + + # Emit utility functions for GPU memory management + emit_gpu_util_funcs(dtype, rank=2) + emit_gpu_util_funcs(dtype, rank=1) + + return mod diff --git a/lighthouse/schedule/xegpu/__init__.py b/lighthouse/schedule/xegpu/__init__.py index 23d9ef0c..17b8e803 100644 --- a/lighthouse/schedule/xegpu/__init__.py +++ b/lighthouse/schedule/xegpu/__init__.py @@ -1,8 +1,10 @@ from .xegpu_to_binary import xegpu_to_binary from .mlp_schedule import mlp_schedule from .softmax_schedule import softmax_schedule +from .layer_norm_schedule import layer_norm_schedule __all__ = [ + "layer_norm_schedule", "mlp_schedule", "softmax_schedule", "xegpu_to_binary", diff --git a/lighthouse/schedule/xegpu/layer_norm_schedule.py b/lighthouse/schedule/xegpu/layer_norm_schedule.py new file mode 100644 index 00000000..3c42bef1 --- /dev/null +++ b/lighthouse/schedule/xegpu/layer_norm_schedule.py @@ -0,0 +1,293 @@ +"""Generate MLIR transform schedule for XeGPU layer_norm operation.""" + +from typing import Optional + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured, loop, xegpu +from mlir.dialects.transform import bufferization as transform_bufferization +from mlir.dialects.bufferization import LayoutMapOption + +from lighthouse.pipeline.helper import ( + apply_registered_pass, + canonicalize, + match, + match_and_split, + PipelineInterrupt, +) +from lighthouse.schedule import schedule_boilerplate +from lighthouse.dialects.transform import transform_ext + + +def layer_norm_schedule( + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, +) -> ir.Module: + """ + Generate transform schedule for layer_norm operation. + + The schedule performs the following transformations: + 1. Tile the outer parallel dimension (rows) using forall + 2. Tile the inner reductions (mean / variance) using for + 3. Vectorize operations + 4. Bufferize tensors + 5. Convert to GPU dialect + 6. Lower to XeGPU operations + + Args: + stop_at_stage: Optional stage name to stop early (for debugging) + parameters: Dictionary with scheduling parameters: + - wg_rows: Number of rows per workgroup + - sg_rows: Number of rows per subgroup + - subgroup_size: Size of subgroup + - sizes: Tuple with the sizes of the input tensors (e.g. (M, N)) + - reduction_step_size: Step size for tiling the reduction loops + """ + assert parameters is not None, "Schedule parameters must be provided" + + with schedule_boilerplate() as (schedule, named_seq): + anytype = transform.AnyOpType.get() + func = match(named_seq.bodyTarget, ops={"func.func"}) + payload_mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + + try: + bundle_xegpu_layer_norm_schedule( + payload_mod, + parameters=parameters, + stop_at_stage=stop_at_stage, + ) + except PipelineInterrupt: + pass + finally: + transform.yield_() + + return schedule + + +def bundle_xegpu_layer_norm_schedule( + mod: ir.Value, + parameters: dict, + stop_at_stage: str = "", +) -> ir.Value: + """Schedule for lowering layer_norm payload to xegpu wg level. + + The payload (see ``generate_gpu_layer_norm_payload``) consists of: + - linalg.fill (init mean accumulator) + - linalg.generic (mean reduction) + - linalg.fill (init var accumulator) + - linalg.generic (var reduction) + - linalg.generic (final normalize: elementwise) + """ + + if stop_at_stage == "initial": + raise PipelineInterrupt() + + anytype = transform.AnyOpType.get() + reduction_step_size = parameters["reduction_step_size"] + + # Get the payload function by anchoring on the last linalg.generic + # (the elementwise normalize op, which is the only op with 2 parallel iterators). + all_generics = match(mod, ops={"linalg.generic"}) + # Split: 3 generics in total (mean reduction, var reduction, normalize). + gen_ops = transform.split_handle( + (anytype,) * 3, all_generics + ) + mean_reduction = gen_ops[0] + var_reduction = gen_ops[1] + normalize_op = gen_ops[2] + + tiled_op, forall_op = structured.structured_tile_using_forall( + anytype, + anytype, + normalize_op, + num_threads=[], + tile_sizes=[], + static_tile_sizes=(parameters["wg_rows"],), + ) + + # Fuse the two reductions into the forall. + _, forall_op = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=var_reduction, + containing_op=forall_op, + ) + _, forall_op = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=mean_reduction, + containing_op=forall_op, + ) + + func = transform.get_parent_op( + anytype, + forall_op, + op_name="func.func", + deduplicate=True, + ) + + # Fuse the (M,)-sized accumulator init fills (mean_acc, var_acc) into the + # forall as well. Fusing the reductions only slices the fill *results* + # inside the loop; the full-size fills themselves stay at function scope. + # If left outside, bufferization turns them into a function-scope + # memref that is passed into gpu.launch as a host pointer, which the + # device then dereferences -> GPU page fault. Privatizing them into the + # forall makes each a per-workgroup (wg_rows,) init inside the kernel. + fill_ops = match(func, ops={"linalg.fill"}) + _, forall_op = structured.structured_fuse_into_containing_op( + anytype, + anytype, + producer_op=fill_ops, + containing_op=forall_op, + ) + + # Drop the dead originals of the reductions (canonicalize removes them) + # before re-matching, otherwise we'd see 5+ generics instead of 3. + transform.apply_cse(func) + canonicalize(func) + + # Re-match the linalg.generic ops after fusion: 3 generics remain + # (mean reduction, var reduction, normalize). + linalg_ops = match_and_split(func, ops={"linalg.generic"}, nhandles=3) + mean_reduction = linalg_ops[0] + var_reduction = linalg_ops[1] + normalize_op = linalg_ops[2] + + # Tile the elementwise normalize along its inner (column) dim. + _, normalize_loop = structured.TileUsingForOp( + normalize_op, sizes=[0, reduction_step_size] + ).results + + # Tile the variance reduction along its reduction dim. + structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=var_reduction, + tile_sizes=[0, reduction_step_size], + ) + + # Tile the mean reduction along its reduction dim. + structured.structured_tile_reduction_using_for( + [anytype], + anytype, + anytype, + anytype, + target=mean_reduction, + tile_sizes=[0, reduction_step_size], + ) + + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "tiled": + raise PipelineInterrupt() + + # vectorize + func = structured.VectorizeChildrenAndApplyPatternsOp( + func, + fold_type_extensions_into_contract=True, + ).result + transform.apply_cse(func) + canonicalize(func) + + if stop_at_stage == "vectorized": + raise PipelineInterrupt() + + # bufferize + mod = apply_registered_pass(mod, "eliminate-empty-tensors") + identity_layout = LayoutMapOption.IdentityLayoutMap + mod = transform_bufferization.OneShotBufferizeOp( + mod, + allow_return_allocs_from_loops=True, + bufferize_function_boundaries=True, + function_boundary_type_conversion=identity_layout, + ).result + mod = apply_registered_pass(mod, "fold-memref-alias-ops") + transform.apply_cse(mod) + canonicalize(mod) + + # promote memref.alloc to memref.alloca in payload function + func = match(mod, ops={"func.func"}) + func = apply_registered_pass( + func, + "promote-buffers-to-stack", + options={ + "max-alloc-size-in-bytes": "8192", + "max-rank-of-allocated-memref": "2", + }, + ) + + if stop_at_stage == "bufferized": + raise PipelineInterrupt() + + # convert forall to parallel + wg_loops = match_and_split(mod, ops={"scf.forall"}) + for wg_loop in wg_loops: + wg_loop = loop.loop_forall_to_parallel([anytype], wg_loop) + func = transform.get_parent_op(anytype, wg_loop) + + # convert scf.parallel to gpu.launch + func = apply_registered_pass(func, "gpu-map-parallel-loops") + func = apply_registered_pass(func, "convert-parallel-loops-to-gpu") + func = apply_registered_pass(func, "lower-affine") + transform.apply_cse(func) + canonicalize(func) + + # set the number of threads for the gpu.launch operation + launch_op = match_and_split(func, ops={"gpu.launch"}) + num_subgroups = parameters["wg_rows"] // parameters["sg_rows"] + num_threads = num_subgroups * parameters["subgroup_size"] + xegpu.set_gpu_launch_threads(launch_op[0], threads=[num_threads, 1, 1]) + + # outline gpu func + func = apply_registered_pass(func, "lower-affine") + canonicalize(func) + func = apply_registered_pass(func, "gpu-launch-sink-index-computations") + mod = apply_registered_pass(mod, "gpu-kernel-outlining") + transform.apply_cse(mod) + + if stop_at_stage == "gpu-outlining": + raise PipelineInterrupt() + + # set xevm target + mod = apply_registered_pass( + mod, + "xevm-attach-target", + options={"O": "3", "chip": "bmg"}, + ) + + # for each gpu function in the gpu module, change memref.alloca address + # space to 3 (SLM) and convert vector to xegpu. + gpu_mod_ops = match_and_split(mod, ops={"gpu.module"}) + for gpu_mod in gpu_mod_ops: + gpu_func = match(gpu_mod, ops={"gpu.func"}) + allocas = match(gpu_func, ops={"memref.alloca"}) + transform_ext.update_address_space(allocas, address_space=3) + gpu_func = apply_registered_pass(gpu_func, "convert-vector-to-xegpu") + transform.apply_cse(gpu_func) + + transform.apply_cse(mod) + canonicalize(mod) + + if stop_at_stage == "xegpu-initial": + raise PipelineInterrupt() + + # Set layout attributes for xegpu.store_nd and xegpu.store_matrix ops. + sg_layout = [parameters["sg_rows"], 1] + sg_data = [parameters["sg_rows"], parameters["reduction_step_size"]] + store_nd_ops = match(gpu_func, ops={"xegpu.store_nd"}) + xegpu.set_anchor_layout(store_nd_ops, sg_layout=sg_layout, sg_data=sg_data) + store_matrix_ops = match(gpu_func, ops={"xegpu.store_matrix"}) + xegpu.set_anchor_layout(store_matrix_ops, sg_layout=sg_layout, sg_data=sg_data) + + if stop_at_stage == "xegpu-wg": + raise PipelineInterrupt() + + return mod From 70755f48ee76e9b234508a192126a7e655a2a8d3 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Mon, 15 Jun 2026 22:54:12 +0000 Subject: [PATCH 2/3] Address feedback --- examples/xegpu/layer_norm.py | 9 ++++++++- lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py | 5 ----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/xegpu/layer_norm.py b/examples/xegpu/layer_norm.py index aac57fd8..8a6f52d9 100644 --- a/examples/xegpu/layer_norm.py +++ b/examples/xegpu/layer_norm.py @@ -119,13 +119,20 @@ def get_complexity(self) -> tuple[int, int, int]: def payload_module(self) -> ir.Module: """Generate MLIR module for layer_norm payload.""" - return generate_gpu_layer_norm_payload( + mod = generate_gpu_layer_norm_payload( func_name=self.payload_function_name, M=self.M, N=self.N, dtype=self.elem_type, eps=self.eps, ) + # Emit the memory management utility functions into the payload + # module: the 2D input/output and the 1D gamma/beta bias vectors. + ranks_and_types = [(2, self.elem_type), (1, self.elem_type)] + self.memory_manager_class.emit_memory_management_funcs( + mod, ranks_and_types=ranks_and_types + ) + return mod def schedule_modules( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None diff --git a/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py index c72d1772..ddc1e26e 100644 --- a/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py @@ -4,7 +4,6 @@ from mlir.dialects import linalg, bufferization, tensor, arith, math from lighthouse.utils.mlir import func_cif -from lighthouse.ingress.mlir_gen.gpu_utils import emit_gpu_util_funcs from lighthouse.ingress.mlir_gen.utils import ( emit_buf_to_tensor, affine_map, @@ -120,8 +119,4 @@ def normalized(x, m_sum, v_sum, g, b, _out): None, normalized, output, restrict=True, writable=True ) - # Emit utility functions for GPU memory management - emit_gpu_util_funcs(dtype, rank=2) - emit_gpu_util_funcs(dtype, rank=1) - return mod From fde048de2bdeb5890c0647caa8aa3b4e6794ea3d Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 16 Jun 2026 16:05:31 +0000 Subject: [PATCH 3/3] Fix format --- lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py | 9 ++++++++- lighthouse/schedule/xegpu/layer_norm_schedule.py | 4 +--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py index ddc1e26e..aa30b13c 100644 --- a/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py +++ b/lighthouse/ingress/mlir_gen/gpu_layer_norm_payload.py @@ -102,7 +102,14 @@ def var_sum(x, m_sum, acc): @linalg.generic( [input_tensor, mean_sum, var_sum, gamma_tensor, beta_tensor], [out_init], - [par_map_2d, red_map_2d, red_map_2d, bias_map_2d, bias_map_2d, par_map_2d], + [ + par_map_2d, + red_map_2d, + red_map_2d, + bias_map_2d, + bias_map_2d, + par_map_2d, + ], [parallel, parallel], ) def normalized(x, m_sum, v_sum, g, b, _out): diff --git a/lighthouse/schedule/xegpu/layer_norm_schedule.py b/lighthouse/schedule/xegpu/layer_norm_schedule.py index 3c42bef1..f8d35a92 100644 --- a/lighthouse/schedule/xegpu/layer_norm_schedule.py +++ b/lighthouse/schedule/xegpu/layer_norm_schedule.py @@ -94,9 +94,7 @@ def bundle_xegpu_layer_norm_schedule( # (the elementwise normalize op, which is the only op with 2 parallel iterators). all_generics = match(mod, ops={"linalg.generic"}) # Split: 3 generics in total (mean reduction, var reduction, normalize). - gen_ops = transform.split_handle( - (anytype,) * 3, all_generics - ) + gen_ops = transform.split_handle((anytype,) * 3, all_generics) mean_reduction = gen_ops[0] var_reduction = gen_ops[1] normalize_op = gen_ops[2]