diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9d19a95136ad7..38d101786b41a 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -4949,7 +4949,7 @@ This version of the operator has been available since version 1 of the 'com.micr
use_sparse_mixer : int
Whether to use sparse mixer
weights_prepacked : int
-Only meaningful when quant_type='int'. Tri-state control over whether the int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS fpA_intB format expected by the runner. -1 (auto): let the execution provider choose its own backward-compatible default; the CUDA EP treats auto as prepacked. 1: the initializers are already prepacked (e.g. produced offline by pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself in PrePack(), matching the behaviour of MatMulNBits and removing the offline pre-pack requirement from exporters. Defaults to -1 (auto) so each execution provider can pick its own backward-compatible default rather than the schema imposing one.
+Only meaningful when quant_type='int'. Tri-state control over the layout of the int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by -1 and 1 are determined by the execution provider. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.
#### Inputs (6 - 21)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index a3b38b17bd874..5fe7dd020d559 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -999,8 +999,10 @@ The **OpSet Version** column uses the following notation:
|Softmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Softplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Softplus|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
+|Softsign|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
+|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
|SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**
or
*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**
or
*in* input:**T**
*out* outputs:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md b/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md
index 34092fe9a0307..60735119fcf1e 100644
--- a/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md
+++ b/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md
@@ -178,6 +178,44 @@ Nodes that do not match any rule fall through to the normal EP capability-based
> **Note — Annotations vs. actual placement:** An annotation expresses a *preference*, not a guarantee. If the target EP does not have a registered kernel for a node (for example, a particular data-type / opset-version combination is not implemented in the CUDA EP), that node will not be placed on the requested device. Instead it falls through to the next EP in the provider list that can handle it.
+### Name-Based Layer Assignment (No Model Modification)
+
+For models that already have structured node names (most HuggingFace exports, ONNX models produced by PyTorch, etc.), you can skip the annotation step entirely. The session option `session.name_based_layer_assignment` performs **substring matching** directly against `Node::Name()`:
+
+```
+device1(pattern1, pattern2, ...); device2(pattern3, pattern4, ...)
+```
+
+- **Substring matching:** A pattern matches if it appears *anywhere* in the node name. For example, `layers.0/` matches `/model/layers.0/self_attn/q_proj/MatMul`.
+- **Longest match wins:** When multiple patterns match the same node name, the longest pattern takes priority. For example, `layers.10/` wins over `layers.1/` for a node named `/model/layers.10/...`.
+- **No `=` prefix:** The exact-match qualifier (`=`) from annotation-based syntax is rejected with an error. All patterns are treated as substrings.
+- **Same device designators:** The device portion uses the same device designators as `session.layer_assignment_settings` (see table above).
+
+```python
+import onnxruntime as ort
+
+opts = ort.SessionOptions()
+
+# Assign layers 0–7 to GPU, layers 8–15 to CPU based on node names
+opts.add_session_config_entry(
+ "session.name_based_layer_assignment",
+ "gpu(layers.0/, layers.1/, layers.2/, layers.3/, layers.4/, layers.5/, layers.6/, layers.7/); "
+ "cpu(layers.8/, layers.9/, layers.10/, layers.11/, layers.12/, layers.13/, layers.14/, layers.15/)"
+)
+
+session = ort.InferenceSession("model.onnx", opts,
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
+```
+
+**Tips for writing patterns:**
+- Include the trailing `/` in layer patterns (e.g., `layers.1/` instead of `layers.1`) to avoid `layers.1` accidentally matching `layers.10`, `layers.11`, etc.
+- Use [Netron](https://netron.app/) to inspect your model's node names and identify suitable substrings.
+- Nodes that do not match any pattern fall through to normal EP capability-based assignment (typically CPU).
+
+**Mutual exclusivity with annotation-based matching:** The `session.name_based_layer_assignment` and `session.layer_assignment_settings` options are **mutually exclusive** — setting both will return an error. Use annotation-based matching for models that carry explicit `layer_ann` metadata annotations, or name-based matching for unmodified models with structured node names. If you need fine-grained exceptions (e.g., force one specific node to CPU), add the node's name pattern to the name-based config instead of mixing the two approaches.
+
+**No subgraph inheritance:** Unlike annotation-based matching (where unannotated subgraph nodes inherit their parent's device assignment), name-based matching treats every node independently. Since node names are dense (virtually every node has a name encoding its structural position), inheritance is unnecessary — each node matches on its own name.
+
## Capacity-Aware Partitioning (implemented for CUDA)
When running models on a CUDA GPU with limited memory, you can set a memory budget so ONNX Runtime stops assigning nodes to the CUDA EP once the estimated memory consumption reaches the limit. Nodes are considered in topological order and assignment halts at the first node that would exceed the budget — ONNX Runtime does not search ahead for smaller nodes that might still fit. Remaining nodes are then eligible for assignment by the subsequent EPs in the session's provider list (often CPU, but not necessarily).
@@ -292,26 +330,30 @@ EPs that prefer the NHWC data layout — for example, the CUDA EP when it is cre
Because the first-pass tags are tentative, ONNX Runtime does **not** commit any memory budget for them. The budget is committed only for the nodes that survive the second pass; the cost of a node that is dropped is never counted against the memory limit. This keeps the accumulated memory estimate accurate when `prefer_nhwc` is combined with `session.resource_cuda_partitioning_settings`, so a dropped node does not consume phantom budget that could prematurely halt assignment of later nodes.
-## Combining Both Features
-Layer annotations and capacity-aware partitioning can be used together. When both are configured:
-- Layer annotations provide the initial node-to-device mapping.
+## Combining Features
+Layer annotations OR name-based assignment can be combined with capacity-aware partitioning. Note that annotation-based and name-based matching are **mutually exclusive** — you cannot use both simultaneously.
+
+When a layer assignment option (either annotation-based or name-based) is configured together with the capacity-aware partitioner:
+- The layer assignment option expresses the desired device placement.
- The capacity-aware partitioner enforces the memory budget, potentially overriding assignments that would exceed the GPU memory limit.
-This combination gives you fine-grained control: use annotations to express logical model structure, and let the memory budget act as a safety net.
+This gives you fine-grained control: use annotations or name patterns to express logical model structure, and let the memory budget act as a safety net.
```python
opts = ort.SessionOptions()
+# Name-based assignment (no model modification needed)
opts.add_session_config_entry(
- "session.layer_assignment_settings",
- "gpu(encoder, decoder); cpu(=postprocess)"
+ "session.name_based_layer_assignment",
+ "gpu(layers.0/, layers.1/, layers.2/, layers.3/); cpu(layers.4/, layers.5/, layers.6/, layers.7/)"
)
+# Memory budget as a safety net
opts.add_session_config_entry(
"session.resource_cuda_partitioning_settings",
"4194304,node_memory_stats.csv"
)
-session = ort.InferenceSession("model_annotated.onnx", opts,
+session = ort.InferenceSession("model.onnx", opts,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
```
diff --git a/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md b/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md
new file mode 100644
index 0000000000000..5d5b0136fe2ea
--- /dev/null
+++ b/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md
@@ -0,0 +1,432 @@
+# CUDA Kernel Workspace Buffer Inventory
+
+This document catalogs all CUDA kernels in ONNX Runtime that allocate temporary/workspace buffers via `GetScratchBuffer` or `GetTransientScratchBuffer`. For each kernel, it identifies what information is needed to compute the workspace size and whether that information is available at `GetCapability()` time (for the workspace estimation function design).
+
+## Classification Key
+
+| Symbol | Meaning |
+|--------|---------|
+| ✅ | Fully determinable from shapes + attributes + device properties |
+| ✅* | Determinable via cuDNN/cuBLAS API call (needs handle, available on EP) |
+| ⚠️ | Requires profiling/tactic selection (deterministic but costly at planning time) |
+
+---
+
+## Core CUDA Providers (`onnxruntime/core/providers/cuda/`)
+
+### 1. Attention (LLM — Opset 23/24)
+
+**File:** `llm/attention.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `softmax_lse_buffer` | `B * S_q * num_heads * sizeof(float)` | shapes |
+| `softmax_lse_accum_buffer` | From `get_num_splits_and_buffer_sizes()` | shapes + `multiProcessorCount` |
+| `out_accum_buffer` | From `get_num_splits_and_buffer_sizes()` | shapes + `multiProcessorCount` |
+| `q_bsnh_buffer` | `B * S_q * num_heads * head_size * element_size` | shapes + dtype |
+| `out_bsnh_buffer` | Same as Q | shapes + dtype |
+| `k_bsnh_buffer` / `v_bsnh_buffer` | `B * S_kv * num_kv_heads * head_size * element_size` | shapes + dtype |
+| `seqlens_k_buffer` | `B * sizeof(int)` | batch size |
+| `past_seqlens_buffer` | `B * sizeof(int)` | batch size |
+| `k_expand_buffer` / `v_expand_buffer` | `B * num_heads * S_kv * head_size * element_size` (GQA expansion) | shapes + dtype |
+| `converted_mask_buffer` | `B * S_q * S_kv * sizeof(float)` | shapes |
+| `present_k_scratch` / `present_v_scratch` | present KV cache size | shapes |
+| `workspace_buffer` (math attention) | `B * S_q * num_heads * sizeof(float)` | shapes |
+
+**What's needed to compute:** Input shapes, `num_heads` attribute, `head_size`, dtype, `device_prop.multiProcessorCount`.
+
+**Static determinability:** ✅ All pure arithmetic on shapes + device SM count.
+
+---
+
+### 2. Conv (cuDNN Frontend)
+
+**File:** `nn/conv.cc`, `nn/conv.h`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace` | `s_.cudnn_fe_graph->get_workspace_size()` | cuDNN plan selection |
+| `memory_for_cudnn_conv_results` (Conv8 only) | `y_dims_with_adjusted_pads.Size() * element_size` | output shape + padding |
+
+**What's needed to compute:** Input shapes (NCHW), weight shapes, pads/strides/dilations attributes, cuDNN handle (for `build_plans()`).
+
+**Static determinability:** ✅* — Requires calling `build_plans(handle)` with `HEUR_MODE_A`. The handle is on the EP. All tensor shapes and conv params come from node attributes.
+
+**Note:** `GetTransientScratchBuffer` (32MB) used for algorithm search in Conv8 — this is a one-time cost during first Compute, not a per-run workspace.
+
+---
+
+### 3. ConvTranspose
+
+**File:** `nn/conv_transpose.h`, `nn/conv_transpose_8.h`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace` | `s_.workspace_bytes` (from cuDNN FE or algo selection) | Same as Conv |
+| `AlgoSearchWorkspaceSize` (Conv8 path) | 32MB constant | N/A |
+
+**Static determinability:** ✅* — Same as Conv.
+
+---
+
+### 4. DeformConv
+
+**File:** `nn/deform_conv.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `col_buffer` | `C * kernel_size * col_stride * sizeof(T)` where `col_stride = n_parallel_imgs * output_image_size` | Input shapes, kernel_size, `device_prop.totalGlobalMem` (for chunk sizing) |
+
+**What's needed:** Input shape (N,C,H,W), kernel dims, output_image_size, `totalGlobalMem` (determines `n_parallel_imgs` via `GetNParallelImgs`).
+
+**Static determinability:** ✅ — Pure arithmetic on shapes + device memory size.
+
+---
+
+### 5. BatchNorm
+
+**File:** `nn/batch_norm.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `f_scale`, `f_B`, `f_mean`, `f_var` | `C * sizeof(float)` each | Channel dim (shape[1] or shape[3]) |
+
+**What's needed:** Channel dimension `C` from input shape.
+
+**Static determinability:** ✅ — Trivial: `4 * C * sizeof(float)`.
+
+---
+
+### 6. InstanceNorm
+
+**File:** `nn/instance_norm.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `mean`, `variance` | `N * C * sizeof(T)` | Batch × channels |
+| `unused_scale`, `unused_bias` | `N * C * sizeof(T)` | Same |
+| `scale_data_fp32`, `bias_data_fp32` | `C * sizeof(float)` (if fp16) | Channel dim + dtype |
+
+**What's needed:** Input shape (N, C), dtype.
+
+**Static determinability:** ✅ — Pure arithmetic on shapes.
+
+---
+
+### 7. Dropout
+
+**File:** `nn/dropout.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| mask buffer | `element_count * sizeof(bool)` or `element_count / 16 * sizeof(BitmaskElementType)` | Input element count, bitmask mode |
+
+**Static determinability:** ✅ — Input element count.
+
+---
+
+### 8. Reduction Ops (ReduceSum, ReduceMax, ReduceMean, etc.)
+
+**File:** `reduction/reduction_ops.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace_cuda` | `cudnnGetReductionWorkspaceSize()` | cuDNN handle, input/output tensor descriptors |
+| `indices_cuda` | `cudnnGetReductionIndicesSize()` | Same |
+| `temp_X` | `input_count * sizeof(float)` (type cast) | Input size |
+| `input_data_buffer` | `input_count * sizeof(T)` (for `calculate_sqt_`) | Input size |
+| `exp_result_buffer` | `input_count * sizeof(T)` (for `log_sum_exp_`) | Input size |
+| `log_sum_result_buffer` | `output_count * sizeof(T)` | Output size |
+| `temp_output` | `output_count * sizeof(float)` | Output size |
+
+**What's needed:** Input/output shapes, reduction axes, op variant (LogSumExp, L2, etc.), cuDNN handle.
+
+**Static determinability:** ✅* — cuDNN workspace query needs handle + tensor descriptors (constructible from shapes).
+
+---
+
+### 9. RNN (LSTM/GRU)
+
+**File:** `rnn/cudnn_rnn_base.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace_cuda` | `cudnnGetRNNTempSpaceSizes(fwdInference)` | RNN descriptor, seq_length, batch_size |
+| `reservespace_cuda` | `cudnnGetRNNTempSpaceSizes(training)` | Same |
+| `reorganized_w_data` | `w_size * sizeof(T)` | hidden_size, num_layers, input_size, direction |
+| `x_reversed_data` | `seq_length * batch_size * input_size * sizeof(T)` | Shapes (bidirectional case) |
+| `y_alloc_data` | `output_size * sizeof(T)` | Shapes |
+| `state_buffer_` | RNN state size from cuDNN | cuDNN descriptor |
+
+**What's needed:** seq_length, batch_size, input_size, hidden_size, num_layers, direction attribute, cuDNN handle.
+
+**Static determinability:** ✅* — cuDNN API queries, all inputs available from node attributes/shapes.
+
+---
+
+### 10. TopK
+
+**File:** `math/topk_impl.cuh`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `input_key_buffer` | `dimension * sizeof(T)` | Last-axis dimension |
+| `output_key_buffer` | `dimension * sizeof(T)` | Same |
+| `input_value_buffer` | `dimension * sizeof(int64_t)` | Same |
+| `output_value_buffer` | `dimension * sizeof(int64_t)` | Same |
+| `temp_storage_buffer` | From `cub::DeviceRadixSort::SortPairs` query | dimension |
+
+**What's needed:** Dimension (last axis size), k, dtype.
+
+**Static determinability:** ✅ — CUB temp storage query is deterministic given size.
+
+---
+
+### 11. MatMulInteger
+
+**File:** `math/matmul_integer.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `a_row_buf` | `(output_size / N) * sizeof(int32_t)` | M dimension |
+| `b_col_buf` | `(output_size / M) * sizeof(int32_t)` | N dimension |
+
+**What's needed:** M, N dimensions from MatMul shapes.
+
+**Static determinability:** ✅ — Pure arithmetic.
+
+---
+
+### 12. IntegerGemm (int8 alignment padding)
+
+**File:** `integer_gemm.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `a_padded` | `m * roundoff(lda, 32) * sizeof(int8_t)` (only if lda not 32-aligned) | M, K dims |
+| `b_padded` | `k * roundoff(ldb, 32) * sizeof(int8_t)` (only if ldb not 32-aligned) | K, N dims |
+
+**What's needed:** M, K, N dimensions + alignment check.
+
+**Static determinability:** ✅ — Pure arithmetic.
+
+---
+
+### 13. Compress
+
+**File:** `tensor/compress.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `condition_cumulative_sum_buffer` | `valid_condition_length * sizeof(int32_t)` | Condition tensor size |
+| `temp_buffer` | CUB `DeviceScan::InclusiveSum` temp storage | Condition size |
+
+**Static determinability:** ✅ — Condition shape determines everything.
+
+---
+
+### 14. GatherND
+
+**File:** `tensor/gather_nd.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `sizes_from_slice_dims_buffer` | `num_slice_dims * sizeof(int64_t)` | Indices shape |
+| `input_slice_offsets_buffer` | `num_slices * sizeof(int64_t)` | Indices shape[:-1] product |
+
+**Static determinability:** ✅ — Indices shape.
+
+---
+
+### 15. NonZero
+
+**File:** `tensor/nonzero_op.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `prefix_buffer` | `number_of_blocks * sizeof(int)` | Input element count / block_size |
+| `temp_buffer` | CUB temp storage | Input element count |
+
+**Static determinability:** ✅ — Input element count.
+
+---
+
+### 16. Upsample/Resize
+
+**File:** `tensor/upsample.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| temp buffer (via lambda) | Varies by resize mode | Input/output shapes, mode |
+| `dims_mapping_buffer` | `temp_buffer_size` (coordinate mapping) | Output shape |
+
+**Static determinability:** ✅ — Input/output shapes + mode attribute.
+
+---
+
+### 17. NonMaxSuppression
+
+**File:** `object_detection/non_max_suppression.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| Various (via lambda) | Determined by CUB DeviceSelect internals | num_boxes, num_classes |
+
+**What's needed:** boxes shape (num_batches, num_boxes, 4), scores shape.
+
+**Static determinability:** ✅ — CUB queries are deterministic given sizes.
+
+---
+
+## Contrib CUDA Operators (`onnxruntime/contrib_ops/cuda/`)
+
+### 18. Attention / MultiHeadAttention (Contrib)
+
+**File:** `bert/attention.cc`, `bert/multihead_attention.cc`
+
+**Buffers:** Uses `GetAttentionWorkspaceSize()` helper function.
+
+**Size formula:** Depends on attention algorithm (Flash, MemoryEfficient, FusedRunner, Default):
+- Flash: `qkv_size` (Q+K+V projection)
+- MemoryEfficient: `qkv_size + output_accum (float)`
+- Default (unfused): `qkv_size + 2 * attention_scratch_size`
+
+**What's needed:** B, S_q, S_kv, num_heads, head_size, dtype, which attention algorithm is selected.
+
+**Static determinability:** ✅ — Algorithm selection depends on shapes + SM version (available from device_prop).
+
+---
+
+### 19. MOE (Mixture of Experts)
+
+**File:** `moe/moe.cc`, `moe/moe_quantization.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace` | `moe_runner->getWorkspaceSize(num_rows, hidden, inter, experts, k)` | Shapes + tactic |
+| `expert_scales` | `num_rows * k * sizeof(float)` | Shapes |
+| `expert_indices` | `num_rows * k * sizeof(int)` | Shapes |
+| `permutation_row_map` | `num_rows * k * sizeof(int)` | Shapes |
+
+**What's needed:** num_rows, hidden_size, inter_size, num_experts, k, activation_type, SM version, selected tactic.
+
+**Static determinability:** ⚠️ — `getWorkspaceSize()` depends on profiled best tactic (CUTLASS config). Could use worst-case across tactics as upper bound.
+
+---
+
+### 20. MatMulNBits (Quantized MatMul)
+
+**File:** `quantization/matmul_nbits.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `workspace_buffer` | `weightOnlyGemmRunner_->getWorkspaceSize(m, n, k)` | Dims + tactic |
+| `packed_transposed_weight_space` | `packed_weight_bytes` (transient) | Weight shape |
+| `permutation_map_buffer` | `32 * sizeof(int32_t)` (transient) | Constant |
+
+**What's needed:** M, N, K dimensions, quantization bits, SM version.
+
+**Static determinability:** ⚠️ — Runner workspace depends on profiled tactic. Could use upper bound.
+
+---
+
+### 21. fpA_intB GEMM (FP×INT quantized)
+
+**File:** `llm/fpA_intB_gemm/`
+
+**Buffers:** `virtual size_t getWorkspaceSize(m, n, k)`.
+
+**What's needed:** M, N, K + CUTLASS template specialization.
+
+**Static determinability:** ⚠️ — Depends on selected CUTLASS config/tactic.
+
+---
+
+### 22. Inverse (Matrix Inversion)
+
+**File:** `inverse.cc`
+
+**Buffers allocated:**
+| Buffer | Size formula | Depends on |
+|--------|--------------|-----------|
+| `input_workspace` | `input_count * sizeof(T)` | Matrix dimensions |
+| `matrix_ptrs` | `n_batches * sizeof(T*)` | Batch count |
+| `output_ptrs` | `n_batches * sizeof(T*)` | Batch count |
+| `ml_float_output` | `input_count * sizeof(float)` (if fp16→fp32) | Dims + dtype |
+
+**Static determinability:** ✅ — Pure arithmetic on matrix dimensions.
+
+---
+
+### 23. Generation (Beam Search / Sampling)
+
+**File:** `transformers/generation_device_helper.cc`
+
+**Buffers:** Various pinned + device buffers for beam state.
+
+**What's needed:** batch_size, beam_width, max_length, vocab_size.
+
+**Static determinability:** ✅ — All from session/model config.
+
+---
+
+## Summary: Coverage Analysis
+
+### Workspace estimation approach validation
+
+| Category | # Kernels | Estimation feasibility | Notes |
+|----------|-----------|----------------------|-------|
+| **Shapes only** | 12 | ✅ Exact, trivial | BatchNorm, InstanceNorm, Dropout, TopK, MatMulInteger, IntegerGemm, Compress, GatherND, NonZero, Upsample, Inverse, Generation |
+| **Shapes + device properties** | 3 | ✅ Exact | Attention (SM count), DeformConv (totalGlobalMem), Contrib Attention (SM version) |
+| **Shapes + cuDNN/cuBLAS handle** | 4 | ✅* Exact via API query | Conv, ConvTranspose, Reduction, RNN |
+| **Shapes + tactic profiling** | 3 | ⚠️ Upper bound only | MOE, MatMulNBits, fpA_intB_GEMM |
+
+### Key takeaways
+
+1. **~75% of kernels** (19/25) can produce **exact** workspace estimates at `GetCapability()` time using only shapes + attributes + device properties (+ cuDNN handle for API queries).
+
+2. **~12% of kernels** (3/25) require tactic profiling (CUTLASS/CUB autotuning). For these, options are:
+ - Use worst-case workspace across all tactics (safe upper bound)
+ - Run tactic selection eagerly at estimation time (expensive but exact)
+ - Accept 1.5x multiplier for these few kernels
+
+3. **The cuDNN handle requirement** affects only 4 kernel types (Conv, ConvTranspose, Reduction, RNN). All are standard cuDNN API queries that are fast and deterministic given the handle + tensor descriptors.
+
+4. **No kernel requires actual GPU execution** to determine workspace size — even tactic-based kernels select tactics via CPU-side profiling/heuristics, not by running GPU code.
+
+5. **Largest workspace consumers** in practice:
+ - **Attention** (Flash): dominates in LLM workloads. Exact estimation possible.
+ - **Conv** (cuDNN): dominates in vision workloads. Exact via `build_plans()`.
+ - **MOE**: significant in MoE models. Upper bound via worst-case tactic.
+
+### What the estimation function needs access to (API requirements):
+
+| Access needed | How accessed | Kernels that require it |
+|---------------|-------------|------------------------|
+| `Node_GetInputShape()` | OrtEpApi (generic) | All 25 kernels |
+| `Node_GetAttributeInt/Ints()` | OrtEpApi (generic) | Conv, Attention, RNN, MOE |
+| `device_prop.multiProcessorCount` | Cast `OrtEp*` to concrete EP type | Attention, DeformConv |
+| `device_prop.totalGlobalMem` | Cast `OrtEp*` to concrete EP type | DeformConv |
+| cuDNN handle | Cast `OrtEp*` to concrete EP type | Conv, ConvTranspose, Reduction, RNN |
+| Tactic profiler state (or worst-case constant) | Cast `OrtEp*` to concrete EP type | MOE, MatMulNBits, fpA_intB |
+
+**API surface:** Only `Node_GetInputShape` and `Node_GetAttributeInt/Ints` need to be added to `OrtEpApi` (generic, EP-agnostic). All device-specific state (cuDNN handles, device properties, profiler state) is accessed by casting `OrtEp*` to the EP's concrete type — no public API needed since the estimation function is EP-specific code.
diff --git a/docs/annotated_partitioning/future_directions_constrained_env.md b/docs/annotated_partitioning/future_directions_constrained_env.md
new file mode 100644
index 0000000000000..ab33a6595baf3
--- /dev/null
+++ b/docs/annotated_partitioning/future_directions_constrained_env.md
@@ -0,0 +1,1338 @@
+# Future Directions: Constrained Environment Partitioning
+
+## Context
+
+Today's annotation-based partitioning requires each node to carry a `layer_ann` metadata property. The `LayeringIndex` matches these annotations against user-supplied rules (prefix trie + exact match) to assign nodes to devices. The `IResourceAccountant` optionally enforces memory budgets.
+
+The goal: make ORT as easy to use as ollama for running large models on machines with limited GPU memory — automatic or near-automatic layer offloading without requiring model producers to annotate every node.
+
+---
+
+## Direction 1: Name-Based Substring Matching (No Annotation Step)
+
+### Idea
+
+Skip the annotation metadata entirely. Instead, match directly against **node names** using substrings/patterns from the configuration. MS Foundry models (and most HuggingFace exports) already encode layer structure in node names:
+
+```
+/model/layers.0/self_attn/q_proj/MatMul
+/model/layers.0/self_attn/k_proj/MatMul
+/model/layers.15/mlp/gate_proj/MatMul
+/model/embed_tokens/Gather
+/model/norm/LayerNormalization
+```
+
+A config like `gpu(layers.0, layers.1, ..., layers.15); cpu(layers.16, ..., layers.31)` would partition without any model modification.
+
+### How to Approach
+
+1. **Add a new `SubstringMatcher` for node-name matching.** Today the `LayeringRuleMatcher` supports exact match and prefix match (via a trie that walks from position 0 of the input string). Neither mode works for node names: a node named `/model/layers.5/self_attn/q_proj/MatMul` does not *start with* `layers.5` — the identifying substring appears in the middle. Name-based matching fundamentally requires **substring** search. The existing trie infrastructure is irrelevant here — a new, simpler matching approach is needed (see "Substring Matching Implementation" below).
+
+2. **Config via a separate session option (same grammar, different matcher).** Rather than introducing a new qualifier into the existing `kOrtSessionOptionsLayerAssignmentSettings` syntax, add a parallel session option that uses **the same `device(pattern1, pattern2, ...); ...` grammar** but performs **substring matching** against `Node::Name()` instead of prefix/exact matching against node metadata:
+
+ ```cpp
+ // Existing (annotation-based, matches node metadata 'layer_ann'):
+ static const char* const kOrtSessionOptionsLayerAssignmentSettings =
+ "session.layer_assignment_settings";
+
+ // NEW (name-based, matches Node::Name() via substring):
+ static const char* const kOrtSessionOptionsNameBasedLayerAssignment =
+ "session.name_based_layer_assignment";
+ ```
+
+ Usage stays identical — only the matching target and algorithm differ:
+ ```
+ # Annotation-based (existing, prefix/exact match against node metadata):
+ session.layer_assignment_settings = "cuda(encoder_layer, attention); cpu(embed)"
+
+ # Name-based (new, substring match against Node::Name()):
+ session.name_based_layer_assignment = "cuda(layers.0/, layers.1/); cpu(layers.16/)"
+
+ # Range expressions (future extension, not currently supported):
+ session.name_based_layer_assignment = "cuda(layers.[0-15]); cpu(layers.[16-31])"
+ ```
+
+ This approach:
+ - Keeps the existing parser/grammar unchanged (reuse the `device(pattern1, pattern2, ...); ...` syntax)
+ - Uses a **new `SubstringMatcher`** (not the existing trie-based `LayeringRuleMatcher`) for the actual matching
+ - Makes intent explicit — users opt into name-based matching deliberately
+ - The two options are **mutually exclusive** — setting both returns an error
+ - No risk of breaking existing annotation-based workflows
+
+3. **Build index at load time.** During `InferenceSession::Initialize()`, after graph is loaded but before partitioning:
+ - If config contains name-based rules, iterate all nodes once
+ - Build `NodeIndex → RuleIndex` map using substring matching on `Node::Name()`
+ - Feed this into the existing `LayeringIndex` infrastructure (same downstream flow)
+
+4. **Range expressions (future extension).** The config grammar does **not** support range syntax today. For transformer models with numbered layers, a future extension could add range support:
+ ```
+ cuda(layers.[0-15]); cpu(layers.[16-31])
+ ```
+ This would avoid enumerating 32+ layer prefixes manually, but requires new parsing logic. Until then, users must enumerate each layer prefix explicitly or use a broad prefix like `layers.` that captures all layers for a single device.
+
+### Substring Matching Implementation
+
+The existing `LayeringRuleMatcher` uses a **trie** for prefix matching — it walks the input string from position 0 and checks if any trie path matches a prefix of the input. This only works when the pattern appears at the **start** of the matched string.
+
+For node names, patterns appear in the **middle**:
+```
+Pattern: "layers.5"
+Node name: "/model/layers.5/self_attn/q_proj/MatMul"
+ ^^^^^^^^ — match at position 7, not position 0
+```
+
+The trie is useless here. A new `SubstringMatcher` class is needed.
+
+#### Design: Flat vector + `std::string::find`
+
+The simplest correct approach:
+
+```cpp
+class SubstringMatcher {
+ public:
+ explicit SubstringMatcher(const LayeringRules& rules);
+
+ /// Returns the index of the best matching rule for the given node name.
+ /// "Best" = longest pattern that appears as a substring in the name.
+ std::optional Match(std::string_view node_name) const;
+
+ private:
+ // Sorted by pattern length descending — longest patterns checked first.
+ // First match wins (longest-match priority).
+ struct PatternEntry {
+ std::string pattern;
+ size_t rule_index;
+ };
+ InlinedVector patterns_; // sorted longest-first
+};
+```
+
+**Match algorithm:**
+```cpp
+std::optional SubstringMatcher::Match(std::string_view node_name) const {
+ for (const auto& entry : patterns_) {
+ if (node_name.find(entry.pattern) != std::string_view::npos) {
+ return entry.rule_index;
+ }
+ }
+ return std::nullopt;
+}
+```
+
+**Why longest-match-first ordering:**
+
+Without it, `layers.1` (a substring of `layers.10`, `layers.11`, ..., `layers.19`) would incorrectly match nodes from layers 10–19. By checking longer patterns first, `layers.10` matches before `layers.1` gets a chance. Users should include the path separator for unambiguous matching: `layers.1/` won't match `layers.10/...`.
+
+**Performance:** With ~64 patterns and node names < 200 chars, this is O(P × N) per node where P = number of patterns and N = name length. Total cost for a 1000-node model: ~64 × 200 × 1000 = ~12M character comparisons. This completes in microseconds on modern hardware and runs only once during `Initialize()`. No optimization (Aho-Corasick, etc.) is warranted.
+
+**Priority semantics:**
+
+| Scenario | Behavior |
+|----------|----------|
+| Single match | Return that rule's index |
+| Multiple matches (different lengths) | Longest pattern wins |
+| Multiple matches (same length, different rules) | First rule in config order wins (stable sort by length, preserving config order as tiebreaker) |
+| No match | Return `nullopt` → node goes to fallback EP (CPU) |
+
+**Integration with `LayeringIndex`:**
+
+`LayeringIndex` owns either a `LayeringRuleMatcher` (annotation mode) or a `SubstringMatcher` (name-based mode) — the two are mutually exclusive. The `ProcessGraph` method branches based on which mode is active:
+
+```cpp
+void LayeringIndex::ProcessGraph(const Graph& graph, std::optional parent_layer_id) {
+ for (const auto& node : graph.Nodes()) {
+ std::optional matched_rule_idx;
+
+ if (substring_matcher_) {
+ // Name-based mode: substring match against node name, no inheritance.
+ // Node names are dense, so each node is matched independently.
+ matched_rule_idx = substring_matcher_->Match(node.Name());
+ } else {
+ // Annotation-based mode: prefix/exact match against metadata,
+ // with subgraph inheritance for unannotated nodes.
+ const std::string& annotation = node.GetLayeringAnnotation();
+ if (!annotation.empty()) {
+ matched_rule_idx = matcher_.Match(annotation);
+ }
+ if (!matched_rule_idx && parent_layer_id) {
+ matched_rule_idx = parent_layer_id;
+ }
+ }
+
+ if (matched_rule_idx) {
+ node_to_layering_index_[node.Index()] = *matched_rule_idx;
+ }
+ }
+}
+```
+
+**Why mutual exclusivity (not priority/fallback):**
+
+The two modes have fundamentally different inheritance semantics. Annotations are sparse — nodes without annotations inherit from their subgraph parent to maintain device consistency. Names are dense — virtually every node has a name, so inheritance is unnecessary and would incorrectly override name-based matches in subgraphs. Making the modes mutually exclusive keeps the semantics simple and predictable.
+
+### Advantages
+
+- **Zero model modification** — works with any model that has structured naming
+- **Reuses existing partitioning infrastructure** — only the index-building and matching steps change
+- **User-friendly** — users can inspect node names with Netron and write rules directly
+- **Composable with resource accounting** — can combine name-based assignment with memory budgets
+
+### Risks / Open Questions
+
+- **Name stability**: Node names aren't guaranteed stable across exports. Mitigated by prefix/substring matching rather than exact names.
+
+### Handling Nodes Created by Graph Transformers
+
+#### Pre-partitioning transformers (Level 1)
+
+Level 1 optimizers run **before** partitioning. With annotation-based matching, these transformers propagate annotations to new nodes via the `AddNode(..., annotation_source)` overload, which copies `GetLayeringAnnotation()` from an original node. The name-based approach needs an analogous story.
+
+**Key insight: new node names ARE derivative of original names.** Verified in the codebase:
+
+- `Graph::GenerateNodeName(base_name)` takes a base string and ensures uniqueness by appending `_token_` only on collision.
+- Transformers construct the base name from original node(s):
+ - `layer_norm_fusion.cc`: `GenerateNodeName(mul_node.Name() + "/LayerNormFusion/")`
+ - `matmul_add_fusion.cc`: `GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion")`
+ - `attention_fusion.cc`: `GenerateNodeName("Attention")` ← **exception — generic name**
+
+So if the original node was `/model/layers.5/self_attn/q_proj/MatMul`, the fused node typically becomes something like `/model/layers.5/self_attn/q_proj/MatMul/MatMulAddFusion` — which still contains the original layer prefix and will match substring rules like `layers.5`.
+
+**This means name-based matching is naturally robust to pre-partitioning fusions** — no explicit annotation-copying step is needed, because the substring match against the derivative name still hits the same rules. This is actually **simpler** than the annotation-based approach.
+
+**Edge cases to handle:**
+1. **Generic names** (e.g., `"Attention"` without incorporating the original name): Some fusions create nodes with generic names that don't contain layer-identifying substrings. In general, name-based partitioning can only be used when node names contain representative strings suitable for layer matching. Two options exist:
+ - **Annotation fallback**: Use annotation-based assignment for these nodes, or update the transformer to follow the derivative naming convention.
+ - **Substring rule**: Add a substring pattern (e.g., `cuda(Attention)`) to assign all nodes whose name contains `Attention`. Note that name-based assignment does not support the '=' exact-match qualifier — all patterns are substrings.
+2. **Multiple source nodes**: When a fusion merges N nodes from potentially different layers, the resulting name typically uses one of them as the base. If the merged nodes span layer boundaries, the fused node will match whichever layer the chosen base name belongs to. This mirrors annotation-based behavior (annotation is copied from one source node).
+
+**Recommendation:** No special machinery needed for pre-partitioning transformers. The derivative naming convention already preserves matchability. Document this as a convention that transformer authors should follow: always pass an original node's name as the base to `GenerateNodeName()`.
+
+#### Post-partitioning transformers (Level 2+)
+
+Level 2+ optimizers run **after** partitioning, when EP assignments have already been made. These transformers already copy the EP assignment from original nodes:
+
+```cpp
+new_node.SetExecutionProviderType(original_node.GetExecutionProviderType());
+```
+
+**No action needed for name-based matching here.** By the time Level 2+ transformers run, partitioning is complete. The node names are irrelevant — only the EP assignment matters, and that's already propagated correctly.
+
+---
+
+## Direction 2: Minimize Allocations for Static-Shape Models
+
+### Goal
+
+For models with fully static shapes (common in transformer inference with fixed batch/sequence dimensions), ORT should minimize or eliminate runtime memory allocations. When all tensor shapes are known at `Initialize()` time, the runtime can pre-compute exact memory requirements and pre-allocate everything upfront — no arena overhead, no per-`Run()` allocation calls, deterministic memory usage.
+
+Additionally, by knowing exact memory requirements before execution begins, ORT can **minimize the chance of running into OOM** — if the total memory needed exceeds device capacity, the session can fail at `Initialize()` time with a clear error rather than crashing mid-inference with an opaque allocation failure.
+
+This brings ORT's allocation efficiency on par with specialized runtimes like llama.cpp, while retaining ORT's generality for arbitrary model architectures.
+
+### Dynamic Shapes in Transformer Models
+
+Nearly all transformer models are *exported* with dynamic batch + sequence_length — this is the default in PyTorch `torch.onnx.export`, Hugging Face Optimum, and Olive. However, for constrained deployment the picture is different:
+
+- **Fixed-shape re-export is standard for edge/embedded:** batch=1, seq_len=128 (or a few discrete lengths like 128/256/512). This is standard practice for TensorRT, CoreML, and QNN deployments.
+- **LLM serving keeps shapes dynamic** (variable prompts, KV-cache growth). But these are typically high-VRAM scenarios (A100/H100), not constrained environments.
+- **Vision transformers** (ViT, DINO, etc.) have fixed patch sequences — only batch is dynamic, and fixing batch=1 yields fully static shapes.
+
+**Implication for pre-allocation:** The target audience of `pre_allocate_execution_buffers` — embedded, edge, single-model-per-device — typically *can* use static shapes. Models are re-exported with fixed dimensions as part of the deployment pipeline. The dynamic-shape case (LLM serving with variable seq_len) lives in a different deployment tier where VRAM budget is less critical than throughput and the arena allocator handles repeated allocations efficiently.
+
+**Implication for workspace estimation:** Even with dynamic shapes, `EstimateWorkspace` (Level 1) and `DeclareWorkspaceRequirements` (Level 2) remain valuable for *budget decisions* — they can use worst-case shapes (max batch, max seq_len from model config) to determine how many nodes fit on the device. The estimate doesn't need to match runtime exactly; it needs to be conservative enough to avoid OOM.
+
+### Reference: What llama.cpp Does
+
+llama.cpp exploits that transformer inference has **fully deterministic memory usage**:
+- All weight tensors are known at load time
+- KV cache is pre-allocated for max sequence length
+- Intermediate activation buffers have shapes determined by `(batch, seq_len, hidden_dim)` — all known in advance
+- Workspace/temp buffers are known per-op and pre-planned
+
+This means the runtime computes **exactly** how much memory is needed before running — zero allocation calls during inference.
+
+### What ORT Already Does
+
+Before discussing gaps, it's important to note what ORT already provides:
+
+- **Shape inference runs during `Initialize()`** — specifically during `graph.Resolve()` after graph optimizations but before memory planning. Shape info is populated on all `NodeArg` objects.
+- **Memory pattern pre-allocation** — When `EnableMemPattern` is set and shapes are static, `TensorAllocatorWithMemPattern` computes exact allocation offsets/sizes via `OrtValuePatternPlanner`, then pre-allocates a single large buffer per device via `Reserve()` (bypasses arena, calls device allocator directly). Intermediate buffers are reused based on liveness analysis.
+- **Initializers can bypass arena** — When `use_device_allocator_for_initializers` is set, initializers are loaded via `Reserve()` (direct device allocation) during session state finalization, bypassing the arena's binning/coalescing logic. Without this option, initializers allocate through the arena like any other buffer.
+- **BFC Arena** — Best-Fit-with-Coalescing allocator provides fast (O(log n)) sub-allocation. Allocations are cheap, but it suffers from memory waste due to power-of-two growth and chunk granularity.
+
+**What's already outside the arena:**
+- Initializers → `Reserve()` (direct device allocator)
+- Memory pattern buffers (activations) → `Reserve()` (direct device allocator)
+
+**What's still arena-allocated:**
+- **Runtime temp/workspace buffers** — kernels call `GetScratchBuffer()` during `Compute()`, which allocates from the device allocator (arena by default). These are ephemeral (freed when kernel completes) and their sizes are only known at execution time.
+
+**Note on arena and temp buffer reuse:** While the BFC arena wastes memory due to chunk granularity, it does provide **automatic reuse** for temp buffers during sequential execution — subsequent kernels reuse the same arena memory for their scratch needs without actual device allocations.
+
+**CUDA mempool as an alternative:** ORT supports replacing the BFC arena with native CUDA memory pools (`cudaMallocFromPoolAsync`). This is enabled via the EP-scoped arena configuration key `arena.use_cuda_mempool` (e.g., `"ep.cudapluginexecutionprovider.arena.use_cuda_mempool" = "1"` in session config). This provides stream-aware pooling managed by the CUDA driver, with less memory waste than BFC. Since `GetScratchBuffer()` uses the same device allocator as activations (resolved via `SessionState::GetAllocator(device)` — keyed by `OrtDevice` only, not by purpose), enabling mempool automatically benefits temp buffers too. A separate temp-only allocator would require architectural changes to `AllocatorMap` (currently not feasible without significant refactoring).
+
+**Pre-allocated space for temp buffers (alternative approach):** If workspace sizes can be pre-computed (see Phase A below), temp buffers could be served from the same pre-allocated memory pattern buffer used for activations. Since workspace is live only during its kernel's execution, it participates naturally in liveness-based offset planning — no arena needed at all. This is the more principled solution: solve the workspace size problem first, then temp memory becomes part of the static plan.
+
+### IResourceAccountant Precision
+
+`IResourceAccountant::ComputeResourceCount()` already returns **exact sizes when shapes are static**:
+- Initializer sizes: exact via `GetSizeInBytesFromTensorProto()`
+- Output tensor sizes: exact when all dimensions are known via `GetSizeInBytesFromTensorTypeProto()`
+- Weight deduplication: tracked via `pending_weights_`/`committed_weights_` to avoid double-counting
+
+**Note on actual memory consumption:** While these give exact logical tensor sizes, actual device memory will be rounded up to page/alignment boundaries — either per allocation or for a single large buffer. The reported sizes are a lower bound; real usage includes alignment overhead.
+
+It applies a **1.5x safety multiplier** to account for unknowable temp/workspace allocations. This multiplier exists because temp buffer sizes are discovered only at runtime — no kernel declares its workspace needs in advance. **For static-shape models where `DeclareWorkspaceRequirements` (Phase A below) has been implemented for all relevant kernels, this multiplier becomes unnecessary and should be bypassed** — exact workspace sizes are known at planning time, eliminating the need for a safety margin.
+
+**Per-node accounting (not per-layer)**: `IResourceAccountant` tracks costs at the node/subgraph level — `nodes_costs` in `IndexedSubGraph` is for accounting after an EP claims nodes (single nodes or fused groups), not for layering. It has no concept of "layer" as defined by the layering index.
+
+**Do we need per-layer aggregation?** Likely not as a first-class feature in `IResourceAccountant`.
+
+The current budget enforcement cuts off EP placement at the individual node level — once the cumulative budget is exceeded, subsequent nodes are rejected regardless of layer boundaries. **This is intentional**: atomic rollback of already-placed nodes within a layer was explicitly rejected during the previous implementation phase due to complexity and because it would require re-running `GetCapability()` with different node sets.
+
+The layering index already controls the *order* in which nodes are presented to the EP (layer by layer), so in practice the budget cutoff tends to land near layer boundaries. If exact layer-boundary cuts are desired in the future, it would need to be a separate mechanism (e.g., pre-computing per-layer costs and making accept/reject decisions at the layer level before calling `GetCapability()`), not a change to the accountant.
+
+For debugging/UX purposes, per-layer summaries can be computed externally by summing node costs grouped by their layer annotation — no accountant changes needed.
+
+### The Remaining Gap: Runtime Temp Buffers
+
+The 1.5x multiplier and arena waste both stem from the same root cause: **kernels allocate workspace at runtime without prior declaration.**
+
+Examples of how workspace sizes are determined:
+- **cuDNN Convolution**: Workspace depends on algorithm selection (`cudnn_fe_graph->get_workspace_size()`)
+- **cuDNN RNN**: Queries cuDNN at runtime (`cudnnGetRNNTempSpaceSizes()`)
+- **Attention kernels**: Computed from runtime parameters (`batch_size * seq_len * num_heads * head_size`)
+
+The `SequentialExecutionPlan` tracks only activation tensors and initializers — workspace buffers are ephemeral and never planned statically. There is no `GetWorkspaceSize()` virtual method on `OpKernel`.
+
+To eliminate both the multiplier and arena waste for temp buffers, two problems must be solved:
+1. **Learn temp buffer sizes in advance** — e.g., a `DeclareWorkspaceRequirements(shapes)` method on kernels, queryable during `Initialize()` when shapes are static
+2. **Allocate temp buffers outside the arena** — once sizes are known, include them in the memory pattern plan alongside activations
+
+### What ORT Is Missing
+
+| Capability | llama.cpp | ORT Today | Gap |
+|-----------|-----------|-----------|-----|
+| Static memory plan | Yes — computed at load | Yes — `MemoryPatternGroup` + `Reserve()` | ✓ Activations + initializers already bypass arena |
+| Pre-allocated activation buffers | Yes — fixed slots | Yes — memory patterns with liveness reuse | ✓ Already exists for static shapes |
+| Workspace pre-computation | Yes — known per op | No — kernels discover at runtime | Need `DeclareWorkspaceRequirements()` on kernels |
+| Workspace outside arena | Yes — part of static plan | No — `GetScratchBuffer()` uses arena | Need to include workspace in memory pattern plan |
+| Zero-copy weight transfer | mmap + `cudaMemcpy` at load per layer | mmap + `cudaMemcpy` at load per partition | ✓ Same model — not a gap |
+
+**Note on weight transfer:** Both llama.cpp and ORT use the same approach: **static partitioning at load time, no dynamic weight swapping during inference.**
+
+- **llama.cpp**: User sets `-ngl N` (number of GPU layers). At load time, those N layers' weights are `cudaMemcpy`'d from mmap'd file to GPU. Remaining layers stay in host memory. No runtime swapping — this is performant because there is zero weight transfer overhead during token generation.
+- **ORT (with constrained partitioning)**: The layering index + `IResourceAccountant` determines which nodes run on GPU. At `Initialize()` time, only those nodes' initializers are copied to device from mmap'd external data. Remaining weights stay in host memory.
+
+**Best practice for constrained environments:** Model weights should be stored as **external data on disk** (not embedded in protobuf). This ensures:
+1. ORT memory-maps the file — minimal host memory overhead during loading.
+2. Only GPU-partitioned nodes' weights are copied to device — no OOM as long as partitioning respects the budget.
+3. CPU-partitioned nodes' weights remain accessible via mmap without requiring a separate host allocation.
+
+Since partitioning is decided once at `Initialize()` time and all required device weights are resident before `Run()`, there is no need for dynamic layer loading/offloading during inference.
+
+### How to Approach
+
+#### The Chicken-and-Egg Problem: Workspace Estimation vs EP Assignment
+
+**Problem statement:** To make precise memory budget decisions during `GetCapability()`, `IResourceAccountant` needs workspace sizes per node. But workspace sizes come from kernels, which don't exist until *after* EP assignment (kernels are created during `Compile()`/session state finalization — same reason `PrePack` happens late). At decision time, you can't ask a kernel that doesn't exist yet.
+
+**Why this matters:** The goal is not to fail gracefully — it's to **avoid failure entirely**. Today, models either OOM on device or trigger heavy VRAM thrashing on Windows. The partitioning must be conservative enough to prevent this, while accurate enough to maximize GPU utilization.
+
+**Solution: Two-level estimation with post-assignment verification**
+
+**Level 1 — Static workspace estimation function (at partitioning time):**
+
+When a kernel is registered (via `KernelRegistry_AddKernel`), optionally provide a **static estimation function** — a class-level function that takes node info and returns a conservative workspace estimate without needing a kernel instance:
+
+```c
+// Registered alongside the kernel definition in KernelRegistry_AddKernel:
+typedef OrtStatus*(ORT_API_CALL* OrtKernelWorkspaceEstimateFunc)(
+ _In_ const OrtEpApi* api, // for querying node attributes/shapes/device props
+ _In_ const OrtNode* node, // the specific node being evaluated
+ _In_ const OrtEp* ep, // EP instance (for device properties like SM count)
+ _Out_ size_t* estimated_workspace_bytes);
+```
+
+The function uses `api->Node_GetAttribute*()` and `api->Node_GetInputShape()` to access the node's attributes and input shapes, and `api->Ep_GetDeviceProperty()` for GPU hardware properties — everything needed to compute workspace without a kernel instance.
+
+**Can the estimate be precise (not just conservative)?**
+
+Depends on the kernel:
+
+| Kernel | Workspace depends on | Available at GetCapability()? | Precise estimate? |
+|--------|---------------------|-------------------------------|-------------------|
+| **Attention (Flash)** | shapes + `num_heads` attr + `device_prop.multiProcessorCount` | ✓ All available (EP has device_prop) | **YES — exact** |
+| **Conv (cuDNN)** | cuDNN `build_plans(handle)` with tensor shapes + conv params | ✓ EP has handle; shapes/attrs available from node | **YES — exact** (with `HEUR_MODE_A`) |
+| **GEMM/MatMul** | No workspace | N/A | N/A (returns 0) |
+
+For **attention**, workspace is determined by `get_num_splits_and_buffer_sizes()` which is pure arithmetic given `(batch, seq, heads, head_size, multiProcessorCount)`. The EP already has `multiProcessorCount` from `cudaGetDeviceProperties()` which runs during EP construction (before `GetCapability()`). So the estimate can be **exact**.
+
+For **cuDNN-based ops** (Conv), the workspace depends on which algorithm cuDNN selects via `build_plans(handle)`. However, a cuDNN handle is just a lightweight context object (`cudnnCreate` + `cudnnSetStream`) — the EP already owns one from construction time. With static shapes, all inputs to `build_plans()` are known: tensor dimensions, conv parameters (from node attributes), and the handle. The `CUDNN_HEUR_MODE_A` (fast heuristic) used by ORT is essentially a lookup + arithmetic — not actual GPU profiling. So the estimation function **can call `build_plans()` and get the exact workspace size**. This makes Conv estimates **precise too**.
+
+The reason `build_plans()` currently runs during first `Compute()` is historical: ORT didn't have a pre-execution workspace declaration phase, and shapes weren't known until runtime. With static shapes and the estimation function pattern, this computation can move earlier.
+
+The estimation function accesses the handle by casting `OrtEp*` to the EP's concrete type (safe because the function is EP-specific code registered by that EP):
+
+```cpp
+auto* cuda_ep = static_cast(ep); // plugin path
+cudnnHandle_t handle = cuda_ep->GetCudnnHandle();
+```
+
+**Can it be the same function as DeclareWorkspaceRequirements?**
+
+Not the same function pointer (different signatures — one has a kernel instance, one doesn't). But the **core computation logic can be a shared static helper** called from both:
+
+```cpp
+// Shared static helper (no instance needed):
+static size_t ComputeAttentionWorkspace(int batch, int seq, int heads,
+ int head_size, int num_SMs) {
+ auto [num_splits, slse_size, o_size] = flash::get_num_splits_and_buffer_sizes(
+ batch, seq, seq, heads, head_size, num_SMs);
+ return flash::get_softmax_lse_size(seq, batch, heads) + slse_size + o_size;
+}
+
+// Estimation function (no kernel instance — called during GetCapability):
+OrtStatus* EstimateAttentionWorkspace(const OrtEpApi* api, const OrtNode* node,
+ const OrtEp* ep, size_t* out) {
+ const int64_t* shape; size_t rank;
+ api->Node_GetInputShape(node, 0, &shape, &rank);
+ int64_t num_heads;
+ api->Node_GetAttributeInt(node, "num_heads", &num_heads);
+
+ // EP-specific: cast to concrete type to access device properties
+ auto* cuda_ep = static_cast(ep);
+ int num_SMs = cuda_ep->GetDeviceProp().multiProcessorCount;
+
+ *out = ComputeAttentionWorkspace(shape[0], shape[1], num_heads, shape[3], num_SMs);
+ return nullptr;
+}
+
+// DeclareWorkspaceRequirements (has kernel instance — called during FinalizeSessionState):
+Status Attention::DeclareWorkspaceRequirements(span shapes,
+ InlinedVector& reqs) {
+ int num_SMs = GetDeviceProp().multiProcessorCount;
+ size_t total = ComputeAttentionWorkspace(
+ shapes[0][0], shapes[0][1], num_heads_, head_size_, num_SMs);
+ reqs.push_back({total, kSlotFlashWorkspace});
+ return Status::OK();
+}
+```
+
+Both call the same `ComputeAttentionWorkspace()` — producing **identical results**. The estimation function gets device properties from the EP; the kernel method gets them from its stored EP reference. Same data, same computation, same answer.
+
+**For cuDNN-based ops**, the estimation function can also be precise — it calls `build_plans()` using the EP's handle and the node's shapes/attributes. Level 2 re-check serves as a diagnostic safety net — if the post-fusion total exceeds the budget, a warning is logged indicating that the Level 1 estimate was too optimistic (e.g., cuDNN returning different workspace sizes due to driver version differences or fusion changing the algorithm selection).
+
+**KernelCreateInfo and registration macros:**
+
+Today `KernelCreateInfo` contains `{kernel_def, kernel_create_func, status}`. To add the estimation function:
+
+```cpp
+struct KernelCreateInfo {
+ std::unique_ptr kernel_def;
+ KernelCreateFn kernel_create_func;
+ OrtKernelWorkspaceEstimateFunc workspace_estimate_func; // NEW — may be nullptr
+ Status status;
+};
+```
+
+The existing `ONNX_OPERATOR_TYPED_KERNEL_EX` macro doesn't need changes — it produces `KernelCreateInfo` via `BuildKernelCreateInfo<>()`. A new macro variant adds the estimation function for kernels that implement it:
+
+```cpp
+// New macro: ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE
+// Same as ONNX_OPERATOR_TYPED_KERNEL_EX but also registers a workspace estimation function.
+#define ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE( \
+ name, domain, ver, type, provider, builder, estimate_fn, ...) \
+ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \
+ template <> \
+ KernelCreateInfo \
+ BuildKernelCreateInfo() { \
+ return KernelCreateInfo( \
+ builder.SetName(#name) \
+ .SetDomain(domain) \
+ .SinceVersion(ver) \
+ .Provider(provider) \
+ .Build(), \
+ static_cast( \
+ [](FuncManager&, const OpKernelInfo& info, \
+ std::unique_ptr& out) -> Status { \
+ out = std::make_unique<__VA_ARGS__>(info); \
+ return Status::OK(); \
+ }), \
+ estimate_fn); \
+ }
+```
+
+This requires a new `KernelCreateInfo` constructor overload:
+
+```cpp
+struct KernelCreateInfo {
+ std::unique_ptr kernel_def;
+ KernelCreateFn kernel_create_func;
+ OrtKernelWorkspaceEstimateFunc workspace_estimate_func; // NEW — may be nullptr
+ Status status;
+
+ // Existing constructor (unchanged — sets workspace_estimate_func to nullptr):
+ KernelCreateInfo(std::unique_ptr definition,
+ KernelCreateFn create_func)
+ : kernel_def(std::move(definition)),
+ kernel_create_func(create_func),
+ workspace_estimate_func(nullptr) {}
+
+ // New constructor with estimation function:
+ KernelCreateInfo(std::unique_ptr definition,
+ KernelCreateFn create_func,
+ OrtKernelWorkspaceEstimateFunc estimate_func)
+ : kernel_def(std::move(definition)),
+ kernel_create_func(create_func),
+ workspace_estimate_func(estimate_func) {}
+
+ KernelCreateInfo(KernelCreateInfo&& other) noexcept
+ : kernel_def(std::move(other.kernel_def)),
+ kernel_create_func(std::move(other.kernel_create_func)),
+ workspace_estimate_func(other.workspace_estimate_func) {}
+
+ KernelCreateInfo() = default;
+};
+```
+
+**Usage example** (registering a CUDA Attention kernel with estimation):
+
+```cpp
+// In cuda_contrib_kernels.cc:
+ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE(
+ Attention, // name
+ kMSDomain, // domain
+ 1, // ver
+ float, // type
+ kCudaExecutionProvider, // provider
+ KernelDefBuilder() // builder
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .InputMemoryType(OrtMemTypeCPUInput, {3, 4}),
+ &cuda::EstimateAttentionWorkspace, // estimate_fn ← NEW argument
+ cuda::Attention); // kernel class (__VA_ARGS__)
+```
+
+Kernels without estimation continue using `ONNX_OPERATOR_TYPED_KERNEL_EX` unchanged — their `workspace_estimate_func` is `nullptr`, and the budget logic applies the 1.5x multiplier as today. The migration is opt-in, kernel by kernel.
+
+**Integration with GetCapability and the resource budget:**
+
+The estimation function is called during budget enforcement — by the EP directly (in-tree) or by the host bridge (plugin). The result is combined with the base cost from `IResourceAccountant`.
+
+**Multiplier handling — non-member helper approach:**
+
+`ComputeResourceCount()` currently applies a 1.5x multiplier to approximate workspace for kernels without estimation functions. With precise workspace estimates available, the multiplier must be skipped. Rather than changing `ComputeResourceCount()`'s signature, we move the multiplier out and into a non-member helper that encapsulates the budget decision:
+
+```cpp
+// Non-member helper (e.g., in resource_accountant_helpers.h):
+// Called by both in-tree GetCapability and the plugin host bridge.
+ResourceCount ComputeNodeCostForBudget(
+ IResourceAccountant& accountant,
+ const Node& node,
+ std::optional workspace_estimate) {
+ // ComputeResourceCount returns base cost: outputs + initializers (dedup'd)
+ // NO multiplier — multiplier is now applied here when needed
+ ResourceCount base_cost = accountant.ComputeResourceCount(node);
+
+ if (workspace_estimate.has_value()) {
+ // Precise workspace known — add it directly, no multiplier
+ return AddResourceCounts(base_cost, *workspace_estimate);
+ }
+ // No workspace estimate — apply heuristic multiplier (1.5x)
+ return ApplyWorkspaceHeuristic(base_cost);
+}
+
+// Multiplier as an explicit utility:
+ResourceCount ApplyWorkspaceHeuristic(ResourceCount base) {
+ size_t bytes = std::get<0>(base);
+ return ResourceCount{static_cast(bytes * 1.5)};
+}
+```
+
+**Design rationale:**
+- `ComputeResourceCount()` signature is **unchanged** — it returns the raw base cost (outputs + initializers with dedup). The 1.5x multiplier moves out of the accountant into this helper.
+- The helper is the **single decision point** for both code paths (in-tree and plugin host bridge). No duplicated logic.
+- `ApplyWorkspaceHeuristic()` makes the multiplier explicit and testable. It can be adjusted (e.g., per-EP or per-op-type) without changing any interface.
+- The helper integrates naturally with the existing budget check pattern:
+
+```cpp
+// Usage in GetCapability (both paths):
+auto total_cost = ComputeNodeCostForBudget(*accountant, node, workspace_estimate);
+auto would_be_consumed = AddResourceCounts(consumed, total_cost);
+
+if (has_budget && ResourceCountExceeds(would_be_consumed, budget)) {
+ accountant->SetStopAssignment();
+ break;
+}
+
+consumed = would_be_consumed;
+sub_graph->SetAccountant(accountant);
+sub_graph->AppendNodeCost(total_cost);
+```
+
+**Why this is clean with committed/uncommitted weights:**
+
+- **Weight dedup is unaffected.** `ComputeResourceCount()` handles pending/committed weight tracking internally. The workspace estimate is purely additive — it's not a weight, so it doesn't participate in dedup.
+- **`AppendNodeCost()` stores the combined total.** When `AccountForNode()` runs later (during `TryAssignNodes()`), it adds the stored cost (base + workspace) to `consumed_amount` and commits the weights. The workspace portion just inflates the per-node cost.
+- **`CommitWeightsForNode()` only touches initializers.** Workspace is a separate addend, not tracked in weight sets.
+- **`ResetForNewPass()` is fine.** The workspace estimate is stateless — recomputed fresh from node shapes each call, no state to carry across passes.
+
+If no estimation function is registered for a kernel, the helper applies the 1.5x multiplier as today (unchanged behavior).
+
+**Example estimation function (CUDA Conv):**
+
+```cpp
+OrtStatus* EstimateConvWorkspace(const OrtEpApi* api, const OrtNode* node,
+ const OrtEp* ep, size_t* out) {
+ // Get input shape (X: NCHW)
+ const int64_t* x_shape = nullptr;
+ size_t x_rank = 0;
+ OrtStatus* status = api->Node_GetInputShape(node, 0, &x_shape, &x_rank);
+ if (status) return status;
+ if (x_rank < 3) {
+ return api->CreateStatus(ORT_INVALID_ARGUMENT,
+ "Conv: input X must be at least rank 3");
+ }
+
+ // Get weight shape (W: [M, C/group, kH, kW, ...])
+ const int64_t* w_shape = nullptr;
+ size_t w_rank = 0;
+ status = api->Node_GetInputShape(node, 1, &w_shape, &w_rank);
+ if (status) return status;
+ if (w_rank != x_rank) {
+ return api->CreateStatus(ORT_INVALID_ARGUMENT,
+ "Conv: weight rank must match input rank");
+ }
+
+ // Get conv attributes (all optional — defaults to empty/zeros per ONNX spec)
+ const int64_t* pads = nullptr;
+ size_t pads_count = 0;
+ status = api->Node_GetAttributeInts(node, "pads", &pads, &pads_count);
+ if (status) return status; // distinguishes "not present" (OK + nullptr) from error
+
+ const int64_t* strides = nullptr;
+ size_t strides_count = 0;
+ status = api->Node_GetAttributeInts(node, "strides", &strides, &strides_count);
+ if (status) return status;
+
+ const int64_t* dilations = nullptr;
+ size_t dilations_count = 0;
+ status = api->Node_GetAttributeInts(node, "dilations", &dilations, &dilations_count);
+ if (status) return status;
+
+ // EP-specific: cast to concrete type to access cuDNN handle
+ auto* cuda_ep = static_cast(ep);
+ cudnnHandle_t handle = cuda_ep->GetCudnnHandle();
+ if (!handle) {
+ return api->CreateStatus(ORT_RUNTIME_EXCEPTION,
+ "Conv: cuDNN handle not available on EP");
+ }
+
+ // Build cuDNN frontend graph and query workspace
+ // (same logic as CreateCudnnFeExecutionPlan but without storing state)
+ auto graph = BuildConvFrontendGraph(x_shape, x_rank, w_shape, w_rank,
+ pads, pads_count, strides, strides_count,
+ dilations, dilations_count);
+ if (!graph) {
+ return api->CreateStatus(ORT_RUNTIME_EXCEPTION,
+ "Conv: failed to build cuDNN frontend graph");
+ }
+
+ auto plan_status = graph->build_plans(handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_ONLY);
+ if (plan_status) {
+ return api->CreateStatus(ORT_RUNTIME_EXCEPTION,
+ plan_status.get_message());
+ }
+
+ *out = graph->get_workspace_size();
+ return nullptr; // success
+}
+```
+
+This produces the **exact same result** as the kernel would compute during `DeclareWorkspaceRequirements` — same handle, same shapes, same algorithm selection. The shared logic is the cuDNN frontend graph construction.
+
+**Why no public API for device properties or library handles:** The estimation function is registered BY the EP for ITS kernels — it's EP-specific code running in the EP's own DLL. It can safely cast `OrtEp*` to its concrete type (e.g., `CudaEp*`) to access device_prop, cuDNN handles, etc. This is the same pattern kernels already use: `static_cast(info.GetExecutionProvider())->GetDeviceProp()`. No generic `Ep_GetCudnnHandle` or `Ep_GetDeviceIntProperty` API is needed — that would be CUDA-specific pollution of the universal EP interface.
+
+**Level 2 — Post-fusion budget re-check (before InsertCast and MemcpyTransformer):**
+
+The `TransformGraph()` pipeline has a natural insertion point after EP-specific optimizers but before the transformers that bake in EP boundaries:
+
+```
+L1 optimizers → Partition (GetCapability) → L2/L3 EP-specific optimizers → [HERE] → InsertCastTransformer → L4 → MemcpyTransformer
+```
+
+At `[HERE]`:
+- Nodes are assigned to EPs ✓
+- EP-specific fusions (ConvRelu, FusedMatMul, etc.) have already been applied ✓
+- The graph reflects the *actual* ops that will become kernels ✓
+- Cast nodes have NOT been inserted yet ✓ (no fp16↔fp32 casts at boundaries)
+- Memcpy nodes have NOT been inserted yet ✓ (boundaries can still move)
+- Kernels do NOT exist yet ✗ (cannot call `DeclareWorkspaceRequirements`)
+
+**Why before InsertCastTransformer:** The InsertCastTransformer inserts fp16↔fp32 Cast nodes at EP boundaries where input/output types don't match. If we offload nodes *after* Cast insertion, we'd leave orphaned Cast nodes at the old boundary and need new ones at the new boundary — similar to the MemcpyTransformer problem. By running before both, any Cast and Memcpy nodes are inserted at the final (post-offload) boundaries.
+
+Since kernels don't exist, we call the **same `EstimateWorkspace()` functions** from Level 1 — but now on the post-fusion graph. This eliminates the only meaningful gap between Level 1 and Level 2: fused ops that didn't exist at `GetCapability()` time now have their own estimation functions registered alongside their kernel definitions.
+
+**Algorithm:**
+
+1. For each node assigned to the constrained EP, look up its `OrtKernelWorkspaceEstimateFunc` from the kernel registry (same registry that will later be used to create the kernel).
+2. Call the estimation function on the (possibly fused) node → get workspace.
+3. Re-run the budget check: `base_cost + workspace` for all assigned nodes.
+4. If total ≤ budget → proceed to InsertCastTransformer/MemcpyTransformer normally.
+5. If total > budget → **log a warning** and proceed. Do NOT attempt to offload nodes.
+
+**Why warn-only (no runtime offload):**
+
+The earlier design attempted tail-node offloading at this stage — walking backward through GPU-assigned nodes and reassigning them to CPU. In practice, this is problematic:
+
+- **For bf16/fp16 models** (the dominant constrained-VRAM use case): CPU EP lacks kernels for most bf16/fp16 compute ops (MatMul, Attention, LayerNorm). The offload loop would hit a non-offloadable node almost immediately and accomplish nothing.
+- **For fp32 CNN models** (where CPU *could* handle offloaded ops): The performance cost of GPU→CPU→GPU data transfers typically outweighs the memory benefit of offloading a few tail nodes.
+- **Complexity vs value**: Offload logic (type checking, contiguous-tail constraint, boundary correctness) adds significant code for a feature that rarely fires and rarely helps.
+
+The correct fix for a Level 2 budget overrun is to **improve Level 1 accuracy** — make the estimation functions precise enough that post-fusion re-check merely confirms (not corrects) the budget. Level 2 serves as a **diagnostic safety net**: if the warning fires, it indicates the Level 1 estimate was too optimistic, and the estimation function for the offending kernel(s) should be improved.
+
+```cpp
+// Pseudo-code in TransformGraph, after L2/L3, before InsertCastTransformer:
+if (level2_total > budget) {
+ size_t overrun = level2_total - budget;
+ LOGS(logger, WARNING)
+ << "Post-fusion budget re-check: EP '" << ep_type
+ << "' exceeds memory budget by " << overrun << " bytes. "
+ << "Level 1 estimation was too optimistic. "
+ << "Consider improving workspace estimation for fused ops. "
+ << "Proceeding — runtime OOM may occur.";
+}
+```
+
+This keeps the pipeline simple: Level 2 is purely observational (re-check + warn), not interventional. If the warning fires in testing, the developer improves the relevant `EstimateWorkspace()` function. In production, the budget was validated at Level 1 and Level 2 divergence should be rare.
+
+**Why Level 2 exists separately from Level 1:**
+
+Level 1 (during `GetCapability()`) operates on the **pre-fusion** graph. It estimates workspace for the original unfused nodes (Conv, Relu separately). Level 2 operates on the **post-fusion** graph (ConvRelu as a single node). The two can diverge when:
+- A fused op's workspace differs from the sum of its parts (common — fusion often reduces workspace)
+- Level 2/3 optimizers add or remove nodes (e.g., constant folding eliminates a node entirely)
+
+For most LLM models (which are repetitive transformer blocks with minimal fusion opportunity), Level 1 and Level 2 will agree. Level 2 matters more for CNN models with heavy fusion (Conv+BN+Relu patterns).
+
+**When static shapes are unavailable:**
+
+If the model has dynamic shapes, the estimation function cannot compute workspace (shapes are unknown at `GetCapability()` time). In this case:
+- The estimation function returns a failure status or a sentinel value indicating "unknown."
+- `ComputeNodeCostForBudget()` falls back to the 1.5x heuristic multiplier on base cost.
+- The user may need to **tune the memory budget by trial and error** — setting a conservative budget and adjusting based on observed OOM or under-utilization. This is analogous to llama.cpp's `-ngl` flag: the user picks a layer count and adjusts based on whether it fits.
+- A future extension could accept user-provided "typical shape hints" (e.g., `max_batch=4, max_seq=2048`) to enable estimation even for dynamic-shape models, but this is out of scope for the initial design.
+
+**Plugin C ABI for Level 1:**
+
+```c
+// Workspace estimation function type (no kernel instance needed):
+typedef OrtStatus*(ORT_API_CALL* OrtKernelWorkspaceEstimateFunc)(
+ _In_ const OrtEpApi* api,
+ _In_ const OrtNode* node,
+ _In_ const OrtEp* ep,
+ _Out_ size_t* estimated_workspace_bytes);
+
+// Extension to KernelRegistry_AddKernel in OrtEpApi:
+ORT_API2_STATUS(KernelRegistry_AddKernelV2,
+ _Inout_ OrtKernelRegistry* registry,
+ _In_ const OrtKernelDef* kernel_def,
+ _In_ OrtKernelCreateFunc create_func,
+ _In_opt_ void* create_func_state,
+ _In_opt_ OrtKernelWorkspaceEstimateFunc workspace_estimate_func); // NEW — may be NULL
+```
+
+**Required OrtEpApi additions for the estimation function to query node info:**
+
+```c
+// Query input shape from a node's NodeArg (populated by shape inference):
+ORT_API2_STATUS(Node_GetInputShape,
+ _In_ const OrtNode* node,
+ _In_ size_t input_index,
+ _Outptr_result_maybenull_ const int64_t** shape, // NULL if dynamic
+ _Out_ size_t* rank);
+
+// Query integer attribute from a node:
+ORT_API2_STATUS(Node_GetAttributeInt,
+ _In_ const OrtNode* node,
+ _In_ const char* attr_name,
+ _Out_ int64_t* value);
+
+// Query integer array attribute:
+ORT_API2_STATUS(Node_GetAttributeInts,
+ _In_ const OrtNode* node,
+ _In_ const char* attr_name,
+ _Outptr_ const int64_t** values,
+ _Out_ size_t* count);
+```
+
+**Device properties and library handles** (cuDNN, cuBLAS, etc.) are accessed by casting `OrtEp*` to the EP's concrete type inside the estimation function — no public API needed (see examples above).
+
+---
+
+### Implementation Across Kernel Types and GetCapability Paths
+
+ORT has **three distinct kernel authoring scenarios** and **two GetCapability architectures**. The workspace estimation and declaration APIs must work correctly in each combination.
+
+#### Three Kernel Types
+
+| Type | Description | Registration mechanism | Examples |
+|------|-------------|----------------------|----------|
+| **In-tree** | C++ kernels compiled into the ORT binary | `BuildKernelCreateInfo<>()` via macros (`ONNX_OPERATOR_TYPED_KERNEL_EX`) | All CPU kernels, legacy CUDA EP kernels |
+| **Plugin (shared source)** | Same C++ source as in-tree, compiled into EP plugin DLL, uses adapter layer | `KernelRegistry_AddKernel` C ABI, with `CudaKernelAdapter` bridging | CUDA EP plugin kernels |
+| **Pure ABI** | Kernels written directly against the C ABI (`OrtKernelImpl`) | `KernelRegistry_AddKernel` C ABI, `OrtKernelImpl` function pointers | Third-party EP plugin kernels |
+
+#### Two GetCapability Architectures
+
+| Architecture | Resource budgeting location | Workspace estimation call site |
+|-------------|---------------------------|-------------------------------|
+| **In-tree** | Inside `CUDAExecutionProvider::GetCapability()` — EP owns the loop, calls `resource_accountant->ComputeResourceCount(node)`, makes accept/reject decisions | EP calls estimation function directly in its loop |
+| **Plugin bridge** | In `PluginExecutionProvider::GetCapability()` (the C++ host wrapper) — plugin EP only proposes candidates, host does budgeting after plugin returns | Host calls estimation function during budget enforcement |
+
+**Critical difference:** In the plugin path, the plugin's `GetCapabilityImpl` returns a list of "I support these nodes" without resource checks. The **host bridge** (`ep_plugin_provider_interfaces.cc`) then iterates those nodes in topological order, calls `resource_accountant->ComputeResourceCount(node)` for each, and enforces the budget — halting assignment when the threshold is exceeded. The plugin never sees the accountant directly.
+
+#### Implementation: `OrtKernelWorkspaceEstimateFunc` (Level 1 — at partitioning time)
+
+**In-tree path:**
+
+```cpp
+// In CUDAExecutionProvider::GetCapability() loop (in-tree only):
+const KernelCreateInfo* kci = kernel_lookup.LookUpKernel(node);
+std::optional workspace_estimate;
+if (kci && kci->workspace_estimate_func) {
+ size_t ws = 0;
+ // In-tree: pass IExecutionProvider* — func casts to CUDAExecutionProvider*
+ kci->workspace_estimate_func(this, node, &ws);
+ workspace_estimate = ws;
+}
+// Use non-member helper for budget decision:
+auto total_cost = ComputeNodeCostForBudget(*resource_accountant, node, workspace_estimate);
+// ... budget check with total_cost ...
+```
+
+The estimation function for in-tree kernels is a static member function. It casts the EP pointer to `CUDAExecutionProvider*` to access `GetDeviceProp()` and `PerThreadDefaultCudnnHandle()` — exactly the same pattern kernels already use.
+
+**Plugin bridge path:**
+
+```cpp
+// In PluginExecutionProvider::GetCapability() host-side budget loop
+// (ep_plugin_provider_interfaces.cc):
+for (const auto& node_grouping : api_graph_support_info.node_groupings) {
+ const Node& internal_node = node_grouping.nodes[0]->GetInternalNode();
+
+ // Look up workspace estimate function from kernel registry
+ const KernelCreateInfo* kci = kernel_lookup.LookUpKernel(internal_node);
+ std::optional workspace_estimate;
+ if (kci && kci->workspace_estimate_func) {
+ size_t ws = 0;
+ // Plugin path: registered via KernelRegistry_AddKernelV2
+ // Function casts OrtEp* to its concrete type internally
+ OrtStatus* est_status = kci->workspace_estimate_func(
+ &ep_api_, ep_node->ToExternal(), ort_ep_.get(), &ws);
+ if (est_status) { OrtApis::ReleaseStatus(est_status); }
+ else { workspace_estimate = ws; }
+ }
+
+ // Same non-member helper as in-tree:
+ auto total_cost = ComputeNodeCostForBudget(*resource_accountant, internal_node,
+ workspace_estimate);
+ // ... budget check with total_cost ...
+}
+```
+
+**Pure ABI path (third-party EP):**
+
+Same as plugin bridge — the estimation function is registered via `KernelRegistry_AddKernelV2` and called by the host during budget enforcement. The kernel author provides the function pointer at registration time:
+
+```c
+// Third-party EP kernel registration:
+OrtStatus* MyConvEstimate(const OrtEpApi* api, const OrtNode* node,
+ const OrtEp* ep, size_t* out) {
+ // Cast to concrete EP type to access device-specific state:
+ auto* my_ep = static_cast(ep);
+ // ... compute workspace from node shapes + my_ep->device_properties ...
+}
+
+// During EP's RegisterKernels callback:
+ep_api->KernelRegistry_AddKernelV2(registry, conv_kernel_def, CreateConvKernel,
+ /*state=*/nullptr, &MyConvEstimate);
+```
+
+#### Implementation: `DeclareWorkspaceRequirements` (Level 2 — after kernel creation)
+
+**In-tree path:**
+
+Straightforward — add a virtual method to `OpKernel`:
+
+```cpp
+// In include/onnxruntime/core/framework/op_kernel.h:
+[[nodiscard]] virtual Status DeclareWorkspaceRequirements(
+ gsl::span input_shapes,
+ InlinedVector& requirements) const {
+ return Status::OK(); // Default: no workspace declared
+}
+```
+
+In-tree kernels override this just like they override `PrePack()`. Called during `FinalizeSessionState()` after kernel instances exist.
+
+**Plugin (shared source) path:**
+
+The `CudaKernelAdapter` already bridges virtual calls to the underlying kernel class. The adapter forwards `DeclareWorkspaceRequirements` to the underlying kernel's implementation:
+
+```cpp
+// In cuda_kernel_adapter.h — adapter already forwards PrePack similarly:
+Status DeclareWorkspaceRequirements(
+ gsl::span input_shapes,
+ InlinedVector& requirements) const override {
+ // The underlying kernel class (compiled in the plugin DLL) implements this directly.
+ // CudaKernelAdapter inherits from T, so T::DeclareWorkspaceRequirements is accessible.
+ return T::DeclareWorkspaceRequirements(input_shapes, requirements);
+}
+```
+
+Since plugin shared-source kernels ARE the same C++ class (just compiled in a different DLL), they implement `DeclareWorkspaceRequirements` as a regular virtual override — no ABI translation needed.
+
+**Pure ABI path (third-party EP):**
+
+Add an optional function pointer to `OrtKernelImpl`:
+
+```c
+// In onnxruntime_ep_c_api.h, extend OrtKernelImpl:
+struct OrtKernelImpl {
+ // ... existing fields (Compute, Release, PrePackWeight, ...) ...
+
+ // NEW — optional workspace declaration (ORT >= 1.XX):
+ ORT_API2_STATUS(DeclareWorkspaceRequirements,
+ _In_ OrtKernelImpl* this_ptr,
+ _In_ const int64_t* const* input_shapes, // array of shape arrays
+ _In_ const size_t* input_ranks, // rank of each input
+ _In_ size_t num_inputs,
+ _Out_ OrtWorkspaceRequirement** requirements, // allocated by kernel
+ _Out_ size_t* num_requirements);
+};
+```
+
+The `PluginEpOpKernel` adapter (in `ep_kernel_registration.cc`) bridges this to the virtual call:
+
+```cpp
+// In PluginEpOpKernel:
+Status DeclareWorkspaceRequirements(
+ gsl::span input_shapes,
+ InlinedVector& requirements) const override {
+ // Version guard (same pattern as PrePack):
+ if (kernel_impl_->ort_version_supported < XX ||
+ kernel_impl_->DeclareWorkspaceRequirements == nullptr) {
+ return Status::OK(); // No declaration — fall back to arena
+ }
+
+ // Convert TensorShape spans to C arrays
+ InlinedVector shape_ptrs;
+ InlinedVector ranks;
+ for (const auto& shape : input_shapes) {
+ shape_ptrs.push_back(shape.GetDims().data());
+ ranks.push_back(shape.NumDimensions());
+ }
+
+ OrtWorkspaceRequirement* reqs = nullptr;
+ size_t num_reqs = 0;
+ ORT_RETURN_IF_ERROR(ToStatusAndRelease(
+ kernel_impl_->DeclareWorkspaceRequirements(
+ kernel_impl_, shape_ptrs.data(), ranks.data(),
+ shape_ptrs.size(), &reqs, &num_reqs)));
+
+ // Convert C results to C++ vector
+ for (size_t i = 0; i < num_reqs; ++i) {
+ requirements.push_back({reqs[i].size_bytes, reqs[i].slot_id});
+ }
+ // Free C allocation (kernel used OrtAllocator or static buffer)
+ return Status::OK();
+}
+```
+
+#### Summary: Where Each Piece Lives
+
+| Component | In-tree | Plugin (shared source) | Pure ABI |
+|-----------|---------|----------------------|----------|
+| **Workspace estimation func** | Static member on kernel class; stored in `KernelCreateInfo::workspace_estimate_func` | Same static function, registered via `KernelRegistry_AddKernelV2` | C function pointer, registered via `KernelRegistry_AddKernelV2` |
+| **Who calls estimation** | EP's `GetCapability()` loop via `ComputeNodeCostForBudget()` helper | Host bridge via same `ComputeNodeCostForBudget()` helper | Host bridge (same) |
+| **DeclareWorkspaceRequirements** | Virtual override on `OpKernel` | Virtual override (same C++ class in plugin DLL) | `OrtKernelImpl::DeclareWorkspaceRequirements` function pointer → `PluginEpOpKernel` adapter |
+| **Who calls DeclareWorkspace** | `FinalizeSessionState()` | `FinalizeSessionState()` (same) | `FinalizeSessionState()` via adapter |
+| **Device property access** | `static_cast(ep)->GetDeviceProp()` | `static_cast(ep)->GetDeviceProp()` | `static_cast(ep)->GetDeviceProps()` |
+| **cuDNN handle access** | `static_cast(ep)->PerThreadDefaultCudnnHandle()` | `static_cast(ep)->GetCudnnHandle()` | N/A (EP-specific) |
+
+#### Key Design Principle
+
+The **estimation function** signature differs between in-tree and plugin paths:
+
+- **In-tree:** `static size_t EstimateWorkspace(const IExecutionProvider* ep, const Node& node)` — C++ types, direct EP access
+- **Plugin/ABI:** `OrtStatus* EstimateWorkspace(const OrtEpApi*, const OrtNode*, const OrtEp*, size_t*)` — C ABI, opaque types
+
+But both compute the same result. For shared-source kernels (compiled both in-tree and as plugin), a single static helper function (e.g., `ComputeAttentionWorkspace()`) is called from both wrappers — ensuring the estimate is identical regardless of build configuration.
+
+---
+
+#### Phase A: Workspace Pre-declaration (`DeclareWorkspaceRequirements`)
+
+The core missing piece. Today no kernel declares its temp buffer needs before `Compute()`. To close this gap, introduce a method on `OpKernel` that returns workspace descriptors — each with a size and a key for later retrieval.
+
+**Analogy to PrePack:** This mechanism is similar to the existing `PrePack()` pattern — both are called once during session state finalization (not during `Run()`), both store results that are reused across all subsequent runs. `PrePack()` pre-processes weight data; `DeclareWorkspaceRequirements()` pre-computes workspace layout.
+
+**Interface:**
+
+```cpp
+struct WorkspaceRequirement {
+ size_t size_bytes; // Size of this workspace buffer
+ int slot_id; // Kernel-defined slot identifier (0, 1, 2, ...)
+ // Unique within a single kernel instance
+};
+
+// Optional override on OpKernel (called during FinalizeSessionState):
+virtual Status DeclareWorkspaceRequirements(
+ gsl::span input_shapes,
+ InlinedVector& requirements) const {
+ return Status::OK(); // Default: no declaration (fall back to arena)
+}
+```
+
+A kernel can declare multiple workspace slots (e.g., attention needs separate Q transpose buffer, output buffer, seqlens buffer). The `slot_id` is defined by the kernel author and is stable across calls — it identifies *which* buffer within that kernel's logic.
+
+**Key constraint:** Multiple nodes may use the same kernel class. Each node instance gets its own set of workspace slots. The unique key for retrieval is `(NodeIndex, slot_id)` — the framework supplies `NodeIndex`, the kernel supplies `slot_id`.
+
+**Memory reuse via liveness-based offset planning:**
+
+Workspace buffers are live only during their kernel's execution step. This means workspaces from non-overlapping steps can share the same physical memory — exactly the same liveness analysis already used for activation tensors. The offset planner assigns overlapping offsets to workspaces whose liveness intervals don't intersect:
+
+```
+Step 0: Node A workspace (slots 0,1) → offsets [0, 4096]
+Step 1: Node B workspace (slot 0) → offset [0] ← reuses Node A's memory
+Step 2: Node C workspace (slots 0,1) → offsets [0, 8192] ← reuses again
+```
+
+Peak workspace memory = max over all steps of (sum of workspace slots for that step), not the sum of all workspaces across all nodes.
+
+**Concurrency model (multiple concurrent `Run()` calls):**
+
+The existing memory pattern system already handles this correctly:
+- The **pattern** (offset/size map) is computed once during `Initialize()` and cached in `SessionState` — shared, read-only.
+- The **actual buffer** is allocated per-`Run()` by each `ExecutionFrame` using the pattern as a blueprint.
+- Each concurrent `Run()` gets its own `ExecutionFrame` with its own workspace buffer — no sharing, no synchronization needed.
+
+Workspace pre-allocation follows the same model:
+- `DeclareWorkspaceRequirements()` is called during `FinalizeSessionState()` → produces a workspace offset plan (shared, immutable).
+- Each `Run()` allocates a workspace buffer of `peak_workspace_size` bytes and uses offsets from the plan.
+- Concurrent runs each get their own buffer — safe without locks.
+
+**Note on CUDA:** In practice, concurrent `Run()` on the same CUDA session is uncommon (users don't typically do this). But the design should remain thread-safe by following the same per-run buffer pattern.
+
+**Single-thread pre-allocation mode (eliminating runtime OOM):**
+
+Even with workspace planning, the per-`Run()` buffer allocation can still OOM if device memory is fragmented or consumed by other processes since `Initialize()`. For constrained environments, this is the last remaining point of failure.
+
+Most constrained-environment users run **single-threaded inference** — one `Run()` at a time. ORT already has a concurrent-run counter (`InferenceSession::current_num_runs_`). If the session is configured to disallow concurrency, the execution buffer (which includes workspace slots) can be **allocated once at initialization and reused for every `Run()` call**.
+
+**Proposed** (not currently implemented): a session option such as `session.pre_allocate_execution_buffers = "1"` would enable this behavior.
+
+When enabled:
+1. After `FinalizeSessionState()` computes the memory pattern (including workspace offsets from `DeclareWorkspaceRequirements`), allocate the peak buffer once: `IAllocator::Alloc(peak_size)` per EP.
+2. Store the pre-allocated buffer pointer on `SessionState`.
+3. Each `Run()` reuses the same buffer — no allocation, no OOM possible.
+4. Enforce `max_concurrent_runs = 1`: if a second `Run()` arrives, fail fast.
+
+```cpp
+if (pre_allocate_mode_ && current_num_runs_.fetch_add(1) > 0) {
+ current_num_runs_.fetch_sub(1);
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "Concurrent Run() not allowed with pre-allocated execution buffers.");
+}
+```
+
+**What this guarantees:** If `Initialize()` succeeds, `Run()` cannot OOM — all device memory (weights + intermediates + workspace) is already resident. The budget at partition time accounts for all three: `budget ≥ weights_on_device + peak_execution_buffer`.
+
+**What already exists:** `MemoryPattern` computation is done, `MemoryPatternGroup::GetPeakAllocSize()` gives peak size, `current_num_runs_` counter exists, per-EP allocators exist. The `ExecutionFrame` already uses offset-based placement into a contiguous block — the change is to not free/reallocate that block between calls.
+
+**Scope:** Single-threaded only. For concurrent inference, multiple buffers are needed (defeating the guarantee).
+
+**Interaction with dynamic shapes:** `pre_allocate_execution_buffers` is fundamentally a **static-shape-only** feature. With dynamic shapes, `ExecutionFrame` must allocate buffers on every `Run()` because activation tensor sizes are unknown until the input arrives — there is no way to pre-compute a total buffer size at `Initialize()` time. Even if some kernels' workspace slots are shape-independent, the activation portion (which typically dominates) still requires per-`Run()` allocation, so the OOM-elimination guarantee cannot hold.
+
+Furthermore, the arena allocator already handles repeated allocations efficiently (same-size blocks are recycled without syscalls), so pre-allocating just the workspace portion while leaving activations dynamic would add complexity for negligible gain.
+
+**Summary:** For dynamic-shape models, the value of `DeclareWorkspaceRequirements` is in **budget estimation** (Level 1/Level 2, using worst-case or max-batch sizes to decide how many nodes fit on the device), not in runtime pre-allocation.
+
+**Planning flow (during FinalizeSessionState):**
+
+1. For each kernel in the execution plan (when shapes are static), call `DeclareWorkspaceRequirements()` with the inferred input shapes.
+2. Record `{NodeIndex, slot_id} → size_bytes` in the execution plan.
+3. Run liveness analysis: workspace for node N is live only during step N's execution.
+4. Compute offsets (same algorithm as activation patterns) → yields `peak_workspace_size` and per-slot offsets.
+5. Store workspace pattern as a **separate `WorkspacePattern`** in `SessionState`.
+
+**Why workspace buffers are separate from `MemoryPattern` (activations):**
+
+Although the offset planning algorithm is the same (liveness → assign offsets → compute peak), workspace buffers differ in allocation and retrieval:
+
+| Aspect | MemoryPattern (activations) | WorkspacePattern |
+|--------|---------------------------|------------------|
+| **Addressing** | `MLValueIndex` — framework-assigned, part of graph IR | `(NodeIndex, slot_id)` — kernel-defined, opaque to framework |
+| **Who queries** | Framework automatically when creating output `OrtValue`s | Kernel explicitly via `GetPreallocatedWorkspace(slot_id)` |
+| **Lifetime** | Multi-step — output lives until its last consumer executes | Single-step — live only during the owning kernel's step |
+| **What's returned** | An `OrtValue` (typed tensor with shape metadata) | Raw `void*` — kernel interprets the bytes internally |
+| **Graph visibility** | Framework manages these as edges between nodes | Invisible to graph — internal scratch memory |
+| **Size determination** | Inferred from output shape × element_size | Declared by kernel (may be unrelated to any tensor shape) |
+
+Concretely, this means:
+- `WorkspacePattern` is a new class (not reusing `MemoryPatternGroup`) with its own lookup: `GetOffset(NodeIndex, slot_id) → {offset, size}`.
+- The workspace buffer is allocated separately from the activation buffer. They could share physical memory (workspace is always single-step, so it never overlaps with itself across steps), but keeping them separate simplifies accounting and makes budget tracking unambiguous: `peak_total = peak_activations + peak_workspace`.
+- In pre-allocation mode, both buffers are allocated once at init. In normal mode, both are allocated per-`Run()` from the arena. But they remain distinct allocations with distinct query paths.
+
+**Per-Run retrieval (during Compute):**
+
+Each `ExecutionFrame` allocates a workspace buffer of `peak_workspace_size` via the EP's allocator and provides offset-based access through a dedicated query interface (not the existing OrtValue/MLValue machinery):
+
+**Alternative A: Transparent fallback in GetScratchBuffer**
+
+Modify `GetScratchBuffer(slot_id, size, stream)` to check for a pre-planned buffer first:
+
+```cpp
+template
+IAllocatorUniquePtr GetScratchBuffer(int slot_id, size_t count_or_bytes, Stream* stream) const {
+ // Check if workspace was pre-planned for this node + slot
+ void* preallocated = context_.GetPreallocatedWorkspace(slot_id);
+ if (preallocated) {
+ // Return non-owning pointer (buffer lifetime managed by the frame)
+ return IAllocatorUniquePtr(static_cast(preallocated), [](T*){});
+ }
+ // Fall back to arena (dynamic shapes, or DeclareWorkspaceRequirements not implemented)
+ return IAllocator::MakeUniquePtr(allocator_, count_or_bytes, false, stream);
+}
+```
+
+Pro: Minimal kernel code changes — just add `slot_id` parameter. Con: Overloads `GetScratchBuffer` semantics; non-owning vs owning pointer distinction is subtle.
+
+**Alternative B: Separate retrieval path**
+
+Keep `GetScratchBuffer()` unchanged for arena allocation. Add a new method:
+
+```cpp
+// In OpKernelContext:
+void* GetPreallocatedWorkspace(int slot_id) const;
+// Returns nullptr if not pre-planned → kernel must call GetScratchBuffer() instead
+
+// Kernel usage:
+void* ws = context->GetPreallocatedWorkspace(0);
+if (!ws) {
+ scratch_buffer_ = GetScratchBuffer(workspace_size, stream);
+ ws = scratch_buffer_.get();
+}
+```
+
+Pro: Clear separation, no ambiguity about ownership. Con: Kernels need explicit fallback logic (but this is a one-time pattern per kernel).
+
+**Compatibility with dynamic shapes:** Both alternatives are opt-in. If `DeclareWorkspaceRequirements()` is not overridden or returns empty (dynamic shapes), everything falls back to `GetScratchBuffer()` → arena, exactly as today. Same kernel binary works for both static and dynamic models.
+
+**Incremental adoption:** Start with the highest-impact ops (attention, convolution, GEMM) which account for the majority of workspace. Less common ops continue using the arena with a reduced safety multiplier in `IResourceAccountant`.
+
+**Buffer strategy:** Workspace offsets can share the activation buffer (liveness doesn't overlap — workspace is live only during its step, activations may span steps). Alternatively, a separate workspace buffer is simpler initially and easier to account for in memory limits.
+
+##### EP Plugin C ABI Surface for Workspace Pre-declaration
+
+In the plugin architecture, `DeclareWorkspaceRequirements` crosses the C ABI boundary. This section defines the concrete API additions.
+
+**Declaration side — new optional function pointer on `OrtKernelImpl`:**
+
+```c
+// Added to OrtKernelImpl (optional, like PrePackWeight):
+ORT_API2_STATUS(DeclareWorkspaceRequirements,
+ _In_ OrtKernelImpl* this_ptr,
+ _In_reads_(num_inputs) const int64_t* const* input_shapes, // shape per input
+ _In_reads_(num_inputs) const size_t* input_shape_ranks, // rank per input
+ _In_ size_t num_inputs,
+ _Out_writes_all_(max_slots) OrtWorkspaceSlot* slots, // pre-allocated by ORT
+ _In_ size_t max_slots, // capacity (e.g., 8)
+ _Out_ size_t* num_slots); // actual count filled
+
+// Slot descriptor (C struct, no inheritance):
+typedef struct OrtWorkspaceSlot {
+ int slot_id; // Kernel-defined, stable identifier (0, 1, 2, ...)
+ size_t size_bytes; // Required size for this slot
+} OrtWorkspaceSlot;
+```
+
+If `DeclareWorkspaceRequirements` is NULL on the `OrtKernelImpl`, ORT skips the kernel during workspace planning (falls back to arena at runtime).
+
+**Retrieval side — new function in `OrtEpApi`:**
+
+```c
+// Added to OrtEpApi (called by plugin kernels during Compute):
+ORT_API2_STATUS(KernelContext_GetPreallocatedWorkspace,
+ _In_ const OrtKernelContext* context,
+ _In_ int slot_id,
+ _Outptr_result_maybenull_ void** buffer); // NULL if not pre-planned
+```
+
+Returns a pointer into the pre-allocated workspace buffer at the offset computed during planning. Returns NULL if no workspace was pre-planned for this kernel+slot (dynamic shapes, or kernel didn't declare). The pointer is valid for the duration of the `Compute()` call.
+
+**Slot ID provisioning — how kernels define unique slot_ids:**
+
+Slot IDs are **kernel-author-defined constants**, not dynamically allocated. Each kernel class defines its slots as an enum or set of constants in its implementation:
+
+```cpp
+// Example: CUDA Attention kernel (inside the plugin DLL)
+namespace cuda {
+class AttentionKernel : public OrtKernelImplBase {
+ // Slot IDs are private constants — stable across versions, used as array indices
+ static constexpr size_t kSlotQTranspose = 0;
+ static constexpr size_t kSlotKTranspose = 1;
+ static constexpr size_t kSlotVTranspose = 2;
+ static constexpr size_t kSlotSoftmaxWorkspace = 3;
+ static constexpr size_t kNumSlots = 4;
+
+ OrtStatus* DeclareWorkspaceRequirements(...) override {
+ slots[kSlotQTranspose] = {kSlotQTranspose, batch * heads * seq * head_dim * sizeof(half)};
+ slots[kSlotKTranspose] = {kSlotKTranspose, batch * heads * seq * head_dim * sizeof(half)};
+ slots[kSlotVTranspose] = {kSlotVTranspose, batch * heads * seq * head_dim * sizeof(half)};
+ slots[kSlotSoftmaxWorkspace] = {kSlotSoftmaxWorkspace, cudnn_workspace_size};
+ *num_slots = kNumSlots;
+ return nullptr;
+ }
+
+ OrtStatus* Compute(OrtKernelContext* ctx) override {
+ void* q_buf = nullptr;
+ // Uses pre-planned workspace if available, falls back to arena otherwise
+ api_->KernelContext_GetScratchBuffer(ctx, kSlotQTranspose, q_transpose_size, &q_buf);
+ // ... use q_buf ...
+ }
+};
+} // namespace cuda
+```
+
+**Key design properties:**
+
+| Property | Design Choice | Rationale |
+|----------|--------------|-----------|
+| Slot ID scope | Per kernel *instance* (node) | Same kernel class on different nodes gets separate buffers; ORT disambiguates via `(NodeIndex, slot_id)` |
+| Slot ID assignment | Static constants in kernel code | No registry, no runtime allocation, no cross-kernel coordination needed |
+| Slot ID range | `[0, max_slots)` — small integers | Simple array indexing in the offset plan; `max_slots` = 8 is generous for any single kernel |
+| Uniqueness guarantee | Kernel author's responsibility | Same convention as `input_index` in `PrePackWeight` — the kernel knows its own buffer layout |
+| Stability across versions | Expected (like enum values) | Slot IDs are internal to the kernel; not exposed to users or other kernels |
+
+**Where state lives:**
+
+| State | Location | Lifetime |
+|-------|----------|----------|
+| Slot definitions (id + size) | Returned by `DeclareWorkspaceRequirements` → stored in `ExecutionPlan` | Session lifetime (computed once at `Initialize()`) |
+| Offset map `{(NodeIndex, slot_id) → offset}` | `SessionState::workspace_pattern_` (new field, analogous to `mem_patterns_`) | Session lifetime (shared, read-only) |
+| Peak workspace size per EP/device | `SessionState::workspace_pattern_` | Session lifetime |
+| Actual workspace buffer | `ExecutionFrame` (allocated per-`Run()` via `Reserve()`) | Single `Run()` invocation |
+
+**No global slot registry needed.** Unlike input indices which are defined by the ONNX op schema, slot IDs are entirely internal to the kernel implementation. Two different kernel classes can both use `slot_id=0` without conflict — the framework always qualifies with `NodeIndex`. This means:
+- No coordination between kernel authors
+- No registration step during plugin initialization
+- No versioning concerns (IDs never cross the plugin boundary as semantic values)
+
+#### Phase B: Eliminate Arena for Static-Shape Models
+
+Once workspace is pre-declared, **all** allocations for a static-shape model are known at `Initialize()` time:
+- Initializers → already `Reserve()` (done)
+- Activations → already memory-pattern `Reserve()` (done)
+- Workspace → new, via Phase A
+
+At this point, the BFC arena serves no purpose for the main execution path. The session could:
+1. Pre-allocate exact memory per device (sum of pattern peak + workspace peak)
+2. Use offset-based addressing for all buffers
+3. Disable the arena entirely for this session (save memory waste from chunk granularity)
+
+Runtime temp buffers from ops that don't implement `DeclareWorkspaceRequirements()` can still fall back to a small arena.
+
+### Custom Executable: Purpose and Scope
+
+A minimal custom executable (CLI tool) serves three purposes:
+
+1. **Code example and test bed.** Demonstrates how to configure and exercise the constrained-environment features (name-based partitioning, memory budgets, static allocation mode) end-to-end using the ORT C/C++ API. Acts as a living integration test that exercises the full pipeline without depending on GenAI or external frameworks.
+
+2. **Interactive LLM demo (llama.cpp-style UX).** Loads a transformer ONNX model, manages the decode loop (prompt → KV cache → token sampling → output), and interacts with the user via stdin/stdout. This showcases ORT's ability to run large models on constrained hardware with the same user experience as llama.cpp — but backed by ORT's general-purpose runtime.
+
+3. **Primitive GenAI replacement for testing.** For the narrow case of single-model, single-user, greedy/top-k text generation, the executable can replace GenAI as a simpler alternative that doesn't pull in the full GenAI dependency. It is **not** a production replacement for GenAI (no batching, no beam search, no speculative decoding) — it is a minimal harness for validating that the partitioning and memory features work correctly on real models.
+
+**What the executable handles (application-level):**
+- Token encode/decode (via sentencepiece or tokenizers library)
+- KV cache allocation and rotation (fixed max sequence length)
+- Autoregressive decode loop (feed output token back as next input)
+- Session configuration: name-based layer assignment, memory budget, static shapes
+
+**What ORT handles (session-level, no executable changes needed):**
+- Graph partitioning across devices (Direction 1 + `IResourceAccountant`)
+- Static memory pre-allocation (Phase A + B)
+- Kernel execution, data transfers, stream synchronization
+
+**Feasibility: no fundamental ORT blockers.** The existing session API (`CreateSession` → `Run` with named I/O) is sufficient for an autoregressive decode loop. KV cache management is feeding output tensors back as inputs — the same pattern GenAI uses over the same C API.
+
+| Concern | Status | Notes |
+|---------|--------|-------|
+| Tokenizer | Reuse from ORT Extensions | See tokenizer strategy below |
+| KV cache rotation | Straightforward | Pre-allocate `(batch, heads, max_seq, head_dim)`, feed `past_key_values` outputs back as inputs each step |
+| Decode loop | Trivial | Run session → extract logits → sample token → repeat |
+| Model format | Constraint | Requires decoder-style ONNX export with explicit KV cache I/O (HuggingFace optimum exports provide this) |
+| Partitioning | This design | Direction 1 + `IResourceAccountant` |
+| Static allocation | Phase A+B | Fixed `max_seq_len` makes all decode-phase shapes static |
+
+**Tokenizer strategy — borrowing from ORT Extensions:**
+
+[ORT Extensions](https://github.com/microsoft/onnxruntime-extensions) already implements production-quality tokenizers in C++:
+- **BPE** (GPT-2, LLaMA-3, Phi, Mistral) — `onnxruntime_extensions/tokenizer/bpe_tokenizer.cc`
+- **SentencePiece** (LLaMA-1/2, T5, mT5) — wraps the SentencePiece C++ library
+- **WordPiece** (BERT, DistilBERT) — `onnxruntime_extensions/tokenizer/wordpiece_tokenizer.cc`
+
+For the custom executable, we can **extract the tokenizer C++ code directly** from ORT Extensions rather than taking a full dependency on the extensions DLL. The tokenizer logic is self-contained: it reads a vocabulary/merge file, applies the algorithm (BPE merge loop, SentencePiece unigram, or WordPiece greedy match), and produces token IDs. No ONNX graph execution is involved.
+
+**Practical approach:**
+1. Copy the relevant tokenizer source files (BPE tokenizer is ~500 LOC + vocab loading) into the demo executable's source tree.
+2. Strip the ORT Extensions custom-op registration wrapper — keep only the core `Encode(string) → vector` and `Decode(vector) → string` logic.
+3. Load the tokenizer model file (e.g., `tokenizer.json` from HuggingFace, or `tokenizer.model` for SentencePiece) at startup alongside the ONNX model.
+
+This gives us a battle-tested tokenizer with no additional runtime dependency — just a few source files compiled into the executable. The code is already Apache-2.0 licensed (same as ORT).
+
+The executable would be ~500–1000 LOC (excluding tokenizer): configure session options, set up KV cache tensors, run the generate loop. With the borrowed tokenizer code, the total grows to ~1500–2000 LOC but remains self-contained with zero external dependencies beyond ORT itself.
+
+---
+
+## Recommended Roadmap
+
+```
+Near-term (low effort, high value):
+├── 1. Name-based matching via session.name_based_layer_assignment [DONE]
+│ - Separate session option with substring matching against Node::Name()
+│ - SubstringMatcher with longest-match-wins priority
+│ - Mutually exclusive with annotation-based matching (setting both options returns INVALID_ARGUMENT)
+│
+├── 2. Precise per-node memory estimation
+│ - Static workspace estimation functions registered per kernel type
+│ - IResourceAccountant uses exact output sizes + workspace estimates
+│ - Eliminates 1.5x multiplier for kernels with estimation functions
+│
+Mid-term (medium effort):
+├── 3. Auto-partitioning with memory budget only
+│ - User specifies "6GB GPU budget"
+│ - ORT computes optimal layer split automatically
+│ - Combines (1) + (2)
+│
+├── 4. Static allocation mode
+│ - Pre-allocate all buffers when shapes are known
+│ - Eliminate per-Run() allocation overhead
+│
+Long-term (high effort, ollama-parity):
+├── 5. Layer prefetch pipeline
+│ - Stream weights CPU↔GPU during execution
+│ - Enables running models larger than GPU memory
+│
+└── 6. Integration with GenAI
+ - KV cache-aware memory planning
+ - Continuous batching + layer offload coordination
+```
+
+---
+
+## Key Insight
+
+The fundamental difference between ORT and llama.cpp for this use case is **generality vs specialization**. llama.cpp knows it's running a transformer with sequential layers. ORT handles arbitrary graphs. The trick is to **detect** when a model is transformer-like (sequential layers, static shapes) and engage a specialized execution path — without losing generality for other model types.
+
+Direction 1 (name-based matching) is the lowest-friction win: it makes the existing annotation system accessible without model modification. Direction 2 (static pre-allocation + auto-splitting) is what closes the gap with ollama but requires more infrastructure work, particularly around shape-aware memory planning at partition time.
diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md
index 6d53211ff40cb..36b68889ae582 100644
--- a/docs/contrib_ops/cuda/moe_qmoe.md
+++ b/docs/contrib_ops/cuda/moe_qmoe.md
@@ -71,6 +71,7 @@ input tokens → router (top-k softmax) → permute by expert
| `expert_weight_bits` (QMoE only) | int | 4 | 4 (INT4/MXFP4) or 8 (INT8/FP8). |
| `block_size` (QMoE only) | int | -1 | Group size for INT4/INT8 group-wise quantization. -1 = per-output-channel. |
| `quant_type` (QMoE only) | string | `"int"` | `"int"`, `"fp4"`, `"fp8"`, `"wfp4afp8"`. See [§3](#3-quantization-modes). |
+| `weights_prepacked` (QMoE only) | int | -1 | Tri-state, only meaningful when `quant_type="int"`. The prepacked layouts selected by `-1` and `1` are **EP-determined**. `-1` (default): the INT4/INT8 `fc1`/`fc2` initializers are already prepacked in the EP's default layout (e.g. from `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). `1`: already prepacked in an alternate EP-selected layout. `0`: the initializers are raw `[E, N, K/pack]` tensors (as produced by `quantize_matmul_{4,8}bits`) and the kernel runs the CUTLASS layout transform in `PrePack()`. **Note:** the CUDA EP INT4/INT8 MoE GEMM always runs the Ampere (SM80) kernel — even on SM90 — so it consumes the SM80 `fpA_intB` layout on all architectures; `-1` and `1` are therefore equivalent for the CUDA EP today, and `1` is reserved for a possible future Hopper-specific layout. See [§5.1](#51-weights-input-2--5--8). |
### 2.2 Type Constraints
@@ -228,10 +229,53 @@ extra subtraction.
### 5.1 Weights (input 2 / 5 / 8)
-Not transformed at runtime. INT4/INT8 weights must already be packed offline by
-`pack_weights_for_cuda_mixed_gemm` (see [§6](#6-weight-formats)). MXFP4 weights
-must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights are stored
-as raw e4m3 bytes (no packing).
+**INT4/INT8** weight layout is controlled by the `weights_prepacked` attribute
+([§2.1](#21-attributes)). The prepacked layouts selected by `-1` and `1` are
+determined by the execution provider:
+
+- **`weights_prepacked=-1` (default)** — the `fc1`/`fc2` weights are already in
+ the EP's default prepacked layout (e.g. packed offline by
+ `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). They are copied to GPU
+ and consumed as-is.
+- **`weights_prepacked=1`** — the `fc1`/`fc2` weights are already in the EP's
+ **SM90** (Hopper) prepacked layout (reserved; see the note below).
+- **`weights_prepacked=0`** — the `fc1`/`fc2` weights are raw, schema-conformant
+ `[E, N, K/pack]` tensors as produced by `quantize_matmul_{4,8}bits`. `PrePack`
+ runs the CUTLASS layout transform itself via `PrePackIntExpertWeights`,
+ removing the offline pre-pack dependency. This makes integer QMoE symmetric
+ with `MatMulNBits::PrePack_B`.
+
+> **Single layout on the CUDA EP.** The CUDA EP INT4/INT8 MoE GEMM always
+> dispatches to the Ampere (**SM80**) grouped-GEMM kernel — even on SM90 —
+> because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA
+> warp-specialized specialisation (`isValidHopperMOESpecialisation` is `false`).
+> This matches **TensorRT-LLM**, which likewise routes `W4A16`/`W8A16` MoE to the
+> SM80 kernel on Hopper; its Hopper TMA-WS mixed-dtype MoE kernel is reserved for
+> `W4A8` (FP8 activation) and `WFP4A16` (FP4 weight). Consequently the CUDA EP
+> consumes the **SM80 `fpA_intB` layout on every GPU**, `PrePack` always packs
+> for SM80, and `weights_prepacked=-1` and `=1` are equivalent today. `1` is
+> accepted and reserved for a possible future Hopper-specific layout (e.g.
+> `W4A8`). There is therefore no architecture-match constraint: SM80-format
+> weights run correctly on SM90 via the SM80 kernel.
+
+`PrePackIntExpertWeights` loops over the `E` experts and, per expert, applies the
+same transpose + row-permutation / column-interleave / bias / pair-interleave
+transform as `pack_weights_for_cuda_mixed_gemm` (see [§6.1](#61-int4-group-wise-quant_typeint-expert_weight_bits4)),
+always targeting the SM80 layout. SM75+ is required. The source
+`[E, N, K/pack]` initializers are released after their shapes are cached
+(`fc1_weights_shape_` / `fc2_weights_shape_`), so peak weight memory stays ~1×.
+The prepacked GPU buffers (`packed_fc1_weights_` / `packed_fc2_weights_`) are then
+preferred by `ComputeInternal`. If prepacking is disabled at the session level
+(`session.disable_prepacking`), the buffers stay null and the raw initializer
+pointers are read at compute time instead.
+
+> **Note**: `weights_prepacked=0` is the only path that triggers an in-`PrePack`
+> layout transform for INT weights. FP4 / FP8 / WFP4AFP8 weight handling is
+> unaffected.
+
+MXFP4 weights must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights
+are stored as raw e4m3 bytes (no packing).
+
### 5.2 INT4/INT8 scales + zero-point → bias
@@ -287,7 +331,12 @@ This section covers the five distinct weight encodings supported by QMoE.
INT4 packing layout within a byte: `[high_nibble | low_nibble] = [elt_1 | elt_0]`.
Each INT4 element is in `[-8, 7]` (signed) before bias, `[0, 15]` after the +8 bias.
-#### Preprocessing pipeline (offline, `pack_weights_for_cuda_mixed_gemm`)
+#### Preprocessing pipeline (offline `pack_weights_for_cuda_mixed_gemm`, or in-`PrePack` via `PrePackIntExpertWeights`)
+
+This is the layout transform applied either offline by
+`pack_weights_for_cuda_mixed_gemm`, or per-expert inside `PrePack` when
+`weights_prepacked=0` (see [§5.1](#51-weights-input-2--5--8)).
+
1. **Input layout**: `[N, K]` per expert (Out × In), 2 elements per byte for INT4.
2. **Transpose & signed conversion**:
@@ -405,6 +454,17 @@ weights are interchangeable across SMs:
— does not use `pack_weights_for_cuda_mixed_gemm`.
- **FP8**: no packing.
+> **QMoE uses Group A on every GPU.** The table above describes the layouts the
+> `pack_weights_for_cuda_mixed_gemm` *preprocessor* can emit. The QMoE INT4/INT8
+> MoE GEMM, however, always dispatches to the Ampere (SM80) grouped-GEMM kernel —
+> even on SM90 — because mixed int-weight + fp16/bf16 activation is not a valid
+> Hopper TMA warp-specialized specialisation (the same is true in TensorRT-LLM).
+> It therefore consumes the **Group A (SM80) layout on all architectures,
+> including Hopper**. For QMoE, always pack INT4/INT8 weights for SM80 (`arch=80`),
+> and `PrePackIntExpertWeights` (`weights_prepacked=0`) does exactly that
+> regardless of the runtime device SM. Group B (SM90) layout is currently unused
+> by QMoE.
+
---
## 8. SwiGLU Fusion
@@ -830,7 +890,7 @@ will not change the operator interface.
|-----------|----------|
| [test_moe_cuda.py](onnxruntime/test/python/transformers/test_moe_cuda.py) | Standard MoE on CUDA: FP16/BF16, SiLU/GeLU/SwiGLU, routing, GEMM parity. SwiGLU coverage includes both GPT-OSS (`TestSwigluMoE`: interleaved, alpha=1.702/beta=1.0/limit=7.0) and Standard/Llama-Gemma (`TestStandardSwigluMoE`: concatenated `swiglu_fusion=2`, alpha=1.0/beta=0.0/no limit → `SiLU(Gate)×Value`). |
| [test_moe_cpu.py](onnxruntime/test/python/transformers/test_moe_cpu.py) | Standard MoE on CPU (smoke). |
-| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. |
+| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. `TestQMoEIntPrePackSmoke` covers the raw-weight `weights_prepacked=0` in-`PrePack` layout transform (smoke test: asserts finite output, not bit-parity). |
| [test_qmoe_cpu.py](onnxruntime/test/python/transformers/test_qmoe_cpu.py) | INT4/INT8 QMoE on CPU (smoke). |
| [test_qmoe_fp4_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py) | MXFP4 QMoE: quantization utilities, packing, FP16/BF16, SiLU/SwiGLU, top-k and expert-count variants. End-to-end runs on SM120; on SM<120 the dequant fallback is exercised. |
| [test_qmoe_fp8_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py) | FP8 W8A16 QMoE on SM90+ native path and SM<90 dequant fallback. |
@@ -954,6 +1014,11 @@ over-aligned by-value parameters.
cannot. See [§14.1](#141-msvc-and-tma-grouped-moe-gemm).
- **WFP4AFP8 native** requires SM100+ hardware; only the dequant fallback path
is validated end-to-end so far.
+- **In-`PrePack` INT weight layout transform** (`weights_prepacked=0`) is
+ currently covered only by a smoke test (`TestQMoEIntPrePackSmoke`), not a
+ bit-parity check: the existing offline pre-pack harness hardcodes
+ `force_arch=80` (the same SM80 layout consumed by the CUDA EP on all GPUs),
+ so a separate parity harness for this path is still pending.
- **Hopper W4A8** (INT4 weight + FP8 activation) is not supported — TRT-LLM gates
its fast path to SM89 only.
diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h
index 8cf6420f2d0f7..eee75d399d767 100644
--- a/include/onnxruntime/core/common/status.h
+++ b/include/onnxruntime/core/common/status.h
@@ -29,8 +29,10 @@ enum StatusCategory {
};
/**
- Error code for ONNXRuntime.
-*/
+ * Error code for ONNXRuntime.
+ *
+ * These values must stay in sync with the public C API OrtErrorCode enum values.
+ */
enum StatusCode {
OK = 0,
FAIL = 1,
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 0cdbfd7114cf0..edb42c0a2f596 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -263,21 +263,76 @@ typedef enum OrtLoggingLevel {
ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe).
} OrtLoggingLevel;
+/** \brief Error codes reported by ONNX Runtime.
+ *
+ * The error code associated with an ::OrtStatus.
+ */
typedef enum OrtErrorCode {
+ /**
+ * Success. No error occurred.
+ */
ORT_OK,
+ /**
+ * Generic failure that does not map to a more specific error code. Consult the error message for details.
+ */
ORT_FAIL,
+ /**
+ * A caller-supplied argument was invalid (e.g. NULL pointer, out-of-range value, mismatched shape/rank, or bad
+ * configuration).
+ */
ORT_INVALID_ARGUMENT,
+ /**
+ * A required file (such as a model file) does not exist.
+ */
ORT_NO_SUCHFILE,
+ /**
+ * Legacy/unused but retained for ABI compatibility. Historically returned when a model could not be found by name in
+ * the ONNX Runtime Server (removed in 2022).
+ */
ORT_NO_MODEL,
+ /**
+ * A hardware accelerator or backend engine reported a failure (e.g. a device crash or other device-level error).
+ */
ORT_ENGINE_ERROR,
+ /**
+ * A generic runtime exception was caught. The error message is the primary source of detail.
+ */
ORT_RUNTIME_EXCEPTION,
+ /**
+ * Protobuf parsing or serialization failed.
+ */
ORT_INVALID_PROTOBUF,
+ /**
+ * Invalid session state for the requested operation. Despite the name, this code does not mean "success, model
+ * loaded"; it is returned when the session is in the wrong state for the requested call (e.g. a model is already
+ * loaded, the session is already initialized, or no model has been loaded yet). The name is historical and is
+ * retained for ABI compatibility; consult the error message for the specific condition.
+ */
ORT_MODEL_LOADED,
+ /**
+ * The requested functionality is not implemented in this build.
+ */
ORT_NOT_IMPLEMENTED,
+ /**
+ * The model graph is structurally invalid (e.g. recursive function definitions, invalid tensor dimensions, or
+ * malformed nodes).
+ */
ORT_INVALID_GRAPH,
+ /**
+ * An execution provider reported a generic failure.
+ */
ORT_EP_FAIL,
+ /**
+ * Model loading or session initialization was canceled at the caller's request.
+ */
ORT_MODEL_LOAD_CANCELED,
+ /**
+ * The model requires compilation by an execution provider, but compilation was disabled via session options.
+ */
ORT_MODEL_REQUIRES_COMPILATION,
+ /**
+ * A requested resource could not be found.
+ */
ORT_NOT_FOUND,
} OrtErrorCode;
@@ -348,9 +403,6 @@ ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared exte
ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails);
ORT_RUNTIME_CLASS(EpAssignedSubgraph);
ORT_RUNTIME_CLASS(EpAssignedNode);
-ORT_RUNTIME_CLASS(ModelPackageOptions);
-ORT_RUNTIME_CLASS(ModelPackageContext);
-ORT_RUNTIME_CLASS(ModelPackageComponentContext);
#ifdef _MSC_VER
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -938,9 +990,6 @@ typedef struct OrtCompileApi OrtCompileApi;
struct OrtInteropApi;
typedef struct OrtInteropApi OrtInteropApi;
-struct OrtModelPackageApi;
-typedef struct OrtModelPackageApi OrtModelPackageApi;
-
struct OrtEpApi;
typedef struct OrtEpApi OrtEpApi;
@@ -7501,26 +7550,6 @@ struct OrtApi {
*/
ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
- /** \brief Get the model package API table.
- *
- * Returns a pointer to the ::OrtModelPackageApi function table, which provides APIs to:
- * - create and release model package options and contexts,
- * - inspect model package metadata (components/variants),
- * - select a component/variant and query selected files/options,
- * - create a session from model package selection results.
- *
- * The returned pointer is owned by ONNX Runtime and is valid for the process lifetime.
- * Do not free it.
- *
- * \note May return NULL if model package support is not available in the current build
- * (for example, minimal builds).
- *
- * \return Pointer to ::OrtModelPackageApi, or NULL if unsupported.
- *
- * \since Version 1.27.
- */
- const OrtModelPackageApi*(ORT_API_CALL* GetModelPackageApi)(void);
-
/** \brief Retrieve an experimental function pointer by name.
*
* Experimental functions are not part of the stable ABI and may be added or removed between releases without notice.
@@ -8645,250 +8674,6 @@ struct OrtInteropApi {
/// @}
};
-/** \brief API table for model package workflows.
- *
- * A model package is a directory containing one or more *components* (logical models).
- * Each component has one or more *variants*, where each variant targets a single
- * execution provider (EP). The package manifest declares the EP name, device type,
- * and an optional compatibility string for every variant so that the runtime can
- * automatically select the best variant for the hardware and EPs available in the
- * caller's session options.
- *
- * Obtain this table from OrtApi::GetModelPackageApi(). The APIs support:
- * - creating model package options that capture EP configuration from OrtSessionOptions,
- * - loading a package context (manifest + metadata) from a package root path,
- * - querying component/variant metadata including per-variant EP information,
- * - selecting a component (which also resolves the best-matching variant),
- * - querying the selected variant's name and folder path,
- * - creating an OrtSession from the selected component context.
- *
- * Typical flow:
- * 1) Create model package options:
- * - CreateModelPackageOptionsFromSessionOptions()
- * 2) Load package metadata:
- * - CreateModelPackageContext()
- * 3) Query metadata (optional):
- * - ModelPackage_GetSchemaVersion()
- * - ModelPackage_GetComponentCount()
- * - ModelPackage_GetComponentNames()
- * - ModelPackage_GetVariantCount()
- * - ModelPackage_GetVariantNames()
- * - ModelPackage_GetVariantEpName()
- * 4) Select a component and resolve variant:
- * - SelectComponent()
- * 5) Query selected variant info (optional):
- * - ModelPackageComponent_GetSelectedVariantName()
- * - ModelPackageComponent_GetSelectedVariantFolderPath()
- * 6) Create session:
- * - CreateSession()
- *
- * Ownership:
- * - Release objects created by this API with the corresponding release methods:
- * ReleaseModelPackageOptions(), ReleaseModelPackageContext(),
- * ReleaseModelPackageComponentContext().
- *
- * \since Version 1.27.
- */
-struct OrtModelPackageApi {
- /// \name OrtModelPackageOptions
- /// @{
-
- /** \brief Create model package options from an existing OrtSessionOptions.
- *
- * Captures EP configuration (registered execution providers and their devices) from
- * the session options for use during variant selection. The resulting OrtModelPackageOptions
- * is passed to SelectComponent() to resolve the best variant for the available EPs.
- *
- * \param[in] env The ORT environment.
- * \param[in] session_options Session options containing registered EPs.
- * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released
- * with ReleaseModelPackageOptions().
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateModelPackageOptionsFromSessionOptions,
- _In_ const OrtEnv* env,
- _In_ const OrtSessionOptions* session_options,
- _Outptr_ OrtModelPackageOptions** out);
-
- ORT_CLASS_RELEASE(ModelPackageOptions);
- /// @}
- /// \name OrtModelPackageContext
- /// @{
-
- /** \brief Create a model package context by parsing the package at the given root path.
- *
- * Parses the manifest.json and component metadata from the specified directory.
- * The returned context provides read-only access to the package structure (components,
- * variants, EP declarations).
- *
- * \param[in] package_root Path to the model package root directory (containing manifest.json).
- * \param[out] out Receives the newly created OrtModelPackageContext. Must be released
- * with ReleaseModelPackageContext().
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateModelPackageContext,
- _In_ const ORTCHAR_T* package_root,
- _Outptr_ OrtModelPackageContext** out);
-
- ORT_CLASS_RELEASE(ModelPackageContext);
-
- /** \brief Get the schema version declared in the model package manifest.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_version Receives the schema version number.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetSchemaVersion,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ int64_t* out_version);
-
- /** \brief Get the number of components in the model package.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_count Receives the component count.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetComponentCount,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ size_t* out_count);
-
- /** \brief Get the names of all components in the model package.
- *
- * Returns a pointer to an array of UTF-8 component name strings. The array and its
- * strings are owned by `ctx` and remain valid until the context is released.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_names Receives a pointer to an array of component name strings.
- * \param[out] out_count Receives the number of elements in the array.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetComponentNames,
- _In_ const OrtModelPackageContext* ctx,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
- _Out_ size_t* out_count);
-
- /** \brief Get the number of variants for a given component.
- *
- * \param[in] ctx The model package context.
- * \param[in] component_name Name of the component to query.
- * \param[out] out_count Receives the variant count.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantCount,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Out_ size_t* out_count);
-
- /** \brief Get the names of all variants for a given component.
- *
- * Returns a pointer to an array of UTF-8 variant name strings. The array and its
- * strings are owned by `ctx` and remain valid until the context is released.
- *
- * \param[in] ctx The model package context.
- * \param[in] component_name Name of the component to query.
- * \param[out] out_variant_names Receives a pointer to an array of variant name strings.
- * \param[out] out_count Receives the number of elements in the array.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantNames,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
- _Out_ size_t* out_count);
-
- /** \brief Get the EP name declared for a (component, variant) pair.
- *
- * Each variant targets a single EP. `out_ep` receives the EP name string.
- * When the variant does not declare an EP, the returned pointer is NULL.
- * String memory is owned by `ctx` and remains valid until the context is released.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantEpName,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _In_ const char* variant_name,
- _Outptr_result_maybenull_ const char** out_ep);
-
- /** \brief Select a component model and return an opaque component instance.
- *
- * The variant selection is also performed during this call based on the component metadata and the provided options.
- * The returned `OrtModelPackgeComponentContext*` is independent of `context` lifetime and must be released via
- * `ReleaseComponentInstance`.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(SelectComponent,
- _In_ const OrtModelPackageContext* context,
- _In_ const char* component_name,
- _In_ const OrtModelPackageOptions* options,
- _Outptr_ OrtModelPackageComponentContext** out);
-
- ORT_CLASS_RELEASE(ModelPackageComponentContext);
-
- /** \brief Get the name of the selected variant after SelectComponent has been called.
- *
- * String memory is owned by `ctx` and remains valid until the context is released.
- *
- * \param[in] ctx The component context returned by SelectComponent().
- * \param[out] out_name Receives the selected variant's name string.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantName,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const char** out_name);
-
- /** \brief Get the folder path of the selected variant.
- *
- * Returns the resolved absolute path to the variant's directory on disk.
- * The string is owned by `ctx` and remains valid until the context is released.
- *
- * \param[in] ctx The component context returned by SelectComponent().
- * \param[out] folder_path Receives the variant folder path string.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantFolderPath,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const ORTCHAR_T** folder_path);
-
- /// @}
- /** \brief Create an OrtSession for a selected file within a component model variant.
- *
- * The chosen variant (and thus its EP selection) is determined by `context`, which
- * was built from an OrtSessionOptions via CreateModelPackageOptionsFromSessionOptions.
- *
- * Session options precedence:
- * 1. session_options == NULL (default path):
- * ORT uses the OrtSessionOptions that was captured when `context` was created.
- * Any variant-specific session and provider options declared in the package
- * metadata are merged on top.
- *
- * 2. session_options != NULL (advanced path):
- * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific
- * session and provider options from the package metadata are NOT applied.
- * Use this when custom EP setup is required (e.g., shared CUDA streams,
- * shared QNN EP contexts, custom allocators).
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateSession,
- _In_ const OrtEnv* env,
- _In_ OrtModelPackageComponentContext* context,
- _In_opt_ const OrtSessionOptions* session_options,
- _Outptr_ OrtSession** session);
-
- // End of Version 1.27 - DO NOT MODIFY ABOVE
-};
-
/*
* This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality
* This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index a999ef5e0faf4..4798d3d4ad1b8 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -264,20 +264,6 @@ inline const OrtEpApi& GetEpApi() {
return *api;
}
-///
-/// This returns a reference to the ORT C Model Package API. Used for loading models from model packages.
-///
-/// ORT C Model Package API reference
-inline const OrtModelPackageApi& GetModelPackageApi() {
- auto* api = GetApi().GetModelPackageApi();
- if (api == nullptr) {
- // minimal build
- ORT_CXX_API_THROW("Model Package API is not available in this build", ORT_FAIL);
- }
-
- return *api;
-}
-
/** \brief IEEE 754 half-precision floating point data type
*
* \details This struct is used for converting float to float16 and back
@@ -678,9 +664,6 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageOptions, GetModelPackageApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageContext, GetModelPackageApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageComponentContext, GetModelPackageApi);
// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
// but the struct has V2 in its name to indicate that it is the second version of the options.
@@ -807,9 +790,6 @@ struct EpDevice;
struct ExternalInitializerInfo;
struct Graph;
struct Model;
-struct ModelPackageOptions;
-struct ModelPackageContext;
-struct ModelPackageComponentContext;
struct Node;
struct ModelMetadata;
struct TypeInfo;
@@ -1806,70 +1786,6 @@ struct ModelCompilationOptions : detail::Base {
*/
Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);
-/** \brief Options for selecting a component from a model package.
- *
- * Wraps ::OrtModelPackageOptions. Created from an Env and SessionOptions, which captures the
- * EP configuration used for variant selection.
- */
-struct ModelPackageOptions : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageOptions(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- ModelPackageOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions
- ModelPackageOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions
-};
-
-/** \brief Context for inspecting and selecting components from a model package.
- *
- * Wraps ::OrtModelPackageContext. Provides traversal APIs to enumerate components, variants,
- * and EP compatibility, as well as component selection.
- */
-struct ModelPackageContext : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- explicit ModelPackageContext(const ORTCHAR_T* package_root); ///< Wraps OrtModelPackageApi::CreateModelPackageContext
-
- size_t GetComponentCount() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentCount
- std::vector GetComponentNames() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentNames
- size_t GetVariantCount(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantCount
- std::vector GetVariantNames(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantNames
-
- /// Get the EP name for a variant. Returns nullptr if not declared.
- /// Returned string is owned by this context and valid until it is released.
- const char* GetVariantEpName(const char* component_name,
- const char* variant_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantEpName
-
- int64_t GetSchemaVersion() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetSchemaVersion
-
- ModelPackageComponentContext SelectComponent(const char* component_name,
- const ModelPackageOptions& options) const; ///< Wraps OrtModelPackageApi::SelectComponent
-};
-
-/** \brief Context for a selected component within a model package.
- *
- * Wraps ::OrtModelPackageComponentContext. Provides accessors for the selected variant's
- * folder path and variant name.
- */
-struct ModelPackageComponentContext : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageComponentContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- std::basic_string GetSelectedVariantFolderPath() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantFolderPath
-
- std::string GetSelectedVariantName() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantName
-
- Session CreateSession(const Env& env); ///< Wraps OrtModelPackageApi::CreateSession (default path, NULL session_options)
- Session CreateSession(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path)
- Session CreateSession(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path)
-};
-
/** \brief Wrapper around ::OrtModelMetadata
*
*/
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 51f99655121c6..d7439e7b356c6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1360,110 +1360,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const Ort
return *this;
}
-// ModelPackageOptions
-inline ModelPackageOptions::ModelPackageOptions(const Env& env, const SessionOptions& session_options) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_));
-}
-
-inline ModelPackageOptions::ModelPackageOptions(const Env& env, ConstSessionOptions session_options) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_));
-}
-
-// ModelPackageContext
-inline ModelPackageContext::ModelPackageContext(const ORTCHAR_T* package_root) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageContext(package_root, &this->p_));
-}
-
-inline size_t ModelPackageContext::GetComponentCount() const {
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentCount(this->p_, &count));
- return count;
-}
-
-inline std::vector ModelPackageContext::GetComponentNames() const {
- const char* const* names = nullptr;
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentNames(this->p_, &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
-}
-
-inline size_t ModelPackageContext::GetVariantCount(const char* component_name) const {
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantCount(this->p_, component_name, &count));
- return count;
-}
-
-inline std::vector ModelPackageContext::GetVariantNames(const char* component_name) const {
- const char* const* names = nullptr;
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantNames(this->p_, component_name, &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
-}
-
-inline const char* ModelPackageContext::GetVariantEpName(const char* component_name,
- const char* variant_name) const {
- const char* ep = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantEpName(
- this->p_, component_name, variant_name, &ep));
- return ep;
-}
-
-inline int64_t ModelPackageContext::GetSchemaVersion() const {
- int64_t version = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetSchemaVersion(this->p_, &version));
- return version;
-}
-
-inline ModelPackageComponentContext ModelPackageContext::SelectComponent(
- const char* component_name, const ModelPackageOptions& options) const {
- OrtModelPackageComponentContext* out = nullptr;
- ThrowOnError(GetModelPackageApi().SelectComponent(this->p_, component_name, options, &out));
- return ModelPackageComponentContext{out};
-}
-
-// ModelPackageComponentContext
-inline std::basic_string ModelPackageComponentContext::GetSelectedVariantFolderPath() const {
- const ORTCHAR_T* path = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantFolderPath(this->p_, &path));
- return std::basic_string{path};
-}
-
-inline std::string ModelPackageComponentContext::GetSelectedVariantName() const {
- const char* name = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantName(this->p_, &name));
- return (name != nullptr) ? std::string{name} : std::string{};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, nullptr, &out));
- return Session{out};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env,
- const SessionOptions& session_options) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out));
- return Session{out};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env,
- ConstSessionOptions session_options) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out));
- return Session{out};
-}
-
namespace detail {
template
diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
index e943a5cf65b11..0dd87c10776d3 100644
--- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
@@ -39,6 +39,10 @@
// ORT_RUNTIME_CLASS(ExperimentalType);
//
+ORT_RUNTIME_CLASS(ModelPackageOptions);
+ORT_RUNTIME_CLASS(ModelPackageContext);
+ORT_RUNTIME_CLASS(ModelPackageComponentContext);
+
//
// C: function pointer typedefs and name constants
//
diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
index 0123b02818584..57a4e472b6f6d 100644
--- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
+++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
@@ -35,3 +35,250 @@
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t* out)
+
+// ---------------------------------------------------------------------------
+// OrtModelPackageApi
+//
+// A model package is a directory containing one or more *components* (logical models).
+// Each component has one or more *variants*, where each variant targets a single
+// execution provider (EP). The package manifest declares the EP name, device type,
+// and an optional compatibility string for every variant so that the runtime can
+// automatically select the best variant for the hardware and EPs available in the
+// caller's session options.
+//
+// The functions below support:
+// - creating model package options that capture EP configuration from OrtSessionOptions,
+// - loading a package context (manifest + metadata) from a package root path,
+// - querying component/variant metadata including per-variant EP information,
+// - selecting a component (which also resolves the best-matching variant),
+// - querying the selected variant's name and folder path,
+// - creating an OrtSession from the selected component context.
+//
+// Typical flow:
+// 1) Create model package options:
+// - OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions
+// 2) Load package metadata:
+// - OrtModelPackageApi_CreateModelPackageContext
+// 3) Query metadata (optional):
+// - OrtModelPackageApi_ModelPackage_GetSchemaVersion
+// - OrtModelPackageApi_ModelPackage_GetComponentCount
+// - OrtModelPackageApi_ModelPackage_GetComponentNames
+// - OrtModelPackageApi_ModelPackage_GetVariantCount
+// - OrtModelPackageApi_ModelPackage_GetVariantNames
+// - OrtModelPackageApi_ModelPackage_GetVariantEpName
+// 4) Select a component and resolve variant:
+// - OrtModelPackageApi_SelectComponent
+// 5) Query selected variant info (optional):
+// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName
+// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath
+// 6) Create session:
+// - OrtModelPackageApi_CreateSession
+//
+// Ownership:
+// - Release objects created by this API with the corresponding release functions:
+// OrtModelPackageApi_ReleaseModelPackageOptions,
+// OrtModelPackageApi_ReleaseModelPackageContext,
+// OrtModelPackageApi_ReleaseModelPackageComponentContext.
+//
+// Opaque handles (OrtModelPackageOptions, OrtModelPackageContext, OrtModelPackageComponentContext)
+// are declared in onnxruntime_experimental_c_api.h.
+// ---------------------------------------------------------------------------
+
+/** \brief Create model package options from an existing OrtSessionOptions.
+ *
+ * Captures EP configuration (registered execution providers and their devices) from
+ * the session options for use during variant selection. The resulting OrtModelPackageOptions
+ * is passed to OrtModelPackageApi_SelectComponent to resolve the best variant for the
+ * available EPs.
+ *
+ * \param[in] env The ORT environment.
+ * \param[in] session_options Session options containing registered EPs.
+ * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released
+ * with OrtModelPackageApi_ReleaseModelPackageOptions.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions,
+ _In_ const OrtEnv* env,
+ _In_ const OrtSessionOptions* session_options,
+ _Outptr_ OrtModelPackageOptions** out)
+
+/** \brief Release an OrtModelPackageOptions created by
+ * OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions.
+ */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageOptions,
+ _Frees_ptr_opt_ OrtModelPackageOptions* options)
+
+/** \brief Create a model package context by parsing the package at the given root path.
+ *
+ * Parses the manifest.json and component metadata from the specified directory.
+ * The returned context provides read-only access to the package structure (components,
+ * variants, EP declarations).
+ *
+ * \param[in] package_root Path to the model package root directory (containing manifest.json).
+ * \param[out] out Receives the newly created OrtModelPackageContext. Must be released
+ * with OrtModelPackageApi_ReleaseModelPackageContext.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageContext,
+ _In_ const ORTCHAR_T* package_root,
+ _Outptr_ OrtModelPackageContext** out)
+
+/** \brief Release an OrtModelPackageContext created by OrtModelPackageApi_CreateModelPackageContext. */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageContext,
+ _Frees_ptr_opt_ OrtModelPackageContext* ctx)
+
+/** \brief Get the schema version declared in the model package manifest.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_version Receives the schema version number.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetSchemaVersion,
+ _In_ const OrtModelPackageContext* ctx,
+ _Out_ int64_t* out_version)
+
+/** \brief Get the number of components in the model package.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_count Receives the component count.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentCount,
+ _In_ const OrtModelPackageContext* ctx,
+ _Out_ size_t* out_count)
+
+/** \brief Get the names of all components in the model package.
+ *
+ * Returns a pointer to an array of UTF-8 component name strings. The array and its
+ * strings are owned by `ctx` and remain valid until the context is released.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_names Receives a pointer to an array of component name strings.
+ * \param[out] out_count Receives the number of elements in the array.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentNames,
+ _In_ const OrtModelPackageContext* ctx,
+ _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
+ _Out_ size_t* out_count)
+
+/** \brief Get the number of variants for a given component.
+ *
+ * \param[in] ctx The model package context.
+ * \param[in] component_name Name of the component to query.
+ * \param[out] out_count Receives the variant count.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantCount,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _Out_ size_t* out_count)
+
+/** \brief Get the names of all variants for a given component.
+ *
+ * Returns a pointer to an array of UTF-8 variant name strings. The array and its
+ * strings are owned by `ctx` and remain valid until the context is released.
+ *
+ * \param[in] ctx The model package context.
+ * \param[in] component_name Name of the component to query.
+ * \param[out] out_variant_names Receives a pointer to an array of variant name strings.
+ * \param[out] out_count Receives the number of elements in the array.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantNames,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
+ _Out_ size_t* out_count)
+
+/** \brief Get the EP name declared for a (component, variant) pair.
+ *
+ * Each variant targets a single EP. `out_ep` receives the EP name string.
+ * When the variant does not declare an EP, the returned pointer is NULL.
+ * String memory is owned by `ctx` and remains valid until the context is released.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantEpName,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _In_ const char* variant_name,
+ _Outptr_result_maybenull_ const char** out_ep)
+
+/** \brief Select a component model and return an opaque component instance.
+ *
+ * The variant selection is also performed during this call based on the component
+ * metadata and the provided options. The returned `OrtModelPackageComponentContext*` is
+ * independent of `context` lifetime and must be released via
+ * OrtModelPackageApi_ReleaseModelPackageComponentContext.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_SelectComponent,
+ _In_ const OrtModelPackageContext* context,
+ _In_ const char* component_name,
+ _In_ const OrtModelPackageOptions* options,
+ _Outptr_ OrtModelPackageComponentContext** out)
+
+/** \brief Release an OrtModelPackageComponentContext created by OrtModelPackageApi_SelectComponent. */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageComponentContext,
+ _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx)
+
+/** \brief Get the name of the selected variant after OrtModelPackageApi_SelectComponent has been called.
+ *
+ * String memory is owned by `ctx` and remains valid until the context is released.
+ *
+ * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent.
+ * \param[out] out_name Receives the selected variant's name string.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName,
+ _In_ const OrtModelPackageComponentContext* ctx,
+ _Outptr_ const char** out_name)
+
+/** \brief Get the folder path of the selected variant.
+ *
+ * Returns the resolved absolute path to the variant's directory on disk.
+ * The string is owned by `ctx` and remains valid until the context is released.
+ *
+ * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent.
+ * \param[out] folder_path Receives the variant folder path string.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath,
+ _In_ const OrtModelPackageComponentContext* ctx,
+ _Outptr_ const ORTCHAR_T** folder_path)
+
+/** \brief Create an OrtSession for a selected file within a component model variant.
+ *
+ * The chosen variant (and thus its EP selection) is determined by `context`, which
+ * was built from an OrtSessionOptions via OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions.
+ *
+ * Session options precedence:
+ * 1. session_options == NULL (default path):
+ * ORT uses the OrtSessionOptions that was captured when `context` was created.
+ * Any variant-specific session and provider options declared in the package
+ * metadata are merged on top.
+ *
+ * 2. session_options != NULL (advanced path):
+ * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific
+ * session and provider options from the package metadata are NOT applied.
+ * Use this when custom EP setup is required (e.g., shared CUDA streams,
+ * shared QNN EP contexts, custom allocators).
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession,
+ _In_ const OrtEnv* env,
+ _In_ OrtModelPackageComponentContext* context,
+ _In_opt_ const OrtSessionOptions* session_options,
+ _Outptr_ OrtSession** session)
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 0b6d009e072ad..eee100aeef8df 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -420,6 +420,21 @@ static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings =
///
static const char* const kOrtSessionOptionsLayerAssignmentSettings = "session.layer_assignment_settings";
+///
+/// Name-based layer assignment. Uses the same device(pattern1, pattern2, ...); ... grammar
+/// as kOrtSessionOptionsLayerAssignmentSettings but performs SUBSTRING matching against
+/// Node::Name() instead of prefix/exact matching against node metadata annotations.
+/// The '=' prefix (exact match) from the annotation-based grammar is rejected with an error
+/// — all patterns are treated as substrings.
+/// Longest matching pattern wins when multiple patterns match the same node name.
+/// No subgraph inheritance is applied — each node is matched independently by its name.
+///
+/// MUTUALLY EXCLUSIVE with kOrtSessionOptionsLayerAssignmentSettings. Setting both returns
+/// INVALID_ARGUMENT. Use annotation-based matching for models with explicit layer annotations,
+/// or name-based matching for models with structured node names (HuggingFace, PyTorch exports).
+///
+static const char* const kOrtSessionOptionsNameBasedLayerAssignment = "session.name_based_layer_assignment";
+
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json
index b7f38815f9da5..ca2d4dfed58f7 100644
--- a/js/react_native/e2e/package-lock.json
+++ b/js/react_native/e2e/package-lock.json
@@ -11629,9 +11629,9 @@
}
},
"node_modules/shell-quote": {
- "version": "1.8.3",
- "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz",
- "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==",
+ "version": "1.8.4",
+ "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.4.tgz",
+ "integrity": "sha512-VsC6n6vz1ihYYyZZwX7YZSF5l5x36ca17OC+a69h94YqB7X6XLwf+5MOgynYir2SLFUbl8gIYvBo8K8RoNQ6bQ==",
"license": "MIT",
"engines": {
"node": ">= 0.4"
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index c26424a0ab9ec..df14bc8c57f24 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -29,9 +29,6 @@
GraphOptimizationLevel, # noqa: F401
LoraAdapter, # noqa: F401
ModelMetadata, # noqa: F401
- ModelPackageComponentContext, # noqa: F401
- ModelPackageContext, # noqa: F401
- ModelPackageOptions, # noqa: F401
NodeArg, # noqa: F401
OrtAllocatorType, # noqa: F401
OrtArenaCfg, # noqa: F401
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
index d28aae02ab2f1..6e75bb1b5a7c0 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
@@ -48,7 +48,11 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences,
unique_word_ids.insert(word_id);
}
+ const int vocab_size = next_token_scores.vocab_size;
for (const int32_t word_id : unique_word_ids) {
+ if (word_id < 0 || word_id >= vocab_size) {
+ continue;
+ }
T score = beam_token_scores[word_id];
// If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability,
@@ -89,7 +93,11 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences,
}
}
+ const int vocab_size = next_token_scores.vocab_size;
for (const int32_t word_id : blocked_word_ids) {
+ if (word_id < 0 || word_id >= vocab_size) {
+ continue;
+ }
beam_token_scores[word_id] = std::numeric_limits::lowest();
}
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 6e157d83159f4..fb9638bb09d59 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -28,7 +28,9 @@ struct NextTokenScores {
}
void SetScore(int token_id, T score) {
- assert(token_id >= 0 && token_id < vocab_size);
+ if (token_id < 0 || token_id >= vocab_size) {
+ return;
+ }
for (int i = 0; i < batch_beam_size; i++) {
scores[static_cast(i) * vocab_size + token_id] = score;
}
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
index b9e62443145e5..c3b734816cf84 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
@@ -31,6 +31,8 @@ enum class QuantType {
W4_AFP8
};
+int get_arch_for_mixed_gemm_weight_preprocess(int arch);
+
void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream,
int arch,
int8_t* preprocessed_quantized_weight,
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
index a006612ddadc9..7e83bdda72eab 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
@@ -521,6 +521,19 @@ void add_bias_and_interleave_quantized_tensor_inplace_cuda(
}
}
+int get_arch_for_mixed_gemm_weight_preprocess(int arch) {
+ ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch);
+ if (arch < 80) {
+ return 75;
+ }
+#ifndef EXCLUDE_SM_90
+ if (arch >= 90 && arch < 100) {
+ return 90;
+ }
+#endif
+ return 80;
+}
+
void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream,
int arch,
int8_t* preprocessed_quantized_weight,
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
index a8fb411ed0663..47bbe0c0e10ec 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
@@ -120,11 +120,11 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) {
}
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) {
- ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch);
- if (arch < 80) {
+ arch = get_arch_for_mixed_gemm_weight_preprocess(arch);
+ if (arch == 75) {
return getLayoutDetailsForArch(quant_type);
#ifndef EXCLUDE_SM_90
- } else if (arch >= 90 && arch < 100) {
+ } else if (arch == 90) {
return getLayoutDetailsForArch(quant_type);
#endif
} else {
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc
index d5155dc5507cb..603ba2b96b81b 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc
@@ -55,6 +55,9 @@ Status MoE::ComputeInternal(OpKernelContext* context) const {
1, // no quantization so pack size is 1
is_fused_swiglu,
0)); // no block-wise quantization for regular MoE
+ ORT_RETURN_IF_NOT(k_ > 0 && k_ <= moe_params.num_experts,
+ "MoE requires 0 < k <= num_experts, got k=", k_,
+ " and num_experts=", moe_params.num_experts);
using CudaT = typename OrtToCudaType::type;
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
index e1ddcac0cea4f..8e78076288015 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
@@ -62,18 +62,28 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE
this->quant_type_ = op_kernel_info.GetAttrOrDefault("quant_type", "int");
ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8",
"quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'");
- // ``weights_prepacked`` is an optional tri-state attribute that defaults to
- // -1 (auto) in the schema, so each EP picks its own backward-compatible
- // default rather than the schema imposing one:
- // -1 (auto, also the schema default): the EP decides. The CUDA EP's
- // backward-compatible default is "prepacked" because all pre-existing
- // tooling ships CUTLASS-prepacked weights.
- // 1: initializers are already prepacked; the compute path reads them as-is.
- // 0: initializers are raw [E, N, K/pack]; the PrePack hook lays them out.
+ // ``weights_prepacked`` is an optional tri-state attribute (default -1) that
+ // declares the layout of the int4/int8 fc1/fc2 weight initializers. The
+ // concrete prepacked layouts selected by -1 and 1 are determined by the
+ // execution provider. The CUDA EP maps the tri-state as:
+ // -1 (default): already prepacked in the EP's default int weight layout.
+ // 1: already prepacked in an alternate EP-selected int weight layout.
+ // 0: raw [E, N, K/pack] initializers; the PrePack hook lays them out.
+ //
+ // Important: the CUDA QMoE int4/int8 MoE GEMM always dispatches to the
+ // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed
+ // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized
+ // specialisation (see isValidHopperMOESpecialisation). The kernel therefore
+ // consumes the SM80/Ampere CUTLASS fpA_intB layout on every GPU. As a result
+ // the EP default (-1) is the SM80 layout regardless of the runtime device SM,
+ // and SM80-format weights are valid on SM90 (they run via the SM80 kernel).
+ // For CUDA today, -1 and 1 are equivalent (both SM80 layout), and 1 is
+ // reserved for a possible future Hopper-specific layout.
+ // PrePack (weights_prepacked=0) packs for the SM80 layout accordingly.
const int64_t weights_prepacked_mode =
op_kernel_info.GetAttrOrDefault("weights_prepacked", static_cast(-1));
ORT_ENFORCE(weights_prepacked_mode == -1 || weights_prepacked_mode == 0 || weights_prepacked_mode == 1,
- "weights_prepacked must be -1 (auto), 0, or 1, but got ", weights_prepacked_mode);
+ "weights_prepacked must be -1 (default), 0, or 1, but got ", weights_prepacked_mode);
weights_prepacked_ = (weights_prepacked_mode != 0);
#if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE)
ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer.");
@@ -224,6 +234,18 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
// to the runner.
const bool int_weights_consumed_by_prepack =
is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr && packed_fc2_weights_ != nullptr;
+ // When ``weights_prepacked == 0`` the raw ``[E, N, K/pack]`` int weights must be
+ // converted to the CUTLASS fpA_intB layout by PrePack before the runner can consume
+ // them. If PrePack never ran (e.g. ``session.disable_prepacking`` is set), the prepack
+ // buffers stay null and falling through to the raw initializer pointers would feed
+ // non-CUTLASS bytes to the runner, producing silently wrong output. Fail loudly instead.
+ if (is_int && !weights_prepacked_ &&
+ (packed_fc1_weights_ == nullptr || packed_fc2_weights_ == nullptr)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "QMoE weights_prepacked=0 requires PrePack to run, but the int weight "
+ "buffers were not produced (is session.disable_prepacking set?). Provide "
+ "CUTLASS-prepacked weights with weights_prepacked=1, or enable prepacking.");
+ }
const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input(2);
const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input(3) : nullptr;
const Tensor* fc1_experts_bias_optional = context->Input(4);
@@ -295,6 +317,9 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
&fc2_shape, fc2_experts_bias_optional, fc2_scales, fc2_zeros,
nullptr, nullptr, nullptr, nullptr,
pack_size, is_fused_swiglu, block_size_));
+ ORT_RETURN_IF_NOT(k_ > 0 && k_ <= moe_params.num_experts,
+ "QMoE requires 0 < k <= num_experts, got k=", k_,
+ " and num_experts=", moe_params.num_experts);
if (uses_fp4_weight_scales) {
constexpr int64_t fp4_block_size = 32;
@@ -844,13 +869,24 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr;
const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr;
if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) {
- fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data;
- fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data;
+ // The native CUTLASS WFP4AFP8 path consumes weights in the repacked FP4
+ // layout produced by PrePack. If PrePack never ran (e.g.
+ // ``session.disable_prepacking`` is set) the repacked buffers stay null and
+ // falling through to the raw initializer bytes would feed a non-CUTLASS
+ // layout to the runner, producing silently wrong output. Fail loudly.
+ if (packed_fp4_fc1_weights_ == nullptr || packed_fp4_fc2_weights_ == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "QMoE wfp4afp8 requires PrePack to run, but the repacked FP4 weight "
+ "buffers were not produced (is session.disable_prepacking set?). "
+ "Enable prepacking to use the native WFP4AFP8 path.");
+ }
+ fc1_weight_data = packed_fp4_fc1_weights_.get();
+ fc2_weight_data = packed_fp4_fc2_weights_.get();
} else if (int_weights_consumed_by_prepack) {
// PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB
// layout that the runner consumes and freed the source initializer
// (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack``
- // (which already requires ``packed_fc1_weights_ != nullptr``) rather than
+ // (which already requires both packed weight buffers) rather than
// just ``is_int && !weights_prepacked_``: when prepacking is disabled at
// the session level (``session.disable_prepacking``) PrePack never runs,
// the prepack buffers stay null, and the raw initializer pointers read
@@ -1146,6 +1182,9 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al
IAllocatorUniquePtr& packed_buf, bool& is_packed) {
ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8,
"PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_);
+ ORT_ENFORCE(sm_ >= 75,
+ "PrePackIntExpertWeights: quant_type='int' with weights_prepacked=0 requires SM75+ CUDA hardware, got SM",
+ sm_);
const auto& shape = tensor.Shape();
ORT_ENFORCE(shape.NumDimensions() == 3,
"PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=",
@@ -1158,22 +1197,15 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al
const int64_t k_packed = shape[2];
const int64_t k = k_packed * pack_factor;
- // Weight packing is architecture-aware (see
- // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing
- // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that
- // skips column interleaving, so it is its own compatibility group. Every
- // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares
- // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack
- // INT8 LDSM and are unsupported. The compute-side runner selects the same
- // layout from this clamped arch, so the two cannot drift.
- //
- // SM75 is passed through unchanged (rather than clamped to 80) even though it
- // shares SM80's layout: the compute-side dispatch (getLayoutDetailsForTransform)
- // still has a distinct SM75 branch, so mirroring it here avoids confusing a
- // reader into thinking prepack and dispatch disagree.
- ORT_ENFORCE(sm_ >= 75,
- "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_);
- const int packing_sm = (sm_ == 90 || sm_ == 75) ? sm_ : 80;
+ // The CUDA QMoE int4/int8 MoE GEMM always dispatches to the Ampere (SM80)
+ // grouped-GEMM kernel -- even on SM90 -- because mixed int-weight + fp16/bf16
+ // is not a valid Hopper TMA warp-specialized specialisation. The kernel thus
+ // consumes the SM80 CUTLASS fpA_intB layout on every GPU, so the weights must
+ // always be preprocessed for SM80 regardless of the runtime device SM.
+ // (Using get_arch_for_mixed_gemm_weight_preprocess(sm_) here would emit the
+ // SM90 layout on Hopper, which the SM80 kernel cannot consume -> wrong output.)
+ const int packing_sm =
+ onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(80);
// Per-expert sizes.
const size_t per_expert_bytes = static_cast(n) * static_cast(k) / pack_factor;
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
index 5722ac41cc470..2bbadc205b5d8 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
@@ -46,16 +46,23 @@ class QMoE final : public CudaKernel, public MoEBase {
IAllocatorUniquePtr& packed_buf, bool& is_packed);
int64_t expert_weight_bits_;
bool is_fp16_;
- // When true (the schema default), the int4/int8 fc1/fc2 weight
- // initializers are already in the CUTLASS fpA_intB layout — produced
- // offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the
- // compute path reads them as-is. When false, the raw schema-conformant
- // ``[E, N, K/pack]`` layout (as produced by
- // ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook
- // via ``PrePackIntExpertWeights``, removing the offline prepack
- // dependency. Only meaningful when ``quant_type_ == "int"``. Derived from
- // the optional tri-state ``weights_prepacked`` attribute: -1/auto (or
- // absent) maps to true on the CUDA EP, 1 maps to true, 0 maps to false.
+ // When true, the int4/int8 fc1/fc2 weight initializers are already in a
+ // CUTLASS fpA_intB layout — produced offline e.g. via
+ // ``pack_weights_for_cuda_mixed_gemm`` — and the compute path reads them
+ // as-is. When false, the raw schema-conformant ``[E, N, K/pack]`` layout
+ // (as produced by ``quantize_matmul_{4,8}bits``) is rewritten inside the
+ // PrePack hook via ``PrePackIntExpertWeights``, removing the offline
+ // prepack dependency. Only meaningful when ``quant_type_ == "int"``.
+ // Derived from the optional tri-state ``weights_prepacked`` attribute:
+ // -1 (default) and 1 both map to true; 0 maps to false. The concrete
+ // prepacked layouts selected by -1 and 1 are determined by the execution
+ // provider. For the CUDA EP the int4/int8 MoE GEMM always dispatches to the
+ // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed
+ // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized
+ // specialisation (matches TensorRT-LLM, which also routes W4A16/W8A16 MoE to
+ // the SM80 kernel on Hopper). The kernel therefore consumes the SM80 fpA_intB
+ // layout on every GPU, so -1 and 1 are currently equivalent for the CUDA EP;
+ // 1 is reserved for a possible future Hopper-specific layout (e.g. W4A8).
bool weights_prepacked_ = true;
// Cached source weight shapes captured at PrePack time. When the
// PrePack hook consumed and released the original int4/int8 weight
diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu
index 61cdf3ab23fca..43420c04b1a8e 100644
--- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu
+++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu
@@ -6,6 +6,7 @@
#include "core/common/narrow.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/cub.cuh"
+#include "core/providers/cuda/cu_inc/topk_warp_sort.cuh"
#include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h"
#include
#include
@@ -19,6 +20,56 @@ int Compute1DGridSize(int num_elements, int block_size) {
return (num_elements + block_size - 1) / block_size;
}
+constexpr float kTopKNormalizeEpsilon = 1e-6f;
+
+__device__ __forceinline__ float SoftmaxScale(float logit, float max_val, float inv_sum) {
+ return (inv_sum > 0.0f) ? expf(logit - max_val) * inv_sum : 0.0f;
+}
+
+__device__ __forceinline__ float SafeInvSum(float sum) {
+ return (sum > 0.0f) ? (1.0f / sum) : 0.0f;
+}
+
+__device__ __forceinline__ float TopKNormalizeDenom(bool normalize_scales, float scale_sum) {
+ return (normalize_scales && scale_sum > kTopKNormalizeEpsilon) ? scale_sum : 1.0f;
+}
+
+__device__ __forceinline__ float WarpReduceMax(float value) {
+ constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize;
+#pragma unroll
+ for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
+ value = fmaxf(value, __shfl_xor_sync(0xFFFFFFFF, value, offset));
+ }
+ return value;
+}
+
+__device__ __forceinline__ float WarpReduceSum(float value) {
+ constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize;
+#pragma unroll
+ for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) {
+ value += __shfl_xor_sync(0xFFFFFFFF, value, offset);
+ }
+ return value;
+}
+
+template
+__device__ __forceinline__ float BlockReduceMax(float value, typename BlockReduce::TempStorage& temp_storage) {
+#if CUDART_VERSION >= 12090
+ return BlockReduce(temp_storage).Reduce(value, ::cuda::maximum());
+#else
+ return BlockReduce(temp_storage).Reduce(value, cub::Max());
+#endif
+}
+
+template
+__device__ __forceinline__ float BlockReduceSum(float value, typename BlockReduce::TempStorage& temp_storage) {
+#if CUDART_VERSION >= 12090
+ return BlockReduce(temp_storage).Reduce(value, ::cuda::std::plus());
+#else
+ return BlockReduce(temp_storage).Reduce(value, cub::Sum());
+#endif
+}
+
template
__global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk_indices,
int num_rows, int num_experts, int k, bool normalize_scales) {
@@ -30,7 +81,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk
int* row_indices = topk_indices + row * k;
// 1. Find max for numerical stability
- float max_val = -FLT_MAX;
+ float max_val = onnxruntime::cuda::topk::kNegativeInfinity;
for (int i = 0; i < num_experts; ++i) {
float val = static_cast(row_logits[i]);
if (val > max_val) max_val = val;
@@ -41,6 +92,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk
for (int i = 0; i < num_experts; ++i) {
sum_exp += expf(static_cast(row_logits[i]) - max_val);
}
+ const float inv_sum = SafeInvSum(sum_exp);
// 3. Compute Softmax and find TopK
// For small k, we can do a simple selection.
@@ -56,7 +108,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk
}
for (int i = 0; i < num_experts; ++i) {
- float prob = expf(static_cast(row_logits[i]) - max_val) / sum_exp;
+ float prob = SoftmaxScale(static_cast(row_logits[i]), max_val, inv_sum);
// Insert into top-k logic
// Simple insertion sort for very small k (e.g. k=2)
@@ -80,7 +132,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk
for (int i = 0; i < k; ++i) {
scale_sum += row_scales[i];
}
- if (scale_sum > 1e-6f) {
+ if (scale_sum > kTopKNormalizeEpsilon) {
for (int i = 0; i < k; ++i) {
row_scales[i] /= scale_sum;
}
@@ -88,6 +140,258 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk
}
}
+// Block-per-row softmax + top-k using a CUB block sort. Each block sorts one
+// row's logits (descending) and reads the first k. A full sort of 256 logits is
+// ~2.5x faster than k rounds of block-argmax on this size (benchmarked), and is
+// the layout onnxruntime-genai's top-k benchmarks also recommend (CUB block
+// merge) for sort sizes up to ~1024. The capacity (kBlockSize*kItemsPerThread)
+// must be >= num_experts; padding lanes carry (-inf, INT_MAX) so valid -inf
+// expert scores sort ahead of padding. Tie-breaking matches the scalar kernel
+// (lower expert index first) via the same packed stable sort key used by the
+// warp merge path.
+
+template
+__global__ void SoftmaxTopKMergeKernel(const T* logits, float* topk_scales, int* topk_indices,
+ int num_rows, int num_experts, int k, bool normalize_scales) {
+ const int row = blockIdx.x;
+ if (row >= num_rows) return;
+
+ const T* row_logits = logits + static_cast(row) * num_experts;
+ const int tid = threadIdx.x;
+
+ using BlockMergeSort = cub::BlockMergeSort;
+ using BlockReduce = cub::BlockReduce;
+ __shared__ union {
+ typename BlockMergeSort::TempStorage merge;
+ typename BlockReduce::TempStorage reduce;
+ } temp;
+ __shared__ float s_topk[64]; // k <= 64
+ __shared__ float s_max;
+ __shared__ float s_sum;
+
+ // Load this thread's packed (logit, expert index) keys in a blocked
+ // arrangement: thread t owns indices [t*ipt, t*ipt+ipt).
+ uint64_t keys[kItemsPerThread];
+ float local_max = onnxruntime::cuda::topk::kNegativeInfinity;
+#pragma unroll
+ for (int j = 0; j < kItemsPerThread; ++j) {
+ const int idx = tid * kItemsPerThread + j;
+ const float logit = (idx < num_experts) ? static_cast(row_logits[idx])
+ : onnxruntime::cuda::topk::kNegativeInfinity;
+ const int index = (idx < num_experts) ? idx : INT_MAX;
+ keys[j] = onnxruntime::cuda::topk::PackStableSortKey(logit, index);
+ local_max = fmaxf(local_max, logit);
+ }
+
+ // Softmax denominator over all experts (needed when normalize_scales is false;
+ // when true it cancels in the top-k normalization but is still correct).
+ const float block_max = BlockReduceMax(local_max, temp.reduce);
+ if (tid == 0) s_max = block_max;
+ // Single barrier: publishes s_max to all threads and also separates the two
+ // BlockReduce uses that share temp.reduce.
+ __syncthreads();
+ const float max_val = s_max;
+
+ float local_sum = 0.0f;
+#pragma unroll
+ for (int j = 0; j < kItemsPerThread; ++j) {
+ const int idx = tid * kItemsPerThread + j;
+ if (idx < num_experts) {
+ local_sum += expf(onnxruntime::cuda::topk::UnpackStableSortScore(keys[j]) - max_val);
+ }
+ }
+ const float block_sum = BlockReduceSum(local_sum, temp.reduce);
+ if (tid == 0) s_sum = block_sum;
+ // Single barrier: publishes s_sum and separates temp.reduce from temp.merge.
+ __syncthreads();
+ const float inv_sum = SafeInvSum(s_sum);
+
+ // Sort packed (logit, index) keys descending. Result stays in a blocked
+ // layout, so sorted rank r lives in thread (r / ipt), item (r % ipt). Sort()
+ // leaves the sorted keys in each thread's registers and temp.merge is not
+ // reused afterwards, so no barrier is needed here; the shared s_topk writes
+ // below are published by the barrier that follows them.
+ BlockMergeSort(temp.merge).Sort(keys, onnxruntime::cuda::topk::Greater());
+
+#pragma unroll
+ for (int j = 0; j < kItemsPerThread; ++j) {
+ const int rank = tid * kItemsPerThread + j;
+ if (rank < k) {
+ const uint64_t key = keys[j];
+ topk_indices[static_cast(row) * k + rank] =
+ onnxruntime::cuda::topk::UnpackStableSortIndex(key);
+ s_topk[rank] = SoftmaxScale(onnxruntime::cuda::topk::UnpackStableSortScore(key), max_val, inv_sum);
+ }
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ if (normalize_scales) {
+ float scale_sum = 0.0f;
+ for (int i = 0; i < k; ++i) scale_sum += s_topk[i];
+ const float denom = TopKNormalizeDenom(normalize_scales, scale_sum);
+ for (int i = 0; i < k; ++i) topk_scales[static_cast(row) * k + i] = s_topk[i] / denom;
+ } else {
+ for (int i = 0; i < k; ++i) topk_scales[static_cast(row) * k + i] = s_topk[i];
+ }
+ }
+}
+
+// Warp-bitonic softmax + top-k for num_experts <= 32. Each warp handles one
+// row, with lane `l` owning expert `l`. The whole softmax reduction and the
+// sort are done with warp shuffles (no shared memory). This is the fastest path
+// for tiny expert counts per the onnxruntime-genai top-k benchmark. Tie-breaking
+// (equal scores prefer the lower expert index) matches SoftmaxTopKMergeKernel.
+template
+__global__ void SoftmaxTopKWarpBitonicKernel(const T* logits, float* topk_scales, int* topk_indices,
+ int num_rows, int num_experts, int k, bool normalize_scales) {
+ const int lane = threadIdx.x;
+ const int row = blockIdx.x * kWarpsPerBlock + threadIdx.y;
+ if (row >= num_rows) return;
+
+ const T* row_logits = logits + static_cast(row) * num_experts;
+ const float logit = (lane < num_experts) ? static_cast(row_logits[lane])
+ : onnxruntime::cuda::topk::kNegativeInfinity;
+
+ const float max_val = WarpReduceMax(logit);
+
+ // Warp-wide exp sum (softmax denominator over all experts).
+ const float sum_exp = WarpReduceSum((lane < num_experts) ? expf(logit - max_val) : 0.0f);
+ const float inv_sum = SafeInvSum(sum_exp);
+
+ // Sort (logit, expert index) descending; sorting by logit is equivalent to
+ // sorting by softmax probability since the mapping is monotonic.
+ float score = logit;
+ int index = (lane < num_experts) ? lane : INT_MAX;
+ onnxruntime::cuda::topk::WarpBitonicSortDescending(score, index);
+
+ // Lane r now holds the rank-r element. Compute the top-k probabilities.
+ float prob = (lane < k) ? SoftmaxScale(score, max_val, inv_sum) : 0.0f;
+
+ if (normalize_scales) {
+ prob /= TopKNormalizeDenom(normalize_scales, WarpReduceSum(prob));
+ }
+
+ if (lane < k) {
+ topk_scales[static_cast(row) * k + lane] = prob;
+ topk_indices[static_cast(row) * k + lane] = index;
+ }
+}
+
+// Warp CUB merge sort softmax + top-k for num_experts <= kBufferSize (64). One
+// warp (32 threads) per block sorts a row's logits held in shared memory. This
+// is the genai-recommended path for sort sizes in (32, 64]. Tie-breaking
+// matches SoftmaxTopKMergeKernel via a packed stable sort key.
+template
+__global__ void SoftmaxTopKWarpMergeKernel(const T* logits, float* topk_scales, int* topk_indices,
+ int num_rows, int num_experts, int k, bool normalize_scales) {
+ constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize;
+ using WarpMergeSorter = onnxruntime::cuda::topk::WarpMergeSorter;
+
+ const int row = blockIdx.x;
+ if (row >= num_rows) return;
+ const int lane = threadIdx.x;
+
+ __shared__ float s_scores[kBufferSize];
+ __shared__ int s_indices[kBufferSize];
+ __shared__ typename WarpMergeSorter::TempStorage temp_storage;
+
+ const T* row_logits = logits + static_cast(row) * num_experts;
+
+ // Load logits into shared memory and compute the warp-wide max.
+ float local_max = onnxruntime::cuda::topk::kNegativeInfinity;
+ for (int i = lane; i < kBufferSize; i += kWarpSize) {
+ const float v = (i < num_experts) ? static_cast(row_logits[i])
+ : onnxruntime::cuda::topk::kNegativeInfinity;
+ s_scores[i] = v;
+ s_indices[i] = (i < num_experts) ? i : INT_MAX;
+ local_max = fmaxf(local_max, v);
+ }
+ const float max_val = WarpReduceMax(local_max);
+
+ // Warp-wide exp sum over all experts.
+ float local_sum = 0.0f;
+ for (int i = lane; i < num_experts; i += kWarpSize) {
+ local_sum += expf(s_scores[i] - max_val);
+ }
+ const float inv_sum = SafeInvSum(WarpReduceSum(local_sum));
+
+ __syncwarp();
+ WarpMergeSorter::Sort(s_scores, s_indices, temp_storage, num_experts);
+ __syncwarp();
+
+ // s_scores[r]/s_indices[r] now hold the rank-r logit/expert index.
+ float scale_sum = 0.0f;
+ if (normalize_scales) {
+ for (int i = lane; i < k; i += kWarpSize) {
+ scale_sum += SoftmaxScale(s_scores[i], max_val, inv_sum);
+ }
+ scale_sum = WarpReduceSum(scale_sum);
+ }
+ const float denom = TopKNormalizeDenom(normalize_scales, scale_sum);
+
+ for (int i = lane; i < k; i += kWarpSize) {
+ const float prob = SoftmaxScale(s_scores[i], max_val, inv_sum);
+ topk_scales[static_cast(row) * k + i] = normalize_scales ? (prob / denom) : prob;
+ topk_indices[static_cast(row) * k + i] = s_indices[i];
+ }
+}
+
+template
+void DispatchSoftmaxTopK(const T* logits, float* topk_scales, int* topk_indices,
+ int num_rows, int num_experts, int k, bool normalize_scales,
+ cudaStream_t stream) {
+ ORT_ENFORCE(k > 0 && k <= num_experts,
+ "SoftmaxTopK requires 0 < k <= num_experts, got k=", k,
+ " and num_experts=", num_experts);
+
+ // Block-per-row CUB merge sort is the fastest path for the common decode case
+ // (one block fully sorts a row). Pick the smallest capacity that covers
+ // num_experts. k must fit in s_topk (<= 64).
+ const dim3 grid(static_cast(num_rows));
+ if (k <= 64 && num_experts <= 1024) {
+ // Tiny expert counts: a single warp sorts a row entirely in registers via
+ // an in-register bitonic sort (no shared memory). Multiple warps per block
+ // process multiple rows for better occupancy.
+ if (num_experts <= onnxruntime::cuda::topk::kWarpBitonicMaxSize) {
+ constexpr int kWarpsPerBlock = 8;
+ const dim3 block(static_cast(onnxruntime::cuda::topk::kWarpSize), kWarpsPerBlock);
+ const dim3 bitonic_grid(static_cast((num_rows + kWarpsPerBlock - 1) / kWarpsPerBlock));
+ SoftmaxTopKWarpBitonicKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ } else if (num_experts <= onnxruntime::cuda::topk::kWarpMergeMaxSize) {
+ // Single warp per row sorts up to 64 logits in shared memory (CUB warp
+ // merge sort), the genai-recommended path for sort sizes in (32, 64].
+ SoftmaxTopKWarpMergeKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ } else if (num_experts <= 128) {
+ SoftmaxTopKMergeKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ } else if (num_experts <= 256) {
+ SoftmaxTopKMergeKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ } else if (num_experts <= 512) {
+ SoftmaxTopKMergeKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ } else /*if (num_experts <= 1024)*/ {
+ SoftmaxTopKMergeKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ return;
+ }
+ } else {
+ // Fall back to the simple one-thread-per-row kernel.
+ const int block = 256;
+ const int grid_1d = Compute1DGridSize(num_rows, block);
+ SoftmaxTopKKernel<<>>(
+ logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ }
+}
+
void LaunchSoftmaxTopK(
const float* logits,
float* topk_scales,
@@ -97,9 +401,7 @@ void LaunchSoftmaxTopK(
int k,
bool normalize_scales,
cudaStream_t stream) {
- int block = 256;
- int grid = Compute1DGridSize(num_rows, block);
- SoftmaxTopKKernel<<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ DispatchSoftmaxTopK(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream);
}
void LaunchSoftmaxTopK(
@@ -111,9 +413,7 @@ void LaunchSoftmaxTopK(
int k,
bool normalize_scales,
cudaStream_t stream) {
- int block = 256;
- int grid = Compute1DGridSize(num_rows, block);
- SoftmaxTopKKernel<<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ DispatchSoftmaxTopK(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream);
}
void LaunchSoftmaxTopK(
@@ -125,9 +425,7 @@ void LaunchSoftmaxTopK(
int k,
bool normalize_scales,
cudaStream_t stream) {
- int block = 256;
- int grid = Compute1DGridSize(num_rows, block);
- SoftmaxTopKKernel<__nv_bfloat16><<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales);
+ DispatchSoftmaxTopK<__nv_bfloat16>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream);
}
template
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
index 684e050f0201a..02e764d01e05e 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
@@ -16,6 +16,40 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {
+// WGSL helper function for normalizing on-device indirect dispatch dims.
+// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram.
+// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers:
+// 1) direct (x, y, z) write when every dim is within the spec limit (65535);
+// 2) 2D sqrt collapse when the product fits a square layout;
+// 3) 3D cbrt collapse otherwise.
+// Consumers are unaffected by the chosen layout: ShaderHelper flattens
+// workgroup_id (x, y, z) into a single linear workgroup_idx.
+// Caller contract: must register a storage output named exactly
+// `indirect_buffer` of array with at least 3 elements.
+constexpr const char kNormalizeDispatchGroupSizeFn[] = R"(
+fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) {
+ let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension
+ if (x <= limit && y <= limit && z <= limit) {
+ indirect_buffer[0] = x;
+ indirect_buffer[1] = y;
+ indirect_buffer[2] = z;
+ return;
+ }
+ let size = f32(x) * f32(y) * f32(z);
+ let dispatch_avg_2d = u32(ceil(sqrt(size)));
+ if (dispatch_avg_2d <= limit) {
+ indirect_buffer[0] = dispatch_avg_2d;
+ indirect_buffer[1] = dispatch_avg_2d;
+ indirect_buffer[2] = 1u;
+ return;
+ }
+ let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0)));
+ indirect_buffer[0] = dispatch_avg_3d;
+ indirect_buffer[1] = dispatch_avg_3d;
+ indirect_buffer[2] = dispatch_avg_3d;
+}
+)";
+
Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
@@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha
if (prepare_indirect_dispatch_) {
sh.AddOutput("indirect_buffer", ShaderUsage::None);
+ sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn;
}
return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
@@ -87,13 +122,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add indirect dispatch logic for thread 0
if (prepare_indirect_dispatch_) {
- // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size.
- shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n"
- << " if (global_idx == 0u) {\n"
+ shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn;
+ shader.MainFunctionBody() << " if (global_idx == 0u) {\n"
<< " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n"
- << " indirect_buffer[0] = num_total_seq_length_tile;\n"
- << " indirect_buffer[1] = uniforms.num_heads;\n"
- << " indirect_buffer[2] = 1u;\n"
+ << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n"
<< " }\n\n";
}
@@ -120,7 +152,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
const Tensor* K, const Tensor* past_key, Tensor* present_key,
const Tensor* V, const Tensor* past_value, Tensor* present_value,
- uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) {
+ uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) {
// CopyKVCache takes past key/value and current key/value and copies them to present key and value.
// This makes it so that FlashAttention only needs to look at present key and value, and saves
// number of input buffers in the shader, which we run out of (<=8) without this optimization.
@@ -176,7 +208,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
{static_cast(parameters.total_sequence_length_)},
{static_cast(parameters.kv_sequence_length_)},
{tile_size},
- {static_cast(parameters.num_heads_)}});
+ {static_cast(parameters.num_heads_)},
+ {static_cast(parameters.batch_size_)},
+ {num_q_tiles}});
return context.RunProgram(program);
}
@@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_));
}
-Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
- shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
- shader.AddOutput("output", ShaderUsage::UseUniform);
- shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
+ const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const uint32_t tile_size_k_vec = 8;
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
- return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template",
+ return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
+ WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
+ WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_),
+ WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
+ WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_),
+ WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_),
+ WGSL_TEMPLATE_VARIABLE(metadata, metadata),
+ WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx),
+ WGSL_TEMPLATE_VARIABLE(present_key, present_key),
+ WGSL_TEMPLATE_VARIABLE(present_value, present_value),
+ WGSL_TEMPLATE_VARIABLE(q, q));
}
-Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
- const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k,
- const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
+Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
+ const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value,
+ Tensor* metadata, const Tensor* seqlen_k,
+ const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) {
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_))
: parameters.scale_;
const bool has_attention_bias = attention_bias != nullptr;
const int components = 4;
+ const int head_size_vec = parameters.v_head_size_ / components;
- FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch};
+ bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
+ bool is_unidirectional = parameters.is_unidirectional_;
+ FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
- {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}});
+ {present_key, ProgramTensorMetadataDependency::TypeAndRank, components},
+ {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
- program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank},
+ program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
{metadata, ProgramTensorMetadataDependency::Rank, 2}});
const uint32_t vectorized_head_size = parameters.head_size_ / components;
- // Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
@@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
if (use_indirect_dispatch) {
program.SetIndirectDispatchTensor(indirect_buffer);
} else {
- program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
+ program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile);
}
program.SetWorkgroupSize(64)
- .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
+ .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile)
.AddUniformVariables({{static_cast(vectorized_head_size)},
{static_cast(parameters.total_sequence_length_)},
{static_cast(alpha)},
@@ -294,124 +342,72 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
{static_cast(parameters.num_heads_)},
{static_cast(parameters.batch_size_)},
{attn_bias_dim0},
- {attn_bias_dim1}});
+ {attn_bias_dim1},
+ {static_cast(parameters.sequence_length_)}});
return context.RunProgram(program);
}
-Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("metadata", ShaderUsage::UseUniform);
- shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
- shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
+ const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform);
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_head_sink_) {
shader.AddInput("head_sink", ShaderUsage::UseUniform);
}
- shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
-
- const uint32_t tile_size_k_vec = 8u;
-
- return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template",
- WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
- WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_),
- WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
- WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
-}
-
-Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context,
- const Tensor* metadata,
- const Tensor* qk,
- Tensor* out_split_vx,
- Tensor* present_value,
- const Tensor* seqlen_k,
- const WebgpuAttentionParameters& parameters,
- const Tensor* indirect_buffer,
- uint32_t num_total_seq_length_tile,
- uint32_t num_present_sequence_length_tile,
- uint32_t tile_size,
- bool use_indirect_dispatch,
- uint32_t present_sequence_length,
- const Tensor* head_sink) {
- const int components = 4;
- const bool has_head_sink = head_sink != nullptr;
- int head_size_vec = parameters.v_head_size_ / components;
- FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink};
- program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2},
- {qk, ProgramTensorMetadataDependency::TypeAndRank},
- {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
- program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
- const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_);
- if (use_indirect_dispatch) {
- program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
- }
- if (has_head_sink) {
- program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
- }
- // SetIndirectDispatchTensor must be called after all AddInput calls because it
- // appends the indirect buffer as the last program input.
- if (use_indirect_dispatch) {
- program.SetIndirectDispatchTensor(indirect_buffer);
- } else {
- program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
- }
- program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
- .SetWorkgroupSize(64)
- .AddUniformVariables({{static_cast(parameters.total_sequence_length_)},
- {static_cast(head_size_vec)},
- present_sequence_length,
- {static_cast(parameters.n_reps)},
- num_present_sequence_length_tile,
- {batch_heads},
- {static_cast(parameters.num_heads_)}});
-
- return context.RunProgram(program);
-}
-
-Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("input", ShaderUsage::UseUniform);
- if (use_indirect_dispatch_) {
- shader.AddInput("seqlens_k", ShaderUsage::None);
- }
- shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template",
+ WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
+ WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_),
WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
+ WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_),
+ WGSL_TEMPLATE_VARIABLE(input, input),
+ WGSL_TEMPLATE_VARIABLE(metadata, metadata),
+ WGSL_TEMPLATE_VARIABLE(output, output));
}
Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context,
const Tensor* out_split_vx,
+ const Tensor* metadata,
Tensor* output,
const Tensor* seqlen_k,
const WebgpuAttentionParameters& parameters,
uint32_t num_total_seq_length_tile,
uint32_t num_present_sequence_length_tile,
uint32_t seq_tile_size,
- bool use_indirect_dispatch) {
+ bool use_indirect_dispatch,
+ const Tensor* head_sink,
+ uint32_t m_tile) {
const int components = 4;
constexpr int tile_size = 8;
int tile_head_size = tile_size * components;
- FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch};
- program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}});
+ bool has_head_sink = head_sink != nullptr;
+ FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile};
+ program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
+ {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}});
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}
+ if (has_head_sink) {
+ program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
+ }
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_);
- program.SetDispatchGroupSize(batch_heads * num_head_size_tile)
- .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch)
+ program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile)
+ .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile)
.SetWorkgroupSize(tile_size * tile_size)
.AddUniformVariables({{static_cast(parameters.v_head_size_ / components)},
num_total_seq_length_tile,
num_present_sequence_length_tile,
{num_head_size_tile},
- {batch_heads}});
+ {batch_heads},
+ {static_cast(parameters.sequence_length_)},
+ {static_cast(parameters.num_heads_)}});
return context.RunProgram(program);
}
@@ -446,14 +442,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
// Declare query_output at function scope to ensure it persists throughout the function
Tensor query_output;
+ // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch.
+ const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u);
+ const uint32_t num_q_tiles = (static_cast(parameters.sequence_length_) + m_tile - 1u) / m_tile;
+
// Create indirect dispatch buffer if using indirect dispatch
Tensor* indirect_buffer_ptr = nullptr;
Tensor indirect_buffer;
- // Prepare indirect dispatch buffer for decode path with static KV cache
- const bool use_indirect_dispatch = !kv_empty &&
- parameters.sequence_length_ == 1 &&
- parameters.past_present_share_buffer_ &&
+ // Prepare indirect dispatch buffer for split-reduce path with static KV cache.
+ // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based
+ // seqlen_k), so the indirect buffer computes dispatch sizes on GPU.
+ const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
seqlen_k != nullptr &&
context.IsGraphCaptureEnabled();
if (use_indirect_dispatch) {
@@ -492,10 +492,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
Q, seqlen_k,
cos_cache, sin_cache,
&query_output, present_key, present_value,
- indirect_buffer_ptr, tile_size));
+ indirect_buffer_ptr, tile_size, num_q_tiles));
Q = &query_output;
} else {
- ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr));
+ ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles));
}
// Extract present_sequence_length directly from present_key tensor shape
@@ -503,7 +503,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]);
- if (parameters.sequence_length_ > 1) {
+ // Route between prefill path (FlashAttentionProgram, single kernel)
+ // and split-reduce decode path (QKV + VxReduce, 2 kernels).
+ // Split-reduce wins for short Q (sequence_length < 32) across all KV
+ // cache lengths measured: 1.13x-2.07x faster at total_sequence_length
+ // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96).
+ const bool use_split_reduce = parameters.sequence_length_ < 32;
+
+ if (!use_split_reduce) {
+ // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles)
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -545,7 +553,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
const uint32_t prefill_tile_size = is_apple ? 128 : tile_size;
const uint32_t num_seq_tile = (parameters.sequence_length_ + prefill_tile_size - 1) / prefill_tile_size;
- // Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
@@ -570,37 +577,31 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
return context.RunProgram(program);
}
- // For decode path (sequence_length == 1)
- const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_,
- parameters.sequence_length_, present_sequence_length});
- const TensorShape qk_shape(qk_dims);
- Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
+ // Split-reduce path (QKV + VxReduce)
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
// The metadata is used to store the max and sum of each tile.
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
- num_present_sequence_length_tile, 2});
+ parameters.sequence_length_, num_present_sequence_length_tile, 2});
const TensorShape metadata_shape(metadata_dims);
Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape);
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k,
- parameters, indirect_buffer_ptr, num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
- present_sequence_length));
const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_,
- num_present_sequence_length_tile, parameters.head_size_});
+ parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_});
const TensorShape out_split_vx_shape(out_split_vx_dims);
Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape);
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value,
- seqlen_k, parameters, indirect_buffer_ptr,
- num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size,
- use_indirect_dispatch, present_sequence_length,
- head_sink));
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters,
+
+ ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value,
+ &metadata, seqlen_k,
+ parameters, indirect_buffer_ptr, num_total_seq_length_tile,
+ num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
+ present_sequence_length, m_tile));
+
+ ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters,
num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size, use_indirect_dispatch));
+ num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
+ head_sink, m_tile));
return Status::OK();
}
@@ -621,7 +622,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer,
- uint32_t tile_size) {
+ uint32_t tile_size, uint32_t num_q_tiles) {
const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;
@@ -678,6 +679,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
{present_sequence_length},
{tile_size},
{static_cast(dispatch_size)},
+ {static_cast(params.batch_size_)},
+ {num_q_tiles},
});
program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
index e75b6378f67c6..3da6b33b4dc0e 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
@@ -35,7 +35,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program {
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"tile_size", ProgramUniformVariableDataType::Uint32},
- {"num_heads", ProgramUniformVariableDataType::Uint32});
+ {"num_heads", ProgramUniformVariableDataType::Uint32},
+ {"batch_size", ProgramUniformVariableDataType::Uint32},
+ {"num_q_tiles", ProgramUniformVariableDataType::Uint32});
private:
bool has_past_;
@@ -138,11 +142,14 @@ class FlashAttentionProgram final : public Program {
int max_k_step_;
};
-class FlashAttentionDecodeQKTProgram final : public Program {
+class FlashAttentionDecodeQKVProgram final : public Program {
public:
- FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
- bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
- : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
+ FlashAttentionDecodeQKVProgram(const std::string& kernel_name,
+ bool has_attention_bias, uint32_t tile_size, int head_size_vec,
+ bool use_indirect_dispatch, bool q_BNSH = false,
+ bool is_unidirectional = false,
+ uint32_t m_tile = 1)
+ : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -156,41 +163,23 @@ class FlashAttentionDecodeQKTProgram final : public Program {
- public:
- FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false)
- : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) {
- }
-
- Status GenerateShaderCode(ShaderHelper& sh) const override;
-
- WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32},
- {"head_size_vec", ProgramUniformVariableDataType::Uint32},
- {"present_sequence_length", ProgramUniformVariableDataType::Uint32},
- {"n_reps", ProgramUniformVariableDataType::Uint32},
- {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
- {"batch_heads", ProgramUniformVariableDataType::Uint32},
- {"num_heads", ProgramUniformVariableDataType::Uint32});
-
- private:
uint32_t tile_size_;
int head_size_vec_;
bool use_indirect_dispatch_;
- bool has_head_sink_;
+ bool q_BNSH_;
+ bool is_unidirectional_;
+ uint32_t m_tile_;
};
class FlashAttentionDecodeVxReduceProgram final : public Program {
public:
- FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch)
- : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
+ FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1)
+ : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -199,12 +188,16 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array;
-var inner_qk_values: array, tile_size>;
-var tile_qk: array;
-
-#if has_attention_bias
- fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
- {
- // Handle broadcasting: if dimension size is 1, use index 0
- let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
- let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
-
- // Calculate flat offset with broadcasting applied
- // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length]
- // For decode, new_seq_length is 1, so we can simplify:
- let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length +
- bias_head_idx * total_seq_length +
- k_idx;
- return attention_bias[offset];
- }
-#else
- fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
- {
- return q_element_t(0);
- }
-#endif
-
-$MAIN {
- let local_row = u32(local_idx / tile_size_k_vec);
- let local_col = local_idx % tile_size_k_vec;
-#if use_indirect_dispatch
- let total_sequence_length = u32(seqlens_k[0]) + 1u;
-#else
- let total_sequence_length = uniforms.total_sequence_length;
-#endif
- let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
- let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile);
- let head_idx = batch_head_idx % uniforms.num_heads;
- let batch_idx = batch_head_idx / uniforms.num_heads;
- if (batch_idx >= uniforms.batch_size) {
- return;
- }
- let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec;
- let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec;
- for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
- if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
- tile_q[local_idx] = q[q_offset + k + local_idx];
- }
- workgroupBarrier();
- let q_data = tile_q[local_col] * q_element_t(uniforms.alpha);
- if (k + local_col < uniforms.head_size_vec) {
- for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
- if (total_seq_offset + row_offset + local_row < total_sequence_length) {
- inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data);
- }
- }
- }
- workgroupBarrier();
- }
-
- if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
- var sum = q_element_t(0);
- for (var i = 0u; i < tile_size_k_vec; i++) {
- sum += inner_qk_values[local_idx][i];
- }
-
- sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length);
- tile_qk[local_idx] = sum;
- output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum;
- }
- workgroupBarrier();
-
- if (local_idx == 0u) {
- // Calculate the max and sum in current split.
- var l_max = f32(-3.4028234663852886e+38f);
- var l_sum = f32(0);
- for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
- l_max = max(l_max, f32(tile_qk[i]));
- }
- for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
- l_sum += exp(f32(tile_qk[i]) - l_max);
- }
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
- metadata[meta_offset] = metadata_value_t(l_max, l_sum);
- }
-}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template
new file mode 100644
index 0000000000000..524a18ca43245
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template
@@ -0,0 +1,197 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#param has_attention_bias
+#param v_head_size_vec
+#param is_unidirectional
+#param m_tile
+#param q_BNSH
+#param sub_tile_count
+#param tile_size
+#param tile_size_k_vec
+#param use_indirect_dispatch
+
+#use .getByOffset .setByOffset
+
+// Fused QK^T + softmax + V multiply shader.
+//
+// Each workgroup processes one KV tile (tile_size rows of present_key/value)
+// for m_tile Q rows. The computation has two phases:
+//
+// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask,
+// per-tile max/sum for online softmax)
+// Phase 2: Local softmax normalization + V multiply (using local max/sum,
+// no cross-workgroup dependency)
+//
+// The VxReduce shader performs the final rescaling across tiles.
+
+var tile_q: array, m_tile>;
+var inner_qk_values: array, tile_size>, m_tile>;
+var tile_qk: array, m_tile>;
+var tile_output: array, m_tile>;
+var qkv_values: array, sub_tile_count>, m_tile>;
+var tile_max: array;
+var tile_sum: array;
+
+#if has_attention_bias
+ fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
+ {
+ let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
+ let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
+ let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length +
+ bias_head_idx * uniforms.new_sequence_length * total_seq_length +
+ q_idx * total_seq_length +
+ k_idx;
+ return attention_bias[offset];
+ }
+#else
+ fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
+ {
+ return q_element_t(0);
+ }
+#endif
+
+$MAIN {
+ let local_row = u32(local_idx / tile_size_k_vec);
+ let local_col = local_idx % tile_size_k_vec;
+ #if use_indirect_dispatch
+ let total_sequence_length = u32(seqlens_k[0]) + 1u;
+ #else
+ let total_sequence_length = uniforms.total_sequence_length;
+ #endif
+ let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
+ let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile;
+ // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile]
+ let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
+ let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles;
+ let q_base = q_tile_idx * m_tile;
+ let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles));
+ let head_idx = batch_head_idx % uniforms.num_heads;
+ let batch_idx = batch_head_idx / uniforms.num_heads;
+ if (batch_idx >= uniforms.batch_size) {
+ return;
+ }
+ let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec;
+ let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length;
+
+ // ============================================================
+ // Phase 1: QK^T computation
+ // ============================================================
+ for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
+ let q_idx = q_base + m;
+#if q_BNSH
+ let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec +
+ head_idx * uniforms.new_sequence_length * uniforms.head_size_vec +
+ q_idx * uniforms.head_size_vec;
+#else
+ let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec +
+ q_idx * uniforms.num_heads * uniforms.head_size_vec +
+ head_idx * uniforms.head_size_vec;
+#endif
+ tile_q[m][local_idx] = q.getByOffset(q_offset + k + local_idx);
+ }
+ }
+ workgroupBarrier();
+ if (k + local_col < uniforms.head_size_vec) {
+ for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
+ if (total_seq_offset + row_offset + local_row < total_sequence_length) {
+ let k_data = present_key.getByOffset(present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col);
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha);
+ inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data);
+ }
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+
+ // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
+ var sum = q_element_t(0);
+ for (var i = 0u; i < tile_size_k_vec; i++) {
+ sum += inner_qk_values[m][local_idx][i];
+ }
+
+ sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length);
+#if is_unidirectional
+ if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) {
+ sum = q_element_t(-65504.0f);
+ }
+#endif
+ tile_qk[m][local_idx] = present_value_element_t(sum);
+ }
+ workgroupBarrier();
+
+ // Compute per-tile max and sum for online softmax
+ if (local_idx == 0u) {
+ var l_max = f32(-3.4028234663852886e+38f);
+ var l_sum = f32(0);
+ for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
+ l_max = max(l_max, f32(tile_qk[m][i]));
+ }
+ for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
+ l_sum += exp(f32(tile_qk[m][i]) - l_max);
+ }
+ tile_max[m] = l_max;
+ tile_sum[m] = l_sum;
+ let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
+ metadata.setByOffset(meta_offset, metadata_value_t(l_max, l_sum));
+ }
+ }
+ workgroupBarrier();
+
+ // ============================================================
+ // Phase 2: Local softmax + V multiply
+ // ============================================================
+
+ // Normalize tile_qk with local max/sum
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
+ tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]);
+ }
+ }
+ workgroupBarrier();
+
+ for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ qkv_values[m][local_row][local_col] = present_value_value_t(0);
+ }
+ workgroupBarrier();
+
+ if (k + local_col < v_head_size_vec) {
+ for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
+ if (total_seq_offset + row_offset + local_row < total_sequence_length) {
+ let v_data = present_value.getByOffset(present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col);
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row];
+ }
+ }
+ }
+ }
+ workgroupBarrier();
+
+ if (local_idx < tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ for (var i = 0u; i < sub_tile_count; i++) {
+ tile_output[m][k + local_idx] += qkv_values[m][i][local_idx];
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+
+ // Write output
+ let tile_idx = workgroup_idx % num_total_seq_length_tile;
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec;
+ for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) {
+ out_split_vx.setByOffset(out_base + tile_idx * v_head_size_vec + i, tile_output[m][i]);
+ }
+ }
+}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template
deleted file mode 100644
index 6f1ad1ca41b71..0000000000000
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template
+++ /dev/null
@@ -1,113 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#param has_head_sink
-#param tile_size
-#param head_size_vec
-#param tile_size_k_vec
-#param sub_tile_count
-#param use_indirect_dispatch
-
-// Note that this shader adopts similar algorithm with dp4a generation shader.
-//
-// This algorithm works to compute dot product of v with qk parallelly, by
-// processing on the head_size dimension at each step amongst tile_size_k_vec
-// threads, and utilizing the remaining threads in the workgroup to process
-// additional rows of |present_value| in parallel (such that the values in
-// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads
-// also reload |present_value| tile_size/sub_tile_count times to compute partial
-// dot products of other |present_value| rows in order to complete all tile_size
-// |present_value| rows in this workgroup and also reusing the values in
-// tile_qk.
-//
-// The difference with FlashAttentionDecodeQKTProgram is that the dot products
-// go through the rows (total_sequence_length) of |present_value| instead of
-// columns (head_size_vec). And each workgroup only calculate current
-// tile_size's dot products instead of iterating the whole row
-// |total_sequence_length|. That's why this shader is a split shader. The final
-// reduce will be done in FlashAttentionDecodeReduceProgram.
-
-// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx
-// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate
-// memory. The FlashAttentionDecodeQKT can be merged into split shader and do
-// the final softmax adjustment in the reduce shader. However, some issues are
-// met that when the total sequence length exceeds some value, the result will
-// become garbage. Since it can't be resolved in a short time, leave it as TODO
-// to fix it in future.
-
-var tile_qk: array;
-var tile_output: array;
-var qkv_values: array, sub_tile_count>;
-
-$MAIN {
- let local_row = u32(local_idx / tile_size_k_vec);
- let local_col = local_idx % tile_size_k_vec;
- #if use_indirect_dispatch
- let total_sequence_length = u32(seqlens_k[0]) + 1u;
- #else
- let total_sequence_length = uniforms.total_sequence_length;
- #endif
- let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
- let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile);
- if (batch_head_idx >= uniforms.batch_heads) {
- return;
- }
- let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length;
-
- // Calculate the global max and sum in qk.
- var g_max = f32(-3.4028234663852886e+38f);
-#if has_head_sink
- let head_idx = batch_head_idx % uniforms.num_heads;
- let sink_value = f32(head_sink[head_idx]);
- g_max = max(g_max, sink_value);
-#endif
- var g_sum = f32(0);
- for (var i = 0u; i < num_total_seq_length_tile; i++)
- {
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i;
- g_max = max(g_max, metadata[meta_offset].x);
- }
- for (var i = 0u; i < num_total_seq_length_tile; i++)
- {
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i;
- let m_value = metadata[meta_offset];
- g_sum += exp(m_value.x - g_max) * m_value.y;
- }
-#if has_head_sink
- g_sum += exp(sink_value - g_max);
-#endif
-
- if (total_seq_offset + local_idx < total_sequence_length) {
- tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum);
- }
-
- for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) {
- var value = present_value_value_t(0);
- qkv_values[local_row][local_col] = present_value_value_t(0);
- workgroupBarrier();
-
- if (k + local_col < head_size_vec) {
- for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
- if (total_seq_offset + row_offset + local_row < total_sequence_length) {
- value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row];
- }
- }
- }
-
- qkv_values[local_row][local_col] = value;
- workgroupBarrier();
-
- if (local_idx < tile_size_k_vec) {
- for (var i = 0u; i < sub_tile_count; i++) {
- tile_output[k + local_idx] += qkv_values[i][local_idx];
- }
- }
- workgroupBarrier();
- }
-
- for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) {
- let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i;
- out_split_vx[out_offset] = tile_output[i];
- }
-}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
index f909a87724da6..a3ce0b68cb659 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
@@ -1,31 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#param has_head_sink
+#param m_tile
#param seq_tile_size
#param tile_size
#param use_indirect_dispatch
-// Inputs are splits of the GQA output, split into num_total_seq_length_tiles
-// rows. This shader needs to add these splits across the row dimension to
-// arrive at the final result. The column is head size wide. The reduction
-// achieves maximum parallelization by splitting this task first into tile_size
-// columns that each workgroup is responsible for. Then within each workgroup
-// the task of summation over the num_total_seq_length_tile for the tile_size
-// columns is further split in two ways. First across the row dimension to have
-// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE
-// rows. Then across the column dimension where each thread is responsible for 1
-// column of the TILE_SIZE columns the workgroup is responsible for.
+#use .getByOffset .setByOffset
+
+// This shader reduces partial V outputs from the fused QKV shader.
+// Each tile produced a locally-normalized V contribution. To get the
+// correct global result, we rescale each tile's contribution using
+// per-tile metadata (max, sum) with online softmax:
+//
+// global_max = max(local_max_i for all tiles)
+// global_sum = sum(local_sum_i * exp(local_max_i - global_max))
+// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum
var tile_input: array, tile_size>;
$MAIN {
+ let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile;
+ // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile]
let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile);
+ let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles;
+ let q_base = q_tile_idx * m_tile;
+ let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles));
if (batch_head_idx >= uniforms.batch_heads) {
return;
}
- let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec;
- var value = output_value_t(0);
let local_row = u32(local_idx / tile_size);
let local_col = local_idx % tile_size;
#if use_indirect_dispatch
@@ -35,23 +39,60 @@ $MAIN {
let num_total_seq_length_tile = uniforms.num_total_seq_length_tile;
#endif
- if (head_size_offset + local_col < uniforms.head_size_vec) {
- for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
- if (r + local_row < num_total_seq_length_tile) {
- value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col];
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec;
+ let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile;
+
+ // Compute global max across all tiles
+ var g_max = f32(-3.4028234663852886e+38f);
+#if has_head_sink
+ let head_idx_for_sink = batch_head_idx % uniforms.num_heads;
+ let sink_value = f32(head_sink[head_idx_for_sink]);
+ g_max = max(g_max, sink_value);
+#endif
+ for (var i = 0u; i < num_total_seq_length_tile; i++) {
+ g_max = max(g_max, metadata.getByOffset(meta_base + i).x);
+ }
+
+ // Compute global sum with rescaling
+ var g_sum = f32(0);
+ for (var i = 0u; i < num_total_seq_length_tile; i++) {
+ let m_value = metadata.getByOffset(meta_base + i);
+ g_sum += m_value.y * exp(m_value.x - g_max);
+ }
+#if has_head_sink
+ g_sum += exp(sink_value - g_max);
+#endif
+
+ // Accumulate rescaled partial outputs
+ var value = output_value_t(0);
+ if (head_size_offset + local_col < uniforms.head_size_vec) {
+ for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
+ if (r + local_row < num_total_seq_length_tile) {
+ let tile_meta = metadata.getByOffset(meta_base + r + local_row);
+ let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum;
+ value += input.getByOffset(in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col) * output_value_t(output_element_t(rescale_f32));
+ }
}
}
- }
- tile_input[local_row][local_col] = value;
- workgroupBarrier();
+ tile_input[local_row][local_col] = value;
+ workgroupBarrier();
- if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) {
- value = output_value_t(0);
- for (var i = 0u; i < tile_size; i++) {
- value += tile_input[i][local_idx];
+ if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) {
+ value = output_value_t(0);
+ for (var i = 0u; i < tile_size; i++) {
+ value += tile_input[i][local_idx];
+ }
+ let head_idx = batch_head_idx % uniforms.num_heads;
+ let batch_idx = batch_head_idx / uniforms.num_heads;
+ let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec +
+ q_idx * uniforms.num_heads * uniforms.head_size_vec +
+ head_idx * uniforms.head_size_vec +
+ head_size_offset + local_idx;
+ output.setByOffset(output_id, value);
}
- let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx;
- output[output_id] = value;
+ workgroupBarrier();
}
}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
index c64bdf45cdcf8..7b09a3a6af080 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
+++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
@@ -41,12 +41,9 @@ $MAIN {
#endif
#if prepare_indirect_dispatch
- // Prepare indirect dispatch buffer for thread 0
if (global_idx == 0u) {
let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size;
- indirect_buffer[0] = num_total_seq_length_tile;
- indirect_buffer[1] = uniforms.num_heads;
- indirect_buffer[2] = 1u;
+ normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);
}
#endif
diff --git a/onnxruntime/core/framework/layering_annotations.cc b/onnxruntime/core/framework/layering_annotations.cc
index b3a41d714137f..f4dddecd207b8 100644
--- a/onnxruntime/core/framework/layering_annotations.cc
+++ b/onnxruntime/core/framework/layering_annotations.cc
@@ -13,6 +13,7 @@
#include "core/framework/execution_providers.h"
#include "core/graph/graph.h"
+#include
#include
namespace onnxruntime {
@@ -335,28 +336,62 @@ LayeringIndex LayeringIndex::Create(const Graph& graph,
EpNameToLayeringIndices ep_map,
LayeringIndexToEpName rule_map,
LayeringRules layering_rules) {
- // 1. Create LayeringIndex instance with pre-computed maps
LayeringIndex index(std::move(layering_rules), std::move(ep_map), std::move(rule_map));
-
- // 2. Traverse the graph and index nodes
index.ProcessGraph(graph, std::nullopt);
+ return index;
+}
+LayeringIndex LayeringIndex::Create(const Graph& graph,
+ EpNameToLayeringIndices ep_map,
+ LayeringIndexToEpName rule_map,
+ LayeringRules layering_rules,
+ SubstringMatcher substring_matcher) {
+ LayeringIndex index(std::move(layering_rules), std::move(ep_map), std::move(rule_map),
+ std::move(substring_matcher));
+ index.ProcessGraph(graph, std::nullopt);
return index;
}
Status LayeringIndex::Create(const Graph& graph,
const std::string& config_string,
+ const std::string& name_based_config_string,
gsl::span ep_devices,
const ExecutionProviders& ep_providers,
const logging::Logger& logger,
std::optional& layering_index) {
- LayeringRules rules;
- ORT_RETURN_IF_ERROR(LayeringRules::FromConfigString(config_string, rules));
+ // Annotation-based and name-based layer assignment are mutually exclusive.
+ if (!config_string.empty() && !name_based_config_string.empty()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Cannot set both 'session.layer_assignment_settings' and "
+ "'session.name_based_layer_assignment'. These options are mutually exclusive. "
+ "Use annotation-based matching for models with explicit layer annotations, "
+ "or name-based matching for models with structured node names.");
+ }
- LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " layering rules from config.";
+ const bool is_name_based = !name_based_config_string.empty();
+ const std::string& active_config = is_name_based ? name_based_config_string : config_string;
+
+ LayeringRules rules;
+ if (!active_config.empty()) {
+ ORT_RETURN_IF_ERROR(LayeringRules::FromConfigString(active_config, rules));
+
+ if (is_name_based) {
+ // Reject '=' (exact-match qualifier) in name-based rules — all patterns must be substrings
+ for (const auto& rule : rules.rules) {
+ if (!rule.prefix_match) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Name-based layer assignment does not support the '=' (exact-match) qualifier. "
+ "All patterns are treated as substrings. Remove the '=' prefix from pattern: '",
+ rule.annotation, "'");
+ }
+ }
+ LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " name-based layering rules from config.";
+ } else {
+ LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " annotation-based layering rules from config.";
+ }
+ }
if (rules.rules.empty()) {
- // Return no index indicating no layering
layering_index.reset();
return Status::OK();
}
@@ -384,9 +419,6 @@ Status LayeringIndex::Create(const Graph& graph,
if (matched_ep) {
const std::string& ep_type = *matched_ep;
ep_map[ep_type].insert(i);
- // Ensure 1:1 mapping from rule index to EP type
- // Note: A rule index refers to a unique entry in LayeringRules::rules vector.
- // So 'i' is unique.
rule_map[i] = ep_type;
matched_rule_count++;
LOGS(logger, VERBOSE) << "Layering Rule " << i << " (" << rule.device << " -> " << rule.annotation
@@ -402,7 +434,17 @@ Status LayeringIndex::Create(const Graph& graph,
LOGS(logger, INFO) << "LayeringIndex created. Matched " << matched_rule_count
<< " out of " << rules.rules.size() << " rules to available Execution Providers.";
- layering_index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules));
+ // Build SubstringMatcher for name-based mode
+ std::optional substring_matcher;
+ if (is_name_based) {
+ substring_matcher.emplace(rules);
+ }
+
+ // Create LayeringIndex — annotation mode uses matcher_ only, name-based uses substring_matcher_ only
+ LayeringIndex index(std::move(rules), std::move(ep_map), std::move(rule_map),
+ std::move(substring_matcher));
+ index.ProcessGraph(graph, std::nullopt);
+ layering_index = std::move(index);
return Status::OK();
}
@@ -423,16 +465,23 @@ void LayeringIndex::ProcessGraph(const Graph& graph, std::optional paren
for (auto& node : graph.Nodes()) {
std::optional matched_rule_idx = std::nullopt;
- // 4. For every node query its annotation
- const std::string& annotation = node.GetLayeringAnnotation();
- if (!annotation.empty()) {
- // If it has an annotation try to match it
- matched_rule_idx = matcher_.Match(annotation);
- }
+ if (substring_matcher_) {
+ // Name-based mode: substring matching against node name, no inheritance.
+ // Node names are dense (virtually every node has one), so inheritance is
+ // unnecessary — each node is matched independently by its own name.
+ matched_rule_idx = substring_matcher_->Match(node.Name());
+ } else {
+ // Annotation-based mode: prefix/exact match against metadata annotation,
+ // with subgraph inheritance for unannotated nodes.
+ const std::string& annotation = node.GetLayeringAnnotation();
+ if (!annotation.empty()) {
+ matched_rule_idx = matcher_.Match(annotation);
+ }
- // 5. If node has no annotation, inherit from subgraph parent node
- if (!matched_rule_idx && parent_layer_id) {
- matched_rule_idx = parent_layer_id;
+ // Inherit from subgraph parent node if no annotation match
+ if (!matched_rule_idx && parent_layer_id) {
+ matched_rule_idx = parent_layer_id;
+ }
}
// Record assignment if we have a match
@@ -485,28 +534,36 @@ void LayeringIndex::Update(const Graph& graph, gsl::span nodes)
continue;
}
- const std::string& annotation = node->GetLayeringAnnotation();
- if (!annotation.empty()) {
- auto matched_rule_idx = matcher_.Match(annotation);
-
- if (matched_rule_idx) {
- const size_t rule_idx = *matched_rule_idx;
-
- // Only assign if this rule maps to a valid EP in our configuration
- if (layering_index_to_ep_name_.count(rule_idx)) {
- // Check if already assigned to a DIFFERENT rule, if so clean up old mapping
- auto prev_assign = current_graph_index.node_to_layering_index_.find(node_index);
- if (prev_assign != current_graph_index.node_to_layering_index_.end()) {
- size_t old_rule = prev_assign->second;
- if (old_rule != rule_idx) {
- current_graph_index.layer_to_node_ids_[old_rule].erase(node_index);
- }
- }
+ std::optional matched_rule_idx;
- ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node_index, rule_idx));
- ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node_index));
- was_updated = true;
+ if (substring_matcher_) {
+ // Name-based mode: substring match against node name
+ matched_rule_idx = substring_matcher_->Match(node->Name());
+ } else {
+ // Annotation-based mode: prefix/exact match against metadata
+ const std::string& annotation = node->GetLayeringAnnotation();
+ if (!annotation.empty()) {
+ matched_rule_idx = matcher_.Match(annotation);
+ }
+ }
+
+ if (matched_rule_idx) {
+ const size_t rule_idx = *matched_rule_idx;
+
+ // Only assign if this rule maps to a valid EP in our configuration
+ if (layering_index_to_ep_name_.count(rule_idx)) {
+ // Check if already assigned to a DIFFERENT rule, if so clean up old mapping
+ auto prev_assign = current_graph_index.node_to_layering_index_.find(node_index);
+ if (prev_assign != current_graph_index.node_to_layering_index_.end()) {
+ size_t old_rule = prev_assign->second;
+ if (old_rule != rule_idx) {
+ current_graph_index.layer_to_node_ids_[old_rule].erase(node_index);
+ }
}
+
+ ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node_index, rule_idx));
+ ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node_index));
+ was_updated = true;
}
}
}
@@ -544,6 +601,30 @@ void LayeringRuleMatcher::UpdateBestMatch(std::optional& current_best, s
}
}
+SubstringMatcher::SubstringMatcher(const LayeringRules& rules) {
+ for (size_t i = 0; i < rules.rules.size(); ++i) {
+ const auto& rule = rules.rules[i];
+ if (!rule.annotation.empty()) {
+ patterns_.push_back({rule.annotation, i});
+ }
+ }
+ // Sort by pattern length descending (longest first).
+ // Stable sort preserves config order as tiebreaker for same-length patterns.
+ std::stable_sort(patterns_.begin(), patterns_.end(),
+ [](const PatternEntry& a, const PatternEntry& b) {
+ return a.pattern.size() > b.pattern.size();
+ });
+}
+
+std::optional SubstringMatcher::Match(std::string_view node_name) const {
+ for (const auto& entry : patterns_) {
+ if (node_name.find(entry.pattern) != std::string_view::npos) {
+ return entry.rule_index;
+ }
+ }
+ return std::nullopt;
+}
+
std::optional>>
LayeringIndex::GetLayeringRulesForThisEp(const std::string& ep_type) const {
auto hit = ep_name_to_layering_indices_.find(ep_type);
diff --git a/onnxruntime/core/framework/layering_annotations.h b/onnxruntime/core/framework/layering_annotations.h
index 5d58e9ace2471..4114527e07d01 100644
--- a/onnxruntime/core/framework/layering_annotations.h
+++ b/onnxruntime/core/framework/layering_annotations.h
@@ -11,6 +11,7 @@
#include "core/common/logging/logging.h"
#include "gsl/gsl"
#include
+#include
#include
#include
#include
@@ -57,6 +58,7 @@ struct LayeringRules {
///
class LayeringRuleMatcher {
public:
+ /// The annotation-based layering rules to index.
explicit LayeringRuleMatcher(const LayeringRules& rules);
///
@@ -83,6 +85,35 @@ class LayeringRuleMatcher {
void UpdateBestMatch(std::optional& current_best, size_t candidate) const;
};
+///
+/// Performs substring matching against node names. Unlike LayeringRuleMatcher (which does
+/// prefix/exact matching from position 0), this matches patterns appearing anywhere in the
+/// input string. Longest matching pattern wins.
+///
+class SubstringMatcher {
+ public:
+ /// The rules whose annotations become substring patterns.
+ /// The '=' prefix (exact match) qualifier is rejected during config parsing — all patterns
+ /// must be substrings.
+ explicit SubstringMatcher(const LayeringRules& rules);
+
+ ///
+ /// Returns the index of the best matching rule for the given node name.
+ /// "Best" = longest pattern that appears as a substring in the name.
+ ///
+ /// the node's name to match against
+ /// index of the matching LayeringRule if a substring match is found
+ std::optional Match(std::string_view node_name) const;
+
+ private:
+ struct PatternEntry {
+ std::string pattern;
+ size_t rule_index;
+ };
+ // Sorted by pattern length descending. First match wins (longest-match priority).
+ InlinedVector patterns_;
+};
+
namespace EpLayeringMatcher {
///
/// Matches a list of available OrtEpDevices against the device string specified in the LayerAnnotation.
@@ -125,12 +156,25 @@ class LayeringIndex {
LayeringIndexToEpName rule_map,
LayeringRules layering_rules);
+ ///
+ /// Creates a fully initialized LayeringIndex with a SubstringMatcher for name-based matching.
+ /// In this mode, annotation matching is disabled and no subgraph inheritance is applied.
+ ///
+ static LayeringIndex Create(const Graph& graph,
+ EpNameToLayeringIndices ep_map,
+ LayeringIndexToEpName rule_map,
+ LayeringRules layering_rules,
+ SubstringMatcher substring_matcher);
+
///
/// Factory method that creates a LayeringIndex by parsing configuration, matching rules against
/// available devices/providers, and indexing the graph.
+ /// Annotation-based and name-based options are mutually exclusive — setting both returns an error.
///
/// The graph to index.
- /// The configuration string containing layering rules.
+ /// The annotation-based configuration string (prefix/exact match on metadata).
+ /// The name-based configuration string (substring match on Node::Name()).
+ /// May be empty if name-based matching is not configured.
/// Available OrtEpDevices to match rules against.
/// Available ExecutionProviders to match rules against (fallback).
/// Logger for reporting information/errors.
@@ -139,6 +183,7 @@ class LayeringIndex {
/// Status indicating success or failure.
static Status Create(const Graph& graph,
const std::string& config_string,
+ const std::string& name_based_config_string,
gsl::span ep_devices,
const ExecutionProviders& ep_providers,
const logging::Logger& logger,
@@ -192,11 +237,17 @@ class LayeringIndex {
LayerIndexToNodes layer_to_node_ids_;
};
- LayeringIndex(LayeringRules layering_rules, EpNameToLayeringIndices ep_name_to_layering_indices, LayeringIndexToEpName layering_index_to_ep_name)
+ LayeringIndex(LayeringRules layering_rules, EpNameToLayeringIndices ep_name_to_layering_indices,
+ LayeringIndexToEpName layering_index_to_ep_name,
+ std::optional substring_matcher = std::nullopt)
: rules_(std::move(layering_rules)),
matcher_(rules_),
ep_name_to_layering_indices_(std::move(ep_name_to_layering_indices)),
- layering_index_to_ep_name_(std::move(layering_index_to_ep_name)) {}
+ layering_index_to_ep_name_(std::move(layering_index_to_ep_name)),
+ substring_matcher_(std::move(substring_matcher)) {}
+
+ // Optional substring matcher for name-based layer assignment
+ std::optional substring_matcher_;
// Graph and sub-graphs mapping to their indices
InlinedHashMap graph_index_;
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index cc0b8533033a6..5d537fa59bfab 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1956,7 +1956,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
+ if (!hasInputShape(ctx, 0)) {
+ return;
+ }
auto& bias_table_shape = getInputShape(ctx, 0);
+ if (bias_table_shape.dim_size() < 2) {
+ fail_shape_inference("RelativePositionBias: bias_table must have rank >= 2");
+ }
TensorShapeProto output_shape;
output_shape.add_dim()->set_dim_value(1);
*output_shape.add_dim() = bias_table_shape.dim(1);
@@ -2219,6 +2225,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
// Output shape: (batch_size, num_heads, seq_len, seq_len)
if (hasInputShape(ctx, 6)) {
auto& token_offset_shape = getInputShape(ctx, 6);
+ if (token_offset_shape.dim_size() < 2) {
+ fail_shape_inference("GatedRelativePositionBias: token_offset must have rank >= 2");
+ }
TensorShapeProto output_shape;
*output_shape.add_dim() = token_offset_shape.dim(0);
output_shape.add_dim()->set_dim_value(num_heads);
@@ -2317,6 +2326,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) {
auto& input_shape = getInputShape(ctx, 0);
auto& weight_shape = getInputShape(ctx, 1);
+ if (input_shape.dim_size() < 2) {
+ fail_shape_inference("CausalConvWithState: input must have rank >= 2");
+ }
+ if (weight_shape.dim_size() < 2) {
+ fail_shape_inference("CausalConvWithState: weight must have rank >= 2");
+ }
int64_t ndim = getAttribute(ctx, "ndim", 1);
TensorShapeProto state_shape;
*state_shape.add_dim() = input_shape.dim(0); // batch_size
@@ -2442,6 +2457,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
auto& query_shape = getInputShape(ctx, 0);
auto& value_shape = getInputShape(ctx, 2);
+ if (query_shape.dim_size() < 3) {
+ fail_shape_inference("LinearAttention: query must have rank >= 3");
+ }
+ if (value_shape.dim_size() < 3) {
+ fail_shape_inference("LinearAttention: value must have rank >= 3");
+ }
TensorShapeProto output_shape;
*output_shape.add_dim() = query_shape.dim(0); // B
*output_shape.add_dim() = query_shape.dim(1); // T
@@ -2459,6 +2480,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) {
auto& query_shape = getInputShape(ctx, 0);
auto& value_shape = getInputShape(ctx, 2);
+ if (query_shape.dim_size() < 3 || value_shape.dim_size() < 3) {
+ // Already validated in Output 0 block above; skip if shapes are invalid.
+ return;
+ }
TensorShapeProto state_shape;
*state_shape.add_dim() = query_shape.dim(0); // B
state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 1054fd94ef423..f3f2f521ecab2 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1520,18 +1520,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::STRING,
std::string("int"))
.Attr("weights_prepacked",
- "Only meaningful when quant_type='int'. Tri-state control over whether the "
- "int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS "
- "fpA_intB format expected by the runner. -1 (auto): let the execution provider "
- "choose its own backward-compatible default; the CUDA EP treats auto as "
- "prepacked. 1: the initializers are already prepacked (e.g. produced offline by "
- "pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers "
- "are raw, un-prepacked [E, N, K/pack] tensors as produced by "
- "quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself "
- "in PrePack(), matching the behaviour of MatMulNBits and removing the offline "
- "pre-pack requirement from exporters. Defaults to -1 (auto) so each execution "
- "provider can pick its own backward-compatible default rather than the schema "
- "imposing one.",
+ "Only meaningful when quant_type='int'. Tri-state control over the layout of the "
+ "int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by "
+ "-1 and 1 are determined by the execution provider. 0: the initializers are raw, "
+ "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.",
AttributeProto::INT,
static_cast(-1))
.Input(0,
diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc
index bf1b3a89d8813..f7e22608002ff 100644
--- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc
+++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc
@@ -134,6 +134,9 @@ void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ct
}
auto& input_shape = ctx.getInputType(0)->tensor_type().shape();
int64_t input_ndim = input_shape.dim_size();
+ if (input_ndim < 1) {
+ fail_shape_inference("SkipLayerNormalization: input must have rank >= 1");
+ }
int axis = static_cast(input_ndim - 1);
if (ctx.getNumOutputs() > 1) {
diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc
index 0334597549609..f083297c16fea 100644
--- a/onnxruntime/core/graph/graph_utils.cc
+++ b/onnxruntime/core/graph/graph_utils.cc
@@ -8,6 +8,7 @@
#include "core/common/logging/logging.h"
#include
+#include
#include
#include
#include
@@ -411,6 +412,14 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
return iter == attrs.end() ? nullptr : &iter->second;
}
+bool IsFullShapeNode(const Node& node) {
+ const auto* start_attr = GetNodeAttribute(node, "start");
+ const auto* end_attr = GetNodeAttribute(node, "end");
+ // end=INT64_MAX is the runtime default meaning "all dimensions" (full shape).
+ return (!start_attr || start_attr->i() == 0) &&
+ (!end_attr || end_attr->i() == std::numeric_limits::max());
+}
+
static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
ONNX_NAMESPACE::TypeProto new_type;
auto* typeproto_tensor = new_type.mutable_tensor_type();
diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h
index 2106da1a96327..5681d4e1d08f0 100644
--- a/onnxruntime/core/graph/graph_utils.h
+++ b/onnxruntime/core/graph/graph_utils.h
@@ -8,6 +8,7 @@
#include "core/graph/onnx_protobuf.h"
#include "core/graph/graph.h"
+#include
#include
#include
@@ -31,6 +32,10 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
/** Returns the attribute of a Node with a given name. */
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
+/** Checks whether a Shape node returns the full tensor shape (all dimensions).
+ * Returns false if start/end attributes restrict the output to a subset of dimensions. */
+bool IsFullShapeNode(const Node& node);
+
/** Add a new initializer to 'graph'.
Checks that new_initializer does not already exist in 'graph' before adding it.
@returns The NodeArg for the new initializer.
diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc
index 7fe7c914fa796..30eaaafccca82 100644
--- a/onnxruntime/core/optimizer/attention_fusion.cc
+++ b/onnxruntime/core/optimizer/attention_fusion.cc
@@ -349,7 +349,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0);
if (sequence_transpose == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
!HasExpectedPerm(*sequence_transpose, {0, 2, 1}) ||
!optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) {
return false;
@@ -357,14 +357,14 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0);
if (input_reshape == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
!optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) {
return fail("missing input Reshape before sequence transpose");
}
Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape");
if (qkv_reshape == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
!optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) {
return fail("qkv Reshape after MatMul not matched");
}
@@ -379,9 +379,9 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze");
Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose");
if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
!HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) ||
!HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) ||
!HasExpectedAxesInput(graph, *k_squeeze, {2})) {
@@ -391,8 +391,8 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze");
Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze");
if (q_squeeze == nullptr || v_squeeze == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
!HasExpectedAxesInput(graph, *q_squeeze, {0}) ||
!HasExpectedAxesInput(graph, *v_squeeze, {0})) {
return fail("q/v squeeze pattern not matched");
@@ -402,7 +402,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose");
if (q_scale_mul == nullptr || k_transpose == nullptr ||
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
!HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) {
return fail("q scale Mul or k Transpose(0,2,3,1) not matched");
}
@@ -460,7 +460,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose");
if (transpose_3 == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
!HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) ||
!optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) {
return fail("output Transpose(0,2,1,3) not matched");
@@ -468,7 +468,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape");
if (reshape_2 == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
!optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) {
return fail("output Reshape not matched");
}
@@ -497,7 +497,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
if (proj_gemm == nullptr) {
proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape");
if (proj_gemm_input_reshape == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
!optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) {
return fail("projection MatMul/Gemm not matched");
}
@@ -511,7 +511,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape");
if (proj_gemm_output_reshape == nullptr ||
- !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
!optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) {
return fail("normalized projection Gemm output Reshape not matched");
}
@@ -920,11 +920,11 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
}
std::vector q_path{
- {0, 0, "Transpose", {1, 13}, kOnnxDomain},
- {0, 0, "Reshape", {5, 13}, kOnnxDomain},
- {0, 0, "Add", {7, 13}, kOnnxDomain},
+ {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
+ {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain},
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
- {0, 0, "LayerNormalization", {1}, kOnnxDomain}};
+ {0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}};
if (!graph_utils::FindPath(edges[edges.size() - 1]->GetNode(), true, q_path, edges, logger)) {
DEBUG_LOG("Failed to find path for q");
return false;
@@ -953,9 +953,9 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
}
std::vector k_path{
- {0, 1, "Transpose", {1, 13}, kOnnxDomain},
- {0, 0, "Reshape", {5, 13}, kOnnxDomain},
- {0, 0, "Add", {7, 13}, kOnnxDomain},
+ {0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
+ {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
+ {0, 0, "Add", {7, 13, 14}, kOnnxDomain},
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
{0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}};
@@ -1070,8 +1070,8 @@ static bool FuseSubGraphQK(Node& layer_norm,
const logging::Logger& logger) {
// path to q
std::vector q_varience_path{
- {0, 0, "Div", {7, 13}, kOnnxDomain},
- {0, 0, "MatMul", {1, 9}, kOnnxDomain}};
+ {0, 0, "Div", {7, 13, 14}, kOnnxDomain},
+ {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}};
std::vector edges;
if (!graph_utils::FindPath(*(mask_nodes.add), true, q_varience_path, edges, logger)) {
DEBUG_LOG("Failed to find path for q");
@@ -1163,7 +1163,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
// path to q
std::vector q_varience_path{
{0, 2, "MatMul", {1, 9, 13}, kOnnxDomain},
- {0, 0, "Div", {7, 13}, kOnnxDomain}};
+ {0, 0, "Div", {7, 13, 14}, kOnnxDomain}};
std::vector edges;
if (!graph_utils::FindPath(*(mask_nodes.where), true, q_varience_path, edges, logger)) {
DEBUG_LOG("Failed to find path for q");
@@ -1265,14 +1265,14 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm,
std::map& mask_int32_map,
const logging::Logger& logger) {
std::vector