Skip to content

WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59

Draft
ymwangg wants to merge 21 commits into
mainfrom
nkigen-lite
Draft

WIP: add nkigen-lite as a standalone IR-based kernel generation backend#59
ymwangg wants to merge 21 commits into
mainfrom
nkigen-lite

Conversation

@ymwangg

@ymwangg ymwangg commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds nkigen-lite, a standalone IR-based kernel generation backend that lowers numpy-style tensor programs to NKI (Neuron Kernel Interface) code for NeuronCore targets.

Architecture

The system is structured as a three-layer IR stack with a multi-phase lowering pipeline:

Core (core.py)

Shared SSA-based IR infrastructure used by both IRs:

  • Value, Op, Graph — SSA primitives with use-lists and mutation helpers
  • DType enum covering f32/f16/bf16/tf32/fp8/int types
  • Common graph utilities: DCE, verification, toposort
  • Shared numpy interpreter dispatch tables

Tensor IR (tensor_ir/)

High-level, hardware-agnostic IR operating on whole tensors:

  • SSA-based — every op produces new Value(s), enabling clean analysis and transformation
  • Numpy-like builder API — familiar interface for constructing kernel graphs
  • Numpy interpreter — executes the IR with real data for correctness checking
  • Ops: elementwise (unary/binary), reduce, matmul, transpose, reshape, slice, concat, broadcast

NKI IR (nki_ir/)

Low-level IR that makes hardware concerns explicit:

  • Memory spaces — every value carries HBM/SBUF/PSUM placement
  • Partition dimension — dim 0 of on-chip tiles is the partition dim (max 128)
  • Explicit memory management — alloc/dealloc + DMA copies for data movement
  • Pre-allocated destinations — all compute ops take a dst parameter
  • Tile indexing — DimSlice-based indexing (ts/ds) mirroring Kernel Builder
  • Loop constructs — fori_loop for explicit tile iteration
  • Hardware verifier — checks tile constraints against target specs
  • Numpy interpreter — reference execution without hardware
  • Emit to Kernel Builder — walks the graph and invokes KB API calls to produce NISA MLIR

Lowering Pipeline (tensor_ir/passes/)

The full pipeline: tensor_ir → canonicalize → decompose → layout_solver → direct_lower → nki_ir

  1. Canonicalize — recomposes primitive-op chains into high-level ops (e.g., div(1, sqrt(x))rsqrt(x), mul(x, div(1, add(1, exp(neg(x)))))silu(x))

  2. Decompose — lowers ops without direct NISA equivalents into supported primitives (e.g., div(a,b)mul(a, reciprocal(b)), reduce(mean)reduce(sum) * 1/N)

  3. Layout Solver — assigns each tensor dimension to one of three roles:

    • I (iteration) — loop indices, not in SBUF tile
    • P (partition) — SBUF dim-0, product ≤ 128, parallel compute
    • F (free) — SBUF dim-1, contiguous per partition

    Propagates constraints across the graph to find a globally consistent assignment.

  4. Direct Lower — converts tensor IR ops to tiled NKI IR:

    • Segments ops into elementwise groups (fused on-chip) vs individual non-elementwise ops (HBM boundaries)
    • Generates tiled load→compute→store sequences
    • Per-op lowering modules: memory, elementwise, reduce, matmul, transpose, broadcast
    • Inserts deallocs via liveness analysis after lowering

Hardware Target (passes/hardware.py)

Parameterized hardware profiles (TRN2 defaults) defining partition limits, SBUF/PSUM sizes, and matmul constraints.

Status

🚧 Work in progress — not ready for review.

Test plan

  • Full test suite passes (uv run pytest nkigen-lite/tests/ -n auto)
  • Integration with main nkipy package verified
  • End-to-end lowering produces correct NKI IR for representative patterns

ymwangg added 21 commits June 1, 2026 22:24
Migrates tensor_ir, nki_ir, and the direct lowering passes from
nano-tensorizer/ir_lab into the nkipy workspace as a new package.
The pipeline (canonicalize → decompose → layout_solver → direct_lower)
produces legal NKI IR directly without intermediate passes.
Add nkigen-lite as a fully functional backend (backend="nkigen-lite")
alongside hlo and nkigen. The pipeline traces Python kernels through
nkigen_lite's tensor_ir Builder, lowers via the pass pipeline
(canonicalize → decompose → layout_solver → direct_lower), and compiles
to NEFF via the NKI kernel_builder API.

