diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9d19a95136ad7..38d101786b41a 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -4949,7 +4949,7 @@ This version of the operator has been available since version 1 of the 'com.micr
use_sparse_mixer : int
Whether to use sparse mixer
weights_prepacked : int
-Only meaningful when quant_type='int'. Tri-state control over whether the int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS fpA_intB format expected by the runner. -1 (auto): let the execution provider choose its own backward-compatible default; the CUDA EP treats auto as prepacked. 1: the initializers are already prepacked (e.g. produced offline by pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself in PrePack(), matching the behaviour of MatMulNBits and removing the offline pre-pack requirement from exporters. Defaults to -1 (auto) so each execution provider can pick its own backward-compatible default rather than the schema imposing one.
+Only meaningful when quant_type='int'. Tri-state control over the layout of the int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by -1 and 1 are determined by the execution provider. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.
#### Inputs (6 - 21)
diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md
index 6d53211ff40cb..36b68889ae582 100644
--- a/docs/contrib_ops/cuda/moe_qmoe.md
+++ b/docs/contrib_ops/cuda/moe_qmoe.md
@@ -71,6 +71,7 @@ input tokens → router (top-k softmax) → permute by expert
| `expert_weight_bits` (QMoE only) | int | 4 | 4 (INT4/MXFP4) or 8 (INT8/FP8). |
| `block_size` (QMoE only) | int | -1 | Group size for INT4/INT8 group-wise quantization. -1 = per-output-channel. |
| `quant_type` (QMoE only) | string | `"int"` | `"int"`, `"fp4"`, `"fp8"`, `"wfp4afp8"`. See [§3](#3-quantization-modes). |
+| `weights_prepacked` (QMoE only) | int | -1 | Tri-state, only meaningful when `quant_type="int"`. The prepacked layouts selected by `-1` and `1` are **EP-determined**. `-1` (default): the INT4/INT8 `fc1`/`fc2` initializers are already prepacked in the EP's default layout (e.g. from `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). `1`: already prepacked in an alternate EP-selected layout. `0`: the initializers are raw `[E, N, K/pack]` tensors (as produced by `quantize_matmul_{4,8}bits`) and the kernel runs the CUTLASS layout transform in `PrePack()`. **Note:** the CUDA EP INT4/INT8 MoE GEMM always runs the Ampere (SM80) kernel — even on SM90 — so it consumes the SM80 `fpA_intB` layout on all architectures; `-1` and `1` are therefore equivalent for the CUDA EP today, and `1` is reserved for a possible future Hopper-specific layout. See [§5.1](#51-weights-input-2--5--8). |
### 2.2 Type Constraints
@@ -228,10 +229,53 @@ extra subtraction.
### 5.1 Weights (input 2 / 5 / 8)
-Not transformed at runtime. INT4/INT8 weights must already be packed offline by
-`pack_weights_for_cuda_mixed_gemm` (see [§6](#6-weight-formats)). MXFP4 weights
-must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights are stored
-as raw e4m3 bytes (no packing).
+**INT4/INT8** weight layout is controlled by the `weights_prepacked` attribute
+([§2.1](#21-attributes)). The prepacked layouts selected by `-1` and `1` are
+determined by the execution provider:
+
+- **`weights_prepacked=-1` (default)** — the `fc1`/`fc2` weights are already in
+ the EP's default prepacked layout (e.g. packed offline by
+ `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). They are copied to GPU
+ and consumed as-is.
+- **`weights_prepacked=1`** — the `fc1`/`fc2` weights are already in the EP's
+ **SM90** (Hopper) prepacked layout (reserved; see the note below).
+- **`weights_prepacked=0`** — the `fc1`/`fc2` weights are raw, schema-conformant
+ `[E, N, K/pack]` tensors as produced by `quantize_matmul_{4,8}bits`. `PrePack`
+ runs the CUTLASS layout transform itself via `PrePackIntExpertWeights`,
+ removing the offline pre-pack dependency. This makes integer QMoE symmetric
+ with `MatMulNBits::PrePack_B`.
+
+> **Single layout on the CUDA EP.** The CUDA EP INT4/INT8 MoE GEMM always
+> dispatches to the Ampere (**SM80**) grouped-GEMM kernel — even on SM90 —
+> because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA
+> warp-specialized specialisation (`isValidHopperMOESpecialisation` is `false`).
+> This matches **TensorRT-LLM**, which likewise routes `W4A16`/`W8A16` MoE to the
+> SM80 kernel on Hopper; its Hopper TMA-WS mixed-dtype MoE kernel is reserved for
+> `W4A8` (FP8 activation) and `WFP4A16` (FP4 weight). Consequently the CUDA EP
+> consumes the **SM80 `fpA_intB` layout on every GPU**, `PrePack` always packs
+> for SM80, and `weights_prepacked=-1` and `=1` are equivalent today. `1` is
+> accepted and reserved for a possible future Hopper-specific layout (e.g.
+> `W4A8`). There is therefore no architecture-match constraint: SM80-format
+> weights run correctly on SM90 via the SM80 kernel.
+
+`PrePackIntExpertWeights` loops over the `E` experts and, per expert, applies the
+same transpose + row-permutation / column-interleave / bias / pair-interleave
+transform as `pack_weights_for_cuda_mixed_gemm` (see [§6.1](#61-int4-group-wise-quant_typeint-expert_weight_bits4)),
+always targeting the SM80 layout. SM75+ is required. The source
+`[E, N, K/pack]` initializers are released after their shapes are cached
+(`fc1_weights_shape_` / `fc2_weights_shape_`), so peak weight memory stays ~1×.
+The prepacked GPU buffers (`packed_fc1_weights_` / `packed_fc2_weights_`) are then
+preferred by `ComputeInternal`. If prepacking is disabled at the session level
+(`session.disable_prepacking`), the buffers stay null and the raw initializer
+pointers are read at compute time instead.
+
+> **Note**: `weights_prepacked=0` is the only path that triggers an in-`PrePack`
+> layout transform for INT weights. FP4 / FP8 / WFP4AFP8 weight handling is
+> unaffected.
+
+MXFP4 weights must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights
+are stored as raw e4m3 bytes (no packing).
+
### 5.2 INT4/INT8 scales + zero-point → bias
@@ -287,7 +331,12 @@ This section covers the five distinct weight encodings supported by QMoE.
INT4 packing layout within a byte: `[high_nibble | low_nibble] = [elt_1 | elt_0]`.
Each INT4 element is in `[-8, 7]` (signed) before bias, `[0, 15]` after the +8 bias.
-#### Preprocessing pipeline (offline, `pack_weights_for_cuda_mixed_gemm`)
+#### Preprocessing pipeline (offline `pack_weights_for_cuda_mixed_gemm`, or in-`PrePack` via `PrePackIntExpertWeights`)
+
+This is the layout transform applied either offline by
+`pack_weights_for_cuda_mixed_gemm`, or per-expert inside `PrePack` when
+`weights_prepacked=0` (see [§5.1](#51-weights-input-2--5--8)).
+
1. **Input layout**: `[N, K]` per expert (Out × In), 2 elements per byte for INT4.
2. **Transpose & signed conversion**:
@@ -405,6 +454,17 @@ weights are interchangeable across SMs:
— does not use `pack_weights_for_cuda_mixed_gemm`.
- **FP8**: no packing.
+> **QMoE uses Group A on every GPU.** The table above describes the layouts the
+> `pack_weights_for_cuda_mixed_gemm` *preprocessor* can emit. The QMoE INT4/INT8
+> MoE GEMM, however, always dispatches to the Ampere (SM80) grouped-GEMM kernel —
+> even on SM90 — because mixed int-weight + fp16/bf16 activation is not a valid
+> Hopper TMA warp-specialized specialisation (the same is true in TensorRT-LLM).
+> It therefore consumes the **Group A (SM80) layout on all architectures,
+> including Hopper**. For QMoE, always pack INT4/INT8 weights for SM80 (`arch=80`),
+> and `PrePackIntExpertWeights` (`weights_prepacked=0`) does exactly that
+> regardless of the runtime device SM. Group B (SM90) layout is currently unused
+> by QMoE.
+
---
## 8. SwiGLU Fusion
@@ -830,7 +890,7 @@ will not change the operator interface.
|-----------|----------|
| [test_moe_cuda.py](onnxruntime/test/python/transformers/test_moe_cuda.py) | Standard MoE on CUDA: FP16/BF16, SiLU/GeLU/SwiGLU, routing, GEMM parity. SwiGLU coverage includes both GPT-OSS (`TestSwigluMoE`: interleaved, alpha=1.702/beta=1.0/limit=7.0) and Standard/Llama-Gemma (`TestStandardSwigluMoE`: concatenated `swiglu_fusion=2`, alpha=1.0/beta=0.0/no limit → `SiLU(Gate)×Value`). |
| [test_moe_cpu.py](onnxruntime/test/python/transformers/test_moe_cpu.py) | Standard MoE on CPU (smoke). |
-| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. |
+| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. `TestQMoEIntPrePackSmoke` covers the raw-weight `weights_prepacked=0` in-`PrePack` layout transform (smoke test: asserts finite output, not bit-parity). |
| [test_qmoe_cpu.py](onnxruntime/test/python/transformers/test_qmoe_cpu.py) | INT4/INT8 QMoE on CPU (smoke). |
| [test_qmoe_fp4_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py) | MXFP4 QMoE: quantization utilities, packing, FP16/BF16, SiLU/SwiGLU, top-k and expert-count variants. End-to-end runs on SM120; on SM<120 the dequant fallback is exercised. |
| [test_qmoe_fp8_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py) | FP8 W8A16 QMoE on SM90+ native path and SM<90 dequant fallback. |
@@ -954,6 +1014,11 @@ over-aligned by-value parameters.
cannot. See [§14.1](#141-msvc-and-tma-grouped-moe-gemm).
- **WFP4AFP8 native** requires SM100+ hardware; only the dequant fallback path
is validated end-to-end so far.
+- **In-`PrePack` INT weight layout transform** (`weights_prepacked=0`) is
+ currently covered only by a smoke test (`TestQMoEIntPrePackSmoke`), not a
+ bit-parity check: the existing offline pre-pack harness hardcodes
+ `force_arch=80` (the same SM80 layout consumed by the CUDA EP on all GPUs),
+ so a separate parity harness for this path is still pending.
- **Hopper W4A8** (INT4 weight + FP8 activation) is not supported — TRT-LLM gates
its fast path to SM89 only.
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 0cdbfd7114cf0..0289bca5c84d2 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -348,9 +348,6 @@ ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared exte
ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails);
ORT_RUNTIME_CLASS(EpAssignedSubgraph);
ORT_RUNTIME_CLASS(EpAssignedNode);
-ORT_RUNTIME_CLASS(ModelPackageOptions);
-ORT_RUNTIME_CLASS(ModelPackageContext);
-ORT_RUNTIME_CLASS(ModelPackageComponentContext);
#ifdef _MSC_VER
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
@@ -938,9 +935,6 @@ typedef struct OrtCompileApi OrtCompileApi;
struct OrtInteropApi;
typedef struct OrtInteropApi OrtInteropApi;
-struct OrtModelPackageApi;
-typedef struct OrtModelPackageApi OrtModelPackageApi;
-
struct OrtEpApi;
typedef struct OrtEpApi OrtEpApi;
@@ -7501,26 +7495,6 @@ struct OrtApi {
*/
ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
- /** \brief Get the model package API table.
- *
- * Returns a pointer to the ::OrtModelPackageApi function table, which provides APIs to:
- * - create and release model package options and contexts,
- * - inspect model package metadata (components/variants),
- * - select a component/variant and query selected files/options,
- * - create a session from model package selection results.
- *
- * The returned pointer is owned by ONNX Runtime and is valid for the process lifetime.
- * Do not free it.
- *
- * \note May return NULL if model package support is not available in the current build
- * (for example, minimal builds).
- *
- * \return Pointer to ::OrtModelPackageApi, or NULL if unsupported.
- *
- * \since Version 1.27.
- */
- const OrtModelPackageApi*(ORT_API_CALL* GetModelPackageApi)(void);
-
/** \brief Retrieve an experimental function pointer by name.
*
* Experimental functions are not part of the stable ABI and may be added or removed between releases without notice.
@@ -8645,250 +8619,6 @@ struct OrtInteropApi {
/// @}
};
-/** \brief API table for model package workflows.
- *
- * A model package is a directory containing one or more *components* (logical models).
- * Each component has one or more *variants*, where each variant targets a single
- * execution provider (EP). The package manifest declares the EP name, device type,
- * and an optional compatibility string for every variant so that the runtime can
- * automatically select the best variant for the hardware and EPs available in the
- * caller's session options.
- *
- * Obtain this table from OrtApi::GetModelPackageApi(). The APIs support:
- * - creating model package options that capture EP configuration from OrtSessionOptions,
- * - loading a package context (manifest + metadata) from a package root path,
- * - querying component/variant metadata including per-variant EP information,
- * - selecting a component (which also resolves the best-matching variant),
- * - querying the selected variant's name and folder path,
- * - creating an OrtSession from the selected component context.
- *
- * Typical flow:
- * 1) Create model package options:
- * - CreateModelPackageOptionsFromSessionOptions()
- * 2) Load package metadata:
- * - CreateModelPackageContext()
- * 3) Query metadata (optional):
- * - ModelPackage_GetSchemaVersion()
- * - ModelPackage_GetComponentCount()
- * - ModelPackage_GetComponentNames()
- * - ModelPackage_GetVariantCount()
- * - ModelPackage_GetVariantNames()
- * - ModelPackage_GetVariantEpName()
- * 4) Select a component and resolve variant:
- * - SelectComponent()
- * 5) Query selected variant info (optional):
- * - ModelPackageComponent_GetSelectedVariantName()
- * - ModelPackageComponent_GetSelectedVariantFolderPath()
- * 6) Create session:
- * - CreateSession()
- *
- * Ownership:
- * - Release objects created by this API with the corresponding release methods:
- * ReleaseModelPackageOptions(), ReleaseModelPackageContext(),
- * ReleaseModelPackageComponentContext().
- *
- * \since Version 1.27.
- */
-struct OrtModelPackageApi {
- /// \name OrtModelPackageOptions
- /// @{
-
- /** \brief Create model package options from an existing OrtSessionOptions.
- *
- * Captures EP configuration (registered execution providers and their devices) from
- * the session options for use during variant selection. The resulting OrtModelPackageOptions
- * is passed to SelectComponent() to resolve the best variant for the available EPs.
- *
- * \param[in] env The ORT environment.
- * \param[in] session_options Session options containing registered EPs.
- * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released
- * with ReleaseModelPackageOptions().
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateModelPackageOptionsFromSessionOptions,
- _In_ const OrtEnv* env,
- _In_ const OrtSessionOptions* session_options,
- _Outptr_ OrtModelPackageOptions** out);
-
- ORT_CLASS_RELEASE(ModelPackageOptions);
- /// @}
- /// \name OrtModelPackageContext
- /// @{
-
- /** \brief Create a model package context by parsing the package at the given root path.
- *
- * Parses the manifest.json and component metadata from the specified directory.
- * The returned context provides read-only access to the package structure (components,
- * variants, EP declarations).
- *
- * \param[in] package_root Path to the model package root directory (containing manifest.json).
- * \param[out] out Receives the newly created OrtModelPackageContext. Must be released
- * with ReleaseModelPackageContext().
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateModelPackageContext,
- _In_ const ORTCHAR_T* package_root,
- _Outptr_ OrtModelPackageContext** out);
-
- ORT_CLASS_RELEASE(ModelPackageContext);
-
- /** \brief Get the schema version declared in the model package manifest.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_version Receives the schema version number.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetSchemaVersion,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ int64_t* out_version);
-
- /** \brief Get the number of components in the model package.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_count Receives the component count.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetComponentCount,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ size_t* out_count);
-
- /** \brief Get the names of all components in the model package.
- *
- * Returns a pointer to an array of UTF-8 component name strings. The array and its
- * strings are owned by `ctx` and remain valid until the context is released.
- *
- * \param[in] ctx The model package context.
- * \param[out] out_names Receives a pointer to an array of component name strings.
- * \param[out] out_count Receives the number of elements in the array.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetComponentNames,
- _In_ const OrtModelPackageContext* ctx,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
- _Out_ size_t* out_count);
-
- /** \brief Get the number of variants for a given component.
- *
- * \param[in] ctx The model package context.
- * \param[in] component_name Name of the component to query.
- * \param[out] out_count Receives the variant count.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantCount,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Out_ size_t* out_count);
-
- /** \brief Get the names of all variants for a given component.
- *
- * Returns a pointer to an array of UTF-8 variant name strings. The array and its
- * strings are owned by `ctx` and remain valid until the context is released.
- *
- * \param[in] ctx The model package context.
- * \param[in] component_name Name of the component to query.
- * \param[out] out_variant_names Receives a pointer to an array of variant name strings.
- * \param[out] out_count Receives the number of elements in the array.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantNames,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
- _Out_ size_t* out_count);
-
- /** \brief Get the EP name declared for a (component, variant) pair.
- *
- * Each variant targets a single EP. `out_ep` receives the EP name string.
- * When the variant does not declare an EP, the returned pointer is NULL.
- * String memory is owned by `ctx` and remains valid until the context is released.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackage_GetVariantEpName,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _In_ const char* variant_name,
- _Outptr_result_maybenull_ const char** out_ep);
-
- /** \brief Select a component model and return an opaque component instance.
- *
- * The variant selection is also performed during this call based on the component metadata and the provided options.
- * The returned `OrtModelPackgeComponentContext*` is independent of `context` lifetime and must be released via
- * `ReleaseComponentInstance`.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(SelectComponent,
- _In_ const OrtModelPackageContext* context,
- _In_ const char* component_name,
- _In_ const OrtModelPackageOptions* options,
- _Outptr_ OrtModelPackageComponentContext** out);
-
- ORT_CLASS_RELEASE(ModelPackageComponentContext);
-
- /** \brief Get the name of the selected variant after SelectComponent has been called.
- *
- * String memory is owned by `ctx` and remains valid until the context is released.
- *
- * \param[in] ctx The component context returned by SelectComponent().
- * \param[out] out_name Receives the selected variant's name string.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantName,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const char** out_name);
-
- /** \brief Get the folder path of the selected variant.
- *
- * Returns the resolved absolute path to the variant's directory on disk.
- * The string is owned by `ctx` and remains valid until the context is released.
- *
- * \param[in] ctx The component context returned by SelectComponent().
- * \param[out] folder_path Receives the variant folder path string.
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantFolderPath,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const ORTCHAR_T** folder_path);
-
- /// @}
- /** \brief Create an OrtSession for a selected file within a component model variant.
- *
- * The chosen variant (and thus its EP selection) is determined by `context`, which
- * was built from an OrtSessionOptions via CreateModelPackageOptionsFromSessionOptions.
- *
- * Session options precedence:
- * 1. session_options == NULL (default path):
- * ORT uses the OrtSessionOptions that was captured when `context` was created.
- * Any variant-specific session and provider options declared in the package
- * metadata are merged on top.
- *
- * 2. session_options != NULL (advanced path):
- * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific
- * session and provider options from the package metadata are NOT applied.
- * Use this when custom EP setup is required (e.g., shared CUDA streams,
- * shared QNN EP contexts, custom allocators).
- *
- * \since Version 1.27.
- */
- ORT_API2_STATUS(CreateSession,
- _In_ const OrtEnv* env,
- _In_ OrtModelPackageComponentContext* context,
- _In_opt_ const OrtSessionOptions* session_options,
- _Outptr_ OrtSession** session);
-
- // End of Version 1.27 - DO NOT MODIFY ABOVE
-};
-
/*
* This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality
* This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index a999ef5e0faf4..4798d3d4ad1b8 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -264,20 +264,6 @@ inline const OrtEpApi& GetEpApi() {
return *api;
}
-///
-/// This returns a reference to the ORT C Model Package API. Used for loading models from model packages.
-///
-/// ORT C Model Package API reference
-inline const OrtModelPackageApi& GetModelPackageApi() {
- auto* api = GetApi().GetModelPackageApi();
- if (api == nullptr) {
- // minimal build
- ORT_CXX_API_THROW("Model Package API is not available in this build", ORT_FAIL);
- }
-
- return *api;
-}
-
/** \brief IEEE 754 half-precision floating point data type
*
* \details This struct is used for converting float to float16 and back
@@ -678,9 +664,6 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageOptions, GetModelPackageApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageContext, GetModelPackageApi);
-ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageComponentContext, GetModelPackageApi);
// This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type,
// but the struct has V2 in its name to indicate that it is the second version of the options.
@@ -807,9 +790,6 @@ struct EpDevice;
struct ExternalInitializerInfo;
struct Graph;
struct Model;
-struct ModelPackageOptions;
-struct ModelPackageContext;
-struct ModelPackageComponentContext;
struct Node;
struct ModelMetadata;
struct TypeInfo;
@@ -1806,70 +1786,6 @@ struct ModelCompilationOptions : detail::Base {
*/
Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options);
-/** \brief Options for selecting a component from a model package.
- *
- * Wraps ::OrtModelPackageOptions. Created from an Env and SessionOptions, which captures the
- * EP configuration used for variant selection.
- */
-struct ModelPackageOptions : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageOptions(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- ModelPackageOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions
- ModelPackageOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions
-};
-
-/** \brief Context for inspecting and selecting components from a model package.
- *
- * Wraps ::OrtModelPackageContext. Provides traversal APIs to enumerate components, variants,
- * and EP compatibility, as well as component selection.
- */
-struct ModelPackageContext : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- explicit ModelPackageContext(const ORTCHAR_T* package_root); ///< Wraps OrtModelPackageApi::CreateModelPackageContext
-
- size_t GetComponentCount() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentCount
- std::vector GetComponentNames() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentNames
- size_t GetVariantCount(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantCount
- std::vector GetVariantNames(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantNames
-
- /// Get the EP name for a variant. Returns nullptr if not declared.
- /// Returned string is owned by this context and valid until it is released.
- const char* GetVariantEpName(const char* component_name,
- const char* variant_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantEpName
-
- int64_t GetSchemaVersion() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetSchemaVersion
-
- ModelPackageComponentContext SelectComponent(const char* component_name,
- const ModelPackageOptions& options) const; ///< Wraps OrtModelPackageApi::SelectComponent
-};
-
-/** \brief Context for a selected component within a model package.
- *
- * Wraps ::OrtModelPackageComponentContext. Provides accessors for the selected variant's
- * folder path and variant name.
- */
-struct ModelPackageComponentContext : detail::Base {
- using Base = detail::Base;
- using Base::Base;
-
- explicit ModelPackageComponentContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used.
-
- std::basic_string GetSelectedVariantFolderPath() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantFolderPath
-
- std::string GetSelectedVariantName() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantName
-
- Session CreateSession(const Env& env); ///< Wraps OrtModelPackageApi::CreateSession (default path, NULL session_options)
- Session CreateSession(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path)
- Session CreateSession(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path)
-};
-
/** \brief Wrapper around ::OrtModelMetadata
*
*/
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 51f99655121c6..d7439e7b356c6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1360,110 +1360,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const Ort
return *this;
}
-// ModelPackageOptions
-inline ModelPackageOptions::ModelPackageOptions(const Env& env, const SessionOptions& session_options) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_));
-}
-
-inline ModelPackageOptions::ModelPackageOptions(const Env& env, ConstSessionOptions session_options) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_));
-}
-
-// ModelPackageContext
-inline ModelPackageContext::ModelPackageContext(const ORTCHAR_T* package_root) {
- ThrowOnError(GetModelPackageApi().CreateModelPackageContext(package_root, &this->p_));
-}
-
-inline size_t ModelPackageContext::GetComponentCount() const {
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentCount(this->p_, &count));
- return count;
-}
-
-inline std::vector ModelPackageContext::GetComponentNames() const {
- const char* const* names = nullptr;
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentNames(this->p_, &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
-}
-
-inline size_t ModelPackageContext::GetVariantCount(const char* component_name) const {
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantCount(this->p_, component_name, &count));
- return count;
-}
-
-inline std::vector ModelPackageContext::GetVariantNames(const char* component_name) const {
- const char* const* names = nullptr;
- size_t count = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantNames(this->p_, component_name, &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
-}
-
-inline const char* ModelPackageContext::GetVariantEpName(const char* component_name,
- const char* variant_name) const {
- const char* ep = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantEpName(
- this->p_, component_name, variant_name, &ep));
- return ep;
-}
-
-inline int64_t ModelPackageContext::GetSchemaVersion() const {
- int64_t version = 0;
- ThrowOnError(GetModelPackageApi().ModelPackage_GetSchemaVersion(this->p_, &version));
- return version;
-}
-
-inline ModelPackageComponentContext ModelPackageContext::SelectComponent(
- const char* component_name, const ModelPackageOptions& options) const {
- OrtModelPackageComponentContext* out = nullptr;
- ThrowOnError(GetModelPackageApi().SelectComponent(this->p_, component_name, options, &out));
- return ModelPackageComponentContext{out};
-}
-
-// ModelPackageComponentContext
-inline std::basic_string ModelPackageComponentContext::GetSelectedVariantFolderPath() const {
- const ORTCHAR_T* path = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantFolderPath(this->p_, &path));
- return std::basic_string{path};
-}
-
-inline std::string ModelPackageComponentContext::GetSelectedVariantName() const {
- const char* name = nullptr;
- ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantName(this->p_, &name));
- return (name != nullptr) ? std::string{name} : std::string{};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, nullptr, &out));
- return Session{out};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env,
- const SessionOptions& session_options) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out));
- return Session{out};
-}
-
-inline Session ModelPackageComponentContext::CreateSession(const Env& env,
- ConstSessionOptions session_options) {
- OrtSession* out = nullptr;
- ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out));
- return Session{out};
-}
-
namespace detail {
template
diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
index e943a5cf65b11..0dd87c10776d3 100644
--- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h
@@ -39,6 +39,10 @@
// ORT_RUNTIME_CLASS(ExperimentalType);
//
+ORT_RUNTIME_CLASS(ModelPackageOptions);
+ORT_RUNTIME_CLASS(ModelPackageContext);
+ORT_RUNTIME_CLASS(ModelPackageComponentContext);
+
//
// C: function pointer typedefs and name constants
//
diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
index 0123b02818584..57a4e472b6f6d 100644
--- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
+++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc
@@ -35,3 +35,250 @@
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t* out)
+
+// ---------------------------------------------------------------------------
+// OrtModelPackageApi
+//
+// A model package is a directory containing one or more *components* (logical models).
+// Each component has one or more *variants*, where each variant targets a single
+// execution provider (EP). The package manifest declares the EP name, device type,
+// and an optional compatibility string for every variant so that the runtime can
+// automatically select the best variant for the hardware and EPs available in the
+// caller's session options.
+//
+// The functions below support:
+// - creating model package options that capture EP configuration from OrtSessionOptions,
+// - loading a package context (manifest + metadata) from a package root path,
+// - querying component/variant metadata including per-variant EP information,
+// - selecting a component (which also resolves the best-matching variant),
+// - querying the selected variant's name and folder path,
+// - creating an OrtSession from the selected component context.
+//
+// Typical flow:
+// 1) Create model package options:
+// - OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions
+// 2) Load package metadata:
+// - OrtModelPackageApi_CreateModelPackageContext
+// 3) Query metadata (optional):
+// - OrtModelPackageApi_ModelPackage_GetSchemaVersion
+// - OrtModelPackageApi_ModelPackage_GetComponentCount
+// - OrtModelPackageApi_ModelPackage_GetComponentNames
+// - OrtModelPackageApi_ModelPackage_GetVariantCount
+// - OrtModelPackageApi_ModelPackage_GetVariantNames
+// - OrtModelPackageApi_ModelPackage_GetVariantEpName
+// 4) Select a component and resolve variant:
+// - OrtModelPackageApi_SelectComponent
+// 5) Query selected variant info (optional):
+// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName
+// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath
+// 6) Create session:
+// - OrtModelPackageApi_CreateSession
+//
+// Ownership:
+// - Release objects created by this API with the corresponding release functions:
+// OrtModelPackageApi_ReleaseModelPackageOptions,
+// OrtModelPackageApi_ReleaseModelPackageContext,
+// OrtModelPackageApi_ReleaseModelPackageComponentContext.
+//
+// Opaque handles (OrtModelPackageOptions, OrtModelPackageContext, OrtModelPackageComponentContext)
+// are declared in onnxruntime_experimental_c_api.h.
+// ---------------------------------------------------------------------------
+
+/** \brief Create model package options from an existing OrtSessionOptions.
+ *
+ * Captures EP configuration (registered execution providers and their devices) from
+ * the session options for use during variant selection. The resulting OrtModelPackageOptions
+ * is passed to OrtModelPackageApi_SelectComponent to resolve the best variant for the
+ * available EPs.
+ *
+ * \param[in] env The ORT environment.
+ * \param[in] session_options Session options containing registered EPs.
+ * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released
+ * with OrtModelPackageApi_ReleaseModelPackageOptions.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions,
+ _In_ const OrtEnv* env,
+ _In_ const OrtSessionOptions* session_options,
+ _Outptr_ OrtModelPackageOptions** out)
+
+/** \brief Release an OrtModelPackageOptions created by
+ * OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions.
+ */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageOptions,
+ _Frees_ptr_opt_ OrtModelPackageOptions* options)
+
+/** \brief Create a model package context by parsing the package at the given root path.
+ *
+ * Parses the manifest.json and component metadata from the specified directory.
+ * The returned context provides read-only access to the package structure (components,
+ * variants, EP declarations).
+ *
+ * \param[in] package_root Path to the model package root directory (containing manifest.json).
+ * \param[out] out Receives the newly created OrtModelPackageContext. Must be released
+ * with OrtModelPackageApi_ReleaseModelPackageContext.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageContext,
+ _In_ const ORTCHAR_T* package_root,
+ _Outptr_ OrtModelPackageContext** out)
+
+/** \brief Release an OrtModelPackageContext created by OrtModelPackageApi_CreateModelPackageContext. */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageContext,
+ _Frees_ptr_opt_ OrtModelPackageContext* ctx)
+
+/** \brief Get the schema version declared in the model package manifest.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_version Receives the schema version number.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetSchemaVersion,
+ _In_ const OrtModelPackageContext* ctx,
+ _Out_ int64_t* out_version)
+
+/** \brief Get the number of components in the model package.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_count Receives the component count.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentCount,
+ _In_ const OrtModelPackageContext* ctx,
+ _Out_ size_t* out_count)
+
+/** \brief Get the names of all components in the model package.
+ *
+ * Returns a pointer to an array of UTF-8 component name strings. The array and its
+ * strings are owned by `ctx` and remain valid until the context is released.
+ *
+ * \param[in] ctx The model package context.
+ * \param[out] out_names Receives a pointer to an array of component name strings.
+ * \param[out] out_count Receives the number of elements in the array.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentNames,
+ _In_ const OrtModelPackageContext* ctx,
+ _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
+ _Out_ size_t* out_count)
+
+/** \brief Get the number of variants for a given component.
+ *
+ * \param[in] ctx The model package context.
+ * \param[in] component_name Name of the component to query.
+ * \param[out] out_count Receives the variant count.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantCount,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _Out_ size_t* out_count)
+
+/** \brief Get the names of all variants for a given component.
+ *
+ * Returns a pointer to an array of UTF-8 variant name strings. The array and its
+ * strings are owned by `ctx` and remain valid until the context is released.
+ *
+ * \param[in] ctx The model package context.
+ * \param[in] component_name Name of the component to query.
+ * \param[out] out_variant_names Receives a pointer to an array of variant name strings.
+ * \param[out] out_count Receives the number of elements in the array.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantNames,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
+ _Out_ size_t* out_count)
+
+/** \brief Get the EP name declared for a (component, variant) pair.
+ *
+ * Each variant targets a single EP. `out_ep` receives the EP name string.
+ * When the variant does not declare an EP, the returned pointer is NULL.
+ * String memory is owned by `ctx` and remains valid until the context is released.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantEpName,
+ _In_ const OrtModelPackageContext* ctx,
+ _In_ const char* component_name,
+ _In_ const char* variant_name,
+ _Outptr_result_maybenull_ const char** out_ep)
+
+/** \brief Select a component model and return an opaque component instance.
+ *
+ * The variant selection is also performed during this call based on the component
+ * metadata and the provided options. The returned `OrtModelPackageComponentContext*` is
+ * independent of `context` lifetime and must be released via
+ * OrtModelPackageApi_ReleaseModelPackageComponentContext.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_SelectComponent,
+ _In_ const OrtModelPackageContext* context,
+ _In_ const char* component_name,
+ _In_ const OrtModelPackageOptions* options,
+ _Outptr_ OrtModelPackageComponentContext** out)
+
+/** \brief Release an OrtModelPackageComponentContext created by OrtModelPackageApi_SelectComponent. */
+ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageComponentContext,
+ _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx)
+
+/** \brief Get the name of the selected variant after OrtModelPackageApi_SelectComponent has been called.
+ *
+ * String memory is owned by `ctx` and remains valid until the context is released.
+ *
+ * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent.
+ * \param[out] out_name Receives the selected variant's name string.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName,
+ _In_ const OrtModelPackageComponentContext* ctx,
+ _Outptr_ const char** out_name)
+
+/** \brief Get the folder path of the selected variant.
+ *
+ * Returns the resolved absolute path to the variant's directory on disk.
+ * The string is owned by `ctx` and remains valid until the context is released.
+ *
+ * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent.
+ * \param[out] folder_path Receives the variant folder path string.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath,
+ _In_ const OrtModelPackageComponentContext* ctx,
+ _Outptr_ const ORTCHAR_T** folder_path)
+
+/** \brief Create an OrtSession for a selected file within a component model variant.
+ *
+ * The chosen variant (and thus its EP selection) is determined by `context`, which
+ * was built from an OrtSessionOptions via OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions.
+ *
+ * Session options precedence:
+ * 1. session_options == NULL (default path):
+ * ORT uses the OrtSessionOptions that was captured when `context` was created.
+ * Any variant-specific session and provider options declared in the package
+ * metadata are merged on top.
+ *
+ * 2. session_options != NULL (advanced path):
+ * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific
+ * session and provider options from the package metadata are NOT applied.
+ * Use this when custom EP setup is required (e.g., shared CUDA streams,
+ * shared QNN EP contexts, custom allocators).
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ */
+ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession,
+ _In_ const OrtEnv* env,
+ _In_ OrtModelPackageComponentContext* context,
+ _In_opt_ const OrtSessionOptions* session_options,
+ _Outptr_ OrtSession** session)
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index c26424a0ab9ec..df14bc8c57f24 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -29,9 +29,6 @@
GraphOptimizationLevel, # noqa: F401
LoraAdapter, # noqa: F401
ModelMetadata, # noqa: F401
- ModelPackageComponentContext, # noqa: F401
- ModelPackageContext, # noqa: F401
- ModelPackageOptions, # noqa: F401
NodeArg, # noqa: F401
OrtAllocatorType, # noqa: F401
OrtArenaCfg, # noqa: F401
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
index d28aae02ab2f1..6e75bb1b5a7c0 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
@@ -48,7 +48,11 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences,
unique_word_ids.insert(word_id);
}
+ const int vocab_size = next_token_scores.vocab_size;
for (const int32_t word_id : unique_word_ids) {
+ if (word_id < 0 || word_id >= vocab_size) {
+ continue;
+ }
T score = beam_token_scores[word_id];
// If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability,
@@ -89,7 +93,11 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences,
}
}
+ const int vocab_size = next_token_scores.vocab_size;
for (const int32_t word_id : blocked_word_ids) {
+ if (word_id < 0 || word_id >= vocab_size) {
+ continue;
+ }
beam_token_scores[word_id] = std::numeric_limits::lowest();
}
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 6e157d83159f4..fb9638bb09d59 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -28,7 +28,9 @@ struct NextTokenScores {
}
void SetScore(int token_id, T score) {
- assert(token_id >= 0 && token_id < vocab_size);
+ if (token_id < 0 || token_id >= vocab_size) {
+ return;
+ }
for (int i = 0; i < batch_beam_size; i++) {
scores[static_cast(i) * vocab_size + token_id] = score;
}
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
index b9e62443145e5..c3b734816cf84 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h
@@ -31,6 +31,8 @@ enum class QuantType {
W4_AFP8
};
+int get_arch_for_mixed_gemm_weight_preprocess(int arch);
+
void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream,
int arch,
int8_t* preprocessed_quantized_weight,
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
index a006612ddadc9..7e83bdda72eab 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu
@@ -521,6 +521,19 @@ void add_bias_and_interleave_quantized_tensor_inplace_cuda(
}
}
+int get_arch_for_mixed_gemm_weight_preprocess(int arch) {
+ ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch);
+ if (arch < 80) {
+ return 75;
+ }
+#ifndef EXCLUDE_SM_90
+ if (arch >= 90 && arch < 100) {
+ return 90;
+ }
+#endif
+ return 80;
+}
+
void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream,
int arch,
int8_t* preprocessed_quantized_weight,
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
index a8fb411ed0663..47bbe0c0e10ec 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h
@@ -120,11 +120,11 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) {
}
LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) {
- ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch);
- if (arch < 80) {
+ arch = get_arch_for_mixed_gemm_weight_preprocess(arch);
+ if (arch == 75) {
return getLayoutDetailsForArch(quant_type);
#ifndef EXCLUDE_SM_90
- } else if (arch >= 90 && arch < 100) {
+ } else if (arch == 90) {
return getLayoutDetailsForArch(quant_type);
#endif
} else {
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
index e1ddcac0cea4f..bc34a2e83318a 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
@@ -62,18 +62,28 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE
this->quant_type_ = op_kernel_info.GetAttrOrDefault("quant_type", "int");
ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8",
"quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'");
- // ``weights_prepacked`` is an optional tri-state attribute that defaults to
- // -1 (auto) in the schema, so each EP picks its own backward-compatible
- // default rather than the schema imposing one:
- // -1 (auto, also the schema default): the EP decides. The CUDA EP's
- // backward-compatible default is "prepacked" because all pre-existing
- // tooling ships CUTLASS-prepacked weights.
- // 1: initializers are already prepacked; the compute path reads them as-is.
- // 0: initializers are raw [E, N, K/pack]; the PrePack hook lays them out.
+ // ``weights_prepacked`` is an optional tri-state attribute (default -1) that
+ // declares the layout of the int4/int8 fc1/fc2 weight initializers. The
+ // concrete prepacked layouts selected by -1 and 1 are determined by the
+ // execution provider. The CUDA EP maps the tri-state as:
+ // -1 (default): already prepacked in the EP's default int weight layout.
+ // 1: already prepacked in an alternate EP-selected int weight layout.
+ // 0: raw [E, N, K/pack] initializers; the PrePack hook lays them out.
+ //
+ // Important: the CUDA QMoE int4/int8 MoE GEMM always dispatches to the
+ // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed
+ // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized
+ // specialisation (see isValidHopperMOESpecialisation). The kernel therefore
+ // consumes the SM80/Ampere CUTLASS fpA_intB layout on every GPU. As a result
+ // the EP default (-1) is the SM80 layout regardless of the runtime device SM,
+ // and SM80-format weights are valid on SM90 (they run via the SM80 kernel).
+ // For CUDA today, -1 and 1 are equivalent (both SM80 layout), and 1 is
+ // reserved for a possible future Hopper-specific layout.
+ // PrePack (weights_prepacked=0) packs for the SM80 layout accordingly.
const int64_t weights_prepacked_mode =
op_kernel_info.GetAttrOrDefault("weights_prepacked", static_cast(-1));
ORT_ENFORCE(weights_prepacked_mode == -1 || weights_prepacked_mode == 0 || weights_prepacked_mode == 1,
- "weights_prepacked must be -1 (auto), 0, or 1, but got ", weights_prepacked_mode);
+ "weights_prepacked must be -1 (default), 0, or 1, but got ", weights_prepacked_mode);
weights_prepacked_ = (weights_prepacked_mode != 0);
#if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE)
ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer.");
@@ -224,6 +234,18 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
// to the runner.
const bool int_weights_consumed_by_prepack =
is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr && packed_fc2_weights_ != nullptr;
+ // When ``weights_prepacked == 0`` the raw ``[E, N, K/pack]`` int weights must be
+ // converted to the CUTLASS fpA_intB layout by PrePack before the runner can consume
+ // them. If PrePack never ran (e.g. ``session.disable_prepacking`` is set), the prepack
+ // buffers stay null and falling through to the raw initializer pointers would feed
+ // non-CUTLASS bytes to the runner, producing silently wrong output. Fail loudly instead.
+ if (is_int && !weights_prepacked_ &&
+ (packed_fc1_weights_ == nullptr || packed_fc2_weights_ == nullptr)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "QMoE weights_prepacked=0 requires PrePack to run, but the int weight "
+ "buffers were not produced (is session.disable_prepacking set?). Provide "
+ "CUTLASS-prepacked weights with weights_prepacked=1, or enable prepacking.");
+ }
const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input(2);
const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input(3) : nullptr;
const Tensor* fc1_experts_bias_optional = context->Input(4);
@@ -844,13 +866,24 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr;
const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr;
if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) {
- fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data;
- fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data;
+ // The native CUTLASS WFP4AFP8 path consumes weights in the repacked FP4
+ // layout produced by PrePack. If PrePack never ran (e.g.
+ // ``session.disable_prepacking`` is set) the repacked buffers stay null and
+ // falling through to the raw initializer bytes would feed a non-CUTLASS
+ // layout to the runner, producing silently wrong output. Fail loudly.
+ if (packed_fp4_fc1_weights_ == nullptr || packed_fp4_fc2_weights_ == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "QMoE wfp4afp8 requires PrePack to run, but the repacked FP4 weight "
+ "buffers were not produced (is session.disable_prepacking set?). "
+ "Enable prepacking to use the native WFP4AFP8 path.");
+ }
+ fc1_weight_data = packed_fp4_fc1_weights_.get();
+ fc2_weight_data = packed_fp4_fc2_weights_.get();
} else if (int_weights_consumed_by_prepack) {
// PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB
// layout that the runner consumes and freed the source initializer
// (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack``
- // (which already requires ``packed_fc1_weights_ != nullptr``) rather than
+ // (which already requires both packed weight buffers) rather than
// just ``is_int && !weights_prepacked_``: when prepacking is disabled at
// the session level (``session.disable_prepacking``) PrePack never runs,
// the prepack buffers stay null, and the raw initializer pointers read
@@ -1146,6 +1179,9 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al
IAllocatorUniquePtr& packed_buf, bool& is_packed) {
ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8,
"PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_);
+ ORT_ENFORCE(sm_ >= 75,
+ "PrePackIntExpertWeights: quant_type='int' with weights_prepacked=0 requires SM75+ CUDA hardware, got SM",
+ sm_);
const auto& shape = tensor.Shape();
ORT_ENFORCE(shape.NumDimensions() == 3,
"PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=",
@@ -1158,22 +1194,15 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al
const int64_t k_packed = shape[2];
const int64_t k = k_packed * pack_factor;
- // Weight packing is architecture-aware (see
- // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing
- // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that
- // skips column interleaving, so it is its own compatibility group. Every
- // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares
- // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack
- // INT8 LDSM and are unsupported. The compute-side runner selects the same
- // layout from this clamped arch, so the two cannot drift.
- //
- // SM75 is passed through unchanged (rather than clamped to 80) even though it
- // shares SM80's layout: the compute-side dispatch (getLayoutDetailsForTransform)
- // still has a distinct SM75 branch, so mirroring it here avoids confusing a
- // reader into thinking prepack and dispatch disagree.
- ORT_ENFORCE(sm_ >= 75,
- "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_);
- const int packing_sm = (sm_ == 90 || sm_ == 75) ? sm_ : 80;
+ // The CUDA QMoE int4/int8 MoE GEMM always dispatches to the Ampere (SM80)
+ // grouped-GEMM kernel -- even on SM90 -- because mixed int-weight + fp16/bf16
+ // is not a valid Hopper TMA warp-specialized specialisation. The kernel thus
+ // consumes the SM80 CUTLASS fpA_intB layout on every GPU, so the weights must
+ // always be preprocessed for SM80 regardless of the runtime device SM.
+ // (Using get_arch_for_mixed_gemm_weight_preprocess(sm_) here would emit the
+ // SM90 layout on Hopper, which the SM80 kernel cannot consume -> wrong output.)
+ const int packing_sm =
+ onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(80);
// Per-expert sizes.
const size_t per_expert_bytes = static_cast(n) * static_cast(k) / pack_factor;
diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
index 5722ac41cc470..2bbadc205b5d8 100644
--- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
+++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
@@ -46,16 +46,23 @@ class QMoE final : public CudaKernel, public MoEBase {
IAllocatorUniquePtr& packed_buf, bool& is_packed);
int64_t expert_weight_bits_;
bool is_fp16_;
- // When true (the schema default), the int4/int8 fc1/fc2 weight
- // initializers are already in the CUTLASS fpA_intB layout — produced
- // offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the
- // compute path reads them as-is. When false, the raw schema-conformant
- // ``[E, N, K/pack]`` layout (as produced by
- // ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook
- // via ``PrePackIntExpertWeights``, removing the offline prepack
- // dependency. Only meaningful when ``quant_type_ == "int"``. Derived from
- // the optional tri-state ``weights_prepacked`` attribute: -1/auto (or
- // absent) maps to true on the CUDA EP, 1 maps to true, 0 maps to false.
+ // When true, the int4/int8 fc1/fc2 weight initializers are already in a
+ // CUTLASS fpA_intB layout — produced offline e.g. via
+ // ``pack_weights_for_cuda_mixed_gemm`` — and the compute path reads them
+ // as-is. When false, the raw schema-conformant ``[E, N, K/pack]`` layout
+ // (as produced by ``quantize_matmul_{4,8}bits``) is rewritten inside the
+ // PrePack hook via ``PrePackIntExpertWeights``, removing the offline
+ // prepack dependency. Only meaningful when ``quant_type_ == "int"``.
+ // Derived from the optional tri-state ``weights_prepacked`` attribute:
+ // -1 (default) and 1 both map to true; 0 maps to false. The concrete
+ // prepacked layouts selected by -1 and 1 are determined by the execution
+ // provider. For the CUDA EP the int4/int8 MoE GEMM always dispatches to the
+ // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed
+ // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized
+ // specialisation (matches TensorRT-LLM, which also routes W4A16/W8A16 MoE to
+ // the SM80 kernel on Hopper). The kernel therefore consumes the SM80 fpA_intB
+ // layout on every GPU, so -1 and 1 are currently equivalent for the CUDA EP;
+ // 1 is reserved for a possible future Hopper-specific layout (e.g. W4A8).
bool weights_prepacked_ = true;
// Cached source weight shapes captured at PrePack time. When the
// PrePack hook consumed and released the original int4/int8 weight
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
index 684e050f0201a..02e764d01e05e 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
@@ -16,6 +16,40 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {
+// WGSL helper function for normalizing on-device indirect dispatch dims.
+// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram.
+// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers:
+// 1) direct (x, y, z) write when every dim is within the spec limit (65535);
+// 2) 2D sqrt collapse when the product fits a square layout;
+// 3) 3D cbrt collapse otherwise.
+// Consumers are unaffected by the chosen layout: ShaderHelper flattens
+// workgroup_id (x, y, z) into a single linear workgroup_idx.
+// Caller contract: must register a storage output named exactly
+// `indirect_buffer` of array with at least 3 elements.
+constexpr const char kNormalizeDispatchGroupSizeFn[] = R"(
+fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) {
+ let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension
+ if (x <= limit && y <= limit && z <= limit) {
+ indirect_buffer[0] = x;
+ indirect_buffer[1] = y;
+ indirect_buffer[2] = z;
+ return;
+ }
+ let size = f32(x) * f32(y) * f32(z);
+ let dispatch_avg_2d = u32(ceil(sqrt(size)));
+ if (dispatch_avg_2d <= limit) {
+ indirect_buffer[0] = dispatch_avg_2d;
+ indirect_buffer[1] = dispatch_avg_2d;
+ indirect_buffer[2] = 1u;
+ return;
+ }
+ let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0)));
+ indirect_buffer[0] = dispatch_avg_3d;
+ indirect_buffer[1] = dispatch_avg_3d;
+ indirect_buffer[2] = dispatch_avg_3d;
+}
+)";
+
Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform);
const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform);
@@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha
if (prepare_indirect_dispatch_) {
sh.AddOutput("indirect_buffer", ShaderUsage::None);
+ sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn;
}
return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
@@ -87,13 +122,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add indirect dispatch logic for thread 0
if (prepare_indirect_dispatch_) {
- // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size.
- shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n"
- << " if (global_idx == 0u) {\n"
+ shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn;
+ shader.MainFunctionBody() << " if (global_idx == 0u) {\n"
<< " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n"
- << " indirect_buffer[0] = num_total_seq_length_tile;\n"
- << " indirect_buffer[1] = uniforms.num_heads;\n"
- << " indirect_buffer[2] = 1u;\n"
+ << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n"
<< " }\n\n";
}
@@ -120,7 +152,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
const Tensor* K, const Tensor* past_key, Tensor* present_key,
const Tensor* V, const Tensor* past_value, Tensor* present_value,
- uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) {
+ uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) {
// CopyKVCache takes past key/value and current key/value and copies them to present key and value.
// This makes it so that FlashAttention only needs to look at present key and value, and saves
// number of input buffers in the shader, which we run out of (<=8) without this optimization.
@@ -176,7 +208,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
{static_cast(parameters.total_sequence_length_)},
{static_cast(parameters.kv_sequence_length_)},
{tile_size},
- {static_cast(parameters.num_heads_)}});
+ {static_cast(parameters.num_heads_)},
+ {static_cast(parameters.batch_size_)},
+ {num_q_tiles}});
return context.RunProgram(program);
}
@@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const {
WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_));
}
-Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
- shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+ const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_attention_bias_) {
shader.AddInput("attention_bias", ShaderUsage::UseUniform);
}
- shader.AddOutput("output", ShaderUsage::UseUniform);
- shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
+ const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const uint32_t tile_size_k_vec = 8;
const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec;
- return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template",
+ return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_),
+ WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_),
+ WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_),
+ WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_),
WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
+ WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_),
+ WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_),
+ WGSL_TEMPLATE_VARIABLE(metadata, metadata),
+ WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx),
+ WGSL_TEMPLATE_VARIABLE(present_key, present_key),
+ WGSL_TEMPLATE_VARIABLE(present_value, present_value),
+ WGSL_TEMPLATE_VARIABLE(q, q));
}
-Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
- const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k,
- const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
+Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
+ const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value,
+ Tensor* metadata, const Tensor* seqlen_k,
+ const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) {
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_))
: parameters.scale_;
const bool has_attention_bias = attention_bias != nullptr;
const int components = 4;
+ const int head_size_vec = parameters.v_head_size_ / components;
- FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch};
+ bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH;
+ bool is_unidirectional = parameters.is_unidirectional_;
+ FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components},
- {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}});
+ {present_key, ProgramTensorMetadataDependency::TypeAndRank, components},
+ {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}
if (has_attention_bias) {
program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank});
}
- program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank},
+ program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
{metadata, ProgramTensorMetadataDependency::Rank, 2}});
const uint32_t vectorized_head_size = parameters.head_size_ / components;
- // Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
@@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
if (use_indirect_dispatch) {
program.SetIndirectDispatchTensor(indirect_buffer);
} else {
- program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile);
+ program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile);
}
program.SetWorkgroupSize(64)
- .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch)
+ .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile)
.AddUniformVariables({{static_cast(vectorized_head_size)},
{static_cast(parameters.total_sequence_length_)},
{static_cast(alpha)},
@@ -294,124 +342,72 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
{static_cast(parameters.num_heads_)},
{static_cast(parameters.batch_size_)},
{attn_bias_dim0},
- {attn_bias_dim1}});
+ {attn_bias_dim1},
+ {static_cast(parameters.sequence_length_)}});
return context.RunProgram(program);
}
-Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("metadata", ShaderUsage::UseUniform);
- shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
- shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
+Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
+ const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
+ const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform);
if (use_indirect_dispatch_) {
shader.AddInput("seqlens_k", ShaderUsage::None);
}
if (has_head_sink_) {
shader.AddInput("head_sink", ShaderUsage::UseUniform);
}
- shader.AddOutput("out_split_vx", ShaderUsage::UseUniform);
-
- const uint32_t tile_size_k_vec = 8u;
-
- return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template",
- WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
- WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_),
- WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
- WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
-}
-
-Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context,
- const Tensor* metadata,
- const Tensor* qk,
- Tensor* out_split_vx,
- Tensor* present_value,
- const Tensor* seqlen_k,
- const WebgpuAttentionParameters& parameters,
- const Tensor* indirect_buffer,
- uint32_t num_total_seq_length_tile,
- uint32_t num_present_sequence_length_tile,
- uint32_t tile_size,
- bool use_indirect_dispatch,
- uint32_t present_sequence_length,
- const Tensor* head_sink) {
- const int components = 4;
- const bool has_head_sink = head_sink != nullptr;
- int head_size_vec = parameters.v_head_size_ / components;
- FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink};
- program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2},
- {qk, ProgramTensorMetadataDependency::TypeAndRank},
- {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}});
- program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size]
- const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_);
- if (use_indirect_dispatch) {
- program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
- }
- if (has_head_sink) {
- program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
- }
- // SetIndirectDispatchTensor must be called after all AddInput calls because it
- // appends the indirect buffer as the last program input.
- if (use_indirect_dispatch) {
- program.SetIndirectDispatchTensor(indirect_buffer);
- } else {
- program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile);
- }
- program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
- .SetWorkgroupSize(64)
- .AddUniformVariables({{static_cast(parameters.total_sequence_length_)},
- {static_cast(head_size_vec)},
- present_sequence_length,
- {static_cast(parameters.n_reps)},
- num_present_sequence_length_tile,
- {batch_heads},
- {static_cast(parameters.num_heads_)}});
-
- return context.RunProgram(program);
-}
-
-Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const {
- shader.AddInput("input", ShaderUsage::UseUniform);
- if (use_indirect_dispatch_) {
- shader.AddInput("seqlens_k", ShaderUsage::None);
- }
- shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
+ const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template",
+ WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_),
+ WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_),
WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_),
WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_),
- WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_));
+ WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_),
+ WGSL_TEMPLATE_VARIABLE(input, input),
+ WGSL_TEMPLATE_VARIABLE(metadata, metadata),
+ WGSL_TEMPLATE_VARIABLE(output, output));
}
Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context,
const Tensor* out_split_vx,
+ const Tensor* metadata,
Tensor* output,
const Tensor* seqlen_k,
const WebgpuAttentionParameters& parameters,
uint32_t num_total_seq_length_tile,
uint32_t num_present_sequence_length_tile,
uint32_t seq_tile_size,
- bool use_indirect_dispatch) {
+ bool use_indirect_dispatch,
+ const Tensor* head_sink,
+ uint32_t m_tile) {
const int components = 4;
constexpr int tile_size = 8;
int tile_head_size = tile_size * components;
- FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch};
- program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}});
+ bool has_head_sink = head_sink != nullptr;
+ FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile};
+ program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components},
+ {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}});
if (use_indirect_dispatch) {
program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None});
}
+ if (has_head_sink) {
+ program.AddInput({head_sink, ProgramTensorMetadataDependency::Type});
+ }
program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}});
const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size);
const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_);
- program.SetDispatchGroupSize(batch_heads * num_head_size_tile)
- .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch)
+ program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile)
+ .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile)
.SetWorkgroupSize(tile_size * tile_size)
.AddUniformVariables({{static_cast(parameters.v_head_size_ / components)},
num_total_seq_length_tile,
num_present_sequence_length_tile,
{num_head_size_tile},
- {batch_heads}});
+ {batch_heads},
+ {static_cast(parameters.sequence_length_)},
+ {static_cast(parameters.num_heads_)}});
return context.RunProgram(program);
}
@@ -446,14 +442,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
// Declare query_output at function scope to ensure it persists throughout the function
Tensor query_output;
+ // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch.
+ const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u);
+ const uint32_t num_q_tiles = (static_cast(parameters.sequence_length_) + m_tile - 1u) / m_tile;
+
// Create indirect dispatch buffer if using indirect dispatch
Tensor* indirect_buffer_ptr = nullptr;
Tensor indirect_buffer;
- // Prepare indirect dispatch buffer for decode path with static KV cache
- const bool use_indirect_dispatch = !kv_empty &&
- parameters.sequence_length_ == 1 &&
- parameters.past_present_share_buffer_ &&
+ // Prepare indirect dispatch buffer for split-reduce path with static KV cache.
+ // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based
+ // seqlen_k), so the indirect buffer computes dispatch sizes on GPU.
+ const bool use_indirect_dispatch = parameters.past_present_share_buffer_ &&
seqlen_k != nullptr &&
context.IsGraphCaptureEnabled();
if (use_indirect_dispatch) {
@@ -492,10 +492,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
Q, seqlen_k,
cos_cache, sin_cache,
&query_output, present_key, present_value,
- indirect_buffer_ptr, tile_size));
+ indirect_buffer_ptr, tile_size, num_q_tiles));
Q = &query_output;
} else {
- ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr));
+ ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles));
}
// Extract present_sequence_length directly from present_key tensor shape
@@ -503,7 +503,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
// (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size)
const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]);
- if (parameters.sequence_length_ > 1) {
+ // Route between prefill path (FlashAttentionProgram, single kernel)
+ // and split-reduce decode path (QKV + VxReduce, 2 kernels).
+ // Split-reduce wins for short Q (sequence_length < 32) across all KV
+ // cache lengths measured: 1.13x-2.07x faster at total_sequence_length
+ // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96).
+ const bool use_split_reduce = parameters.sequence_length_ < 32;
+
+ if (!use_split_reduce) {
+ // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles)
bool has_attention_bias = attention_bias != nullptr;
bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"};
bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"};
@@ -545,7 +553,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
const uint32_t prefill_tile_size = is_apple ? 128 : tile_size;
const uint32_t num_seq_tile = (parameters.sequence_length_ + prefill_tile_size - 1) / prefill_tile_size;
- // Get attention bias dimensions for broadcasting
uint32_t attn_bias_dim0 = 1;
uint32_t attn_bias_dim1 = 1;
if (has_attention_bias) {
@@ -570,37 +577,31 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
return context.RunProgram(program);
}
- // For decode path (sequence_length == 1)
- const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_,
- parameters.sequence_length_, present_sequence_length});
- const TensorShape qk_shape(qk_dims);
- Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape);
+ // Split-reduce path (QKV + VxReduce)
const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size;
const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size;
// The metadata is used to store the max and sum of each tile.
const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_,
- num_present_sequence_length_tile, 2});
+ parameters.sequence_length_, num_present_sequence_length_tile, 2});
const TensorShape metadata_shape(metadata_dims);
Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape);
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k,
- parameters, indirect_buffer_ptr, num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
- present_sequence_length));
const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_,
- num_present_sequence_length_tile, parameters.head_size_});
+ parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_});
const TensorShape out_split_vx_shape(out_split_vx_dims);
Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape);
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value,
- seqlen_k, parameters, indirect_buffer_ptr,
- num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size,
- use_indirect_dispatch, present_sequence_length,
- head_sink));
- ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters,
+
+ ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value,
+ &metadata, seqlen_k,
+ parameters, indirect_buffer_ptr, num_total_seq_length_tile,
+ num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
+ present_sequence_length, m_tile));
+
+ ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters,
num_total_seq_length_tile,
- num_present_sequence_length_tile, tile_size, use_indirect_dispatch));
+ num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
+ head_sink, m_tile));
return Status::OK();
}
@@ -621,7 +622,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
Tensor* present_key,
Tensor* present_value,
Tensor* indirect_buffer,
- uint32_t tile_size) {
+ uint32_t tile_size, uint32_t num_q_tiles) {
const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;
@@ -678,6 +679,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
{present_sequence_length},
{tile_size},
{static_cast(dispatch_size)},
+ {static_cast(params.batch_size_)},
+ {num_q_tiles},
});
program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
index e75b6378f67c6..3da6b33b4dc0e 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
@@ -35,7 +35,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program {
{"total_sequence_length", ProgramUniformVariableDataType::Uint32},
{"kv_sequence_length", ProgramUniformVariableDataType::Uint32},
{"tile_size", ProgramUniformVariableDataType::Uint32},
- {"num_heads", ProgramUniformVariableDataType::Uint32});
+ {"num_heads", ProgramUniformVariableDataType::Uint32},
+ {"batch_size", ProgramUniformVariableDataType::Uint32},
+ {"num_q_tiles", ProgramUniformVariableDataType::Uint32});
private:
bool has_past_;
@@ -138,11 +142,14 @@ class FlashAttentionProgram final : public Program {
int max_k_step_;
};
-class FlashAttentionDecodeQKTProgram final : public Program {
+class FlashAttentionDecodeQKVProgram final : public Program {
public:
- FlashAttentionDecodeQKTProgram(const std::string& kernel_name,
- bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch)
- : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
+ FlashAttentionDecodeQKVProgram(const std::string& kernel_name,
+ bool has_attention_bias, uint32_t tile_size, int head_size_vec,
+ bool use_indirect_dispatch, bool q_BNSH = false,
+ bool is_unidirectional = false,
+ uint32_t m_tile = 1)
+ : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -156,41 +163,23 @@ class FlashAttentionDecodeQKTProgram final : public Program {
- public:
- FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false)
- : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) {
- }
-
- Status GenerateShaderCode(ShaderHelper& sh) const override;
-
- WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32},
- {"head_size_vec", ProgramUniformVariableDataType::Uint32},
- {"present_sequence_length", ProgramUniformVariableDataType::Uint32},
- {"n_reps", ProgramUniformVariableDataType::Uint32},
- {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32},
- {"batch_heads", ProgramUniformVariableDataType::Uint32},
- {"num_heads", ProgramUniformVariableDataType::Uint32});
-
- private:
uint32_t tile_size_;
int head_size_vec_;
bool use_indirect_dispatch_;
- bool has_head_sink_;
+ bool q_BNSH_;
+ bool is_unidirectional_;
+ uint32_t m_tile_;
};
class FlashAttentionDecodeVxReduceProgram final : public Program {
public:
- FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch)
- : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) {
+ FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1)
+ : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) {
}
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -199,12 +188,16 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array;
-var inner_qk_values: array, tile_size>;
-var tile_qk: array;
-
-#if has_attention_bias
- fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
- {
- // Handle broadcasting: if dimension size is 1, use index 0
- let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
- let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
-
- // Calculate flat offset with broadcasting applied
- // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length]
- // For decode, new_seq_length is 1, so we can simplify:
- let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length +
- bias_head_idx * total_seq_length +
- k_idx;
- return attention_bias[offset];
- }
-#else
- fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
- {
- return q_element_t(0);
- }
-#endif
-
-$MAIN {
- let local_row = u32(local_idx / tile_size_k_vec);
- let local_col = local_idx % tile_size_k_vec;
-#if use_indirect_dispatch
- let total_sequence_length = u32(seqlens_k[0]) + 1u;
-#else
- let total_sequence_length = uniforms.total_sequence_length;
-#endif
- let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
- let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile);
- let head_idx = batch_head_idx % uniforms.num_heads;
- let batch_idx = batch_head_idx / uniforms.num_heads;
- if (batch_idx >= uniforms.batch_size) {
- return;
- }
- let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec;
- let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec;
- for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
- if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
- tile_q[local_idx] = q[q_offset + k + local_idx];
- }
- workgroupBarrier();
- let q_data = tile_q[local_col] * q_element_t(uniforms.alpha);
- if (k + local_col < uniforms.head_size_vec) {
- for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
- if (total_seq_offset + row_offset + local_row < total_sequence_length) {
- inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data);
- }
- }
- }
- workgroupBarrier();
- }
-
- if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
- var sum = q_element_t(0);
- for (var i = 0u; i < tile_size_k_vec; i++) {
- sum += inner_qk_values[local_idx][i];
- }
-
- sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length);
- tile_qk[local_idx] = sum;
- output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum;
- }
- workgroupBarrier();
-
- if (local_idx == 0u) {
- // Calculate the max and sum in current split.
- var l_max = f32(-3.4028234663852886e+38f);
- var l_sum = f32(0);
- for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
- l_max = max(l_max, f32(tile_qk[i]));
- }
- for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
- l_sum += exp(f32(tile_qk[i]) - l_max);
- }
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
- metadata[meta_offset] = metadata_value_t(l_max, l_sum);
- }
-}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template
new file mode 100644
index 0000000000000..524a18ca43245
--- /dev/null
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template
@@ -0,0 +1,197 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#param has_attention_bias
+#param v_head_size_vec
+#param is_unidirectional
+#param m_tile
+#param q_BNSH
+#param sub_tile_count
+#param tile_size
+#param tile_size_k_vec
+#param use_indirect_dispatch
+
+#use .getByOffset .setByOffset
+
+// Fused QK^T + softmax + V multiply shader.
+//
+// Each workgroup processes one KV tile (tile_size rows of present_key/value)
+// for m_tile Q rows. The computation has two phases:
+//
+// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask,
+// per-tile max/sum for online softmax)
+// Phase 2: Local softmax normalization + V multiply (using local max/sum,
+// no cross-workgroup dependency)
+//
+// The VxReduce shader performs the final rescaling across tiles.
+
+var tile_q: array, m_tile>;
+var inner_qk_values: array, tile_size>, m_tile>;
+var tile_qk: array, m_tile>;
+var tile_output: array, m_tile>;
+var qkv_values: array, sub_tile_count>, m_tile>;
+var tile_max: array;
+var tile_sum: array;
+
+#if has_attention_bias
+ fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
+ {
+ let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
+ let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
+ let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length +
+ bias_head_idx * uniforms.new_sequence_length * total_seq_length +
+ q_idx * total_seq_length +
+ k_idx;
+ return attention_bias[offset];
+ }
+#else
+ fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t
+ {
+ return q_element_t(0);
+ }
+#endif
+
+$MAIN {
+ let local_row = u32(local_idx / tile_size_k_vec);
+ let local_col = local_idx % tile_size_k_vec;
+ #if use_indirect_dispatch
+ let total_sequence_length = u32(seqlens_k[0]) + 1u;
+ #else
+ let total_sequence_length = uniforms.total_sequence_length;
+ #endif
+ let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
+ let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile;
+ // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile]
+ let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
+ let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles;
+ let q_base = q_tile_idx * m_tile;
+ let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles));
+ let head_idx = batch_head_idx % uniforms.num_heads;
+ let batch_idx = batch_head_idx / uniforms.num_heads;
+ if (batch_idx >= uniforms.batch_size) {
+ return;
+ }
+ let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec;
+ let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length;
+
+ // ============================================================
+ // Phase 1: QK^T computation
+ // ============================================================
+ for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) {
+ let q_idx = q_base + m;
+#if q_BNSH
+ let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec +
+ head_idx * uniforms.new_sequence_length * uniforms.head_size_vec +
+ q_idx * uniforms.head_size_vec;
+#else
+ let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec +
+ q_idx * uniforms.num_heads * uniforms.head_size_vec +
+ head_idx * uniforms.head_size_vec;
+#endif
+ tile_q[m][local_idx] = q.getByOffset(q_offset + k + local_idx);
+ }
+ }
+ workgroupBarrier();
+ if (k + local_col < uniforms.head_size_vec) {
+ for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
+ if (total_seq_offset + row_offset + local_row < total_sequence_length) {
+ let k_data = present_key.getByOffset(present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col);
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha);
+ inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data);
+ }
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+
+ // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
+ var sum = q_element_t(0);
+ for (var i = 0u; i < tile_size_k_vec; i++) {
+ sum += inner_qk_values[m][local_idx][i];
+ }
+
+ sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length);
+#if is_unidirectional
+ if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) {
+ sum = q_element_t(-65504.0f);
+ }
+#endif
+ tile_qk[m][local_idx] = present_value_element_t(sum);
+ }
+ workgroupBarrier();
+
+ // Compute per-tile max and sum for online softmax
+ if (local_idx == 0u) {
+ var l_max = f32(-3.4028234663852886e+38f);
+ var l_sum = f32(0);
+ for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
+ l_max = max(l_max, f32(tile_qk[m][i]));
+ }
+ for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) {
+ l_sum += exp(f32(tile_qk[m][i]) - l_max);
+ }
+ tile_max[m] = l_max;
+ tile_sum[m] = l_sum;
+ let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile;
+ metadata.setByOffset(meta_offset, metadata_value_t(l_max, l_sum));
+ }
+ }
+ workgroupBarrier();
+
+ // ============================================================
+ // Phase 2: Local softmax + V multiply
+ // ============================================================
+
+ // Normalize tile_qk with local max/sum
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) {
+ tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]);
+ }
+ }
+ workgroupBarrier();
+
+ for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ qkv_values[m][local_row][local_col] = present_value_value_t(0);
+ }
+ workgroupBarrier();
+
+ if (k + local_col < v_head_size_vec) {
+ for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
+ if (total_seq_offset + row_offset + local_row < total_sequence_length) {
+ let v_data = present_value.getByOffset(present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col);
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row];
+ }
+ }
+ }
+ }
+ workgroupBarrier();
+
+ if (local_idx < tile_size_k_vec) {
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ for (var i = 0u; i < sub_tile_count; i++) {
+ tile_output[m][k + local_idx] += qkv_values[m][i][local_idx];
+ }
+ }
+ }
+ workgroupBarrier();
+ }
+
+ // Write output
+ let tile_idx = workgroup_idx % num_total_seq_length_tile;
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec;
+ for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) {
+ out_split_vx.setByOffset(out_base + tile_idx * v_head_size_vec + i, tile_output[m][i]);
+ }
+ }
+}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template
deleted file mode 100644
index 6f1ad1ca41b71..0000000000000
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template
+++ /dev/null
@@ -1,113 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#param has_head_sink
-#param tile_size
-#param head_size_vec
-#param tile_size_k_vec
-#param sub_tile_count
-#param use_indirect_dispatch
-
-// Note that this shader adopts similar algorithm with dp4a generation shader.
-//
-// This algorithm works to compute dot product of v with qk parallelly, by
-// processing on the head_size dimension at each step amongst tile_size_k_vec
-// threads, and utilizing the remaining threads in the workgroup to process
-// additional rows of |present_value| in parallel (such that the values in
-// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads
-// also reload |present_value| tile_size/sub_tile_count times to compute partial
-// dot products of other |present_value| rows in order to complete all tile_size
-// |present_value| rows in this workgroup and also reusing the values in
-// tile_qk.
-//
-// The difference with FlashAttentionDecodeQKTProgram is that the dot products
-// go through the rows (total_sequence_length) of |present_value| instead of
-// columns (head_size_vec). And each workgroup only calculate current
-// tile_size's dot products instead of iterating the whole row
-// |total_sequence_length|. That's why this shader is a split shader. The final
-// reduce will be done in FlashAttentionDecodeReduceProgram.
-
-// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx
-// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate
-// memory. The FlashAttentionDecodeQKT can be merged into split shader and do
-// the final softmax adjustment in the reduce shader. However, some issues are
-// met that when the total sequence length exceeds some value, the result will
-// become garbage. Since it can't be resolved in a short time, leave it as TODO
-// to fix it in future.
-
-var tile_qk: array;
-var tile_output: array;
-var qkv_values: array, sub_tile_count>;
-
-$MAIN {
- let local_row = u32(local_idx / tile_size_k_vec);
- let local_col = local_idx % tile_size_k_vec;
- #if use_indirect_dispatch
- let total_sequence_length = u32(seqlens_k[0]) + 1u;
- #else
- let total_sequence_length = uniforms.total_sequence_length;
- #endif
- let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size;
- let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile);
- if (batch_head_idx >= uniforms.batch_heads) {
- return;
- }
- let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length;
-
- // Calculate the global max and sum in qk.
- var g_max = f32(-3.4028234663852886e+38f);
-#if has_head_sink
- let head_idx = batch_head_idx % uniforms.num_heads;
- let sink_value = f32(head_sink[head_idx]);
- g_max = max(g_max, sink_value);
-#endif
- var g_sum = f32(0);
- for (var i = 0u; i < num_total_seq_length_tile; i++)
- {
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i;
- g_max = max(g_max, metadata[meta_offset].x);
- }
- for (var i = 0u; i < num_total_seq_length_tile; i++)
- {
- let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i;
- let m_value = metadata[meta_offset];
- g_sum += exp(m_value.x - g_max) * m_value.y;
- }
-#if has_head_sink
- g_sum += exp(sink_value - g_max);
-#endif
-
- if (total_seq_offset + local_idx < total_sequence_length) {
- tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum);
- }
-
- for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) {
- var value = present_value_value_t(0);
- qkv_values[local_row][local_col] = present_value_value_t(0);
- workgroupBarrier();
-
- if (k + local_col < head_size_vec) {
- for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) {
- if (total_seq_offset + row_offset + local_row < total_sequence_length) {
- value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row];
- }
- }
- }
-
- qkv_values[local_row][local_col] = value;
- workgroupBarrier();
-
- if (local_idx < tile_size_k_vec) {
- for (var i = 0u; i < sub_tile_count; i++) {
- tile_output[k + local_idx] += qkv_values[i][local_idx];
- }
- }
- workgroupBarrier();
- }
-
- for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) {
- let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i;
- out_split_vx[out_offset] = tile_output[i];
- }
-}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
index f909a87724da6..a3ce0b68cb659 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
+++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template
@@ -1,31 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+#param has_head_sink
+#param m_tile
#param seq_tile_size
#param tile_size
#param use_indirect_dispatch
-// Inputs are splits of the GQA output, split into num_total_seq_length_tiles
-// rows. This shader needs to add these splits across the row dimension to
-// arrive at the final result. The column is head size wide. The reduction
-// achieves maximum parallelization by splitting this task first into tile_size
-// columns that each workgroup is responsible for. Then within each workgroup
-// the task of summation over the num_total_seq_length_tile for the tile_size
-// columns is further split in two ways. First across the row dimension to have
-// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE
-// rows. Then across the column dimension where each thread is responsible for 1
-// column of the TILE_SIZE columns the workgroup is responsible for.
+#use .getByOffset .setByOffset
+
+// This shader reduces partial V outputs from the fused QKV shader.
+// Each tile produced a locally-normalized V contribution. To get the
+// correct global result, we rescale each tile's contribution using
+// per-tile metadata (max, sum) with online softmax:
+//
+// global_max = max(local_max_i for all tiles)
+// global_sum = sum(local_sum_i * exp(local_max_i - global_max))
+// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum
var tile_input: array, tile_size>;
$MAIN {
+ let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile;
+ // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile]
let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size;
- let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile);
+ let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles;
+ let q_base = q_tile_idx * m_tile;
+ let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles));
if (batch_head_idx >= uniforms.batch_heads) {
return;
}
- let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec;
- var value = output_value_t(0);
let local_row = u32(local_idx / tile_size);
let local_col = local_idx % tile_size;
#if use_indirect_dispatch
@@ -35,23 +39,60 @@ $MAIN {
let num_total_seq_length_tile = uniforms.num_total_seq_length_tile;
#endif
- if (head_size_offset + local_col < uniforms.head_size_vec) {
- for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
- if (r + local_row < num_total_seq_length_tile) {
- value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col];
+ for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) {
+ let q_idx = q_base + m;
+ let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec;
+ let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile;
+
+ // Compute global max across all tiles
+ var g_max = f32(-3.4028234663852886e+38f);
+#if has_head_sink
+ let head_idx_for_sink = batch_head_idx % uniforms.num_heads;
+ let sink_value = f32(head_sink[head_idx_for_sink]);
+ g_max = max(g_max, sink_value);
+#endif
+ for (var i = 0u; i < num_total_seq_length_tile; i++) {
+ g_max = max(g_max, metadata.getByOffset(meta_base + i).x);
+ }
+
+ // Compute global sum with rescaling
+ var g_sum = f32(0);
+ for (var i = 0u; i < num_total_seq_length_tile; i++) {
+ let m_value = metadata.getByOffset(meta_base + i);
+ g_sum += m_value.y * exp(m_value.x - g_max);
+ }
+#if has_head_sink
+ g_sum += exp(sink_value - g_max);
+#endif
+
+ // Accumulate rescaled partial outputs
+ var value = output_value_t(0);
+ if (head_size_offset + local_col < uniforms.head_size_vec) {
+ for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) {
+ if (r + local_row < num_total_seq_length_tile) {
+ let tile_meta = metadata.getByOffset(meta_base + r + local_row);
+ let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum;
+ value += input.getByOffset(in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col) * output_value_t(output_element_t(rescale_f32));
+ }
}
}
- }
- tile_input[local_row][local_col] = value;
- workgroupBarrier();
+ tile_input[local_row][local_col] = value;
+ workgroupBarrier();
- if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) {
- value = output_value_t(0);
- for (var i = 0u; i < tile_size; i++) {
- value += tile_input[i][local_idx];
+ if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) {
+ value = output_value_t(0);
+ for (var i = 0u; i < tile_size; i++) {
+ value += tile_input[i][local_idx];
+ }
+ let head_idx = batch_head_idx % uniforms.num_heads;
+ let batch_idx = batch_head_idx / uniforms.num_heads;
+ let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec +
+ q_idx * uniforms.num_heads * uniforms.head_size_vec +
+ head_idx * uniforms.head_size_vec +
+ head_size_offset + local_idx;
+ output.setByOffset(output_id, value);
}
- let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx;
- output[output_id] = value;
+ workgroupBarrier();
}
}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
index c64bdf45cdcf8..7b09a3a6af080 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
+++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template
@@ -41,12 +41,9 @@ $MAIN {
#endif
#if prepare_indirect_dispatch
- // Prepare indirect dispatch buffer for thread 0
if (global_idx == 0u) {
let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size;
- indirect_buffer[0] = num_total_seq_length_tile;
- indirect_buffer[1] = uniforms.num_heads;
- indirect_buffer[2] = 1u;
+ normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);
}
#endif
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 1054fd94ef423..f3f2f521ecab2 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1520,18 +1520,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::STRING,
std::string("int"))
.Attr("weights_prepacked",
- "Only meaningful when quant_type='int'. Tri-state control over whether the "
- "int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS "
- "fpA_intB format expected by the runner. -1 (auto): let the execution provider "
- "choose its own backward-compatible default; the CUDA EP treats auto as "
- "prepacked. 1: the initializers are already prepacked (e.g. produced offline by "
- "pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers "
- "are raw, un-prepacked [E, N, K/pack] tensors as produced by "
- "quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself "
- "in PrePack(), matching the behaviour of MatMulNBits and removing the offline "
- "pre-pack requirement from exporters. Defaults to -1 (auto) so each execution "
- "provider can pick its own backward-compatible default rather than the schema "
- "imposing one.",
+ "Only meaningful when quant_type='int'. Tri-state control over the layout of the "
+ "int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by "
+ "-1 and 1 are determined by the execution provider. 0: the initializers are raw, "
+ "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.",
AttributeProto::INT,
static_cast(-1))
.Input(0,
diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc
index 732a5a855d65d..4260ef706befa 100644
--- a/onnxruntime/core/platform/linux/device_discovery.cc
+++ b/onnxruntime/core/platform/linux/device_discovery.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/platform/device_discovery.h"
+#include "core/platform/linux/npu_device_discovery.h"
#include "core/platform/linux/pci_device_discovery.h"
#include
@@ -123,7 +124,7 @@ Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& gpu_devices_out) {
std::vector gpu_sysfs_path_infos{};
@@ -315,6 +317,119 @@ Status GetGpuDevices(std::vector& gpu_devices_out) {
} // namespace
+namespace npu_device_discovery {
+
+Status DetectNpuSysfsPaths(const fs::path& sysfs_accel_path,
+ std::vector& npu_sysfs_paths_out) {
+ std::error_code error_code{};
+
+ const bool sysfs_accel_path_exists = fs::exists(sysfs_accel_path, error_code);
+ ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Checking existence of accel sysfs path"));
+
+ if (!sysfs_accel_path_exists) {
+ npu_sysfs_paths_out = {};
+ return Status::OK();
+ }
+
+ const auto detect_accel_path = [](const fs::path& sysfs_path, size_t& accel_idx) -> bool {
+ const auto filename = sysfs_path.filename();
+ const auto filename_str = std::string_view{filename.native()};
+
+ // Look for a filename matching "accelN". N is a number.
+ constexpr std::string_view prefix = "accel";
+ if (filename_str.find(prefix) != 0) {
+ return false;
+ }
+
+ size_t parsed_accel_idx{};
+ if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_accel_idx)) {
+ return false;
+ }
+
+ accel_idx = parsed_accel_idx;
+ return true;
+ };
+
+ std::vector npu_sysfs_paths{};
+
+ auto dir_iterator = fs::directory_iterator{sysfs_accel_path, error_code};
+ ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Iterating over accel sysfs devices"));
+
+ for (const auto& dir_item : dir_iterator) {
+ const auto& dir_item_path = dir_item.path();
+
+ if (size_t accel_idx{}; detect_accel_path(dir_item_path, accel_idx)) {
+ NpuSysfsPathInfo path_info{};
+ path_info.accel_idx = accel_idx;
+ path_info.path = dir_item_path;
+ npu_sysfs_paths.emplace_back(std::move(path_info));
+ }
+ }
+
+ npu_sysfs_paths_out = std::move(npu_sysfs_paths);
+ return Status::OK();
+}
+
+Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info,
+ OrtHardwareDevice& npu_device_out) {
+ OrtHardwareDevice npu_device{};
+
+ const auto& sysfs_path = path_info.path;
+
+ uint16_t vendor_id{};
+ const auto vendor_id_path = sysfs_path / "device" / "vendor";
+ ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, vendor_id));
+ npu_device.vendor_id = vendor_id;
+
+ uint16_t device_id{};
+ const auto device_id_path = sysfs_path / "device" / "device";
+ ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, device_id));
+ npu_device.device_id = device_id;
+
+ npu_device.metadata.Add("accel_idx", MakeString(path_info.accel_idx));
+
+ std::optional pci_bus_id;
+ ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id));
+ if (pci_bus_id) {
+ npu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id));
+ }
+
+ npu_device.type = OrtHardwareDeviceType_NPU;
+
+ npu_device_out = std::move(npu_device);
+
+ return Status::OK();
+}
+
+} // namespace npu_device_discovery
+
+namespace {
+
+Status GetNpuDevices(std::vector& npu_devices_out) {
+ std::vector npu_sysfs_path_infos{};
+ ORT_RETURN_IF_ERROR(npu_device_discovery::DetectNpuSysfsPaths(kSysfsAccelPath, npu_sysfs_path_infos));
+
+ std::vector npu_devices{};
+ npu_devices.reserve(npu_sysfs_path_infos.size());
+
+ for (const auto& npu_sysfs_path_info : npu_sysfs_path_infos) {
+ OrtHardwareDevice npu_device{};
+ if (auto status = npu_device_discovery::GetNpuDeviceFromSysfs(npu_sysfs_path_info, npu_device); !status.IsOK()) {
+ LOGS_DEFAULT(WARNING) << MakeString("Failed to detect devices under ",
+ npu_sysfs_path_info.path,
+ ": ",
+ status.ErrorMessage());
+ continue;
+ }
+ npu_devices.emplace_back(std::move(npu_device));
+ }
+
+ npu_devices_out = std::move(npu_devices);
+
+ return Status::OK();
+}
+} // namespace
+
std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() {
std::unordered_set devices;
@@ -334,7 +449,16 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor
}
// get NPU devices
- // TODO figure out how to discover these
+ {
+ std::vector npu_devices{};
+ Status npu_device_discovery_status = GetNpuDevices(npu_devices);
+ if (npu_device_discovery_status.IsOK()) {
+ devices.insert(std::make_move_iterator(npu_devices.begin()),
+ std::make_move_iterator(npu_devices.end()));
+ } else {
+ LOGS_DEFAULT(WARNING) << "NPU device discovery failed: " << npu_device_discovery_status.ErrorMessage();
+ }
+ }
return devices;
}
diff --git a/onnxruntime/core/platform/linux/npu_device_discovery.h b/onnxruntime/core/platform/linux/npu_device_discovery.h
new file mode 100644
index 0000000000000..e139d33988903
--- /dev/null
+++ b/onnxruntime/core/platform/linux/npu_device_discovery.h
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// This header exposes Linux NPU device discovery internals for testing.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "core/common/status.h"
+#include "core/session/abi_devices.h"
+
+namespace onnxruntime {
+namespace npu_device_discovery {
+
+struct NpuSysfsPathInfo {
+ size_t accel_idx;
+ std::filesystem::path path;
+};
+
+// Scans the given sysfs accel directory for NPU accel devices.
+Status DetectNpuSysfsPaths(const std::filesystem::path& sysfs_accel_path,
+ std::vector& npu_sysfs_paths_out);
+
+// Reads vendor/device IDs and populates an OrtHardwareDevice from an accel sysfs path.
+Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info,
+ OrtHardwareDevice& npu_device_out);
+
+} // namespace npu_device_discovery
+} // namespace onnxruntime
diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc
index 2008f998384a3..20ec11bf7deb2 100644
--- a/onnxruntime/core/platform/telemetry.cc
+++ b/onnxruntime/core/platform/telemetry.cc
@@ -64,6 +64,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
const std::string& loadedFrom, const std::vector& execution_provider_ids,
const std::string& hardware_device_types,
const std::string& hardware_vendor_ids,
+ const std::string& ep_versions,
bool use_fp16, bool captureState) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(ir_version);
@@ -81,6 +82,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
ORT_UNUSED_PARAMETER(execution_provider_ids);
ORT_UNUSED_PARAMETER(hardware_device_types);
ORT_UNUSED_PARAMETER(hardware_vendor_ids);
+ ORT_UNUSED_PARAMETER(ep_versions);
ORT_UNUSED_PARAMETER(use_fp16);
ORT_UNUSED_PARAMETER(captureState);
}
@@ -124,6 +126,15 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu
ORT_UNUSED_PARAMETER(line);
}
+void Telemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status,
+ const std::string& ep_versions,
+ const std::string& ep_device_types) const {
+ ORT_UNUSED_PARAMETER(session_id);
+ ORT_UNUSED_PARAMETER(status);
+ ORT_UNUSED_PARAMETER(ep_versions);
+ ORT_UNUSED_PARAMETER(ep_device_types);
+}
+
void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
const std::unordered_map& duration_per_batch_size) const {
ORT_UNUSED_PARAMETER(session_id);
@@ -139,6 +150,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id,
uint32_t hardware_device_id,
const std::string& hardware_vendor,
const std::string& ep_vendor,
+ const std::string& ep_version,
int assigned_node_count,
uint32_t total_runs_since_last,
int64_t total_run_duration_since_last) const {
@@ -149,6 +161,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id,
ORT_UNUSED_PARAMETER(hardware_device_id);
ORT_UNUSED_PARAMETER(hardware_vendor);
ORT_UNUSED_PARAMETER(ep_vendor);
+ ORT_UNUSED_PARAMETER(ep_version);
ORT_UNUSED_PARAMETER(assigned_node_count);
ORT_UNUSED_PARAMETER(total_runs_since_last);
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h
index c2dce68faa2dd..946e34ee35832 100644
--- a/onnxruntime/core/platform/telemetry.h
+++ b/onnxruntime/core/platform/telemetry.h
@@ -66,6 +66,7 @@ class Telemetry {
const std::string& loadedFrom, const std::vector& execution_provider_ids,
const std::string& hardware_device_types,
const std::string& hardware_vendor_ids,
+ const std::string& ep_versions,
bool use_fp16, bool captureState) const;
virtual void LogCompileModelStart(uint32_t session_id,
@@ -86,6 +87,10 @@ class Telemetry {
virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const;
+ virtual void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status,
+ const std::string& ep_versions,
+ const std::string& ep_device_types) const;
+
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
const std::unordered_map& duration_per_batch_size) const;
@@ -100,6 +105,7 @@ class Telemetry {
uint32_t hardware_device_id,
const std::string& hardware_vendor,
const std::string& ep_vendor,
+ const std::string& ep_version,
int assigned_node_count,
uint32_t total_runs_since_last,
int64_t total_run_duration_since_last) const;
diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc
index 04b9aaa0eb8ed..342b937ffb656 100644
--- a/onnxruntime/core/platform/windows/telemetry.cc
+++ b/onnxruntime/core/platform/windows/telemetry.cc
@@ -285,6 +285,7 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const {
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -318,6 +319,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
const std::string& loaded_from, const std::vector& execution_provider_ids,
const std::string& hardware_device_types,
const std::string& hardware_vendor_ids,
+ const std::string& ep_versions,
bool use_fp16, bool captureState) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
@@ -373,7 +375,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
// schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds
- TraceLoggingUInt8(1, "schemaVersion"),
+ // schemaVersion 2: added executionProviderVersions
+ TraceLoggingUInt8(2, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
@@ -392,6 +395,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"),
TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"),
+ TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"),
TraceLoggingString(service_names.c_str(), "serviceNames"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
} else {
@@ -404,7 +408,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
// schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds
- TraceLoggingUInt8(1, "schemaVersion"),
+ // schemaVersion 2: added executionProviderVersions
+ TraceLoggingUInt8(2, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
@@ -423,6 +428,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"),
TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"),
+ TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"),
TraceLoggingString(service_names.c_str(), "serviceNames"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -457,7 +463,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id,
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingString(input_source.c_str(), "inputSource"),
TraceLoggingString(output_target.c_str(), "outputTarget"),
@@ -466,6 +472,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id,
TraceLoggingBool(embed_ep_context, "embedEpContext"),
TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -507,7 +514,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingHResult(hr, "hResult"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(status.Code(), "errorCode"),
@@ -516,6 +523,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingString(file, "file"),
TraceLoggingString(function, "function"),
TraceLoggingInt32(line, "line"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
#else
TraceLoggingWrite(telemetry_provider_handle,
@@ -525,7 +533,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(status.Code(), "errorCode"),
TraceLoggingUInt32(status.Category(), "errorCategory"),
@@ -533,10 +541,35 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status
TraceLoggingString(file, "file"),
TraceLoggingString(function, "function"),
TraceLoggingInt32(line, "line"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
#endif
}
+void WindowsTelemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status,
+ const std::string& ep_versions,
+ const std::string& ep_device_types) const {
+ if (global_register_count_ == 0 || enabled_ == false)
+ return;
+
+ TraceLoggingWrite(telemetry_provider_handle,
+ "RuntimeInferenceError",
+ TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
+ TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
+ TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
+ TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
+ // Telemetry info
+ TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt32(session_id, "sessionId"),
+ TraceLoggingUInt32(status.Code(), "errorCode"),
+ TraceLoggingUInt32(status.Category(), "errorCategory"),
+ TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"),
+ TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"),
+ TraceLoggingString(ep_device_types.c_str(), "executionProviderDeviceTypes"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
+ TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
+}
+
void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
const std::unordered_map& duration_per_batch_size) const {
if (global_register_count_ == 0 || enabled_ == false)
@@ -559,11 +592,12 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(total_runs_since_last, "totalRuns"),
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"),
TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -574,6 +608,7 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id,
uint32_t hardware_device_id,
const std::string& hardware_vendor,
const std::string& ep_vendor,
+ const std::string& ep_version,
int assigned_node_count,
uint32_t total_runs_since_last,
int64_t total_run_duration_since_last) const {
@@ -588,7 +623,8 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id,
TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ // schemaVersion 1: added epVersion, runtimeVersion
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingString(ep_type.c_str(), "executionProviderType"),
TraceLoggingString(hardware_device_type.c_str(), "hardwareDeviceType"),
@@ -596,9 +632,11 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id,
TraceLoggingUInt32(hardware_device_id, "hardwareDeviceId"),
TraceLoggingString(hardware_vendor.c_str(), "hardwareVendor"),
TraceLoggingString(ep_vendor.c_str(), "epVendor"),
+ TraceLoggingString(ep_version.c_str(), "epVersion"),
TraceLoggingInt32(assigned_node_count, "assignedNodeCount"),
TraceLoggingUInt32(total_runs_since_last, "totalRunsSinceLast"),
TraceLoggingInt64(total_run_duration_since_last, "totalRunDurationSinceLast"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -724,8 +762,9 @@ void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const {
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
@@ -799,8 +838,9 @@ void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
- TraceLoggingUInt8(0, "schemaVersion"),
+ TraceLoggingUInt8(1, "schemaVersion"),
TraceLoggingString(registration_name.c_str(), "registrationName"),
+ TraceLoggingString(ORT_VERSION, "runtimeVersion"),
TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"));
}
diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h
index 8295d042d7ec9..46c262e1479d3 100644
--- a/onnxruntime/core/platform/windows/telemetry.h
+++ b/onnxruntime/core/platform/windows/telemetry.h
@@ -59,6 +59,7 @@ class WindowsTelemetry : public Telemetry {
const std::string& loadedFrom, const std::vector& execution_provider_ids,
const std::string& hardware_device_types,
const std::string& hardware_vendor_ids,
+ const std::string& ep_versions,
bool use_fp16, bool captureState) const override;
void LogCompileModelStart(uint32_t session_id,
@@ -79,6 +80,10 @@ class WindowsTelemetry : public Telemetry {
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const override;
+ void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status,
+ const std::string& ep_versions,
+ const std::string& ep_device_types) const override;
+
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last,
const std::unordered_map& duration_per_batch_size) const override;
@@ -89,6 +94,7 @@ class WindowsTelemetry : public Telemetry {
uint32_t hardware_device_id,
const std::string& hardware_vendor,
const std::string& ep_vendor,
+ const std::string& ep_version,
int assigned_node_count,
uint32_t total_runs_since_last,
int64_t total_run_duration_since_last) const override;
diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc
index f462680e02587..f75c5fe12a6f4 100644
--- a/onnxruntime/core/providers/cpu/signal/dft.cc
+++ b/onnxruntime/core/providers/cpu/signal/dft.cc
@@ -524,14 +524,15 @@ template
static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool /*inverse*/) {
// Attr("onesided"): default = 1
// Input(0, "signal") type = T1
- // Input(1, "frame_length") type = T2
+ // Input(1, "frame_step") type = T2
// Input(2, "window") type = T1, optional
- // Input(3, "frame_step") type = T2
+ // Input(3, "frame_length") type = T2
// Output(0, "output") type = T1
// Get signal
const auto* signal = ctx->Input(0);
const auto frame_step = signal::get_scalar_value_from_tensor(ctx->Input(1));
+ ORT_RETURN_IF_NOT(frame_step > 0, "frame_step must be greater than zero.");
const auto* window = ctx->Input(2);
const auto* frame_length_tensor = ctx->Input(3);
@@ -596,8 +597,11 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside
// Run each dft of each batch as if it was a real-valued batch size 1 dft operation
for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int64_t i = 0; i < n_dfts; i++) {
- auto input_frame_begin =
- signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components);
+ const auto frame_start = i * frame_step;
+ // Defensive check before creating a non-owning tensor view. n_dfts derivation should keep this in bounds.
+ ORT_RETURN_IF_NOT(frame_start <= signal_size - window_size, "STFT input frame is out of bounds.");
+ // signal_data is U*, so one increment advances one input sample, including both lanes for complex input.
+ auto input_frame_begin = signal_data + (batch_idx * signal_size) + frame_start;
auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) +
(i * dft_output_size * output_components);
@@ -619,9 +623,9 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside
Status STFT::Compute(OpKernelContext* ctx) const {
// Attr("onesided"): default = 1
// Input(0, "signal") type = T1
- // Input(1, "frame_length") type = T2
+ // Input(1, "frame_step") type = T2
// Input(2, "window") type = T1, optional
- // Input(3, "frame_step") type = T2
+ // Input(3, "frame_length") type = T2
// Output(0, "output") type = T1
// Get signal shape
diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc
index d030b68abdb99..458c47bcb58cd 100644
--- a/onnxruntime/core/session/experimental_c_api.cc
+++ b/onnxruntime/core/session/experimental_c_api.cc
@@ -18,6 +18,14 @@
namespace OrtExperimentalApis {
+// Forward declarations driven by the .inc file so the registration table below
+// can take the address of every entry, including those defined in other
+// translation units linked into onnxruntime_session.
+#define ORT_EXPERIMENTAL_API(VER, RET, NAME, ...) \
+ RET ORT_API_CALL NAME##_SinceV##VER(__VA_ARGS__) NO_EXCEPTION;
+#include "onnxruntime_experimental_c_api.inc"
+#undef ORT_EXPERIMENTAL_API
+
// Test-only experimental function that writes a known sentinel value.
// Exists to exercise the experimental API mechanism end-to-end and to serve as a template for future experimental
// functions.
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 73f3e77b93177..9db18fc4ac69e 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -78,6 +78,7 @@
#include "core/session/environment.h"
#include "core/session/IOBinding.h"
#include "core/session/inference_session_utils.h"
+#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_run_options_config_keys.h"
#include "core/session/user_logging_sink.h"
@@ -871,7 +872,7 @@ InferenceSession::~InferenceSession() {
telemetry_provider.LogEpDeviceUsage(
session_id_, ep_info.ep_type, ep_info.hardware_device_type,
ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor,
- ep_info.assigned_node_count,
+ ep_info.ep_version, ep_info.assigned_node_count,
telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_);
}
}
@@ -2778,6 +2779,7 @@ common::Status InferenceSession::Initialize() {
graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash,
model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(),
telemetry_.ep_device_types_summary_, telemetry_.ep_device_vendor_ids_summary_,
+ telemetry_.ep_versions_summary_,
model_has_fp16_inputs, false);
// Emit one initial EpDeviceUsage event per (EP, device) pair with run counts of 0.
@@ -2787,7 +2789,7 @@ common::Status InferenceSession::Initialize() {
env.GetTelemetryProvider().LogEpDeviceUsage(
session_id_, ep_info.ep_type, ep_info.hardware_device_type,
ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor,
- ep_info.assigned_node_count, 0, 0);
+ ep_info.ep_version, ep_info.assigned_node_count, 0, 0);
}
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
@@ -3456,7 +3458,7 @@ Status InferenceSession::RunImpl(const RunOptions& run_options,
env.GetTelemetryProvider().LogEpDeviceUsage(
session_id_, ep_info.ep_type, ep_info.hardware_device_type,
ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor,
- ep_info.assigned_node_count,
+ ep_info.ep_version, ep_info.assigned_node_count,
telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_);
}
// reset counters
@@ -3470,6 +3472,10 @@ Status InferenceSession::RunImpl(const RunOptions& run_options,
Telemetry::kRuntimePerfMaxInterval);
}
}
+ } else {
+ // Log runtime error with EP versions
+ env.GetTelemetryProvider().LogRuntimeInferenceError(session_id_, retval, telemetry_.ep_versions_summary_,
+ telemetry_.ep_device_types_summary_);
}
// log evaluation stop to trace logging provider
@@ -4087,6 +4093,7 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) {
telemetry_.ep_device_info_.clear();
telemetry_.ep_device_types_summary_.clear();
telemetry_.ep_device_vendor_ids_summary_.clear();
+ telemetry_.ep_versions_summary_.clear();
// First, count nodes assigned to each EP type after graph partitioning.
// The graph node only carries the EP type string, so when a single EP targets
@@ -4124,6 +4131,10 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) {
Telemetry::EpDeviceInfo entry;
entry.ep_type = ep_type;
entry.ep_vendor = ep_device->ep_vendor;
+ auto it = ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version);
+ if (it != ep_device->ep_metadata.Entries().end()) {
+ entry.ep_version = it->second;
+ }
if (ep_device->device != nullptr) {
entry.hardware_device_type = HardwareDeviceTypeToString(ep_device->device->type);
entry.vendor_id = ep_device->device->vendor_id;
@@ -4158,11 +4169,13 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) {
// by position against the existing executionProviderIds field.
std::ostringstream types_oss;
std::ostringstream vendor_ids_oss;
+ std::ostringstream versions_oss;
bool first = true;
for (const auto& entry : telemetry_.ep_device_info_) {
if (!first) {
types_oss << ',';
vendor_ids_oss << ',';
+ versions_oss << ',';
}
first = false;
types_oss << entry.hardware_device_type;
@@ -4170,9 +4183,11 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) {
vendor_ids_oss << "0x" << std::hex << std::uppercase << std::setw(4)
<< std::setfill('0') << entry.vendor_id
<< std::dec << std::nouppercase << std::setfill(' ');
+ versions_oss << entry.ep_type << ':' << entry.ep_version;
}
telemetry_.ep_device_types_summary_ = types_oss.str();
telemetry_.ep_device_vendor_ids_summary_ = vendor_ids_oss.str();
+ telemetry_.ep_versions_summary_ = versions_oss.str();
}
#if !defined(ORT_MINIMAL_BUILD)
@@ -4419,6 +4434,7 @@ void InferenceSession::LogAllSessions() {
graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash,
model->MetaData(), session->telemetry_.event_name_, session->execution_providers_.GetIds(),
session->telemetry_.ep_device_types_summary_, session->telemetry_.ep_device_vendor_ids_summary_,
+ session->telemetry_.ep_versions_summary_,
model_has_fp16_inputs, true);
}
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index 11c887f6cdc16..19c18652ab67a 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -1023,12 +1023,14 @@ class InferenceSession {
uint32_t device_id = 0; // PCI device ID (0 when unavailable)
std::string vendor; // e.g. "Qualcomm"
std::string ep_vendor; // e.g. "Qualcomm" (from OrtEpDevice)
+ std::string ep_version; // e.g. "1.2.3" (from OrtEpFactory::GetVersion, empty when unavailable)
int assigned_node_count = 0; // # graph nodes assigned to this EP type
};
std::vector ep_device_info_;
// Pre-formatted comma-separated summaries used to enrich SessionCreation.
std::string ep_device_types_summary_; // "NPU,CPU"
std::string ep_device_vendor_ids_summary_; // "0x5143,0x0000"
+ std::string ep_versions_summary_; // "QNNExecutionProvider:1.2.3,CPUExecutionProvider:"
} telemetry_;
mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data
diff --git a/onnxruntime/core/session/model_package_api.cc b/onnxruntime/core/session/model_package_api.cc
index 27abb0f5f7a37..5fd25f9511eb9 100644
--- a/onnxruntime/core/session/model_package_api.cc
+++ b/onnxruntime/core/session/model_package_api.cc
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "core/session/model_package_api.h"
+#include "core/session/onnxruntime_experimental_c_api.h"
#include "core/common/common.h"
#include "core/framework/error_code_helper.h"
@@ -23,7 +23,9 @@ using namespace onnxruntime;
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, \
"Model package API is not supported in this build")
-ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions,
+namespace OrtExperimentalApis {
+
+ORT_API(void, OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28,
_Frees_ptr_opt_ OrtModelPackageOptions* options) {
#if !defined(ORT_MINIMAL_BUILD)
delete reinterpret_cast(options);
@@ -32,7 +34,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions,
#endif
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28,
_In_ const OrtEnv* env,
_In_ const OrtSessionOptions* session_options,
_Outptr_ OrtModelPackageOptions** out) {
@@ -54,7 +56,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOpti
API_IMPL_END
}
-ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext,
+ORT_API(void, OrtModelPackageApi_ReleaseModelPackageContext_SinceV28,
_Frees_ptr_opt_ OrtModelPackageContext* ctx) {
#if !defined(ORT_MINIMAL_BUILD)
delete reinterpret_cast(ctx);
@@ -63,7 +65,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext,
#endif
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageContext_SinceV28,
_In_ const ORTCHAR_T* package_root,
_Outptr_ OrtModelPackageContext** out) {
API_IMPL_BEGIN
@@ -89,7 +91,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_Out_ size_t* out_count) {
API_IMPL_BEGIN
@@ -107,7 +109,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
_Out_ size_t* out_count) {
@@ -136,7 +138,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_In_ const char* component_name,
_Out_ size_t* out_count) {
@@ -158,7 +160,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_In_ const char* component_name,
_Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
@@ -190,7 +192,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_SelectComponent_SinceV28,
_In_ const OrtModelPackageContext* context,
_In_ const char* component_name,
_In_ const OrtModelPackageOptions* options,
@@ -235,7 +237,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent,
API_IMPL_END
}
-ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext,
+ORT_API(void, OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28,
_Frees_ptr_opt_ OrtModelPackageComponentContext* cix) {
#if !defined(ORT_MINIMAL_BUILD)
delete reinterpret_cast(cix);
@@ -244,7 +246,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext,
#endif
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28,
_In_ const OrtModelPackageComponentContext* ctx,
_Outptr_ const ORTCHAR_T** folder_path) {
API_IMPL_BEGIN
@@ -269,7 +271,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateSession_SinceV28,
_In_ const OrtEnv* env,
_In_ OrtModelPackageComponentContext* ctx,
_In_opt_ const OrtSessionOptions* session_options,
@@ -383,7 +385,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession,
// ---------- API table ------------------------------------------------------
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_In_ const char* component_name,
_In_ const char* variant_name,
@@ -417,7 +419,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28,
_In_ const OrtModelPackageContext* ctx,
_Out_ int64_t* out_version) {
API_IMPL_BEGIN
@@ -437,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion,
API_IMPL_END
}
-ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName,
+ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28,
_In_ const OrtModelPackageComponentContext* ctx,
_Outptr_ const char** out_name) {
API_IMPL_BEGIN
@@ -461,40 +463,4 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant
API_IMPL_END
}
-// ---------- API table dispatch ---------------------------------------------
-
-static constexpr OrtModelPackageApi ort_model_package_api = {
- // Options
- &OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions,
- &OrtModelPackageAPI::ReleaseModelPackageOptions,
-
- // Context
- &OrtModelPackageAPI::CreateModelPackageContext,
- &OrtModelPackageAPI::ReleaseModelPackageContext,
-
- // Package-level queries
- &OrtModelPackageAPI::ModelPackage_GetSchemaVersion,
- &OrtModelPackageAPI::ModelPackage_GetComponentCount,
- &OrtModelPackageAPI::ModelPackage_GetComponentNames,
- &OrtModelPackageAPI::ModelPackage_GetVariantCount,
- &OrtModelPackageAPI::ModelPackage_GetVariantNames,
- &OrtModelPackageAPI::ModelPackage_GetVariantEpName,
-
- // Variant selection and component queries
- &OrtModelPackageAPI::SelectComponent,
- &OrtModelPackageAPI::ReleaseModelPackageComponentContext,
- &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName,
- &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath,
-
- // Session
- &OrtModelPackageAPI::CreateSession,
-
- // End of Version 1.27 - DO NOT MODIFY ABOVE
-};
-
-static_assert(offsetof(OrtModelPackageApi, CreateSession) / sizeof(void*) == 14,
- "Size of initial OrtModelPackageApi cannot change");
-
-ORT_API(const OrtModelPackageApi*, OrtModelPackageAPI::GetModelPackageApi) {
- return &ort_model_package_api;
-}
+} // namespace OrtExperimentalApis
diff --git a/onnxruntime/core/session/model_package_api.h b/onnxruntime/core/session/model_package_api.h
deleted file mode 100644
index 435e3a521c24c..0000000000000
--- a/onnxruntime/core/session/model_package_api.h
+++ /dev/null
@@ -1,76 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-#pragma once
-
-#include "core/session/onnxruntime_c_api.h"
-
-namespace OrtModelPackageAPI {
-
-ORT_API(const OrtModelPackageApi*, GetModelPackageApi);
-
-ORT_API(void, ReleaseModelPackageOptions, _Frees_ptr_opt_ OrtModelPackageOptions*);
-ORT_API_STATUS_IMPL(CreateModelPackageOptionsFromSessionOptions,
- _In_ const OrtEnv* env,
- _In_ const OrtSessionOptions* session_options,
- _Outptr_ OrtModelPackageOptions** out);
-
-ORT_API(void, ReleaseModelPackageContext, _Frees_ptr_opt_ OrtModelPackageContext*);
-ORT_API_STATUS_IMPL(CreateModelPackageContext,
- _In_ const ORTCHAR_T* package_root,
- _Outptr_ OrtModelPackageContext** out);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetComponentCount,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ size_t* out_count);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetComponentNames,
- _In_ const OrtModelPackageContext* ctx,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names,
- _Out_ size_t* out_count);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetVariantCount,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Out_ size_t* out_count);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetVariantNames,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names,
- _Out_ size_t* out_count);
-
-ORT_API_STATUS_IMPL(SelectComponent,
- _In_ const OrtModelPackageContext* context,
- _In_ const char* component_name,
- _In_ const OrtModelPackageOptions* options,
- _Outptr_ OrtModelPackageComponentContext** out);
-
-ORT_API(void, ReleaseModelPackageComponentContext,
- _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx);
-
-ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantFolderPath,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const ORTCHAR_T** folder_path);
-
-ORT_API_STATUS_IMPL(CreateSession,
- _In_ const OrtEnv* env,
- _In_ OrtModelPackageComponentContext* ctx,
- _In_opt_ const OrtSessionOptions* session_options,
- _Outptr_ OrtSession** session);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetVariantEpName,
- _In_ const OrtModelPackageContext* ctx,
- _In_ const char* component_name,
- _In_ const char* variant_name,
- _Outptr_result_maybenull_ const char** out_ep);
-
-ORT_API_STATUS_IMPL(ModelPackage_GetSchemaVersion,
- _In_ const OrtModelPackageContext* ctx,
- _Out_ int64_t* out_version);
-
-ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantName,
- _In_ const OrtModelPackageComponentContext* ctx,
- _Outptr_ const char** out_name);
-
-} // namespace OrtModelPackageAPI
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 61a413d92e7fc..f451eaa401497 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -57,7 +57,6 @@
#include "core/session/ort_env.h"
#include "core/session/ort_version_check.h"
#include "core/session/utils.h"
-#include "core/session/model_package_api.h"
#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
#include "core/providers/cuda/cuda_provider_factory.h"
@@ -3689,10 +3688,6 @@ ORT_API(const OrtCompileApi*, OrtApis::GetCompileApi) {
return OrtCompileAPI::GetCompileApi();
}
-ORT_API(const OrtModelPackageApi*, OrtApis::GetModelPackageApi) {
- return OrtModelPackageAPI::GetModelPackageApi();
-}
-
ORT_API(void, OrtApis::CreateKeyValuePairs, _Outptr_ OrtKeyValuePairs** out) {
auto kvps = std::make_unique();
*out = reinterpret_cast(kvps.release());
@@ -4911,7 +4906,6 @@ static constexpr OrtApi ort_api_1_to_28 = {
&OrtApis::GetMemPatternEnabled,
&OrtApis::GetSessionExecutionMode,
&OrtApis::SessionReleaseCapturedGraph,
- &OrtApis::GetModelPackageApi,
// End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::GetExperimentalFunction,
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 250a2853a4777..61ece2dd9a682 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -825,9 +825,6 @@ ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtVal
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
-// Model Package API
-ORT_API(const OrtModelPackageApi*, GetModelPackageApi);
-
// Experimental API
ORT_API(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name);
diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc
index 5b1d590a06234..7220153b4fa17 100644
--- a/onnxruntime/python/onnxruntime_pybind_quant.cc
+++ b/onnxruntime/python/onnxruntime_pybind_quant.cc
@@ -16,7 +16,6 @@
#endif
#include
#include
-#include
#include
namespace pybind11 {
@@ -252,17 +251,8 @@ py::array_t PackWeightsForMixedGemm(
cudaDeviceProp device_prop;
ThrowIfCudaError(cudaGetDeviceProperties(&device_prop, device_id), "cudaGetDeviceProperties");
sm = device_prop.major * 10 + device_prop.minor;
- } else {
- // Validate force_arch against the SM versions for which preprocess_weights_for_mixed_gemm_cuda
- // has tile/permutation tables. Unknown SMs would silently produce incorrect weight layouts.
- static const std::set kSupportedSm = {75, 80, 90};
- if (kSupportedSm.find(sm) == kSupportedSm.end()) {
- std::ostringstream oss;
- oss << "force_arch=" << sm << " is not a supported SM version. "
- << "Pass -1 for auto-detect, or one of: 75, 80, 90 (arch > 90 will fallback to 80).";
- throw std::invalid_argument(oss.str());
- }
}
+ sm = ::onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(sm);
auto permutation_map_buffer = make_cuda_ptr(32 * sizeof(int32_t));
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 77200ea84778e..2044a128d9540 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -3247,195 +3247,6 @@ including arg name, arg type (contains both type and shape).)pbdoc")
#endif
},
R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc");
-
- // --- Model Package API ---
-#if !defined(ORT_MINIMAL_BUILD)
- // Helper to create a PyInferenceSession from a pre-initialized OrtSession* (C API handle).
- // PyInferenceSession's owning ctor is protected; this subclass provides access.
- struct PyModelPackageSession : PyInferenceSession {
- PyModelPackageSession(std::unique_ptr sess)
- : PyInferenceSession(std::move(sess)) {}
- };
-
- // Wrapper classes to manage opaque C handles with proper RAII
- struct PyModelPackageContext {
- OrtModelPackageContext* ctx_{nullptr};
- PyModelPackageContext(const std::string& package_path) {
- auto path = ToPathString(package_path);
- const auto* api = Ort::GetApi().GetModelPackageApi();
- Ort::ThrowOnError(api->CreateModelPackageContext(path.c_str(), &ctx_));
- }
- ~PyModelPackageContext() {
- if (ctx_) {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- api->ReleaseModelPackageContext(ctx_);
- }
- }
- PyModelPackageContext(const PyModelPackageContext&) = delete;
- PyModelPackageContext& operator=(const PyModelPackageContext&) = delete;
- };
-
- struct PyModelPackageComponentContext {
- OrtModelPackageComponentContext* ctx_{nullptr};
- ~PyModelPackageComponentContext() {
- if (ctx_) {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- api->ReleaseModelPackageComponentContext(ctx_);
- }
- }
- PyModelPackageComponentContext(const PyModelPackageComponentContext&) = delete;
- PyModelPackageComponentContext& operator=(const PyModelPackageComponentContext&) = delete;
- PyModelPackageComponentContext() = default;
- };
-
- struct PyModelPackageOptions {
- OrtModelPackageOptions* opts_{nullptr};
- ~PyModelPackageOptions() {
- if (opts_) {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- api->ReleaseModelPackageOptions(opts_);
- }
- }
- PyModelPackageOptions(const PyModelPackageOptions&) = delete;
- PyModelPackageOptions& operator=(const PyModelPackageOptions&) = delete;
- PyModelPackageOptions() = default;
- };
-
- py::class_(m, "ModelPackageContext",
- R"pbdoc(Represents an opened model package for inspection and component selection.)pbdoc")
- .def(py::init(), py::arg("package_path"),
- R"pbdoc(Open a model package from the given directory path.)pbdoc")
- .def(
- "get_component_names",
- [](PyModelPackageContext& self) -> std::vector {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- const char* const* names = nullptr;
- size_t count = 0;
- Ort::ThrowOnError(api->ModelPackage_GetComponentNames(self.ctx_, &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
- },
- R"pbdoc(Get the names of all components in the package.)pbdoc")
- .def(
- "get_variant_names",
- [](PyModelPackageContext& self, const std::string& component_name) -> std::vector {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- const char* const* names = nullptr;
- size_t count = 0;
- Ort::ThrowOnError(api->ModelPackage_GetVariantNames(
- self.ctx_, component_name.c_str(), &names, &count));
- std::vector result;
- result.reserve(count);
- for (size_t i = 0; i < count; ++i) {
- result.emplace_back(names[i]);
- }
- return result;
- },
- py::arg("component_name"),
- R"pbdoc(Get the variant names for a given component.)pbdoc")
- .def(
- "get_variant_ep_name",
- [](PyModelPackageContext& self, const std::string& component_name,
- const std::string& variant_name) -> std::optional {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- const char* ep = nullptr;
- Ort::ThrowOnError(api->ModelPackage_GetVariantEpName(
- self.ctx_, component_name.c_str(), variant_name.c_str(), &ep));
- if (ep) return std::string(ep);
- return std::nullopt;
- },
- py::arg("component_name"), py::arg("variant_name"),
- R"pbdoc(Get the EP name for a variant. Returns None if not declared.)pbdoc")
- .def(
- "get_schema_version",
- [](PyModelPackageContext& self) -> int64_t {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- int64_t version = 0;
- Ort::ThrowOnError(api->ModelPackage_GetSchemaVersion(self.ctx_, &version));
- return version;
- },
- R"pbdoc(Get the schema version declared in the model package manifest.)pbdoc")
- .def(
- "select_component",
- [](PyModelPackageContext& self, const std::string& component_name,
- PyModelPackageOptions& options) -> std::unique_ptr {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- auto result = std::make_unique();
- Ort::ThrowOnError(api->SelectComponent(
- self.ctx_, component_name.c_str(), options.opts_, &result->ctx_));
- return result;
- },
- py::arg("component_name"), py::arg("options"),
- R"pbdoc(Select a component and resolve its variant based on the provided options.
-Returns a ModelPackageComponentContext for inspecting the selected variant.)pbdoc");
-
- py::class_(m, "ModelPackageOptions",
- R"pbdoc(Options used for variant selection in a model package.
-Created from a SessionOptions to capture EP configuration for variant matching.)pbdoc")
- .def(py::init([](PySessionOptions& session_options) {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- auto result = std::make_unique();
- Ort::ThrowOnError(api->CreateModelPackageOptionsFromSessionOptions(
- GetOrtEnv(), &session_options, &result->opts_));
- return result;
- }),
- py::arg("session_options"),
- R"pbdoc(Create model package options from a SessionOptions instance.
-The EP configured on the session options is used for variant selection.)pbdoc");
-
- py::class_(m, "ModelPackageComponentContext",
- R"pbdoc(Represents a selected component within a model package.
-Provides access to the resolved variant's files, session options, and metadata.)pbdoc")
- .def(
- "get_selected_variant_folder_path",
- [](PyModelPackageComponentContext& self) -> std::string {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- const ORTCHAR_T* path = nullptr;
- Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantFolderPath(self.ctx_, &path));
- return PathToUTF8String(PathString(path));
- },
- R"pbdoc(Get the folder path of the selected variant.)pbdoc")
- .def(
- "get_selected_variant_name",
- [](PyModelPackageComponentContext& self) -> std::string {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- const char* name = nullptr;
- Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantName(
- self.ctx_, &name));
- return name ? std::string(name) : std::string();
- },
- R"pbdoc(Get the name of the selected variant.)pbdoc")
- .def(
- "create_session",
- [](PyModelPackageComponentContext& self, py::object session_options_obj) -> std::unique_ptr {
- const auto* api = Ort::GetApi().GetModelPackageApi();
- OrtSession* ort_session = nullptr;
- if (session_options_obj.is_none()) {
- Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, nullptr, &ort_session));
- } else {
- auto& so = session_options_obj.cast();
- Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, &so, &ort_session));
- }
- // OrtSession* is a reinterpret_cast of InferenceSession*
- auto* inference_session = reinterpret_cast(ort_session);
- std::unique_ptr session_ptr(inference_session);
- return std::make_unique(std::move(session_ptr));
- },
- py::arg("session_options") = py::none(),
- R"pbdoc(Create an InferenceSession from the selected component variant.
-
-Args:
- session_options: Optional SessionOptions override. If None, uses the options
- captured during variant selection with per-file options merged on top.
- If provided, variant-specific options are NOT applied.
-
-Returns:
- An InferenceSession ready for inference.)pbdoc");
-#endif // !defined(ORT_MINIMAL_BUILD)
}
bool InitArray() {
diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc
index 6fb3f8e6ba82f..34f73eb69b149 100644
--- a/onnxruntime/test/autoep/test_model_package.cc
+++ b/onnxruntime/test/autoep/test_model_package.cc
@@ -12,6 +12,7 @@
#include "gtest/gtest.h"
#include "core/session/model_package/model_package_context.h"
+#include "core/session/onnxruntime_experimental_c_api.h"
#include "core/session/abi_devices.h"
#include "test/autoep/test_autoep_utils.h"
#include "test/util/include/asserts.h"
@@ -22,6 +23,90 @@ extern std::unique_ptr ort_env;
namespace onnxruntime {
namespace test {
namespace {
+
+// Typed function pointers for every OrtModelPackageApi_* experimental entry,
+// resolved once via the experimental name-based lookup.
+struct ModelPackageFns {
+ OrtExperimental_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn
+ CreateModelPackageOptionsFromSessionOptions{nullptr};
+ OrtExperimental_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn
+ ReleaseModelPackageOptions{nullptr};
+ OrtExperimental_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn
+ CreateModelPackageContext{nullptr};
+ OrtExperimental_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn
+ ReleaseModelPackageContext{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn
+ ModelPackage_GetSchemaVersion{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn
+ ModelPackage_GetComponentCount{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn
+ ModelPackage_GetComponentNames{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn
+ ModelPackage_GetVariantCount{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn
+ ModelPackage_GetVariantNames{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn
+ ModelPackage_GetVariantEpName{nullptr};
+ OrtExperimental_OrtModelPackageApi_SelectComponent_SinceV28_Fn
+ SelectComponent{nullptr};
+ OrtExperimental_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn
+ ReleaseModelPackageComponentContext{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn
+ ModelPackageComponent_GetSelectedVariantName{nullptr};
+ OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn
+ ModelPackageComponent_GetSelectedVariantFolderPath{nullptr};
+ OrtExperimental_OrtModelPackageApi_CreateSession_SinceV28_Fn
+ CreateSession{nullptr};
+};
+
+inline const ModelPackageFns& GetModelPackageFns() {
+ static const ModelPackageFns fns = []() {
+ const OrtApi* api = &Ort::GetApi();
+ ModelPackageFns f;
+#define RESOLVE(member, getter) \
+ do { \
+ f.member = Ort::Experimental::getter(api); \
+ if (f.member == nullptr) { \
+ throw std::runtime_error(std::string("Failed to resolve experimental " \
+ "OrtModelPackageApi_") + \
+ #member "_SinceV28"); \
+ } \
+ } while (0)
+ RESOLVE(CreateModelPackageOptionsFromSessionOptions,
+ Get_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn);
+ RESOLVE(ReleaseModelPackageOptions,
+ Get_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn);
+ RESOLVE(CreateModelPackageContext,
+ Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn);
+ RESOLVE(ReleaseModelPackageContext,
+ Get_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetSchemaVersion,
+ Get_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetComponentCount,
+ Get_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetComponentNames,
+ Get_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetVariantCount,
+ Get_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetVariantNames,
+ Get_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn);
+ RESOLVE(ModelPackage_GetVariantEpName,
+ Get_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn);
+ RESOLVE(SelectComponent,
+ Get_OrtModelPackageApi_SelectComponent_SinceV28_Fn);
+ RESOLVE(ReleaseModelPackageComponentContext,
+ Get_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn);
+ RESOLVE(ModelPackageComponent_GetSelectedVariantName,
+ Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn);
+ RESOLVE(ModelPackageComponent_GetSelectedVariantFolderPath,
+ Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn);
+ RESOLVE(CreateSession,
+ Get_OrtModelPackageApi_CreateSession_SinceV28_Fn);
+#undef RESOLVE
+ return f;
+ }();
+ return fns;
+}
// ------------------------------------------------------------------
// Helpers to build a test model package on disk
// ------------------------------------------------------------------
@@ -233,26 +318,25 @@ std::filesystem::path CreateModelPackageApiTestPackage(bool multi_file_variant =
TEST(ModelPackageApiTest, PackageContextQueries) {
const auto package_root = CreateModelPackageApiTestPackage();
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) {
- if (p) pkg_api->ReleaseModelPackageContext(p);
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) {
+ if (p) pkg_api.ReleaseModelPackageContext(p);
};
std::unique_ptr model_pkg_context(nullptr, context_deleter);
OrtModelPackageContext* raw_context = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context));
model_pkg_context.reset(raw_context);
// Query: component count + names
size_t component_count = 0;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count));
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count));
ASSERT_EQ(component_count, 1u);
const char* const* component_names = nullptr;
size_t component_name_count = 0;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentNames(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentNames(
model_pkg_context.get(), &component_names, &component_name_count));
ASSERT_EQ(component_name_count, 1u);
ASSERT_NE(component_names, nullptr);
@@ -261,13 +345,13 @@ TEST(ModelPackageApiTest, PackageContextQueries) {
// Query: variant count + names
size_t variant_count = 0;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantCount(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantCount(
model_pkg_context.get(), "model_1", &variant_count));
ASSERT_EQ(variant_count, 2u);
const char* const* variant_names = nullptr;
size_t variant_name_count = 0;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantNames(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantNames(
model_pkg_context.get(), "model_1", &variant_names, &variant_name_count));
ASSERT_EQ(variant_name_count, 2u);
@@ -294,17 +378,16 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS
std::unordered_map ep_options;
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
- auto options_deleter = [pkg_api](OrtModelPackageOptions* p) {
- if (p) pkg_api->ReleaseModelPackageOptions(p);
+ auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) {
+ if (p) pkg_api.ReleaseModelPackageOptions(p);
};
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) {
- if (p) pkg_api->ReleaseModelPackageContext(p);
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) {
+ if (p) pkg_api.ReleaseModelPackageContext(p);
};
- auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) {
- if (p) pkg_api->ReleaseModelPackageComponentContext(p);
+ auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) {
+ if (p) pkg_api.ReleaseModelPackageComponentContext(p);
};
std::unique_ptr model_pkg_options(nullptr, options_deleter);
@@ -312,25 +395,25 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS
std::unique_ptr component_context(nullptr, component_context_deleter);
OrtModelPackageOptions* raw_options = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options));
model_pkg_options.reset(raw_options);
OrtModelPackageContext* raw_context = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context));
model_pkg_context.reset(raw_context);
OrtModelPackageComponentContext* raw_component_context = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(model_pkg_context.get(),
- "model_1",
- model_pkg_options.get(),
- &raw_component_context));
+ ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(model_pkg_context.get(),
+ "model_1",
+ model_pkg_options.get(),
+ &raw_component_context));
component_context.reset(raw_component_context);
OrtSession* raw_session = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateSession(*ort_env,
- component_context.get(),
- session_options,
- &raw_session));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateSession(*ort_env,
+ component_context.get(),
+ session_options,
+ &raw_session));
Ort::Session session(raw_session);
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -921,33 +1004,32 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) {
os << R"({"filename":"mul_1.onnx"})";
}
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) {
- if (p) pkg_api->ReleaseModelPackageContext(p);
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) {
+ if (p) pkg_api.ReleaseModelPackageContext(p);
};
std::unique_ptr ctx(nullptr, context_deleter);
OrtModelPackageContext* raw_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx));
ctx.reset(raw_ctx);
// variant_1 targets example_ep
const char* ep1 = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName(
ctx.get(), "model_1", "variant_1", &ep1));
ASSERT_NE(ep1, nullptr);
EXPECT_STREQ(ep1, "example_ep");
// variant_2 targets other_ep
const char* ep2 = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName(
ctx.get(), "model_1", "variant_2", &ep2));
ASSERT_NE(ep2, nullptr);
EXPECT_STREQ(ep2, "other_ep");
// Optional out-parameter: callers can pass NULL.
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName(
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName(
ctx.get(), "model_1", "variant_1", nullptr));
std::filesystem::remove_all(package_root, ec);
@@ -990,32 +1072,31 @@ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) {
std::unordered_map ep_options;
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
- auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); };
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); };
- auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) {
- if (p) pkg_api->ReleaseModelPackageComponentContext(p);
+ auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); };
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); };
+ auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) {
+ if (p) pkg_api.ReleaseModelPackageComponentContext(p);
};
std::unique_ptr mp_opts(nullptr, options_deleter);
std::unique_ptr ctx(nullptr, context_deleter);
std::unique_ptr comp_ctx(nullptr, component_context_deleter);
OrtModelPackageOptions* raw_mp_opts = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts));
mp_opts.reset(raw_mp_opts);
OrtModelPackageContext* raw_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx));
ctx.reset(raw_ctx);
OrtModelPackageComponentContext* raw_comp_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
+ ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
comp_ctx.reset(raw_comp_ctx);
const ORTCHAR_T* selected_folder = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder));
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder));
ASSERT_NE(selected_folder, nullptr);
// Path looks like .../models/model_1/ -- the folder name is the variant.
@@ -1082,28 +1163,27 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn
std::unordered_map ep_options;
session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
- auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); };
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); };
- auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) {
- if (p) pkg_api->ReleaseModelPackageComponentContext(p);
+ auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); };
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); };
+ auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) {
+ if (p) pkg_api.ReleaseModelPackageComponentContext(p);
};
std::unique_ptr mp_opts(nullptr, options_deleter);
std::unique_ptr ctx(nullptr, context_deleter);
std::unique_ptr comp_ctx(nullptr, component_context_deleter);
OrtModelPackageOptions* raw_mp_opts = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts));
mp_opts.reset(raw_mp_opts);
OrtModelPackageContext* raw_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx));
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx));
ctx.reset(raw_ctx);
OrtModelPackageComponentContext* raw_comp_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
+ ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
comp_ctx.reset(raw_comp_ctx);
// CreateSession iterates the per-file session_options and dispatches each through OrtApis::AddSessionConfigEntry.
@@ -1111,7 +1191,7 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn
// Pass nullptr for session_options so the metadata-merge path runs (it is skipped when the caller
// supplies their own session_options).
OrtSession* raw_session = nullptr;
- OrtStatus* st = pkg_api->CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session);
+ OrtStatus* st = pkg_api.CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session);
// Clean up session first to avoid leaks if assertion fails.
if (raw_session != nullptr) {
Ort::GetApi().ReleaseSession(raw_session);
@@ -1131,60 +1211,6 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn
std::filesystem::remove_all(package_root, ec);
}
-// Test that the C++ RAII wrappers (Ort::ModelPackageContext, etc.) work correctly.
-TEST(ModelPackageApiTest, CxxWrappers_PackageContextQueries) {
- const auto package_root = CreateModelPackageApiTestPackage();
-
- Ort::ModelPackageContext ctx(package_root.c_str());
-
- // Component queries
- EXPECT_EQ(ctx.GetComponentCount(), 1u);
- auto component_names = ctx.GetComponentNames();
- ASSERT_EQ(component_names.size(), 1u);
- EXPECT_EQ(component_names[0], "model_1");
-
- // Variant queries
- EXPECT_EQ(ctx.GetVariantCount("model_1"), 2u);
- auto variant_names = ctx.GetVariantNames("model_1");
- ASSERT_EQ(variant_names.size(), 2u);
- std::unordered_set variant_set(variant_names.begin(), variant_names.end());
- EXPECT_EQ(variant_set.count("variant_1"), 1u);
- EXPECT_EQ(variant_set.count("variant_2"), 1u);
-
- std::error_code ec;
- std::filesystem::remove_all(package_root, ec);
-}
-
-TEST(ModelPackageApiTest, CxxWrappers_SelectComponentAndQueryFileAccessors) {
- const auto package_root = CreateModelPackageApiTestPackage();
-
- RegisteredEpDeviceUniquePtr example_ep;
- ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep));
- Ort::ConstEpDevice plugin_ep_device(example_ep.get());
-
- Ort::SessionOptions so;
- std::unordered_map ep_options;
- so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- Ort::ModelPackageOptions pkg_opts(*ort_env, so);
-
- Ort::ModelPackageContext ctx(package_root.c_str());
- auto cix = ctx.SelectComponent("model_1", pkg_opts);
-
- // Folder path should be non-empty
- auto folder = cix.GetSelectedVariantFolderPath();
- EXPECT_FALSE(folder.empty());
-
- // Selected variant name should not throw
- auto variant_name = cix.GetSelectedVariantName();
- EXPECT_FALSE(variant_name.empty());
-
- // CreateSession via C++ wrapper
- auto session = cix.CreateSession(*ort_env, so);
-
- std::error_code ec;
- std::filesystem::remove_all(package_root, ec);
-}
-
// ------------------------------------------------------------------
// Test: GetSelectedVariantFolderPath returns correct path even when variant.json is absent.
// ------------------------------------------------------------------
@@ -1221,31 +1247,29 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) {
Ort::SessionOptions so;
std::unordered_map ep_options;
so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
- Ort::ModelPackageOptions pkg_opts(*ort_env, so);
- const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi();
- ASSERT_NE(pkg_api, nullptr);
+ const auto& pkg_api = GetModelPackageFns();
OrtModelPackageOptions* raw_mp_opts = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts));
- auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); };
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts));
+ auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); };
std::unique_ptr mp_opts(raw_mp_opts, options_deleter);
OrtModelPackageContext* raw_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx));
- auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); };
+ ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx));
+ auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); };
std::unique_ptr ctx(raw_ctx, context_deleter);
OrtModelPackageComponentContext* raw_comp_ctx = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
- auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) {
- if (p) pkg_api->ReleaseModelPackageComponentContext(p);
+ ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx));
+ auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) {
+ if (p) pkg_api.ReleaseModelPackageComponentContext(p);
};
std::unique_ptr comp_ctx(raw_comp_ctx, component_context_deleter);
// GetSelectedVariantFolderPath should return the variant directory even without variant.json.
const ORTCHAR_T* selected_folder = nullptr;
- ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder));
+ ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder));
ASSERT_NE(selected_folder, nullptr);
const auto result_path = std::filesystem::path(selected_folder);
diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc
index cc50494273aad..38cc3014aa873 100644
--- a/onnxruntime/test/contrib_ops/moe_test.cc
+++ b/onnxruntime/test/contrib_ops/moe_test.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "gtest/gtest.h"
+#include "core/session/onnxruntime_session_options_config_keys.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
@@ -778,6 +779,74 @@ TEST(MoETest, MoETest_Mixtral) {
2 /*top_k*/);
}
+TEST(MoETest, QMoETest_CUDA_Int4_DisablePrepackingFailsLoudly) {
+ constexpr int min_cuda_arch = 700;
+ if (!HasCudaEnvironment(min_cuda_arch)) {
+ GTEST_SKIP() << "CUDA execution provider not available";
+ }
+
+ auto cuda_ep = DefaultCudaExecutionProvider();
+ if (!cuda_ep) {
+ GTEST_SKIP() << "CUDA execution provider not available";
+ }
+
+ constexpr int64_t num_rows = 1;
+ constexpr int64_t num_experts = 1;
+ constexpr int64_t hidden_size = 128;
+ constexpr int64_t inter_size = 128;
+ constexpr int64_t expert_weight_bits = 4;
+ constexpr int64_t pack_size = 8 / expert_weight_bits;
+
+ const std::vector input(num_rows * hidden_size, 0.0f);
+ const std::vector router_probs(num_rows * num_experts, 1.0f);
+ const std::vector fc1_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0);
+ const std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0);
+ const std::vector fc1_scales(num_experts * inter_size, 1.0f);
+ const std::vector fc2_scales(num_experts * hidden_size, 1.0f);
+ const std::vector dummy_output(num_rows * hidden_size, 0.0f);
+
+ OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain);
+ cuda_tester.AddAttribute("k", 1);
+ cuda_tester.AddAttribute("activation_type", "identity");
+ cuda_tester.AddAttribute("normalize_routing_weights", 1);
+ cuda_tester.AddAttribute("expert_weight_bits", expert_weight_bits);
+ cuda_tester.AddAttribute("quant_type", "int");
+ cuda_tester.AddAttribute("weights_prepacked", 0);
+
+ const std::vector input_dims = {num_rows, hidden_size};
+ const std::vector router_probs_dims = {num_rows, num_experts};
+ const std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size};
+ const std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size};
+ const std::vector fc1_scales_dims = {num_experts, inter_size};
+ const std::vector fc2_scales_dims = {num_experts, hidden_size};
+ const std::vector output_dims = {num_rows, hidden_size};
+
+ cuda_tester.AddInput("input", input_dims, ToFloat16(input));
+ cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs));
+ cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights);
+ cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales));
+ cuda_tester.AddOptionalInputEdge();
+ cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights);
+ cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales));
+ cuda_tester.AddOptionalInputEdge();
+ cuda_tester.AddOptionalInputEdge();
+ cuda_tester.AddOptionalInputEdge();
+ cuda_tester.AddOptionalInputEdge();
+ cuda_tester.AddOutput("output", output_dims, ToFloat16(dummy_output));
+
+ SessionOptions session_options;
+ session_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "1";
+
+ std::vector> cuda_execution_providers;
+ cuda_execution_providers.push_back(std::move(cuda_ep));
+ cuda_tester.Run(session_options,
+ OpTester::ExpectResult::kExpectFailure,
+ "QMoE weights_prepacked=0 requires PrePack to run",
+ {},
+ nullptr,
+ &cuda_execution_providers);
+}
+
TEST(MoETest, QMoETest_Mixtral_Int4) {
// This test uses FC3 (gated SiLU / Mixtral pattern) with dimensions too small for the
// CUTLASS kernel (needs hidden_size >= 128, inter_size >= 128). CPU QMoE does not
diff --git a/onnxruntime/test/platform/linux/npu_device_discovery_test.cc b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc
new file mode 100644
index 0000000000000..6e2170146f39f
--- /dev/null
+++ b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc
@@ -0,0 +1,103 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/platform/linux/npu_device_discovery.h"
+
+#include
+#include
+#include
+#include
+#include
+
+#include "gtest/gtest.h"
+#include "test/util/include/asserts.h"
+
+namespace fs = std::filesystem;
+
+namespace onnxruntime::test {
+namespace {
+
+void WriteFile(const fs::path& path, const std::string& value) {
+ std::ofstream f(path);
+ f << value;
+}
+
+class NpuDeviceDiscoveryTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ temp_dir_ = fs::temp_directory_path() / "ort_npu_discovery_test";
+ fs::remove_all(temp_dir_);
+ fs::create_directories(temp_dir_);
+ }
+
+ void TearDown() override {
+ fs::remove_all(temp_dir_);
+ }
+
+ fs::path temp_dir_;
+};
+
+} // namespace
+
+TEST_F(NpuDeviceDiscoveryTest, ReturnsEmptyForNonexistentPath) {
+ std::vector npu_paths;
+ ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_ / "nonexistent", npu_paths));
+ EXPECT_TRUE(npu_paths.empty());
+}
+
+TEST_F(NpuDeviceDiscoveryTest, DetectsAccelDevices) {
+ fs::create_directories(temp_dir_ / "accel0");
+ fs::create_directories(temp_dir_ / "accel12");
+ fs::create_directories(temp_dir_ / "renderD128");
+ fs::create_directories(temp_dir_ / "accelabc");
+
+ std::vector npu_paths;
+ ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_, npu_paths));
+
+ ASSERT_EQ(npu_paths.size(), 2u);
+
+ std::vector accel_indices;
+ accel_indices.reserve(npu_paths.size());
+ for (const auto& npu_path : npu_paths) {
+ accel_indices.push_back(npu_path.accel_idx);
+ }
+
+ std::sort(accel_indices.begin(), accel_indices.end());
+
+ EXPECT_EQ(accel_indices[0], 0u);
+ EXPECT_EQ(accel_indices[1], 12u);
+}
+
+TEST_F(NpuDeviceDiscoveryTest, GetNpuDeviceFromSysfsReadsVendorDeviceAndMetadata) {
+ const auto pci_device_dir = temp_dir_ / "pci_devices" / "0000:65:00.0";
+ const auto accel_dir = temp_dir_ / "class_accel" / "accel0";
+
+ fs::create_directories(pci_device_dir);
+ fs::create_directories(accel_dir);
+
+ WriteFile(pci_device_dir / "vendor", "0x1022");
+ WriteFile(pci_device_dir / "device", "0x1502");
+
+ std::error_code error_code{};
+ fs::create_directory_symlink(pci_device_dir, accel_dir / "device", error_code);
+ ASSERT_FALSE(error_code) << error_code.message();
+
+ npu_device_discovery::NpuSysfsPathInfo path_info{};
+ path_info.accel_idx = 0;
+ path_info.path = accel_dir;
+
+ OrtHardwareDevice npu_device{};
+ ASSERT_STATUS_OK(npu_device_discovery::GetNpuDeviceFromSysfs(path_info, npu_device));
+
+ EXPECT_EQ(npu_device.type, OrtHardwareDeviceType_NPU);
+ EXPECT_EQ(npu_device.vendor_id, 0x1022u);
+ EXPECT_EQ(npu_device.device_id, 0x1502u);
+
+ const auto& entries = npu_device.metadata.Entries();
+ EXPECT_NE(entries.find("accel_idx"), entries.end());
+ EXPECT_EQ(entries.at("accel_idx"), "0");
+ EXPECT_NE(entries.find("pci_bus_id"), entries.end());
+ EXPECT_EQ(entries.at("pci_bus_id"), "0000:65:00.0");
+}
+
+} // namespace onnxruntime::test
diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
index 966090eed861e..c87e5d7c4c6e1 100644
--- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
@@ -233,6 +233,80 @@ TEST(SignalOpsTest, STFTFloat) {
test.Run();
}
+static void TestSTFTInvalidFrameStep(int64_t frame_step) {
+ OpTester test("STFT", kMinOpsetVersion);
+
+ vector signal(64, 1);
+ test.AddInput("signal", {1, 64, 1}, signal);
+ test.AddInput("frame_step", {}, {frame_step});
+ vector window(16, 1);
+ test.AddInput("window", {16}, window);
+ test.AddInput("frame_length", {}, {16});
+
+ vector output_shape = {1, 7, 9, 2};
+ vector expected_output(1 * 7 * 9 * 2, 0.f);
+ test.AddOutput("output", output_shape, expected_output);
+ test.Config(OpTester::ExpectResult::kExpectFailure, "frame_step must be greater than zero");
+ test.ConfigExcludeEps({kDmlExecutionProvider});
+ test.RunWithConfig();
+}
+
+TEST(SignalOpsTest, STFTFrameStepMustBePositive) {
+ TestSTFTInvalidFrameStep(0);
+ TestSTFTInvalidFrameStep(-1);
+}
+
+template
+static void TestSTFTComplexInputBatched() {
+ OpTester test("STFT", kMinOpsetVersion);
+ test.AddAttribute("onesided", static_cast(false));
+
+ constexpr int64_t batch_size = 2;
+ constexpr int64_t signal_size = 128;
+ constexpr int64_t signal_components = 2;
+ constexpr int64_t frame_length = 32;
+ constexpr int64_t frame_step = 16;
+ constexpr int64_t n_dfts = 7;
+ constexpr int64_t dft_output_size = frame_length;
+ constexpr int64_t output_components = 2;
+
+ vector signal(batch_size * signal_size * signal_components, static_cast(0));
+ for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
+ const T signal_value = batch_idx == 0 ? static_cast(1) : static_cast(99);
+ for (int64_t sample_idx = 0; sample_idx < signal_size; ++sample_idx) {
+ signal[(batch_idx * signal_size + sample_idx) * signal_components] = signal_value;
+ }
+ }
+
+ test.AddInput("signal", {batch_size, signal_size, signal_components}, signal);
+ test.AddInput("frame_step", {}, {frame_step});
+ test.AddOptionalInputEdge();
+ test.AddInput("frame_length", {}, {frame_length});
+
+ vector output_shape = {batch_size, n_dfts, dft_output_size, output_components};
+ vector expected_output(batch_size * n_dfts * dft_output_size * output_components, static_cast(0));
+ for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
+ const T dc_value = (batch_idx == 0 ? static_cast(1) : static_cast(99)) * static_cast(frame_length);
+ for (int64_t frame_idx = 0; frame_idx < n_dfts; ++frame_idx) {
+ expected_output[((batch_idx * n_dfts + frame_idx) * dft_output_size) * output_components] = dc_value;
+ }
+ }
+
+ test.AddOutput("output", output_shape, expected_output);
+ test.SetOutputAbsErr("output", 0.001f);
+ // DML does not consistently match these CPU STFT validation/regression paths in Windows GPU CI.
+ test.ConfigExcludeEps({kDmlExecutionProvider});
+ test.RunWithConfig();
+}
+
+TEST(SignalOpsTest, STFTFloatComplexInputBatched) {
+ TestSTFTComplexInputBatched();
+}
+
+TEST(SignalOpsTest, STFTDoubleComplexInputBatched) {
+ TestSTFTComplexInputBatched();
+}
+
TEST(SignalOpsTest, HannWindowFloat) {
OpTester test("HannWindow", kMinOpsetVersion);
diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py
index c5fc826a5a6ed..9677542270a53 100644
--- a/onnxruntime/test/python/transformers/test_moe_cuda.py
+++ b/onnxruntime/test/python/transformers/test_moe_cuda.py
@@ -152,7 +152,10 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True):
q_weight_reshaped = q_weight.reshape(n, -1)
# Pack weights for CUDA mixed-gemm kernel (FpA_IntB format), and qMoE kernel uses the same format.
- processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4)
+ # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 --
+ # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1)
+ # would emit the non-interleaved SM90 layout on Hopper and produce wrong results.
+ processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4, 80)
# So we need to DEQUANTIZE back to get `result`.
# scale is [n, block_per_k]
@@ -232,8 +235,11 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True):
)
q_weight_reshaped = q_weight.reshape(n, -1)
- # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format)
- processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8)
+ # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format).
+ # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 --
+ # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1)
+ # would emit the non-interleaved SM90 layout on Hopper and produce wrong results.
+ processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8, 80)
# Dequantize for reference
# (q - 128) * scale if using 128 offset? or (q) * scale if symmetric around 0?
@@ -1084,8 +1090,8 @@ def parity_check(self):
ort_dtype_quant_bits_tolerance_map = {
"FP32:0": (5e-3, 1e-3),
"FP16:0": (0.3, 0.05),
- "FP16:4": (3.0, 1e-2),
- "FP16:8": (2.0, 1e-2),
+ "FP16:4": (0.5, 1e-2),
+ "FP16:8": (0.5, 1e-2),
"BF16:0": (1.0, 1e-2),
"BF16:4": (30.0, 1e-1),
"BF16:8": (20.0, 1e-1),
diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py
index 993716a4c80b0..c56383d2851d3 100644
--- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py
+++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py
@@ -137,43 +137,6 @@ def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""):
)
-def preprocess_weights_for_mixed_gemm(
- tensor: torch.Tensor, quant_bits: int, sm: int = -1, do_weight_interleave: bool = True
-) -> torch.Tensor:
- if len(tensor.shape) == 2:
- tensor = tensor.unsqueeze(0)
-
- # Input tensor shape is [Experts, n, k_packed]. k_packed is k/2 for 4-bit, k for 8-bit.
- num_experts = tensor.shape[0]
- n = tensor.shape[1]
- k_packed = tensor.shape[2]
- k = k_packed * 2 if quant_bits == 4 else k_packed
-
- packed_list = []
-
- if _pybind and hasattr(_pybind, "pack_weights_for_cuda_mixed_gemm") and torch.cuda.is_available():
- for i in range(num_experts):
- if tensor[i].dtype == torch.bfloat16:
- weight = tensor[i].to(torch.float32).cpu().numpy()
- else:
- weight = tensor[i].cpu().numpy()
- packed = _pybind.pack_weights_for_cuda_mixed_gemm(weight, n, k, quant_bits, sm)
- # pack_weights_for_cuda_mixed_gemm returns int8 array of shape [packed_size]
- # We need to reshape it to (k, n/2) for 4-bit, (k, n) for 8-bit.
- output_rows = k
- output_cols = n // 2 if quant_bits == 4 else n
- packed_tensor = torch.from_numpy(packed).to(tensor.device)
- packed_tensor = packed_tensor.view(torch.uint8).view(output_rows, output_cols)
- packed_list.append(packed_tensor)
-
- return torch.stack(packed_list)
- else:
- # This shall not happen unless older version of onnxruntime is used.
- raise ImportError(
- "onnxruntime._pybind_state.pack_weights_for_cuda_mixed_gemm not found. Cannot preprocess weights."
- )
-
-
def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True, asymmetric: bool = False):
# DEBUG
# print(f"DEBUG: quant_dequant input shape={weights.shape}, 4bit={is_4_bit_quantization}, asym={asymmetric}")
@@ -2110,7 +2073,7 @@ class TestQMoEIntPrePackSmoke(unittest.TestCase):
hardware (the other ``test_swiglu_qmoe_parity_*`` cases in this file
fail on H200 / H100 with max-diff > 1.0 on plain main, by
inspection — pre-existing). A real parity check can be added once
- that harness honours the runtime SM.
+ that harness honors the runtime SM.
"""
def _run_one(self, *, hidden_size, inter_size, num_experts, top_k, swiglu_fusion, batch_size):