diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9d19a95136ad7..38d101786b41a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4949,7 +4949,7 @@ This version of the operator has been available since version 1 of the 'com.micr
use_sparse_mixer : int
Whether to use sparse mixer
weights_prepacked : int
-
Only meaningful when quant_type='int'. Tri-state control over whether the int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS fpA_intB format expected by the runner. -1 (auto): let the execution provider choose its own backward-compatible default; the CUDA EP treats auto as prepacked. 1: the initializers are already prepacked (e.g. produced offline by pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself in PrePack(), matching the behaviour of MatMulNBits and removing the offline pre-pack requirement from exporters. Defaults to -1 (auto) so each execution provider can pick its own backward-compatible default rather than the schema imposing one.
+
Only meaningful when quant_type='int'. Tri-state control over the layout of the int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by -1 and 1 are determined by the execution provider. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.
#### Inputs (6 - 21) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index 6d53211ff40cb..36b68889ae582 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -71,6 +71,7 @@ input tokens → router (top-k softmax) → permute by expert | `expert_weight_bits` (QMoE only) | int | 4 | 4 (INT4/MXFP4) or 8 (INT8/FP8). | | `block_size` (QMoE only) | int | -1 | Group size for INT4/INT8 group-wise quantization. -1 = per-output-channel. | | `quant_type` (QMoE only) | string | `"int"` | `"int"`, `"fp4"`, `"fp8"`, `"wfp4afp8"`. See [§3](#3-quantization-modes). | +| `weights_prepacked` (QMoE only) | int | -1 | Tri-state, only meaningful when `quant_type="int"`. The prepacked layouts selected by `-1` and `1` are **EP-determined**. `-1` (default): the INT4/INT8 `fc1`/`fc2` initializers are already prepacked in the EP's default layout (e.g. from `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). `1`: already prepacked in an alternate EP-selected layout. `0`: the initializers are raw `[E, N, K/pack]` tensors (as produced by `quantize_matmul_{4,8}bits`) and the kernel runs the CUTLASS layout transform in `PrePack()`. **Note:** the CUDA EP INT4/INT8 MoE GEMM always runs the Ampere (SM80) kernel — even on SM90 — so it consumes the SM80 `fpA_intB` layout on all architectures; `-1` and `1` are therefore equivalent for the CUDA EP today, and `1` is reserved for a possible future Hopper-specific layout. See [§5.1](#51-weights-input-2--5--8). | ### 2.2 Type Constraints @@ -228,10 +229,53 @@ extra subtraction. ### 5.1 Weights (input 2 / 5 / 8) -Not transformed at runtime. INT4/INT8 weights must already be packed offline by -`pack_weights_for_cuda_mixed_gemm` (see [§6](#6-weight-formats)). MXFP4 weights -must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights are stored -as raw e4m3 bytes (no packing). +**INT4/INT8** weight layout is controlled by the `weights_prepacked` attribute +([§2.1](#21-attributes)). The prepacked layouts selected by `-1` and `1` are +determined by the execution provider: + +- **`weights_prepacked=-1` (default)** — the `fc1`/`fc2` weights are already in + the EP's default prepacked layout (e.g. packed offline by + `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). They are copied to GPU + and consumed as-is. +- **`weights_prepacked=1`** — the `fc1`/`fc2` weights are already in the EP's + **SM90** (Hopper) prepacked layout (reserved; see the note below). +- **`weights_prepacked=0`** — the `fc1`/`fc2` weights are raw, schema-conformant + `[E, N, K/pack]` tensors as produced by `quantize_matmul_{4,8}bits`. `PrePack` + runs the CUTLASS layout transform itself via `PrePackIntExpertWeights`, + removing the offline pre-pack dependency. This makes integer QMoE symmetric + with `MatMulNBits::PrePack_B`. + +> **Single layout on the CUDA EP.** The CUDA EP INT4/INT8 MoE GEMM always +> dispatches to the Ampere (**SM80**) grouped-GEMM kernel — even on SM90 — +> because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA +> warp-specialized specialisation (`isValidHopperMOESpecialisation` is `false`). +> This matches **TensorRT-LLM**, which likewise routes `W4A16`/`W8A16` MoE to the +> SM80 kernel on Hopper; its Hopper TMA-WS mixed-dtype MoE kernel is reserved for +> `W4A8` (FP8 activation) and `WFP4A16` (FP4 weight). Consequently the CUDA EP +> consumes the **SM80 `fpA_intB` layout on every GPU**, `PrePack` always packs +> for SM80, and `weights_prepacked=-1` and `=1` are equivalent today. `1` is +> accepted and reserved for a possible future Hopper-specific layout (e.g. +> `W4A8`). There is therefore no architecture-match constraint: SM80-format +> weights run correctly on SM90 via the SM80 kernel. + +`PrePackIntExpertWeights` loops over the `E` experts and, per expert, applies the +same transpose + row-permutation / column-interleave / bias / pair-interleave +transform as `pack_weights_for_cuda_mixed_gemm` (see [§6.1](#61-int4-group-wise-quant_typeint-expert_weight_bits4)), +always targeting the SM80 layout. SM75+ is required. The source +`[E, N, K/pack]` initializers are released after their shapes are cached +(`fc1_weights_shape_` / `fc2_weights_shape_`), so peak weight memory stays ~1×. +The prepacked GPU buffers (`packed_fc1_weights_` / `packed_fc2_weights_`) are then +preferred by `ComputeInternal`. If prepacking is disabled at the session level +(`session.disable_prepacking`), the buffers stay null and the raw initializer +pointers are read at compute time instead. + +> **Note**: `weights_prepacked=0` is the only path that triggers an in-`PrePack` +> layout transform for INT weights. FP4 / FP8 / WFP4AFP8 weight handling is +> unaffected. + +MXFP4 weights must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights +are stored as raw e4m3 bytes (no packing). + ### 5.2 INT4/INT8 scales + zero-point → bias @@ -287,7 +331,12 @@ This section covers the five distinct weight encodings supported by QMoE. INT4 packing layout within a byte: `[high_nibble | low_nibble] = [elt_1 | elt_0]`. Each INT4 element is in `[-8, 7]` (signed) before bias, `[0, 15]` after the +8 bias. -#### Preprocessing pipeline (offline, `pack_weights_for_cuda_mixed_gemm`) +#### Preprocessing pipeline (offline `pack_weights_for_cuda_mixed_gemm`, or in-`PrePack` via `PrePackIntExpertWeights`) + +This is the layout transform applied either offline by +`pack_weights_for_cuda_mixed_gemm`, or per-expert inside `PrePack` when +`weights_prepacked=0` (see [§5.1](#51-weights-input-2--5--8)). + 1. **Input layout**: `[N, K]` per expert (Out × In), 2 elements per byte for INT4. 2. **Transpose & signed conversion**: @@ -405,6 +454,17 @@ weights are interchangeable across SMs: — does not use `pack_weights_for_cuda_mixed_gemm`. - **FP8**: no packing. +> **QMoE uses Group A on every GPU.** The table above describes the layouts the +> `pack_weights_for_cuda_mixed_gemm` *preprocessor* can emit. The QMoE INT4/INT8 +> MoE GEMM, however, always dispatches to the Ampere (SM80) grouped-GEMM kernel — +> even on SM90 — because mixed int-weight + fp16/bf16 activation is not a valid +> Hopper TMA warp-specialized specialisation (the same is true in TensorRT-LLM). +> It therefore consumes the **Group A (SM80) layout on all architectures, +> including Hopper**. For QMoE, always pack INT4/INT8 weights for SM80 (`arch=80`), +> and `PrePackIntExpertWeights` (`weights_prepacked=0`) does exactly that +> regardless of the runtime device SM. Group B (SM90) layout is currently unused +> by QMoE. + --- ## 8. SwiGLU Fusion @@ -830,7 +890,7 @@ will not change the operator interface. |-----------|----------| | [test_moe_cuda.py](onnxruntime/test/python/transformers/test_moe_cuda.py) | Standard MoE on CUDA: FP16/BF16, SiLU/GeLU/SwiGLU, routing, GEMM parity. SwiGLU coverage includes both GPT-OSS (`TestSwigluMoE`: interleaved, alpha=1.702/beta=1.0/limit=7.0) and Standard/Llama-Gemma (`TestStandardSwigluMoE`: concatenated `swiglu_fusion=2`, alpha=1.0/beta=0.0/no limit → `SiLU(Gate)×Value`). | | [test_moe_cpu.py](onnxruntime/test/python/transformers/test_moe_cpu.py) | Standard MoE on CPU (smoke). | -| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. | +| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. `TestQMoEIntPrePackSmoke` covers the raw-weight `weights_prepacked=0` in-`PrePack` layout transform (smoke test: asserts finite output, not bit-parity). | | [test_qmoe_cpu.py](onnxruntime/test/python/transformers/test_qmoe_cpu.py) | INT4/INT8 QMoE on CPU (smoke). | | [test_qmoe_fp4_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py) | MXFP4 QMoE: quantization utilities, packing, FP16/BF16, SiLU/SwiGLU, top-k and expert-count variants. End-to-end runs on SM120; on SM<120 the dequant fallback is exercised. | | [test_qmoe_fp8_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py) | FP8 W8A16 QMoE on SM90+ native path and SM<90 dequant fallback. | @@ -954,6 +1014,11 @@ over-aligned by-value parameters. cannot. See [§14.1](#141-msvc-and-tma-grouped-moe-gemm). - **WFP4AFP8 native** requires SM100+ hardware; only the dequant fallback path is validated end-to-end so far. +- **In-`PrePack` INT weight layout transform** (`weights_prepacked=0`) is + currently covered only by a smoke test (`TestQMoEIntPrePackSmoke`), not a + bit-parity check: the existing offline pre-pack harness hardcodes + `force_arch=80` (the same SM80 layout consumed by the CUDA EP on all GPUs), + so a separate parity harness for this path is still pending. - **Hopper W4A8** (INT4 weight + FP8 activation) is not supported — TRT-LLM gates its fast path to SM89 only. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0cdbfd7114cf0..0289bca5c84d2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -348,9 +348,6 @@ ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared exte ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails); ORT_RUNTIME_CLASS(EpAssignedSubgraph); ORT_RUNTIME_CLASS(EpAssignedNode); -ORT_RUNTIME_CLASS(ModelPackageOptions); -ORT_RUNTIME_CLASS(ModelPackageContext); -ORT_RUNTIME_CLASS(ModelPackageComponentContext); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -938,9 +935,6 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtInteropApi; typedef struct OrtInteropApi OrtInteropApi; -struct OrtModelPackageApi; -typedef struct OrtModelPackageApi OrtModelPackageApi; - struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -7501,26 +7495,6 @@ struct OrtApi { */ ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); - /** \brief Get the model package API table. - * - * Returns a pointer to the ::OrtModelPackageApi function table, which provides APIs to: - * - create and release model package options and contexts, - * - inspect model package metadata (components/variants), - * - select a component/variant and query selected files/options, - * - create a session from model package selection results. - * - * The returned pointer is owned by ONNX Runtime and is valid for the process lifetime. - * Do not free it. - * - * \note May return NULL if model package support is not available in the current build - * (for example, minimal builds). - * - * \return Pointer to ::OrtModelPackageApi, or NULL if unsupported. - * - * \since Version 1.27. - */ - const OrtModelPackageApi*(ORT_API_CALL* GetModelPackageApi)(void); - /** \brief Retrieve an experimental function pointer by name. * * Experimental functions are not part of the stable ABI and may be added or removed between releases without notice. @@ -8645,250 +8619,6 @@ struct OrtInteropApi { /// @} }; -/** \brief API table for model package workflows. - * - * A model package is a directory containing one or more *components* (logical models). - * Each component has one or more *variants*, where each variant targets a single - * execution provider (EP). The package manifest declares the EP name, device type, - * and an optional compatibility string for every variant so that the runtime can - * automatically select the best variant for the hardware and EPs available in the - * caller's session options. - * - * Obtain this table from OrtApi::GetModelPackageApi(). The APIs support: - * - creating model package options that capture EP configuration from OrtSessionOptions, - * - loading a package context (manifest + metadata) from a package root path, - * - querying component/variant metadata including per-variant EP information, - * - selecting a component (which also resolves the best-matching variant), - * - querying the selected variant's name and folder path, - * - creating an OrtSession from the selected component context. - * - * Typical flow: - * 1) Create model package options: - * - CreateModelPackageOptionsFromSessionOptions() - * 2) Load package metadata: - * - CreateModelPackageContext() - * 3) Query metadata (optional): - * - ModelPackage_GetSchemaVersion() - * - ModelPackage_GetComponentCount() - * - ModelPackage_GetComponentNames() - * - ModelPackage_GetVariantCount() - * - ModelPackage_GetVariantNames() - * - ModelPackage_GetVariantEpName() - * 4) Select a component and resolve variant: - * - SelectComponent() - * 5) Query selected variant info (optional): - * - ModelPackageComponent_GetSelectedVariantName() - * - ModelPackageComponent_GetSelectedVariantFolderPath() - * 6) Create session: - * - CreateSession() - * - * Ownership: - * - Release objects created by this API with the corresponding release methods: - * ReleaseModelPackageOptions(), ReleaseModelPackageContext(), - * ReleaseModelPackageComponentContext(). - * - * \since Version 1.27. - */ -struct OrtModelPackageApi { - /// \name OrtModelPackageOptions - /// @{ - - /** \brief Create model package options from an existing OrtSessionOptions. - * - * Captures EP configuration (registered execution providers and their devices) from - * the session options for use during variant selection. The resulting OrtModelPackageOptions - * is passed to SelectComponent() to resolve the best variant for the available EPs. - * - * \param[in] env The ORT environment. - * \param[in] session_options Session options containing registered EPs. - * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released - * with ReleaseModelPackageOptions(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - - ORT_CLASS_RELEASE(ModelPackageOptions); - /// @} - /// \name OrtModelPackageContext - /// @{ - - /** \brief Create a model package context by parsing the package at the given root path. - * - * Parses the manifest.json and component metadata from the specified directory. - * The returned context provides read-only access to the package structure (components, - * variants, EP declarations). - * - * \param[in] package_root Path to the model package root directory (containing manifest.json). - * \param[out] out Receives the newly created OrtModelPackageContext. Must be released - * with ReleaseModelPackageContext(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - - ORT_CLASS_RELEASE(ModelPackageContext); - - /** \brief Get the schema version declared in the model package manifest. - * - * \param[in] ctx The model package context. - * \param[out] out_version Receives the schema version number. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - - /** \brief Get the number of components in the model package. - * - * \param[in] ctx The model package context. - * \param[out] out_count Receives the component count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - - /** \brief Get the names of all components in the model package. - * - * Returns a pointer to an array of UTF-8 component name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[out] out_names Receives a pointer to an array of component name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - - /** \brief Get the number of variants for a given component. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_count Receives the variant count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - - /** \brief Get the names of all variants for a given component. - * - * Returns a pointer to an array of UTF-8 variant name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_variant_names Receives a pointer to an array of variant name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - - /** \brief Get the EP name declared for a (component, variant) pair. - * - * Each variant targets a single EP. `out_ep` receives the EP name string. - * When the variant does not declare an EP, the returned pointer is NULL. - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - - /** \brief Select a component model and return an opaque component instance. - * - * The variant selection is also performed during this call based on the component metadata and the provided options. - * The returned `OrtModelPackgeComponentContext*` is independent of `context` lifetime and must be released via - * `ReleaseComponentInstance`. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - - ORT_CLASS_RELEASE(ModelPackageComponentContext); - - /** \brief Get the name of the selected variant after SelectComponent has been called. - * - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] out_name Receives the selected variant's name string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - - /** \brief Get the folder path of the selected variant. - * - * Returns the resolved absolute path to the variant's directory on disk. - * The string is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] folder_path Receives the variant folder path string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - - /// @} - /** \brief Create an OrtSession for a selected file within a component model variant. - * - * The chosen variant (and thus its EP selection) is determined by `context`, which - * was built from an OrtSessionOptions via CreateModelPackageOptionsFromSessionOptions. - * - * Session options precedence: - * 1. session_options == NULL (default path): - * ORT uses the OrtSessionOptions that was captured when `context` was created. - * Any variant-specific session and provider options declared in the package - * metadata are merged on top. - * - * 2. session_options != NULL (advanced path): - * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific - * session and provider options from the package metadata are NOT applied. - * Use this when custom EP setup is required (e.g., shared CUDA streams, - * shared QNN EP contexts, custom allocators). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* context, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a999ef5e0faf4..4798d3d4ad1b8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -264,20 +264,6 @@ inline const OrtEpApi& GetEpApi() { return *api; } -/// -/// This returns a reference to the ORT C Model Package API. Used for loading models from model packages. -/// -/// ORT C Model Package API reference -inline const OrtModelPackageApi& GetModelPackageApi() { - auto* api = GetApi().GetModelPackageApi(); - if (api == nullptr) { - // minimal build - ORT_CXX_API_THROW("Model Package API is not available in this build", ORT_FAIL); - } - - return *api; -} - /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -678,9 +664,6 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageOptions, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageContext, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageComponentContext, GetModelPackageApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -807,9 +790,6 @@ struct EpDevice; struct ExternalInitializerInfo; struct Graph; struct Model; -struct ModelPackageOptions; -struct ModelPackageContext; -struct ModelPackageComponentContext; struct Node; struct ModelMetadata; struct TypeInfo; @@ -1806,70 +1786,6 @@ struct ModelCompilationOptions : detail::Base { */ Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); -/** \brief Options for selecting a component from a model package. - * - * Wraps ::OrtModelPackageOptions. Created from an Env and SessionOptions, which captures the - * EP configuration used for variant selection. - */ -struct ModelPackageOptions : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageOptions(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - ModelPackageOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions - ModelPackageOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions -}; - -/** \brief Context for inspecting and selecting components from a model package. - * - * Wraps ::OrtModelPackageContext. Provides traversal APIs to enumerate components, variants, - * and EP compatibility, as well as component selection. - */ -struct ModelPackageContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - explicit ModelPackageContext(const ORTCHAR_T* package_root); ///< Wraps OrtModelPackageApi::CreateModelPackageContext - - size_t GetComponentCount() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentCount - std::vector GetComponentNames() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentNames - size_t GetVariantCount(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantCount - std::vector GetVariantNames(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantNames - - /// Get the EP name for a variant. Returns nullptr if not declared. - /// Returned string is owned by this context and valid until it is released. - const char* GetVariantEpName(const char* component_name, - const char* variant_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantEpName - - int64_t GetSchemaVersion() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetSchemaVersion - - ModelPackageComponentContext SelectComponent(const char* component_name, - const ModelPackageOptions& options) const; ///< Wraps OrtModelPackageApi::SelectComponent -}; - -/** \brief Context for a selected component within a model package. - * - * Wraps ::OrtModelPackageComponentContext. Provides accessors for the selected variant's - * folder path and variant name. - */ -struct ModelPackageComponentContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageComponentContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - std::basic_string GetSelectedVariantFolderPath() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantFolderPath - - std::string GetSelectedVariantName() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantName - - Session CreateSession(const Env& env); ///< Wraps OrtModelPackageApi::CreateSession (default path, NULL session_options) - Session CreateSession(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) - Session CreateSession(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) -}; - /** \brief Wrapper around ::OrtModelMetadata * */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 51f99655121c6..d7439e7b356c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1360,110 +1360,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const Ort return *this; } -// ModelPackageOptions -inline ModelPackageOptions::ModelPackageOptions(const Env& env, const SessionOptions& session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -inline ModelPackageOptions::ModelPackageOptions(const Env& env, ConstSessionOptions session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -// ModelPackageContext -inline ModelPackageContext::ModelPackageContext(const ORTCHAR_T* package_root) { - ThrowOnError(GetModelPackageApi().CreateModelPackageContext(package_root, &this->p_)); -} - -inline size_t ModelPackageContext::GetComponentCount() const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentCount(this->p_, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetComponentNames() const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentNames(this->p_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline size_t ModelPackageContext::GetVariantCount(const char* component_name) const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantCount(this->p_, component_name, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetVariantNames(const char* component_name) const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantNames(this->p_, component_name, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline const char* ModelPackageContext::GetVariantEpName(const char* component_name, - const char* variant_name) const { - const char* ep = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantEpName( - this->p_, component_name, variant_name, &ep)); - return ep; -} - -inline int64_t ModelPackageContext::GetSchemaVersion() const { - int64_t version = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetSchemaVersion(this->p_, &version)); - return version; -} - -inline ModelPackageComponentContext ModelPackageContext::SelectComponent( - const char* component_name, const ModelPackageOptions& options) const { - OrtModelPackageComponentContext* out = nullptr; - ThrowOnError(GetModelPackageApi().SelectComponent(this->p_, component_name, options, &out)); - return ModelPackageComponentContext{out}; -} - -// ModelPackageComponentContext -inline std::basic_string ModelPackageComponentContext::GetSelectedVariantFolderPath() const { - const ORTCHAR_T* path = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantFolderPath(this->p_, &path)); - return std::basic_string{path}; -} - -inline std::string ModelPackageComponentContext::GetSelectedVariantName() const { - const char* name = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantName(this->p_, &name)); - return (name != nullptr) ? std::string{name} : std::string{}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, nullptr, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - const SessionOptions& session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - ConstSessionOptions session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h index e943a5cf65b11..0dd87c10776d3 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h @@ -39,6 +39,10 @@ // ORT_RUNTIME_CLASS(ExperimentalType); // +ORT_RUNTIME_CLASS(ModelPackageOptions); +ORT_RUNTIME_CLASS(ModelPackageContext); +ORT_RUNTIME_CLASS(ModelPackageComponentContext); + // // C: function pointer typedefs and name constants // diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 0123b02818584..57a4e472b6f6d 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -35,3 +35,250 @@ * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t* out) + +// --------------------------------------------------------------------------- +// OrtModelPackageApi +// +// A model package is a directory containing one or more *components* (logical models). +// Each component has one or more *variants*, where each variant targets a single +// execution provider (EP). The package manifest declares the EP name, device type, +// and an optional compatibility string for every variant so that the runtime can +// automatically select the best variant for the hardware and EPs available in the +// caller's session options. +// +// The functions below support: +// - creating model package options that capture EP configuration from OrtSessionOptions, +// - loading a package context (manifest + metadata) from a package root path, +// - querying component/variant metadata including per-variant EP information, +// - selecting a component (which also resolves the best-matching variant), +// - querying the selected variant's name and folder path, +// - creating an OrtSession from the selected component context. +// +// Typical flow: +// 1) Create model package options: +// - OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions +// 2) Load package metadata: +// - OrtModelPackageApi_CreateModelPackageContext +// 3) Query metadata (optional): +// - OrtModelPackageApi_ModelPackage_GetSchemaVersion +// - OrtModelPackageApi_ModelPackage_GetComponentCount +// - OrtModelPackageApi_ModelPackage_GetComponentNames +// - OrtModelPackageApi_ModelPackage_GetVariantCount +// - OrtModelPackageApi_ModelPackage_GetVariantNames +// - OrtModelPackageApi_ModelPackage_GetVariantEpName +// 4) Select a component and resolve variant: +// - OrtModelPackageApi_SelectComponent +// 5) Query selected variant info (optional): +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath +// 6) Create session: +// - OrtModelPackageApi_CreateSession +// +// Ownership: +// - Release objects created by this API with the corresponding release functions: +// OrtModelPackageApi_ReleaseModelPackageOptions, +// OrtModelPackageApi_ReleaseModelPackageContext, +// OrtModelPackageApi_ReleaseModelPackageComponentContext. +// +// Opaque handles (OrtModelPackageOptions, OrtModelPackageContext, OrtModelPackageComponentContext) +// are declared in onnxruntime_experimental_c_api.h. +// --------------------------------------------------------------------------- + +/** \brief Create model package options from an existing OrtSessionOptions. + * + * Captures EP configuration (registered execution providers and their devices) from + * the session options for use during variant selection. The resulting OrtModelPackageOptions + * is passed to OrtModelPackageApi_SelectComponent to resolve the best variant for the + * available EPs. + * + * \param[in] env The ORT environment. + * \param[in] session_options Session options containing registered EPs. + * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released + * with OrtModelPackageApi_ReleaseModelPackageOptions. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions, + _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtModelPackageOptions** out) + +/** \brief Release an OrtModelPackageOptions created by + * OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageOptions, + _Frees_ptr_opt_ OrtModelPackageOptions* options) + +/** \brief Create a model package context by parsing the package at the given root path. + * + * Parses the manifest.json and component metadata from the specified directory. + * The returned context provides read-only access to the package structure (components, + * variants, EP declarations). + * + * \param[in] package_root Path to the model package root directory (containing manifest.json). + * \param[out] out Receives the newly created OrtModelPackageContext. Must be released + * with OrtModelPackageApi_ReleaseModelPackageContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageContext, + _In_ const ORTCHAR_T* package_root, + _Outptr_ OrtModelPackageContext** out) + +/** \brief Release an OrtModelPackageContext created by OrtModelPackageApi_CreateModelPackageContext. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageContext, + _Frees_ptr_opt_ OrtModelPackageContext* ctx) + +/** \brief Get the schema version declared in the model package manifest. + * + * \param[in] ctx The model package context. + * \param[out] out_version Receives the schema version number. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetSchemaVersion, + _In_ const OrtModelPackageContext* ctx, + _Out_ int64_t* out_version) + +/** \brief Get the number of components in the model package. + * + * \param[in] ctx The model package context. + * \param[out] out_count Receives the component count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentCount, + _In_ const OrtModelPackageContext* ctx, + _Out_ size_t* out_count) + +/** \brief Get the names of all components in the model package. + * + * Returns a pointer to an array of UTF-8 component name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[out] out_names Receives a pointer to an array of component name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentNames, + _In_ const OrtModelPackageContext* ctx, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, + _Out_ size_t* out_count) + +/** \brief Get the number of variants for a given component. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_count Receives the variant count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantCount, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Out_ size_t* out_count) + +/** \brief Get the names of all variants for a given component. + * + * Returns a pointer to an array of UTF-8 variant name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_variant_names Receives a pointer to an array of variant name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantNames, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, + _Out_ size_t* out_count) + +/** \brief Get the EP name declared for a (component, variant) pair. + * + * Each variant targets a single EP. `out_ep` receives the EP name string. + * When the variant does not declare an EP, the returned pointer is NULL. + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantEpName, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _In_ const char* variant_name, + _Outptr_result_maybenull_ const char** out_ep) + +/** \brief Select a component model and return an opaque component instance. + * + * The variant selection is also performed during this call based on the component + * metadata and the provided options. The returned `OrtModelPackageComponentContext*` is + * independent of `context` lifetime and must be released via + * OrtModelPackageApi_ReleaseModelPackageComponentContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_SelectComponent, + _In_ const OrtModelPackageContext* context, + _In_ const char* component_name, + _In_ const OrtModelPackageOptions* options, + _Outptr_ OrtModelPackageComponentContext** out) + +/** \brief Release an OrtModelPackageComponentContext created by OrtModelPackageApi_SelectComponent. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageComponentContext, + _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx) + +/** \brief Get the name of the selected variant after OrtModelPackageApi_SelectComponent has been called. + * + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] out_name Receives the selected variant's name string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const char** out_name) + +/** \brief Get the folder path of the selected variant. + * + * Returns the resolved absolute path to the variant's directory on disk. + * The string is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] folder_path Receives the variant folder path string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const ORTCHAR_T** folder_path) + +/** \brief Create an OrtSession for a selected file within a component model variant. + * + * The chosen variant (and thus its EP selection) is determined by `context`, which + * was built from an OrtSessionOptions via OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + * + * Session options precedence: + * 1. session_options == NULL (default path): + * ORT uses the OrtSessionOptions that was captured when `context` was created. + * Any variant-specific session and provider options declared in the package + * metadata are merged on top. + * + * 2. session_options != NULL (advanced path): + * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific + * session and provider options from the package metadata are NOT applied. + * Use this when custom EP setup is required (e.g., shared CUDA streams, + * shared QNN EP contexts, custom allocators). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession, + _In_ const OrtEnv* env, + _In_ OrtModelPackageComponentContext* context, + _In_opt_ const OrtSessionOptions* session_options, + _Outptr_ OrtSession** session) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c26424a0ab9ec..df14bc8c57f24 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -29,9 +29,6 @@ GraphOptimizationLevel, # noqa: F401 LoraAdapter, # noqa: F401 ModelMetadata, # noqa: F401 - ModelPackageComponentContext, # noqa: F401 - ModelPackageContext, # noqa: F401 - ModelPackageOptions, # noqa: F401 NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index d28aae02ab2f1..6e75bb1b5a7c0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -48,7 +48,11 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, unique_word_ids.insert(word_id); } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : unique_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } T score = beam_token_scores[word_id]; // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, @@ -89,7 +93,11 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, } } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : blocked_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } beam_token_scores[word_id] = std::numeric_limits::lowest(); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 6e157d83159f4..fb9638bb09d59 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -28,7 +28,9 @@ struct NextTokenScores { } void SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); + if (token_id < 0 || token_id >= vocab_size) { + return; + } for (int i = 0; i < batch_beam_size; i++) { scores[static_cast(i) * vocab_size + token_id] = score; } diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h index b9e62443145e5..c3b734816cf84 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h @@ -31,6 +31,8 @@ enum class QuantType { W4_AFP8 }; +int get_arch_for_mixed_gemm_weight_preprocess(int arch); + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu index a006612ddadc9..7e83bdda72eab 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu @@ -521,6 +521,19 @@ void add_bias_and_interleave_quantized_tensor_inplace_cuda( } } +int get_arch_for_mixed_gemm_weight_preprocess(int arch) { + ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); + if (arch < 80) { + return 75; + } +#ifndef EXCLUDE_SM_90 + if (arch >= 90 && arch < 100) { + return 90; + } +#endif + return 80; +} + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h index a8fb411ed0663..47bbe0c0e10ec 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h @@ -120,11 +120,11 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { } LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { - ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); - if (arch < 80) { + arch = get_arch_for_mixed_gemm_weight_preprocess(arch); + if (arch == 75) { return getLayoutDetailsForArch(quant_type); #ifndef EXCLUDE_SM_90 - } else if (arch >= 90 && arch < 100) { + } else if (arch == 90) { return getLayoutDetailsForArch(quant_type); #endif } else { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index e1ddcac0cea4f..bc34a2e83318a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -62,18 +62,28 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE this->quant_type_ = op_kernel_info.GetAttrOrDefault("quant_type", "int"); ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8", "quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'"); - // ``weights_prepacked`` is an optional tri-state attribute that defaults to - // -1 (auto) in the schema, so each EP picks its own backward-compatible - // default rather than the schema imposing one: - // -1 (auto, also the schema default): the EP decides. The CUDA EP's - // backward-compatible default is "prepacked" because all pre-existing - // tooling ships CUTLASS-prepacked weights. - // 1: initializers are already prepacked; the compute path reads them as-is. - // 0: initializers are raw [E, N, K/pack]; the PrePack hook lays them out. + // ``weights_prepacked`` is an optional tri-state attribute (default -1) that + // declares the layout of the int4/int8 fc1/fc2 weight initializers. The + // concrete prepacked layouts selected by -1 and 1 are determined by the + // execution provider. The CUDA EP maps the tri-state as: + // -1 (default): already prepacked in the EP's default int weight layout. + // 1: already prepacked in an alternate EP-selected int weight layout. + // 0: raw [E, N, K/pack] initializers; the PrePack hook lays them out. + // + // Important: the CUDA QMoE int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (see isValidHopperMOESpecialisation). The kernel therefore + // consumes the SM80/Ampere CUTLASS fpA_intB layout on every GPU. As a result + // the EP default (-1) is the SM80 layout regardless of the runtime device SM, + // and SM80-format weights are valid on SM90 (they run via the SM80 kernel). + // For CUDA today, -1 and 1 are equivalent (both SM80 layout), and 1 is + // reserved for a possible future Hopper-specific layout. + // PrePack (weights_prepacked=0) packs for the SM80 layout accordingly. const int64_t weights_prepacked_mode = op_kernel_info.GetAttrOrDefault("weights_prepacked", static_cast(-1)); ORT_ENFORCE(weights_prepacked_mode == -1 || weights_prepacked_mode == 0 || weights_prepacked_mode == 1, - "weights_prepacked must be -1 (auto), 0, or 1, but got ", weights_prepacked_mode); + "weights_prepacked must be -1 (default), 0, or 1, but got ", weights_prepacked_mode); weights_prepacked_ = (weights_prepacked_mode != 0); #if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE) ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer."); @@ -224,6 +234,18 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { // to the runner. const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr && packed_fc2_weights_ != nullptr; + // When ``weights_prepacked == 0`` the raw ``[E, N, K/pack]`` int weights must be + // converted to the CUTLASS fpA_intB layout by PrePack before the runner can consume + // them. If PrePack never ran (e.g. ``session.disable_prepacking`` is set), the prepack + // buffers stay null and falling through to the raw initializer pointers would feed + // non-CUTLASS bytes to the runner, producing silently wrong output. Fail loudly instead. + if (is_int && !weights_prepacked_ && + (packed_fc1_weights_ == nullptr || packed_fc2_weights_ == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE weights_prepacked=0 requires PrePack to run, but the int weight " + "buffers were not produced (is session.disable_prepacking set?). Provide " + "CUTLASS-prepacked weights with weights_prepacked=1, or enable prepacking."); + } const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input(2); const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input(3) : nullptr; const Tensor* fc1_experts_bias_optional = context->Input(4); @@ -844,13 +866,24 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr; const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr; if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) { - fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data; - fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data; + // The native CUTLASS WFP4AFP8 path consumes weights in the repacked FP4 + // layout produced by PrePack. If PrePack never ran (e.g. + // ``session.disable_prepacking`` is set) the repacked buffers stay null and + // falling through to the raw initializer bytes would feed a non-CUTLASS + // layout to the runner, producing silently wrong output. Fail loudly. + if (packed_fp4_fc1_weights_ == nullptr || packed_fp4_fc2_weights_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE wfp4afp8 requires PrePack to run, but the repacked FP4 weight " + "buffers were not produced (is session.disable_prepacking set?). " + "Enable prepacking to use the native WFP4AFP8 path."); + } + fc1_weight_data = packed_fp4_fc1_weights_.get(); + fc2_weight_data = packed_fp4_fc2_weights_.get(); } else if (int_weights_consumed_by_prepack) { // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB // layout that the runner consumes and freed the source initializer // (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack`` - // (which already requires ``packed_fc1_weights_ != nullptr``) rather than + // (which already requires both packed weight buffers) rather than // just ``is_int && !weights_prepacked_``: when prepacking is disabled at // the session level (``session.disable_prepacking``) PrePack never runs, // the prepack buffers stay null, and the raw initializer pointers read @@ -1146,6 +1179,9 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al IAllocatorUniquePtr& packed_buf, bool& is_packed) { ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_); + ORT_ENFORCE(sm_ >= 75, + "PrePackIntExpertWeights: quant_type='int' with weights_prepacked=0 requires SM75+ CUDA hardware, got SM", + sm_); const auto& shape = tensor.Shape(); ORT_ENFORCE(shape.NumDimensions() == 3, "PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=", @@ -1158,22 +1194,15 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al const int64_t k_packed = shape[2]; const int64_t k = k_packed * pack_factor; - // Weight packing is architecture-aware (see - // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing - // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that - // skips column interleaving, so it is its own compatibility group. Every - // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares - // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack - // INT8 LDSM and are unsupported. The compute-side runner selects the same - // layout from this clamped arch, so the two cannot drift. - // - // SM75 is passed through unchanged (rather than clamped to 80) even though it - // shares SM80's layout: the compute-side dispatch (getLayoutDetailsForTransform) - // still has a distinct SM75 branch, so mirroring it here avoids confusing a - // reader into thinking prepack and dispatch disagree. - ORT_ENFORCE(sm_ >= 75, - "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_); - const int packing_sm = (sm_ == 90 || sm_ == 75) ? sm_ : 80; + // The CUDA QMoE int4/int8 MoE GEMM always dispatches to the Ampere (SM80) + // grouped-GEMM kernel -- even on SM90 -- because mixed int-weight + fp16/bf16 + // is not a valid Hopper TMA warp-specialized specialisation. The kernel thus + // consumes the SM80 CUTLASS fpA_intB layout on every GPU, so the weights must + // always be preprocessed for SM80 regardless of the runtime device SM. + // (Using get_arch_for_mixed_gemm_weight_preprocess(sm_) here would emit the + // SM90 layout on Hopper, which the SM80 kernel cannot consume -> wrong output.) + const int packing_sm = + onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(80); // Per-expert sizes. const size_t per_expert_bytes = static_cast(n) * static_cast(k) / pack_factor; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h index 5722ac41cc470..2bbadc205b5d8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h @@ -46,16 +46,23 @@ class QMoE final : public CudaKernel, public MoEBase { IAllocatorUniquePtr& packed_buf, bool& is_packed); int64_t expert_weight_bits_; bool is_fp16_; - // When true (the schema default), the int4/int8 fc1/fc2 weight - // initializers are already in the CUTLASS fpA_intB layout — produced - // offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the - // compute path reads them as-is. When false, the raw schema-conformant - // ``[E, N, K/pack]`` layout (as produced by - // ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook - // via ``PrePackIntExpertWeights``, removing the offline prepack - // dependency. Only meaningful when ``quant_type_ == "int"``. Derived from - // the optional tri-state ``weights_prepacked`` attribute: -1/auto (or - // absent) maps to true on the CUDA EP, 1 maps to true, 0 maps to false. + // When true, the int4/int8 fc1/fc2 weight initializers are already in a + // CUTLASS fpA_intB layout — produced offline e.g. via + // ``pack_weights_for_cuda_mixed_gemm`` — and the compute path reads them + // as-is. When false, the raw schema-conformant ``[E, N, K/pack]`` layout + // (as produced by ``quantize_matmul_{4,8}bits``) is rewritten inside the + // PrePack hook via ``PrePackIntExpertWeights``, removing the offline + // prepack dependency. Only meaningful when ``quant_type_ == "int"``. + // Derived from the optional tri-state ``weights_prepacked`` attribute: + // -1 (default) and 1 both map to true; 0 maps to false. The concrete + // prepacked layouts selected by -1 and 1 are determined by the execution + // provider. For the CUDA EP the int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (matches TensorRT-LLM, which also routes W4A16/W8A16 MoE to + // the SM80 kernel on Hopper). The kernel therefore consumes the SM80 fpA_intB + // layout on every GPU, so -1 and 1 are currently equivalent for the CUDA EP; + // 1 is reserved for a possible future Hopper-specific layout (e.g. W4A8). bool weights_prepacked_ = true; // Cached source weight shapes captured at PrePack time. When the // PrePack hook consumed and released the original int4/int8 weight diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 684e050f0201a..02e764d01e05e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,40 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +// WGSL helper function for normalizing on-device indirect dispatch dims. +// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram. +// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers: +// 1) direct (x, y, z) write when every dim is within the spec limit (65535); +// 2) 2D sqrt collapse when the product fits a square layout; +// 3) 3D cbrt collapse otherwise. +// Consumers are unaffected by the chosen layout: ShaderHelper flattens +// workgroup_id (x, y, z) into a single linear workgroup_idx. +// Caller contract: must register a storage output named exactly +// `indirect_buffer` of array with at least 3 elements. +constexpr const char kNormalizeDispatchGroupSizeFn[] = R"( +fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) { + let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension + if (x <= limit && y <= limit && z <= limit) { + indirect_buffer[0] = x; + indirect_buffer[1] = y; + indirect_buffer[2] = z; + return; + } + let size = f32(x) * f32(y) * f32(z); + let dispatch_avg_2d = u32(ceil(sqrt(size))); + if (dispatch_avg_2d <= limit) { + indirect_buffer[0] = dispatch_avg_2d; + indirect_buffer[1] = dispatch_avg_2d; + indirect_buffer[2] = 1u; + return; + } + let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0))); + indirect_buffer[0] = dispatch_avg_3d; + indirect_buffer[1] = dispatch_avg_3d; + indirect_buffer[2] = dispatch_avg_3d; +} +)"; + Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); @@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha if (prepare_indirect_dispatch_) { sh.AddOutput("indirect_buffer", ShaderUsage::None); + sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; } return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", @@ -87,13 +122,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { - // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size. - shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n" - << " if (global_idx == 0u) {\n" + shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; + shader.MainFunctionBody() << " if (global_idx == 0u) {\n" << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" - << " indirect_buffer[0] = num_total_seq_length_tile;\n" - << " indirect_buffer[1] = uniforms.num_heads;\n" - << " indirect_buffer[2] = 1u;\n" + << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -120,7 +152,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -176,7 +208,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt {static_cast(parameters.total_sequence_length_)}, {static_cast(parameters.kv_sequence_length_)}, {tile_size}, - {static_cast(parameters.num_heads_)}}); + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.batch_size_)}, + {num_q_tiles}}); return context.RunProgram(program); } @@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } -Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); +Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); + const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const uint32_t tile_size_k_vec = 8; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(q, q)); } -Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, - const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) { +Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, + const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, + Tensor* metadata, const Tensor* seqlen_k, + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; + const int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch}; + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; + bool is_unidirectional = parameters.is_unidirectional_; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}}); + {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank}, + program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -294,124 +342,72 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.num_heads_)}, {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); } -Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("metadata", ShaderUsage::UseUniform); - shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); +Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { shader.AddInput("head_sink", ShaderUsage::UseUniform); } - shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); - - const uint32_t tile_size_k_vec = 8u; - - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template", - WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); -} - -Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context, - const Tensor* metadata, - const Tensor* qk, - Tensor* out_split_vx, - Tensor* present_value, - const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, - const Tensor* indirect_buffer, - uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile, - uint32_t tile_size, - bool use_indirect_dispatch, - uint32_t present_sequence_length, - const Tensor* head_sink) { - const int components = 4; - const bool has_head_sink = head_sink != nullptr; - int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink}; - program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, - {qk, ProgramTensorMetadataDependency::TypeAndRank}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] - const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - if (use_indirect_dispatch) { - program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); - } - if (has_head_sink) { - program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); - } - // SetIndirectDispatchTensor must be called after all AddInput calls because it - // appends the indirect buffer as the last program input. - if (use_indirect_dispatch) { - program.SetIndirectDispatchTensor(indirect_buffer); - } else { - program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); - } - program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink) - .SetWorkgroupSize(64) - .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, - {static_cast(head_size_vec)}, - present_sequence_length, - {static_cast(parameters.n_reps)}, - num_present_sequence_length_tile, - {batch_heads}, - {static_cast(parameters.num_heads_)}}); - - return context.RunProgram(program); -} - -Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { - shader.AddInput("seqlens_k", ShaderUsage::None); - } - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(input, input), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(output, output)); } Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context, const Tensor* out_split_vx, + const Tensor* metadata, Tensor* output, const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch) { + bool use_indirect_dispatch, + const Tensor* head_sink, + uint32_t m_tile) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch}; - program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); + bool has_head_sink = head_sink != nullptr; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, + {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (has_head_sink) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - program.SetDispatchGroupSize(batch_heads * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) + program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {batch_heads}}); + {batch_heads}, + {static_cast(parameters.sequence_length_)}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -446,14 +442,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; + // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch. + const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u); + const uint32_t num_q_tiles = (static_cast(parameters.sequence_length_) + m_tile - 1u) / m_tile; + // Create indirect dispatch buffer if using indirect dispatch Tensor* indirect_buffer_ptr = nullptr; Tensor indirect_buffer; - // Prepare indirect dispatch buffer for decode path with static KV cache - const bool use_indirect_dispatch = !kv_empty && - parameters.sequence_length_ == 1 && - parameters.past_present_share_buffer_ && + // Prepare indirect dispatch buffer for split-reduce path with static KV cache. + // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based + // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. + const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && seqlen_k != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { @@ -492,10 +492,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size)); + indirect_buffer_ptr, tile_size, num_q_tiles)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); } // Extract present_sequence_length directly from present_key tensor shape @@ -503,7 +503,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); - if (parameters.sequence_length_ > 1) { + // Route between prefill path (FlashAttentionProgram, single kernel) + // and split-reduce decode path (QKV + VxReduce, 2 kernels). + // Split-reduce wins for short Q (sequence_length < 32) across all KV + // cache lengths measured: 1.13x-2.07x faster at total_sequence_length + // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96). + const bool use_split_reduce = parameters.sequence_length_ < 32; + + if (!use_split_reduce) { + // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles) bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -545,7 +553,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t prefill_tile_size = is_apple ? 128 : tile_size; const uint32_t num_seq_tile = (parameters.sequence_length_ + prefill_tile_size - 1) / prefill_tile_size; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -570,37 +577,31 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } - // For decode path (sequence_length == 1) - const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, present_sequence_length}); - const TensorShape qk_shape(qk_dims); - Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); + // Split-reduce path (QKV + VxReduce) const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, 2}); + parameters.sequence_length_, num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k, - parameters, indirect_buffer_ptr, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length)); const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, parameters.head_size_}); + parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, - seqlen_k, parameters, indirect_buffer_ptr, - num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, - use_indirect_dispatch, present_sequence_length, - head_sink)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value, + &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + present_sequence_length, m_tile)); + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + head_sink, m_tile)); return Status::OK(); } @@ -621,7 +622,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size) { + uint32_t tile_size, uint32_t num_q_tiles) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -678,6 +679,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, + {static_cast(params.batch_size_)}, + {num_q_tiles}, }); program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index e75b6378f67c6..3da6b33b4dc0e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -35,7 +35,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"tile_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_q_tiles", ProgramUniformVariableDataType::Uint32}); private: bool has_past_; @@ -138,11 +142,14 @@ class FlashAttentionProgram final : public Program { int max_k_step_; }; -class FlashAttentionDecodeQKTProgram final : public Program { +class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeQKTProgram(const std::string& kernel_name, - bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeQKVProgram(const std::string& kernel_name, + bool has_attention_bias, uint32_t tile_size, int head_size_vec, + bool use_indirect_dispatch, bool q_BNSH = false, + bool is_unidirectional = false, + uint32_t m_tile = 1) + : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -156,41 +163,23 @@ class FlashAttentionDecodeQKTProgram final : public Program { - public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_size_vec", ProgramUniformVariableDataType::Uint32}, - {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_heads", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); - - private: uint32_t tile_size_; int head_size_vec_; bool use_indirect_dispatch_; - bool has_head_sink_; + bool q_BNSH_; + bool is_unidirectional_; + uint32_t m_tile_; }; class FlashAttentionDecodeVxReduceProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -199,12 +188,16 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array; -var inner_qk_values: array, tile_size>; -var tile_qk: array; - -#if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - // Handle broadcasting: if dimension size is 1, use index 0 - let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); - let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - - // Calculate flat offset with broadcasting applied - // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] - // For decode, new_seq_length is 1, so we can simplify: - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + - bias_head_idx * total_seq_length + - k_idx; - return attention_bias[offset]; - } -#else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - return q_element_t(0); - } -#endif - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; -#if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; -#else - let total_sequence_length = uniforms.total_sequence_length; -#endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let head_idx = batch_head_idx % uniforms.num_heads; - let batch_idx = batch_head_idx / uniforms.num_heads; - if (batch_idx >= uniforms.batch_size) { - return; - } - let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; - let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; - for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { - if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { - tile_q[local_idx] = q[q_offset + k + local_idx]; - } - workgroupBarrier(); - let q_data = tile_q[local_col] * q_element_t(uniforms.alpha); - if (k + local_col < uniforms.head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data); - } - } - } - workgroupBarrier(); - } - - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - var sum = q_element_t(0); - for (var i = 0u; i < tile_size_k_vec; i++) { - sum += inner_qk_values[local_idx][i]; - } - - sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); - tile_qk[local_idx] = sum; - output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; - } - workgroupBarrier(); - - if (local_idx == 0u) { - // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); - var l_sum = f32(0); - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_max = max(l_max, f32(tile_qk[i])); - } - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_sum += exp(f32(tile_qk[i]) - l_max); - } - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; - metadata[meta_offset] = metadata_value_t(l_max, l_sum); - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template new file mode 100644 index 0000000000000..524a18ca43245 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_attention_bias +#param v_head_size_vec +#param is_unidirectional +#param m_tile +#param q_BNSH +#param sub_tile_count +#param tile_size +#param tile_size_k_vec +#param use_indirect_dispatch + +#use .getByOffset .setByOffset + +// Fused QK^T + softmax + V multiply shader. +// +// Each workgroup processes one KV tile (tile_size rows of present_key/value) +// for m_tile Q rows. The computation has two phases: +// +// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask, +// per-tile max/sum for online softmax) +// Phase 2: Local softmax normalization + V multiply (using local max/sum, +// no cross-workgroup dependency) +// +// The VxReduce shader performs the final rescaling across tiles. + +var tile_q: array, m_tile>; +var inner_qk_values: array, tile_size>, m_tile>; +var tile_qk: array, m_tile>; +var tile_output: array, m_tile>; +var qkv_values: array, sub_tile_count>, m_tile>; +var tile_max: array; +var tile_sum: array; + +#if has_attention_bias + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + + bias_head_idx * uniforms.new_sequence_length * total_seq_length + + q_idx * total_seq_length + + k_idx; + return attention_bias[offset]; + } +#else + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + return q_element_t(0); + } +#endif + +$MAIN { + let local_row = u32(local_idx / tile_size_k_vec); + let local_col = local_idx % tile_size_k_vec; + #if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; + #else + let total_sequence_length = uniforms.total_sequence_length; + #endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles)); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; + let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + + // ============================================================ + // Phase 1: QK^T computation + // ============================================================ + for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { + let q_idx = q_base + m; +#if q_BNSH + let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec + + head_idx * uniforms.new_sequence_length * uniforms.head_size_vec + + q_idx * uniforms.head_size_vec; +#else + let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec; +#endif + tile_q[m][local_idx] = q.getByOffset(q_offset + k + local_idx); + } + } + workgroupBarrier(); + if (k + local_col < uniforms.head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let k_data = present_key.getByOffset(present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha); + inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data); + } + } + } + } + workgroupBarrier(); + } + + // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + var sum = q_element_t(0); + for (var i = 0u; i < tile_size_k_vec; i++) { + sum += inner_qk_values[m][local_idx][i]; + } + + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); +#if is_unidirectional + if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + sum = q_element_t(-65504.0f); + } +#endif + tile_qk[m][local_idx] = present_value_element_t(sum); + } + workgroupBarrier(); + + // Compute per-tile max and sum for online softmax + if (local_idx == 0u) { + var l_max = f32(-3.4028234663852886e+38f); + var l_sum = f32(0); + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_max = max(l_max, f32(tile_qk[m][i])); + } + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_sum += exp(f32(tile_qk[m][i]) - l_max); + } + tile_max[m] = l_max; + tile_sum[m] = l_sum; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(l_max, l_sum)); + } + } + workgroupBarrier(); + + // ============================================================ + // Phase 2: Local softmax + V multiply + // ============================================================ + + // Normalize tile_qk with local max/sum + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); + } + } + workgroupBarrier(); + + for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] = present_value_value_t(0); + } + workgroupBarrier(); + + if (k + local_col < v_head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let v_data = present_value.getByOffset(present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row]; + } + } + } + } + workgroupBarrier(); + + if (local_idx < tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + for (var i = 0u; i < sub_tile_count; i++) { + tile_output[m][k + local_idx] += qkv_values[m][i][local_idx]; + } + } + } + workgroupBarrier(); + } + + // Write output + let tile_idx = workgroup_idx % num_total_seq_length_tile; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec; + for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) { + out_split_vx.setByOffset(out_base + tile_idx * v_head_size_vec + i, tile_output[m][i]); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template deleted file mode 100644 index 6f1ad1ca41b71..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param has_head_sink -#param tile_size -#param head_size_vec -#param tile_size_k_vec -#param sub_tile_count -#param use_indirect_dispatch - -// Note that this shader adopts similar algorithm with dp4a generation shader. -// -// This algorithm works to compute dot product of v with qk parallelly, by -// processing on the head_size dimension at each step amongst tile_size_k_vec -// threads, and utilizing the remaining threads in the workgroup to process -// additional rows of |present_value| in parallel (such that the values in -// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads -// also reload |present_value| tile_size/sub_tile_count times to compute partial -// dot products of other |present_value| rows in order to complete all tile_size -// |present_value| rows in this workgroup and also reusing the values in -// tile_qk. -// -// The difference with FlashAttentionDecodeQKTProgram is that the dot products -// go through the rows (total_sequence_length) of |present_value| instead of -// columns (head_size_vec). And each workgroup only calculate current -// tile_size's dot products instead of iterating the whole row -// |total_sequence_length|. That's why this shader is a split shader. The final -// reduce will be done in FlashAttentionDecodeReduceProgram. - -// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx -// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate -// memory. The FlashAttentionDecodeQKT can be merged into split shader and do -// the final softmax adjustment in the reduce shader. However, some issues are -// met that when the total sequence length exceeds some value, the result will -// become garbage. Since it can't be resolved in a short time, leave it as TODO -// to fix it in future. - -var tile_qk: array; -var tile_output: array; -var qkv_values: array, sub_tile_count>; - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; - #else - let total_sequence_length = uniforms.total_sequence_length; - #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - if (batch_head_idx >= uniforms.batch_heads) { - return; - } - let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; - - // Calculate the global max and sum in qk. - var g_max = f32(-3.4028234663852886e+38f); -#if has_head_sink - let head_idx = batch_head_idx % uniforms.num_heads; - let sink_value = f32(head_sink[head_idx]); - g_max = max(g_max, sink_value); -#endif - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } -#if has_head_sink - g_sum += exp(sink_value - g_max); -#endif - - if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); - } - - for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { - var value = present_value_value_t(0); - qkv_values[local_row][local_col] = present_value_value_t(0); - workgroupBarrier(); - - if (k + local_col < head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row]; - } - } - } - - qkv_values[local_row][local_col] = value; - workgroupBarrier(); - - if (local_idx < tile_size_k_vec) { - for (var i = 0u; i < sub_tile_count; i++) { - tile_output[k + local_idx] += qkv_values[i][local_idx]; - } - } - workgroupBarrier(); - } - - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; - out_split_vx[out_offset] = tile_output[i]; - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index f909a87724da6..a3ce0b68cb659 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -1,31 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#param has_head_sink +#param m_tile #param seq_tile_size #param tile_size #param use_indirect_dispatch -// Inputs are splits of the GQA output, split into num_total_seq_length_tiles -// rows. This shader needs to add these splits across the row dimension to -// arrive at the final result. The column is head size wide. The reduction -// achieves maximum parallelization by splitting this task first into tile_size -// columns that each workgroup is responsible for. Then within each workgroup -// the task of summation over the num_total_seq_length_tile for the tile_size -// columns is further split in two ways. First across the row dimension to have -// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE -// rows. Then across the column dimension where each thread is responsible for 1 -// column of the TILE_SIZE columns the workgroup is responsible for. +#use .getByOffset .setByOffset + +// This shader reduces partial V outputs from the fused QKV shader. +// Each tile produced a locally-normalized V contribution. To get the +// correct global result, we rescale each tile's contribution using +// per-tile metadata (max, sum) with online softmax: +// +// global_max = max(local_max_i for all tiles) +// global_sum = sum(local_sum_i * exp(local_max_i - global_max)) +// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum var tile_input: array, tile_size>; $MAIN { + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile] let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles)); if (batch_head_idx >= uniforms.batch_heads) { return; } - let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; - var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; #if use_indirect_dispatch @@ -35,23 +39,60 @@ $MAIN { let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif - if (head_size_offset + local_col < uniforms.head_size_vec) { - for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { - if (r + local_row < num_total_seq_length_tile) { - value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile; + + // Compute global max across all tiles + var g_max = f32(-3.4028234663852886e+38f); +#if has_head_sink + let head_idx_for_sink = batch_head_idx % uniforms.num_heads; + let sink_value = f32(head_sink[head_idx_for_sink]); + g_max = max(g_max, sink_value); +#endif + for (var i = 0u; i < num_total_seq_length_tile; i++) { + g_max = max(g_max, metadata.getByOffset(meta_base + i).x); + } + + // Compute global sum with rescaling + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { + let m_value = metadata.getByOffset(meta_base + i); + g_sum += m_value.y * exp(m_value.x - g_max); + } +#if has_head_sink + g_sum += exp(sink_value - g_max); +#endif + + // Accumulate rescaled partial outputs + var value = output_value_t(0); + if (head_size_offset + local_col < uniforms.head_size_vec) { + for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { + if (r + local_row < num_total_seq_length_tile) { + let tile_meta = metadata.getByOffset(meta_base + r + local_row); + let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum; + value += input.getByOffset(in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col) * output_value_t(output_element_t(rescale_f32)); + } } } - } - tile_input[local_row][local_col] = value; - workgroupBarrier(); + tile_input[local_row][local_col] = value; + workgroupBarrier(); - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { - value = output_value_t(0); - for (var i = 0u; i < tile_size; i++) { - value += tile_input[i][local_idx]; + if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { + value = output_value_t(0); + for (var i = 0u; i < tile_size; i++) { + value += tile_input[i][local_idx]; + } + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec + + head_size_offset + local_idx; + output.setByOffset(output_id, value); } - let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; - output[output_id] = value; + workgroupBarrier(); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index c64bdf45cdcf8..7b09a3a6af080 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -41,12 +41,9 @@ $MAIN { #endif #if prepare_indirect_dispatch - // Prepare indirect dispatch buffer for thread 0 if (global_idx == 0u) { let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; - indirect_buffer[0] = num_total_seq_length_tile; - indirect_buffer[1] = uniforms.num_heads; - indirect_buffer[2] = 1u; + normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1054fd94ef423..f3f2f521ecab2 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1520,18 +1520,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::STRING, std::string("int")) .Attr("weights_prepacked", - "Only meaningful when quant_type='int'. Tri-state control over whether the " - "int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS " - "fpA_intB format expected by the runner. -1 (auto): let the execution provider " - "choose its own backward-compatible default; the CUDA EP treats auto as " - "prepacked. 1: the initializers are already prepacked (e.g. produced offline by " - "pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers " - "are raw, un-prepacked [E, N, K/pack] tensors as produced by " - "quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself " - "in PrePack(), matching the behaviour of MatMulNBits and removing the offline " - "pre-pack requirement from exporters. Defaults to -1 (auto) so each execution " - "provider can pick its own backward-compatible default rather than the schema " - "imposing one.", + "Only meaningful when quant_type='int'. Tri-state control over the layout of the " + "int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by " + "-1 and 1 are determined by the execution provider. 0: the initializers are raw, " + "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.", AttributeProto::INT, static_cast(-1)) .Input(0, diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index 732a5a855d65d..4260ef706befa 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/device_discovery.h" +#include "core/platform/linux/npu_device_discovery.h" #include "core/platform/linux/pci_device_discovery.h" #include @@ -123,7 +124,7 @@ Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& gpu_devices_out) { std::vector gpu_sysfs_path_infos{}; @@ -315,6 +317,119 @@ Status GetGpuDevices(std::vector& gpu_devices_out) { } // namespace +namespace npu_device_discovery { + +Status DetectNpuSysfsPaths(const fs::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out) { + std::error_code error_code{}; + + const bool sysfs_accel_path_exists = fs::exists(sysfs_accel_path, error_code); + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Checking existence of accel sysfs path")); + + if (!sysfs_accel_path_exists) { + npu_sysfs_paths_out = {}; + return Status::OK(); + } + + const auto detect_accel_path = [](const fs::path& sysfs_path, size_t& accel_idx) -> bool { + const auto filename = sysfs_path.filename(); + const auto filename_str = std::string_view{filename.native()}; + + // Look for a filename matching "accelN". N is a number. + constexpr std::string_view prefix = "accel"; + if (filename_str.find(prefix) != 0) { + return false; + } + + size_t parsed_accel_idx{}; + if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_accel_idx)) { + return false; + } + + accel_idx = parsed_accel_idx; + return true; + }; + + std::vector npu_sysfs_paths{}; + + auto dir_iterator = fs::directory_iterator{sysfs_accel_path, error_code}; + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Iterating over accel sysfs devices")); + + for (const auto& dir_item : dir_iterator) { + const auto& dir_item_path = dir_item.path(); + + if (size_t accel_idx{}; detect_accel_path(dir_item_path, accel_idx)) { + NpuSysfsPathInfo path_info{}; + path_info.accel_idx = accel_idx; + path_info.path = dir_item_path; + npu_sysfs_paths.emplace_back(std::move(path_info)); + } + } + + npu_sysfs_paths_out = std::move(npu_sysfs_paths); + return Status::OK(); +} + +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out) { + OrtHardwareDevice npu_device{}; + + const auto& sysfs_path = path_info.path; + + uint16_t vendor_id{}; + const auto vendor_id_path = sysfs_path / "device" / "vendor"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, vendor_id)); + npu_device.vendor_id = vendor_id; + + uint16_t device_id{}; + const auto device_id_path = sysfs_path / "device" / "device"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, device_id)); + npu_device.device_id = device_id; + + npu_device.metadata.Add("accel_idx", MakeString(path_info.accel_idx)); + + std::optional pci_bus_id; + ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id)); + if (pci_bus_id) { + npu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id)); + } + + npu_device.type = OrtHardwareDeviceType_NPU; + + npu_device_out = std::move(npu_device); + + return Status::OK(); +} + +} // namespace npu_device_discovery + +namespace { + +Status GetNpuDevices(std::vector& npu_devices_out) { + std::vector npu_sysfs_path_infos{}; + ORT_RETURN_IF_ERROR(npu_device_discovery::DetectNpuSysfsPaths(kSysfsAccelPath, npu_sysfs_path_infos)); + + std::vector npu_devices{}; + npu_devices.reserve(npu_sysfs_path_infos.size()); + + for (const auto& npu_sysfs_path_info : npu_sysfs_path_infos) { + OrtHardwareDevice npu_device{}; + if (auto status = npu_device_discovery::GetNpuDeviceFromSysfs(npu_sysfs_path_info, npu_device); !status.IsOK()) { + LOGS_DEFAULT(WARNING) << MakeString("Failed to detect devices under ", + npu_sysfs_path_info.path, + ": ", + status.ErrorMessage()); + continue; + } + npu_devices.emplace_back(std::move(npu_device)); + } + + npu_devices_out = std::move(npu_devices); + + return Status::OK(); +} +} // namespace + std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { std::unordered_set devices; @@ -334,7 +449,16 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } // get NPU devices - // TODO figure out how to discover these + { + std::vector npu_devices{}; + Status npu_device_discovery_status = GetNpuDevices(npu_devices); + if (npu_device_discovery_status.IsOK()) { + devices.insert(std::make_move_iterator(npu_devices.begin()), + std::make_move_iterator(npu_devices.end())); + } else { + LOGS_DEFAULT(WARNING) << "NPU device discovery failed: " << npu_device_discovery_status.ErrorMessage(); + } + } return devices; } diff --git a/onnxruntime/core/platform/linux/npu_device_discovery.h b/onnxruntime/core/platform/linux/npu_device_discovery.h new file mode 100644 index 0000000000000..e139d33988903 --- /dev/null +++ b/onnxruntime/core/platform/linux/npu_device_discovery.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This header exposes Linux NPU device discovery internals for testing. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" +#include "core/session/abi_devices.h" + +namespace onnxruntime { +namespace npu_device_discovery { + +struct NpuSysfsPathInfo { + size_t accel_idx; + std::filesystem::path path; +}; + +// Scans the given sysfs accel directory for NPU accel devices. +Status DetectNpuSysfsPaths(const std::filesystem::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out); + +// Reads vendor/device IDs and populates an OrtHardwareDevice from an accel sysfs path. +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out); + +} // namespace npu_device_discovery +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 2008f998384a3..20ec11bf7deb2 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -64,6 +64,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(ir_version); @@ -81,6 +82,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(execution_provider_ids); ORT_UNUSED_PARAMETER(hardware_device_types); ORT_UNUSED_PARAMETER(hardware_vendor_ids); + ORT_UNUSED_PARAMETER(ep_versions); ORT_UNUSED_PARAMETER(use_fp16); ORT_UNUSED_PARAMETER(captureState); } @@ -124,6 +126,15 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu ORT_UNUSED_PARAMETER(line); } +void Telemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(status); + ORT_UNUSED_PARAMETER(ep_versions); + ORT_UNUSED_PARAMETER(ep_device_types); +} + void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { ORT_UNUSED_PARAMETER(session_id); @@ -139,6 +150,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -149,6 +161,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, ORT_UNUSED_PARAMETER(hardware_device_id); ORT_UNUSED_PARAMETER(hardware_vendor); ORT_UNUSED_PARAMETER(ep_vendor); + ORT_UNUSED_PARAMETER(ep_version); ORT_UNUSED_PARAMETER(assigned_node_count); ORT_UNUSED_PARAMETER(total_runs_since_last); ORT_UNUSED_PARAMETER(total_run_duration_since_last); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index c2dce68faa2dd..946e34ee35832 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,7 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const; virtual void LogCompileModelStart(uint32_t session_id, @@ -86,6 +87,10 @@ class Telemetry { virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; + virtual void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const; + virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const; @@ -100,6 +105,7 @@ class Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 04b9aaa0eb8ed..342b937ffb656 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -285,6 +285,7 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -318,6 +319,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio const std::string& loaded_from, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -373,7 +375,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -392,6 +395,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { @@ -404,7 +408,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -423,6 +428,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -457,7 +463,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(input_source.c_str(), "inputSource"), TraceLoggingString(output_target.c_str(), "outputTarget"), @@ -466,6 +472,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingBool(embed_ep_context, "embedEpContext"), TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"), TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -507,7 +514,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingHResult(hr, "hResult"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), @@ -516,6 +523,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #else TraceLoggingWrite(telemetry_provider_handle, @@ -525,7 +533,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), TraceLoggingUInt32(status.Category(), "errorCategory"), @@ -533,10 +541,35 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #endif } +void WindowsTelemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RuntimeInferenceError", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), + TraceLoggingString(ep_device_types.c_str(), "executionProviderDeviceTypes"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { if (global_register_count_ == 0 || enabled_ == false) @@ -559,11 +592,12 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -574,6 +608,7 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -588,7 +623,8 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + // schemaVersion 1: added epVersion, runtimeVersion + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(ep_type.c_str(), "executionProviderType"), TraceLoggingString(hardware_device_type.c_str(), "hardwareDeviceType"), @@ -596,9 +632,11 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingUInt32(hardware_device_id, "hardwareDeviceId"), TraceLoggingString(hardware_vendor.c_str(), "hardwareVendor"), TraceLoggingString(ep_vendor.c_str(), "epVendor"), + TraceLoggingString(ep_version.c_str(), "epVersion"), TraceLoggingInt32(assigned_node_count, "assignedNodeCount"), TraceLoggingUInt32(total_runs_since_last, "totalRunsSinceLast"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDurationSinceLast"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -724,8 +762,9 @@ void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -799,8 +838,9 @@ void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 8295d042d7ec9..46c262e1479d3 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,7 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const override; void LogCompileModelStart(uint32_t session_id, @@ -79,6 +80,10 @@ class WindowsTelemetry : public Telemetry { void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; + void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const override; + void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const override; @@ -89,6 +94,7 @@ class WindowsTelemetry : public Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override; diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index f462680e02587..f75c5fe12a6f4 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -524,14 +524,15 @@ template static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool /*inverse*/) { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal const auto* signal = ctx->Input(0); const auto frame_step = signal::get_scalar_value_from_tensor(ctx->Input(1)); + ORT_RETURN_IF_NOT(frame_step > 0, "frame_step must be greater than zero."); const auto* window = ctx->Input(2); const auto* frame_length_tensor = ctx->Input(3); @@ -596,8 +597,11 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Run each dft of each batch as if it was a real-valued batch size 1 dft operation for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int64_t i = 0; i < n_dfts; i++) { - auto input_frame_begin = - signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components); + const auto frame_start = i * frame_step; + // Defensive check before creating a non-owning tensor view. n_dfts derivation should keep this in bounds. + ORT_RETURN_IF_NOT(frame_start <= signal_size - window_size, "STFT input frame is out of bounds."); + // signal_data is U*, so one increment advances one input sample, including both lanes for complex input. + auto input_frame_begin = signal_data + (batch_idx * signal_size) + frame_start; auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + (i * dft_output_size * output_components); @@ -619,9 +623,9 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside Status STFT::Compute(OpKernelContext* ctx) const { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal shape diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc index d030b68abdb99..458c47bcb58cd 100644 --- a/onnxruntime/core/session/experimental_c_api.cc +++ b/onnxruntime/core/session/experimental_c_api.cc @@ -18,6 +18,14 @@ namespace OrtExperimentalApis { +// Forward declarations driven by the .inc file so the registration table below +// can take the address of every entry, including those defined in other +// translation units linked into onnxruntime_session. +#define ORT_EXPERIMENTAL_API(VER, RET, NAME, ...) \ + RET ORT_API_CALL NAME##_SinceV##VER(__VA_ARGS__) NO_EXCEPTION; +#include "onnxruntime_experimental_c_api.inc" +#undef ORT_EXPERIMENTAL_API + // Test-only experimental function that writes a known sentinel value. // Exists to exercise the experimental API mechanism end-to-end and to serve as a template for future experimental // functions. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 73f3e77b93177..9db18fc4ac69e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -78,6 +78,7 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/user_logging_sink.h" @@ -871,7 +872,7 @@ InferenceSession::~InferenceSession() { telemetry_provider.LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } } @@ -2778,6 +2779,7 @@ common::Status InferenceSession::Initialize() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(), telemetry_.ep_device_types_summary_, telemetry_.ep_device_vendor_ids_summary_, + telemetry_.ep_versions_summary_, model_has_fp16_inputs, false); // Emit one initial EpDeviceUsage event per (EP, device) pair with run counts of 0. @@ -2787,7 +2789,7 @@ common::Status InferenceSession::Initialize() { env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, 0, 0); + ep_info.ep_version, ep_info.assigned_node_count, 0, 0); } LOGS(*session_logger_, INFO) << "Session successfully initialized."; @@ -3456,7 +3458,7 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } // reset counters @@ -3470,6 +3472,10 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, Telemetry::kRuntimePerfMaxInterval); } } + } else { + // Log runtime error with EP versions + env.GetTelemetryProvider().LogRuntimeInferenceError(session_id_, retval, telemetry_.ep_versions_summary_, + telemetry_.ep_device_types_summary_); } // log evaluation stop to trace logging provider @@ -4087,6 +4093,7 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { telemetry_.ep_device_info_.clear(); telemetry_.ep_device_types_summary_.clear(); telemetry_.ep_device_vendor_ids_summary_.clear(); + telemetry_.ep_versions_summary_.clear(); // First, count nodes assigned to each EP type after graph partitioning. // The graph node only carries the EP type string, so when a single EP targets @@ -4124,6 +4131,10 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { Telemetry::EpDeviceInfo entry; entry.ep_type = ep_type; entry.ep_vendor = ep_device->ep_vendor; + auto it = ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version); + if (it != ep_device->ep_metadata.Entries().end()) { + entry.ep_version = it->second; + } if (ep_device->device != nullptr) { entry.hardware_device_type = HardwareDeviceTypeToString(ep_device->device->type); entry.vendor_id = ep_device->device->vendor_id; @@ -4158,11 +4169,13 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { // by position against the existing executionProviderIds field. std::ostringstream types_oss; std::ostringstream vendor_ids_oss; + std::ostringstream versions_oss; bool first = true; for (const auto& entry : telemetry_.ep_device_info_) { if (!first) { types_oss << ','; vendor_ids_oss << ','; + versions_oss << ','; } first = false; types_oss << entry.hardware_device_type; @@ -4170,9 +4183,11 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { vendor_ids_oss << "0x" << std::hex << std::uppercase << std::setw(4) << std::setfill('0') << entry.vendor_id << std::dec << std::nouppercase << std::setfill(' '); + versions_oss << entry.ep_type << ':' << entry.ep_version; } telemetry_.ep_device_types_summary_ = types_oss.str(); telemetry_.ep_device_vendor_ids_summary_ = vendor_ids_oss.str(); + telemetry_.ep_versions_summary_ = versions_oss.str(); } #if !defined(ORT_MINIMAL_BUILD) @@ -4419,6 +4434,7 @@ void InferenceSession::LogAllSessions() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model->MetaData(), session->telemetry_.event_name_, session->execution_providers_.GetIds(), session->telemetry_.ep_device_types_summary_, session->telemetry_.ep_device_vendor_ids_summary_, + session->telemetry_.ep_versions_summary_, model_has_fp16_inputs, true); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 11c887f6cdc16..19c18652ab67a 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -1023,12 +1023,14 @@ class InferenceSession { uint32_t device_id = 0; // PCI device ID (0 when unavailable) std::string vendor; // e.g. "Qualcomm" std::string ep_vendor; // e.g. "Qualcomm" (from OrtEpDevice) + std::string ep_version; // e.g. "1.2.3" (from OrtEpFactory::GetVersion, empty when unavailable) int assigned_node_count = 0; // # graph nodes assigned to this EP type }; std::vector ep_device_info_; // Pre-formatted comma-separated summaries used to enrich SessionCreation. std::string ep_device_types_summary_; // "NPU,CPU" std::string ep_device_vendor_ids_summary_; // "0x5143,0x0000" + std::string ep_versions_summary_; // "QNNExecutionProvider:1.2.3,CPUExecutionProvider:" } telemetry_; mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data diff --git a/onnxruntime/core/session/model_package_api.cc b/onnxruntime/core/session/model_package_api.cc index 27abb0f5f7a37..5fd25f9511eb9 100644 --- a/onnxruntime/core/session/model_package_api.cc +++ b/onnxruntime/core/session/model_package_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/model_package_api.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/common/common.h" #include "core/framework/error_code_helper.h" @@ -23,7 +23,9 @@ using namespace onnxruntime; return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, \ "Model package API is not supported in this build") -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, +namespace OrtExperimentalApis { + +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28, _Frees_ptr_opt_ OrtModelPackageOptions* options) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(options); @@ -32,7 +34,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelPackageOptions** out) { @@ -54,7 +56,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOpti API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageContext* ctx) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(ctx); @@ -63,7 +65,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageContext_SinceV28, _In_ const ORTCHAR_T* package_root, _Outptr_ OrtModelPackageContext** out) { API_IMPL_BEGIN @@ -89,7 +91,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ size_t* out_count) { API_IMPL_BEGIN @@ -107,7 +109,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, _Out_ size_t* out_count) { @@ -136,7 +138,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Out_ size_t* out_count) { @@ -158,7 +160,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, @@ -190,7 +192,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, +ORT_API_STATUS_IMPL(OrtModelPackageApi_SelectComponent_SinceV28, _In_ const OrtModelPackageContext* context, _In_ const char* component_name, _In_ const OrtModelPackageOptions* options, @@ -235,7 +237,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageComponentContext* cix) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(cix); @@ -244,7 +246,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const ORTCHAR_T** folder_path) { API_IMPL_BEGIN @@ -269,7 +271,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateSession_SinceV28, _In_ const OrtEnv* env, _In_ OrtModelPackageComponentContext* ctx, _In_opt_ const OrtSessionOptions* session_options, @@ -383,7 +385,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, // ---------- API table ------------------------------------------------------ -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _In_ const char* variant_name, @@ -417,7 +419,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ int64_t* out_version) { API_IMPL_BEGIN @@ -437,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const char** out_name) { API_IMPL_BEGIN @@ -461,40 +463,4 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -// ---------- API table dispatch --------------------------------------------- - -static constexpr OrtModelPackageApi ort_model_package_api = { - // Options - &OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, - &OrtModelPackageAPI::ReleaseModelPackageOptions, - - // Context - &OrtModelPackageAPI::CreateModelPackageContext, - &OrtModelPackageAPI::ReleaseModelPackageContext, - - // Package-level queries - &OrtModelPackageAPI::ModelPackage_GetSchemaVersion, - &OrtModelPackageAPI::ModelPackage_GetComponentCount, - &OrtModelPackageAPI::ModelPackage_GetComponentNames, - &OrtModelPackageAPI::ModelPackage_GetVariantCount, - &OrtModelPackageAPI::ModelPackage_GetVariantNames, - &OrtModelPackageAPI::ModelPackage_GetVariantEpName, - - // Variant selection and component queries - &OrtModelPackageAPI::SelectComponent, - &OrtModelPackageAPI::ReleaseModelPackageComponentContext, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, - - // Session - &OrtModelPackageAPI::CreateSession, - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - -static_assert(offsetof(OrtModelPackageApi, CreateSession) / sizeof(void*) == 14, - "Size of initial OrtModelPackageApi cannot change"); - -ORT_API(const OrtModelPackageApi*, OrtModelPackageAPI::GetModelPackageApi) { - return &ort_model_package_api; -} +} // namespace OrtExperimentalApis diff --git a/onnxruntime/core/session/model_package_api.h b/onnxruntime/core/session/model_package_api.h deleted file mode 100644 index 435e3a521c24c..0000000000000 --- a/onnxruntime/core/session/model_package_api.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/session/onnxruntime_c_api.h" - -namespace OrtModelPackageAPI { - -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - -ORT_API(void, ReleaseModelPackageOptions, _Frees_ptr_opt_ OrtModelPackageOptions*); -ORT_API_STATUS_IMPL(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - -ORT_API(void, ReleaseModelPackageContext, _Frees_ptr_opt_ OrtModelPackageContext*); -ORT_API_STATUS_IMPL(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - -ORT_API(void, ReleaseModelPackageComponentContext, - _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - -ORT_API_STATUS_IMPL(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* ctx, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - -ORT_API_STATUS_IMPL(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - -} // namespace OrtModelPackageAPI diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 61a413d92e7fc..f451eaa401497 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -57,7 +57,6 @@ #include "core/session/ort_env.h" #include "core/session/ort_version_check.h" #include "core/session/utils.h" -#include "core/session/model_package_api.h" #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" @@ -3689,10 +3688,6 @@ ORT_API(const OrtCompileApi*, OrtApis::GetCompileApi) { return OrtCompileAPI::GetCompileApi(); } -ORT_API(const OrtModelPackageApi*, OrtApis::GetModelPackageApi) { - return OrtModelPackageAPI::GetModelPackageApi(); -} - ORT_API(void, OrtApis::CreateKeyValuePairs, _Outptr_ OrtKeyValuePairs** out) { auto kvps = std::make_unique(); *out = reinterpret_cast(kvps.release()); @@ -4911,7 +4906,6 @@ static constexpr OrtApi ort_api_1_to_28 = { &OrtApis::GetMemPatternEnabled, &OrtApis::GetSessionExecutionMode, &OrtApis::SessionReleaseCapturedGraph, - &OrtApis::GetModelPackageApi, // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 250a2853a4777..61ece2dd9a682 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -825,9 +825,6 @@ ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtVal _Outptr_result_maybenull_ const int64_t** shape_data, _Out_ size_t* shape_data_count); -// Model Package API -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - // Experimental API ORT_API(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5b1d590a06234..7220153b4fa17 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -16,7 +16,6 @@ #endif #include #include -#include #include namespace pybind11 { @@ -252,17 +251,8 @@ py::array_t PackWeightsForMixedGemm( cudaDeviceProp device_prop; ThrowIfCudaError(cudaGetDeviceProperties(&device_prop, device_id), "cudaGetDeviceProperties"); sm = device_prop.major * 10 + device_prop.minor; - } else { - // Validate force_arch against the SM versions for which preprocess_weights_for_mixed_gemm_cuda - // has tile/permutation tables. Unknown SMs would silently produce incorrect weight layouts. - static const std::set kSupportedSm = {75, 80, 90}; - if (kSupportedSm.find(sm) == kSupportedSm.end()) { - std::ostringstream oss; - oss << "force_arch=" << sm << " is not a supported SM version. " - << "Pass -1 for auto-detect, or one of: 75, 80, 90 (arch > 90 will fallback to 80)."; - throw std::invalid_argument(oss.str()); - } } + sm = ::onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(sm); auto permutation_map_buffer = make_cuda_ptr(32 * sizeof(int32_t)); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 77200ea84778e..2044a128d9540 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -3247,195 +3247,6 @@ including arg name, arg type (contains both type and shape).)pbdoc") #endif }, R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); - - // --- Model Package API --- -#if !defined(ORT_MINIMAL_BUILD) - // Helper to create a PyInferenceSession from a pre-initialized OrtSession* (C API handle). - // PyInferenceSession's owning ctor is protected; this subclass provides access. - struct PyModelPackageSession : PyInferenceSession { - PyModelPackageSession(std::unique_ptr sess) - : PyInferenceSession(std::move(sess)) {} - }; - - // Wrapper classes to manage opaque C handles with proper RAII - struct PyModelPackageContext { - OrtModelPackageContext* ctx_{nullptr}; - PyModelPackageContext(const std::string& package_path) { - auto path = ToPathString(package_path); - const auto* api = Ort::GetApi().GetModelPackageApi(); - Ort::ThrowOnError(api->CreateModelPackageContext(path.c_str(), &ctx_)); - } - ~PyModelPackageContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageContext(ctx_); - } - } - PyModelPackageContext(const PyModelPackageContext&) = delete; - PyModelPackageContext& operator=(const PyModelPackageContext&) = delete; - }; - - struct PyModelPackageComponentContext { - OrtModelPackageComponentContext* ctx_{nullptr}; - ~PyModelPackageComponentContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageComponentContext(ctx_); - } - } - PyModelPackageComponentContext(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext& operator=(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext() = default; - }; - - struct PyModelPackageOptions { - OrtModelPackageOptions* opts_{nullptr}; - ~PyModelPackageOptions() { - if (opts_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageOptions(opts_); - } - } - PyModelPackageOptions(const PyModelPackageOptions&) = delete; - PyModelPackageOptions& operator=(const PyModelPackageOptions&) = delete; - PyModelPackageOptions() = default; - }; - - py::class_(m, "ModelPackageContext", - R"pbdoc(Represents an opened model package for inspection and component selection.)pbdoc") - .def(py::init(), py::arg("package_path"), - R"pbdoc(Open a model package from the given directory path.)pbdoc") - .def( - "get_component_names", - [](PyModelPackageContext& self) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetComponentNames(self.ctx_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - R"pbdoc(Get the names of all components in the package.)pbdoc") - .def( - "get_variant_names", - [](PyModelPackageContext& self, const std::string& component_name) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetVariantNames( - self.ctx_, component_name.c_str(), &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - py::arg("component_name"), - R"pbdoc(Get the variant names for a given component.)pbdoc") - .def( - "get_variant_ep_name", - [](PyModelPackageContext& self, const std::string& component_name, - const std::string& variant_name) -> std::optional { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* ep = nullptr; - Ort::ThrowOnError(api->ModelPackage_GetVariantEpName( - self.ctx_, component_name.c_str(), variant_name.c_str(), &ep)); - if (ep) return std::string(ep); - return std::nullopt; - }, - py::arg("component_name"), py::arg("variant_name"), - R"pbdoc(Get the EP name for a variant. Returns None if not declared.)pbdoc") - .def( - "get_schema_version", - [](PyModelPackageContext& self) -> int64_t { - const auto* api = Ort::GetApi().GetModelPackageApi(); - int64_t version = 0; - Ort::ThrowOnError(api->ModelPackage_GetSchemaVersion(self.ctx_, &version)); - return version; - }, - R"pbdoc(Get the schema version declared in the model package manifest.)pbdoc") - .def( - "select_component", - [](PyModelPackageContext& self, const std::string& component_name, - PyModelPackageOptions& options) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->SelectComponent( - self.ctx_, component_name.c_str(), options.opts_, &result->ctx_)); - return result; - }, - py::arg("component_name"), py::arg("options"), - R"pbdoc(Select a component and resolve its variant based on the provided options. -Returns a ModelPackageComponentContext for inspecting the selected variant.)pbdoc"); - - py::class_(m, "ModelPackageOptions", - R"pbdoc(Options used for variant selection in a model package. -Created from a SessionOptions to capture EP configuration for variant matching.)pbdoc") - .def(py::init([](PySessionOptions& session_options) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->CreateModelPackageOptionsFromSessionOptions( - GetOrtEnv(), &session_options, &result->opts_)); - return result; - }), - py::arg("session_options"), - R"pbdoc(Create model package options from a SessionOptions instance. -The EP configured on the session options is used for variant selection.)pbdoc"); - - py::class_(m, "ModelPackageComponentContext", - R"pbdoc(Represents a selected component within a model package. -Provides access to the resolved variant's files, session options, and metadata.)pbdoc") - .def( - "get_selected_variant_folder_path", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const ORTCHAR_T* path = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantFolderPath(self.ctx_, &path)); - return PathToUTF8String(PathString(path)); - }, - R"pbdoc(Get the folder path of the selected variant.)pbdoc") - .def( - "get_selected_variant_name", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* name = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantName( - self.ctx_, &name)); - return name ? std::string(name) : std::string(); - }, - R"pbdoc(Get the name of the selected variant.)pbdoc") - .def( - "create_session", - [](PyModelPackageComponentContext& self, py::object session_options_obj) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - OrtSession* ort_session = nullptr; - if (session_options_obj.is_none()) { - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, nullptr, &ort_session)); - } else { - auto& so = session_options_obj.cast(); - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, &so, &ort_session)); - } - // OrtSession* is a reinterpret_cast of InferenceSession* - auto* inference_session = reinterpret_cast(ort_session); - std::unique_ptr session_ptr(inference_session); - return std::make_unique(std::move(session_ptr)); - }, - py::arg("session_options") = py::none(), - R"pbdoc(Create an InferenceSession from the selected component variant. - -Args: - session_options: Optional SessionOptions override. If None, uses the options - captured during variant selection with per-file options merged on top. - If provided, variant-specific options are NOT applied. - -Returns: - An InferenceSession ready for inference.)pbdoc"); -#endif // !defined(ORT_MINIMAL_BUILD) } bool InitArray() { diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index 6fb3f8e6ba82f..34f73eb69b149 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -12,6 +12,7 @@ #include "gtest/gtest.h" #include "core/session/model_package/model_package_context.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/abi_devices.h" #include "test/autoep/test_autoep_utils.h" #include "test/util/include/asserts.h" @@ -22,6 +23,90 @@ extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { namespace { + +// Typed function pointers for every OrtModelPackageApi_* experimental entry, +// resolved once via the experimental name-based lookup. +struct ModelPackageFns { + OrtExperimental_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn + CreateModelPackageOptionsFromSessionOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn + ReleaseModelPackageOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn + CreateModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn + ReleaseModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn + ModelPackage_GetSchemaVersion{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn + ModelPackage_GetComponentCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn + ModelPackage_GetComponentNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn + ModelPackage_GetVariantCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn + ModelPackage_GetVariantNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn + ModelPackage_GetVariantEpName{nullptr}; + OrtExperimental_OrtModelPackageApi_SelectComponent_SinceV28_Fn + SelectComponent{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn + ReleaseModelPackageComponentContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantName{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantFolderPath{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateSession_SinceV28_Fn + CreateSession{nullptr}; +}; + +inline const ModelPackageFns& GetModelPackageFns() { + static const ModelPackageFns fns = []() { + const OrtApi* api = &Ort::GetApi(); + ModelPackageFns f; +#define RESOLVE(member, getter) \ + do { \ + f.member = Ort::Experimental::getter(api); \ + if (f.member == nullptr) { \ + throw std::runtime_error(std::string("Failed to resolve experimental " \ + "OrtModelPackageApi_") + \ + #member "_SinceV28"); \ + } \ + } while (0) + RESOLVE(CreateModelPackageOptionsFromSessionOptions, + Get_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn); + RESOLVE(ReleaseModelPackageOptions, + Get_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn); + RESOLVE(CreateModelPackageContext, + Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn); + RESOLVE(ReleaseModelPackageContext, + Get_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn); + RESOLVE(ModelPackage_GetSchemaVersion, + Get_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentCount, + Get_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentNames, + Get_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantCount, + Get_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantNames, + Get_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantEpName, + Get_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn); + RESOLVE(SelectComponent, + Get_OrtModelPackageApi_SelectComponent_SinceV28_Fn); + RESOLVE(ReleaseModelPackageComponentContext, + Get_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantName, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantFolderPath, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn); + RESOLVE(CreateSession, + Get_OrtModelPackageApi_CreateSession_SinceV28_Fn); +#undef RESOLVE + return f; + }(); + return fns; +} // ------------------------------------------------------------------ // Helpers to build a test model package on disk // ------------------------------------------------------------------ @@ -233,26 +318,25 @@ std::filesystem::path CreateModelPackageApiTestPackage(bool multi_file_variant = TEST(ModelPackageApiTest, PackageContextQueries) { const auto package_root = CreateModelPackageApiTestPackage(); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr model_pkg_context(nullptr, context_deleter); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); // Query: component count + names size_t component_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); ASSERT_EQ(component_count, 1u); const char* const* component_names = nullptr; size_t component_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentNames( model_pkg_context.get(), &component_names, &component_name_count)); ASSERT_EQ(component_name_count, 1u); ASSERT_NE(component_names, nullptr); @@ -261,13 +345,13 @@ TEST(ModelPackageApiTest, PackageContextQueries) { // Query: variant count + names size_t variant_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantCount( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantCount( model_pkg_context.get(), "model_1", &variant_count)); ASSERT_EQ(variant_count, 2u); const char* const* variant_names = nullptr; size_t variant_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantNames( model_pkg_context.get(), "model_1", &variant_names, &variant_name_count)); ASSERT_EQ(variant_name_count, 2u); @@ -294,17 +378,16 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { - if (p) pkg_api->ReleaseModelPackageOptions(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { + if (p) pkg_api.ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr model_pkg_options(nullptr, options_deleter); @@ -312,25 +395,25 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unique_ptr component_context(nullptr, component_context_deleter); OrtModelPackageOptions* raw_options = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); model_pkg_options.reset(raw_options); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); OrtModelPackageComponentContext* raw_component_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(model_pkg_context.get(), - "model_1", - model_pkg_options.get(), - &raw_component_context)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(model_pkg_context.get(), + "model_1", + model_pkg_options.get(), + &raw_component_context)); component_context.reset(raw_component_context); OrtSession* raw_session = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateSession(*ort_env, - component_context.get(), - session_options, - &raw_session)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateSession(*ort_env, + component_context.get(), + session_options, + &raw_session)); Ort::Session session(raw_session); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -921,33 +1004,32 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { os << R"({"filename":"mul_1.onnx"})"; } - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(nullptr, context_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); // variant_1 targets example_ep const char* ep1 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", &ep1)); ASSERT_NE(ep1, nullptr); EXPECT_STREQ(ep1, "example_ep"); // variant_2 targets other_ep const char* ep2 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_2", &ep2)); ASSERT_NE(ep2, nullptr); EXPECT_STREQ(ep2, "other_ep"); // Optional out-parameter: callers can pass NULL. - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", nullptr)); std::filesystem::remove_all(package_root, ec); @@ -990,32 +1072,31 @@ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); // Path looks like .../models/model_1/ -- the folder name is the variant. @@ -1082,28 +1163,27 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); // CreateSession iterates the per-file session_options and dispatches each through OrtApis::AddSessionConfigEntry. @@ -1111,7 +1191,7 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn // Pass nullptr for session_options so the metadata-merge path runs (it is skipped when the caller // supplies their own session_options). OrtSession* raw_session = nullptr; - OrtStatus* st = pkg_api->CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); + OrtStatus* st = pkg_api.CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); // Clean up session first to avoid leaks if assertion fails. if (raw_session != nullptr) { Ort::GetApi().ReleaseSession(raw_session); @@ -1131,60 +1211,6 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::filesystem::remove_all(package_root, ec); } -// Test that the C++ RAII wrappers (Ort::ModelPackageContext, etc.) work correctly. -TEST(ModelPackageApiTest, CxxWrappers_PackageContextQueries) { - const auto package_root = CreateModelPackageApiTestPackage(); - - Ort::ModelPackageContext ctx(package_root.c_str()); - - // Component queries - EXPECT_EQ(ctx.GetComponentCount(), 1u); - auto component_names = ctx.GetComponentNames(); - ASSERT_EQ(component_names.size(), 1u); - EXPECT_EQ(component_names[0], "model_1"); - - // Variant queries - EXPECT_EQ(ctx.GetVariantCount("model_1"), 2u); - auto variant_names = ctx.GetVariantNames("model_1"); - ASSERT_EQ(variant_names.size(), 2u); - std::unordered_set variant_set(variant_names.begin(), variant_names.end()); - EXPECT_EQ(variant_set.count("variant_1"), 1u); - EXPECT_EQ(variant_set.count("variant_2"), 1u); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageApiTest, CxxWrappers_SelectComponentAndQueryFileAccessors) { - const auto package_root = CreateModelPackageApiTestPackage(); - - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - - Ort::SessionOptions so; - std::unordered_map ep_options; - so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - - Ort::ModelPackageContext ctx(package_root.c_str()); - auto cix = ctx.SelectComponent("model_1", pkg_opts); - - // Folder path should be non-empty - auto folder = cix.GetSelectedVariantFolderPath(); - EXPECT_FALSE(folder.empty()); - - // Selected variant name should not throw - auto variant_name = cix.GetSelectedVariantName(); - EXPECT_FALSE(variant_name.empty()); - - // CreateSession via C++ wrapper - auto session = cix.CreateSession(*ort_env, so); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - // ------------------------------------------------------------------ // Test: GetSelectedVariantFolderPath returns correct path even when variant.json is absent. // ------------------------------------------------------------------ @@ -1221,31 +1247,29 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { Ort::SessionOptions so; std::unordered_map ep_options; so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; std::unique_ptr mp_opts(raw_mp_opts, options_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(raw_ctx, context_deleter); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr comp_ctx(raw_comp_ctx, component_context_deleter); // GetSelectedVariantFolderPath should return the variant directory even without variant.json. const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); const auto result_path = std::filesystem::path(selected_folder); diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index cc50494273aad..38cc3014aa873 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -778,6 +779,74 @@ TEST(MoETest, MoETest_Mixtral) { 2 /*top_k*/); } +TEST(MoETest, QMoETest_CUDA_Int4_DisablePrepackingFailsLoudly) { + constexpr int min_cuda_arch = 700; + if (!HasCudaEnvironment(min_cuda_arch)) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + constexpr int64_t num_rows = 1; + constexpr int64_t num_experts = 1; + constexpr int64_t hidden_size = 128; + constexpr int64_t inter_size = 128; + constexpr int64_t expert_weight_bits = 4; + constexpr int64_t pack_size = 8 / expert_weight_bits; + + const std::vector input(num_rows * hidden_size, 0.0f); + const std::vector router_probs(num_rows * num_experts, 1.0f); + const std::vector fc1_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0); + const std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0); + const std::vector fc1_scales(num_experts * inter_size, 1.0f); + const std::vector fc2_scales(num_experts * hidden_size, 1.0f); + const std::vector dummy_output(num_rows * hidden_size, 0.0f); + + OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain); + cuda_tester.AddAttribute("k", 1); + cuda_tester.AddAttribute("activation_type", "identity"); + cuda_tester.AddAttribute("normalize_routing_weights", 1); + cuda_tester.AddAttribute("expert_weight_bits", expert_weight_bits); + cuda_tester.AddAttribute("quant_type", "int"); + cuda_tester.AddAttribute("weights_prepacked", 0); + + const std::vector input_dims = {num_rows, hidden_size}; + const std::vector router_probs_dims = {num_rows, num_experts}; + const std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + const std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + const std::vector fc1_scales_dims = {num_experts, inter_size}; + const std::vector fc2_scales_dims = {num_experts, hidden_size}; + const std::vector output_dims = {num_rows, hidden_size}; + + cuda_tester.AddInput("input", input_dims, ToFloat16(input)); + cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOutput("output", output_dims, ToFloat16(dummy_output)); + + SessionOptions session_options; + session_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "1"; + + std::vector> cuda_execution_providers; + cuda_execution_providers.push_back(std::move(cuda_ep)); + cuda_tester.Run(session_options, + OpTester::ExpectResult::kExpectFailure, + "QMoE weights_prepacked=0 requires PrePack to run", + {}, + nullptr, + &cuda_execution_providers); +} + TEST(MoETest, QMoETest_Mixtral_Int4) { // This test uses FC3 (gated SiLU / Mixtral pattern) with dimensions too small for the // CUTLASS kernel (needs hidden_size >= 128, inter_size >= 128). CPU QMoE does not diff --git a/onnxruntime/test/platform/linux/npu_device_discovery_test.cc b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc new file mode 100644 index 0000000000000..6e2170146f39f --- /dev/null +++ b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/linux/npu_device_discovery.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "test/util/include/asserts.h" + +namespace fs = std::filesystem; + +namespace onnxruntime::test { +namespace { + +void WriteFile(const fs::path& path, const std::string& value) { + std::ofstream f(path); + f << value; +} + +class NpuDeviceDiscoveryTest : public ::testing::Test { + protected: + void SetUp() override { + temp_dir_ = fs::temp_directory_path() / "ort_npu_discovery_test"; + fs::remove_all(temp_dir_); + fs::create_directories(temp_dir_); + } + + void TearDown() override { + fs::remove_all(temp_dir_); + } + + fs::path temp_dir_; +}; + +} // namespace + +TEST_F(NpuDeviceDiscoveryTest, ReturnsEmptyForNonexistentPath) { + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_ / "nonexistent", npu_paths)); + EXPECT_TRUE(npu_paths.empty()); +} + +TEST_F(NpuDeviceDiscoveryTest, DetectsAccelDevices) { + fs::create_directories(temp_dir_ / "accel0"); + fs::create_directories(temp_dir_ / "accel12"); + fs::create_directories(temp_dir_ / "renderD128"); + fs::create_directories(temp_dir_ / "accelabc"); + + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_, npu_paths)); + + ASSERT_EQ(npu_paths.size(), 2u); + + std::vector accel_indices; + accel_indices.reserve(npu_paths.size()); + for (const auto& npu_path : npu_paths) { + accel_indices.push_back(npu_path.accel_idx); + } + + std::sort(accel_indices.begin(), accel_indices.end()); + + EXPECT_EQ(accel_indices[0], 0u); + EXPECT_EQ(accel_indices[1], 12u); +} + +TEST_F(NpuDeviceDiscoveryTest, GetNpuDeviceFromSysfsReadsVendorDeviceAndMetadata) { + const auto pci_device_dir = temp_dir_ / "pci_devices" / "0000:65:00.0"; + const auto accel_dir = temp_dir_ / "class_accel" / "accel0"; + + fs::create_directories(pci_device_dir); + fs::create_directories(accel_dir); + + WriteFile(pci_device_dir / "vendor", "0x1022"); + WriteFile(pci_device_dir / "device", "0x1502"); + + std::error_code error_code{}; + fs::create_directory_symlink(pci_device_dir, accel_dir / "device", error_code); + ASSERT_FALSE(error_code) << error_code.message(); + + npu_device_discovery::NpuSysfsPathInfo path_info{}; + path_info.accel_idx = 0; + path_info.path = accel_dir; + + OrtHardwareDevice npu_device{}; + ASSERT_STATUS_OK(npu_device_discovery::GetNpuDeviceFromSysfs(path_info, npu_device)); + + EXPECT_EQ(npu_device.type, OrtHardwareDeviceType_NPU); + EXPECT_EQ(npu_device.vendor_id, 0x1022u); + EXPECT_EQ(npu_device.device_id, 0x1502u); + + const auto& entries = npu_device.metadata.Entries(); + EXPECT_NE(entries.find("accel_idx"), entries.end()); + EXPECT_EQ(entries.at("accel_idx"), "0"); + EXPECT_NE(entries.find("pci_bus_id"), entries.end()); + EXPECT_EQ(entries.at("pci_bus_id"), "0000:65:00.0"); +} + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index 966090eed861e..c87e5d7c4c6e1 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -233,6 +233,80 @@ TEST(SignalOpsTest, STFTFloat) { test.Run(); } +static void TestSTFTInvalidFrameStep(int64_t frame_step) { + OpTester test("STFT", kMinOpsetVersion); + + vector signal(64, 1); + test.AddInput("signal", {1, 64, 1}, signal); + test.AddInput("frame_step", {}, {frame_step}); + vector window(16, 1); + test.AddInput("window", {16}, window); + test.AddInput("frame_length", {}, {16}); + + vector output_shape = {1, 7, 9, 2}; + vector expected_output(1 * 7 * 9 * 2, 0.f); + test.AddOutput("output", output_shape, expected_output); + test.Config(OpTester::ExpectResult::kExpectFailure, "frame_step must be greater than zero"); + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFrameStepMustBePositive) { + TestSTFTInvalidFrameStep(0); + TestSTFTInvalidFrameStep(-1); +} + +template +static void TestSTFTComplexInputBatched() { + OpTester test("STFT", kMinOpsetVersion); + test.AddAttribute("onesided", static_cast(false)); + + constexpr int64_t batch_size = 2; + constexpr int64_t signal_size = 128; + constexpr int64_t signal_components = 2; + constexpr int64_t frame_length = 32; + constexpr int64_t frame_step = 16; + constexpr int64_t n_dfts = 7; + constexpr int64_t dft_output_size = frame_length; + constexpr int64_t output_components = 2; + + vector signal(batch_size * signal_size * signal_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T signal_value = batch_idx == 0 ? static_cast(1) : static_cast(99); + for (int64_t sample_idx = 0; sample_idx < signal_size; ++sample_idx) { + signal[(batch_idx * signal_size + sample_idx) * signal_components] = signal_value; + } + } + + test.AddInput("signal", {batch_size, signal_size, signal_components}, signal); + test.AddInput("frame_step", {}, {frame_step}); + test.AddOptionalInputEdge(); + test.AddInput("frame_length", {}, {frame_length}); + + vector output_shape = {batch_size, n_dfts, dft_output_size, output_components}; + vector expected_output(batch_size * n_dfts * dft_output_size * output_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T dc_value = (batch_idx == 0 ? static_cast(1) : static_cast(99)) * static_cast(frame_length); + for (int64_t frame_idx = 0; frame_idx < n_dfts; ++frame_idx) { + expected_output[((batch_idx * n_dfts + frame_idx) * dft_output_size) * output_components] = dc_value; + } + } + + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.001f); + // DML does not consistently match these CPU STFT validation/regression paths in Windows GPU CI. + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFloatComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + +TEST(SignalOpsTest, STFTDoubleComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + TEST(SignalOpsTest, HannWindowFloat) { OpTester test("HannWindow", kMinOpsetVersion); diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index c5fc826a5a6ed..9677542270a53 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -152,7 +152,10 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): q_weight_reshaped = q_weight.reshape(n, -1) # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format), and qMoE kernel uses the same format. - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4) + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4, 80) # So we need to DEQUANTIZE back to get `result`. # scale is [n, block_per_k] @@ -232,8 +235,11 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): ) q_weight_reshaped = q_weight.reshape(n, -1) - # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format) - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8) + # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format). + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8, 80) # Dequantize for reference # (q - 128) * scale if using 128 offset? or (q) * scale if symmetric around 0? @@ -1084,8 +1090,8 @@ def parity_check(self): ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (0.3, 0.05), - "FP16:4": (3.0, 1e-2), - "FP16:8": (2.0, 1e-2), + "FP16:4": (0.5, 1e-2), + "FP16:8": (0.5, 1e-2), "BF16:0": (1.0, 1e-2), "BF16:4": (30.0, 1e-1), "BF16:8": (20.0, 1e-1), diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 993716a4c80b0..c56383d2851d3 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -137,43 +137,6 @@ def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): ) -def preprocess_weights_for_mixed_gemm( - tensor: torch.Tensor, quant_bits: int, sm: int = -1, do_weight_interleave: bool = True -) -> torch.Tensor: - if len(tensor.shape) == 2: - tensor = tensor.unsqueeze(0) - - # Input tensor shape is [Experts, n, k_packed]. k_packed is k/2 for 4-bit, k for 8-bit. - num_experts = tensor.shape[0] - n = tensor.shape[1] - k_packed = tensor.shape[2] - k = k_packed * 2 if quant_bits == 4 else k_packed - - packed_list = [] - - if _pybind and hasattr(_pybind, "pack_weights_for_cuda_mixed_gemm") and torch.cuda.is_available(): - for i in range(num_experts): - if tensor[i].dtype == torch.bfloat16: - weight = tensor[i].to(torch.float32).cpu().numpy() - else: - weight = tensor[i].cpu().numpy() - packed = _pybind.pack_weights_for_cuda_mixed_gemm(weight, n, k, quant_bits, sm) - # pack_weights_for_cuda_mixed_gemm returns int8 array of shape [packed_size] - # We need to reshape it to (k, n/2) for 4-bit, (k, n) for 8-bit. - output_rows = k - output_cols = n // 2 if quant_bits == 4 else n - packed_tensor = torch.from_numpy(packed).to(tensor.device) - packed_tensor = packed_tensor.view(torch.uint8).view(output_rows, output_cols) - packed_list.append(packed_tensor) - - return torch.stack(packed_list) - else: - # This shall not happen unless older version of onnxruntime is used. - raise ImportError( - "onnxruntime._pybind_state.pack_weights_for_cuda_mixed_gemm not found. Cannot preprocess weights." - ) - - def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True, asymmetric: bool = False): # DEBUG # print(f"DEBUG: quant_dequant input shape={weights.shape}, 4bit={is_4_bit_quantization}, asym={asymmetric}") @@ -2110,7 +2073,7 @@ class TestQMoEIntPrePackSmoke(unittest.TestCase): hardware (the other ``test_swiglu_qmoe_parity_*`` cases in this file fail on H200 / H100 with max-diff > 1.0 on plain main, by inspection — pre-existing). A real parity check can be added once - that harness honours the runtime SM. + that harness honors the runtime SM. """ def _run_one(self, *, hidden_size, inter_size, num_experts, top_k, swiglu_fusion, batch_size):