From 2c4be91361bb7377d7cb4db5599776685813ceae Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Wed, 10 Jun 2026 17:13:55 -0700 Subject: [PATCH 01/15] Move OrtModelPackageApi to the experimental C API (#28990) ### Description Move the existing model package C API off the stable `OrtApi` onto the experimental name-based lookup mechanism added in #28746. Each model package function is registered individually in `include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc` with the `OrtModelPackageApi_` prefix and the `_SinceV28` version suffix, following the lifecycle rules in `docs/design/Experimental_C_API.md`. Headline changes: - `OrtApi::GetModelPackageApi`, the `OrtModelPackageApi` struct, `OrtApis::GetModelPackageApi`, the `OrtModelPackageAPI` namespace, `onnxruntime/core/session/model_package_api.h`, and the C++ wrappers (`Ort::GetModelPackageApi`, `ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackage*)`, `ModelPackageOptions/Context/ComponentContext`) are removed. - Opaque handle types (`OrtModelPackageOptions`, `OrtModelPackageContext`, `OrtModelPackageComponentContext`) move into `onnxruntime_experimental_c_api.h`. - All 15 model package functions are registered in `onnxruntime_experimental_c_api.inc`. Impls move into `namespace OrtExperimentalApis` with `_SinceV28`-suffixed names in `model_package_api.cc`; bodies are unchanged. - `experimental_c_api.cc` gains a forward-decl block (driven by the same `.inc` X-macro) so the auto-generated registration table can take the address of every entry, even those defined in `model_package_api.cc`. - The Python bindings (`PyModelPackageContext` / `PyModelPackageOptions` / `PyModelPackageComponentContext` and their `onnxruntime.__init__` exports) are removed. Per the design doc we start the experimental API in C/C++ only. - `onnxruntime/test/autoep/test_model_package.cc` switches to a local `ModelPackageFns` struct populated through the `Ort::Experimental::Get_OrtModelPackageApi_*_Fn(api)` typed accessors. Consumer usage going forward, in C++: ```cpp #include "onnxruntime_c_api.h" #include "onnxruntime_experimental_c_api.h" const OrtApi* ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); if (auto* fn = Ort::Experimental::Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn(ort)) { OrtModelPackageContext* ctx = nullptr; Ort::ThrowOnError(fn(ORT_TSTR("/path/to/pkg"), &ctx)); // ... } ``` ### Motivation and Context The model package API was added to the stable `OrtApi` in 1.27 but has not shipped in a release yet. Now that #28746 has landed the experimental C API framework, the right home for an iterating preview surface like model package is behind `OrtApi::GetExperimentalFunction`, not on the stable struct. Moving it to experimental: - frees us to change signatures (each name is uniquely versioned) without breaking the stable ABI; - gives consumers a clear "is this specific thing available?" contract instead of a struct that *looks* stable but isn't; - lets the surface be promoted to stable cleanly later (move entries to `OrtApi`, drop the `_SinceV` suffix, remove the experimental entries). --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 270 ------------------ .../core/session/onnxruntime_cxx_api.h | 84 ------ .../core/session/onnxruntime_cxx_inline.h | 104 ------- .../session/onnxruntime_experimental_c_api.h | 4 + .../onnxruntime_experimental_c_api.inc | 247 ++++++++++++++++ onnxruntime/__init__.py | 3 - .../core/session/experimental_c_api.cc | 8 + onnxruntime/core/session/model_package_api.cc | 72 ++--- onnxruntime/core/session/model_package_api.h | 76 ----- onnxruntime/core/session/onnxruntime_c_api.cc | 6 - onnxruntime/core/session/ort_apis.h | 3 - .../python/onnxruntime_pybind_state.cc | 189 ------------ onnxruntime/test/autoep/test_model_package.cc | 264 +++++++++-------- 13 files changed, 422 insertions(+), 908 deletions(-) delete mode 100644 onnxruntime/core/session/model_package_api.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0cdbfd7114cf0..0289bca5c84d2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -348,9 +348,6 @@ ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared exte ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails); ORT_RUNTIME_CLASS(EpAssignedSubgraph); ORT_RUNTIME_CLASS(EpAssignedNode); -ORT_RUNTIME_CLASS(ModelPackageOptions); -ORT_RUNTIME_CLASS(ModelPackageContext); -ORT_RUNTIME_CLASS(ModelPackageComponentContext); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -938,9 +935,6 @@ typedef struct OrtCompileApi OrtCompileApi; struct OrtInteropApi; typedef struct OrtInteropApi OrtInteropApi; -struct OrtModelPackageApi; -typedef struct OrtModelPackageApi OrtModelPackageApi; - struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -7501,26 +7495,6 @@ struct OrtApi { */ ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); - /** \brief Get the model package API table. - * - * Returns a pointer to the ::OrtModelPackageApi function table, which provides APIs to: - * - create and release model package options and contexts, - * - inspect model package metadata (components/variants), - * - select a component/variant and query selected files/options, - * - create a session from model package selection results. - * - * The returned pointer is owned by ONNX Runtime and is valid for the process lifetime. - * Do not free it. - * - * \note May return NULL if model package support is not available in the current build - * (for example, minimal builds). - * - * \return Pointer to ::OrtModelPackageApi, or NULL if unsupported. - * - * \since Version 1.27. - */ - const OrtModelPackageApi*(ORT_API_CALL* GetModelPackageApi)(void); - /** \brief Retrieve an experimental function pointer by name. * * Experimental functions are not part of the stable ABI and may be added or removed between releases without notice. @@ -8645,250 +8619,6 @@ struct OrtInteropApi { /// @} }; -/** \brief API table for model package workflows. - * - * A model package is a directory containing one or more *components* (logical models). - * Each component has one or more *variants*, where each variant targets a single - * execution provider (EP). The package manifest declares the EP name, device type, - * and an optional compatibility string for every variant so that the runtime can - * automatically select the best variant for the hardware and EPs available in the - * caller's session options. - * - * Obtain this table from OrtApi::GetModelPackageApi(). The APIs support: - * - creating model package options that capture EP configuration from OrtSessionOptions, - * - loading a package context (manifest + metadata) from a package root path, - * - querying component/variant metadata including per-variant EP information, - * - selecting a component (which also resolves the best-matching variant), - * - querying the selected variant's name and folder path, - * - creating an OrtSession from the selected component context. - * - * Typical flow: - * 1) Create model package options: - * - CreateModelPackageOptionsFromSessionOptions() - * 2) Load package metadata: - * - CreateModelPackageContext() - * 3) Query metadata (optional): - * - ModelPackage_GetSchemaVersion() - * - ModelPackage_GetComponentCount() - * - ModelPackage_GetComponentNames() - * - ModelPackage_GetVariantCount() - * - ModelPackage_GetVariantNames() - * - ModelPackage_GetVariantEpName() - * 4) Select a component and resolve variant: - * - SelectComponent() - * 5) Query selected variant info (optional): - * - ModelPackageComponent_GetSelectedVariantName() - * - ModelPackageComponent_GetSelectedVariantFolderPath() - * 6) Create session: - * - CreateSession() - * - * Ownership: - * - Release objects created by this API with the corresponding release methods: - * ReleaseModelPackageOptions(), ReleaseModelPackageContext(), - * ReleaseModelPackageComponentContext(). - * - * \since Version 1.27. - */ -struct OrtModelPackageApi { - /// \name OrtModelPackageOptions - /// @{ - - /** \brief Create model package options from an existing OrtSessionOptions. - * - * Captures EP configuration (registered execution providers and their devices) from - * the session options for use during variant selection. The resulting OrtModelPackageOptions - * is passed to SelectComponent() to resolve the best variant for the available EPs. - * - * \param[in] env The ORT environment. - * \param[in] session_options Session options containing registered EPs. - * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released - * with ReleaseModelPackageOptions(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - - ORT_CLASS_RELEASE(ModelPackageOptions); - /// @} - /// \name OrtModelPackageContext - /// @{ - - /** \brief Create a model package context by parsing the package at the given root path. - * - * Parses the manifest.json and component metadata from the specified directory. - * The returned context provides read-only access to the package structure (components, - * variants, EP declarations). - * - * \param[in] package_root Path to the model package root directory (containing manifest.json). - * \param[out] out Receives the newly created OrtModelPackageContext. Must be released - * with ReleaseModelPackageContext(). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - - ORT_CLASS_RELEASE(ModelPackageContext); - - /** \brief Get the schema version declared in the model package manifest. - * - * \param[in] ctx The model package context. - * \param[out] out_version Receives the schema version number. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - - /** \brief Get the number of components in the model package. - * - * \param[in] ctx The model package context. - * \param[out] out_count Receives the component count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - - /** \brief Get the names of all components in the model package. - * - * Returns a pointer to an array of UTF-8 component name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[out] out_names Receives a pointer to an array of component name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - - /** \brief Get the number of variants for a given component. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_count Receives the variant count. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - - /** \brief Get the names of all variants for a given component. - * - * Returns a pointer to an array of UTF-8 variant name strings. The array and its - * strings are owned by `ctx` and remain valid until the context is released. - * - * \param[in] ctx The model package context. - * \param[in] component_name Name of the component to query. - * \param[out] out_variant_names Receives a pointer to an array of variant name strings. - * \param[out] out_count Receives the number of elements in the array. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - - /** \brief Get the EP name declared for a (component, variant) pair. - * - * Each variant targets a single EP. `out_ep` receives the EP name string. - * When the variant does not declare an EP, the returned pointer is NULL. - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - - /** \brief Select a component model and return an opaque component instance. - * - * The variant selection is also performed during this call based on the component metadata and the provided options. - * The returned `OrtModelPackgeComponentContext*` is independent of `context` lifetime and must be released via - * `ReleaseComponentInstance`. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - - ORT_CLASS_RELEASE(ModelPackageComponentContext); - - /** \brief Get the name of the selected variant after SelectComponent has been called. - * - * String memory is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] out_name Receives the selected variant's name string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - - /** \brief Get the folder path of the selected variant. - * - * Returns the resolved absolute path to the variant's directory on disk. - * The string is owned by `ctx` and remains valid until the context is released. - * - * \param[in] ctx The component context returned by SelectComponent(). - * \param[out] folder_path Receives the variant folder path string. - * - * \since Version 1.27. - */ - ORT_API2_STATUS(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - - /// @} - /** \brief Create an OrtSession for a selected file within a component model variant. - * - * The chosen variant (and thus its EP selection) is determined by `context`, which - * was built from an OrtSessionOptions via CreateModelPackageOptionsFromSessionOptions. - * - * Session options precedence: - * 1. session_options == NULL (default path): - * ORT uses the OrtSessionOptions that was captured when `context` was created. - * Any variant-specific session and provider options declared in the package - * metadata are merged on top. - * - * 2. session_options != NULL (advanced path): - * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific - * session and provider options from the package metadata are NOT applied. - * Use this when custom EP setup is required (e.g., shared CUDA streams, - * shared QNN EP contexts, custom allocators). - * - * \since Version 1.27. - */ - ORT_API2_STATUS(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* context, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a999ef5e0faf4..4798d3d4ad1b8 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -264,20 +264,6 @@ inline const OrtEpApi& GetEpApi() { return *api; } -/// -/// This returns a reference to the ORT C Model Package API. Used for loading models from model packages. -/// -/// ORT C Model Package API reference -inline const OrtModelPackageApi& GetModelPackageApi() { - auto* api = GetApi().GetModelPackageApi(); - if (api == nullptr) { - // minimal build - ORT_CXX_API_THROW("Model Package API is not available in this build", ORT_FAIL); - } - - return *api; -} - /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -678,9 +664,6 @@ ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelDefBuilder, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(KernelRegistry, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(OpSchema, GetEpApi); ORT_DEFINE_RELEASE_FROM_API_STRUCT(ProfilingEvent, GetEpApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageOptions, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageContext, GetModelPackageApi); -ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelPackageComponentContext, GetModelPackageApi); // This is defined explicitly since OrtTensorRTProviderOptionsV2 is not a C API type, // but the struct has V2 in its name to indicate that it is the second version of the options. @@ -807,9 +790,6 @@ struct EpDevice; struct ExternalInitializerInfo; struct Graph; struct Model; -struct ModelPackageOptions; -struct ModelPackageContext; -struct ModelPackageComponentContext; struct Node; struct ModelMetadata; struct TypeInfo; @@ -1806,70 +1786,6 @@ struct ModelCompilationOptions : detail::Base { */ Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); -/** \brief Options for selecting a component from a model package. - * - * Wraps ::OrtModelPackageOptions. Created from an Env and SessionOptions, which captures the - * EP configuration used for variant selection. - */ -struct ModelPackageOptions : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageOptions(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - ModelPackageOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions - ModelPackageOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateModelPackageOptionsFromSessionOptions -}; - -/** \brief Context for inspecting and selecting components from a model package. - * - * Wraps ::OrtModelPackageContext. Provides traversal APIs to enumerate components, variants, - * and EP compatibility, as well as component selection. - */ -struct ModelPackageContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - explicit ModelPackageContext(const ORTCHAR_T* package_root); ///< Wraps OrtModelPackageApi::CreateModelPackageContext - - size_t GetComponentCount() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentCount - std::vector GetComponentNames() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetComponentNames - size_t GetVariantCount(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantCount - std::vector GetVariantNames(const char* component_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantNames - - /// Get the EP name for a variant. Returns nullptr if not declared. - /// Returned string is owned by this context and valid until it is released. - const char* GetVariantEpName(const char* component_name, - const char* variant_name) const; ///< Wraps OrtModelPackageApi::ModelPackage_GetVariantEpName - - int64_t GetSchemaVersion() const; ///< Wraps OrtModelPackageApi::ModelPackage_GetSchemaVersion - - ModelPackageComponentContext SelectComponent(const char* component_name, - const ModelPackageOptions& options) const; ///< Wraps OrtModelPackageApi::SelectComponent -}; - -/** \brief Context for a selected component within a model package. - * - * Wraps ::OrtModelPackageComponentContext. Provides accessors for the selected variant's - * folder path and variant name. - */ -struct ModelPackageComponentContext : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelPackageComponentContext(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used. - - std::basic_string GetSelectedVariantFolderPath() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantFolderPath - - std::string GetSelectedVariantName() const; ///< Wraps OrtModelPackageApi::ModelPackageComponent_GetSelectedVariantName - - Session CreateSession(const Env& env); ///< Wraps OrtModelPackageApi::CreateSession (default path, NULL session_options) - Session CreateSession(const Env& env, const SessionOptions& session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) - Session CreateSession(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtModelPackageApi::CreateSession (advanced path) -}; - /** \brief Wrapper around ::OrtModelMetadata * */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 51f99655121c6..d7439e7b356c6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1360,110 +1360,6 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel(const Ort return *this; } -// ModelPackageOptions -inline ModelPackageOptions::ModelPackageOptions(const Env& env, const SessionOptions& session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -inline ModelPackageOptions::ModelPackageOptions(const Env& env, ConstSessionOptions session_options) { - ThrowOnError(GetModelPackageApi().CreateModelPackageOptionsFromSessionOptions(env, session_options, &this->p_)); -} - -// ModelPackageContext -inline ModelPackageContext::ModelPackageContext(const ORTCHAR_T* package_root) { - ThrowOnError(GetModelPackageApi().CreateModelPackageContext(package_root, &this->p_)); -} - -inline size_t ModelPackageContext::GetComponentCount() const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentCount(this->p_, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetComponentNames() const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetComponentNames(this->p_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline size_t ModelPackageContext::GetVariantCount(const char* component_name) const { - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantCount(this->p_, component_name, &count)); - return count; -} - -inline std::vector ModelPackageContext::GetVariantNames(const char* component_name) const { - const char* const* names = nullptr; - size_t count = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantNames(this->p_, component_name, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; -} - -inline const char* ModelPackageContext::GetVariantEpName(const char* component_name, - const char* variant_name) const { - const char* ep = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackage_GetVariantEpName( - this->p_, component_name, variant_name, &ep)); - return ep; -} - -inline int64_t ModelPackageContext::GetSchemaVersion() const { - int64_t version = 0; - ThrowOnError(GetModelPackageApi().ModelPackage_GetSchemaVersion(this->p_, &version)); - return version; -} - -inline ModelPackageComponentContext ModelPackageContext::SelectComponent( - const char* component_name, const ModelPackageOptions& options) const { - OrtModelPackageComponentContext* out = nullptr; - ThrowOnError(GetModelPackageApi().SelectComponent(this->p_, component_name, options, &out)); - return ModelPackageComponentContext{out}; -} - -// ModelPackageComponentContext -inline std::basic_string ModelPackageComponentContext::GetSelectedVariantFolderPath() const { - const ORTCHAR_T* path = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantFolderPath(this->p_, &path)); - return std::basic_string{path}; -} - -inline std::string ModelPackageComponentContext::GetSelectedVariantName() const { - const char* name = nullptr; - ThrowOnError(GetModelPackageApi().ModelPackageComponent_GetSelectedVariantName(this->p_, &name)); - return (name != nullptr) ? std::string{name} : std::string{}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, nullptr, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - const SessionOptions& session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - -inline Session ModelPackageComponentContext::CreateSession(const Env& env, - ConstSessionOptions session_options) { - OrtSession* out = nullptr; - ThrowOnError(GetModelPackageApi().CreateSession(env, this->p_, session_options, &out)); - return Session{out}; -} - namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h index e943a5cf65b11..0dd87c10776d3 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.h @@ -39,6 +39,10 @@ // ORT_RUNTIME_CLASS(ExperimentalType); // +ORT_RUNTIME_CLASS(ModelPackageOptions); +ORT_RUNTIME_CLASS(ModelPackageContext); +ORT_RUNTIME_CLASS(ModelPackageComponentContext); + // // C: function pointer typedefs and name constants // diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 0123b02818584..57a4e472b6f6d 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -35,3 +35,250 @@ * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t* out) + +// --------------------------------------------------------------------------- +// OrtModelPackageApi +// +// A model package is a directory containing one or more *components* (logical models). +// Each component has one or more *variants*, where each variant targets a single +// execution provider (EP). The package manifest declares the EP name, device type, +// and an optional compatibility string for every variant so that the runtime can +// automatically select the best variant for the hardware and EPs available in the +// caller's session options. +// +// The functions below support: +// - creating model package options that capture EP configuration from OrtSessionOptions, +// - loading a package context (manifest + metadata) from a package root path, +// - querying component/variant metadata including per-variant EP information, +// - selecting a component (which also resolves the best-matching variant), +// - querying the selected variant's name and folder path, +// - creating an OrtSession from the selected component context. +// +// Typical flow: +// 1) Create model package options: +// - OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions +// 2) Load package metadata: +// - OrtModelPackageApi_CreateModelPackageContext +// 3) Query metadata (optional): +// - OrtModelPackageApi_ModelPackage_GetSchemaVersion +// - OrtModelPackageApi_ModelPackage_GetComponentCount +// - OrtModelPackageApi_ModelPackage_GetComponentNames +// - OrtModelPackageApi_ModelPackage_GetVariantCount +// - OrtModelPackageApi_ModelPackage_GetVariantNames +// - OrtModelPackageApi_ModelPackage_GetVariantEpName +// 4) Select a component and resolve variant: +// - OrtModelPackageApi_SelectComponent +// 5) Query selected variant info (optional): +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName +// - OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath +// 6) Create session: +// - OrtModelPackageApi_CreateSession +// +// Ownership: +// - Release objects created by this API with the corresponding release functions: +// OrtModelPackageApi_ReleaseModelPackageOptions, +// OrtModelPackageApi_ReleaseModelPackageContext, +// OrtModelPackageApi_ReleaseModelPackageComponentContext. +// +// Opaque handles (OrtModelPackageOptions, OrtModelPackageContext, OrtModelPackageComponentContext) +// are declared in onnxruntime_experimental_c_api.h. +// --------------------------------------------------------------------------- + +/** \brief Create model package options from an existing OrtSessionOptions. + * + * Captures EP configuration (registered execution providers and their devices) from + * the session options for use during variant selection. The resulting OrtModelPackageOptions + * is passed to OrtModelPackageApi_SelectComponent to resolve the best variant for the + * available EPs. + * + * \param[in] env The ORT environment. + * \param[in] session_options Session options containing registered EPs. + * \param[out] out Receives the newly created OrtModelPackageOptions. Must be released + * with OrtModelPackageApi_ReleaseModelPackageOptions. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions, + _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* session_options, + _Outptr_ OrtModelPackageOptions** out) + +/** \brief Release an OrtModelPackageOptions created by + * OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageOptions, + _Frees_ptr_opt_ OrtModelPackageOptions* options) + +/** \brief Create a model package context by parsing the package at the given root path. + * + * Parses the manifest.json and component metadata from the specified directory. + * The returned context provides read-only access to the package structure (components, + * variants, EP declarations). + * + * \param[in] package_root Path to the model package root directory (containing manifest.json). + * \param[out] out Receives the newly created OrtModelPackageContext. Must be released + * with OrtModelPackageApi_ReleaseModelPackageContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateModelPackageContext, + _In_ const ORTCHAR_T* package_root, + _Outptr_ OrtModelPackageContext** out) + +/** \brief Release an OrtModelPackageContext created by OrtModelPackageApi_CreateModelPackageContext. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageContext, + _Frees_ptr_opt_ OrtModelPackageContext* ctx) + +/** \brief Get the schema version declared in the model package manifest. + * + * \param[in] ctx The model package context. + * \param[out] out_version Receives the schema version number. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetSchemaVersion, + _In_ const OrtModelPackageContext* ctx, + _Out_ int64_t* out_version) + +/** \brief Get the number of components in the model package. + * + * \param[in] ctx The model package context. + * \param[out] out_count Receives the component count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentCount, + _In_ const OrtModelPackageContext* ctx, + _Out_ size_t* out_count) + +/** \brief Get the names of all components in the model package. + * + * Returns a pointer to an array of UTF-8 component name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[out] out_names Receives a pointer to an array of component name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetComponentNames, + _In_ const OrtModelPackageContext* ctx, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, + _Out_ size_t* out_count) + +/** \brief Get the number of variants for a given component. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_count Receives the variant count. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantCount, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Out_ size_t* out_count) + +/** \brief Get the names of all variants for a given component. + * + * Returns a pointer to an array of UTF-8 variant name strings. The array and its + * strings are owned by `ctx` and remain valid until the context is released. + * + * \param[in] ctx The model package context. + * \param[in] component_name Name of the component to query. + * \param[out] out_variant_names Receives a pointer to an array of variant name strings. + * \param[out] out_count Receives the number of elements in the array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantNames, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, + _Out_ size_t* out_count) + +/** \brief Get the EP name declared for a (component, variant) pair. + * + * Each variant targets a single EP. `out_ep` receives the EP name string. + * When the variant does not declare an EP, the returned pointer is NULL. + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVariantEpName, + _In_ const OrtModelPackageContext* ctx, + _In_ const char* component_name, + _In_ const char* variant_name, + _Outptr_result_maybenull_ const char** out_ep) + +/** \brief Select a component model and return an opaque component instance. + * + * The variant selection is also performed during this call based on the component + * metadata and the provided options. The returned `OrtModelPackageComponentContext*` is + * independent of `context` lifetime and must be released via + * OrtModelPackageApi_ReleaseModelPackageComponentContext. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_SelectComponent, + _In_ const OrtModelPackageContext* context, + _In_ const char* component_name, + _In_ const OrtModelPackageOptions* options, + _Outptr_ OrtModelPackageComponentContext** out) + +/** \brief Release an OrtModelPackageComponentContext created by OrtModelPackageApi_SelectComponent. */ +ORT_EXPERIMENTAL_API(28, void, OrtModelPackageApi_ReleaseModelPackageComponentContext, + _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx) + +/** \brief Get the name of the selected variant after OrtModelPackageApi_SelectComponent has been called. + * + * String memory is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] out_name Receives the selected variant's name string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const char** out_name) + +/** \brief Get the folder path of the selected variant. + * + * Returns the resolved absolute path to the variant's directory on disk. + * The string is owned by `ctx` and remains valid until the context is released. + * + * \param[in] ctx The component context returned by OrtModelPackageApi_SelectComponent. + * \param[out] folder_path Receives the variant folder path string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath, + _In_ const OrtModelPackageComponentContext* ctx, + _Outptr_ const ORTCHAR_T** folder_path) + +/** \brief Create an OrtSession for a selected file within a component model variant. + * + * The chosen variant (and thus its EP selection) is determined by `context`, which + * was built from an OrtSessionOptions via OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions. + * + * Session options precedence: + * 1. session_options == NULL (default path): + * ORT uses the OrtSessionOptions that was captured when `context` was created. + * Any variant-specific session and provider options declared in the package + * metadata are merged on top. + * + * 2. session_options != NULL (advanced path): + * ORT uses the caller-provided OrtSessionOptions as-is. Variant-specific + * session and provider options from the package metadata are NOT applied. + * Use this when custom EP setup is required (e.g., shared CUDA streams, + * shared QNN EP contexts, custom allocators). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_CreateSession, + _In_ const OrtEnv* env, + _In_ OrtModelPackageComponentContext* context, + _In_opt_ const OrtSessionOptions* session_options, + _Outptr_ OrtSession** session) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c26424a0ab9ec..df14bc8c57f24 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -29,9 +29,6 @@ GraphOptimizationLevel, # noqa: F401 LoraAdapter, # noqa: F401 ModelMetadata, # noqa: F401 - ModelPackageComponentContext, # noqa: F401 - ModelPackageContext, # noqa: F401 - ModelPackageOptions, # noqa: F401 NodeArg, # noqa: F401 OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 diff --git a/onnxruntime/core/session/experimental_c_api.cc b/onnxruntime/core/session/experimental_c_api.cc index d030b68abdb99..458c47bcb58cd 100644 --- a/onnxruntime/core/session/experimental_c_api.cc +++ b/onnxruntime/core/session/experimental_c_api.cc @@ -18,6 +18,14 @@ namespace OrtExperimentalApis { +// Forward declarations driven by the .inc file so the registration table below +// can take the address of every entry, including those defined in other +// translation units linked into onnxruntime_session. +#define ORT_EXPERIMENTAL_API(VER, RET, NAME, ...) \ + RET ORT_API_CALL NAME##_SinceV##VER(__VA_ARGS__) NO_EXCEPTION; +#include "onnxruntime_experimental_c_api.inc" +#undef ORT_EXPERIMENTAL_API + // Test-only experimental function that writes a known sentinel value. // Exists to exercise the experimental API mechanism end-to-end and to serve as a template for future experimental // functions. diff --git a/onnxruntime/core/session/model_package_api.cc b/onnxruntime/core/session/model_package_api.cc index 27abb0f5f7a37..5fd25f9511eb9 100644 --- a/onnxruntime/core/session/model_package_api.cc +++ b/onnxruntime/core/session/model_package_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/model_package_api.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/common/common.h" #include "core/framework/error_code_helper.h" @@ -23,7 +23,9 @@ using namespace onnxruntime; return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, \ "Model package API is not supported in this build") -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, +namespace OrtExperimentalApis { + +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28, _Frees_ptr_opt_ OrtModelPackageOptions* options) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(options); @@ -32,7 +34,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageOptions, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelPackageOptions** out) { @@ -54,7 +56,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOpti API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageContext* ctx) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(ctx); @@ -63,7 +65,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateModelPackageContext_SinceV28, _In_ const ORTCHAR_T* package_root, _Outptr_ OrtModelPackageContext** out) { API_IMPL_BEGIN @@ -89,7 +91,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateModelPackageContext, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ size_t* out_count) { API_IMPL_BEGIN @@ -107,7 +109,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, _Out_ size_t* out_count) { @@ -136,7 +138,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetComponentNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Out_ size_t* out_count) { @@ -158,7 +160,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantCount, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, @@ -190,7 +192,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantNames, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, +ORT_API_STATUS_IMPL(OrtModelPackageApi_SelectComponent_SinceV28, _In_ const OrtModelPackageContext* context, _In_ const char* component_name, _In_ const OrtModelPackageOptions* options, @@ -235,7 +237,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::SelectComponent, API_IMPL_END } -ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, +ORT_API(void, OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28, _Frees_ptr_opt_ OrtModelPackageComponentContext* cix) { #if !defined(ORT_MINIMAL_BUILD) delete reinterpret_cast(cix); @@ -244,7 +246,7 @@ ORT_API(void, OrtModelPackageAPI::ReleaseModelPackageComponentContext, #endif } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const ORTCHAR_T** folder_path) { API_IMPL_BEGIN @@ -269,7 +271,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, +ORT_API_STATUS_IMPL(OrtModelPackageApi_CreateSession_SinceV28, _In_ const OrtEnv* env, _In_ OrtModelPackageComponentContext* ctx, _In_opt_ const OrtSessionOptions* session_options, @@ -383,7 +385,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::CreateSession, // ---------- API table ------------------------------------------------------ -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28, _In_ const OrtModelPackageContext* ctx, _In_ const char* component_name, _In_ const char* variant_name, @@ -417,7 +419,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetVariantEpName, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ int64_t* out_version) { API_IMPL_BEGIN @@ -437,7 +439,7 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackage_GetSchemaVersion, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28, _In_ const OrtModelPackageComponentContext* ctx, _Outptr_ const char** out_name) { API_IMPL_BEGIN @@ -461,40 +463,4 @@ ORT_API_STATUS_IMPL(OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariant API_IMPL_END } -// ---------- API table dispatch --------------------------------------------- - -static constexpr OrtModelPackageApi ort_model_package_api = { - // Options - &OrtModelPackageAPI::CreateModelPackageOptionsFromSessionOptions, - &OrtModelPackageAPI::ReleaseModelPackageOptions, - - // Context - &OrtModelPackageAPI::CreateModelPackageContext, - &OrtModelPackageAPI::ReleaseModelPackageContext, - - // Package-level queries - &OrtModelPackageAPI::ModelPackage_GetSchemaVersion, - &OrtModelPackageAPI::ModelPackage_GetComponentCount, - &OrtModelPackageAPI::ModelPackage_GetComponentNames, - &OrtModelPackageAPI::ModelPackage_GetVariantCount, - &OrtModelPackageAPI::ModelPackage_GetVariantNames, - &OrtModelPackageAPI::ModelPackage_GetVariantEpName, - - // Variant selection and component queries - &OrtModelPackageAPI::SelectComponent, - &OrtModelPackageAPI::ReleaseModelPackageComponentContext, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantName, - &OrtModelPackageAPI::ModelPackageComponent_GetSelectedVariantFolderPath, - - // Session - &OrtModelPackageAPI::CreateSession, - - // End of Version 1.27 - DO NOT MODIFY ABOVE -}; - -static_assert(offsetof(OrtModelPackageApi, CreateSession) / sizeof(void*) == 14, - "Size of initial OrtModelPackageApi cannot change"); - -ORT_API(const OrtModelPackageApi*, OrtModelPackageAPI::GetModelPackageApi) { - return &ort_model_package_api; -} +} // namespace OrtExperimentalApis diff --git a/onnxruntime/core/session/model_package_api.h b/onnxruntime/core/session/model_package_api.h deleted file mode 100644 index 435e3a521c24c..0000000000000 --- a/onnxruntime/core/session/model_package_api.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/session/onnxruntime_c_api.h" - -namespace OrtModelPackageAPI { - -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - -ORT_API(void, ReleaseModelPackageOptions, _Frees_ptr_opt_ OrtModelPackageOptions*); -ORT_API_STATUS_IMPL(CreateModelPackageOptionsFromSessionOptions, - _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, - _Outptr_ OrtModelPackageOptions** out); - -ORT_API(void, ReleaseModelPackageContext, _Frees_ptr_opt_ OrtModelPackageContext*); -ORT_API_STATUS_IMPL(CreateModelPackageContext, - _In_ const ORTCHAR_T* package_root, - _Outptr_ OrtModelPackageContext** out); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentCount, - _In_ const OrtModelPackageContext* ctx, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetComponentNames, - _In_ const OrtModelPackageContext* ctx, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantCount, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantNames, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _Outptr_result_buffer_maybenull_(*out_count) const char* const** out_variant_names, - _Out_ size_t* out_count); - -ORT_API_STATUS_IMPL(SelectComponent, - _In_ const OrtModelPackageContext* context, - _In_ const char* component_name, - _In_ const OrtModelPackageOptions* options, - _Outptr_ OrtModelPackageComponentContext** out); - -ORT_API(void, ReleaseModelPackageComponentContext, - _Frees_ptr_opt_ OrtModelPackageComponentContext* ctx); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantFolderPath, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const ORTCHAR_T** folder_path); - -ORT_API_STATUS_IMPL(CreateSession, - _In_ const OrtEnv* env, - _In_ OrtModelPackageComponentContext* ctx, - _In_opt_ const OrtSessionOptions* session_options, - _Outptr_ OrtSession** session); - -ORT_API_STATUS_IMPL(ModelPackage_GetVariantEpName, - _In_ const OrtModelPackageContext* ctx, - _In_ const char* component_name, - _In_ const char* variant_name, - _Outptr_result_maybenull_ const char** out_ep); - -ORT_API_STATUS_IMPL(ModelPackage_GetSchemaVersion, - _In_ const OrtModelPackageContext* ctx, - _Out_ int64_t* out_version); - -ORT_API_STATUS_IMPL(ModelPackageComponent_GetSelectedVariantName, - _In_ const OrtModelPackageComponentContext* ctx, - _Outptr_ const char** out_name); - -} // namespace OrtModelPackageAPI diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 61a413d92e7fc..f451eaa401497 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -57,7 +57,6 @@ #include "core/session/ort_env.h" #include "core/session/ort_version_check.h" #include "core/session/utils.h" -#include "core/session/model_package_api.h" #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) #include "core/providers/cuda/cuda_provider_factory.h" @@ -3689,10 +3688,6 @@ ORT_API(const OrtCompileApi*, OrtApis::GetCompileApi) { return OrtCompileAPI::GetCompileApi(); } -ORT_API(const OrtModelPackageApi*, OrtApis::GetModelPackageApi) { - return OrtModelPackageAPI::GetModelPackageApi(); -} - ORT_API(void, OrtApis::CreateKeyValuePairs, _Outptr_ OrtKeyValuePairs** out) { auto kvps = std::make_unique(); *out = reinterpret_cast(kvps.release()); @@ -4911,7 +4906,6 @@ static constexpr OrtApi ort_api_1_to_28 = { &OrtApis::GetMemPatternEnabled, &OrtApis::GetSessionExecutionMode, &OrtApis::SessionReleaseCapturedGraph, - &OrtApis::GetModelPackageApi, // End of Version 27 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetExperimentalFunction, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 250a2853a4777..61ece2dd9a682 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -825,9 +825,6 @@ ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtVal _Outptr_result_maybenull_ const int64_t** shape_data, _Out_ size_t* shape_data_count); -// Model Package API -ORT_API(const OrtModelPackageApi*, GetModelPackageApi); - // Experimental API ORT_API(OrtExperimentalFnPtr, GetExperimentalFunction, _In_ const char* name); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 77200ea84778e..2044a128d9540 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -3247,195 +3247,6 @@ including arg name, arg type (contains both type and shape).)pbdoc") #endif }, R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); - - // --- Model Package API --- -#if !defined(ORT_MINIMAL_BUILD) - // Helper to create a PyInferenceSession from a pre-initialized OrtSession* (C API handle). - // PyInferenceSession's owning ctor is protected; this subclass provides access. - struct PyModelPackageSession : PyInferenceSession { - PyModelPackageSession(std::unique_ptr sess) - : PyInferenceSession(std::move(sess)) {} - }; - - // Wrapper classes to manage opaque C handles with proper RAII - struct PyModelPackageContext { - OrtModelPackageContext* ctx_{nullptr}; - PyModelPackageContext(const std::string& package_path) { - auto path = ToPathString(package_path); - const auto* api = Ort::GetApi().GetModelPackageApi(); - Ort::ThrowOnError(api->CreateModelPackageContext(path.c_str(), &ctx_)); - } - ~PyModelPackageContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageContext(ctx_); - } - } - PyModelPackageContext(const PyModelPackageContext&) = delete; - PyModelPackageContext& operator=(const PyModelPackageContext&) = delete; - }; - - struct PyModelPackageComponentContext { - OrtModelPackageComponentContext* ctx_{nullptr}; - ~PyModelPackageComponentContext() { - if (ctx_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageComponentContext(ctx_); - } - } - PyModelPackageComponentContext(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext& operator=(const PyModelPackageComponentContext&) = delete; - PyModelPackageComponentContext() = default; - }; - - struct PyModelPackageOptions { - OrtModelPackageOptions* opts_{nullptr}; - ~PyModelPackageOptions() { - if (opts_) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - api->ReleaseModelPackageOptions(opts_); - } - } - PyModelPackageOptions(const PyModelPackageOptions&) = delete; - PyModelPackageOptions& operator=(const PyModelPackageOptions&) = delete; - PyModelPackageOptions() = default; - }; - - py::class_(m, "ModelPackageContext", - R"pbdoc(Represents an opened model package for inspection and component selection.)pbdoc") - .def(py::init(), py::arg("package_path"), - R"pbdoc(Open a model package from the given directory path.)pbdoc") - .def( - "get_component_names", - [](PyModelPackageContext& self) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetComponentNames(self.ctx_, &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - R"pbdoc(Get the names of all components in the package.)pbdoc") - .def( - "get_variant_names", - [](PyModelPackageContext& self, const std::string& component_name) -> std::vector { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* const* names = nullptr; - size_t count = 0; - Ort::ThrowOnError(api->ModelPackage_GetVariantNames( - self.ctx_, component_name.c_str(), &names, &count)); - std::vector result; - result.reserve(count); - for (size_t i = 0; i < count; ++i) { - result.emplace_back(names[i]); - } - return result; - }, - py::arg("component_name"), - R"pbdoc(Get the variant names for a given component.)pbdoc") - .def( - "get_variant_ep_name", - [](PyModelPackageContext& self, const std::string& component_name, - const std::string& variant_name) -> std::optional { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* ep = nullptr; - Ort::ThrowOnError(api->ModelPackage_GetVariantEpName( - self.ctx_, component_name.c_str(), variant_name.c_str(), &ep)); - if (ep) return std::string(ep); - return std::nullopt; - }, - py::arg("component_name"), py::arg("variant_name"), - R"pbdoc(Get the EP name for a variant. Returns None if not declared.)pbdoc") - .def( - "get_schema_version", - [](PyModelPackageContext& self) -> int64_t { - const auto* api = Ort::GetApi().GetModelPackageApi(); - int64_t version = 0; - Ort::ThrowOnError(api->ModelPackage_GetSchemaVersion(self.ctx_, &version)); - return version; - }, - R"pbdoc(Get the schema version declared in the model package manifest.)pbdoc") - .def( - "select_component", - [](PyModelPackageContext& self, const std::string& component_name, - PyModelPackageOptions& options) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->SelectComponent( - self.ctx_, component_name.c_str(), options.opts_, &result->ctx_)); - return result; - }, - py::arg("component_name"), py::arg("options"), - R"pbdoc(Select a component and resolve its variant based on the provided options. -Returns a ModelPackageComponentContext for inspecting the selected variant.)pbdoc"); - - py::class_(m, "ModelPackageOptions", - R"pbdoc(Options used for variant selection in a model package. -Created from a SessionOptions to capture EP configuration for variant matching.)pbdoc") - .def(py::init([](PySessionOptions& session_options) { - const auto* api = Ort::GetApi().GetModelPackageApi(); - auto result = std::make_unique(); - Ort::ThrowOnError(api->CreateModelPackageOptionsFromSessionOptions( - GetOrtEnv(), &session_options, &result->opts_)); - return result; - }), - py::arg("session_options"), - R"pbdoc(Create model package options from a SessionOptions instance. -The EP configured on the session options is used for variant selection.)pbdoc"); - - py::class_(m, "ModelPackageComponentContext", - R"pbdoc(Represents a selected component within a model package. -Provides access to the resolved variant's files, session options, and metadata.)pbdoc") - .def( - "get_selected_variant_folder_path", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const ORTCHAR_T* path = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantFolderPath(self.ctx_, &path)); - return PathToUTF8String(PathString(path)); - }, - R"pbdoc(Get the folder path of the selected variant.)pbdoc") - .def( - "get_selected_variant_name", - [](PyModelPackageComponentContext& self) -> std::string { - const auto* api = Ort::GetApi().GetModelPackageApi(); - const char* name = nullptr; - Ort::ThrowOnError(api->ModelPackageComponent_GetSelectedVariantName( - self.ctx_, &name)); - return name ? std::string(name) : std::string(); - }, - R"pbdoc(Get the name of the selected variant.)pbdoc") - .def( - "create_session", - [](PyModelPackageComponentContext& self, py::object session_options_obj) -> std::unique_ptr { - const auto* api = Ort::GetApi().GetModelPackageApi(); - OrtSession* ort_session = nullptr; - if (session_options_obj.is_none()) { - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, nullptr, &ort_session)); - } else { - auto& so = session_options_obj.cast(); - Ort::ThrowOnError(api->CreateSession(GetOrtEnv(), self.ctx_, &so, &ort_session)); - } - // OrtSession* is a reinterpret_cast of InferenceSession* - auto* inference_session = reinterpret_cast(ort_session); - std::unique_ptr session_ptr(inference_session); - return std::make_unique(std::move(session_ptr)); - }, - py::arg("session_options") = py::none(), - R"pbdoc(Create an InferenceSession from the selected component variant. - -Args: - session_options: Optional SessionOptions override. If None, uses the options - captured during variant selection with per-file options merged on top. - If provided, variant-specific options are NOT applied. - -Returns: - An InferenceSession ready for inference.)pbdoc"); -#endif // !defined(ORT_MINIMAL_BUILD) } bool InitArray() { diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index 6fb3f8e6ba82f..34f73eb69b149 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -12,6 +12,7 @@ #include "gtest/gtest.h" #include "core/session/model_package/model_package_context.h" +#include "core/session/onnxruntime_experimental_c_api.h" #include "core/session/abi_devices.h" #include "test/autoep/test_autoep_utils.h" #include "test/util/include/asserts.h" @@ -22,6 +23,90 @@ extern std::unique_ptr ort_env; namespace onnxruntime { namespace test { namespace { + +// Typed function pointers for every OrtModelPackageApi_* experimental entry, +// resolved once via the experimental name-based lookup. +struct ModelPackageFns { + OrtExperimental_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn + CreateModelPackageOptionsFromSessionOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn + ReleaseModelPackageOptions{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn + CreateModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn + ReleaseModelPackageContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn + ModelPackage_GetSchemaVersion{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn + ModelPackage_GetComponentCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn + ModelPackage_GetComponentNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn + ModelPackage_GetVariantCount{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn + ModelPackage_GetVariantNames{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn + ModelPackage_GetVariantEpName{nullptr}; + OrtExperimental_OrtModelPackageApi_SelectComponent_SinceV28_Fn + SelectComponent{nullptr}; + OrtExperimental_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn + ReleaseModelPackageComponentContext{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantName{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn + ModelPackageComponent_GetSelectedVariantFolderPath{nullptr}; + OrtExperimental_OrtModelPackageApi_CreateSession_SinceV28_Fn + CreateSession{nullptr}; +}; + +inline const ModelPackageFns& GetModelPackageFns() { + static const ModelPackageFns fns = []() { + const OrtApi* api = &Ort::GetApi(); + ModelPackageFns f; +#define RESOLVE(member, getter) \ + do { \ + f.member = Ort::Experimental::getter(api); \ + if (f.member == nullptr) { \ + throw std::runtime_error(std::string("Failed to resolve experimental " \ + "OrtModelPackageApi_") + \ + #member "_SinceV28"); \ + } \ + } while (0) + RESOLVE(CreateModelPackageOptionsFromSessionOptions, + Get_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn); + RESOLVE(ReleaseModelPackageOptions, + Get_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn); + RESOLVE(CreateModelPackageContext, + Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn); + RESOLVE(ReleaseModelPackageContext, + Get_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn); + RESOLVE(ModelPackage_GetSchemaVersion, + Get_OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentCount, + Get_OrtModelPackageApi_ModelPackage_GetComponentCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetComponentNames, + Get_OrtModelPackageApi_ModelPackage_GetComponentNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantCount, + Get_OrtModelPackageApi_ModelPackage_GetVariantCount_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantNames, + Get_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_Fn); + RESOLVE(ModelPackage_GetVariantEpName, + Get_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn); + RESOLVE(SelectComponent, + Get_OrtModelPackageApi_SelectComponent_SinceV28_Fn); + RESOLVE(ReleaseModelPackageComponentContext, + Get_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantName, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantName_SinceV28_Fn); + RESOLVE(ModelPackageComponent_GetSelectedVariantFolderPath, + Get_OrtModelPackageApi_ModelPackageComponent_GetSelectedVariantFolderPath_SinceV28_Fn); + RESOLVE(CreateSession, + Get_OrtModelPackageApi_CreateSession_SinceV28_Fn); +#undef RESOLVE + return f; + }(); + return fns; +} // ------------------------------------------------------------------ // Helpers to build a test model package on disk // ------------------------------------------------------------------ @@ -233,26 +318,25 @@ std::filesystem::path CreateModelPackageApiTestPackage(bool multi_file_variant = TEST(ModelPackageApiTest, PackageContextQueries) { const auto package_root = CreateModelPackageApiTestPackage(); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr model_pkg_context(nullptr, context_deleter); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); // Query: component count + names size_t component_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); ASSERT_EQ(component_count, 1u); const char* const* component_names = nullptr; size_t component_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetComponentNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentNames( model_pkg_context.get(), &component_names, &component_name_count)); ASSERT_EQ(component_name_count, 1u); ASSERT_NE(component_names, nullptr); @@ -261,13 +345,13 @@ TEST(ModelPackageApiTest, PackageContextQueries) { // Query: variant count + names size_t variant_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantCount( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantCount( model_pkg_context.get(), "model_1", &variant_count)); ASSERT_EQ(variant_count, 2u); const char* const* variant_names = nullptr; size_t variant_name_count = 0; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantNames( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantNames( model_pkg_context.get(), "model_1", &variant_names, &variant_name_count)); ASSERT_EQ(variant_name_count, 2u); @@ -294,17 +378,16 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { - if (p) pkg_api->ReleaseModelPackageOptions(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { + if (p) pkg_api.ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr model_pkg_options(nullptr, options_deleter); @@ -312,25 +395,25 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS std::unique_ptr component_context(nullptr, component_context_deleter); OrtModelPackageOptions* raw_options = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_options)); model_pkg_options.reset(raw_options); OrtModelPackageContext* raw_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_context)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); OrtModelPackageComponentContext* raw_component_context = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(model_pkg_context.get(), - "model_1", - model_pkg_options.get(), - &raw_component_context)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(model_pkg_context.get(), + "model_1", + model_pkg_options.get(), + &raw_component_context)); component_context.reset(raw_component_context); OrtSession* raw_session = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateSession(*ort_env, - component_context.get(), - session_options, - &raw_session)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateSession(*ort_env, + component_context.get(), + session_options, + &raw_session)); Ort::Session session(raw_session); Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); @@ -921,33 +1004,32 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { os << R"({"filename":"mul_1.onnx"})"; } - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { - if (p) pkg_api->ReleaseModelPackageContext(p); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(nullptr, context_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); // variant_1 targets example_ep const char* ep1 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", &ep1)); ASSERT_NE(ep1, nullptr); EXPECT_STREQ(ep1, "example_ep"); // variant_2 targets other_ep const char* ep2 = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_2", &ep2)); ASSERT_NE(ep2, nullptr); EXPECT_STREQ(ep2, "other_ep"); // Optional out-parameter: callers can pass NULL. - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackage_GetVariantEpName( + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", nullptr)); std::filesystem::remove_all(package_root, ec); @@ -990,32 +1072,31 @@ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); // Path looks like .../models/model_1/ -- the folder name is the variant. @@ -1082,28 +1163,27 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr mp_opts(nullptr, options_deleter); std::unique_ptr ctx(nullptr, context_deleter); std::unique_ptr comp_ctx(nullptr, component_context_deleter); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, session_options, &raw_mp_opts)); mp_opts.reset(raw_mp_opts); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); // CreateSession iterates the per-file session_options and dispatches each through OrtApis::AddSessionConfigEntry. @@ -1111,7 +1191,7 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn // Pass nullptr for session_options so the metadata-merge path runs (it is skipped when the caller // supplies their own session_options). OrtSession* raw_session = nullptr; - OrtStatus* st = pkg_api->CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); + OrtStatus* st = pkg_api.CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); // Clean up session first to avoid leaks if assertion fails. if (raw_session != nullptr) { Ort::GetApi().ReleaseSession(raw_session); @@ -1131,60 +1211,6 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn std::filesystem::remove_all(package_root, ec); } -// Test that the C++ RAII wrappers (Ort::ModelPackageContext, etc.) work correctly. -TEST(ModelPackageApiTest, CxxWrappers_PackageContextQueries) { - const auto package_root = CreateModelPackageApiTestPackage(); - - Ort::ModelPackageContext ctx(package_root.c_str()); - - // Component queries - EXPECT_EQ(ctx.GetComponentCount(), 1u); - auto component_names = ctx.GetComponentNames(); - ASSERT_EQ(component_names.size(), 1u); - EXPECT_EQ(component_names[0], "model_1"); - - // Variant queries - EXPECT_EQ(ctx.GetVariantCount("model_1"), 2u); - auto variant_names = ctx.GetVariantNames("model_1"); - ASSERT_EQ(variant_names.size(), 2u); - std::unordered_set variant_set(variant_names.begin(), variant_names.end()); - EXPECT_EQ(variant_set.count("variant_1"), 1u); - EXPECT_EQ(variant_set.count("variant_2"), 1u); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageApiTest, CxxWrappers_SelectComponentAndQueryFileAccessors) { - const auto package_root = CreateModelPackageApiTestPackage(); - - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - - Ort::SessionOptions so; - std::unordered_map ep_options; - so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - - Ort::ModelPackageContext ctx(package_root.c_str()); - auto cix = ctx.SelectComponent("model_1", pkg_opts); - - // Folder path should be non-empty - auto folder = cix.GetSelectedVariantFolderPath(); - EXPECT_FALSE(folder.empty()); - - // Selected variant name should not throw - auto variant_name = cix.GetSelectedVariantName(); - EXPECT_FALSE(variant_name.empty()); - - // CreateSession via C++ wrapper - auto session = cix.CreateSession(*ort_env, so); - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - // ------------------------------------------------------------------ // Test: GetSelectedVariantFolderPath returns correct path even when variant.json is absent. // ------------------------------------------------------------------ @@ -1221,31 +1247,29 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { Ort::SessionOptions so; std::unordered_map ep_options; so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - Ort::ModelPackageOptions pkg_opts(*ort_env, so); - const OrtModelPackageApi* pkg_api = Ort::GetApi().GetModelPackageApi(); - ASSERT_NE(pkg_api, nullptr); + const auto& pkg_api = GetModelPackageFns(); OrtModelPackageOptions* raw_mp_opts = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); - auto options_deleter = [pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api->ReleaseModelPackageOptions(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); + auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; std::unique_ptr mp_opts(raw_mp_opts, options_deleter); OrtModelPackageContext* raw_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->CreateModelPackageContext(package_root.c_str(), &raw_ctx)); - auto context_deleter = [pkg_api](OrtModelPackageContext* p) { if (p) pkg_api->ReleaseModelPackageContext(p); }; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; std::unique_ptr ctx(raw_ctx, context_deleter); OrtModelPackageComponentContext* raw_comp_ctx = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); - auto component_context_deleter = [pkg_api](OrtModelPackageComponentContext* p) { - if (p) pkg_api->ReleaseModelPackageComponentContext(p); + ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); + auto component_context_deleter = [&pkg_api](OrtModelPackageComponentContext* p) { + if (p) pkg_api.ReleaseModelPackageComponentContext(p); }; std::unique_ptr comp_ctx(raw_comp_ctx, component_context_deleter); // GetSelectedVariantFolderPath should return the variant directory even without variant.json. const ORTCHAR_T* selected_folder = nullptr; - ASSERT_ORTSTATUS_OK(pkg_api->ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); const auto result_path = std::filesystem::path(selected_folder); From 2cf6c6c6939533d888233f25f58b030c74e31c14 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 10 Jun 2026 17:26:13 -0700 Subject: [PATCH 02/15] [CUDA] Fix QMoE int4/int8 weight prepack to always use SM80 layout (#28978) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary The CUDA QMoE INT4/INT8 grouped GEMM always dispatches to the Ampere (SM80) CUTLASS kernel — even on Hopper (SM90) — because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized specialisation. This PR makes weight prepacking always emit the SM80 (column-interleaved) `fpA_intB` layout regardless of the runtime device SM, fixing silently-wrong output on Hopper, and centralizes the arch-clamping logic in a single shared helper. It also cleans up the related tests and tightens MoE parity tolerances that were too loose to catch the layout bug. ## Motivation https://github.com/microsoft/onnxruntime/pull/28749 uses 90 for sm90 weight prepacking. On SM90, `isValidHopperMOESpecialisation()` is `false`, so the grouped MoE GEMM falls back to the SM80 kernel. The weight preprocessor, however, skips column interleaving for `arch == 90`, so an auto-detected (`force_arch=-1`) pack on an H200 produced the non-interleaved SM90 layout that the SM80 kernel cannot consume — yielding wrong results. The previous `PrePackIntExpertWeights` logic clamped to `sm_` (passing SM90 through), and the test that exercised the offline packer used auto-detect, so both could emit the wrong layout. ## Key Changes | Area | Change | |---|---| | `fpA_intB_gemm_preprocessors{.h,_impl.cu}` | Extracted `get_arch_for_mixed_gemm_weight_preprocess(int arch)` as a shared, declared helper (clamps SM to the layout group: `<80→75`, `90→90`, else `80`). | | `fpA_intB_gemm_preprocessors_impl.h` | `getLayoutDetailsForTransform` now routes through the shared helper instead of duplicating the arch-range logic. | | `moe_quantization.cc` (`PrePackIntExpertWeights`) | Always packs INT4/INT8 expert weights for the SM80 layout (`get_arch_for_mixed_gemm_weight_preprocess(80)`) instead of clamping to the runtime `sm_`, since the SM80 kernel runs on every GPU. | | `onnxruntime_pybind_quant.cc` (`PackWeightsForMixedGemm`) | Replaced the ad-hoc `{75,80,90}` allowlist with the shared helper, so `force_arch` is clamped consistently with the runtime dispatch (removes the now-unused `` include). | | `contrib_defs.cc` / `moe_quantization.h` | Updated `weights_prepacked` schema/field docs: layouts for `-1`/`1` are EP-determined; for the CUDA EP `-1` and `1` are equivalent today (both SM80), `1` reserved for a future Hopper-specific layout. | | `test_qmoe_cuda.py` | Removed the dead, never-called `preprocess_weights_for_mixed_gemm` helper; the real path (`quant_dequant_blockwise`) already pins `sm=80`. | | `test_moe_cuda.py` | Pinned the offline packer to `arch=80`, and tightened FP16 QMoE parity tolerance from `atol 3.0 (4-bit)` / `2.0 (8-bit)` to `0.5` now that the layout is correct. | | `docs/` | Regenerated `ContribOperators.md` and updated `moe_qmoe.md` to match the new schema docs and SM80-always packing rationale. | ## Testing Notes On an H200 (SM90), with the CUDA 12.x/13.x Python wheel: ```bash python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py python -m pytest onnxruntime/test/python/transformers/test_moe_cuda.py -k "PhiQMoE or qmoe" ``` - `test_qmoe_cuda.py` SwiGLU parity: SM80 layout → max diff ~0.001 (pass, tol 0.1); the prior SM90 layout produced max diff ~1.2 (fail), confirming the fix. - `test_moe_cuda.py` `TestPhiQMoE` (4-bit and 8-bit, all batch/seq combinations): worst observed `max_diff` ≈ 0.375 with the fixed layout, comfortably under the new `atol=0.5`. - `ruff check` passes on both edited test files. --------- Co-authored-by: tlwu Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- docs/ContribOperators.md | 2 +- docs/contrib_ops/cuda/moe_qmoe.md | 77 +++++++++++++++++-- .../cuda/llm/fpA_intB_gemm_preprocessors.h | 2 + .../llm/fpA_intB_gemm_preprocessors_impl.cu | 13 ++++ .../llm/fpA_intB_gemm_preprocessors_impl.h | 6 +- .../contrib_ops/cuda/moe/moe_quantization.cc | 58 +++++++------- .../contrib_ops/cuda/moe/moe_quantization.h | 27 ++++--- .../core/graph/contrib_ops/contrib_defs.cc | 16 +--- .../python/onnxruntime_pybind_quant.cc | 12 +-- .../test/python/transformers/test_moe_cuda.py | 16 ++-- .../python/transformers/test_qmoe_cuda.py | 39 +--------- 11 files changed, 156 insertions(+), 112 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 9d19a95136ad7..38d101786b41a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4949,7 +4949,7 @@ This version of the operator has been available since version 1 of the 'com.micr
use_sparse_mixer : int
Whether to use sparse mixer
weights_prepacked : int
-
Only meaningful when quant_type='int'. Tri-state control over whether the int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS fpA_intB format expected by the runner. -1 (auto): let the execution provider choose its own backward-compatible default; the CUDA EP treats auto as prepacked. 1: the initializers are already prepacked (e.g. produced offline by pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself in PrePack(), matching the behaviour of MatMulNBits and removing the offline pre-pack requirement from exporters. Defaults to -1 (auto) so each execution provider can pick its own backward-compatible default rather than the schema imposing one.
+
Only meaningful when quant_type='int'. Tri-state control over the layout of the int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by -1 and 1 are determined by the execution provider. 0: the initializers are raw, un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.
#### Inputs (6 - 21) diff --git a/docs/contrib_ops/cuda/moe_qmoe.md b/docs/contrib_ops/cuda/moe_qmoe.md index 6d53211ff40cb..36b68889ae582 100644 --- a/docs/contrib_ops/cuda/moe_qmoe.md +++ b/docs/contrib_ops/cuda/moe_qmoe.md @@ -71,6 +71,7 @@ input tokens → router (top-k softmax) → permute by expert | `expert_weight_bits` (QMoE only) | int | 4 | 4 (INT4/MXFP4) or 8 (INT8/FP8). | | `block_size` (QMoE only) | int | -1 | Group size for INT4/INT8 group-wise quantization. -1 = per-output-channel. | | `quant_type` (QMoE only) | string | `"int"` | `"int"`, `"fp4"`, `"fp8"`, `"wfp4afp8"`. See [§3](#3-quantization-modes). | +| `weights_prepacked` (QMoE only) | int | -1 | Tri-state, only meaningful when `quant_type="int"`. The prepacked layouts selected by `-1` and `1` are **EP-determined**. `-1` (default): the INT4/INT8 `fc1`/`fc2` initializers are already prepacked in the EP's default layout (e.g. from `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). `1`: already prepacked in an alternate EP-selected layout. `0`: the initializers are raw `[E, N, K/pack]` tensors (as produced by `quantize_matmul_{4,8}bits`) and the kernel runs the CUTLASS layout transform in `PrePack()`. **Note:** the CUDA EP INT4/INT8 MoE GEMM always runs the Ampere (SM80) kernel — even on SM90 — so it consumes the SM80 `fpA_intB` layout on all architectures; `-1` and `1` are therefore equivalent for the CUDA EP today, and `1` is reserved for a possible future Hopper-specific layout. See [§5.1](#51-weights-input-2--5--8). | ### 2.2 Type Constraints @@ -228,10 +229,53 @@ extra subtraction. ### 5.1 Weights (input 2 / 5 / 8) -Not transformed at runtime. INT4/INT8 weights must already be packed offline by -`pack_weights_for_cuda_mixed_gemm` (see [§6](#6-weight-formats)). MXFP4 weights -must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights are stored -as raw e4m3 bytes (no packing). +**INT4/INT8** weight layout is controlled by the `weights_prepacked` attribute +([§2.1](#21-attributes)). The prepacked layouts selected by `-1` and `1` are +determined by the execution provider: + +- **`weights_prepacked=-1` (default)** — the `fc1`/`fc2` weights are already in + the EP's default prepacked layout (e.g. packed offline by + `pack_weights_for_cuda_mixed_gemm` for the CUDA EP). They are copied to GPU + and consumed as-is. +- **`weights_prepacked=1`** — the `fc1`/`fc2` weights are already in the EP's + **SM90** (Hopper) prepacked layout (reserved; see the note below). +- **`weights_prepacked=0`** — the `fc1`/`fc2` weights are raw, schema-conformant + `[E, N, K/pack]` tensors as produced by `quantize_matmul_{4,8}bits`. `PrePack` + runs the CUTLASS layout transform itself via `PrePackIntExpertWeights`, + removing the offline pre-pack dependency. This makes integer QMoE symmetric + with `MatMulNBits::PrePack_B`. + +> **Single layout on the CUDA EP.** The CUDA EP INT4/INT8 MoE GEMM always +> dispatches to the Ampere (**SM80**) grouped-GEMM kernel — even on SM90 — +> because mixed int-weight + fp16/bf16 activation is not a valid Hopper TMA +> warp-specialized specialisation (`isValidHopperMOESpecialisation` is `false`). +> This matches **TensorRT-LLM**, which likewise routes `W4A16`/`W8A16` MoE to the +> SM80 kernel on Hopper; its Hopper TMA-WS mixed-dtype MoE kernel is reserved for +> `W4A8` (FP8 activation) and `WFP4A16` (FP4 weight). Consequently the CUDA EP +> consumes the **SM80 `fpA_intB` layout on every GPU**, `PrePack` always packs +> for SM80, and `weights_prepacked=-1` and `=1` are equivalent today. `1` is +> accepted and reserved for a possible future Hopper-specific layout (e.g. +> `W4A8`). There is therefore no architecture-match constraint: SM80-format +> weights run correctly on SM90 via the SM80 kernel. + +`PrePackIntExpertWeights` loops over the `E` experts and, per expert, applies the +same transpose + row-permutation / column-interleave / bias / pair-interleave +transform as `pack_weights_for_cuda_mixed_gemm` (see [§6.1](#61-int4-group-wise-quant_typeint-expert_weight_bits4)), +always targeting the SM80 layout. SM75+ is required. The source +`[E, N, K/pack]` initializers are released after their shapes are cached +(`fc1_weights_shape_` / `fc2_weights_shape_`), so peak weight memory stays ~1×. +The prepacked GPU buffers (`packed_fc1_weights_` / `packed_fc2_weights_`) are then +preferred by `ComputeInternal`. If prepacking is disabled at the session level +(`session.disable_prepacking`), the buffers stay null and the raw initializer +pointers are read at compute time instead. + +> **Note**: `weights_prepacked=0` is the only path that triggers an in-`PrePack` +> layout transform for INT weights. FP4 / FP8 / WFP4AFP8 weight handling is +> unaffected. + +MXFP4 weights must be packed by `pack_fp4_weights_for_cuda_moe_gemm`. FP8 weights +are stored as raw e4m3 bytes (no packing). + ### 5.2 INT4/INT8 scales + zero-point → bias @@ -287,7 +331,12 @@ This section covers the five distinct weight encodings supported by QMoE. INT4 packing layout within a byte: `[high_nibble | low_nibble] = [elt_1 | elt_0]`. Each INT4 element is in `[-8, 7]` (signed) before bias, `[0, 15]` after the +8 bias. -#### Preprocessing pipeline (offline, `pack_weights_for_cuda_mixed_gemm`) +#### Preprocessing pipeline (offline `pack_weights_for_cuda_mixed_gemm`, or in-`PrePack` via `PrePackIntExpertWeights`) + +This is the layout transform applied either offline by +`pack_weights_for_cuda_mixed_gemm`, or per-expert inside `PrePack` when +`weights_prepacked=0` (see [§5.1](#51-weights-input-2--5--8)). + 1. **Input layout**: `[N, K]` per expert (Out × In), 2 elements per byte for INT4. 2. **Transpose & signed conversion**: @@ -405,6 +454,17 @@ weights are interchangeable across SMs: — does not use `pack_weights_for_cuda_mixed_gemm`. - **FP8**: no packing. +> **QMoE uses Group A on every GPU.** The table above describes the layouts the +> `pack_weights_for_cuda_mixed_gemm` *preprocessor* can emit. The QMoE INT4/INT8 +> MoE GEMM, however, always dispatches to the Ampere (SM80) grouped-GEMM kernel — +> even on SM90 — because mixed int-weight + fp16/bf16 activation is not a valid +> Hopper TMA warp-specialized specialisation (the same is true in TensorRT-LLM). +> It therefore consumes the **Group A (SM80) layout on all architectures, +> including Hopper**. For QMoE, always pack INT4/INT8 weights for SM80 (`arch=80`), +> and `PrePackIntExpertWeights` (`weights_prepacked=0`) does exactly that +> regardless of the runtime device SM. Group B (SM90) layout is currently unused +> by QMoE. + --- ## 8. SwiGLU Fusion @@ -830,7 +890,7 @@ will not change the operator interface. |-----------|----------| | [test_moe_cuda.py](onnxruntime/test/python/transformers/test_moe_cuda.py) | Standard MoE on CUDA: FP16/BF16, SiLU/GeLU/SwiGLU, routing, GEMM parity. SwiGLU coverage includes both GPT-OSS (`TestSwigluMoE`: interleaved, alpha=1.702/beta=1.0/limit=7.0) and Standard/Llama-Gemma (`TestStandardSwigluMoE`: concatenated `swiglu_fusion=2`, alpha=1.0/beta=0.0/no limit → `SiLU(Gate)×Value`). | | [test_moe_cpu.py](onnxruntime/test/python/transformers/test_moe_cpu.py) | Standard MoE on CPU (smoke). | -| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. | +| [test_qmoe_cuda.py](onnxruntime/test/python/transformers/test_qmoe_cuda.py) | INT4/INT8 QMoE — primary regression signal for the production QMoE path. Exercises `pack_weights_for_cuda_mixed_gemm` and dequant-then-matmul reference. `TestQMoEIntPrePackSmoke` covers the raw-weight `weights_prepacked=0` in-`PrePack` layout transform (smoke test: asserts finite output, not bit-parity). | | [test_qmoe_cpu.py](onnxruntime/test/python/transformers/test_qmoe_cpu.py) | INT4/INT8 QMoE on CPU (smoke). | | [test_qmoe_fp4_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py) | MXFP4 QMoE: quantization utilities, packing, FP16/BF16, SiLU/SwiGLU, top-k and expert-count variants. End-to-end runs on SM120; on SM<120 the dequant fallback is exercised. | | [test_qmoe_fp8_cuda.py](onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py) | FP8 W8A16 QMoE on SM90+ native path and SM<90 dequant fallback. | @@ -954,6 +1014,11 @@ over-aligned by-value parameters. cannot. See [§14.1](#141-msvc-and-tma-grouped-moe-gemm). - **WFP4AFP8 native** requires SM100+ hardware; only the dequant fallback path is validated end-to-end so far. +- **In-`PrePack` INT weight layout transform** (`weights_prepacked=0`) is + currently covered only by a smoke test (`TestQMoEIntPrePackSmoke`), not a + bit-parity check: the existing offline pre-pack harness hardcodes + `force_arch=80` (the same SM80 layout consumed by the CUDA EP on all GPUs), + so a separate parity harness for this path is still pending. - **Hopper W4A8** (INT4 weight + FP8 activation) is not supported — TRT-LLM gates its fast path to SM89 only. diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h index b9e62443145e5..c3b734816cf84 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h @@ -31,6 +31,8 @@ enum class QuantType { W4_AFP8 }; +int get_arch_for_mixed_gemm_weight_preprocess(int arch); + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu index a006612ddadc9..7e83bdda72eab 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.cu @@ -521,6 +521,19 @@ void add_bias_and_interleave_quantized_tensor_inplace_cuda( } } +int get_arch_for_mixed_gemm_weight_preprocess(int arch) { + ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); + if (arch < 80) { + return 75; + } +#ifndef EXCLUDE_SM_90 + if (arch >= 90 && arch < 100) { + return 90; + } +#endif + return 80; +} + void preprocess_weights_for_mixed_gemm_cuda(cudaStream_t stream, int arch, int8_t* preprocessed_quantized_weight, diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h index a8fb411ed0663..47bbe0c0e10ec 100644 --- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors_impl.h @@ -120,11 +120,11 @@ LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { } LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { - ORT_ENFORCE(arch >= 75, "Unsupported CUDA architecture: ", arch); - if (arch < 80) { + arch = get_arch_for_mixed_gemm_weight_preprocess(arch); + if (arch == 75) { return getLayoutDetailsForArch(quant_type); #ifndef EXCLUDE_SM_90 - } else if (arch >= 90 && arch < 100) { + } else if (arch == 90) { return getLayoutDetailsForArch(quant_type); #endif } else { diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index e1ddcac0cea4f..7d1291e004d78 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -62,18 +62,28 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE this->quant_type_ = op_kernel_info.GetAttrOrDefault("quant_type", "int"); ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8", "quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'"); - // ``weights_prepacked`` is an optional tri-state attribute that defaults to - // -1 (auto) in the schema, so each EP picks its own backward-compatible - // default rather than the schema imposing one: - // -1 (auto, also the schema default): the EP decides. The CUDA EP's - // backward-compatible default is "prepacked" because all pre-existing - // tooling ships CUTLASS-prepacked weights. - // 1: initializers are already prepacked; the compute path reads them as-is. - // 0: initializers are raw [E, N, K/pack]; the PrePack hook lays them out. + // ``weights_prepacked`` is an optional tri-state attribute (default -1) that + // declares the layout of the int4/int8 fc1/fc2 weight initializers. The + // concrete prepacked layouts selected by -1 and 1 are determined by the + // execution provider. The CUDA EP maps the tri-state as: + // -1 (default): already prepacked in the EP's default int weight layout. + // 1: already prepacked in an alternate EP-selected int weight layout. + // 0: raw [E, N, K/pack] initializers; the PrePack hook lays them out. + // + // Important: the CUDA QMoE int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (see isValidHopperMOESpecialisation). The kernel therefore + // consumes the SM80/Ampere CUTLASS fpA_intB layout on every GPU. As a result + // the EP default (-1) is the SM80 layout regardless of the runtime device SM, + // and SM80-format weights are valid on SM90 (they run via the SM80 kernel). + // For CUDA today, -1 and 1 are equivalent (both SM80 layout), and 1 is + // reserved for a possible future Hopper-specific layout. + // PrePack (weights_prepacked=0) packs for the SM80 layout accordingly. const int64_t weights_prepacked_mode = op_kernel_info.GetAttrOrDefault("weights_prepacked", static_cast(-1)); ORT_ENFORCE(weights_prepacked_mode == -1 || weights_prepacked_mode == 0 || weights_prepacked_mode == 1, - "weights_prepacked must be -1 (auto), 0, or 1, but got ", weights_prepacked_mode); + "weights_prepacked must be -1 (default), 0, or 1, but got ", weights_prepacked_mode); weights_prepacked_ = (weights_prepacked_mode != 0); #if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE) ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer."); @@ -850,7 +860,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB // layout that the runner consumes and freed the source initializer // (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack`` - // (which already requires ``packed_fc1_weights_ != nullptr``) rather than + // (which already requires both packed weight buffers) rather than // just ``is_int && !weights_prepacked_``: when prepacking is disabled at // the session level (``session.disable_prepacking``) PrePack never runs, // the prepack buffers stay null, and the raw initializer pointers read @@ -1146,6 +1156,9 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al IAllocatorUniquePtr& packed_buf, bool& is_packed) { ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_); + ORT_ENFORCE(sm_ >= 75, + "PrePackIntExpertWeights: quant_type='int' with weights_prepacked=0 requires SM75+ CUDA hardware, got SM", + sm_); const auto& shape = tensor.Shape(); ORT_ENFORCE(shape.NumDimensions() == 3, "PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=", @@ -1158,22 +1171,15 @@ void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, Al const int64_t k_packed = shape[2]; const int64_t k = k_packed * pack_factor; - // Weight packing is architecture-aware (see - // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing - // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that - // skips column interleaving, so it is its own compatibility group. Every - // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares - // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack - // INT8 LDSM and are unsupported. The compute-side runner selects the same - // layout from this clamped arch, so the two cannot drift. - // - // SM75 is passed through unchanged (rather than clamped to 80) even though it - // shares SM80's layout: the compute-side dispatch (getLayoutDetailsForTransform) - // still has a distinct SM75 branch, so mirroring it here avoids confusing a - // reader into thinking prepack and dispatch disagree. - ORT_ENFORCE(sm_ >= 75, - "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_); - const int packing_sm = (sm_ == 90 || sm_ == 75) ? sm_ : 80; + // The CUDA QMoE int4/int8 MoE GEMM always dispatches to the Ampere (SM80) + // grouped-GEMM kernel -- even on SM90 -- because mixed int-weight + fp16/bf16 + // is not a valid Hopper TMA warp-specialized specialisation. The kernel thus + // consumes the SM80 CUTLASS fpA_intB layout on every GPU, so the weights must + // always be preprocessed for SM80 regardless of the runtime device SM. + // (Using get_arch_for_mixed_gemm_weight_preprocess(sm_) here would emit the + // SM90 layout on Hopper, which the SM80 kernel cannot consume -> wrong output.) + const int packing_sm = + onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(80); // Per-expert sizes. const size_t per_expert_bytes = static_cast(n) * static_cast(k) / pack_factor; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h index 5722ac41cc470..2bbadc205b5d8 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.h @@ -46,16 +46,23 @@ class QMoE final : public CudaKernel, public MoEBase { IAllocatorUniquePtr& packed_buf, bool& is_packed); int64_t expert_weight_bits_; bool is_fp16_; - // When true (the schema default), the int4/int8 fc1/fc2 weight - // initializers are already in the CUTLASS fpA_intB layout — produced - // offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the - // compute path reads them as-is. When false, the raw schema-conformant - // ``[E, N, K/pack]`` layout (as produced by - // ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook - // via ``PrePackIntExpertWeights``, removing the offline prepack - // dependency. Only meaningful when ``quant_type_ == "int"``. Derived from - // the optional tri-state ``weights_prepacked`` attribute: -1/auto (or - // absent) maps to true on the CUDA EP, 1 maps to true, 0 maps to false. + // When true, the int4/int8 fc1/fc2 weight initializers are already in a + // CUTLASS fpA_intB layout — produced offline e.g. via + // ``pack_weights_for_cuda_mixed_gemm`` — and the compute path reads them + // as-is. When false, the raw schema-conformant ``[E, N, K/pack]`` layout + // (as produced by ``quantize_matmul_{4,8}bits``) is rewritten inside the + // PrePack hook via ``PrePackIntExpertWeights``, removing the offline + // prepack dependency. Only meaningful when ``quant_type_ == "int"``. + // Derived from the optional tri-state ``weights_prepacked`` attribute: + // -1 (default) and 1 both map to true; 0 maps to false. The concrete + // prepacked layouts selected by -1 and 1 are determined by the execution + // provider. For the CUDA EP the int4/int8 MoE GEMM always dispatches to the + // Ampere (SM80) grouped-GEMM kernel -- even on SM90 -- because mixed + // int-weight + fp16/bf16 activation is not a valid Hopper TMA warp-specialized + // specialisation (matches TensorRT-LLM, which also routes W4A16/W8A16 MoE to + // the SM80 kernel on Hopper). The kernel therefore consumes the SM80 fpA_intB + // layout on every GPU, so -1 and 1 are currently equivalent for the CUDA EP; + // 1 is reserved for a possible future Hopper-specific layout (e.g. W4A8). bool weights_prepacked_ = true; // Cached source weight shapes captured at PrePack time. When the // PrePack hook consumed and released the original int4/int8 weight diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1054fd94ef423..f3f2f521ecab2 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1520,18 +1520,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::STRING, std::string("int")) .Attr("weights_prepacked", - "Only meaningful when quant_type='int'. Tri-state control over whether the " - "int4/int8 fc1/fc2 weight initializers are already laid out in the CUTLASS " - "fpA_intB format expected by the runner. -1 (auto): let the execution provider " - "choose its own backward-compatible default; the CUDA EP treats auto as " - "prepacked. 1: the initializers are already prepacked (e.g. produced offline by " - "pack_weights_for_cuda_mixed_gemm) and are consumed as-is. 0: the initializers " - "are raw, un-prepacked [E, N, K/pack] tensors as produced by " - "quantize_matmul_{4,8}bits; the kernel runs the CUTLASS layout transform itself " - "in PrePack(), matching the behaviour of MatMulNBits and removing the offline " - "pre-pack requirement from exporters. Defaults to -1 (auto) so each execution " - "provider can pick its own backward-compatible default rather than the schema " - "imposing one.", + "Only meaningful when quant_type='int'. Tri-state control over the layout of the " + "int4/int8 fc1/fc2 weight initializers. The concrete prepacked layouts selected by " + "-1 and 1 are determined by the execution provider. 0: the initializers are raw, " + "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits. Defaults to -1.", AttributeProto::INT, static_cast(-1)) .Input(0, diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5b1d590a06234..7220153b4fa17 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -16,7 +16,6 @@ #endif #include #include -#include #include namespace pybind11 { @@ -252,17 +251,8 @@ py::array_t PackWeightsForMixedGemm( cudaDeviceProp device_prop; ThrowIfCudaError(cudaGetDeviceProperties(&device_prop, device_id), "cudaGetDeviceProperties"); sm = device_prop.major * 10 + device_prop.minor; - } else { - // Validate force_arch against the SM versions for which preprocess_weights_for_mixed_gemm_cuda - // has tile/permutation tables. Unknown SMs would silently produce incorrect weight layouts. - static const std::set kSupportedSm = {75, 80, 90}; - if (kSupportedSm.find(sm) == kSupportedSm.end()) { - std::ostringstream oss; - oss << "force_arch=" << sm << " is not a supported SM version. " - << "Pass -1 for auto-detect, or one of: 75, 80, 90 (arch > 90 will fallback to 80)."; - throw std::invalid_argument(oss.str()); - } } + sm = ::onnxruntime::llm::kernels::weight_only::get_arch_for_mixed_gemm_weight_preprocess(sm); auto permutation_map_buffer = make_cuda_ptr(32 * sizeof(int32_t)); diff --git a/onnxruntime/test/python/transformers/test_moe_cuda.py b/onnxruntime/test/python/transformers/test_moe_cuda.py index c5fc826a5a6ed..9677542270a53 100644 --- a/onnxruntime/test/python/transformers/test_moe_cuda.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -152,7 +152,10 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): q_weight_reshaped = q_weight.reshape(n, -1) # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format), and qMoE kernel uses the same format. - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4) + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 4, 80) # So we need to DEQUANTIZE back to get `result`. # scale is [n, block_per_k] @@ -232,8 +235,11 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): ) q_weight_reshaped = q_weight.reshape(n, -1) - # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format) - processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8) + # Pack weights for CUDA mixed-gemm kernel (FpA_IntB format). + # Pin arch=80: the QMoE grouped MoE GEMM always runs the Ampere (SM80) kernel -- even on SM90 -- + # so it consumes the SM80 (column-interleaved) layout on every GPU. Auto-detect (force_arch=-1) + # would emit the non-interleaved SM90 layout on Hopper and produce wrong results. + processed_q_weight = _quantize.pack_weights_for_cuda_mixed_gemm(q_weight_reshaped, n, k, 8, 80) # Dequantize for reference # (q - 128) * scale if using 128 offset? or (q) * scale if symmetric around 0? @@ -1084,8 +1090,8 @@ def parity_check(self): ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (0.3, 0.05), - "FP16:4": (3.0, 1e-2), - "FP16:8": (2.0, 1e-2), + "FP16:4": (0.5, 1e-2), + "FP16:8": (0.5, 1e-2), "BF16:0": (1.0, 1e-2), "BF16:4": (30.0, 1e-1), "BF16:8": (20.0, 1e-1), diff --git a/onnxruntime/test/python/transformers/test_qmoe_cuda.py b/onnxruntime/test/python/transformers/test_qmoe_cuda.py index 993716a4c80b0..c56383d2851d3 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cuda.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cuda.py @@ -137,43 +137,6 @@ def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): ) -def preprocess_weights_for_mixed_gemm( - tensor: torch.Tensor, quant_bits: int, sm: int = -1, do_weight_interleave: bool = True -) -> torch.Tensor: - if len(tensor.shape) == 2: - tensor = tensor.unsqueeze(0) - - # Input tensor shape is [Experts, n, k_packed]. k_packed is k/2 for 4-bit, k for 8-bit. - num_experts = tensor.shape[0] - n = tensor.shape[1] - k_packed = tensor.shape[2] - k = k_packed * 2 if quant_bits == 4 else k_packed - - packed_list = [] - - if _pybind and hasattr(_pybind, "pack_weights_for_cuda_mixed_gemm") and torch.cuda.is_available(): - for i in range(num_experts): - if tensor[i].dtype == torch.bfloat16: - weight = tensor[i].to(torch.float32).cpu().numpy() - else: - weight = tensor[i].cpu().numpy() - packed = _pybind.pack_weights_for_cuda_mixed_gemm(weight, n, k, quant_bits, sm) - # pack_weights_for_cuda_mixed_gemm returns int8 array of shape [packed_size] - # We need to reshape it to (k, n/2) for 4-bit, (k, n) for 8-bit. - output_rows = k - output_cols = n // 2 if quant_bits == 4 else n - packed_tensor = torch.from_numpy(packed).to(tensor.device) - packed_tensor = packed_tensor.view(torch.uint8).view(output_rows, output_cols) - packed_list.append(packed_tensor) - - return torch.stack(packed_list) - else: - # This shall not happen unless older version of onnxruntime is used. - raise ImportError( - "onnxruntime._pybind_state.pack_weights_for_cuda_mixed_gemm not found. Cannot preprocess weights." - ) - - def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True, asymmetric: bool = False): # DEBUG # print(f"DEBUG: quant_dequant input shape={weights.shape}, 4bit={is_4_bit_quantization}, asym={asymmetric}") @@ -2110,7 +2073,7 @@ class TestQMoEIntPrePackSmoke(unittest.TestCase): hardware (the other ``test_swiglu_qmoe_parity_*`` cases in this file fail on H200 / H100 with max-diff > 1.0 on plain main, by inspection — pre-existing). A real parity check can be added once - that harness honours the runtime SM. + that harness honors the runtime SM. """ def _run_one(self, *, hidden_size, inter_size, num_experts, top_k, swiglu_fusion, batch_size): From f2c8a0965cc991a21a2f6cad1a1874eb997503f1 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 11 Jun 2026 10:33:53 +0800 Subject: [PATCH 03/15] webgpu: Fuse FlashAttention decode kernels and extend to any sequence length (#28389) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Extend the FlashAttention decode path to work with any sequence length (not just seq_len=1), with causal masking and `use_seqlen_k` support for static KV cache - Add m_tile optimization to process multiple Q rows per workgroup (m_tile=1/2/4), amortizing K/V loads - Fuse the separate QKT and SplitVx shaders into a single QKV kernel using online softmax, eliminating the intermediate `qk` tensor (`B×H×seq×present_seq`) and reducing dispatch count from 3 to 2 - Route between prefill (FlashAttentionProgram) and split-reduce (fused QKV + VxReduce) paths based on sequence length ## Resolved Issues **Whisper decoding prefill improved from 4.68ms to 1.09ms.** Whisper's decoder attention has a small sequence length but large total sequence length (seq_len=4, total_seq_len=1500). The default prefill shader (FlashAttentionProgram) has low parallelism in this case because each workgroup iterates serially over the full KV cache. The split-reduce path tiles the KV dimension across workgroups, achieving much higher GPU occupancy for this workload shape. ## Details **Fused QKV kernel**: Each workgroup computes QK^T dot products, applies attention bias and causal mask, computes local softmax (per-tile max and sum), normalizes, and multiplies by V — all in one kernel. Per-tile metadata (max, sum) is written for the VxReduce shader to rescale partial outputs using online softmax: `output = Σ(partial_i × local_sum_i × exp(local_max_i - global_max)) / global_sum`. **Path routing** (`use_split_reduce`): The split-reduce path is used when `sequence_length_ < 32`; otherwise the single-kernel FlashAttentionProgram prefill path is used. Microbenchmarks on Phi-4 (32 heads, head_size 128, GQA group 3) show split-reduce is 1.13×-2.07× faster than the fused prefill kernel across `sequence_length ∈ {16, 30, 31}` × `total_sequence_length ∈ {128, 500, 2000}`. The previous heuristic additionally gated on `total_sequence_length_ > 1000`, but that signal is 0 under graph capture (seqlen_k lives on the GPU) and the carve-out is unnecessary because split-reduce is uniformly faster for short Q. ## Test plan - [x] 30/30 MHA unit tests pass - [x] phi4-graph-prune produces correct output - [x] whisper-tiny-int4 produces correct transcription - [x] clang-format clean --- .../webgpu/bert/flash_attention.cc | 267 +++++++++--------- .../contrib_ops/webgpu/bert/flash_attention.h | 59 ++-- .../flash_attention_decode_qkt.wgsl.template | 118 -------- .../flash_attention_decode_qkv.wgsl.template | 197 +++++++++++++ ...sh_attention_decode_split_vx.wgsl.template | 113 -------- ...h_attention_decode_vx_reduce.wgsl.template | 93 ++++-- ..._rotary_embedding_and_copykv.wgsl.template | 5 +- 7 files changed, 426 insertions(+), 426 deletions(-) delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template delete mode 100644 onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 684e050f0201a..02e764d01e05e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -16,6 +16,40 @@ namespace onnxruntime { namespace contrib { namespace webgpu { +// WGSL helper function for normalizing on-device indirect dispatch dims. +// Shared by CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram. +// Mirrors ProgramManager::NormalizeDispatchGroupSize three tiers: +// 1) direct (x, y, z) write when every dim is within the spec limit (65535); +// 2) 2D sqrt collapse when the product fits a square layout; +// 3) 3D cbrt collapse otherwise. +// Consumers are unaffected by the chosen layout: ShaderHelper flattens +// workgroup_id (x, y, z) into a single linear workgroup_idx. +// Caller contract: must register a storage output named exactly +// `indirect_buffer` of array with at least 3 elements. +constexpr const char kNormalizeDispatchGroupSizeFn[] = R"( +fn normalize_dispatch_group_size(x: u32, y: u32, z: u32) { + let limit = 65535u; // WebGPU spec maxComputeWorkgroupsPerDimension + if (x <= limit && y <= limit && z <= limit) { + indirect_buffer[0] = x; + indirect_buffer[1] = y; + indirect_buffer[2] = z; + return; + } + let size = f32(x) * f32(y) * f32(z); + let dispatch_avg_2d = u32(ceil(sqrt(size))); + if (dispatch_avg_2d <= limit) { + indirect_buffer[0] = dispatch_avg_2d; + indirect_buffer[1] = dispatch_avg_2d; + indirect_buffer[2] = 1u; + return; + } + let dispatch_avg_3d = u32(ceil(pow(size, 1.0 / 3.0))); + indirect_buffer[0] = dispatch_avg_3d; + indirect_buffer[1] = dispatch_avg_3d; + indirect_buffer[2] = dispatch_avg_3d; +} +)"; + Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); @@ -28,6 +62,7 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha if (prepare_indirect_dispatch_) { sh.AddOutput("indirect_buffer", ShaderUsage::None); + sh.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; } return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template", @@ -87,13 +122,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { - // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size. - shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n" - << " if (global_idx == 0u) {\n" + shader.AdditionalImplementation() << kNormalizeDispatchGroupSizeFn; + shader.MainFunctionBody() << " if (global_idx == 0u) {\n" << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n" - << " indirect_buffer[0] = num_total_seq_length_tile;\n" - << " indirect_buffer[1] = uniforms.num_heads;\n" - << " indirect_buffer[2] = 1u;\n" + << " normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size);\n" << " }\n\n"; } @@ -120,7 +152,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* K, const Tensor* past_key, Tensor* present_key, const Tensor* V, const Tensor* past_value, Tensor* present_value, - uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) { + uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer, uint32_t num_q_tiles) { // CopyKVCache takes past key/value and current key/value and copies them to present key and value. // This makes it so that FlashAttention only needs to look at present key and value, and saves // number of input buffers in the shader, which we run out of (<=8) without this optimization. @@ -176,7 +208,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt {static_cast(parameters.total_sequence_length_)}, {static_cast(parameters.kv_sequence_length_)}, {tile_size}, - {static_cast(parameters.num_heads_)}}); + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.batch_size_)}, + {num_q_tiles}}); return context.RunProgram(program); } @@ -224,52 +258,66 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(use_shm_path, use_shm_path_)); } -Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); +Status FlashAttentionDecodeQKVProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& q = shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& present_key = shader.AddInput("present_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& present_value = shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_attention_bias_) { shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& out_split_vx = shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); + const auto& metadata = shader.AddOutput("metadata", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const uint32_t tile_size_k_vec = 8; const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec; - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkt.wgsl.template", + return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_qkv.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), + WGSL_TEMPLATE_PARAMETER(q_BNSH, q_BNSH_), WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_PARAMETER(v_head_size_vec, head_size_vec_), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(out_split_vx, out_split_vx), + WGSL_TEMPLATE_VARIABLE(present_key, present_key), + WGSL_TEMPLATE_VARIABLE(present_value, present_value), + WGSL_TEMPLATE_VARIABLE(q, q)); } -Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, - const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) { +Status ComputeFlashAttentionDecodeQKV(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q, + const Tensor* attention_bias, Tensor* out_split_vx, Tensor* present_key, Tensor* present_value, + Tensor* metadata, const Tensor* seqlen_k, + const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length, uint32_t m_tile) { const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size_)) : parameters.scale_; const bool has_attention_bias = attention_bias != nullptr; const int components = 4; + const int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeQKTProgram program{"FlashAttentionDecodeQKT", has_attention_bias, tile_size, use_indirect_dispatch}; + bool q_BNSH = parameters.qkv_format_ == Q_K_V_BNSH; + bool is_unidirectional = parameters.is_unidirectional_; + FlashAttentionDecodeQKVProgram program{"FlashAttentionDecodeQKV", has_attention_bias, tile_size, head_size_vec, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, - {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}}); + {present_key, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } if (has_attention_bias) { program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::Rank}, + program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, {metadata, ProgramTensorMetadataDependency::Rank, 2}}); const uint32_t vectorized_head_size = parameters.head_size_ / components; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -281,10 +329,10 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte if (use_indirect_dispatch) { program.SetIndirectDispatchTensor(indirect_buffer); } else { - program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_total_seq_length_tile); + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_total_seq_length_tile); } program.SetWorkgroupSize(64) - .CacheHint(tile_size, has_attention_bias, use_indirect_dispatch) + .CacheHint(tile_size, head_size_vec, has_attention_bias, use_indirect_dispatch, q_BNSH, is_unidirectional, m_tile) .AddUniformVariables({{static_cast(vectorized_head_size)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(alpha)}, @@ -294,124 +342,72 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte {static_cast(parameters.num_heads_)}, {static_cast(parameters.batch_size_)}, {attn_bias_dim0}, - {attn_bias_dim1}}); + {attn_bias_dim1}, + {static_cast(parameters.sequence_length_)}}); return context.RunProgram(program); } -Status FlashAttentionDecodeSplitVxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("metadata", ShaderUsage::UseUniform); - shader.AddInput("qk", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("present_value", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); +Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); + const auto& metadata = shader.AddInput("metadata", ShaderUsage::UseUniform); if (use_indirect_dispatch_) { shader.AddInput("seqlens_k", ShaderUsage::None); } if (has_head_sink_) { shader.AddInput("head_sink", ShaderUsage::UseUniform); } - shader.AddOutput("out_split_vx", ShaderUsage::UseUniform); - - const uint32_t tile_size_k_vec = 8u; - - return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_split_vx.wgsl.template", - WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), - WGSL_TEMPLATE_PARAMETER(head_size_vec, head_size_vec_), - WGSL_TEMPLATE_PARAMETER(sub_tile_count, WorkgroupSizeX() / tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); -} - -Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeContext& context, - const Tensor* metadata, - const Tensor* qk, - Tensor* out_split_vx, - Tensor* present_value, - const Tensor* seqlen_k, - const WebgpuAttentionParameters& parameters, - const Tensor* indirect_buffer, - uint32_t num_total_seq_length_tile, - uint32_t num_present_sequence_length_tile, - uint32_t tile_size, - bool use_indirect_dispatch, - uint32_t present_sequence_length, - const Tensor* head_sink) { - const int components = 4; - const bool has_head_sink = head_sink != nullptr; - int head_size_vec = parameters.v_head_size_ / components; - FlashAttentionDecodeSplitVxProgram program{"FlashAttentionDecodeSplitVx", tile_size, head_size_vec, use_indirect_dispatch, has_head_sink}; - program.AddInputs({{metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}, - {qk, ProgramTensorMetadataDependency::TypeAndRank}, - {present_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); - program.AddOutputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); // [B, N, split_k, head_size] - const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - if (use_indirect_dispatch) { - program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); - } - if (has_head_sink) { - program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); - } - // SetIndirectDispatchTensor must be called after all AddInput calls because it - // appends the indirect buffer as the last program input. - if (use_indirect_dispatch) { - program.SetIndirectDispatchTensor(indirect_buffer); - } else { - program.SetDispatchGroupSize(batch_heads * num_total_seq_length_tile); - } - program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink) - .SetWorkgroupSize(64) - .AddUniformVariables({{static_cast(parameters.total_sequence_length_)}, - {static_cast(head_size_vec)}, - present_sequence_length, - {static_cast(parameters.n_reps)}, - num_present_sequence_length_tile, - {batch_heads}, - {static_cast(parameters.num_heads_)}}); - - return context.RunProgram(program); -} - -Status FlashAttentionDecodeVxReduceProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input", ShaderUsage::UseUniform); - if (use_indirect_dispatch_) { - shader.AddInput("seqlens_k", ShaderUsage::None); - } - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention_decode_vx_reduce.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_head_sink, has_head_sink_), + WGSL_TEMPLATE_PARAMETER(m_tile, m_tile_), WGSL_TEMPLATE_PARAMETER(seq_tile_size, seq_tile_size_), WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), - WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_)); + WGSL_TEMPLATE_PARAMETER(use_indirect_dispatch, use_indirect_dispatch_), + WGSL_TEMPLATE_VARIABLE(input, input), + WGSL_TEMPLATE_VARIABLE(metadata, metadata), + WGSL_TEMPLATE_VARIABLE(output, output)); } Status ComputeFlashAttentionDecodeVxReduce(onnxruntime::webgpu::ComputeContext& context, const Tensor* out_split_vx, + const Tensor* metadata, Tensor* output, const Tensor* seqlen_k, const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t seq_tile_size, - bool use_indirect_dispatch) { + bool use_indirect_dispatch, + const Tensor* head_sink, + uint32_t m_tile) { const int components = 4; constexpr int tile_size = 8; int tile_head_size = tile_size * components; - FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch}; - program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}}); + bool has_head_sink = head_sink != nullptr; + FlashAttentionDecodeVxReduceProgram program{"FlashAttentionDecodeVxReduce", tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile}; + program.AddInputs({{out_split_vx, ProgramTensorMetadataDependency::TypeAndRank, components}, + {metadata, ProgramTensorMetadataDependency::TypeAndRank, 2}}); if (use_indirect_dispatch) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } + if (has_head_sink) { + program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); + } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}}); const uint32_t num_head_size_tile = static_cast((parameters.v_head_size_ + tile_head_size - 1) / tile_head_size); const uint32_t batch_heads = static_cast(parameters.batch_size_ * parameters.num_heads_); - program.SetDispatchGroupSize(batch_heads * num_head_size_tile) - .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch) + program.SetDispatchGroupSize(batch_heads * ((parameters.sequence_length_ + m_tile - 1) / m_tile) * num_head_size_tile) + .CacheHint(tile_size, seq_tile_size, use_indirect_dispatch, has_head_sink, m_tile) .SetWorkgroupSize(tile_size * tile_size) .AddUniformVariables({{static_cast(parameters.v_head_size_ / components)}, num_total_seq_length_tile, num_present_sequence_length_tile, {num_head_size_tile}, - {batch_heads}}); + {batch_heads}, + {static_cast(parameters.sequence_length_)}, + {static_cast(parameters.num_heads_)}}); return context.RunProgram(program); } @@ -446,14 +442,18 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // Declare query_output at function scope to ensure it persists throughout the function Tensor query_output; + // Compute m_tile early so it can be passed to CopyKVCache for indirect dispatch. + const uint32_t m_tile = parameters.sequence_length_ >= 4 ? 4u : (parameters.sequence_length_ >= 2 ? 2u : 1u); + const uint32_t num_q_tiles = (static_cast(parameters.sequence_length_) + m_tile - 1u) / m_tile; + // Create indirect dispatch buffer if using indirect dispatch Tensor* indirect_buffer_ptr = nullptr; Tensor indirect_buffer; - // Prepare indirect dispatch buffer for decode path with static KV cache - const bool use_indirect_dispatch = !kv_empty && - parameters.sequence_length_ == 1 && - parameters.past_present_share_buffer_ && + // Prepare indirect dispatch buffer for split-reduce path with static KV cache. + // When graph capture is enabled, total_sequence_length_ may be 0 (GPU-based + // seqlen_k), so the indirect buffer computes dispatch sizes on GPU. + const bool use_indirect_dispatch = parameters.past_present_share_buffer_ && seqlen_k != nullptr && context.IsGraphCaptureEnabled(); if (use_indirect_dispatch) { @@ -492,10 +492,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co Q, seqlen_k, cos_cache, sin_cache, &query_output, present_key, present_value, - indirect_buffer_ptr, tile_size)); + indirect_buffer_ptr, tile_size, num_q_tiles)); Q = &query_output; } else { - ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr)); + ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr, num_q_tiles)); } // Extract present_sequence_length directly from present_key tensor shape @@ -503,7 +503,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co // (batch_size, num_heads, total_sequence_length/max_sequence_length, head_size) const uint32_t present_sequence_length = static_cast(present_key->Shape()[2]); - if (parameters.sequence_length_ > 1) { + // Route between prefill path (FlashAttentionProgram, single kernel) + // and split-reduce decode path (QKV + VxReduce, 2 kernels). + // Split-reduce wins for short Q (sequence_length < 32) across all KV + // cache lengths measured: 1.13x-2.07x faster at total_sequence_length + // 128 / 500 / 2000 on a representative LLM (32 heads, head_size 96). + const bool use_split_reduce = parameters.sequence_length_ < 32; + + if (!use_split_reduce) { + // Prefill path: FlashAttentionProgram (single kernel with subgroup shuffles) bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_nvidia = context.AdapterInfo().vendor == std::string_view{"nvidia"}; @@ -545,7 +553,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t prefill_tile_size = is_apple ? 128 : tile_size; const uint32_t num_seq_tile = (parameters.sequence_length_ + prefill_tile_size - 1) / prefill_tile_size; - // Get attention bias dimensions for broadcasting uint32_t attn_bias_dim0 = 1; uint32_t attn_bias_dim1 = 1; if (has_attention_bias) { @@ -570,37 +577,31 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co return context.RunProgram(program); } - // For decode path (sequence_length == 1) - const TensorShapeVector qk_dims({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, present_sequence_length}); - const TensorShape qk_shape(qk_dims); - Tensor qk = context.CreateGPUTensor(Q->DataType(), qk_shape); + // Split-reduce path (QKV + VxReduce) const uint32_t num_total_seq_length_tile = (parameters.total_sequence_length_ + tile_size - 1) / tile_size; const uint32_t num_present_sequence_length_tile = (present_sequence_length + tile_size - 1) / tile_size; // The metadata is used to store the max and sum of each tile. const TensorShapeVector metadata_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, 2}); + parameters.sequence_length_, num_present_sequence_length_tile, 2}); const TensorShape metadata_shape(metadata_dims); Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType(), metadata_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k, - parameters, indirect_buffer_ptr, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch, - present_sequence_length)); const TensorShapeVector out_split_vx_dims({parameters.batch_size_, parameters.num_heads_, - num_present_sequence_length_tile, parameters.head_size_}); + parameters.sequence_length_, num_present_sequence_length_tile, parameters.head_size_}); const TensorShape out_split_vx_shape(out_split_vx_dims); Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value, - seqlen_k, parameters, indirect_buffer_ptr, - num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, - use_indirect_dispatch, present_sequence_length, - head_sink)); - ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, output, seqlen_k, parameters, + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKV(context, Q, attention_bias, &out_split_vx, present_key, present_value, + &metadata, seqlen_k, + parameters, indirect_buffer_ptr, num_total_seq_length_tile, + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + present_sequence_length, m_tile)); + + ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeVxReduce(context, &out_split_vx, &metadata, output, seqlen_k, parameters, num_total_seq_length_tile, - num_present_sequence_length_tile, tile_size, use_indirect_dispatch)); + num_present_sequence_length_tile, tile_size, use_indirect_dispatch, + head_sink, m_tile)); return Status::OK(); } @@ -621,7 +622,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput Tensor* present_key, Tensor* present_value, Tensor* indirect_buffer, - uint32_t tile_size) { + uint32_t tile_size, uint32_t num_q_tiles) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -678,6 +679,8 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput {present_sequence_length}, {tile_size}, {static_cast(dispatch_size)}, + {static_cast(params.batch_size_)}, + {num_q_tiles}, }); program.SetDispatchGroupSize((dispatch_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index e75b6378f67c6..3da6b33b4dc0e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -35,7 +35,9 @@ class SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"tile_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_q_tiles", ProgramUniformVariableDataType::Uint32}); private: bool has_past_; @@ -138,11 +142,14 @@ class FlashAttentionProgram final : public Program { int max_k_step_; }; -class FlashAttentionDecodeQKTProgram final : public Program { +class FlashAttentionDecodeQKVProgram final : public Program { public: - FlashAttentionDecodeQKTProgram(const std::string& kernel_name, - bool has_attention_bias, uint32_t tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeQKVProgram(const std::string& kernel_name, + bool has_attention_bias, uint32_t tile_size, int head_size_vec, + bool use_indirect_dispatch, bool q_BNSH = false, + bool is_unidirectional = false, + uint32_t m_tile = 1) + : Program{kernel_name}, has_attention_bias_(has_attention_bias), tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), q_BNSH_(q_BNSH), is_unidirectional_(is_unidirectional), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -156,41 +163,23 @@ class FlashAttentionDecodeQKTProgram final : public Program { - public: - FlashAttentionDecodeSplitVxProgram(const std::string& kernel_name, uint32_t tile_size, int head_size_vec, bool use_indirect_dispatch, bool has_head_sink = false) - : Program{kernel_name}, tile_size_(tile_size), head_size_vec_(head_size_vec), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink) { - } - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"total_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_size_vec", ProgramUniformVariableDataType::Uint32}, - {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"num_present_sequence_length_tile", ProgramUniformVariableDataType::Uint32}, - {"batch_heads", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}); - - private: uint32_t tile_size_; int head_size_vec_; bool use_indirect_dispatch_; - bool has_head_sink_; + bool q_BNSH_; + bool is_unidirectional_; + uint32_t m_tile_; }; class FlashAttentionDecodeVxReduceProgram final : public Program { public: - FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch) - : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch) { + FlashAttentionDecodeVxReduceProgram(const std::string& kernel_name, uint32_t tile_size, uint32_t seq_tile_size, bool use_indirect_dispatch, bool has_head_sink = false, uint32_t m_tile = 1) + : Program{kernel_name}, tile_size_(tile_size), seq_tile_size_(seq_tile_size), use_indirect_dispatch_(use_indirect_dispatch), has_head_sink_(has_head_sink), m_tile_(m_tile) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -199,12 +188,16 @@ class FlashAttentionDecodeVxReduceProgram final : public Program tile_q: array; -var inner_qk_values: array, tile_size>; -var tile_qk: array; - -#if has_attention_bias - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - // Handle broadcasting: if dimension size is 1, use index 0 - let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); - let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); - - // Calculate flat offset with broadcasting applied - // attention_bias shape: [attn_bias_dim0, attn_bias_dim1, new_seq_length, total_seq_length] - // For decode, new_seq_length is 1, so we can simplify: - let offset = bias_batch_idx * uniforms.attn_bias_dim1 * total_seq_length + - bias_head_idx * total_seq_length + - k_idx; - return attention_bias[offset]; - } -#else - fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t - { - return q_element_t(0); - } -#endif - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; -#if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; -#else - let total_sequence_length = uniforms.total_sequence_length; -#endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - let head_idx = batch_head_idx % uniforms.num_heads; - let batch_idx = batch_head_idx / uniforms.num_heads; - if (batch_idx >= uniforms.batch_size) { - return; - } - let q_offset = batch_idx * uniforms.num_heads * uniforms.head_size_vec + head_idx * uniforms.head_size_vec; - let present_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; - for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { - if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { - tile_q[local_idx] = q[q_offset + k + local_idx]; - } - workgroupBarrier(); - let q_data = tile_q[local_col] * q_element_t(uniforms.alpha); - if (k + local_col < uniforms.head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - inner_qk_values[row_offset + local_row][local_col] += dot(present_key[present_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col], q_data); - } - } - } - workgroupBarrier(); - } - - if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { - var sum = q_element_t(0); - for (var i = 0u; i < tile_size_k_vec; i++) { - sum += inner_qk_values[local_idx][i]; - } - - sum = sum + loadAttentionBias(batch_idx, head_idx, 0u, total_seq_offset + local_idx, total_sequence_length); - tile_qk[local_idx] = sum; - output[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx] = sum; - } - workgroupBarrier(); - - if (local_idx == 0u) { - // Calculate the max and sum in current split. - var l_max = f32(-3.4028234663852886e+38f); - var l_sum = f32(0); - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_max = max(l_max, f32(tile_qk[i])); - } - for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { - l_sum += exp(f32(tile_qk[i]) - l_max); - } - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; - metadata[meta_offset] = metadata_value_t(l_max, l_sum); - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template new file mode 100644 index 0000000000000..524a18ca43245 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkv.wgsl.template @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_attention_bias +#param v_head_size_vec +#param is_unidirectional +#param m_tile +#param q_BNSH +#param sub_tile_count +#param tile_size +#param tile_size_k_vec +#param use_indirect_dispatch + +#use .getByOffset .setByOffset + +// Fused QK^T + softmax + V multiply shader. +// +// Each workgroup processes one KV tile (tile_size rows of present_key/value) +// for m_tile Q rows. The computation has two phases: +// +// Phase 1: QK^T (dot product of Q with K, attention bias, causal mask, +// per-tile max/sum for online softmax) +// Phase 2: Local softmax normalization + V multiply (using local max/sum, +// no cross-workgroup dependency) +// +// The VxReduce shader performs the final rescaling across tiles. + +var tile_q: array, m_tile>; +var inner_qk_values: array, tile_size>, m_tile>; +var tile_qk: array, m_tile>; +var tile_output: array, m_tile>; +var qkv_values: array, sub_tile_count>, m_tile>; +var tile_max: array; +var tile_sum: array; + +#if has_attention_bias + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0); + let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1); + let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.new_sequence_length * total_seq_length + + bias_head_idx * uniforms.new_sequence_length * total_seq_length + + q_idx * total_seq_length + + k_idx; + return attention_bias[offset]; + } +#else + fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32, total_seq_length: u32) -> q_element_t + { + return q_element_t(0); + } +#endif + +$MAIN { + let local_row = u32(local_idx / tile_size_k_vec); + let local_col = local_idx % tile_size_k_vec; + #if use_indirect_dispatch + let total_sequence_length = u32(seqlens_k[0]) + 1u; + #else + let total_sequence_length = uniforms.total_sequence_length; + #endif + let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_total_seq_length_tile] + let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; + let q_tile_idx = (workgroup_idx / num_total_seq_length_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (num_total_seq_length_tile * num_q_tiles)); + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + if (batch_idx >= uniforms.batch_size) { + return; + } + let present_key_offset = batch_head_idx / uniforms.n_reps * uniforms.present_sequence_length * uniforms.head_size_vec; + let present_value_offset = u32(batch_head_idx / uniforms.n_reps) * v_head_size_vec * uniforms.present_sequence_length; + + // ============================================================ + // Phase 1: QK^T computation + // ============================================================ + for (var k: u32 = 0u; k < uniforms.head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size_k_vec && k + local_idx < uniforms.head_size_vec) { + let q_idx = q_base + m; +#if q_BNSH + let q_offset = batch_idx * uniforms.num_heads * uniforms.new_sequence_length * uniforms.head_size_vec + + head_idx * uniforms.new_sequence_length * uniforms.head_size_vec + + q_idx * uniforms.head_size_vec; +#else + let q_offset = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec; +#endif + tile_q[m][local_idx] = q.getByOffset(q_offset + k + local_idx); + } + } + workgroupBarrier(); + if (k + local_col < uniforms.head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let k_data = present_key.getByOffset(present_key_offset + (total_seq_offset + row_offset + local_row) * uniforms.head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_data = tile_q[m][local_col] * q_element_t(uniforms.alpha); + inner_qk_values[m][row_offset + local_row][local_col] += dot(k_data, q_data); + } + } + } + } + workgroupBarrier(); + } + + // Reduce inner_qk_values to tile_qk, apply attention bias and causal mask + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + var sum = q_element_t(0); + for (var i = 0u; i < tile_size_k_vec; i++) { + sum += inner_qk_values[m][local_idx][i]; + } + + sum = sum + loadAttentionBias(batch_idx, head_idx, q_idx, total_seq_offset + local_idx, total_sequence_length); +#if is_unidirectional + if (total_seq_offset + local_idx > total_sequence_length - uniforms.new_sequence_length + q_idx) { + sum = q_element_t(-65504.0f); + } +#endif + tile_qk[m][local_idx] = present_value_element_t(sum); + } + workgroupBarrier(); + + // Compute per-tile max and sum for online softmax + if (local_idx == 0u) { + var l_max = f32(-3.4028234663852886e+38f); + var l_sum = f32(0); + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_max = max(l_max, f32(tile_qk[m][i])); + } + for (var i = 0u; i < tile_size && (total_seq_offset + i) < total_sequence_length; i++) { + l_sum += exp(f32(tile_qk[m][i]) - l_max); + } + tile_max[m] = l_max; + tile_sum[m] = l_sum; + let meta_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile + workgroup_idx % num_total_seq_length_tile; + metadata.setByOffset(meta_offset, metadata_value_t(l_max, l_sum)); + } + } + workgroupBarrier(); + + // ============================================================ + // Phase 2: Local softmax + V multiply + // ============================================================ + + // Normalize tile_qk with local max/sum + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + if (local_idx < tile_size && total_seq_offset + local_idx < total_sequence_length) { + tile_qk[m][local_idx] = present_value_element_t(exp(f32(tile_qk[m][local_idx]) - tile_max[m]) / tile_sum[m]); + } + } + workgroupBarrier(); + + for (var k: u32 = 0u; k < v_head_size_vec; k += tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] = present_value_value_t(0); + } + workgroupBarrier(); + + if (k + local_col < v_head_size_vec) { + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + if (total_seq_offset + row_offset + local_row < total_sequence_length) { + let v_data = present_value.getByOffset(present_value_offset + (total_seq_offset + row_offset + local_row) * v_head_size_vec + k + local_col); + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + qkv_values[m][local_row][local_col] += v_data * tile_qk[m][row_offset + local_row]; + } + } + } + } + workgroupBarrier(); + + if (local_idx < tile_size_k_vec) { + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + for (var i = 0u; i < sub_tile_count; i++) { + tile_output[m][k + local_idx] += qkv_values[m][i][local_idx]; + } + } + } + workgroupBarrier(); + } + + // Write output + let tile_idx = workgroup_idx % num_total_seq_length_tile; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let out_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * v_head_size_vec; + for (var i = local_idx; i < v_head_size_vec; i += workgroup_size_x) { + out_split_vx.setByOffset(out_base + tile_idx * v_head_size_vec + i, tile_output[m][i]); + } + } +} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template deleted file mode 100644 index 6f1ad1ca41b71..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#param has_head_sink -#param tile_size -#param head_size_vec -#param tile_size_k_vec -#param sub_tile_count -#param use_indirect_dispatch - -// Note that this shader adopts similar algorithm with dp4a generation shader. -// -// This algorithm works to compute dot product of v with qk parallelly, by -// processing on the head_size dimension at each step amongst tile_size_k_vec -// threads, and utilizing the remaining threads in the workgroup to process -// additional rows of |present_value| in parallel (such that the values in -// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads -// also reload |present_value| tile_size/sub_tile_count times to compute partial -// dot products of other |present_value| rows in order to complete all tile_size -// |present_value| rows in this workgroup and also reusing the values in -// tile_qk. -// -// The difference with FlashAttentionDecodeQKTProgram is that the dot products -// go through the rows (total_sequence_length) of |present_value| instead of -// columns (head_size_vec). And each workgroup only calculate current -// tile_size's dot products instead of iterating the whole row -// |total_sequence_length|. That's why this shader is a split shader. The final -// reduce will be done in FlashAttentionDecodeReduceProgram. - -// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx -// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate -// memory. The FlashAttentionDecodeQKT can be merged into split shader and do -// the final softmax adjustment in the reduce shader. However, some issues are -// met that when the total sequence length exceeds some value, the result will -// become garbage. Since it can't be resolved in a short time, leave it as TODO -// to fix it in future. - -var tile_qk: array; -var tile_output: array; -var qkv_values: array, sub_tile_count>; - -$MAIN { - let local_row = u32(local_idx / tile_size_k_vec); - let local_col = local_idx % tile_size_k_vec; - #if use_indirect_dispatch - let total_sequence_length = u32(seqlens_k[0]) + 1u; - #else - let total_sequence_length = uniforms.total_sequence_length; - #endif - let num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; - let total_seq_offset = (workgroup_idx % num_total_seq_length_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / num_total_seq_length_tile); - if (batch_head_idx >= uniforms.batch_heads) { - return; - } - let present_offset = u32(batch_head_idx / uniforms.n_reps) * head_size_vec * uniforms.present_sequence_length; - - // Calculate the global max and sum in qk. - var g_max = f32(-3.4028234663852886e+38f); -#if has_head_sink - let head_idx = batch_head_idx % uniforms.num_heads; - let sink_value = f32(head_sink[head_idx]); - g_max = max(g_max, sink_value); -#endif - var g_sum = f32(0); - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - g_max = max(g_max, metadata[meta_offset].x); - } - for (var i = 0u; i < num_total_seq_length_tile; i++) - { - let meta_offset = batch_head_idx * uniforms.num_present_sequence_length_tile + i; - let m_value = metadata[meta_offset]; - g_sum += exp(m_value.x - g_max) * m_value.y; - } -#if has_head_sink - g_sum += exp(sink_value - g_max); -#endif - - if (total_seq_offset + local_idx < total_sequence_length) { - tile_qk[local_idx] = present_value_element_t(exp(f32(qk[batch_head_idx * uniforms.present_sequence_length + total_seq_offset + local_idx]) - g_max) / g_sum); - } - - for (var k: u32 = 0u; k < head_size_vec; k += tile_size_k_vec) { - var value = present_value_value_t(0); - qkv_values[local_row][local_col] = present_value_value_t(0); - workgroupBarrier(); - - if (k + local_col < head_size_vec) { - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - if (total_seq_offset + row_offset + local_row < total_sequence_length) { - value += present_value[present_offset + (total_seq_offset + row_offset + local_row) * head_size_vec + k + local_col] * tile_qk[row_offset + local_row]; - } - } - } - - qkv_values[local_row][local_col] = value; - workgroupBarrier(); - - if (local_idx < tile_size_k_vec) { - for (var i = 0u; i < sub_tile_count; i++) { - tile_output[k + local_idx] += qkv_values[i][local_idx]; - } - } - workgroupBarrier(); - } - - for (var i = local_idx; i < head_size_vec; i += workgroup_size_x) { - let out_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * head_size_vec + (workgroup_idx % num_total_seq_length_tile) * head_size_vec + i; - out_split_vx[out_offset] = tile_output[i]; - } -} diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index f909a87724da6..a3ce0b68cb659 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -1,31 +1,35 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#param has_head_sink +#param m_tile #param seq_tile_size #param tile_size #param use_indirect_dispatch -// Inputs are splits of the GQA output, split into num_total_seq_length_tiles -// rows. This shader needs to add these splits across the row dimension to -// arrive at the final result. The column is head size wide. The reduction -// achieves maximum parallelization by splitting this task first into tile_size -// columns that each workgroup is responsible for. Then within each workgroup -// the task of summation over the num_total_seq_length_tile for the tile_size -// columns is further split in two ways. First across the row dimension to have -// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE -// rows. Then across the column dimension where each thread is responsible for 1 -// column of the TILE_SIZE columns the workgroup is responsible for. +#use .getByOffset .setByOffset + +// This shader reduces partial V outputs from the fused QKV shader. +// Each tile produced a locally-normalized V contribution. To get the +// correct global result, we rescale each tile's contribution using +// per-tile metadata (max, sum) with online softmax: +// +// global_max = max(local_max_i for all tiles) +// global_sum = sum(local_sum_i * exp(local_max_i - global_max)) +// output[h] = sum(partial_i[h] * exp(local_max_i - global_max)) / global_sum var tile_input: array, tile_size>; $MAIN { + let num_q_tiles = (uniforms.new_sequence_length + m_tile - 1) / m_tile; + // Workgroup layout: [batch_heads, num_q_tiles, num_head_size_tile] let head_size_offset = (workgroup_idx % uniforms.num_head_size_tile) * tile_size; - let batch_head_idx = u32(workgroup_idx / uniforms.num_head_size_tile); + let q_tile_idx = (workgroup_idx / uniforms.num_head_size_tile) % num_q_tiles; + let q_base = q_tile_idx * m_tile; + let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * num_q_tiles)); if (batch_head_idx >= uniforms.batch_heads) { return; } - let in_offset = batch_head_idx * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; - var value = output_value_t(0); let local_row = u32(local_idx / tile_size); let local_col = local_idx % tile_size; #if use_indirect_dispatch @@ -35,23 +39,60 @@ $MAIN { let num_total_seq_length_tile = uniforms.num_total_seq_length_tile; #endif - if (head_size_offset + local_col < uniforms.head_size_vec) { - for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { - if (r + local_row < num_total_seq_length_tile) { - value += input[in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col]; + for (var m = 0u; m < m_tile && q_base + m < uniforms.new_sequence_length; m++) { + let q_idx = q_base + m; + let in_offset = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile * uniforms.head_size_vec; + let meta_base = (batch_head_idx * uniforms.new_sequence_length + q_idx) * uniforms.num_present_sequence_length_tile; + + // Compute global max across all tiles + var g_max = f32(-3.4028234663852886e+38f); +#if has_head_sink + let head_idx_for_sink = batch_head_idx % uniforms.num_heads; + let sink_value = f32(head_sink[head_idx_for_sink]); + g_max = max(g_max, sink_value); +#endif + for (var i = 0u; i < num_total_seq_length_tile; i++) { + g_max = max(g_max, metadata.getByOffset(meta_base + i).x); + } + + // Compute global sum with rescaling + var g_sum = f32(0); + for (var i = 0u; i < num_total_seq_length_tile; i++) { + let m_value = metadata.getByOffset(meta_base + i); + g_sum += m_value.y * exp(m_value.x - g_max); + } +#if has_head_sink + g_sum += exp(sink_value - g_max); +#endif + + // Accumulate rescaled partial outputs + var value = output_value_t(0); + if (head_size_offset + local_col < uniforms.head_size_vec) { + for (var r = 0u; r < num_total_seq_length_tile; r += tile_size) { + if (r + local_row < num_total_seq_length_tile) { + let tile_meta = metadata.getByOffset(meta_base + r + local_row); + let rescale_f32 = tile_meta.y * exp(tile_meta.x - g_max) / g_sum; + value += input.getByOffset(in_offset + (r + local_row) * uniforms.head_size_vec + head_size_offset + local_col) * output_value_t(output_element_t(rescale_f32)); + } } } - } - tile_input[local_row][local_col] = value; - workgroupBarrier(); + tile_input[local_row][local_col] = value; + workgroupBarrier(); - if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { - value = output_value_t(0); - for (var i = 0u; i < tile_size; i++) { - value += tile_input[i][local_idx]; + if (local_idx < tile_size && head_size_offset + local_idx < uniforms.head_size_vec) { + value = output_value_t(0); + for (var i = 0u; i < tile_size; i++) { + value += tile_input[i][local_idx]; + } + let head_idx = batch_head_idx % uniforms.num_heads; + let batch_idx = batch_head_idx / uniforms.num_heads; + let output_id = batch_idx * uniforms.new_sequence_length * uniforms.num_heads * uniforms.head_size_vec + + q_idx * uniforms.num_heads * uniforms.head_size_vec + + head_idx * uniforms.head_size_vec + + head_size_offset + local_idx; + output.setByOffset(output_id, value); } - let output_id = batch_head_idx * uniforms.head_size_vec + head_size_offset + local_idx; - output[output_id] = value; + workgroupBarrier(); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template index c64bdf45cdcf8..7b09a3a6af080 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template @@ -41,12 +41,9 @@ $MAIN { #endif #if prepare_indirect_dispatch - // Prepare indirect dispatch buffer for thread 0 if (global_idx == 0u) { let num_total_seq_length_tile = (total_seqlen + uniforms.tile_size - 1u) / uniforms.tile_size; - indirect_buffer[0] = num_total_seq_length_tile; - indirect_buffer[1] = uniforms.num_heads; - indirect_buffer[2] = 1u; + normalize_dispatch_group_size(num_total_seq_length_tile, uniforms.num_heads * uniforms.num_q_tiles, uniforms.batch_size); } #endif From 0931c1a038302d7f32278b533a5e455a506c4680 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:00:31 -0600 Subject: [PATCH 04/15] Fix out-of-bounds error (#28991) This pull request introduces important safety checks to prevent out-of-bounds access in the logits processing code for transformers. The main updates ensure that token IDs are validated against the vocabulary size before being used, which improves robustness and prevents potential crashes. **Safety and robustness improvements:** * Added bounds checking for token IDs in the `RepetitionPenaltyLogitsProcessor::Process` method to ensure only valid IDs are used when accessing `beam_token_scores`. * Added bounds checking for token IDs in the `NoRepeatNGramLogitsProcessor::Process` method to prevent out-of-bounds writes to `beam_token_scores`. * Updated the `NextTokenScores::SetScore` method to return early if the provided `token_id` is out of bounds, replacing the previous assert with a safe check. --- .../contrib_ops/cpu/transformers/logits_processor.cc | 8 ++++++++ .../contrib_ops/cpu/transformers/logits_processor.h | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index d28aae02ab2f1..6e75bb1b5a7c0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -48,7 +48,11 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, unique_word_ids.insert(word_id); } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : unique_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } T score = beam_token_scores[word_id]; // If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability, @@ -89,7 +93,11 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, } } + const int vocab_size = next_token_scores.vocab_size; for (const int32_t word_id : blocked_word_ids) { + if (word_id < 0 || word_id >= vocab_size) { + continue; + } beam_token_scores[word_id] = std::numeric_limits::lowest(); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 6e157d83159f4..fb9638bb09d59 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -28,7 +28,9 @@ struct NextTokenScores { } void SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); + if (token_id < 0 || token_id >= vocab_size) { + return; + } for (int i = 0; i < batch_beam_size; i++) { scores[static_cast(i) * vocab_size + token_id] = score; } From 5c7ba1315d75fd162d5f0a8155a0879ff376b0c5 Mon Sep 17 00:00:00 2001 From: Theodore Cooper <63190431+the0cp@users.noreply.github.com> Date: Fri, 12 Jun 2026 00:04:38 +0800 Subject: [PATCH 05/15] fix: Add Linux NPU discovery through sysfs accel devices (#28703) ## Description This PR adds Linux NPU discovery through sysfs accel devices Currently, `DeviceDiscovery::DiscoverDevicesForPlatform()` on Linux discovers CPU and GPU devices, but NPU discovery is still missing. As a result, plugin execution providers that filter devices by `OrtHardwareDeviceType_NPU` do not receive any NPU hardware devices on Linux, even when the NPU is present and exposed by the kernel. This change scans `/sys/class/accel` for `accelN` devices and creates `OrtHardwareDevice` entries with: - `type = OrtHardwareDeviceType_NPU` - PCI `vendor_id` - PCI `device_id` - `accel_idx` metadata - `pci_bus_id` metadata when available This enables Linux systems with NPUs exposed through the accel subsystem, such as AMD Ryzen AI / XDNA devices, to be reported through ORT device discovery and made available to plugin EP factories. ## Changes - Add Linux sysfs discovery for NPU devices under `/sys/class/accel`. - Read NPU PCI vendor and device IDs from the underlying sysfs device path. - Add NPU metadata including `accel_idx` and `pci_bus_id`. - Include discovered NPU devices in `DeviceDiscovery::DiscoverDevicesForPlatform()`. - Add a `kSysfsAccelPath` constant for the Linux accel sysfs path. ## Motivation Linux plugin EPs that target NPUs rely on ORT passing `OrtHardwareDeviceType_NPU` devices into `GetSupportedDevices()`. Without Linux NPU discovery, those EPs cannot claim NPU devices and provider selection policies such as `PREFER_NPU` silently fall back to CPU. Fixes #28660. --- .../core/platform/linux/device_discovery.cc | 128 +++++++++++++++++- .../platform/linux/npu_device_discovery.h | 32 +++++ .../linux/npu_device_discovery_test.cc | 103 ++++++++++++++ 3 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/platform/linux/npu_device_discovery.h create mode 100644 onnxruntime/test/platform/linux/npu_device_discovery_test.cc diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index 732a5a855d65d..4260ef706befa 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/device_discovery.h" +#include "core/platform/linux/npu_device_discovery.h" #include "core/platform/linux/pci_device_discovery.h" #include @@ -123,7 +124,7 @@ Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& gpu_devices_out) { std::vector gpu_sysfs_path_infos{}; @@ -315,6 +317,119 @@ Status GetGpuDevices(std::vector& gpu_devices_out) { } // namespace +namespace npu_device_discovery { + +Status DetectNpuSysfsPaths(const fs::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out) { + std::error_code error_code{}; + + const bool sysfs_accel_path_exists = fs::exists(sysfs_accel_path, error_code); + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Checking existence of accel sysfs path")); + + if (!sysfs_accel_path_exists) { + npu_sysfs_paths_out = {}; + return Status::OK(); + } + + const auto detect_accel_path = [](const fs::path& sysfs_path, size_t& accel_idx) -> bool { + const auto filename = sysfs_path.filename(); + const auto filename_str = std::string_view{filename.native()}; + + // Look for a filename matching "accelN". N is a number. + constexpr std::string_view prefix = "accel"; + if (filename_str.find(prefix) != 0) { + return false; + } + + size_t parsed_accel_idx{}; + if (!TryParseStringWithClassicLocale(filename_str.substr(prefix.size()), parsed_accel_idx)) { + return false; + } + + accel_idx = parsed_accel_idx; + return true; + }; + + std::vector npu_sysfs_paths{}; + + auto dir_iterator = fs::directory_iterator{sysfs_accel_path, error_code}; + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code, sysfs_accel_path, "Iterating over accel sysfs devices")); + + for (const auto& dir_item : dir_iterator) { + const auto& dir_item_path = dir_item.path(); + + if (size_t accel_idx{}; detect_accel_path(dir_item_path, accel_idx)) { + NpuSysfsPathInfo path_info{}; + path_info.accel_idx = accel_idx; + path_info.path = dir_item_path; + npu_sysfs_paths.emplace_back(std::move(path_info)); + } + } + + npu_sysfs_paths_out = std::move(npu_sysfs_paths); + return Status::OK(); +} + +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out) { + OrtHardwareDevice npu_device{}; + + const auto& sysfs_path = path_info.path; + + uint16_t vendor_id{}; + const auto vendor_id_path = sysfs_path / "device" / "vendor"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(vendor_id_path, vendor_id)); + npu_device.vendor_id = vendor_id; + + uint16_t device_id{}; + const auto device_id_path = sysfs_path / "device" / "device"; + ORT_RETURN_IF_ERROR(ReadValueFromFile(device_id_path, device_id)); + npu_device.device_id = device_id; + + npu_device.metadata.Add("accel_idx", MakeString(path_info.accel_idx)); + + std::optional pci_bus_id; + ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id)); + if (pci_bus_id) { + npu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id)); + } + + npu_device.type = OrtHardwareDeviceType_NPU; + + npu_device_out = std::move(npu_device); + + return Status::OK(); +} + +} // namespace npu_device_discovery + +namespace { + +Status GetNpuDevices(std::vector& npu_devices_out) { + std::vector npu_sysfs_path_infos{}; + ORT_RETURN_IF_ERROR(npu_device_discovery::DetectNpuSysfsPaths(kSysfsAccelPath, npu_sysfs_path_infos)); + + std::vector npu_devices{}; + npu_devices.reserve(npu_sysfs_path_infos.size()); + + for (const auto& npu_sysfs_path_info : npu_sysfs_path_infos) { + OrtHardwareDevice npu_device{}; + if (auto status = npu_device_discovery::GetNpuDeviceFromSysfs(npu_sysfs_path_info, npu_device); !status.IsOK()) { + LOGS_DEFAULT(WARNING) << MakeString("Failed to detect devices under ", + npu_sysfs_path_info.path, + ": ", + status.ErrorMessage()); + continue; + } + npu_devices.emplace_back(std::move(npu_device)); + } + + npu_devices_out = std::move(npu_devices); + + return Status::OK(); +} +} // namespace + std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatform() { std::unordered_set devices; @@ -334,7 +449,16 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } // get NPU devices - // TODO figure out how to discover these + { + std::vector npu_devices{}; + Status npu_device_discovery_status = GetNpuDevices(npu_devices); + if (npu_device_discovery_status.IsOK()) { + devices.insert(std::make_move_iterator(npu_devices.begin()), + std::make_move_iterator(npu_devices.end())); + } else { + LOGS_DEFAULT(WARNING) << "NPU device discovery failed: " << npu_device_discovery_status.ErrorMessage(); + } + } return devices; } diff --git a/onnxruntime/core/platform/linux/npu_device_discovery.h b/onnxruntime/core/platform/linux/npu_device_discovery.h new file mode 100644 index 0000000000000..e139d33988903 --- /dev/null +++ b/onnxruntime/core/platform/linux/npu_device_discovery.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This header exposes Linux NPU device discovery internals for testing. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" +#include "core/session/abi_devices.h" + +namespace onnxruntime { +namespace npu_device_discovery { + +struct NpuSysfsPathInfo { + size_t accel_idx; + std::filesystem::path path; +}; + +// Scans the given sysfs accel directory for NPU accel devices. +Status DetectNpuSysfsPaths(const std::filesystem::path& sysfs_accel_path, + std::vector& npu_sysfs_paths_out); + +// Reads vendor/device IDs and populates an OrtHardwareDevice from an accel sysfs path. +Status GetNpuDeviceFromSysfs(const NpuSysfsPathInfo& path_info, + OrtHardwareDevice& npu_device_out); + +} // namespace npu_device_discovery +} // namespace onnxruntime diff --git a/onnxruntime/test/platform/linux/npu_device_discovery_test.cc b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc new file mode 100644 index 0000000000000..6e2170146f39f --- /dev/null +++ b/onnxruntime/test/platform/linux/npu_device_discovery_test.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/platform/linux/npu_device_discovery.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "test/util/include/asserts.h" + +namespace fs = std::filesystem; + +namespace onnxruntime::test { +namespace { + +void WriteFile(const fs::path& path, const std::string& value) { + std::ofstream f(path); + f << value; +} + +class NpuDeviceDiscoveryTest : public ::testing::Test { + protected: + void SetUp() override { + temp_dir_ = fs::temp_directory_path() / "ort_npu_discovery_test"; + fs::remove_all(temp_dir_); + fs::create_directories(temp_dir_); + } + + void TearDown() override { + fs::remove_all(temp_dir_); + } + + fs::path temp_dir_; +}; + +} // namespace + +TEST_F(NpuDeviceDiscoveryTest, ReturnsEmptyForNonexistentPath) { + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_ / "nonexistent", npu_paths)); + EXPECT_TRUE(npu_paths.empty()); +} + +TEST_F(NpuDeviceDiscoveryTest, DetectsAccelDevices) { + fs::create_directories(temp_dir_ / "accel0"); + fs::create_directories(temp_dir_ / "accel12"); + fs::create_directories(temp_dir_ / "renderD128"); + fs::create_directories(temp_dir_ / "accelabc"); + + std::vector npu_paths; + ASSERT_STATUS_OK(npu_device_discovery::DetectNpuSysfsPaths(temp_dir_, npu_paths)); + + ASSERT_EQ(npu_paths.size(), 2u); + + std::vector accel_indices; + accel_indices.reserve(npu_paths.size()); + for (const auto& npu_path : npu_paths) { + accel_indices.push_back(npu_path.accel_idx); + } + + std::sort(accel_indices.begin(), accel_indices.end()); + + EXPECT_EQ(accel_indices[0], 0u); + EXPECT_EQ(accel_indices[1], 12u); +} + +TEST_F(NpuDeviceDiscoveryTest, GetNpuDeviceFromSysfsReadsVendorDeviceAndMetadata) { + const auto pci_device_dir = temp_dir_ / "pci_devices" / "0000:65:00.0"; + const auto accel_dir = temp_dir_ / "class_accel" / "accel0"; + + fs::create_directories(pci_device_dir); + fs::create_directories(accel_dir); + + WriteFile(pci_device_dir / "vendor", "0x1022"); + WriteFile(pci_device_dir / "device", "0x1502"); + + std::error_code error_code{}; + fs::create_directory_symlink(pci_device_dir, accel_dir / "device", error_code); + ASSERT_FALSE(error_code) << error_code.message(); + + npu_device_discovery::NpuSysfsPathInfo path_info{}; + path_info.accel_idx = 0; + path_info.path = accel_dir; + + OrtHardwareDevice npu_device{}; + ASSERT_STATUS_OK(npu_device_discovery::GetNpuDeviceFromSysfs(path_info, npu_device)); + + EXPECT_EQ(npu_device.type, OrtHardwareDeviceType_NPU); + EXPECT_EQ(npu_device.vendor_id, 0x1022u); + EXPECT_EQ(npu_device.device_id, 0x1502u); + + const auto& entries = npu_device.metadata.Entries(); + EXPECT_NE(entries.find("accel_idx"), entries.end()); + EXPECT_EQ(entries.at("accel_idx"), "0"); + EXPECT_NE(entries.find("pci_bus_id"), entries.end()); + EXPECT_EQ(entries.at("pci_bus_id"), "0000:65:00.0"); +} + +} // namespace onnxruntime::test From 6584194ff9417879e4a6b133df37fd10ac0db1a0 Mon Sep 17 00:00:00 2001 From: Darshak Bhatti <47045043+dabhattimsft@users.noreply.github.com> Date: Thu, 11 Jun 2026 10:09:21 -0700 Subject: [PATCH 06/15] Log EP version on inference failure and in EpDeviceUsage event, Log ORT version (#28794) ### Description Adds new telemetry event for inference failure which logs ep versions and types along with runtime error. Adds logging of ORT version in other telemetry events. Adds logging of ep versions in SessionCreation telemetry ### Motivation and Context To better diagnose failures in inference --------- Co-authored-by: Darshak Bhatti Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/core/platform/telemetry.cc | 13 +++++ onnxruntime/core/platform/telemetry.h | 6 ++ .../core/platform/windows/telemetry.cc | 58 ++++++++++++++++--- onnxruntime/core/platform/windows/telemetry.h | 6 ++ onnxruntime/core/session/inference_session.cc | 22 ++++++- onnxruntime/core/session/inference_session.h | 2 + 6 files changed, 95 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 2008f998384a3..20ec11bf7deb2 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -64,6 +64,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(ir_version); @@ -81,6 +82,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(execution_provider_ids); ORT_UNUSED_PARAMETER(hardware_device_types); ORT_UNUSED_PARAMETER(hardware_vendor_ids); + ORT_UNUSED_PARAMETER(ep_versions); ORT_UNUSED_PARAMETER(use_fp16); ORT_UNUSED_PARAMETER(captureState); } @@ -124,6 +126,15 @@ void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& statu ORT_UNUSED_PARAMETER(line); } +void Telemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + ORT_UNUSED_PARAMETER(session_id); + ORT_UNUSED_PARAMETER(status); + ORT_UNUSED_PARAMETER(ep_versions); + ORT_UNUSED_PARAMETER(ep_device_types); +} + void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { ORT_UNUSED_PARAMETER(session_id); @@ -139,6 +150,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -149,6 +161,7 @@ void Telemetry::LogEpDeviceUsage(uint32_t session_id, ORT_UNUSED_PARAMETER(hardware_device_id); ORT_UNUSED_PARAMETER(hardware_vendor); ORT_UNUSED_PARAMETER(ep_vendor); + ORT_UNUSED_PARAMETER(ep_version); ORT_UNUSED_PARAMETER(assigned_node_count); ORT_UNUSED_PARAMETER(total_runs_since_last); ORT_UNUSED_PARAMETER(total_run_duration_since_last); diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index c2dce68faa2dd..946e34ee35832 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -66,6 +66,7 @@ class Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const; virtual void LogCompileModelStart(uint32_t session_id, @@ -86,6 +87,10 @@ class Telemetry { virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; + virtual void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const; + virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const; @@ -100,6 +105,7 @@ class Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 04b9aaa0eb8ed..342b937ffb656 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -285,6 +285,7 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -318,6 +319,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio const std::string& loaded_from, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -373,7 +375,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -392,6 +395,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { @@ -404,7 +408,8 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info // schemaVersion 1: added hardwareDeviceTypes and hardwareVendorIds - TraceLoggingUInt8(1, "schemaVersion"), + // schemaVersion 2: added executionProviderVersions + TraceLoggingUInt8(2, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingInt64(ir_version, "irVersion"), TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), @@ -423,6 +428,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), TraceLoggingString(hardware_device_types.c_str(), "hardwareDeviceTypes"), TraceLoggingString(hardware_vendor_ids.c_str(), "hardwareVendorIds"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), TraceLoggingString(service_names.c_str(), "serviceNames"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -457,7 +463,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(input_source.c_str(), "inputSource"), TraceLoggingString(output_target.c_str(), "outputTarget"), @@ -466,6 +472,7 @@ void WindowsTelemetry::LogCompileModelStart(uint32_t session_id, TraceLoggingBool(embed_ep_context, "embedEpContext"), TraceLoggingBool(has_external_initializers_file, "hasExternalInitializersFile"), TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -507,7 +514,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingHResult(hr, "hResult"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), @@ -516,6 +523,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #else TraceLoggingWrite(telemetry_provider_handle, @@ -525,7 +533,7 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_ERROR), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(status.Code(), "errorCode"), TraceLoggingUInt32(status.Category(), "errorCategory"), @@ -533,10 +541,35 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #endif } +void WindowsTelemetry::LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const { + if (global_register_count_ == 0 || enabled_ == false) + return; + + TraceLoggingWrite(telemetry_provider_handle, + "RuntimeInferenceError", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingLevel(WINEVENT_LEVEL_ERROR), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingUInt32(status.Code(), "errorCode"), + TraceLoggingUInt32(status.Category(), "errorCategory"), + TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), + TraceLoggingString(ep_versions.c_str(), "executionProviderVersions"), + TraceLoggingString(ep_device_types.c_str(), "executionProviderDeviceTypes"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); +} + void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const { if (global_register_count_ == 0 || enabled_ == false) @@ -559,11 +592,12 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -574,6 +608,7 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const { @@ -588,7 +623,8 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + // schemaVersion 1: added epVersion, runtimeVersion + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(ep_type.c_str(), "executionProviderType"), TraceLoggingString(hardware_device_type.c_str(), "hardwareDeviceType"), @@ -596,9 +632,11 @@ void WindowsTelemetry::LogEpDeviceUsage(uint32_t session_id, TraceLoggingUInt32(hardware_device_id, "hardwareDeviceId"), TraceLoggingString(hardware_vendor.c_str(), "hardwareVendor"), TraceLoggingString(ep_vendor.c_str(), "epVendor"), + TraceLoggingString(ep_version.c_str(), "epVersion"), TraceLoggingInt32(assigned_node_count, "assignedNodeCount"), TraceLoggingUInt32(total_runs_since_last, "totalRunsSinceLast"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDurationSinceLast"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -724,8 +762,9 @@ void WindowsTelemetry::LogModelLoadStart(uint32_t session_id) const { TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } @@ -799,8 +838,9 @@ void WindowsTelemetry::LogRegisterEpLibraryStart(const std::string& registration TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingLevel(WINEVENT_LEVEL_INFO), // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt8(1, "schemaVersion"), TraceLoggingString(registration_name.c_str(), "registrationName"), + TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 8295d042d7ec9..46c262e1479d3 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -59,6 +59,7 @@ class WindowsTelemetry : public Telemetry { const std::string& loadedFrom, const std::vector& execution_provider_ids, const std::string& hardware_device_types, const std::string& hardware_vendor_ids, + const std::string& ep_versions, bool use_fp16, bool captureState) const override; void LogCompileModelStart(uint32_t session_id, @@ -79,6 +80,10 @@ class WindowsTelemetry : public Telemetry { void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; + void LogRuntimeInferenceError(uint32_t session_id, const common::Status& status, + const std::string& ep_versions, + const std::string& ep_device_types) const override; + void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last, const std::unordered_map& duration_per_batch_size) const override; @@ -89,6 +94,7 @@ class WindowsTelemetry : public Telemetry { uint32_t hardware_device_id, const std::string& hardware_vendor, const std::string& ep_vendor, + const std::string& ep_version, int assigned_node_count, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 73f3e77b93177..9db18fc4ac69e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -78,6 +78,7 @@ #include "core/session/environment.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/user_logging_sink.h" @@ -871,7 +872,7 @@ InferenceSession::~InferenceSession() { telemetry_provider.LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } } @@ -2778,6 +2779,7 @@ common::Status InferenceSession::Initialize() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model_->MetaData(), telemetry_.event_name_, execution_providers_.GetIds(), telemetry_.ep_device_types_summary_, telemetry_.ep_device_vendor_ids_summary_, + telemetry_.ep_versions_summary_, model_has_fp16_inputs, false); // Emit one initial EpDeviceUsage event per (EP, device) pair with run counts of 0. @@ -2787,7 +2789,7 @@ common::Status InferenceSession::Initialize() { env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, 0, 0); + ep_info.ep_version, ep_info.assigned_node_count, 0, 0); } LOGS(*session_logger_, INFO) << "Session successfully initialized."; @@ -3456,7 +3458,7 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, env.GetTelemetryProvider().LogEpDeviceUsage( session_id_, ep_info.ep_type, ep_info.hardware_device_type, ep_info.vendor_id, ep_info.device_id, ep_info.vendor, ep_info.ep_vendor, - ep_info.assigned_node_count, + ep_info.ep_version, ep_info.assigned_node_count, telemetry_.total_runs_since_last_, telemetry_.total_run_duration_since_last_); } // reset counters @@ -3470,6 +3472,10 @@ Status InferenceSession::RunImpl(const RunOptions& run_options, Telemetry::kRuntimePerfMaxInterval); } } + } else { + // Log runtime error with EP versions + env.GetTelemetryProvider().LogRuntimeInferenceError(session_id_, retval, telemetry_.ep_versions_summary_, + telemetry_.ep_device_types_summary_); } // log evaluation stop to trace logging provider @@ -4087,6 +4093,7 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { telemetry_.ep_device_info_.clear(); telemetry_.ep_device_types_summary_.clear(); telemetry_.ep_device_vendor_ids_summary_.clear(); + telemetry_.ep_versions_summary_.clear(); // First, count nodes assigned to each EP type after graph partitioning. // The graph node only carries the EP type string, so when a single EP targets @@ -4124,6 +4131,10 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { Telemetry::EpDeviceInfo entry; entry.ep_type = ep_type; entry.ep_vendor = ep_device->ep_vendor; + auto it = ep_device->ep_metadata.Entries().find(kOrtEpDevice_EpMetadataKey_Version); + if (it != ep_device->ep_metadata.Entries().end()) { + entry.ep_version = it->second; + } if (ep_device->device != nullptr) { entry.hardware_device_type = HardwareDeviceTypeToString(ep_device->device->type); entry.vendor_id = ep_device->device->vendor_id; @@ -4158,11 +4169,13 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { // by position against the existing executionProviderIds field. std::ostringstream types_oss; std::ostringstream vendor_ids_oss; + std::ostringstream versions_oss; bool first = true; for (const auto& entry : telemetry_.ep_device_info_) { if (!first) { types_oss << ','; vendor_ids_oss << ','; + versions_oss << ','; } first = false; types_oss << entry.hardware_device_type; @@ -4170,9 +4183,11 @@ void InferenceSession::PopulateEpDeviceInfo(const onnxruntime::Graph& graph) { vendor_ids_oss << "0x" << std::hex << std::uppercase << std::setw(4) << std::setfill('0') << entry.vendor_id << std::dec << std::nouppercase << std::setfill(' '); + versions_oss << entry.ep_type << ':' << entry.ep_version; } telemetry_.ep_device_types_summary_ = types_oss.str(); telemetry_.ep_device_vendor_ids_summary_ = vendor_ids_oss.str(); + telemetry_.ep_versions_summary_ = versions_oss.str(); } #if !defined(ORT_MINIMAL_BUILD) @@ -4419,6 +4434,7 @@ void InferenceSession::LogAllSessions() { graph.DomainToVersionMap(), model_file_name, graph.Name(), model_weight_type, model_graph_hash, model_weight_hash, model->MetaData(), session->telemetry_.event_name_, session->execution_providers_.GetIds(), session->telemetry_.ep_device_types_summary_, session->telemetry_.ep_device_vendor_ids_summary_, + session->telemetry_.ep_versions_summary_, model_has_fp16_inputs, true); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 11c887f6cdc16..19c18652ab67a 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -1023,12 +1023,14 @@ class InferenceSession { uint32_t device_id = 0; // PCI device ID (0 when unavailable) std::string vendor; // e.g. "Qualcomm" std::string ep_vendor; // e.g. "Qualcomm" (from OrtEpDevice) + std::string ep_version; // e.g. "1.2.3" (from OrtEpFactory::GetVersion, empty when unavailable) int assigned_node_count = 0; // # graph nodes assigned to this EP type }; std::vector ep_device_info_; // Pre-formatted comma-separated summaries used to enrich SessionCreation. std::string ep_device_types_summary_; // "NPU,CPU" std::string ep_device_vendor_ids_summary_; // "0x5143,0x0000" + std::string ep_versions_summary_; // "QNNExecutionProvider:1.2.3,CPUExecutionProvider:" } telemetry_; mutable std::mutex telemetry_mutex_; // to ensure thread-safe access to telemetry data From 4b5abcf51b605049ebbed07672018fd6e8d82d47 Mon Sep 17 00:00:00 2001 From: Gopalakrishnan Nallasamy Date: Thu, 11 Jun 2026 11:10:27 -0700 Subject: [PATCH 07/15] Fix STFT complex input frame offsets (#28961) ### Description Fix STFT frame pointer arithmetic for complex-valued input so frame starts are computed in input samples, not trailing real/imag components. Since the frame view pointer is `U*`, one pointer increment advances one full real or complex sample. Also add validation that `frame_step` is positive and keep a defensive bounds check before creating non-owning tensor views. Review feedback addressed: simplified the frame pointer arithmetic, fixed the swapped STFT input comments, documented the defensive bounds check, and added double-complex regression coverage. The new STFT validation/regression tests exclude `kDmlExecutionProvider` because these CPU STFT validation/regression paths do not consistently match DirectML behavior in Windows GPU CI. ### Motivation and Context For complex input shaped `[batch_size, signal_length, 2]`, pointer increments already advance by one real/imag pair. Multiplying frame offsets by `signal_components == 2` again can advance past the valid frame start, allowing later frames to read across batches or beyond the input allocation. ### Testing - `git diff --check -- onnxruntime/core/providers/cpu/signal/dft.cc onnxruntime/test/providers/cpu/signal/signal_ops_test.cc` - `.\.venv\Scripts\python.exe tools\ci_build\build.py --config RelWithDebInfo --build --parallel --target onnxruntime_provider_test --build_dir build\Windows` - `.\onnxruntime_provider_test.exe --gtest_filter="SignalOpsTest.STFTFloat:SignalOpsTest.STFTFrameStepMustBePositive:SignalOpsTest.STFTFloatComplexInputBatched:SignalOpsTest.STFTDoubleComplexInputBatched"` from `build\Windows\RelWithDebInfo\RelWithDebInfo` --------- Co-authored-by: Gopalakrishnan Nallasamy --- onnxruntime/core/providers/cpu/signal/dft.cc | 16 ++-- .../providers/cpu/signal/signal_ops_test.cc | 74 +++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index f462680e02587..f75c5fe12a6f4 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -524,14 +524,15 @@ template static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_onesided, bool /*inverse*/) { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal const auto* signal = ctx->Input(0); const auto frame_step = signal::get_scalar_value_from_tensor(ctx->Input(1)); + ORT_RETURN_IF_NOT(frame_step > 0, "frame_step must be greater than zero."); const auto* window = ctx->Input(2); const auto* frame_length_tensor = ctx->Input(3); @@ -596,8 +597,11 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Run each dft of each batch as if it was a real-valued batch size 1 dft operation for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int64_t i = 0; i < n_dfts; i++) { - auto input_frame_begin = - signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components); + const auto frame_start = i * frame_step; + // Defensive check before creating a non-owning tensor view. n_dfts derivation should keep this in bounds. + ORT_RETURN_IF_NOT(frame_start <= signal_size - window_size, "STFT input frame is out of bounds."); + // signal_data is U*, so one increment advances one input sample, including both lanes for complex input. + auto input_frame_begin = signal_data + (batch_idx * signal_size) + frame_start; auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + (i * dft_output_size * output_components); @@ -619,9 +623,9 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside Status STFT::Compute(OpKernelContext* ctx) const { // Attr("onesided"): default = 1 // Input(0, "signal") type = T1 - // Input(1, "frame_length") type = T2 + // Input(1, "frame_step") type = T2 // Input(2, "window") type = T1, optional - // Input(3, "frame_step") type = T2 + // Input(3, "frame_length") type = T2 // Output(0, "output") type = T1 // Get signal shape diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index 966090eed861e..c87e5d7c4c6e1 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -233,6 +233,80 @@ TEST(SignalOpsTest, STFTFloat) { test.Run(); } +static void TestSTFTInvalidFrameStep(int64_t frame_step) { + OpTester test("STFT", kMinOpsetVersion); + + vector signal(64, 1); + test.AddInput("signal", {1, 64, 1}, signal); + test.AddInput("frame_step", {}, {frame_step}); + vector window(16, 1); + test.AddInput("window", {16}, window); + test.AddInput("frame_length", {}, {16}); + + vector output_shape = {1, 7, 9, 2}; + vector expected_output(1 * 7 * 9 * 2, 0.f); + test.AddOutput("output", output_shape, expected_output); + test.Config(OpTester::ExpectResult::kExpectFailure, "frame_step must be greater than zero"); + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFrameStepMustBePositive) { + TestSTFTInvalidFrameStep(0); + TestSTFTInvalidFrameStep(-1); +} + +template +static void TestSTFTComplexInputBatched() { + OpTester test("STFT", kMinOpsetVersion); + test.AddAttribute("onesided", static_cast(false)); + + constexpr int64_t batch_size = 2; + constexpr int64_t signal_size = 128; + constexpr int64_t signal_components = 2; + constexpr int64_t frame_length = 32; + constexpr int64_t frame_step = 16; + constexpr int64_t n_dfts = 7; + constexpr int64_t dft_output_size = frame_length; + constexpr int64_t output_components = 2; + + vector signal(batch_size * signal_size * signal_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T signal_value = batch_idx == 0 ? static_cast(1) : static_cast(99); + for (int64_t sample_idx = 0; sample_idx < signal_size; ++sample_idx) { + signal[(batch_idx * signal_size + sample_idx) * signal_components] = signal_value; + } + } + + test.AddInput("signal", {batch_size, signal_size, signal_components}, signal); + test.AddInput("frame_step", {}, {frame_step}); + test.AddOptionalInputEdge(); + test.AddInput("frame_length", {}, {frame_length}); + + vector output_shape = {batch_size, n_dfts, dft_output_size, output_components}; + vector expected_output(batch_size * n_dfts * dft_output_size * output_components, static_cast(0)); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const T dc_value = (batch_idx == 0 ? static_cast(1) : static_cast(99)) * static_cast(frame_length); + for (int64_t frame_idx = 0; frame_idx < n_dfts; ++frame_idx) { + expected_output[((batch_idx * n_dfts + frame_idx) * dft_output_size) * output_components] = dc_value; + } + } + + test.AddOutput("output", output_shape, expected_output); + test.SetOutputAbsErr("output", 0.001f); + // DML does not consistently match these CPU STFT validation/regression paths in Windows GPU CI. + test.ConfigExcludeEps({kDmlExecutionProvider}); + test.RunWithConfig(); +} + +TEST(SignalOpsTest, STFTFloatComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + +TEST(SignalOpsTest, STFTDoubleComplexInputBatched) { + TestSTFTComplexInputBatched(); +} + TEST(SignalOpsTest, HannWindowFloat) { OpTester test("HannWindow", kMinOpsetVersion); From cf509d84a906a593e89362e9ff7e2ece1a1e372b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:40:38 -0700 Subject: [PATCH 08/15] QMoE: fail loudly when weights_prepacked=0 but PrePack did not run (#28965) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description When a QMoE model sets `weights_prepacked=0` (raw `[E, N, K/pack]` int weights) and the session has `session.disable_prepacking`, `PrePack()` never runs, so `packed_fc{1,2}_weights_` stay null and `int_weights_consumed_by_prepack` is false. The code then falls through to the raw initializer pointers — but those bytes are not in CUTLASS layout, so the runner consumes them as-if-prepacked and produces silently wrong output with no diagnostic. Changes in `onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc` (`QMoE::ComputeInternal`): - **Int path**: Added a defensive `INVALID_ARGUMENT` guard — when `is_int && !weights_prepacked_` but either prepack buffer is null, return a clear error instead of feeding non-CUTLASS bytes to the runner. - **wfp4afp8 native path**: Same fall-through (`packed_fp4_fc{1,2}_weights_ ? ... : raw`) replaced with an explicit guard that errors when the repacked FP4 buffers were not produced. Also added a focused regression test in `onnxruntime/test/contrib_ops/moe_test.cc` covering `quant_type='int'` with `weights_prepacked=0` and `session.disable_prepacking=1`, asserting that QMoE fails with an actionable error instead of producing output. Merged the branch with the latest `main`. ### Motivation and Context A prior fix removed the null-pointer crash on this path but left a misleading-success outcome that is newly user-reachable via the `weights_prepacked=0` contract — the exact silent-failure mode the offline-path work set out to eliminate. These guards convert that into a loud, actionable error. The wfp4afp8 branch shares the same fall-through and is hardened for consistency. The added regression test ensures this fail-loudly behavior remains covered going forward, especially when prepacking is disabled at the session level. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../contrib_ops/cuda/moe/moe_quantization.cc | 27 +++++++- onnxruntime/test/contrib_ops/moe_test.cc | 69 +++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index 7d1291e004d78..bc34a2e83318a 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -234,6 +234,18 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { // to the runner. const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr && packed_fc2_weights_ != nullptr; + // When ``weights_prepacked == 0`` the raw ``[E, N, K/pack]`` int weights must be + // converted to the CUTLASS fpA_intB layout by PrePack before the runner can consume + // them. If PrePack never ran (e.g. ``session.disable_prepacking`` is set), the prepack + // buffers stay null and falling through to the raw initializer pointers would feed + // non-CUTLASS bytes to the runner, producing silently wrong output. Fail loudly instead. + if (is_int && !weights_prepacked_ && + (packed_fc1_weights_ == nullptr || packed_fc2_weights_ == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE weights_prepacked=0 requires PrePack to run, but the int weight " + "buffers were not produced (is session.disable_prepacking set?). Provide " + "CUTLASS-prepacked weights with weights_prepacked=1, or enable prepacking."); + } const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input(2); const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input(3) : nullptr; const Tensor* fc1_experts_bias_optional = context->Input(4); @@ -854,8 +866,19 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr; const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr; if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) { - fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data; - fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data; + // The native CUTLASS WFP4AFP8 path consumes weights in the repacked FP4 + // layout produced by PrePack. If PrePack never ran (e.g. + // ``session.disable_prepacking`` is set) the repacked buffers stay null and + // falling through to the raw initializer bytes would feed a non-CUTLASS + // layout to the runner, producing silently wrong output. Fail loudly. + if (packed_fp4_fc1_weights_ == nullptr || packed_fp4_fc2_weights_ == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "QMoE wfp4afp8 requires PrePack to run, but the repacked FP4 weight " + "buffers were not produced (is session.disable_prepacking set?). " + "Enable prepacking to use the native WFP4AFP8 path."); + } + fc1_weight_data = packed_fp4_fc1_weights_.get(); + fc2_weight_data = packed_fp4_fc2_weights_.get(); } else if (int_weights_consumed_by_prepack) { // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB // layout that the runner consumes and freed the source initializer diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index cc50494273aad..38cc3014aa873 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -778,6 +779,74 @@ TEST(MoETest, MoETest_Mixtral) { 2 /*top_k*/); } +TEST(MoETest, QMoETest_CUDA_Int4_DisablePrepackingFailsLoudly) { + constexpr int min_cuda_arch = 700; + if (!HasCudaEnvironment(min_cuda_arch)) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA execution provider not available"; + } + + constexpr int64_t num_rows = 1; + constexpr int64_t num_experts = 1; + constexpr int64_t hidden_size = 128; + constexpr int64_t inter_size = 128; + constexpr int64_t expert_weight_bits = 4; + constexpr int64_t pack_size = 8 / expert_weight_bits; + + const std::vector input(num_rows * hidden_size, 0.0f); + const std::vector router_probs(num_rows * num_experts, 1.0f); + const std::vector fc1_experts_weights(num_experts * inter_size * (hidden_size / pack_size), 0); + const std::vector fc2_experts_weights(num_experts * hidden_size * (inter_size / pack_size), 0); + const std::vector fc1_scales(num_experts * inter_size, 1.0f); + const std::vector fc2_scales(num_experts * hidden_size, 1.0f); + const std::vector dummy_output(num_rows * hidden_size, 0.0f); + + OpTester cuda_tester("QMoE", 1, onnxruntime::kMSDomain); + cuda_tester.AddAttribute("k", 1); + cuda_tester.AddAttribute("activation_type", "identity"); + cuda_tester.AddAttribute("normalize_routing_weights", 1); + cuda_tester.AddAttribute("expert_weight_bits", expert_weight_bits); + cuda_tester.AddAttribute("quant_type", "int"); + cuda_tester.AddAttribute("weights_prepacked", 0); + + const std::vector input_dims = {num_rows, hidden_size}; + const std::vector router_probs_dims = {num_rows, num_experts}; + const std::vector fc1_experts_weights_dims = {num_experts, inter_size, hidden_size / pack_size}; + const std::vector fc2_experts_weights_dims = {num_experts, hidden_size, inter_size / pack_size}; + const std::vector fc1_scales_dims = {num_experts, inter_size}; + const std::vector fc2_scales_dims = {num_experts, hidden_size}; + const std::vector output_dims = {num_rows, hidden_size}; + + cuda_tester.AddInput("input", input_dims, ToFloat16(input)); + cuda_tester.AddInput("router_probs", router_probs_dims, ToFloat16(router_probs)); + cuda_tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + cuda_tester.AddInput("fc1_scales", fc1_scales_dims, ToFloat16(fc1_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + cuda_tester.AddInput("fc2_scales", fc2_scales_dims, ToFloat16(fc2_scales)); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOptionalInputEdge(); + cuda_tester.AddOutput("output", output_dims, ToFloat16(dummy_output)); + + SessionOptions session_options; + session_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "1"; + + std::vector> cuda_execution_providers; + cuda_execution_providers.push_back(std::move(cuda_ep)); + cuda_tester.Run(session_options, + OpTester::ExpectResult::kExpectFailure, + "QMoE weights_prepacked=0 requires PrePack to run", + {}, + nullptr, + &cuda_execution_providers); +} + TEST(MoETest, QMoETest_Mixtral_Int4) { // This test uses FC3 (gated SiLU / Mixtral pattern) with dimensions too small for the // CUTLASS kernel (needs hidden_size >= 128, inter_size >= 128). CPU QMoE does not From b823aecca44f3d80ce018cc53189a32c4dddf94a Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 11 Jun 2026 16:39:14 -0600 Subject: [PATCH 09/15] Fix oob dereference (#29012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This pull request strengthens shape inference validation for several custom BERT-related ONNX operators by adding explicit rank checks for input tensors. These changes ensure that input tensors meet minimum rank requirements, improving error messaging and preventing incorrect shape propagation. **Enhanced shape validation for custom ONNX operators:** *RelativePositionBias and GatedRelativePositionBias:* - Added checks to ensure `bias_table` (for `RelativePositionBias`) and `token_offset` (for `GatedRelativePositionBias`) inputs have rank ≥ 2, with clear error messages if not. [[1]](diffhunk://#diff-8bf31275168b1e4a2aecd6760acf0ef92347134b003dfdc687c5d3cec4a178ecR1959-R1965) [[2]](diffhunk://#diff-8bf31275168b1e4a2aecd6760acf0ef92347134b003dfdc687c5d3cec4a178ecR2228-R2230) *CausalConvWithState:* - Added checks to ensure both `input` and `weight` tensors have rank ≥ 2, failing shape inference with descriptive errors if violated. *LinearAttention:* - Added checks to ensure `query` and `value` tensors have rank ≥ 3 for both output and state shape inference, with early returns or errors if requirements are not met. [[1]](diffhunk://#diff-8bf31275168b1e4a2aecd6760acf0ef92347134b003dfdc687c5d3cec4a178ecR2460-R2465) [[2]](diffhunk://#diff-8bf31275168b1e4a2aecd6760acf0ef92347134b003dfdc687c5d3cec4a178ecR2483-R2486) *SkipLayerNormalization:* - Added a check to ensure the `input` tensor has rank ≥ 1, improving error reporting for invalid input shapes. --- .../core/graph/contrib_ops/bert_defs.cc | 25 +++++++++++++++++++ .../contrib_ops/shape_inference_functions.cc | 3 +++ 2 files changed, 28 insertions(+) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index cc0b8533033a6..5d537fa59bfab 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1956,7 +1956,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("U", {"tensor(int64)"}, "Constrain sequence_length to int tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasInputShape(ctx, 0)) { + return; + } auto& bias_table_shape = getInputShape(ctx, 0); + if (bias_table_shape.dim_size() < 2) { + fail_shape_inference("RelativePositionBias: bias_table must have rank >= 2"); + } TensorShapeProto output_shape; output_shape.add_dim()->set_dim_value(1); *output_shape.add_dim() = bias_table_shape.dim(1); @@ -2219,6 +2225,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( // Output shape: (batch_size, num_heads, seq_len, seq_len) if (hasInputShape(ctx, 6)) { auto& token_offset_shape = getInputShape(ctx, 6); + if (token_offset_shape.dim_size() < 2) { + fail_shape_inference("GatedRelativePositionBias: token_offset must have rank >= 2"); + } TensorShapeProto output_shape; *output_shape.add_dim() = token_offset_shape.dim(0); output_shape.add_dim()->set_dim_value(num_heads); @@ -2317,6 +2326,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { auto& input_shape = getInputShape(ctx, 0); auto& weight_shape = getInputShape(ctx, 1); + if (input_shape.dim_size() < 2) { + fail_shape_inference("CausalConvWithState: input must have rank >= 2"); + } + if (weight_shape.dim_size() < 2) { + fail_shape_inference("CausalConvWithState: weight must have rank >= 2"); + } int64_t ndim = getAttribute(ctx, "ndim", 1); TensorShapeProto state_shape; *state_shape.add_dim() = input_shape.dim(0); // batch_size @@ -2442,6 +2457,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); + if (query_shape.dim_size() < 3) { + fail_shape_inference("LinearAttention: query must have rank >= 3"); + } + if (value_shape.dim_size() < 3) { + fail_shape_inference("LinearAttention: value must have rank >= 3"); + } TensorShapeProto output_shape; *output_shape.add_dim() = query_shape.dim(0); // B *output_shape.add_dim() = query_shape.dim(1); // T @@ -2459,6 +2480,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); + if (query_shape.dim_size() < 3 || value_shape.dim_size() < 3) { + // Already validated in Output 0 block above; skip if shapes are invalid. + return; + } TensorShapeProto state_shape; *state_shape.add_dim() = query_shape.dim(0); // B state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv diff --git a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc index bf1b3a89d8813..f7e22608002ff 100644 --- a/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc +++ b/onnxruntime/core/graph/contrib_ops/shape_inference_functions.cc @@ -134,6 +134,9 @@ void SkipLayerNormalizationShapeInference(::ONNX_NAMESPACE::InferenceContext& ct } auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); int64_t input_ndim = input_shape.dim_size(); + if (input_ndim < 1) { + fail_shape_inference("SkipLayerNormalization: input must have rank >= 1"); + } int axis = static_cast(input_ndim - 1); if (ctx.getNumOutputs() > 1) { From dbf95cfb0e3c4f1ea17ea4fbcb941ee33cb74e59 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 12 Jun 2026 00:11:22 -0700 Subject: [PATCH 10/15] [CUDA] Optimize QMoE SoftmaxTopK router for small-batch decode (#28980) ### Description Optimizes the CUDA QMoE router top-k (`LaunchSoftmaxTopK`) for small-batch / autoregressive decode by replacing the old one-thread-per-row hot path with parallel CUB and warp-level top-k kernels. The dispatch now uses the fastest specialized path for common MoE expert counts while preserving the existing softmax normalization and deterministic lower-index tie-breaking semantics. This PR also factors the warp-level top-k sorting code into a reusable CUDA helper header and adds direct CUDA-internal tests so the new routing paths are covered independently of higher-level QMoE tests. ### Motivation and Context The previous router path launched a 256-thread block per row but did all top-k work in a single thread. In decode scenarios such as `num_rows == 1`, that made the router latency-bound on a serial scan of all expert logits and turned `SoftmaxTopKKernel` into a major MoE decode bottleneck. For a Qwen3-style MoE workload with 256 experts, top-8 routing, and 40 MoE layers, the original router accounted for roughly 50% of decode GPU time. Moving the work to block/warp-parallel kernels removes that bottleneck while keeping the same output ordering and scaling behavior. ### Key Changes | Area | Change | |---|---| | QMoE router dispatch | Adds `DispatchSoftmaxTopK` routing for `k <= 64` and `num_experts <= 1024`, with a fallback to the original scalar kernel for larger or uncommon shapes. | | Tiny expert counts | Adds `SoftmaxTopKWarpBitonicKernel` for `num_experts <= 32`, using one warp per row and in-register bitonic sorting via warp shuffles. | | Small expert counts | Adds `SoftmaxTopKWarpMergeKernel` for `32 < num_experts <= 64`, using a single warp and CUB warp merge sort. | | Larger common MoE counts | Uses `SoftmaxTopKMergeKernel` with CUB block merge sort for `num_experts <= 128`, `256`, `512`, and `1024`. | | Reusable top-k helpers | Adds `onnxruntime/core/providers/cuda/cu_inc/topk_warp_sort.cuh` with reusable warp bitonic and warp merge sort helpers. | | Stable tie-breaking | Packs `(score, index)` into a `uint64_t` stable sort key for the CUB merge paths, matching onnxruntime-genai's lower-index tie-breaking and avoiding compound comparators. | | Softmax cleanup | Factors shared softmax scale, safe reciprocal, top-k normalization, warp reduction, and CUB block reduction helpers to keep the optimized kernels consistent. | | Tests | Adds CUDA-internal `SoftmaxTopK_*` tests covering warp bitonic, warp merge, block merge, stable ties, normalization, `float`, `half`, and `bfloat16`. | ### Performance H200 measurements for the target QMoE decode scenario showed the router cost dropping from roughly `5.56 ms/token` to `0.17 ms/token`, improving end-to-end Qwen3.6-35B-A3B INT4 decode throughput from about `80 tok/s` to `113 tok/s`. Additional profiling of the `32 < num_experts <= 64` warp merge path showed the packed `uint64_t` stable sort key is consistently faster than a `{float, int}` struct comparator on H200: | Experts | Sort-only packed/struct | Full softmax+top-k packed/struct | |---:|---:|---:| | 33 | 0.680x | 0.704x | | 48 | 0.672x | 0.695x | | 64 | 0.673x | 0.696x | ### Testing - `lintrunner -a` - `ninja onnxruntime_providers_cuda_ut` - `ninja onnxruntime_provider_test` - `GTEST_FILTER='CUDA_EP_Unittest.SoftmaxTopK_*' ./onnxruntime_provider_test --gtest_filter='CUDA_EP_Unittest.All'` - `onnxruntime/test/python/transformers/test_qmoe_cuda.py -k parity` (`44 passed`) --- onnxruntime/contrib_ops/cuda/moe/moe.cc | 3 + .../contrib_ops/cuda/moe/moe_quantization.cc | 3 + .../contrib_ops/cuda/moe/qmoe_kernels.cu | 322 +++++++++++++++++- .../providers/cuda/cu_inc/topk_warp_sort.cuh | 189 ++++++++++ .../cuda_kernels/softmax_topk_kernel_test.cc | 273 +++++++++++++++ 5 files changed, 778 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/cu_inc/topk_warp_sort.cuh create mode 100644 onnxruntime/test/contrib_ops/cuda_kernels/softmax_topk_kernel_test.cc diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index d5155dc5507cb..603ba2b96b81b 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -55,6 +55,9 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { 1, // no quantization so pack size is 1 is_fused_swiglu, 0)); // no block-wise quantization for regular MoE + ORT_RETURN_IF_NOT(k_ > 0 && k_ <= moe_params.num_experts, + "MoE requires 0 < k <= num_experts, got k=", k_, + " and num_experts=", moe_params.num_experts); using CudaT = typename OrtToCudaType::type; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc index bc34a2e83318a..8e78076288015 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc @@ -317,6 +317,9 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { &fc2_shape, fc2_experts_bias_optional, fc2_scales, fc2_zeros, nullptr, nullptr, nullptr, nullptr, pack_size, is_fused_swiglu, block_size_)); + ORT_RETURN_IF_NOT(k_ > 0 && k_ <= moe_params.num_experts, + "QMoE requires 0 < k <= num_experts, got k=", k_, + " and num_experts=", moe_params.num_experts); if (uses_fp4_weight_scales) { constexpr int64_t fp4_block_size = 32; diff --git a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu index 61cdf3ab23fca..43420c04b1a8e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu @@ -6,6 +6,7 @@ #include "core/common/narrow.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/cub.cuh" +#include "core/providers/cuda/cu_inc/topk_warp_sort.cuh" #include "contrib_ops/cuda/llm/moe_gemm/moe_kernels.h" #include #include @@ -19,6 +20,56 @@ int Compute1DGridSize(int num_elements, int block_size) { return (num_elements + block_size - 1) / block_size; } +constexpr float kTopKNormalizeEpsilon = 1e-6f; + +__device__ __forceinline__ float SoftmaxScale(float logit, float max_val, float inv_sum) { + return (inv_sum > 0.0f) ? expf(logit - max_val) * inv_sum : 0.0f; +} + +__device__ __forceinline__ float SafeInvSum(float sum) { + return (sum > 0.0f) ? (1.0f / sum) : 0.0f; +} + +__device__ __forceinline__ float TopKNormalizeDenom(bool normalize_scales, float scale_sum) { + return (normalize_scales && scale_sum > kTopKNormalizeEpsilon) ? scale_sum : 1.0f; +} + +__device__ __forceinline__ float WarpReduceMax(float value) { + constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize; +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + value = fmaxf(value, __shfl_xor_sync(0xFFFFFFFF, value, offset)); + } + return value; +} + +__device__ __forceinline__ float WarpReduceSum(float value) { + constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize; +#pragma unroll + for (int offset = kWarpSize / 2; offset > 0; offset >>= 1) { + value += __shfl_xor_sync(0xFFFFFFFF, value, offset); + } + return value; +} + +template +__device__ __forceinline__ float BlockReduceMax(float value, typename BlockReduce::TempStorage& temp_storage) { +#if CUDART_VERSION >= 12090 + return BlockReduce(temp_storage).Reduce(value, ::cuda::maximum()); +#else + return BlockReduce(temp_storage).Reduce(value, cub::Max()); +#endif +} + +template +__device__ __forceinline__ float BlockReduceSum(float value, typename BlockReduce::TempStorage& temp_storage) { +#if CUDART_VERSION >= 12090 + return BlockReduce(temp_storage).Reduce(value, ::cuda::std::plus()); +#else + return BlockReduce(temp_storage).Reduce(value, cub::Sum()); +#endif +} + template __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk_indices, int num_rows, int num_experts, int k, bool normalize_scales) { @@ -30,7 +81,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk int* row_indices = topk_indices + row * k; // 1. Find max for numerical stability - float max_val = -FLT_MAX; + float max_val = onnxruntime::cuda::topk::kNegativeInfinity; for (int i = 0; i < num_experts; ++i) { float val = static_cast(row_logits[i]); if (val > max_val) max_val = val; @@ -41,6 +92,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk for (int i = 0; i < num_experts; ++i) { sum_exp += expf(static_cast(row_logits[i]) - max_val); } + const float inv_sum = SafeInvSum(sum_exp); // 3. Compute Softmax and find TopK // For small k, we can do a simple selection. @@ -56,7 +108,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk } for (int i = 0; i < num_experts; ++i) { - float prob = expf(static_cast(row_logits[i]) - max_val) / sum_exp; + float prob = SoftmaxScale(static_cast(row_logits[i]), max_val, inv_sum); // Insert into top-k logic // Simple insertion sort for very small k (e.g. k=2) @@ -80,7 +132,7 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk for (int i = 0; i < k; ++i) { scale_sum += row_scales[i]; } - if (scale_sum > 1e-6f) { + if (scale_sum > kTopKNormalizeEpsilon) { for (int i = 0; i < k; ++i) { row_scales[i] /= scale_sum; } @@ -88,6 +140,258 @@ __global__ void SoftmaxTopKKernel(const T* logits, float* topk_scales, int* topk } } +// Block-per-row softmax + top-k using a CUB block sort. Each block sorts one +// row's logits (descending) and reads the first k. A full sort of 256 logits is +// ~2.5x faster than k rounds of block-argmax on this size (benchmarked), and is +// the layout onnxruntime-genai's top-k benchmarks also recommend (CUB block +// merge) for sort sizes up to ~1024. The capacity (kBlockSize*kItemsPerThread) +// must be >= num_experts; padding lanes carry (-inf, INT_MAX) so valid -inf +// expert scores sort ahead of padding. Tie-breaking matches the scalar kernel +// (lower expert index first) via the same packed stable sort key used by the +// warp merge path. + +template +__global__ void SoftmaxTopKMergeKernel(const T* logits, float* topk_scales, int* topk_indices, + int num_rows, int num_experts, int k, bool normalize_scales) { + const int row = blockIdx.x; + if (row >= num_rows) return; + + const T* row_logits = logits + static_cast(row) * num_experts; + const int tid = threadIdx.x; + + using BlockMergeSort = cub::BlockMergeSort; + using BlockReduce = cub::BlockReduce; + __shared__ union { + typename BlockMergeSort::TempStorage merge; + typename BlockReduce::TempStorage reduce; + } temp; + __shared__ float s_topk[64]; // k <= 64 + __shared__ float s_max; + __shared__ float s_sum; + + // Load this thread's packed (logit, expert index) keys in a blocked + // arrangement: thread t owns indices [t*ipt, t*ipt+ipt). + uint64_t keys[kItemsPerThread]; + float local_max = onnxruntime::cuda::topk::kNegativeInfinity; +#pragma unroll + for (int j = 0; j < kItemsPerThread; ++j) { + const int idx = tid * kItemsPerThread + j; + const float logit = (idx < num_experts) ? static_cast(row_logits[idx]) + : onnxruntime::cuda::topk::kNegativeInfinity; + const int index = (idx < num_experts) ? idx : INT_MAX; + keys[j] = onnxruntime::cuda::topk::PackStableSortKey(logit, index); + local_max = fmaxf(local_max, logit); + } + + // Softmax denominator over all experts (needed when normalize_scales is false; + // when true it cancels in the top-k normalization but is still correct). + const float block_max = BlockReduceMax(local_max, temp.reduce); + if (tid == 0) s_max = block_max; + // Single barrier: publishes s_max to all threads and also separates the two + // BlockReduce uses that share temp.reduce. + __syncthreads(); + const float max_val = s_max; + + float local_sum = 0.0f; +#pragma unroll + for (int j = 0; j < kItemsPerThread; ++j) { + const int idx = tid * kItemsPerThread + j; + if (idx < num_experts) { + local_sum += expf(onnxruntime::cuda::topk::UnpackStableSortScore(keys[j]) - max_val); + } + } + const float block_sum = BlockReduceSum(local_sum, temp.reduce); + if (tid == 0) s_sum = block_sum; + // Single barrier: publishes s_sum and separates temp.reduce from temp.merge. + __syncthreads(); + const float inv_sum = SafeInvSum(s_sum); + + // Sort packed (logit, index) keys descending. Result stays in a blocked + // layout, so sorted rank r lives in thread (r / ipt), item (r % ipt). Sort() + // leaves the sorted keys in each thread's registers and temp.merge is not + // reused afterwards, so no barrier is needed here; the shared s_topk writes + // below are published by the barrier that follows them. + BlockMergeSort(temp.merge).Sort(keys, onnxruntime::cuda::topk::Greater()); + +#pragma unroll + for (int j = 0; j < kItemsPerThread; ++j) { + const int rank = tid * kItemsPerThread + j; + if (rank < k) { + const uint64_t key = keys[j]; + topk_indices[static_cast(row) * k + rank] = + onnxruntime::cuda::topk::UnpackStableSortIndex(key); + s_topk[rank] = SoftmaxScale(onnxruntime::cuda::topk::UnpackStableSortScore(key), max_val, inv_sum); + } + } + __syncthreads(); + + if (tid == 0) { + if (normalize_scales) { + float scale_sum = 0.0f; + for (int i = 0; i < k; ++i) scale_sum += s_topk[i]; + const float denom = TopKNormalizeDenom(normalize_scales, scale_sum); + for (int i = 0; i < k; ++i) topk_scales[static_cast(row) * k + i] = s_topk[i] / denom; + } else { + for (int i = 0; i < k; ++i) topk_scales[static_cast(row) * k + i] = s_topk[i]; + } + } +} + +// Warp-bitonic softmax + top-k for num_experts <= 32. Each warp handles one +// row, with lane `l` owning expert `l`. The whole softmax reduction and the +// sort are done with warp shuffles (no shared memory). This is the fastest path +// for tiny expert counts per the onnxruntime-genai top-k benchmark. Tie-breaking +// (equal scores prefer the lower expert index) matches SoftmaxTopKMergeKernel. +template +__global__ void SoftmaxTopKWarpBitonicKernel(const T* logits, float* topk_scales, int* topk_indices, + int num_rows, int num_experts, int k, bool normalize_scales) { + const int lane = threadIdx.x; + const int row = blockIdx.x * kWarpsPerBlock + threadIdx.y; + if (row >= num_rows) return; + + const T* row_logits = logits + static_cast(row) * num_experts; + const float logit = (lane < num_experts) ? static_cast(row_logits[lane]) + : onnxruntime::cuda::topk::kNegativeInfinity; + + const float max_val = WarpReduceMax(logit); + + // Warp-wide exp sum (softmax denominator over all experts). + const float sum_exp = WarpReduceSum((lane < num_experts) ? expf(logit - max_val) : 0.0f); + const float inv_sum = SafeInvSum(sum_exp); + + // Sort (logit, expert index) descending; sorting by logit is equivalent to + // sorting by softmax probability since the mapping is monotonic. + float score = logit; + int index = (lane < num_experts) ? lane : INT_MAX; + onnxruntime::cuda::topk::WarpBitonicSortDescending(score, index); + + // Lane r now holds the rank-r element. Compute the top-k probabilities. + float prob = (lane < k) ? SoftmaxScale(score, max_val, inv_sum) : 0.0f; + + if (normalize_scales) { + prob /= TopKNormalizeDenom(normalize_scales, WarpReduceSum(prob)); + } + + if (lane < k) { + topk_scales[static_cast(row) * k + lane] = prob; + topk_indices[static_cast(row) * k + lane] = index; + } +} + +// Warp CUB merge sort softmax + top-k for num_experts <= kBufferSize (64). One +// warp (32 threads) per block sorts a row's logits held in shared memory. This +// is the genai-recommended path for sort sizes in (32, 64]. Tie-breaking +// matches SoftmaxTopKMergeKernel via a packed stable sort key. +template +__global__ void SoftmaxTopKWarpMergeKernel(const T* logits, float* topk_scales, int* topk_indices, + int num_rows, int num_experts, int k, bool normalize_scales) { + constexpr int kWarpSize = onnxruntime::cuda::topk::kWarpSize; + using WarpMergeSorter = onnxruntime::cuda::topk::WarpMergeSorter; + + const int row = blockIdx.x; + if (row >= num_rows) return; + const int lane = threadIdx.x; + + __shared__ float s_scores[kBufferSize]; + __shared__ int s_indices[kBufferSize]; + __shared__ typename WarpMergeSorter::TempStorage temp_storage; + + const T* row_logits = logits + static_cast(row) * num_experts; + + // Load logits into shared memory and compute the warp-wide max. + float local_max = onnxruntime::cuda::topk::kNegativeInfinity; + for (int i = lane; i < kBufferSize; i += kWarpSize) { + const float v = (i < num_experts) ? static_cast(row_logits[i]) + : onnxruntime::cuda::topk::kNegativeInfinity; + s_scores[i] = v; + s_indices[i] = (i < num_experts) ? i : INT_MAX; + local_max = fmaxf(local_max, v); + } + const float max_val = WarpReduceMax(local_max); + + // Warp-wide exp sum over all experts. + float local_sum = 0.0f; + for (int i = lane; i < num_experts; i += kWarpSize) { + local_sum += expf(s_scores[i] - max_val); + } + const float inv_sum = SafeInvSum(WarpReduceSum(local_sum)); + + __syncwarp(); + WarpMergeSorter::Sort(s_scores, s_indices, temp_storage, num_experts); + __syncwarp(); + + // s_scores[r]/s_indices[r] now hold the rank-r logit/expert index. + float scale_sum = 0.0f; + if (normalize_scales) { + for (int i = lane; i < k; i += kWarpSize) { + scale_sum += SoftmaxScale(s_scores[i], max_val, inv_sum); + } + scale_sum = WarpReduceSum(scale_sum); + } + const float denom = TopKNormalizeDenom(normalize_scales, scale_sum); + + for (int i = lane; i < k; i += kWarpSize) { + const float prob = SoftmaxScale(s_scores[i], max_val, inv_sum); + topk_scales[static_cast(row) * k + i] = normalize_scales ? (prob / denom) : prob; + topk_indices[static_cast(row) * k + i] = s_indices[i]; + } +} + +template +void DispatchSoftmaxTopK(const T* logits, float* topk_scales, int* topk_indices, + int num_rows, int num_experts, int k, bool normalize_scales, + cudaStream_t stream) { + ORT_ENFORCE(k > 0 && k <= num_experts, + "SoftmaxTopK requires 0 < k <= num_experts, got k=", k, + " and num_experts=", num_experts); + + // Block-per-row CUB merge sort is the fastest path for the common decode case + // (one block fully sorts a row). Pick the smallest capacity that covers + // num_experts. k must fit in s_topk (<= 64). + const dim3 grid(static_cast(num_rows)); + if (k <= 64 && num_experts <= 1024) { + // Tiny expert counts: a single warp sorts a row entirely in registers via + // an in-register bitonic sort (no shared memory). Multiple warps per block + // process multiple rows for better occupancy. + if (num_experts <= onnxruntime::cuda::topk::kWarpBitonicMaxSize) { + constexpr int kWarpsPerBlock = 8; + const dim3 block(static_cast(onnxruntime::cuda::topk::kWarpSize), kWarpsPerBlock); + const dim3 bitonic_grid(static_cast((num_rows + kWarpsPerBlock - 1) / kWarpsPerBlock)); + SoftmaxTopKWarpBitonicKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } else if (num_experts <= onnxruntime::cuda::topk::kWarpMergeMaxSize) { + // Single warp per row sorts up to 64 logits in shared memory (CUB warp + // merge sort), the genai-recommended path for sort sizes in (32, 64]. + SoftmaxTopKWarpMergeKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } else if (num_experts <= 128) { + SoftmaxTopKMergeKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } else if (num_experts <= 256) { + SoftmaxTopKMergeKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } else if (num_experts <= 512) { + SoftmaxTopKMergeKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } else /*if (num_experts <= 1024)*/ { + SoftmaxTopKMergeKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + return; + } + } else { + // Fall back to the simple one-thread-per-row kernel. + const int block = 256; + const int grid_1d = Compute1DGridSize(num_rows, block); + SoftmaxTopKKernel<<>>( + logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + } +} + void LaunchSoftmaxTopK( const float* logits, float* topk_scales, @@ -97,9 +401,7 @@ void LaunchSoftmaxTopK( int k, bool normalize_scales, cudaStream_t stream) { - int block = 256; - int grid = Compute1DGridSize(num_rows, block); - SoftmaxTopKKernel<<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + DispatchSoftmaxTopK(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream); } void LaunchSoftmaxTopK( @@ -111,9 +413,7 @@ void LaunchSoftmaxTopK( int k, bool normalize_scales, cudaStream_t stream) { - int block = 256; - int grid = Compute1DGridSize(num_rows, block); - SoftmaxTopKKernel<<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + DispatchSoftmaxTopK(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream); } void LaunchSoftmaxTopK( @@ -125,9 +425,7 @@ void LaunchSoftmaxTopK( int k, bool normalize_scales, cudaStream_t stream) { - int block = 256; - int grid = Compute1DGridSize(num_rows, block); - SoftmaxTopKKernel<__nv_bfloat16><<>>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales); + DispatchSoftmaxTopK<__nv_bfloat16>(logits, topk_scales, topk_indices, num_rows, num_experts, k, normalize_scales, stream); } template diff --git a/onnxruntime/core/providers/cuda/cu_inc/topk_warp_sort.cuh b/onnxruntime/core/providers/cuda/cu_inc/topk_warp_sort.cuh new file mode 100644 index 0000000000000..c0fe36bb65ddc --- /dev/null +++ b/onnxruntime/core/providers/cuda/cu_inc/topk_warp_sort.cuh @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Reusable warp-level Top-K sorting primitives for CUDA. +// +// These helpers sort (score, index) pairs in descending order. Ties on the +// score are broken deterministically by preferring the smaller index, matching +// the tie-breaking used by the onnxruntime-genai Top-K kernels (the +// `STABLE_TOPK` path in cuda_topk_warp_sort_helper.cuh). +// +// Two primitives are provided, mirroring the algorithms that the genai offline +// benchmark found fastest for small sort sizes: +// * WarpBitonicSortDescending : best for sort sizes up to 32. Each lane holds +// a single (score, index) pair entirely in registers and exchanges data +// via warp shuffles, avoiding shared memory. +// * WarpMergeSorter : best for sort sizes up to 64 (CUB warp merge +// sort). A single warp sorts up to `BufferSize` pairs held in shared memory. +// +// They are intentionally operator-agnostic so they can be reused outside the +// MoE Top-K path. + +#pragma once + +#include +#include +#include +#include + +#include "core/providers/cuda/cu_inc/cub.cuh" + +namespace onnxruntime { +namespace cuda { +namespace topk { + +constexpr int kWarpSize = 32; + +// Compile-time threshold guidance based on the onnxruntime-genai offline +// benchmark (NVIDIA H200, CUDA 12.8). Use WarpBitonicSortDescending for sort +// sizes up to kWarpBitonicMaxSize, and the CUB warp merge sort for sizes up to +// kWarpMergeMaxSize. Larger sizes should fall back to a block-wide sort. +constexpr int kWarpBitonicMaxSize = 32; +constexpr int kWarpMergeMaxSize = 64; +constexpr float kNegativeInfinity = -std::numeric_limits::infinity(); + +__device__ __forceinline__ int LaneId() { + int lane_id; + asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ int LinearThreadIdInBlock() { + return threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); +} + +/** + * @brief In-register, warp-wide bitonic sort of kWarpSize (32) (score, index) + * pairs, producing a descending order. + * + * Each lane in the warp contributes exactly one (score, index) pair. After the + * call, the warp's pairs are sorted so that lane 0 holds the largest score. + * Ties on the score are broken in favor of the smaller index. Data is exchanged + * with __shfl_sync, so no shared memory is required. + * + * Lanes that do not hold a valid element should pass score = kNegativeInfinity + * and index = INT_MAX so that valid -inf scores sort ahead of padding. + */ +__device__ inline void WarpBitonicSortDescending(float& score, int& index) { + const int lane_id = LaneId(); + + // Build the bitonic sorting network in stages. + for (int k = 2; k <= kWarpSize; k <<= 1) { + for (int j = k >> 1; j > 0; j >>= 1) { + const int paired_lane = lane_id ^ j; + const float paired_score = __shfl_sync(0xFFFFFFFF, score, paired_lane); + const int paired_index = __shfl_sync(0xFFFFFFFF, index, paired_lane); + + // A standard bitonic network sorts ascending when (lane_id & k) == 0; we + // invert the swap condition to produce an overall descending sort. + const bool direction = ((lane_id & k) == 0); + + // Tie-break: equal scores prefer the smaller index. + const bool is_mine_greater = + (score > paired_score) || (score == paired_score && index < paired_index); + + const float s_max = is_mine_greater ? score : paired_score; + const int i_max = is_mine_greater ? index : paired_index; + const float s_min = is_mine_greater ? paired_score : score; + const int i_min = is_mine_greater ? paired_index : index; + + if (direction) { + score = (lane_id < paired_lane) ? s_max : s_min; + index = (lane_id < paired_lane) ? i_max : i_min; + } else { + score = (lane_id < paired_lane) ? s_min : s_max; + index = (lane_id < paired_lane) ? i_min : i_max; + } + } + } +} + +// Convert a (score, index) pair into a single unsigned integer key. Descending +// integer order then gives descending float score order, with equal scores +// preferring the smaller original index. This matches the stable Top-K packing +// used by onnxruntime-genai while avoiding a compound comparator in CUB. +__device__ __forceinline__ uint64_t PackStableSortKey(float score, int index) { + const uint32_t score_bits = __float_as_uint(score); + const uint32_t sortable_score = + (score_bits & 0x80000000u) ? (~score_bits) : (score_bits | 0x80000000u); + const uint32_t inverted_index = UINT_MAX - static_cast(index); + return (static_cast(sortable_score) << 32) | inverted_index; +} + +__device__ __forceinline__ float UnpackStableSortScore(uint64_t key) { + const uint32_t sortable_score = static_cast(key >> 32); + const uint32_t score_bits = + (sortable_score & 0x80000000u) ? (sortable_score & 0x7fffffffu) : ~sortable_score; + return __uint_as_float(score_bits); +} + +__device__ __forceinline__ int UnpackStableSortIndex(uint64_t key) { + const uint32_t inverted_index = static_cast(key & 0xffffffffu); + return static_cast(UINT_MAX - inverted_index); +} + +template +struct Greater { + __device__ __host__ __forceinline__ bool operator()(const T& a, const T& b) const { + return a > b; + } +}; + +/** + * @brief Single-warp CUB merge sort of up to `BufferSize` (score, index) pairs + * held in shared memory, producing a descending order. + * + * Only the first warp of the calling block performs work; the caller is + * responsible for any __syncthreads() needed before (to publish the shared + * memory inputs) and after (to consume the sorted outputs). On return, + * smem_scores[r]/smem_indices[r] hold the element of rank r (rank 0 == largest). + * + * @tparam BufferSize Maximum number of pairs to sort. Must be <= 256. + */ +template +struct WarpMergeSorter { + static_assert(BufferSize > 0 && BufferSize <= 256, "BufferSize must be in (0, 256]."); + + static constexpr int kItemsPerThread = (BufferSize + kWarpSize - 1) / kWarpSize; + using SortT = cub::WarpMergeSort; + using TempStorage = typename SortT::TempStorage; + + // num_valid_items elements are read from shared memory; the remainder are + // padded with (kNegativeInfinity, INT_MAX) so valid -inf scores sort ahead of padding. + __device__ static void Sort(float* smem_scores, int* smem_indices, + TempStorage& temp_storage, int num_valid_items) { + const int thread_id = LinearThreadIdInBlock(); + if (thread_id >= kWarpSize) { + return; + } + + const int lane_id = thread_id; + + uint64_t items[kItemsPerThread]; +#pragma unroll + for (int i = 0; i < kItemsPerThread; ++i) { + const int idx = lane_id + i * kWarpSize; + if (idx < num_valid_items) { + items[i] = PackStableSortKey(smem_scores[idx], smem_indices[idx]); + } else { + items[i] = PackStableSortKey(kNegativeInfinity, INT_MAX); + } + } + + SortT(temp_storage).Sort(items, Greater()); + + // Blocked write-back: rank r lives at smem[r]. +#pragma unroll + for (int i = 0; i < kItemsPerThread; ++i) { + const int idx = lane_id * kItemsPerThread + i; + if (idx < BufferSize) { + smem_scores[idx] = UnpackStableSortScore(items[i]); + smem_indices[idx] = UnpackStableSortIndex(items[i]); + } + } + } +}; + +} // namespace topk +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/cuda_kernels/softmax_topk_kernel_test.cc b/onnxruntime/test/contrib_ops/cuda_kernels/softmax_topk_kernel_test.cc new file mode 100644 index 0000000000000..345983b876b12 --- /dev/null +++ b/onnxruntime/test/contrib_ops/cuda_kernels/softmax_topk_kernel_test.cc @@ -0,0 +1,273 @@ +#include +#include +#include +#include + +#include "contrib_ops/cuda/moe/qmoe_kernels.h" +#include "core/providers/cuda/cuda_common.h" + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace test { +namespace { + +struct CudaBuffer { + void* data = nullptr; + size_t bytes = 0; + + explicit CudaBuffer(size_t size_in_bytes) : bytes(size_in_bytes) { + CUDA_CALL_THROW(cudaMalloc(&data, bytes)); + } + + ~CudaBuffer() { + if (data != nullptr) { + cudaFree(data); + } + } + + template + T* As() { + return reinterpret_cast(data); + } + + void CopyFromHost(const void* src) { + CUDA_CALL_THROW(cudaMemcpy(data, src, bytes, cudaMemcpyHostToDevice)); + } + + void CopyToHost(void* dst) const { + CUDA_CALL_THROW(cudaMemcpy(dst, data, bytes, cudaMemcpyDeviceToHost)); + } +}; + +struct ExpectedTopK { + std::vector scales; + std::vector indices; +}; + +template +float ToFloat(T value) { + return static_cast(value); +} + +template +T FromFloat(float value) { + return static_cast(value); +} + +template +ExpectedTopK ReferenceSoftmaxTopK(const std::vector& logits, int num_rows, int num_experts, int k, + bool normalize_scales) { + ExpectedTopK expected; + expected.scales.resize(static_cast(num_rows) * k); + expected.indices.resize(static_cast(num_rows) * k); + + for (int row = 0; row < num_rows; ++row) { + const T* row_logits = logits.data() + static_cast(row) * num_experts; + float max_logit = -std::numeric_limits::infinity(); + for (int expert = 0; expert < num_experts; ++expert) { + max_logit = std::max(max_logit, ToFloat(row_logits[expert])); + } + + float sum_exp = 0.0f; + for (int expert = 0; expert < num_experts; ++expert) { + sum_exp += std::exp(ToFloat(row_logits[expert]) - max_logit); + } + + std::vector order(num_experts); + std::iota(order.begin(), order.end(), 0); + std::stable_sort(order.begin(), order.end(), [&](int lhs, int rhs) { + const float lhs_logit = ToFloat(row_logits[lhs]); + const float rhs_logit = ToFloat(row_logits[rhs]); + return lhs_logit > rhs_logit || (lhs_logit == rhs_logit && lhs < rhs); + }); + + float topk_sum = 0.0f; + for (int rank = 0; rank < k; ++rank) { + const int expert = order[rank]; + const float scale = std::exp(ToFloat(row_logits[expert]) - max_logit) / sum_exp; + expected.scales[static_cast(row) * k + rank] = scale; + expected.indices[static_cast(row) * k + rank] = expert; + topk_sum += scale; + } + + if (normalize_scales && topk_sum > 1e-6f) { + for (int rank = 0; rank < k; ++rank) { + expected.scales[static_cast(row) * k + rank] /= topk_sum; + } + } + } + + return expected; +} + +template +void RunSoftmaxTopKTest(int num_rows, int num_experts, int k, bool normalize_scales, + const std::vector& logits_float, float tolerance) { + ASSERT_EQ(logits_float.size(), static_cast(num_rows) * num_experts); + + std::vector logits(logits_float.size()); + std::transform(logits_float.begin(), logits_float.end(), logits.begin(), [](float value) { + return FromFloat(value); + }); + + CudaBuffer d_logits(logits.size() * sizeof(T)); + CudaBuffer d_scales(static_cast(num_rows) * k * sizeof(float)); + CudaBuffer d_indices(static_cast(num_rows) * k * sizeof(int)); + d_logits.CopyFromHost(logits.data()); + + cudaStream_t stream = nullptr; + CUDA_CALL_THROW(cudaStreamCreate(&stream)); + onnxruntime::contrib::cuda::LaunchSoftmaxTopK( + d_logits.As(), d_scales.As(), d_indices.As(), num_rows, num_experts, k, normalize_scales, stream); + CUDA_CALL_THROW(cudaGetLastError()); + CUDA_CALL_THROW(cudaStreamSynchronize(stream)); + CUDA_CALL_THROW(cudaStreamDestroy(stream)); + + std::vector actual_scales(static_cast(num_rows) * k); + std::vector actual_indices(static_cast(num_rows) * k); + d_scales.CopyToHost(actual_scales.data()); + d_indices.CopyToHost(actual_indices.data()); + + const ExpectedTopK expected = ReferenceSoftmaxTopK(logits, num_rows, num_experts, k, normalize_scales); + ASSERT_EQ(actual_indices, expected.indices); + ASSERT_EQ(actual_scales.size(), expected.scales.size()); + for (size_t i = 0; i < actual_scales.size(); ++i) { + EXPECT_NEAR(actual_scales[i], expected.scales[i], tolerance) << "at flattened top-k index " << i; + } +} + +std::vector MakeLogits(int num_rows, int num_experts) { + std::vector logits(static_cast(num_rows) * num_experts); + for (int row = 0; row < num_rows; ++row) { + for (int expert = 0; expert < num_experts; ++expert) { + const int v = (expert * 17 + row * 11) % 29; + logits[static_cast(row) * num_experts + expert] = 0.125f * static_cast(v - 14); + } + } + return logits; +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpBitonicStableTiesFloat) { + constexpr int num_rows = 9; + constexpr int num_experts = 8; + constexpr int k = 4; + std::vector logits = { + 4.0f, 1.0f, 4.0f, 0.0f, 3.0f, 4.0f, -2.0f, 4.0f, + -1.0f, 2.0f, 2.0f, 5.0f, 5.0f, 5.0f, 0.5f, -3.0f, + 0.0f, 0.0f, 0.0f, -1.0f, -1.0f, 2.0f, 2.0f, 2.0f, + 3.0f, 1.0f, 2.0f, 3.0f, 3.0f, -4.0f, -5.0f, 0.0f, + 1.5f, 1.5f, 1.0f, 2.0f, 0.0f, 2.0f, 2.0f, -2.0f, + -6.0f, -5.0f, -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, + 7.0f, 7.0f, 6.0f, 6.0f, 5.0f, 5.0f, 4.0f, 4.0f, + -1.0f, 8.0f, -1.0f, 8.0f, 3.0f, 8.0f, 2.0f, 1.0f, + 0.25f, 0.5f, 0.75f, 1.0f, 1.25f, 1.5f, 1.75f, 2.0f}; + + RunSoftmaxTopKTest(num_rows, num_experts, k, false, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpBitonicBoundaryNormalizeHalf) { + constexpr int num_rows = 10; + constexpr int num_experts = 32; + constexpr int k = 8; + auto logits = MakeLogits(num_rows, num_experts); + logits[3 * num_experts + 0] = 6.0f; + logits[3 * num_experts + 5] = 6.0f; + logits[3 * num_experts + 17] = 6.0f; + + RunSoftmaxTopKTest(num_rows, num_experts, k, true, logits, 2e-3f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpMergeStableTiesFloat) { + constexpr int num_rows = 5; + constexpr int num_experts = 64; + constexpr int k = 8; + auto logits = MakeLogits(num_rows, num_experts); + for (int expert : {2, 13, 31, 63}) { + logits[static_cast(1) * num_experts + expert] = 5.0f; + } + for (int expert : {0, 16, 32, 48}) { + logits[static_cast(4) * num_experts + expert] = 4.5f; + } + + RunSoftmaxTopKTest(num_rows, num_experts, k, false, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpMergeNormalizeBFloat16) { + constexpr int num_rows = 6; + constexpr int num_experts = 64; + constexpr int k = 16; + auto logits = MakeLogits(num_rows, num_experts); + logits[2 * num_experts + 4] = 7.0f; + logits[2 * num_experts + 9] = 7.0f; + logits[2 * num_experts + 47] = 7.0f; + + RunSoftmaxTopKTest<__nv_bfloat16>(num_rows, num_experts, k, true, logits, 2e-2f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_BlockMergeStillMatchesReference) { + constexpr int num_rows = 4; + constexpr int num_experts = 128; + constexpr int k = 8; + auto logits = MakeLogits(num_rows, num_experts); + logits[0] = 9.0f; + logits[7] = 9.0f; + logits[65] = 9.0f; + + RunSoftmaxTopKTest(num_rows, num_experts, k, true, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpBitonicHandlesNegativeInfinityPadding) { + constexpr int num_rows = 2; + constexpr int num_experts = 8; + constexpr int k = 8; + constexpr float neg_inf = -std::numeric_limits::infinity(); + std::vector logits = { + 3.0f, neg_inf, 2.0f, neg_inf, 1.0f, neg_inf, 0.0f, neg_inf, + neg_inf, 4.0f, neg_inf, neg_inf, 3.0f, neg_inf, neg_inf, 2.0f}; + + RunSoftmaxTopKTest(num_rows, num_experts, k, false, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_WarpMergeHandlesNegativeInfinityPadding) { + constexpr int num_rows = 2; + constexpr int num_experts = 40; + constexpr int k = 40; + constexpr float neg_inf = -std::numeric_limits::infinity(); + std::vector logits(static_cast(num_rows) * num_experts, neg_inf); + for (int expert = 0; expert < 10; ++expert) { + logits[expert * 3] = static_cast(expert) * 0.25f; + logits[static_cast(1) * num_experts + expert * 2] = 3.0f - static_cast(expert) * 0.125f; + } + + RunSoftmaxTopKTest(num_rows, num_experts, k, true, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_BlockMergeHandlesNegativeInfinityPadding) { + constexpr int num_rows = 2; + constexpr int num_experts = 80; + constexpr int k = 64; + constexpr float neg_inf = -std::numeric_limits::infinity(); + std::vector logits(static_cast(num_rows) * num_experts, neg_inf); + for (int expert = 0; expert < 20; ++expert) { + logits[expert * 2] = static_cast(expert) * 0.125f; + logits[static_cast(1) * num_experts + expert * 3] = 5.0f - static_cast(expert) * 0.25f; + } + + RunSoftmaxTopKTest(num_rows, num_experts, k, true, logits, 1e-6f); +} + +TEST(CUDA_EP_Unittest, SoftmaxTopK_RejectsKGreaterThanExperts) { + EXPECT_THROW(onnxruntime::contrib::cuda::LaunchSoftmaxTopK( + static_cast(nullptr), nullptr, nullptr, 1, 8, 9, false, nullptr), + OnnxRuntimeException); +} + +} // namespace +} // namespace test +} // namespace onnxruntime From 0172bcd0277df3a4240c719078cbd973b7bd103e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 12 Jun 2026 10:45:24 -0400 Subject: [PATCH 11/15] Bump shell-quote from 1.8.3 to 1.8.4 in /js/react_native/e2e (#29022) Bumps [shell-quote](https://github.com/ljharb/shell-quote) from 1.8.3 to 1.8.4.
Changelog

Sourced from shell-quote's changelog.

v1.8.4 - 2026-05-22

Commits

  • [Fix] quote: validate object-token shapes 4378a6e
  • [Dev Deps] update @ljharb/eslint-config, auto-changelog, eslint, npmignore 22ebec0
  • [Tests] increase coverage 9f3caa3
  • [readme] replace runkit CI badge with shields.io check-runs badge 3344a04
  • [Dev Deps] update @ljharb/eslint-config 699c511
Commits
  • ff166e2 v1.8.4
  • 4378a6e [Fix] quote: validate object-token shapes
  • 22ebec0 [Dev Deps] update @ljharb/eslint-config, auto-changelog, eslint, `npmig...
  • 9f3caa3 [Tests] increase coverage
  • 3344a04 [readme] replace runkit CI badge with shields.io check-runs badge
  • 699c511 [Dev Deps] update @ljharb/eslint-config
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=shell-quote&package-manager=npm_and_yarn&previous-version=1.8.3&new-version=1.8.4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index b7f38815f9da5..ca2d4dfed58f7 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -11629,9 +11629,9 @@ } }, "node_modules/shell-quote": { - "version": "1.8.3", - "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.3.tgz", - "integrity": "sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==", + "version": "1.8.4", + "resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.4.tgz", + "integrity": "sha512-VsC6n6vz1ihYYyZZwX7YZSF5l5x36ca17OC+a69h94YqB7X6XLwf+5MOgynYir2SLFUbl8gIYvBo8K8RoNQ6bQ==", "license": "MIT", "engines": { "node": ">= 0.4" From 7db789300561c51f857e14c7da5258b26fe4be6b Mon Sep 17 00:00:00 2001 From: Lee Yongjun Date: Sat, 13 Jun 2026 00:37:27 +0900 Subject: [PATCH 12/15] Fill CUDA opset gap for Softplus and Softsign to opset 22 (#28982) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes part of #27729. The CUDA EP registered Softplus and Softsign only for opset 1 (HFD). Opset 22 extended the type constraint to include BFloat16, but the CUDA EP was never updated. The BFloat16 compute kernels were already instantiated via `UNARY_ACTIVATION_OP_HFD_WITH_BF16`; only the EP registration was missing. Changes: - Cap opset-1 kernels as versioned [1, 21] (Softplus, Softsign) - Register opset-22 kernels with BFloat16 support (HFDX) - Add opset-22 tests: float (CPU+CUDA) and BFloat16 (CUDA, sm≥530) - Update `docs/OperatorKernels.md` --- docs/OperatorKernels.md | 6 +- .../providers/cuda/activation/activations.cc | 7 +- .../providers/cuda/cuda_execution_provider.cc | 40 +++++++---- .../cpu/activation/activation_op_test.cc | 66 +++++++++++++++++++ 4 files changed, 103 insertions(+), 16 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index a3b38b17bd874..5fe7dd020d559 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -999,8 +999,10 @@ The **OpSet Version** column uses the following notation: |Softmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 10]|**T** = tensor(double), tensor(float), tensor(float16)| -|Softplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Softsign|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Softplus|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)| +|Softsign|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| +|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)| |SpaceToDepth|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Split|*in* input:**T**
*in* split:**T**
*out* outputs...:**T**

or

*in* input:**T**
*in* split:**tensor(int64)**
*out* outputs:**T**

or

*in* input:**T**
*out* outputs:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/activation/activations.cc b/onnxruntime/core/providers/cuda/activation/activations.cc index de84928553ffa..0d029fa8b847c 100644 --- a/onnxruntime/core/providers/cuda/activation/activations.cc +++ b/onnxruntime/core/providers/cuda/activation/activations.cc @@ -79,8 +79,8 @@ UNARY_ACTIVATION_OP_VERSIONED_HFD(Relu, 6, 12); UNARY_ACTIVATION_OP_HFD_WITH_BF16(Selu, 6); UNARY_ACTIVATION_OP_HFD_WITH_BF16(Sigmoid, 13); UNARY_ACTIVATION_OP_VERSIONED_HFD(Sigmoid, 6, 12); -UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softplus, 1); -UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softsign, 1); +UNARY_ACTIVATION_OP_VERSIONED_HFD(Softplus, 1, 21); +UNARY_ACTIVATION_OP_VERSIONED_HFD(Softsign, 1, 21); UNARY_ACTIVATION_OP_HFD_WITH_BF16(Tanh, 13); UNARY_ACTIVATION_OP_VERSIONED_HFD(Tanh, 6, 12); UNARY_ACTIVATION_OP_HFD_WITH_BF16(ThresholdedRelu, 10); @@ -91,6 +91,9 @@ UNARY_ACTIVATION_OP_HFD_WITH_BF16(LeakyRelu, 16); // Opset-22 adds BFloat16 to allowed types for the HardSigmoid / HardSwish operators UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSigmoid, 22); UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSwish, 22); +// Opset-22 adds BFloat16 to allowed types for the Softplus / Softsign operators +UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softplus, 22); +UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softsign, 22); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index e9339956a387b..528db1970ef25 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -646,15 +646,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, float, Sigmoid); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, double, Sigmoid); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sigmoid); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Softsign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Softsign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Softsign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, float, Softsign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, double, Softsign); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, MLFloat16, Softsign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, float, Tanh); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, double, Tanh); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Tanh); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Softplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, Softplus); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Softplus); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, float, Softplus); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, double, Softplus); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 21, MLFloat16, Softplus); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, float, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); @@ -1716,6 +1716,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Sin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Sin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Sin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Softplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Softplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Softplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Softplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, Softsign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, Softsign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, Softsign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, Softsign); // Opset 23. class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, 23, float, Attention); @@ -1926,15 +1934,15 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2996,6 +3004,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 23 BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index e6bfa46d79b2b..05b7bbe7648fe 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -697,6 +697,37 @@ TEST_F(ActivationOpTest, Softplus) { }); } +TEST_F(ActivationOpTest, Softplus_Opset22) { + TestActivationOp("Softplus", {{-1.0f, 0.0f, 1.0f, -5.0f, 5.0f, -100.0f, 100.0f}}, [](float x) { + if (x > 0) + return x + log1pf(expf(-x)); + else + return log1pf(expf(x)); }, {}, {}, /*is_tensorrt_supported=*/true, /*opset_version=*/22); +} + +#if defined(USE_CUDA) +TEST_F(ActivationOpTest, Softplus_bfloat16_Opset22) { + if (!HasCudaEnvironment(530)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; + return; + } + + OpTester test("Softplus", 22); + std::vector X = {-1.0f, 0.0f, 1.0f, -5.0f, 5.0f, -100.0f, 100.0f}; + std::vector Y; + for (float x : X) { + Y.push_back(x > 0 ? x + log1pf(expf(-x)) : log1pf(expf(x))); + } + std::vector dims{static_cast(X.size())}; + + test.AddInput("X", dims, FloatsToBFloat16s(X)); + test.AddOutput("Y", dims, FloatsToBFloat16s(Y)); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif // USE_CUDA + TEST_F(ActivationOpNoInfTest, Softsign) { if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { GTEST_SKIP() << "Skipping the test"; @@ -738,6 +769,41 @@ TEST_F(ActivationOpNoInfTest, Softsign) { {}, {}, false); // Disable TensorRT because result mismatches } +TEST_F(ActivationOpNoInfTest, Softsign_Opset22) { + if constexpr (!SessionOptions::DEFAULT_USE_PER_SESSION_THREADS) { + GTEST_SKIP() << "Skipping the test"; + } + + TestActivationOp( + "Softsign", + {{-1.0f, 0.0f, 1.0f, -5.0f, 5.0f, -100.0f, 100.0f}}, + [](float x) { return x / (1 + std::abs(x)); }, + {}, {}, /*is_tensorrt_supported=*/false, /*opset_version=*/22); +} + +#if defined(USE_CUDA) +TEST_F(ActivationOpNoInfTest, Softsign_bfloat16_Opset22) { + if (!HasCudaEnvironment(530)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; + return; + } + + OpTester test("Softsign", 22); + std::vector X = {-1.0f, 0.0f, 1.0f, -5.0f, 5.0f, -100.0f, 100.0f}; + std::vector Y; + for (float x : X) { + Y.push_back(x / (1 + std::abs(x))); + } + std::vector dims{static_cast(X.size())}; + + test.AddInput("X", dims, FloatsToBFloat16s(X)); + test.AddOutput("Y", dims, FloatsToBFloat16s(Y)); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif // USE_CUDA + #if defined(ENABLE_TRAINING_OPS) TEST(ReluGradInferenceTest, Basic) { const std::vector x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; From 0b278bbf9c124de8de0d353f474a4fc5a36381fd Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 12 Jun 2026 10:35:28 -0700 Subject: [PATCH 13/15] Update optimizer opset version checks for latest ONNX opset 26 (#28966) This pull request expands support for additional ONNX opset versions in the attention fusion optimization code, making the optimizer compatible with newer and more diverse ONNX models. The changes primarily update the accepted opset versions for various operators such as `Transpose`, `Reshape`, `Squeeze`, `Unsqueeze`, `Shape`, and others across multiple functions. This ensures broader model compatibility and improves the robustness of the fusion logic. **Expanded opset version support for attention fusion:** * Updated accepted opset versions for key operators (`Transpose`, `Reshape`, `Squeeze`, `Unsqueeze`, `Shape`, `Add`, `Mul`, `Sub`, `Div`, `Cast`, etc.) in the main attention fusion logic (`attention_fusion.cc`), allowing matching and fusion of newer ONNX models using these operators at opsets up to 25. [[1]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L352-R367) [[2]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L382-R384) [[3]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L394-R395) [[4]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L405-R405) [[5]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L463-R471) [[6]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L500-R500) [[7]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L514-R514) [[8]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L923-R927) [[9]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L956-R958) [[10]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1073-R1074) [[11]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1166-R1166) [[12]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1268-R1275) **Helper and mask subgraph matching improvements:** * Broadened opset version checks for subgraph matching in helper functions, including those for Gemm subgraphs, unidirectional mask subgraphs, input mask subgraphs, and past subgraph matching, to support additional opset versions and operator variants. [[1]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L77-R84) [[2]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L169-R171) [[3]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L378-R379) [[4]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L395-R402) [[5]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L457-R458) [[6]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L485-R487) [[7]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L635-R637) [[8]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L769-R769) [[9]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L794-R796) [[10]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L812-R814) [[11]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L890-R890) These changes collectively future-proof the attention fusion optimizer for a wider range of ONNX models and operator versions, reducing the likelihood of unsupported patterns during optimization. --- onnxruntime/core/graph/graph_utils.cc | 9 + onnxruntime/core/graph/graph_utils.h | 5 + .../core/optimizer/attention_fusion.cc | 58 +-- .../core/optimizer/attention_fusion_helper.h | 98 ++--- .../core/optimizer/embed_layer_norm_fusion.cc | 71 ++-- .../group_query_attention_pre_norm_fusion.cc | 4 +- onnxruntime/core/optimizer/reshape_fusion.cc | 16 +- .../test/optimizer/graph_transform_test.cc | 351 +++++++++++++++++- .../graph_transform_test_layernorm.cc | 243 ++++++++++++ ...up_query_attention_pre_norm_fusion_test.cc | 15 + 10 files changed, 743 insertions(+), 127 deletions(-) diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 0334597549609..f083297c16fea 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -8,6 +8,7 @@ #include "core/common/logging/logging.h" #include +#include #include #include #include @@ -411,6 +412,14 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s return iter == attrs.end() ? nullptr : &iter->second; } +bool IsFullShapeNode(const Node& node) { + const auto* start_attr = GetNodeAttribute(node, "start"); + const auto* end_attr = GetNodeAttribute(node, "end"); + // end=INT64_MAX is the runtime default meaning "all dimensions" (full shape). + return (!start_attr || start_attr->i() == 0) && + (!end_attr || end_attr->i() == std::numeric_limits::max()); +} + static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { ONNX_NAMESPACE::TypeProto new_type; auto* typeproto_tensor = new_type.mutable_tensor_type(); diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 2106da1a96327..5681d4e1d08f0 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -8,6 +8,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/graph/graph.h" +#include #include #include @@ -31,6 +32,10 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, /** Returns the attribute of a Node with a given name. */ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name); +/** Checks whether a Shape node returns the full tensor shape (all dimensions). + * Returns false if start/end attributes restrict the output to a subset of dimensions. */ +bool IsFullShapeNode(const Node& node); + /** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. @returns The NodeArg for the new initializer. diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 7fe7c914fa796..30eaaafccca82 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -349,7 +349,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0); if (sequence_transpose == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) || !HasExpectedPerm(*sequence_transpose, {0, 2, 1}) || !optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) { return false; @@ -357,14 +357,14 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0); if (input_reshape == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) || !optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) { return fail("missing input Reshape before sequence transpose"); } Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape"); if (qkv_reshape == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) || !optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) { return fail("qkv Reshape after MatMul not matched"); } @@ -379,9 +379,9 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze"); Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose"); if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) || - !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) || - !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) || !HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) || !HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) || !HasExpectedAxesInput(graph, *k_squeeze, {2})) { @@ -391,8 +391,8 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze"); Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze"); if (q_squeeze == nullptr || v_squeeze == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) || - !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) || !HasExpectedAxesInput(graph, *q_squeeze, {0}) || !HasExpectedAxesInput(graph, *v_squeeze, {0})) { return fail("q/v squeeze pattern not matched"); @@ -402,7 +402,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose"); if (q_scale_mul == nullptr || k_transpose == nullptr || !graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) || - !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) || !HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) { return fail("q scale Mul or k Transpose(0,2,3,1) not matched"); } @@ -460,7 +460,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose"); if (transpose_3 == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) || !HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) || !optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) { return fail("output Transpose(0,2,1,3) not matched"); @@ -468,7 +468,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape"); if (reshape_2 == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) || !optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) { return fail("output Reshape not matched"); } @@ -497,7 +497,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, if (proj_gemm == nullptr) { proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape"); if (proj_gemm_input_reshape == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) || !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) { return fail("projection MatMul/Gemm not matched"); } @@ -511,7 +511,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul, proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape"); if (proj_gemm_output_reshape == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) || !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) { return fail("normalized projection Gemm output Reshape not matched"); } @@ -920,11 +920,11 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, } std::vector q_path{ - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, - {0, 0, "LayerNormalization", {1}, kOnnxDomain}}; + {0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}}; if (!graph_utils::FindPath(edges[edges.size() - 1]->GetNode(), true, q_path, edges, logger)) { DEBUG_LOG("Failed to find path for q"); return false; @@ -953,9 +953,9 @@ static bool FuseSubGraphQKImpl(Node& layer_norm, } std::vector k_path{ - {0, 1, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, {0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}}; @@ -1070,8 +1070,8 @@ static bool FuseSubGraphQK(Node& layer_norm, const logging::Logger& logger) { // path to q std::vector q_varience_path{ - {0, 0, "Div", {7, 13}, kOnnxDomain}, - {0, 0, "MatMul", {1, 9}, kOnnxDomain}}; + {0, 0, "Div", {7, 13, 14}, kOnnxDomain}, + {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(*(mask_nodes.add), true, q_varience_path, edges, logger)) { DEBUG_LOG("Failed to find path for q"); @@ -1163,7 +1163,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm, // path to q std::vector q_varience_path{ {0, 2, "MatMul", {1, 9, 13}, kOnnxDomain}, - {0, 0, "Div", {7, 13}, kOnnxDomain}}; + {0, 0, "Div", {7, 13, 14}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(*(mask_nodes.where), true, q_varience_path, edges, logger)) { DEBUG_LOG("Failed to find path for q"); @@ -1265,14 +1265,14 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm, std::map& mask_int32_map, const logging::Logger& logger) { std::vector parent_path{ - {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, - {0, 1, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, {0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}}; diff --git a/onnxruntime/core/optimizer/attention_fusion_helper.h b/onnxruntime/core/optimizer/attention_fusion_helper.h index a328a6d451a89..b7bd36c8df9e4 100644 --- a/onnxruntime/core/optimizer/attention_fusion_helper.h +++ b/onnxruntime/core/optimizer/attention_fusion_helper.h @@ -74,14 +74,14 @@ bool MatchGemmSubgraph(Graph& graph, DEBUG_LOG("Start MatchGemmSubgraph"); // GPT Attention fusion supports opset version 9 or later. std::vector parent_path{ - {0, dst_arg_index, "Reshape", {5, 13}, kOnnxDomain}, + {0, dst_arg_index, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gemm", {9, 11, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 1, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Slice", {1, 10, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(node_after_gemm_reshape, true, parent_path, edges, logger)) { @@ -98,6 +98,14 @@ bool MatchGemmSubgraph(Graph& graph, const Node& slice = edges[6]->GetNode(); const Node& shape_before_slice = edges[7]->GetNode(); + // The downstream Slice/Squeeze/Gather nodes assume Shape returns the full tensor shape so + // that indices map directly to tensor dimensions. A partial shape (opset 15+ start/end + // attributes) would produce incorrect index mapping. + if (!graph_utils::IsFullShapeNode(shape_before_slice)) { + DEBUG_LOG("Shape node has non-default start/end attributes"); + return false; + } + const auto& subgraph_input = shape_before_slice.InputDefs()[0]; if (reshape_before_gemm.InputDefs()[0]->Name() != subgraph_input->Name()) { DEBUG_LOG("Input of reshape_before_gemm is not the input of subgraph"); @@ -166,9 +174,9 @@ bool MatchGemmSubgraph(Graph& graph, // Match: [Input] ----> Shape --> Gather (indices=0 or 1) --> Unsqueeze (axes=0) ----> Concat ( , , ) for (int i = 0; i < 2; i++) { std::vector gather_path1{ - {0, i, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, i, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(concat_after_gather, true, gather_path1, edges, logger)) { DEBUG_LOG("Faild to match gemm gather path"); @@ -375,8 +383,8 @@ bool ValidateUnidirMask(const Graph& graph, const NodeArg& mask, bool& is_unidir bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnidirMaskResult& result, bool use_shared_node, const logging::Logger& logger) { DEBUG_LOG("Start MatchUnidirMaskSubgraph"); std::vector root_path{ - {0, 0, "Where", {9}, kOnnxDomain}, - {0, 1, "Div", {7, 13}, kOnnxDomain}}; + {0, 0, "Where", {9, 16}, kOnnxDomain}, + {0, 1, "Div", {7, 13, 14}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(add_node, true, root_path, edges, logger)) { @@ -392,14 +400,14 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } std::vector path1{ - {0, 0, "Cast", {9, 13}, kOnnxDomain}, + {0, 0, "Cast", {9, 13, 19, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Slice", {10, 11, 13}, kOnnxDomain}, // Last Slice {0, 0, "Slice", {10, 11, 13}, kOnnxDomain}, // Mask Slice - {0, 1, "Unsqueeze", {9, 11, 13}, kOnnxDomain}, - {0, 0, "Sub", {7, 13}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 1, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Sub", {7, 13, 14}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Slice", {10, 11, 13}, kOnnxDomain}, // Slice 1 - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(where_node, true, path1, edges, logger)) { DEBUG_LOG("Faild to match path 1 for unidirectional mask"); @@ -454,8 +462,8 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } std::vector slice_ends_path{ - {0, 2, "Unsqueeze", {9, 11, 13}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}}; + {0, 2, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(last_slice, true, slice_ends_path, edges, logger) || edges[1]->GetNode().Index() != squeeze1.Index()) { @@ -482,9 +490,9 @@ bool MatchUnidirMaskSubgraph(const Graph& graph, const Node& add_node, MatchUnid } std::vector path4{ - {0, 1, "Squeeze", {1, 11, 13}, kOnnxDomain}, + {0, 1, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Slice", {10, 11, 13}, kOnnxDomain}, // Slice 2 - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(sub, true, path4, edges, logger)) { DEBUG_LOG("Faild to match path 4 for unidirectional mask"); @@ -632,9 +640,9 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& qkv_matmul, Attentio } std::vector mask_path{ - {0, 0, "Add", {7, 13}, kOnnxDomain}, - {0, 1, "Mul", {7, 13}, kOnnxDomain}, - {0, 0, "Sub", {7, 13}, kOnnxDomain}}; + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, + {0, 1, "Mul", {7, 13, 14}, kOnnxDomain}, + {0, 0, "Sub", {7, 13, 14}, kOnnxDomain}}; if (!graph_utils::FindPath(softmax, true, mask_path, edges, logger)) { DEBUG_LOG("Failed to find path for mask"); @@ -766,7 +774,7 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No // expand has another input Shape <-- qk_MatMul std::vector shape_path{ - {0, 1, "Shape", {1, 13}, kOnnxDomain}, + {0, 1, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(expand, true, shape_path, edges, logger)) { DEBUG_LOG("Failed to find shape path"); @@ -791,9 +799,9 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No // reshape node's shape input std::vector reshape_shape_path_1{ {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(reshape, true, reshape_shape_path_1, edges, logger)) { DEBUG_LOG("Failed to find reshape shape path 1"); return false; @@ -809,9 +817,9 @@ bool MatchInputMaskSubgraph(const Graph& graph, const Node& layer_norm, const No } std::vector reshape_shape_path_2{ - {0, 3, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 3, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(concat, true, reshape_shape_path_2, edges, logger)) { DEBUG_LOG("Failed to find reshape shape path 2"); return false; @@ -887,7 +895,7 @@ bool MatchPastSubgraph(Graph& graph, const Node& k_concat, const Node& v_concat, MatchPastResult& result, const logging::Logger& logger) { DEBUG_LOG("Start MatchPastSubgraph"); std::vector past_k_path{ - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; if (transpose_optimized_pattern) { @@ -904,13 +912,13 @@ bool MatchPastSubgraph(Graph& graph, const Node& k_concat, const Node& v_concat, const Node& past_k_gather = edges[i++]->GetNode(); std::vector present_k_path{ - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Concat", {4, 11, 13}, kOnnxDomain}}; if (transpose_optimized_pattern) { present_k_path = { - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Concat", {4, 11, 13}, kOnnxDomain}}; } @@ -924,7 +932,7 @@ bool MatchPastSubgraph(Graph& graph, const Node& k_concat, const Node& v_concat, const Node& present_concat = edges[i++]->GetNode(); std::vector present_past_v_path{ - {0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 1, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Concat", {4, 11, 13}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(present_concat, true, present_past_v_path, edges, logger)) { @@ -1024,7 +1032,7 @@ bool CheckDistilBertReshapeShape(const Graph& graph, const Node& reshape, int64_ // lazy check: record unqueeze first and then check in the mask path std::vector shape_path{ {0, 1, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}}; + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(reshape, true, shape_path, edges, logger)) { DEBUG_LOG("Failed to find shape path"); @@ -1339,9 +1347,9 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: } std::vector path1{ - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "MatMul", {1, 9}, kOnnxDomain}}; + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}}; std::vector edges; if (!graph_utils::FindPath(*gemm1_result.input_node, true, path1, edges, logger)) { @@ -1361,9 +1369,9 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: bool has_past = graph_utils::IsSupportedOptypeVersionAndDomain(*v_concat, "Concat", {4, 11, 13}, kOnnxDomain); std::vector path2{ - {0, 1, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {2, 0, "Split", {2, 11, 13}, kOnnxDomain}}; + {0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {2, 0, "Split", {2, 11, 13, 18}, kOnnxDomain}}; if (!graph_utils::FindPath(has_past ? *v_concat : qkv_matmul, true, path2, edges, logger)) { DEBUG_LOG("Faild to find path v to Split"); @@ -1413,9 +1421,9 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: // path to q std::vector q_path{ {0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}, - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {0, 0, "Split", {2, 11, 13}, kOnnxDomain}}; + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Split", {2, 11, 13, 18}, kOnnxDomain}}; const Node* qk_div = unidir_mask_result.div_node; if (!graph_utils::FindPath(*qk_div, true, q_path, edges, logger)) { @@ -1447,7 +1455,7 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: return false; } - if (graph_utils::IsSupportedOptypeVersionAndDomain(*k_concat, "Transpose", {1, 13, 21}, kOnnxDomain)) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*k_concat, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain)) { transpose_optimized_pattern = true; DEBUG_LOG("Using transpose optimized pattern"); opt_k_transpose = k_concat; @@ -1471,9 +1479,9 @@ bool FuseGptAttention(Node& layer_norm, Graph& graph, int64_t hidden_size, std:: // path to k std::vector k_path{ - {0, 1, "Transpose", {1, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}, - {1, 0, "Split", {2, 11, 13}, kOnnxDomain}}; + {0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}, + {1, 0, "Split", {2, 11, 13, 18}, kOnnxDomain}}; if (!graph_utils::FindPath(has_past ? *k_concat : qk_matmul, true, k_path, edges, logger)) { DEBUG_LOG("Failed to find path for k"); diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 606e91ce91bbb..e2a3ef2a4fe2d 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -115,9 +115,9 @@ static bool MatchInputToConcatSubgraph( const NodeIndex expected_gather_node_1_index) { std::vector expand_parent_path1{ {0, index, "Concat", {4, 11, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}, + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}, }; std::vector edges; @@ -138,6 +138,14 @@ static bool MatchInputToConcatSubgraph( } } + // The Gather(index=0) below assumes Shape returns the full tensor shape. A partial shape + // (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension. + const Node& shape_node_path1 = edges[shape_index]->GetNode(); + if (!graph_utils::IsFullShapeNode(shape_node_path1)) { + DEBUG_LOG("Shape node in path 1 has non-default start/end attributes."); + return false; + } + Node& concat_node = *graph.GetNode(edges[0]->GetNode().Index()); Node& gather_node_0 = *graph.GetNode(edges[2]->GetNode().Index()); Node& shape_node_0 = *graph.GetNode(edges[3]->GetNode().Index()); @@ -147,9 +155,9 @@ static bool MatchInputToConcatSubgraph( } std::vector concat_parent_path{ - {0, 1, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 1, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (!graph_utils::FindPath(concat_node, true, concat_parent_path, edges, logger)) { DEBUG_LOG("Failed to find path 2 of position shape."); @@ -167,6 +175,13 @@ static bool MatchInputToConcatSubgraph( Node& gather_node_1 = *graph.GetNode(edges[1]->GetNode().Index()); Node& shape_node_1 = *graph.GetNode(edges[2]->GetNode().Index()); + // The Gather(index=1) below assumes Shape returns the full tensor shape. A partial shape + // (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension. + if (!graph_utils::IsFullShapeNode(shape_node_1)) { + DEBUG_LOG("Shape node in path 2 has non-default start/end attributes."); + return false; + } + // The gather node (with second input indices==1) is also shared by other subgraph if (expected_gather_node_1_index != gather_node_1.Index()) { DEBUG_LOG("Gather node in path 2 is not linked to another subgraph."); @@ -231,42 +246,42 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( // --> Cast --> Unsqueeze --> Expand --> Gather std::vector parent_path_1{ {0, 1, "Expand", {8, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Cast", {9, 13}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Cast", {9, 13, 19, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "NonZero", {9, 13}, kOnnxDomain}, - {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9, 20, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; // Look for Path 2 (Path 1 with no cast): std::vector parent_path_2{ {0, 1, "Expand", {8, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Transpose", {1, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "NonZero", {9, 13}, kOnnxDomain}, - {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9, 20, 21, 23, 24, 25}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; // Path 3 Pattern: // Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand std::vector parent_path_3{ {0, 1, "Expand", {8, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Range", {1, 11}, kOnnxDomain}, - {0, 1, "Cast", {9, 13}, kOnnxDomain}, + {0, 1, "Cast", {9, 13, 19, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; // Path 4 pattern (Path 3 with no "Cast"): std::vector parent_path_4{ {0, 1, "Expand", {8, 13}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11, 13}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11, 13, 21, 23, 24, 25}, kOnnxDomain}, {0, 0, "Range", {1, 11}, kOnnxDomain}, {0, 1, "Gather", {1, 11, 13}, kOnnxDomain}, - {0, 0, "Shape", {1, 13}, kOnnxDomain}}; + {0, 0, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}; // Match one of the three path patterns. if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) && !graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) && @@ -318,7 +333,7 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( // Match Shape --> Expand path. std::vector pg_edges_2; - if (!graph_utils::FindPath(expand_node, true, {graph_utils::EdgeEndToMatch{0, 1, "Shape", {1, 13}, kOnnxDomain}}, pg_edges_2, logger)) { + if (!graph_utils::FindPath(expand_node, true, {graph_utils::EdgeEndToMatch{0, 1, "Shape", {1, 13, 15, 19, 21, 23, 24, 25}, kOnnxDomain}}, pg_edges_2, logger)) { DEBUG_LOG("Failed to match Shape node. "); return false; } @@ -338,9 +353,9 @@ static bool MatchPositionEmbeddingSubgraphsFromGather( // -------------------- std::vector pg_edges_2; std::vector path_to_match_1{ - {0, 1, "Where", {9}, kOnnxDomain}, - {0, 0, "Equal", {1, 7, 11, 13}, kOnnxDomain}, - {0, 0, "Reshape", {5, 13}, kOnnxDomain}}; + {0, 1, "Where", {9, 16}, kOnnxDomain}, + {0, 0, "Equal", {1, 7, 11, 13, 19}, kOnnxDomain}, + {0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain}}; if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) || !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) || @@ -559,7 +574,7 @@ static bool FuseSubGraph(Graph& graph, // Trace back to find Gather --> Add --> LayerNormalization std::vector word_embedding_path{ - {0, 0, "Add", {7, 13}, kOnnxDomain}, + {0, 0, "Add", {7, 13, 14}, kOnnxDomain}, {0, 0, "Gather", {1, 11, 13}, kOnnxDomain}}; if (!graph_utils::FindPath(layer_norm_add_node, true, word_embedding_path, edges, logger)) { return false; @@ -843,7 +858,7 @@ Status EmbedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l std::vector edges; // Find Add --> LayerNormalization - if (!graph_utils::FindPath(layer_norm_node, true, {graph_utils::EdgeEndToMatch{0, 0, "Add", {7, 13}, kOnnxDomain}}, edges, logger)) { + if (!graph_utils::FindPath(layer_norm_node, true, {graph_utils::EdgeEndToMatch{0, 0, "Add", {7, 13, 14}, kOnnxDomain}}, edges, logger)) { continue; } Node& layer_norm_add_node = *graph.GetNode(edges[0]->GetNode().Index()); diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc index 909229822f134..3271c4cc9a22f 100644 --- a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc @@ -79,7 +79,7 @@ bool MatchPreNormReshapeChain(Graph& graph, Node* reshape_outer = graph.GetMutableProducerNode(consumer_input->Name()); if (reshape_outer == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_outer, "Reshape", {5, 13, 14, 19, 21, 23})) { + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_outer, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25})) { return false; } if (reshape_outer->GetOutputEdgesCount() != 1) { @@ -174,7 +174,7 @@ bool MatchPreNormReshapeChain(Graph& graph, } Node* reshape_inner = graph.GetMutableProducerNode(sln->InputDefs()[0]->Name()); if (reshape_inner == nullptr || - !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_inner, "Reshape", {5, 13, 14, 19, 21, 23})) { + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_inner, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25})) { return false; } if (reshape_inner->GetOutputEdgesCount() != 1) { diff --git a/onnxruntime/core/optimizer/reshape_fusion.cc b/onnxruntime/core/optimizer/reshape_fusion.cc index f50c0e2e635bc..65f477d9c7479 100644 --- a/onnxruntime/core/optimizer/reshape_fusion.cc +++ b/onnxruntime/core/optimizer/reshape_fusion.cc @@ -15,7 +15,10 @@ namespace onnxruntime { bool GetAxesFromUnsqueezeNode(const Graph& graph, const Node& unsqueeze, InlinedVector& axes) { if (graph_utils::MatchesOpSinceVersion(unsqueeze, {1, 11})) { return graph_utils::GetRepeatedNodeAttributeValues(unsqueeze, "axes", axes); - } else if (graph_utils::MatchesOpSinceVersion(unsqueeze, {13})) { + } + + // Opset 13+ moved axes from attribute to input[1]. + if (unsqueeze.InputDefs().size() > 1) { const NodeArg* axes_node_arg = unsqueeze.InputDefs()[1]; return optimizer_utils::AppendTensorFromInitializer(graph, *axes_node_arg, axes, true); } @@ -169,12 +172,11 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node const Node& gather = edges[1]->GetNode(); const Node& shape = edges[2]->GetNode(); - if (graph_utils::MatchesOpSinceVersion(shape, {15})) { - const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape, "start"); - const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape, "end"); - if (!((!start_attr || static_cast(start_attr->i()) == 0) && (!end_attr))) { - return false; - } + // The fusion assumes Shape returns the full tensor shape so that Gather indices correspond + // directly to tensor dimensions. A partial shape (opset 15+ start/end attributes) would shift + // the index mapping and produce incorrect results. + if (!graph_utils::IsFullShapeNode(shape)) { + return false; } InlinedVector axes; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f14327af4b8dd..5c76074d6b119 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -15,6 +15,7 @@ #include "onnx/defs/parser.h" #include "onnx/defs/printer.h" +#include "onnx/defs/schema.h" #include "core/common/span_utils.h" #include "core/framework/data_types.h" @@ -6597,6 +6598,26 @@ TEST_F(GraphTransformationTests, AttentionFusionDistilBertTest) { EXPECT_EQ(op_to_count["Shape"], 0); } +// These tests verify that attention fusions fire at the CURRENT max ONNX opset. +// When the ONNX opset advances (e.g., via submodule update), nodes will report a new +// SinceVersion(). If the optimizer's version lists are not updated, the fusion will +// fail to match and these tests will fail. +// +// To fix: update the EdgeEndToMatch version lists in the indicated optimizer file +// to include the new opset version for each affected op type. +// +// NOTE: GPT-2 and DistilBert attention models use attribute-based Unsqueeze/Squeeze +// in their mask subgraphs (opset <= 12 only). Those patterns cannot be converted to +// current opset without rewriting the mask matching logic. The MobileCLIP test +// (programmatically built) covers the current-opset attention fusion paths. + +static int GetCurrentOnnxOpset() { + const auto& map = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().Map(); + auto it = map.find(ONNX_NAMESPACE::ONNX_DOMAIN); + EXPECT_TRUE(it != map.end()) << "ONNX domain not found in OpSchemaRegistry"; + return it != map.end() ? it->second.second : 0; +} + enum class MobileClipProjectionType { MatMulAdd, GemmWithReshapes, @@ -6883,6 +6904,21 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) { std::make_unique()); } +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCurrentOpsetTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + TransformerTester(build_test_case, + CheckMobileClipAttentionFusedSession, + TransformerLevel::Level1, + TransformerLevel::Level2, + GetCurrentOnnxOpset(), + 1e-3, + 0.0, + std::make_unique()); +} + TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) { auto build_test_case = [](ModelTestBuilder& builder) { BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); @@ -6975,6 +7011,255 @@ TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionRewriteFa CheckMobileClipAttentionUnfusedMatMulGraph)); } +// Current-opset regression tests for fusion optimizers. +// These construct minimal graphs at the current ONNX opset and verify the optimizer fires. +// If ONNX bumps an op's since_version, the optimizer's version list will miss it and the test fails. + +TEST_F(GraphTransformationTests, GeluFusionCurrentOpsetTest) { + // Pattern: x -> Div(sqrt2) -> Erf -> Add(1) -> Mul(x) -> Mul(0.5) -> output + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 3, 64}}); + auto* sqrt2 = builder.MakeInitializer({}, {1.4142135f}); + auto* one = builder.MakeInitializer({}, {1.0f}); + auto* half = builder.MakeInitializer({}, {0.5f}); + + auto* div_out = builder.MakeIntermediate(); + auto* erf_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* mul1_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Div", {input, sqrt2}, {div_out}); + builder.AddNode("Erf", {div_out}, {erf_out}); + builder.AddNode("Add", {erf_out, one}, {add_out}); + builder.AddNode("Mul", {input, add_out}, {mul1_out}); + builder.AddNode("Mul", {mul1_out, half}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["Gelu"] == 1 || op_to_count["com.microsoft.Gelu"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["Div"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Erf"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Gelu fusion failed at opset ", current_opset, + ". Remaining ops: Div=", op_to_count["Div"], + " Erf=", op_to_count["Erf"], + " Add=", op_to_count["Add"], + " Mul=", op_to_count["Mul"], + ". Either update version lists in " + "onnxruntime/core/optimizer/gelu_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, FastGeluFusionCurrentOpsetTest) { + // FastGelu pattern: x*x*0.044715 + x -> mul(sqrt(2/pi)) -> tanh -> add(1) -> mul(0.5) -> mul(x) + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 3, 64}}); + auto* coeff = builder.MakeInitializer({}, {0.044715f}); + auto* sqrt2pi = builder.MakeInitializer({}, {0.7978845f}); // sqrt(2/pi) + auto* one = builder.MakeInitializer({}, {1.0f}); + auto* half = builder.MakeInitializer({}, {0.5f}); + auto* three = builder.MakeInitializer({}, {3.0f}); + + auto* pow_out = builder.MakeIntermediate(); + auto* mul1_out = builder.MakeIntermediate(); + auto* add1_out = builder.MakeIntermediate(); + auto* mul2_out = builder.MakeIntermediate(); + auto* tanh_out = builder.MakeIntermediate(); + auto* add2_out = builder.MakeIntermediate(); + auto* mul_half_out = builder.MakeIntermediate(); + auto* mul_final_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, three}, {pow_out}); + builder.AddNode("Mul", {pow_out, coeff}, {mul1_out}); + builder.AddNode("Add", {input, mul1_out}, {add1_out}); + builder.AddNode("Mul", {add1_out, sqrt2pi}, {mul2_out}); + builder.AddNode("Tanh", {mul2_out}, {tanh_out}); + builder.AddNode("Add", {tanh_out, one}, {add2_out}); + builder.AddNode("Mul", {input, half}, {mul_half_out}); + builder.AddNode("Mul", {mul_half_out, add2_out}, {mul_final_out}); + builder.AddNode("Identity", {mul_final_out}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["com.microsoft.FastGelu"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["Tanh"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Pow"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "FastGelu fusion failed at opset ", current_opset, + ". Remaining ops: Tanh=", op_to_count["Tanh"], + " Pow=", op_to_count["Pow"], + " Mul=", op_to_count["Mul"], + " Add=", op_to_count["Add"], + ". Either update version lists in " + "onnxruntime/core/optimizer/fast_gelu_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, BiasGeluFusionCurrentOpsetTest) { + // BiasGelu pattern: Add(input, bias) -> Gelu -> output (requires opset >= 20 for ONNX Gelu) + int current_opset = GetCurrentOnnxOpset(); + if (current_opset < 20) { + GTEST_SKIP() << "BiasGelu with ONNX Gelu requires opset >= 20"; + } + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 3, 64}}); + auto* bias = builder.MakeInitializer({64}, -0.5f, 0.5f); + auto* add_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Add", {input, bias}, {add_out}); + builder.AddNode("Gelu", {add_out}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["com.microsoft.BiasGelu"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gelu"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "BiasGelu fusion failed at opset ", current_opset, + ". Remaining ops: Add=", op_to_count["Add"], + " Gelu=", op_to_count["Gelu"], + ". Either update version lists in " + "onnxruntime/core/optimizer/bias_gelu_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, MatMulAddFusionCurrentOpsetTest) { + // MatMul + Add -> Gemm fusion + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 64}}); + auto* weights = builder.MakeInitializer({64, 32}, -1.0f, 1.0f); + auto* bias = builder.MakeInitializer({32}, -0.5f, 0.5f); + auto* matmul_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("MatMul", {input, weights}, {matmul_out}); + builder.AddNode("Add", {matmul_out, bias}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["Gemm"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "MatMulAdd fusion failed at opset ", current_opset, + ". Remaining ops: MatMul=", op_to_count["MatMul"], + " Add=", op_to_count["Add"], + ". Either update version lists in " + "onnxruntime/core/optimizer/matmul_add_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, DivMulFusionCurrentOpsetTest) { + // 1/x * y -> y/x fusion + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input_x = builder.MakeInput({{2, 64}}); + auto* input_y = builder.MakeInput({{2, 64}}); + auto* one = builder.MakeInitializer({}, {1.0f}); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Div", {one, input_x}, {div_out}); + builder.AddNode("Mul", {div_out, input_y}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["Div"] == 1 && op_to_count["Mul"] == 0) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "DivMul fusion failed at opset ", current_opset, + ". Remaining ops: Div=", op_to_count["Div"], + " Mul=", op_to_count["Mul"], + ". Either update version lists in " + "onnxruntime/core/optimizer/div_mul_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + auto rule_transformer = std::make_unique("DivMulFusionCurrentOpset"); + ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique())); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::move(rule_transformer), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, QuickGeluFusionCurrentOpsetTest) { + // x * Sigmoid(alpha * x) -> QuickGelu(x, alpha) fusion + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 3, 64}}); + auto* alpha = builder.MakeInitializer({}, {1.702f}); + auto* mul1_out = builder.MakeIntermediate(); + auto* sigmoid_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Mul", {input, alpha}, {mul1_out}); + builder.AddNode("Sigmoid", {mul1_out}, {sigmoid_out}); + builder.AddNode("Mul", {input, sigmoid_out}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["com.microsoft.QuickGelu"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Sigmoid"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "QuickGelu fusion failed at opset ", current_opset, + ". Remaining ops: Mul=", op_to_count["Mul"], + " Sigmoid=", op_to_count["Sigmoid"], + ". Either update version lists in " + "onnxruntime/core/optimizer/quick_gelu_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GeluFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model; @@ -8219,8 +8504,14 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) { return Status::OK(); }; - const std::vector opsets{11, 12, 13, 14, 15, 18}; - bool shape_test_for_opset15 = false; + // Include the current max opset to ensure the fusion stays up-to-date. + // If this test fails at the current opset, update version lists in + // onnxruntime/core/optimizer/reshape_fusion.cc. + std::vector opsets{11, 12, 13, 14, 15, 18, 19, 21, 23, 24, 25}; + int current_opset = GetCurrentOnnxOpset(); + if (std::find(opsets.begin(), opsets.end(), current_opset) == opsets.end()) { + opsets.push_back(current_opset); + } for (auto& opset : opsets) { auto build_test_case = [&](ModelTestBuilder& builder) { @@ -8245,14 +8536,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) { builder.AddNode("Add", {input_arg0, input_arg1}, {add_out}); if (opset_version >= 15) { - if (shape_test_for_opset15) { - auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out}); - shape_1.AddAttribute("start", (int64_t)1); - shape_1.AddAttribute("end", (int64_t)2); - } else { - builder.AddNode("Shape", {add_out}, {shape_out}).AddAttribute("start", (int64_t)0); - shape_test_for_opset15 = true; - } + builder.AddNode("Shape", {add_out}, {shape_out}).AddAttribute("start", (int64_t)0); } else { builder.AddNode("Shape", {add_out}, {shape_out}); } @@ -8271,13 +8555,48 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) { builder.AddNode("Reshape", {add_out, concattraining1_out}, {out}); }; + // Test that the fusion fires for every opset. std::unique_ptr transformer = std::make_unique(); - if (opset >= 15 && shape_test_for_opset15) { - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, pre_graph_checker)); - } else { - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1, - pre_graph_checker, post_graph_checker)); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + + // For opset >= 15, also test that partial Shape (start=1, end=2) prevents fusion. + if (opset >= 15) { + auto build_partial_shape_case = [&](ModelTestBuilder& builder) { + auto* input_arg0 = builder.MakeInput({{batch_size, seq_lenth, hidden_size}}); + auto* input_arg1 = builder.MakeInput({{hidden_size}}); + auto* scalar_int_0 = builder.MakeInitializer({}, {0}); + auto* scalar_int_1 = builder.MakeInitializer({}, {1}); + auto* single_value_1d_int_0 = builder.MakeInitializer({1}, {0}); + auto* single_value_1d_int_16 = builder.MakeInitializer({1}, {16}); + auto* single_value_1d_int_64 = builder.MakeInitializer({1}, {64}); + auto* add_out = builder.MakeIntermediate(); + auto* shape_out = builder.MakeIntermediate(); + auto* gather_out_0 = builder.MakeIntermediate(); + auto* gather_out_1 = builder.MakeIntermediate(); + auto* unsqueeze_out_0 = builder.MakeIntermediate(); + auto* unsqueeze_out_1 = builder.MakeIntermediate(); + auto* concattraining1_out = builder.MakeIntermediate(); + auto* concattraining1_length = builder.MakeIntermediate(); + auto* out = builder.MakeOutput(); + + builder.AddNode("Add", {input_arg0, input_arg1}, {add_out}); + auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out}); + shape_1.AddAttribute("start", (int64_t)1); + shape_1.AddAttribute("end", (int64_t)2); + builder.AddNode("Gather", {shape_out, scalar_int_0}, {gather_out_0}); + builder.AddNode("Gather", {shape_out, scalar_int_1}, {gather_out_1}); + builder.AddNode("Unsqueeze", {gather_out_0, single_value_1d_int_0}, {unsqueeze_out_0}); + builder.AddNode("Unsqueeze", {gather_out_1, single_value_1d_int_0}, {unsqueeze_out_1}); + builder.AddNode("ConcatTraining", {unsqueeze_out_0, unsqueeze_out_1, single_value_1d_int_16, single_value_1d_int_64}, + {concattraining1_out, concattraining1_length}, "com.microsoft") + .AddAttribute("axis", static_cast(0)); + builder.AddNode("Reshape", {add_out, concattraining1_out}, {out}); + }; + + std::unique_ptr transformer_no_fuse = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_partial_shape_case, opset, *logger_, std::move(transformer_no_fuse), + TransformerLevel::Level1, 1, pre_graph_checker, pre_graph_checker)); } } } diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 6ce2357f648d8..263003d212b3a 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -9,6 +9,8 @@ #include "gtest/gtest.h" +#include "onnx/defs/schema.h" + #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" @@ -36,6 +38,13 @@ namespace test { #ifndef DISABLE_CONTRIB_OPS +static int GetCurrentOnnxOpset() { + const auto& map = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().Map(); + auto it = map.find(ONNX_NAMESPACE::ONNX_DOMAIN); + EXPECT_TRUE(it != map.end()) << "ONNX domain not found in OpSchemaRegistry"; + return it != map.end() ? it->second.second : 0; +} + TEST_F(GraphTransformationTests, LayerNormFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; std::shared_ptr p_model; @@ -625,6 +634,111 @@ static void TestSkipLayerNormFusion(const std::basic_string& file_pat ASSERT_TRUE(op_to_count["Cast"] == cast_count); } +// Current-opset regression tests for LayerNorm and SkipLayerNorm fusions. +// These construct minimal graphs at the current ONNX opset and verify the optimizer fires. + +TEST_F(GraphTransformationTests, LayerNormFusionCurrentOpsetTest) { + // LayerNorm pattern: ReduceMean -> Sub -> Pow(2) -> ReduceMean -> Add(eps) -> Sqrt -> Div -> Mul(gamma) -> Add(beta) + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + constexpr int64_t hidden_size = 64; + auto* input = builder.MakeInput({{2, 3, hidden_size}}); + auto* gamma = builder.MakeInitializer({hidden_size}, -1.0f, 1.0f); + auto* beta = builder.MakeInitializer({hidden_size}, -0.5f, 0.5f); + auto* two = builder.MakeInitializer({}, {2.0f}); + auto* eps = builder.MakeInitializer({}, {1e-5f}); + + auto* axes = builder.MakeInitializer({1}, {-1}); + auto* mean1_out = builder.MakeIntermediate(); + auto* sub_out = builder.MakeIntermediate(); + auto* pow_out = builder.MakeIntermediate(); + auto* mean2_out = builder.MakeIntermediate(); + auto* add_eps_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* mul_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("ReduceMean", {input, axes}, {mean1_out}); + builder.AddNode("Sub", {input, mean1_out}, {sub_out}); + builder.AddNode("Pow", {sub_out, two}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out, axes}, {mean2_out}); + builder.AddNode("Add", {mean2_out, eps}, {add_eps_out}); + builder.AddNode("Sqrt", {add_eps_out}, {sqrt_out}); + builder.AddNode("Div", {sub_out, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, gamma}, {mul_out}); + builder.AddNode("Add", {mul_out, beta}, {output}); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["LayerNormalization"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["ReduceMean"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Sub"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Pow"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Sqrt"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Div"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "LayerNorm fusion failed at opset ", current_opset, + ". Remaining ops: ReduceMean=", op_to_count["ReduceMean"], + " Sub=", op_to_count["Sub"], + " Pow=", op_to_count["Pow"], + " Sqrt=", op_to_count["Sqrt"], + " Div=", op_to_count["Div"], + ". Either update version lists in " + "onnxruntime/core/optimizer/layer_norm_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + const InlinedHashSet no_limit_empty_ep_list = {}; + // LayerNorm fusion at Level1 when opset >= 17 (ONNX LayerNormalization available). + // At Level2, it skips if fuse_in_level_1 is true. + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(no_limit_empty_ep_list, TransformerLevel::Level1), + TransformerLevel::Level1, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SkipLayerNormFusionCurrentOpsetTest) { + // SkipLayerNorm pattern: Add(input, skip) -> LayerNormalization(gamma, beta) + int current_opset = GetCurrentOnnxOpset(); + auto build_test_case = [](ModelTestBuilder& builder) { + constexpr int64_t hidden_size = 64; + auto* input = builder.MakeInput({{2, 3, hidden_size}}); + auto* skip = builder.MakeInput({{2, 3, hidden_size}}); + auto* gamma = builder.MakeInitializer({hidden_size}, -1.0f, 1.0f); + auto* beta = builder.MakeInitializer({hidden_size}, -0.5f, 0.5f); + + auto* add_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Add", {input, skip}, {add_out}); + builder.AddNode("LayerNormalization", {add_out, gamma, beta}, {output}) + .AddAttribute("axis", static_cast(-1)); + }; + + auto post_graph_checker = [current_opset](Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + if (op_to_count["com.microsoft.SkipLayerNormalization"] == 1) { + TEST_RETURN_IF_NOT(op_to_count["Add"] == 0); + TEST_RETURN_IF_NOT(op_to_count["LayerNormalization"] == 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "SkipLayerNorm fusion failed at opset ", current_opset, + ". Remaining ops: Add=", op_to_count["Add"], + " LayerNormalization=", op_to_count["LayerNormalization"], + ". Either update version lists in " + "onnxruntime/core/optimizer/skip_layer_norm_fusion.cc" + " or skip this opset in the test if the fusion is not expected to apply."); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, current_opset, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1.onnx", 0, 0, 1, 0, logger_.get()); TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2.onnx", 0, 0, 1, 0, logger_.get()); @@ -1279,6 +1393,135 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2) { ASSERT_TRUE(op_to_count["com.microsoft.EmbedLayerNormalization"] == 1); } +// These tests verify that EmbedLayerNorm fusion fires at the CURRENT max ONNX opset. +// When the ONNX opset advances, nodes will report a new SinceVersion(). If the +// optimizer's version lists are not updated, the fusion will fail to match. +// +// To fix: update the EdgeEndToMatch version lists in +// onnxruntime/core/optimizer/embed_layer_norm_fusion.cc + +// Loads a model and upgrades it to the current ONNX opset. Currently handles converting +// Squeeze/Unsqueeze/Reduce* ops from attribute-based axes to input-based (opset 13+/18+). +// Extend this function if additional op conversions are needed for new test models. +static void LoadModelAtCurrentOpset(const ORTCHAR_T* base_model_uri, + std::shared_ptr& p_model, + const logging::Logger& logger) { + int current_opset = GetCurrentOnnxOpset(); + ONNX_NAMESPACE::ModelProto model_proto; + { + std::shared_ptr base_model; + ASSERT_STATUS_OK(Model::Load(base_model_uri, base_model, nullptr, logger)); + model_proto = base_model->ToProto(); + } + for (auto& opset_import : *model_proto.mutable_opset_import()) { + if (opset_import.domain().empty() || opset_import.domain() == "ai.onnx") { + opset_import.set_version(current_opset); + break; + } + } + + // Convert attribute-based Squeeze/Unsqueeze to input-based for opset 13+ compatibility. + // Also convert ReduceSum/ReduceMean axes attribute to input for opset 18+ compatibility. + auto* graph_proto = model_proto.mutable_graph(); + int next_init_id = 0; + for (auto& node : *graph_proto->mutable_node()) { + bool is_squeeze_unsqueeze = (node.op_type() == "Squeeze" || node.op_type() == "Unsqueeze"); + bool is_reduce = (node.op_type() == "ReduceSum" || node.op_type() == "ReduceMean" || + node.op_type() == "ReduceMax" || node.op_type() == "ReduceMin" || + node.op_type() == "ReduceProd"); + if (!is_squeeze_unsqueeze && !is_reduce) { + continue; + } + int axes_attr_idx = -1; + for (int i = 0; i < node.attribute_size(); ++i) { + if (node.attribute(i).name() == "axes") { + axes_attr_idx = i; + break; + } + } + if (axes_attr_idx < 0) { + continue; + } + const auto& axes_attr = node.attribute(axes_attr_idx); + std::string init_name = "__axes_init_" + std::to_string(next_init_id++); + auto* init = graph_proto->add_initializer(); + init->set_name(init_name); + init->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + init->add_dims(axes_attr.ints_size()); + for (int i = 0; i < axes_attr.ints_size(); ++i) { + init->add_int64_data(axes_attr.ints(i)); + } + // For Squeeze/Unsqueeze, axes is input[1]. For Reduce ops, axes is also input[1]. + node.add_input(init_name); + node.mutable_attribute()->SwapElements(axes_attr_idx, node.attribute_size() - 1); + node.mutable_attribute()->RemoveLast(); + } + + ASSERT_STATUS_OK(Model::Load(std::move(model_proto), p_model, nullptr, logger)); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat1CurrentOpset) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format1.onnx"; + std::shared_ptr p_model; + ASSERT_NO_FATAL_FAILURE(LoadModelAtCurrentOpset(model_uri, p_model, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1) + << "EmbedLayerNorm fusion (format 1) failed at opset " << GetCurrentOnnxOpset() << ". " + << "Update version lists in onnxruntime/core/optimizer/embed_layer_norm_fusion.cc " + << "(MatchInputToConcatSubgraph, MatchPositionEmbeddingSubgraph)."; + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Add"], 0); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat2CurrentOpset) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format2.onnx"; + std::shared_ptr p_model; + ASSERT_NO_FATAL_FAILURE(LoadModelAtCurrentOpset(model_uri, p_model, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1) + << "EmbedLayerNorm fusion (format 2: NonZero+Transpose+Squeeze path) failed at opset " + << GetCurrentOnnxOpset() << ". " + << "Update version lists in onnxruntime/core/optimizer/embed_layer_norm_fusion.cc " + << "(parent_path_1, parent_path_2)."; + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); +} + +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat3CurrentOpset) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format3.onnx"; + std::shared_ptr p_model; + ASSERT_NO_FATAL_FAILURE(LoadModelAtCurrentOpset(model_uri, p_model, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.EmbedLayerNormalization"], 1) + << "EmbedLayerNorm fusion (format 3: Range-based path) failed at opset " + << GetCurrentOnnxOpset() << ". " + << "Update version lists in onnxruntime/core/optimizer/embed_layer_norm_fusion.cc " + << "(parent_path_3, parent_path_4)."; + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); +} + static void EmbedLayerNormFusionFormat3(const std::basic_string& file_path, logging::Logger* logger) { std::shared_ptr p_model; ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, *logger).IsOK()); diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc index b330475c01486..3efaba78004d6 100644 --- a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -16,6 +16,7 @@ #include "test/optimizer/webgpu_fusion_test_util.h" #include "gtest/gtest.h" +#include "onnx/defs/schema.h" namespace onnxruntime { namespace test { @@ -258,6 +259,20 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatter TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); } +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPatternCurrentOpset) { + // Uses the current max ONNX opset to catch version-list drift. + // If this fails, update version lists in + // onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc. + const auto& map = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance().Map(); + auto it = map.find(ONNX_NAMESPACE::ONNX_DOMAIN); + ASSERT_TRUE(it != map.end()) << "ONNX domain not found in OpSchemaRegistry"; + int current_opset = it->second.second; + auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/current_opset, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionMatchesUnfusedWebGpuResults) { if (!DefaultWebGpuExecutionProvider()) { GTEST_SKIP() << "WebGPU EP unavailable in this build."; From 9c5e3be2cd6e53c8a89dae2e88baaa51c8cf63fd Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 12 Jun 2026 18:07:20 +0000 Subject: [PATCH 14/15] Add `OrtErrorCode` documentation (#29018) ### Description Add documentation for `OrtErrorCode` enum and its values. ### Motivation and Context Provide some documentation about what the error codes mean. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- include/onnxruntime/core/common/status.h | 6 +- .../core/session/onnxruntime_c_api.h | 55 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index 8cf6420f2d0f7..eee75d399d767 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -29,8 +29,10 @@ enum StatusCategory { }; /** - Error code for ONNXRuntime. -*/ + * Error code for ONNXRuntime. + * + * These values must stay in sync with the public C API OrtErrorCode enum values. + */ enum StatusCode { OK = 0, FAIL = 1, diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 0289bca5c84d2..edb42c0a2f596 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -263,21 +263,76 @@ typedef enum OrtLoggingLevel { ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe). } OrtLoggingLevel; +/** \brief Error codes reported by ONNX Runtime. + * + * The error code associated with an ::OrtStatus. + */ typedef enum OrtErrorCode { + /** + * Success. No error occurred. + */ ORT_OK, + /** + * Generic failure that does not map to a more specific error code. Consult the error message for details. + */ ORT_FAIL, + /** + * A caller-supplied argument was invalid (e.g. NULL pointer, out-of-range value, mismatched shape/rank, or bad + * configuration). + */ ORT_INVALID_ARGUMENT, + /** + * A required file (such as a model file) does not exist. + */ ORT_NO_SUCHFILE, + /** + * Legacy/unused but retained for ABI compatibility. Historically returned when a model could not be found by name in + * the ONNX Runtime Server (removed in 2022). + */ ORT_NO_MODEL, + /** + * A hardware accelerator or backend engine reported a failure (e.g. a device crash or other device-level error). + */ ORT_ENGINE_ERROR, + /** + * A generic runtime exception was caught. The error message is the primary source of detail. + */ ORT_RUNTIME_EXCEPTION, + /** + * Protobuf parsing or serialization failed. + */ ORT_INVALID_PROTOBUF, + /** + * Invalid session state for the requested operation. Despite the name, this code does not mean "success, model + * loaded"; it is returned when the session is in the wrong state for the requested call (e.g. a model is already + * loaded, the session is already initialized, or no model has been loaded yet). The name is historical and is + * retained for ABI compatibility; consult the error message for the specific condition. + */ ORT_MODEL_LOADED, + /** + * The requested functionality is not implemented in this build. + */ ORT_NOT_IMPLEMENTED, + /** + * The model graph is structurally invalid (e.g. recursive function definitions, invalid tensor dimensions, or + * malformed nodes). + */ ORT_INVALID_GRAPH, + /** + * An execution provider reported a generic failure. + */ ORT_EP_FAIL, + /** + * Model loading or session initialization was canceled at the caller's request. + */ ORT_MODEL_LOAD_CANCELED, + /** + * The model requires compilation by an execution provider, but compilation was disabled via session options. + */ ORT_MODEL_REQUIRES_COMPILATION, + /** + * A requested resource could not be found. + */ ORT_NOT_FOUND, } OrtErrorCode; From 29616be9d03e2eb6c02d4ed98d99ca915243dcf4 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 12 Jun 2026 13:20:14 -0700 Subject: [PATCH 15/15] Implement Name Based partitioning and update documents (#28903) This pull request introduces a new "name-based layer assignment" feature for ONNX Runtime, allowing device assignment of model nodes based on substring matching of node names, as an alternative to annotation-based matching. The implementation ensures that name-based and annotation-based assignment modes are mutually exclusive, and updates documentation, configuration keys, and core logic to support this new capability. **Key changes:** #### Name-based Layer Assignment Feature - Added support for a new session option, `session.name_based_layer_assignment`, which enables device assignment using substring matching against node names (rather than metadata annotations). The longest matching pattern wins, and all patterns are treated as substrings (the `=` exact-match qualifier is disallowed in this mode). [[1]](diffhunk://#diff-10b3051b9e36eccfc7ca0f2d44ce78a9980ca573cde0f931ffd1456da2c681daR181-R218) [[2]](diffhunk://#diff-62d211e77c575a2fec6492c9fbfe25743fac9d6d72be7c007e7f6eb8dbecc7e7R423-R437) - Implemented the `SubstringMatcher` class and integrated it into the `LayeringIndex` logic, enabling efficient substring-based rule matching for node assignment. The matcher sorts patterns by length and returns the first (longest) match. [[1]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5R604-R627) [[2]](diffhunk://#diff-b64b395fbb0afa67dcf493f97493f3df353486c3c8344df89f25df057ca3840fR88-R116) #### Mutual Exclusivity and API Changes - Enforced mutual exclusivity between annotation-based (`session.layer_assignment_settings`) and name-based (`session.name_based_layer_assignment`) assignment options. Attempting to set both now results in an error, with clear messaging. [[1]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L338-L359) [[2]](diffhunk://#diff-b64b395fbb0afa67dcf493f97493f3df353486c3c8344df89f25df057ca3840fR159-R177) - Updated the `LayeringIndex::Create` API and related logic to accept both configuration strings, select the active mode, and construct the appropriate matcher. [[1]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L338-L359) [[2]](diffhunk://#diff-b64b395fbb0afa67dcf493f97493f3df353486c3c8344df89f25df057ca3840fR186) #### Documentation Updates - Expanded the documentation to describe the new name-based assignment feature, provide usage examples, highlight best practices for pattern writing, and explain the mutual exclusivity and lack of subgraph inheritance in name-based mode. [[1]](diffhunk://#diff-10b3051b9e36eccfc7ca0f2d44ce78a9980ca573cde0f931ffd1456da2c681daR181-R218) [[2]](diffhunk://#diff-10b3051b9e36eccfc7ca0f2d44ce78a9980ca573cde0f931ffd1456da2c681daL295-R357) #### Core Logic and Maintenance - Refactored the `LayeringIndex` and related methods to support both matching modes, updating node assignment and update logic to branch appropriately based on the selected mode. [[1]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L426-R485) [[2]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5R537-R548) - Added and documented the new configuration key `kOrtSessionOptionsNameBasedLayerAssignment` in the public API headers. These changes provide a more flexible and user-friendly way to partition models across devices, especially for models with structured node names but lacking explicit annotations. --- ...ningWithAnnotationsAndMemoryConstraints.md | 56 +- .../cuda_kernel_workspace_inventory.md | 432 ++++++ .../future_directions_constrained_env.md | 1338 +++++++++++++++++ .../onnxruntime_session_options_config_keys.h | 15 + .../core/framework/layering_annotations.cc | 161 +- .../core/framework/layering_annotations.h | 57 +- onnxruntime/core/session/inference_session.cc | 5 +- .../framework/layering_annotations_test.cc | 259 +++- .../test/framework/session_state_test.cc | 2 +- 9 files changed, 2271 insertions(+), 54 deletions(-) create mode 100644 docs/annotated_partitioning/cuda_kernel_workspace_inventory.md create mode 100644 docs/annotated_partitioning/future_directions_constrained_env.md diff --git a/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md b/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md index 34092fe9a0307..60735119fcf1e 100644 --- a/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md +++ b/docs/annotated_partitioning/PartitioningWithAnnotationsAndMemoryConstraints.md @@ -178,6 +178,44 @@ Nodes that do not match any rule fall through to the normal EP capability-based > **Note — Annotations vs. actual placement:** An annotation expresses a *preference*, not a guarantee. If the target EP does not have a registered kernel for a node (for example, a particular data-type / opset-version combination is not implemented in the CUDA EP), that node will not be placed on the requested device. Instead it falls through to the next EP in the provider list that can handle it. +### Name-Based Layer Assignment (No Model Modification) + +For models that already have structured node names (most HuggingFace exports, ONNX models produced by PyTorch, etc.), you can skip the annotation step entirely. The session option `session.name_based_layer_assignment` performs **substring matching** directly against `Node::Name()`: + +``` +device1(pattern1, pattern2, ...); device2(pattern3, pattern4, ...) +``` + +- **Substring matching:** A pattern matches if it appears *anywhere* in the node name. For example, `layers.0/` matches `/model/layers.0/self_attn/q_proj/MatMul`. +- **Longest match wins:** When multiple patterns match the same node name, the longest pattern takes priority. For example, `layers.10/` wins over `layers.1/` for a node named `/model/layers.10/...`. +- **No `=` prefix:** The exact-match qualifier (`=`) from annotation-based syntax is rejected with an error. All patterns are treated as substrings. +- **Same device designators:** The device portion uses the same device designators as `session.layer_assignment_settings` (see table above). + +```python +import onnxruntime as ort + +opts = ort.SessionOptions() + +# Assign layers 0–7 to GPU, layers 8–15 to CPU based on node names +opts.add_session_config_entry( + "session.name_based_layer_assignment", + "gpu(layers.0/, layers.1/, layers.2/, layers.3/, layers.4/, layers.5/, layers.6/, layers.7/); " + "cpu(layers.8/, layers.9/, layers.10/, layers.11/, layers.12/, layers.13/, layers.14/, layers.15/)" +) + +session = ort.InferenceSession("model.onnx", opts, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) +``` + +**Tips for writing patterns:** +- Include the trailing `/` in layer patterns (e.g., `layers.1/` instead of `layers.1`) to avoid `layers.1` accidentally matching `layers.10`, `layers.11`, etc. +- Use [Netron](https://netron.app/) to inspect your model's node names and identify suitable substrings. +- Nodes that do not match any pattern fall through to normal EP capability-based assignment (typically CPU). + +**Mutual exclusivity with annotation-based matching:** The `session.name_based_layer_assignment` and `session.layer_assignment_settings` options are **mutually exclusive** — setting both will return an error. Use annotation-based matching for models that carry explicit `layer_ann` metadata annotations, or name-based matching for unmodified models with structured node names. If you need fine-grained exceptions (e.g., force one specific node to CPU), add the node's name pattern to the name-based config instead of mixing the two approaches. + +**No subgraph inheritance:** Unlike annotation-based matching (where unannotated subgraph nodes inherit their parent's device assignment), name-based matching treats every node independently. Since node names are dense (virtually every node has a name encoding its structural position), inheritance is unnecessary — each node matches on its own name. + ## Capacity-Aware Partitioning (implemented for CUDA) When running models on a CUDA GPU with limited memory, you can set a memory budget so ONNX Runtime stops assigning nodes to the CUDA EP once the estimated memory consumption reaches the limit. Nodes are considered in topological order and assignment halts at the first node that would exceed the budget — ONNX Runtime does not search ahead for smaller nodes that might still fit. Remaining nodes are then eligible for assignment by the subsequent EPs in the session's provider list (often CPU, but not necessarily). @@ -292,26 +330,30 @@ EPs that prefer the NHWC data layout — for example, the CUDA EP when it is cre Because the first-pass tags are tentative, ONNX Runtime does **not** commit any memory budget for them. The budget is committed only for the nodes that survive the second pass; the cost of a node that is dropped is never counted against the memory limit. This keeps the accumulated memory estimate accurate when `prefer_nhwc` is combined with `session.resource_cuda_partitioning_settings`, so a dropped node does not consume phantom budget that could prematurely halt assignment of later nodes. -## Combining Both Features -Layer annotations and capacity-aware partitioning can be used together. When both are configured: -- Layer annotations provide the initial node-to-device mapping. +## Combining Features +Layer annotations OR name-based assignment can be combined with capacity-aware partitioning. Note that annotation-based and name-based matching are **mutually exclusive** — you cannot use both simultaneously. + +When a layer assignment option (either annotation-based or name-based) is configured together with the capacity-aware partitioner: +- The layer assignment option expresses the desired device placement. - The capacity-aware partitioner enforces the memory budget, potentially overriding assignments that would exceed the GPU memory limit. -This combination gives you fine-grained control: use annotations to express logical model structure, and let the memory budget act as a safety net. +This gives you fine-grained control: use annotations or name patterns to express logical model structure, and let the memory budget act as a safety net. ```python opts = ort.SessionOptions() +# Name-based assignment (no model modification needed) opts.add_session_config_entry( - "session.layer_assignment_settings", - "gpu(encoder, decoder); cpu(=postprocess)" + "session.name_based_layer_assignment", + "gpu(layers.0/, layers.1/, layers.2/, layers.3/); cpu(layers.4/, layers.5/, layers.6/, layers.7/)" ) +# Memory budget as a safety net opts.add_session_config_entry( "session.resource_cuda_partitioning_settings", "4194304,node_memory_stats.csv" ) -session = ort.InferenceSession("model_annotated.onnx", opts, +session = ort.InferenceSession("model.onnx", opts, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) ``` diff --git a/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md b/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md new file mode 100644 index 0000000000000..5d5b0136fe2ea --- /dev/null +++ b/docs/annotated_partitioning/cuda_kernel_workspace_inventory.md @@ -0,0 +1,432 @@ +# CUDA Kernel Workspace Buffer Inventory + +This document catalogs all CUDA kernels in ONNX Runtime that allocate temporary/workspace buffers via `GetScratchBuffer` or `GetTransientScratchBuffer`. For each kernel, it identifies what information is needed to compute the workspace size and whether that information is available at `GetCapability()` time (for the workspace estimation function design). + +## Classification Key + +| Symbol | Meaning | +|--------|---------| +| ✅ | Fully determinable from shapes + attributes + device properties | +| ✅* | Determinable via cuDNN/cuBLAS API call (needs handle, available on EP) | +| ⚠️ | Requires profiling/tactic selection (deterministic but costly at planning time) | + +--- + +## Core CUDA Providers (`onnxruntime/core/providers/cuda/`) + +### 1. Attention (LLM — Opset 23/24) + +**File:** `llm/attention.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `softmax_lse_buffer` | `B * S_q * num_heads * sizeof(float)` | shapes | +| `softmax_lse_accum_buffer` | From `get_num_splits_and_buffer_sizes()` | shapes + `multiProcessorCount` | +| `out_accum_buffer` | From `get_num_splits_and_buffer_sizes()` | shapes + `multiProcessorCount` | +| `q_bsnh_buffer` | `B * S_q * num_heads * head_size * element_size` | shapes + dtype | +| `out_bsnh_buffer` | Same as Q | shapes + dtype | +| `k_bsnh_buffer` / `v_bsnh_buffer` | `B * S_kv * num_kv_heads * head_size * element_size` | shapes + dtype | +| `seqlens_k_buffer` | `B * sizeof(int)` | batch size | +| `past_seqlens_buffer` | `B * sizeof(int)` | batch size | +| `k_expand_buffer` / `v_expand_buffer` | `B * num_heads * S_kv * head_size * element_size` (GQA expansion) | shapes + dtype | +| `converted_mask_buffer` | `B * S_q * S_kv * sizeof(float)` | shapes | +| `present_k_scratch` / `present_v_scratch` | present KV cache size | shapes | +| `workspace_buffer` (math attention) | `B * S_q * num_heads * sizeof(float)` | shapes | + +**What's needed to compute:** Input shapes, `num_heads` attribute, `head_size`, dtype, `device_prop.multiProcessorCount`. + +**Static determinability:** ✅ All pure arithmetic on shapes + device SM count. + +--- + +### 2. Conv (cuDNN Frontend) + +**File:** `nn/conv.cc`, `nn/conv.h` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace` | `s_.cudnn_fe_graph->get_workspace_size()` | cuDNN plan selection | +| `memory_for_cudnn_conv_results` (Conv8 only) | `y_dims_with_adjusted_pads.Size() * element_size` | output shape + padding | + +**What's needed to compute:** Input shapes (NCHW), weight shapes, pads/strides/dilations attributes, cuDNN handle (for `build_plans()`). + +**Static determinability:** ✅* — Requires calling `build_plans(handle)` with `HEUR_MODE_A`. The handle is on the EP. All tensor shapes and conv params come from node attributes. + +**Note:** `GetTransientScratchBuffer` (32MB) used for algorithm search in Conv8 — this is a one-time cost during first Compute, not a per-run workspace. + +--- + +### 3. ConvTranspose + +**File:** `nn/conv_transpose.h`, `nn/conv_transpose_8.h` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace` | `s_.workspace_bytes` (from cuDNN FE or algo selection) | Same as Conv | +| `AlgoSearchWorkspaceSize` (Conv8 path) | 32MB constant | N/A | + +**Static determinability:** ✅* — Same as Conv. + +--- + +### 4. DeformConv + +**File:** `nn/deform_conv.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `col_buffer` | `C * kernel_size * col_stride * sizeof(T)` where `col_stride = n_parallel_imgs * output_image_size` | Input shapes, kernel_size, `device_prop.totalGlobalMem` (for chunk sizing) | + +**What's needed:** Input shape (N,C,H,W), kernel dims, output_image_size, `totalGlobalMem` (determines `n_parallel_imgs` via `GetNParallelImgs`). + +**Static determinability:** ✅ — Pure arithmetic on shapes + device memory size. + +--- + +### 5. BatchNorm + +**File:** `nn/batch_norm.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `f_scale`, `f_B`, `f_mean`, `f_var` | `C * sizeof(float)` each | Channel dim (shape[1] or shape[3]) | + +**What's needed:** Channel dimension `C` from input shape. + +**Static determinability:** ✅ — Trivial: `4 * C * sizeof(float)`. + +--- + +### 6. InstanceNorm + +**File:** `nn/instance_norm.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `mean`, `variance` | `N * C * sizeof(T)` | Batch × channels | +| `unused_scale`, `unused_bias` | `N * C * sizeof(T)` | Same | +| `scale_data_fp32`, `bias_data_fp32` | `C * sizeof(float)` (if fp16) | Channel dim + dtype | + +**What's needed:** Input shape (N, C), dtype. + +**Static determinability:** ✅ — Pure arithmetic on shapes. + +--- + +### 7. Dropout + +**File:** `nn/dropout.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| mask buffer | `element_count * sizeof(bool)` or `element_count / 16 * sizeof(BitmaskElementType)` | Input element count, bitmask mode | + +**Static determinability:** ✅ — Input element count. + +--- + +### 8. Reduction Ops (ReduceSum, ReduceMax, ReduceMean, etc.) + +**File:** `reduction/reduction_ops.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace_cuda` | `cudnnGetReductionWorkspaceSize()` | cuDNN handle, input/output tensor descriptors | +| `indices_cuda` | `cudnnGetReductionIndicesSize()` | Same | +| `temp_X` | `input_count * sizeof(float)` (type cast) | Input size | +| `input_data_buffer` | `input_count * sizeof(T)` (for `calculate_sqt_`) | Input size | +| `exp_result_buffer` | `input_count * sizeof(T)` (for `log_sum_exp_`) | Input size | +| `log_sum_result_buffer` | `output_count * sizeof(T)` | Output size | +| `temp_output` | `output_count * sizeof(float)` | Output size | + +**What's needed:** Input/output shapes, reduction axes, op variant (LogSumExp, L2, etc.), cuDNN handle. + +**Static determinability:** ✅* — cuDNN workspace query needs handle + tensor descriptors (constructible from shapes). + +--- + +### 9. RNN (LSTM/GRU) + +**File:** `rnn/cudnn_rnn_base.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace_cuda` | `cudnnGetRNNTempSpaceSizes(fwdInference)` | RNN descriptor, seq_length, batch_size | +| `reservespace_cuda` | `cudnnGetRNNTempSpaceSizes(training)` | Same | +| `reorganized_w_data` | `w_size * sizeof(T)` | hidden_size, num_layers, input_size, direction | +| `x_reversed_data` | `seq_length * batch_size * input_size * sizeof(T)` | Shapes (bidirectional case) | +| `y_alloc_data` | `output_size * sizeof(T)` | Shapes | +| `state_buffer_` | RNN state size from cuDNN | cuDNN descriptor | + +**What's needed:** seq_length, batch_size, input_size, hidden_size, num_layers, direction attribute, cuDNN handle. + +**Static determinability:** ✅* — cuDNN API queries, all inputs available from node attributes/shapes. + +--- + +### 10. TopK + +**File:** `math/topk_impl.cuh` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `input_key_buffer` | `dimension * sizeof(T)` | Last-axis dimension | +| `output_key_buffer` | `dimension * sizeof(T)` | Same | +| `input_value_buffer` | `dimension * sizeof(int64_t)` | Same | +| `output_value_buffer` | `dimension * sizeof(int64_t)` | Same | +| `temp_storage_buffer` | From `cub::DeviceRadixSort::SortPairs` query | dimension | + +**What's needed:** Dimension (last axis size), k, dtype. + +**Static determinability:** ✅ — CUB temp storage query is deterministic given size. + +--- + +### 11. MatMulInteger + +**File:** `math/matmul_integer.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `a_row_buf` | `(output_size / N) * sizeof(int32_t)` | M dimension | +| `b_col_buf` | `(output_size / M) * sizeof(int32_t)` | N dimension | + +**What's needed:** M, N dimensions from MatMul shapes. + +**Static determinability:** ✅ — Pure arithmetic. + +--- + +### 12. IntegerGemm (int8 alignment padding) + +**File:** `integer_gemm.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `a_padded` | `m * roundoff(lda, 32) * sizeof(int8_t)` (only if lda not 32-aligned) | M, K dims | +| `b_padded` | `k * roundoff(ldb, 32) * sizeof(int8_t)` (only if ldb not 32-aligned) | K, N dims | + +**What's needed:** M, K, N dimensions + alignment check. + +**Static determinability:** ✅ — Pure arithmetic. + +--- + +### 13. Compress + +**File:** `tensor/compress.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `condition_cumulative_sum_buffer` | `valid_condition_length * sizeof(int32_t)` | Condition tensor size | +| `temp_buffer` | CUB `DeviceScan::InclusiveSum` temp storage | Condition size | + +**Static determinability:** ✅ — Condition shape determines everything. + +--- + +### 14. GatherND + +**File:** `tensor/gather_nd.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `sizes_from_slice_dims_buffer` | `num_slice_dims * sizeof(int64_t)` | Indices shape | +| `input_slice_offsets_buffer` | `num_slices * sizeof(int64_t)` | Indices shape[:-1] product | + +**Static determinability:** ✅ — Indices shape. + +--- + +### 15. NonZero + +**File:** `tensor/nonzero_op.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `prefix_buffer` | `number_of_blocks * sizeof(int)` | Input element count / block_size | +| `temp_buffer` | CUB temp storage | Input element count | + +**Static determinability:** ✅ — Input element count. + +--- + +### 16. Upsample/Resize + +**File:** `tensor/upsample.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| temp buffer (via lambda) | Varies by resize mode | Input/output shapes, mode | +| `dims_mapping_buffer` | `temp_buffer_size` (coordinate mapping) | Output shape | + +**Static determinability:** ✅ — Input/output shapes + mode attribute. + +--- + +### 17. NonMaxSuppression + +**File:** `object_detection/non_max_suppression.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| Various (via lambda) | Determined by CUB DeviceSelect internals | num_boxes, num_classes | + +**What's needed:** boxes shape (num_batches, num_boxes, 4), scores shape. + +**Static determinability:** ✅ — CUB queries are deterministic given sizes. + +--- + +## Contrib CUDA Operators (`onnxruntime/contrib_ops/cuda/`) + +### 18. Attention / MultiHeadAttention (Contrib) + +**File:** `bert/attention.cc`, `bert/multihead_attention.cc` + +**Buffers:** Uses `GetAttentionWorkspaceSize()` helper function. + +**Size formula:** Depends on attention algorithm (Flash, MemoryEfficient, FusedRunner, Default): +- Flash: `qkv_size` (Q+K+V projection) +- MemoryEfficient: `qkv_size + output_accum (float)` +- Default (unfused): `qkv_size + 2 * attention_scratch_size` + +**What's needed:** B, S_q, S_kv, num_heads, head_size, dtype, which attention algorithm is selected. + +**Static determinability:** ✅ — Algorithm selection depends on shapes + SM version (available from device_prop). + +--- + +### 19. MOE (Mixture of Experts) + +**File:** `moe/moe.cc`, `moe/moe_quantization.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace` | `moe_runner->getWorkspaceSize(num_rows, hidden, inter, experts, k)` | Shapes + tactic | +| `expert_scales` | `num_rows * k * sizeof(float)` | Shapes | +| `expert_indices` | `num_rows * k * sizeof(int)` | Shapes | +| `permutation_row_map` | `num_rows * k * sizeof(int)` | Shapes | + +**What's needed:** num_rows, hidden_size, inter_size, num_experts, k, activation_type, SM version, selected tactic. + +**Static determinability:** ⚠️ — `getWorkspaceSize()` depends on profiled best tactic (CUTLASS config). Could use worst-case across tactics as upper bound. + +--- + +### 20. MatMulNBits (Quantized MatMul) + +**File:** `quantization/matmul_nbits.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `workspace_buffer` | `weightOnlyGemmRunner_->getWorkspaceSize(m, n, k)` | Dims + tactic | +| `packed_transposed_weight_space` | `packed_weight_bytes` (transient) | Weight shape | +| `permutation_map_buffer` | `32 * sizeof(int32_t)` (transient) | Constant | + +**What's needed:** M, N, K dimensions, quantization bits, SM version. + +**Static determinability:** ⚠️ — Runner workspace depends on profiled tactic. Could use upper bound. + +--- + +### 21. fpA_intB GEMM (FP×INT quantized) + +**File:** `llm/fpA_intB_gemm/` + +**Buffers:** `virtual size_t getWorkspaceSize(m, n, k)`. + +**What's needed:** M, N, K + CUTLASS template specialization. + +**Static determinability:** ⚠️ — Depends on selected CUTLASS config/tactic. + +--- + +### 22. Inverse (Matrix Inversion) + +**File:** `inverse.cc` + +**Buffers allocated:** +| Buffer | Size formula | Depends on | +|--------|--------------|-----------| +| `input_workspace` | `input_count * sizeof(T)` | Matrix dimensions | +| `matrix_ptrs` | `n_batches * sizeof(T*)` | Batch count | +| `output_ptrs` | `n_batches * sizeof(T*)` | Batch count | +| `ml_float_output` | `input_count * sizeof(float)` (if fp16→fp32) | Dims + dtype | + +**Static determinability:** ✅ — Pure arithmetic on matrix dimensions. + +--- + +### 23. Generation (Beam Search / Sampling) + +**File:** `transformers/generation_device_helper.cc` + +**Buffers:** Various pinned + device buffers for beam state. + +**What's needed:** batch_size, beam_width, max_length, vocab_size. + +**Static determinability:** ✅ — All from session/model config. + +--- + +## Summary: Coverage Analysis + +### Workspace estimation approach validation + +| Category | # Kernels | Estimation feasibility | Notes | +|----------|-----------|----------------------|-------| +| **Shapes only** | 12 | ✅ Exact, trivial | BatchNorm, InstanceNorm, Dropout, TopK, MatMulInteger, IntegerGemm, Compress, GatherND, NonZero, Upsample, Inverse, Generation | +| **Shapes + device properties** | 3 | ✅ Exact | Attention (SM count), DeformConv (totalGlobalMem), Contrib Attention (SM version) | +| **Shapes + cuDNN/cuBLAS handle** | 4 | ✅* Exact via API query | Conv, ConvTranspose, Reduction, RNN | +| **Shapes + tactic profiling** | 3 | ⚠️ Upper bound only | MOE, MatMulNBits, fpA_intB_GEMM | + +### Key takeaways + +1. **~75% of kernels** (19/25) can produce **exact** workspace estimates at `GetCapability()` time using only shapes + attributes + device properties (+ cuDNN handle for API queries). + +2. **~12% of kernels** (3/25) require tactic profiling (CUTLASS/CUB autotuning). For these, options are: + - Use worst-case workspace across all tactics (safe upper bound) + - Run tactic selection eagerly at estimation time (expensive but exact) + - Accept 1.5x multiplier for these few kernels + +3. **The cuDNN handle requirement** affects only 4 kernel types (Conv, ConvTranspose, Reduction, RNN). All are standard cuDNN API queries that are fast and deterministic given the handle + tensor descriptors. + +4. **No kernel requires actual GPU execution** to determine workspace size — even tactic-based kernels select tactics via CPU-side profiling/heuristics, not by running GPU code. + +5. **Largest workspace consumers** in practice: + - **Attention** (Flash): dominates in LLM workloads. Exact estimation possible. + - **Conv** (cuDNN): dominates in vision workloads. Exact via `build_plans()`. + - **MOE**: significant in MoE models. Upper bound via worst-case tactic. + +### What the estimation function needs access to (API requirements): + +| Access needed | How accessed | Kernels that require it | +|---------------|-------------|------------------------| +| `Node_GetInputShape()` | OrtEpApi (generic) | All 25 kernels | +| `Node_GetAttributeInt/Ints()` | OrtEpApi (generic) | Conv, Attention, RNN, MOE | +| `device_prop.multiProcessorCount` | Cast `OrtEp*` to concrete EP type | Attention, DeformConv | +| `device_prop.totalGlobalMem` | Cast `OrtEp*` to concrete EP type | DeformConv | +| cuDNN handle | Cast `OrtEp*` to concrete EP type | Conv, ConvTranspose, Reduction, RNN | +| Tactic profiler state (or worst-case constant) | Cast `OrtEp*` to concrete EP type | MOE, MatMulNBits, fpA_intB | + +**API surface:** Only `Node_GetInputShape` and `Node_GetAttributeInt/Ints` need to be added to `OrtEpApi` (generic, EP-agnostic). All device-specific state (cuDNN handles, device properties, profiler state) is accessed by casting `OrtEp*` to the EP's concrete type — no public API needed since the estimation function is EP-specific code. diff --git a/docs/annotated_partitioning/future_directions_constrained_env.md b/docs/annotated_partitioning/future_directions_constrained_env.md new file mode 100644 index 0000000000000..ab33a6595baf3 --- /dev/null +++ b/docs/annotated_partitioning/future_directions_constrained_env.md @@ -0,0 +1,1338 @@ +# Future Directions: Constrained Environment Partitioning + +## Context + +Today's annotation-based partitioning requires each node to carry a `layer_ann` metadata property. The `LayeringIndex` matches these annotations against user-supplied rules (prefix trie + exact match) to assign nodes to devices. The `IResourceAccountant` optionally enforces memory budgets. + +The goal: make ORT as easy to use as ollama for running large models on machines with limited GPU memory — automatic or near-automatic layer offloading without requiring model producers to annotate every node. + +--- + +## Direction 1: Name-Based Substring Matching (No Annotation Step) + +### Idea + +Skip the annotation metadata entirely. Instead, match directly against **node names** using substrings/patterns from the configuration. MS Foundry models (and most HuggingFace exports) already encode layer structure in node names: + +``` +/model/layers.0/self_attn/q_proj/MatMul +/model/layers.0/self_attn/k_proj/MatMul +/model/layers.15/mlp/gate_proj/MatMul +/model/embed_tokens/Gather +/model/norm/LayerNormalization +``` + +A config like `gpu(layers.0, layers.1, ..., layers.15); cpu(layers.16, ..., layers.31)` would partition without any model modification. + +### How to Approach + +1. **Add a new `SubstringMatcher` for node-name matching.** Today the `LayeringRuleMatcher` supports exact match and prefix match (via a trie that walks from position 0 of the input string). Neither mode works for node names: a node named `/model/layers.5/self_attn/q_proj/MatMul` does not *start with* `layers.5` — the identifying substring appears in the middle. Name-based matching fundamentally requires **substring** search. The existing trie infrastructure is irrelevant here — a new, simpler matching approach is needed (see "Substring Matching Implementation" below). + +2. **Config via a separate session option (same grammar, different matcher).** Rather than introducing a new qualifier into the existing `kOrtSessionOptionsLayerAssignmentSettings` syntax, add a parallel session option that uses **the same `device(pattern1, pattern2, ...); ...` grammar** but performs **substring matching** against `Node::Name()` instead of prefix/exact matching against node metadata: + + ```cpp + // Existing (annotation-based, matches node metadata 'layer_ann'): + static const char* const kOrtSessionOptionsLayerAssignmentSettings = + "session.layer_assignment_settings"; + + // NEW (name-based, matches Node::Name() via substring): + static const char* const kOrtSessionOptionsNameBasedLayerAssignment = + "session.name_based_layer_assignment"; + ``` + + Usage stays identical — only the matching target and algorithm differ: + ``` + # Annotation-based (existing, prefix/exact match against node metadata): + session.layer_assignment_settings = "cuda(encoder_layer, attention); cpu(embed)" + + # Name-based (new, substring match against Node::Name()): + session.name_based_layer_assignment = "cuda(layers.0/, layers.1/); cpu(layers.16/)" + + # Range expressions (future extension, not currently supported): + session.name_based_layer_assignment = "cuda(layers.[0-15]); cpu(layers.[16-31])" + ``` + + This approach: + - Keeps the existing parser/grammar unchanged (reuse the `device(pattern1, pattern2, ...); ...` syntax) + - Uses a **new `SubstringMatcher`** (not the existing trie-based `LayeringRuleMatcher`) for the actual matching + - Makes intent explicit — users opt into name-based matching deliberately + - The two options are **mutually exclusive** — setting both returns an error + - No risk of breaking existing annotation-based workflows + +3. **Build index at load time.** During `InferenceSession::Initialize()`, after graph is loaded but before partitioning: + - If config contains name-based rules, iterate all nodes once + - Build `NodeIndex → RuleIndex` map using substring matching on `Node::Name()` + - Feed this into the existing `LayeringIndex` infrastructure (same downstream flow) + +4. **Range expressions (future extension).** The config grammar does **not** support range syntax today. For transformer models with numbered layers, a future extension could add range support: + ``` + cuda(layers.[0-15]); cpu(layers.[16-31]) + ``` + This would avoid enumerating 32+ layer prefixes manually, but requires new parsing logic. Until then, users must enumerate each layer prefix explicitly or use a broad prefix like `layers.` that captures all layers for a single device. + +### Substring Matching Implementation + +The existing `LayeringRuleMatcher` uses a **trie** for prefix matching — it walks the input string from position 0 and checks if any trie path matches a prefix of the input. This only works when the pattern appears at the **start** of the matched string. + +For node names, patterns appear in the **middle**: +``` +Pattern: "layers.5" +Node name: "/model/layers.5/self_attn/q_proj/MatMul" + ^^^^^^^^ — match at position 7, not position 0 +``` + +The trie is useless here. A new `SubstringMatcher` class is needed. + +#### Design: Flat vector + `std::string::find` + +The simplest correct approach: + +```cpp +class SubstringMatcher { + public: + explicit SubstringMatcher(const LayeringRules& rules); + + /// Returns the index of the best matching rule for the given node name. + /// "Best" = longest pattern that appears as a substring in the name. + std::optional Match(std::string_view node_name) const; + + private: + // Sorted by pattern length descending — longest patterns checked first. + // First match wins (longest-match priority). + struct PatternEntry { + std::string pattern; + size_t rule_index; + }; + InlinedVector patterns_; // sorted longest-first +}; +``` + +**Match algorithm:** +```cpp +std::optional SubstringMatcher::Match(std::string_view node_name) const { + for (const auto& entry : patterns_) { + if (node_name.find(entry.pattern) != std::string_view::npos) { + return entry.rule_index; + } + } + return std::nullopt; +} +``` + +**Why longest-match-first ordering:** + +Without it, `layers.1` (a substring of `layers.10`, `layers.11`, ..., `layers.19`) would incorrectly match nodes from layers 10–19. By checking longer patterns first, `layers.10` matches before `layers.1` gets a chance. Users should include the path separator for unambiguous matching: `layers.1/` won't match `layers.10/...`. + +**Performance:** With ~64 patterns and node names < 200 chars, this is O(P × N) per node where P = number of patterns and N = name length. Total cost for a 1000-node model: ~64 × 200 × 1000 = ~12M character comparisons. This completes in microseconds on modern hardware and runs only once during `Initialize()`. No optimization (Aho-Corasick, etc.) is warranted. + +**Priority semantics:** + +| Scenario | Behavior | +|----------|----------| +| Single match | Return that rule's index | +| Multiple matches (different lengths) | Longest pattern wins | +| Multiple matches (same length, different rules) | First rule in config order wins (stable sort by length, preserving config order as tiebreaker) | +| No match | Return `nullopt` → node goes to fallback EP (CPU) | + +**Integration with `LayeringIndex`:** + +`LayeringIndex` owns either a `LayeringRuleMatcher` (annotation mode) or a `SubstringMatcher` (name-based mode) — the two are mutually exclusive. The `ProcessGraph` method branches based on which mode is active: + +```cpp +void LayeringIndex::ProcessGraph(const Graph& graph, std::optional parent_layer_id) { + for (const auto& node : graph.Nodes()) { + std::optional matched_rule_idx; + + if (substring_matcher_) { + // Name-based mode: substring match against node name, no inheritance. + // Node names are dense, so each node is matched independently. + matched_rule_idx = substring_matcher_->Match(node.Name()); + } else { + // Annotation-based mode: prefix/exact match against metadata, + // with subgraph inheritance for unannotated nodes. + const std::string& annotation = node.GetLayeringAnnotation(); + if (!annotation.empty()) { + matched_rule_idx = matcher_.Match(annotation); + } + if (!matched_rule_idx && parent_layer_id) { + matched_rule_idx = parent_layer_id; + } + } + + if (matched_rule_idx) { + node_to_layering_index_[node.Index()] = *matched_rule_idx; + } + } +} +``` + +**Why mutual exclusivity (not priority/fallback):** + +The two modes have fundamentally different inheritance semantics. Annotations are sparse — nodes without annotations inherit from their subgraph parent to maintain device consistency. Names are dense — virtually every node has a name, so inheritance is unnecessary and would incorrectly override name-based matches in subgraphs. Making the modes mutually exclusive keeps the semantics simple and predictable. + +### Advantages + +- **Zero model modification** — works with any model that has structured naming +- **Reuses existing partitioning infrastructure** — only the index-building and matching steps change +- **User-friendly** — users can inspect node names with Netron and write rules directly +- **Composable with resource accounting** — can combine name-based assignment with memory budgets + +### Risks / Open Questions + +- **Name stability**: Node names aren't guaranteed stable across exports. Mitigated by prefix/substring matching rather than exact names. + +### Handling Nodes Created by Graph Transformers + +#### Pre-partitioning transformers (Level 1) + +Level 1 optimizers run **before** partitioning. With annotation-based matching, these transformers propagate annotations to new nodes via the `AddNode(..., annotation_source)` overload, which copies `GetLayeringAnnotation()` from an original node. The name-based approach needs an analogous story. + +**Key insight: new node names ARE derivative of original names.** Verified in the codebase: + +- `Graph::GenerateNodeName(base_name)` takes a base string and ensures uniqueness by appending `_token_` only on collision. +- Transformers construct the base name from original node(s): + - `layer_norm_fusion.cc`: `GenerateNodeName(mul_node.Name() + "/LayerNormFusion/")` + - `matmul_add_fusion.cc`: `GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion")` + - `attention_fusion.cc`: `GenerateNodeName("Attention")` ← **exception — generic name** + +So if the original node was `/model/layers.5/self_attn/q_proj/MatMul`, the fused node typically becomes something like `/model/layers.5/self_attn/q_proj/MatMul/MatMulAddFusion` — which still contains the original layer prefix and will match substring rules like `layers.5`. + +**This means name-based matching is naturally robust to pre-partitioning fusions** — no explicit annotation-copying step is needed, because the substring match against the derivative name still hits the same rules. This is actually **simpler** than the annotation-based approach. + +**Edge cases to handle:** +1. **Generic names** (e.g., `"Attention"` without incorporating the original name): Some fusions create nodes with generic names that don't contain layer-identifying substrings. In general, name-based partitioning can only be used when node names contain representative strings suitable for layer matching. Two options exist: + - **Annotation fallback**: Use annotation-based assignment for these nodes, or update the transformer to follow the derivative naming convention. + - **Substring rule**: Add a substring pattern (e.g., `cuda(Attention)`) to assign all nodes whose name contains `Attention`. Note that name-based assignment does not support the '=' exact-match qualifier — all patterns are substrings. +2. **Multiple source nodes**: When a fusion merges N nodes from potentially different layers, the resulting name typically uses one of them as the base. If the merged nodes span layer boundaries, the fused node will match whichever layer the chosen base name belongs to. This mirrors annotation-based behavior (annotation is copied from one source node). + +**Recommendation:** No special machinery needed for pre-partitioning transformers. The derivative naming convention already preserves matchability. Document this as a convention that transformer authors should follow: always pass an original node's name as the base to `GenerateNodeName()`. + +#### Post-partitioning transformers (Level 2+) + +Level 2+ optimizers run **after** partitioning, when EP assignments have already been made. These transformers already copy the EP assignment from original nodes: + +```cpp +new_node.SetExecutionProviderType(original_node.GetExecutionProviderType()); +``` + +**No action needed for name-based matching here.** By the time Level 2+ transformers run, partitioning is complete. The node names are irrelevant — only the EP assignment matters, and that's already propagated correctly. + +--- + +## Direction 2: Minimize Allocations for Static-Shape Models + +### Goal + +For models with fully static shapes (common in transformer inference with fixed batch/sequence dimensions), ORT should minimize or eliminate runtime memory allocations. When all tensor shapes are known at `Initialize()` time, the runtime can pre-compute exact memory requirements and pre-allocate everything upfront — no arena overhead, no per-`Run()` allocation calls, deterministic memory usage. + +Additionally, by knowing exact memory requirements before execution begins, ORT can **minimize the chance of running into OOM** — if the total memory needed exceeds device capacity, the session can fail at `Initialize()` time with a clear error rather than crashing mid-inference with an opaque allocation failure. + +This brings ORT's allocation efficiency on par with specialized runtimes like llama.cpp, while retaining ORT's generality for arbitrary model architectures. + +### Dynamic Shapes in Transformer Models + +Nearly all transformer models are *exported* with dynamic batch + sequence_length — this is the default in PyTorch `torch.onnx.export`, Hugging Face Optimum, and Olive. However, for constrained deployment the picture is different: + +- **Fixed-shape re-export is standard for edge/embedded:** batch=1, seq_len=128 (or a few discrete lengths like 128/256/512). This is standard practice for TensorRT, CoreML, and QNN deployments. +- **LLM serving keeps shapes dynamic** (variable prompts, KV-cache growth). But these are typically high-VRAM scenarios (A100/H100), not constrained environments. +- **Vision transformers** (ViT, DINO, etc.) have fixed patch sequences — only batch is dynamic, and fixing batch=1 yields fully static shapes. + +**Implication for pre-allocation:** The target audience of `pre_allocate_execution_buffers` — embedded, edge, single-model-per-device — typically *can* use static shapes. Models are re-exported with fixed dimensions as part of the deployment pipeline. The dynamic-shape case (LLM serving with variable seq_len) lives in a different deployment tier where VRAM budget is less critical than throughput and the arena allocator handles repeated allocations efficiently. + +**Implication for workspace estimation:** Even with dynamic shapes, `EstimateWorkspace` (Level 1) and `DeclareWorkspaceRequirements` (Level 2) remain valuable for *budget decisions* — they can use worst-case shapes (max batch, max seq_len from model config) to determine how many nodes fit on the device. The estimate doesn't need to match runtime exactly; it needs to be conservative enough to avoid OOM. + +### Reference: What llama.cpp Does + +llama.cpp exploits that transformer inference has **fully deterministic memory usage**: +- All weight tensors are known at load time +- KV cache is pre-allocated for max sequence length +- Intermediate activation buffers have shapes determined by `(batch, seq_len, hidden_dim)` — all known in advance +- Workspace/temp buffers are known per-op and pre-planned + +This means the runtime computes **exactly** how much memory is needed before running — zero allocation calls during inference. + +### What ORT Already Does + +Before discussing gaps, it's important to note what ORT already provides: + +- **Shape inference runs during `Initialize()`** — specifically during `graph.Resolve()` after graph optimizations but before memory planning. Shape info is populated on all `NodeArg` objects. +- **Memory pattern pre-allocation** — When `EnableMemPattern` is set and shapes are static, `TensorAllocatorWithMemPattern` computes exact allocation offsets/sizes via `OrtValuePatternPlanner`, then pre-allocates a single large buffer per device via `Reserve()` (bypasses arena, calls device allocator directly). Intermediate buffers are reused based on liveness analysis. +- **Initializers can bypass arena** — When `use_device_allocator_for_initializers` is set, initializers are loaded via `Reserve()` (direct device allocation) during session state finalization, bypassing the arena's binning/coalescing logic. Without this option, initializers allocate through the arena like any other buffer. +- **BFC Arena** — Best-Fit-with-Coalescing allocator provides fast (O(log n)) sub-allocation. Allocations are cheap, but it suffers from memory waste due to power-of-two growth and chunk granularity. + +**What's already outside the arena:** +- Initializers → `Reserve()` (direct device allocator) +- Memory pattern buffers (activations) → `Reserve()` (direct device allocator) + +**What's still arena-allocated:** +- **Runtime temp/workspace buffers** — kernels call `GetScratchBuffer()` during `Compute()`, which allocates from the device allocator (arena by default). These are ephemeral (freed when kernel completes) and their sizes are only known at execution time. + +**Note on arena and temp buffer reuse:** While the BFC arena wastes memory due to chunk granularity, it does provide **automatic reuse** for temp buffers during sequential execution — subsequent kernels reuse the same arena memory for their scratch needs without actual device allocations. + +**CUDA mempool as an alternative:** ORT supports replacing the BFC arena with native CUDA memory pools (`cudaMallocFromPoolAsync`). This is enabled via the EP-scoped arena configuration key `arena.use_cuda_mempool` (e.g., `"ep.cudapluginexecutionprovider.arena.use_cuda_mempool" = "1"` in session config). This provides stream-aware pooling managed by the CUDA driver, with less memory waste than BFC. Since `GetScratchBuffer()` uses the same device allocator as activations (resolved via `SessionState::GetAllocator(device)` — keyed by `OrtDevice` only, not by purpose), enabling mempool automatically benefits temp buffers too. A separate temp-only allocator would require architectural changes to `AllocatorMap` (currently not feasible without significant refactoring). + +**Pre-allocated space for temp buffers (alternative approach):** If workspace sizes can be pre-computed (see Phase A below), temp buffers could be served from the same pre-allocated memory pattern buffer used for activations. Since workspace is live only during its kernel's execution, it participates naturally in liveness-based offset planning — no arena needed at all. This is the more principled solution: solve the workspace size problem first, then temp memory becomes part of the static plan. + +### IResourceAccountant Precision + +`IResourceAccountant::ComputeResourceCount()` already returns **exact sizes when shapes are static**: +- Initializer sizes: exact via `GetSizeInBytesFromTensorProto()` +- Output tensor sizes: exact when all dimensions are known via `GetSizeInBytesFromTensorTypeProto()` +- Weight deduplication: tracked via `pending_weights_`/`committed_weights_` to avoid double-counting + +**Note on actual memory consumption:** While these give exact logical tensor sizes, actual device memory will be rounded up to page/alignment boundaries — either per allocation or for a single large buffer. The reported sizes are a lower bound; real usage includes alignment overhead. + +It applies a **1.5x safety multiplier** to account for unknowable temp/workspace allocations. This multiplier exists because temp buffer sizes are discovered only at runtime — no kernel declares its workspace needs in advance. **For static-shape models where `DeclareWorkspaceRequirements` (Phase A below) has been implemented for all relevant kernels, this multiplier becomes unnecessary and should be bypassed** — exact workspace sizes are known at planning time, eliminating the need for a safety margin. + +**Per-node accounting (not per-layer)**: `IResourceAccountant` tracks costs at the node/subgraph level — `nodes_costs` in `IndexedSubGraph` is for accounting after an EP claims nodes (single nodes or fused groups), not for layering. It has no concept of "layer" as defined by the layering index. + +**Do we need per-layer aggregation?** Likely not as a first-class feature in `IResourceAccountant`. + +The current budget enforcement cuts off EP placement at the individual node level — once the cumulative budget is exceeded, subsequent nodes are rejected regardless of layer boundaries. **This is intentional**: atomic rollback of already-placed nodes within a layer was explicitly rejected during the previous implementation phase due to complexity and because it would require re-running `GetCapability()` with different node sets. + +The layering index already controls the *order* in which nodes are presented to the EP (layer by layer), so in practice the budget cutoff tends to land near layer boundaries. If exact layer-boundary cuts are desired in the future, it would need to be a separate mechanism (e.g., pre-computing per-layer costs and making accept/reject decisions at the layer level before calling `GetCapability()`), not a change to the accountant. + +For debugging/UX purposes, per-layer summaries can be computed externally by summing node costs grouped by their layer annotation — no accountant changes needed. + +### The Remaining Gap: Runtime Temp Buffers + +The 1.5x multiplier and arena waste both stem from the same root cause: **kernels allocate workspace at runtime without prior declaration.** + +Examples of how workspace sizes are determined: +- **cuDNN Convolution**: Workspace depends on algorithm selection (`cudnn_fe_graph->get_workspace_size()`) +- **cuDNN RNN**: Queries cuDNN at runtime (`cudnnGetRNNTempSpaceSizes()`) +- **Attention kernels**: Computed from runtime parameters (`batch_size * seq_len * num_heads * head_size`) + +The `SequentialExecutionPlan` tracks only activation tensors and initializers — workspace buffers are ephemeral and never planned statically. There is no `GetWorkspaceSize()` virtual method on `OpKernel`. + +To eliminate both the multiplier and arena waste for temp buffers, two problems must be solved: +1. **Learn temp buffer sizes in advance** — e.g., a `DeclareWorkspaceRequirements(shapes)` method on kernels, queryable during `Initialize()` when shapes are static +2. **Allocate temp buffers outside the arena** — once sizes are known, include them in the memory pattern plan alongside activations + +### What ORT Is Missing + +| Capability | llama.cpp | ORT Today | Gap | +|-----------|-----------|-----------|-----| +| Static memory plan | Yes — computed at load | Yes — `MemoryPatternGroup` + `Reserve()` | ✓ Activations + initializers already bypass arena | +| Pre-allocated activation buffers | Yes — fixed slots | Yes — memory patterns with liveness reuse | ✓ Already exists for static shapes | +| Workspace pre-computation | Yes — known per op | No — kernels discover at runtime | Need `DeclareWorkspaceRequirements()` on kernels | +| Workspace outside arena | Yes — part of static plan | No — `GetScratchBuffer()` uses arena | Need to include workspace in memory pattern plan | +| Zero-copy weight transfer | mmap + `cudaMemcpy` at load per layer | mmap + `cudaMemcpy` at load per partition | ✓ Same model — not a gap | + +**Note on weight transfer:** Both llama.cpp and ORT use the same approach: **static partitioning at load time, no dynamic weight swapping during inference.** + +- **llama.cpp**: User sets `-ngl N` (number of GPU layers). At load time, those N layers' weights are `cudaMemcpy`'d from mmap'd file to GPU. Remaining layers stay in host memory. No runtime swapping — this is performant because there is zero weight transfer overhead during token generation. +- **ORT (with constrained partitioning)**: The layering index + `IResourceAccountant` determines which nodes run on GPU. At `Initialize()` time, only those nodes' initializers are copied to device from mmap'd external data. Remaining weights stay in host memory. + +**Best practice for constrained environments:** Model weights should be stored as **external data on disk** (not embedded in protobuf). This ensures: +1. ORT memory-maps the file — minimal host memory overhead during loading. +2. Only GPU-partitioned nodes' weights are copied to device — no OOM as long as partitioning respects the budget. +3. CPU-partitioned nodes' weights remain accessible via mmap without requiring a separate host allocation. + +Since partitioning is decided once at `Initialize()` time and all required device weights are resident before `Run()`, there is no need for dynamic layer loading/offloading during inference. + +### How to Approach + +#### The Chicken-and-Egg Problem: Workspace Estimation vs EP Assignment + +**Problem statement:** To make precise memory budget decisions during `GetCapability()`, `IResourceAccountant` needs workspace sizes per node. But workspace sizes come from kernels, which don't exist until *after* EP assignment (kernels are created during `Compile()`/session state finalization — same reason `PrePack` happens late). At decision time, you can't ask a kernel that doesn't exist yet. + +**Why this matters:** The goal is not to fail gracefully — it's to **avoid failure entirely**. Today, models either OOM on device or trigger heavy VRAM thrashing on Windows. The partitioning must be conservative enough to prevent this, while accurate enough to maximize GPU utilization. + +**Solution: Two-level estimation with post-assignment verification** + +**Level 1 — Static workspace estimation function (at partitioning time):** + +When a kernel is registered (via `KernelRegistry_AddKernel`), optionally provide a **static estimation function** — a class-level function that takes node info and returns a conservative workspace estimate without needing a kernel instance: + +```c +// Registered alongside the kernel definition in KernelRegistry_AddKernel: +typedef OrtStatus*(ORT_API_CALL* OrtKernelWorkspaceEstimateFunc)( + _In_ const OrtEpApi* api, // for querying node attributes/shapes/device props + _In_ const OrtNode* node, // the specific node being evaluated + _In_ const OrtEp* ep, // EP instance (for device properties like SM count) + _Out_ size_t* estimated_workspace_bytes); +``` + +The function uses `api->Node_GetAttribute*()` and `api->Node_GetInputShape()` to access the node's attributes and input shapes, and `api->Ep_GetDeviceProperty()` for GPU hardware properties — everything needed to compute workspace without a kernel instance. + +**Can the estimate be precise (not just conservative)?** + +Depends on the kernel: + +| Kernel | Workspace depends on | Available at GetCapability()? | Precise estimate? | +|--------|---------------------|-------------------------------|-------------------| +| **Attention (Flash)** | shapes + `num_heads` attr + `device_prop.multiProcessorCount` | ✓ All available (EP has device_prop) | **YES — exact** | +| **Conv (cuDNN)** | cuDNN `build_plans(handle)` with tensor shapes + conv params | ✓ EP has handle; shapes/attrs available from node | **YES — exact** (with `HEUR_MODE_A`) | +| **GEMM/MatMul** | No workspace | N/A | N/A (returns 0) | + +For **attention**, workspace is determined by `get_num_splits_and_buffer_sizes()` which is pure arithmetic given `(batch, seq, heads, head_size, multiProcessorCount)`. The EP already has `multiProcessorCount` from `cudaGetDeviceProperties()` which runs during EP construction (before `GetCapability()`). So the estimate can be **exact**. + +For **cuDNN-based ops** (Conv), the workspace depends on which algorithm cuDNN selects via `build_plans(handle)`. However, a cuDNN handle is just a lightweight context object (`cudnnCreate` + `cudnnSetStream`) — the EP already owns one from construction time. With static shapes, all inputs to `build_plans()` are known: tensor dimensions, conv parameters (from node attributes), and the handle. The `CUDNN_HEUR_MODE_A` (fast heuristic) used by ORT is essentially a lookup + arithmetic — not actual GPU profiling. So the estimation function **can call `build_plans()` and get the exact workspace size**. This makes Conv estimates **precise too**. + +The reason `build_plans()` currently runs during first `Compute()` is historical: ORT didn't have a pre-execution workspace declaration phase, and shapes weren't known until runtime. With static shapes and the estimation function pattern, this computation can move earlier. + +The estimation function accesses the handle by casting `OrtEp*` to the EP's concrete type (safe because the function is EP-specific code registered by that EP): + +```cpp +auto* cuda_ep = static_cast(ep); // plugin path +cudnnHandle_t handle = cuda_ep->GetCudnnHandle(); +``` + +**Can it be the same function as DeclareWorkspaceRequirements?** + +Not the same function pointer (different signatures — one has a kernel instance, one doesn't). But the **core computation logic can be a shared static helper** called from both: + +```cpp +// Shared static helper (no instance needed): +static size_t ComputeAttentionWorkspace(int batch, int seq, int heads, + int head_size, int num_SMs) { + auto [num_splits, slse_size, o_size] = flash::get_num_splits_and_buffer_sizes( + batch, seq, seq, heads, head_size, num_SMs); + return flash::get_softmax_lse_size(seq, batch, heads) + slse_size + o_size; +} + +// Estimation function (no kernel instance — called during GetCapability): +OrtStatus* EstimateAttentionWorkspace(const OrtEpApi* api, const OrtNode* node, + const OrtEp* ep, size_t* out) { + const int64_t* shape; size_t rank; + api->Node_GetInputShape(node, 0, &shape, &rank); + int64_t num_heads; + api->Node_GetAttributeInt(node, "num_heads", &num_heads); + + // EP-specific: cast to concrete type to access device properties + auto* cuda_ep = static_cast(ep); + int num_SMs = cuda_ep->GetDeviceProp().multiProcessorCount; + + *out = ComputeAttentionWorkspace(shape[0], shape[1], num_heads, shape[3], num_SMs); + return nullptr; +} + +// DeclareWorkspaceRequirements (has kernel instance — called during FinalizeSessionState): +Status Attention::DeclareWorkspaceRequirements(span shapes, + InlinedVector& reqs) { + int num_SMs = GetDeviceProp().multiProcessorCount; + size_t total = ComputeAttentionWorkspace( + shapes[0][0], shapes[0][1], num_heads_, head_size_, num_SMs); + reqs.push_back({total, kSlotFlashWorkspace}); + return Status::OK(); +} +``` + +Both call the same `ComputeAttentionWorkspace()` — producing **identical results**. The estimation function gets device properties from the EP; the kernel method gets them from its stored EP reference. Same data, same computation, same answer. + +**For cuDNN-based ops**, the estimation function can also be precise — it calls `build_plans()` using the EP's handle and the node's shapes/attributes. Level 2 re-check serves as a diagnostic safety net — if the post-fusion total exceeds the budget, a warning is logged indicating that the Level 1 estimate was too optimistic (e.g., cuDNN returning different workspace sizes due to driver version differences or fusion changing the algorithm selection). + +**KernelCreateInfo and registration macros:** + +Today `KernelCreateInfo` contains `{kernel_def, kernel_create_func, status}`. To add the estimation function: + +```cpp +struct KernelCreateInfo { + std::unique_ptr kernel_def; + KernelCreateFn kernel_create_func; + OrtKernelWorkspaceEstimateFunc workspace_estimate_func; // NEW — may be nullptr + Status status; +}; +``` + +The existing `ONNX_OPERATOR_TYPED_KERNEL_EX` macro doesn't need changes — it produces `KernelCreateInfo` via `BuildKernelCreateInfo<>()`. A new macro variant adds the estimation function for kernels that implement it: + +```cpp +// New macro: ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE +// Same as ONNX_OPERATOR_TYPED_KERNEL_EX but also registers a workspace estimation function. +#define ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE( \ + name, domain, ver, type, provider, builder, estimate_fn, ...) \ + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \ + template <> \ + KernelCreateInfo \ + BuildKernelCreateInfo() { \ + return KernelCreateInfo( \ + builder.SetName(#name) \ + .SetDomain(domain) \ + .SinceVersion(ver) \ + .Provider(provider) \ + .Build(), \ + static_cast( \ + [](FuncManager&, const OpKernelInfo& info, \ + std::unique_ptr& out) -> Status { \ + out = std::make_unique<__VA_ARGS__>(info); \ + return Status::OK(); \ + }), \ + estimate_fn); \ + } +``` + +This requires a new `KernelCreateInfo` constructor overload: + +```cpp +struct KernelCreateInfo { + std::unique_ptr kernel_def; + KernelCreateFn kernel_create_func; + OrtKernelWorkspaceEstimateFunc workspace_estimate_func; // NEW — may be nullptr + Status status; + + // Existing constructor (unchanged — sets workspace_estimate_func to nullptr): + KernelCreateInfo(std::unique_ptr definition, + KernelCreateFn create_func) + : kernel_def(std::move(definition)), + kernel_create_func(create_func), + workspace_estimate_func(nullptr) {} + + // New constructor with estimation function: + KernelCreateInfo(std::unique_ptr definition, + KernelCreateFn create_func, + OrtKernelWorkspaceEstimateFunc estimate_func) + : kernel_def(std::move(definition)), + kernel_create_func(create_func), + workspace_estimate_func(estimate_func) {} + + KernelCreateInfo(KernelCreateInfo&& other) noexcept + : kernel_def(std::move(other.kernel_def)), + kernel_create_func(std::move(other.kernel_create_func)), + workspace_estimate_func(other.workspace_estimate_func) {} + + KernelCreateInfo() = default; +}; +``` + +**Usage example** (registering a CUDA Attention kernel with estimation): + +```cpp +// In cuda_contrib_kernels.cc: +ONNX_OPERATOR_TYPED_KERNEL_EX_WITH_ESTIMATE( + Attention, // name + kMSDomain, // domain + 1, // ver + float, // type + kCudaExecutionProvider, // provider + KernelDefBuilder() // builder + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, {3, 4}), + &cuda::EstimateAttentionWorkspace, // estimate_fn ← NEW argument + cuda::Attention); // kernel class (__VA_ARGS__) +``` + +Kernels without estimation continue using `ONNX_OPERATOR_TYPED_KERNEL_EX` unchanged — their `workspace_estimate_func` is `nullptr`, and the budget logic applies the 1.5x multiplier as today. The migration is opt-in, kernel by kernel. + +**Integration with GetCapability and the resource budget:** + +The estimation function is called during budget enforcement — by the EP directly (in-tree) or by the host bridge (plugin). The result is combined with the base cost from `IResourceAccountant`. + +**Multiplier handling — non-member helper approach:** + +`ComputeResourceCount()` currently applies a 1.5x multiplier to approximate workspace for kernels without estimation functions. With precise workspace estimates available, the multiplier must be skipped. Rather than changing `ComputeResourceCount()`'s signature, we move the multiplier out and into a non-member helper that encapsulates the budget decision: + +```cpp +// Non-member helper (e.g., in resource_accountant_helpers.h): +// Called by both in-tree GetCapability and the plugin host bridge. +ResourceCount ComputeNodeCostForBudget( + IResourceAccountant& accountant, + const Node& node, + std::optional workspace_estimate) { + // ComputeResourceCount returns base cost: outputs + initializers (dedup'd) + // NO multiplier — multiplier is now applied here when needed + ResourceCount base_cost = accountant.ComputeResourceCount(node); + + if (workspace_estimate.has_value()) { + // Precise workspace known — add it directly, no multiplier + return AddResourceCounts(base_cost, *workspace_estimate); + } + // No workspace estimate — apply heuristic multiplier (1.5x) + return ApplyWorkspaceHeuristic(base_cost); +} + +// Multiplier as an explicit utility: +ResourceCount ApplyWorkspaceHeuristic(ResourceCount base) { + size_t bytes = std::get<0>(base); + return ResourceCount{static_cast(bytes * 1.5)}; +} +``` + +**Design rationale:** +- `ComputeResourceCount()` signature is **unchanged** — it returns the raw base cost (outputs + initializers with dedup). The 1.5x multiplier moves out of the accountant into this helper. +- The helper is the **single decision point** for both code paths (in-tree and plugin host bridge). No duplicated logic. +- `ApplyWorkspaceHeuristic()` makes the multiplier explicit and testable. It can be adjusted (e.g., per-EP or per-op-type) without changing any interface. +- The helper integrates naturally with the existing budget check pattern: + +```cpp +// Usage in GetCapability (both paths): +auto total_cost = ComputeNodeCostForBudget(*accountant, node, workspace_estimate); +auto would_be_consumed = AddResourceCounts(consumed, total_cost); + +if (has_budget && ResourceCountExceeds(would_be_consumed, budget)) { + accountant->SetStopAssignment(); + break; +} + +consumed = would_be_consumed; +sub_graph->SetAccountant(accountant); +sub_graph->AppendNodeCost(total_cost); +``` + +**Why this is clean with committed/uncommitted weights:** + +- **Weight dedup is unaffected.** `ComputeResourceCount()` handles pending/committed weight tracking internally. The workspace estimate is purely additive — it's not a weight, so it doesn't participate in dedup. +- **`AppendNodeCost()` stores the combined total.** When `AccountForNode()` runs later (during `TryAssignNodes()`), it adds the stored cost (base + workspace) to `consumed_amount` and commits the weights. The workspace portion just inflates the per-node cost. +- **`CommitWeightsForNode()` only touches initializers.** Workspace is a separate addend, not tracked in weight sets. +- **`ResetForNewPass()` is fine.** The workspace estimate is stateless — recomputed fresh from node shapes each call, no state to carry across passes. + +If no estimation function is registered for a kernel, the helper applies the 1.5x multiplier as today (unchanged behavior). + +**Example estimation function (CUDA Conv):** + +```cpp +OrtStatus* EstimateConvWorkspace(const OrtEpApi* api, const OrtNode* node, + const OrtEp* ep, size_t* out) { + // Get input shape (X: NCHW) + const int64_t* x_shape = nullptr; + size_t x_rank = 0; + OrtStatus* status = api->Node_GetInputShape(node, 0, &x_shape, &x_rank); + if (status) return status; + if (x_rank < 3) { + return api->CreateStatus(ORT_INVALID_ARGUMENT, + "Conv: input X must be at least rank 3"); + } + + // Get weight shape (W: [M, C/group, kH, kW, ...]) + const int64_t* w_shape = nullptr; + size_t w_rank = 0; + status = api->Node_GetInputShape(node, 1, &w_shape, &w_rank); + if (status) return status; + if (w_rank != x_rank) { + return api->CreateStatus(ORT_INVALID_ARGUMENT, + "Conv: weight rank must match input rank"); + } + + // Get conv attributes (all optional — defaults to empty/zeros per ONNX spec) + const int64_t* pads = nullptr; + size_t pads_count = 0; + status = api->Node_GetAttributeInts(node, "pads", &pads, &pads_count); + if (status) return status; // distinguishes "not present" (OK + nullptr) from error + + const int64_t* strides = nullptr; + size_t strides_count = 0; + status = api->Node_GetAttributeInts(node, "strides", &strides, &strides_count); + if (status) return status; + + const int64_t* dilations = nullptr; + size_t dilations_count = 0; + status = api->Node_GetAttributeInts(node, "dilations", &dilations, &dilations_count); + if (status) return status; + + // EP-specific: cast to concrete type to access cuDNN handle + auto* cuda_ep = static_cast(ep); + cudnnHandle_t handle = cuda_ep->GetCudnnHandle(); + if (!handle) { + return api->CreateStatus(ORT_RUNTIME_EXCEPTION, + "Conv: cuDNN handle not available on EP"); + } + + // Build cuDNN frontend graph and query workspace + // (same logic as CreateCudnnFeExecutionPlan but without storing state) + auto graph = BuildConvFrontendGraph(x_shape, x_rank, w_shape, w_rank, + pads, pads_count, strides, strides_count, + dilations, dilations_count); + if (!graph) { + return api->CreateStatus(ORT_RUNTIME_EXCEPTION, + "Conv: failed to build cuDNN frontend graph"); + } + + auto plan_status = graph->build_plans(handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_ONLY); + if (plan_status) { + return api->CreateStatus(ORT_RUNTIME_EXCEPTION, + plan_status.get_message()); + } + + *out = graph->get_workspace_size(); + return nullptr; // success +} +``` + +This produces the **exact same result** as the kernel would compute during `DeclareWorkspaceRequirements` — same handle, same shapes, same algorithm selection. The shared logic is the cuDNN frontend graph construction. + +**Why no public API for device properties or library handles:** The estimation function is registered BY the EP for ITS kernels — it's EP-specific code running in the EP's own DLL. It can safely cast `OrtEp*` to its concrete type (e.g., `CudaEp*`) to access device_prop, cuDNN handles, etc. This is the same pattern kernels already use: `static_cast(info.GetExecutionProvider())->GetDeviceProp()`. No generic `Ep_GetCudnnHandle` or `Ep_GetDeviceIntProperty` API is needed — that would be CUDA-specific pollution of the universal EP interface. + +**Level 2 — Post-fusion budget re-check (before InsertCast and MemcpyTransformer):** + +The `TransformGraph()` pipeline has a natural insertion point after EP-specific optimizers but before the transformers that bake in EP boundaries: + +``` +L1 optimizers → Partition (GetCapability) → L2/L3 EP-specific optimizers → [HERE] → InsertCastTransformer → L4 → MemcpyTransformer +``` + +At `[HERE]`: +- Nodes are assigned to EPs ✓ +- EP-specific fusions (ConvRelu, FusedMatMul, etc.) have already been applied ✓ +- The graph reflects the *actual* ops that will become kernels ✓ +- Cast nodes have NOT been inserted yet ✓ (no fp16↔fp32 casts at boundaries) +- Memcpy nodes have NOT been inserted yet ✓ (boundaries can still move) +- Kernels do NOT exist yet ✗ (cannot call `DeclareWorkspaceRequirements`) + +**Why before InsertCastTransformer:** The InsertCastTransformer inserts fp16↔fp32 Cast nodes at EP boundaries where input/output types don't match. If we offload nodes *after* Cast insertion, we'd leave orphaned Cast nodes at the old boundary and need new ones at the new boundary — similar to the MemcpyTransformer problem. By running before both, any Cast and Memcpy nodes are inserted at the final (post-offload) boundaries. + +Since kernels don't exist, we call the **same `EstimateWorkspace()` functions** from Level 1 — but now on the post-fusion graph. This eliminates the only meaningful gap between Level 1 and Level 2: fused ops that didn't exist at `GetCapability()` time now have their own estimation functions registered alongside their kernel definitions. + +**Algorithm:** + +1. For each node assigned to the constrained EP, look up its `OrtKernelWorkspaceEstimateFunc` from the kernel registry (same registry that will later be used to create the kernel). +2. Call the estimation function on the (possibly fused) node → get workspace. +3. Re-run the budget check: `base_cost + workspace` for all assigned nodes. +4. If total ≤ budget → proceed to InsertCastTransformer/MemcpyTransformer normally. +5. If total > budget → **log a warning** and proceed. Do NOT attempt to offload nodes. + +**Why warn-only (no runtime offload):** + +The earlier design attempted tail-node offloading at this stage — walking backward through GPU-assigned nodes and reassigning them to CPU. In practice, this is problematic: + +- **For bf16/fp16 models** (the dominant constrained-VRAM use case): CPU EP lacks kernels for most bf16/fp16 compute ops (MatMul, Attention, LayerNorm). The offload loop would hit a non-offloadable node almost immediately and accomplish nothing. +- **For fp32 CNN models** (where CPU *could* handle offloaded ops): The performance cost of GPU→CPU→GPU data transfers typically outweighs the memory benefit of offloading a few tail nodes. +- **Complexity vs value**: Offload logic (type checking, contiguous-tail constraint, boundary correctness) adds significant code for a feature that rarely fires and rarely helps. + +The correct fix for a Level 2 budget overrun is to **improve Level 1 accuracy** — make the estimation functions precise enough that post-fusion re-check merely confirms (not corrects) the budget. Level 2 serves as a **diagnostic safety net**: if the warning fires, it indicates the Level 1 estimate was too optimistic, and the estimation function for the offending kernel(s) should be improved. + +```cpp +// Pseudo-code in TransformGraph, after L2/L3, before InsertCastTransformer: +if (level2_total > budget) { + size_t overrun = level2_total - budget; + LOGS(logger, WARNING) + << "Post-fusion budget re-check: EP '" << ep_type + << "' exceeds memory budget by " << overrun << " bytes. " + << "Level 1 estimation was too optimistic. " + << "Consider improving workspace estimation for fused ops. " + << "Proceeding — runtime OOM may occur."; +} +``` + +This keeps the pipeline simple: Level 2 is purely observational (re-check + warn), not interventional. If the warning fires in testing, the developer improves the relevant `EstimateWorkspace()` function. In production, the budget was validated at Level 1 and Level 2 divergence should be rare. + +**Why Level 2 exists separately from Level 1:** + +Level 1 (during `GetCapability()`) operates on the **pre-fusion** graph. It estimates workspace for the original unfused nodes (Conv, Relu separately). Level 2 operates on the **post-fusion** graph (ConvRelu as a single node). The two can diverge when: +- A fused op's workspace differs from the sum of its parts (common — fusion often reduces workspace) +- Level 2/3 optimizers add or remove nodes (e.g., constant folding eliminates a node entirely) + +For most LLM models (which are repetitive transformer blocks with minimal fusion opportunity), Level 1 and Level 2 will agree. Level 2 matters more for CNN models with heavy fusion (Conv+BN+Relu patterns). + +**When static shapes are unavailable:** + +If the model has dynamic shapes, the estimation function cannot compute workspace (shapes are unknown at `GetCapability()` time). In this case: +- The estimation function returns a failure status or a sentinel value indicating "unknown." +- `ComputeNodeCostForBudget()` falls back to the 1.5x heuristic multiplier on base cost. +- The user may need to **tune the memory budget by trial and error** — setting a conservative budget and adjusting based on observed OOM or under-utilization. This is analogous to llama.cpp's `-ngl` flag: the user picks a layer count and adjusts based on whether it fits. +- A future extension could accept user-provided "typical shape hints" (e.g., `max_batch=4, max_seq=2048`) to enable estimation even for dynamic-shape models, but this is out of scope for the initial design. + +**Plugin C ABI for Level 1:** + +```c +// Workspace estimation function type (no kernel instance needed): +typedef OrtStatus*(ORT_API_CALL* OrtKernelWorkspaceEstimateFunc)( + _In_ const OrtEpApi* api, + _In_ const OrtNode* node, + _In_ const OrtEp* ep, + _Out_ size_t* estimated_workspace_bytes); + +// Extension to KernelRegistry_AddKernel in OrtEpApi: +ORT_API2_STATUS(KernelRegistry_AddKernelV2, + _Inout_ OrtKernelRegistry* registry, + _In_ const OrtKernelDef* kernel_def, + _In_ OrtKernelCreateFunc create_func, + _In_opt_ void* create_func_state, + _In_opt_ OrtKernelWorkspaceEstimateFunc workspace_estimate_func); // NEW — may be NULL +``` + +**Required OrtEpApi additions for the estimation function to query node info:** + +```c +// Query input shape from a node's NodeArg (populated by shape inference): +ORT_API2_STATUS(Node_GetInputShape, + _In_ const OrtNode* node, + _In_ size_t input_index, + _Outptr_result_maybenull_ const int64_t** shape, // NULL if dynamic + _Out_ size_t* rank); + +// Query integer attribute from a node: +ORT_API2_STATUS(Node_GetAttributeInt, + _In_ const OrtNode* node, + _In_ const char* attr_name, + _Out_ int64_t* value); + +// Query integer array attribute: +ORT_API2_STATUS(Node_GetAttributeInts, + _In_ const OrtNode* node, + _In_ const char* attr_name, + _Outptr_ const int64_t** values, + _Out_ size_t* count); +``` + +**Device properties and library handles** (cuDNN, cuBLAS, etc.) are accessed by casting `OrtEp*` to the EP's concrete type inside the estimation function — no public API needed (see examples above). + +--- + +### Implementation Across Kernel Types and GetCapability Paths + +ORT has **three distinct kernel authoring scenarios** and **two GetCapability architectures**. The workspace estimation and declaration APIs must work correctly in each combination. + +#### Three Kernel Types + +| Type | Description | Registration mechanism | Examples | +|------|-------------|----------------------|----------| +| **In-tree** | C++ kernels compiled into the ORT binary | `BuildKernelCreateInfo<>()` via macros (`ONNX_OPERATOR_TYPED_KERNEL_EX`) | All CPU kernels, legacy CUDA EP kernels | +| **Plugin (shared source)** | Same C++ source as in-tree, compiled into EP plugin DLL, uses adapter layer | `KernelRegistry_AddKernel` C ABI, with `CudaKernelAdapter` bridging | CUDA EP plugin kernels | +| **Pure ABI** | Kernels written directly against the C ABI (`OrtKernelImpl`) | `KernelRegistry_AddKernel` C ABI, `OrtKernelImpl` function pointers | Third-party EP plugin kernels | + +#### Two GetCapability Architectures + +| Architecture | Resource budgeting location | Workspace estimation call site | +|-------------|---------------------------|-------------------------------| +| **In-tree** | Inside `CUDAExecutionProvider::GetCapability()` — EP owns the loop, calls `resource_accountant->ComputeResourceCount(node)`, makes accept/reject decisions | EP calls estimation function directly in its loop | +| **Plugin bridge** | In `PluginExecutionProvider::GetCapability()` (the C++ host wrapper) — plugin EP only proposes candidates, host does budgeting after plugin returns | Host calls estimation function during budget enforcement | + +**Critical difference:** In the plugin path, the plugin's `GetCapabilityImpl` returns a list of "I support these nodes" without resource checks. The **host bridge** (`ep_plugin_provider_interfaces.cc`) then iterates those nodes in topological order, calls `resource_accountant->ComputeResourceCount(node)` for each, and enforces the budget — halting assignment when the threshold is exceeded. The plugin never sees the accountant directly. + +#### Implementation: `OrtKernelWorkspaceEstimateFunc` (Level 1 — at partitioning time) + +**In-tree path:** + +```cpp +// In CUDAExecutionProvider::GetCapability() loop (in-tree only): +const KernelCreateInfo* kci = kernel_lookup.LookUpKernel(node); +std::optional workspace_estimate; +if (kci && kci->workspace_estimate_func) { + size_t ws = 0; + // In-tree: pass IExecutionProvider* — func casts to CUDAExecutionProvider* + kci->workspace_estimate_func(this, node, &ws); + workspace_estimate = ws; +} +// Use non-member helper for budget decision: +auto total_cost = ComputeNodeCostForBudget(*resource_accountant, node, workspace_estimate); +// ... budget check with total_cost ... +``` + +The estimation function for in-tree kernels is a static member function. It casts the EP pointer to `CUDAExecutionProvider*` to access `GetDeviceProp()` and `PerThreadDefaultCudnnHandle()` — exactly the same pattern kernels already use. + +**Plugin bridge path:** + +```cpp +// In PluginExecutionProvider::GetCapability() host-side budget loop +// (ep_plugin_provider_interfaces.cc): +for (const auto& node_grouping : api_graph_support_info.node_groupings) { + const Node& internal_node = node_grouping.nodes[0]->GetInternalNode(); + + // Look up workspace estimate function from kernel registry + const KernelCreateInfo* kci = kernel_lookup.LookUpKernel(internal_node); + std::optional workspace_estimate; + if (kci && kci->workspace_estimate_func) { + size_t ws = 0; + // Plugin path: registered via KernelRegistry_AddKernelV2 + // Function casts OrtEp* to its concrete type internally + OrtStatus* est_status = kci->workspace_estimate_func( + &ep_api_, ep_node->ToExternal(), ort_ep_.get(), &ws); + if (est_status) { OrtApis::ReleaseStatus(est_status); } + else { workspace_estimate = ws; } + } + + // Same non-member helper as in-tree: + auto total_cost = ComputeNodeCostForBudget(*resource_accountant, internal_node, + workspace_estimate); + // ... budget check with total_cost ... +} +``` + +**Pure ABI path (third-party EP):** + +Same as plugin bridge — the estimation function is registered via `KernelRegistry_AddKernelV2` and called by the host during budget enforcement. The kernel author provides the function pointer at registration time: + +```c +// Third-party EP kernel registration: +OrtStatus* MyConvEstimate(const OrtEpApi* api, const OrtNode* node, + const OrtEp* ep, size_t* out) { + // Cast to concrete EP type to access device-specific state: + auto* my_ep = static_cast(ep); + // ... compute workspace from node shapes + my_ep->device_properties ... +} + +// During EP's RegisterKernels callback: +ep_api->KernelRegistry_AddKernelV2(registry, conv_kernel_def, CreateConvKernel, + /*state=*/nullptr, &MyConvEstimate); +``` + +#### Implementation: `DeclareWorkspaceRequirements` (Level 2 — after kernel creation) + +**In-tree path:** + +Straightforward — add a virtual method to `OpKernel`: + +```cpp +// In include/onnxruntime/core/framework/op_kernel.h: +[[nodiscard]] virtual Status DeclareWorkspaceRequirements( + gsl::span input_shapes, + InlinedVector& requirements) const { + return Status::OK(); // Default: no workspace declared +} +``` + +In-tree kernels override this just like they override `PrePack()`. Called during `FinalizeSessionState()` after kernel instances exist. + +**Plugin (shared source) path:** + +The `CudaKernelAdapter` already bridges virtual calls to the underlying kernel class. The adapter forwards `DeclareWorkspaceRequirements` to the underlying kernel's implementation: + +```cpp +// In cuda_kernel_adapter.h — adapter already forwards PrePack similarly: +Status DeclareWorkspaceRequirements( + gsl::span input_shapes, + InlinedVector& requirements) const override { + // The underlying kernel class (compiled in the plugin DLL) implements this directly. + // CudaKernelAdapter inherits from T, so T::DeclareWorkspaceRequirements is accessible. + return T::DeclareWorkspaceRequirements(input_shapes, requirements); +} +``` + +Since plugin shared-source kernels ARE the same C++ class (just compiled in a different DLL), they implement `DeclareWorkspaceRequirements` as a regular virtual override — no ABI translation needed. + +**Pure ABI path (third-party EP):** + +Add an optional function pointer to `OrtKernelImpl`: + +```c +// In onnxruntime_ep_c_api.h, extend OrtKernelImpl: +struct OrtKernelImpl { + // ... existing fields (Compute, Release, PrePackWeight, ...) ... + + // NEW — optional workspace declaration (ORT >= 1.XX): + ORT_API2_STATUS(DeclareWorkspaceRequirements, + _In_ OrtKernelImpl* this_ptr, + _In_ const int64_t* const* input_shapes, // array of shape arrays + _In_ const size_t* input_ranks, // rank of each input + _In_ size_t num_inputs, + _Out_ OrtWorkspaceRequirement** requirements, // allocated by kernel + _Out_ size_t* num_requirements); +}; +``` + +The `PluginEpOpKernel` adapter (in `ep_kernel_registration.cc`) bridges this to the virtual call: + +```cpp +// In PluginEpOpKernel: +Status DeclareWorkspaceRequirements( + gsl::span input_shapes, + InlinedVector& requirements) const override { + // Version guard (same pattern as PrePack): + if (kernel_impl_->ort_version_supported < XX || + kernel_impl_->DeclareWorkspaceRequirements == nullptr) { + return Status::OK(); // No declaration — fall back to arena + } + + // Convert TensorShape spans to C arrays + InlinedVector shape_ptrs; + InlinedVector ranks; + for (const auto& shape : input_shapes) { + shape_ptrs.push_back(shape.GetDims().data()); + ranks.push_back(shape.NumDimensions()); + } + + OrtWorkspaceRequirement* reqs = nullptr; + size_t num_reqs = 0; + ORT_RETURN_IF_ERROR(ToStatusAndRelease( + kernel_impl_->DeclareWorkspaceRequirements( + kernel_impl_, shape_ptrs.data(), ranks.data(), + shape_ptrs.size(), &reqs, &num_reqs))); + + // Convert C results to C++ vector + for (size_t i = 0; i < num_reqs; ++i) { + requirements.push_back({reqs[i].size_bytes, reqs[i].slot_id}); + } + // Free C allocation (kernel used OrtAllocator or static buffer) + return Status::OK(); +} +``` + +#### Summary: Where Each Piece Lives + +| Component | In-tree | Plugin (shared source) | Pure ABI | +|-----------|---------|----------------------|----------| +| **Workspace estimation func** | Static member on kernel class; stored in `KernelCreateInfo::workspace_estimate_func` | Same static function, registered via `KernelRegistry_AddKernelV2` | C function pointer, registered via `KernelRegistry_AddKernelV2` | +| **Who calls estimation** | EP's `GetCapability()` loop via `ComputeNodeCostForBudget()` helper | Host bridge via same `ComputeNodeCostForBudget()` helper | Host bridge (same) | +| **DeclareWorkspaceRequirements** | Virtual override on `OpKernel` | Virtual override (same C++ class in plugin DLL) | `OrtKernelImpl::DeclareWorkspaceRequirements` function pointer → `PluginEpOpKernel` adapter | +| **Who calls DeclareWorkspace** | `FinalizeSessionState()` | `FinalizeSessionState()` (same) | `FinalizeSessionState()` via adapter | +| **Device property access** | `static_cast(ep)->GetDeviceProp()` | `static_cast(ep)->GetDeviceProp()` | `static_cast(ep)->GetDeviceProps()` | +| **cuDNN handle access** | `static_cast(ep)->PerThreadDefaultCudnnHandle()` | `static_cast(ep)->GetCudnnHandle()` | N/A (EP-specific) | + +#### Key Design Principle + +The **estimation function** signature differs between in-tree and plugin paths: + +- **In-tree:** `static size_t EstimateWorkspace(const IExecutionProvider* ep, const Node& node)` — C++ types, direct EP access +- **Plugin/ABI:** `OrtStatus* EstimateWorkspace(const OrtEpApi*, const OrtNode*, const OrtEp*, size_t*)` — C ABI, opaque types + +But both compute the same result. For shared-source kernels (compiled both in-tree and as plugin), a single static helper function (e.g., `ComputeAttentionWorkspace()`) is called from both wrappers — ensuring the estimate is identical regardless of build configuration. + +--- + +#### Phase A: Workspace Pre-declaration (`DeclareWorkspaceRequirements`) + +The core missing piece. Today no kernel declares its temp buffer needs before `Compute()`. To close this gap, introduce a method on `OpKernel` that returns workspace descriptors — each with a size and a key for later retrieval. + +**Analogy to PrePack:** This mechanism is similar to the existing `PrePack()` pattern — both are called once during session state finalization (not during `Run()`), both store results that are reused across all subsequent runs. `PrePack()` pre-processes weight data; `DeclareWorkspaceRequirements()` pre-computes workspace layout. + +**Interface:** + +```cpp +struct WorkspaceRequirement { + size_t size_bytes; // Size of this workspace buffer + int slot_id; // Kernel-defined slot identifier (0, 1, 2, ...) + // Unique within a single kernel instance +}; + +// Optional override on OpKernel (called during FinalizeSessionState): +virtual Status DeclareWorkspaceRequirements( + gsl::span input_shapes, + InlinedVector& requirements) const { + return Status::OK(); // Default: no declaration (fall back to arena) +} +``` + +A kernel can declare multiple workspace slots (e.g., attention needs separate Q transpose buffer, output buffer, seqlens buffer). The `slot_id` is defined by the kernel author and is stable across calls — it identifies *which* buffer within that kernel's logic. + +**Key constraint:** Multiple nodes may use the same kernel class. Each node instance gets its own set of workspace slots. The unique key for retrieval is `(NodeIndex, slot_id)` — the framework supplies `NodeIndex`, the kernel supplies `slot_id`. + +**Memory reuse via liveness-based offset planning:** + +Workspace buffers are live only during their kernel's execution step. This means workspaces from non-overlapping steps can share the same physical memory — exactly the same liveness analysis already used for activation tensors. The offset planner assigns overlapping offsets to workspaces whose liveness intervals don't intersect: + +``` +Step 0: Node A workspace (slots 0,1) → offsets [0, 4096] +Step 1: Node B workspace (slot 0) → offset [0] ← reuses Node A's memory +Step 2: Node C workspace (slots 0,1) → offsets [0, 8192] ← reuses again +``` + +Peak workspace memory = max over all steps of (sum of workspace slots for that step), not the sum of all workspaces across all nodes. + +**Concurrency model (multiple concurrent `Run()` calls):** + +The existing memory pattern system already handles this correctly: +- The **pattern** (offset/size map) is computed once during `Initialize()` and cached in `SessionState` — shared, read-only. +- The **actual buffer** is allocated per-`Run()` by each `ExecutionFrame` using the pattern as a blueprint. +- Each concurrent `Run()` gets its own `ExecutionFrame` with its own workspace buffer — no sharing, no synchronization needed. + +Workspace pre-allocation follows the same model: +- `DeclareWorkspaceRequirements()` is called during `FinalizeSessionState()` → produces a workspace offset plan (shared, immutable). +- Each `Run()` allocates a workspace buffer of `peak_workspace_size` bytes and uses offsets from the plan. +- Concurrent runs each get their own buffer — safe without locks. + +**Note on CUDA:** In practice, concurrent `Run()` on the same CUDA session is uncommon (users don't typically do this). But the design should remain thread-safe by following the same per-run buffer pattern. + +**Single-thread pre-allocation mode (eliminating runtime OOM):** + +Even with workspace planning, the per-`Run()` buffer allocation can still OOM if device memory is fragmented or consumed by other processes since `Initialize()`. For constrained environments, this is the last remaining point of failure. + +Most constrained-environment users run **single-threaded inference** — one `Run()` at a time. ORT already has a concurrent-run counter (`InferenceSession::current_num_runs_`). If the session is configured to disallow concurrency, the execution buffer (which includes workspace slots) can be **allocated once at initialization and reused for every `Run()` call**. + +**Proposed** (not currently implemented): a session option such as `session.pre_allocate_execution_buffers = "1"` would enable this behavior. + +When enabled: +1. After `FinalizeSessionState()` computes the memory pattern (including workspace offsets from `DeclareWorkspaceRequirements`), allocate the peak buffer once: `IAllocator::Alloc(peak_size)` per EP. +2. Store the pre-allocated buffer pointer on `SessionState`. +3. Each `Run()` reuses the same buffer — no allocation, no OOM possible. +4. Enforce `max_concurrent_runs = 1`: if a second `Run()` arrives, fail fast. + +```cpp +if (pre_allocate_mode_ && current_num_runs_.fetch_add(1) > 0) { + current_num_runs_.fetch_sub(1); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Concurrent Run() not allowed with pre-allocated execution buffers."); +} +``` + +**What this guarantees:** If `Initialize()` succeeds, `Run()` cannot OOM — all device memory (weights + intermediates + workspace) is already resident. The budget at partition time accounts for all three: `budget ≥ weights_on_device + peak_execution_buffer`. + +**What already exists:** `MemoryPattern` computation is done, `MemoryPatternGroup::GetPeakAllocSize()` gives peak size, `current_num_runs_` counter exists, per-EP allocators exist. The `ExecutionFrame` already uses offset-based placement into a contiguous block — the change is to not free/reallocate that block between calls. + +**Scope:** Single-threaded only. For concurrent inference, multiple buffers are needed (defeating the guarantee). + +**Interaction with dynamic shapes:** `pre_allocate_execution_buffers` is fundamentally a **static-shape-only** feature. With dynamic shapes, `ExecutionFrame` must allocate buffers on every `Run()` because activation tensor sizes are unknown until the input arrives — there is no way to pre-compute a total buffer size at `Initialize()` time. Even if some kernels' workspace slots are shape-independent, the activation portion (which typically dominates) still requires per-`Run()` allocation, so the OOM-elimination guarantee cannot hold. + +Furthermore, the arena allocator already handles repeated allocations efficiently (same-size blocks are recycled without syscalls), so pre-allocating just the workspace portion while leaving activations dynamic would add complexity for negligible gain. + +**Summary:** For dynamic-shape models, the value of `DeclareWorkspaceRequirements` is in **budget estimation** (Level 1/Level 2, using worst-case or max-batch sizes to decide how many nodes fit on the device), not in runtime pre-allocation. + +**Planning flow (during FinalizeSessionState):** + +1. For each kernel in the execution plan (when shapes are static), call `DeclareWorkspaceRequirements()` with the inferred input shapes. +2. Record `{NodeIndex, slot_id} → size_bytes` in the execution plan. +3. Run liveness analysis: workspace for node N is live only during step N's execution. +4. Compute offsets (same algorithm as activation patterns) → yields `peak_workspace_size` and per-slot offsets. +5. Store workspace pattern as a **separate `WorkspacePattern`** in `SessionState`. + +**Why workspace buffers are separate from `MemoryPattern` (activations):** + +Although the offset planning algorithm is the same (liveness → assign offsets → compute peak), workspace buffers differ in allocation and retrieval: + +| Aspect | MemoryPattern (activations) | WorkspacePattern | +|--------|---------------------------|------------------| +| **Addressing** | `MLValueIndex` — framework-assigned, part of graph IR | `(NodeIndex, slot_id)` — kernel-defined, opaque to framework | +| **Who queries** | Framework automatically when creating output `OrtValue`s | Kernel explicitly via `GetPreallocatedWorkspace(slot_id)` | +| **Lifetime** | Multi-step — output lives until its last consumer executes | Single-step — live only during the owning kernel's step | +| **What's returned** | An `OrtValue` (typed tensor with shape metadata) | Raw `void*` — kernel interprets the bytes internally | +| **Graph visibility** | Framework manages these as edges between nodes | Invisible to graph — internal scratch memory | +| **Size determination** | Inferred from output shape × element_size | Declared by kernel (may be unrelated to any tensor shape) | + +Concretely, this means: +- `WorkspacePattern` is a new class (not reusing `MemoryPatternGroup`) with its own lookup: `GetOffset(NodeIndex, slot_id) → {offset, size}`. +- The workspace buffer is allocated separately from the activation buffer. They could share physical memory (workspace is always single-step, so it never overlaps with itself across steps), but keeping them separate simplifies accounting and makes budget tracking unambiguous: `peak_total = peak_activations + peak_workspace`. +- In pre-allocation mode, both buffers are allocated once at init. In normal mode, both are allocated per-`Run()` from the arena. But they remain distinct allocations with distinct query paths. + +**Per-Run retrieval (during Compute):** + +Each `ExecutionFrame` allocates a workspace buffer of `peak_workspace_size` via the EP's allocator and provides offset-based access through a dedicated query interface (not the existing OrtValue/MLValue machinery): + +**Alternative A: Transparent fallback in GetScratchBuffer** + +Modify `GetScratchBuffer(slot_id, size, stream)` to check for a pre-planned buffer first: + +```cpp +template +IAllocatorUniquePtr GetScratchBuffer(int slot_id, size_t count_or_bytes, Stream* stream) const { + // Check if workspace was pre-planned for this node + slot + void* preallocated = context_.GetPreallocatedWorkspace(slot_id); + if (preallocated) { + // Return non-owning pointer (buffer lifetime managed by the frame) + return IAllocatorUniquePtr(static_cast(preallocated), [](T*){}); + } + // Fall back to arena (dynamic shapes, or DeclareWorkspaceRequirements not implemented) + return IAllocator::MakeUniquePtr(allocator_, count_or_bytes, false, stream); +} +``` + +Pro: Minimal kernel code changes — just add `slot_id` parameter. Con: Overloads `GetScratchBuffer` semantics; non-owning vs owning pointer distinction is subtle. + +**Alternative B: Separate retrieval path** + +Keep `GetScratchBuffer()` unchanged for arena allocation. Add a new method: + +```cpp +// In OpKernelContext: +void* GetPreallocatedWorkspace(int slot_id) const; +// Returns nullptr if not pre-planned → kernel must call GetScratchBuffer() instead + +// Kernel usage: +void* ws = context->GetPreallocatedWorkspace(0); +if (!ws) { + scratch_buffer_ = GetScratchBuffer(workspace_size, stream); + ws = scratch_buffer_.get(); +} +``` + +Pro: Clear separation, no ambiguity about ownership. Con: Kernels need explicit fallback logic (but this is a one-time pattern per kernel). + +**Compatibility with dynamic shapes:** Both alternatives are opt-in. If `DeclareWorkspaceRequirements()` is not overridden or returns empty (dynamic shapes), everything falls back to `GetScratchBuffer()` → arena, exactly as today. Same kernel binary works for both static and dynamic models. + +**Incremental adoption:** Start with the highest-impact ops (attention, convolution, GEMM) which account for the majority of workspace. Less common ops continue using the arena with a reduced safety multiplier in `IResourceAccountant`. + +**Buffer strategy:** Workspace offsets can share the activation buffer (liveness doesn't overlap — workspace is live only during its step, activations may span steps). Alternatively, a separate workspace buffer is simpler initially and easier to account for in memory limits. + +##### EP Plugin C ABI Surface for Workspace Pre-declaration + +In the plugin architecture, `DeclareWorkspaceRequirements` crosses the C ABI boundary. This section defines the concrete API additions. + +**Declaration side — new optional function pointer on `OrtKernelImpl`:** + +```c +// Added to OrtKernelImpl (optional, like PrePackWeight): +ORT_API2_STATUS(DeclareWorkspaceRequirements, + _In_ OrtKernelImpl* this_ptr, + _In_reads_(num_inputs) const int64_t* const* input_shapes, // shape per input + _In_reads_(num_inputs) const size_t* input_shape_ranks, // rank per input + _In_ size_t num_inputs, + _Out_writes_all_(max_slots) OrtWorkspaceSlot* slots, // pre-allocated by ORT + _In_ size_t max_slots, // capacity (e.g., 8) + _Out_ size_t* num_slots); // actual count filled + +// Slot descriptor (C struct, no inheritance): +typedef struct OrtWorkspaceSlot { + int slot_id; // Kernel-defined, stable identifier (0, 1, 2, ...) + size_t size_bytes; // Required size for this slot +} OrtWorkspaceSlot; +``` + +If `DeclareWorkspaceRequirements` is NULL on the `OrtKernelImpl`, ORT skips the kernel during workspace planning (falls back to arena at runtime). + +**Retrieval side — new function in `OrtEpApi`:** + +```c +// Added to OrtEpApi (called by plugin kernels during Compute): +ORT_API2_STATUS(KernelContext_GetPreallocatedWorkspace, + _In_ const OrtKernelContext* context, + _In_ int slot_id, + _Outptr_result_maybenull_ void** buffer); // NULL if not pre-planned +``` + +Returns a pointer into the pre-allocated workspace buffer at the offset computed during planning. Returns NULL if no workspace was pre-planned for this kernel+slot (dynamic shapes, or kernel didn't declare). The pointer is valid for the duration of the `Compute()` call. + +**Slot ID provisioning — how kernels define unique slot_ids:** + +Slot IDs are **kernel-author-defined constants**, not dynamically allocated. Each kernel class defines its slots as an enum or set of constants in its implementation: + +```cpp +// Example: CUDA Attention kernel (inside the plugin DLL) +namespace cuda { +class AttentionKernel : public OrtKernelImplBase { + // Slot IDs are private constants — stable across versions, used as array indices + static constexpr size_t kSlotQTranspose = 0; + static constexpr size_t kSlotKTranspose = 1; + static constexpr size_t kSlotVTranspose = 2; + static constexpr size_t kSlotSoftmaxWorkspace = 3; + static constexpr size_t kNumSlots = 4; + + OrtStatus* DeclareWorkspaceRequirements(...) override { + slots[kSlotQTranspose] = {kSlotQTranspose, batch * heads * seq * head_dim * sizeof(half)}; + slots[kSlotKTranspose] = {kSlotKTranspose, batch * heads * seq * head_dim * sizeof(half)}; + slots[kSlotVTranspose] = {kSlotVTranspose, batch * heads * seq * head_dim * sizeof(half)}; + slots[kSlotSoftmaxWorkspace] = {kSlotSoftmaxWorkspace, cudnn_workspace_size}; + *num_slots = kNumSlots; + return nullptr; + } + + OrtStatus* Compute(OrtKernelContext* ctx) override { + void* q_buf = nullptr; + // Uses pre-planned workspace if available, falls back to arena otherwise + api_->KernelContext_GetScratchBuffer(ctx, kSlotQTranspose, q_transpose_size, &q_buf); + // ... use q_buf ... + } +}; +} // namespace cuda +``` + +**Key design properties:** + +| Property | Design Choice | Rationale | +|----------|--------------|-----------| +| Slot ID scope | Per kernel *instance* (node) | Same kernel class on different nodes gets separate buffers; ORT disambiguates via `(NodeIndex, slot_id)` | +| Slot ID assignment | Static constants in kernel code | No registry, no runtime allocation, no cross-kernel coordination needed | +| Slot ID range | `[0, max_slots)` — small integers | Simple array indexing in the offset plan; `max_slots` = 8 is generous for any single kernel | +| Uniqueness guarantee | Kernel author's responsibility | Same convention as `input_index` in `PrePackWeight` — the kernel knows its own buffer layout | +| Stability across versions | Expected (like enum values) | Slot IDs are internal to the kernel; not exposed to users or other kernels | + +**Where state lives:** + +| State | Location | Lifetime | +|-------|----------|----------| +| Slot definitions (id + size) | Returned by `DeclareWorkspaceRequirements` → stored in `ExecutionPlan` | Session lifetime (computed once at `Initialize()`) | +| Offset map `{(NodeIndex, slot_id) → offset}` | `SessionState::workspace_pattern_` (new field, analogous to `mem_patterns_`) | Session lifetime (shared, read-only) | +| Peak workspace size per EP/device | `SessionState::workspace_pattern_` | Session lifetime | +| Actual workspace buffer | `ExecutionFrame` (allocated per-`Run()` via `Reserve()`) | Single `Run()` invocation | + +**No global slot registry needed.** Unlike input indices which are defined by the ONNX op schema, slot IDs are entirely internal to the kernel implementation. Two different kernel classes can both use `slot_id=0` without conflict — the framework always qualifies with `NodeIndex`. This means: +- No coordination between kernel authors +- No registration step during plugin initialization +- No versioning concerns (IDs never cross the plugin boundary as semantic values) + +#### Phase B: Eliminate Arena for Static-Shape Models + +Once workspace is pre-declared, **all** allocations for a static-shape model are known at `Initialize()` time: +- Initializers → already `Reserve()` (done) +- Activations → already memory-pattern `Reserve()` (done) +- Workspace → new, via Phase A + +At this point, the BFC arena serves no purpose for the main execution path. The session could: +1. Pre-allocate exact memory per device (sum of pattern peak + workspace peak) +2. Use offset-based addressing for all buffers +3. Disable the arena entirely for this session (save memory waste from chunk granularity) + +Runtime temp buffers from ops that don't implement `DeclareWorkspaceRequirements()` can still fall back to a small arena. + +### Custom Executable: Purpose and Scope + +A minimal custom executable (CLI tool) serves three purposes: + +1. **Code example and test bed.** Demonstrates how to configure and exercise the constrained-environment features (name-based partitioning, memory budgets, static allocation mode) end-to-end using the ORT C/C++ API. Acts as a living integration test that exercises the full pipeline without depending on GenAI or external frameworks. + +2. **Interactive LLM demo (llama.cpp-style UX).** Loads a transformer ONNX model, manages the decode loop (prompt → KV cache → token sampling → output), and interacts with the user via stdin/stdout. This showcases ORT's ability to run large models on constrained hardware with the same user experience as llama.cpp — but backed by ORT's general-purpose runtime. + +3. **Primitive GenAI replacement for testing.** For the narrow case of single-model, single-user, greedy/top-k text generation, the executable can replace GenAI as a simpler alternative that doesn't pull in the full GenAI dependency. It is **not** a production replacement for GenAI (no batching, no beam search, no speculative decoding) — it is a minimal harness for validating that the partitioning and memory features work correctly on real models. + +**What the executable handles (application-level):** +- Token encode/decode (via sentencepiece or tokenizers library) +- KV cache allocation and rotation (fixed max sequence length) +- Autoregressive decode loop (feed output token back as next input) +- Session configuration: name-based layer assignment, memory budget, static shapes + +**What ORT handles (session-level, no executable changes needed):** +- Graph partitioning across devices (Direction 1 + `IResourceAccountant`) +- Static memory pre-allocation (Phase A + B) +- Kernel execution, data transfers, stream synchronization + +**Feasibility: no fundamental ORT blockers.** The existing session API (`CreateSession` → `Run` with named I/O) is sufficient for an autoregressive decode loop. KV cache management is feeding output tensors back as inputs — the same pattern GenAI uses over the same C API. + +| Concern | Status | Notes | +|---------|--------|-------| +| Tokenizer | Reuse from ORT Extensions | See tokenizer strategy below | +| KV cache rotation | Straightforward | Pre-allocate `(batch, heads, max_seq, head_dim)`, feed `past_key_values` outputs back as inputs each step | +| Decode loop | Trivial | Run session → extract logits → sample token → repeat | +| Model format | Constraint | Requires decoder-style ONNX export with explicit KV cache I/O (HuggingFace optimum exports provide this) | +| Partitioning | This design | Direction 1 + `IResourceAccountant` | +| Static allocation | Phase A+B | Fixed `max_seq_len` makes all decode-phase shapes static | + +**Tokenizer strategy — borrowing from ORT Extensions:** + +[ORT Extensions](https://github.com/microsoft/onnxruntime-extensions) already implements production-quality tokenizers in C++: +- **BPE** (GPT-2, LLaMA-3, Phi, Mistral) — `onnxruntime_extensions/tokenizer/bpe_tokenizer.cc` +- **SentencePiece** (LLaMA-1/2, T5, mT5) — wraps the SentencePiece C++ library +- **WordPiece** (BERT, DistilBERT) — `onnxruntime_extensions/tokenizer/wordpiece_tokenizer.cc` + +For the custom executable, we can **extract the tokenizer C++ code directly** from ORT Extensions rather than taking a full dependency on the extensions DLL. The tokenizer logic is self-contained: it reads a vocabulary/merge file, applies the algorithm (BPE merge loop, SentencePiece unigram, or WordPiece greedy match), and produces token IDs. No ONNX graph execution is involved. + +**Practical approach:** +1. Copy the relevant tokenizer source files (BPE tokenizer is ~500 LOC + vocab loading) into the demo executable's source tree. +2. Strip the ORT Extensions custom-op registration wrapper — keep only the core `Encode(string) → vector` and `Decode(vector) → string` logic. +3. Load the tokenizer model file (e.g., `tokenizer.json` from HuggingFace, or `tokenizer.model` for SentencePiece) at startup alongside the ONNX model. + +This gives us a battle-tested tokenizer with no additional runtime dependency — just a few source files compiled into the executable. The code is already Apache-2.0 licensed (same as ORT). + +The executable would be ~500–1000 LOC (excluding tokenizer): configure session options, set up KV cache tensors, run the generate loop. With the borrowed tokenizer code, the total grows to ~1500–2000 LOC but remains self-contained with zero external dependencies beyond ORT itself. + +--- + +## Recommended Roadmap + +``` +Near-term (low effort, high value): +├── 1. Name-based matching via session.name_based_layer_assignment [DONE] +│ - Separate session option with substring matching against Node::Name() +│ - SubstringMatcher with longest-match-wins priority +│ - Mutually exclusive with annotation-based matching (setting both options returns INVALID_ARGUMENT) +│ +├── 2. Precise per-node memory estimation +│ - Static workspace estimation functions registered per kernel type +│ - IResourceAccountant uses exact output sizes + workspace estimates +│ - Eliminates 1.5x multiplier for kernels with estimation functions +│ +Mid-term (medium effort): +├── 3. Auto-partitioning with memory budget only +│ - User specifies "6GB GPU budget" +│ - ORT computes optimal layer split automatically +│ - Combines (1) + (2) +│ +├── 4. Static allocation mode +│ - Pre-allocate all buffers when shapes are known +│ - Eliminate per-Run() allocation overhead +│ +Long-term (high effort, ollama-parity): +├── 5. Layer prefetch pipeline +│ - Stream weights CPU↔GPU during execution +│ - Enables running models larger than GPU memory +│ +└── 6. Integration with GenAI + - KV cache-aware memory planning + - Continuous batching + layer offload coordination +``` + +--- + +## Key Insight + +The fundamental difference between ORT and llama.cpp for this use case is **generality vs specialization**. llama.cpp knows it's running a transformer with sequential layers. ORT handles arbitrary graphs. The trick is to **detect** when a model is transformer-like (sequential layers, static shapes) and engage a specialized execution path — without losing generality for other model types. + +Direction 1 (name-based matching) is the lowest-friction win: it makes the existing annotation system accessible without model modification. Direction 2 (static pre-allocation + auto-splitting) is what closes the gap with ollama but requires more infrastructure work, particularly around shape-aware memory planning at partition time. diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 0b6d009e072ad..eee100aeef8df 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -420,6 +420,21 @@ static const char* const kOrtSessionOptionsResourceCudaPartitioningSettings = /// static const char* const kOrtSessionOptionsLayerAssignmentSettings = "session.layer_assignment_settings"; +/// +/// Name-based layer assignment. Uses the same device(pattern1, pattern2, ...); ... grammar +/// as kOrtSessionOptionsLayerAssignmentSettings but performs SUBSTRING matching against +/// Node::Name() instead of prefix/exact matching against node metadata annotations. +/// The '=' prefix (exact match) from the annotation-based grammar is rejected with an error +/// — all patterns are treated as substrings. +/// Longest matching pattern wins when multiple patterns match the same node name. +/// No subgraph inheritance is applied — each node is matched independently by its name. +/// +/// MUTUALLY EXCLUSIVE with kOrtSessionOptionsLayerAssignmentSettings. Setting both returns +/// INVALID_ARGUMENT. Use annotation-based matching for models with explicit layer annotations, +/// or name-based matching for models with structured node names (HuggingFace, PyTorch exports). +/// +static const char* const kOrtSessionOptionsNameBasedLayerAssignment = "session.name_based_layer_assignment"; + // Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) diff --git a/onnxruntime/core/framework/layering_annotations.cc b/onnxruntime/core/framework/layering_annotations.cc index b3a41d714137f..f4dddecd207b8 100644 --- a/onnxruntime/core/framework/layering_annotations.cc +++ b/onnxruntime/core/framework/layering_annotations.cc @@ -13,6 +13,7 @@ #include "core/framework/execution_providers.h" #include "core/graph/graph.h" +#include #include namespace onnxruntime { @@ -335,28 +336,62 @@ LayeringIndex LayeringIndex::Create(const Graph& graph, EpNameToLayeringIndices ep_map, LayeringIndexToEpName rule_map, LayeringRules layering_rules) { - // 1. Create LayeringIndex instance with pre-computed maps LayeringIndex index(std::move(layering_rules), std::move(ep_map), std::move(rule_map)); - - // 2. Traverse the graph and index nodes index.ProcessGraph(graph, std::nullopt); + return index; +} +LayeringIndex LayeringIndex::Create(const Graph& graph, + EpNameToLayeringIndices ep_map, + LayeringIndexToEpName rule_map, + LayeringRules layering_rules, + SubstringMatcher substring_matcher) { + LayeringIndex index(std::move(layering_rules), std::move(ep_map), std::move(rule_map), + std::move(substring_matcher)); + index.ProcessGraph(graph, std::nullopt); return index; } Status LayeringIndex::Create(const Graph& graph, const std::string& config_string, + const std::string& name_based_config_string, gsl::span ep_devices, const ExecutionProviders& ep_providers, const logging::Logger& logger, std::optional& layering_index) { - LayeringRules rules; - ORT_RETURN_IF_ERROR(LayeringRules::FromConfigString(config_string, rules)); + // Annotation-based and name-based layer assignment are mutually exclusive. + if (!config_string.empty() && !name_based_config_string.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Cannot set both 'session.layer_assignment_settings' and " + "'session.name_based_layer_assignment'. These options are mutually exclusive. " + "Use annotation-based matching for models with explicit layer annotations, " + "or name-based matching for models with structured node names."); + } - LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " layering rules from config."; + const bool is_name_based = !name_based_config_string.empty(); + const std::string& active_config = is_name_based ? name_based_config_string : config_string; + + LayeringRules rules; + if (!active_config.empty()) { + ORT_RETURN_IF_ERROR(LayeringRules::FromConfigString(active_config, rules)); + + if (is_name_based) { + // Reject '=' (exact-match qualifier) in name-based rules — all patterns must be substrings + for (const auto& rule : rules.rules) { + if (!rule.prefix_match) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Name-based layer assignment does not support the '=' (exact-match) qualifier. " + "All patterns are treated as substrings. Remove the '=' prefix from pattern: '", + rule.annotation, "'"); + } + } + LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " name-based layering rules from config."; + } else { + LOGS(logger, INFO) << "Parsed " << rules.rules.size() << " annotation-based layering rules from config."; + } + } if (rules.rules.empty()) { - // Return no index indicating no layering layering_index.reset(); return Status::OK(); } @@ -384,9 +419,6 @@ Status LayeringIndex::Create(const Graph& graph, if (matched_ep) { const std::string& ep_type = *matched_ep; ep_map[ep_type].insert(i); - // Ensure 1:1 mapping from rule index to EP type - // Note: A rule index refers to a unique entry in LayeringRules::rules vector. - // So 'i' is unique. rule_map[i] = ep_type; matched_rule_count++; LOGS(logger, VERBOSE) << "Layering Rule " << i << " (" << rule.device << " -> " << rule.annotation @@ -402,7 +434,17 @@ Status LayeringIndex::Create(const Graph& graph, LOGS(logger, INFO) << "LayeringIndex created. Matched " << matched_rule_count << " out of " << rules.rules.size() << " rules to available Execution Providers."; - layering_index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(rules)); + // Build SubstringMatcher for name-based mode + std::optional substring_matcher; + if (is_name_based) { + substring_matcher.emplace(rules); + } + + // Create LayeringIndex — annotation mode uses matcher_ only, name-based uses substring_matcher_ only + LayeringIndex index(std::move(rules), std::move(ep_map), std::move(rule_map), + std::move(substring_matcher)); + index.ProcessGraph(graph, std::nullopt); + layering_index = std::move(index); return Status::OK(); } @@ -423,16 +465,23 @@ void LayeringIndex::ProcessGraph(const Graph& graph, std::optional paren for (auto& node : graph.Nodes()) { std::optional matched_rule_idx = std::nullopt; - // 4. For every node query its annotation - const std::string& annotation = node.GetLayeringAnnotation(); - if (!annotation.empty()) { - // If it has an annotation try to match it - matched_rule_idx = matcher_.Match(annotation); - } + if (substring_matcher_) { + // Name-based mode: substring matching against node name, no inheritance. + // Node names are dense (virtually every node has one), so inheritance is + // unnecessary — each node is matched independently by its own name. + matched_rule_idx = substring_matcher_->Match(node.Name()); + } else { + // Annotation-based mode: prefix/exact match against metadata annotation, + // with subgraph inheritance for unannotated nodes. + const std::string& annotation = node.GetLayeringAnnotation(); + if (!annotation.empty()) { + matched_rule_idx = matcher_.Match(annotation); + } - // 5. If node has no annotation, inherit from subgraph parent node - if (!matched_rule_idx && parent_layer_id) { - matched_rule_idx = parent_layer_id; + // Inherit from subgraph parent node if no annotation match + if (!matched_rule_idx && parent_layer_id) { + matched_rule_idx = parent_layer_id; + } } // Record assignment if we have a match @@ -485,28 +534,36 @@ void LayeringIndex::Update(const Graph& graph, gsl::span nodes) continue; } - const std::string& annotation = node->GetLayeringAnnotation(); - if (!annotation.empty()) { - auto matched_rule_idx = matcher_.Match(annotation); - - if (matched_rule_idx) { - const size_t rule_idx = *matched_rule_idx; - - // Only assign if this rule maps to a valid EP in our configuration - if (layering_index_to_ep_name_.count(rule_idx)) { - // Check if already assigned to a DIFFERENT rule, if so clean up old mapping - auto prev_assign = current_graph_index.node_to_layering_index_.find(node_index); - if (prev_assign != current_graph_index.node_to_layering_index_.end()) { - size_t old_rule = prev_assign->second; - if (old_rule != rule_idx) { - current_graph_index.layer_to_node_ids_[old_rule].erase(node_index); - } - } + std::optional matched_rule_idx; - ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node_index, rule_idx)); - ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node_index)); - was_updated = true; + if (substring_matcher_) { + // Name-based mode: substring match against node name + matched_rule_idx = substring_matcher_->Match(node->Name()); + } else { + // Annotation-based mode: prefix/exact match against metadata + const std::string& annotation = node->GetLayeringAnnotation(); + if (!annotation.empty()) { + matched_rule_idx = matcher_.Match(annotation); + } + } + + if (matched_rule_idx) { + const size_t rule_idx = *matched_rule_idx; + + // Only assign if this rule maps to a valid EP in our configuration + if (layering_index_to_ep_name_.count(rule_idx)) { + // Check if already assigned to a DIFFERENT rule, if so clean up old mapping + auto prev_assign = current_graph_index.node_to_layering_index_.find(node_index); + if (prev_assign != current_graph_index.node_to_layering_index_.end()) { + size_t old_rule = prev_assign->second; + if (old_rule != rule_idx) { + current_graph_index.layer_to_node_ids_[old_rule].erase(node_index); + } } + + ORT_IGNORE_RETURN_VALUE(current_graph_index.node_to_layering_index_.insert_or_assign(node_index, rule_idx)); + ORT_IGNORE_RETURN_VALUE(current_graph_index.layer_to_node_ids_[rule_idx].insert(node_index)); + was_updated = true; } } } @@ -544,6 +601,30 @@ void LayeringRuleMatcher::UpdateBestMatch(std::optional& current_best, s } } +SubstringMatcher::SubstringMatcher(const LayeringRules& rules) { + for (size_t i = 0; i < rules.rules.size(); ++i) { + const auto& rule = rules.rules[i]; + if (!rule.annotation.empty()) { + patterns_.push_back({rule.annotation, i}); + } + } + // Sort by pattern length descending (longest first). + // Stable sort preserves config order as tiebreaker for same-length patterns. + std::stable_sort(patterns_.begin(), patterns_.end(), + [](const PatternEntry& a, const PatternEntry& b) { + return a.pattern.size() > b.pattern.size(); + }); +} + +std::optional SubstringMatcher::Match(std::string_view node_name) const { + for (const auto& entry : patterns_) { + if (node_name.find(entry.pattern) != std::string_view::npos) { + return entry.rule_index; + } + } + return std::nullopt; +} + std::optional>> LayeringIndex::GetLayeringRulesForThisEp(const std::string& ep_type) const { auto hit = ep_name_to_layering_indices_.find(ep_type); diff --git a/onnxruntime/core/framework/layering_annotations.h b/onnxruntime/core/framework/layering_annotations.h index 5d58e9ace2471..4114527e07d01 100644 --- a/onnxruntime/core/framework/layering_annotations.h +++ b/onnxruntime/core/framework/layering_annotations.h @@ -11,6 +11,7 @@ #include "core/common/logging/logging.h" #include "gsl/gsl" #include +#include #include #include #include @@ -57,6 +58,7 @@ struct LayeringRules { /// class LayeringRuleMatcher { public: + /// The annotation-based layering rules to index. explicit LayeringRuleMatcher(const LayeringRules& rules); /// @@ -83,6 +85,35 @@ class LayeringRuleMatcher { void UpdateBestMatch(std::optional& current_best, size_t candidate) const; }; +/// +/// Performs substring matching against node names. Unlike LayeringRuleMatcher (which does +/// prefix/exact matching from position 0), this matches patterns appearing anywhere in the +/// input string. Longest matching pattern wins. +/// +class SubstringMatcher { + public: + /// The rules whose annotations become substring patterns. + /// The '=' prefix (exact match) qualifier is rejected during config parsing — all patterns + /// must be substrings. + explicit SubstringMatcher(const LayeringRules& rules); + + /// + /// Returns the index of the best matching rule for the given node name. + /// "Best" = longest pattern that appears as a substring in the name. + /// + /// the node's name to match against + /// index of the matching LayeringRule if a substring match is found + std::optional Match(std::string_view node_name) const; + + private: + struct PatternEntry { + std::string pattern; + size_t rule_index; + }; + // Sorted by pattern length descending. First match wins (longest-match priority). + InlinedVector patterns_; +}; + namespace EpLayeringMatcher { /// /// Matches a list of available OrtEpDevices against the device string specified in the LayerAnnotation. @@ -125,12 +156,25 @@ class LayeringIndex { LayeringIndexToEpName rule_map, LayeringRules layering_rules); + /// + /// Creates a fully initialized LayeringIndex with a SubstringMatcher for name-based matching. + /// In this mode, annotation matching is disabled and no subgraph inheritance is applied. + /// + static LayeringIndex Create(const Graph& graph, + EpNameToLayeringIndices ep_map, + LayeringIndexToEpName rule_map, + LayeringRules layering_rules, + SubstringMatcher substring_matcher); + /// /// Factory method that creates a LayeringIndex by parsing configuration, matching rules against /// available devices/providers, and indexing the graph. + /// Annotation-based and name-based options are mutually exclusive — setting both returns an error. /// /// The graph to index. - /// The configuration string containing layering rules. + /// The annotation-based configuration string (prefix/exact match on metadata). + /// The name-based configuration string (substring match on Node::Name()). + /// May be empty if name-based matching is not configured. /// Available OrtEpDevices to match rules against. /// Available ExecutionProviders to match rules against (fallback). /// Logger for reporting information/errors. @@ -139,6 +183,7 @@ class LayeringIndex { /// Status indicating success or failure. static Status Create(const Graph& graph, const std::string& config_string, + const std::string& name_based_config_string, gsl::span ep_devices, const ExecutionProviders& ep_providers, const logging::Logger& logger, @@ -192,11 +237,17 @@ class LayeringIndex { LayerIndexToNodes layer_to_node_ids_; }; - LayeringIndex(LayeringRules layering_rules, EpNameToLayeringIndices ep_name_to_layering_indices, LayeringIndexToEpName layering_index_to_ep_name) + LayeringIndex(LayeringRules layering_rules, EpNameToLayeringIndices ep_name_to_layering_indices, + LayeringIndexToEpName layering_index_to_ep_name, + std::optional substring_matcher = std::nullopt) : rules_(std::move(layering_rules)), matcher_(rules_), ep_name_to_layering_indices_(std::move(ep_name_to_layering_indices)), - layering_index_to_ep_name_(std::move(layering_index_to_ep_name)) {} + layering_index_to_ep_name_(std::move(layering_index_to_ep_name)), + substring_matcher_(std::move(substring_matcher)) {} + + // Optional substring matcher for name-based layer assignment + std::optional substring_matcher_; // Graph and sub-graphs mapping to their indices InlinedHashMap graph_index_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 9db18fc4ac69e..9d48e98a8bc16 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1652,8 +1652,9 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::optional layering_index_storage; const auto layering_config = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsLayerAssignmentSettings, ""); - if (!layering_config.empty()) { - ORT_RETURN_IF_ERROR_SESSIONID_(LayeringIndex::Create(graph, layering_config, {}, execution_providers_, + const auto name_based_config = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsNameBasedLayerAssignment, ""); + if (!layering_config.empty() || !name_based_config.empty()) { + ORT_RETURN_IF_ERROR_SESSIONID_(LayeringIndex::Create(graph, layering_config, name_based_config, {}, execution_providers_, *session_logger_, layering_index_storage)); if (layering_index_storage) { layering_index = &layering_index_storage.value(); diff --git a/onnxruntime/test/framework/layering_annotations_test.cc b/onnxruntime/test/framework/layering_annotations_test.cc index 7c7bd6a230a9d..a698b2c16abb1 100644 --- a/onnxruntime/test/framework/layering_annotations_test.cc +++ b/onnxruntime/test/framework/layering_annotations_test.cc @@ -12,6 +12,7 @@ #include "core/graph/constants.h" #include "core/graph/model.h" // For Model, Graph #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "test/util/include/asserts.h" #include "test/util/include/test_environment.h" @@ -1786,7 +1787,263 @@ TEST(LayeringIndexPartitionerTest, MultipleRulesForSameEp) { EXPECT_TRUE(found[3]); // node3 - unassigned } +// ===================== SubstringMatcher Tests ===================== + +TEST(SubstringMatcherTest, BasicSubstringMatch) { + LayeringRules rules; + rules.rules.push_back({"gpu", "layers.0/", true}); // Index 0 + rules.rules.push_back({"gpu", "layers.1/", true}); // Index 1 + rules.rules.push_back({"cpu", "embed_tokens", true}); // Index 2 + + SubstringMatcher matcher(rules); + + // Match in the middle of a node name + { + auto result = matcher.Match("/model/layers.0/self_attn/q_proj/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + { + auto result = matcher.Match("/model/layers.1/mlp/gate_proj/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); + } + { + auto result = matcher.Match("/model/embed_tokens/Gather"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 2u); + } + // No match + { + auto result = matcher.Match("/model/norm/LayerNormalization"); + EXPECT_FALSE(result.has_value()); + } +} + +TEST(SubstringMatcherTest, LongestMatchWins) { + LayeringRules rules; + // Shorter pattern first in config order + rules.rules.push_back({"cpu", "layers.1", true}); // Index 0 — would match layers.10, layers.11, etc. + rules.rules.push_back({"gpu", "layers.10", true}); // Index 1 — longer, should win for layers.10 + + SubstringMatcher matcher(rules); + + // "layers.10" is longer — should match even though "layers.1" also appears as a substring + { + auto result = matcher.Match("/model/layers.10/self_attn/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); // layers.10 wins (longer) + } + // "layers.1/" — only "layers.1" matches (layers.10 does not appear here) + { + auto result = matcher.Match("/model/layers.1/self_attn/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } +} + +TEST(SubstringMatcherTest, TrailingSlashDisambiguates) { + LayeringRules rules; + rules.rules.push_back({"gpu", "layers.1/", true}); // Index 0 — won't match layers.10/ + rules.rules.push_back({"cpu", "layers.10/", true}); // Index 1 + + SubstringMatcher matcher(rules); + + { + auto result = matcher.Match("/model/layers.1/self_attn/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + { + auto result = matcher.Match("/model/layers.10/self_attn/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); // "layers.10/" matches, not "layers.1/" + } +} + +TEST(SubstringMatcherTest, BasicRuleIndices) { + LayeringRules rules; + rules.rules.push_back({"gpu", "layers.0/", true}); // Index 0 + rules.rules.push_back({"cpu", "embed", true}); // Index 1 + + SubstringMatcher matcher(rules); + + { + auto result = matcher.Match("/model/layers.0/MatMul"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 0u); + } + { + auto result = matcher.Match("/model/embed_tokens/Gather"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, 1u); + } +} + +TEST(SubstringMatcherTest, EmptyNameNoMatch) { + LayeringRules rules; + rules.rules.push_back({"gpu", "layers.0/", true}); + + SubstringMatcher matcher(rules); + + auto result = matcher.Match(""); + EXPECT_FALSE(result.has_value()); +} + +TEST(LayeringIndexTest, NameBasedMatchingIntegration) { + // Test that LayeringIndex uses SubstringMatcher for unannotated nodes when configured + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // Create nodes with transformer-style names (no annotations) + NodeArg* input = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* mid1 = &graph.GetOrCreateNodeArg("mid1", &type_proto); + NodeArg* mid2 = &graph.GetOrCreateNodeArg("mid2", &type_proto); + NodeArg* output = &graph.GetOrCreateNodeArg("output", &type_proto); + + Node& node0 = graph.AddNode("/model/layers.0/self_attn/MatMul", "Abs", "", {input}, {mid1}); + Node& node1 = graph.AddNode("/model/layers.1/mlp/MatMul", "Abs", "", {mid1}, {mid2}); + Node& node2 = graph.AddNode("/model/embed_tokens/Gather", "Abs", "", {mid2}, {output}); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Build name-based rules for layers.0 -> gpu, layers.1 -> cpu + LayeringRules name_rules; + name_rules.rules.push_back({"gpu", "layers.0/", true}); // merged index 0 + name_rules.rules.push_back({"cpu", "layers.1/", true}); // merged index 1 + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["GpuEP"].insert(0); + ep_map["CpuEP"].insert(1); + + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "GpuEP"; + rule_map[1] = "CpuEP"; + + // Create index with SubstringMatcher + SubstringMatcher substring_matcher(name_rules); + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), std::move(name_rules), + std::move(substring_matcher)); + + // Node with "layers.0/" in name → rule 0 + auto assign0 = index.GetNodeAssignment(graph, node0.Index()); + ASSERT_TRUE(assign0.has_value()); + EXPECT_EQ(*assign0, 0u); + + // Node with "layers.1/" in name → rule 1 + auto assign1 = index.GetNodeAssignment(graph, node1.Index()); + ASSERT_TRUE(assign1.has_value()); + EXPECT_EQ(*assign1, 1u); + + // Node with "embed_tokens" — no matching rule → unassigned + auto assign2 = index.GetNodeAssignment(graph, node2.Index()); + EXPECT_FALSE(assign2.has_value()); +} + +TEST(LayeringIndexTest, MutualExclusivityRejectsBothConfigs) { + // Setting both annotation-based and name-based configs must return INVALID_ARGUMENT. + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + NodeArg* input = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* output = &graph.GetOrCreateNodeArg("output", &type_proto); + graph.AddNode("node0", "Abs", "", {input}, {output}); + ASSERT_STATUS_OK(graph.Resolve()); + + ExecutionProviders providers; + std::optional layering_index; + + // Both configs set — should fail + auto status = LayeringIndex::Create(graph, + /*config_string=*/"cpu(layer1)", + /*name_based_config_string=*/"gpu(layers.0/)", + {}, providers, + DefaultLoggingManager().DefaultLogger(), + layering_index); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), ::testing::HasSubstr("mutually exclusive")); +} + +TEST(LayeringIndexTest, UpdateAppliesSubstringMatcherToNewNodes) { + // Verifies that LayeringIndex::Update() applies substring_matcher_ fallback + // for new nodes that have no annotation but whose Name() matches a name-based pattern. + // This covers the post-layout-transform incremental update path. + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 12; + Model model("test_model", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // Initial graph with one node assigned via name-based matching + NodeArg* input = &graph.GetOrCreateNodeArg("input", &type_proto); + NodeArg* mid = &graph.GetOrCreateNodeArg("mid", &type_proto); + NodeArg* output = &graph.GetOrCreateNodeArg("output", &type_proto); + + graph.AddNode("/model/layers.0/self_attn/MatMul", "Abs", "", {input}, {mid}); + graph.AddNode("/model/norm/LayerNorm", "Abs", "", {mid}, {output}); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Name-based rules: layers.0/ -> gpu (index 0) + LayeringRules name_rules; + name_rules.rules.push_back({"gpu", "layers.0/", true}); + + LayeringIndex::EpNameToLayeringIndices ep_map; + ep_map["GpuEP"].insert(0); + LayeringIndex::LayeringIndexToEpName rule_map; + rule_map[0] = "GpuEP"; + + SubstringMatcher substring_matcher(name_rules); + auto index = LayeringIndex::Create(graph, std::move(ep_map), std::move(rule_map), + std::move(name_rules), std::move(substring_matcher)); + + // Simulate layout transform adding new nodes with structured names + NodeArg* new_out1 = &graph.GetOrCreateNodeArg("new_out1", &type_proto); + NodeArg* new_out2 = &graph.GetOrCreateNodeArg("new_out2", &type_proto); + + // New node whose name matches "layers.0/" pattern + Node& new_matching = graph.AddNode("/model/layers.0/self_attn/Transpose", "Abs", "", {output}, {new_out1}); + + // New node whose name does NOT match any pattern + Node& new_unmatched = graph.AddNode("/model/lm_head/MatMul", "Abs", "", {new_out1}, {new_out2}); + + ASSERT_STATUS_OK(graph.Resolve()); + + // Call Update() with the new nodes (the incremental path) + std::vector new_nodes = {new_matching.Index(), new_unmatched.Index()}; + index.Update(graph, new_nodes); + + // new_matching should be assigned via substring match + auto assign_match = index.GetNodeAssignment(graph, new_matching.Index()); + ASSERT_TRUE(assign_match.has_value()); + EXPECT_EQ(*assign_match, 0u); + + // new_unmatched should remain unassigned + auto assign_no = index.GetNodeAssignment(graph, new_unmatched.Index()); + EXPECT_FALSE(assign_no.has_value()); +} + } // namespace test } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) \ No newline at end of file +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 50e801525f237..35d698cfef726 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -497,7 +497,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, LayeringIndex* layering_index = nullptr; std::optional layering_index_storage; if (!layering_config.empty()) { - ASSERT_STATUS_OK(LayeringIndex::Create(graph, layering_config, {}, execution_providers, + ASSERT_STATUS_OK(LayeringIndex::Create(graph, layering_config, /*name_based_config_string=*/"", {}, execution_providers, default_logger, layering_index_storage)); if (layering_index_storage.has_value()) { layering_index = &layering_index_storage.value();