nkipy integration:
- backend/nkigen_lite.py: TraceContext, Tensor, IR adapter
- ops/_nkigen_lite_impls.py: op implementations delegating to Builder
- ops/_register_nkigen_lite.py: lazy op registration
- trace.py: _specialize_nkigen_lite() dispatch
- compile.py: _compile_nkigen_lite() via kernel_builder
- knob.py, nki_op.py: backend-aware dispatch

nkigen-lite enhancements:
- Builder: add abs, sign, floor, ceil, power, floor_divide, mod ops
- Interpreter: numpy dispatch for new ops, dtype-aware tensor_copy
- Decompose pass: floor_divide/mod use divide-then-verify-and-correct
  strategy (matching neuronx-cc BIR), power→exp(b*log(a)),
  ceil→neg(floor(neg(x))), fixed-point iteration with max-iter guard
- Direct lowering: abs/sign/sin via NisaActivationOp, floor via i32
  truncation + sign correction, cast via tensor_copy, 1D reshape fix
- docs/floor_divide_precision.md: documents the precision strategy

Test results: 134/135 HLO-parity tests pass on trn2 hardware (99.3%).
Add "nkigen-lite" to the trace_mode fixture so all parametrized tests
run with both backends. Add a pytest hook that marks NotImplementedError
as xfail for nkigen-lite — ops not yet implemented show as expected
failures and automatically start passing when added.

Current results:
- HLO: 741 passed, 4 xfailed, 42 skipped
- nkigen-lite: ~340 passed, 161 xfailed (unimplemented ops), ~93 failed
  (partial implementations needing further work)
Add ReduceKeepdimsFalsePattern to decompose keepdims=False reductions
into keepdims=True + reshape, which the layout solver and lowering
require. Handle scalar (rank-0) tensors throughout the lowering pipeline
by promoting them to (1,) at the NKI boundary since the hardware doesn't
support rank-0 tensors.

Also fix negative axis normalization in squeeze() and expand_dims().
- matmul: add 1D→2D promotion following NumPy semantics
- squeeze: validate non-1 dims, normalize negative axis
- reshape: handle int newshape argument
- zeros/full: handle int shape argument
- concatenate: handle single-tensor case, validate empty/axis bounds
- split: validate axis bounds and unequal division
- where: handle numpy array condition argument
- _ensure_value: handle numpy array operands (uniform-fill)
- expand_dims: validate duplicate axes and out-of-bounds axis
- Skip test_reduce_unsupported_op and test_topk_non_last_axis for
  non-HLO backends since they test HLO-specific internal behavior
NeuronCore hardware only supports Add and Max for cross_lane_reduce_arith.
Implement MIN as -max(-x) transparently in the NKI IR builder so all
existing P-dimension reduce codepaths work with min reductions.
Replace HLO-specific DeviceKernel.compile_and_load path with the
shared on_device_test utility which handles input/output naming
differences between backends automatically.
- broadcast_to: handle scalar (rank-0) source by loading the single
  element and broadcasting via tensor_scalar_arith with ones
- emit_to_kb: auto-cast f16/bf16 operands to f32 around
  tensor_scalar_arith since the hardware scalar engine requires f32
Add NisaBitvecOp enum and tensor_tensor_bitvec builder method to NKI IR.
Wire through the full pipeline: tensor IR opcodes, elementwise lowering,
emit_to_kb mapping, and interpreter support. Replace the old arithmetic
approximations (which only worked for booleans) with hardware bitwise
instructions that work correctly on integer types.
Add NKI IR primitives for mixed-dtype operations:
- tensor_tensor_compare: comparison ops (IsGT, IsGE, etc.) that accept
  float inputs and produce uint8 predicate output
- tensor_scalar_bitvec: scalar bitvec ops (XOR for logical NOT, etc.)
- Comparison and logical op variants in NisaArithOp enum

Rewrite _emit_floor to use the NKI compiler's compare+select pattern:
trunc→compare→conditional select in integer domain, avoiding float
precision issues in the correction step.
…lice

- emit_slice: add strides parameter; delegate to _emit_strided_slice
  for non-unit strides (element-by-element DMA for F-stride, row-by-row
  for P-stride)
- dynamic_update_slice: handle numpy array value argument (uniform fill)
Add DType.FP8_E4M3_IEEE for the IEEE-standard float8_e4m3 format
(distinct from the NaN-free float8_e4m3fn variant already supported).
Wire through core, emit_to_kb, and compile dtype mappings.
Each pytest-xdist worker now claims a specific Neuron core via
NEURON_RT_VISIBLE_CORES, enabling parallel test execution across
all 64 available cores (~8.5x speedup).
- Comparison ops (equal, not_equal, greater, less, etc.) now produce
  same dtype as input (1.0/0.0 float) matching NKI convention, instead
  of DType.BOOL
