From 2c4be91361bb7377d7cb4db5599776685813ceae Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 10 Jun 2026 17:13:55 -0700 Subject: [PATCH 1/8] Move OrtModelPackageApi to the experimental C API (#28990) ### Description Move the existing model package C API off the stable `OrtApi` onto the experimental name-based lookup mechanism added in #28746. Each model package function is registered individually in `include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc` with the `OrtModelPackageApi_` prefix and the `_SinceV28` version suffix, following the lifecycle rules in `docs/design/Experimental_C_API.md`. Headline changes: - `OrtApi::GetModelPackageApi`, the `OrtModelPackageApi` struct, `OrtApis::GetModelPackageApi`, the `OrtModelPackageAPI` namespace, `onnxruntime/core/session/model_package_api.h`, and the C++ wrappers (`Ort::GetModelPackageApi`, `ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackage*)`, `ModelPackageOptions/Context/ComponentContext`) are removed. - Opaque handle types (`OrtModelPackageOptions`, `OrtModelPackageContext`, `OrtModelPackageComponentContext`) move into `onnxruntime_experimental_c_api.h`. - All 15 model package functions are registered in `onnxruntime_experimental_c_api.inc`. Impls move into `namespace OrtExperimentalApis` with `_SinceV28`-suffixed names in `model_package_api.cc`; bodies are unchanged. - `experimental_c_api.cc` gains a forward-decl block (driven by the same `.inc` X-macro) so the auto-generated registration table can take the address of every entry, even those defined in `model_package_api.cc`. - The Python bindings (`PyModelPackageContext` / `PyModelPackageOptions` / `PyModelPackageComponentContext` and their `onnxruntime.__init__` exports) are removed. Per the design doc we start the experimental API in C/C++ only. - `onnxruntime/test/autoep/test_model_package.cc` switches to a local `ModelPackageFns` struct populated through the `Ort::Experimental::Get_OrtModelPackageApi_*_Fn(api)` typed accessors. Consumer usage going forward, in C++: ```cpp #include "onnxruntime_c_api.h" #include "onnxruntime_experimental_c_api.h" const OrtApi* ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); if (auto* fn = Ort::Experimental::Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn(ort)) { OrtModelPackageContext* ctx = nullptr; Ort::ThrowOnError(fn(ORT_TSTR("/path/to/pkg"), &ctx)); // ... } ``` ### Motivation and Context The model package API was added to the stable `OrtApi` in 1.27 but has not shipped in a release yet. Now that #28746 has landed the experimental C API framework, the right home for an iterating preview surface like model package is behind `OrtApi::GetExperimentalFunction`, not on the stable struct. Moving it to experimental: - frees us to change signatures (each name is uniquely versioned) without breaking the stable ABI; - gives consumers a clear "is this specific thing available?" contract instead of a struct that *looks* stable but isn't; - lets the surface be promoted to stable cleanly later (move entries to `OrtApi`, drop the `_SinceV` suffix, remove the experimental entries). --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 270 ------------------ .../core/session/onnxruntime_cxx_api.h | 84 ------ .../core/session/onnxruntime_cxx_inline.h | 104 ------- .../session/onnxruntime_experimental_c_api.h | 4 + .../onnxruntime_experimental_c_api.inc | 247 ++++++++++++++++ onnxruntime/__init__.py | 3 - .../core/session/experimental_c_api.cc | 8 + onnxruntime/core/session/model_package_api.cc | 72 ++--- onnxruntime/core/session/model_package_api.h | 76 ----- onnxruntime/core/session/onnxruntime_c_api.cc | 6 - onnxruntime/core/session/ort_apis.h | 3 - .../python/onnxruntime_pybind_state.cc | 189 ------------ onnxruntime/test/autoep/test_model_package.cc | 264 +++++++++-------- 13 files changed, 422 insertions(+), 908 deletions(-) delete mode 100644 onnxruntime/core/session/model_package_api.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0cdbfd7114cf0..0289bca5c84d2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -348,9 +348,6 @@ ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared exte ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails); ORT_RUNTIME_CLASS(EpAssignedSubgraph); ORT_RUNTIME_CLASS(EpAssignedNode); -ORT_RUNTIME_CLASS(ModelPackageOptions); -ORT_RUNTIME_CLASS(ModelPackageContext); -ORT_RUNTIME_CLASS(ModelPackageComponentContext); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -938,9 +935,6 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtInteropApi; typedef struct OrtInteropApi OrtInteropApi; -struct OrtModelPackageApi; -typedef struct OrtModelPackageApi OrtModelPackageApi; - struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -7501,26 +7495,6 @@ struct OrtApi { */ ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); - /** \brief Get the model package API table. - * - * Returns a pointer to the ::OrtModelPackageApi function table, which provides APIs to: - * - create and release model package options and contexts, - * - inspect model package metadata (components/variants), - * - select a component/variant and query selected files/options, - * - create a session from model package selection results. - * - * The returned pointer is owned by ONNX Runtime and is valid for the process lifetime. - * Do not free it. - * - * \note May return NULL if model package support is not available in the current build - * (for example, minimal builds). - * - * \return Pointer to ::OrtModelPackageApi, or NULL if unsupported. - * - * \since Version 1.27. - */ - const OrtModelPackageApi*(ORT_API_CALL* GetModelPackageApi)(void); - /** \brief Retrieve an experimental function pointer by name. * * Experimental functions are not part of the stable ABI and may be added or removed between releases without notice. @@ -8645,250 +8619,6 @@ struct OrtInteropApi { /// @} }; -/** \brief API table for model package workflows. - * - * A model package is a directory containing one or more *components* (logical models). - * Each component has one or more *variants*, where each variant targets a single - * execution provider (EP). The package manifest declares the EP name, device type, - * and an optional compatibility string for every variant so that the runtime can - * automatically select the best variant for the hardware and EPs available in the - * caller's session options. - * - * Obtain this table from OrtApi::GetModelPackageApi(). The APIs support: - * - creating model package options that capture EP configuration from OrtSessionOptions, - * - loading a package context (manifest + metadata) from a package root path, - * - querying component/variant metadata including per-variant EP information, - * - selecting a component (which also resolves the best-matching variant), - * - querying the selected variant's name and folder path, - * - creating an OrtSession from the selected component context. - * - * Typical flow: - * 1) Create model package options: - * - CreateModelPackageOptionsFromSessionOptions() - * 2) Load package metadata: - * - CreateModelPackageContext() - * 3) Query metadata (optional): - * - ModelPackage_GetSchemaVersion() - * - ModelPackage_GetComponentCount() - * - ModelPackage_GetComponentNames() - * - ModelPackage_GetVariantCount() - * - ModelPackage_GetVariantNames() - * - ModelPackage_GetVariantEpName() - * 4) Select a component and resolve variant: - * - SelectComponent() - * 5) Query selected variant info (optional): - * - ModelPackageComponent_GetSelectedVariantName() - * - ModelPackageComponent_GetSelectedVariantFolderPath() - * 6) Create session: - * - CreateSession() - * - * Ownership: - * - Release objects created by this API with the corresponding release methods: - * ReleaseModelPackageOptions(), ReleaseModelPackageContext(), - * ReleaseModelPackageComponentContext(). - * - * \since Version 1.27. - */ -struct OrtModelPackageApi { - /// \name OrtModelPackageOptions - /// @{ - - /** \brief Create model package options from an existing OrtSessionOptions. - * - * Captures EP configuration (registered execution providers and their devices) from - * the session options for use during variant selection. The resulting OrtModelPackageOptions - * is passed to SelectComponent() to resolve the best variant for the available EPs. - * - * \param[in] env The ORT environment. - * \param[in] session_options Session options containing registered EPs. - * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released - * with ReleaseModelPackageOptions(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - - ORT_CLASS_RELEASE(ModelPackageOptions); - /// @} - /// \name OrtModelPackageContext - /// @{ - - /** \brief Create a model package context by parsing the package at the given root path. - * - * Parses the manifest.json and component metadata from the specified directory. - * The returned context provides read-only access to the package structure (components, - * variants, EP declarations). - * - * \param[in] package_root Path to the model package root directory (containing manifest.json). - * \param[out] out Receives the newly created OrtModelPackageContext. Must be released - * with ReleaseModelPackageContext(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - - ORT_CLASS_RELEASE(ModelPackageContext); - - /** \brief Get the schema version declared in the model package manifest. - * - * \param[in] ctx The model package context. - * \param[out] out_version Receives the schema version number. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - - /** \brief Get the number of components in the model package. - * - * \param[in] ctx The model package context. - * \param[out] out_count Receives the component count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - - /** \brief Get the names of all components in the model package. - * - * Returns a pointer to an array of UTF-8 component name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[out] out_names Receives a pointer to an array of component name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - - /** \brief Get the number of variants for a given component. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_count Receives the variant count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - - /** \brief Get the names of all variants for a given component. - * - * Returns a pointer to an array of UTF-8 variant name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_variant_names Receives a pointer to an array of variant name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - - /** \brief Get the EP name declared for a (component, variant) pair. - * - * Each variant targets a single EP. `out_ep` receives the EP name string. - * When the variant does not declare an EP, the returned pointer is NULL. - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - - /** \brief Select a component model and return an opaque component instance. - * - * The variant selection is also performed during this call based on the component metadata and the provided options. - * The returned `OrtModelPackgeComponentContext*` is independent of `context` lifetime and must be released via - * `ReleaseComponentInstance`. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - - ORT_CLASS_RELEASE(ModelPackageComponentContext); - - /** \brief Get the name of the selected variant after SelectComponent has been called. - * - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] out_name Receives the selected variant's name string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - - /** \brief Get the folder path of the selected variant. - * - * Returns the resolved absolute path to the variant's directory on disk. - * The string is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] folder_path Receives the variant folder path string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - - /// @} - /** \brief Create an OrtSession for a selected file within a component model variant. - * - * The chosen variant (and thus its EP selection) is determined by `context`, which - * was built from an OrtSessionOptions via CreateModelPackageOptionsFromSessionOptions. - * - * Session options precedence: - * 1. session_options == NULL (default path): - * ORT uses the OrtSessionOptions that was captured when `context` was created. - * Any variant-specific session and provider options declared in the package - * metadata are merged on top. - * - * 2. session_options != NULL (advanced path): - * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific - * session and provider options from the package metadata are NOT applied. - * Use this when custom EP setup is required (e.g., shared CUDA streams, - * shared QNN EP contexts, custom allocators). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* context, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a999ef5e0faf4..4798d3d4ad1b8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -264,20 +264,6 @@ inline const OrtEpApi& GetEpApi() { return *api; } -/// -/// This returns a reference to the ORT C Model Package API. Used for loading models from model packages. -/// -/// ORT C Model Package API reference -inline const OrtModelPackageApi& GetModelPackageApi() { - auto* api = GetApi().GetModelPackageApi(); - if (api == nullptr) { - // minimal build - ORT_CXX_API_THROW("Model Package API is not available in this build", ORT_FAIL); - } - - return *api; -} - /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -678,9 +664,6 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageOptions, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageContext, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageComponentContext, GetModelPackageApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -807,9 +790,6 @@ struct EpDevice; struct ExternalInitializerInfo; struct Graph; struct Model; -struct ModelPackageOptions; -struct ModelPackageContext; -struct ModelPackageComponentContext; struct Node; struct ModelMetadata; struct TypeInfo; @@ -1806,70 +1786,6 @@ struct ModelCompilationOptions : detail::Base { */ Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); -/** \brief Options for selecting a component from a model package. - * - * Wraps ::OrtModelPackageOptions. Created from an Env and SessionOptions, which captures the - * EP configuration used for variant selection. - */ -struct ModelPackageOptions : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageOptions(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - ModelPackageOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions - ModelPackageOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions -}; - -/** \brief Context for inspecting and selecting components from a model package. - * - * Wraps ::OrtModelPackageContext. Provides traversal APIs to enumerate components, variants, - * and EP compatibility, as well as component selection. - */ -struct ModelPackageContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - explicit ModelPackageContext(const ORTCHAR_T* package_root); ///< Wraps OrtModelPackageApi::CreateModelPackageContext - - size_t GetComponentCount() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentCount - std::vector GetComponentNames() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentNames - size_t GetVariantCount(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantCount - std::vector GetVariantNames(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantNames - - /// Get the EP name for a variant. Returns nullptr if not declared. - /// Returned string is owned by this context and valid until it is released. - const char* GetVariantEpName(const char* component_name, - const char* variant_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantEpName - - int64_t GetSchemaVersion() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetSchemaVersion - - ModelPackageComponentContext SelectComponent(const char* component_name, - const ModelPackageOptions& options) const; ///< Wraps OrtModelPackageApi::SelectComponent -}; - -/** \brief Context for a selected component within a model package. - * - * Wraps ::OrtModelPackageComponentContext. Provides accessors for the selected variant's - * folder path and variant name. - */ -struct ModelPackageComponentContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageComponentContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - std::basic_string GetSelectedVariantFolderPath() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantFolderPath - - std::string GetSelectedVariantName() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantName - - Session CreateSession(const Env& env); ///< Wraps OrtModelPackageApi::CreateSession (default path, NULL session_options) - Session CreateSession(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) - Session CreateSession(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) -}; - /** \brief Wrapper around ::OrtModelMetadata * */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 51f99655121c6..d7439e7b356c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1360,110 +1360,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const Ort return *this; } -// ModelPackageOptions -inline ModelPackageOptions::ModelPackageOptions(const Env& env, const SessionOptions& session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -inline ModelPackageOptions::ModelPackageOptions(const Env& env, ConstSessionOptions session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -// ModelPackageContext -inline ModelPackageContext::ModelPackageContext(const ORTCHAR_T* package_root) { - ThrowOnError(GetModelPackageApi().CreateModelPackageContext(package_root, &this->p_)); -} - -inline size_t ModelPackageContext::GetComponentCount() const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentCount(this->p_, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetComponentNames() const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentNames(this->p_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline size_t ModelPackageContext::GetVariantCount(const char* component_name) const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantCount(this->p_, component_name, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetVariantNames(const char* component_name) const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantNames(this->p_, component_name, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline const char* ModelPackageContext::GetVariantEpName(const char* component_name, - const char* variant_name) const { - const char* ep = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantEpName( - this->p_, component_name, variant_name, &ep)); - return ep; -} - -inline int64_t ModelPackageContext::GetSchemaVersion() const { - int64_t version = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetSchemaVersion(this->p_, &version)); - return version; -} - -inline ModelPackageComponentContext ModelPackageContext::SelectComponent( - const char* component_name, const ModelPackageOptions& options) const { - OrtModelPackageComponentContext* out = nullptr; - ThrowOnError(GetModelPackageApi().SelectComponent(this->p_, component_name, options, &out)); - return ModelPackageComponentContext{out}; -} - -// ModelPackageComponentContext -inline std::basic_string ModelPackageComponentContext::GetSelectedVariantFolderPath() const { - const ORTCHAR_T* path = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantFolderPath(this->p_, &path)); - return std::basic_string{path}; -} - -inline std::string ModelPackageComponentContext::GetSelectedVariantName() const { - const char* name = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantName(this->p_, &name)); - return (name != nullptr) ? std::string{name} : std::string{}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, nullptr, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - const SessionOptions& session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - ConstSessionOptions session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h index e943a5cf65b11..0dd87c10776d3 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h @@ -39,6 +39,10 @@ // ORT_RUNTIME_CLASS(ExperimentalType); // +ORT_RUNTIME_CLASS(ModelPackageOptions); +ORT_RUNTIME_CLASS(ModelPackageContext); +ORT_RUNTIME_CLASS(ModelPackageComponentContext); + // // C: function pointer typedefs and name constants // diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 0123b02818584..57a4e472b6f6d 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -35,3 +35,250 @@ * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t* out) + +// --------------------------------------------------------------------------- +// OrtModelPackageApi +// +// A model package is a directory containing one or more *components* (logical models). +// Each component has one or more *variants*, where each variant targets a single +// execution provider (EP). The package manifest declares the EP name, device type, +// and an optional compatibility string for every variant so that the runtime can +// automatically select the best variant for the hardware and EPs available in the +// caller's session options. +// +// The functions below support: +// - creating model package options that capture EP configuration from OrtSessionOptions, +// - loading a package context (manifest + metadata) from a package root path, +// - querying component/variant metadata including per-variant EP information, +// - selecting a component (which also resolves the best-matching variant), +// - querying the selected variant's name and folder path, +// - creating an OrtSession from the selected component context. +// +// Typical flow: +// 1) Create model package options: +// - OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions +// 2) Load package metadata: +// - OrtModelPackageApi_CreateModelPackageContext +// 3) Query metadata (optional): +// - OrtModelPackageApi_ModelPackage_GetSchemaVersion +// - OrtModelPackageApi_ModelPackage_GetComponentCount +// - OrtModelPackageApi_ModelPackage_GetComponentNames +// - OrtModelPackageApi_ModelPackage_GetVariantCount +// - OrtModelPackageApi_ModelPackage_GetVariantNames +// - OrtModelPackageApi_ModelPackage_GetVariantEpName +// 4) Select a component and resolve variant: +// - OrtModelPackageApi_SelectComponent +// 5) Query selected variant info (optional): +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath +// 6) Create session: +// - OrtModelPackageApi_CreateSession +// +// Ownership: +// - Release objects created by this API with the corresponding release functions: +// OrtModelPackageApi_ReleaseModelPackageOptions, +// OrtModelPackageApi_ReleaseModelPackageContext, +// OrtModelPackageApi_ReleaseModelPackageComponentContext. +// +// Opaque handles (OrtModelPackageOptions, OrtModelPackageContext, OrtModelPackageComponentContext) +// are declared in onnxruntime_experimental_c_api.h. +// --------------------------------------------------------------------------- + +/** \brief Create model package options from an existing OrtSessionOptions. + * + * Captures EP configuration (registered execution providers and their devices) from + * the session options for use during variant selection. The resulting OrtModelPackageOptions + * is passed to OrtModelPackageApi_SelectComponent to resolve the best variant for the + * available EPs. + * + * \param[in] env The ORT environment. + * \param[in] session_options Session options containing registered EPs. + * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released + * with OrtModelPackageApi_ReleaseModelPackageOptions. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions, + _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtModelPackageOptions** out) + +/** \brief Release an OrtModelPackageOptions created by + * OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageOptions, + _Frees_ptr_opt_ OrtModelPackageOptions* options) + +/** \brief Create a model package context by parsing the package at the given root path. + * + * Parses the manifest.json and component metadata from the specified directory. + * The returned context provides read-only access to the package structure (components, + * variants, EP declarations). + * + * \param[in] package_root Path to the model package root directory (containing manifest.json). + * \param[out] out Receives the newly created OrtModelPackageContext. Must be released + * with OrtModelPackageApi_ReleaseModelPackageContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageContext, + _In_ const ORTCHAR_T* package_root, + _Outptr_ OrtModelPackageContext** out) + +/** \brief Release an OrtModelPackageContext created by OrtModelPackageApi_CreateModelPackageContext. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageContext, + _Frees_ptr_opt_ OrtModelPackageContext* ctx) + +/** \brief Get the schema version declared in the model package manifest. + * + * \param[in] ctx The model package context. + * \param[out] out_version Receives the schema version number. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetSchemaVersion, + _In_ const OrtModelPackageContext* ctx, + _Out_ int64_t* out_version) + +/** \brief Get the number of components in the model package. + * + * \param[in] ctx The model package context. + * \param[out] out_count Receives the component count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentCount, + _In_ const OrtModelPackageContext* ctx, + _Out_ size_t* out_count) + +/** \brief Get the names of all components in the model package. + * + * Returns a pointer to an array of UTF-8 component name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[out] out_names Receives a pointer to an array of component name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentNames, + _In_ const OrtModelPackageContext* ctx, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, + _Out_ size_t* out_count) + +/** \brief Get the number of variants for a given component. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_count Receives the variant count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantCount, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Out_ size_t* out_count) + +/** \brief Get the names of all variants for a given component. + * + * Returns a pointer to an array of UTF-8 variant name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_variant_names Receives a pointer to an array of variant name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantNames, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, + _Out_ size_t* out_count) + +/** \brief Get the EP name declared for a (component, variant) pair. + * + * Each variant targets a single EP. `out_ep` receives the EP name string. + * When the variant does not declare an EP, the returned pointer is NULL. + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantEpName, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _In_ const char* variant_name, + _Outptr_result_maybenull_ const char** out_ep) + +/** \brief Select a component model and return an opaque component instance. + * + * The variant selection is also performed during this call based on the component + * metadata and the provided options. The returned `OrtModelPackageComponentContext*` is + * independent of `context` lifetime and must be released via + * OrtModelPackageApi_ReleaseModelPackageComponentContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_SelectComponent, + _In_ const OrtModelPackageContext* context, + _In_ const char* component_name, + _In_ const OrtModelPackageOptions* options, + _Outptr_ OrtModelPackageComponentContext** out) + +/** \brief Release an OrtModelPackageComponentContext created by OrtModelPackageApi_SelectComponent. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageComponentContext, + _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx) + +/** \brief Get the name of the selected variant after OrtModelPackageApi_SelectComponent has been called. + * + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] out_name Receives the selected variant's name string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const char** out_name) + +/** \brief Get the folder path of the selected variant. + * + * Returns the resolved absolute path to the variant's directory on disk. + * The string is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] folder_path Receives the variant folder path string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const ORTCHAR_T** folder_path) + +/** \brief Create an OrtSession for a selected file within a component model variant. + * + * The chosen variant (and thus its EP selection) is determined by `context`, which + * was built from an OrtSessionOptions via OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + * + * Session options precedence: + * 1. session_options == NULL (default path): + * ORT uses the OrtSessionOptions that was captured when `context` was created. + * Any variant-specific session and provider options declared in the package + * metadata are merged on top. + * + * 2. session_options != NULL (advanced path): + * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific + * session and provider options from the package metadata are NOT applied. + * Use this when custom EP setup is required (e.g., shared CUDA streams, + * shared QNN EP contexts, custom allocators). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession, + _In_ const OrtEnv* env, + _In_ OrtModelPackageComponentContext* context, + _In_opt_ const OrtSessionOptions* session_options, + _Outptr_ OrtSession** session) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c26424a0ab9ec..df14bc8c57f24 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -29,9 +29,6 @@ GraphOptimizationLevel, # noqa: F401 LoraAdapter, # noqa: F401 ModelMetadata, # noqa: F401 - ModelPackageComponentContext, # noqa: F401 - ModelPackageContext, # noqa: F401 - ModelPackageOptions, # noqa: F401 NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc index d030b68abdb99..458c47bcb58cd 100644 --- a/onnxruntime/core/session/experimental_c_api.cc +++ b/onnxruntime/core/session/experimental_c_api.cc @@ -18,6 +18,14 @@ namespace OrtExperimentalApis { +// Forward declarations driven by the .inc file so the registration table below +// can take the address of every entry, including those defined in other +// translation units linked into onnxruntime_session. +#define ORT_EXPERIMENTAL_API(VER, RET, NAME, ...) \ + RET ORT_API_CALL NAME##_SinceV##VER(__VA_ARGS__) NO_EXCEPTION; +#include "onnxruntime_experimental_c_api.inc" +#undef ORT_EXPERIMENTAL_API + // Test-only experimental function that writes a known sentinel value. // Exists to exercise the experimental API mechanism end-to-end and to serve as a template for future experimental // functions. diff --git a/onnxruntime/core/session/model_package_api.cc b/onnxruntime/core/session/model_package_api.cc index 27abb0f5f7a37..5fd25f9511eb9 100644 --- a/onnxruntime/core/session/model_package_api.cc +++ b/onnxruntime/core/session/model_package_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/model_package_api.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/common/common.h" #include "core/framework/error_code_helper.h" @@ -23,7 +23,9 @@ using namespace onnxruntime; return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, \ "Model package API is not supported in this build") -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, +namespace OrtExperimentalApis { + +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28, _Frees_ptr_opt_ OrtModelPackageOptions* options) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(options); @@ -32,7 +34,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelPackageOptions** out) { @@ -54,7 +56,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOpti API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageContext* ctx) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(ctx); @@ -63,7 +65,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageContext_SinceV28, _In_ const ORTCHAR_T* package_root, _Outptr_ OrtModelPackageContext** out) { API_IMPL_BEGIN @@ -89,7 +91,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ size_t* out_count) { API_IMPL_BEGIN @@ -107,7 +109,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, _Out_ size_t* out_count) { @@ -136,7 +138,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Out_ size_t* out_count) { @@ -158,7 +160,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, @@ -190,7 +192,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, +ORT_API_STATUS_IMPL(OrtModelPackageApi_SelectComponent_SinceV28, _In_ const OrtModelPackageContext* context, _In_ const char* component_name, _In_ const OrtModelPackageOptions* options, @@ -235,7 +237,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageComponentContext* cix) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(cix); @@ -244,7 +246,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const ORTCHAR_T** folder_path) { API_IMPL_BEGIN @@ -269,7 +271,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateSession_SinceV28, _In_ const OrtEnv* env, _In_ OrtModelPackageComponentContext* ctx, _In_opt_ const OrtSessionOptions* session_options, @@ -383,7 +385,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, // ---------- API table ------------------------------------------------------ -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _In_ const char* variant_name, @@ -417,7 +419,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ int64_t* out_version) { API_IMPL_BEGIN @@ -437,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const char** out_name) { API_IMPL_BEGIN @@ -461,40 +463,4 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -// ---------- API table dispatch --------------------------------------------- - -static constexpr OrtModelPackageApi ort_model_package_api = { - // Options - &OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, - &OrtModelPackageAPI::ReleaseModelPackageOptions, - - // Context - &OrtModelPackageAPI::CreateModelPackageContext, - &OrtModelPackageAPI::ReleaseModelPackageContext, - - // Package-level queries - &OrtModelPackageAPI::ModelPackage_GetSchemaVersion, - &OrtModelPackageAPI::ModelPackage_GetComponentCount, - &OrtModelPackageAPI::ModelPackage_GetComponentNames, - &OrtModelPackageAPI::ModelPackage_GetVariantCount, - &OrtModelPackageAPI::ModelPackage_GetVariantNames, - &OrtModelPackageAPI::ModelPackage_GetVariantEpName, - - // Variant selection and component queries - &OrtModelPackageAPI::SelectComponent, - &OrtModelPackageAPI::ReleaseModelPackageComponentContext, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, - - // Session - &OrtModelPackageAPI::CreateSession, - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - -static_assert(offsetof(OrtModelPackageApi, CreateSession) / sizeof(void*) == 14, - "Size of initial OrtModelPackageApi cannot change"); - -ORT_API(const OrtModelPackageApi*, OrtModelPackageAPI::GetModelPackageApi) { - return &ort_model_package_api; -} +} // namespace OrtExperimentalApis diff --git a/onnxruntime/core/session/model_package_api.h b/onnxruntime/core/session/model_package_api.h deleted file mode 100644 index 435e3a521c24c..0000000000000 --- a/onnxruntime/core/session/model_package_api.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/session/onnxruntime_c_api.h" - -namespace OrtModelPackageAPI { - -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - -ORT_API(void, ReleaseModelPackageOptions, _Frees_ptr_opt_ OrtModelPackageOptions*); -ORT_API_STATUS_IMPL(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - -ORT_API(void, ReleaseModelPackageContext, _Frees_ptr_opt_ OrtModelPackageContext*); -ORT_API_STATUS_IMPL(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - -ORT_API(void, ReleaseModelPackageComponentContext, - _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - -ORT_API_STATUS_IMPL(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* ctx, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - -ORT_API_STATUS_IMPL(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - -} // namespace OrtModelPackageAPI diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 61a413d92e7fc..f451eaa401497 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -57,7 +57,6 @@ #include "core/session/ort_env.h" #include "core/session/ort_version_check.h" #include "core/session/utils.h" -#include "core/session/model_package_api.h" #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" @@ -3689,10 +3688,6 @@ ORT_API(const OrtCompileApi*, OrtApis::GetCompileApi) { return OrtCompileAPI::GetCompileApi(); } -ORT_API(const OrtModelPackageApi*, OrtApis::GetModelPackageApi) { - return OrtModelPackageAPI::GetModelPackageApi(); -} - ORT_API(void, OrtApis::CreateKeyValuePairs, _Outptr_ OrtKeyValuePairs** out) { auto kvps = std::make_unique(); *out = reinterpret_cast(kvps.release()); @@ -4911,7 +4906,6 @@ static constexpr OrtApi ort_api_1_to_28 = { &OrtApis::GetMemPatternEnabled, &OrtApis::GetSessionExecutionMode, &OrtApis::SessionReleaseCapturedGraph, - &OrtApis::GetModelPackageApi, // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 250a2853a4777..61ece2dd9a682 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -825,9 +825,6 @@ ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtVal _Outptr_result_maybenull_ const int64_t** shape_data, _Out_ size_t* shape_data_count); -// Model Package API -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - // Experimental API ORT_API(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 77200ea84778e..2044a128d9540 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -3247,195 +3247,6 @@ including arg name, arg type (contains both type and shape).)pbdoc") #endif }, R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); - - // --- Model Package API --- -#if !defined(ORT_MINIMAL_BUILD) - // Helper to create a PyInferenceSession from a pre-initialized OrtSession* (C API handle). - // PyInferenceSession's owning ctor is protected; this subclass provides access. - struct PyModelPackageSession : PyInferenceSession { - PyModelPackageSession(std::unique_ptr sess) - : PyInferenceSession(std::move(sess)) {} - }; - - // Wrapper classes to manage opaque C handles with proper RAII - struct PyModelPackageContext { - OrtModelPackageContext* ctx_{nullptr}; - PyModelPackageContext(const std::string& package_path) { - auto path = ToPathString(package_path); - const auto* api = Ort::GetApi().GetModelPackageApi(); - Ort::ThrowOnError(api->CreateModelPackageContext(path.c_str(), &ctx_)); - } - ~PyModelPackageContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageContext(ctx_); - } - } - PyModelPackageContext(const PyModelPackageContext&) = delete; - PyModelPackageContext& operator=(const PyModelPackageContext&) = delete; - }; - - struct PyModelPackageComponentContext { - OrtModelPackageComponentContext* ctx_{nullptr}; - ~PyModelPackageComponentContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageComponentContext(ctx_); - } - } - PyModelPackageComponentContext(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext& operator=(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext() = default; - }; - - struct PyModelPackageOptions { - OrtModelPackageOptions* opts_{nullptr}; - ~PyModelPackageOptions() { - if (opts_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageOptions(opts_); - } - } - PyModelPackageOptions(const PyModelPackageOptions&) = delete; - PyModelPackageOptions& operator=(const PyModelPackageOptions&) = delete; - PyModelPackageOptions() = default; - }; - - py::class_(m, "ModelPackageContext", - R"pbdoc(Represents an opened model package for inspection and component selection.)pbdoc") - .def(py::init(), py::arg("package_path"), - R"pbdoc(Open a model package from the given directory path.)pbdoc") - .def( - "get_component_names", - [](PyModelPackageContext& self) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetComponentNames(self.ctx_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - R"pbdoc(Get the names of all components in the package.)pbdoc") - .def( - "get_variant_names", - [](PyModelPackageContext& self, const std::string& component_name) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetVariantNames( - self.ctx_, component_name.c_str(), &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - py::arg("component_name"), - R"pbdoc(Get the variant names for a given component.)pbdoc") - .def( - "get_variant_ep_name", - [](PyModelPackageContext& self, const std::string& component_name, - const std::string& variant_name) -> std::optional { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* ep = nullptr; - Ort::ThrowOnError(api->ModelPackage_GetVariantEpName( - self.ctx_, component_name.c_str(), variant_name.c_str(), &ep)); - if (ep) return std::string(ep); - return std::nullopt; - }, - py::arg("component_name"), py::arg("variant_name"), - R"pbdoc(Get the EP name for a variant. Returns None if not declared.)pbdoc") - .def( - "get_schema_version", - [](PyModelPackageContext& self) -> int64_t { - const auto* api = Ort::GetApi().GetModelPackageApi(); - int64_t version = 0; - Ort::ThrowOnError(api->ModelPackage_GetSchemaVersion(self.ctx_, &version)); - return version; - }, - R"pbdoc(Get the schema version declared in the model package manifest.)pbdoc") - .def( - "select_component", - [](PyModelPackageContext& self, const std::string& component_name, - PyModelPackageOptions& options) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->SelectComponent( - self.ctx_, component_name.c_str(), options.opts_, &result->ctx_)); - return result; - }, - py::arg("component_name"), py::arg("options"), - R"pbdoc(Select a component and resolve its variant based on the provided options. -Returns a ModelPackageComponentContext for inspecting the selected variant.)pbdoc"); - - py::class_(m, "ModelPackageOptions", - R"pbdoc(Options used for variant selection in a model package. -Created from a SessionOptions to capture EP configuration for variant matching.)pbdoc") - .def(py::init([](PySessionOptions& session_options) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->CreateModelPackageOptionsFromSessionOptions( - GetOrtEnv(), &session_options, &result->opts_)); - return result; - }), - py::arg("session_options"), - R"pbdoc(Create model package options from a SessionOptions instance. -The EP configured on the session options is used for variant selection.)pbdoc"); - - py::class_(m, "ModelPackageComponentContext", - R"pbdoc(Represents a selected component within a model package. -Provides access to the resolved variant's files, session options, and metadata.)pbdoc") - .def( - "get_selected_variant_folder_path", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const ORTCHAR_T* path = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantFolderPath(self.ctx_, &path)); - return PathToUTF8String(PathString(path)); - }, - R"pbdoc(Get the folder path of the selected variant.)pbdoc") - .def( - "get_selected_variant_name", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* name = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantName( - self.ctx_, &name)); - return name ? std::string(name) : std::string(); - }, - R"pbdoc(Get the name of the selected variant.)pbdoc") - .def( - "create_session", - [](PyModelPackageComponentContext& self, py::object session_options_obj) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - OrtSession* ort_session = nullptr; - if (session_options_obj.is_none()) { - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, nullptr, &ort_session)); - } else { - auto& so = session_options_obj.cast(); - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, &so, &ort_session)); - } - // OrtSession* is a reinterpret_cast of InferenceSession* - auto* inference_session = reinterpret_cast(ort_session); - std::unique_ptr session_ptr(inference_session); - return std::make_unique(std::move(session_ptr)); - }, - py::arg("session_options") = py::none(), - R"pbdoc(Create an InferenceSession from the selected component variant. - -Args: - session_options: Optional SessionOptions override. If None, uses the options - captured during variant selection with per-file options merged on top. - If provided, variant-specific options are NOT applied. - -Returns: - An InferenceSession ready for inference.)pbdoc"); -#endif // !defined(ORT_MINIMAL_BUILD) } bool InitArray() { diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index 6fb3f8e6ba82f..34f73eb69b149 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -12,6 +12,7 @@ #include "gtest/gtest.h" #include "core/session/model_package/model_package_context.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/abi_devices.h" #include "test/autoep/test_autoep_utils.h" #include "test/util/include/asserts.h" @@ -22,6 +23,90 @@ extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { namespace { + +// Typed function pointers for every OrtModelPackageApi_* experimental entry, +// resolved once via the experimental name-based lookup. +struct ModelPackageFns { + OrtExperimental_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn + CreateModelPackageOptionsFromSessionOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn + ReleaseModelPackageOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn + CreateModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn + ReleaseModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn + ModelPackage_GetSchemaVersion{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn + ModelPackage_GetComponentCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn + ModelPackage_GetComponentNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn + ModelPackage_GetVariantCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn + ModelPackage_GetVariantNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn + ModelPackage_GetVariantEpName{nullptr}; + OrtExperimental_OrtModelPackageApi_SelectComponent_SinceV28_Fn + SelectComponent{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn + ReleaseModelPackageComponentContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantName{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantFolderPath{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateSession_SinceV28_Fn + CreateSession{nullptr}; +}; + +inline const ModelPackageFns& GetModelPackageFns() { + static const ModelPackageFns fns = []() { + const OrtApi* api = &Ort::GetApi(); + ModelPackageFns f; +#define RESOLVE(member, getter) \ + do { \ + f.member = Ort::Experimental::getter(api); \ + if (f.member == nullptr) { \ + throw std::runtime_error(std::string("Failed to resolve experimental " \ + "OrtModelPackageApi_") + \ + #member "_SinceV28"); \ + } \ + } while (0) + RESOLVE(CreateModelPackageOptionsFromSessionOptions, + Get_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn); + RESOLVE(ReleaseModelPackageOptions, + Get_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn); + RESOLVE(CreateModelPackageContext, + Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn); + RESOLVE(ReleaseModelPackageContext, + Get_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn); + RESOLVE(ModelPackage_GetSchemaVersion, + Get_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentCount, + Get_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentNames, + Get_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantCount, + Get_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantNames, + Get_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantEpName, + Get_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn); + RESOLVE(SelectComponent, + Get_OrtModelPackageApi_SelectComponent_SinceV28_Fn); + RESOLVE(ReleaseModelPackageComponentContext, + Get_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantName, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantFolderPath, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn); + RESOLVE(CreateSession, + Get_OrtModelPackageApi_CreateSession_SinceV28_Fn); +#undef RESOLVE + return f; + }(); + return fns; +} // ------------------------------------------------------------------ // Helpers to build a test model package on disk // ------------------------------------------------------------------ @@ -233,26 +318,25 @@ std::filesystem::path CreateModelPackageApiTestPackage(bool multi_file_variant = TEST(ModelPackageApiTest, PackageContextQueries) { const auto package_root = CreateModelPackageApiTestPackage(); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr model_pkg_context(nullptr, context_deleter); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); // Query: component count + names size_t component_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); ASSERT_EQ(component_count, 1u); const char* const* component_names = nullptr; size_t component_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentNames( model_pkg_context.get(), &component_names, &component_name_count)); ASSERT_EQ(component_name_count, 1u); ASSERT_NE(component_names, nullptr); @@ -261,13 +345,13 @@ TEST(ModelPackageApiTest, PackageContextQueries) { // Query: variant count + names size_t variant_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantCount( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantCount( model_pkg_context.get(), "model_1", &variant_count)); ASSERT_EQ(variant_count, 2u); const char* const* variant_names = nullptr; size_t variant_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantNames( model_pkg_context.get(), "model_1", &variant_names, &variant_name_count)); ASSERT_EQ(variant_name_count, 2u); @@ -294,17 +378,16 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { - if (p) pkg_api->ReleaseModelPackageOptions(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { + if (p) pkg_api.ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr model_pkg_options(nullptr, options_deleter); @@ -312,25 +395,25 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unique_ptr component_context(nullptr, component_context_deleter); OrtModelPackageOptions* raw_options = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); model_pkg_options.reset(raw_options); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); OrtModelPackageComponentContext* raw_component_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(model_pkg_context.get(), - "model_1", - model_pkg_options.get(), - &raw_component_context)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(model_pkg_context.get(), + "model_1", + model_pkg_options.get(), + &raw_component_context)); component_context.reset(raw_component_context); OrtSession* raw_session = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateSession(*ort_env, - component_context.get(), - session_options, - &raw_session)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateSession(*ort_env, + component_context.get(), + session_options, + &raw_session)); Ort::Session session(raw_session); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -921,33 +1004,32 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { os << R"({"filename":"mul_1.onnx"})"; } - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(nullptr, context_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); // variant_1 targets example_ep const char* ep1 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", &ep1)); ASSERT_NE(ep1, nullptr); EXPECT_STREQ(ep1, "example_ep"); // variant_2 targets other_ep const char* ep2 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_2", &ep2)); ASSERT_NE(ep2, nullptr); EXPECT_STREQ(ep2, "other_ep"); // Optional out-parameter: callers can pass NULL. - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", nullptr)); std::filesystem::remove_all(package_root, ec); @@ -990,32 +1072,31 @@ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); // Path looks like .../models/model_1/ -- the folder name is the variant. @@ -1082,28 +1163,27 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); // CreateSession iterates the per-file session_options and dispatches each through OrtApis::AddSessionConfigEntry. @@ -1111,7 +1191,7 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn // Pass nullptr for session_options so the metadata-merge path runs (it is skipped when the caller // supplies their own session_options). OrtSession* raw_session = nullptr; - OrtStatus* st = pkg_api->CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); + OrtStatus* st = pkg_api.CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); // Clean up session first to avoid leaks if assertion fails. if (raw_session != nullptr) { Ort::GetApi().ReleaseSession(raw_session); @@ -1131,60 +1211,6 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::filesystem::remove_all(package_root, ec); } -// Test that the C++ RAII wrappers (Ort::ModelPackageContext, etc.) work correctly. -TEST(ModelPackageApiTest, CxxWrappers_PackageContextQueries) { - const auto package_root = CreateModelPackageApiTestPackage(); - - Ort::ModelPackageContext ctx(package_root.c_str()); - - // Component queries - EXPECT_EQ(ctx.GetComponentCount(), 1u); - auto component_names = ctx.GetComponentNames(); - ASSERT_EQ(component_names.size(), 1u); - EXPECT_EQ(component_names[0], "model_1"); - - // Variant queries - EXPECT_EQ(ctx.GetVariantCount("model_1"), 2u); - auto variant_names = ctx.GetVariantNames("model_1"); - ASSERT_EQ(variant_names.size(), 2u); - std::unordered_set variant_set(variant_names.begin(), variant_names.end()); - EXPECT_EQ(variant_set.count("variant_1"), 1u); - EXPECT_EQ(variant_set.count("variant_2"), 1u); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageApiTest, CxxWrappers_SelectComponentAndQueryFileAccessors) { - const auto package_root = CreateModelPackageApiTestPackage(); - - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - - Ort::SessionOptions so; - std::unordered_map ep_options; - so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - - Ort::ModelPackageContext ctx(package_root.c_str()); - auto cix = ctx.SelectComponent("model_1", pkg_opts); - - // Folder path should be non-empty - auto folder = cix.GetSelectedVariantFolderPath(); - EXPECT_FALSE(folder.empty()); - - // Selected variant name should not throw - auto variant_name = cix.GetSelectedVariantName(); - EXPECT_FALSE(variant_name.empty()); - - // CreateSession via C++ wrapper - auto session = cix.CreateSession(*ort_env, so); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - // ------------------------------------------------------------------ // Test: GetSelectedVariantFolderPath returns correct path even when variant.json is absent. // ------------------------------------------------------------------ @@ -1221,31 +1247,29 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { Ort::SessionOptions so; std::unordered_map ep_options; so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; std::unique_ptr mp_opts(raw_mp_opts, options_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(raw_ctx, context_deleter); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr comp_ctx(raw_comp_ctx, component_context_deleter); // GetSelectedVariantFolderPath should return the variant directory even without variant.json. const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); const auto result_path = std::filesystem::path(selected_folder); From 2cf6c6c6939533d888233f25f58b030c74e31c14 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 10 Jun 2026 17:26:13 -0700 Subject: [PATCH 2/8] [CUDA] Fix QMoE int4/int8 weight prepack to always use SM80 layout (#28978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The CUDA QMoE INT4/INT8 grouped GEMM always dispatches to the Ampere (SM80) CUTLASS kernel — even on Hopper (SM90) — because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized specialisation. This PR makes weight prepacking always emit the SM80 (column-interleaved) `fpA_intB` layout regardless of the runtime device SM, fixing silently-wrong output on Hopper, and centralizes the arch-clamping logic in a single shared helper. It also cleans up the related tests and tightens MoE parity tolerances that were too loose to catch the layout bug. ## Motivation https://github.com/microsoft/onnxruntime/pull/28749 uses 90 for sm90 weight prepacking. On SM90, `isValidHopperMOESpecialisation()` is `false`, so the grouped MoE GEMM falls back to the SM80 kernel. The weight preprocessor, however, skips column interleaving for `arch == 90`, so an auto-detected (`force_arch=-1`) pack on an H200 produced the non-interleaved SM90 layout that the SM80 kernel cannot consume — yielding wrong results. The previous `PrePackIntExpertWeights` logic clamped to `sm_` (passing SM90 through), and the test that exercised the offline packer used auto-detect, so both could emit the wrong layout. ## Key Changes | Area | Change | |---|---| | `fpA_intB_gemm_preprocessors{.h,_impl.cu}` | Extracted `get_arch_for_mixed_gemm_weight_preprocess(int arch)` as a shared, declared helper (clamps SM to the layout group: `<80→75`, `90→90`, else `80`). | | `fpA_intB_gemm_preprocessors_impl.h` | `getLayoutDetailsForTransform` now routes through the shared helper instead of duplicating the arch-range logic. | | `moe_quantization.cc` (`PrePackIntExpertWeights`) | Always packs INT4/INT8 expert weights for the SM80 layout (`get_arch_for_mixed_gemm_weight_preprocess(80)`) instead of clamping to the runtime `sm_`, since the SM80 kernel runs on every GPU. | | `onnxruntime_pybind_quant.cc` (`PackWeightsForMixedGemm`) | Replaced the ad-hoc `{75,80,90}` allowlist with the shared helper, so `force_arch` is clamped consistently with the runtime dispatch (removes the now-unused `` include). | | `contrib_defs.cc` / `moe_quantization.h` | Updated `weights_prepacked` schema/field docs: layouts for `-1`/`1` are EP-determined; for the CUDA EP `-1` and `1` are equivalent today (both SM80), `1` reserved for a future Hopper-specific layout. | | `test_qmoe_cuda.py` | Removed the dead, never-called `preprocess_weights_for_mixed_gemm` helper; the real path (`quant_dequant_blockwise`) already pins `sm=80`. | | `test_moe_cuda.py` | Pinned the offline packer to `arch=80`, and tightened FP16 QMoE parity tolerance from `atol 3.0 (4-bit)` / `2.0 (8-bit)` to `0.5` now that the layout is correct. | | `docs/` | Regenerated `ContribOperators.md` and updated `moe_qmoe.md` to match the new schema docs and SM80-always packing rationale. | ## Testing Notes On an H200 (SM90), with the CUDA 12.x/13.x Python wheel: ```bash python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py python -m pytest onnxruntime/test/python/transformers/test_moe_cuda.py -k "PhiQMoE or qmoe" ``` - `test_qmoe_cuda.py` SwiGLU parity: SM80 layout → max diff ~0.001 (pass, tol 0.1); the prior SM90 layout produced max diff ~1.2 (fail), confirming the fix. - `test_moe_cuda.py` `TestPhiQMoE` (4-bit and 8-bit, all batch/seq combinations): worst observed `max_diff` ≈ 0.375 with the fixed layout, comfortably under the new `atol=0.5`. - `ruff check` passes on both edited test files. --------- Co-authored-by: tlwu Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- docs/ContribOperators.md | 2 +- docs/contrib_ops/cuda/moe_qmoe.md | 77 +++++++++++++++++-- .../cuda/llm/fpA_intB_gemm_preprocessors.h | 2 + .../llm/fpA_intB_gemm_preprocessors_impl.cu | 13 ++++ .../llm/fpA_intB_gemm_preprocessors_impl.h | 6 +- .../contrib_ops/cuda/moe/moe_quantization.cc | 58 +++++++------- .../contrib_ops/cuda/moe/moe_quantization.h | 27 ++++--- .../core/graph/contrib_ops/contrib_defs.cc | 16 +--- .../python/onnxruntime_pybind_quant.cc | 12 +-- .../test/python/transformers/test_moe_cuda.py | 16 ++-- .../python/transformers/test_qmoe_cuda.py | 39 +--------- 11 files changed, 156 insertions(+), 112 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9d19a95136ad7..38d101786b41a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4949,7 +4949,7 @@ This version of the operator has been available since version 1 of the 'com.micr
use_sparse_mixer : int
Whether to use sparse mixer
weights_prepacked : int
-
Only meaningful when quant_type='int'. Tri-state control over whether the int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS fpA_intB format expected by the runner. -1 (auto): let the execution provider choose its own backward-compatible default; the CUDA EP treats auto as prepacked. 1: the initializers are already prepacked (e.g. produced offline by pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself in PrePack(), matching the behaviour of MatMulNBits and removing the offline pre-pack requirement from exporters. Defaults to -1 (auto) so each execution provider can pick its own backward-compatible default rather than the schema imposing one.
+
Only meaningful when quant_type='int'. Tri-state control over the layout of the int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by -1 and 1 are determined by the execution provider. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.
#### Inputs (6 - 21) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index 6d53211ff40cb..36b68889ae582 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -71,6 +71,7 @@ input tokens → router (top-k softmax) → permute by expert | `expert_weight_bits` (QMoE only) | int | 4 | 4 (INT4/MXFP4) or 8 (INT8/FP8). | | `block_size` (QMoE only) | int | -1 | Group size for INT4/INT8 group-wise quantization. -1 = per-output-channel. | | `quant_type` (QMoE only) | string | `"int"` | `"int"`, `"fp4"`, `"fp8"`, `"wfp4afp8"`. See [§3](#3-quantization-modes). | +| `weights_prepacked` (QMoE only) | int | -1 | Tri-state, only meaningful when `quant_type="int"`. The prepacked layouts selected by `-1` and `1` are **EP-determined**. `-1` (default): the INT4/INT8 `fc1`/`fc2` initializers are already prepacked in the EP's default layout (e.g. from `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). `1`: already prepacked in an alternate EP-selected layout. `0`: the initializers are raw `[E, N, K/pack]` tensors (as produced by `quantize_matmul_{4,8}bits`) and the kernel runs the CUTLASS layout transform in `PrePack()`. **Note:** the CUDA EP INT4/INT8 MoE GEMM always runs the Ampere (SM80) kernel — even on SM90 — so it consumes the SM80 `fpA_intB` layout on all architectures; `-1` and `1` are therefore equivalent for the CUDA EP today, and `1` is reserved for a possible future Hopper-specific layout. See [§5.1](#51-weights-input-2--5--8). | ### 2.2 Type Constraints @@ -228,10 +229,53 @@ extra subtraction. ### 5.1 Weights (input 2 / 5 / 8) -Not transformed at runtime. INT4/INT8 weights must already be packed offline by -`pack_weights_for_cuda_mixed_gemm` (see [§6](#6-weight-formats)). MXFP4 weights -must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights are stored -as raw e4m3 bytes (no packing). +**INT4/INT8** weight layout is controlled by the `weights_prepacked` attribute +([§2.1](#21-attributes)). The prepacked layouts selected by `-1` and `1` are +determined by the execution provider: + +- **`weights_prepacked=-1` (default)** — the `fc1`/`fc2` weights are already in + the EP's default prepacked layout (e.g. packed offline by + `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). They are copied to GPU + and consumed as-is. +- **`weights_prepacked=1`** — the `fc1`/`fc2` weights are already in the EP's + **SM90** (Hopper) prepacked layout (reserved; see the note below). +- **`weights_prepacked=0`** — the `fc1`/`fc2` weights are raw, schema-conformant + `[E, N, K/pack]` tensors as produced by `quantize_matmul_{4,8}bits`. `PrePack` + runs the CUTLASS layout transform itself via `PrePackIntExpertWeights`, + removing the offline pre-pack dependency. This makes integer QMoE symmetric + with `MatMulNBits::PrePack_B`. + +> **Single layout on the CUDA EP.** The CUDA EP INT4/INT8 MoE GEMM always +> dispatches to the Ampere (**SM80**) grouped-GEMM kernel — even on SM90 — +> because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA +> warp-specialized specialisation (`isValidHopperMOESpecialisation` is `false`). +> This matches **TensorRT-LLM**, which likewise routes `W4A16`/`W8A16` MoE to the +> SM80 kernel on Hopper; its Hopper TMA-WS mixed-dtype MoE kernel is reserved for +> `W4A8` (FP8 activation) and `WFP4A16` (FP4 weight). Consequently the CUDA EP +> consumes the **SM80 `fpA_intB` layout on every GPU**, `PrePack` always packs +> for SM80, and `weights_prepacked=-1` and `=1` are equivalent today. `1` is +> accepted and reserved for a possible future Hopper-specific layout (e.g. +> `W4A8`). There is therefore no architecture-match constraint: SM80-format +> weights run correctly on SM90 via the SM80 kernel. + +`PrePackIntExpertWeights` loops over the `E` experts and, per expert, applies the +same transpose + row-permutation / column-interleave / bias / pair-interleave +transform as `pack_weights_for_cuda_mixed_gemm` (see [§6.1](#61-int4-group-wise-quant_typeint-expert_weight_bits4)), +always targeting the SM80 layout. SM75+ is required. The source +`[E, N, K/pack]` initializers are released after their shapes are cached +(`fc1_weights_shape_` / `fc2_weights_shape_`), so peak weight memory stays ~1×. +The prepacked GPU buffers (`packed_fc1_weights_` / `packed_fc2_weights_`) are then +preferred by `ComputeInternal`. If prepacking is disabled at the session level +(`session.disable_prepacking`), the buffers stay null and the raw initializer +pointers are read at compute time instead. + +> **Note**: `weights_prepacked=0` is the only path that triggers an in-`PrePack` +> layout transform for INT weights. FP4 / FP8 / WFP4AFP8 weight handling is +> unaffected. + +MXFP4 weights must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights +are stored as raw e4m3 bytes (no packing). + ### 5.2 INT4/INT8 scales + zero-point → bias @@ -287,7 +331,12 @@ This section covers the five distinct weight encodings supported by QMoE. INT4 packing layout within a byte: `[high_nibble | low_nibble] = [elt_1 | elt_0]`. Each INT4 element is in `[-8, 7]` (signed) before bias, `[0, 15]` after the +8 bias. -#### Preprocessing pipeline (offline, `pack_weights_for_cuda_mixed_gemm`) +#### Preprocessing pipeline (offline `pack_weights_for_cuda_mixed_gemm`, or in-`PrePack` via `PrePackIntExpertWeights`) + +This is the layout transform applied either offline by +`pack_weights_for_cuda_mixed_gemm`, or per-expert inside `PrePack` when +`weights_prepacked=0` (see [§5.1](#51-weights-input-2--5--8)). + 1. **Input layout**: `[N, K]` per expert (Out × In), 2 elements per byte for INT4. 2. **Transpose & signed conversion**: @@ -405,6 +454,17 @@ weights are interchangeable across SMs: — does not use `pack_weights_for_cuda_mixed_gemm`. - **FP8**: no packing. +> **QMoE uses Group A on every GPU.** The table above describes the layouts the +> `pack_weights_for_cuda_mixed_gemm` *preprocessor* can emit. The QMoE INT4/INT8 +> MoE GEMM, however, always dispatches to the Ampere (SM80) grouped-GEMM kernel — +> even on SM90 — because mixed int-weight + fp16/bf16 activation is not a valid +> Hopper TMA warp-specialized specialisation (the same is true in TensorRT-LLM). +> It therefore consumes the **Group A (SM80) layout on all architectures, +> including Hopper**. For QMoE, always pack INT4/INT8 weights for SM80 (`arch=80`), +> and `PrePackIntExpertWeights` (`weights_prepacked=0`) does exactly that +> regardless of the runtime device SM. Group B (SM90) layout is currently unused +> by QMoE. + --- ## 8. SwiGLU Fusion @@ -830,7 +890,7 @@ will not change the operator interface. |-----------|----------| | [test_moe_cuda.py](onnxruntime/test/python/transformers/test_moe_cuda.py) | Standard MoE on CUDA: FP16/BF16, SiLU/GeLU/SwiGLU, routing, GEMM parity. SwiGLU coverage includes both GPT-OSS (`TestSwigluMoE`: interleaved, alpha=1.702/beta=1.0/limit=7.0) and Standard/Llama-Gemma (`TestStandardSwigluMoE`: concatenated `swiglu_fusion=2`, alpha=1.0/beta=0.0/no limit → `SiLU(Gate)×Value`). | | [test_moe_cpu.py](onnxruntime/test/python/transformers/test_moe_cpu.py) | Standard MoE on CPU (smoke). | -| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. | +| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. `TestQMoEIntPrePackSmoke` covers the raw-weight `weights_prepacked=0` in-`PrePack` layout transform (smoke test: asserts finite output, not bit-parity). | | [test_qmoe_cpu.py](onnxruntime/test/python/transformers/test_qmoe_cpu.py) | INT4/INT8 QMoE on CPU (smoke). | | [test_qmoe_fp4_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py) | MXFP4 QMoE: quantization utilities, packing, FP16/BF16, SiLU/SwiGLU, top-k and expert-count variants. End-to-end runs on SM120; on SM<120 the dequant fallback is exercised. | | [test_qmoe_fp8_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py) | FP8 W8A16 QMoE on SM90+ native path and SM<90 dequant fallback. | @@ -954,6 +1014,11 @@ over-aligned by-value parameters. cannot. See [§14.1](#141-msvc-and-tma-grouped-moe-gemm). - **WFP4AFP8 native** requires SM100+ hardware; only the dequant fallback path is validated end-to-end so far. +- **In-`PrePack` INT weight layout transform** (`weights_prepacked=0`) is + currently covered only by a smoke test (`TestQMoEIntPrePackSmoke`), not a + bit-parity check: the existing offline pre-pack harness hardcodes + `force_arch=80` (the same SM80 layout consumed by the CUDA EP on all GPUs), + so a separate parity harness for this path is still pending. - **Hopper W4A8** (INT4 weight + FP8 activation) is not supported — TRT-LLM gates its fast path to SM89 only. diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h index b9e62443145e5..c3b734816cf84 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h @@ -31,6 +31,8 @@ enum class QuantType { W4_AFP8 }; +int get_arch_for_mixed_gemm_weight_preprocess(int arch); + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu index a006612ddadc9..7e83bdda72eab 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu @@ -521,6 +521,19 @@ void add_bias_and_interleave_quantized_tensor_inplace_cuda( } } +int get_arch_for_mixed_gemm_weight_preprocess(int arch) { + ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); + if (arch < 80) { + return 75; + } +#ifndef EXCLUDE_SM_90 + if (arch >= 90 && arch < 100) { + return 90; + } +#endif + return 80; +} + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h index a8fb411ed0663..47bbe0c0e10ec 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h @@ -120,11 +120,11 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { } LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { - ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); - if (arch < 80) { + arch = get_arch_for_mixed_gemm_weight_preprocess(arch); + if (arch == 75) { return getLayoutDetailsForArch(quant_type); #ifndef EXCLUDE_SM_90 - } else if (arch >= 90 && arch < 100) { + } else if (arch == 90) { return getLayoutDetailsForArch(quant_type); #endif } else { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index e1ddcac0cea4f..7d1291e004d78 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -62,18 +62,28 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE this->quant_type_ = op_kernel_info.GetAttrOrDefault("quant_type", "int"); ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8", "quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'"); - // ``weights_prepacked`` is an optional tri-state attribute that defaults to - // -1 (auto) in the schema, so each EP picks its own backward-compatible - // default rather than the schema imposing one: - // -1 (auto, also the schema default): the EP decides. The CUDA EP's - // backward-compatible default is "prepacked" because all pre-existing - // tooling ships CUTLASS-prepacked weights. - // 1: initializers are already prepacked; the compute path reads them as-is. - // 0: initializers are raw [E, N, K/pack]; the PrePack hook lays them out. + // ``weights_prepacked`` is an optional tri-state attribute (default -1) that + // declares the layout of the int4/int8 fc1/fc2 weight initializers. The + // concrete prepacked layouts selected by -1 and 1 are determined by the + // execution provider. The CUDA EP maps the tri-state as: + // -1 (default): already prepacked in the EP's default int weight layout. + // 1: already prepacked in an alternate EP-selected int weight layout. + // 0: raw [E, N, K/pack] initializers; the PrePack hook lays them out. + // + // Important: the CUDA QMoE int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (see isValidHopperMOESpecialisation). The kernel therefore + // consumes the SM80/Ampere CUTLASS fpA_intB layout on every GPU. As a result + // the EP default (-1) is the SM80 layout regardless of the runtime device SM, + // and SM80-format weights are valid on SM90 (they run via the SM80 kernel). + // For CUDA today, -1 and 1 are equivalent (both SM80 layout), and 1 is + // reserved for a possible future Hopper-specific layout. + // PrePack (weights_prepacked=0) packs for the SM80 layout accordingly. const int64_t weights_prepacked_mode = op_kernel_info.GetAttrOrDefault("weights_prepacked", static_cast(-1)); ORT_ENFORCE(weights_prepacked_mode == -1 || weights_prepacked_mode == 0 || weights_prepacked_mode == 1, - "weights_prepacked must be -1 (auto), 0, or 1, but got ", weights_prepacked_mode); + "weights_prepacked must be -1 (default), 0, or 1, but got ", weights_prepacked_mode); weights_prepacked_ = (weights_prepacked_mode != 0); #if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE) ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer."); @@ -850,7 +860,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB // layout that the runner consumes and freed the source initializer // (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack`` - // (which already requires ``packed_fc1_weights_ != nullptr``) rather than + // (which already requires both packed weight buffers) rather than // just ``is_int && !weights_prepacked_``: when prepacking is disabled at // the session level (``session.disable_prepacking``) PrePack never runs, // the prepack buffers stay null, and the raw initializer pointers read @@ -1146,6 +1156,9 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al IAllocatorUniquePtr& packed_buf, bool& is_packed) { ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_); + ORT_ENFORCE(sm_ >= 75, + "PrePackIntExpertWeights: quant_type='int' with weights_prepacked=0 requires SM75+ CUDA hardware, got SM", + sm_); const auto& shape = tensor.Shape(); ORT_ENFORCE(shape.NumDimensions() == 3, "PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=", @@ -1158,22 +1171,15 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al const int64_t k_packed = shape[2]; const int64_t k = k_packed * pack_factor; - // Weight packing is architecture-aware (see - // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing - // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that - // skips column interleaving, so it is its own compatibility group. Every - // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares - // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack - // INT8 LDSM and are unsupported. The compute-side runner selects the same - // layout from this clamped arch, so the two cannot drift. - // - // SM75 is passed through unchanged (rather than clamped to 80) even though it - // shares SM80's layout: the compute-side dispatch (getLayoutDetailsForTransform) - // still has a distinct SM75 branch, so mirroring it here avoids confusing a - // reader into thinking prepack and dispatch disagree. - ORT_ENFORCE(sm_ >= 75, - "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_); - const int packing_sm = (sm_ == 90 || sm_ == 75) ? sm_ : 80; + // The CUDA QMoE int4/int8 MoE GEMM always dispatches to the Ampere (SM80) + // grouped-GEMM kernel -- even on SM90 -- because mixed int-weight + fp16/bf16 + // is not a valid Hopper TMA warp-specialized specialisation. The kernel thus + // consumes the SM80 CUTLASS fpA_intB layout on every GPU, so the weights must + // always be preprocessed for SM80 regardless of the runtime device SM. + // (Using get_arch_for_mixed_gemm_weight_preprocess(sm_) here would emit the + // SM90 layout on Hopper, which the SM80 kernel cannot consume -> wrong output.) + const int packing_sm = + onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(80); // Per-expert sizes. const size_t per_expert_bytes = static_cast(n) * static_cast(k) / pack_factor; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h index 5722ac41cc470..2bbadc205b5d8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h @@ -46,16 +46,23 @@ class QMoE final : public CudaKernel, public MoEBase { IAllocatorUniquePtr& packed_buf, bool& is_packed); int64_t expert_weight_bits_; bool is_fp16_; - // When true (the schema default), the int4/int8 fc1/fc2 weight - // initializers are already in the CUTLASS fpA_intB layout — produced - // offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the - // compute path reads them as-is. When false, the raw schema-conformant - // ``[E, N, K/pack]`` layout (as produced by - // ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook - // via ``PrePackIntExpertWeights``, removing the offline prepack - // dependency. Only meaningful when ``quant_type_ == "int"``. Derived from - // the optional tri-state ``weights_prepacked`` attribute: -1/auto (or - // absent) maps to true on the CUDA EP, 1 maps to true, 0 maps to false. + // When true, the int4/int8 fc1/fc2 weight initializers are already in a + // CUTLASS fpA_intB layout — produced offline e.g. via + // ``pack_weights_for_cuda_mixed_gemm`` — and the compute path reads them + // as-is. When false, the raw schema-conformant ``[E, N, K/pack]`` layout + // (as produced by ``quantize_matmul_{4,8}bits``) is rewritten inside the + // PrePack hook via ``PrePackIntExpertWeights``, removing the offline + // prepack dependency. Only meaningful when ``quant_type_ == "int"``. + // Derived from the optional tri-state ``weights_prepacked`` attribute: + // -1 (default) and 1 both map to true; 0 maps to false. The concrete + // prepacked layouts selected by -1 and 1 are determined by the execution + // provider. For the CUDA EP the int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (matches TensorRT-LLM, which also routes W4A16/W8A16 MoE to + // the SM80 kernel on Hopper). The kernel therefore consumes the SM80 fpA_intB + // layout on every GPU, so -1 and 1 are currently equivalent for the CUDA EP; + // 1 is reserved for a possible future Hopper-specific layout (e.g. W4A8). bool weights_prepacked_ = true; // Cached source weight shapes captured at PrePack time. When the // PrePack hook consumed and released the original int4/int8 weight diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1054fd94ef423..f3f2f521ecab2 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1520,18 +1520,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::STRING, std::string("int")) .Attr("weights_prepacked", - "Only meaningful when quant_type='int'. Tri-state control over whether the " - "int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS " - "fpA_intB format expected by the runner. -1 (auto): let the execution provider " - "choose its own backward-compatible default; the CUDA EP treats auto as " - "prepacked. 1: the initializers are already prepacked (e.g. produced offline by " - "pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers " - "are raw, un-prepacked [E, N, K/pack] tensors as produced by " - "quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself " - "in PrePack(), matching the behaviour of MatMulNBits and removing the offline " - "pre-pack requirement from exporters. Defaults to -1 (auto) so each execution " - "provider can pick its own backward-compatible default rather than the schema " - "imposing one.", + "Only meaningful when quant_type='int'. Tri-state control over the layout of the " + "int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by " + "-1 and 1 are determined by the execution provider. 0: the initializers are raw, " + "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.", AttributeProto::INT, static_cast(-1)) .Input(0, diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5b1d590a06234..7220153b4fa17 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -16,7 +16,6 @@ #endif #include #include -#include #include namespace pybind11 { @@ -252,17 +251,8 @@ py::array_t PackWeightsForMixedGemm( cudaDeviceProp device_prop; ThrowIfCudaError(cudaGetDeviceProperties(&device_prop, device_id), "cudaGetDeviceProperties"); sm = device_prop.major * 10 + device_prop.minor; - } else { - // Validate force_arch against the SM versions for which preprocess_weights_for_mixed_gemm_cuda - // has tile/permutation tables. Unknown SMs would silently produce incorrect weight layouts. - static const std::set kSupportedSm = {75, 80, 90}; - if (kSupportedSm.find(sm) == kSupportedSm.end()) { - std::ostringstream oss; - oss << "force_arch=" << sm << " is not a supported SM version. " - << "Pass -1 for auto-detect, or one of: 75, 80, 90 (arch > 90 will fallback to 80)."; - throw std::invalid_argument(oss.str()); - } } + sm = ::onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(sm); auto permutation_map_buffer = make_cuda_ptr(32 * sizeof(int32_t)); diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index c5fc826a5a6ed..9677542270a53 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -152,7 +152,10 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): q_weight_reshaped = q_weight.reshape(n, -1) # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format), and qMoE kernel uses the same format. - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4) + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4, 80) # So we need to DEQUANTIZE back to get `result`. # scale is [n, block_per_k] @@ -232,8 +235,11 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): ) q_weight_reshaped = q_weight.reshape(n, -1) - # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format) - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8) + # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format). + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8, 80) # Dequantize for reference # (q - 128) * scale if using 128 offset? or (q) * scale if symmetric around 0? @@ -1084,8 +1090,8 @@ def parity_check(self): ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (0.3, 0.05), - "FP16:4": (3.0, 1e-2), - "FP16:8": (2.0, 1e-2), + "FP16:4": (0.5, 1e-2), + "FP16:8": (0.5, 1e-2), "BF16:0": (1.0, 1e-2), "BF16:4": (30.0, 1e-1), "BF16:8": (20.0, 1e-1), diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 993716a4c80b0..c56383d2851d3 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -137,43 +137,6 @@ def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): ) -def preprocess_weights_for_mixed_gemm( - tensor: torch.Tensor, quant_bits: int, sm: int = -1, do_weight_interleave: bool = True -) -> torch.Tensor: - if len(tensor.shape) == 2: - tensor = tensor.unsqueeze(0) - - # Input tensor shape is [Experts, n, k_packed]. k_packed is k/2 for 4-bit, k for 8-bit. - num_experts = tensor.shape[0] - n = tensor.shape[1] - k_packed = tensor.shape[2] - k = k_packed * 2 if quant_bits == 4 else k_packed - - packed_list = [] - - if _pybind and hasattr(_pybind, "pack_weights_for_cuda_mixed_gemm") and torch.cuda.is_available(): - for i in range(num_experts): - if tensor[i].dtype == torch.bfloat16: - weight = tensor[i].to(torch.float32).cpu().numpy() - else: - weight = tensor[i].cpu().numpy() - packed = _pybind.pack_weights_for_cuda_mixed_gemm(weight, n, k, quant_bits, sm) - # pack_weights_for_cuda_mixed_gemm returns int8 array of shape [packed_size] - # We need to reshape it to (k, n/2) for 4-bit, (k, n) for 8-bit. - output_rows = k - output_cols = n // 2 if quant_bits == 4 else n - packed_tensor = torch.from_numpy(packed).to(tensor.device) - packed_tensor = packed_tensor.view(torch.uint8).view(output_rows, output_cols) - packed_list.append(packed_tensor) - - return torch.stack(packed_list) - else: - # This shall not happen unless older version of onnxruntime is used. - raise ImportError( - "onnxruntime._pybind_state.pack_weights_for_cuda_mixed_gemm not found. Cannot preprocess weights." - ) - - def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True, asymmetric: bool = False): # DEBUG # print(f"DEBUG: quant_dequant input shape={weights.shape}, 4bit={is_4_bit_quantization}, asym={asymmetric}") @@ -2110,7 +2073,7 @@ class TestQMoEIntPrePackSmoke(unittest.TestCase): hardware (the other ``test_swiglu_qmoe_parity_*`` cases in this file fail on H200 / H100 with max-diff > 1.0 on plain main, by inspection — pre-existing). A real parity check can be added once - that harness honours the runtime SM. + that harness honors the runtime SM. """ def _run_one(self, *, hidden_size, inter_size, num_experts, top_k, swiglu_fusion, batch_size): From f2c8a0965cc991a21a2f6cad1a1874eb997503f1 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jun 2026 10:33:53 +0800 Subject: [PATCH 3/8] webgpu: Fuse FlashAttention decode kernels and extend to any sequence length (#28389) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Extend the FlashAttention decode path to work with any sequence length (not just seq_len=1), with causal masking and `use_seqlen_k` support for static KV cache - Add m_tile optimization to process multiple Q rows per workgroup (m_tile=1/2/4), amortizing K/V loads - Fuse the separate QKT and SplitVx shaders into a single QKV kernel using online softmax, eliminating the intermediate `qk` tensor (`B×H×seq×present_seq`) and reducing dispatch count from 3 to 2 - Route between prefill (FlashAttentionProgram) and split-reduce (fused QKV + VxReduce) paths based on sequence length ## Resolved Issues **Whisper decoding prefill improved from 4.68ms to 1.09ms.** Whisper's decoder attention has a small sequence length but large total sequence length (seq_len=4, total_seq_len=1500). The default prefill shader (FlashAttentionProgram) has low parallelism in this case because each workgroup iterates serially over the full KV cache. The split-reduce path tiles the KV dimension across workgroups, achieving much higher GPU occupancy for this workload shape. ## Details **Fused QKV kernel**: Each workgroup computes QK^T dot products, applies attention bias and causal mask, computes local softmax (per-tile max and sum), normalizes, and multiplies by V — all in one kernel. Per-tile metadata (max, sum) is written for the VxReduce shader to rescale partial outputs using online softmax: `output = Σ(partial_i × local_sum_i × exp(local_max_i - global_max)) / global_sum`. **Path routing** (`use_split_reduce`): The split-reduce path is used when `sequence_length_ < 32`; otherwise the single-kernel FlashAttentionProgram prefill path is used. Microbenchmarks on Phi-4 (32 heads, head_size 128, GQA group 3) show split-reduce is 1.13×-2.07× faster than the fused prefill kernel across `sequence_length ∈ {16, 30, 31}` × `total_sequence_length ∈ {128, 500, 2000}`. The previous heuristic additionally gated on `total_sequence_length_ > 1000`, but that signal is 0 under graph capture (seqlen_k lives on the GPU) and the carve-out is unnecessary because split-reduce is uniformly faster for short Q. ## Test plan - [x] 30/30 MHA unit tests pass - [x] phi4-graph-prune produces correct output - [x] whisper-tiny-int4 produces correct transcription - [x] clang-format clean --- .../webgpu/bert/flash_attention.cc | 267 +++++++++--------- .../contrib_ops/webgpu/bert/flash_attention.h | 59 ++-- .../flash_attention_decode_qkt.wgsl.template | 118 -------- .../flash_attention_decode_qkv.wgsl.template | 197 +++++++++++++ ...sh_attention_decode_split_vx.wgsl.template | 113 -------- ...h_attention_decode_vx_reduce.wgsl.template | 93 ++++-- ..._rotary_embedding_and_copykv.wgsl.template | 5 +- 7 files changed, 426 insertions(+), 426 deletions(-) delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 684e050f0201a..02e764d01e05e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,40 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +// WGSL helper function for normalizing on-device indirect dispatch dims. +// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram. +// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers: +// 1) direct (x, y, z) write when every dim is within the spec limit (65535); +// 2) 2D sqrt collapse when the product fits a square layout; +// 3) 3D cbrt collapse otherwise. +// Consumers are unaffected by the chosen layout: ShaderHelper flattens +// workgroup_id (x, y, z) into a single linear workgroup_idx. +// Caller contract: must register a storage output named exactly +// `indirect_buffer` of array with at least 3 elements. +constexpr const char kNormalizeDispatchGroupSizeFn[] = R"( +fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) { + let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension + if (x <= limit && y <= limit && z <= limit) { + indirect_buffer[0] = x; + indirect_buffer[1] = y; + indirect_buffer[2] = z; + return; + } + let size = f32(x) * f32(y) * f32(z); + let dispatch_avg_2d = u32(ceil(sqrt(size))); + if (dispatch_avg_2d <= limit) { + indirect_buffer[0] = dispatch_avg_2d; + indirect_buffer[1] = dispatch_avg_2d; + indirect_buffer[2] = 1u; + return; + } + let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0))); + indirect_buffer[0] = dispatch_avg_3d; + indirect_buffer[1] = dispatch_avg_3d; + indirect_buffer[2] = dispatch_avg_3d; +} +)"; + Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); @@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha if (prepare_indirect_dispatch_) { sh.AddOutput("indirect_buffer", ShaderUsage::None); + sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; } return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", @@ -87,13 +122,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { - // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size. - shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n" - << " if (global_idx == 0u) {\n" + shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; + shader.MainFunctionBody() << " if (global_idx == 0u) {\n" << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" - << " indirect_buffer[0] = num_total_seq_length_tile;\n" - << " indirect_buffer[1] = uniforms.num_heads;\n" - << " indirect_buffer[2] = 1u;\n" + << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -120,7 +152,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -176,7 +208,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt {static_cast(parameters.total_sequence_length_)}, {static_cast(parameters.kv_sequence_length_)}, {tile_size}, - {static_cast(parameters.num_heads_)}}); + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.batch_size_)}, + {num_q_tiles}}); return context.RunProgram(program); } @@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } -Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); +Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); + const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const uint32_t tile_size_k_vec = 8; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(q, q)); } -Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, - const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) { +Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, + const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, + Tensor* metadata, const Tensor* seqlen_k, + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; + const int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch}; + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; + bool is_unidirectional = parameters.is_unidirectional_; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}}); + {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank}, + program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -294,124 +342,72 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.num_heads_)}, {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); } -Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("metadata", ShaderUsage::UseUniform); - shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); +Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { shader.AddInput("head_sink", ShaderUsage::UseUniform); } - shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); - - const uint32_t tile_size_k_vec = 8u; - - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template", - WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); -} - -Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context, - const Tensor* metadata, - const Tensor* qk, - Tensor* out_split_vx, - Tensor* present_value, - const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, - const Tensor* indirect_buffer, - uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile, - uint32_t tile_size, - bool use_indirect_dispatch, - uint32_t present_sequence_length, - const Tensor* head_sink) { - const int components = 4; - const bool has_head_sink = head_sink != nullptr; - int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink}; - program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, - {qk, ProgramTensorMetadataDependency::TypeAndRank}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] - const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - if (use_indirect_dispatch) { - program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); - } - if (has_head_sink) { - program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); - } - // SetIndirectDispatchTensor must be called after all AddInput calls because it - // appends the indirect buffer as the last program input. - if (use_indirect_dispatch) { - program.SetIndirectDispatchTensor(indirect_buffer); - } else { - program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); - } - program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink) - .SetWorkgroupSize(64) - .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, - {static_cast(head_size_vec)}, - present_sequence_length, - {static_cast(parameters.n_reps)}, - num_present_sequence_length_tile, - {batch_heads}, - {static_cast(parameters.num_heads_)}}); - - return context.RunProgram(program); -} - -Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { - shader.AddInput("seqlens_k", ShaderUsage::None); - } - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(input, input), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(output, output)); } Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context, const Tensor* out_split_vx, + const Tensor* metadata, Tensor* output, const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch) { + bool use_indirect_dispatch, + const Tensor* head_sink, + uint32_t m_tile) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch}; - program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); + bool has_head_sink = head_sink != nullptr; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, + {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (has_head_sink) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - program.SetDispatchGroupSize(batch_heads * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) + program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {batch_heads}}); + {batch_heads}, + {static_cast(parameters.sequence_length_)}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -446,14 +442,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; + // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch. + const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u); + const uint32_t num_q_tiles = (static_cast(parameters.sequence_length_) + m_tile - 1u) / m_tile; + // Create indirect dispatch buffer if using indirect dispatch Tensor* indirect_buffer_ptr = nullptr; Tensor indirect_buffer; - // Prepare indirect dispatch buffer for decode path with static KV cache - const bool use_indirect_dispatch = !kv_empty && - parameters.sequence_length_ == 1 && - parameters.past_present_share_buffer_ && + // Prepare indirect dispatch buffer for split-reduce path with static KV cache. + // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based + // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. + const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && seqlen_k != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { @@ -492,10 +492,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size)); + indirect_buffer_ptr, tile_size, num_q_tiles)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); } // Extract present_sequence_length directly from present_key tensor shape @@ -503,7 +503,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); - if (parameters.sequence_length_ > 1) { + // Route between prefill path (FlashAttentionProgram, single kernel) + // and split-reduce decode path (QKV + VxReduce, 2 kernels). + // Split-reduce wins for short Q (sequence_length < 32) across all KV + // cache lengths measured: 1.13x-2.07x faster at total_sequence_length + // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96). + const bool use_split_reduce = parameters.sequence_length_ < 32; + + if (!use_split_reduce) { + // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles) bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -545,7 +553,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t prefill_tile_size = is_apple ? 128 : tile_size; const uint32_t num_seq_tile = (parameters.sequence_length_ + prefill_tile_size - 1) / prefill_tile_size; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -570,37 +577,31 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } - // For decode path (sequence_length == 1) - const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, present_sequence_length}); - const TensorShape qk_shape(qk_dims); - Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); + // Split-reduce path (QKV + VxReduce) const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, 2}); + parameters.sequence_length_, num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k, - parameters, indirect_buffer_ptr, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length)); const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, parameters.head_size_}); + parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, - seqlen_k, parameters, indirect_buffer_ptr, - num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, - use_indirect_dispatch, present_sequence_length, - head_sink)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value, + &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + present_sequence_length, m_tile)); + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + head_sink, m_tile)); return Status::OK(); } @@ -621,7 +622,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size) { + uint32_t tile_size, uint32_t num_q_tiles) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -678,6 +679,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, + {static_cast(params.batch_size_)}, + {num_q_tiles}, }); program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index e75b6378f67c6..3da6b33b4dc0e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -35,7 +35,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"tile_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_q_tiles", ProgramUniformVariableDataType::Uint32}); private: bool has_past_; @@ -138,11 +142,14 @@ class FlashAttentionProgram final : public Program { int max_k_step_; }; -class FlashAttentionDecodeQKTProgram final : public Program { +class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeQKTProgram(const std::string& kernel_name, - bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeQKVProgram(const std::string& kernel_name, + bool has_attention_bias, uint32_t tile_size, int head_size_vec, + bool use_indirect_dispatch, bool q_BNSH = false, + bool is_unidirectional = false, + uint32_t m_tile = 1) + : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -156,41 +163,23 @@ class FlashAttentionDecodeQKTProgram final : public Program { - public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_size_vec", ProgramUniformVariableDataType::Uint32}, - {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_heads", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); - - private: uint32_t tile_size_; int head_size_vec_; bool use_indirect_dispatch_; - bool has_head_sink_; + bool q_BNSH_; + bool is_unidirectional_; + uint32_t m_tile_; }; class FlashAttentionDecodeVxReduceProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -199,12 +188,16 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array; -var inner_qk_values: array, tile_size>; -var tile_qk: array; - -#if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - // Handle broadcasting: if dimension size is 1, use index 0 - let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); - let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - - // Calculate flat offset with broadcasting applied - // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] - // For decode, new_seq_length is 1, so we can simplify: - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + - bias_head_idx * total_seq_length + - k_idx; - return attention_bias[offset]; - } -#else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - return q_element_t(0); - } -#endif - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; -#if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; -#else - let total_sequence_length = uniforms.total_sequence_length; -#endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let head_idx = batch_head_idx % uniforms.num_heads; - let batch_idx = batch_head_idx / uniforms.num_heads; - if (batch_idx >= uniforms.batch_size) { - return; - } - let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; - let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; - for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { - if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { - tile_q[local_idx] = q[q_offset + k + local_idx]; - } - workgroupBarrier(); - let q_data = tile_q[local_col] * q_element_t(uniforms.alpha); - if (k + local_col < uniforms.head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data); - } - } - } - workgroupBarrier(); - } - - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - var sum = q_element_t(0); - for (var i = 0u; i < tile_size_k_vec; i++) { - sum += inner_qk_values[local_idx][i]; - } - - sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); - tile_qk[local_idx] = sum; - output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; - } - workgroupBarrier(); - - if (local_idx == 0u) { - // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); - var l_sum = f32(0); - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_max = max(l_max, f32(tile_qk[i])); - } - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_sum += exp(f32(tile_qk[i]) - l_max); - } - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; - metadata[meta_offset] = metadata_value_t(l_max, l_sum); - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template new file mode 100644 index 0000000000000..524a18ca43245 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_attention_bias +#param v_head_size_vec +#param is_unidirectional +#param m_tile +#param q_BNSH +#param sub_tile_count +#param tile_size +#param tile_size_k_vec +#param use_indirect_dispatch + +#use .getByOffset .setByOffset + +// Fused QK^T + softmax + V multiply shader. +// +// Each workgroup processes one KV tile (tile_size rows of present_key/value) +// for m_tile Q rows. The computation has two phases: +// +// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask, +// per-tile max/sum for online softmax) +// Phase 2: Local softmax normalization + V multiply (using local max/sum, +// no cross-workgroup dependency) +// +// The VxReduce shader performs the final rescaling across tiles. + +var tile_q: array, m_tile>; +var inner_qk_values: array, tile_size>, m_tile>; +var tile_qk: array, m_tile>; +var tile_output: array, m_tile>; +var qkv_values: array, sub_tile_count>, m_tile>; +var tile_max: array; +var tile_sum: array; + +#if has_attention_bias + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + + bias_head_idx * uniforms.new_sequence_length * total_seq_length + + q_idx * total_seq_length + + k_idx; + return attention_bias[offset]; + } +#else + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + return q_element_t(0); + } +#endif + +$MAIN { + let local_row = u32(local_idx / tile_size_k_vec); + let local_col = local_idx % tile_size_k_vec; + #if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; + #else + let total_sequence_length = uniforms.total_sequence_length; + #endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles)); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; + let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + + // ============================================================ + // Phase 1: QK^T computation + // ============================================================ + for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { + let q_idx = q_base + m; +#if q_BNSH + let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec + + head_idx * uniforms.new_sequence_length * uniforms.head_size_vec + + q_idx * uniforms.head_size_vec; +#else + let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec; +#endif + tile_q[m][local_idx] = q.getByOffset(q_offset + k + local_idx); + } + } + workgroupBarrier(); + if (k + local_col < uniforms.head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let k_data = present_key.getByOffset(present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha); + inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data); + } + } + } + } + workgroupBarrier(); + } + + // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + var sum = q_element_t(0); + for (var i = 0u; i < tile_size_k_vec; i++) { + sum += inner_qk_values[m][local_idx][i]; + } + + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); +#if is_unidirectional + if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + sum = q_element_t(-65504.0f); + } +#endif + tile_qk[m][local_idx] = present_value_element_t(sum); + } + workgroupBarrier(); + + // Compute per-tile max and sum for online softmax + if (local_idx == 0u) { + var l_max = f32(-3.4028234663852886e+38f); + var l_sum = f32(0); + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_max = max(l_max, f32(tile_qk[m][i])); + } + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_sum += exp(f32(tile_qk[m][i]) - l_max); + } + tile_max[m] = l_max; + tile_sum[m] = l_sum; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(l_max, l_sum)); + } + } + workgroupBarrier(); + + // ============================================================ + // Phase 2: Local softmax + V multiply + // ============================================================ + + // Normalize tile_qk with local max/sum + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); + } + } + workgroupBarrier(); + + for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] = present_value_value_t(0); + } + workgroupBarrier(); + + if (k + local_col < v_head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let v_data = present_value.getByOffset(present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row]; + } + } + } + } + workgroupBarrier(); + + if (local_idx < tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + for (var i = 0u; i < sub_tile_count; i++) { + tile_output[m][k + local_idx] += qkv_values[m][i][local_idx]; + } + } + } + workgroupBarrier(); + } + + // Write output + let tile_idx = workgroup_idx % num_total_seq_length_tile; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec; + for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) { + out_split_vx.setByOffset(out_base + tile_idx * v_head_size_vec + i, tile_output[m][i]); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template deleted file mode 100644 index 6f1ad1ca41b71..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param has_head_sink -#param tile_size -#param head_size_vec -#param tile_size_k_vec -#param sub_tile_count -#param use_indirect_dispatch - -// Note that this shader adopts similar algorithm with dp4a generation shader. -// -// This algorithm works to compute dot product of v with qk parallelly, by -// processing on the head_size dimension at each step amongst tile_size_k_vec -// threads, and utilizing the remaining threads in the workgroup to process -// additional rows of |present_value| in parallel (such that the values in -// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads -// also reload |present_value| tile_size/sub_tile_count times to compute partial -// dot products of other |present_value| rows in order to complete all tile_size -// |present_value| rows in this workgroup and also reusing the values in -// tile_qk. -// -// The difference with FlashAttentionDecodeQKTProgram is that the dot products -// go through the rows (total_sequence_length) of |present_value| instead of -// columns (head_size_vec). And each workgroup only calculate current -// tile_size's dot products instead of iterating the whole row -// |total_sequence_length|. That's why this shader is a split shader. The final -// reduce will be done in FlashAttentionDecodeReduceProgram. - -// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx -// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate -// memory. The FlashAttentionDecodeQKT can be merged into split shader and do -// the final softmax adjustment in the reduce shader. However, some issues are -// met that when the total sequence length exceeds some value, the result will -// become garbage. Since it can't be resolved in a short time, leave it as TODO -// to fix it in future. - -var tile_qk: array; -var tile_output: array; -var qkv_values: array, sub_tile_count>; - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; - #else - let total_sequence_length = uniforms.total_sequence_length; - #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - if (batch_head_idx >= uniforms.batch_heads) { - return; - } - let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; - - // Calculate the global max and sum in qk. - var g_max = f32(-3.4028234663852886e+38f); -#if has_head_sink - let head_idx = batch_head_idx % uniforms.num_heads; - let sink_value = f32(head_sink[head_idx]); - g_max = max(g_max, sink_value); -#endif - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } -#if has_head_sink - g_sum += exp(sink_value - g_max); -#endif - - if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); - } - - for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { - var value = present_value_value_t(0); - qkv_values[local_row][local_col] = present_value_value_t(0); - workgroupBarrier(); - - if (k + local_col < head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row]; - } - } - } - - qkv_values[local_row][local_col] = value; - workgroupBarrier(); - - if (local_idx < tile_size_k_vec) { - for (var i = 0u; i < sub_tile_count; i++) { - tile_output[k + local_idx] += qkv_values[i][local_idx]; - } - } - workgroupBarrier(); - } - - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; - out_split_vx[out_offset] = tile_output[i]; - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index f909a87724da6..a3ce0b68cb659 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -1,31 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#param has_head_sink +#param m_tile #param seq_tile_size #param tile_size #param use_indirect_dispatch -// Inputs are splits of the GQA output, split into num_total_seq_length_tiles -// rows. This shader needs to add these splits across the row dimension to -// arrive at the final result. The column is head size wide. The reduction -// achieves maximum parallelization by splitting this task first into tile_size -// columns that each workgroup is responsible for. Then within each workgroup -// the task of summation over the num_total_seq_length_tile for the tile_size -// columns is further split in two ways. First across the row dimension to have -// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE -// rows. Then across the column dimension where each thread is responsible for 1 -// column of the TILE_SIZE columns the workgroup is responsible for. +#use .getByOffset .setByOffset + +// This shader reduces partial V outputs from the fused QKV shader. +// Each tile produced a locally-normalized V contribution. To get the +// correct global result, we rescale each tile's contribution using +// per-tile metadata (max, sum) with online softmax: +// +// global_max = max(local_max_i for all tiles) +// global_sum = sum(local_sum_i * exp(local_max_i - global_max)) +// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum var tile_input: array, tile_size>; $MAIN { + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile] let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles)); if (batch_head_idx >= uniforms.batch_heads) { return; } - let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; - var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; #if use_indirect_dispatch @@ -35,23 +39,60 @@ $MAIN { let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif - if (head_size_offset + local_col < uniforms.head_size_vec) { - for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { - if (r + local_row < num_total_seq_length_tile) { - value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile; + + // Compute global max across all tiles + var g_max = f32(-3.4028234663852886e+38f); +#if has_head_sink + let head_idx_for_sink = batch_head_idx % uniforms.num_heads; + let sink_value = f32(head_sink[head_idx_for_sink]); + g_max = max(g_max, sink_value); +#endif + for (var i = 0u; i < num_total_seq_length_tile; i++) { + g_max = max(g_max, metadata.getByOffset(meta_base + i).x); + } + + // Compute global sum with rescaling + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { + let m_value = metadata.getByOffset(meta_base + i); + g_sum += m_value.y * exp(m_value.x - g_max); + } +#if has_head_sink + g_sum += exp(sink_value - g_max); +#endif + + // Accumulate rescaled partial outputs + var value = output_value_t(0); + if (head_size_offset + local_col < uniforms.head_size_vec) { + for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { + if (r + local_row < num_total_seq_length_tile) { + let tile_meta = metadata.getByOffset(meta_base + r + local_row); + let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum; + value += input.getByOffset(in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col) * output_value_t(output_element_t(rescale_f32)); + } } } - } - tile_input[local_row][local_col] = value; - workgroupBarrier(); + tile_input[local_row][local_col] = value; + workgroupBarrier(); - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { - value = output_value_t(0); - for (var i = 0u; i < tile_size; i++) { - value += tile_input[i][local_idx]; + if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { + value = output_value_t(0); + for (var i = 0u; i < tile_size; i++) { + value += tile_input[i][local_idx]; + } + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec + + head_size_offset + local_idx; + output.setByOffset(output_id, value); } - let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; - output[output_id] = value; + workgroupBarrier(); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index c64bdf45cdcf8..7b09a3a6af080 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -41,12 +41,9 @@ $MAIN { #endif #if prepare_indirect_dispatch - // Prepare indirect dispatch buffer for thread 0 if (global_idx == 0u) { let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; - indirect_buffer[0] = num_total_seq_length_tile; - indirect_buffer[1] = uniforms.num_heads; - indirect_buffer[2] = 1u; + normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif From 0931c1a038302d7f32278b533a5e455a506c4680 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:00:31 -0600 Subject: [PATCH 4/8] Fix out-of-bounds error (#28991) This pull request introduces important safety checks to prevent out-of-bounds access in the logits processing code for transformers. The main updates ensure that token IDs are validated against the vocabulary size before being used, which improves robustness and prevents potential crashes. **Safety and robustness improvements:** * Added bounds checking for token IDs in the `RepetitionPenaltyLogitsProcessor::Process` method to ensure only valid IDs are used when accessing `beam_token_scores`. * Added bounds checking for token IDs in the `NoRepeatNGramLogitsProcessor::Process` method to prevent out-of-bounds writes to `beam_token_scores`. * Updated the `NextTokenScores::SetScore` method to return early if the provided `token_id` is out of bounds, replacing the previous assert with a safe check. --- .../contrib_ops/cpu/transformers/logits_processor.cc | 8 ++++++++ .../contrib_ops/cpu/transformers/logits_processor.h | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index d28aae02ab2f1..6e75bb1b5a7c0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -48,7 +48,11 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, unique_word_ids.insert(word_id); } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : unique_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } T score = beam_token_scores[word_id]; // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, @@ -89,7 +93,11 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, } } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : blocked_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } beam_token_scores[word_id] = std::numeric_limits::lowest(); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 6e157d83159f4..fb9638bb09d59 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -28,7 +28,9 @@ struct NextTokenScores { } void SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); + if (token_id < 0 || token_id >= vocab_size) { + return; + } for (int i = 0; i < batch_beam_size; i++) { scores[static_cast(i) * vocab_size + token_id] = score; } From 5c7ba1315d75fd162d5f0a8155a0879ff376b0c5 Mon Sep 17 00:00:00 2001 From: Theodore Cooper <63190431+the0cp@users.noreply.github.com> Date: Fri, 12 Jun 2026 00:04:38 +0800 Subject: [PATCH 5/8] fix: Add Linux NPU discovery through sysfs accel devices (#28703) ## Description This PR adds Linux NPU discovery through sysfs accel devices Currently, `DeviceDiscovery::DiscoverDevicesForPlatform()` on Linux discovers CPU and GPU devices, but NPU discovery is still missing. As a result, plugin execution providers that filter devices by `OrtHardwareDeviceType_NPU` do not receive any NPU hardware devices on Linux, even when the NPU is present and exposed by the kernel. This change scans `/sys/class/accel` for `accelN` devices and creates `OrtHardwareDevice` entries with: - `type = OrtHardwareDeviceType_NPU` - PCI `vendor_id` - PCI `device_id` - `accel_idx` metadata - `pci_bus_id` metadata when available This enables Linux systems with NPUs exposed through the accel subsystem, such as AMD Ryzen AI / XDNA devices, to be reported through ORT device discovery and made available to plugin EP factories. ## Changes - Add Linux sysfs discovery for NPU devices under `/sys/class/accel`. - Read NPU PCI vendor and device IDs from the underlying sysfs device path. - Add NPU metadata including `accel_idx` and `pci_bus_id`. - Include discovered NPU devices in `DeviceDiscovery::DiscoverDevicesForPlatform()`. - Add a `kSysfsAccelPath` constant for the Linux accel sysfs path. ## Motivation Linux plugin EPs that target NPUs rely on ORT passing `OrtHardwareDeviceType_NPU` devices into `GetSupportedDevices()`. Without Linux NPU discovery, those EPs cannot claim NPU devices and provider selection policies such as `PREFER_NPU` silently fall back to CPU. Fixes #28660. --- .../core/platform/linux/device_discovery.cc | 128 +++++++++++++++++- .../platform/linux/npu_device_discovery.h | 32 +++++ .../linux/npu_device_discovery_test.cc | 103 ++++++++++++++ 3 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/platform/linux/npu_device_discovery.h create mode 100644 onnxruntime/test/platform/linux/npu_device_discovery_test.cc diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index 732a5a855d65d..4260ef706befa 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/device_discovery.h" +#include "core/platform/linux/npu_device_discovery.h" #include "core/platform/linux/pci_device_discovery.h" #include @@ -123,7 +124,7 @@ Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& gpu_devices_out) { std::vector gpu_sysfs_path_infos{}; @@ -315,6 +317,119 @@ Status GetGpuDevices(std::vector& gpu_devices_out) { } // namespace +namespace npu_device_discovery { + +Status DetectNpuSysfsPaths(const fs::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out) { + std::error_code error_code{}; + + const bool sysfs_accel_path_exists = fs::exists(sysfs_accel_path, error_code); + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Checking existence of accel sysfs path")); + + if (!sysfs_accel_path_exists) { + npu_sysfs_paths_out = {}; + return Status::OK(); + } + + const auto detect_accel_path = [](const fs::path& sysfs_path, size_t& accel_idx) -> bool { + const auto filename = sysfs_path.filename(); + const auto filename_str = std::string_view{filename.native()}; + + // Look for a filename matching "accelN". N is a number. + constexpr std::string_view prefix = "accel"; + if (filename_str.find(prefix) != 0) { + return false; + } + + size_t parsed_accel_idx{}; + if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_accel_idx)) { + return false; + } + + accel_idx = parsed_accel_idx; + return true; + }; + + std::vector npu_sysfs_paths{}; + + auto dir_iterator = fs::directory_iterator{sysfs_accel_path, error_code}; + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Iterating over accel sysfs devices")); + + for (const auto& dir_item : dir_iterator) { + const auto& dir_item_path = dir_item.path(); + + if (size_t accel_idx{}; detect_accel_path(dir_item_path, accel_idx)) { + NpuSysfsPathInfo path_info{}; + path_info.accel_idx = accel_idx; + path_info.path = dir_item_path; + npu_sysfs_paths.emplace_back(std::move(path_info)); + } + } + + npu_sysfs_paths_out = std::move(npu_sysfs_paths); + return Status::OK(); +} + +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out) { + OrtHardwareDevice npu_device{}; + + const auto& sysfs_path = path_info.path; + + uint16_t vendor_id{}; + const auto vendor_id_path = sysfs_path / "device" / "vendor"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, vendor_id)); + npu_device.vendor_id = vendor_id; + + uint16_t device_id{}; + const auto device_id_path = sysfs_path / "device" / "device"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, device_id)); + npu_device.device_id = device_id; + + npu_device.metadata.Add("accel_idx", MakeString(path_info.accel_idx)); + + std::optional pci_bus_id; + ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id)); + if (pci_bus_id) { + npu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id)); + } + + npu_device.type = OrtHardwareDeviceType_NPU; + + npu_device_out = std::move(npu_device); + + return Status::OK(); +} + +} // namespace npu_device_discovery + +namespace { + +Status GetNpuDevices(std::vector& npu_devices_out) { + std::vector npu_sysfs_path_infos{}; + ORT_RETURN_IF_ERROR(npu_device_discovery::DetectNpuSysfsPaths(kSysfsAccelPath, npu_sysfs_path_infos)); + + std::vector npu_devices{}; + npu_devices.reserve(npu_sysfs_path_infos.size()); + + for (const auto& npu_sysfs_path_info : npu_sysfs_path_infos) { + OrtHardwareDevice npu_device{}; + if (auto status = npu_device_discovery::GetNpuDeviceFromSysfs(npu_sysfs_path_info, npu_device); !status.IsOK()) { + LOGS_DEFAULT(WARNING) << MakeString("Failed to detect devices under ", + npu_sysfs_path_info.path, + ": ", + status.ErrorMessage()); + continue; + } + npu_devices.emplace_back(std::move(npu_device)); + } + + npu_devices_out = std::move(npu_devices); + + return Status::OK(); +} +} // namespace + std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { std::unordered_set devices; @@ -334,7 +449,16 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } // get NPU devices - // TODO figure out how to discover these + { + std::vector npu_devices{}; + Status npu_device_discovery_status = GetNpuDevices(npu_devices); + if (npu_device_discovery_status.IsOK()) { + devices.insert(std::make_move_iterator(npu_devices.begin()), + std::make_move_iterator(npu_devices.end())); + } else { + LOGS_DEFAULT(WARNING) << "NPU device discovery failed: " << npu_device_discovery_status.ErrorMessage(); + } + } return devices; } diff --git a/onnxruntime/core/platform/linux/npu_device_discovery.h b/onnxruntime/core/platform/linux/npu_device_discovery.h new file mode 100644 index 0000000000000..e139d33988903 --- /dev/null +++ b/onnxruntime/core/platform/linux/npu_device_discovery.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This header exposes Linux NPU device discovery internals for testing. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" +#include "core/session/abi_devices.h" + +namespace onnxruntime { +namespace npu_device_discovery { + +struct NpuSysfsPathInfo { + size_t accel_idx; + std::filesystem::path path; +}; + +// Scans the given sysfs accel directory for NPU accel devices. +Status DetectNpuSysfsPaths(const std::filesystem::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out); + +// Reads vendor/device IDs and populates an OrtHardwareDevice from an accel sysfs path. +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out); + +} // namespace npu_device_discovery +} // namespace onnxruntime diff --git a/onnxruntime/test/platform/linux/npu_device_discovery_test.cc b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc new file mode 100644 index 0000000000000..6e2170146f39f --- /dev/null +++ b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/linux/npu_device_discovery.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "test/util/include/asserts.h" + +namespace fs = std::filesystem; + +namespace onnxruntime::test { +namespace { + +void WriteFile(const fs::path& path, const std::string& value) { + std::ofstream f(path); + f << value; +} + +class NpuDeviceDiscoveryTest : public ::testing::Test { + protected: + void SetUp() override { + temp_dir_ = fs::temp_directory_path() / "ort_npu_discovery_test"; + fs::remove_all(temp_dir_); + fs::create_directories(temp_dir_); + } + + void TearDown() override { + fs::remove_all(temp_dir_); + } + + fs::path temp_dir_; +}; + +} // namespace + +TEST_F(NpuDeviceDiscoveryTest, ReturnsEmptyForNonexistentPath) { + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_ / "nonexistent", npu_paths)); + EXPECT_TRUE(npu_paths.empty()); +} + +TEST_F(NpuDeviceDiscoveryTest, DetectsAccelDevices) { + fs::create_directories(temp_dir_ / "accel0"); + fs::create_directories(temp_dir_ / "accel12"); + fs::create_directories(temp_dir_ / "renderD128"); + fs::create_directories(temp_dir_ / "accelabc"); + + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_, npu_paths)); + + ASSERT_EQ(npu_paths.size(), 2u); + + std::vector accel_indices; + accel_indices.reserve(npu_paths.size()); + for (const auto& npu_path : npu_paths) { + accel_indices.push_back(npu_path.accel_idx); + } + + std::sort(accel_indices.begin(), accel_indices.end()); + + EXPECT_EQ(accel_indices[0], 0u); + EXPECT_EQ(accel_indices[1], 12u); +} + +TEST_F(NpuDeviceDiscoveryTest, GetNpuDeviceFromSysfsReadsVendorDeviceAndMetadata) { + const auto pci_device_dir = temp_dir_ / "pci_devices" / "0000:65:00.0"; + const auto accel_dir = temp_dir_ / "class_accel" / "accel0"; + + fs::create_directories(pci_device_dir); + fs::create_directories(accel_dir); + + WriteFile(pci_device_dir / "vendor", "0x1022"); + WriteFile(pci_device_dir / "device", "0x1502"); + + std::error_code error_code{}; + fs::create_directory_symlink(pci_device_dir, accel_dir / "device", error_code); + ASSERT_FALSE(error_code) << error_code.message(); + + npu_device_discovery::NpuSysfsPathInfo path_info{}; + path_info.accel_idx = 0; + path_info.path = accel_dir; + + OrtHardwareDevice npu_device{}; + ASSERT_STATUS_OK(npu_device_discovery::GetNpuDeviceFromSysfs(path_info, npu_device)); + + EXPECT_EQ(npu_device.type, OrtHardwareDeviceType_NPU); + EXPECT_EQ(npu_device.vendor_id, 0x1022u); + EXPECT_EQ(npu_device.device_id, 0x1502u); + + const auto& entries = npu_device.metadata.Entries(); + EXPECT_NE(entries.find("accel_idx"), entries.end()); + EXPECT_EQ(entries.at("accel_idx"), "0"); + EXPECT_NE(entries.find("pci_bus_id"), entries.end()); + EXPECT_EQ(entries.at("pci_bus_id"), "0000:65:00.0"); +} + +} // namespace onnxruntime::test From 6584194ff9417879e4a6b133df37fd10ac0db1a0 Mon Sep 17 00:00:00 2001 From: Darshak Bhatti <47045043+dabhattimsft@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:09:21 -0700 Subject: [PATCH 6/8] Log EP version on inference failure and in EpDeviceUsage event, Log ORT version (#28794) ### Description Adds new telemetry event for inference failure which logs ep versions and types along with runtime error. Adds logging of ORT version in other telemetry events. Adds logging of ep versions in SessionCreation telemetry ### Motivation and Context To better diagnose failures in inference --------- Co-authored-by: Darshak Bhatti Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/core/platform/telemetry.cc | 13 +++++ onnxruntime/core/platform/telemetry.h | 6 ++ .../core/platform/windows/telemetry.cc | 58 ++++++++++++++++--- onnxruntime/core/platform/windows/telemetry.h | 6 ++ onnxruntime/core/session/inference_session.cc | 22 ++++++- onnxruntime/core/session/inference_session.h | 2 + 6 files changed, 95 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 2008f998384a3..20ec11bf7deb2 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -64,6 +64,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(ir_version); @@ -81,6 +82,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(execution_provider_ids); ORT_UNUSED_PARAMETER(hardware_device_types); ORT_UNUSED_PARAMETER(hardware_vendor_ids); + ORT_UNUSED_PARAMETER(ep_versions); ORT_UNUSED_PARAMETER(use_fp16); ORT_UNUSED_PARAMETER(captureState); } @@ -124,6 +126,15 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu ORT_UNUSED_PARAMETER(line); } +void Telemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(status); + ORT_UNUSED_PARAMETER(ep_versions); + ORT_UNUSED_PARAMETER(ep_device_types); +} + void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { ORT_UNUSED_PARAMETER(session_id); @@ -139,6 +150,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -149,6 +161,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, ORT_UNUSED_PARAMETER(hardware_device_id); ORT_UNUSED_PARAMETER(hardware_vendor); ORT_UNUSED_PARAMETER(ep_vendor); + ORT_UNUSED_PARAMETER(ep_version); ORT_UNUSED_PARAMETER(assigned_node_count); ORT_UNUSED_PARAMETER(total_runs_since_last); ORT_UNUSED_PARAMETER(total_run_duration_since_last); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index c2dce68faa2dd..946e34ee35832 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,7 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const; virtual void LogCompileModelStart(uint32_t session_id, @@ -86,6 +87,10 @@ class Telemetry { virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; + virtual void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const; + virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const; @@ -100,6 +105,7 @@ class Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 04b9aaa0eb8ed..342b937ffb656 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -285,6 +285,7 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -318,6 +319,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio const std::string& loaded_from, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -373,7 +375,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -392,6 +395,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { @@ -404,7 +408,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -423,6 +428,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -457,7 +463,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(input_source.c_str(), "inputSource"), TraceLoggingString(output_target.c_str(), "outputTarget"), @@ -466,6 +472,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingBool(embed_ep_context, "embedEpContext"), TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"), TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -507,7 +514,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingHResult(hr, "hResult"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), @@ -516,6 +523,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #else TraceLoggingWrite(telemetry_provider_handle, @@ -525,7 +533,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), TraceLoggingUInt32(status.Category(), "errorCategory"), @@ -533,10 +541,35 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #endif } +void WindowsTelemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RuntimeInferenceError", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), + TraceLoggingString(ep_device_types.c_str(), "executionProviderDeviceTypes"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { if (global_register_count_ == 0 || enabled_ == false) @@ -559,11 +592,12 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -574,6 +608,7 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -588,7 +623,8 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + // schemaVersion 1: added epVersion, runtimeVersion + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(ep_type.c_str(), "executionProviderType"), TraceLoggingString(hardware_device_type.c_str(), "hardwareDeviceType"), @@ -596,9 +632,11 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingUInt32(hardware_device_id, "hardwareDeviceId"), TraceLoggingString(hardware_vendor.c_str(), "hardwareVendor"), TraceLoggingString(ep_vendor.c_str(), "epVendor"), + TraceLoggingString(ep_version.c_str(), "epVersion"), TraceLoggingInt32(assigned_node_count, "assignedNodeCount"), TraceLoggingUInt32(total_runs_since_last, "totalRunsSinceLast"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDurationSinceLast"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -724,8 +762,9 @@ void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -799,8 +838,9 @@ void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 8295d042d7ec9..46c262e1479d3 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,7 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const override; void LogCompileModelStart(uint32_t session_id, @@ -79,6 +80,10 @@ class WindowsTelemetry : public Telemetry { void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; + void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const override; + void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const override; @@ -89,6 +94,7 @@ class WindowsTelemetry : public Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 73f3e77b93177..9db18fc4ac69e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -78,6 +78,7 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/user_logging_sink.h" @@ -871,7 +872,7 @@ InferenceSession::~InferenceSession() { telemetry_provider.LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } } @@ -2778,6 +2779,7 @@ common::Status InferenceSession::Initialize() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(), telemetry_.ep_device_types_summary_, telemetry_.ep_device_vendor_ids_summary_, + telemetry_.ep_versions_summary_, model_has_fp16_inputs, false); // Emit one initial EpDeviceUsage event per (EP, device) pair with run counts of 0. @@ -2787,7 +2789,7 @@ common::Status InferenceSession::Initialize() { env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, 0, 0); + ep_info.ep_version, ep_info.assigned_node_count, 0, 0); } LOGS(*session_logger_, INFO) << "Session successfully initialized."; @@ -3456,7 +3458,7 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } // reset counters @@ -3470,6 +3472,10 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, Telemetry::kRuntimePerfMaxInterval); } } + } else { + // Log runtime error with EP versions + env.GetTelemetryProvider().LogRuntimeInferenceError(session_id_, retval, telemetry_.ep_versions_summary_, + telemetry_.ep_device_types_summary_); } // log evaluation stop to trace logging provider @@ -4087,6 +4093,7 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { telemetry_.ep_device_info_.clear(); telemetry_.ep_device_types_summary_.clear(); telemetry_.ep_device_vendor_ids_summary_.clear(); + telemetry_.ep_versions_summary_.clear(); // First, count nodes assigned to each EP type after graph partitioning. // The graph node only carries the EP type string, so when a single EP targets @@ -4124,6 +4131,10 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { Telemetry::EpDeviceInfo entry; entry.ep_type = ep_type; entry.ep_vendor = ep_device->ep_vendor; + auto it = ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version); + if (it != ep_device->ep_metadata.Entries().end()) { + entry.ep_version = it->second; + } if (ep_device->device != nullptr) { entry.hardware_device_type = HardwareDeviceTypeToString(ep_device->device->type); entry.vendor_id = ep_device->device->vendor_id; @@ -4158,11 +4169,13 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { // by position against the existing executionProviderIds field. std::ostringstream types_oss; std::ostringstream vendor_ids_oss; + std::ostringstream versions_oss; bool first = true; for (const auto& entry : telemetry_.ep_device_info_) { if (!first) { types_oss << ','; vendor_ids_oss << ','; + versions_oss << ','; } first = false; types_oss << entry.hardware_device_type; @@ -4170,9 +4183,11 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { vendor_ids_oss << "0x" << std::hex << std::uppercase << std::setw(4) << std::setfill('0') << entry.vendor_id << std::dec << std::nouppercase << std::setfill(' '); + versions_oss << entry.ep_type << ':' << entry.ep_version; } telemetry_.ep_device_types_summary_ = types_oss.str(); telemetry_.ep_device_vendor_ids_summary_ = vendor_ids_oss.str(); + telemetry_.ep_versions_summary_ = versions_oss.str(); } #if !defined(ORT_MINIMAL_BUILD) @@ -4419,6 +4434,7 @@ void InferenceSession::LogAllSessions() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model->MetaData(), session->telemetry_.event_name_, session->execution_providers_.GetIds(), session->telemetry_.ep_device_types_summary_, session->telemetry_.ep_device_vendor_ids_summary_, + session->telemetry_.ep_versions_summary_, model_has_fp16_inputs, true); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 11c887f6cdc16..19c18652ab67a 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -1023,12 +1023,14 @@ class InferenceSession { uint32_t device_id = 0; // PCI device ID (0 when unavailable) std::string vendor; // e.g. "Qualcomm" std::string ep_vendor; // e.g. "Qualcomm" (from OrtEpDevice) + std::string ep_version; // e.g. "1.2.3" (from OrtEpFactory::GetVersion, empty when unavailable) int assigned_node_count = 0; // # graph nodes assigned to this EP type }; std::vector ep_device_info_; // Pre-formatted comma-separated summaries used to enrich SessionCreation. std::string ep_device_types_summary_; // "NPU,CPU" std::string ep_device_vendor_ids_summary_; // "0x5143,0x0000" + std::string ep_versions_summary_; // "QNNExecutionProvider:1.2.3,CPUExecutionProvider:" } telemetry_; mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data From 4b5abcf51b605049ebbed07672018fd6e8d82d47 Mon Sep 17 00:00:00 2001 From: Gopalakrishnan Nallasamy Date: Thu, 11 Jun 2026 11:10:27 -0700 Subject: [PATCH 7/8] Fix STFT complex input frame offsets (#28961) ### Description Fix STFT frame pointer arithmetic for complex-valued input so frame starts are computed in input samples, not trailing real/imag components. Since the frame view pointer is `U*`, one pointer increment advances one full real or complex sample. Also add validation that `frame_step` is positive and keep a defensive bounds check before creating non-owning tensor views. Review feedback addressed: simplified the frame pointer arithmetic, fixed the swapped STFT input comments, documented the defensive bounds check, and added double-complex regression coverage. The new STFT validation/regression tests exclude `kDmlExecutionProvider` because these CPU STFT validation/regression paths do not consistently match DirectML behavior in Windows GPU CI. ### Motivation and Context For complex input shaped `[batch_size, signal_length, 2]`, pointer increments already advance by one real/imag pair. Multiplying frame offsets by `signal_components == 2` again can advance past the valid frame start, allowing later frames to read across batches or beyond the input allocation. ### Testing - `git diff --check -- onnxruntime/core/providers/cpu/signal/dft.cc onnxruntime/test/providers/cpu/signal/signal_ops_test.cc` - `.\.venv\Scripts\python.exe tools\ci_build\build.py --config RelWithDebInfo --build --parallel --target onnxruntime_provider_test --build_dir build\Windows` - `.\onnxruntime_provider_test.exe --gtest_filter="SignalOpsTest.STFTFloat:SignalOpsTest.STFTFrameStepMustBePositive:SignalOpsTest.STFTFloatComplexInputBatched:SignalOpsTest.STFTDoubleComplexInputBatched"` from `build\Windows\RelWithDebInfo\RelWithDebInfo` --------- Co-authored-by: Gopalakrishnan Nallasamy --- onnxruntime/core/providers/cpu/signal/dft.cc | 16 ++-- .../providers/cpu/signal/signal_ops_test.cc | 74 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index f462680e02587..f75c5fe12a6f4 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -524,14 +524,15 @@ template static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool /*inverse*/) { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal const auto* signal = ctx->Input(0); const auto frame_step = signal::get_scalar_value_from_tensor(ctx->Input(1)); + ORT_RETURN_IF_NOT(frame_step > 0, "frame_step must be greater than zero."); const auto* window = ctx->Input(2); const auto* frame_length_tensor = ctx->Input(3); @@ -596,8 +597,11 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Run each dft of each batch as if it was a real-valued batch size 1 dft operation for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int64_t i = 0; i < n_dfts; i++) { - auto input_frame_begin = - signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components); + const auto frame_start = i * frame_step; + // Defensive check before creating a non-owning tensor view. n_dfts derivation should keep this in bounds. + ORT_RETURN_IF_NOT(frame_start <= signal_size - window_size, "STFT input frame is out of bounds."); + // signal_data is U*, so one increment advances one input sample, including both lanes for complex input. + auto input_frame_begin = signal_data + (batch_idx * signal_size) + frame_start; auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + (i * dft_output_size * output_components); @@ -619,9 +623,9 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside Status STFT::Compute(OpKernelContext* ctx) const { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal shape diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index 966090eed861e..c87e5d7c4c6e1 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -233,6 +233,80 @@ TEST(SignalOpsTest, STFTFloat) { test.Run(); } +static void TestSTFTInvalidFrameStep(int64_t frame_step) { + OpTester test("STFT", kMinOpsetVersion); + + vector signal(64, 1); + test.AddInput("signal", {1, 64, 1}, signal); + test.AddInput("frame_step", {}, {frame_step}); + vector window(16, 1); + test.AddInput("window", {16}, window); + test.AddInput("frame_length", {}, {16}); + + vector output_shape = {1, 7, 9, 2}; + vector expected_output(1 * 7 * 9 * 2, 0.f); + test.AddOutput("output", output_shape, expected_output); + test.Config(OpTester::ExpectResult::kExpectFailure, "frame_step must be greater than zero"); + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFrameStepMustBePositive) { + TestSTFTInvalidFrameStep(0); + TestSTFTInvalidFrameStep(-1); +} + +template +static void TestSTFTComplexInputBatched() { + OpTester test("STFT", kMinOpsetVersion); + test.AddAttribute("onesided", static_cast(false)); + + constexpr int64_t batch_size = 2; + constexpr int64_t signal_size = 128; + constexpr int64_t signal_components = 2; + constexpr int64_t frame_length = 32; + constexpr int64_t frame_step = 16; + constexpr int64_t n_dfts = 7; + constexpr int64_t dft_output_size = frame_length; + constexpr int64_t output_components = 2; + + vector signal(batch_size * signal_size * signal_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T signal_value = batch_idx == 0 ? static_cast(1) : static_cast(99); + for (int64_t sample_idx = 0; sample_idx < signal_size; ++sample_idx) { + signal[(batch_idx * signal_size + sample_idx) * signal_components] = signal_value; + } + } + + test.AddInput("signal", {batch_size, signal_size, signal_components}, signal); + test.AddInput("frame_step", {}, {frame_step}); + test.AddOptionalInputEdge(); + test.AddInput("frame_length", {}, {frame_length}); + + vector output_shape = {batch_size, n_dfts, dft_output_size, output_components}; + vector expected_output(batch_size * n_dfts * dft_output_size * output_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T dc_value = (batch_idx == 0 ? static_cast(1) : static_cast(99)) * static_cast(frame_length); + for (int64_t frame_idx = 0; frame_idx < n_dfts; ++frame_idx) { + expected_output[((batch_idx * n_dfts + frame_idx) * dft_output_size) * output_components] = dc_value; + } + } + + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.001f); + // DML does not consistently match these CPU STFT validation/regression paths in Windows GPU CI. + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFloatComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + +TEST(SignalOpsTest, STFTDoubleComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + TEST(SignalOpsTest, HannWindowFloat) { OpTester test("HannWindow", kMinOpsetVersion); From cf509d84a906a593e89362e9ff7e2ece1a1e372b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:40:38 -0700 Subject: [PATCH 8/8] QMoE: fail loudly when weights_prepacked=0 but PrePack did not run (#28965) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description When a QMoE model sets `weights_prepacked=0` (raw `[E, N, K/pack]` int weights) and the session has `session.disable_prepacking`, `PrePack()` never runs, so `packed_fc{1,2}_weights_` stay null and `int_weights_consumed_by_prepack` is false. The code then falls through to the raw initializer pointers — but those bytes are not in CUTLASS layout, so the runner consumes them as-if-prepacked and produces silently wrong output with no diagnostic. Changes in `onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc` (`QMoE::ComputeInternal`): - **Int path**: Added a defensive `INVALID_ARGUMENT` guard — when `is_int && !weights_prepacked_` but either prepack buffer is null, return a clear error instead of feeding non-CUTLASS bytes to the runner. - **wfp4afp8 native path**: Same fall-through (`packed_fp4_fc{1,2}_weights_ ? ... : raw`) replaced with an explicit guard that errors when the repacked FP4 buffers were not produced. Also added a focused regression test in `onnxruntime/test/contrib_ops/moe_test.cc` covering `quant_type='int'` with `weights_prepacked=0` and `session.disable_prepacking=1`, asserting that QMoE fails with an actionable error instead of producing output. Merged the branch with the latest `main`. ### Motivation and Context A prior fix removed the null-pointer crash on this path but left a misleading-success outcome that is newly user-reachable via the `weights_prepacked=0` contract — the exact silent-failure mode the offline-path work set out to eliminate. These guards convert that into a loud, actionable error. The wfp4afp8 branch shares the same fall-through and is hardened for consistency. The added regression test ensures this fail-loudly behavior remains covered going forward, especially when prepacking is disabled at the session level. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../contrib_ops/cuda/moe/moe_quantization.cc | 27 +++++++- onnxruntime/test/contrib_ops/moe_test.cc | 69 +++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index 7d1291e004d78..bc34a2e83318a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -234,6 +234,18 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { // to the runner. const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr && packed_fc2_weights_ != nullptr; + // When ``weights_prepacked == 0`` the raw ``[E, N, K/pack]`` int weights must be + // converted to the CUTLASS fpA_intB layout by PrePack before the runner can consume + // them. If PrePack never ran (e.g. ``session.disable_prepacking`` is set), the prepack + // buffers stay null and falling through to the raw initializer pointers would feed + // non-CUTLASS bytes to the runner, producing silently wrong output. Fail loudly instead. + if (is_int && !weights_prepacked_ && + (packed_fc1_weights_ == nullptr || packed_fc2_weights_ == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE weights_prepacked=0 requires PrePack to run, but the int weight " + "buffers were not produced (is session.disable_prepacking set?). Provide " + "CUTLASS-prepacked weights with weights_prepacked=1, or enable prepacking."); + } const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input(2); const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input(3) : nullptr; const Tensor* fc1_experts_bias_optional = context->Input(4); @@ -854,8 +866,19 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr; const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr; if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) { - fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data; - fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data; + // The native CUTLASS WFP4AFP8 path consumes weights in the repacked FP4 + // layout produced by PrePack. If PrePack never ran (e.g. + // ``session.disable_prepacking`` is set) the repacked buffers stay null and + // falling through to the raw initializer bytes would feed a non-CUTLASS + // layout to the runner, producing silently wrong output. Fail loudly. + if (packed_fp4_fc1_weights_ == nullptr || packed_fp4_fc2_weights_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE wfp4afp8 requires PrePack to run, but the repacked FP4 weight " + "buffers were not produced (is session.disable_prepacking set?). " + "Enable prepacking to use the native WFP4AFP8 path."); + } + fc1_weight_data = packed_fp4_fc1_weights_.get(); + fc2_weight_data = packed_fp4_fc2_weights_.get(); } else if (int_weights_consumed_by_prepack) { // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB // layout that the runner consumes and freed the source initializer diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index cc50494273aad..38cc3014aa873 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -778,6 +779,74 @@ TEST(MoETest, MoETest_Mixtral) { 2 /*top_k*/); } +TEST(MoETest, QMoETest_CUDA_Int4_DisablePrepackingFailsLoudly) { + constexpr int min_cuda_arch = 700; + if (!HasCudaEnvironment(min_cuda_arch)) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + constexpr int64_t num_rows = 1; + constexpr int64_t num_experts = 1; + constexpr int64_t hidden_size = 128; + constexpr int64_t inter_size = 128; + constexpr int64_t expert_weight_bits = 4; + constexpr int64_t pack_size = 8 / expert_weight_bits; + + const std::vector input(num_rows * hidden_size, 0.0f); + const std::vector router_probs(num_rows * num_experts, 1.0f); + const std::vector fc1_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0); + const std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0); + const std::vector fc1_scales(num_experts * inter_size, 1.0f); + const std::vector fc2_scales(num_experts * hidden_size, 1.0f); + const std::vector dummy_output(num_rows * hidden_size, 0.0f); + + OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain); + cuda_tester.AddAttribute("k", 1); + cuda_tester.AddAttribute("activation_type", "identity"); + cuda_tester.AddAttribute("normalize_routing_weights", 1); + cuda_tester.AddAttribute("expert_weight_bits", expert_weight_bits); + cuda_tester.AddAttribute("quant_type", "int"); + cuda_tester.AddAttribute("weights_prepacked", 0); + + const std::vector input_dims = {num_rows, hidden_size}; + const std::vector router_probs_dims = {num_rows, num_experts}; + const std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + const std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + const std::vector fc1_scales_dims = {num_experts, inter_size}; + const std::vector fc2_scales_dims = {num_experts, hidden_size}; + const std::vector output_dims = {num_rows, hidden_size}; + + cuda_tester.AddInput("input", input_dims, ToFloat16(input)); + cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOutput("output", output_dims, ToFloat16(dummy_output)); + + SessionOptions session_options; + session_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "1"; + + std::vector> cuda_execution_providers; + cuda_execution_providers.push_back(std::move(cuda_ep)); + cuda_tester.Run(session_options, + OpTester::ExpectResult::kExpectFailure, + "QMoE weights_prepacked=0 requires PrePack to run", + {}, + nullptr, + &cuda_execution_providers); +} + TEST(MoETest, QMoETest_Mixtral_Int4) { // This test uses FC3 (gated SiLU / Mixtral pattern) with dimensions too small for the // CUTLASS kernel (needs hidden_size >= 128, inter_size >= 128). CPU QMoE does not