diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 6fb895cd1800c..21229c4a4f2de 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -13,10 +13,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS # Standalone model package library (parsing/inspection with no ORT dependency). # Compiled as a static library and linked into onnxruntime_session. -# NOTE: ORT intentionally uses the library's internal C++ types directly (model_package::ParsePackage, -# model_package_internal.h) rather than going through its public C API (ModelPackage_*). This avoids -# double-wrapping (ORT C API -> standalone C API -> C++ internals). The public C API exists for -# external consumers (GenAI, FL) who link against the standalone library independently. +# ORT uses the standalone library's public C API (model_package.h) and translates +# the C handles into ORT-internal C++ types inside core/session/model_package/. set(MODEL_PACKAGE_LIB_DIR "${REPO_ROOT}/model_package") if(NOT onnxruntime_MINIMAL_BUILD) set(MODEL_PACKAGE_BUILD_SHARED OFF CACHE BOOL "" FORCE) @@ -59,7 +57,7 @@ onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxrun target_link_libraries(onnxruntime_session PRIVATE onnxruntime_lora) if(TARGET model_package) target_link_libraries(onnxruntime_session PRIVATE model_package) - target_include_directories(onnxruntime_session PRIVATE ${MODEL_PACKAGE_LIB_DIR}/include ${MODEL_PACKAGE_LIB_DIR}/src) + target_include_directories(onnxruntime_session PRIVATE ${MODEL_PACKAGE_LIB_DIR}/include) endif() if(onnxruntime_ENABLE_INSTRUMENT) target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) diff --git a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc index 57a4e472b6f6d..aae1dd2f8e401 100644 --- a/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc +++ b/include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc @@ -66,6 +66,7 @@ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtApi_ExperimentalApiTest, _Out_ int64_t // - OrtModelPackageApi_ModelPackage_GetVariantCount // - OrtModelPackageApi_ModelPackage_GetVariantNames // - OrtModelPackageApi_ModelPackage_GetVariantEpName +// - OrtModelPackageApi_ModelPackage_ResolveStringRef // 4) Select a component and resolve variant: // - OrtModelPackageApi_SelectComponent // 5) Query selected variant info (optional): @@ -212,6 +213,33 @@ ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_GetVarian _In_ const char* variant_name, _Outptr_result_maybenull_ const char** out_ep) +/** \brief Resolve a path reference declared inside the package against the model_package rules. + * + * Handles the path forms a package may use: + * - "sha256:" or "sha256:/": a content-addressed shared asset. Resolves to + * the asset's on-disk directory (honoring manifest shared_assets overrides), optionally + * joined with the confined tail. Errors if the asset is not declared/discovered. + * - any other value: a relative path resolved against `base_dir` (or the package root when + * `base_dir` is NULL), with portable-layout confinement (no absolute paths, no ".."). + * + * When `must_exist` is non-zero the resolved path must exist on disk. `out_path` is owned by + * `ctx` and remains valid until the next call to this function on the same context. + * + * \param[in] ctx The package context returned by OrtModelPackageApi_CreateModelPackageContext. + * \param[in] base_dir Base directory for relative inputs. May be NULL to use the package root. + * \param[in] input The path reference to resolve. + * \param[in] must_exist Non-zero to require that the resolved path exists on disk. + * \param[out] out_path Receives the resolved path string. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ +ORT_EXPERIMENTAL_API(28, OrtStatusPtr, OrtModelPackageApi_ModelPackage_ResolveStringRef, + _In_ const OrtModelPackageContext* ctx, + _In_opt_ const char* base_dir, + _In_ const char* input, + _In_ int must_exist, + _Outptr_ const char** out_path) + /** \brief Select a component model and return an opaque component instance. * * The variant selection is also performed during this call based on the component diff --git a/model_package/CMakeLists.txt b/model_package/CMakeLists.txt index 326a1e541696a..4b296cdca96a0 100644 --- a/model_package/CMakeLists.txt +++ b/model_package/CMakeLists.txt @@ -52,8 +52,13 @@ endif() # ───────────────────────────────────────────────────────────────────────────── set(MODEL_PACKAGE_SOURCES - src/api.cc - src/parser.cc + src/asset_hasher.cc + src/authoring.cc + src/commit_prune_validate.cc + src/manifest_parser.cc + src/model_package_impl.cc + src/path_resolver.cc + src/sha256.cc ) if(MODEL_PACKAGE_BUILD_SHARED) @@ -91,6 +96,45 @@ install(TARGETS model_package RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) -install(FILES include/model_package_api.h +install(FILES include/model_package_api.h include/model_package.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} ) + +# ───────────────────────────────────────────────────────────────────────────── +# Tests +# ───────────────────────────────────────────────────────────────────────────── + +if(MODEL_PACKAGE_BUILD_TESTS) + enable_testing() + add_executable(test_inspection tests/test_inspection.cc) + target_link_libraries(test_inspection PRIVATE model_package nlohmann_json::nlohmann_json) + target_include_directories(test_inspection PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + add_test(NAME inspection COMMAND test_inspection) + + add_executable(test_asset_hashing tests/test_asset_hashing.cc) + target_link_libraries(test_asset_hashing PRIVATE model_package nlohmann_json::nlohmann_json) + target_include_directories(test_asset_hashing PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + add_test(NAME asset_hashing COMMAND test_asset_hashing) + + add_executable(test_authoring tests/test_authoring.cc) + target_link_libraries(test_authoring PRIVATE model_package nlohmann_json::nlohmann_json) + target_include_directories(test_authoring PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + add_test(NAME authoring COMMAND test_authoring) + + add_executable(test_commit tests/test_commit.cc) + target_link_libraries(test_commit PRIVATE model_package nlohmann_json::nlohmann_json) + target_include_directories(test_commit PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + add_test(NAME commit COMMAND test_commit) +endif() diff --git a/model_package/README.md b/model_package/README.md index 604720916d764..dcbb9252c71ed 100644 --- a/model_package/README.md +++ b/model_package/README.md @@ -1,78 +1,625 @@ # Model Package Library -A standalone C library for parsing and inspecting ONNX Runtime Model Packages. +A standalone C library for **reading, authoring, validating, and committing** +ONNX Runtime model packages. The library has no dependency on ONNX Runtime +itself, so any consumer (ORT, publisher tools, ...) can compile it in +without dragging in a session runtime. It is distributed and consumed as +**source** (see [Versioning and compatibility](#versioning-and-compatibility)). -**No dependency on ONNX Runtime.** This library can be consumed independently by any component (ORT, GenAI, FL, or external tools). +The library owns three things: -## What it does +1. The **on-disk layout** of a model package (directory + manifest + shared + assets). +2. The **schema** of `manifest.json` and `component.json`, including the + `executor_info` extension point. +3. The **resolution rules** for paths and content-addressed shared assets, + including portable vs installed confinement. -- Parses model package directory structures (`manifest.json`, `metadata.json`, `variant.json`) -- Provides read-only access to: - - Components and their variants - - EP compatibility declarations (opaque strings) - - Model file paths within variants - - Session/provider options per file - - Consumer metadata (opaque JSON) +It deliberately does **not** know about ONNX, execution providers, sessions, +or the JSON payload that lives under any `executor_info[""]` slot. +Each consumer owns its own slot and parses it itself. -## What it does NOT do +--- -- Variant selection (requires runtime EP factory validation → stays in ORT) -- Session creation (requires ORT `InferenceSession`) -- Any interpretation of `compatibility_string` tokens +## On-disk layout -## Building +A package is a directory containing a top-level `manifest.json`. Components +live under the package root, either declared inline in the manifest or as +external `component.json` files. Variants are directories under their +component. Shared assets are content-addressed directories under +`shared_assets/`. -```bash -cmake -B build -S . -cmake --build build +``` +package_root/ +├── manifest.json # required +├── decoder/ # external component (directory) +│ ├── component.json # required when external +│ └── cpu/ # variant_directory +│ ├── model.onnx +│ └── ort_info.json # executor_info["ort"], external form +├── encoder/ # inline component (no component.json) +│ └── cuda/ +│ └── model.onnx +└── shared_assets/ + └── sha256-<64hex>/ # content-addressed asset directory + ├── tokenizer.json + └── chat_template.jinja ``` -Options: -- `-DMODEL_PACKAGE_BUILD_SHARED=ON|OFF` — Build as shared (default) or static library -- `-DMODEL_PACKAGE_BUILD_TESTS=ON` — Build tests (default OFF) +- The package root must be a directory. A single file is **not** a package. +- A package has at least one component. A component has at least one variant. +- A variant always corresponds to a directory on disk (`variant_directory`). + Files inside that directory are referenced by `executor_info` payloads, not + by the manifest. +- `shared_assets/` is optional and only needs to exist if at least one + shared asset is published. -## C API Usage +### Portable vs installed layout -```c -#include "model_package_api.h" - -ModelPackageContext* ctx = NULL; -ModelPackageStatus* status = ModelPackage_CreateContext("/path/to/package", &ctx); -if (status != NULL) { - printf("Error: %s\n", ModelPackage_GetErrorMessage(status)); - ModelPackage_ReleaseStatus(status); - return; +`manifest.layout` declares how the package may use paths: + +- `"portable"` (default): every path is a `package_root`-relative POSIX path + with no `..` segments and no absolute paths. The package is self-contained + and movable. This is the format you ship. +- `"installed"`: absolute paths and `..` segments are allowed. This is for + packages that have been "installed" onto a system that links shared assets + to a system-wide cache, or that reference pre-existing files outside the + package root. + +The library enforces these rules at parse time. `ModelPackageOpenOptions. +allow_external_paths` can additionally relax portable confinement for read +operations, but the parser still rejects absolute paths inside the manifest +unless `layout == "installed"`. + +--- + +## `manifest.json` + +```jsonc +{ + "schema_version": "1.0", // required, "." (major gates compat) + "package_name": "phi-4-mini", // optional, free-form + "package_version":"4.0.0", // optional, free-form + "description": "Phi-4 mini reasoning model.", // optional + "layout": "portable", // optional: "portable" (default) | "installed" + + "components": { // required, at least one entry + "decoder": "decoder", // external — path relative to package_root + "encoder": { /* inline component body */ } + }, + + "shared_assets": { // optional + "sha256:<64hex>": "shared_assets/sha256-<64hex>" // optional path override + }, + + "additional_metadata": { /* free-form */ } // optional } +``` + +Field reference: + +| Field | Type | Required | Notes | +| -------------------- | --------------- | -------- | ----- | +| `schema_version` | string | yes | `"."` (e.g. `"1.0"`). The library accepts any package whose **major** is in its supported range and any **minor**; a major outside the range is an `ERR_VERSION`. A bare integer is accepted as `".0"`. Major gates compatibility; minor tells consumers which optional fields may be present. | +| `package_name` | string | no | Human label. Not used for resolution. | +| `package_version` | string | no | Human label. Not used for resolution. | +| `description` | string | no | Free-form. | +| `layout` | string | no | `"portable"` (default) or `"installed"`. | +| `components` | object | yes | Map of component name → component value. See below. | +| `shared_assets` | object | no | Map of `sha256:` URI → path override (string). | +| `additional_metadata`| any JSON value | no | Opaque to this library. Round-tripped verbatim. | + +By default the parser rejects unknown top-level keys (`strict_unknown_fields`, +on by default). Disable it via `ModelPackageOpenOptions` to round-trip +manifests authored against a newer schema. + +### Components -size_t count = 0; -ModelPackage_GetComponentCount(ctx, &count); +The value under `components[name]` is either: -for (size_t i = 0; i < count; i++) { - const char* name = NULL; - ModelPackage_GetComponentName(ctx, i, &name); - printf("Component: %s\n", name); +- **A string** — the path to an external component, resolved against + `package_root`. The path may be: + - **A directory.** The loader appends `component.json` and reads that + file. The filename is fixed in this form (must be exactly + `component.json`). + - **A file.** Loaded directly. The filename is not enforced and may be + anything (e.g. `decoder.json`). Useful when one directory holds + multiple component definitions. +- **A JSON object** — an inline component body matching the + [component schema](#componentjson) below. + +The component's "directory" is: + +- For an inline component, the package root itself. +- For an external component pointed at by a directory path, that directory. +- For an external component pointed at by a file path, the file's parent. + +Variant paths in the component body are resolved against this directory. + +### Shared assets + +`shared_assets[uri]` is an **override**: it says "the asset with this URI +lives at this path", overriding the default convention of +`/shared_assets/sha256-/`. Overrides are eagerly rejected +in portable layout when they would escape `package_root` (e.g. absolute paths, +`..` segments). + +Variants reference shared assets only by embedding `sha256:[/sub/path]` +strings inside their `executor_info` payloads. Consumers resolve those +references through [`ModelPackage_ResolveStringRef`](#path-resolution-rules). +The library never parses `executor_info` payloads, so it has no manifest-level +list of which variant uses which asset. + +--- + +## `component.json` + +When a component is external, `component.json` is the file referenced from +the manifest. When inline, the same body is embedded directly in +`manifest.components[name]`. + +```jsonc +{ + "component_name": "decoder", // optional, descriptive only + "variants": { // required, may be empty + "cpu": { /* variant body */ }, + "cuda": { /* variant body */ } + }, + "additional_metadata": { /* free-form */ } // optional } +``` + +Field reference: + +| Field | Type | Required | Notes | +| -------------------- | ------ | -------- | ----- | +| `component_name` | string | no | Sanity-checked as a string; not used for lookup. The map key in `components` wins. | +| `variants` | object | yes | Map of variant name → variant body. May be empty (placeholder component). | +| `additional_metadata`| any | no | Free-form. | -ModelPackage_ReleaseContext(ctx); +--- + +## Variant body + +A variant binds a single (EP, device, compatibility) triple to a single +on-disk directory plus zero or more per-consumer `executor_info` payloads. + +```jsonc +{ + "variant_directory": "cuda", // optional — defaults to variant name + "ep": "CUDAExecutionProvider", // optional + "device": "gpu", // optional ("cpu" | "gpu" | "npu") + "compatibility_string": "", // optional, opaque to library + "executor_info": { // optional + "ort": "ort_info.json", // string → external file + "other": { "filename": "model.onnx" } // object → inline JSON + }, + "additional_metadata": { /* free-form */ } // optional +} ``` -## Integration with ORT +Field reference: -ORT compiles this library as part of its build and wraps the C API through `OrtModelPackageApi`, adding: -- Variant selection via EP factory compatibility validation -- Session creation with merged options +| Field | Type | Required | Notes | +| ---------------------- | ---------------- | -------- | ----- | +| `variant_directory` | string | no | Path relative to the component directory. Defaults to the variant name. If declared but missing on disk, parse fails. | +| `ep` | string | no | Single ONNX Runtime EP name (e.g. `CPUExecutionProvider`). | +| `device` | string | no | Lower-case `cpu` / `gpu` / `npu`. ORT uses this for variant selection. | +| `compatibility_string` | string | no | Opaque to the library. ORT hands it to the EP's `ValidateCompiledModelCompatibilityInfo` callback. | +| `executor_info` | object | no | Map of consumer namespace → string (external file) or object (inline JSON). | +| `additional_metadata` | any | no | Free-form. | -## Package Format +#### `variant_directory` +- Always interpreted as a directory. +- Resolved against the **component directory** (not the package root). +- The library does not validate the directory's contents; consumers resolve + their own file references relative to it. + +#### `executor_info` + +This is the extension point that lets ORT and any future consumer share a +package without colliding. Keys are consumer namespaces; values are either: + +- **A string** — a path to a JSON file. Resolved against the variant + directory. The file must exist (in strict mode) and parse as JSON. +- **An inline JSON object** — embedded directly in the manifest. + +The library round-trips the payload but never interprets it. See +[`onnxruntime/core/session/model_package/README.md`](../onnxruntime/core/session/model_package/README.md) +for the `"ort"` namespace schema. + +Consumers can embed `sha256:[/sub/path]` references inside their +`executor_info` payload and resolve them through +`ModelPackage_ResolveStringRef`. The library does not maintain a per-variant +list of consumed assets; see [Shared assets](#shared-assets) for how URIs +enter the resolvable set. + +--- + +## Shared assets + +Shared assets are **directories** identified by a content hash. Two packages +that ship the same tokenizer will reuse the same asset directory on disk in +an installed layout, dedup-ing storage and downloads. + +### Canonical asset URI + +`ModelPackage_ComputeDirectoryHash(source_dir)` computes the canonical URI: + +1. Walk `source_dir` recursively, collecting regular files. Empty + subdirectories are ignored. +2. Reject symlinks (portability hazard). +3. For each file, compute `sha256(file_bytes)` → per-file hex digest. +4. Build a manifest text of lines ` \n` + sorted lexicographically by path. Paths use forward slashes, no leading + `./`. Non-ASCII paths must be NFC-normalized by the caller. +5. `asset_uri = "sha256:" + sha256(manifest_text)`, lowercase hex. + +The scheme hashes **both** file contents and file names, so renaming a file +inside an asset changes the URI. The on-disk directory name follows the +convention `sha256-` (dash, not colon) to keep the path filesystem-safe. + +### Default location + +`/shared_assets/sha256-/`. Override per-asset by adding an +entry to `manifest.shared_assets`. + +### How URIs enter the resolvable set + +At Open time the library populates the resolvable shared-asset table from +three sources, in order. Within each tier an already-seen URI is skipped: + +1. **Manifest overrides.** Every entry under `manifest.shared_assets` lands + first. These can also point at non-default paths (subject to the + layout's portability rules). +2. **On-disk discovery.** The library lists `/shared_assets/` + and admits each `sha256-` subdirectory it finds (sorted + lexicographically). The resolved path is the default + `/shared_assets/sha256-/`. A missing `shared_assets/` + directory is fine. +3. **Pending authoring stages.** Any `copy_in=true` source registered via + `ModelPackage_AddSharedAsset` is surfaced at its staged source path so + `ResolveStringRef` works before `Commit`. + +This means the manifest does not need to enumerate the assets that ship in +the conventional `shared_assets/` directory. The override list is only +needed when an asset lives outside the default convention. + +### Adding a shared asset programmatically + +```c +const char* uri = NULL; +ModelPackageStatus* st = ModelPackage_AddSharedAsset( + pkg, + "/path/to/tokenizer", // source_dir + NULL, // expected_uri_or_null (reproducible-build check) + /*copy_in=*/true, // stage for copy at Commit time + &uri); ``` -package_root/ -├── manifest.json # schema_version, components list -└── models/ - └── / - ├── metadata.json # variants + EP compatibility declarations - └── / - ├── variant.json # files list, consumer_metadata - └── model.onnx # (or other model files) + +`copy_in == false` stores an override path in the manifest and is rejected +eagerly in portable layout (the path is unlikely to be portable). `copy_in +== true` stages the source for copy when `ModelPackage_Commit()` runs. + +--- + +## Path resolution rules + +`ModelPackage_ResolveStringRef(pkg, base_dir, input, must_exist, &out)` is +the canonical path resolver. It accepts: + +| Input form | Resolution | +| --------------------------- | ---------- | +| `sha256:` | Returns the on-disk directory for that shared asset. Error if the asset isn't registered. | +| `sha256:/sub/path` | Returns `/sub/path`. The subpath is confined to the asset folder (no absolute, no `..`). | +| Relative path | Resolved against `base_dir` (or `package_root` when `base_dir` is NULL). | +| Absolute path / `..` segments | Allowed only in `installed` layout or when the package was opened with `allow_external_paths = true`. | + +In portable layout the resolver enforces that the resolved path stays +underneath `package_root`. Symlinks are followed by default +(`follow_symlinks`). + +`out_path` is a NUL-terminated thread-local pointer; copy it if it must +outlive the next `ResolveStringRef` call on the same thread. + +--- + +## C API quick tour + +All public entry points are declared in `include/model_package.h`. Reading a +package and walking the info tree: + +```c +#include "model_package.h" + +ModelPackage* pkg = NULL; +if (ModelPackageStatus* st = ModelPackage_Open("/path/to/pkg", NULL, &pkg)) { + fprintf(stderr, "open failed: %s\n", ModelPackageStatus_Message(st)); + ModelPackageStatus_Release(st); + return 1; +} + +const ModelPackageInfo* info = ModelPackage_Info(pkg); +printf("schema=%lld.%lld layout=%s\n", + (long long)info->schema_version_major, (long long)info->schema_version_minor, info->layout); +for (size_t i = 0; i < info->num_components; ++i) { + const ModelComponentInfo* c = &info->components[i]; + printf("component %s (%zu variants)\n", c->name, c->num_variants); + for (size_t v = 0; v < c->num_variants; ++v) { + const ModelVariantInfo* var = &c->variants[v]; + printf(" variant %s dir=%s ep=%s\n", + var->name, + var->variant_directory ? var->variant_directory : "(unset)", + var->ep ? var->ep : "(unset)"); + for (size_t e = 0; e < var->num_executor_infos; ++e) { + const ModelExecutorInfoEntry* ei = &var->executor_infos[e]; + printf(" executor_info[%s] = %s\n", ei->namespace_key, ei->json); + } + } +} + +ModelPackage_Close(pkg); +``` + +Authoring a new package from scratch: + +```c +ModelPackage* pkg = NULL; +ModelPackage_New(&pkg); +ModelPackage_SetMetadata(pkg, "phi-4-mini", "4.0.0", "Phi-4 mini."); + +ModelPackage_SetComponentInline(pkg, "decoder", "{\"variants\": {}}"); +ModelPackage_SetVariant(pkg, "decoder", "cpu", + "{\"variant_directory\":\"decoder/cpu\"," + " \"ep\":\"CPUExecutionProvider\"," + " \"device\":\"cpu\"}"); +ModelPackage_SetVariantExecutorInfoInline( + pkg, "decoder", "cpu", "ort", "{\"model_file\":\"model.onnx\"}"); + +const char* asset_uri = NULL; +ModelPackage_AddSharedAsset(pkg, "/src/tokenizer", NULL, /*copy_in=*/true, &asset_uri); +// asset_uri is owned by pkg; copy it if you need it past the next mutation. + +ModelPackage_Commit(pkg, "/path/to/new_pkg", MODEL_PACKAGE_WRITE_PRESERVE); +ModelPackage_Close(pkg); ``` -Single-component shorthand (metadata.json at root, no manifest.json) is also supported. +### Lifetime contract + +Every `const char*` and every `const ModelPackageInfo*` (plus sub-arrays) +returned by the read API is owned by the `ModelPackage` handle and remains +valid **until the next mutation of that scope** or until +`ModelPackage_Close()`. Any `Set*` / `Remove*` / `Add*` / `Commit` call +invalidates cached pointers in the mutated scope; re-read `Info()` after +mutating. + +`ModelPackage_AddSharedAsset`'s `out_uri` follows the same "valid until next +mutation" rule. + +`ModelPackage_ResolveStringRef` and `ModelPackage_ComputeDirectoryHash` +return pointers into a per-thread scratch slot; copy before the next call on +the same thread. + +### Commit modes + +`ModelPackage_Commit(pkg, dest, mode)`: + +- `dest == NULL` → in-place commit at `package_root`. +- `dest != NULL` → write a self-contained "save as". `dest` must be empty or + nonexistent. On success the package's root is updated to `dest`, so + subsequent in-place commits go there. + +`mode`: + +- `MODEL_PACKAGE_WRITE_PRESERVE` (default) — each component and + `executor_info` entry keeps its current inline-or-external shape. +- `MODEL_PACKAGE_WRITE_DENSE` — flatten every external component back inline + into `manifest.json`. Useful for single-file authoring inspection. + +### Prune + +`ModelPackage_Prune(pkg)` reclaims storage that the library itself manages: + +- Tracked orphan variant and component directories left behind by + `RemoveVariant`, `RemoveComponent`, `SetVariant`, or + `SetComponentExternal`. +- Stale `.tmp.` staging directories from interrupted commits, after + a short grace window. + +`Prune` deliberately never removes `shared_assets/sha256-/` directories. +Consumers freely embed `sha256:` references inside their own `executor_info` +payloads, and the library cannot prove an asset is unused without parsing +every consumer's namespace. Use `ModelPackage_RemoveSharedAsset(uri)` to +delete a shared asset explicitly when the caller knows it is unreferenced. + +Only paths registered through this API and strictly inside `package_root` +are touched. + +### Validate + +`ModelPackage_Validate(pkg, flags, &report_json)` runs a configurable set of +structural checks and returns a JSON report +`{"errors": [...], "warnings": [...]}`: + +| Flag | Checks | +| --------------------------------------- | ------ | +| `MODEL_PACKAGE_VALIDATE_SCHEMA` | Required keys, types, value ranges. | +| `MODEL_PACKAGE_VALIDATE_PATHS` | Every recorded path resolves under the configured layout. | +| `MODEL_PACKAGE_VALIDATE_ASSET_REHASH` | Recompute every asset directory hash and compare to its URI (slow). | +| `MODEL_PACKAGE_VALIDATE_UNKNOWN_FIELDS` | Surface unknown JSON fields as warnings. | +| `MODEL_PACKAGE_VALIDATE_ALL` | All of the above. | + +Errors cause a non-NULL status return; warnings alone return success. + +--- + +## Versioning and compatibility + +### Distributed as source + +The library is meant to be **vendored and compiled into each consumer's own +binary** (ORT, publisher tooling, third-party loaders). No prebuilt shared +library (`.so`/`.dll`) is published as the supported interface. + +A direct consequence is that the public POD structs in `model_package.h` have +**no binary boundary** to defend: within any single build there is exactly one +definition of every struct, so there is nothing for two separately-compiled +artifacts to disagree about. The library therefore carries **none** of the usual +ABI machinery — no per-struct `struct_size`/`cbSize`, no `abi_version`, no +library SOVERSION, and no offset `static_assert`s. Collections are exposed as +plain array members (`components`/`num_components`, `variants`/`num_variants`, +…) rather than count+index accessors, since accessors only earn their keep when +the library owns the struct stride across a binary boundary. + +The **only** compatibility contract is the on-disk data format, expressed by +`schema_version`. Everything a consumer needs to know about which fields and +objects a package may contain follows from that one value. + +### `schema_version` + +`schema_version` is a `"."` string in `manifest.json` (a bare +integer `N` is accepted and treated as `N.0`). It is parsed into +`ModelPackageInfo.schema_version_major` and `schema_version_minor`. + +- **major** — the data contract. Incremented only for a **breaking** change + (a field removed, renamed, retyped, or given new semantics). A consumer that + understands major *N* can read any `N.x` package. +- **minor** — additive evolution within a major. Incremented when a new + **optional** field or object is added. It never removes or reinterprets + anything, so it is fully backward- and forward-compatible within the major. + +Consumers should branch **solely on `schema_version_major` / `schema_version_minor`** +to decide which optional fields a package may carry — not on the presence or +absence of individual fields, and never on any library version. + +### What the parser enforces + +Each build declares the majors it understands as a closed range +(`kMinSupportedSchemaMajor … kMaxSupportedSchemaMajor` in `manifest_parser.cc`) +plus the highest minor it authored (`kMaxKnownSchemaMinor`): + +- **Unsupported major** → `ModelPackage_Open` fails with + `MODEL_PACKAGE_ERR_VERSION`. A consumer never silently misreads a package + whose contract it does not understand. +- **Any minor is accepted.** When the minor is **newer** than this build knows + (`minor > kMaxKnownSchemaMinor`), unknown-field strictness is relaxed for that + package so the additive fields a newer authoring tool wrote are **tolerated** + (read through, preserved on round-trip via the JSON getters) instead of + rejected. An older library can therefore load a newer-minor package and ignore + the fields it does not recognize. + +### Supporting a major version bump + +When a breaking change requires a new major, deployed packages do **not** have to +be rewritten and consumers do **not** have to upgrade in lockstep. The library is +designed to support a **range** of majors simultaneously: + +1. Bump `kMaxSupportedSchemaMajor` and add the new major's parse/serialize path, + keeping the existing major's path in place. The supported range now spans both. +2. Existing `N.x` packages keep loading unchanged through the old path; new + `(N+1).x` packages load through the new path. +3. Consumers branch on `schema_version_major` to pick the field set they read. + Code that only supports major *N* simply declines `(N+1).x` packages (the open + call returns `MODEL_PACKAGE_ERR_VERSION` for it) rather than misreading them. +4. A major is dropped from the supported range only when its packages are no + longer in circulation — an explicit, opt-in deprecation, never an implicit + break. + +This keeps already-published packages valid for as long as the library advertises +their major, which is the backward-compatibility guarantee external publishers +depend on. + +### How a major bump maps onto the structs + +A natural question is how a single C struct can represent two majors with +different fields. It can't — and it never has to, because **there is only one +struct definition in any given build**. The "old major" exists only as JSON on +disk; it is never a second C type in the consumer's binary. Since the library is +compiled from source, every consumer compiles exactly one definition of +`ModelPackageInfo`/`ModelVariantInfo`/etc. — the current one. Reconciling an +old-major package with that one definition is a **parse-time** job, not a +struct-layout one. + +The single struct is the **superset / newest** shape, and divergence between +majors is absorbed in three places: + +1. **Additive differences (common).** A field a new major added is present in the + struct and is simply `NULL`/`0`/empty when an older-major package lacks it — + the same mechanism as a minor bump. The consumer treats absence as "not + provided". + +2. **Parse-time normalization (preferred).** When a new major is added, its + parser path is added alongside the existing one, and **both populate the same + struct**. An older-major package is mapped up to the current in-memory model + (defaults filled, renamed fields mapped to their current names) before the + consumer sees it, so reads are uniform. `schema_version_major` then records the + *source* contract — useful for write-back and provenance — rather than + selecting a layout. + +3. **Non-migratable changes (rare).** A field whose *type* changes, or one + removed with no equivalent, cannot reuse the same name (C gives one field one + type). Add a new field for the new representation, populate the old field only + for old-major packages and the new field only for new-major packages, and let + the consumer branch on `schema_version_major`: + + ```c + // e.g. major 1 stored a single compatibility string; major 2 stores a list + const char* compatibility_string; // set when schema_version_major == 1 + const char* const* compatibilities; // set when schema_version_major == 2 + size_t num_compatibilities; + ``` + +**Escape hatch.** If a major bump is sweeping enough that the superset becomes +unwieldy, the standard move is **per-major typed structs** (e.g. a +`ModelPackageInfoV2` returned by a versioned accessor) — a deliberate API +expansion reserved for a wholesale redesign, not the default. In practice: prefer +normalizing old majors up to the newest struct at parse time; fall back to extra +nullable fields plus `schema_version_major` branching only when a change cannot be +auto-migrated. + +--- + +## What the library deliberately does NOT do + +- **Variant selection.** Picking which variant best matches the EPs the + caller has available requires EP factory introspection and is owned by the + executor. ORT's selector lives in + `onnxruntime/core/session/model_package/` and uses each EP's + `ValidateCompiledModelCompatibilityInfo` callback. +- **Session creation.** Building an `OrtSession` is ORT's job. +- **Interpreting `executor_info` payloads.** Each consumer namespace owns + its own slot. The library only validates that values are either strings + (paths) or objects. +- **Interpreting `compatibility_string`.** The format is owned by the EP + declared in `ep`. The library never parses it. + +--- + +## Building + +```bash +cmake -B build -S . [-DMODEL_PACKAGE_BUILD_TESTS=ON] +cmake --build build -j +ctest --test-dir build --output-on-failure # requires BUILD_TESTS=ON +``` + +CMake options: + +- `MODEL_PACKAGE_BUILD_SHARED` (default `ON`) — shared vs static. +- `MODEL_PACKAGE_BUILD_TESTS` (default `OFF`) — build the unit-test + executables (`test_asset_hashing`, `test_inspection`, `test_authoring`, + `test_commit`). + +The only build-time dependency is a vendored copy of nlohmann/json (header +only). + +--- + +## See also + +- `onnxruntime/core/session/model_package/README.md` — how ORT consumes this + library and the `executor_info["ort"]` schema. +- `model_package_redesign.md` in the `archive` repo — original design + rationale (extension fields, content addressing, portable vs installed, + shared-asset overrides). diff --git a/model_package/include/model_package.h b/model_package/include/model_package.h new file mode 100644 index 0000000000000..3b456852adb37 --- /dev/null +++ b/model_package/include/model_package.h @@ -0,0 +1,350 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file model_package.h +/// \brief Public C API for the ONNX Runtime Model Package library. +/// +/// A model package is a directory with a top-level `manifest.json` that +/// declares a set of components; each component declares a set of variants; +/// each variant points at a directory containing the model files and may +/// carry executor-specific configuration under per-namespace +/// `executor_info` entries. +/// +/// Error handling: functions that can fail return `ModelPackageStatus*`. +/// A `nullptr` return indicates success. Use `ModelPackageStatus_Message`, +/// `ModelPackageStatus_Code`, and `ModelPackageStatus_Release` to inspect and +/// release statuses. +/// +/// Object lifetime: every `const char*` and every `const ModelPackageInfo*` +/// (and its sub-arrays) returned by this API is owned by the `ModelPackage` +/// handle and remains valid until the next mutation of that scope or until +/// the package is closed. Mutations invalidate cached pointers in the mutated +/// scope and its descendants; callers must re-read `ModelPackage_Info()` +/// after any mutation. + +#pragma once + +#include +#include +#include + +#include "model_package_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// The library is consumed as source (compiled into each consumer's own binary), +// so the structs below have no binary boundary to maintain: there is no +// struct_size/ABI versioning. Compatibility with on-disk packages is governed +// solely by `schema_version` (see ModelPackageInfo). + +// ───────────────────────────────────────────────────────────────────────────── +// Opaque handle +// ───────────────────────────────────────────────────────────────────────────── + +typedef struct ModelPackage ModelPackage; + +// ───────────────────────────────────────────────────────────────────────────── +// Status helpers +// ───────────────────────────────────────────────────────────────────────────── + +/// Get the error message from a status object. Returns NULL if `status` is NULL. +/// The returned string is owned by the status object. +MODEL_PACKAGE_API const char* ModelPackageStatus_Message(const ModelPackageStatus*); +/// Get the categorical error code. Returns `MODEL_PACKAGE_OK` when `status` is NULL. +MODEL_PACKAGE_API ModelPackageErrorCode ModelPackageStatus_Code(const ModelPackageStatus*); +/// Release a status object. Safe to call with NULL. +MODEL_PACKAGE_API void ModelPackageStatus_Release(ModelPackageStatus*); + +// ───────────────────────────────────────────────────────────────────────────── +// Lifecycle +// ───────────────────────────────────────────────────────────────────────────── + +typedef struct ModelPackageOpenOptions { + bool allow_external_paths; ///< default false; unlocks absolute paths and `..` segments + bool follow_symlinks; ///< default true + bool strict_unknown_fields; ///< default true; relax to round-trip newer schemas +} ModelPackageOpenOptions; + +/// Open an existing model package directory. `opts` may be NULL for defaults. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_Open(const char* package_root, + const ModelPackageOpenOptions* opts, + ModelPackage** out); + +/// Create a new empty in-memory package for from-scratch authoring. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_New(ModelPackage** out); + +/// Release a ModelPackage handle and all its caches. Safe on NULL. +MODEL_PACKAGE_API void ModelPackage_Close(ModelPackage* pkg); + +// ───────────────────────────────────────────────────────────────────────────── +// Data model — POD structs read from ModelPackage_Info() +// ───────────────────────────────────────────────────────────────────────────── + +typedef struct ModelExecutorInfoEntry { + const char* namespace_key; ///< executor namespace name (e.g. "ort") + const char* json; ///< canonical JSON value as string (object, array, etc.) +} ModelExecutorInfoEntry; + +typedef struct ModelVariantInfo { + const char* name; + /// Resolved absolute path to the variant's on-disk directory, or NULL when + /// no directory has been declared and the default location does not exist. + const char* variant_directory; + const char* ep; ///< NULL when unset + const char* device; ///< NULL when unset + const char* compatibility_string; ///< NULL when unset + const char* additional_metadata_json; ///< NULL when unset + size_t num_executor_infos; + const ModelExecutorInfoEntry* executor_infos; +} ModelVariantInfo; + +typedef struct ModelComponentInfo { + const char* name; + const char* additional_metadata_json; ///< NULL when unset + size_t num_variants; + const ModelVariantInfo* variants; +} ModelComponentInfo; + +typedef struct ModelSharedAssetInfo { + const char* uri; ///< "sha256:" + const char* resolved_path; ///< absolute on-disk directory path +} ModelSharedAssetInfo; + +typedef struct ModelPackageInfo { + int64_t schema_version_major; ///< parsed from on-disk "."; gates compatibility + int64_t schema_version_minor; ///< informational; indicates which optional fields may be present + const char* package_name; ///< NULL when unset + const char* package_version; ///< NULL when unset + const char* description; ///< NULL when unset + const char* layout; ///< "portable" or "installed" + const char* additional_metadata_json; ///< NULL when unset + size_t num_components; + const ModelComponentInfo* components; + size_t num_shared_assets; + const ModelSharedAssetInfo* shared_assets; +} ModelPackageInfo; + +/// Return the package-level info tree. Pointer is owned by the package and is +/// invalidated by any mutation. +MODEL_PACKAGE_API const ModelPackageInfo* ModelPackage_Info(const ModelPackage* pkg); + +// ───────────────────────────────────────────────────────────────────────────── +// Convenience lookups +// ───────────────────────────────────────────────────────────────────────────── + +/// Find a component by name. Returns NULL when not found. +MODEL_PACKAGE_API const ModelComponentInfo* ModelPackage_FindComponent(const ModelPackageInfo*, + const char* name); +/// Find a variant within a component by name. Returns NULL when not found. +MODEL_PACKAGE_API const ModelVariantInfo* ModelComponentInfo_FindVariant(const ModelComponentInfo*, + const char* name); +/// Find an executor_info entry by namespace. Returns NULL when not declared. +MODEL_PACKAGE_API const ModelExecutorInfoEntry* ModelVariantInfo_FindExecutorInfo( + const ModelVariantInfo*, const char* namespace_key); + +// ───────────────────────────────────────────────────────────────────────────── +// Round-trip JSON getters +// ───────────────────────────────────────────────────────────────────────────── + +/// Get the canonical schema-shaped JSON for the named component. Preserves +/// fields unknown to this build. The returned pointer is owned by the package. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetComponentJson(const ModelPackage*, + const char* component_name, + const char** out_json); + +/// Get the canonical schema-shaped JSON for the named variant. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetVariantJson(const ModelPackage*, + const char* component_name, + const char* variant_name, + const char** out_json); + +// ───────────────────────────────────────────────────────────────────────────── +// Asset resolution + hashing +// ───────────────────────────────────────────────────────────────────────────── + +/// Resolve a `sha256:` URI to an on-disk directory. Errors with +/// `MODEL_PACKAGE_ERR_ASSET_MISSING` when not resolvable. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_ResolveAssetUri(const ModelPackage*, + const char* uri, + const char** out_path); + +/// Resolve a string reference using the model package's path resolution rules. +/// `input` may be: +/// - `sha256:` -> shared-asset folder +/// - `sha256:/sub/path` -> file or subdir inside a shared-asset folder +/// (sub/path is resolved with portable-mode +/// confinement under the asset folder: no +/// absolute, no `..`) +/// - relative path -> resolved against `base_dir` (or +/// `package_root` when `base_dir == NULL`), +/// confined to `package_root` in portable layout +/// - absolute path / `..` segments -> only allowed in installed layout, or in +/// any layout when the package was opened with +/// `ModelPackageOpenOptions.allow_external_paths` +/// +/// `must_exist` controls whether a missing target is `MODEL_PACKAGE_ERR_NOT_FOUND` +/// or the lexically-normalized path is returned anyway. +/// On success `*out_path` points to a NUL-terminated thread-local string; copy +/// it if you need it to outlive the next `ModelPackage_ResolveStringRef` call on +/// the same thread. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_ResolveStringRef(const ModelPackage*, + const char* base_dir, + const char* input, + bool must_exist, + const char** out_path); + +/// Compute the canonical `sha256:` URI for a directory. On success, +/// `*out_uri` is set to a NUL-terminated string owned by an internal +/// thread-local slot; the caller must copy if it must outlive the next call +/// on the same thread. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_ComputeDirectoryHash(const char* source_dir, + const char** out_uri); + +// ───────────────────────────────────────────────────────────────────────────── +// Authoring — mutation API +// ───────────────────────────────────────────────────────────────────────────── +// +// Each mutation invalidates info pointers in the mutated scope and its +// descendants. Strict unknown-field rejection follows the open-time option +// `strict_unknown_fields` (default true). + +/// Set or replace an inline component. `component_json` must be a JSON object +/// matching the component schema. An existing component with the same name is +/// replaced. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetComponentInline(ModelPackage*, + const char* name, + const char* component_json); + +/// Set or replace an external component. `path` is recorded in the manifest +/// (relative to package_root, or absolute in installed layout). If the file +/// exists, it is loaded; otherwise the component is initialized empty +/// (`{"variants": {}}`). `path` may be a directory (resolves to +/// `/component.json`). +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetComponentExternal(ModelPackage*, + const char* name, + const char* path); + +/// Remove a component by name. No-op when the name is not present. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_RemoveComponent(ModelPackage*, const char* name); + +/// Upsert a variant inside a component. `variant_json` must be a JSON object +/// matching the variant schema. The library does not validate that +/// `variant_directory` exists on disk; executors are responsible for resolving +/// their own file references at load time. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetVariant(ModelPackage*, + const char* component_name, + const char* variant_name, + const char* variant_json); + +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_RemoveVariant(ModelPackage*, + const char* component_name, + const char* variant_name); + +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetVariantExecutorInfoInline(ModelPackage*, + const char* component, + const char* variant, + const char* ns, + const char* info_json); + +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetVariantExecutorInfoExternal(ModelPackage*, + const char* component, + const char* variant, + const char* ns, + const char* path); + +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_RemoveVariantExecutorInfo(ModelPackage*, + const char* component, + const char* variant, + const char* ns); + +/// Add a content-addressed shared asset. When `expected_uri_or_null` is +/// non-NULL, the computed URI must match (reproducible-build check). With +/// `copy_in == false`, an override path is stored in the manifest; this is +/// rejected eagerly in portable layout. With `copy_in == true`, the source +/// directory is staged for copy at `_Commit` time. +/// +/// `out_uri` is set to a NUL-terminated string owned by the package. The +/// pointer is only guaranteed to remain valid until the next mutation +/// (any ModelPackage_Set*, ModelPackage_Remove*, ModelPackage_AddSharedAsset, +/// or ModelPackage_Commit call), since those calls may rebuild the +/// shared-asset table or rehash the pending-copies map. Callers that need to +/// retain the URI must copy it into their own storage. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_AddSharedAsset(ModelPackage*, + const char* source_dir, + const char* expected_uri_or_null, + bool copy_in, + const char** out_uri); + +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_RemoveSharedAsset(ModelPackage*, const char* uri); + +/// Set or clear package-level metadata. Any argument may be NULL to leave the +/// existing value untouched. Passing an empty string clears the field. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetMetadata(ModelPackage*, + const char* name_or_null, + const char* version_or_null, + const char* description_or_null); + +/// Set the layout. Valid values: "portable" or "installed". +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetLayout(ModelPackage*, const char* layout); + +/// Set or clear `additional_metadata` at a given scope. +/// scope = "manifest" — component_or_null and variant_or_null must be NULL +/// scope = "component" — component_or_null is required, variant_or_null is NULL +/// scope = "variant" — component_or_null and variant_or_null are both required +/// `json_or_null == NULL` clears the field at that scope. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_SetAdditionalMetadataJson(ModelPackage*, + const char* scope, + const char* component_or_null, + const char* variant_or_null, + const char* json_or_null); + +// ───────────────────────────────────────────────────────────────────────────── +// Commit / Prune / Validate +// ───────────────────────────────────────────────────────────────────────────── + +typedef enum { + MODEL_PACKAGE_WRITE_PRESERVE = 0, ///< each component/executor-info keeps its current shape + MODEL_PACKAGE_WRITE_DENSE = 1, ///< flatten all external components inline +} ModelPackageWriteMode; + +/// Persist the in-memory model to disk. `dest_root_or_null == NULL` commits +/// in-place at `package_root`. Otherwise `dest_root` must be empty or +/// nonexistent and the entire package is materialized there (self-contained +/// "save as"). On a successful dest_root commit, the package's root is +/// updated to `dest_root` so subsequent in-place commits go there. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_Commit(ModelPackage*, + const char* dest_root_or_null, + ModelPackageWriteMode mode); + +/// Reclaim stale `.tmp.` staging directories under +/// `/shared_assets/` (left by interrupted commits, after a grace +/// window) and tracked orphan variant/component directories left behind by +/// RemoveVariant, RemoveComponent, SetVariant or SetComponentExternal. Only +/// paths registered through this API and inside `package_root` are touched. +/// Content-addressed shared-asset (`sha256-`) directories are never removed +/// — use ModelPackage_RemoveSharedAsset to reclaim those. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_Prune(ModelPackage*); + +typedef enum { + MODEL_PACKAGE_VALIDATE_SCHEMA = 1 << 0, + MODEL_PACKAGE_VALIDATE_PATHS = 1 << 1, + MODEL_PACKAGE_VALIDATE_ASSET_REHASH = 1 << 2, + MODEL_PACKAGE_VALIDATE_UNKNOWN_FIELDS = 1 << 3, + MODEL_PACKAGE_VALIDATE_ALL = ~0, +} ModelPackageValidateFlags; + +/// Run structural and reachability checks. `*out_report_json` is set to a +/// JSON string owned by the package describing findings: +/// `{"errors": [{"code": "...", "message": "..."}, ...], +/// "warnings": [...]}` +/// Returns non-NULL status when any error-level finding fired; warnings alone +/// still return success. +MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_Validate(ModelPackage*, + int flags, + const char** out_report_json); + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/model_package/include/model_package_api.h b/model_package/include/model_package_api.h index ca840c8a33e0e..36e678feed0f6 100644 --- a/model_package/include/model_package_api.h +++ b/model_package/include/model_package_api.h @@ -2,18 +2,15 @@ // Licensed under the MIT License. /// \file model_package_api.h -/// \brief Standalone C API for parsing and inspecting ONNX Runtime Model Packages. +/// \brief Core types shared by the model_package public API surface. /// -/// This library has no dependency on ONNX Runtime. It provides read-only access to -/// model package structure: components, variants, EP compatibility declarations, -/// model files, session/provider options, and consumer metadata. +/// This header defines the export macro, the opaque `ModelPackageStatus` type, +/// and the `ModelPackageErrorCode` enum used by every entry point in the +/// library. The actual API entry points live in `model_package.h`. /// -/// Error handling: Functions that can fail return `ModelPackageStatus*`. -/// A nullptr return indicates success. On failure, use `ModelPackage_GetErrorMessage()` -/// to retrieve the error string, and `ModelPackage_ReleaseStatus()` to free it. -/// -/// Lifetime: All `const char*` pointers returned by this API are owned by the -/// `ModelPackageContext` and remain valid until it is released. +/// Error handling: functions that can fail return `ModelPackageStatus*`. A +/// `nullptr` return indicates success. Use the `ModelPackageStatus_*` helpers +/// in `model_package.h` to inspect and release statuses. #pragma once @@ -45,110 +42,31 @@ extern "C" { #endif // ───────────────────────────────────────────────────────────────────────────── -// Opaque types +// Opaque status type // ───────────────────────────────────────────────────────────────────────────── /// Opaque status type. nullptr indicates success. typedef struct ModelPackageStatus ModelPackageStatus; -/// Opaque context holding a parsed model package. -typedef struct ModelPackageContext ModelPackageContext; - -// ───────────────────────────────────────────────────────────────────────────── -// Status API -// ───────────────────────────────────────────────────────────────────────────── - -/// Release a status object. Safe to call with nullptr. -MODEL_PACKAGE_API void ModelPackage_ReleaseStatus(ModelPackageStatus* status); - -/// Get the error message from a status object. Returns nullptr if status is nullptr. -/// The returned string is owned by the status object. -MODEL_PACKAGE_API const char* ModelPackage_GetErrorMessage(const ModelPackageStatus* status); - -// ───────────────────────────────────────────────────────────────────────────── -// Context lifecycle -// ───────────────────────────────────────────────────────────────────────────── - -/// Parse a model package from a directory path and create a context. -/// -/// \param[in] package_root_path Null-terminated UTF-8 path to the package root directory. -/// \param[out] out_context On success, receives the created context. Caller must release -/// via ModelPackage_ReleaseContext(). -/// \return nullptr on success, or a status object describing the error. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_CreateContext( - const char* package_root_path, - ModelPackageContext** out_context); - -/// Release a model package context and all associated resources. -/// Safe to call with nullptr. -MODEL_PACKAGE_API void ModelPackage_ReleaseContext(ModelPackageContext* context); - -// ───────────────────────────────────────────────────────────────────────────── -// Package-level queries -// ───────────────────────────────────────────────────────────────────────────── - -/// Get the schema version declared in manifest.json. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetSchemaVersion( - const ModelPackageContext* context, - int64_t* out_version); - -// ───────────────────────────────────────────────────────────────────────────── -// Component queries -// ───────────────────────────────────────────────────────────────────────────── - -/// Get the number of components in the package. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetComponentCount( - const ModelPackageContext* context, - size_t* out_count); - -/// Get the name of a component by index. -/// -/// \param[in] context The package context. -/// \param[in] component_idx Zero-based index (must be < component count). -/// \param[out] out_name Receives a pointer to the component name string. -/// Lifetime is tied to the context. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetComponentName( - const ModelPackageContext* context, - size_t component_idx, - const char** out_name); - -// ───────────────────────────────────────────────────────────────────────────── -// Variant queries // ───────────────────────────────────────────────────────────────────────────── - -/// Get the number of variants for a component. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetVariantCount( - const ModelPackageContext* context, - const char* component_name, - size_t* out_count); - -/// Get the name of a variant by index. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetVariantName( - const ModelPackageContext* context, - const char* component_name, - size_t variant_idx, - const char** out_name); - -/// Get the folder path for a variant (resolved absolute path). -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetVariantFolderPath( - const ModelPackageContext* context, - const char* component_name, - const char* variant_name, - const char** out_path); - -// ───────────────────────────────────────────────────────────────────────────── -// EP compatibility queries -// ───────────────────────────────────────────────────────────────────────────── - -/// Get the EP name declared for a variant. -/// -/// Each variant targets a single EP. When the variant does not declare an EP, -/// the returned pointer is set to nullptr. -MODEL_PACKAGE_API ModelPackageStatus* ModelPackage_GetVariantEpName( - const ModelPackageContext* context, - const char* component_name, - const char* variant_name, - const char** out_ep); +// Error codes +// ───────────────────────────────────────────────────────────────────────────── + +/// Categorical error codes attached to every non-OK ModelPackageStatus. +/// Stable additive enum: new codes will be appended at the end; existing +/// values will not be renumbered. +typedef enum ModelPackageErrorCode { + MODEL_PACKAGE_OK = 0, + MODEL_PACKAGE_ERR_IO = 1, ///< Filesystem read/write/sync failure. + MODEL_PACKAGE_ERR_SCHEMA = 2, ///< JSON value has wrong shape or wrong type. + MODEL_PACKAGE_ERR_VERSION = 3, ///< Unsupported schema_version. + MODEL_PACKAGE_ERR_PATH_CONFINEMENT = 4, ///< Path resolution escaped the allowed base. + MODEL_PACKAGE_ERR_ASSET_MISSING = 5, ///< Declared shared asset not resolvable. + MODEL_PACKAGE_ERR_ASSET_HASH_MISMATCH = 6, ///< Existing asset directory failed rehash. + MODEL_PACKAGE_ERR_NOT_FOUND = 7, ///< Named entity not present. + MODEL_PACKAGE_ERR_INVALID_ARG = 8, ///< Null pointer or otherwise invalid argument. + MODEL_PACKAGE_ERR_STATE = 9 ///< Operation not legal in current state. +} ModelPackageErrorCode; #ifdef __cplusplus } // extern "C" diff --git a/model_package/src/api.cc b/model_package/src/api.cc deleted file mode 100644 index 103bff8e1a4a3..0000000000000 --- a/model_package/src/api.cc +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "model_package_api.h" -#include "model_package_internal.h" -#include "parser.h" - -#include -#include - -// ───────────────────────────────────────────────────────────────────────────── -// Status implementation -// ───────────────────────────────────────────────────────────────────────────── - -struct ModelPackageStatus { - std::string message; -}; - -static ModelPackageStatus* MakeError(std::string msg) { - return new (std::nothrow) ModelPackageStatus{std::move(msg)}; -} - -// ───────────────────────────────────────────────────────────────────────────── -// Context is the public opaque type wrapping ContextImpl -// ───────────────────────────────────────────────────────────────────────────── - -struct ModelPackageContext { - model_package::ContextImpl impl; -}; - -// ───────────────────────────────────────────────────────────────────────────── -// ContextImpl lookup helpers -// ───────────────────────────────────────────────────────────────────────────── - -namespace model_package { - -const Component* ContextImpl::FindComponent(const char* name) const { - for (const auto& c : package_info.components) { - if (c.name == name) return &c; - } - return nullptr; -} - -const Variant* ContextImpl::FindVariant(const char* component_name, const char* variant_name) const { - const auto* comp = FindComponent(component_name); - if (!comp) return nullptr; - for (const auto& v : comp->variants) { - if (v.name == variant_name) return &v; - } - return nullptr; -} - -} // namespace model_package - -// ───────────────────────────────────────────────────────────────────────────── -// Validation macro -// ───────────────────────────────────────────────────────────────────────────── - -#define RETURN_IF_NULL(ptr, param_name) \ - do { \ - if ((ptr) == nullptr) \ - return MakeError(std::string(param_name) + " must not be null."); \ - } while (0) - -// ───────────────────────────────────────────────────────────────────────────── -// C API implementation -// ───────────────────────────────────────────────────────────────────────────── - -extern "C" { - -void ModelPackage_ReleaseStatus(ModelPackageStatus* status) { - delete status; -} - -const char* ModelPackage_GetErrorMessage(const ModelPackageStatus* status) { - if (status == nullptr) return nullptr; - return status->message.c_str(); -} - -ModelPackageStatus* ModelPackage_CreateContext( - const char* package_root_path, - ModelPackageContext** out_context) { - RETURN_IF_NULL(package_root_path, "package_root_path"); - RETURN_IF_NULL(out_context, "out_context"); - - *out_context = nullptr; - - auto ctx = std::make_unique(); - std::string error; - - if (!model_package::ParsePackage( - std::filesystem::path(std::string(package_root_path)), - ctx->impl.package_info, error)) { - return MakeError(std::move(error)); - } - - // Build component names cache. - ctx->impl.component_names_cache.clear(); - for (const auto& c : ctx->impl.package_info.components) { - ctx->impl.component_names_cache.push_back(c.name); - } - - // Build variant names cache. - for (const auto& c : ctx->impl.package_info.components) { - auto& names = ctx->impl.variant_names_cache[c.name]; - names.clear(); - for (const auto& v : c.variants) { - names.push_back(v.name); - } - } - - *out_context = ctx.release(); - return nullptr; -} - -void ModelPackage_ReleaseContext(ModelPackageContext* context) { - delete context; -} - -ModelPackageStatus* ModelPackage_GetSchemaVersion( - const ModelPackageContext* context, - int64_t* out_version) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(out_version, "out_version"); - *out_version = context->impl.package_info.schema_version; - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetComponentCount( - const ModelPackageContext* context, - size_t* out_count) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(out_count, "out_count"); - *out_count = context->impl.package_info.components.size(); - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetComponentName( - const ModelPackageContext* context, - size_t component_idx, - const char** out_name) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(out_name, "out_name"); - - if (component_idx >= context->impl.component_names_cache.size()) { - return MakeError("component_idx out of range: " + std::to_string(component_idx)); - } - - *out_name = context->impl.component_names_cache[component_idx].c_str(); - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetVariantCount( - const ModelPackageContext* context, - const char* component_name, - size_t* out_count) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(component_name, "component_name"); - RETURN_IF_NULL(out_count, "out_count"); - - const auto* comp = context->impl.FindComponent(component_name); - if (!comp) { - return MakeError(std::string("Component not found: '") + component_name + "'."); - } - - *out_count = comp->variants.size(); - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetVariantName( - const ModelPackageContext* context, - const char* component_name, - size_t variant_idx, - const char** out_name) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(component_name, "component_name"); - RETURN_IF_NULL(out_name, "out_name"); - - auto it = context->impl.variant_names_cache.find(component_name); - if (it == context->impl.variant_names_cache.end()) { - return MakeError(std::string("Component not found: '") + component_name + "'."); - } - - if (variant_idx >= it->second.size()) { - return MakeError("variant_idx out of range: " + std::to_string(variant_idx)); - } - - *out_name = it->second[variant_idx].c_str(); - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetVariantFolderPath( - const ModelPackageContext* context, - const char* component_name, - const char* variant_name, - const char** out_path) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(component_name, "component_name"); - RETURN_IF_NULL(variant_name, "variant_name"); - RETURN_IF_NULL(out_path, "out_path"); - - const auto* variant = context->impl.FindVariant(component_name, variant_name); - if (!variant) { - return MakeError(std::string("Variant '") + variant_name + "' not found in component '" + - component_name + "'."); - } - - // Cache the path string for stable pointer. - std::string cache_key = std::string(component_name) + "/" + variant_name; - auto& cached = const_cast(context)->impl.folder_path_strings_cache[cache_key]; - if (cached.empty()) { - cached = variant->folder_path.string(); - } - *out_path = cached.c_str(); - return nullptr; -} - -ModelPackageStatus* ModelPackage_GetVariantEpName( - const ModelPackageContext* context, - const char* component_name, - const char* variant_name, - const char** out_ep) { - RETURN_IF_NULL(context, "context"); - RETURN_IF_NULL(component_name, "component_name"); - RETURN_IF_NULL(variant_name, "variant_name"); - - const auto* variant = context->impl.FindVariant(component_name, variant_name); - if (!variant) { - return MakeError(std::string("Variant '") + variant_name + "' not found in component '" + - component_name + "'."); - } - - if (out_ep) { - if (variant->ep_compatibility.ep.has_value()) { - *out_ep = variant->ep_compatibility.ep->c_str(); - } else { - *out_ep = nullptr; - } - } - return nullptr; -} - -} // extern "C" diff --git a/model_package/src/asset_hasher.cc b/model_package/src/asset_hasher.cc new file mode 100644 index 0000000000000..b41019d11a757 --- /dev/null +++ b/model_package/src/asset_hasher.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "asset_hasher.h" + +#include +#include +#include +#include + +#include "sha256.h" +#include "status_impl.h" + +namespace fs = std::filesystem; + +namespace model_package { + +using model_package::MakeStatus; + +namespace { + +std::string ToPosix(const fs::path& rel) { + std::string s = rel.generic_string(); // generic_string uses '/' + // Strip leading "./" if any (lexical normalization edge case). + if (s.size() >= 2 && s[0] == '.' && s[1] == '/') s.erase(0, 2); + return s; +} + +} // namespace + +ModelPackageStatus* ComputeDirectoryAssetUri(const fs::path& source_dir, + std::string* out_uri) { + if (!out_uri) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, "ComputeDirectoryAssetUri: out_uri is null."); + } + std::error_code ec; + if (!fs::exists(source_dir, ec) || !fs::is_directory(source_dir, ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + "ComputeDirectoryAssetUri: '" + source_dir.string() + "' is not a directory."); + } + + // Collect (relative_posix_path, absolute_path) pairs. + std::vector> entries; + + auto walker = fs::recursive_directory_iterator( + source_dir, fs::directory_options::none, ec); + if (ec) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "ComputeDirectoryAssetUri: cannot iterate '" + source_dir.string() + + "': " + ec.message()); + } + for (; walker != fs::recursive_directory_iterator(); walker.increment(ec)) { + if (ec) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "ComputeDirectoryAssetUri: iteration error: " + ec.message()); + } + const fs::directory_entry& de = *walker; + if (de.is_symlink(ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "ComputeDirectoryAssetUri: symlink not allowed: '" + de.path().string() + "'."); + } + if (de.is_regular_file(ec)) { + fs::path rel = fs::relative(de.path(), source_dir, ec); + if (ec) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "ComputeDirectoryAssetUri: relative path failed: " + ec.message()); + } + entries.emplace_back(ToPosix(rel), de.path()); + } else if (!de.is_directory(ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "ComputeDirectoryAssetUri: unsupported file kind: '" + + de.path().string() + "' (only regular files and directories allowed)."); + } + } + + std::sort(entries.begin(), entries.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + + std::string manifest_text; + manifest_text.reserve(entries.size() * 96); + for (const auto& entry : entries) { + std::string file_hex = Sha256::HashFileHex(entry.second.string()); + if (file_hex.empty()) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "ComputeDirectoryAssetUri: failed to hash file '" + entry.second.string() + "'."); + } + manifest_text.append(file_hex); + manifest_text.append(" "); + manifest_text.append(entry.first); + manifest_text.append("\n"); + } + + *out_uri = "sha256:" + Sha256::HashStringHex(manifest_text); + return nullptr; +} + +} // namespace model_package diff --git a/model_package/src/asset_hasher.h b/model_package/src/asset_hasher.h new file mode 100644 index 0000000000000..f9bd6eb1c5d9b --- /dev/null +++ b/model_package/src/asset_hasher.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file asset_hasher.h +/// \brief Directory Merkle hash for content-addressed shared assets. + +#pragma once + +#include +#include + +#include "model_package_api.h" + +namespace model_package { + +/// Compute the canonical asset URI for a directory: +/// 1. Walk recursively, collect regular files (ignore empty dirs). +/// 2. Reject symlinks (ERR_SCHEMA: portability hazard). +/// 3. For each file, compute sha256(file_bytes) → per-file hex. +/// 4. Build manifest text: ` \n` lines, +/// sorted lexicographically by path. Paths are POSIX (`/`), no leading +/// `./`. NFC normalization is the caller's responsibility for non-ASCII +/// paths; ASCII is identity. +/// 5. asset_uri = "sha256:" + sha256(manifest_text), lowercase hex. +/// +/// On success, *out_uri is set to the URI string. +ModelPackageStatus* ComputeDirectoryAssetUri(const std::filesystem::path& source_dir, + std::string* out_uri); + +} // namespace model_package diff --git a/model_package/src/authoring.cc b/model_package/src/authoring.cc new file mode 100644 index 0000000000000..4c7a9e9c1f6d5 --- /dev/null +++ b/model_package/src/authoring.cc @@ -0,0 +1,593 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file authoring.cc +/// \brief Mutation (authoring) API implementation. + +#include "model_package.h" + +#include +#include +#include +#include +#include +#include + +#include "asset_hasher.h" +#include "manifest_parser.h" +#include "model_package_impl.h" +#include "path_resolver.h" +#include "status_impl.h" + +namespace fs = std::filesystem; +namespace mp = model_package; +using model_package::MakeStatus; +using nlohmann::ordered_json; + +namespace { + +// Schema version stamped into newly authored packages, written as a "." +// string. Keep in sync with the parser's supported major + highest known minor +// (manifest_parser.cc: kMaxSupportedSchemaMajor / kMaxKnownSchemaMinor). +constexpr const char* kAuthoredSchemaVersion = "1.0"; + +ModelPackageStatus* NullArg(const char* name) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + std::string("model_package: '") + name + "' must not be null."); +} + +ModelPackageStatus* ParseJsonString(const char* json, const char* where, ordered_json* out) { + try { + *out = ordered_json::parse(json); + } catch (const ordered_json::parse_error& e) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + std::string(where) + ": JSON parse error: " + e.what()); + } + return nullptr; +} + +ModelPackageStatus* ExpectObject(const ordered_json& j, const char* where) { + if (!j.is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + std::string(where) + ": expected a JSON object."); + } + return nullptr; +} + +void RebuildComponentIndex(ModelPackage* pkg) { + pkg->component_index_by_name.clear(); + for (size_t i = 0; i < pkg->components.size(); ++i) { + pkg->component_index_by_name[pkg->components[i]->name] = i; + } +} + +mp::ComponentRecord* FindComponentRecord(ModelPackage* pkg, const std::string& name) { + auto it = pkg->component_index_by_name.find(name); + if (it == pkg->component_index_by_name.end()) return nullptr; + return pkg->components[it->second].get(); +} + +mp::VariantRecord* FindVariantRecord(mp::ComponentRecord* comp, const std::string& name) { + for (auto& v : comp->variants) { + if (v->name == name) return v.get(); + } + return nullptr; +} + +ModelPackageStatus* RefreshSharedAssetsHelper(ModelPackage* pkg) { + return mp::RefreshSharedAssets(pkg, mp::PathOptionsFor(pkg)); +} + +ModelPackageStatus* PostMutate(ModelPackage* pkg, bool refresh_assets = true) { + mp::DropViewCache(pkg); + if (refresh_assets) { + if (auto* s = RefreshSharedAssetsHelper(pkg)) return s; + } + if (auto* s = mp::RefreshPackageMetadata(pkg)) return s; + return mp::RefreshExecutorInfoCache(pkg, /*strict_missing_external=*/false); +} + +ordered_json& EnsureManifestComponentsObject(ModelPackage* pkg) { + if (!pkg->manifest.contains("components")) { + pkg->manifest["components"] = ordered_json::object(); + } + return pkg->manifest["components"]; +} + +} // namespace + +extern "C" { + +// ───────────────────────────────────────────────────────────────────────────── +// ModelPackage_New +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_New(ModelPackage** out) { + if (!out) return NullArg("out"); + auto pkg = std::make_unique(); + pkg->manifest = ordered_json::object(); + // Authored at this build's schema version, written as a "." string. + pkg->manifest["schema_version"] = kAuthoredSchemaVersion; + pkg->manifest["layout"] = "portable"; + pkg->manifest["components"] = ordered_json::object(); + pkg->layout = "portable"; + pkg->strict_unknown_fields = true; + pkg->follow_symlinks = true; + pkg->allow_external_paths = false; + pkg->package_root = fs::path(); + if (auto* s = mp::RefreshPackageMetadata(pkg.get())) return s; + *out = pkg.release(); + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Components +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_SetComponentInline(ModelPackage* pkg, + const char* name, + const char* component_json) { + if (!pkg) return NullArg("pkg"); + if (!name) return NullArg("name"); + if (!component_json) return NullArg("component_json"); + + ordered_json body; + if (auto* s = ParseJsonString(component_json, + ("component '" + std::string(name) + "'").c_str(), &body)) return s; + if (auto* s = ExpectObject(body, ("component '" + std::string(name) + "'").c_str())) return s; + + auto opts = mp::PathOptionsFor(pkg); + auto rec = std::make_unique(); + rec->storage = mp::ComponentStorage::kInline; + rec->component_dir = pkg->package_root; + if (auto* s = mp::ParseComponentBody(pkg->package_root, opts, pkg->strict_unknown_fields, + name, body, pkg->package_root, rec.get())) return s; + + EnsureManifestComponentsObject(pkg)[name] = body; + + if (auto* existing = FindComponentRecord(pkg, name)) { + size_t idx = pkg->component_index_by_name[name]; + mp::RecordOrphanComponent(pkg, *pkg->components[idx]); + pkg->components[idx] = std::move(rec); + } else { + pkg->components.push_back(std::move(rec)); + } + RebuildComponentIndex(pkg); + return PostMutate(pkg); +} + +ModelPackageStatus* ModelPackage_SetComponentExternal(ModelPackage* pkg, + const char* name, + const char* path) { + if (!pkg) return NullArg("pkg"); + if (!name) return NullArg("name"); + if (!path) return NullArg("path"); + if (pkg->package_root.empty()) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "SetComponentExternal requires a package_root (use _Open or _Commit " + "with a dest_root first; or rely on _Commit(dest_root) to materialize)."); + } + + auto opts = mp::PathOptionsFor(pkg); + fs::path resolved; + // Allow the file/dir to not exist yet (we'll initialize empty). + if (auto* s = mp::ResolvePath(pkg->package_root, pkg->package_root, path, opts, + /*must_exist=*/false, &resolved)) return s; + std::error_code ec; + fs::path component_dir; + fs::path file_path; + if (fs::exists(resolved, ec) && fs::is_directory(resolved, ec)) { + file_path = resolved / "component.json"; + component_dir = resolved; + } else { + file_path = resolved; + component_dir = resolved.parent_path(); + } + ordered_json body; + if (fs::exists(file_path, ec)) { + std::ifstream f(file_path, std::ios::binary); + std::ostringstream buf; + buf << f.rdbuf(); + if (auto* s = ParseJsonString(buf.str().c_str(), + ("component '" + std::string(name) + "'").c_str(), &body)) return s; + } else { + body = ordered_json::object(); + body["variants"] = ordered_json::object(); + } + if (auto* s = ExpectObject(body, ("component '" + std::string(name) + "'").c_str())) return s; + + auto rec = std::make_unique(); + rec->storage = mp::ComponentStorage::kExternal; + rec->external_path = file_path; + rec->component_dir = component_dir; + if (auto* s = mp::ParseComponentBody(pkg->package_root, opts, pkg->strict_unknown_fields, + name, body, component_dir, rec.get())) return s; + + EnsureManifestComponentsObject(pkg)[name] = std::string(path); + + if (FindComponentRecord(pkg, name)) { + size_t idx = pkg->component_index_by_name[name]; + mp::RecordOrphanComponent(pkg, *pkg->components[idx]); + pkg->components[idx] = std::move(rec); + } else { + pkg->components.push_back(std::move(rec)); + } + RebuildComponentIndex(pkg); + return PostMutate(pkg); +} + +ModelPackageStatus* ModelPackage_RemoveComponent(ModelPackage* pkg, const char* name) { + if (!pkg) return NullArg("pkg"); + if (!name) return NullArg("name"); + auto it = pkg->component_index_by_name.find(name); + if (it == pkg->component_index_by_name.end()) return nullptr; + size_t idx = it->second; + mp::RecordOrphanComponent(pkg, *pkg->components[idx]); + pkg->components.erase(pkg->components.begin() + idx); + auto comps_it = pkg->manifest.find("components"); + if (comps_it != pkg->manifest.end() && comps_it->is_object()) { + comps_it->erase(name); + } + RebuildComponentIndex(pkg); + return PostMutate(pkg); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Variants +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_SetVariant(ModelPackage* pkg, + const char* component_name, + const char* variant_name, + const char* variant_json) { + if (!pkg) return NullArg("pkg"); + if (!component_name) return NullArg("component_name"); + if (!variant_name) return NullArg("variant_name"); + if (!variant_json) return NullArg("variant_json"); + auto* comp = FindComponentRecord(pkg, component_name); + if (!comp) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("SetVariant: component '") + component_name + "' not found."); + } + ordered_json body; + if (auto* s = ParseJsonString(variant_json, + ("variant '" + std::string(variant_name) + "'").c_str(), &body)) return s; + + auto vr = std::make_unique(); + auto opts = mp::PathOptionsFor(pkg); + if (auto* s = mp::ParseVariantBody(comp->component_dir, pkg->package_root, opts, + pkg->strict_unknown_fields, + variant_name, body, vr.get())) return s; + + // Update component.body["variants"][variant_name] + if (!comp->body.contains("variants") || !comp->body["variants"].is_object()) { + comp->body["variants"] = ordered_json::object(); + } + comp->body["variants"][variant_name] = body; + // If component is inline, mirror into manifest. + if (comp->storage == mp::ComponentStorage::kInline) { + pkg->manifest["components"][comp->name] = comp->body; + } + // Replace or append. + bool replaced = false; + for (auto& v : comp->variants) { + if (v->name == variant_name) { + mp::RecordOrphanVariantDir(pkg, *v); + v = std::move(vr); + replaced = true; + break; + } + } + if (!replaced) comp->variants.push_back(std::move(vr)); + + // Invalidate cached component JSON. + comp->component_json_cache.reset(); + return PostMutate(pkg); +} + +ModelPackageStatus* ModelPackage_RemoveVariant(ModelPackage* pkg, + const char* component_name, + const char* variant_name) { + if (!pkg) return NullArg("pkg"); + if (!component_name) return NullArg("component_name"); + if (!variant_name) return NullArg("variant_name"); + auto* comp = FindComponentRecord(pkg, component_name); + if (!comp) return nullptr; + auto pred = [&](const std::unique_ptr& v) { + if (v->name == variant_name) { + mp::RecordOrphanVariantDir(pkg, *v); + return true; + } + return false; + }; + comp->variants.erase(std::remove_if(comp->variants.begin(), comp->variants.end(), pred), + comp->variants.end()); + if (comp->body.contains("variants") && comp->body["variants"].is_object()) { + comp->body["variants"].erase(variant_name); + } + if (comp->storage == mp::ComponentStorage::kInline) { + pkg->manifest["components"][comp->name] = comp->body; + } + comp->component_json_cache.reset(); + return PostMutate(pkg); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Variant executor_info +// ───────────────────────────────────────────────────────────────────────────── + +namespace { + +ModelPackageStatus* ReparseVariantInPlace(ModelPackage* pkg, + mp::ComponentRecord* comp, + mp::VariantRecord* var) { + auto opts = mp::PathOptionsFor(pkg); + auto rebuilt = std::make_unique(); + if (auto* s = mp::ParseVariantBody(comp->component_dir, pkg->package_root, opts, + pkg->strict_unknown_fields, + var->name, var->body, rebuilt.get())) return s; + *var = std::move(*rebuilt); + return nullptr; +} + +ModelPackageStatus* MutateExecutorInfo(ModelPackage* pkg, + const char* component, + const char* variant, + const char* namespace_, + const ordered_json* new_value /* null = remove */) { + if (!pkg) return NullArg("pkg"); + if (!component) return NullArg("component"); + if (!variant) return NullArg("variant"); + if (!namespace_) return NullArg("namespace"); + auto* comp = FindComponentRecord(pkg, component); + if (!comp) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("component '") + component + "' not found."); + } + auto* var = FindVariantRecord(comp, variant); + if (!var) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("variant '") + variant + "' not found in component '" + + component + "'."); + } + if (!var->body.contains("executor_info") || !var->body["executor_info"].is_object()) { + if (!new_value) return nullptr; // remove on absent -> nothing to do + var->body["executor_info"] = ordered_json::object(); + } + if (new_value) { + var->body["executor_info"][namespace_] = *new_value; + } else { + var->body["executor_info"].erase(namespace_); + if (var->body["executor_info"].empty()) { + var->body.erase("executor_info"); + } + } + comp->body["variants"][var->name] = var->body; + if (comp->storage == mp::ComponentStorage::kInline) { + pkg->manifest["components"][comp->name] = comp->body; + } + if (auto* s = ReparseVariantInPlace(pkg, comp, var)) return s; + comp->component_json_cache.reset(); + return PostMutate(pkg, /*refresh_assets=*/false); +} + +} // namespace + +ModelPackageStatus* ModelPackage_SetVariantExecutorInfoInline(ModelPackage* pkg, + const char* component, + const char* variant, + const char* namespace_, + const char* info_json) { + if (!info_json) return NullArg("info_json"); + ordered_json body; + if (auto* s = ParseJsonString(info_json, "executor_info", &body)) return s; + if (auto* s = ExpectObject(body, "executor_info inline value")) return s; + return MutateExecutorInfo(pkg, component, variant, namespace_, &body); +} + +ModelPackageStatus* ModelPackage_SetVariantExecutorInfoExternal(ModelPackage* pkg, + const char* component, + const char* variant, + const char* namespace_, + const char* path) { + if (!path) return NullArg("path"); + ordered_json body = std::string(path); + return MutateExecutorInfo(pkg, component, variant, namespace_, &body); +} + +ModelPackageStatus* ModelPackage_RemoveVariantExecutorInfo(ModelPackage* pkg, + const char* component, + const char* variant, + const char* namespace_) { + return MutateExecutorInfo(pkg, component, variant, namespace_, nullptr); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Shared assets +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_AddSharedAsset(ModelPackage* pkg, + const char* source_dir, + const char* expected_uri_or_null, + bool copy_in, + const char** out_uri) { + if (!pkg) return NullArg("pkg"); + if (!source_dir) return NullArg("source_dir"); + if (!out_uri) return NullArg("out_uri"); + *out_uri = nullptr; + + if (!copy_in && pkg->layout == "portable") { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "AddSharedAsset: copy_in=false rejected in portable layout (the " + "path would point outside )."); + } + + std::string computed_uri; + if (auto* s = mp::ComputeDirectoryAssetUri(fs::path(source_dir), &computed_uri)) return s; + if (expected_uri_or_null) { + if (computed_uri != expected_uri_or_null) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + std::string("AddSharedAsset: hash mismatch — computed ") + + computed_uri + ", expected " + expected_uri_or_null + "."); + } + } + + if (!pkg->manifest.contains("shared_assets") || !pkg->manifest["shared_assets"].is_object()) { + pkg->manifest["shared_assets"] = ordered_json::object(); + } + if (copy_in) { + // No manifest entry needed — the asset will be materialized at the default + // convention path on commit. LoadSharedAssets surfaces the staged source + // immediately so the URI shows up in ModelPackage_Info() before commit. + pkg->pending_shared_asset_copies[computed_uri] = fs::path(source_dir); + } else { + pkg->manifest["shared_assets"][computed_uri] = std::string(source_dir); + } + + if (auto* s = PostMutate(pkg)) return s; + + // Look up the record and return its URI. After PostMutate, the URI is + // always present in shared_assets_index_by_uri (either via the override + // path or via the pending-copy tier of LoadSharedAssets). + auto sit = pkg->shared_asset_index_by_uri.find(computed_uri); + if (sit == pkg->shared_asset_index_by_uri.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + std::string("AddSharedAsset: failed to register URI ") + computed_uri); + } + *out_uri = pkg->shared_assets[sit->second]->uri_cache.c_str(); + return nullptr; +} + +ModelPackageStatus* ModelPackage_RemoveSharedAsset(ModelPackage* pkg, const char* uri) { + if (!pkg) return NullArg("pkg"); + if (!uri) return NullArg("uri"); + std::string uri_str(uri); + if (pkg->manifest.contains("shared_assets") && pkg->manifest["shared_assets"].is_object()) { + pkg->manifest["shared_assets"].erase(uri_str); + if (pkg->manifest["shared_assets"].empty()) { + pkg->manifest.erase("shared_assets"); + } + } + pkg->pending_shared_asset_copies.erase(uri_str); + // Physically remove the on-disk directory at the default convention. If it + // stays on disk, the next RefreshSharedAssets would auto-discover it again + // and the removal would be a no-op. We only touch paths that live inside + // package_root. + if (!pkg->package_root.empty()) { + std::string dir_name = mp::DefaultSharedAssetDirName(uri_str); + if (!dir_name.empty()) { + std::filesystem::path on_disk = pkg->package_root / "shared_assets" / dir_name; + if (mp::IsInsidePackageRoot(pkg, on_disk)) { + std::error_code ec; + std::filesystem::remove_all(on_disk, ec); + } + } + } + return PostMutate(pkg); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Package metadata +// ───────────────────────────────────────────────────────────────────────────── + +namespace { + +void SetOrClearString(ordered_json* obj, const char* key, const char* value) { + if (value == nullptr) return; // leave untouched + if (value[0] == '\0') { + obj->erase(key); + } else { + (*obj)[key] = std::string(value); + } +} + +} // namespace + +ModelPackageStatus* ModelPackage_SetMetadata(ModelPackage* pkg, + const char* name_or_null, + const char* version_or_null, + const char* description_or_null) { + if (!pkg) return NullArg("pkg"); + SetOrClearString(&pkg->manifest, "package_name", name_or_null); + SetOrClearString(&pkg->manifest, "package_version", version_or_null); + SetOrClearString(&pkg->manifest, "description", description_or_null); + return PostMutate(pkg, /*refresh_assets=*/false); +} + +ModelPackageStatus* ModelPackage_SetLayout(ModelPackage* pkg, const char* layout) { + if (!pkg) return NullArg("pkg"); + if (!layout) return NullArg("layout"); + std::string l(layout); + if (l != "portable" && l != "installed") { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "SetLayout: layout must be 'portable' or 'installed'."); + } + pkg->manifest["layout"] = l; + pkg->layout = l; + return PostMutate(pkg, /*refresh_assets=*/false); +} + +ModelPackageStatus* ModelPackage_SetAdditionalMetadataJson(ModelPackage* pkg, + const char* scope, + const char* component_or_null, + const char* variant_or_null, + const char* json_or_null) { + if (!pkg) return NullArg("pkg"); + if (!scope) return NullArg("scope"); + std::string s(scope); + ordered_json* target = nullptr; + mp::ComponentRecord* comp = nullptr; + mp::VariantRecord* var = nullptr; + if (s == "manifest") { + if (component_or_null || variant_or_null) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + "SetAdditionalMetadataJson: 'manifest' scope takes no component/variant."); + } + target = &pkg->manifest; + } else if (s == "component") { + if (!component_or_null) return NullArg("component"); + if (variant_or_null) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + "SetAdditionalMetadataJson: 'component' scope takes no variant."); + } + comp = FindComponentRecord(pkg, component_or_null); + if (!comp) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("component '") + component_or_null + "' not found."); + } + target = &comp->body; + } else if (s == "variant") { + if (!component_or_null) return NullArg("component"); + if (!variant_or_null) return NullArg("variant"); + comp = FindComponentRecord(pkg, component_or_null); + if (!comp) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("component '") + component_or_null + "' not found."); + } + var = FindVariantRecord(comp, variant_or_null); + if (!var) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("variant '") + variant_or_null + "' not found."); + } + target = &var->body; + } else { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + "SetAdditionalMetadataJson: scope must be 'manifest', 'component', or 'variant'."); + } + if (json_or_null == nullptr) { + target->erase("additional_metadata"); + } else { + ordered_json body; + if (auto* st = ParseJsonString(json_or_null, "additional_metadata", &body)) return st; + (*target)["additional_metadata"] = body; + } + if (comp && comp->storage == mp::ComponentStorage::kInline) { + pkg->manifest["components"][comp->name] = comp->body; + } + if (comp) comp->component_json_cache.reset(); + if (var) var->additional_metadata_cache.reset(); + if (comp) comp->additional_metadata_cache.reset(); + return PostMutate(pkg, /*refresh_assets=*/false); +} + +} // extern "C" diff --git a/model_package/src/commit_prune_validate.cc b/model_package/src/commit_prune_validate.cc new file mode 100644 index 0000000000000..c603dd6bcbf00 --- /dev/null +++ b/model_package/src/commit_prune_validate.cc @@ -0,0 +1,769 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file commit_prune_validate.cc +/// \brief Commit, prune, and validate implementation. + +#include "model_package.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#include +#endif + +#include "asset_hasher.h" +#include "manifest_parser.h" +#include "model_package_impl.h" +#include "path_resolver.h" +#include "status_impl.h" + +namespace fs = std::filesystem; +namespace mp = model_package; +using model_package::MakeStatus; +using nlohmann::ordered_json; + +namespace { + +ModelPackageStatus* NullArg(const char* name) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + std::string("model_package: '") + name + "' must not be null."); +} + +// ───────────────────────────────────────────────────────────────────────────── +// fsync / random helpers (POSIX). Windows would substitute FlushFileBuffers + +// BCryptGenRandom; deferred to a follow-up. +// ───────────────────────────────────────────────────────────────────────────── + +std::string RandomSuffix() { + std::random_device rd; + uint64_t hi = (uint64_t(rd()) << 32) | rd(); + char buf[17]; + std::snprintf(buf, sizeof(buf), "%016llx", static_cast(hi)); + return buf; +} + +ModelPackageStatus* FsyncPath(const fs::path& p, bool is_dir) { +#ifdef _WIN32 + (void)p; + (void)is_dir; + return nullptr; +#else + int flags = is_dir ? (O_RDONLY | O_DIRECTORY) : O_RDONLY; + int fd = ::open(p.c_str(), flags); + if (fd < 0) { + // Best-effort: missing fsync targets are not fatal on tmpfs etc. + return nullptr; + } + if (::fsync(fd) != 0) { + int err = errno; + ::close(fd); + return MakeStatus(MODEL_PACKAGE_ERR_IO, + std::string("fsync '") + p.string() + "' failed: " + std::strerror(err)); + } + ::close(fd); + return nullptr; +#endif +} + +ModelPackageStatus* WriteFileAtomic(const fs::path& final_path, const std::string& bytes) { + fs::path tmp = final_path; + tmp += ".tmp." + RandomSuffix(); + { + std::ofstream f(tmp, std::ios::binary | std::ios::trunc); + if (!f) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Cannot open '" + tmp.string() + "' for writing."); + } + f.write(bytes.data(), static_cast(bytes.size())); + if (!f) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Write to '" + tmp.string() + "' failed."); + } + } + if (auto* s = FsyncPath(tmp, /*is_dir=*/false)) return s; + std::error_code ec; + fs::rename(tmp, final_path, ec); + if (ec) { + fs::remove(tmp, ec); + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Rename '" + tmp.string() + "' -> '" + final_path.string() + + "' failed: " + ec.message()); + } + if (auto* s = FsyncPath(final_path.parent_path(), /*is_dir=*/true)) return s; + return nullptr; +} + +ModelPackageStatus* CopyTreeNoFollow(const fs::path& src, const fs::path& dst) { + // Recursively copy `src` into `dst`. Refuses to follow symlinks (consistent + // with the directory hash semantics) so the on-disk bytes match the URI we + // already computed. + std::error_code ec; + fs::create_directories(dst, ec); + if (ec) return MakeStatus(MODEL_PACKAGE_ERR_IO, + "mkdir '" + dst.string() + "': " + ec.message()); + for (fs::recursive_directory_iterator it(src, fs::directory_options::none, ec), end; + it != end; it.increment(ec)) { + if (ec) return MakeStatus(MODEL_PACKAGE_ERR_IO, + "iterate '" + src.string() + "': " + ec.message()); + const auto& entry = *it; + fs::path rel = fs::relative(entry.path(), src, ec); + fs::path target = dst / rel; + if (entry.is_symlink()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "shared asset source contains a symlink: '" + entry.path().string() + "'."); + } + if (entry.is_directory()) { + fs::create_directories(target, ec); + if (ec) return MakeStatus(MODEL_PACKAGE_ERR_IO, + "mkdir '" + target.string() + "': " + ec.message()); + } else if (entry.is_regular_file()) { + fs::create_directories(target.parent_path(), ec); + fs::copy_file(entry.path(), target, fs::copy_options::overwrite_existing, ec); + if (ec) return MakeStatus(MODEL_PACKAGE_ERR_IO, + "copy '" + entry.path().string() + "' -> '" + + target.string() + "': " + ec.message()); + if (auto* s = FsyncPath(target, /*is_dir=*/false)) return s; + } else { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "unsupported file kind in shared asset source: '" + + entry.path().string() + "'."); + } + } + if (auto* s = FsyncPath(dst, /*is_dir=*/true)) return s; + return nullptr; +} + +ModelPackageStatus* CheckPortableConfinement(const fs::path& root, + const fs::path& candidate, + const std::string& where) { + std::error_code ec; + fs::path c = candidate.lexically_normal(); + fs::path r = root.lexically_normal(); + if (c.is_absolute()) { + // Confirm c is under r. + auto rel = fs::relative(c, r, ec); + // An empty relative path, or one whose first component is "..", escapes the root. + // (Checking only the first character would wrongly reject in-root dot-prefixed names + // such as ".hidden/component.json".) + if (ec || rel.empty() || rel.begin()->string() == "..") { + return MakeStatus(MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + where + ": absolute path '" + c.string() + + "' escapes package_root '" + r.string() + "' (portable layout)."); + } + } else { + // Relative: a leading ".." escapes. + auto first = c.begin(); + if (first != c.end() && first->string() == "..") { + return MakeStatus(MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + where + ": relative path '" + c.string() + + "' escapes package_root (portable layout)."); + } + } + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Manifest serialization +// ───────────────────────────────────────────────────────────────────────────── + +std::string SerializeManifestForCommit(const ModelPackage* pkg) { + // Use the live in-memory manifest, but for external components, the + // ComponentRecord::body may have diverged from the string path. The manifest + // entry stays as the string in that case; the body is serialized separately + // into the external file. + return pkg->manifest.dump(2) + "\n"; +} + +ordered_json SerializeComponentBody(const mp::ComponentRecord* comp) { + return comp->body; +} + +// ───────────────────────────────────────────────────────────────────────────── +// In-place commit (PRESERVE / DENSE) +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* CheckDenseConstraints(ModelPackage* pkg) { + // Reject external executor_info in dense mode (dense flattens everything, + // but the in-memory model never loads external executor_info bodies, so we + // can't inline them surgically. ERR_STATE so the caller's intent is clear.) + for (const auto& comp : pkg->components) { + auto vit = comp->body.find("variants"); + if (vit == comp->body.end() || !vit->is_object()) continue; + for (auto v = vit->begin(); v != vit->end(); ++v) { + auto ei = v->find("executor_info"); + if (ei == v->end() || !ei->is_object()) continue; + for (auto e = ei->begin(); e != ei->end(); ++e) { + if (e->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "WRITE_DENSE: component '" + comp->name + "' variant '" + + v.key() + "' has external executor_info '" + e.key() + + "' (string path). Convert to inline via " + "SetVariantExecutorInfoInline before dense commit."); + } + } + } + } + return nullptr; +} + +ModelPackageStatus* CommitSharedAssetsCopyIn(ModelPackage* pkg, const fs::path& root) { + if (pkg->pending_shared_asset_copies.empty()) return nullptr; + fs::path assets_root = root / "shared_assets"; + std::error_code ec; + fs::create_directories(assets_root, ec); + for (const auto& [uri, src] : pkg->pending_shared_asset_copies) { + std::string dir_name = mp::DefaultSharedAssetDirName(uri); + fs::path final_dir = assets_root / dir_name; + if (fs::exists(final_dir, ec)) continue; // already materialized — trust it. + fs::path stage_dir = assets_root / (dir_name + ".tmp." + RandomSuffix()); + if (auto* s = CopyTreeNoFollow(src, stage_dir)) { + fs::remove_all(stage_dir, ec); + return s; + } + // Re-hash staging to verify TOCTOU did not strike. + std::string verify_uri; + if (auto* s = mp::ComputeDirectoryAssetUri(stage_dir, &verify_uri)) { + fs::remove_all(stage_dir, ec); + return s; + } + if (verify_uri != uri) { + fs::remove_all(stage_dir, ec); + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "Shared asset source mutated during commit: expected " + + uri + ", staged " + verify_uri + "."); + } + fs::rename(stage_dir, final_dir, ec); + if (ec) { + fs::remove_all(stage_dir, ec); + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Rename shared asset dir '" + stage_dir.string() + "' -> '" + + final_dir.string() + "' failed: " + ec.message()); + } + if (auto* s = FsyncPath(assets_root, /*is_dir=*/true)) return s; + } + return nullptr; +} + +ModelPackageStatus* CommitExternalComponents(ModelPackage* pkg) { + // Write each external component's current in-memory body to its disk file. + // These are library-owned; for in-place PRESERVE commit we re-emit them + // every time (cheaper than tracking dirtiness). External executor_info + // files are opaque and intentionally left untouched. + for (const auto& comp : pkg->components) { + if (comp->storage != mp::ComponentStorage::kExternal) continue; + fs::path path = comp->external_path; + std::error_code ec; + fs::create_directories(path.parent_path(), ec); + std::string text = SerializeComponentBody(comp.get()).dump(2) + "\n"; + if (auto* s = WriteFileAtomic(path, text)) return s; + } + return nullptr; +} + +ModelPackageStatus* CommitInPlace(ModelPackage* pkg, ModelPackageWriteMode mode) { + if (pkg->package_root.empty()) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "Commit: package has no package_root. Use dest_root variant."); + } + std::error_code ec; + if (!fs::is_directory(pkg->package_root, ec)) { + fs::create_directories(pkg->package_root, ec); + if (ec) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Cannot create package_root '" + pkg->package_root.string() + + "': " + ec.message()); + } + } + + // Portable confinement pre-flight for external paths. + if (pkg->layout == "portable") { + for (const auto& comp : pkg->components) { + if (comp->storage == mp::ComponentStorage::kExternal) { + if (auto* s = CheckPortableConfinement(pkg->package_root, comp->external_path, + "component '" + comp->name + "'")) return s; + } + } + } + + // Dense mode: flatten external components into manifest before writing. + if (mode == MODEL_PACKAGE_WRITE_DENSE) { + if (auto* s = CheckDenseConstraints(pkg)) return s; + for (auto& comp : pkg->components) { + if (comp->storage == mp::ComponentStorage::kExternal) { + pkg->manifest["components"][comp->name] = comp->body; + // After commit, this becomes inline. + comp->storage = mp::ComponentStorage::kInline; + comp->external_path.clear(); + comp->component_dir = pkg->package_root; + } + } + } + + if (auto* s = CommitSharedAssetsCopyIn(pkg, pkg->package_root)) return s; + if (mode != MODEL_PACKAGE_WRITE_DENSE) { + if (auto* s = CommitExternalComponents(pkg)) return s; + } + + // Final manifest write. + fs::path manifest_path = pkg->package_root / "manifest.json"; + if (auto* s = WriteFileAtomic(manifest_path, SerializeManifestForCommit(pkg))) return s; + + pkg->pending_shared_asset_copies.clear(); + + // Re-derive shared assets + info view to pick up the materialized assets. + if (auto* s = mp::RefreshSharedAssets(pkg, mp::PathOptionsFor(pkg))) return s; + if (auto* s = mp::RefreshPackageMetadata(pkg)) return s; + mp::DropViewCache(pkg); + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// dest_root commit ("save as"): write to dest_root, then re-parse & swap. +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* CommitToDestRoot(ModelPackage* pkg, + const fs::path& dest_root, + ModelPackageWriteMode mode) { + std::error_code ec; + if (fs::exists(dest_root, ec)) { + if (!fs::is_directory(dest_root, ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "Commit dest_root '" + dest_root.string() + "' exists and is not a directory."); + } + if (!fs::is_empty(dest_root, ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "Commit dest_root '" + dest_root.string() + "' is not empty."); + } + } else { + fs::create_directories(dest_root, ec); + if (ec) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Cannot create dest_root '" + dest_root.string() + "': " + ec.message()); + } + } + + // Build a snapshot manifest mirroring `pkg->manifest`, then handle assets. + ordered_json manifest = pkg->manifest; + + // Dense mode constraints up-front. + if (mode == MODEL_PACKAGE_WRITE_DENSE) { + if (auto* s = CheckDenseConstraints(pkg)) return s; + for (const auto& comp : pkg->components) { + if (comp->storage == mp::ComponentStorage::kExternal) { + manifest["components"][comp->name] = comp->body; + } + } + } + + // Copy all shared assets into dest_root. Any manifest override entries are + // re-mapped to the default convention path under dest_root. + fs::path assets_root = dest_root / "shared_assets"; + // Gather source dirs for every URI we know about. + // 1. URIs already on disk (under current package_root) and not in pending: copy from there. + // 2. Pending copy_in sources: copy from staged source. + // 3. Manifest override entries: copy from the override path. + std::vector> to_copy; + for (const auto& rec : pkg->shared_assets) { + auto pit = pkg->pending_shared_asset_copies.find(rec->uri); + if (pit != pkg->pending_shared_asset_copies.end()) { + to_copy.emplace_back(rec->uri, pit->second); + } else { + to_copy.emplace_back(rec->uri, rec->resolved_path); + } + } + // Pending copies without a SharedAssetRecord shouldn't happen now that + // LoadSharedAssets surfaces pending copies, but stay defensive. + for (const auto& [uri, src] : pkg->pending_shared_asset_copies) { + if (pkg->shared_asset_index_by_uri.find(uri) == pkg->shared_asset_index_by_uri.end()) { + to_copy.emplace_back(uri, src); + } + } + // Only materialize shared_assets/ when something will actually land in it. + if (!to_copy.empty()) { + fs::create_directories(assets_root, ec); + } + + for (const auto& [uri, src] : to_copy) { + if (!fs::is_directory(src, ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + "Commit dest_root: shared asset source '" + src.string() + + "' for " + uri + " is not a directory."); + } + std::string dir_name = mp::DefaultSharedAssetDirName(uri); + fs::path final_dir = assets_root / dir_name; + fs::path stage_dir = assets_root / (dir_name + ".tmp." + RandomSuffix()); + if (auto* s = CopyTreeNoFollow(src, stage_dir)) { + fs::remove_all(stage_dir, ec); + return s; + } + std::string verify_uri; + if (auto* s = mp::ComputeDirectoryAssetUri(stage_dir, &verify_uri)) { + fs::remove_all(stage_dir, ec); + return s; + } + if (verify_uri != uri) { + fs::remove_all(stage_dir, ec); + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "Shared asset hash mismatch during dest_root commit: expected " + + uri + ", staged " + verify_uri); + } + fs::rename(stage_dir, final_dir, ec); + if (ec) { + fs::remove_all(stage_dir, ec); + return MakeStatus(MODEL_PACKAGE_ERR_IO, "Rename failed: " + ec.message()); + } + } + // All assets now live at the default convention path; drop overrides. + manifest.erase("shared_assets"); + + // External components (PRESERVE mode): re-emit under dest_root using the same + // path string from the manifest. We treat the manifest string as relative to + // dest_root for portable mode; absolute paths are kept as-is iff the layout + // is installed. + if (mode == MODEL_PACKAGE_WRITE_PRESERVE) { + auto comps_it = manifest.find("components"); + if (comps_it != manifest.end() && comps_it->is_object()) { + for (auto e = comps_it->begin(); e != comps_it->end(); ++e) { + if (!e->is_string()) continue; + fs::path p(e->get()); + fs::path target; + if (p.is_absolute()) { + if (pkg->layout == "portable") { + return MakeStatus(MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + "dest_root commit (portable): component '" + e.key() + + "' has absolute path '" + p.string() + "'."); + } + target = p; + } else { + target = dest_root / p; + std::error_code ec2; + fs::path normalized = target.lexically_normal(); + if (normalized.string().find(dest_root.lexically_normal().string()) != 0) { + return MakeStatus(MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + "dest_root commit (portable): component '" + e.key() + + "' relative path '" + p.string() + "' escapes dest_root."); + } + target = normalized; + } + // Find the corresponding component body to write. + std::string ext_body; + for (const auto& comp : pkg->components) { + if (comp->name == e.key()) { + ext_body = comp->body.dump(2) + "\n"; + break; + } + } + std::error_code ec_md; + fs::create_directories(target.parent_path(), ec_md); + if (auto* s = WriteFileAtomic(target, ext_body)) return s; + } + } + } + + fs::path manifest_path = dest_root / "manifest.json"; + if (auto* s = WriteFileAtomic(manifest_path, manifest.dump(2) + "\n")) return s; + + // Re-parse the newly written package into a fresh state and swap in. + ModelPackageOpenOptions opts{}; + opts.allow_external_paths = pkg->allow_external_paths; + opts.follow_symlinks = pkg->follow_symlinks; + opts.strict_unknown_fields = pkg->strict_unknown_fields; + ModelPackage fresh; + if (auto* s = mp::ParsePackage(dest_root, opts, &fresh)) { + return s; + } + // Tear down the existing view cache for the old package, then swap. + mp::DropViewCache(pkg); + // Field-by-field swap (the opaque struct is non-trivial; std::swap of the + // struct works because all members are move/swap-friendly). + std::swap(pkg->package_root, fresh.package_root); + std::swap(pkg->manifest, fresh.manifest); + std::swap(pkg->layout, fresh.layout); + std::swap(pkg->components, fresh.components); + std::swap(pkg->shared_assets, fresh.shared_assets); + std::swap(pkg->component_index_by_name, fresh.component_index_by_name); + std::swap(pkg->shared_asset_index_by_uri, fresh.shared_asset_index_by_uri); + std::swap(pkg->package_name_cache, fresh.package_name_cache); + std::swap(pkg->package_version_cache, fresh.package_version_cache); + std::swap(pkg->description_cache, fresh.description_cache); + std::swap(pkg->layout_cache, fresh.layout_cache); + std::swap(pkg->additional_metadata_cache, fresh.additional_metadata_cache); + std::swap(pkg->schema_version_major, fresh.schema_version_major); + std::swap(pkg->schema_version_minor, fresh.schema_version_minor); + pkg->pending_shared_asset_copies.clear(); + pkg->info_cache.reset(); + + if (auto* s = mp::RefreshPackageMetadata(pkg)) return s; + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Prune +// ───────────────────────────────────────────────────────────────────────────── + +constexpr std::chrono::seconds kPruneGrace{60}; + +bool IsTmpName(const fs::path& p) { + std::string name = p.filename().string(); + return name.find(".tmp.") != std::string::npos; +} + +bool IsOldEnough(const fs::path& p) { + std::error_code ec; + auto last = fs::last_write_time(p, ec); + if (ec) return false; + auto now = decltype(last)::clock::now(); + return (now - last) >= kPruneGrace; +} + +bool IsAncestorOrEqual(const fs::path& ancestor, const fs::path& descendant) { + // ancestor == descendant, or descendant lives under ancestor (boundary aware). + auto a = ancestor.lexically_normal().generic_string(); + auto d = descendant.lexically_normal().generic_string(); + if (d.size() < a.size()) return false; + if (d.compare(0, a.size(), a) != 0) return false; + return d.size() == a.size() || d[a.size()] == '/'; +} + +std::vector CollectLiveDirs(const ModelPackage* pkg) { + std::vector out; + for (const auto& c : pkg->components) { + if (c->storage == mp::ComponentStorage::kExternal) { + out.push_back(c->component_dir); + } + for (const auto& v : c->variants) { + if (v->resolved_directory.has_value()) { + out.push_back(*v->resolved_directory); + } + } + } + return out; +} + +// Drop entries we've handled (removed, or unsafe to touch). Entries that +// reference live state stay for a future Prune call. Tracked orphans don't +// wait on the kPruneGrace window: they were recorded by an in-session +// mutation, so there's no concurrent writer to protect against. The grace +// window is still applied to the shared_assets sweep below, which discovers +// candidates fresh from disk. +void SweepOrphanDirs(ModelPackage* pkg, + std::vector* pending, + const std::vector& live_dirs) { + pending->erase(std::remove_if(pending->begin(), pending->end(), [&](const fs::path& p) { + if (!mp::IsInsidePackageRoot(pkg, p)) return true; // outside our scope + std::error_code ec; + if (!fs::exists(p, ec)) return true; + // Skip if any live dir IS p or lives under it; deleting would damage live state. + for (const auto& live : live_dirs) { + if (IsAncestorOrEqual(p, live)) return false; + } + fs::remove_all(p, ec); + return true; + }), + pending->end()); +} + +} // namespace + +namespace model_package { + +bool IsInsidePackageRoot(const ModelPackage* pkg, const fs::path& p) { + if (pkg->package_root.empty()) return false; + return IsAncestorOrEqual(pkg->package_root, p); +} + +void RecordOrphanVariantDir(ModelPackage* pkg, const VariantRecord& v) { + if (!v.resolved_directory.has_value()) return; + if (!IsInsidePackageRoot(pkg, *v.resolved_directory)) return; + pkg->pending_orphan_variant_dirs.push_back(*v.resolved_directory); +} + +void RecordOrphanComponent(ModelPackage* pkg, const ComponentRecord& c) { + for (const auto& v : c.variants) { + RecordOrphanVariantDir(pkg, *v); + } + if (c.storage == ComponentStorage::kExternal && + IsInsidePackageRoot(pkg, c.component_dir)) { + pkg->pending_orphan_component_dirs.push_back(c.component_dir); + } +} + +} // namespace model_package + +extern "C" { + +ModelPackageStatus* ModelPackage_Commit(ModelPackage* pkg, + const char* dest_root_or_null, + ModelPackageWriteMode mode) { + if (!pkg) return NullArg("pkg"); + if (dest_root_or_null) { + return CommitToDestRoot(pkg, fs::path(dest_root_or_null), mode); + } + return CommitInPlace(pkg, mode); +} + +ModelPackageStatus* ModelPackage_Prune(ModelPackage* pkg) { + if (!pkg) return NullArg("pkg"); + if (pkg->package_root.empty()) return nullptr; + + // Shared assets are NEVER auto-pruned. The library cannot prove an asset is + // unused without parsing every consumer's executor_info payload, and a + // mistaken delete is worse than disk bloat for content-addressed dirs that + // dedupe naturally. Callers reclaim shared assets via explicit + // ModelPackage_RemoveSharedAsset(uri) (which still requires consumer-aware + // knowledge of what's reachable). + // + // Stale `.tmp.` staging dirs from interrupted commits are reclaimed + // here after a grace window: they belong to this library's own staging + // protocol and aren't user data. + std::error_code ec; + fs::path assets_root = pkg->package_root / "shared_assets"; + if (fs::is_directory(assets_root, ec)) { + for (const auto& entry : fs::directory_iterator(assets_root, ec)) { + if (ec) break; + if (!entry.is_directory()) continue; + if (!IsTmpName(entry.path())) continue; + if (!IsOldEnough(entry.path())) continue; + fs::remove_all(entry.path(), ec); + } + } + + // Tracked-orphan sweep: components before variants so a component_dir + // removal reclaims its child variant dirs in one shot. + std::vector live_dirs = CollectLiveDirs(pkg); + SweepOrphanDirs(pkg, &pkg->pending_orphan_component_dirs, live_dirs); + SweepOrphanDirs(pkg, &pkg->pending_orphan_variant_dirs, live_dirs); + + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Validate +// ───────────────────────────────────────────────────────────────────────────── + +namespace { + +void AddFinding(ordered_json* arr, const std::string& code, const std::string& msg) { + ordered_json e = ordered_json::object(); + e["code"] = code; + e["message"] = msg; + arr->push_back(e); +} + +} // namespace + +ModelPackageStatus* ModelPackage_Validate(ModelPackage* pkg, int flags, + const char** out_report_json) { + if (!pkg) return NullArg("pkg"); + if (!out_report_json) return NullArg("out_report_json"); + *out_report_json = nullptr; + ordered_json report = ordered_json::object(); + report["errors"] = ordered_json::array(); + report["warnings"] = ordered_json::array(); + ordered_json* errors = &report["errors"]; + ordered_json* warnings = &report["warnings"]; + + std::error_code ec; + + // SCHEMA: re-validate the in-memory manifest by serializing then re-parsing + // into a scratch ModelPackage with strict mode. Validates schema for both + // committed and uncommitted state. + if (flags & MODEL_PACKAGE_VALIDATE_SCHEMA) { + // Re-run each component/variant through the parser to confirm shape. + for (const auto& comp : pkg->components) { + mp::ComponentRecord scratch; + auto opts = mp::PathOptionsFor(pkg); + if (auto* s = mp::ParseComponentBody(pkg->package_root, opts, + /*strict=*/true, + comp->name, comp->body, + comp->component_dir, &scratch)) { + AddFinding(errors, "SCHEMA", std::string("component '") + comp->name + "': " + ModelPackageStatus_Message(s)); + ModelPackageStatus_Release(s); + } + } + } + + // PATHS: each external component's path on disk; each shared-asset resolved_path exists. + if (flags & MODEL_PACKAGE_VALIDATE_PATHS) { + for (const auto& comp : pkg->components) { + if (comp->storage == mp::ComponentStorage::kExternal) { + if (!fs::exists(comp->external_path, ec)) { + AddFinding(warnings, "PATHS", + "component '" + comp->name + "' external file does not exist: " + + comp->external_path.string()); + } + } + } + for (const auto& rec : pkg->shared_assets) { + if (!fs::is_directory(rec->resolved_path, ec)) { + AddFinding(warnings, "PATHS", + "shared asset " + rec->uri + " resolved path is not a directory: " + + rec->resolved_path.string()); + } + } + } + + // ASSET_REHASH: re-hash each on-disk shared asset and compare to its URI. + if (flags & MODEL_PACKAGE_VALIDATE_ASSET_REHASH) { + for (const auto& rec : pkg->shared_assets) { + if (!fs::is_directory(rec->resolved_path, ec)) continue; // PATHS / REACH covers this. + std::string computed; + if (auto* s = mp::ComputeDirectoryAssetUri(rec->resolved_path, &computed)) { + AddFinding(errors, "ASSET_REHASH", + "shared asset " + rec->uri + ": hashing failed: " + + ModelPackageStatus_Message(s)); + ModelPackageStatus_Release(s); + continue; + } + if (computed != rec->uri) { + AddFinding(errors, "ASSET_REHASH", + "shared asset " + rec->uri + " on-disk hash differs: " + computed); + } + } + } + + // UNKNOWN_FIELDS: re-run with strict=true (only flags top-level / known scopes). + if (flags & MODEL_PACKAGE_VALIDATE_UNKNOWN_FIELDS) { + static const char* kKnown[] = { + "schema_version", "package_name", "package_version", "description", + "layout", "components", "shared_assets", "additional_metadata"}; + for (auto it = pkg->manifest.begin(); it != pkg->manifest.end(); ++it) { + bool found = false; + for (auto* k : kKnown) + if (it.key() == k) { + found = true; + break; + } + if (!found) { + AddFinding(warnings, "UNKNOWN_FIELDS", + "manifest contains unknown field '" + it.key() + "'."); + } + } + } + + pkg->last_validate_report = report.dump(2); + *out_report_json = pkg->last_validate_report->c_str(); + if (!errors->empty()) { + return MakeStatus(MODEL_PACKAGE_ERR_STATE, + "ModelPackage_Validate: " + std::to_string(errors->size()) + + " error(s) found. See out_report_json for details."); + } + return nullptr; +} + +} // extern "C" diff --git a/model_package/src/manifest_parser.cc b/model_package/src/manifest_parser.cc new file mode 100644 index 0000000000000..58b7ad16fdd7d --- /dev/null +++ b/model_package/src/manifest_parser.cc @@ -0,0 +1,732 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "manifest_parser.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "path_resolver.h" +#include "status_impl.h" + +namespace fs = std::filesystem; + +namespace model_package { + +namespace { + +// The on-disk schema_version is a "." string (e.g. "1.0"). The major gates +// compatibility; the minor is informational and tells consumers which optional fields may +// be present. This build understands schema majors in [kMinSupportedSchemaMajor, +// kMaxSupportedSchemaMajor] and any minor: schema evolution within a major is additive and +// backward-compatible (newer minors only add optional fields), so a single parser reads +// every minor. A package whose major is below the minimum predates a breaking change and +// must be re-authored; one above the maximum was produced by a newer toolchain this build +// does not understand. kMaxKnownSchemaMinor is the highest minor this build authored/knows; +// a package with a higher minor is still accepted but may carry unknown optional fields, +// which are tolerated rather than rejected. +constexpr int64_t kMinSupportedSchemaMajor = 1; +constexpr int64_t kMaxSupportedSchemaMajor = 1; +constexpr int64_t kMaxKnownSchemaMinor = 0; +constexpr const char* kManifestFileName = "manifest.json"; +constexpr const char* kComponentFileName = "component.json"; + +constexpr const char* kSchemaVersionKey = "schema_version"; +constexpr const char* kPackageNameKey = "package_name"; +constexpr const char* kPackageVersionKey = "package_version"; +constexpr const char* kDescriptionKey = "description"; +constexpr const char* kLayoutKey = "layout"; +constexpr const char* kComponentsKey = "components"; +constexpr const char* kSharedAssetsKey = "shared_assets"; +constexpr const char* kAdditionalMetadataKey = "additional_metadata"; + +constexpr const char* kComponentNameKey = "component_name"; +constexpr const char* kVariantsKey = "variants"; + +constexpr const char* kVariantDirectoryKey = "variant_directory"; +constexpr const char* kEpKey = "ep"; +constexpr const char* kDeviceKey = "device"; +constexpr const char* kCompatibilityStringKey = "compatibility_string"; +constexpr const char* kExecutorInfoKey = "executor_info"; + +static const std::set kManifestKnownKeys = { + kSchemaVersionKey, + kPackageNameKey, + kPackageVersionKey, + kDescriptionKey, + kLayoutKey, + kComponentsKey, + kSharedAssetsKey, + kAdditionalMetadataKey, +}; + +static const std::set kComponentKnownKeys = { + kComponentNameKey, + kVariantsKey, + kAdditionalMetadataKey, +}; + +static const std::set kVariantKnownKeys = { + kVariantDirectoryKey, + kEpKey, + kDeviceKey, + kCompatibilityStringKey, + kExecutorInfoKey, + kAdditionalMetadataKey, +}; + +ModelPackageStatus* ReadFileToString(const fs::path& path, std::string* out) { + std::ifstream f(path, std::ios::binary); + if (!f) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Cannot open file: '" + path.string() + "': " + std::strerror(errno)); + } + std::ostringstream buf; + buf << f.rdbuf(); + *out = buf.str(); + return nullptr; +} + +ModelPackageStatus* ParseJsonFile(const fs::path& path, ordered_json* out) { + std::string contents; + if (auto* s = ReadFileToString(path, &contents)) return s; + try { + *out = ordered_json::parse(contents); + } catch (const ordered_json::parse_error& e) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "Failed to parse JSON at '" + path.string() + "': " + e.what()); + } + return nullptr; +} + +ModelPackageStatus* ExpectObject(const ordered_json& j, const std::string& where) { + if (!j.is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, where + ": expected a JSON object."); + } + return nullptr; +} + +ModelPackageStatus* CheckUnknownFields(const ordered_json& obj, + const std::set& known, + const std::string& where, + bool strict) { + if (!strict) return nullptr; + for (auto it = obj.begin(); it != obj.end(); ++it) { + if (known.find(it.key()) == known.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + where + ": unknown field '" + it.key() + "'."); + } + } + return nullptr; +} + +ModelPackageStatus* ResolveVariantDirectory(const fs::path& component_dir, + const fs::path& package_root, + const ordered_json& variant_body, + const std::string& variant_name, + const PathResolverOptions& opts, + bool require_exists, + std::optional* out) { + auto it = variant_body.find(kVariantDirectoryKey); + bool explicitly_declared = (it != variant_body.end()); + std::string dir_input; + if (explicitly_declared) { + if (!it->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "variant '" + variant_name + "': variant_directory must be a string."); + } + dir_input = it->get(); + } else { + // Inferred default: missing-on-disk is fine; we just leave out unset. + dir_input = variant_name; + } + + fs::path resolved; + // Explicit value must exist; inferred default may not. + bool must_exist = require_exists || explicitly_declared; + auto* status = ResolvePath(component_dir, package_root, dir_input, opts, + must_exist, &resolved); + if (status) { + if (!must_exist && ModelPackageStatus_Code(status) == MODEL_PACKAGE_ERR_NOT_FOUND) { + ModelPackageStatus_Release(status); + *out = std::nullopt; + return nullptr; + } + return status; + } + std::error_code ec; + if (fs::exists(resolved, ec)) { + *out = resolved; + } else { + *out = std::nullopt; + } + return nullptr; +} + +ModelPackageStatus* ParseVariant(const fs::path& component_dir, + const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& variant_name, + const ordered_json& variant_body, + VariantRecord* out); +ModelPackageStatus* ParseComponent(const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& component_name, + const ordered_json& body, + const fs::path& component_dir, + ComponentRecord* out); +ModelPackageStatus* LoadSharedAssets(ModelPackage* pkg, const PathResolverOptions& opts); +ModelPackageStatus* PopulatePackageMetadata(ModelPackage* pkg); + +ModelPackageStatus* ParseVariant(const fs::path& component_dir, + const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& variant_name, + const ordered_json& variant_body, + VariantRecord* out) { + if (auto* s = ExpectObject(variant_body, "variant '" + variant_name + "'")) return s; + if (auto* s = CheckUnknownFields(variant_body, kVariantKnownKeys, + "variant '" + variant_name + "'", strict)) + return s; + + out->name = variant_name; + out->body = variant_body; + out->name_cache = variant_name; + + auto stringopt = [&](const char* key, std::optional* dst) -> ModelPackageStatus* { + auto it = variant_body.find(key); + if (it == variant_body.end()) return nullptr; + if (!it->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + std::string("variant '") + variant_name + "': '" + key + + "' must be a string."); + } + *dst = it->get(); + return nullptr; + }; + if (auto* s = stringopt(kEpKey, &out->ep_cache)) return s; + if (auto* s = stringopt(kDeviceKey, &out->device_cache)) return s; + if (auto* s = stringopt(kCompatibilityStringKey, &out->compatibility_string_cache)) return s; + + auto ei_it = variant_body.find(kExecutorInfoKey); + if (ei_it != variant_body.end()) { + if (!ei_it->is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "variant '" + variant_name + "': executor_info must be an object."); + } + for (auto e = ei_it->begin(); e != ei_it->end(); ++e) { + if (!e->is_string() && !e->is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "variant '" + variant_name + "': executor_info['" + e.key() + + "'] must be a string (path) or object (inline)."); + } + } + } + + // Resolve variant_directory if declared (records the resolved path when it + // exists on disk). We do NOT require the directory to exist here: executor + // semantics are not the library's concern, and executors must resolve their + // own file references against variant_directory at load time anyway. + std::optional resolved_dir; + auto* status = ResolveVariantDirectory(component_dir, package_root, variant_body, + variant_name, opts, + /*require_exists=*/false, &resolved_dir); + if (status) return status; + out->resolved_directory = resolved_dir; + out->resolved_directory_attempted = true; + if (resolved_dir.has_value()) { + out->resolved_directory_cache = resolved_dir->string(); + } + + return nullptr; +} + +ModelPackageStatus* ParseComponent(const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& component_name, + const ordered_json& body, + const fs::path& component_dir, + ComponentRecord* out) { + if (auto* s = ExpectObject(body, "component '" + component_name + "'")) return s; + if (auto* s = CheckUnknownFields(body, kComponentKnownKeys, + "component '" + component_name + "'", strict)) + return s; + out->name = component_name; + out->name_cache = component_name; + out->component_dir = component_dir; + out->body = body; + + // Optional component_name override — for now we just sanity-check it. + auto cn_it = body.find(kComponentNameKey); + if (cn_it != body.end() && !cn_it->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "component '" + component_name + "': component_name must be a string."); + } + + auto variants_it = body.find(kVariantsKey); + if (variants_it == body.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "component '" + component_name + "': missing required 'variants' object."); + } + if (!variants_it->is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "component '" + component_name + "': 'variants' must be an object."); + } + for (auto v = variants_it->begin(); v != variants_it->end(); ++v) { + auto vr = std::make_unique(); + if (auto* s = ParseVariant(component_dir, package_root, opts, strict, + v.key(), v.value(), vr.get())) { + return s; + } + out->variants.push_back(std::move(vr)); + } + return nullptr; +} + +ModelPackageStatus* LoadComponentForEntry(const fs::path& manifest_dir, + const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& name, + const ordered_json& value, + std::unique_ptr* out) { + auto rec = std::make_unique(); + if (value.is_string()) { + rec->storage = ComponentStorage::kExternal; + fs::path resolved; + if (auto* s = ResolvePath(manifest_dir, package_root, value.get(), + opts, /*must_exist=*/true, &resolved)) { + return s; + } + std::error_code ec; + if (fs::is_directory(resolved, ec)) { + resolved /= kComponentFileName; + if (!fs::exists(resolved)) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + "component '" + name + "': directory has no '" + + kComponentFileName + "'."); + } + } + rec->external_path = resolved; + ordered_json body; + if (auto* s = ParseJsonFile(resolved, &body)) return s; + fs::path component_dir = resolved.parent_path(); + if (auto* s = ParseComponent(package_root, opts, strict, name, body, component_dir, rec.get())) { + return s; + } + } else if (value.is_object()) { + rec->storage = ComponentStorage::kInline; + if (auto* s = ParseComponent(package_root, opts, strict, name, value, manifest_dir, rec.get())) { + return s; + } + } else { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "component '" + name + "': value must be a string (path) or object (inline)."); + } + *out = std::move(rec); + return nullptr; +} + +ModelPackageStatus* LoadSharedAssets(ModelPackage* pkg, const PathResolverOptions& opts) { + // Source-of-truth ordering for the assembled shared_assets vector: + // 1. Manifest overrides (in declaration order). These specify a custom + // on-disk path for an asset URI (e.g. a system-wide cache or another + // location outside /shared_assets/). + // 2. Discovered sha256- subdirectories under /shared_assets/. + // These resolve to the default-convention path. Missing shared_assets/ is + // not an error: portable packages may not ship one yet, installed + // packages may route everything through overrides. + // 3. Pending copy_in assets from the authoring API that haven't been + // committed yet. These surface immediately so callers see the asset + // they just added; the staged source dir is reported as resolved_path + // until commit materializes it under shared_assets/. + // Within each tier, an URI that's already known is skipped. + std::vector ordered_uris; + std::unordered_map override_paths; + + auto sa_it = pkg->manifest.find(kSharedAssetsKey); + if (sa_it != pkg->manifest.end()) { + if (!sa_it->is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: shared_assets must be an object."); + } + for (auto e = sa_it->begin(); e != sa_it->end(); ++e) { + const std::string uri = e.key(); + if (!IsSha256AssetUri(uri)) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: shared_assets key '" + uri + "' is not a valid sha256: URI."); + } + if (!e->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: shared_assets['" + uri + "'] must be a string path."); + } + ordered_uris.push_back(uri); + override_paths.emplace(uri, e->get()); + } + } + std::set seen(ordered_uris.begin(), ordered_uris.end()); + + // Tier 2: discover sha256- dirs under /shared_assets/. + fs::path assets_root = pkg->package_root / "shared_assets"; + std::error_code ec; + if (!pkg->package_root.empty() && fs::is_directory(assets_root, ec)) { + std::vector discovered; + for (const auto& entry : fs::directory_iterator(assets_root, ec)) { + if (ec) break; + if (!entry.is_directory(ec)) continue; + std::string name = entry.path().filename().string(); + std::string uri = SharedAssetUriFromDirName(name); + if (uri.empty()) continue; // not a sha256- dir; ignore (.tmp staging, etc.) + if (!seen.insert(uri).second) continue; + discovered.push_back(std::move(uri)); + } + std::sort(discovered.begin(), discovered.end()); + for (auto& uri : discovered) ordered_uris.push_back(std::move(uri)); + } + + // Tier 3: pending copy_in entries. + for (const auto& [uri, src] : pkg->pending_shared_asset_copies) { + if (!seen.insert(uri).second) continue; + ordered_uris.push_back(uri); + } + + for (const auto& uri : ordered_uris) { + auto rec = std::make_unique(); + rec->uri = uri; + rec->uri_cache = uri; + auto override_it = override_paths.find(uri); + fs::path resolved; + if (override_it != override_paths.end()) { + if (auto* s = ResolvePath(pkg->package_root, pkg->package_root, override_it->second, + opts, /*must_exist=*/false, &resolved)) { + return s; + } + } else if (auto pending_it = pkg->pending_shared_asset_copies.find(uri); + pending_it != pkg->pending_shared_asset_copies.end() && + override_paths.find(uri) == override_paths.end()) { + // Pending copy_in with no override: surface the staged source until commit. + resolved = pending_it->second; + } else { + // Default convention: /shared_assets/sha256-/ + resolved = assets_root / DefaultSharedAssetDirName(uri); + } + rec->resolved_path = resolved; + rec->resolved_path_cache = resolved.string(); + pkg->shared_asset_index_by_uri.emplace(uri, pkg->shared_assets.size()); + pkg->shared_assets.push_back(std::move(rec)); + } + return nullptr; +} + +ModelPackageStatus* ParseSchemaVersion(ModelPackage* pkg) { + auto sv_it = pkg->manifest.find(kSchemaVersionKey); + if (sv_it == pkg->manifest.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: missing required 'schema_version'."); + } + + // schema_version is a "." string (e.g. "1.0"). A bare integer is accepted + // as shorthand for ".0". + int64_t major = 0; + int64_t minor = 0; + if (sv_it->is_string()) { + const std::string sv = sv_it->get(); + const size_t dot = sv.find('.'); + const std::string major_str = (dot == std::string::npos) ? sv : sv.substr(0, dot); + const std::string minor_str = (dot == std::string::npos) ? std::string("0") : sv.substr(dot + 1); + auto parse_part = [](const std::string& s, int64_t* out) -> bool { + if (s.empty() || s.find_first_not_of("0123456789") != std::string::npos) return false; + try { + *out = std::stoll(s); + } catch (const std::exception&) { + return false; + } + return true; + }; + if (dot != std::string::npos && minor_str.find('.') != std::string::npos) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: 'schema_version' must be a \".\" string."); + } + if (!parse_part(major_str, &major) || !parse_part(minor_str, &minor)) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: 'schema_version' must be a \".\" string."); + } + } else if (sv_it->is_number_integer() || sv_it->is_number_unsigned()) { + major = sv_it->get(); + minor = 0; + } else { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: 'schema_version' must be a \".\" string."); + } + + if (major < kMinSupportedSchemaMajor || major > kMaxSupportedSchemaMajor) { + std::string supported = (kMinSupportedSchemaMajor == kMaxSupportedSchemaMajor) + ? std::to_string(kMinSupportedSchemaMajor) + : std::to_string(kMinSupportedSchemaMajor) + "-" + + std::to_string(kMaxSupportedSchemaMajor); + return MakeStatus(MODEL_PACKAGE_ERR_VERSION, + "manifest: schema_version major " + std::to_string(major) + + " is not supported (this build supports major " + supported + ")."); + } + pkg->schema_version_major = major; + pkg->schema_version_minor = minor; + + // A package authored at a newer minor than this build knows may carry optional fields this + // build does not recognize. Those are additive and must be tolerated rather than rejected, + // so relax unknown-field strictness for a newer minor. + if (minor > kMaxKnownSchemaMinor) { + pkg->strict_unknown_fields = false; + } + return nullptr; +} + +ModelPackageStatus* PopulatePackageMetadata(ModelPackage* pkg) { + auto stropt = [&](const char* key, std::optional* dst) -> ModelPackageStatus* { + auto it = pkg->manifest.find(key); + if (it == pkg->manifest.end()) { + dst->reset(); + return nullptr; + } + if (!it->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + std::string("manifest: '") + key + "' must be a string."); + } + *dst = it->get(); + return nullptr; + }; + if (auto* s = stropt(kPackageNameKey, &pkg->package_name_cache)) return s; + if (auto* s = stropt(kPackageVersionKey, &pkg->package_version_cache)) return s; + if (auto* s = stropt(kDescriptionKey, &pkg->description_cache)) return s; + + // layout: default "portable" + auto layout_it = pkg->manifest.find(kLayoutKey); + if (layout_it != pkg->manifest.end()) { + if (!layout_it->is_string()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, "manifest: 'layout' must be a string."); + } + pkg->layout = layout_it->get(); + if (pkg->layout != "portable" && pkg->layout != "installed") { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: 'layout' must be 'portable' or 'installed'."); + } + } else { + pkg->layout = "portable"; + } + pkg->layout_cache = pkg->layout; + + // additional_metadata: serialize as JSON string if present. + auto am_it = pkg->manifest.find(kAdditionalMetadataKey); + if (am_it != pkg->manifest.end()) { + pkg->additional_metadata_cache = am_it->dump(); + } else { + pkg->additional_metadata_cache.reset(); + } + return nullptr; +} + +} // namespace + +PathResolverOptions PathOptionsFor(const ModelPackage* pkg) { + PathResolverOptions o; + o.follow_symlinks = pkg->follow_symlinks; + o.allow_external_paths = pkg->allow_external_paths || (pkg->layout == "installed"); + return o; +} + +ModelPackageStatus* ParseVariantBody(const fs::path& component_dir, + const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& variant_name, + const ordered_json& variant_body, + VariantRecord* out) { + return ParseVariant(component_dir, package_root, opts, strict, variant_name, variant_body, out); +} + +ModelPackageStatus* ParseComponentBody(const fs::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& component_name, + const ordered_json& body, + const fs::path& component_dir, + ComponentRecord* out) { + return ParseComponent(package_root, opts, strict, component_name, body, component_dir, out); +} + +ModelPackageStatus* RefreshPackageMetadata(ModelPackage* pkg) { + pkg->package_name_cache.reset(); + pkg->package_version_cache.reset(); + pkg->description_cache.reset(); + pkg->additional_metadata_cache.reset(); + return PopulatePackageMetadata(pkg); +} + +ModelPackageStatus* RefreshSharedAssets(ModelPackage* pkg, const PathResolverOptions& opts) { + pkg->shared_assets.clear(); + pkg->shared_asset_index_by_uri.clear(); + return LoadSharedAssets(pkg, opts); +} + +namespace { + +ModelPackageStatus* ResolveExecutorInfoEntry(const ModelPackage* pkg, + const VariantRecord& var, + const std::string& ns, + const ordered_json& entry, + bool strict_missing_external, + std::string* dst_json) { + if (entry.is_object()) { + *dst_json = entry.dump(); + return nullptr; + } + if (entry.is_string()) { + if (!var.resolved_directory.has_value()) { + if (!strict_missing_external) { + dst_json->clear(); + return nullptr; + } + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + "variant '" + var.name + "': executor_info['" + ns + + "'] points at an external file but the variant has no " + "resolved variant_directory to anchor it."); + } + PathResolverOptions opts = PathOptionsFor(pkg); + fs::path resolved; + if (auto* s = ResolvePath(*var.resolved_directory, pkg->package_root, + entry.get(), opts, + /*must_exist=*/strict_missing_external, &resolved)) { + if (!strict_missing_external) { + ModelPackageStatus_Release(s); + dst_json->clear(); + return nullptr; + } + return s; + } + std::ifstream f(resolved, std::ios::binary); + if (!f) { + if (!strict_missing_external) { + dst_json->clear(); + return nullptr; + } + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "Cannot open executor_info file: '" + resolved.string() + "'."); + } + std::ostringstream buf; + buf << f.rdbuf(); + std::string contents = buf.str(); + try { + auto _ = ordered_json::parse(contents); + (void)_; + } catch (const std::exception& e) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + std::string("Failed to parse executor_info JSON at '") + + resolved.string() + "': " + e.what()); + } + *dst_json = std::move(contents); + return nullptr; + } + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "variant '" + var.name + "': executor_info['" + ns + + "'] must be a string or object."); +} + +} // namespace + +ModelPackageStatus* RefreshExecutorInfoCache(ModelPackage* pkg, bool strict_missing_external) { + for (auto& comp : pkg->components) { + for (auto& vp : comp->variants) { + VariantRecord& var = *vp; + var.executor_info_resolved.clear(); + auto ei_it = var.body.find("executor_info"); + if (ei_it == var.body.end() || !ei_it->is_object()) continue; + var.executor_info_resolved.reserve(ei_it->size()); + for (auto e = ei_it->begin(); e != ei_it->end(); ++e) { + std::string body_json; + if (auto* s = ResolveExecutorInfoEntry(pkg, var, e.key(), e.value(), + strict_missing_external, &body_json)) { + return s; + } + var.executor_info_resolved.emplace_back(e.key(), std::move(body_json)); + } + } + } + return nullptr; +} + +ModelPackageStatus* ParsePackage(const fs::path& package_root, + const ModelPackageOpenOptions& opts, + ModelPackage* pkg) { + std::error_code ec; + if (!fs::exists(package_root, ec) || !fs::is_directory(package_root, ec)) { + return MakeStatus(MODEL_PACKAGE_ERR_IO, + "package_root '" + package_root.string() + "' is not a directory."); + } + pkg->package_root = fs::canonical(package_root, ec); + if (ec) pkg->package_root = package_root; + pkg->allow_external_paths = opts.allow_external_paths; + pkg->follow_symlinks = opts.follow_symlinks; + pkg->strict_unknown_fields = opts.strict_unknown_fields; + + fs::path manifest_path = pkg->package_root / kManifestFileName; + if (auto* s = ParseJsonFile(manifest_path, &pkg->manifest)) return s; + if (auto* s = ExpectObject(pkg->manifest, "manifest")) return s; + + // Validate the schema version first so an unsupported package fails fast, before any + // component/asset parsing. May relax pkg->strict_unknown_fields for a newer minor. + if (auto* s = ParseSchemaVersion(pkg)) return s; + + // Layout pre-read for path-resolver options. Done before strict-unknown + // check because we need the layout value to decide path-confinement. + PathResolverOptions presolve_opts; + presolve_opts.follow_symlinks = opts.follow_symlinks; + presolve_opts.allow_external_paths = opts.allow_external_paths; + { + auto layout_it = pkg->manifest.find(kLayoutKey); + if (layout_it != pkg->manifest.end() && layout_it->is_string() && + layout_it->get() == "installed") { + presolve_opts.allow_external_paths = true; + } + } + + if (auto* s = CheckUnknownFields(pkg->manifest, kManifestKnownKeys, "manifest", + pkg->strict_unknown_fields)) + return s; + + // Components. + auto comps_it = pkg->manifest.find(kComponentsKey); + if (comps_it == pkg->manifest.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, + "manifest: missing required 'components' object."); + } + if (!comps_it->is_object()) { + return MakeStatus(MODEL_PACKAGE_ERR_SCHEMA, "manifest: 'components' must be an object."); + } + for (auto e = comps_it->begin(); e != comps_it->end(); ++e) { + std::unique_ptr rec; + if (auto* s = LoadComponentForEntry(pkg->package_root, pkg->package_root, + presolve_opts, pkg->strict_unknown_fields, + e.key(), e.value(), &rec)) { + return s; + } + pkg->component_index_by_name.emplace(rec->name, pkg->components.size()); + pkg->components.push_back(std::move(rec)); + } + + if (auto* s = LoadSharedAssets(pkg, presolve_opts)) return s; + if (auto* s = PopulatePackageMetadata(pkg)) return s; + if (auto* s = RefreshExecutorInfoCache(pkg, /*strict_missing_external=*/true)) return s; + + return nullptr; +} + +} // namespace model_package diff --git a/model_package/src/manifest_parser.h b/model_package/src/manifest_parser.h new file mode 100644 index 0000000000000..d8266605440dc --- /dev/null +++ b/model_package/src/manifest_parser.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file manifest_parser.h +/// \brief Internal parser that reads a model package from disk into the +/// in-memory representation defined in model_package_impl.h. + +#pragma once + +#include "model_package_impl.h" +#include "path_resolver.h" + +namespace model_package { + +/// Parse the manifest at `/manifest.json` and all referenced +/// external component files, then populate `*pkg`. Caller owns `pkg`. +ModelPackageStatus* ParsePackage(const std::filesystem::path& package_root, + const ModelPackageOpenOptions& opts, + ModelPackage* pkg); + +/// Parse a single variant body into `out`. Used by authoring. +ModelPackageStatus* ParseVariantBody(const std::filesystem::path& component_dir, + const std::filesystem::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& variant_name, + const ordered_json& variant_body, + VariantRecord* out); + +/// Parse a single component body. `component_dir` is the directory used as the +/// base for the component's relative paths. +ModelPackageStatus* ParseComponentBody(const std::filesystem::path& package_root, + const PathResolverOptions& opts, + bool strict, + const std::string& component_name, + const ordered_json& body, + const std::filesystem::path& component_dir, + ComponentRecord* out); + +/// Re-derive package-level metadata (schema_version, package_name, version, +/// description, layout, additional_metadata) from `pkg->manifest` into the +/// package's stable string buffers. +ModelPackageStatus* RefreshPackageMetadata(ModelPackage* pkg); + +/// Re-derive `pkg->shared_assets` from `pkg->manifest.shared_assets` overrides, +/// plus any `sha256-` subdirectories discovered under +/// `/shared_assets/`, plus any pending copy_in entries staged via +/// the authoring API. Clears and replaces the existing shared_assets vector +/// and `shared_asset_index_by_uri`. +ModelPackageStatus* RefreshSharedAssets(ModelPackage* pkg, const PathResolverOptions& opts); + +/// Re-resolve every variant's executor_info entries into stable strings on the +/// VariantRecord (inline bodies dumped, external files loaded + JSON-parsed). +/// If `strict_missing_external` is true, missing external files are an error +/// (use at Open: the package is already published, files must be present); +/// if false, missing external files are recorded as an empty body (use during +/// authoring: callers may set the path before writing the file). Parse errors +/// on existing external files are always surfaced. +ModelPackageStatus* RefreshExecutorInfoCache(ModelPackage* pkg, bool strict_missing_external); + +/// Build PathResolverOptions appropriate for `pkg` (respects layout). +PathResolverOptions PathOptionsFor(const ModelPackage* pkg); + +} // namespace model_package diff --git a/model_package/src/model_package_impl.cc b/model_package/src/model_package_impl.cc new file mode 100644 index 0000000000000..cf383c659b1c7 --- /dev/null +++ b/model_package/src/model_package_impl.cc @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file model_package_impl.cc +/// \brief Implementation of the public C API declared in model_package.h. + +#include "model_package.h" + +#include +#include +#include +#include + +#include "asset_hasher.h" +#include "manifest_parser.h" +#include "model_package_impl.h" +#include "path_resolver.h" +#include "status_impl.h" + +namespace mp = model_package; +using mp::MakeStatus; + +namespace { + +ModelPackageStatus* NullArg(const char* name) { + return MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + std::string("model_package: '") + name + "' must not be null."); +} + +const char* OptStr(const std::optional& s) { + return s.has_value() ? s->c_str() : nullptr; +} + +} // namespace + +// ───────────────────────────────────────────────────────────────────────────── +// View cache materialization +// ───────────────────────────────────────────────────────────────────────────── + +namespace model_package { + +void DropViewCache(ModelPackage* pkg) { + if (!pkg) return; + pkg->info_cache.reset(); + for (auto& comp : pkg->components) { + comp->component_json_cache.reset(); + comp->additional_metadata_cache.reset(); + for (auto& var : comp->variants) { + var->variant_json_cache.reset(); + var->additional_metadata_cache.reset(); + } + } + pkg->additional_metadata_cache.reset(); +} + +const InfoViewCache& BuildOrGetViewCache(const ModelPackage* pkg) { + if (pkg->info_cache.has_value()) return *pkg->info_cache; + + pkg->info_cache.emplace(); + auto& cache = *pkg->info_cache; + const size_t num_components = pkg->components.size(); + + cache.executor_infos_storage.resize(num_components); + cache.variants_storage.resize(num_components); + cache.components.resize(num_components); + + for (size_t ci = 0; ci < num_components; ++ci) { + const auto& comp = *pkg->components[ci]; + const size_t num_variants = comp.variants.size(); + cache.executor_infos_storage[ci].clear(); + cache.variants_storage[ci].resize(num_variants); + + // Total executor_info entry count across this component's variants. + size_t total_execs = 0; + for (const auto& vp : comp.variants) { + total_execs += vp->executor_info_resolved.size(); + } + cache.executor_infos_storage[ci].reserve(total_execs); + + // First pass: append all executor_info entries so storage pointers stay + // stable for the second pass. + std::vector> ei_ranges(num_variants); + + for (size_t vi = 0; vi < num_variants; ++vi) { + const auto& var = *comp.variants[vi]; + size_t ei_begin = cache.executor_infos_storage[ci].size(); + // executor_info_resolved is populated eagerly by RefreshExecutorInfoCache + // (at Open and on every mutation); any parse/IO error surfaces there. + for (const auto& [ns_str, body_json] : var.executor_info_resolved) { + ModelExecutorInfoEntry entry{}; + entry.namespace_key = ns_str.c_str(); + entry.json = body_json.c_str(); + cache.executor_infos_storage[ci].push_back(entry); + } + ei_ranges[vi] = {ei_begin, cache.executor_infos_storage[ci].size()}; + } + + // Additional metadata strings live in the record-level cache; populate it + // lazily here as well. + for (size_t vi = 0; vi < num_variants; ++vi) { + auto& var = *comp.variants[vi]; + auto am_it = var.body.find("additional_metadata"); + if (am_it != var.body.end() && !var.additional_metadata_cache.has_value()) { + var.additional_metadata_cache = am_it->dump(); + } + } + if (auto am_it = comp.body.find("additional_metadata"); am_it != comp.body.end()) { + if (!comp.additional_metadata_cache.has_value()) { + comp.additional_metadata_cache = am_it->dump(); + } + } + + // Second pass: populate ModelVariantInfo entries pointing at the now-stable + // storage above. + for (size_t vi = 0; vi < num_variants; ++vi) { + const auto& var = *comp.variants[vi]; + ModelVariantInfo& vi_out = cache.variants_storage[ci][vi]; + vi_out = ModelVariantInfo{}; + vi_out.name = var.name_cache.c_str(); + vi_out.variant_directory = + var.resolved_directory_cache.has_value() ? var.resolved_directory_cache->c_str() : nullptr; + vi_out.ep = OptStr(var.ep_cache); + vi_out.device = OptStr(var.device_cache); + vi_out.compatibility_string = OptStr(var.compatibility_string_cache); + vi_out.additional_metadata_json = OptStr(var.additional_metadata_cache); + auto [ei_begin, ei_end] = ei_ranges[vi]; + vi_out.num_executor_infos = ei_end - ei_begin; + vi_out.executor_infos = + (vi_out.num_executor_infos > 0) ? &cache.executor_infos_storage[ci][ei_begin] : nullptr; + } + + ModelComponentInfo& ci_out = cache.components[ci]; + ci_out = ModelComponentInfo{}; + ci_out.name = comp.name_cache.c_str(); + ci_out.additional_metadata_json = OptStr(comp.additional_metadata_cache); + ci_out.num_variants = num_variants; + ci_out.variants = num_variants > 0 ? cache.variants_storage[ci].data() : nullptr; + } + + // Shared assets. + cache.shared_assets.resize(pkg->shared_assets.size()); + for (size_t i = 0; i < pkg->shared_assets.size(); ++i) { + const auto& rec = *pkg->shared_assets[i]; + ModelSharedAssetInfo& sa = cache.shared_assets[i]; + sa = ModelSharedAssetInfo{}; + sa.uri = rec.uri_cache.c_str(); + sa.resolved_path = rec.resolved_path_cache.c_str(); + } + + ModelPackageInfo& info = cache.info; + info = ModelPackageInfo{}; + info.schema_version_major = pkg->schema_version_major; + info.schema_version_minor = pkg->schema_version_minor; + info.package_name = OptStr(pkg->package_name_cache); + info.package_version = OptStr(pkg->package_version_cache); + info.description = OptStr(pkg->description_cache); + info.layout = pkg->layout_cache.c_str(); + info.additional_metadata_json = OptStr(pkg->additional_metadata_cache); + info.num_components = cache.components.size(); + info.components = cache.components.empty() ? nullptr : cache.components.data(); + info.num_shared_assets = cache.shared_assets.size(); + info.shared_assets = cache.shared_assets.empty() ? nullptr : cache.shared_assets.data(); + + return cache; +} + +} // namespace model_package + +// ───────────────────────────────────────────────────────────────────────────── +// Status helpers +// ───────────────────────────────────────────────────────────────────────────── + +extern "C" { + +const char* ModelPackageStatus_Message(const ModelPackageStatus* s) { + return s ? s->message.c_str() : nullptr; +} +ModelPackageErrorCode ModelPackageStatus_Code(const ModelPackageStatus* s) { + return s ? s->code : MODEL_PACKAGE_OK; +} +void ModelPackageStatus_Release(ModelPackageStatus* s) { + delete s; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Lifecycle +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_Open(const char* package_root, + const ModelPackageOpenOptions* opts, + ModelPackage** out) { + if (!package_root) return NullArg("package_root"); + if (!out) return NullArg("out"); + *out = nullptr; + + ModelPackageOpenOptions effective{}; + effective.allow_external_paths = false; + effective.follow_symlinks = true; + effective.strict_unknown_fields = true; + if (opts) { + effective.allow_external_paths = opts->allow_external_paths; + effective.follow_symlinks = opts->follow_symlinks; + effective.strict_unknown_fields = opts->strict_unknown_fields; + } + + auto pkg = std::make_unique(); + if (auto* s = mp::ParsePackage(std::filesystem::path(package_root), effective, pkg.get())) { + return s; + } + *out = pkg.release(); + return nullptr; +} + +void ModelPackage_Close(ModelPackage* pkg) { + if (!pkg) return; + mp::DropViewCache(pkg); + delete pkg; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Info tree + convenience lookups +// ───────────────────────────────────────────────────────────────────────────── + +const ModelPackageInfo* ModelPackage_Info(const ModelPackage* pkg) { + if (!pkg) return nullptr; + return &mp::BuildOrGetViewCache(pkg).info; +} + +const ModelComponentInfo* ModelPackage_FindComponent(const ModelPackageInfo* info, + const char* name) { + if (!info || !name) return nullptr; + for (size_t i = 0; i < info->num_components; ++i) { + if (info->components[i].name && std::strcmp(info->components[i].name, name) == 0) { + return &info->components[i]; + } + } + return nullptr; +} + +const ModelVariantInfo* ModelComponentInfo_FindVariant(const ModelComponentInfo* comp, + const char* name) { + if (!comp || !name) return nullptr; + for (size_t i = 0; i < comp->num_variants; ++i) { + if (comp->variants[i].name && std::strcmp(comp->variants[i].name, name) == 0) { + return &comp->variants[i]; + } + } + return nullptr; +} + +const ModelExecutorInfoEntry* ModelVariantInfo_FindExecutorInfo(const ModelVariantInfo* var, + const char* namespace_key) { + if (!var || !namespace_key) return nullptr; + for (size_t i = 0; i < var->num_executor_infos; ++i) { + if (var->executor_infos[i].namespace_key && + std::strcmp(var->executor_infos[i].namespace_key, namespace_key) == 0) { + return &var->executor_infos[i]; + } + } + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Shared assets +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_ResolveAssetUri(const ModelPackage* pkg, + const char* uri, + const char** out_path) { + if (!pkg) return NullArg("pkg"); + if (!uri) return NullArg("uri"); + if (!out_path) return NullArg("out_path"); + *out_path = nullptr; + auto it = pkg->shared_asset_index_by_uri.find(uri); + if (it == pkg->shared_asset_index_by_uri.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_ASSET_MISSING, + std::string("Asset URI not declared in this package: '") + uri + "'."); + } + *out_path = pkg->shared_assets[it->second]->resolved_path_cache.c_str(); + return nullptr; +} + +ModelPackageStatus* ModelPackage_ResolveStringRef(const ModelPackage* pkg, + const char* base_dir, + const char* input, + bool must_exist, + const char** out_path) { + if (!pkg) return NullArg("pkg"); + if (!input) return NullArg("input"); + if (!out_path) return NullArg("out_path"); + *out_path = nullptr; + static thread_local std::string slot; + + std::string uri_part, tail_part; + if (mp::TrySplitAssetUriPrefix(std::string(input), uri_part, tail_part)) { + auto asset_it = pkg->shared_asset_index_by_uri.find(uri_part); + if (asset_it == pkg->shared_asset_index_by_uri.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_ASSET_MISSING, + std::string("Asset URI not declared in this package: '") + uri_part + "'."); + } + const std::string& asset_folder = pkg->shared_assets[asset_it->second]->resolved_path_cache; + if (tail_part.empty()) { + slot = asset_folder; + *out_path = slot.c_str(); + return nullptr; + } + // Tail is resolved with portable confinement under the asset folder: + // no absolute, no `..`. follow_symlinks mirrors the package setting. + mp::PathResolverOptions tail_opts; + tail_opts.allow_external_paths = false; + tail_opts.follow_symlinks = pkg->follow_symlinks; + std::filesystem::path resolved; + if (auto* s = mp::ResolvePath(asset_folder, asset_folder, tail_part, tail_opts, + must_exist, &resolved)) { + return s; + } + slot = resolved.string(); + *out_path = slot.c_str(); + return nullptr; + } + + std::filesystem::path base = base_dir ? std::filesystem::path(base_dir) : pkg->package_root; + std::filesystem::path resolved; + if (auto* s = mp::ResolvePath(base, pkg->package_root, std::string(input), + mp::PathOptionsFor(pkg), must_exist, &resolved)) { + return s; + } + slot = resolved.string(); + *out_path = slot.c_str(); + return nullptr; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Round-trip JSON getters +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_GetComponentJson(const ModelPackage* pkg, + const char* component_name, + const char** out_json) { + if (!pkg) return NullArg("pkg"); + if (!component_name) return NullArg("component_name"); + if (!out_json) return NullArg("out_json"); + *out_json = nullptr; + auto it = pkg->component_index_by_name.find(component_name); + if (it == pkg->component_index_by_name.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("component '") + component_name + "' not found."); + } + auto& rec = pkg->components[it->second]; + if (!rec->component_json_cache.has_value()) { + rec->component_json_cache = rec->body.dump(); + } + *out_json = rec->component_json_cache->c_str(); + return nullptr; +} + +ModelPackageStatus* ModelPackage_GetVariantJson(const ModelPackage* pkg, + const char* component_name, + const char* variant_name, + const char** out_json) { + if (!pkg) return NullArg("pkg"); + if (!component_name) return NullArg("component_name"); + if (!variant_name) return NullArg("variant_name"); + if (!out_json) return NullArg("out_json"); + *out_json = nullptr; + auto it = pkg->component_index_by_name.find(component_name); + if (it == pkg->component_index_by_name.end()) { + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("component '") + component_name + "' not found."); + } + auto& comp = pkg->components[it->second]; + for (auto& var : comp->variants) { + if (var->name == variant_name) { + if (!var->variant_json_cache.has_value()) { + var->variant_json_cache = var->body.dump(); + } + *out_json = var->variant_json_cache->c_str(); + return nullptr; + } + } + return MakeStatus(MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("variant '") + variant_name + "' not found in component '" + + component_name + "'."); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Hashing utility +// ───────────────────────────────────────────────────────────────────────────── + +ModelPackageStatus* ModelPackage_ComputeDirectoryHash(const char* source_dir, + const char** out_uri) { + if (!source_dir) return NullArg("source_dir"); + if (!out_uri) return NullArg("out_uri"); + *out_uri = nullptr; + static thread_local std::string slot; + if (auto* s = mp::ComputeDirectoryAssetUri(std::filesystem::path(source_dir), &slot)) { + return s; + } + *out_uri = slot.c_str(); + return nullptr; +} + +} // extern "C" diff --git a/model_package/src/model_package_impl.h b/model_package/src/model_package_impl.h new file mode 100644 index 0000000000000..6770b9774e132 --- /dev/null +++ b/model_package/src/model_package_impl.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file model_package_impl.h +/// \brief Internal C++ representation of a ModelPackage handle. +/// +/// Records hold the parsed manifest data plus stable per-entity string buffers +/// so that all `const char*` exposed through the C API have package-owned +/// storage. The package builds an `InfoViewCache` lazily that materializes the +/// public POD struct tree returned by `ModelPackage_Info()`; any mutation +/// drops the cache so the next read produces a fresh tree. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "model_package.h" + +namespace model_package { + +using ordered_json = nlohmann::ordered_json; + +/// How the component's body is stored on disk relative to the manifest. +enum class ComponentStorage { + kInline, ///< body lives directly inside the manifest as an object + kExternal, ///< body lives in a separate file pointed to by a string +}; + +struct VariantRecord { + std::string name; + nlohmann::ordered_json body; ///< the full variant JSON object + + // Stable string buffers for ABI exposure. + std::string name_cache; + std::optional ep_cache; + std::optional device_cache; + std::optional compatibility_string_cache; + std::optional resolved_directory_cache; + mutable std::optional additional_metadata_cache; + mutable std::optional variant_json_cache; + + /// Resolved variant_directory for variants that have one. `std::nullopt` + /// means none was declared and the default location does not exist. + std::optional resolved_directory; + bool resolved_directory_attempted{false}; + + /// Pre-resolved executor_info entries. Populated eagerly at Open and + /// after any mutation that can touch executor_info. The first member is the + /// namespace key; the second is the serialized JSON body of that entry + /// (inline bodies are dumped, external file bodies are read + validated). + std::vector> executor_info_resolved; +}; + +struct ComponentRecord { + std::string name; + ComponentStorage storage{ComponentStorage::kInline}; + std::filesystem::path external_path; ///< valid iff storage == kExternal + std::filesystem::path component_dir; ///< base directory for relative paths inside this component + nlohmann::ordered_json body; ///< {"component_name": ..., "variants": {...}, "additional_metadata": {...}} + std::vector> variants; + + std::string name_cache; + mutable std::optional additional_metadata_cache; + mutable std::optional component_json_cache; +}; + +struct SharedAssetRecord { + std::string uri; ///< "sha256:" + std::filesystem::path resolved_path; + std::string uri_cache; + std::string resolved_path_cache; +}; + +/// Materialized POD-struct tree returned by ModelPackage_Info(). Owns all +/// backing storage (extra strings and array buffers) so pointers stay valid +/// until the next mutation drops the cache. +struct InfoViewCache { + // Per-variant arrays. Indexed [component_idx][variant_idx]. + std::vector> executor_infos_storage; + std::vector> variants_storage; + + std::vector components; + std::vector shared_assets; + ModelPackageInfo info{}; +}; + +} // namespace model_package + +// ───────────────────────────────────────────────────────────────────────────── +// Public opaque type (lives in the global namespace to match the C API) +// ───────────────────────────────────────────────────────────────────────────── + +struct ModelPackage { + std::filesystem::path package_root; + nlohmann::ordered_json manifest; ///< parsed manifest.json with declarations intact (component values stay in their original string-or-object form) + std::string layout; ///< "portable" | "installed" + + // Open-time options. + bool allow_external_paths{false}; + bool follow_symlinks{true}; + bool strict_unknown_fields{true}; + + // Package-level parsed data and stable string buffers. + int64_t schema_version_major{0}; + int64_t schema_version_minor{0}; + std::optional package_name_cache; + std::optional package_version_cache; + std::optional description_cache; + std::string layout_cache; + mutable std::optional additional_metadata_cache; + + std::vector> components; + std::vector> shared_assets; + + std::unordered_map component_index_by_name; + std::unordered_map shared_asset_index_by_uri; + + /// Authoring-time staging for copy_in=true shared assets that have not been + /// committed yet. Keyed by sha256: URI. + std::unordered_map pending_shared_asset_copies; + + /// Paths removed from the live tree, candidates for ModelPackage_Prune. + /// Populated by the authoring API; never by walking package_root. + std::vector pending_orphan_variant_dirs; + std::vector pending_orphan_component_dirs; + + /// Cache for the most recent ModelPackage_Validate report JSON. + mutable std::optional last_validate_report; + + /// Lazily built; dropped on any mutation. + mutable std::optional info_cache; +}; + +namespace model_package { + +/// Drop the materialized view cache. Call after any mutation that affects the +/// view tree. Safe on a cleared cache. +void DropViewCache(ModelPackage* pkg); + +/// Return the package's info view, building it lazily. +const InfoViewCache& BuildOrGetViewCache(const ModelPackage* pkg); + +/// Returns true iff `p` is `package_root` or lives under it (lexically). +bool IsInsidePackageRoot(const ModelPackage* pkg, const std::filesystem::path& p); + +/// Push the variant's resolved_directory onto the Prune candidates if it's +/// inside package_root. No-op if unresolved. +void RecordOrphanVariantDir(ModelPackage* pkg, const VariantRecord& v); + +/// Push every variant_dir of `c`, plus `c.component_dir` if external. +void RecordOrphanComponent(ModelPackage* pkg, const ComponentRecord& c); + +} // namespace model_package diff --git a/model_package/src/model_package_internal.h b/model_package/src/model_package_internal.h deleted file mode 100644 index 8d116b78f5880..0000000000000 --- a/model_package/src/model_package_internal.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/// \file model_package_internal.h -/// \brief Internal C++ types for the model package library. - -#pragma once - -#include -#include -#include -#include -#include - -namespace model_package { - -// ───────────────────────────────────────────────────────────────────────────── -// Data types -// ───────────────────────────────────────────────────────────────────────────── - -/// EP compatibility declaration for a variant (opaque to this library). -struct EpCompatibility { - std::optional ep; - std::optional device; - std::optional compatibility_string; -}; - -/// A single model file within a variant. -struct VariantFile { - std::string filename; - std::filesystem::path resolved_path; - - std::optional> session_options; - std::optional> provider_options; - std::optional> shared_files; -}; - -/// A variant of a component. -struct Variant { - std::string name; - std::filesystem::path folder_path; - // Single EP compatibility entry per variant (from metadata.json). - EpCompatibility ep_compatibility; - // Single model file entry (from variant.json). Empty when variant.json is absent. - std::optional file; - std::optional consumer_metadata_json; -}; - -/// A component in the model package. -struct Component { - std::string name; - std::vector variants; -}; - -/// Top-level model package descriptor. -struct PackageInfo { - int64_t schema_version{}; - std::filesystem::path root_path; - std::vector components; -}; - -// ───────────────────────────────────────────────────────────────────────────── -// Context implementation -// ───────────────────────────────────────────────────────────────────────────── - -/// Internal context holding parsed package data and C API caches. -struct ContextImpl { - PackageInfo package_info; - - // Caches for C API string access (stable pointers). - std::vector component_names_cache; - std::unordered_map> variant_names_cache; - std::unordered_map folder_path_strings_cache; - - // Lookup helpers. - const Component* FindComponent(const char* name) const; - const Variant* FindVariant(const char* component_name, const char* variant_name) const; -}; - -} // namespace model_package diff --git a/model_package/src/parser.cc b/model_package/src/parser.cc deleted file mode 100644 index 70d95b0297e38..0000000000000 --- a/model_package/src/parser.cc +++ /dev/null @@ -1,595 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "parser.h" - -#include -#include -#include -#include -#include -#include - -#include "nlohmann/json.hpp" - -using json = nlohmann::json; - -namespace model_package { -namespace { - -// ───────────────────────────────────────────────────────────────────────────── -// JSON key constants -// ───────────────────────────────────────────────────────────────────────────── - -constexpr const char* kManifestFileName = "manifest.json"; -constexpr const char* kMetadataFileName = "metadata.json"; -constexpr const char* kVariantDescriptorFileName = "variant.json"; - -constexpr const char* kSchemaVersionKey = "schema_version"; -constexpr const char* kComponentsKey = "components"; -constexpr const char* kComponentNameKey = "component_name"; -constexpr const char* kVariantsKey = "variants"; - -constexpr const char* kEpKey = "ep"; -constexpr const char* kDeviceKey = "device"; -constexpr const char* kCompatibilityStringKey = "compatibility_string"; - -constexpr const char* kFilenameKey = "filename"; -constexpr const char* kSessionOptionsKey = "session_options"; -constexpr const char* kProviderOptionsKey = "provider_options"; -constexpr const char* kSharedFilesKey = "shared_files"; -constexpr const char* kConsumerMetadataKey = "consumer_metadata"; - -// ───────────────────────────────────────────────────────────────────────────── -// Internal schema types for deserialization -// ───────────────────────────────────────────────────────────────────────────── - -struct VariantMetadataSchema { - std::string filename; - std::optional> session_options; - std::optional> provider_options; - std::optional> shared_files; -}; - -struct EpCompatibilitySchema { - std::optional ep; - std::optional device; - std::optional compatibility_string; -}; - -struct VariantSchema { - EpCompatibilitySchema ep_info; -}; - -struct ComponentSchema { - std::optional component_name; - std::unordered_map variants; -}; - -struct ManifestSchema { - int64_t schema_version; - std::optional> components; -}; - -// ───────────────────────────────────────────────────────────────────────────── -// JSON helpers -// ───────────────────────────────────────────────────────────────────────────── - -std::string JsonScalarToString(const json& v, const char* key_name, const std::string& parent_key) { - if (v.is_string()) return v.get(); - if (v.is_number_integer()) return std::to_string(v.get()); - if (v.is_number_unsigned()) return std::to_string(v.get()); - if (v.is_number_float()) return v.dump(); - if (v.is_boolean()) return v.get() ? "true" : "false"; - - throw std::invalid_argument( - std::string("\"") + key_name + "\" under '" + parent_key + - "' must contain scalar (string/number/bool) values."); -} - -std::optional> ParseFlatOptionsObject( - const json& j, const char* key_name) { - if (!j.contains(key_name) || j[key_name].is_null()) { - return std::nullopt; - } - - const auto& obj = j[key_name]; - if (!obj.is_object()) { - throw std::invalid_argument(std::string("\"") + key_name + "\" must be an object."); - } - - std::unordered_map result; - result.reserve(obj.size()); - - for (auto it = obj.begin(); it != obj.end(); ++it) { - result.emplace(it.key(), JsonScalarToString(it.value(), key_name, it.key())); - } - - return result; -} - -std::optional ParseOptionalString(const json& j, const char* key_name) { - if (!j.contains(key_name) || j[key_name].is_null()) { - return std::nullopt; - } - - const auto& value = j[key_name]; - if (!value.is_string()) { - throw std::invalid_argument(std::string("\"") + key_name + "\" must be a string."); - } - return value.get(); -} - -// ───────────────────────────────────────────────────────────────────────────── -// nlohmann from_json overloads -// ───────────────────────────────────────────────────────────────────────────── - -void from_json(const json& j, EpCompatibilitySchema& c) { - if (!j.contains(kEpKey) || j[kEpKey].is_null()) { - throw std::invalid_argument(std::string("\"") + kEpKey + "\" is required in each ep_compatibility entry."); - } - if (!j[kEpKey].is_string()) { - throw std::invalid_argument(std::string("\"") + kEpKey + "\" must be a string."); - } - c.ep = j[kEpKey].get(); - if (c.ep->empty()) { - throw std::invalid_argument(std::string("\"") + kEpKey + "\" must be a non-empty string."); - } - - if (j.contains(kDeviceKey) && !j[kDeviceKey].is_null()) { - if (!j[kDeviceKey].is_string()) { - throw std::invalid_argument(std::string("\"") + kDeviceKey + "\" must be a string when present."); - } - c.device = j[kDeviceKey].get(); - } - c.compatibility_string = ParseOptionalString(j, kCompatibilityStringKey); -} - -void from_json(const json& j, VariantSchema& v) { - // EP fields (ep, device, compatibility_string) are now directly on the variant object. - // "ep" is required. - v.ep_info = j.get(); -} - -void from_json(const json& j, VariantMetadataSchema& v) { - v.filename = j.at(kFilenameKey).get(); - v.session_options = ParseFlatOptionsObject(j, kSessionOptionsKey); - v.provider_options = ParseFlatOptionsObject(j, kProviderOptionsKey); - v.shared_files = ParseFlatOptionsObject(j, kSharedFilesKey); -} - -void from_json(const json& j, ManifestSchema& m) { - m.schema_version = j.at(kSchemaVersionKey).get(); - - if (j.contains(kComponentsKey)) { - if (!j[kComponentsKey].is_array()) { - throw std::invalid_argument(std::string("\"") + kComponentsKey + "\" must be an array of strings"); - } - m.components = j[kComponentsKey].get>(); - } -} - -void from_json(const json& j, ComponentSchema& m) { - if (j.contains(kComponentNameKey) && j[kComponentNameKey].is_string()) { - m.component_name = j[kComponentNameKey].get(); - } - - m.variants = j.at(kVariantsKey).get>(); -} - -// ───────────────────────────────────────────────────────────────────────────── -// Parsing variants in declaration order (from the JSON object) -// ───────────────────────────────────────────────────────────────────────────── - -std::vector> ParseVariantsInOrder(const json& variants_obj) { - std::vector> result; - result.reserve(variants_obj.size()); - for (auto it = variants_obj.begin(); it != variants_obj.end(); ++it) { - result.emplace_back(it.key(), it.value().get()); - } - return result; -} - -// ───────────────────────────────────────────────────────────────────────────── -// Path validation -// ───────────────────────────────────────────────────────────────────────────── - -bool ValidatePathSegment(const std::string& segment, const char* segment_type, std::string& error) { - if (segment.empty()) { - error = std::string(segment_type) + " must not be empty."; - return false; - } - - if (std::filesystem::path(segment).is_absolute()) { - error = std::string(segment_type) + " must not be an absolute path: '" + segment + "'."; - return false; - } - - for (const auto& part : std::filesystem::path(segment)) { - if (part == "..") { - error = std::string(segment_type) + " must not contain '..' path components: '" + segment + "'."; - return false; - } - } - - return true; -} - -bool ValidatePathConfinement(const std::filesystem::path& resolved_path, - const std::filesystem::path& root, - const char* description, - std::string& error) { - auto normal_root = root.lexically_normal(); - auto normal_path = resolved_path.lexically_normal(); - - auto root_str = normal_root.string(); - auto path_str = normal_path.string(); - - if (path_str.size() < root_str.size() || - path_str.compare(0, root_str.size(), root_str) != 0 || - (path_str.size() > root_str.size() && path_str[root_str.size()] != std::filesystem::path::preferred_separator -#ifndef _WIN32 - && path_str[root_str.size()] != '/' -#endif - )) { - error = std::string(description) + " resolves outside the package root. Path: '" + - resolved_path.string() + "', Root: '" + root.string() + "'."; - return false; - } - - return true; -} - -// ───────────────────────────────────────────────────────────────────────────── -// Find single ONNX file in directory -// ───────────────────────────────────────────────────────────────────────────── - -bool FindSingleOnnxFile(const std::filesystem::path& search_dir, - std::filesystem::path& resolved_path, - std::string& error) { - std::vector onnx_files; - for (const auto& entry : std::filesystem::directory_iterator(search_dir)) { - if (!entry.is_regular_file()) continue; - - std::string ext = entry.path().extension().string(); - std::transform(ext.begin(), ext.end(), ext.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (ext == ".onnx") { - onnx_files.push_back(entry.path()); - } - } - - if (onnx_files.empty()) { - error = "No ONNX model file found under " + search_dir.string(); - return false; - } - - if (onnx_files.size() > 1) { - error = "Multiple ONNX model files found under " + search_dir.string() + - ". Multiple ONNX files per variant are not supported yet."; - return false; - } - - resolved_path = onnx_files.front(); - return true; -} - -// ───────────────────────────────────────────────────────────────────────────── -// Parse variants from a single component -// ───────────────────────────────────────────────────────────────────────────── - -bool ParseVariantsFromComponent(const std::string& component_name, - const std::filesystem::path& component_root, - const json* variants_obj, - std::vector& out_variants, - std::string& error) { - if (variants_obj == nullptr) { - error = "Missing metadata variants for component: " + component_name; - return false; - } - - std::vector> variants; - try { - variants = ParseVariantsInOrder(*variants_obj); - } catch (const std::exception& ex) { - error = "Invalid metadata variant schema for component '" + component_name + "': " + ex.what(); - return false; - } - - for (const auto& [variant_name, variant_schema] : variants) { - if (!ValidatePathSegment(variant_name, "Variant name", error)) return false; - - const std::filesystem::path variant_root = component_root / variant_name; - if (!ValidatePathConfinement(variant_root, component_root, "Variant directory", error)) return false; - - const std::filesystem::path variant_descriptor_path = variant_root / kVariantDescriptorFileName; - - Variant variant_info{}; - variant_info.name = variant_name; - variant_info.folder_path = variant_root; - - // variant.json is optional. If present, it declares the file list, - // per-file session/provider options, and consumer metadata. - if (std::filesystem::exists(variant_descriptor_path)) { - std::ifstream vf(variant_descriptor_path, std::ios::binary); - if (!vf) { - error = "Failed to open variant.json at " + variant_descriptor_path.string(); - return false; - } - - json variant_doc; - try { - variant_doc = json::parse(vf); - } catch (const std::exception& ex) { - error = "variant.json at " + variant_descriptor_path.string() + " is not valid JSON: " + ex.what(); - return false; - } - - VariantMetadataSchema variant_metadata; - try { - variant_metadata = variant_doc.get(); - } catch (const std::exception& ex) { - error = "variant.json at " + variant_descriptor_path.string() + " has invalid schema: " + ex.what(); - return false; - } - - // consumer_metadata is a top-level optional field parsed separately from the schema struct. - if (variant_doc.contains(kConsumerMetadataKey) && variant_doc[kConsumerMetadataKey].is_object()) { - variant_info.consumer_metadata_json = variant_doc[kConsumerMetadataKey].dump(); - } - - if (!ValidatePathSegment(variant_metadata.filename, "File name", error)) return false; - - const std::filesystem::path candidate_path = variant_root / variant_metadata.filename; - if (!ValidatePathConfinement(candidate_path, variant_root, "Variant file path", error)) return false; - - if (!std::filesystem::exists(candidate_path)) { - error = "Variant '" + variant_name + "', file '" + variant_metadata.filename + - "' path does not exist: " + candidate_path.string(); - return false; - } - - std::filesystem::path resolved_model_path; - if (std::filesystem::is_regular_file(candidate_path)) { - resolved_model_path = candidate_path; - } else if (std::filesystem::is_directory(candidate_path)) { - if (!FindSingleOnnxFile(candidate_path, resolved_model_path, error)) return false; - } else { - error = "Variant '" + variant_name + "', file '" + variant_metadata.filename + - "' path is neither a file nor directory: " + candidate_path.string(); - return false; - } - - VariantFile file_info{}; - file_info.filename = variant_metadata.filename; - file_info.resolved_path = std::move(resolved_model_path); - file_info.session_options = variant_metadata.session_options; - file_info.provider_options = variant_metadata.provider_options; - file_info.shared_files = variant_metadata.shared_files; - - variant_info.file = std::move(file_info); - } - - // EP compatibility from metadata.json (single entry per variant) - variant_info.ep_compatibility.ep = variant_schema.ep_info.ep; - variant_info.ep_compatibility.device = variant_schema.ep_info.device; - variant_info.ep_compatibility.compatibility_string = variant_schema.ep_info.compatibility_string; - - out_variants.push_back(std::move(variant_info)); - } - - return true; -} - -} // namespace - -// ───────────────────────────────────────────────────────────────────────────── -// Public parser entry point -// ───────────────────────────────────────────────────────────────────────────── - -bool ParsePackage(const std::filesystem::path& package_root, - PackageInfo& out_package, - std::string& out_error) { - out_package = {}; - out_package.root_path = package_root; - - // Check for single-component mode: metadata.json at root - const auto root_metadata_path = package_root / kMetadataFileName; - if (std::filesystem::exists(root_metadata_path) && - std::filesystem::is_regular_file(root_metadata_path)) { - std::ifstream mf(root_metadata_path, std::ios::binary); - if (!mf) { - out_error = "Failed to open metadata.json at " + root_metadata_path.string(); - return false; - } - - json metadata_doc; - try { - metadata_doc = json::parse(mf); - } catch (const std::exception& ex) { - out_error = "metadata.json at " + root_metadata_path.string() + " is not valid JSON: " + ex.what(); - return false; - } - - ComponentSchema metadata_schema; - try { - metadata_schema = metadata_doc.get(); - } catch (const std::exception& ex) { - out_error = "metadata.json at " + root_metadata_path.string() + " has invalid schema: " + ex.what(); - return false; - } - - const std::string component_name = - metadata_schema.component_name.has_value() - ? *metadata_schema.component_name - : package_root.filename().string(); - - const json* variants_obj = &metadata_doc.at(kVariantsKey); - - Component component{}; - component.name = component_name; - - if (!ParseVariantsFromComponent(component_name, package_root, variants_obj, - component.variants, out_error)) { - return false; - } - - out_package.schema_version = 0; // Single-component mode doesn't have a manifest - out_package.components.push_back(std::move(component)); - return true; - } - - // Multi-component mode: manifest.json at root - const auto manifest_path = package_root / kManifestFileName; - if (!std::filesystem::exists(manifest_path)) { - out_error = "No manifest.json found at " + manifest_path.string(); - return false; - } - - std::ifstream f(manifest_path, std::ios::binary); - if (!f) { - out_error = "Failed to open manifest.json at " + manifest_path.string(); - return false; - } - - json doc; - try { - doc = json::parse(f); - } catch (const std::exception& ex) { - out_error = std::string("manifest.json is not valid JSON: ") + ex.what(); - return false; - } - - ManifestSchema manifest_schema; - try { - manifest_schema = doc.get(); - } catch (const std::exception& ex) { - out_error = std::string("manifest.json has invalid schema: ") + ex.what(); - return false; - } - - if (manifest_schema.schema_version != 1) { - out_error = "Unsupported schema_version in manifest.json: " + - std::to_string(manifest_schema.schema_version) + ". Expected 1."; - return false; - } - - out_package.schema_version = manifest_schema.schema_version; - - const bool has_components = manifest_schema.components.has_value(); - std::vector component_names; - std::unordered_map discovered_metadata_docs; - - if (has_components) { - component_names = *manifest_schema.components; - } else { - const auto models_dir = package_root / "models"; - if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) { - out_error = "manifest.json missing \"components\" and no discoverable models directory at " + - models_dir.string(); - return false; - } - - for (const auto& entry : std::filesystem::directory_iterator(models_dir)) { - if (!entry.is_directory()) continue; - - const auto name = entry.path().filename().string(); - const auto metadata_path = entry.path() / kMetadataFileName; - if (!std::filesystem::exists(metadata_path)) continue; - - std::ifstream mf(metadata_path, std::ios::binary); - if (!mf) { - out_error = "Failed to open metadata.json at " + metadata_path.string(); - return false; - } - - json metadata_doc; - try { - metadata_doc = json::parse(mf); - (void)metadata_doc.get(); - } catch (const std::exception& ex) { - out_error = "metadata.json at " + metadata_path.string() + - " has invalid schema: " + std::string(ex.what()); - return false; - } - - discovered_metadata_docs.emplace(name, std::move(metadata_doc)); - component_names.push_back(name); - } - - if (component_names.empty()) { - out_error = - "manifest.json missing \"components\" and no component model folders with " - "metadata.json were found under " + - models_dir.string(); - return false; - } - } - - for (const auto& component_name : component_names) { - if (!ValidatePathSegment(component_name, "Component name", out_error)) return false; - - const auto component_root = package_root / "models" / component_name; - if (!ValidatePathConfinement(component_root, package_root, "Component directory", out_error)) return false; - - if (has_components && - (!std::filesystem::exists(component_root) || !std::filesystem::is_directory(component_root))) { - // Skip missing component directories (just warn — standalone library doesn't have logging, - // so we skip silently for now). - continue; - } - - json metadata_doc; - const json* variants_obj = nullptr; - const auto metadata_path = component_root / kMetadataFileName; - - if (!has_components) { - auto it_meta = discovered_metadata_docs.find(component_name); - if (it_meta != discovered_metadata_docs.end()) { - metadata_doc = it_meta->second; - variants_obj = &metadata_doc.at(kVariantsKey); - } - } else if (std::filesystem::exists(metadata_path)) { - std::ifstream mf(metadata_path, std::ios::binary); - if (mf) { - try { - metadata_doc = json::parse(mf); - (void)metadata_doc.get(); - variants_obj = &metadata_doc.at(kVariantsKey); - } catch (const std::exception&) { - // Ignore parse errors, fall through. - } - } - } - - if (!metadata_doc.is_null() && - metadata_doc.contains(kComponentNameKey) && - metadata_doc[kComponentNameKey].is_string()) { - const auto metadata_component_name = metadata_doc[kComponentNameKey].get(); - if (metadata_component_name != component_name) { - out_error = "metadata.json component_name '" + metadata_component_name + - "' does not match directory/manifest component name '" + component_name + "'."; - return false; - } - } - - Component component{}; - component.name = component_name; - - if (!ParseVariantsFromComponent(component_name, component_root, variants_obj, - component.variants, out_error)) { - return false; - } - - out_package.components.push_back(std::move(component)); - } - - if (out_package.components.empty()) { - out_error = "No valid component models were found under " + (package_root / "models").string(); - return false; - } - - return true; -} - -} // namespace model_package diff --git a/model_package/src/parser.h b/model_package/src/parser.h deleted file mode 100644 index ed3d22cb29d36..0000000000000 --- a/model_package/src/parser.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -/// \file parser.h -/// \brief Model package JSON parser (internal). - -#pragma once - -#include -#include - -#include "model_package_internal.h" - -namespace model_package { - -/// Parse a model package from a directory. -/// Reads manifest.json, metadata.json per component, variant.json per variant. -/// -/// \param[in] package_root Path to the model package root directory. -/// \param[out] out_package On success, filled with the parsed package info. -/// \param[out] out_error On failure, filled with an error message. -/// \return true on success, false on error. -bool ParsePackage(const std::filesystem::path& package_root, - PackageInfo& out_package, - std::string& out_error); - -} // namespace model_package diff --git a/model_package/src/path_resolver.cc b/model_package/src/path_resolver.cc new file mode 100644 index 0000000000000..d62662d3ffcb1 --- /dev/null +++ b/model_package/src/path_resolver.cc @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "path_resolver.h" + +#include +#include +#include +#include + +#include "status_impl.h" + +namespace fs = std::filesystem; + +namespace model_package { + +namespace { + +bool IsHexLower(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'); } + +bool ContainsParentRefSegment(const fs::path& p) { + for (const auto& seg : p) { + if (seg == "..") return true; + } + return false; +} + +} // namespace + +bool IsSha256AssetUri(const std::string& uri) { + static constexpr const char* kPrefix = "sha256:"; + static constexpr size_t kPrefixLen = 7; + static constexpr size_t kHexLen = 64; + if (uri.size() != kPrefixLen + kHexLen) return false; + if (uri.compare(0, kPrefixLen, kPrefix) != 0) return false; + for (size_t i = kPrefixLen; i < uri.size(); ++i) { + if (!IsHexLower(uri[i])) return false; + } + return true; +} + +ModelPackageStatus* ResolvePath(const fs::path& base_dir, + const fs::path& package_root, + const std::string& input, + const PathResolverOptions& opts, + bool must_exist, + fs::path* out) { + if (!out) { + return model_package::MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + "ResolvePath: out must not be null."); + } + if (input.empty()) { + return model_package::MakeStatus(MODEL_PACKAGE_ERR_INVALID_ARG, + "ResolvePath: input must not be empty."); + } + + fs::path raw(input); + + if (!opts.allow_external_paths) { + if (raw.is_absolute() || raw.has_root_name()) { + return model_package::MakeStatus( + MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + std::string("ResolvePath: absolute or drive-rooted path '") + input + + "' is not allowed in portable layout."); + } + if (ContainsParentRefSegment(raw)) { + return model_package::MakeStatus( + MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + std::string("ResolvePath: '..' segments are not allowed in portable layout: '") + + input + "'."); + } + } + + fs::path joined = (raw.is_absolute() || raw.has_root_name()) ? raw : (base_dir / raw); + + std::error_code ec; + fs::path canonical; + bool exists_on_disk = fs::exists(joined, ec); + if (!exists_on_disk) { + if (must_exist) { + return model_package::MakeStatus( + MODEL_PACKAGE_ERR_NOT_FOUND, + std::string("ResolvePath: '") + joined.string() + "' does not exist."); + } + // Missing leaf (common during authoring/commit). When following symlinks, use + // weakly_canonical so any existing symlinks in the path prefix are still resolved; + // lexically_normal would leave a symlinked prefix unresolved and let it escape + // package_root undetected. Fall back to lexical normalization if that fails. + if (opts.follow_symlinks) { + canonical = fs::weakly_canonical(joined, ec); + if (ec) canonical = joined.lexically_normal(); + } else { + canonical = joined.lexically_normal(); + } + } else if (opts.follow_symlinks) { + canonical = fs::canonical(joined, ec); + if (ec) { + return model_package::MakeStatus( + MODEL_PACKAGE_ERR_IO, + std::string("ResolvePath: canonical('") + joined.string() + "') failed: " + ec.message()); + } + } else { + canonical = fs::weakly_canonical(joined, ec); + if (ec) { + canonical = joined.lexically_normal(); + } + } + + if (!opts.allow_external_paths && !package_root.empty()) { + // Confinement check: canonical must live under package_root's canonical form. This runs + // whether or not the leaf exists, so a not-yet-created path that resolves outside + // package_root (e.g. through a symlinked prefix) is still rejected. It is skipped when + // package_root is empty, which happens for in-memory authoring before a package has been + // anchored to a directory (there is no on-disk root to confine against yet); the + // absolute-path and ".." lexical checks above still apply in that case. + fs::path canonical_root = fs::weakly_canonical(package_root, ec); + if (ec) canonical_root = package_root.lexically_normal(); + + auto root_str = canonical_root.lexically_normal().string(); + auto can_str = canonical.lexically_normal().string(); + if (can_str.size() < root_str.size() || + can_str.compare(0, root_str.size(), root_str) != 0 || + (can_str.size() > root_str.size() && + can_str[root_str.size()] != fs::path::preferred_separator && + can_str[root_str.size()] != '/')) { + return model_package::MakeStatus( + MODEL_PACKAGE_ERR_PATH_CONFINEMENT, + std::string("ResolvePath: '") + can_str + + "' escapes package_root '" + root_str + "'."); + } + } + + *out = canonical; + return nullptr; +} + +bool TrySplitAssetUriPrefix(const std::string& input, std::string& uri, std::string& tail) { + static constexpr size_t kPrefixLen = 7; // "sha256:" + static constexpr size_t kHexLen = 64; + static constexpr size_t kUriLen = kPrefixLen + kHexLen; + if (input.size() < kUriLen) return false; + if (input.compare(0, kPrefixLen, "sha256:") != 0) return false; + for (size_t i = kPrefixLen; i < kUriLen; ++i) { + if (!IsHexLower(input[i])) return false; + } + if (input.size() == kUriLen) { + uri.assign(input); + tail.clear(); + return true; + } + if (input[kUriLen] != '/') return false; + uri.assign(input, 0, kUriLen); + tail.assign(input, kUriLen + 1, std::string::npos); + return true; +} + +std::string DefaultSharedAssetDirName(const std::string& uri) { + if (!IsSha256AssetUri(uri)) return {}; + return std::string(kSharedAssetOnDiskPrefix) + uri.substr(std::strlen("sha256:")); +} + +std::string SharedAssetUriFromDirName(const std::string& dir_name) { + const size_t prefix_len = std::strlen(kSharedAssetOnDiskPrefix); + if (dir_name.size() != prefix_len + 64) return {}; + if (dir_name.compare(0, prefix_len, kSharedAssetOnDiskPrefix) != 0) return {}; + for (size_t i = prefix_len; i < dir_name.size(); ++i) { + if (!IsHexLower(dir_name[i])) return {}; + } + return "sha256:" + dir_name.substr(prefix_len); +} + +} // namespace model_package diff --git a/model_package/src/path_resolver.h b/model_package/src/path_resolver.h new file mode 100644 index 0000000000000..f008897ff5bb0 --- /dev/null +++ b/model_package/src/path_resolver.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file path_resolver.h +/// \brief Path-resolution and confinement helpers. + +#pragma once + +#include +#include + +#include "model_package_api.h" // for ModelPackageStatus + +namespace model_package { + +struct PathResolverOptions { + bool allow_external_paths{false}; + bool follow_symlinks{true}; +}; + +/// Resolve a relative-or-absolute path string under a given base directory. +/// In portable mode (`allow_external_paths == false`): +/// - Reject absolute inputs (ERR_PATH_CONFINEMENT). +/// - Reject any path that, after canonicalization, escapes `package_root`. +/// - Reject `..` segments syntactically before resolution. +/// In installed mode: +/// - Absolute and `..` allowed. +/// - No confinement check. +/// +/// `must_exist` controls whether a missing target is an error (ERR_NOT_FOUND) +/// or whether the resolved (non-canonical) path is returned anyway. +/// Symlinks are followed when `follow_symlinks` is true. +ModelPackageStatus* ResolvePath(const std::filesystem::path& base_dir, + const std::filesystem::path& package_root, + const std::string& input, + const PathResolverOptions& opts, + bool must_exist, + std::filesystem::path* out); + +/// True if `uri` matches `^sha256:[0-9a-f]{64}$`. +bool IsSha256AssetUri(const std::string& uri); + +/// If `input` begins with a `sha256:` token followed by end-of-string or +/// '/', split into `uri` (the bare URI) and `tail` (substring after '/', or +/// empty). Returns true on a match, false otherwise. +bool TrySplitAssetUriPrefix(const std::string& input, std::string& uri, std::string& tail); + +/// Default on-disk directory name for a shared asset URI, i.e. the basename +/// under `/shared_assets/`. For `sha256:` this is +/// `sha256-`. Returns empty string if `uri` is not a valid sha256 URI. +std::string DefaultSharedAssetDirName(const std::string& uri); + +/// Inverse of `DefaultSharedAssetDirName`. If `dir_name` matches `sha256-` +/// returns the corresponding `sha256:` URI; otherwise returns empty string. +std::string SharedAssetUriFromDirName(const std::string& dir_name); + +/// Prefix shared by every default-convention shared-asset directory name. +constexpr const char* kSharedAssetOnDiskPrefix = "sha256-"; + +} // namespace model_package diff --git a/model_package/src/sha256.cc b/model_package/src/sha256.cc new file mode 100644 index 0000000000000..1ea26a555ad43 --- /dev/null +++ b/model_package/src/sha256.cc @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Clean-room SHA-256 (FIPS 180-4) implementation. No external crypto deps. +// Intended for content-addressed asset hashing, not for cryptographic +// authentication. + +#include "sha256.h" + +#include +#include +#include +#include + +namespace model_package { + +namespace { + +constexpr uint32_t kInitState[8] = { + 0x6a09e667u, + 0xbb67ae85u, + 0x3c6ef372u, + 0xa54ff53au, + 0x510e527fu, + 0x9b05688cu, + 0x1f83d9abu, + 0x5be0cd19u, +}; + +constexpr uint32_t kRoundConstants[64] = { + 0x428a2f98u, + 0x71374491u, + 0xb5c0fbcfu, + 0xe9b5dba5u, + 0x3956c25bu, + 0x59f111f1u, + 0x923f82a4u, + 0xab1c5ed5u, + 0xd807aa98u, + 0x12835b01u, + 0x243185beu, + 0x550c7dc3u, + 0x72be5d74u, + 0x80deb1feu, + 0x9bdc06a7u, + 0xc19bf174u, + 0xe49b69c1u, + 0xefbe4786u, + 0x0fc19dc6u, + 0x240ca1ccu, + 0x2de92c6fu, + 0x4a7484aau, + 0x5cb0a9dcu, + 0x76f988dau, + 0x983e5152u, + 0xa831c66du, + 0xb00327c8u, + 0xbf597fc7u, + 0xc6e00bf3u, + 0xd5a79147u, + 0x06ca6351u, + 0x14292967u, + 0x27b70a85u, + 0x2e1b2138u, + 0x4d2c6dfcu, + 0x53380d13u, + 0x650a7354u, + 0x766a0abbu, + 0x81c2c92eu, + 0x92722c85u, + 0xa2bfe8a1u, + 0xa81a664bu, + 0xc24b8b70u, + 0xc76c51a3u, + 0xd192e819u, + 0xd6990624u, + 0xf40e3585u, + 0x106aa070u, + 0x19a4c116u, + 0x1e376c08u, + 0x2748774cu, + 0x34b0bcb5u, + 0x391c0cb3u, + 0x4ed8aa4au, + 0x5b9cca4fu, + 0x682e6ff3u, + 0x748f82eeu, + 0x78a5636fu, + 0x84c87814u, + 0x8cc70208u, + 0x90befffau, + 0xa4506cebu, + 0xbef9a3f7u, + 0xc67178f2u, +}; + +inline uint32_t Rotr(uint32_t x, int n) { return (x >> n) | (x << (32 - n)); } +inline uint32_t Ch(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (~x & z); } +inline uint32_t Maj(uint32_t x, uint32_t y, uint32_t z) { return (x & y) ^ (x & z) ^ (y & z); } +inline uint32_t Bsig0(uint32_t x) { return Rotr(x, 2) ^ Rotr(x, 13) ^ Rotr(x, 22); } +inline uint32_t Bsig1(uint32_t x) { return Rotr(x, 6) ^ Rotr(x, 11) ^ Rotr(x, 25); } +inline uint32_t Ssig0(uint32_t x) { return Rotr(x, 7) ^ Rotr(x, 18) ^ (x >> 3); } +inline uint32_t Ssig1(uint32_t x) { return Rotr(x, 17) ^ Rotr(x, 19) ^ (x >> 10); } + +} // namespace + +Sha256::Sha256() { + std::memcpy(state_, kInitState, sizeof(state_)); + bit_count_ = 0; + buffer_len_ = 0; +} + +void Sha256::Transform(const uint8_t block[64]) { + uint32_t w[64]; + for (int i = 0; i < 16; ++i) { + w[i] = (static_cast(block[i * 4]) << 24) | + (static_cast(block[i * 4 + 1]) << 16) | + (static_cast(block[i * 4 + 2]) << 8) | + (static_cast(block[i * 4 + 3])); + } + for (int i = 16; i < 64; ++i) { + w[i] = Ssig1(w[i - 2]) + w[i - 7] + Ssig0(w[i - 15]) + w[i - 16]; + } + + uint32_t a = state_[0], b = state_[1], c = state_[2], d = state_[3]; + uint32_t e = state_[4], f = state_[5], g = state_[6], h = state_[7]; + for (int i = 0; i < 64; ++i) { + uint32_t t1 = h + Bsig1(e) + Ch(e, f, g) + kRoundConstants[i] + w[i]; + uint32_t t2 = Bsig0(a) + Maj(a, b, c); + h = g; + g = f; + f = e; + e = d + t1; + d = c; + c = b; + b = a; + a = t1 + t2; + } + state_[0] += a; + state_[1] += b; + state_[2] += c; + state_[3] += d; + state_[4] += e; + state_[5] += f; + state_[6] += g; + state_[7] += h; +} + +void Sha256::Update(const void* data, size_t len) { + const uint8_t* p = static_cast(data); + bit_count_ += static_cast(len) * 8; + while (len > 0) { + size_t take = std::min(64 - buffer_len_, len); + std::memcpy(buffer_ + buffer_len_, p, take); + buffer_len_ += take; + p += take; + len -= take; + if (buffer_len_ == 64) { + Transform(buffer_); + buffer_len_ = 0; + } + } +} + +void Sha256::Final(uint8_t out[kDigestSize]) { + // Append 0x80, pad with zeros, append 64-bit big-endian length. + buffer_[buffer_len_++] = 0x80; + if (buffer_len_ > 56) { + std::memset(buffer_ + buffer_len_, 0, 64 - buffer_len_); + Transform(buffer_); + buffer_len_ = 0; + } + std::memset(buffer_ + buffer_len_, 0, 56 - buffer_len_); + uint64_t bc = bit_count_; + for (int i = 7; i >= 0; --i) { + buffer_[56 + i] = static_cast(bc & 0xff); + bc >>= 8; + } + Transform(buffer_); + for (int i = 0; i < 8; ++i) { + out[i * 4] = static_cast((state_[i] >> 24) & 0xff); + out[i * 4 + 1] = static_cast((state_[i] >> 16) & 0xff); + out[i * 4 + 2] = static_cast((state_[i] >> 8) & 0xff); + out[i * 4 + 3] = static_cast(state_[i] & 0xff); + } +} + +namespace { +constexpr char kHex[] = "0123456789abcdef"; +std::string ToHex(const uint8_t* bytes, size_t len) { + std::string s(len * 2, '0'); + for (size_t i = 0; i < len; ++i) { + s[i * 2] = kHex[(bytes[i] >> 4) & 0x0f]; + s[i * 2 + 1] = kHex[bytes[i] & 0x0f]; + } + return s; +} +} // namespace + +std::string Sha256::FinalHex() { + uint8_t out[kDigestSize]; + Final(out); + return ToHex(out, kDigestSize); +} + +std::string Sha256::HashBytesHex(const void* data, size_t len) { + Sha256 h; + h.Update(data, len); + return h.FinalHex(); +} + +std::string Sha256::HashStringHex(const std::string& s) { + return HashBytesHex(s.data(), s.size()); +} + +std::string Sha256::HashFileHex(const std::string& path) { + std::ifstream f(path, std::ios::binary); + if (!f) return std::string(); + Sha256 h; + char buf[8192]; + while (f) { + f.read(buf, sizeof(buf)); + std::streamsize n = f.gcount(); + if (n > 0) h.Update(buf, static_cast(n)); + } + return h.FinalHex(); +} + +} // namespace model_package diff --git a/model_package/src/sha256.h b/model_package/src/sha256.h new file mode 100644 index 0000000000000..da4125ecd80b0 --- /dev/null +++ b/model_package/src/sha256.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file sha256.h +/// \brief Minimal SHA-256 implementation used for content-addressed assets. +/// No external crypto dependency. + +#pragma once + +#include +#include +#include +#include + +namespace model_package { + +class Sha256 { + public: + static constexpr size_t kDigestSize = 32; + + Sha256(); + void Update(const void* data, size_t len); + void Update(const std::string& s) { Update(s.data(), s.size()); } + void Final(uint8_t out[kDigestSize]); + + /// Hex-encoded (lowercase) digest, 64 chars. + std::string FinalHex(); + + static std::string HashBytesHex(const void* data, size_t len); + static std::string HashStringHex(const std::string& s); + + /// Stream-hash a file by path. Returns the hex digest, or empty string on + /// IO error (caller should pre-check existence). + static std::string HashFileHex(const std::string& path); + + private: + void Transform(const uint8_t block[64]); + uint32_t state_[8]; + uint64_t bit_count_; + uint8_t buffer_[64]; + size_t buffer_len_; +}; + +} // namespace model_package diff --git a/model_package/src/status_impl.h b/model_package/src/status_impl.h new file mode 100644 index 0000000000000..6cc1c94238f98 --- /dev/null +++ b/model_package/src/status_impl.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file status_impl.h +/// \brief Internal representation of ModelPackageStatus, shared by all +/// implementation units in the model_package library. + +#pragma once + +#include +#include +#include + +#include "model_package_api.h" + +struct ModelPackageStatus { + ModelPackageErrorCode code{MODEL_PACKAGE_ERR_INVALID_ARG}; + std::string message; +}; + +namespace model_package { + +/// Allocate a new failure status. Returns nullptr if allocation fails (callers +/// should treat that as a generic error; we deliberately never throw out of the +/// C API). +inline ModelPackageStatus* MakeStatus(ModelPackageErrorCode code, std::string message) { + return new (std::nothrow) ModelPackageStatus{code, std::move(message)}; +} + +} // namespace model_package diff --git a/model_package/tests/test_asset_hashing.cc b/model_package/tests/test_asset_hashing.cc new file mode 100644 index 0000000000000..717c3fefea4b6 --- /dev/null +++ b/model_package/tests/test_asset_hashing.cc @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file test_asset_hashing.cc +/// \brief Tests for the directory Merkle hash and SHA-256 implementation. + +#include "model_package.h" +#include "model_package_api.h" +#include "sha256.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; +using model_package::Sha256; + +namespace { + +int g_failed = 0; +int g_passed = 0; +const char* g_current = ""; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: CHECK(%s)\n", g_current, __LINE__, #cond); \ + return false; \ + } \ + } while (0) + +#define CHECK_OK(status) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s != nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected OK, got: %s\n", \ + g_current, __LINE__, ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + } while (0) + +class Sandbox { + public: + Sandbox() { + std::random_device rd; + std::mt19937_64 g(rd()); + char buf[32]; + std::snprintf(buf, sizeof(buf), "mp_hash_%016llx", static_cast(g())); + root_ = fs::temp_directory_path() / buf; + fs::create_directories(root_); + } + ~Sandbox() { + std::error_code ec; + fs::remove_all(root_, ec); + } + Sandbox(const Sandbox&) = delete; + Sandbox& operator=(const Sandbox&) = delete; + const fs::path& root() const { return root_; } + void Write(const std::string& relpath, const std::string& contents) { + fs::path full = root_ / relpath; + fs::create_directories(full.parent_path()); + std::ofstream f(full, std::ios::binary); + f << contents; + } + + private: + fs::path root_; +}; + +// FIPS-180-4 known-answer test vectors. +bool test_sha256_known_vectors() { + CHECK(Sha256::HashStringHex("") == + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"); + CHECK(Sha256::HashStringHex("abc") == + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"); + // Long message: 1,000,000 'a' characters. + std::string a_million(1000000, 'a'); + CHECK(Sha256::HashStringHex(a_million) == + "cdc76e5c9914fb9281a1c7e284d73e67f1809a48a497200e046d39ccc7112cd0"); + return true; +} + +bool test_sha256_incremental_matches_oneshot() { + std::string msg = "the quick brown fox jumps over the lazy dog"; + std::string oneshot = Sha256::HashStringHex(msg); + Sha256 h; + for (char c : msg) h.Update(&c, 1); + CHECK(h.FinalHex() == oneshot); + return true; +} + +bool test_directory_hash_basic() { + Sandbox s; + s.Write("a.txt", "alpha"); + s.Write("b.txt", "beta"); + + const char* uri = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s.root().c_str(), &uri)); + CHECK(uri != nullptr); + std::string u(uri); + CHECK(u.substr(0, 7) == "sha256:"); + CHECK(u.size() == 7 + 64); + return true; +} + +bool test_directory_hash_reproducible() { + Sandbox s1; + s1.Write("a.txt", "alpha"); + s1.Write("nested/b.txt", "beta"); + + Sandbox s2; + s2.Write("a.txt", "alpha"); + s2.Write("nested/b.txt", "beta"); + + const char* u1 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s1.root().c_str(), &u1)); + std::string copy1(u1); + + const char* u2 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s2.root().c_str(), &u2)); + CHECK(copy1 == std::string(u2)); + return true; +} + +bool test_directory_hash_name_change_differs() { + Sandbox s1; + s1.Write("a.txt", "alpha"); + + Sandbox s2; + s2.Write("b.txt", "alpha"); // same content, different name + + const char* u1 = nullptr; + const char* u2 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s1.root().c_str(), &u1)); + std::string copy1(u1); + CHECK_OK(ModelPackage_ComputeDirectoryHash(s2.root().c_str(), &u2)); + CHECK(copy1 != std::string(u2)); + return true; +} + +bool test_directory_hash_swapped_names_differ() { + Sandbox s1; + s1.Write("a.txt", "alpha"); + s1.Write("b.txt", "beta"); + + Sandbox s2; + s2.Write("a.txt", "beta"); // swapped contents + s2.Write("b.txt", "alpha"); + + const char* u1 = nullptr; + const char* u2 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s1.root().c_str(), &u1)); + std::string copy1(u1); + CHECK_OK(ModelPackage_ComputeDirectoryHash(s2.root().c_str(), &u2)); + CHECK(copy1 != std::string(u2)); + return true; +} + +bool test_directory_hash_content_change_differs() { + Sandbox s1; + s1.Write("a.txt", "alpha"); + Sandbox s2; + s2.Write("a.txt", "ALPHA"); + + const char* u1 = nullptr; + const char* u2 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s1.root().c_str(), &u1)); + std::string copy1(u1); + CHECK_OK(ModelPackage_ComputeDirectoryHash(s2.root().c_str(), &u2)); + CHECK(copy1 != std::string(u2)); + return true; +} + +bool test_directory_hash_empty_dirs_ignored() { + Sandbox s1; + s1.Write("a.txt", "alpha"); + Sandbox s2; + s2.Write("a.txt", "alpha"); + fs::create_directories(s2.root() / "empty_subdir"); + + const char* u1 = nullptr; + const char* u2 = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s1.root().c_str(), &u1)); + std::string copy1(u1); + CHECK_OK(ModelPackage_ComputeDirectoryHash(s2.root().c_str(), &u2)); + CHECK(copy1 == std::string(u2)); + return true; +} + +bool test_directory_hash_rejects_symlink() { + Sandbox s; + s.Write("a.txt", "alpha"); + std::error_code ec; + fs::create_symlink("a.txt", s.root() / "a_link.txt", ec); + // If symlink creation isn't supported on this filesystem, skip the test + // (treat as pass — the rejection is the behavior under test). + if (ec) { + std::printf("[SKIP] %s (symlink unsupported)\n", g_current); + return true; + } + const char* uri = nullptr; + ModelPackageStatus* st = ModelPackage_ComputeDirectoryHash(s.root().c_str(), &uri); + CHECK(st != nullptr); + CHECK(ModelPackageStatus_Code(st) == MODEL_PACKAGE_ERR_SCHEMA); + ModelPackageStatus_Release(st); + return true; +} + +bool test_directory_hash_known_value_single_file() { + // Known-answer check: the directory URI hashes a manifest of " \n" + // lines, so compute the expected value the same way and compare. + Sandbox s; + s.Write("a.txt", "alpha"); + + std::string file_hex = Sha256::HashStringHex("alpha"); + std::string manifest = file_hex + " a.txt\n"; + std::string expected = "sha256:" + Sha256::HashStringHex(manifest); + + const char* uri = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s.root().c_str(), &uri)); + CHECK(std::string(uri) == expected); + return true; +} + +bool test_directory_hash_sorted_order_independent_of_walk() { + // Whether the OS walks "b.txt" before "a.txt" must not matter. + Sandbox s; + s.Write("a.txt", "alpha"); + s.Write("b.txt", "beta"); + s.Write("c.txt", "gamma"); + + // Compute expected manifest manually (sorted). + std::string hex_a = Sha256::HashStringHex("alpha"); + std::string hex_b = Sha256::HashStringHex("beta"); + std::string hex_c = Sha256::HashStringHex("gamma"); + std::string manifest = hex_a + " a.txt\n" + + hex_b + " b.txt\n" + + hex_c + " c.txt\n"; + std::string expected = "sha256:" + Sha256::HashStringHex(manifest); + + const char* uri = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s.root().c_str(), &uri)); + CHECK(std::string(uri) == expected); + return true; +} + +bool test_directory_hash_uses_forward_slash() { + Sandbox s; + s.Write("dir/sub/c.txt", "x"); + + std::string file_hex = Sha256::HashStringHex("x"); + // Path must be POSIX style in the manifest (forward slashes). + std::string manifest = file_hex + " dir/sub/c.txt\n"; + std::string expected = "sha256:" + Sha256::HashStringHex(manifest); + + const char* uri = nullptr; + CHECK_OK(ModelPackage_ComputeDirectoryHash(s.root().c_str(), &uri)); + CHECK(std::string(uri) == expected); + return true; +} + +bool test_missing_directory_errors() { + const char* uri = nullptr; + ModelPackageStatus* s = ModelPackage_ComputeDirectoryHash("/tmp/does_not_exist_xyzzy_zzz", &uri); + CHECK(s != nullptr); + CHECK(ModelPackageStatus_Code(s) == MODEL_PACKAGE_ERR_NOT_FOUND); + ModelPackageStatus_Release(s); + return true; +} + +struct Test { + const char* name; + bool (*fn)(); +}; + +const Test kTests[] = { + {"sha256_known_vectors", test_sha256_known_vectors}, + {"sha256_incremental_matches_oneshot", test_sha256_incremental_matches_oneshot}, + {"directory_hash_basic", test_directory_hash_basic}, + {"directory_hash_reproducible", test_directory_hash_reproducible}, + {"directory_hash_name_change_differs", test_directory_hash_name_change_differs}, + {"directory_hash_swapped_names_differ", test_directory_hash_swapped_names_differ}, + {"directory_hash_content_change_differs", test_directory_hash_content_change_differs}, + {"directory_hash_empty_dirs_ignored", test_directory_hash_empty_dirs_ignored}, + {"directory_hash_rejects_symlink", test_directory_hash_rejects_symlink}, + {"directory_hash_known_value_single_file", test_directory_hash_known_value_single_file}, + {"directory_hash_sorted_order_independent_of_walk", test_directory_hash_sorted_order_independent_of_walk}, + {"directory_hash_uses_forward_slash", test_directory_hash_uses_forward_slash}, + {"missing_directory_errors", test_missing_directory_errors}, +}; + +} // namespace + +int main() { + for (const auto& t : kTests) { + g_current = t.name; + bool ok = t.fn(); + if (ok) { + std::printf("[PASS] %s\n", t.name); + g_passed++; + } else { + g_failed++; + } + } + std::printf("\n=== %d passed, %d failed ===\n", g_passed, g_failed); + return g_failed == 0 ? 0 : 1; +} diff --git a/model_package/tests/test_authoring.cc b/model_package/tests/test_authoring.cc new file mode 100644 index 0000000000000..4f6808d966093 --- /dev/null +++ b/model_package/tests/test_authoring.cc @@ -0,0 +1,525 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file test_authoring.cc +/// \brief Authoring (mutation) API tests. + +#include "model_package.h" +#include "model_package_api.h" + +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace { + +int g_failed = 0; +int g_passed = 0; +const char* g_current = ""; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: CHECK(%s)\n", g_current, __LINE__, #cond); \ + return false; \ + } \ + } while (0) + +#define CHECK_OK(status) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s != nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected OK, got: %s\n", \ + g_current, __LINE__, ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + } while (0) + +#define CHECK_ERR(status, expected_code) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s == nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got OK\n", \ + g_current, __LINE__, (int)(expected_code)); \ + return false; \ + } \ + ModelPackageErrorCode _c = ModelPackageStatus_Code(_s); \ + if (_c != (expected_code)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got %d (%s)\n", \ + g_current, __LINE__, (int)(expected_code), (int)_c, \ + ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + ModelPackageStatus_Release(_s); \ + } while (0) + +class Sandbox { + public: + Sandbox() { + std::random_device rd; + std::mt19937_64 g(rd()); + char buf[32]; + std::snprintf(buf, sizeof(buf), "mp_auth_%016llx", static_cast(g())); + root_ = fs::temp_directory_path() / buf; + fs::create_directories(root_); + } + ~Sandbox() { + std::error_code ec; + fs::remove_all(root_, ec); + } + Sandbox(const Sandbox&) = delete; + Sandbox& operator=(const Sandbox&) = delete; + const fs::path& root() const { return root_; } + fs::path path(const std::string& rel) const { return root_ / rel; } + void Write(const std::string& rel, const std::string& contents) { + fs::path full = root_ / rel; + fs::create_directories(full.parent_path()); + std::ofstream f(full, std::ios::binary); + f << contents; + } + + private: + fs::path root_; +}; + +class PkgHandle { + public: + explicit PkgHandle(ModelPackage* p) : p_(p) {} + ~PkgHandle() { ModelPackage_Close(p_); } + PkgHandle(const PkgHandle&) = delete; + PkgHandle& operator=(const PkgHandle&) = delete; + ModelPackage* get() const { return p_; } + + private: + ModelPackage* p_; +}; + +// ───────────────────────────────────────────────────────────────────────────── +// ModelPackage_New +// ───────────────────────────────────────────────────────────────────────────── + +bool test_new_creates_empty_package() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + CHECK(raw != nullptr); + PkgHandle p(raw); + const ModelPackageInfo* info = ModelPackage_Info(p.get()); + CHECK(info != nullptr); + CHECK(info->schema_version_major == 0); + CHECK(info->schema_version_minor == 0); + CHECK((info)->num_components == 0); + CHECK((info)->num_shared_assets == 0); + CHECK(std::string(info->layout) == "portable"); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Component operations +// ───────────────────────────────────────────────────────────────────────────── + +bool test_set_component_inline_basic() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "encoder", + R"({"variants": {}})")); + CHECK((ModelPackage_Info(p.get()))->num_components == 1); + const ModelComponentInfo* c = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "encoder"); + CHECK(c != nullptr); + CHECK(std::string(c->name) == "encoder"); + CHECK((c)->num_variants == 0); + return true; +} + +bool test_set_component_inline_replaces_existing() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", + R"({"variants": {"v1": {"variant_directory": "."}}})")); + CHECK((ModelPackage_Info(p.get()))->num_components == 1); + const ModelComponentInfo* c = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"); + CHECK((c)->num_variants == 1); + return true; +} + +bool test_set_component_inline_rejects_unknown_field() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_ERR(ModelPackage_SetComponentInline(p.get(), "c", + R"({"variants": {}, "typo_field": 1})"), + MODEL_PACKAGE_ERR_SCHEMA); + CHECK((ModelPackage_Info(p.get()))->num_components == 0); + return true; +} + +bool test_set_component_inline_rejects_bad_json() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_ERR(ModelPackage_SetComponentInline(p.get(), "c", "not-json"), + MODEL_PACKAGE_ERR_SCHEMA); + return true; +} + +bool test_remove_component() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "a", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "b", R"({"variants": {}})")); + CHECK((ModelPackage_Info(p.get()))->num_components == 2); + CHECK_OK(ModelPackage_RemoveComponent(p.get(), "a")); + CHECK((ModelPackage_Info(p.get()))->num_components == 1); + const ModelPackageInfo* info = ModelPackage_Info(p.get()); + CHECK(ModelPackage_FindComponent(info, "a") == nullptr); + CHECK(ModelPackage_FindComponent(info, "b") != nullptr); + return true; +} + +bool test_remove_missing_component_is_noop() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_RemoveComponent(p.get(), "nope")); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Variant operations +// ───────────────────────────────────────────────────────────────────────────── + +bool test_set_variant_upsert() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", + R"({"variant_directory": ".", "ep": "CPU"})")); + const ModelComponentInfo* c = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"); + CHECK((c)->num_variants == 1); + const ModelVariantInfo* v = ModelComponentInfo_FindVariant(c, "v1"); + CHECK(v != nullptr); + CHECK(std::string(v->ep) == "CPU"); + + // Upsert: change ep. + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", + R"({"variant_directory": ".", "ep": "CUDA"})")); + c = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"); + CHECK((c)->num_variants == 1); + v = ModelComponentInfo_FindVariant(c, "v1"); + CHECK(std::string(v->ep) == "CUDA"); + return true; +} + +bool test_set_variant_unknown_component_errors() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_ERR(ModelPackage_SetVariant(p.get(), "nope", "v1", R"({"variant_directory": "."})"), + MODEL_PACKAGE_ERR_NOT_FOUND); + return true; +} + +bool test_remove_variant() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", R"({"variant_directory": "."})")); + CHECK_OK(ModelPackage_RemoveVariant(p.get(), "c", "v1")); + const ModelComponentInfo* c = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"); + CHECK((c)->num_variants == 0); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Variant executor_info +// ───────────────────────────────────────────────────────────────────────────── + +bool test_set_executor_info_inline_and_remove() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", R"({"variant_directory": "."})")); + + CHECK_OK(ModelPackage_SetVariantExecutorInfoInline(p.get(), "c", "v1", "ort", + R"({"model": "m.onnx"})")); + const ModelVariantInfo* v = ModelComponentInfo_FindVariant( + ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"), "v1"); + const char* ej = nullptr; + const ModelExecutorInfoEntry* ei = ModelVariantInfo_FindExecutorInfo(v, "ort"); + ej = ei ? ei->json : nullptr; + CHECK(ej != nullptr); + CHECK(std::strstr(ej, "\"model\"") != nullptr); + + CHECK_OK(ModelPackage_RemoveVariantExecutorInfo(p.get(), "c", "v1", "ort")); + v = ModelComponentInfo_FindVariant(ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"), "v1"); + ei = ModelVariantInfo_FindExecutorInfo(v, "ort"); + ej = ei ? ei->json : nullptr; + CHECK(ei == nullptr); + CHECK(ej == nullptr); + return true; +} + +bool test_set_executor_info_external_records_path() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", R"({"variant_directory": "."})")); + CHECK_OK(ModelPackage_SetVariantExecutorInfoExternal(p.get(), "c", "v1", "ort", + "ort_info.json")); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Package metadata +// ───────────────────────────────────────────────────────────────────────────── + +bool test_set_metadata() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetMetadata(p.get(), "mypkg", "1.0.0", "desc")); + const ModelPackageInfo* info = ModelPackage_Info(p.get()); + CHECK(std::string(info->package_name) == "mypkg"); + CHECK(std::string(info->package_version) == "1.0.0"); + CHECK(std::string(info->description) == "desc"); + + // Empty string clears. + CHECK_OK(ModelPackage_SetMetadata(p.get(), nullptr, "", nullptr)); + info = ModelPackage_Info(p.get()); + CHECK(info->package_version == nullptr); + CHECK(std::string(info->package_name) == "mypkg"); + return true; +} + +bool test_set_layout() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetLayout(p.get(), "installed")); + CHECK(std::string(ModelPackage_Info(p.get())->layout) == "installed"); + CHECK_ERR(ModelPackage_SetLayout(p.get(), "weird"), MODEL_PACKAGE_ERR_SCHEMA); + return true; +} + +bool test_set_additional_metadata_manifest_scope() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetAdditionalMetadataJson(p.get(), "manifest", nullptr, nullptr, + R"({"author":"jambayk"})")); + const ModelPackageInfo* info = ModelPackage_Info(p.get()); + CHECK(info->additional_metadata_json != nullptr); + CHECK(std::string(info->additional_metadata_json).find("jambayk") != std::string::npos); + + // Clear. + CHECK_OK(ModelPackage_SetAdditionalMetadataJson(p.get(), "manifest", nullptr, nullptr, nullptr)); + info = ModelPackage_Info(p.get()); + CHECK(info->additional_metadata_json == nullptr); + return true; +} + +bool test_set_additional_metadata_variant_scope() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetVariant(p.get(), "c", "v1", R"({"variant_directory": "."})")); + CHECK_OK(ModelPackage_SetAdditionalMetadataJson(p.get(), "variant", "c", "v1", + R"({"foo":"bar"})")); + const ModelVariantInfo* v = ModelComponentInfo_FindVariant( + ModelPackage_FindComponent(ModelPackage_Info(p.get()), "c"), "v1"); + CHECK(v != nullptr); + const char* md = v->additional_metadata_json; + CHECK(md != nullptr); + CHECK(std::string(md).find("foo") != std::string::npos); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Shared assets — authoring +// ───────────────────────────────────────────────────────────────────────────── + +bool test_add_shared_asset_copy_in_true_portable_ok() { + Sandbox s; + s.Write("src/a.txt", "alpha"); + + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), (s.root() / "src").c_str(), + nullptr, /*copy_in=*/true, &uri)); + CHECK(uri != nullptr); + CHECK(std::string(uri).substr(0, 7) == "sha256:"); + return true; +} + +bool test_add_shared_asset_copy_in_false_portable_rejected() { + Sandbox s; + s.Write("src/a.txt", "alpha"); + + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + const char* uri = nullptr; + CHECK_ERR(ModelPackage_AddSharedAsset(p.get(), (s.root() / "src").c_str(), + nullptr, /*copy_in=*/false, &uri), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +bool test_add_shared_asset_copy_in_false_installed_ok() { + Sandbox s; + s.Write("src/a.txt", "alpha"); + + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetLayout(p.get(), "installed")); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), (s.root() / "src").c_str(), + nullptr, /*copy_in=*/false, &uri)); + CHECK(uri != nullptr); + // Surfaced as a manifest override -> shared_assets count should be 1. + CHECK((ModelPackage_Info(p.get()))->num_shared_assets == 1); + return true; +} + +bool test_add_shared_asset_expected_uri_mismatch_errors() { + Sandbox s; + s.Write("src/a.txt", "alpha"); + + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetLayout(p.get(), "installed")); + const char* uri = nullptr; + std::string bogus = "sha256:" + std::string(64, '0'); + CHECK_ERR(ModelPackage_AddSharedAsset(p.get(), (s.root() / "src").c_str(), + bogus.c_str(), /*copy_in=*/false, &uri), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +bool test_remove_shared_asset() { + Sandbox s; + s.Write("src/a.txt", "alpha"); + + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetLayout(p.get(), "installed")); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), (s.root() / "src").c_str(), + nullptr, /*copy_in=*/false, &uri)); + std::string uri_copy(uri); + CHECK((ModelPackage_Info(p.get()))->num_shared_assets == 1); + CHECK_OK(ModelPackage_RemoveSharedAsset(p.get(), uri_copy.c_str())); + CHECK((ModelPackage_Info(p.get()))->num_shared_assets == 0); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Round-trip through GetComponentJson / GetVariantJson +// ───────────────────────────────────────────────────────────────────────────── + +bool test_round_trip_component_json() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "c", + R"({"variants": {"v1": {"variant_directory": ".", "ep": "CPU"}}})")); + const char* j = nullptr; + CHECK_OK(ModelPackage_GetComponentJson(p.get(), "c", &j)); + CHECK(j != nullptr); + std::string s(j); + CHECK(s.find("\"variants\"") != std::string::npos); + CHECK(s.find("\"v1\"") != std::string::npos); + CHECK(s.find("\"CPU\"") != std::string::npos); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// View cache invalidation after mutation +// ───────────────────────────────────────────────────────────────────────────── + +bool test_view_cache_drops_on_remove() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "a", R"({"variants": {}})")); + CHECK_OK(ModelPackage_SetComponentInline(p.get(), "b", R"({"variants": {}})")); + const ModelComponentInfo* a = ModelPackage_FindComponent(ModelPackage_Info(p.get()), "a"); + CHECK(a != nullptr); + CHECK_OK(ModelPackage_RemoveComponent(p.get(), "a")); + // Old pointer was invalidated by the mutation; re-fetch and 'a' must now be gone. + const ModelPackageInfo* info = ModelPackage_Info(p.get()); + CHECK(ModelPackage_FindComponent(info, "a") == nullptr); + CHECK(ModelPackage_FindComponent(info, "b") != nullptr); + return true; +} + +struct Test { + const char* name; + bool (*fn)(); +}; + +const Test kTests[] = { + {"new_creates_empty_package", test_new_creates_empty_package}, + {"set_component_inline_basic", test_set_component_inline_basic}, + {"set_component_inline_replaces_existing", test_set_component_inline_replaces_existing}, + {"set_component_inline_rejects_unknown_field", test_set_component_inline_rejects_unknown_field}, + {"set_component_inline_rejects_bad_json", test_set_component_inline_rejects_bad_json}, + {"remove_component", test_remove_component}, + {"remove_missing_component_is_noop", test_remove_missing_component_is_noop}, + {"set_variant_upsert", test_set_variant_upsert}, + {"set_variant_unknown_component_errors", test_set_variant_unknown_component_errors}, + {"remove_variant", test_remove_variant}, + {"set_executor_info_inline_and_remove", test_set_executor_info_inline_and_remove}, + {"set_executor_info_external_records_path", test_set_executor_info_external_records_path}, + {"set_metadata", test_set_metadata}, + {"set_layout", test_set_layout}, + {"set_additional_metadata_manifest_scope", test_set_additional_metadata_manifest_scope}, + {"set_additional_metadata_variant_scope", test_set_additional_metadata_variant_scope}, + {"add_shared_asset_copy_in_true_portable_ok", test_add_shared_asset_copy_in_true_portable_ok}, + {"add_shared_asset_copy_in_false_portable_rejected", test_add_shared_asset_copy_in_false_portable_rejected}, + {"add_shared_asset_copy_in_false_installed_ok", test_add_shared_asset_copy_in_false_installed_ok}, + {"add_shared_asset_expected_uri_mismatch_errors", test_add_shared_asset_expected_uri_mismatch_errors}, + {"remove_shared_asset", test_remove_shared_asset}, + {"round_trip_component_json", test_round_trip_component_json}, + {"view_cache_drops_on_remove", test_view_cache_drops_on_remove}, +}; + +} // namespace + +int main() { + for (const auto& t : kTests) { + g_current = t.name; + bool ok = t.fn(); + if (ok) { + std::printf("[PASS] %s\n", t.name); + g_passed++; + } else { + g_failed++; + } + } + std::printf("\n=== %d passed, %d failed ===\n", g_passed, g_failed); + return g_failed == 0 ? 0 : 1; +} diff --git a/model_package/tests/test_commit.cc b/model_package/tests/test_commit.cc new file mode 100644 index 0000000000000..4ede82394e170 --- /dev/null +++ b/model_package/tests/test_commit.cc @@ -0,0 +1,504 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file test_commit.cc +/// \brief Commit, prune, and validate tests. + +#include "model_package.h" +#include "model_package_api.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace { + +int g_failed = 0; +int g_passed = 0; +const char* g_current = ""; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: CHECK(%s)\n", g_current, __LINE__, #cond); \ + return false; \ + } \ + } while (0) + +#define CHECK_OK(status) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s != nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected OK, got: %s\n", \ + g_current, __LINE__, ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + } while (0) + +#define CHECK_ERR(status, expected_code) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s == nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got OK\n", \ + g_current, __LINE__, (int)(expected_code)); \ + return false; \ + } \ + ModelPackageErrorCode _c = ModelPackageStatus_Code(_s); \ + if (_c != (expected_code)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got %d (%s)\n", \ + g_current, __LINE__, (int)(expected_code), (int)_c, \ + ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + ModelPackageStatus_Release(_s); \ + } while (0) + +class Sandbox { + public: + Sandbox() { + std::random_device rd; + std::mt19937_64 g(rd()); + char buf[32]; + std::snprintf(buf, sizeof(buf), "mp_commit_%016llx", static_cast(g())); + root_ = fs::temp_directory_path() / buf; + fs::create_directories(root_); + } + ~Sandbox() { + std::error_code ec; + fs::remove_all(root_, ec); + } + Sandbox(const Sandbox&) = delete; + Sandbox& operator=(const Sandbox&) = delete; + const fs::path& root() const { return root_; } + fs::path path(const std::string& rel) const { return root_ / rel; } + void Write(const std::string& rel, const std::string& contents) { + fs::path full = root_ / rel; + fs::create_directories(full.parent_path()); + std::ofstream f(full, std::ios::binary); + f << contents; + } + + private: + fs::path root_; +}; + +class PkgHandle { + public: + explicit PkgHandle(ModelPackage* p) : p_(p) {} + ~PkgHandle() { ModelPackage_Close(p_); } + PkgHandle(const PkgHandle&) = delete; + PkgHandle& operator=(const PkgHandle&) = delete; + ModelPackage* get() const { return p_; } + ModelPackage** outparam() { return &p_; } + + private: + ModelPackage* p_; +}; + +// Open a freshly-created in-memory package ready to commit at `root`. +// `root` must be empty/nonexistent for the subsequent dest_root commit. +PkgHandle MakeAuthoredPkgAt(const fs::path& /*root*/, + const std::string& layout = "portable") { + ModelPackage* raw = nullptr; + ModelPackage_New(&raw); + if (layout != "portable") ModelPackage_SetLayout(raw, layout.c_str()); + ModelPackage_SetComponentInline(raw, "encoder", R"({"variants": {}})"); + ModelPackage_SetVariant(raw, "encoder", "v1", R"({"ep": "CPU"})"); + return PkgHandle(raw); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Commit (in-place, PRESERVE) +// ───────────────────────────────────────────────────────────────────────────── + +bool test_commit_inplace_basic_roundtrip() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), MODEL_PACKAGE_WRITE_PRESERVE)); + // manifest.json exists. + CHECK(fs::is_regular_file(s.path("pkg") / "manifest.json")); + + // Reopen and confirm. + ModelPackage* re = nullptr; + CHECK_OK(ModelPackage_Open(s.path("pkg").c_str(), nullptr, &re)); + PkgHandle rep(re); + CHECK((ModelPackage_Info(rep.get()))->num_components == 1); + const ModelPackageInfo* info = ModelPackage_Info(rep.get()); + const ModelComponentInfo* c = ModelPackage_FindComponent(info, "encoder"); + CHECK(c != nullptr); + CHECK((c)->num_variants == 1); + const ModelVariantInfo* v = ModelComponentInfo_FindVariant(c, "v1"); + CHECK(std::string(v->ep) == "CPU"); + return true; +} + +bool test_commit_requires_package_root() { + ModelPackage* raw = nullptr; + CHECK_OK(ModelPackage_New(&raw)); + PkgHandle p(raw); + CHECK_ERR(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +bool test_commit_external_component_writes_file() { + Sandbox s; + // Author an inline package committed to disk first. + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), MODEL_PACKAGE_WRITE_PRESERVE)); + + // Reopen, add an external component pointing at a file that doesn't exist yet. + ModelPackage* re = nullptr; + CHECK_OK(ModelPackage_Open(s.path("pkg").c_str(), nullptr, &re)); + PkgHandle rep(re); + CHECK_OK(ModelPackage_SetComponentExternal(rep.get(), "decoder", "decoder.json")); + CHECK_OK(ModelPackage_Commit(rep.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK(fs::is_regular_file(s.path("pkg") / "decoder.json")); + CHECK(fs::is_regular_file(s.path("pkg") / "manifest.json")); + + // Reopen yet again and verify external component round-trips. + ModelPackage* re2 = nullptr; + CHECK_OK(ModelPackage_Open(s.path("pkg").c_str(), nullptr, &re2)); + PkgHandle rep2(re2); + CHECK(ModelPackage_FindComponent(ModelPackage_Info(rep2.get()), "decoder") != nullptr); + return true; +} + +bool test_commit_pending_shared_asset_copy_in() { + Sandbox s; + s.Write("src_asset/m.onnx", "hello world"); + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, /*copy_in=*/true, &uri)); + std::string uri_copy(uri); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + std::string hex = uri_copy.substr(7); + fs::path landed = s.path("pkg") / "shared_assets" / ("sha256-" + hex); + CHECK(fs::is_directory(landed)); + CHECK(fs::is_regular_file(landed / "m.onnx")); + return true; +} + +bool test_commit_dense_inlines_external_component() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK_OK(ModelPackage_SetComponentExternal(p.get(), "decoder", "decoder.json")); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_DENSE)); + // The dense commit should NOT have written decoder.json (component became inline). + CHECK(!fs::exists(s.path("pkg") / "decoder.json")); + // Manifest contains decoder as an inline object. + std::ifstream f(s.path("pkg") / "manifest.json"); + std::ostringstream oss; + oss << f.rdbuf(); + std::string m = oss.str(); + CHECK(m.find("\"decoder\"") != std::string::npos); + CHECK(m.find("\"variants\"") != std::string::npos); + return true; +} + +bool test_commit_dense_rejects_external_executor_info() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_SetVariantExecutorInfoExternal( + p.get(), "encoder", "v1", "ort", "encoder/ort.json")); + CHECK_ERR(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), MODEL_PACKAGE_WRITE_DENSE), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Commit (dest_root "save as") +// ───────────────────────────────────────────────────────────────────────────── + +bool test_commit_dest_root_self_contained() { + Sandbox s; + s.Write("src_asset/m.onnx", "alpha"); + PkgHandle p = MakeAuthoredPkgAt(s.path("orig")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("orig").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + + // Add an asset and commit as. + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, /*copy_in=*/true, &uri)); + std::string uri_copy(uri); + fs::path saved = s.path("saved"); + CHECK_OK(ModelPackage_Commit(p.get(), saved.c_str(), MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK(fs::is_regular_file(saved / "manifest.json")); + std::string hex = uri_copy.substr(7); + CHECK(fs::is_directory(saved / "shared_assets" / ("sha256-" + hex))); + + // After dest_root commit, in-memory state reflects the new root. + // (We can verify by mutating + committing in-place again.) + CHECK_OK(ModelPackage_SetMetadata(p.get(), "savedpkg", "1.0", nullptr)); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE)); + // The most recent in-place commit should have landed at `saved`, not `orig`. + std::ifstream f(saved / "manifest.json"); + std::ostringstream oss; + oss << f.rdbuf(); + CHECK(oss.str().find("savedpkg") != std::string::npos); + return true; +} + +bool test_commit_dest_root_must_be_empty() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + s.Write("dest/something", "x"); + // Try to commit to non-empty dest. + CHECK_ERR(ModelPackage_Commit(p.get(), s.path("dest").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Prune +// ───────────────────────────────────────────────────────────────────────────── + +bool test_commit_dest_root_rehashes_existing_asset() { + Sandbox s; + s.Write("src_asset/m.onnx", "alpha"); + PkgHandle p = MakeAuthoredPkgAt(s.path("orig")); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, /*copy_in=*/true, &uri)); + std::string uri_copy(uri); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("orig").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + + // Tamper with the landed sha256-/ dir under the existing package root. + std::string hex = uri_copy.substr(7); + fs::path landed = s.path("orig") / "shared_assets" / ("sha256-" + hex) / "m.onnx"; + { + std::ofstream f(landed, std::ios::binary); + f << "TAMPERED"; + } + + // CommitToDestRoot must rehash the source and refuse the mismatch. + CHECK_ERR(ModelPackage_Commit(p.get(), s.path("saved").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE), + MODEL_PACKAGE_ERR_STATE); + return true; +} + +bool test_prune_never_touches_shared_assets() { + // Shared assets are content-addressed and only removed via explicit + // RemoveSharedAsset. Even an obviously orphan sha256-/ directory that + // matches no manifest entry must survive Prune. + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + + fs::path planted = s.path("pkg") / "shared_assets" / + ("sha256-" + std::string(64, 'a')); + fs::create_directories(planted); + // Backdate mtime to past grace window to make sure it isn't grace-protected. + auto old = fs::file_time_type::clock::now() - std::chrono::seconds(120); + std::error_code ec; + fs::last_write_time(planted, old, ec); + CHECK_OK(ModelPackage_Prune(p.get())); + CHECK(fs::is_directory(planted)); + return true; +} + +bool test_prune_reclaims_tracked_orphan_variant_dirs() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + // Now that package_root is anchored, materialize an on-disk variant dir and + // register it so subsequent removal records a tracked orphan. + fs::path victim = s.path("pkg") / "encoder" / "v1"; + fs::create_directories(victim); + CHECK_OK(ModelPackage_SetVariant(p.get(), "encoder", "v1", + R"({"ep":"CPU","variant_directory":"encoder/v1"})")); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK(fs::is_directory(victim)); + CHECK_OK(ModelPackage_RemoveVariant(p.get(), "encoder", "v1")); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK_OK(ModelPackage_Prune(p.get())); + CHECK(!fs::exists(victim)); + return true; +} + +bool test_prune_removes_stale_staging_dirs() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + + fs::path stage = s.path("pkg") / "shared_assets" / + ("sha256-" + std::string(64, 'c') + ".tmp.abcdef0123"); + fs::create_directories(stage); + auto old = fs::file_time_type::clock::now() - std::chrono::seconds(120); + std::error_code ec; + fs::last_write_time(stage, old, ec); + CHECK_OK(ModelPackage_Prune(p.get())); + CHECK(!fs::exists(stage)); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Validate +// ───────────────────────────────────────────────────────────────────────────── + +bool test_validate_all_clean_package() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + const char* report = nullptr; + CHECK_OK(ModelPackage_Validate(p.get(), MODEL_PACKAGE_VALIDATE_ALL, &report)); + CHECK(report != nullptr); + CHECK(std::string(report).find("\"errors\": []") != std::string::npos); + return true; +} + +bool test_validate_paths_flags_missing_external() { + Sandbox s; + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + // Register an external component then delete the file behind the library's back. + CHECK_OK(ModelPackage_SetComponentExternal(p.get(), "decoder", "decoder.json")); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, MODEL_PACKAGE_WRITE_PRESERVE)); + std::error_code ec; + fs::remove(s.path("pkg") / "decoder.json", ec); + const char* report = nullptr; + CHECK_OK(ModelPackage_Validate(p.get(), MODEL_PACKAGE_VALIDATE_PATHS, &report)); + CHECK(std::string(report).find("PATHS") != std::string::npos); + return true; +} + +bool test_validate_asset_rehash_detects_mutation() { + Sandbox s; + s.Write("src_asset/m.onnx", "alpha"); + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, /*copy_in=*/true, &uri)); + std::string uri_copy(uri); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + // Mutate the on-disk shared asset directly. + std::string hex = uri_copy.substr(7); + fs::path landed = s.path("pkg") / "shared_assets" / ("sha256-" + hex) / "m.onnx"; + CHECK(fs::is_regular_file(landed)); + { + std::ofstream f(landed, std::ios::binary); + f << "MUTATED"; + } + const char* report = nullptr; + CHECK_ERR(ModelPackage_Validate(p.get(), MODEL_PACKAGE_VALIDATE_ASSET_REHASH, &report), + MODEL_PACKAGE_ERR_STATE); + CHECK(std::string(report).find("ASSET_REHASH") != std::string::npos); + return true; +} + +bool test_commit_accepts_unreferenced_shared_asset() { + // Shared assets no longer require an in-manifest reference: AddSharedAsset + // signals the user's intent to ship the asset, period. Commit materializes + // it under shared_assets/ at the default-convention path. + Sandbox s; + s.Write("src_asset/m.onnx", "alpha"); + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, /*copy_in=*/true, &uri)); + std::string uri_copy(uri); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + std::string hex = uri_copy.substr(7); + CHECK(fs::is_directory(s.path("pkg") / "shared_assets" / ("sha256-" + hex))); + // Same on dest_root path. + CHECK_OK(ModelPackage_Commit(p.get(), s.path("saved").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + CHECK(fs::is_directory(s.path("saved") / "shared_assets" / ("sha256-" + hex))); + return true; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Atomicity hint: no stray .tmp.* under after successful commit +// ───────────────────────────────────────────────────────────────────────────── + +bool test_commit_leaves_no_temp_files() { + Sandbox s; + s.Write("src_asset/m.onnx", "alpha"); + PkgHandle p = MakeAuthoredPkgAt(s.path("pkg")); + CHECK_OK(ModelPackage_Commit(p.get(), s.path("pkg").c_str(), + MODEL_PACKAGE_WRITE_PRESERVE)); + const char* uri = nullptr; + CHECK_OK(ModelPackage_AddSharedAsset(p.get(), s.path("src_asset").c_str(), + nullptr, true, &uri)); + (void)uri; + CHECK_OK(ModelPackage_SetComponentExternal(p.get(), "decoder", "decoder.json")); + CHECK_OK(ModelPackage_Commit(p.get(), nullptr, + MODEL_PACKAGE_WRITE_PRESERVE)); + std::error_code ec; + for (auto& e : fs::recursive_directory_iterator(s.path("pkg"), ec)) { + if (e.path().filename().string().find(".tmp.") != std::string::npos) { + std::fprintf(stderr, " stray temp file: %s\n", e.path().c_str()); + return false; + } + } + return true; +} + +struct Test { + const char* name; + bool (*fn)(); +}; + +const Test kTests[] = { + {"commit_inplace_basic_roundtrip", test_commit_inplace_basic_roundtrip}, + {"commit_requires_package_root", test_commit_requires_package_root}, + {"commit_external_component_writes_file", test_commit_external_component_writes_file}, + {"commit_pending_shared_asset_copy_in", test_commit_pending_shared_asset_copy_in}, + {"commit_dense_inlines_external_component", test_commit_dense_inlines_external_component}, + {"commit_dense_rejects_external_executor_info", test_commit_dense_rejects_external_executor_info}, + {"commit_dest_root_self_contained", test_commit_dest_root_self_contained}, + {"commit_dest_root_must_be_empty", test_commit_dest_root_must_be_empty}, + {"commit_dest_root_rehashes_existing_asset", test_commit_dest_root_rehashes_existing_asset}, + {"prune_never_touches_shared_assets", test_prune_never_touches_shared_assets}, + {"prune_reclaims_tracked_orphan_variant_dirs", test_prune_reclaims_tracked_orphan_variant_dirs}, + {"prune_removes_stale_staging_dirs", test_prune_removes_stale_staging_dirs}, + {"validate_all_clean_package", test_validate_all_clean_package}, + {"validate_paths_flags_missing_external", test_validate_paths_flags_missing_external}, + {"validate_asset_rehash_detects_mutation", test_validate_asset_rehash_detects_mutation}, + {"commit_accepts_unreferenced_shared_asset", test_commit_accepts_unreferenced_shared_asset}, + {"commit_leaves_no_temp_files", test_commit_leaves_no_temp_files}, +}; + +} // namespace + +int main() { + for (const auto& t : kTests) { + g_current = t.name; + bool ok = t.fn(); + if (ok) { + std::printf("[PASS] %s\n", t.name); + g_passed++; + } else { + g_failed++; + } + } + std::printf("\n=== %d passed, %d failed ===\n", g_passed, g_failed); + return g_failed == 0 ? 0 : 1; +} diff --git a/model_package/tests/test_inspection.cc b/model_package/tests/test_inspection.cc new file mode 100644 index 0000000000000..0b51681bc7c80 --- /dev/null +++ b/model_package/tests/test_inspection.cc @@ -0,0 +1,582 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +/// \file test_inspection.cc +/// \brief Tests for the read-only inspection API (model_package.h). + +#include "model_package.h" +#include "model_package_api.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace { + +int g_failed = 0; +int g_passed = 0; +const char* g_current = ""; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: CHECK(%s)\n", g_current, __LINE__, #cond); \ + return false; \ + } \ + } while (0) + +#define CHECK_OK(status) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s != nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected OK, got: %s\n", \ + g_current, __LINE__, ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + } while (0) + +#define CHECK_ERR(status, expected_code) \ + do { \ + ModelPackageStatus* _s = (status); \ + if (_s == nullptr) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got OK\n", \ + g_current, __LINE__, (int)(expected_code)); \ + return false; \ + } \ + ModelPackageErrorCode _c = ModelPackageStatus_Code(_s); \ + if (_c != (expected_code)) { \ + std::fprintf(stderr, "[FAIL] %s line %d: expected error %d, got %d: %s\n", \ + g_current, __LINE__, (int)(expected_code), (int)_c, \ + ModelPackageStatus_Message(_s)); \ + ModelPackageStatus_Release(_s); \ + return false; \ + } \ + ModelPackageStatus_Release(_s); \ + } while (0) + +class Sandbox { + public: + Sandbox() { + std::random_device rd; + std::mt19937_64 g(rd()); + char buf[32]; + std::snprintf(buf, sizeof(buf), "mp_inspect_%016llx", static_cast(g())); + root_ = fs::temp_directory_path() / buf; + fs::create_directories(root_); + } + ~Sandbox() { + std::error_code ec; + fs::remove_all(root_, ec); + } + Sandbox(const Sandbox&) = delete; + Sandbox& operator=(const Sandbox&) = delete; + + const fs::path& root() const { return root_; } + + void Write(const std::string& relpath, const std::string& contents) { + fs::path full = root_ / relpath; + fs::create_directories(full.parent_path()); + std::ofstream f(full, std::ios::binary); + f << contents; + } + + void Touch(const std::string& relpath) { Write(relpath, ""); } + + private: + fs::path root_; +}; + +bool test_open_minimal_inline() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "package_name": "test", + "components": { + "alpha": { + "variants": { + "cpu": {} + } + } + } + })"); + + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + CHECK(pkg != nullptr); + + const ModelPackageInfo* info = ModelPackage_Info(pkg); + CHECK(info != nullptr); + CHECK(info->schema_version_major == 1); + CHECK(info->schema_version_minor == 0); + CHECK(std::string(info->package_name) == "test"); + CHECK(std::string(info->layout) == "portable"); + CHECK((info)->num_components == 1); + CHECK((info)->num_shared_assets == 0); + CHECK(info->additional_metadata_json == nullptr); + + const ModelComponentInfo* c = &(info)->components[0]; + CHECK(c != nullptr); + CHECK(std::string(c->name) == "alpha"); + CHECK((c)->num_variants == 1); + + const ModelVariantInfo* v = &(c)->variants[0]; + CHECK(v != nullptr); + CHECK(std::string(v->name) == "cpu"); + CHECK(v->ep == nullptr); + CHECK(v->device == nullptr); + CHECK(v->compatibility_string == nullptr); + + ModelPackage_Close(pkg); + return true; +} + +bool test_open_full_inline_with_metadata() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "package_name": "phi-4", + "package_version": "1.2.3", + "description": "demo", + "layout": "portable", + "additional_metadata": {"author": "team"}, + "components": { + "decoder": { + "additional_metadata": {"size": "small"}, + "variants": { + "cuda_fp16": { + "variant_directory": "decoder/cuda_fp16", + "ep": "CUDAExecutionProvider", + "device": "gpu", + "compatibility_string": "sm_80", + "additional_metadata": {"notes": "quantized"} + } + } + } + } + })"); + fs::create_directories(s.root() / "decoder" / "cuda_fp16"); + + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelPackageInfo* info = ModelPackage_Info(pkg); + CHECK(std::string(info->package_name) == "phi-4"); + CHECK(std::string(info->package_version) == "1.2.3"); + CHECK(std::string(info->description) == "demo"); + CHECK(info->additional_metadata_json != nullptr); + CHECK(std::string(info->additional_metadata_json).find("\"author\":\"team\"") != std::string::npos); + + const ModelComponentInfo* c = ModelPackage_FindComponent(info, "decoder"); + CHECK(c != nullptr); + const char* comp_meta = c->additional_metadata_json; + CHECK(comp_meta != nullptr); + CHECK(std::string(comp_meta).find("\"size\":\"small\"") != std::string::npos); + + const ModelVariantInfo* v = ModelComponentInfo_FindVariant(c, "cuda_fp16"); + CHECK(v != nullptr); + CHECK(std::string(v->ep) == "CUDAExecutionProvider"); + CHECK(std::string(v->device) == "gpu"); + CHECK(std::string(v->compatibility_string) == "sm_80"); + const char* var_meta = v->additional_metadata_json; + CHECK(var_meta != nullptr); + CHECK(std::string(var_meta).find("\"notes\":\"quantized\"") != std::string::npos); + + const char* resolved = v->variant_directory; + CHECK(resolved != nullptr); + CHECK(std::string(resolved).find("decoder/cuda_fp16") != std::string::npos); + + ModelPackage_Close(pkg); + return true; +} + +bool test_external_component_file() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "decoder": "components/decoder.json" } + })"); + s.Write("components/decoder.json", R"({ + "variants": { "cpu": {} } + })"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelComponentInfo* c = ModelPackage_FindComponent(ModelPackage_Info(pkg), "decoder"); + CHECK(c != nullptr); + CHECK((c)->num_variants == 1); + ModelPackage_Close(pkg); + return true; +} + +bool test_external_component_directory() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "decoder": "components/decoder" } + })"); + s.Write("components/decoder/component.json", R"({ + "variants": { "cpu": {} } + })"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + CHECK((ModelPackage_Info(pkg))->num_components == 1); + ModelPackage_Close(pkg); + return true; +} + +bool test_executor_info_inline_and_external() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { + "decoder": { + "variants": { + "cuda": { + "variant_directory": "v", + "executor_info": { + "ort": "ort_info.json", + "other": {"x": 1} + } + } + } + } + } + })"); + fs::create_directories(s.root() / "v"); + s.Write("v/ort_info.json", R"({"model_file":"model.onnx"})"); + + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelPackageInfo* info = ModelPackage_Info(pkg); + const ModelVariantInfo* v = + ModelComponentInfo_FindVariant(ModelPackage_FindComponent(info, "decoder"), "cuda"); + CHECK(v != nullptr); + + const ModelExecutorInfoEntry* ort_ei = ModelVariantInfo_FindExecutorInfo(v, "ort"); + const char* ort_json = ort_ei ? ort_ei->json : nullptr; + CHECK(ort_json != nullptr); + CHECK(std::string(ort_json).find("model.onnx") != std::string::npos); + + const ModelExecutorInfoEntry* other_ei = ModelVariantInfo_FindExecutorInfo(v, "other"); + const char* other_json = other_ei ? other_ei->json : nullptr; + CHECK(other_json != nullptr); + CHECK(std::string(other_json).find("\"x\":1") != std::string::npos); + + const ModelExecutorInfoEntry* missing_ei = ModelVariantInfo_FindExecutorInfo(v, "absent"); + const char* missing = missing_ei ? missing_ei->json : nullptr; + CHECK(missing_ei == nullptr); + CHECK(missing == nullptr); + + ModelPackage_Close(pkg); + return true; +} + +bool test_inline_executor_info_without_directory_accepted() { + // Library no longer requires variant_directory to exist for inline + // executor_info. Executors interpret their own payload. + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { + "decoder": { + "variants": { + "cuda": { + "executor_info": { "other": {"x": 1} } + } + } + } + } + })"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + ModelPackage_Close(pkg); + return true; +} + +bool test_path_confinement_rejects_external_paths() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "x": "../escape.json" } + })"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_PATH_CONFINEMENT); + return true; +} + +bool test_installed_layout_allows_absolute() { + // Build a package whose component lives outside its root. + Sandbox external; + external.Write("decoder.json", R"({"variants": {"cpu": {}}})"); + + Sandbox s; + std::string abs_comp = (external.root() / "decoder.json").string(); + // Escape backslashes for any platform that uses them — POSIX is fine as-is. + s.Write("manifest.json", std::string(R"({ + "schema_version": 1, + "layout": "installed", + "components": {"decoder": ")") + + abs_comp + R"("} + })"); + + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + CHECK((ModelPackage_Info(pkg))->num_components == 1); + ModelPackage_Close(pkg); + return true; +} + +bool test_shared_assets_resolve() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "shared_assets": { + "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa": "assets/a" + }, + "components": { + "x": { + "variants": { + "cpu": {} + } + } + } + })"); + fs::create_directories(s.root() / "assets" / "a"); + // Discovery: an on-disk sha256- dir without an override entry must + // surface alongside the explicit override. + fs::create_directories( + s.root() / "shared_assets" / + "sha256-bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"); + + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + CHECK((ModelPackage_Info(pkg))->num_shared_assets == 2); + + const ModelSharedAssetInfo* a = &(ModelPackage_Info(pkg))->shared_assets[0]; + CHECK(a != nullptr); + CHECK(std::string(a->uri).find("aaaa") != std::string::npos); + CHECK(std::string(a->resolved_path).find("assets/a") != std::string::npos); + + const ModelSharedAssetInfo* b = &(ModelPackage_Info(pkg))->shared_assets[1]; + CHECK(b != nullptr); + CHECK(std::string(b->uri).find("bbbb") != std::string::npos); + // Default convention path: shared_assets/sha256- + CHECK(std::string(b->resolved_path).find("shared_assets/sha256-bb") != std::string::npos); + + // Resolve via API. + const char* path = nullptr; + CHECK_OK(ModelPackage_ResolveAssetUri(pkg, + "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + &path)); + CHECK(std::string(path).find("assets/a") != std::string::npos); + + CHECK_ERR(ModelPackage_ResolveAssetUri(pkg, "sha256:not_a_known_one", &path), + MODEL_PACKAGE_ERR_ASSET_MISSING); + + ModelPackage_Close(pkg); + return true; +} + +bool test_unknown_field_rejected_strict() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "x": {"variants": {"cpu": {"typo_field": 1}}} } + })"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_SCHEMA); + return true; +} + +bool test_unknown_field_tolerated_lenient() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "x": {"variants": {"cpu": {"typo_field": 1}}} } + })"); + ModelPackageOpenOptions opts{}; + opts.strict_unknown_fields = false; + opts.follow_symlinks = true; + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), &opts, &pkg)); + ModelPackage_Close(pkg); + return true; +} + +bool test_round_trip_getters_preserve_order() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "decoder": {"variants": {"cuda": {"ep":"CUDAExecutionProvider","device":"gpu"}}} } + })"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const char* comp_json = nullptr; + CHECK_OK(ModelPackage_GetComponentJson(pkg, "decoder", &comp_json)); + CHECK(comp_json != nullptr); + CHECK(std::string(comp_json).find("\"variants\":") != std::string::npos); + + const char* var_json = nullptr; + CHECK_OK(ModelPackage_GetVariantJson(pkg, "decoder", "cuda", &var_json)); + CHECK(var_json != nullptr); + // "ep" must appear before "device" — ordered_json preserves declaration order. + size_t ep_pos = std::string(var_json).find("\"ep\""); + size_t dev_pos = std::string(var_json).find("\"device\""); + CHECK(ep_pos != std::string::npos && dev_pos != std::string::npos && ep_pos < dev_pos); + ModelPackage_Close(pkg); + return true; +} + +bool test_round_trip_preserves_unknown_fields_lenient() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "components": { "x": {"variants": {"cpu": {"future_field":"keepme"}}} } + })"); + ModelPackageOpenOptions opts{}; + opts.strict_unknown_fields = false; + opts.follow_symlinks = true; + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), &opts, &pkg)); + const char* var_json = nullptr; + CHECK_OK(ModelPackage_GetVariantJson(pkg, "x", "cpu", &var_json)); + CHECK(std::string(var_json).find("future_field") != std::string::npos); + ModelPackage_Close(pkg); + return true; +} + +bool test_missing_manifest() { + Sandbox s; + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_IO); + return true; +} + +bool test_unsupported_schema_version() { + Sandbox s; + s.Write("manifest.json", R"({"schema_version": 99, "components": {}})"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_VERSION); + return true; +} + +bool test_schema_version_string_and_minor() { + // "." string parses into the split fields. + { + Sandbox s; + s.Write("manifest.json", + R"({"schema_version": "1.0", "components": {"a": {"variants": {"cpu": {}}}}})"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelPackageInfo* info = ModelPackage_Info(pkg); + CHECK(info->schema_version_major == 1); + CHECK(info->schema_version_minor == 0); + ModelPackage_Close(pkg); + } + + // A newer minor than this build knows is accepted, and its unknown additive fields are + // tolerated rather than rejected even under the default strict mode. + { + Sandbox s; + s.Write("manifest.json", + R"({"schema_version": "1.7", "some_future_field": true, + "components": {"a": {"variants": {"cpu": {}}}}})"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelPackageInfo* info = ModelPackage_Info(pkg); + CHECK(info->schema_version_major == 1); + CHECK(info->schema_version_minor == 7); + ModelPackage_Close(pkg); + } + + // An unsupported major is rejected regardless of minor. + { + Sandbox s; + s.Write("manifest.json", R"({"schema_version": "2.0", "components": {}})"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_VERSION); + } + + // A malformed schema_version string is a schema error. + { + Sandbox s; + s.Write("manifest.json", R"({"schema_version": "1.x", "components": {}})"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_SCHEMA); + } + return true; +} + +bool test_invalid_sha256_uri_rejected() { + Sandbox s; + s.Write("manifest.json", R"({ + "schema_version": 1, + "shared_assets": { "sha256:notenough": "assets/a" }, + "components": {"x": {"variants": {"cpu": {}}}} + })"); + ModelPackage* pkg = nullptr; + CHECK_ERR(ModelPackage_Open(s.root().c_str(), nullptr, &pkg), MODEL_PACKAGE_ERR_SCHEMA); + return true; +} + +bool test_find_returns_null_on_missing() { + Sandbox s; + s.Write("manifest.json", R"({"schema_version":1,"components":{"a":{"variants":{"cpu":{}}}}})"); + ModelPackage* pkg = nullptr; + CHECK_OK(ModelPackage_Open(s.root().c_str(), nullptr, &pkg)); + const ModelPackageInfo* info = ModelPackage_Info(pkg); + CHECK(ModelPackage_FindComponent(info, "missing") == nullptr); + CHECK(ModelComponentInfo_FindVariant(ModelPackage_FindComponent(info, "a"), "missing") == nullptr); + ModelPackage_Close(pkg); + return true; +} + +struct Test { + const char* name; + bool (*fn)(); +}; + +const Test kTests[] = { + {"open_minimal_inline", test_open_minimal_inline}, + {"open_full_inline_with_metadata", test_open_full_inline_with_metadata}, + {"external_component_file", test_external_component_file}, + {"external_component_directory", test_external_component_directory}, + {"executor_info_inline_and_external", test_executor_info_inline_and_external}, + {"inline_executor_info_without_directory_accepted", + test_inline_executor_info_without_directory_accepted}, + {"path_confinement_rejects_external_paths", test_path_confinement_rejects_external_paths}, + {"installed_layout_allows_absolute", test_installed_layout_allows_absolute}, + {"shared_assets_resolve", test_shared_assets_resolve}, + {"unknown_field_rejected_strict", test_unknown_field_rejected_strict}, + {"unknown_field_tolerated_lenient", test_unknown_field_tolerated_lenient}, + {"round_trip_getters_preserve_order", test_round_trip_getters_preserve_order}, + {"round_trip_preserves_unknown_fields_lenient", + test_round_trip_preserves_unknown_fields_lenient}, + {"missing_manifest", test_missing_manifest}, + {"unsupported_schema_version", test_unsupported_schema_version}, + {"schema_version_string_and_minor", test_schema_version_string_and_minor}, + {"invalid_sha256_uri_rejected", test_invalid_sha256_uri_rejected}, + {"find_returns_null_on_missing", test_find_returns_null_on_missing}, +}; + +} // namespace + +int main() { + for (const auto& t : kTests) { + g_current = t.name; + bool ok = t.fn(); + if (ok) { + std::printf("[PASS] %s\n", t.name); + g_passed++; + } else { + g_failed++; + } + } + std::printf("\n=== %d passed, %d failed ===\n", g_passed, g_failed); + return g_failed == 0 ? 0 : 1; +} diff --git a/onnxruntime/core/session/model_package/README.md b/onnxruntime/core/session/model_package/README.md new file mode 100644 index 0000000000000..c4919219d7d40 --- /dev/null +++ b/onnxruntime/core/session/model_package/README.md @@ -0,0 +1,246 @@ +# ORT Model Package Integration + +This directory implements ONNX Runtime's consumer-side glue for the +standalone [`model_package` library](../../../../model_package/README.md): +loading packages, selecting variants against the runtime's execution +providers, and creating an `OrtSession` for the chosen variant. + +The package format, manifest schema, shared-asset rules, and the C +authoring/inspection API all live in `model_package/`. **This directory +adds three things on top**: + +1. The `executor_info["ort"]` payload schema (this is ORT's slot in the + variant body). +2. The variant selection algorithm, which queries each execution provider + factory and picks the highest-scoring variant. +3. The experimental `OrtModelPackageApi_*` C functions that wrap the library + and expose session creation. They are registered in + `include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc` and + resolved by name through `OrtApi::GetExperimentalFunction`. + +ORT links the `model_package` library as a static archive; the library +itself never links against ORT. + +--- + +## Files + +| File | Responsibility | +| ------------------------------------- | -------------- | +| `model_package_context.h/.cc` | Translates the `model_package` library's C info tree into ORT-internal C++ structs (`ModelPackageInfo`, `ComponentInfo`, `VariantInfo`, `VariantModelInfo`). Parses the `executor_info["ort"]` payload. Owns `ModelPackageContext` (package-level) and `ModelPackageComponentContext` (per-component, with selected variant and provider list). | +| `model_package_options.h/.cc` | `ModelPackageOptions` snapshots EP intent (factories, devices, EP-name list) from an `OrtSessionOptions` at the moment `OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28` is called. Drives variant selection and provider construction. | +| `model_package_variant_selector.h/.cc`| `VariantSelector::SelectVariant` picks the best variant from a component given the EP list. Uses `OrtEpFactory::ValidateCompiledModelCompatibilityInfo`. | + +The C entry points themselves live in +`onnxruntime/core/session/model_package_api.cc` under +`namespace OrtExperimentalApis`. + +--- + +## `executor_info["ort"]` schema + +ORT's slot in `variant.executor_info` is a JSON object. All fields are +optional, but in practice `model_file` is required to load a session. + +```jsonc +{ + "model_file": "model.onnx", + "external_data": "weights", + "session_options": { "session.intra_op_thread_count": "4" }, + "provider_options": { "device_id": "0" } +} +``` + +| Field | Type | Required | Notes | +| ------------------ | ------ | -------- | ----- | +| `model_file` | string | yes (for session) | Path to the model file inside the variant. Resolved via `ModelPackage_ResolveStringRef`, anchored at the variant directory. Accepts relative paths, absolute paths or `..` segments (installed layout only), and `sha256:[/sub/path]` for shared-asset content. | +| `external_data` | string | no | Folder containing the model's external-initializers blobs. Wired into the session as ORT's external-initializers folder hint. Same resolution rules as `model_file`. | +| `session_options` | object | no | Map of `string -> string`. Merged on top of a fresh `OrtSessionOptions` when the caller passes `session_options == NULL` to `CreateSession`. Ignored when the caller supplies their own `OrtSessionOptions`. | +| `provider_options` | object | no | Map of `string -> string`. Merged into the variant's EP provider options on the default path. Ignored when the caller supplies their own `OrtSessionOptions`. | + +#### Inline vs external + +The slot follows the standard `executor_info` shape: the value may be either + +- a **string**, a path to a JSON file containing the body above (commonly + `ort_info.json` next to `model.onnx`), or +- an **object**, the body inlined into `component.json` / + `manifest.json`. + +Inline form keeps the package single-file. External form (the common case) +keeps the variant directory self-describing and survives `executor_info` +schema evolution without rewriting the manifest. + +The key under `executor_info` is the **executor namespace name** (`"ort"`), +not the EP. Other consumers use their own namespace key, so a single +variant can carry per-consumer payloads side by side. + +--- + +## Variant selection + +`ModelPackageOptions(env, session_options)` captures the **EP intent**: the +ordered list of execution providers registered on the session options, plus +their associated `OrtEpDevice` / `OrtHardwareDevice` / metadata. + +`VariantSelector::SelectVariant(component, ep_infos, &selected)` then walks +the component's variants and picks the best match: + +1. Use only the **first** EP from the captured list. (A policy may rank + several EPs; callers that need a specific EP should put it first. + Ranking across the full EP list is on the TODO list.) +2. For each variant, require `variant.ep == ep_info.ep_name`. +3. If `variant.device` is set (`"cpu"` / `"gpu"` / `"npu"`), require it to + match at least one of the EP's `OrtHardwareDevice` entries. +4. If both pass, call `OrtEpFactory::ValidateCompiledModelCompatibilityInfo` + with `variant.compatibility_string`. The EP returns an + `OrtCompiledModelCompatibility` enum which maps to a score: + + | Enum | Score | + | -------------------------------------------- | ----- | + | `EP_SUPPORTED_OPTIMAL` | 100 | + | `EP_SUPPORTED_PREFER_RECOMPILATION` | 50 | + | `EP_NOT_APPLICABLE` (or EP too old / no ABI) | 0 | + | `EP_UNSUPPORTED` | rejected | + +5. Pick the highest-scoring matching variant. Manifest declaration order + breaks ties. + +If no variant matches, `SelectComponent` fails with "No suitable model +variant found for the configured execution providers." + +ORT does **not** parse `compatibility_string`. The EP owns the format and +may encode multiple sub-targets (SoC ids, ISA flags, etc.) into the single +string internally; ORT only round-trips it through the EP callback. + +--- + +## Session creation contract + +`OrtModelPackageApi_CreateSession_SinceV28(env, component_ctx, session_options, &session)`. + +The `component_ctx` already knows which variant won selection and which +provider list it should use. Two paths: + +- **`session_options == NULL` (default).** ORT starts from a fresh + `OrtSessionOptions` and merges the variant's `session_options` / + `provider_options` from `executor_info["ort"]` on top. EPs declared in the + manifest are constructed and registered. This is what nearly all callers + want. + +- **`session_options != NULL` (advanced).** ORT uses the caller-supplied + `OrtSessionOptions` as-is. The manifest's `session_options` and + `provider_options` are **not** merged. Use this when you need custom EP + setup that does not round-trip through string options (shared CUDA + streams, shared QNN EP contexts, custom allocators, ...). The + `OrtSessionOptions` passed earlier to + `CreateModelPackageOptionsFromSessionOptions` only drives variant + selection / EP discovery; it is never silently re-applied here. + +In both modes, `external_data` from `executor_info["ort"]` is wired in as +ORT's external-initializers folder hint, so the model file can reference +weights stored next to (or shared by) the package. + +--- + +## C API surface + +The model package API is exposed via ONNX Runtime's +[experimental C API](../../../../docs/design/Experimental_C_API.md). Each +function is registered as a separate entry in +`include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc` with +prefix `OrtModelPackageApi_` and version suffix `_SinceV28`. Consumers look +the functions up by name through `OrtApi::GetExperimentalFunction`, either +directly or via the typed C++ accessors in `Ort::Experimental::*` generated +from `onnxruntime_experimental_c_api.h`. + +The opaque handle types (`OrtModelPackageOptions`, `OrtModelPackageContext`, +`OrtModelPackageComponentContext`) are forward-declared at the top of +`onnxruntime_experimental_c_api.h`. + +Registered entries: + +| Function | Notes | +| ----------------------------------------------------- | ----- | +| `CreateModelPackageOptionsFromSessionOptions` | Snapshots EP intent. | +| `ReleaseModelPackageOptions` | | +| `CreateModelPackageContext` | Parses the manifest. | +| `ReleaseModelPackageContext` | | +| `ModelPackage_GetSchemaVersion` | | +| `ModelPackage_GetComponentCount` | | +| `ModelPackage_GetComponentNames` | | +| `ModelPackage_GetVariantCount` | | +| `ModelPackage_GetVariantNames` | | +| `ModelPackage_GetVariantEpName` | | +| `SelectComponent` | Resolves the best-matching variant. | +| `ReleaseModelPackageComponentContext` | | +| `ModelPackageComponent_GetSelectedVariantName` | | +| `ModelPackageComponent_GetSelectedVariantFolderPath` | | +| `CreateSession` | | + +> Experimental functions are not part of the stable ABI. Names, signatures +> and behaviour may change between releases until the surface is promoted +> to the stable `OrtApi`. Callers should null-check every lookup. + +Typical flow: + +```cpp +#include "onnxruntime_c_api.h" +#include "onnxruntime_experimental_c_api.h" + +const OrtApi* ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + +auto fn_create_opts = + Ort::Experimental::Get_OrtModelPackageApi_CreateModelPackageOptionsFromSessionOptions_SinceV28_Fn(ort); +auto fn_release_opts = + Ort::Experimental::Get_OrtModelPackageApi_ReleaseModelPackageOptions_SinceV28_Fn(ort); +auto fn_create_ctx = + Ort::Experimental::Get_OrtModelPackageApi_CreateModelPackageContext_SinceV28_Fn(ort); +auto fn_release_ctx = + Ort::Experimental::Get_OrtModelPackageApi_ReleaseModelPackageContext_SinceV28_Fn(ort); +auto fn_select = + Ort::Experimental::Get_OrtModelPackageApi_SelectComponent_SinceV28_Fn(ort); +auto fn_release_comp = + Ort::Experimental::Get_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn(ort); +auto fn_create_session = + Ort::Experimental::Get_OrtModelPackageApi_CreateSession_SinceV28_Fn(ort); + +OrtSessionOptions* so = nullptr; +ort->CreateSessionOptions(&so); +ort->SessionOptionsAppendExecutionProvider(so, "CUDAExecutionProvider", nullptr, nullptr, 0); + +OrtModelPackageOptions* mp_opts = nullptr; +fn_create_opts(env, so, &mp_opts); + +OrtModelPackageContext* ctx = nullptr; +fn_create_ctx(ORT_TSTR("/path/to/pkg"), &ctx); + +OrtModelPackageComponentContext* comp_ctx = nullptr; +fn_select(ctx, "decoder", mp_opts, &comp_ctx); + +OrtSession* session = nullptr; +fn_create_session(env, comp_ctx, nullptr, &session); + +ort->ReleaseSession(session); +fn_release_comp(comp_ctx); +fn_release_ctx(ctx); +fn_release_opts(mp_opts); +ort->ReleaseSessionOptions(so); +``` + +All `const char*` / `const ORTCHAR_T*` / array pointers returned by the API +are owned by the context that produced them and remain valid until the +context is released. + +--- + +## See also + +- [`model_package/README.md`](../../../../model_package/README.md): package + format, manifest/component schema, shared assets, path resolution, the + authoring C API, and the `executor_info` extension point. +- [`docs/design/Experimental_C_API.md`](../../../../docs/design/Experimental_C_API.md): + design and lifecycle rules for the experimental C API mechanism that + hosts these entries. +- `include/onnxruntime/core/session/onnxruntime_experimental_c_api.inc`: + the canonical list of `OrtModelPackageApi_*` entries. diff --git a/onnxruntime/core/session/model_package/model_package_context.cc b/onnxruntime/core/session/model_package/model_package_context.cc index ca4adb9c877a5..a0da46a10f88f 100644 --- a/onnxruntime/core/session/model_package/model_package_context.cc +++ b/onnxruntime/core/session/model_package/model_package_context.cc @@ -5,9 +5,7 @@ #include #include -#include #include -#include #include #include @@ -22,15 +20,23 @@ #include "core/session/provider_policy_context.h" #include "core/session/utils.h" -// We intentionally use the standalone model_package library's internal C++ types directly -// (model_package::ParsePackage, model_package_internal.h) rather than its public C API -// (ModelPackage_* functions). This avoids double-wrapping since ORT compiles the library in-tree. -// The public C API exists for external consumers (GenAI, FL) who link independently. -#include "model_package_internal.h" -#include "parser.h" +// Use the standalone model_package library's public C API. The library has no ORT +// dependency; ORT links it as a static archive (see cmake/onnxruntime_session.cmake) +// and translates the C handles into the ORT-internal C++ types defined in +// model_package_context.h here. +#include "model_package.h" namespace onnxruntime { +namespace { +// Deleter for the type-erased model_package handle held by ModelPackageContext. +void CloseModelPackageHandle(void* handle) { + if (handle != nullptr) { + ::ModelPackage_Close(static_cast<::ModelPackage*>(handle)); + } +} +} // namespace + namespace { Status FillOptionCachesFromMap( @@ -150,10 +156,9 @@ Status ModelPackageComponentContext::GetSelectedVariantFilePath(std::filesystem: "Selected variant index out of range for component: ", component_model_name_); const auto& selected_variant = component_model_info_.variants[selected_idx]; - ORT_RETURN_IF(!selected_variant.file.has_value(), + ORT_RETURN_IF(!selected_variant.file.has_value() || selected_variant.file->identifier.empty(), "Selected variant '", selected_variant.variant_name, - "' does not have a variant.json descriptor (or it lacks a 'filename' entry). " - "Component: ", + "' has no executor_info[\"ort\"][\"model_file\"]. Component: ", component_model_name_); out_path = selected_variant.file->model_file_path; @@ -345,54 +350,172 @@ Status ModelPackageComponentContext::GetSelectedVariantName(const std::string*& return Status::OK(); } -ModelPackageContext::ModelPackageContext(const std::filesystem::path& package_root) { - // Use the standalone model_package library for parsing. - model_package::PackageInfo pkg_info; - std::string error; - if (!model_package::ParsePackage(package_root, pkg_info, error)) { - ORT_THROW("Failed to parse model package: ", error); +Status ModelPackageComponentContext::GetSelectedVariantExternalDataFolder( + const std::string*& out_folder) const { + out_folder = nullptr; + const VariantInfo* selected_variant = nullptr; + ORT_RETURN_IF_ERROR(GetSelectedVariantInfo(selected_variant)); + ORT_RETURN_IF(selected_variant == nullptr, + "Selected variant is null for component: ", component_model_name_); + if (selected_variant->file.has_value() && + selected_variant->file->external_data_folder_path.has_value() && + !selected_variant->file->external_data_folder_path->empty()) { + out_folder = &(*selected_variant->file->external_data_folder_path); + } + return Status::OK(); +} + +ModelPackageContext::ModelPackageContext(const std::filesystem::path& package_root) + : package_handle_(nullptr, &CloseModelPackageHandle), package_root_(package_root) { + // Open the package via the model_package C API and keep the handle open for this context's + // lifetime (owned by package_handle_) so path references can be resolved later without + // reopening. The unique_ptr releases the handle even on exception paths during conversion. + ::ModelPackage* pkg = nullptr; + if (::ModelPackageStatus* st = ::ModelPackage_Open(package_root.string().c_str(), nullptr, &pkg)) { + std::string msg = ::ModelPackageStatus_Message(st) ? ::ModelPackageStatus_Message(st) : "unknown error"; + ::ModelPackageStatus_Release(st); + ORT_THROW("Failed to open model package at '", package_root.string(), "': ", msg); } + package_handle_.reset(pkg); - // Convert standalone library types to ORT internal types. - model_package_info_.schema_version = pkg_info.schema_version; + const ::ModelPackageInfo* pkg_info = ::ModelPackage_Info(pkg); + model_package_info_.schema_version = pkg_info ? pkg_info->schema_version_major : 0; model_package_info_.components.clear(); component_name_to_index_.clear(); - for (const auto& component : pkg_info.components) { - const auto& name = component.name; - size_t component_idx = model_package_info_.components.size(); - component_name_to_index_[name] = component_idx; + const size_t component_count = pkg_info ? pkg_info->num_components : 0; + for (size_t ci = 0; ci < component_count; ++ci) { + const ::ModelComponentInfo* component = &pkg_info->components[ci]; + + std::string component_name = component->name ? component->name : ""; + const size_t component_idx = model_package_info_.components.size(); + component_name_to_index_[component_name] = component_idx; ComponentInfo ort_component{}; - ort_component.component_name = name; + ort_component.component_name = component_name; ort_component.selected_variant_index.reset(); - for (const auto& variant : component.variants) { + const size_t variant_count = component->num_variants; + for (size_t vi = 0; vi < variant_count; ++vi) { + const ::ModelVariantInfo* variant = &component->variants[vi]; + VariantInfo ort_variant{}; - ort_variant.component_name = name; - ort_variant.variant_name = variant.name; - ort_variant.folder_path = variant.folder_path; - - // Convert EP compatibility (single entry per variant). - ort_variant.ep_compatibility.ep = variant.ep_compatibility.ep; - ort_variant.ep_compatibility.device = variant.ep_compatibility.device; - ort_variant.ep_compatibility.compatibility_string = variant.ep_compatibility.compatibility_string; - ort_variant.ep_compatibility.compiled_model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; - - // Convert file entry (single file per variant). - if (variant.file.has_value()) { + ort_variant.component_name = component_name; + ort_variant.variant_name = variant->name ? variant->name : ""; + + // Resolve the variant directory. Absence is treated as a soft signal; + // downstream callers that require a directory surface a clearer error + // at the point of use. + if (variant->variant_directory != nullptr) { + ort_variant.folder_path = std::filesystem::path(variant->variant_directory); + } + + // EP compatibility (single entry per variant). + if (variant->ep != nullptr) ort_variant.ep_compatibility.ep = std::string(variant->ep); + if (variant->device != nullptr) ort_variant.ep_compatibility.device = std::string(variant->device); + if (variant->compatibility_string != nullptr) + ort_variant.ep_compatibility.compatibility_string = std::string(variant->compatibility_string); + ort_variant.ep_compatibility.compiled_model_compatibility = + OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + + // Resolve the ORT executor_info from the manifest. + std::optional ort_obj; + if (const ::ModelExecutorInfoEntry* ei = + ::ModelVariantInfo_FindExecutorInfo(variant, "ort")) { + if (ei->json != nullptr && ei->json[0] != '\0') { + try { + ort_obj = json::parse(ei->json); + } catch (const std::exception& e) { + ORT_THROW("Failed to parse executor_info[\"ort\"] JSON for variant '", + ort_variant.variant_name, "' in component '", component_name, "': ", e.what()); + } + } + } + + if (ort_obj.has_value()) { + if (!ort_obj->is_object()) { + ORT_THROW("ORT variant configuration must be a JSON object for variant '", + ort_variant.variant_name, "' in component '", component_name, "'"); + } + VariantModelInfo ort_file{}; - ort_file.identifier = variant.file->filename; - ort_file.model_file_path = variant.file->resolved_path; - ort_file.session_options = variant.file->session_options; - ort_file.provider_options = variant.file->provider_options; - ort_file.shared_files = variant.file->shared_files; - ort_variant.file = std::move(ort_file); + + // Common resolver for ORT-side string refs (model_file, external_data). + // Delegates to ModelPackage_ResolveStringRef so accepted forms (relative, + // absolute, '..', sha256: URI, sha256: URI + subpath) and portable/installed + // confinement match the rest of the model_package library. + const std::string base_dir_str = ort_variant.folder_path.string(); + const char* base_dir = base_dir_str.empty() ? nullptr : base_dir_str.c_str(); + auto resolve_string_ref = [&](const char* field, const std::string& input, + bool must_exist) -> std::string { + const char* resolved = nullptr; + if (::ModelPackageStatus* st = ::ModelPackage_ResolveStringRef( + pkg, base_dir, input.c_str(), must_exist, &resolved)) { + std::string msg = ::ModelPackageStatus_Message(st) ? ::ModelPackageStatus_Message(st) + : "unknown error"; + ::ModelPackageStatus_Release(st); + ORT_THROW("Failed to resolve ORT variant '", field, "' = '", input, "' for variant '", + ort_variant.variant_name, "' in component '", component_name, "': ", msg); + } + return resolved ? std::string(resolved) : std::string{}; + }; + + if (auto it = ort_obj->find("model_file"); it != ort_obj->end()) { + if (!it->is_string()) { + ORT_THROW("ORT variant configuration: model_file must be a string for variant '", + ort_variant.variant_name, "' in component '", component_name, "'"); + } + const std::string model_file = it->get(); + ort_file.identifier = model_file; + ort_file.model_file_path = resolve_string_ref("model_file", model_file, + /*must_exist=*/false); + } + + auto fill_string_map = [&](const char* key, + std::optional>& dest) { + auto it = ort_obj->find(key); + if (it == ort_obj->end()) return; + if (!it->is_object()) { + ORT_THROW("ORT variant configuration: '", key, "' must be a JSON object for variant '", + ort_variant.variant_name, "' in component '", component_name, "'"); + } + std::unordered_map out; + out.reserve(it->size()); + for (auto kv = it->begin(); kv != it->end(); ++kv) { + if (!kv.value().is_string()) { + ORT_THROW("ORT variant configuration: '", key, "' entries must be strings for variant '", + ort_variant.variant_name, "' in component '", component_name, "'"); + } + out.emplace(kv.key(), kv.value().get()); + } + dest = std::move(out); + }; + fill_string_map("session_options", ort_file.session_options); + fill_string_map("provider_options", ort_file.provider_options); + + if (auto it = ort_obj->find("external_data"); it != ort_obj->end()) { + if (!it->is_string()) { + ORT_THROW("ORT variant configuration: external_data must be a string for variant '", + ort_variant.variant_name, "' in component '", component_name, "'"); + } + ort_file.external_data_folder_path = resolve_string_ref( + "external_data", it->get(), /*must_exist=*/false); + } + + if (!ort_file.identifier.empty() || ort_file.session_options.has_value() || + ort_file.provider_options.has_value() || ort_file.external_data_folder_path.has_value()) { + ort_variant.file = std::move(ort_file); + } } - // Consumer metadata. - if (variant.consumer_metadata_json.has_value()) { - ort_variant.consumer_metadata = nlohmann::json::parse(*variant.consumer_metadata_json); + // Variant-scope additional_metadata. + if (variant->additional_metadata_json != nullptr) { + try { + ort_variant.consumer_metadata = json::parse(variant->additional_metadata_json); + } catch (const std::exception& e) { + ORT_THROW("Failed to parse additional_metadata JSON for variant '", ort_variant.variant_name, + "' in component '", component_name, "': ", e.what()); + } } model_variant_infos_.push_back(ort_variant); @@ -405,7 +528,6 @@ ModelPackageContext::ModelPackageContext(const std::filesystem::path& package_ro // Create component names cache for quick lookup. component_names_cache_.clear(); component_names_cache_.reserve(model_package_info_.components.size()); - for (const auto& component : model_package_info_.components) { component_names_cache_.push_back(component.component_name); } @@ -415,6 +537,29 @@ size_t ModelPackageContext::GetComponentCount() const noexcept { return model_package_info_.components.size(); } +Status ModelPackageContext::ResolveStringRef(const std::string& base_dir, + const std::string& input, + bool must_exist, + const char*& out_path) const { + out_path = nullptr; + auto* pkg = static_cast<::ModelPackage*>(package_handle_.get()); + if (pkg == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model package handle is not open"); + } + const char* resolved = nullptr; + if (::ModelPackageStatus* st = ::ModelPackage_ResolveStringRef( + pkg, base_dir.empty() ? nullptr : base_dir.c_str(), input.c_str(), must_exist, &resolved)) { + std::string msg = ::ModelPackageStatus_Message(st) ? ::ModelPackageStatus_Message(st) : "unknown error"; + ::ModelPackageStatus_Release(st); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to resolve '", input, "' in model package: ", msg); + } + // Copy out of the library's transient buffer into a context-owned cache so the returned + // pointer stays valid until the next ResolveStringRef call. + resolve_string_ref_cache_ = resolved ? resolved : ""; + out_path = resolve_string_ref_cache_.c_str(); + return Status::OK(); +} + Status ModelPackageContext::GetComponentNames(gsl::span& out_names) const { out_names = gsl::span(component_names_cache_.data(), component_names_cache_.size()); diff --git a/onnxruntime/core/session/model_package/model_package_context.h b/onnxruntime/core/session/model_package/model_package_context.h index 1ecfffd7f74e0..a5ed5a04e917c 100644 --- a/onnxruntime/core/session/model_package/model_package_context.h +++ b/onnxruntime/core/session/model_package/model_package_context.h @@ -39,7 +39,11 @@ struct VariantModelInfo { // from variant.json file entry std::optional> session_options; std::optional> provider_options; - std::optional> shared_files; // logical_name -> checksum/path + + // Resolved folder containing the model's external initializer file, when + // executor_info.ort.external_data was set (path or sha256: URI). Empty + // otherwise. Used as the ORT external-initializers folder hint. + std::optional external_data_folder_path; }; // variant-level info (metadata.json + variant.json) @@ -124,6 +128,10 @@ class ModelPackageComponentContext { Status GetSelectedVariantName(const std::string*& out_name) const; + // Returns the resolved external_data folder for the selected variant, or + // nullptr-on-success if none was declared. Borrowed from VariantModelInfo. + Status GetSelectedVariantExternalDataFolder(const std::string*& out_folder) const; + std::vector>& MutableProviderList() { return provider_list_; } const std::vector& ExecutionDevices() const { return execution_devices_; } const std::vector& DevicesSelected() const { return devices_selected_; } @@ -184,7 +192,7 @@ class ModelPackageContext { gsl::span& out_variant_names) const; // Get the EP compatibility info declared on a variant. - // Lets callers (e.g. GenAI defaulting logic) inspect what EP a variant targets + // Lets callers inspect what EP a variant targets // before any EP has been resolved / before SelectComponent has been called. Status GetVariantEpCompatibility(const std::string& component_name, const std::string& variant_name, @@ -198,7 +206,25 @@ class ModelPackageContext { return model_variant_infos_; } + // Resolves a path reference from the package against the model_package library's rules: + // a "sha256:[/tail]" content-addressed shared-asset reference (honoring manifest + // overrides), or a plain relative path resolved against `base_dir` (empty base_dir falls + // back to the package root). When `must_exist` is true the resolved path must exist on + // disk. The returned pointer is owned by this context and stays valid until the next + // ResolveStringRef call. The underlying package handle is kept open for the context's + // lifetime so no reopen/reparse happens per call. + Status ResolveStringRef(const std::string& base_dir, const std::string& input, + bool must_exist, const char*& out_path) const; + private: + // The open model_package library handle, kept alive for this context's lifetime so path + // references can be resolved on demand. Stored type-erased (void*) to keep the + // model_package C header out of this ORT header; the deleter defined in the .cc closes it + // via ModelPackage_Close. + std::unique_ptr package_handle_; + std::filesystem::path package_root_{}; + mutable std::string resolve_string_ref_cache_{}; + ModelPackageInfo model_package_info_{}; std::vector model_variant_infos_; diff --git a/onnxruntime/core/session/model_package_api.cc b/onnxruntime/core/session/model_package_api.cc index 5fd25f9511eb9..aeae2ec1855d3 100644 --- a/onnxruntime/core/session/model_package_api.cc +++ b/onnxruntime/core/session/model_package_api.cc @@ -14,7 +14,6 @@ #include "core/session/model_package/model_package_context.h" #include "core/session/model_package/model_package_options.h" #include "core/session/utils.h" - #endif using namespace onnxruntime; @@ -400,13 +399,13 @@ ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28, const onnxruntime::VariantEpCompatibilityInfo* info = nullptr; auto status = reinterpret_cast(ctx)->GetVariantEpCompatibility( component_name, variant_name, info); + if (!status.IsOK()) { + if (out_ep != nullptr) *out_ep = nullptr; + return onnxruntime::ToOrtStatus(status); + } if (out_ep != nullptr) { - if (status.IsOK() && info != nullptr && info->ep.has_value()) { - *out_ep = info->ep->c_str(); - } else { - *out_ep = nullptr; - } + *out_ep = (info != nullptr && info->ep.has_value()) ? info->ep->c_str() : nullptr; } return nullptr; #else @@ -419,6 +418,39 @@ ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_ResolveStringRef_SinceV28, + _In_ const OrtModelPackageContext* ctx, + _In_opt_ const char* base_dir, + _In_ const char* input, + _In_ int must_exist, + _Outptr_ const char** out_path) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (ctx == nullptr || input == nullptr || out_path == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ctx, input, and out_path must be non-null"); + } + *out_path = nullptr; + + const char* resolved = nullptr; + auto status = reinterpret_cast(ctx)->ResolveStringRef( + base_dir != nullptr ? std::string(base_dir) : std::string{}, std::string(input), + must_exist != 0, resolved); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } + *out_path = resolved; + return nullptr; +#else + ORT_UNUSED_PARAMETER(ctx); + ORT_UNUSED_PARAMETER(base_dir); + ORT_UNUSED_PARAMETER(input); + ORT_UNUSED_PARAMETER(must_exist); + ORT_UNUSED_PARAMETER(out_path); + RETURN_NOT_IMPL_IN_MINIMAL_BUILD(); +#endif + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtModelPackageApi_ModelPackage_GetSchemaVersion_SinceV28, _In_ const OrtModelPackageContext* ctx, _Out_ int64_t* out_version) { diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 330974aeed8d8..d196221ec55dc 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -981,16 +981,71 @@ OrtStatus* CreateSessionForModelPackage(_In_ const OrtSessionOptions* options, const std::filesystem::path& selected_model_path, onnxruntime::ModelPackageComponentContext& model_package_context, std::unique_ptr& sess) { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadSingleModelImpl(options, env, - selected_model_path.c_str(), - /*model_data*/ nullptr, - /*model_data_length*/ 0, - sess)); + // When the variant declares an external_data folder (e.g. a shared asset + // under /shared_assets/sha256-/) we must switch to + // buffer load: ORT only honors session.model_external_initializers_file_folder_path + // when model_location_ is empty (see inference_session.cc). The mmap'd + // model buffer can be released right after Load; external initializers + // are read from the folder hint during Initialize. + const std::string* external_data_folder = nullptr; + ORT_API_RETURN_IF_STATUS_NOT_OK( + model_package_context.GetSelectedVariantExternalDataFolder(external_data_folder)); + + std::unique_ptr cloned_options; + const OrtSessionOptions* options_to_use = options; + onnxruntime::Env::MappedMemoryPtr mapped_model; + const void* model_data = nullptr; + size_t model_data_length = 0; + + if (external_data_folder != nullptr) { + cloned_options = options ? std::make_unique(*options) + : std::make_unique(); + ORT_API_RETURN_IF_STATUS_NOT_OK( + cloned_options->value.config_options.AddConfigEntry( + kOrtSessionOptionsModelExternalInitializersFileFolderPath, + external_data_folder->c_str())); + options_to_use = cloned_options.get(); + + size_t model_file_length = 0; + ORT_API_RETURN_IF_STATUS_NOT_OK( + onnxruntime::Env::Default().GetFileLength(selected_model_path.c_str(), model_file_length)); + if (model_file_length == 0) { + return OrtApis::CreateStatus( + ORT_FAIL, + ("model_package: selected variant model file is empty: " + selected_model_path.string()).c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK( + onnxruntime::Env::Default().MapFileIntoMemory(selected_model_path.c_str(), + /*offset=*/0, + model_file_length, + mapped_model)); + model_data = mapped_model.get(); + model_data_length = model_file_length; + } else if (options_to_use == nullptr) { + // No external_data and caller did not pass options: synthesize a default + // OrtSessionOptions so the downstream *options_to_use dereferences are safe. + cloned_options = std::make_unique(); + options_to_use = cloned_options.get(); + } + + if (model_data != nullptr) { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadSingleModelImpl(options_to_use, env, + /*model_path*/ nullptr, + model_data, + model_data_length, + sess)); + } else { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadSingleModelImpl(options_to_use, env, + selected_model_path.c_str(), + /*model_data*/ nullptr, + /*model_data_length*/ 0, + sess)); + } + mapped_model.reset(); - // Always rebuild providers from the effective session options (which include merged variant - // provider options). Providers created during EP selection used the original session options - // and would not reflect variant-specific provider options. - ORT_API_RETURN_IF_STATUS_NOT_OK(model_package_context.RebuildProviderListForSession(env, *options)); + // Providers were created earlier from the original options; rebuild now so + // any merged variant-specific provider options take effect. + ORT_API_RETURN_IF_STATUS_NOT_OK(model_package_context.RebuildProviderListForSession(env, *options_to_use)); auto& provider_list = model_package_context.MutableProviderList(); @@ -1000,10 +1055,10 @@ OrtStatus* CreateSessionForModelPackage(_In_ const OrtSessionOptions* options, } } - if (model_package_context.IsFromPolicy() && options != nullptr) { + if (model_package_context.IsFromPolicy()) { ProviderPolicyContext provider_policy_context; ORT_API_RETURN_IF_STATUS_NOT_OK(provider_policy_context.LogTelemetry( - *sess, *options, + *sess, *options_to_use, model_package_context.ExecutionDevices(), model_package_context.DevicesSelected())); } diff --git a/onnxruntime/test/autoep/test_model_package.cc b/onnxruntime/test/autoep/test_model_package.cc index ee5c8bb567e1e..8b887b3c41330 100644 --- a/onnxruntime/test/autoep/test_model_package.cc +++ b/onnxruntime/test/autoep/test_model_package.cc @@ -1,15 +1,18 @@ -// Copyright (c) Microsoft Corporation. +// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. #include #include #include +#include #include #include #include +#include #include #include "gtest/gtest.h" +#include "nlohmann/json.hpp" #include "core/session/model_package/model_package_context.h" #include "core/session/onnxruntime_experimental_cxx_api.h" @@ -47,6 +50,8 @@ struct ModelPackageFns { ModelPackage_GetVariantNames{nullptr}; OrtExperimental_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_Fn ModelPackage_GetVariantEpName{nullptr}; + OrtExperimental_OrtModelPackageApi_ModelPackage_ResolveStringRef_SinceV28_Fn + ModelPackage_ResolveStringRef{nullptr}; OrtExperimental_OrtModelPackageApi_SelectComponent_SinceV28_Fn SelectComponent{nullptr}; OrtExperimental_OrtModelPackageApi_ReleaseModelPackageComponentContext_SinceV28_Fn @@ -84,6 +89,8 @@ inline const ModelPackageFns& GetModelPackageFns() { Exp::Get_OrtModelPackageApi_ModelPackage_GetVariantNames_SinceV28_FnOrThrow(api); f.ModelPackage_GetVariantEpName = Exp::Get_OrtModelPackageApi_ModelPackage_GetVariantEpName_SinceV28_FnOrThrow(api); + f.ModelPackage_ResolveStringRef = + Exp::Get_OrtModelPackageApi_ModelPackage_ResolveStringRef_SinceV28_FnOrThrow(api); f.SelectComponent = Exp::Get_OrtModelPackageApi_SelectComponent_SinceV28_FnOrThrow(api); f.ReleaseModelPackageComponentContext = @@ -98,218 +105,121 @@ inline const ModelPackageFns& GetModelPackageFns() { }(); return fns; } -// ------------------------------------------------------------------ -// Helpers to build a test model package on disk -// ------------------------------------------------------------------ -std::filesystem::path CreateManifestJson(const std::filesystem::path& package_root, - std::string_view manifest_json) { - std::filesystem::path manifest_path = package_root / "manifest.json"; - std::filesystem::create_directories(package_root); - - std::ofstream os(manifest_path, std::ios::binary); - os << manifest_json; - return manifest_path; -} - -std::string MakeVariantJson(std::string_view filename) { - std::ostringstream oss; - oss << R"({ - "filename": ")" - << filename << R"(" - })"; - return oss.str(); -} - -void CreateVariantDescriptor(const std::filesystem::path& package_root, - std::string_view component_name, - std::string_view variant_name, - std::string_view variant_json) { - const auto variant_root = package_root / "models" / std::string(component_name) / std::string(variant_name); - std::filesystem::create_directories(variant_root); - - std::ofstream os(variant_root / "variant.json", std::ios::binary); - os << variant_json; -} +// ──────────────────────────────────────────────────────────────────────────── +// Fixture helpers for building model packages on disk. +// Every package is a single manifest.json at the package root that declares +// components/variants/executor_info inline. Variant directories live at +// `///` and contain the model file. +// ──────────────────────────────────────────────────────────────────────────── + +struct VariantSpec { + std::string variant_name; + std::string ep; // empty => omit + std::string device; // empty => omit + std::string compatibility_string; // empty => omit + std::filesystem::path source_model; // empty => no executor_info + std::optional> session_options; + std::optional> provider_options; +}; -std::filesystem::path CreateModelPackage( - const std::filesystem::path& package_root, - std::string_view manifest_json, - std::string_view component_name, - std::string_view variant_name_1, - std::string_view variant_name_2, - const std::filesystem::path& source_model_1, - const std::filesystem::path& source_model_2) { +// Build a single-component new-schema package on disk and return its root. +// `package_root` is wiped before writing. +std::filesystem::path BuildPackage(const std::filesystem::path& package_root, + const std::string& component_name, + const std::vector& variants) { std::error_code ec; std::filesystem::remove_all(package_root, ec); std::filesystem::create_directories(package_root); - CreateManifestJson(package_root, manifest_json); - - const auto models_root = package_root / "models" / std::string(component_name); - const auto variant1_dir = models_root / std::string(variant_name_1); - const auto variant2_dir = models_root / std::string(variant_name_2); - - std::filesystem::create_directories(variant1_dir); - std::filesystem::create_directories(variant2_dir); - - const auto variant1_model = variant1_dir / source_model_1.filename(); - const auto variant2_model = variant2_dir / source_model_2.filename(); - - std::filesystem::copy_file(source_model_1, variant1_model, std::filesystem::copy_options::overwrite_existing, ec); - std::filesystem::copy_file(source_model_2, variant2_model, std::filesystem::copy_options::overwrite_existing, ec); - - CreateVariantDescriptor(package_root, component_name, variant_name_1, - MakeVariantJson(source_model_1.filename().string())); - CreateVariantDescriptor(package_root, component_name, variant_name_2, - MakeVariantJson(source_model_2.filename().string())); - - return package_root; -} - -std::filesystem::path CreateComponentModelMetadata( - const std::filesystem::path& package_root, - std::string_view component_name, - std::string_view metadata_json) { - const auto component_root = package_root / "models" / std::string(component_name); - - std::filesystem::create_directories(component_root); - - const std::filesystem::path metadata_path = component_root / "metadata.json"; - std::ofstream os(metadata_path, std::ios::binary); - os << metadata_json; - - return component_root; -} - -std::string MakeManifestJson(std::string_view component_name) { - std::ostringstream oss; - oss << R"({ - "schema_version": 1, - "components": [")" - << component_name << R"("] - })"; - return oss.str(); -} - -std::string MakeMetadataJsonTwoVariants(std::string_view component_name, - std::string_view variant_name_1, - std::string_view variant_ep_1, - std::string_view variant_device_1, - std::string_view variant_compatibility_string_1, - std::string_view variant_name_2, - std::string_view variant_ep_2, - std::string_view variant_device_2, - std::string_view variant_compatibility_string_2) { - std::ostringstream oss; - oss << R"({ - "component_name": ")" - << component_name << R"(", - "variants": { - ")" - << variant_name_1 << R"(": { - "ep": ")" - << variant_ep_1 << R"(", - "device": ")" - << variant_device_1 << R"(", - "compatibility_string": ")" - << variant_compatibility_string_1 << R"(" - }, - ")" - << variant_name_2 << R"(": { - "ep": ")" - << variant_ep_2 << R"(", - "device": ")" - << variant_device_2 << R"(", - "compatibility_string": ")" - << variant_compatibility_string_2 << R"(" + using ojson = nlohmann::ordered_json; + ojson variants_obj = ojson::object(); + for (const auto& v : variants) { + const std::string variant_dir_rel = component_name + "/" + v.variant_name; + const auto variant_dir_abs = package_root / component_name / v.variant_name; + std::filesystem::create_directories(variant_dir_abs); + + ojson variant_obj = ojson::object(); + variant_obj["variant_directory"] = variant_dir_rel; + if (!v.ep.empty()) variant_obj["ep"] = v.ep; + if (!v.device.empty()) variant_obj["device"] = v.device; + if (!v.compatibility_string.empty()) variant_obj["compatibility_string"] = v.compatibility_string; + + if (!v.source_model.empty()) { + const std::string model_filename = v.source_model.filename().string(); + std::filesystem::copy_file(v.source_model, variant_dir_abs / model_filename, + std::filesystem::copy_options::overwrite_existing, ec); + + ojson ort_info = ojson::object(); + ort_info["model_file"] = model_filename; + if (v.session_options.has_value()) { + ojson so = ojson::object(); + for (const auto& kv : *v.session_options) so[kv.first] = kv.second; + ort_info["session_options"] = std::move(so); } - } - })"; - return oss.str(); -} - -std::filesystem::path CreateModelPackageApiTestPackage(bool multi_file_variant = false) { - const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_api_test"; - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - - constexpr std::string_view manifest_json = R"({ - "schema_version": 1, - "components": ["model_1"] - })"; - - CreateModelPackage(package_root, manifest_json, - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); - - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", - "device": "cpu", - "compatibility_string": "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch1" - }, - "variant_2": { - "ep": "example_ep", - "device": "npu", - "compatibility_string": "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch2" + if (v.provider_options.has_value()) { + ojson po = ojson::object(); + for (const auto& kv : *v.provider_options) po[kv.first] = kv.second; + ort_info["provider_options"] = std::move(po); } + ojson executor_info = ojson::object(); + executor_info["ort"] = std::move(ort_info); + variant_obj["executor_info"] = std::move(executor_info); } - })"; - - CreateComponentModelMetadata(package_root, "model_1", metadata_json); - if (!multi_file_variant) { - std::ofstream os(package_root / "models" / "model_1" / "variant_1" / "variant.json", std::ios::binary); - os << R"({ - "filename": "mul_1.onnx", - "session_options": { - "session.disable_prepacking": "1", - "session.intra_op.allow_spinning": "0" - }, - "provider_options": { - "backend_path": "example_backend", - "enable_htp": "1" - } - })"; - } else { - // Multi-file variants are no longer supported. For backward-compat testing, - // just write a single-file variant.json. - std::ofstream os(package_root / "models" / "model_1" / "variant_1" / "variant.json", std::ios::binary); - os << R"({ - "filename": "mul_1.onnx", - "session_options": { - "session.disable_prepacking": "1", - "session.intra_op.allow_spinning": "0" - }, - "provider_options": { - "backend_path": "example_backend", - "enable_htp": "1" - } - })"; + variants_obj[v.variant_name] = std::move(variant_obj); } - { - std::ofstream os(package_root / "models" / "model_1" / "variant_2" / "variant.json", std::ios::binary); - os << R"({ - "filename": "mul_16.onnx" - })"; - } + ojson component_obj = ojson::object(); + component_obj["variants"] = std::move(variants_obj); + + ojson components_obj = ojson::object(); + components_obj[component_name] = std::move(component_obj); + + ojson manifest = ojson::object(); + manifest["schema_version"] = "1.0"; + manifest["components"] = std::move(components_obj); + std::ofstream os(package_root / "manifest.json", std::ios::binary); + os << manifest.dump(2); return package_root; } +// Convenience: most tests use the same two-variant shape backed by mul_1.onnx / +// mul_16.onnx. `compat_1` and `compat_2` default to empty (no compatibility string). +std::filesystem::path BuildTwoVariantPackage(const std::filesystem::path& package_root, + std::string_view variant_name_1, + std::string_view device_1, + std::string_view compat_1, + const std::filesystem::path& model_1, + std::string_view variant_name_2, + std::string_view device_2, + std::string_view compat_2, + const std::filesystem::path& model_2, + std::string_view ep_name = "example_ep") { + std::vector variants; + variants.push_back(VariantSpec{std::string(variant_name_1), std::string(ep_name), std::string(device_1), std::string(compat_1), model_1, {}, {}}); + variants.push_back(VariantSpec{std::string(variant_name_2), std::string(ep_name), std::string(device_2), std::string(compat_2), model_2, {}, {}}); + return BuildPackage(package_root, "model_1", variants); +} + } // namespace -// ------------------------------------------------------------------ +// ──────────────────────────────────────────────────────────────────────────── // Model Package API tests -// ------------------------------------------------------------------ +// ──────────────────────────────────────────────────────────────────────────── TEST(ModelPackageApiTest, PackageContextQueries) { - const auto package_root = CreateModelPackageApiTestPackage(); + const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_api_test"; + BuildTwoVariantPackage(package_root, + "variant_1", "cpu", + "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch1", + "testdata/mul_1.onnx", + "variant_2", "npu", + "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch2", + "testdata/mul_16.onnx"); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); @@ -320,7 +230,6 @@ TEST(ModelPackageApiTest, PackageContextQueries) { ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); model_pkg_context.reset(raw_context); - // Query: component count + names size_t component_count = 0; ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetComponentCount(model_pkg_context.get(), &component_count)); ASSERT_EQ(component_count, 1u); @@ -334,7 +243,6 @@ TEST(ModelPackageApiTest, PackageContextQueries) { ASSERT_NE(component_names[0], nullptr); EXPECT_STREQ(component_names[0], "model_1"); - // Query: variant count + names size_t variant_count = 0; ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantCount( model_pkg_context.get(), "model_1", &variant_count)); @@ -358,8 +266,82 @@ TEST(ModelPackageApiTest, PackageContextQueries) { std::filesystem::remove_all(package_root, ec); } +TEST(ModelPackageApiTest, ResolveStringRef) { + const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_resolve_test"; + std::vector variants; + variants.push_back(VariantSpec{"variant_1", "example_ep", "cpu", "", "testdata/mul_1.onnx", {}, {}}); + BuildPackage(package_root, "model_1", variants); + + // A content-addressed shared asset, discovered by convention at shared_assets/sha256-/. + const std::string digest(64, 'a'); + const auto asset_dir = package_root / "shared_assets" / ("sha256-" + digest); + std::filesystem::create_directories(asset_dir); + { + std::ofstream os(asset_dir / "asset.txt", std::ios::binary); + os << "hello"; + } + + const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.ModelPackage_ResolveStringRef, nullptr) << "Model package experimental API is not available"; + + auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { + if (p) pkg_api.ReleaseModelPackageContext(p); + }; + std::unique_ptr ctx(nullptr, context_deleter); + OrtModelPackageContext* raw_context = nullptr; + ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_context)); + ctx.reset(raw_context); + + const char* resolved = nullptr; + + // "sha256:" resolves to the shared asset directory (override/discovery aware). + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_ResolveStringRef( + ctx.get(), nullptr, ("sha256:" + digest).c_str(), /*must_exist=*/1, &resolved)); + ASSERT_NE(resolved, nullptr); + EXPECT_EQ(std::filesystem::canonical(resolved), std::filesystem::canonical(asset_dir)); + + // "sha256:/" resolves the confined tail under the asset directory. + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_ResolveStringRef( + ctx.get(), nullptr, ("sha256:" + digest + "/asset.txt").c_str(), /*must_exist=*/1, &resolved)); + ASSERT_NE(resolved, nullptr); + EXPECT_EQ(std::filesystem::canonical(resolved), std::filesystem::canonical(asset_dir / "asset.txt")); + + // A plain relative path resolves against base_dir. + const auto variant_dir = package_root / "model_1" / "variant_1"; + ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_ResolveStringRef( + ctx.get(), variant_dir.string().c_str(), "mul_1.onnx", /*must_exist=*/1, &resolved)); + ASSERT_NE(resolved, nullptr); + EXPECT_EQ(std::filesystem::canonical(resolved), std::filesystem::canonical(variant_dir / "mul_1.onnx")); + + // An undeclared sha256 asset is rejected even when must_exist is false. + const std::string missing_digest(64, 'b'); + OrtStatus* status = pkg_api.ModelPackage_ResolveStringRef( + ctx.get(), nullptr, ("sha256:" + missing_digest).c_str(), /*must_exist=*/0, &resolved); + EXPECT_NE(status, nullptr); + if (status != nullptr) Ort::GetApi().ReleaseStatus(status); + + std::error_code ec; + std::filesystem::remove_all(package_root, ec); +} + TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateSession) { - const auto package_root = CreateModelPackageApiTestPackage(); + const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_api_test"; + std::vector variants; + variants.push_back(VariantSpec{ + "variant_1", "example_ep", "cpu", + "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch1", + "testdata/mul_1.onnx", + std::unordered_map{ + {"session.disable_prepacking", "1"}, + {"session.intra_op.allow_spinning", "0"}, + }, + std::unordered_map{ + {"backend_path", "example_backend"}, + {"enable_htp", "1"}, + }}); + variants.push_back(VariantSpec{ + "variant_2", "example_ep", "npu", "example_ep;version=0.1.0;ort_api_version=25;hardware_architecture=arch2", "testdata/mul_16.onnx", {}, {}}); + BuildPackage(package_root, "model_1", variants); RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); @@ -370,6 +352,7 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); @@ -428,156 +411,64 @@ TEST(ModelPackageApiTest, SingleFileVariantInComponent_SelectComponentAndCreateS } TEST(ModelPackageTest, LoadModelPackageAndRunInference_PluginEp_AppendV2) { - // Test Case 1: - // package_root is a model package directory which contains a manifest.json. - // This model package only contains one component model and it has a metadata.json. - // ORT should parse the manifest and the metadata.json to get model variants' constraints. - // ORT selects most suitable model variant based on constraints and then loads it to run inference successfully. - { - // Build model package on disk - const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; - CreateModelPackage(package_root, MakeManifestJson("model_1"), - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); - - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_1", "example_ep", "cpu", "", - "variant_2", "example_ep", "npu", ""); - - CreateComponentModelMetadata(package_root, - "model_1", - metadata_json); - - // Register example EP and get its device - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - - // Prepare session options with ExampleEP appended - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - - // Create session from package root (directory) - // ORT should pick the variant_1 model since the constraints match the example EP device (device "cpu" matches) - Ort::Session session(*ort_env, package_root.c_str(), session_options); - - // Prepare input X - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - Ort::Value input = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), - shape.data(), shape.size()); - const char* input_names[] = {"X"}; - const char* output_names[] = {"Y"}; - std::vector inputs; - inputs.push_back(std::move(input)); - - // Run - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), inputs.size(), - output_names, 1); - ASSERT_EQ(outputs.size(), 1u); - const float* out = outputs[0].GetTensorData(); - gsl::span out_span(out, input_data.size()); - EXPECT_THAT(out_span, ::testing::ElementsAre(1.f, 4.f, 9.f, 16.f, 25.f, 36.f)); - - // Cleanup - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - } - - // Test Case 2: - // package_root is a component model directory which contains a metadata.json. - // ORT should parse metadata.json to get model variants' constraints. - // ORT selects most suitable model variant based on constraints and then loads it to run inference successfully. - { - // Build model package on disk - const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; + // package_root is a new-schema model package directory with one component and two variants. + // ORT parses the manifest, selects the variant whose device matches the registered EP (cpu), + // and loads/runs it successfully. + const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; + BuildTwoVariantPackage(package_root, + "variant_1", "cpu", "", + "testdata/mul_1.onnx", + "variant_2", "npu", "", + "testdata/mul_16.onnx"); - CreateModelPackage(package_root, MakeManifestJson("model_1"), - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_1", "example_ep", "cpu", "", - "variant_2", "example_ep", "npu", ""); + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - const auto component_model_root = CreateComponentModelMetadata(package_root, - "model_1", - metadata_json); + Ort::Session session(*ort_env, package_root.c_str(), session_options); - // Register example EP and get its device - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + Ort::Value input = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), + shape.data(), shape.size()); + const char* input_names[] = {"X"}; + const char* output_names[] = {"Y"}; + std::vector inputs; + inputs.push_back(std::move(input)); - // Prepare session options with ExampleEP appended - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), inputs.size(), + output_names, 1); + ASSERT_EQ(outputs.size(), 1u); + const float* out = outputs[0].GetTensorData(); + gsl::span out_span(out, input_data.size()); + EXPECT_THAT(out_span, ::testing::ElementsAre(1.f, 4.f, 9.f, 16.f, 25.f, 36.f)); - // Create session from component model root (directory) - // ORT should pick the variant_1 model since the constraints match the example EP device (device "cpu" matches) - Ort::Session session(*ort_env, component_model_root.c_str(), session_options); - - // Prepare input X - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - Ort::Value input = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), - shape.data(), shape.size()); - const char* input_names[] = {"X"}; - const char* output_names[] = {"Y"}; - std::vector inputs; - inputs.push_back(std::move(input)); - - // Run - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), inputs.size(), - output_names, 1); - ASSERT_EQ(outputs.size(), 1u); - const float* out = outputs[0].GetTensorData(); - gsl::span out_span(out, input_data.size()); - EXPECT_THAT(out_span, ::testing::ElementsAre(1.f, 4.f, 9.f, 16.f, 25.f, 36.f)); - - // Cleanup - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - } + std::error_code ec; + std::filesystem::remove_all(package_root, ec); } TEST(ModelPackageTest, LoadModelPackageAndRunInference_PreferCpu) { - // Build model package on disk const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; + BuildTwoVariantPackage(package_root, + "variant_1", "cpu", "", + "testdata/mul_1.onnx", + "variant_2", "npu", "", + "testdata/mul_16.onnx"); - CreateModelPackage(package_root, MakeManifestJson("model_1"), - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); - - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_1", "example_ep", "cpu", "", - "variant_2", "example_ep", "npu", ""); - - CreateComponentModelMetadata(package_root, - "model_1", - metadata_json); - - // Register example EP and get its device RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - // Prepare session options with ExampleEP appended Ort::SessionOptions session_options; session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - // Create session from package root (directory) - // ORT should pick the variant_1 model since the constraints match the example EP device (device "cpu" matches) Ort::Session session(*ort_env, package_root.c_str(), session_options); - // Prepare input X Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); std::vector shape = {3, 2}; std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; @@ -588,7 +479,6 @@ TEST(ModelPackageTest, LoadModelPackageAndRunInference_PreferCpu) { std::vector inputs; inputs.push_back(std::move(input)); - // Run auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), inputs.size(), output_names, 1); ASSERT_EQ(outputs.size(), 1u); @@ -596,7 +486,6 @@ TEST(ModelPackageTest, LoadModelPackageAndRunInference_PreferCpu) { gsl::span out_span(out, input_data.size()); EXPECT_THAT(out_span, ::testing::ElementsAre(1.f, 4.f, 9.f, 16.f, 25.f, 36.f)); - // Cleanup std::error_code ec; std::filesystem::remove_all(package_root, ec); } @@ -610,7 +499,6 @@ TEST(ModelPackageTest, CheckCompiledModelCompatibilityInfo) { const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_compat_test.onnx"); std::filesystem::remove(output_model_file); - // Compile the model { Ort::SessionOptions session_options; std::unordered_map ep_options; @@ -625,153 +513,41 @@ TEST(ModelPackageTest, CheckCompiledModelCompatibilityInfo) { ASSERT_TRUE(std::filesystem::exists(output_model_file)); } - // Build model package on disk - const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; - - CreateModelPackage(package_root, MakeManifestJson("model_1"), - "model_1", "variant_2", "variant_1", - std::filesystem::path{"testdata/mul_16.onnx"}, std::filesystem::path{"plugin_ep_compat_test.onnx"}); - // Build compat strings dynamically against current ORT_API_VERSION so the EP's ORT-version check - // doesn't short-circuit to PREFER_RECOMPILATION for both variants (which would make hardware_architecture - // irrelevant and the variant ranking collapse to a tie). With matching ORT versions, the arch differentiates: - // arch1 -> OPTIMAL, arch2 -> PREFER_RECOMPILATION; variant_1 must win. + // doesn't short-circuit to PREFER_RECOMPILATION for both variants. With matching ORT versions the + // hardware_architecture field differentiates: arch1 -> OPTIMAL, arch2 -> PREFER_RECOMPILATION, so + // variant_1 (mul_1) must win over variant_2 (mul_16). If variant_2 was picked, session init would + // fail with "No Op registered for Mul16". const std::string ort_api_version_str = std::to_string(ORT_API_VERSION); const std::string compat_arch2 = "example_ep;version=0.1.0;ort_api_version=" + ort_api_version_str + ";hardware_architecture=arch2"; const std::string compat_arch1 = "example_ep;version=0.1.0;ort_api_version=" + ort_api_version_str + ";hardware_architecture=arch1"; - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_2", "example_ep", "cpu", compat_arch2.c_str(), - "variant_1", "example_ep", "cpu", compat_arch1.c_str()); - CreateComponentModelMetadata(package_root, - "model_1", - metadata_json); + const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_test"; + BuildTwoVariantPackage(package_root, + "variant_2", "cpu", compat_arch2, + std::filesystem::path{"testdata/mul_16.onnx"}, + "variant_1", "cpu", compat_arch1, + std::filesystem::path{"plugin_ep_compat_test.onnx"}); - // Prepare session options with ExampleEP appended Ort::SessionOptions session_options; std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - // Create session from package root (directory) - // ORT should pick the variant_1 model since it has OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL for the example EP, - // while variant_2 is only OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION. - // If variant_2 was selected and loaded, i.e. mul_16.onnx, session initialization would fail with error "Error No Op registered for Mul16". Ort::Session session(*ort_env, package_root.c_str(), session_options); - // Cleanup std::error_code ec; std::filesystem::remove_all(package_root, ec); } -TEST(ModelPackageTest, LoadModelPackageAndRunInference_DiscoverComponentsFromModelsFolder) { - // manifest.json without "components"; discovery should scan models/* with metadata.json. - const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_discover_test"; - constexpr std::string_view manifest_json = R"({ - "schema_version": 1, - "model_name": "test_model" - })"; - - CreateModelPackage(package_root, manifest_json, - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); - - // Prepare component model with metadata and variants - const std::string component_name = "model_1"; - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_1", "example_ep", "cpu", "", - "variant_2", "example_ep", "npu", ""); - - // Create metadata.json under models/model_1 - const auto component_root = CreateComponentModelMetadata(package_root, - component_name, - metadata_json); - - // Add another component folder without metadata to ensure it's ignored - std::filesystem::create_directories(package_root / "models" / "ignored_component"); - - // Register example EP and get its device - RegisteredEpDeviceUniquePtr example_ep; - ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); - Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - - // Prepare session options with ExampleEP appended - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - - // Create session from package root (directory). Discovery should find model_1 via metadata.json, - // then pick variant_1 (device cpu) matching the example EP device. - // If variant_2 was selected and loaded, i.e. mul_16.onnx, session initialization would fail with error "Error No Op registered for Mul16". - Ort::Session session(*ort_env, package_root.c_str(), session_options); - - // Prepare input X - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input_data = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - Ort::Value input = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), - shape.data(), shape.size()); - const char* input_names[] = {"X"}; - const char* output_names[] = {"Y"}; - std::vector inputs; - inputs.push_back(std::move(input)); - - // Run - auto outputs = session.Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), inputs.size(), - output_names, 1); - ASSERT_EQ(outputs.size(), 1u); - const float* out = outputs[0].GetTensorData(); - gsl::span out_span(out, input_data.size()); - EXPECT_THAT(out_span, ::testing::ElementsAre(1.f, 4.f, 9.f, 16.f, 25.f, 36.f)); - - // Cleanup - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageTest, ParseVariantsFromRoot_PackageRootDirectory) { +TEST(ModelPackageTest, ParseVariantsFromPackageRoot) { const auto package_root = std::filesystem::temp_directory_path() / "ort_model_package_parse_from_package_root"; - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - - // package_root is a model package directory (has manifest.json). - constexpr std::string_view manifest_json = R"({ - "schema_version": 1, - "components": ["model_1"] - })"; - - CreateModelPackage(package_root, manifest_json, - "model_1", "variant_1", "variant_2", - std::filesystem::path{"testdata/mul_1.onnx"}, std::filesystem::path{"testdata/mul_16.onnx"}); - - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", - "device": "cpu" - }, - "variant_2": { - "ep": "example_ep", - "device": "npu" - } - } - })"; - - CreateComponentModelMetadata(package_root, "model_1", metadata_json); - - // New schema: per-variant descriptor in variant.json - { - std::ofstream os(package_root / "models" / "model_1" / "variant_1" / "variant.json", std::ios::binary); - os << R"({ "filename": "mul_1.onnx" })"; - } - { - std::ofstream os(package_root / "models" / "model_1" / "variant_2" / "variant.json", std::ios::binary); - os << R"({ "filename": "mul_16.onnx" })"; - } + BuildTwoVariantPackage(package_root, + "variant_1", "cpu", "", + std::filesystem::path{"testdata/mul_1.onnx"}, + "variant_2", "npu", "", + std::filesystem::path{"testdata/mul_16.onnx"}); ModelPackageContext ctx(package_root); const auto& variants = ctx.GetVariantInfos(); @@ -795,207 +571,19 @@ TEST(ModelPackageTest, ParseVariantsFromRoot_PackageRootDirectory) { EXPECT_EQ(v2->ep_compatibility.ep.value_or(""), "example_ep"); EXPECT_EQ(v2->ep_compatibility.device.value_or(""), "npu"); - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageTest, ParseVariantsFromRoot_ComponentModelDirectory) { - const auto component_root = std::filesystem::temp_directory_path() / "ort_model_package_parse_from_component_root"; - std::error_code ec; - std::filesystem::remove_all(component_root, ec); - std::filesystem::create_directories(component_root); - - // package_root is a component model directory (has metadata.json, no manifest.json). - const auto variant_dir = component_root / "variant_1"; - std::filesystem::create_directories(variant_dir); - std::filesystem::copy_file("testdata/mul_1.onnx", variant_dir / "mul_1.onnx", - std::filesystem::copy_options::overwrite_existing, ec); - - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", - "device": "cpu" - } - } - })"; - - { - std::ofstream os(component_root / "metadata.json", std::ios::binary); - os << metadata_json; - } - - { - std::ofstream os(variant_dir / "variant.json", std::ios::binary); - os << R"({ "filename": "mul_1.onnx" })"; - } - - ModelPackageContext ctx(component_root); - const auto& variants = ctx.GetVariantInfos(); - - ASSERT_EQ(variants.size(), 1u); - ASSERT_TRUE(variants[0].file.has_value()); - EXPECT_EQ(variants[0].file->model_file_path.filename().string(), "mul_1.onnx"); - - EXPECT_EQ(variants[0].ep_compatibility.ep.value_or(""), "example_ep"); - EXPECT_EQ(variants[0].ep_compatibility.device.value_or(""), "cpu"); - - std::filesystem::remove_all(component_root, ec); -} - -// ------------------------------------------------------------------ -// Tests for descriptor parser: enforced "ep" field in variant EP metadata. -// ------------------------------------------------------------------ -namespace { - -// Make a single-component, single-variant package on disk where metadata.json is written -// directly at the package root (the "single-component metadata flow" of the parser). -// In this flow variant EP metadata schema validation errors are propagated, instead of being -// swallowed by the manifest-driven discovery path which falls back to "Missing metadata variants". -// Returns the package_root. -std::filesystem::path MakeSingleComponentPackageWithMetadata(std::string_view subdir, - std::string_view metadata_json, - std::string_view variant_json = R"({"filename":"mul_1.onnx"})") { - const auto package_root = std::filesystem::temp_directory_path() / std::string(subdir); std::error_code ec; std::filesystem::remove_all(package_root, ec); - std::filesystem::create_directories(package_root); - - // Write metadata.json directly under package_root (no manifest, no models/ subdir). - { - std::ofstream os(package_root / "metadata.json", std::ios::binary); - os << metadata_json; - } - - // Variants live directly under package_root for the single-component flow. - const auto variant_dir = package_root / "variant_1"; - std::filesystem::create_directories(variant_dir); - std::filesystem::copy_file("testdata/mul_1.onnx", variant_dir / "mul_1.onnx", - std::filesystem::copy_options::overwrite_existing, ec); - - std::ofstream os(variant_dir / "variant.json", std::ios::binary); - os << variant_json; - - return package_root; } -} // namespace - -TEST(ModelPackageTest, ParserRejects_EpCompatibilityMissingEp) { - // The "ep" field is required in every variant descriptor. - // Omitting it must yield a parse error (not silently accept a wildcard / portable variant). - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "device": "cpu", - "compatibility_string": "anything" - } - } - })"; - const auto package_root = MakeSingleComponentPackageWithMetadata( - "ort_model_package_parser_missing_ep", metadata_json); - - try { - ModelPackageContext ctx(package_root); - FAIL() << "Expected exception for missing 'ep' field"; - } catch (const std::exception& e) { - EXPECT_NE(std::string(e.what()).find("ep"), std::string::npos) << e.what(); - } - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageTest, ParserRejects_EpCompatibilityNullEp) { - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": null, - "device": "cpu" - } - } - })"; - const auto package_root = MakeSingleComponentPackageWithMetadata( - "ort_model_package_parser_null_ep", metadata_json); - - try { - ModelPackageContext ctx(package_root); - FAIL() << "Expected exception for null 'ep' field"; - } catch (const std::exception& e) { - EXPECT_NE(std::string(e.what()).find("ep"), std::string::npos) << e.what(); - } - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -TEST(ModelPackageTest, ParserRejects_EpCompatibilityEmptyEp) { - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "", - "device": "cpu" - } - } - })"; - const auto package_root = MakeSingleComponentPackageWithMetadata( - "ort_model_package_parser_empty_ep", metadata_json); - - try { - ModelPackageContext ctx(package_root); - FAIL() << "Expected exception for empty 'ep' field"; - } catch (const std::exception& e) { - EXPECT_NE(std::string(e.what()).find("ep"), std::string::npos) << e.what(); - } - - std::error_code ec; - std::filesystem::remove_all(package_root, ec); -} - -// ------------------------------------------------------------------ -// Tests for new pre-selection EP-compat traversal accessors. -// ------------------------------------------------------------------ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { const auto package_root = std::filesystem::temp_directory_path() / "ort_mp_pre_selection_ep_name"; - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - - CreateManifestJson(package_root, MakeManifestJson("model_1")); - - const auto variant1_dir = package_root / "models" / "model_1" / "variant_1"; - const auto variant2_dir = package_root / "models" / "model_1" / "variant_2"; - std::filesystem::create_directories(variant1_dir); - std::filesystem::create_directories(variant2_dir); - std::filesystem::copy_file("testdata/mul_1.onnx", variant1_dir / "mul_1.onnx", - std::filesystem::copy_options::overwrite_existing, ec); - std::filesystem::copy_file("testdata/mul_1.onnx", variant2_dir / "mul_1.onnx", - std::filesystem::copy_options::overwrite_existing, ec); - - // Each variant declares a single EP. - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", - "device": "cpu" - }, - "variant_2": { - "ep": "other_ep", - "device": "npu" - } - } - })"; - CreateComponentModelMetadata(package_root, "model_1", metadata_json); - - for (const auto& d : {variant1_dir, variant2_dir}) { - std::ofstream os(d / "variant.json", std::ios::binary); - os << R"({"filename":"mul_1.onnx"})"; - } + std::vector variants; + variants.push_back(VariantSpec{"variant_1", "example_ep", "cpu", "", "testdata/mul_1.onnx", {}, {}}); + variants.push_back(VariantSpec{"variant_2", "other_ep", "npu", "", "testdata/mul_1.onnx", {}, {}}); + BuildPackage(package_root, "model_1", variants); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); @@ -1005,14 +593,12 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageContext(package_root.c_str(), &raw_ctx)); ctx.reset(raw_ctx); - // variant_1 targets example_ep const char* ep1 = nullptr; ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", &ep1)); ASSERT_NE(ep1, nullptr); EXPECT_STREQ(ep1, "example_ep"); - // variant_2 targets other_ep const char* ep2 = nullptr; ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_2", &ep2)); @@ -1023,47 +609,34 @@ TEST(ModelPackageApiTest, GetVariantEpName_ReturnsSingleEp) { ASSERT_ORTSTATUS_OK(pkg_api.ModelPackage_GetVariantEpName( ctx.get(), "model_1", "variant_1", nullptr)); + std::error_code ec; std::filesystem::remove_all(package_root, ec); } -// ------------------------------------------------------------------ -// ------------------------------------------------------------------ -// Test: variant selector tie-break is deterministic across repeated invocations. -// Two variants advertise compatibility for the same EP/device and EP returns the same -// validation score for both -- selection must be stable. -// ------------------------------------------------------------------ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) { // Both variants point at the *same* model file (mul_1.onnx) so whichever wins works at runtime. // They advertise identical EP/device pairs and empty compatibility_string so the EP returns the - // same score (NOT_APPLICABLE) for both -- a tie. The fix in commit 27217da484 guarantees that - // ties resolve deterministically, i.e., selection is stable across repeated runs. + // same score (NOT_APPLICABLE) for both: ties must resolve deterministically across runs. RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); - std::string first_selected_filename; + std::string first_selected_variant; for (int iter = 0; iter < 5; ++iter) { const auto package_root = std::filesystem::temp_directory_path() / "ort_mp_tie_break"; - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - - CreateModelPackage(package_root, MakeManifestJson("model_1"), - "model_1", "variant_a", "variant_b", - std::filesystem::path{"testdata/mul_1.onnx"}, - std::filesystem::path{"testdata/mul_1.onnx"}); - - const std::string metadata_json = MakeMetadataJsonTwoVariants( - "model_1", - "variant_a", "example_ep", "cpu", "", - "variant_b", "example_ep", "cpu", ""); - CreateComponentModelMetadata(package_root, "model_1", metadata_json); + BuildTwoVariantPackage(package_root, + "variant_a", "cpu", "", + std::filesystem::path{"testdata/mul_1.onnx"}, + "variant_b", "cpu", "", + std::filesystem::path{"testdata/mul_1.onnx"}); Ort::SessionOptions session_options; std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; @@ -1090,61 +663,34 @@ TEST(ModelPackageTest, VariantSelector_TieBreakIsDeterministic) { ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr); - // Path looks like .../models/model_1/ -- the folder name is the variant. + // Variant directories live at /model_1/; the leaf name is the variant. const auto selected_variant_dir = std::filesystem::path(selected_folder).filename().string(); ASSERT_TRUE(selected_variant_dir == "variant_a" || selected_variant_dir == "variant_b") << "unexpected variant dir: " << selected_variant_dir; if (iter == 0) { - first_selected_filename = selected_variant_dir; + first_selected_variant = selected_variant_dir; } else { - EXPECT_EQ(selected_variant_dir, first_selected_filename) + EXPECT_EQ(selected_variant_dir, first_selected_variant) << "tie-break selection drifted across runs (iter " << iter << ")"; } + std::error_code ec; std::filesystem::remove_all(package_root, ec); } } -// ------------------------------------------------------------------ -// Test: a variant's per-file `session_options` flow through OrtApis::AddSessionConfigEntry. -// We verify this by feeding a *known* typed key (session.intra_op_num_threads) a non-integer value: -// pre-change behavior would silently stuff it into AddConfigEntry and succeed; post-change -// behavior parses it via the typed dispatcher and fails CreateSession with a parse error. -// ------------------------------------------------------------------ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEntry) { + // Per-variant session_options assigns a typed key (session.intra_op_num_threads) a value that + // is not a valid integer. Routing this through OrtApis::AddSessionConfigEntry must reject it. const auto package_root = std::filesystem::temp_directory_path() / "ort_mp_session_options_dispatch"; - std::error_code ec; - std::filesystem::remove_all(package_root, ec); - - CreateManifestJson(package_root, MakeManifestJson("model_1")); - - const auto variant_dir = package_root / "models" / "model_1" / "variant_1"; - std::filesystem::create_directories(variant_dir); - std::filesystem::copy_file("testdata/mul_1.onnx", variant_dir / "mul_1.onnx", - std::filesystem::copy_options::overwrite_existing, ec); - - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", "device": "cpu" - } - } - })"; - CreateComponentModelMetadata(package_root, "model_1", metadata_json); - - // Per-file session_options assigns a typed key (session.intra_op_num_threads) a value that is not a - // valid integer. Routing this through OrtApis::AddSessionConfigEntry (the new behavior) must reject it. - { - std::ofstream os(variant_dir / "variant.json", std::ios::binary); - os << R"({ - "filename": "mul_1.onnx", - "session_options": { - "session.intra_op_num_threads": "not_an_int" - } - })"; - } + std::vector variants; + variants.push_back(VariantSpec{ + "variant_1", "example_ep", "cpu", "", "testdata/mul_1.onnx", std::unordered_map{ + {"session.intra_op_num_threads", "not_an_int"}, + }, + {}}); + BuildPackage(package_root, "model_1", variants); RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); @@ -1155,6 +701,7 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; auto options_deleter = [&pkg_api](OrtModelPackageOptions* p) { if (p) pkg_api.ReleaseModelPackageOptions(p); }; auto context_deleter = [&pkg_api](OrtModelPackageContext* p) { if (p) pkg_api.ReleaseModelPackageContext(p); }; @@ -1177,13 +724,9 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn ASSERT_ORTSTATUS_OK(pkg_api.SelectComponent(ctx.get(), "model_1", mp_opts.get(), &raw_comp_ctx)); comp_ctx.reset(raw_comp_ctx); - // CreateSession iterates the per-file session_options and dispatches each through OrtApis::AddSessionConfigEntry. - // The bad int value must surface as an error from this call. - // Pass nullptr for session_options so the metadata-merge path runs (it is skipped when the caller - // supplies their own session_options). + // Pass nullptr for session_options so the metadata-merge path runs. OrtSession* raw_session = nullptr; OrtStatus* st = pkg_api.CreateSession(*ort_env, comp_ctx.get(), /*session_options=*/nullptr, &raw_session); - // Clean up session first to avoid leaks if assertion fails. if (raw_session != nullptr) { Ort::GetApi().ReleaseSession(raw_session); raw_session = nullptr; @@ -1192,45 +735,32 @@ TEST(ModelPackageTest, VariantSessionOptions_DispatchedThroughAddSessionConfigEn const std::string err_msg = Ort::GetApi().GetErrorMessage(st); Ort::GetApi().ReleaseStatus(st); - // Message should mention either AddSessionConfigEntry or the typed-int parse failure. const bool mentions_dispatch = err_msg.find("AddSessionConfigEntry") != std::string::npos || err_msg.find("base-10 int32") != std::string::npos || err_msg.find("intra_op_num_threads") != std::string::npos; EXPECT_TRUE(mentions_dispatch) << "error did not mention typed dispatch: " << err_msg; + std::error_code ec; std::filesystem::remove_all(package_root, ec); } -// ------------------------------------------------------------------ -// Test: GetSelectedVariantFolderPath returns correct path even when variant.json is absent. -// ------------------------------------------------------------------ -TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { - const auto package_root = std::filesystem::temp_directory_path() / "ort_mp_folder_path_no_variant_json"; +// GetSelectedVariantFolderPath returns the correct path even when the variant +// declares no executor_info (i.e., no `file` descriptor for the variant). +TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenExecutorInfoAbsent) { + const auto package_root = std::filesystem::temp_directory_path() / "ort_mp_folder_path_no_executor_info"; + std::vector variants; + // No source_model => no executor_info is emitted for this variant. + VariantSpec only{"variant_1", "example_ep", "cpu", "", {}, {}, {}}; + variants.push_back(only); + BuildPackage(package_root, "model_1", variants); + + // Drop a model file in the variant directory so the package looks plausible on disk. std::error_code ec; - std::filesystem::remove_all(package_root, ec); - std::filesystem::create_directories(package_root); - - CreateManifestJson(package_root, MakeManifestJson("model_1")); - - const auto variant_dir = package_root / "models" / "model_1" / "variant_1"; - std::filesystem::create_directories(variant_dir); - - // Copy a model file but do NOT create variant.json - std::filesystem::copy_file("testdata/mul_1.onnx", variant_dir / "mul_1.onnx", + std::filesystem::copy_file("testdata/mul_1.onnx", + package_root / "model_1" / "variant_1" / "mul_1.onnx", std::filesystem::copy_options::overwrite_existing, ec); - constexpr std::string_view metadata_json = R"({ - "component_name": "model_1", - "variants": { - "variant_1": { - "ep": "example_ep", - "device": "cpu" - } - } - })"; - CreateComponentModelMetadata(package_root, "model_1", metadata_json); - RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); @@ -1240,6 +770,7 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { so.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); const auto& pkg_api = GetModelPackageFns(); + ASSERT_NE(pkg_api.CreateModelPackageContext, nullptr) << "Model package experimental API is not available"; OrtModelPackageOptions* raw_mp_opts = nullptr; ASSERT_ORTSTATUS_OK(pkg_api.CreateModelPackageOptionsFromSessionOptions(*ort_env, so, &raw_mp_opts)); @@ -1258,7 +789,6 @@ TEST(ModelPackageApiTest, FolderPath_ReturnsCorrectPath_WhenVariantJsonAbsent) { }; std::unique_ptr comp_ctx(raw_comp_ctx, component_context_deleter); - // GetSelectedVariantFolderPath should return the variant directory even without variant.json. const ORTCHAR_T* selected_folder = nullptr; ASSERT_ORTSTATUS_OK(pkg_api.ModelPackageComponent_GetSelectedVariantFolderPath(comp_ctx.get(), &selected_folder)); ASSERT_NE(selected_folder, nullptr);