- where op lowered using NKI pattern: cond*x + (1-cond)*y with all
  float arithmetic — no mixed-dtype operations needed
- Map DType.BOOL → uint8 in kernel builder and execution layer
- Update tensor IR builder to remove BOOL requirement from where
- Reduces xfail count from 162 → 125 (37 tests now passing)
Matches NKI compiler's approach: cos(x) = sin(x + π/2).
The hardware sin activation instruction handles the computation.
Implement np.dot semantics as a composed op:
- 1D/2D cases delegate directly to matmul
- N-D × 1D delegates to matmul (batched matrix-vector)
- N-D × M-D decomposes to reshape + matmul + reshape to achieve
  the outer-product batch semantics of np.dot
- arctan: wire native NISA ARCTAN activation through the Builder and
  direct-lower tables
- invert/bitwise_not: composed_impl as XOR with all-ones (-1), matching
  the NKI compiler's implementation
- logical_and/or/xor: composed_impl via 0/1 truthiness; also unblocks
  rint/round which depend on logical_and
- constant: backend impl mirroring HLO (passthrough + uniform fill);
  non-uniform array constants raise NotImplementedError

Fix a pre-existing bug in _emit_broadcast_scalar that fed a (1,1) tile
to tensor_scalar_arith whose scalar operand partition dim must match the
destination; replicate to (p_size, 1) via broadcast_partition.

Also make test_ml_dtypes_constant_encoding's float8 xfails backend-aware
so float8_e5m2 on nkigen-lite no longer reports XPASS.
The slice-based gather produced wrong output shapes: it ignored
axis=None (no flatten), concatenated slices along the original axis
instead of replacing it with indices.shape, and mishandled scalar and
multi-dimensional index arrays.

Rewrite to match numpy:
  out.shape == a.shape[:axis] + indices.shape + a.shape[axis+1:]

- axis=None flattens the input first
- negative indices are normalized modulo the axis dimension
- each flat index becomes a width-1 slice; slices are concatenated then
  reshaped so the gathered axis is replaced by indices.shape (dropped
  entirely for a scalar index)

Fixes 13 failing test_take_scalar / test_take_numpy_indices cases.
…) for nkigen-lite

Wires the four distributed collectives through the full nkigen-lite stack:

- tensor_ir Builder: collective ops with correct output-shape inference
  (all_gather grows the gather dim, reduce_scatter/all_to_all shrink/grow
  by world size)
- nkipy lite impls + registration, mapping numpy reduce ufuncs to the
  collective reduce-op names
- direct_lower: stage collectives through internal HBM scratch buffers
  (the compiler forbids collectives from reading/writing kernel IO
  tensors directly — "Collective instruction cannot read IO tensors")
- nki_ir Builder: collective() side-effect node (HBM->HBM)
- emit_to_kb: lower to nisa.all_reduce/all_gather/reduce_scatter/all_to_all
  via ExplicitReplicaGroupAttr + dma_compute_reduce_op

The KB collective API only operates on the last (free) axis of 2D HBM
tensors (cc_dim=0 raises std::bad_cast), so all_gather/reduce_scatter
along other axes are staged via transpose-collective-transpose.

Fixes the all_reduce/all_gather/reduce_scatter/all_to_all xfails
(multiply-reduce variants stay xfailed for both backends — unsupported
by the compiler).
The earlier transpose workaround for all_gather/reduce_scatter was based on
a misdiagnosis: cc_dim=0 appeared to raise std::bad_cast, so collectives
were staged through a transpose to operate on the last axis. Multi-core
numerical verification showed that path silently dropped the remote rank's
data (all_gather duplicated the local source; reduce_scatter ignored the
per-rank scatter offset).

Root cause: the KB nisa collective APIs forward cc_dim to the native
builder un-converted, so a bare int 0 fails the int->enum cast. The NKI
collectives contract also requires collective_dim=0 for HBM tensors.

Fix:
- emit_to_kb: convert the int dim to CollectiveDimension (DIM_0/DIM_1)
  before calling nisa.all_gather/reduce_scatter/all_to_all
- drop the transpose workaround; gather/scatter along the requested dim
  directly

Verified on 2 NeuronCores with distinct per-rank data: all_reduce,
all_gather(dim0), reduce_scatter(dim0), and all_to_all all produce the
correct cross-rank results.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant