From 91eb27e97a89b5af539888d32918707333d6f31f Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 22 Jun 2026 17:18:56 +0100 Subject: [PATCH 1/7] initial commit of rust v2 sdk --- sdk_v2/rust/.clippy.toml | 2 + sdk_v2/rust/.rustfmt.toml | 3 + sdk_v2/rust/Cargo.toml | 50 ++ sdk_v2/rust/GENERATE-DOCS.md | 41 ++ sdk_v2/rust/LICENSE.txt | 21 + sdk_v2/rust/README.md | 608 ++++++++++++++++ sdk_v2/rust/build.rs | 306 ++++++++ sdk_v2/rust/docs/api.md | 552 +++++++++++++++ sdk_v2/rust/examples/chat_completion.rs | 97 +++ sdk_v2/rust/examples/interactive_chat.rs | 112 +++ sdk_v2/rust/examples/tool_calling.rs | 202 ++++++ sdk_v2/rust/src/catalog.rs | 152 ++++ sdk_v2/rust/src/configuration.rs | 253 +++++++ sdk_v2/rust/src/detail/api.rs | 431 ++++++++++++ sdk_v2/rust/src/detail/ffi.rs | 658 ++++++++++++++++++ sdk_v2/rust/src/detail/info.rs | 152 ++++ sdk_v2/rust/src/detail/items.rs | 226 ++++++ sdk_v2/rust/src/detail/manager.rs | 229 ++++++ sdk_v2/rust/src/detail/mod.rs | 13 + sdk_v2/rust/src/detail/model.rs | 405 +++++++++++ sdk_v2/rust/src/detail/native.rs | 250 +++++++ sdk_v2/rust/src/detail/session.rs | 311 +++++++++ sdk_v2/rust/src/detail/task.rs | 16 + sdk_v2/rust/src/error.rs | 36 + sdk_v2/rust/src/foundry_local_manager.rs | 351 ++++++++++ sdk_v2/rust/src/lib.rs | 45 ++ sdk_v2/rust/src/openai/audio_client.rs | 210 ++++++ sdk_v2/rust/src/openai/chat_client.rs | 317 +++++++++ sdk_v2/rust/src/openai/embedding_client.rs | 98 +++ sdk_v2/rust/src/openai/json_stream.rs | 49 ++ sdk_v2/rust/src/openai/live_audio_session.rs | 471 +++++++++++++ sdk_v2/rust/src/openai/mod.rs | 17 + sdk_v2/rust/src/types.rs | 151 ++++ .../tests/integration/audio_client_test.rs | 130 ++++ sdk_v2/rust/tests/integration/catalog_test.rs | 106 +++ .../tests/integration/chat_client_test.rs | 334 +++++++++ sdk_v2/rust/tests/integration/common/mod.rs | 128 ++++ .../integration/embedding_client_test.rs | 223 ++++++ .../rust/tests/integration/live_audio_test.rs | 114 +++ sdk_v2/rust/tests/integration/main.rs | 18 + sdk_v2/rust/tests/integration/manager_test.rs | 21 + sdk_v2/rust/tests/integration/model_test.rs | 323 +++++++++ .../tests/integration/web_service_test.rs | 161 +++++ 43 files changed, 8393 insertions(+) create mode 100644 sdk_v2/rust/.clippy.toml create mode 100644 sdk_v2/rust/.rustfmt.toml create mode 100644 sdk_v2/rust/Cargo.toml create mode 100644 sdk_v2/rust/GENERATE-DOCS.md create mode 100644 sdk_v2/rust/LICENSE.txt create mode 100644 sdk_v2/rust/README.md create mode 100644 sdk_v2/rust/build.rs create mode 100644 sdk_v2/rust/docs/api.md create mode 100644 sdk_v2/rust/examples/chat_completion.rs create mode 100644 sdk_v2/rust/examples/interactive_chat.rs create mode 100644 sdk_v2/rust/examples/tool_calling.rs create mode 100644 sdk_v2/rust/src/catalog.rs create mode 100644 sdk_v2/rust/src/configuration.rs create mode 100644 sdk_v2/rust/src/detail/api.rs create mode 100644 sdk_v2/rust/src/detail/ffi.rs create mode 100644 sdk_v2/rust/src/detail/info.rs create mode 100644 sdk_v2/rust/src/detail/items.rs create mode 100644 sdk_v2/rust/src/detail/manager.rs create mode 100644 sdk_v2/rust/src/detail/mod.rs create mode 100644 sdk_v2/rust/src/detail/model.rs create mode 100644 sdk_v2/rust/src/detail/native.rs create mode 100644 sdk_v2/rust/src/detail/session.rs create mode 100644 sdk_v2/rust/src/detail/task.rs create mode 100644 sdk_v2/rust/src/error.rs create mode 100644 sdk_v2/rust/src/foundry_local_manager.rs create mode 100644 sdk_v2/rust/src/lib.rs create mode 100644 sdk_v2/rust/src/openai/audio_client.rs create mode 100644 sdk_v2/rust/src/openai/chat_client.rs create mode 100644 sdk_v2/rust/src/openai/embedding_client.rs create mode 100644 sdk_v2/rust/src/openai/json_stream.rs create mode 100644 sdk_v2/rust/src/openai/live_audio_session.rs create mode 100644 sdk_v2/rust/src/openai/mod.rs create mode 100644 sdk_v2/rust/src/types.rs create mode 100644 sdk_v2/rust/tests/integration/audio_client_test.rs create mode 100644 sdk_v2/rust/tests/integration/catalog_test.rs create mode 100644 sdk_v2/rust/tests/integration/chat_client_test.rs create mode 100644 sdk_v2/rust/tests/integration/common/mod.rs create mode 100644 sdk_v2/rust/tests/integration/embedding_client_test.rs create mode 100644 sdk_v2/rust/tests/integration/live_audio_test.rs create mode 100644 sdk_v2/rust/tests/integration/main.rs create mode 100644 sdk_v2/rust/tests/integration/manager_test.rs create mode 100644 sdk_v2/rust/tests/integration/model_test.rs create mode 100644 sdk_v2/rust/tests/integration/web_service_test.rs diff --git a/sdk_v2/rust/.clippy.toml b/sdk_v2/rust/.clippy.toml new file mode 100644 index 000000000..1d42f2f12 --- /dev/null +++ b/sdk_v2/rust/.clippy.toml @@ -0,0 +1,2 @@ +# Clippy configuration for Foundry Local Rust SDK +msrv = "1.70" diff --git a/sdk_v2/rust/.rustfmt.toml b/sdk_v2/rust/.rustfmt.toml new file mode 100644 index 000000000..dce363edf --- /dev/null +++ b/sdk_v2/rust/.rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2021" +max_width = 100 +use_field_init_shorthand = true diff --git a/sdk_v2/rust/Cargo.toml b/sdk_v2/rust/Cargo.toml new file mode 100644 index 000000000..d85dcbe3f --- /dev/null +++ b/sdk_v2/rust/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "foundry-local-sdk" +version = "2.0.0" +edition = "2021" +license = "MIT" +readme = "README.md" +description = "Local AI model inference powered by the Foundry Local Core engine" +homepage = "https://www.foundrylocal.ai/" +repository = "https://github.com/microsoft/Foundry-Local" +documentation = "https://github.com/microsoft/Foundry-Local/blob/main/sdk_v2/rust/docs/api.md" +include = ["src/**", "build.rs", "Cargo.toml", "README.md", "LICENSE.txt", "deps_versions.json"] + +[features] +default = [] +winml = [] +nightly = [] + +[dependencies] +libloading = "0.8" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } +tokio-stream = "0.1" +tokio-util = "0.7" +futures-core = "0.3" +reqwest = { version = "0.12", features = ["json"] } +urlencoding = "2" +async-openai = { version = "=0.33.1", default-features = false, features = ["chat-completion-types", "embedding-types"] } + +[build-dependencies] +ureq = "3" +zip = "2" +serde_json = "1" +serde = { version = "1", features = ["derive"] } + +[[example]] +name = "chat_completion" +path = "examples/chat_completion.rs" + +[[example]] +name = "tool_calling" +path = "examples/tool_calling.rs" + +[[example]] +name = "interactive_chat" +path = "examples/interactive_chat.rs" + +[lints.clippy] +all = { level = "warn", priority = -1 } diff --git a/sdk_v2/rust/GENERATE-DOCS.md b/sdk_v2/rust/GENERATE-DOCS.md new file mode 100644 index 000000000..f02b5d99b --- /dev/null +++ b/sdk_v2/rust/GENERATE-DOCS.md @@ -0,0 +1,41 @@ +# Generating API Reference Docs + +The Rust SDK uses `cargo doc` to generate API documentation from `///` doc comments in the source code. + +## Viewing Docs Locally + +To generate and open the API docs in your browser: + +```bash +cd sdk/rust +cargo doc --no-deps --open +``` + +This generates HTML documentation at `target/doc/foundry_local_sdk/index.html`. + +## Public API Surface + +The SDK re-exports all public types from the crate root. Key modules: + +| Module / Type | Description | +|---|---| +| `FoundryLocalManager` | Singleton entry point — SDK initialisation, web service lifecycle | +| `FoundryLocalConfig` | Configuration (app name, log level, service endpoint) | +| `Catalog` | Model discovery and lookup | +| `Model` | Grouped model (alias → best variant) | +| `ModelVariant` | Single variant — download, load, unload | +| `ChatClient` | OpenAI-compatible chat completions (sync + streaming) | +| `AudioClient` | OpenAI-compatible audio transcription (sync + streaming) | +| `CreateChatCompletionResponse` | Typed chat completion response (from `async-openai`) | +| `CreateChatCompletionStreamResponse` | Typed streaming chat chunk (from `async-openai`) | +| `AudioTranscriptionResponse` | Typed audio transcription response | +| `FoundryLocalError` | Error enum with variants for all failure modes | + +## Notes + +- Unlike the C# and JS SDKs which commit generated markdown docs, Rust's convention is to generate HTML docs on demand with `cargo doc`. +- Once the crate is published to crates.io, docs will be automatically hosted at [docs.rs](https://docs.rs). +- Use `--document-private-items` to include internal/private API in the generated docs: + ```bash + cargo doc --no-deps --document-private-items --open + ``` diff --git a/sdk_v2/rust/LICENSE.txt b/sdk_v2/rust/LICENSE.txt new file mode 100644 index 000000000..48bc6bb49 --- /dev/null +++ b/sdk_v2/rust/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md new file mode 100644 index 000000000..9e4bf8673 --- /dev/null +++ b/sdk_v2/rust/README.md @@ -0,0 +1,608 @@ +# Foundry Local Rust SDK (v2) + +The Foundry Local Rust SDK provides an async Rust interface for running AI models locally on your machine. Discover, download, load, and run inference — all without cloud dependencies. + +> **v2 — built on the Foundry Local C++ engine.** This is the v2 binding, layered on the +> `foundry_local` C ABI (the same native engine used by the v2 Python and C# SDKs). Its public +> API is **fully backwards-compatible** with the v1 Rust SDK (`sdk/rust`) — existing code +> compiles and runs unchanged. + +## Features + +- **Local-first AI** — Run models entirely on your machine with no cloud calls +- **Model catalog** — Browse and discover available models; check what's cached or loaded +- **Automatic model management** — Download, load, unload, and remove models from cache +- **Chat completions** — OpenAI-compatible chat API with both non-streaming and streaming responses +- **Embeddings** — Generate text embeddings via OpenAI-compatible API +- **Audio transcription** — Transcribe audio files locally with streaming support +- **Tool calling** — Function/tool calling with streaming, multi-turn conversation support +- **Response format control** — Text, JSON, JSON Schema, and Lark grammar constrained output +- **Multi-variant models** — Models can have multiple variants (e.g., different quantizations) with automatic selection of the best cached variant +- **Embedded web service** — Start a local HTTP server for OpenAI-compatible API access +- **WinML support** — Automatic execution provider download on Windows for NPU/GPU acceleration +- **Configurable inference** — Control temperature, max tokens, top-k, top-p, frequency penalty, random seed, and more +- **Async-first** — Every operation is `async`; designed for use with the `tokio` runtime +- **Safe FFI** — Dynamically loads the native Foundry Local engine (`foundry_local`) with a safe Rust wrapper + +## Prerequisites + +- **Rust** 1.70+ (stable toolchain) +- The native `foundry_local` engine (plus its ONNX Runtime / GenAI dependencies) available at + build or run time — see [Native binary](#native-binary) + +## Installation + +```sh +cargo add foundry-local-sdk +``` + +Or add to your `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = "2" +``` + +You also need an async runtime. Most examples use [tokio](https://crates.io/crates/tokio): + +```toml +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" # for StreamExt on streaming responses +``` + +### Feature Flags + +| Feature | Description | +|-----------|-------------| +| `winml` | Use the WinML backend (Windows only). Selects different ONNX Runtime and GenAI packages for NPU/GPU acceleration. | +| `nightly` | Resolve the latest nightly build of the Core package from the ORT-Nightly feed. | + +Enable features in `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = { version = "2", features = ["winml"] } +``` + +> **Note:** The `winml` feature is only relevant on Windows. On macOS and Linux, the standard build is used regardless. No code changes are needed — your application code stays the same. + +With `winml` enabled on Windows, the `winml` feature selects the WinML Runtime package (`Microsoft.AI.Foundry.Local.Runtime.WinML`) when downloading via `FOUNDRY_LOCAL_RUNTIME_VERSION`, and the WinML execution-provider DLLs are pre-loaded alongside `foundry_local`. See [Native binary](#native-binary) for how the engine is obtained. + +### Explicit EP Management + +You can explicitly discover and download execution providers: + +```rust +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; + +let manager = FoundryLocalManager::create(FoundryLocalConfig::new("my_app"))?; + +// Discover available EPs and their status +let eps = manager.discover_eps()?; +for ep in &eps { + println!("{} — registered: {}", ep.name, ep.is_registered); +} + +// Download and register all available EPs +let result = manager.download_and_register_eps(None).await?; +println!("Success: {}, Status: {}", result.success, result.status); + +// Download only specific EPs +let result = manager.download_and_register_eps(Some(&[eps[0].name.as_str()])).await?; +``` + +#### Per-EP download progress + +Use `download_and_register_eps_with_progress` to receive typed `(ep_name, percent)` updates +as each EP downloads (`percent` is 0.0–100.0): + +```rust +use std::sync::{Arc, Mutex}; + +let current_ep = Arc::new(Mutex::new(String::new())); +let ep = Arc::clone(¤t_ep); +manager.download_and_register_eps_with_progress(None, move |ep_name: &str, percent: f64| { + let mut current = ep.lock().unwrap(); + if ep_name != current.as_str() { + if !current.is_empty() { + println!(); + } + *current = ep_name.to_string(); + } + print!("\r {} {:5.1}%", ep_name, percent); +}).await?; +println!(); +``` + +#### Cancelling model and EP downloads + +Use a shared `Arc` with the download builders. Set the flag from another task or signal handler to stop the in-progress download. + +```rust +use std::sync::{ + atomic::AtomicBool, + Arc, +}; + +// manager and model already initialized +let cancel_flag = Arc::new(AtomicBool::new(false)); +// call cancel_flag.store(true, ...) from another task or signal handler to cancel + +manager + .download_and_register_eps_builder() + .cancel(Arc::clone(&cancel_flag)) + .run() + .await?; +model + .download_builder() + .cancel(Arc::clone(&cancel_flag)) + .run() + .await?; +``` + +Catalog access does not block on EP downloads. Call `download_and_register_eps` when you need hardware-accelerated execution providers. + +## Quick Start + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalManager, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // 1. Initialize the manager — loads native libraries and starts the engine + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("my_app"))?; + + // 2. Get a model from the catalog and load it + let model = manager.catalog().get_model("phi-3.5-mini").await?; + model.load().await?; + + // 3. Create a chat client and run inference + let client = model.create_chat_client() + .temperature(0.7) + .max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is the capital of France?").into(), + ]; + + let response = client.complete_chat(&messages, None).await?; + println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); + + // 4. Clean up + model.unload().await?; + + Ok(()) +} +``` + +## Usage + +### Browsing the Model Catalog + +The `Catalog` lets you discover what models are available, which are already cached locally, and which are currently loaded in memory. + +```rust +let catalog = manager.catalog(); + +// List all available models +let models = catalog.get_models().await?; +for model in &models { + println!("{} (id: {})", model.alias(), model.id()); +} + +// Look up a specific model by alias +let model = catalog.get_model("phi-3.5-mini").await?; + +// Look up a specific variant by its unique model ID +let variant = catalog.get_model_variant("phi-3.5-mini-generic-gpu-4").await?; + +// See what's already downloaded +let cached = catalog.get_cached_models().await?; + +// See what's currently loaded in memory +let loaded = catalog.get_loaded_models().await?; +``` + +### Model Lifecycle + +Each model may have multiple variants (different quantizations, hardware targets). The SDK auto-selects the best available variant, preferring cached versions. All models are represented by the `Model` type. + +```rust +let model = catalog.get_model("phi-3.5-mini").await?; + +// Inspect available variants +println!("Selected: {}", model.id()); +for v in model.variants() { + println!(" {} (info.cached: {})", v.id(), v.info().cached); +} +``` + +Download, load, and unload: + +```rust +// Download with progress reporting +model.download(Some(|progress: f64| { + print!("\r{progress:.1}%"); + std::io::Write::flush(&mut std::io::stdout()).ok(); +})).await?; + +// Or use the builder when combining progress, cancellation, or future options +let cancel_flag = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); +model.download_builder() + .progress(|progress| { + print!("\r{progress:.1}%"); + std::io::Write::flush(&mut std::io::stdout()).ok(); + }) + .cancel(cancel_flag.clone()) + .run() + .await?; + +// Load into memory +model.load().await?; + +// Unload when done +model.unload().await?; + +// Remove from local cache entirely +model.remove_from_cache().await?; +``` + +### Chat Completions + +The `ChatClient` follows the OpenAI Chat Completion API structure. + +```rust +let client = model.create_chat_client() + +// Configure generation settings (fluent builder) + .temperature(0.7) + .max_tokens(256) + .top_p(0.9) + .frequency_penalty(0.5); + +// Non-streaming completion +let response = client.complete_chat( + &[ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain Rust's ownership model.").into(), + ], + None, +).await?; + +println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); +``` + +### Streaming Responses + +For real-time token-by-token output, use streaming: + +```rust +use tokio_stream::StreamExt; + +let mut stream = client.complete_streaming_chat( + &[ChatCompletionRequestUserMessage::from("Write a short poem about Rust.").into()], + None, +).await?; + +while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(content) = &chunk.choices[0].delta.content { + print!("{content}"); + } +} + +// Errors from the native core are delivered as stream items — +// no separate close() call needed. +``` + +### Tool Calling + +Define functions the model can call and handle the multi-turn conversation: + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, + ChatCompletionTools, ChatToolChoice, FinishReason, +}; +use serde_json::json; + +// Define available tools +let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { "type": "string", "description": "City name" } + }, + "required": ["location"] + } + } +}]))?; + +let client = model.create_chat_client() + .max_tokens(512) + .tool_choice(ChatToolChoice::Auto); + +let mut messages: Vec = vec![ + ChatCompletionRequestUserMessage::from("What's the weather in Seattle?").into(), +]; + +// First request — model may call a tool +let response = client.complete_chat(&messages, Some(&tools)).await?; +let choice = &response.choices[0]; + +if choice.finish_reason == Some(FinishReason::ToolCalls) { + if let Some(tool_calls) = &choice.message.tool_calls { + for tc in tool_calls { + // Execute the tool (your application logic) + let result = execute_tool(&tc.function.name, &tc.function.arguments); + + // Add assistant message with tool calls, then the tool result + messages.push(serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ "id": tc.id, "type": "function", + "function": { "name": tc.function.name, + "arguments": tc.function.arguments } }] + }))?); + messages.push(ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc.id.clone(), + }.into()); + } + + // Continue the conversation with tool results + let final_response = client.complete_chat(&messages, Some(&tools)).await?; + println!("{}", final_response.choices[0].message.content.as_deref().unwrap_or("")); + } +} +``` + +Tool calling also works with streaming via `complete_streaming_chat` — accumulate tool call fragments during streaming and check for `FinishReason::ToolCalls`. + +### Response Format Options + +Control the output format of chat completions: + +```rust +use foundry_local_sdk::ChatResponseFormat; + +// Plain text (default) +let client = model.create_chat_client() + .response_format(ChatResponseFormat::Text); + +// Unstructured JSON output +let client = model.create_chat_client() + .response_format(ChatResponseFormat::JsonObject); + +// JSON constrained to a schema +let client = model.create_chat_client() + .response_format(ChatResponseFormat::JsonSchema(r#"{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name", "age"] + }"#.to_string())); + +// Output constrained by a Lark grammar (Foundry extension) +let client = model.create_chat_client() + .response_format(ChatResponseFormat::LarkGrammar(grammar.to_string())); +``` + +### Embeddings + +Generate text embeddings using the `EmbeddingClient`: + +```rust +let embedding_client = model.create_embedding_client(); + +// Single input +let response = embedding_client + .generate_embedding("The quick brown fox jumps over the lazy dog") + .await?; +let embedding = &response.data[0].embedding; // Vec +println!("Dimensions: {}", embedding.len()); + +// Batch input +let batch_response = embedding_client + .generate_embeddings(&["The quick brown fox", "The capital of France is Paris"]) + .await?; +// batch_response.data[0].embedding, batch_response.data[1].embedding +``` + +### Audio Transcription + +Transcribe audio files locally using the `AudioClient`: + +```rust +let model = manager.catalog().get_model("whisper-tiny").await?; +model.load().await?; + +let audio_client = model.create_audio_client() + .language("en"); + +// Non-streaming transcription +let result = audio_client.transcribe("recording.wav").await?; +println!("{}", result.text); +``` + +#### Streaming Transcription + +```rust +use tokio_stream::StreamExt; + +let mut stream = audio_client.transcribe_streaming("recording.wav").await?; +while let Some(chunk) = stream.next().await { + print!("{}", chunk?.text); +} +``` + +### Embedded Web Service + +Start a local HTTP server that exposes an OpenAI-compatible REST API: + +```rust +manager.start_web_service().await?; +let urls = manager.urls()?; +println!("Service running at: {:?}", urls); + +// Any OpenAI-compatible client or tool can now connect to the endpoint. +// ... + +manager.stop_web_service().await?; +``` + +### Chat Client Settings + +All settings are configured via chainable builder methods on `ChatClient`: + +| Method | Type | Description | +|--------|------|-------------| +| `temperature(v)` | `f64` | Sampling temperature (0.0–2.0; higher = more random) | +| `max_tokens(v)` | `u32` | Maximum number of tokens to generate | +| `top_p(v)` | `f64` | Nucleus sampling probability (0.0–1.0) | +| `top_k(v)` | `u32` | Top-k sampling parameter (Foundry extension) | +| `frequency_penalty(v)` | `f64` | Frequency penalty | +| `presence_penalty(v)` | `f64` | Presence penalty | +| `n(v)` | `u32` | Number of completions to generate | +| `random_seed(v)` | `u64` | Random seed for reproducible results (Foundry extension) | +| `response_format(v)` | `ChatResponseFormat` | Output format (Text, JsonObject, JsonSchema, LarkGrammar) | +| `tool_choice(v)` | `ChatToolChoice` | Tool selection strategy (None, Auto, Required, Function) | + +## Error Handling + +All fallible operations return `foundry_local_sdk::Result`, which is an alias for `std::result::Result`. + +```rust +use foundry_local_sdk::FoundryLocalError; + +match manager.catalog().get_model("nonexistent").await { + Ok(model) => { /* use model */ } + Err(FoundryLocalError::ModelOperation { reason }) => { + eprintln!("Model error: {reason}"); + } + Err(FoundryLocalError::CommandExecution { reason }) => { + eprintln!("Core engine error: {reason}"); + } + Err(e) => { + eprintln!("Unexpected error: {e}"); + } +} +``` + +### Error Variants + +| Variant | Description | +|---------|-------------| +| `LibraryLoad { reason }` | The native core library could not be loaded | +| `CommandExecution { reason }` | A command executed against native core returned an error | +| `InvalidConfiguration { reason }` | The provided configuration is invalid | +| `ModelOperation { reason }` | A model operation failed (load, unload, download, etc.) | +| `HttpRequest(reqwest::Error)` | An HTTP request to an external service failed | +| `Serialization(serde_json::Error)` | JSON serialization/deserialization failed | +| `Validation { reason }` | A validation check on user-supplied input failed | +| `Io(std::io::Error)` | An I/O error occurred | +| `Internal { reason }` | An internal SDK error (e.g. poisoned lock) | + +## Configuration + +The SDK is configured via `FoundryLocalConfig` when creating the manager: + +```rust +use foundry_local_sdk::{FoundryLocalConfig, LogLevel}; + +let config = FoundryLocalConfig::new("my_app") + .log_level(LogLevel::Info) + .model_cache_dir("/path/to/cache") + .web_service_urls("http://127.0.0.1:5000"); + +let manager = FoundryLocalManager::create(config)?; +``` + +| Setting | Builder method | Default | Description | +|---------|---------------|---------|-------------| +| App name | `new(name)` | **(required)** | Your application name | +| App data dir | `.app_data_dir(dir)` | `~/.{app_name}` | Application data directory | +| Model cache dir | `.model_cache_dir(dir)` | `{app_data_dir}/cache/models` | Where models are stored locally | +| Logs dir | `.logs_dir(dir)` | `{app_data_dir}/logs` | Log output directory | +| Log level | `.log_level(level)` | `Warn` | `Trace`, `Debug`, `Info`, `Warn`, `Error`, `Fatal` | +| Web service URLs | `.web_service_urls(urls)` | `None` | Bind address for the embedded web service | +| Service endpoint | `.service_endpoint(url)` | `None` | URL of an existing external service to connect to | +| Library path | `.library_path(path)` | Auto-discovered | Path to the native `foundry_local` library (or its directory) | +| Additional settings | `.additional_setting(k, v)` | `None` | Extra key-value settings passed to Core | +| Logger | `.logger(impl Logger)` | `None` | Application logger (stub — not yet wired) | + +## How It Works + +### Native binary + +The SDK loads the `foundry_local` native library (and its ONNX Runtime / GenAI dependencies) at +runtime. The `build.rs` build script can obtain it in two ways, controlled by environment variables: + +| Variable | Purpose | +|----------|---------| +| `FOUNDRY_LOCAL_NATIVE_BIN_DIR` | Copy native binaries from a local C++ build output directory (the dev path). Mirrors the C# `FoundryLocalNativeBinDir`. | +| `FOUNDRY_LOCAL_RUNTIME_VERSION` | Download the Runtime NuGet package (`Microsoft.AI.Foundry.Local.Runtime`, or `.Runtime.WinML` with the `winml` feature) plus ONNX Runtime / GenAI for the target RID. | + +If neither is set at build time, the library is resolved at **runtime** from (in order): + +1. `FoundryLocalConfig::library_path` — a path to the `foundry_local` library file or its directory. +2. The `FOUNDRY_LOCAL_LIB_DIR` environment variable. +3. The directory of the running executable. +4. The system loader search path. + +> **Migration note:** in v1, `FoundryLocalConfig::library_path` pointed at the +> `Microsoft.AI.Foundry.Local.Core` library. In v2 it points at the `foundry_local` library (or its +> containing directory). The builder method and signature are unchanged. + +ONNX Runtime and GenAI are pre-loaded before `foundry_local` so the dynamic loader resolves the +engine's dependencies regardless of rpath/search-path setup. On platforms where GenAI honours it, +`ORT_LIB_PATH` is set to the pre-loaded ONNX Runtime. + +### Runtime Loading + +At runtime, the SDK uses `libloading` to dynamically load the `foundry_local` library, resolve the +API function table via `FoundryLocalGetApi`, and cache the sub-API tables. No static linking or +system-wide installation is required. + +## Platform Support + +| Platform | RID | Status | +|-----------------|--------------|--------| +| Windows x64 | `win-x64` | ✅ | +| Windows ARM64 | `win-arm64` | ✅ | +| Linux x64 | `linux-x64` | ✅ | +| Linux ARM64 | `linux-arm64`| ✅ | +| macOS ARM64 | `osx-arm64` | ✅ | + +## Running Examples + +Sample applications are available in [`samples/rust/`](../../samples/rust/): + +| Sample | Description | +|--------|-------------| +| `native-chat-completions` | Non-streaming and streaming chat completions | +| `tool-calling-foundry-local` | Function/tool calling with multi-turn conversations | +| `audio-transcription-example` | Audio transcription (non-streaming and streaming) | +| `foundry-local-webserver` | Embedded OpenAI-compatible REST API server | + +Run a sample with: + +```sh +cd samples/rust +cargo run -p native-chat-completions +``` + +## License + +Microsoft Software License Terms — see [LICENSE](../../LICENSE) for details. diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs new file mode 100644 index 000000000..5b3100eae --- /dev/null +++ b/sdk_v2/rust/build.rs @@ -0,0 +1,306 @@ +//! Build script for the Foundry Local v2 Rust SDK. +//! +//! Obtains the native `foundry_local` library (plus ONNX Runtime + GenAI) and +//! makes it discoverable at runtime via the `FOUNDRY_NATIVE_DIR` compile-time +//! env that `detail::api` consults. +//! +//! Native acquisition order: +//! 1. `FOUNDRY_LOCAL_NATIVE_BIN_DIR` — copy native files from a local C++ build +//! (the dev path, mirroring the C# `FoundryLocalNativeBinDir`). +//! 2. `FOUNDRY_LOCAL_RUNTIME_VERSION` — download the Runtime NuGet package +//! (`Microsoft.AI.Foundry.Local.Runtime[.WinML]`) plus ORT/GenAI for the RID. +//! 3. Otherwise no-op: the library is resolved at runtime from +//! `FOUNDRY_LOCAL_LIB_DIR`, next to the executable, or the system path. + +use std::env; +use std::fs; +use std::io::{self, Read}; +use std::path::{Path, PathBuf}; + +const FEEDS: &[&str] = &[ + "https://api.nuget.org/v3/index.json", + "https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json", +]; + +struct DepsVersions { + ort: String, + genai: String, +} + +fn load_deps_versions(manifest_dir: &Path) -> DepsVersions { + let candidates = [ + manifest_dir.join("deps_versions.json"), + manifest_dir.join("..").join("deps_versions.json"), + ]; + let json_path = candidates + .iter() + .find(|p| p.exists()) + .cloned() + .unwrap_or_else(|| candidates[0].clone()); + println!("cargo:rerun-if-changed={}", json_path.display()); + + let content = fs::read_to_string(&json_path).unwrap_or_default(); + let stripped = content.strip_prefix('\u{FEFF}').unwrap_or(&content); + let val: serde_json::Value = serde_json::from_str(stripped).unwrap_or(serde_json::Value::Null); + let s = |key: &str| -> String { + val.get(key) + .and_then(|o| o.get("version")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string() + }; + DepsVersions { + ort: s("onnxruntime"), + genai: s("onnxruntime-genai"), + } +} + +fn get_rid() -> Option<&'static str> { + match (env::consts::OS, env::consts::ARCH) { + ("windows", "x86_64") => Some("win-x64"), + ("windows", "aarch64") => Some("win-arm64"), + ("linux", "x86_64") => Some("linux-x64"), + ("linux", "aarch64") => Some("linux-arm64"), + ("macos", "aarch64") => Some("osx-arm64"), + ("macos", "x86_64") => Some("osx-x64"), + _ => None, + } +} + +fn native_lib_extension() -> &'static str { + match env::consts::OS { + "windows" => "dll", + "macos" => "dylib", + _ => "so", + } +} + +struct NuGetPackage { + name: String, + version: String, + expected_file: String, +} + +fn get_packages(deps: &DepsVersions, runtime_version: &str) -> Vec { + let ext = native_lib_extension(); + let prefix = if env::consts::OS == "windows" { + "" + } else { + "lib" + }; + let runtime_name = if env::var("CARGO_FEATURE_WINML").is_ok() { + "Microsoft.AI.Foundry.Local.Runtime.WinML" + } else { + "Microsoft.AI.Foundry.Local.Runtime" + }; + + vec![ + NuGetPackage { + name: runtime_name.to_string(), + version: runtime_version.to_string(), + expected_file: format!("{prefix}foundry_local.{ext}"), + }, + NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Foundry".to_string(), + version: deps.ort.clone(), + expected_file: format!("{prefix}onnxruntime.{ext}"), + }, + NuGetPackage { + name: "Microsoft.ML.OnnxRuntimeGenAI.Foundry".to_string(), + version: deps.genai.clone(), + expected_file: format!("{prefix}onnxruntime-genai.{ext}"), + }, + ] +} + +fn resolve_base_address(feed_url: &str) -> Result { + let body: String = ureq::get(feed_url) + .call() + .map_err(|e| format!("fetch feed index {feed_url}: {e}"))? + .body_mut() + .read_to_string() + .map_err(|e| format!("read feed index: {e}"))?; + let index: serde_json::Value = + serde_json::from_str(&body).map_err(|e| format!("parse feed index: {e}"))?; + for resource in index["resources"].as_array().ok_or("missing resources")? { + if resource["@type"].as_str() == Some("PackageBaseAddress/3.0.0") { + if let Some(id) = resource["@id"].as_str() { + return Ok(if id.ends_with('/') { + id.to_string() + } else { + format!("{id}/") + }); + } + } + } + Err(format!("no PackageBaseAddress in {feed_url}")) +} + +fn try_download( + pkg: &NuGetPackage, + rid: &str, + out_dir: &Path, + feed_url: &str, +) -> Result { + let base = resolve_base_address(feed_url)?; + let name = pkg.name.to_lowercase(); + let version = pkg.version.to_lowercase(); + let url = format!("{base}{name}/{version}/{name}.{version}.nupkg"); + println!("cargo:warning=Downloading {} {}", pkg.name, pkg.version); + + let mut response = ureq::get(&url) + .call() + .map_err(|e| format!("download {}: {e}", pkg.name))?; + let mut bytes = Vec::new(); + response + .body_mut() + .as_reader() + .read_to_end(&mut bytes) + .map_err(|e| format!("read body {}: {e}", pkg.name))?; + + let ext = native_lib_extension(); + let native_prefix = format!("runtimes/{rid}/native/"); + let runtime_prefix = format!("runtimes/{rid}/"); + let mut archive = zip::ZipArchive::new(io::Cursor::new(&bytes)) + .map_err(|e| format!("open nupkg {}: {e}", pkg.name))?; + + let mut extracted = 0usize; + for i in 0..archive.len() { + let mut file = archive.by_index(i).map_err(|e| format!("zip entry: {e}"))?; + let entry = file.name().to_string(); + if !entry.ends_with(&format!(".{ext}")) { + continue; + } + let direct = entry + .strip_prefix(&runtime_prefix) + .map(|r| !r.is_empty() && !r.contains('/')) + .unwrap_or(false); + if !entry.starts_with(&native_prefix) && !direct { + continue; + } + let file_name = match Path::new(&entry).file_name() { + Some(n) => n.to_string_lossy().to_string(), + None => continue, + }; + let dest = out_dir.join(&file_name); + let mut out = + fs::File::create(&dest).map_err(|e| format!("create {}: {e}", dest.display()))?; + io::copy(&mut file, &mut out).map_err(|e| format!("write {}: {e}", dest.display()))?; + println!("cargo:warning= Extracted {file_name}"); + extracted += 1; + } + Ok(extracted) +} + +fn download_and_extract(pkg: &NuGetPackage, rid: &str, out_dir: &Path) -> Result<(), String> { + if out_dir.join(&pkg.expected_file).exists() { + return Ok(()); + } + if pkg.version.trim().is_empty() { + return Err(format!("no version configured for {}", pkg.name)); + } + let mut last = String::new(); + for feed in FEEDS { + match try_download(pkg, rid, out_dir, feed) { + Ok(_) => return Ok(()), + Err(e) => last = e, + } + } + Err(format!( + "download {} {} failed: {last}", + pkg.name, pkg.version + )) +} + +fn copy_from_local_dir(src: &Path, out_dir: &Path) -> bool { + let ext = native_lib_extension(); + let Ok(entries) = fs::read_dir(src) else { + return false; + }; + let mut copied = false; + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().and_then(|e| e.to_str()) == Some(ext) { + if let Some(name) = path.file_name() { + if fs::copy(&path, out_dir.join(name)).is_ok() { + println!( + "cargo:warning=Copied {} from FOUNDRY_LOCAL_NATIVE_BIN_DIR", + name.to_string_lossy() + ); + copied = true; + } + } + } + } + copied +} + +fn emit_native_dir(out_dir: &Path) { + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + #[cfg(windows)] + println!("cargo:rustc-link-lib=kernel32"); +} + +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + println!("cargo:rerun-if-env-changed=FOUNDRY_LOCAL_NATIVE_BIN_DIR"); + println!("cargo:rerun-if-env-changed=FOUNDRY_LOCAL_RUNTIME_VERSION"); + println!("cargo:rerun-if-env-changed=CARGO_FEATURE_WINML"); + + let manifest_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap_or_default()); + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set")); + + // 1. Local C++ build output (dev path). + if let Ok(local) = env::var("FOUNDRY_LOCAL_NATIVE_BIN_DIR") { + let src = Path::new(&local); + if src.is_dir() && copy_from_local_dir(src, &out_dir) { + emit_native_dir(&out_dir); + return; + } + } + + // 2. Runtime NuGet download (release path), only when a version is pinned. + let runtime_version = env::var("FOUNDRY_LOCAL_RUNTIME_VERSION").unwrap_or_default(); + if !runtime_version.trim().is_empty() { + let rid = match get_rid() { + Some(r) => r, + None => { + println!( + "cargo:warning=Unsupported platform {} {}; skipping native download.", + env::consts::OS, + env::consts::ARCH + ); + return; + } + }; + let deps = load_deps_versions(&manifest_dir); + let packages = get_packages(&deps, &runtime_version); + let mut failed = false; + for pkg in &packages { + if let Err(e) = download_and_extract(pkg, rid, &out_dir) { + println!("cargo:warning={e}"); + failed = true; + } + } + if !failed { + emit_native_dir(&out_dir); + } + return; + } + + // 3. No build-time native configured. Runtime discovery handles loading: + // FOUNDRY_LOCAL_LIB_DIR, the executable's directory, or the system loader + // path (see detail::api::resolve_library_path). Stay quiet when the runtime + // override is set — the `FOUNDRY_LOCAL_LIB_DIR=... cargo run` workflow is a + // fully supported path and shouldn't trigger a build warning. Only hint when + // nothing is configured at build *or* run time. + println!("cargo:rerun-if-env-changed=FOUNDRY_LOCAL_LIB_DIR"); + if env::var_os("FOUNDRY_LOCAL_LIB_DIR").is_none() { + println!( + "cargo:warning=foundry-local-sdk: no native library configured. Provide it at build time \ + via FOUNDRY_LOCAL_NATIVE_BIN_DIR (local C++ build) or FOUNDRY_LOCAL_RUNTIME_VERSION (NuGet), \ + or at runtime via FOUNDRY_LOCAL_LIB_DIR / by placing foundry_local on the loader search path." + ); + } +} diff --git a/sdk_v2/rust/docs/api.md b/sdk_v2/rust/docs/api.md new file mode 100644 index 000000000..8dcb0c292 --- /dev/null +++ b/sdk_v2/rust/docs/api.md @@ -0,0 +1,552 @@ +# Foundry Local Rust SDK — Public API Reference + +> Auto-generated from `sdk/rust/src` source files. + +## Table of Contents + +- [Entry Point](#entry-point) + - [FoundryLocalManager](#foundrylocalmanager) + - [FoundryLocalConfig](#foundrylocalconfig) + - [Logger](#logger) + - [LogLevel](#loglevel) +- [Model Catalog](#model-catalog) + - [Catalog](#catalog) + - [Model](#model) +- [OpenAI Clients](#openai-clients) + - [ChatClient](#chatclient) + - [ChatCompletionStream](#chatcompletionstream) + - [EmbeddingClient](#embeddingclient) + - [EmbeddingResponse](#embeddingresponse) + - [AudioClient](#audioclient) + - [AudioTranscriptionStream](#audiotranscriptionstream) + - [AudioTranscriptionResponse](#audiotranscriptionresponse) + - [TranscriptionSegment](#transcriptionsegment) + - [TranscriptionWord](#transcriptionword) + - [JsonStream\](#jsonstreamt) +- [Types](#types) + - [ModelInfo](#modelinfo) + - [ChatResponseFormat](#chatresponseformat) + - [ChatToolChoice](#chattoolchoice) + - [DeviceType](#devicetype) + - [PromptTemplate](#prompttemplate) + - [Runtime](#runtime) + - [ModelSettings](#modelsettings) + - [Parameter](#parameter) +- [Error Handling](#error-handling) + - [FoundryLocalError](#foundrylocalerror) +- [Re-exported OpenAI Types](#re-exported-openai-types) + +--- + +## Entry Point + +### FoundryLocalManager + +Primary entry point for interacting with Foundry Local. Singleton — created once via `create()`. + +```rust +pub struct FoundryLocalManager { /* private fields */ } +``` + +| Method | Signature | Description | +|--------|-----------|-------------| +| `create` | `fn create(config: FoundryLocalConfig) -> Result<&'static Self, FoundryLocalError>` | Initialise the SDK. First call creates the singleton; subsequent calls return the existing instance (config is ignored after first call). | +| `catalog` | `fn catalog(&self) -> &Catalog` | Access the model catalog. | +| `urls` | `fn urls(&self) -> Result, FoundryLocalError>` | URLs the local web service is listening on. Empty until `start_web_service` is called. | +| `start_web_service` | `async fn start_web_service(&self) -> Result<(), FoundryLocalError>` | Start the local web service. Retrieve listening URLs via `urls()`. | +| `stop_web_service` | `async fn stop_web_service(&self) -> Result<(), FoundryLocalError>` | Stop the local web service. | + +--- + +### FoundryLocalConfig + +User-facing configuration for initializing the SDK. Fields are private; use +the builder methods to customise. + +```rust +pub struct FoundryLocalConfig { /* private fields */ } +``` + +| Method | Signature | Description | +|--------|-----------|-------------| +| `new` | `fn new(app_name: impl Into) -> Self` | Create a new configuration. All optional fields default to `None`. | +| `app_data_dir` | `fn app_data_dir(self, dir: impl Into) -> Self` | Override the application-data directory. | +| `model_cache_dir` | `fn model_cache_dir(self, dir: impl Into) -> Self` | Override the model-cache directory. | +| `logs_dir` | `fn logs_dir(self, dir: impl Into) -> Self` | Override the logs directory. | +| `log_level` | `fn log_level(self, level: LogLevel) -> Self` | Set the log level. | +| `web_service_urls` | `fn web_service_urls(self, urls: impl Into) -> Self` | Set the web-service listen URLs. | +| `service_endpoint` | `fn service_endpoint(self, endpoint: impl Into) -> Self` | Set an external service endpoint URL. | +| `library_path` | `fn library_path(self, path: impl Into) -> Self` | Override the path to the native core library. | +| `additional_setting` | `fn additional_setting(self, key: impl Into, value: impl Into) -> Self` | Add a key-value pair to additional settings. | +| `logger` | `fn logger(self, logger: impl Logger + 'static) -> Self` | Provide an application logger (stub — not yet wired into native core). | + +**Example:** +```rust +let config = FoundryLocalConfig::new("my_app") + .log_level(LogLevel::Debug) + .model_cache_dir("/path/to/cache"); +``` + +--- + +### LogLevel + +```rust +pub enum LogLevel { + Trace, + Debug, + Info, + Warn, + Error, + Fatal, +} +``` + +--- + +### Logger + +Application logger trait. Implement this to receive SDK log messages. + +> **Note:** Stub — not yet wired into the native core. Stored in configuration for future use. + +```rust +pub trait Logger: Send + Sync { + fn log(&self, level: LogLevel, message: &str); +} +``` + +--- + +### Catalog + +Discovers, caches, and looks up available models. + +```rust +pub struct Catalog { /* private fields */ } +``` + +| Method | Signature | Description | +|--------|-----------|-------------| +| `name` | `fn name(&self) -> &str` | Catalog name as reported by the native core. | +| `update_models` | `async fn update_models(&self) -> Result<(), FoundryLocalError>` | Refresh catalog if cache expired or invalidated. | +| `get_models` | `async fn get_models(&self) -> Result>, FoundryLocalError>` | Return all known models. | +| `get_model` | `async fn get_model(&self, alias: &str) -> Result, FoundryLocalError>` | Look up a model by alias. | +| `get_model_variant` | `async fn get_model_variant(&self, id: &str) -> Result, FoundryLocalError>` | Look up a variant by unique id. | +| `get_cached_models` | `async fn get_cached_models(&self) -> Result>, FoundryLocalError>` | Return only variants cached on disk. | +| `get_loaded_models` | `async fn get_loaded_models(&self) -> Result>, FoundryLocalError>` | Return model variants currently loaded in memory. | + +--- + +### Model + +Groups one or more variants sharing the same alias. By default, the cached variant is selected. + +```rust +pub struct Model { /* private fields */ } +``` + +| Method | Signature | Description | +|--------|-----------|-------------| +| `alias` | `fn alias(&self) -> &str` | Alias shared by all variants. | +| `id` | `fn id(&self) -> &str` | Unique identifier of the selected variant. | +| `variants` | `fn variants(&self) -> Vec>` | All variants in this model. | +| `select_variant` | `fn select_variant(&self, variant: &Model) -> Result<(), FoundryLocalError>` | Select a variant from `variants()`. | +| `select_variant_by_id` | `fn select_variant_by_id(&self, id: &str) -> Result<(), FoundryLocalError>` | Select a variant by its unique id string. | +| `is_cached` | `async fn is_cached(&self) -> Result` | Whether the selected variant is cached on disk. | +| `is_loaded` | `async fn is_loaded(&self) -> Result` | Whether the selected variant is loaded in memory. | +| `download` | `async fn download(&self, progress: Option) -> Result<(), FoundryLocalError>` | Download the selected variant. `F: FnMut(f64) + Send + 'static` — receives progress as a percentage (0.0–100.0). | +| `path` | `async fn path(&self) -> Result` | Local file-system path of the selected variant. | +| `load` | `async fn load(&self) -> Result<(), FoundryLocalError>` | Load the selected variant into memory. | +| `unload` | `async fn unload(&self) -> Result` | Unload the selected variant from memory. | +| `remove_from_cache` | `async fn remove_from_cache(&self) -> Result` | Remove the selected variant from the local cache. | +| `create_chat_client` | `fn create_chat_client(&self) -> ChatClient` | Create a ChatClient bound to the selected variant. | +| `create_audio_client` | `fn create_audio_client(&self) -> AudioClient` | Create an AudioClient bound to the selected variant. | + +--- + +## OpenAI Clients + +### ChatClient + +OpenAI-compatible chat completions backed by a local model. Uses a consuming builder pattern. + +```rust +pub struct ChatClient { /* private fields */ } +``` + +**Builder methods** (all `mut self -> Self`): + +| Method | Signature | Description | +|--------|-----------|-------------| +| `frequency_penalty` | `fn frequency_penalty(mut self, v: f64) -> Self` | Set the frequency penalty. | +| `max_tokens` | `fn max_tokens(mut self, v: u32) -> Self` | Maximum tokens to generate. | +| `n` | `fn n(mut self, v: u32) -> Self` | Number of completions. | +| `temperature` | `fn temperature(mut self, v: f64) -> Self` | Sampling temperature. | +| `presence_penalty` | `fn presence_penalty(mut self, v: f64) -> Self` | Presence penalty. | +| `top_p` | `fn top_p(mut self, v: f64) -> Self` | Nucleus sampling probability. | +| `top_k` | `fn top_k(mut self, v: u32) -> Self` | Top-k sampling *(Foundry extension)*. | +| `random_seed` | `fn random_seed(mut self, v: u64) -> Self` | Random seed for reproducibility *(Foundry extension)*. | +| `response_format` | `fn response_format(mut self, v: ChatResponseFormat) -> Self` | Desired response format. | +| `tool_choice` | `fn tool_choice(mut self, v: ChatToolChoice) -> Self` | Tool choice strategy. | + +**Completion methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `complete_chat` | `async fn complete_chat(&self, messages: &[ChatCompletionRequestMessage], tools: Option<&[ChatCompletionTools]>) -> Result` | Non-streaming chat completion. | +| `complete_streaming_chat` | `async fn complete_streaming_chat(&self, messages: &[ChatCompletionRequestMessage], tools: Option<&[ChatCompletionTools]>) -> Result` | Streaming chat completion. | + +**Example:** +```rust +let client = model.create_chat_client() + .temperature(0.7) + .max_tokens(256); +``` + +--- + +### ChatCompletionStream + +```rust +pub type ChatCompletionStream = JsonStream; +``` + +A stream of `CreateChatCompletionStreamResponse` chunks. Use with `StreamExt::next()`. + +--- + +### EmbeddingClient + +OpenAI-compatible embedding generation backed by a local model. + +| Method | Description | +|---|---| +| `new(model_id, core)` | *(internal)* Create a new client | +| `generate_embedding(input: &str) -> Result` | Generate embedding for a single input | +| `generate_embeddings(inputs: &[&str]) -> Result` | Generate embeddings for multiple inputs | + +Returns `async_openai::types::embeddings::CreateEmbeddingResponse`: + +| Field | Type | Description | +|---|---|---| +| `model` | `String` | Model used for generation | +| `object` | `String` | Object type (always `"list"`) | +| `data` | `Vec` | List of embedding results | +| `usage` | `Usage` | Token usage information | + +Each `Embedding` in `data`: + +| Field | Type | Description | +|---|---|---| +| `index` | `u32` | Index of this embedding in the batch | +| `embedding` | `Vec` | The embedding vector (float32) | + +--- + +### AudioClient + +OpenAI-compatible audio transcription backed by a local model. + +```rust +pub struct AudioClient { /* private fields */ } +``` + +**Builder methods** (all `mut self -> Self`): + +| Method | Signature | Description | +|--------|-----------|-------------| +| `language` | `fn language(mut self, lang: impl Into) -> Self` | Language hint for transcription. | +| `temperature` | `fn temperature(mut self, v: f64) -> Self` | Sampling temperature. | + +**Transcription methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `transcribe` | `async fn transcribe(&self, audio_file_path: impl AsRef) -> Result` | Transcribe an audio file. | +| `transcribe_streaming` | `async fn transcribe_streaming(&self, audio_file_path: impl AsRef) -> Result` | Streaming transcription. | + +**Example:** +```rust +let client = model.create_audio_client() + .language("en") + .temperature(0.2); +``` + +--- + +### AudioTranscriptionStream + +```rust +pub type AudioTranscriptionStream = JsonStream; +``` + +A stream of `AudioTranscriptionResponse` chunks. Use with `StreamExt::next()`. + +--- + +### AudioTranscriptionResponse + +```rust +pub struct AudioTranscriptionResponse { + pub text: String, // The transcribed text + pub language: Option, // Language of input audio (if detected) + pub duration: Option, // Duration in seconds (if available) + pub segments: Option>, // Transcription segments (if available) + pub words: Option>, // Words with timestamps (if available) +} +``` + +Derives: `Debug`, `Clone`, `Deserialize`, `Serialize` + +--- + +### TranscriptionSegment + +A segment of a transcription, as returned by the OpenAI-compatible API. + +```rust +pub struct TranscriptionSegment { + pub id: i32, + pub seek: i32, + pub start: f64, + pub end: f64, + pub text: String, + pub tokens: Option>, + pub temperature: Option, + pub avg_logprob: Option, + pub compression_ratio: Option, + pub no_speech_prob: Option, +} +``` + +Derives: `Debug`, `Clone`, `Deserialize`, `Serialize` + +--- + +### TranscriptionWord + +A word with timing information, as returned by the OpenAI-compatible API. + +```rust +pub struct TranscriptionWord { + pub word: String, + pub start: f64, + pub end: f64, +} +``` + +Derives: `Debug`, `Clone`, `Deserialize`, `Serialize` + +--- + +### JsonStream\ + +Generic stream that deserializes each received JSON string chunk into `T`. Empty chunks are silently skipped. + +```rust +pub struct JsonStream { /* private fields */ } + +impl Unpin for JsonStream {} +impl Stream for JsonStream { + type Item = Result; +} +``` + +--- + +## Types + +### ModelInfo + +Full metadata for a model variant as returned by the catalog. + +```rust +pub struct ModelInfo { + pub id: String, + pub name: String, + pub version: u64, + pub alias: String, + pub display_name: Option, + pub provider_type: String, + pub uri: String, + pub model_type: String, + pub prompt_template: Option, + pub publisher: Option, + pub model_settings: Option, + pub license: Option, + pub license_description: Option, + pub cached: bool, + pub task: Option, + pub runtime: Option, + pub file_size_mb: Option, + pub supports_tool_calling: Option, + pub max_output_tokens: Option, + pub min_fl_version: Option, + pub created_at_unix: u64, +} +``` + +Derives: `Debug`, `Clone`, `Deserialize` + +--- + +### ChatResponseFormat + +```rust +pub enum ChatResponseFormat { + Text, // Plain text output (default) + JsonObject, // JSON output (unstructured) + JsonSchema(String), // JSON constrained by schema string + LarkGrammar(String), // Lark grammar constraint (Foundry extension) +} +``` + +--- + +### ChatToolChoice + +```rust +pub enum ChatToolChoice { + None, // Model will not call any tool + Auto, // Model decides whether to call a tool + Required, // Model must call at least one tool + Function(String), // Model must call the named function +} +``` + +--- + +### DeviceType + +```rust +pub enum DeviceType { + Invalid, + CPU, + GPU, + NPU, +} +``` + +--- + +### PromptTemplate + +```rust +pub struct PromptTemplate { + pub system: Option, + pub user: Option, + pub assistant: Option, + pub prompt: Option, +} +``` + +--- + +### Runtime + +```rust +pub struct Runtime { + pub device_type: DeviceType, + pub execution_provider: String, +} +``` + +--- + +### ModelSettings + +```rust +pub struct ModelSettings { + pub parameters: Option>, +} +``` + +--- + +### Parameter + +```rust +pub struct Parameter { + pub name: String, + pub value: Option, +} +``` + +--- + +## Error Handling + +### FoundryLocalError + +```rust +pub enum FoundryLocalError { + /// The native core library could not be loaded. + LibraryLoad { reason: String }, + + /// A command executed against the native core returned an error. + CommandExecution { reason: String }, + + /// The provided configuration is invalid. + InvalidConfiguration { reason: String }, + + /// A model operation failed (load, unload, download, etc.). + ModelOperation { reason: String }, + + /// An HTTP request to the external service failed. + HttpRequest(reqwest::Error), + + /// Serialization or deserialization of JSON data failed. + Serialization(serde_json::Error), + + /// A validation check on user-supplied input failed. + Validation { reason: String }, + + /// An I/O error occurred. + Io(std::io::Error), + + /// An internal SDK error (e.g. poisoned lock). + Internal { reason: String }, +} +``` + +Implements: `Display`, `Error`, `From`, `From`, `From` + +> **Note:** The `Result` type alias (`std::result::Result`) is defined +> in `error.rs` for internal SDK use but is **not** re-exported from the crate root. +> Public API signatures use `Result` explicitly to avoid shadowing +> the standard `Result`. + +--- + +## Re-exported OpenAI Types + +The following types from `async_openai` are re-exported at the crate root for convenience: + +**Request types:** +- `ChatCompletionRequestMessage` +- `ChatCompletionRequestSystemMessage` +- `ChatCompletionRequestUserMessage` +- `ChatCompletionRequestAssistantMessage` +- `ChatCompletionRequestToolMessage` +- `ChatCompletionTools` +- `ChatCompletionToolChoiceOption` +- `ChatCompletionNamedToolChoice` +- `FunctionObject` + +**Response types:** +- `CreateChatCompletionResponse` +- `CreateChatCompletionStreamResponse` +- `ChatChoice` +- `ChatChoiceStream` +- `ChatCompletionResponseMessage` +- `ChatCompletionStreamResponseDelta` +- `CompletionUsage` +- `FinishReason` + +**Tool call types:** +- `ChatCompletionMessageToolCall` +- `ChatCompletionMessageToolCallChunk` +- `ChatCompletionMessageToolCalls` +- `FunctionCall` +- `FunctionCallStream` diff --git a/sdk_v2/rust/examples/chat_completion.rs b/sdk_v2/rust/examples/chat_completion.rs new file mode 100644 index 000000000..f3ac15c81 --- /dev/null +++ b/sdk_v2/rust/examples/chat_completion.rs @@ -0,0 +1,97 @@ +//! Basic chat completion example demonstrating synchronous and streaming +//! usage of the Foundry Local SDK. + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise the manager ──────────────────────────────────────── + let config = FoundryLocalConfig::new("foundry_local_samples"); + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. List available models ───────────────────────────────────────── + let models = manager.catalog().get_models().await?; + println!("Available models:"); + for model in &models { + println!(" • {} (id: {})", model.alias(), model.id()); + } + + // ── 3. Pick a model and ensure it is loaded ────────────────────────── + // Prefer a known chat model; fall back to the first available. + let model_alias = ["phi-3.5-mini", "phi-4-mini"] + .iter() + .find(|alias| models.iter().any(|m| m.alias() == **alias)) + .map(|s| s.to_string()) + .or_else(|| models.first().map(|m| m.alias().to_string())) + .expect("No models available in the catalog"); + + let model = manager.catalog().get_model(&model_alias).await?; + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model + .download(Some(|progress: f64| { + println!(" {progress:.1}%"); + })) + .await?; + } + + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 4. Synchronous chat completion ─────────────────────────────────── + let client = model.create_chat_client().temperature(0.7).max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is Rust's ownership model?").into(), + ]; + + println!("\n--- Synchronous completion ---"); + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("Assistant: {content}"); + } + } + + // ── 5. Streaming chat completion ───────────────────────────────────── + println!("\n--- Streaming completion ---"); + let stream_messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain the borrow checker in two sentences.") + .into(), + ]; + + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&stream_messages, None) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!(); + + // ── 6. Unload the model────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs new file mode 100644 index 000000000..e1ccb5641 --- /dev/null +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -0,0 +1,112 @@ +//! Interactive chat example — a simple terminal chatbot powered by Foundry Local. +//! +//! Run with: `cargo run --example interactive_chat` + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, FoundryLocalConfig, + FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // ── Initialise ─────────────────────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + + // Pick the first available model (or change this to a specific alias) + let catalog = manager.catalog(); + let models = catalog.get_models().await?; + + println!("Available models:"); + for (i, m) in models.iter().enumerate() { + println!(" [{i}] {}", m.alias()); + } + + print!("\nSelect a model number (default 0): "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let idx: usize = choice.trim().parse().unwrap_or(0); + + let alias = models + .get(idx) + .map(|m| m.alias().to_string()) + .unwrap_or_else(|| models[0].alias().to_string()); + + let model = catalog.get_model(&alias).await?; + + // Download if needed + if !model.is_cached().await? { + println!("Downloading '{alias}'…"); + model.download(Some(|p: f64| print!("\r {p:.1}%"))).await?; + println!(); + } + + println!("Loading '{alias}'…"); + model.load().await?; + println!("Ready! Type your messages below. Press Ctrl-D (or type 'quit') to exit.\n"); + + // ── Chat loop ──────────────────────────────────────────────────────── + let client = model.create_chat_client().temperature(0.7).max_tokens(512); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful, concise assistant.").into(), + ]; + + loop { + print!("You: "); + io::stdout().flush()?; + + let mut input = String::new(); + if io::stdin().read_line(&mut input)? == 0 { + break; // EOF (Ctrl-D) + } + + let input = input.trim(); + if input.is_empty() { + continue; + } + if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { + break; + } + + messages.push(ChatCompletionRequestUserMessage::from(input).into()); + + // Stream the response token by token + print!("Assistant: "); + io::stdout().flush()?; + + let mut full_response = String::new(); + let mut stream = client.complete_streaming_chat(&messages, None).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + full_response.push_str(content); + } + } + } + println!("\n"); + + // Add assistant reply to history for multi-turn conversation + messages.push( + ChatCompletionRequestAssistantMessage { + content: Some(full_response.into()), + ..Default::default() + } + .into(), + ); + } + + // ── Cleanup ────────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Goodbye!"); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs new file mode 100644 index 000000000..f556b2a92 --- /dev/null +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -0,0 +1,202 @@ +//! Tool-calling example demonstrating how to define tools, handle +//! `tool_calls` in streaming responses, execute the tool locally, +//! and feed results back for a multi-turn conversation. + +use std::collections::HashMap; +use std::io::{self, Write}; + +use serde_json::{json, Value}; +use tokio_stream::StreamExt; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatCompletionTools, + ChatToolChoice, FinishReason, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +/// A trivial tool that multiplies two numbers. +fn multiply(a: f64, b: f64) -> f64 { + a * b +} + +/// Dispatch a tool call by name and arguments. +fn invoke_tool(name: &str, arguments: &Value) -> Result { + match name { + "multiply" => { + let a = arguments.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0); + let b = arguments.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0); + let result = multiply(a, b); + Ok(result.to_string()) + } + _ => Ok(format!("Unknown tool: {name}")), + } +} + +#[derive(Default, Clone)] +struct StreamedToolCall { + id: String, + name: String, + arguments: String, +} + +#[derive(Default)] +struct ToolCallState { + /// In-progress tool calls indexed by their stream position. + pending: HashMap, + /// Finalized tool calls ready for execution. + completed: Vec, +} + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise ──────────────────────────────────────────────────── + let config = FoundryLocalConfig::new("foundry_local_samples"); + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. Load a model ────────────────────────────────────────────────── + let models = manager.catalog().get_models().await?; + let model = models + .iter() + .find(|m| m.info().supports_tool_calling == Some(true)) + .or_else(|| models.first()) + .expect("No models available"); + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model.download(Some(|p: f64| println!(" {p:.1}%"))).await?; + } + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 3. Create a chat client with tool_choice = required ────────────── + let client = model + .create_chat_client() + .tool_choice(ChatToolChoice::Required) + .max_tokens(512); + + let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": { "type": "number", "description": "First operand" }, + "b": { "type": "number", "description": "Second operand" } + }, + "required": ["a", "b"] + } + } + }])) + .expect("Failed to parse tool definitions"); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful calculator assistant. Use the multiply tool when asked to multiply.", + ) + .into(), + ChatCompletionRequestUserMessage::from("What is 6 times 7?").into(), + ]; + + // ── 4. First streaming call – expect tool_calls ────────────────────── + println!("Sending initial request…"); + + let mut state = ToolCallState::default(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + let idx = tc.index; + let entry = state.pending.entry(idx).or_default(); + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + entry.name = name.clone(); + } + if let Some(ref args) = func.arguments { + entry.arguments.push_str(args); + } + } + if let Some(ref id) = tc.id { + entry.id = id.clone(); + } + } + } + + if choice.finish_reason == Some(FinishReason::ToolCalls) { + for (_, call) in state.pending.drain() { + state.completed.push(json!({ + "id": call.id, + "type": "function", + "function": { + "name": call.name, + "arguments": call.arguments, + } + })); + } + } + } + } + // ── 5. Execute the tool(s)─────────────────────────────────────────── + for tc in &state.completed { + let func = &tc["function"]; + let name = func["name"].as_str().unwrap_or_default(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + + println!("Tool call: {name}({args})"); + let result = invoke_tool(name, &args)?; + println!("Tool result: {result}"); + + // Append the assistant's tool_calls message and the tool result. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [tc], + })) + .expect("Failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc["id"].as_str().unwrap_or_default().to_string(), + } + .into(), + ); + } + + // ── 6. Continue the conversation with auto tool_choice ─────────────── + let client = client.tool_choice(ChatToolChoice::Auto); + + println!("\nContinuing conversation…"); + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!(); + + // ── 7. Clean up────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs new file mode 100644 index 000000000..95f35ce37 --- /dev/null +++ b/sdk_v2/rust/src/catalog.rs @@ -0,0 +1,152 @@ +//! Model catalog — discovery and lookup for available models. +//! +//! The native catalog (owned by the [`FoundryLocalManager`](crate::FoundryLocalManager)) +//! caches the model list and refreshes itself, so this is a thin async wrapper +//! that preserves the legacy public surface. + +use std::sync::Arc; + +use crate::detail::api::Api; +use crate::detail::model::Model; +use crate::detail::native::NativeCatalog; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; + +/// The model catalog provides discovery and lookup for all available models. +pub struct Catalog { + native: NativeCatalog, + name: String, +} + +impl Catalog { + pub(crate) fn new(api: Arc, ptr: *mut crate::detail::ffi::flCatalog) -> Result { + let native = NativeCatalog::new(api, ptr); + let name = native.name().unwrap_or_else(|_| "default".into()); + Ok(Self { native, name }) + } + + /// Catalog name as reported by the native core. + pub fn name(&self) -> &str { + &self.name + } + + /// Refresh the catalog from the native core. + /// + /// The native catalog manages its own caching and refresh, so this is a + /// no-op retained for API compatibility. + pub async fn update_models(&self) -> Result<()> { + Ok(()) + } + + /// Return all known models keyed by alias. + pub async fn get_models(&self) -> Result>> { + let native = self.native.clone(); + spawn_blocking(move || { + native + .get_models()? + .into_iter() + .map(|m| Model::from_group(&native.api, m).map(Arc::new)) + .collect() + }) + .await + } + + /// Look up a model by its alias. + pub async fn get_model(&self, alias: &str) -> Result> { + if alias.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Model alias must be a non-empty string".into(), + }); + } + let native = self.native.clone(); + let alias = alias.to_owned(); + spawn_blocking(move || match native.get_model(&alias)? { + Some(m) => Model::from_group(&native.api, m).map(Arc::new), + None => { + let available: Vec = native + .get_models() + .ok() + .map(|models| { + models + .iter() + .filter_map(|m| { + m.info_ptr().ok().map(|info| unsafe { + crate::detail::api::cstr_to_string((native + .api + .model_api() + .Info_GetAlias)( + info + )) + .unwrap_or_default() + }) + }) + .collect() + }) + .unwrap_or_default(); + Err(FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + }) + } + }) + .await + } + + /// Look up a specific model variant by its unique id. + /// + /// NOTE: This will return a `Model` representing a single variant. Use + /// [`get_model`](Catalog::get_model) to obtain a `Model` with all + /// available variants. + pub async fn get_model_variant(&self, id: &str) -> Result> { + if id.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Variant id must be a non-empty string".into(), + }); + } + let native = self.native.clone(); + let id = id.to_owned(); + spawn_blocking(move || match native.get_model_variant(&id)? { + Some(m) => Model::from_variant(&native.api, m).map(Arc::new), + None => Err(FoundryLocalError::ModelOperation { + reason: format!("Unknown variant id '{id}'."), + }), + }) + .await + } + + /// Return only the model variants that are currently cached on disk. + pub async fn get_cached_models(&self) -> Result>> { + let native = self.native.clone(); + spawn_blocking(move || { + native + .get_cached_models()? + .into_iter() + .map(|m| Model::from_variant(&native.api, m).map(Arc::new)) + .collect() + }) + .await + } + + /// Return model variants that are currently loaded into memory. + pub async fn get_loaded_models(&self) -> Result>> { + let native = self.native.clone(); + spawn_blocking(move || { + native + .get_loaded_models()? + .into_iter() + .map(|m| Model::from_variant(&native.api, m).map(Arc::new)) + .collect() + }) + .await + } + + /// Resolve the latest catalog version for the provided model or variant. + pub async fn get_latest_version(&self, model_or_model_variant: &Model) -> Result> { + let native = self.native.clone(); + let target = model_or_model_variant.selected_native().clone(); + spawn_blocking(move || { + let latest = native.get_latest_version(&target)?; + Model::from_variant(&native.api, latest).map(Arc::new) + }) + .await + } +} diff --git a/sdk_v2/rust/src/configuration.rs b/sdk_v2/rust/src/configuration.rs new file mode 100644 index 000000000..7b0bfdc19 --- /dev/null +++ b/sdk_v2/rust/src/configuration.rs @@ -0,0 +1,253 @@ +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; + +use crate::detail::api::{to_cstring, Api, Kvps}; +use crate::detail::ffi::*; +use crate::error::{FoundryLocalError, Result}; + +/// Log level for the Foundry Local service. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogLevel { + Trace, + Debug, + Info, + Warn, + Error, + Fatal, +} + +impl LogLevel { + /// Map to the native `flLogLevel` value. + fn as_native(&self) -> flLogLevel { + match self { + Self::Trace => FOUNDRY_LOCAL_LOG_VERBOSE, + Self::Debug => FOUNDRY_LOCAL_LOG_DEBUG, + Self::Info => FOUNDRY_LOCAL_LOG_INFO, + Self::Warn => FOUNDRY_LOCAL_LOG_WARNING, + Self::Error => FOUNDRY_LOCAL_LOG_ERROR, + Self::Fatal => FOUNDRY_LOCAL_LOG_FATAL, + } + } +} + +/// Application-level logger that the SDK can use to emit diagnostic messages. +/// +/// This is a stub — the logger is stored in the configuration and passed +/// through to the manager, but it is not wired into the native core yet. +pub trait Logger: Send + Sync { + /// Log a message at the given severity level. + fn log(&self, level: LogLevel, message: &str); +} + +/// User-facing configuration for initializing the Foundry Local SDK. +/// +/// Construct with [`FoundryLocalConfig::new`] and customise via the builder +/// methods: +/// +/// ```ignore +/// let config = FoundryLocalConfig::new("my_app") +/// .log_level(LogLevel::Debug) +/// .model_cache_dir("/path/to/cache"); +/// ``` +#[derive(Default)] +pub struct FoundryLocalConfig { + app_name: String, + app_data_dir: Option, + model_cache_dir: Option, + logs_dir: Option, + log_level: Option, + web_service_urls: Option, + service_endpoint: Option, + library_path: Option, + additional_settings: Option>, + logger: Option>, +} + +impl fmt::Debug for FoundryLocalConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FoundryLocalConfig") + .field("app_name", &self.app_name) + .field("app_data_dir", &self.app_data_dir) + .field("model_cache_dir", &self.model_cache_dir) + .field("logs_dir", &self.logs_dir) + .field("log_level", &self.log_level) + .field("web_service_urls", &self.web_service_urls) + .field("service_endpoint", &self.service_endpoint) + .field("library_path", &self.library_path) + .field("additional_settings", &self.additional_settings) + .field("logger", &self.logger.as_ref().map(|_| "..")) + .finish() + } +} + +impl FoundryLocalConfig { + /// Create a new configuration with the given application name. + /// + /// All other fields default to `None`. Use the builder methods to + /// customise: + /// + /// ```ignore + /// let config = FoundryLocalConfig::new("my_app") + /// .log_level(LogLevel::Debug) + /// .model_cache_dir("/path/to/cache"); + /// ``` + pub fn new(app_name: impl Into) -> Self { + Self { + app_name: app_name.into(), + ..Self::default() + } + } + + /// Override the application-data directory. + pub fn app_data_dir(mut self, dir: impl Into) -> Self { + self.app_data_dir = Some(dir.into()); + self + } + + /// Override the model-cache directory. + pub fn model_cache_dir(mut self, dir: impl Into) -> Self { + self.model_cache_dir = Some(dir.into()); + self + } + + /// Override the logs directory. + pub fn logs_dir(mut self, dir: impl Into) -> Self { + self.logs_dir = Some(dir.into()); + self + } + + /// Set the log level. + pub fn log_level(mut self, level: LogLevel) -> Self { + self.log_level = Some(level); + self + } + + /// Set the web-service listen URLs (e.g. `"http://localhost:5273"`). + pub fn web_service_urls(mut self, urls: impl Into) -> Self { + self.web_service_urls = Some(urls.into()); + self + } + + /// Set an external service endpoint URL. + pub fn service_endpoint(mut self, endpoint: impl Into) -> Self { + self.service_endpoint = Some(endpoint.into()); + self + } + + /// Override the path to the native Foundry Local Core library. + pub fn library_path(mut self, path: impl Into) -> Self { + self.library_path = Some(path.into()); + self + } + + /// Add a single key-value pair to the additional settings map. + pub fn additional_setting(mut self, key: impl Into, value: impl Into) -> Self { + self.additional_settings + .get_or_insert_with(HashMap::new) + .insert(key.into(), value.into()); + self + } + + /// Provide an application logger. + /// + /// *Stub* — the logger is stored but not yet wired into the native core. + pub fn logger(mut self, logger: impl Logger + 'static) -> Self { + self.logger = Some(Box::new(logger)); + self + } + + // ── Crate-internal helpers ─────────────────────────────────────────────── + + /// The configured native library path override, if any. + pub(crate) fn library_path_ref(&self) -> Option<&str> { + self.library_path.as_deref() + } + + /// Take ownership of the configured logger (consumed once by the manager). + pub(crate) fn take_logger(&mut self) -> Option> { + self.logger.take() + } + + /// Build a native `flConfiguration` from this configuration. + /// + /// Returns [`FoundryLocalError::InvalidConfiguration`] when `app_name` is + /// empty or blank. + pub(crate) fn build_native(&self, api: &Arc) -> Result { + let app_name = self.app_name.trim(); + if app_name.is_empty() { + return Err(FoundryLocalError::InvalidConfiguration { + reason: "app_name must be set and non-empty".into(), + }); + } + + let cfg = NativeConfig::create(Arc::clone(api), app_name)?; + let c = api.config_api(); + + if let Some(dir) = &self.app_data_dir { + let s = to_cstring(dir)?; + // SAFETY: ptr is valid; the native call copies the string. + api.check(unsafe { (c.SetAppDataDir)(cfg.ptr, s.as_ptr()) })?; + } + if let Some(dir) = &self.model_cache_dir { + let s = to_cstring(dir)?; + api.check(unsafe { (c.SetModelCacheDir)(cfg.ptr, s.as_ptr()) })?; + } + if let Some(dir) = &self.logs_dir { + let s = to_cstring(dir)?; + api.check(unsafe { (c.SetLogsDir)(cfg.ptr, s.as_ptr()) })?; + } + if let Some(level) = self.log_level { + api.check(unsafe { (c.SetDefaultLogLevel)(cfg.ptr, level.as_native()) })?; + } + if let Some(urls) = &self.web_service_urls { + for url in urls.split(',').map(str::trim).filter(|u| !u.is_empty()) { + let s = to_cstring(url)?; + api.check(unsafe { (c.AddWebServiceEndpoint)(cfg.ptr, s.as_ptr()) })?; + } + } + if let Some(endpoint) = &self.service_endpoint { + let s = to_cstring(endpoint)?; + api.check(unsafe { (c.SetExternalServiceUrl)(cfg.ptr, s.as_ptr()) })?; + } + if let Some(extra) = &self.additional_settings { + if !extra.is_empty() { + let kvps = Kvps::from_pairs(Arc::clone(api), extra.iter())?; + api.check(unsafe { (c.SetAdditionalOptions)(cfg.ptr, kvps.as_ptr()) })?; + } + } + + Ok(cfg) + } +} + +/// Owning wrapper around a native `flConfiguration`, released on drop. +pub(crate) struct NativeConfig { + api: Arc, + ptr: *mut flConfiguration, +} + +impl NativeConfig { + fn create(api: Arc, app_name: &str) -> Result { + let name = to_cstring(app_name)?; + let mut ptr: *mut flConfiguration = std::ptr::null_mut(); + // SAFETY: `Create` writes a valid handle into `ptr` on success. + let status = unsafe { (api.config_api().Create)(name.as_ptr(), &mut ptr) }; + api.check(status)?; + Ok(Self { api, ptr }) + } + + pub(crate) fn as_ptr(&self) -> *const flConfiguration { + self.ptr + } +} + +impl Drop for NativeConfig { + fn drop(&mut self) { + if !self.ptr.is_null() { + // SAFETY: `ptr` was created by `Create` and not yet released. + unsafe { (self.api.config_api().Configuration_Release)(self.ptr) }; + self.ptr = std::ptr::null_mut(); + } + } +} diff --git a/sdk_v2/rust/src/detail/api.rs b/sdk_v2/rust/src/detail/api.rs new file mode 100644 index 000000000..5e6481254 --- /dev/null +++ b/sdk_v2/rust/src/detail/api.rs @@ -0,0 +1,431 @@ +//! Safe wrapper around the dynamically-loaded Foundry Local C ABI. +//! +//! [`Api`] loads the `foundry_local` shared library (pre-loading its ONNX Runtime +//! and GenAI dependencies first), resolves the root function table via +//! `FoundryLocalGetApi`, and caches the sub-API tables. It also provides the +//! `flStatus*` → [`FoundryLocalError`] mapping and small FFI string / key-value +//! helpers used throughout the `detail` layer. + +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use libloading::Library; + +use super::ffi::*; +use crate::error::{FoundryLocalError, Result}; + +// ── Library file names ─────────────────────────────────────────────────────── + +#[cfg(target_os = "windows")] +const LIB_FILE: &str = "foundry_local.dll"; +#[cfg(target_os = "macos")] +const LIB_FILE: &str = "libfoundry_local.dylib"; +#[cfg(all(unix, not(target_os = "macos")))] +const LIB_FILE: &str = "libfoundry_local.so"; + +#[cfg(target_os = "windows")] +const ORT_FILE: &str = "onnxruntime.dll"; +#[cfg(target_os = "macos")] +const ORT_FILE: &str = "libonnxruntime.dylib"; +#[cfg(all(unix, not(target_os = "macos")))] +const ORT_FILE: &str = "libonnxruntime.so"; + +#[cfg(target_os = "windows")] +const GENAI_FILE: &str = "onnxruntime-genai.dll"; +#[cfg(target_os = "macos")] +const GENAI_FILE: &str = "libonnxruntime-genai.dylib"; +#[cfg(all(unix, not(target_os = "macos")))] +const GENAI_FILE: &str = "libonnxruntime-genai.so"; + +// ── FFI string helpers ─────────────────────────────────────────────────────── + +/// Build a [`CString`], mapping an interior NUL to a validation error. +pub(crate) fn to_cstring(s: &str) -> Result { + CString::new(s).map_err(|e| FoundryLocalError::Validation { + reason: format!("string contains an interior NUL byte: {e}"), + }) +} + +/// Read a borrowed C string into an owned `String`. Returns `None` for null. +/// +/// # Safety +/// `ptr` must be null or point to a valid NUL-terminated string that stays alive +/// for the duration of this call. +pub(crate) unsafe fn cstr_to_string(ptr: *const c_char) -> Option { + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr).to_string_lossy().into_owned()) + } +} + +fn map_error(code: flErrorCode, message: String) -> FoundryLocalError { + match code { + FOUNDRY_LOCAL_ERROR_INVALID_ARGUMENT | FOUNDRY_LOCAL_ERROR_INVALID_USAGE => { + FoundryLocalError::Validation { reason: message } + } + FOUNDRY_LOCAL_ERROR_INTERNAL | FOUNDRY_LOCAL_ERROR_NOT_IMPLEMENTED => { + FoundryLocalError::Internal { reason: message } + } + FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED => FoundryLocalError::CommandExecution { + reason: if message.is_empty() { + "Operation cancelled".into() + } else { + message + }, + }, + _ => FoundryLocalError::CommandExecution { reason: message }, + } +} + +// ── Api ────────────────────────────────────────────────────────────────────── + +/// Loaded native library plus its resolved function tables. +/// +/// `Api` is `Send + Sync`: the underlying library is thread-safe for distinct +/// handles, and the cached vtable pointers are read-only for the process lifetime. +pub(crate) struct Api { + _foundry_lib: Library, + _dep_libs: Vec, + root: *const flApiVtable, + catalog: *const flCatalogApiVtable, + config: *const flConfigurationApiVtable, + item: *const flItemApiVtable, + inference: *const flInferenceApiVtable, + model: *const flModelApiVtable, +} + +// SAFETY: the native library is documented as thread-safe for independent +// handles; the cached vtable pointers are immutable for the library's lifetime. +unsafe impl Send for Api {} +unsafe impl Sync for Api {} + +impl std::fmt::Debug for Api { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Api").finish_non_exhaustive() + } +} + +impl Api { + /// Load the native library and resolve the API tables. + /// + /// `library_path` is an optional override pointing at the `foundry_local` + /// shared library file or the directory containing it. + pub(crate) fn load(library_path: Option<&str>) -> Result { + let (lib_path, native_dir) = resolve_library_path(library_path)?; + + // Pre-load ONNX Runtime then GenAI so the dynamic loader resolves + // foundry_local's dependencies regardless of rpath/search-path setup. + let dep_libs = preload_dependencies(native_dir.as_deref()); + + // SAFETY: `lib_path` was resolved from trusted configuration / build + // inputs. Loading a shared library executes foreign initialisers. + let foundry_lib = unsafe { + load_library(&lib_path).map_err(|e| match e { + FoundryLocalError::LibraryLoad { reason } => FoundryLocalError::LibraryLoad { + reason: format!( + "{reason}. Ensure the native Foundry Local library is available: set \ + FOUNDRY_LOCAL_LIB_DIR to the directory containing {LIB_FILE} (alongside its \ + onnxruntime and onnxruntime-genai libraries), pass it via \ + FoundryLocalConfig::library_path, or place it on the loader search path." + ), + }, + other => other, + })? + }; + + // SAFETY: the library exports `FoundryLocalGetApi` with this signature. + let root = unsafe { + let get_api: libloading::Symbol = foundry_lib + .get(FOUNDRY_LOCAL_GET_API_SYMBOL) + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("symbol 'FoundryLocalGetApi' not found: {e}"), + })?; + let ptr = get_api(FOUNDRY_LOCAL_API_VERSION); + if ptr.is_null() { + return Err(FoundryLocalError::LibraryLoad { + reason: format!( + "FoundryLocalGetApi({FOUNDRY_LOCAL_API_VERSION}) returned null (unsupported API version)" + ), + }); + } + ptr + }; + + // SAFETY: `root` is a valid, non-null vtable for the library's lifetime. + let root_ref = unsafe { &*root }; + let catalog = unsafe { (root_ref.GetCatalogApi)() }; + let config = unsafe { (root_ref.GetConfigurationApi)() }; + let item = unsafe { (root_ref.GetItemApi)() }; + let inference = unsafe { (root_ref.GetInferenceApi)() }; + let model = unsafe { (root_ref.GetModelApi)() }; + + for (name, ptr) in [ + ("CatalogApi", catalog as *const ()), + ("ConfigurationApi", config as *const ()), + ("ItemApi", item as *const ()), + ("InferenceApi", inference as *const ()), + ("ModelApi", model as *const ()), + ] { + if ptr.is_null() { + return Err(FoundryLocalError::LibraryLoad { + reason: format!("native {name} table is null"), + }); + } + } + + Ok(Self { + _foundry_lib: foundry_lib, + _dep_libs: dep_libs, + root, + catalog, + config, + item, + inference, + model, + }) + } + + #[inline] + pub(crate) fn root(&self) -> &flApiVtable { + // SAFETY: non-null and valid for the library's lifetime. + unsafe { &*self.root } + } + #[inline] + pub(crate) fn catalog_api(&self) -> &flCatalogApiVtable { + unsafe { &*self.catalog } + } + #[inline] + pub(crate) fn config_api(&self) -> &flConfigurationApiVtable { + unsafe { &*self.config } + } + #[inline] + pub(crate) fn item_api(&self) -> &flItemApiVtable { + unsafe { &*self.item } + } + #[inline] + pub(crate) fn inference_api(&self) -> &flInferenceApiVtable { + unsafe { &*self.inference } + } + #[inline] + pub(crate) fn model_api(&self) -> &flModelApiVtable { + unsafe { &*self.model } + } + + /// Convert a returned `flStatus*` into a `Result`. A non-null status is an error. + pub(crate) fn check(&self, status: flStatusPtr) -> Result<()> { + if status.is_null() { + return Ok(()); + } + let root = self.root(); + // SAFETY: `status` is a valid non-null status owned by us until released. + unsafe { + let code = (root.Status_GetErrorCode)(status); + let message = cstr_to_string((root.Status_GetErrorMessage)(status)).unwrap_or_default(); + (root.Status_Release)(status); + Err(map_error(code, message)) + } + } + + /// Read a returned `flStatus*` as an optional message. `None` for success + /// (null status). Releases the status. Used where a non-null status is a + /// soft/partial failure rather than a hard error (e.g. EP registration). + pub(crate) fn status_message(&self, status: flStatusPtr) -> Option { + if status.is_null() { + return None; + } + let root = self.root(); + // SAFETY: `status` is a valid non-null status owned by us until released. + unsafe { + let message = cstr_to_string((root.Status_GetErrorMessage)(status)).unwrap_or_default(); + (root.Status_Release)(status); + Some(message) + } + } +} + +// ── KeyValuePairs RAII helper ──────────────────────────────────────────────── + +/// Owning wrapper around a native `flKeyValuePairs`, released on drop. +pub(crate) struct Kvps { + api: Arc, + ptr: *mut flKeyValuePairs, +} + +impl Kvps { + /// Create an empty key/value collection. + pub(crate) fn new(api: Arc) -> Self { + let mut ptr: *mut flKeyValuePairs = std::ptr::null_mut(); + // SAFETY: `CreateKeyValuePairs` writes a valid handle into `ptr`. + unsafe { (api.root().CreateKeyValuePairs)(&mut ptr) }; + Self { api, ptr } + } + + /// Build from an iterator of `(key, value)` string pairs. + pub(crate) fn from_pairs(api: Arc, pairs: I) -> Result + where + I: IntoIterator, + K: AsRef, + V: AsRef, + { + let mut kvps = Self::new(api); + for (k, v) in pairs { + kvps.set(k.as_ref(), v.as_ref())?; + } + Ok(kvps) + } + + /// Add or replace a key/value pair. + pub(crate) fn set(&mut self, key: &str, value: &str) -> Result<()> { + let key = to_cstring(key)?; + let value = to_cstring(value)?; + // SAFETY: the native call copies both strings; our CStrings outlive the call. + unsafe { (self.api.root().AddKeyValuePair)(self.ptr, key.as_ptr(), value.as_ptr()) }; + Ok(()) + } + + pub(crate) fn as_ptr(&self) -> *const flKeyValuePairs { + self.ptr + } +} + +impl Drop for Kvps { + fn drop(&mut self) { + if !self.ptr.is_null() { + // SAFETY: `ptr` was created by `CreateKeyValuePairs` and not yet released. + unsafe { (self.api.root().KeyValuePairs_Release)(self.ptr) }; + self.ptr = std::ptr::null_mut(); + } + } +} + +/// Read a borrowed native `flKeyValuePairs` into an owned `Vec<(String, Option)>`. +/// +/// # Safety +/// `kvps` must be null or a valid pointer that stays alive for the duration of this call. +pub(crate) unsafe fn read_kvps( + api: &Api, + kvps: *const flKeyValuePairs, +) -> Vec<(String, Option)> { + if kvps.is_null() { + return Vec::new(); + } + let mut keys: *const *const c_char = std::ptr::null(); + let mut values: *const *const c_char = std::ptr::null(); + let mut count: usize = 0; + (api.root().GetKeyValuePairs)(kvps, &mut keys, &mut values, &mut count); + if keys.is_null() || count == 0 { + return Vec::new(); + } + let mut out = Vec::with_capacity(count); + for i in 0..count { + let key = cstr_to_string(*keys.add(i)).unwrap_or_default(); + let value = if values.is_null() { + None + } else { + cstr_to_string(*values.add(i)) + }; + out.push((key, value)); + } + out +} + +// ── Library discovery & loading ────────────────────────────────────────────── + +#[cfg(unix)] +unsafe fn load_library(path: &Path) -> Result { + use libloading::os::unix::{Library as UnixLibrary, RTLD_GLOBAL, RTLD_NOW}; + UnixLibrary::open(Some(path), RTLD_NOW | RTLD_GLOBAL) + .map(Library::from) + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("failed to load native library at {}: {e}", path.display()), + }) +} + +#[cfg(windows)] +unsafe fn load_library(path: &Path) -> Result { + Library::new(path).map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("failed to load native library at {}: {e}", path.display()), + }) +} + +/// Best-effort pre-load of ORT and GenAI from the native directory. Failures are +/// ignored: foundry_local may resolve them via rpath/search path instead. +fn preload_dependencies(native_dir: Option<&Path>) -> Vec { + let mut libs = Vec::new(); + let Some(dir) = native_dir else { + return libs; + }; + + #[allow(unused_mut)] + let mut deps: Vec<&str> = vec![ORT_FILE, GENAI_FILE]; + #[cfg(all(windows, feature = "winml"))] + deps.push("Microsoft.Windows.AI.MachineLearning.dll"); + + // Help GenAI's dlopen build find the exact ORT we are pre-loading. + let ort_path = dir.join(ORT_FILE); + if ort_path.exists() && std::env::var_os("ORT_LIB_PATH").is_none() { + std::env::set_var("ORT_LIB_PATH", &ort_path); + } + + for dep in deps { + let dep_path = dir.join(dep); + if dep_path.exists() { + // SAFETY: pre-loading a known dependency from the trusted native dir. + if let Ok(lib) = unsafe { load_library(&dep_path) } { + libs.push(lib); + } + } + } + libs +} + +/// Resolve the full path to the `foundry_local` library and its containing dir. +/// +/// Search order: +/// 1. `library_path` override (a file or a directory). +/// 2. `FOUNDRY_LOCAL_LIB_DIR` environment variable (dev override). +/// 3. `FOUNDRY_NATIVE_DIR` compile-time path (set by `build.rs`). +/// 4. The directory of the current executable. +/// 5. Bare library name resolved via the system search path. +fn resolve_library_path(library_path: Option<&str>) -> Result<(PathBuf, Option)> { + // 1. Explicit override: a file path or a directory. + if let Some(p) = library_path { + let path = Path::new(p); + if path.is_file() { + let dir = path.parent().map(Path::to_path_buf); + return Ok((path.to_path_buf(), dir)); + } + let candidate = path.join(LIB_FILE); + if candidate.is_file() { + return Ok((candidate, Some(path.to_path_buf()))); + } + } + + let mut dirs: Vec = Vec::new(); + if let Ok(dir) = std::env::var("FOUNDRY_LOCAL_LIB_DIR") { + if !dir.trim().is_empty() { + dirs.push(PathBuf::from(dir)); + } + } + if let Some(dir) = option_env!("FOUNDRY_NATIVE_DIR") { + dirs.push(PathBuf::from(dir)); + } + if let Ok(exe) = std::env::current_exe() { + if let Some(dir) = exe.parent() { + dirs.push(dir.to_path_buf()); + } + } + + for dir in &dirs { + let candidate = dir.join(LIB_FILE); + if candidate.is_file() { + return Ok((candidate, Some(dir.clone()))); + } + } + + // 5. Fall back to the bare library name (system loader search path). + Ok((PathBuf::from(LIB_FILE), None)) +} diff --git a/sdk_v2/rust/src/detail/ffi.rs b/sdk_v2/rust/src/detail/ffi.rs new file mode 100644 index 000000000..e87c350f3 --- /dev/null +++ b/sdk_v2/rust/src/detail/ffi.rs @@ -0,0 +1,658 @@ +//! Raw FFI declarations mirroring `sdk_v2/cpp/include/foundry_local/foundry_local_c.h`. +//! +//! This module is a faithful `repr(C)` transcription of the Foundry Local C ABI. +//! It exposes the opaque handle types, versioned data structs, enums, callback +//! signatures, and the six function-pointer tables that the library returns via +//! [`FoundryLocalGetApi`]. +//! +//! Everything here is `unsafe` to use; safe wrappers live in [`super::api`] and the +//! higher-level detail modules. +#![allow(non_snake_case)] +#![allow(non_camel_case_types)] +#![allow(dead_code)] + +use core::ffi::c_void; +use std::os::raw::{c_char, c_int}; + +/// The library is built against this API version (`FOUNDRY_LOCAL_API_VERSION`). +pub const FOUNDRY_LOCAL_API_VERSION: u32 = 1; + +// ── Opaque handle types ────────────────────────────────────────────────────── + +macro_rules! opaque_type { + ($name:ident) => { + #[repr(C)] + pub struct $name { + _data: [u8; 0], + _marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>, + } + }; +} + +opaque_type!(flCatalog); +opaque_type!(flConfiguration); +opaque_type!(flItem); +opaque_type!(flItemQueue); +opaque_type!(flKeyValuePairs); +opaque_type!(flManager); +opaque_type!(flModel); +opaque_type!(flModelInfo); +opaque_type!(flModelList); +opaque_type!(flRequest); +opaque_type!(flResponse); +opaque_type!(flSession); +opaque_type!(flStatus); + +/// A non-null `flStatus*` indicates an error; `null` indicates success. +pub type flStatusPtr = *mut flStatus; + +// ── Enums (C enums are `int`-sized) ────────────────────────────────────────── + +pub type flErrorCode = c_int; +pub const FOUNDRY_LOCAL_OK: flErrorCode = 0; +pub const FOUNDRY_LOCAL_ERROR_NOT_IMPLEMENTED: flErrorCode = 1; +pub const FOUNDRY_LOCAL_ERROR_INTERNAL: flErrorCode = 2; +pub const FOUNDRY_LOCAL_ERROR_INVALID_ARGUMENT: flErrorCode = 3; +pub const FOUNDRY_LOCAL_ERROR_INVALID_USAGE: flErrorCode = 4; +pub const FOUNDRY_LOCAL_ERROR_OPERATION_CANCELLED: flErrorCode = 5; +pub const FOUNDRY_LOCAL_ERROR_NETWORK: flErrorCode = 6; + +pub type flLogLevel = c_int; +pub const FOUNDRY_LOCAL_LOG_VERBOSE: flLogLevel = 0; +pub const FOUNDRY_LOCAL_LOG_DEBUG: flLogLevel = 1; +pub const FOUNDRY_LOCAL_LOG_INFO: flLogLevel = 2; +pub const FOUNDRY_LOCAL_LOG_WARNING: flLogLevel = 3; +pub const FOUNDRY_LOCAL_LOG_ERROR: flLogLevel = 4; +pub const FOUNDRY_LOCAL_LOG_FATAL: flLogLevel = 5; + +pub type flDeviceType = c_int; +pub const FOUNDRY_LOCAL_DEVICE_NOTSET: flDeviceType = 0; +pub const FOUNDRY_LOCAL_DEVICE_CPU: flDeviceType = 1; +pub const FOUNDRY_LOCAL_DEVICE_GPU: flDeviceType = 2; +pub const FOUNDRY_LOCAL_DEVICE_NPU: flDeviceType = 3; + +pub type flTensorDataType = c_int; + +pub type flItemType = c_int; +pub const FOUNDRY_LOCAL_ITEM_UNKNOWN: flItemType = 0; +pub const FOUNDRY_LOCAL_ITEM_BYTES: flItemType = 1; +pub const FOUNDRY_LOCAL_ITEM_TENSOR: flItemType = 10; +pub const FOUNDRY_LOCAL_ITEM_TEXT: flItemType = 20; +pub const FOUNDRY_LOCAL_ITEM_MESSAGE: flItemType = 21; +pub const FOUNDRY_LOCAL_ITEM_IMAGE: flItemType = 25; +pub const FOUNDRY_LOCAL_ITEM_AUDIO: flItemType = 30; +pub const FOUNDRY_LOCAL_ITEM_TOOL_CALL: flItemType = 100; +pub const FOUNDRY_LOCAL_ITEM_TOOL_RESULT: flItemType = 101; +pub const FOUNDRY_LOCAL_ITEM_QUEUE: flItemType = 200; + +pub type flTextItemType = c_int; +pub const FOUNDRY_LOCAL_TEXT_ITEM_TYPE_DEFAULT: flTextItemType = 0; +pub const FOUNDRY_LOCAL_TEXT_ITEM_TYPE_REASONING: flTextItemType = 1; +pub const FOUNDRY_LOCAL_TEXT_ITEM_TYPE_OPENAI_JSON: flTextItemType = 2; + +pub type flMessageRole = c_int; +pub const FOUNDRY_LOCAL_ROLE_NONE: flMessageRole = 0; +pub const FOUNDRY_LOCAL_ROLE_SYSTEM: flMessageRole = 1; +pub const FOUNDRY_LOCAL_ROLE_USER: flMessageRole = 2; +pub const FOUNDRY_LOCAL_ROLE_ASSISTANT: flMessageRole = 3; +pub const FOUNDRY_LOCAL_ROLE_TOOL: flMessageRole = 4; +pub const FOUNDRY_LOCAL_ROLE_DEVELOPER: flMessageRole = 5; + +pub type flFinishReason = c_int; +pub const FOUNDRY_LOCAL_FINISH_NONE: flFinishReason = 0; +pub const FOUNDRY_LOCAL_FINISH_ERROR: flFinishReason = 1; +pub const FOUNDRY_LOCAL_FINISH_STOP: flFinishReason = 2; +pub const FOUNDRY_LOCAL_FINISH_LENGTH: flFinishReason = 3; +pub const FOUNDRY_LOCAL_FINISH_TOOL_CALLS: flFinishReason = 4; + +pub type flToolChoice = c_int; +pub const FOUNDRY_LOCAL_TOOL_CHOICE_AUTO: flToolChoice = 0; +pub const FOUNDRY_LOCAL_TOOL_CHOICE_NONE: flToolChoice = 1; +pub const FOUNDRY_LOCAL_TOOL_CHOICE_REQUIRED: flToolChoice = 2; + +// ── Well-known property / parameter keys ───────────────────────────────────── + +pub const FOUNDRY_LOCAL_MODEL_PROP_DISPLAY_NAME_STR: &str = "display_name"; +pub const FOUNDRY_LOCAL_MODEL_PROP_MODEL_TYPE_STR: &str = "type"; +pub const FOUNDRY_LOCAL_MODEL_PROP_PUBLISHER_STR: &str = "publisher"; +pub const FOUNDRY_LOCAL_MODEL_PROP_LICENSE_STR: &str = "license"; +pub const FOUNDRY_LOCAL_MODEL_PROP_LICENSE_DESCRIPTION_STR: &str = "license_description"; +pub const FOUNDRY_LOCAL_MODEL_PROP_TASK_STR: &str = "task"; +pub const FOUNDRY_LOCAL_MODEL_PROP_MODEL_PROVIDER_STR: &str = "model_provider"; +pub const FOUNDRY_LOCAL_MODEL_PROP_MIN_FL_VERSION_STR: &str = "min_fl_version"; +pub const FOUNDRY_LOCAL_MODEL_PROP_INPUT_MODALITIES_STR: &str = "input_modalities"; +pub const FOUNDRY_LOCAL_MODEL_PROP_OUTPUT_MODALITIES_STR: &str = "output_modalities"; +pub const FOUNDRY_LOCAL_MODEL_PROP_CAPABILITIES_STR: &str = "capabilities"; + +pub const FOUNDRY_LOCAL_MODEL_PROP_SUPPORTS_TOOL_CALLING_INT: &str = "supports_tool_calling"; +pub const FOUNDRY_LOCAL_MODEL_PROP_SUPPORTS_REASONING_INT: &str = "supports_reasoning"; +pub const FOUNDRY_LOCAL_MODEL_PROP_FILESIZE_MB_INT: &str = "filesize_mb"; +pub const FOUNDRY_LOCAL_MODEL_PROP_MAX_OUTPUT_TOKENS_INT: &str = "max_output_tokens"; +pub const FOUNDRY_LOCAL_MODEL_PROP_CREATED_AT_UNIX_INT: &str = "created_at_unix"; +pub const FOUNDRY_LOCAL_MODEL_PROP_CONTEXT_LENGTH_INT: &str = "context_length"; + +pub const FOUNDRY_LOCAL_PARAM_TEMPERATURE: &str = "temperature"; +pub const FOUNDRY_LOCAL_PARAM_TOP_P: &str = "top_p"; +pub const FOUNDRY_LOCAL_PARAM_TOP_K: &str = "top_k"; +pub const FOUNDRY_LOCAL_PARAM_MAX_OUTPUT_TOKENS: &str = "max_output_tokens"; +pub const FOUNDRY_LOCAL_PARAM_FREQUENCY_PENALTY: &str = "frequency_penalty"; +pub const FOUNDRY_LOCAL_PARAM_PRESENCE_PENALTY: &str = "presence_penalty"; +pub const FOUNDRY_LOCAL_PARAM_SEED: &str = "seed"; +pub const FOUNDRY_LOCAL_PARAM_TOOL_CHOICE: &str = "tool_choice"; + +// ── Versioned data structs ─────────────────────────────────────────────────── + +#[repr(C)] +pub struct flUsage { + pub version: u32, + pub prompt_tokens: i64, + pub completion_tokens: i64, + pub total_tokens: i64, +} + +#[repr(C)] +pub struct flEpInfo { + pub version: u32, + pub name: *const c_char, + pub is_registered: bool, +} + +pub type flBytesDataDeleter = + Option; +pub type flTensorDataDeleter = + Option; +pub type flImageDataDeleter = + Option; +pub type flAudioDataDeleter = + Option; + +#[repr(C)] +pub struct flTextData { + pub version: u32, + pub text: *const c_char, + pub r#type: flTextItemType, +} + +#[repr(C)] +pub struct flBytesData { + pub version: u32, + pub item_type: flItemType, + pub data: *const c_void, + pub mutable_data: *mut c_void, + pub data_size: usize, + pub deleter: flBytesDataDeleter, + pub deleter_user_data: *mut c_void, +} + +#[repr(C)] +pub struct flTensorData { + pub version: u32, + pub data_type: flTensorDataType, + pub data: *const c_void, + pub mutable_data: *mut c_void, + pub shape: *const i64, + pub rank: usize, + pub deleter: flTensorDataDeleter, + pub deleter_user_data: *mut c_void, +} + +#[repr(C)] +pub struct flMessageData { + pub version: u32, + pub role: flMessageRole, + pub content_items: *const *const flItem, + pub content_items_count: usize, + pub name: *const c_char, +} + +#[repr(C)] +pub struct flImageData { + pub version: u32, + pub data: *const c_void, + pub mutable_data: *mut c_void, + pub data_size: usize, + pub format: *const c_char, + pub uri: *const c_char, + pub deleter: flImageDataDeleter, + pub deleter_user_data: *mut c_void, +} + +#[repr(C)] +pub struct flAudioData { + pub version: u32, + pub data: *const c_void, + pub mutable_data: *mut c_void, + pub data_size: usize, + pub format: *const c_char, + pub uri: *const c_char, + pub sample_rate: c_int, + pub channels: c_int, + pub deleter: flAudioDataDeleter, + pub deleter_user_data: *mut c_void, +} + +#[repr(C)] +pub struct flToolCallData { + pub version: u32, + pub call_id: *const c_char, + pub name: *const c_char, + pub arguments: *const c_char, +} + +#[repr(C)] +pub struct flToolResultData { + pub version: u32, + pub call_id: *const c_char, + pub result: *const c_char, +} + +#[repr(C)] +pub struct flStreamingCallbackData { + pub version: u32, + pub item_queue: *mut flItemQueue, +} + +#[repr(C)] +pub struct flToolDefinition { + pub version: u32, + pub name: *const c_char, + pub description: *const c_char, + pub json_schema: *const c_char, +} + +// ── Callback types (plain C calling convention) ────────────────────────────── + +pub type flProgressCallback = + Option c_int>; +pub type flStreamingCallback = + Option c_int>; +pub type flEpProgressCallback = Option< + unsafe extern "C" fn(ep_name: *const c_char, value: f32, user_data: *mut c_void) -> c_int, +>; + +// ── Exported entry points (FL_API_CALL == __stdcall on Win32) ──────────────── +// +// The library is loaded at runtime via `libloading`; these are the signatures of +// the two exported symbols resolved from it. `flApi` is the root vtable type. + +pub type flApi = flApiVtable; + +pub type FoundryLocalGetApiFn = unsafe extern "system" fn(version: u32) -> *const flApiVtable; +pub type FoundryLocalGetVersionStringFn = unsafe extern "system" fn() -> *const c_char; + +pub const FOUNDRY_LOCAL_GET_API_SYMBOL: &[u8] = b"FoundryLocalGetApi\0"; +pub const FOUNDRY_LOCAL_GET_VERSION_STRING_SYMBOL: &[u8] = b"FoundryLocalGetVersionString\0"; + +// ── Function tables ────────────────────────────────────────────────────────── +// +// Field order and signatures MUST match foundry_local_c.h exactly. New entries +// are only ever appended at the end of each table. + +/// Root API table (`flApi`). +#[repr(C)] +pub struct flApiVtable { + // Status + pub Status_Create: + unsafe extern "system" fn(error_code: flErrorCode, error_msg: *const c_char) -> flStatusPtr, + pub Status_Release: unsafe extern "system" fn(instance: *mut flStatus), + pub Status_GetErrorCode: unsafe extern "system" fn(status: *const flStatus) -> flErrorCode, + pub Status_GetErrorMessage: unsafe extern "system" fn(status: *const flStatus) -> *const c_char, + + // Manager lifecycle + pub Manager_Create: unsafe extern "system" fn( + config: *const flConfiguration, + out_manager: *mut *mut flManager, + ) -> flStatusPtr, + pub Manager_Release: unsafe extern "system" fn(instance: *mut flManager), + + pub Manager_GetCatalog: unsafe extern "system" fn( + manager: *const flManager, + out_catalog: *mut *mut flCatalog, + ) -> flStatusPtr, + pub Manager_WebServiceStart: unsafe extern "system" fn(manager: *mut flManager) -> flStatusPtr, + pub Manager_WebServiceUrls: unsafe extern "system" fn( + manager: *const flManager, + out_urls: *mut *const *const c_char, + out_num_urls: *mut usize, + ) -> flStatusPtr, + pub Manager_WebServiceStop: unsafe extern "system" fn(manager: *mut flManager) -> flStatusPtr, + + // Sub-API accessors + pub GetCatalogApi: unsafe extern "system" fn() -> *const flCatalogApiVtable, + pub GetConfigurationApi: unsafe extern "system" fn() -> *const flConfigurationApiVtable, + pub GetItemApi: unsafe extern "system" fn() -> *const flItemApiVtable, + pub GetInferenceApi: unsafe extern "system" fn() -> *const flInferenceApiVtable, + pub GetModelApi: unsafe extern "system" fn() -> *const flModelApiVtable, + + // KeyValuePairs + pub CreateKeyValuePairs: unsafe extern "system" fn(out: *mut *mut flKeyValuePairs), + pub AddKeyValuePair: unsafe extern "system" fn( + kvps: *mut flKeyValuePairs, + key: *const c_char, + value: *const c_char, + ), + pub GetKeyValue: unsafe extern "system" fn( + kvps: *const flKeyValuePairs, + key: *const c_char, + ) -> *const c_char, + pub GetKeyValuePairs: unsafe extern "system" fn( + kvps: *const flKeyValuePairs, + keys: *mut *const *const c_char, + values: *mut *const *const c_char, + num_entries: *mut usize, + ), + pub RemoveKeyValuePair: + unsafe extern "system" fn(kvps: *mut flKeyValuePairs, key: *const c_char), + pub KeyValuePairs_Release: unsafe extern "system" fn(instance: *mut flKeyValuePairs), + + // ModelList + pub ModelList_Release: unsafe extern "system" fn(instance: *mut flModelList), + pub ModelList_Size: unsafe extern "system" fn(models: *const flModelList) -> usize, + pub ModelList_GetAt: + unsafe extern "system" fn(models: *const flModelList, idx: usize) -> *mut flModel, + + // EP detection + pub Manager_GetDiscoverableEps: unsafe extern "system" fn( + manager: *const flManager, + out_eps: *mut *const flEpInfo, + out_count: *mut usize, + ) -> flStatusPtr, + pub Manager_DownloadAndRegisterEps: unsafe extern "system" fn( + manager: *mut flManager, + ep_names: *const *const c_char, + num_ep_names: usize, + callback: flEpProgressCallback, + user_data: *mut c_void, + ) -> flStatusPtr, + pub Manager_IsEpDownloadInProgress: + unsafe extern "system" fn(manager: *const flManager) -> bool, + + pub Manager_Shutdown: unsafe extern "system" fn(manager: *mut flManager) -> flStatusPtr, + pub Manager_IsShutdownRequested: unsafe extern "system" fn(manager: *const flManager) -> bool, +} + +/// Item API table (`flItemApi`). +#[repr(C)] +pub struct flItemApiVtable { + pub Create: + unsafe extern "system" fn(item_type: flItemType, out_item: *mut *mut flItem) -> flStatusPtr, + pub Item_Release: unsafe extern "system" fn(instance: *mut flItem), + pub GetType: unsafe extern "system" fn(item: *const flItem) -> flItemType, + + pub SetBytes: + unsafe extern "system" fn(item: *mut flItem, bytes: *const flBytesData) -> flStatusPtr, + pub SetTensor: + unsafe extern "system" fn(item: *mut flItem, tensor: *const flTensorData) -> flStatusPtr, + pub SetText: + unsafe extern "system" fn(item: *mut flItem, text_data: *const flTextData) -> flStatusPtr, + pub SetMessage: + unsafe extern "system" fn(item: *mut flItem, message: *const flMessageData) -> flStatusPtr, + pub SetImage: + unsafe extern "system" fn(item: *mut flItem, image: *const flImageData) -> flStatusPtr, + pub SetAudio: + unsafe extern "system" fn(item: *mut flItem, audio: *const flAudioData) -> flStatusPtr, + pub SetToolCall: unsafe extern "system" fn( + item: *mut flItem, + tool_call: *const flToolCallData, + ) -> flStatusPtr, + pub SetToolResult: unsafe extern "system" fn( + item: *mut flItem, + tool_result: *const flToolResultData, + ) -> flStatusPtr, + + pub GetBytes: + unsafe extern "system" fn(item: *const flItem, out_bytes: *mut flBytesData) -> flStatusPtr, + pub GetText: unsafe extern "system" fn( + item: *const flItem, + out_text_data: *mut flTextData, + ) -> flStatusPtr, + pub GetMessage: unsafe extern "system" fn( + item: *const flItem, + out_message: *mut flMessageData, + ) -> flStatusPtr, + pub GetTensor: unsafe extern "system" fn( + item: *const flItem, + out_tensor: *mut flTensorData, + ) -> flStatusPtr, + pub GetImage: + unsafe extern "system" fn(item: *const flItem, out_image: *mut flImageData) -> flStatusPtr, + pub GetAudio: + unsafe extern "system" fn(item: *const flItem, out_audio: *mut flAudioData) -> flStatusPtr, + pub GetToolCall: unsafe extern "system" fn( + item: *const flItem, + out_tool_call: *mut flToolCallData, + ) -> flStatusPtr, + pub GetToolResult: unsafe extern "system" fn( + item: *const flItem, + out_tool_result: *mut flToolResultData, + ) -> flStatusPtr, + + pub GetMetadata: unsafe extern "system" fn( + item: *const flItem, + out_metadata: *mut *const flKeyValuePairs, + ) -> flStatusPtr, + pub GetMutableMetadata: unsafe extern "system" fn( + item: *mut flItem, + out_metadata: *mut *mut flKeyValuePairs, + ) -> flStatusPtr, + pub GetQueue: unsafe extern "system" fn( + item: *mut flItem, + out_queue: *mut *mut flItemQueue, + ) -> flStatusPtr, + + // ItemQueue + pub ItemQueue_Create: + unsafe extern "system" fn(out_queue: *mut *mut flItemQueue) -> flStatusPtr, + pub ItemQueue_Release: unsafe extern "system" fn(instance: *mut flItemQueue), + pub ItemQueue_Push: + unsafe extern "system" fn(queue: *mut flItemQueue, item: *mut flItem) -> flStatusPtr, + pub ItemQueue_TryPop: + unsafe extern "system" fn(queue: *mut flItemQueue, out_item: *mut *mut flItem) -> bool, + pub ItemQueue_Size: unsafe extern "system" fn(queue: *const flItemQueue) -> usize, + pub ItemQueue_MarkFinished: unsafe extern "system" fn(queue: *mut flItemQueue), + pub ItemQueue_IsFinished: unsafe extern "system" fn(queue: *const flItemQueue) -> bool, +} + +/// Inference API table (`flInferenceApi`). +#[repr(C)] +pub struct flInferenceApiVtable { + pub Request_Create: unsafe extern "system" fn(out_request: *mut *mut flRequest) -> flStatusPtr, + pub Request_Release: unsafe extern "system" fn(instance: *mut flRequest), + pub Request_AddItem: unsafe extern "system" fn( + request: *mut flRequest, + item: *mut flItem, + take_ownership: bool, + ) -> flStatusPtr, + pub Request_GetItemCount: unsafe extern "system" fn(request: *const flRequest) -> usize, + pub Request_GetItem: unsafe extern "system" fn( + request: *const flRequest, + idx: usize, + out_item: *mut *const flItem, + ) -> flStatusPtr, + pub Request_SetOptions: unsafe extern "system" fn( + request: *mut flRequest, + options: *const flKeyValuePairs, + ) -> flStatusPtr, + pub Request_Cancel: unsafe extern "system" fn(request: *mut flRequest) -> flStatusPtr, + + pub Response_Create: + unsafe extern "system" fn(out_response: *mut *mut flResponse) -> flStatusPtr, + pub Response_Release: unsafe extern "system" fn(instance: *mut flResponse), + pub Response_GetItemCount: unsafe extern "system" fn(response: *const flResponse) -> usize, + pub Response_GetItem: unsafe extern "system" fn( + response: *const flResponse, + idx: usize, + out_item: *mut *const flItem, + ) -> flStatusPtr, + pub Response_GetFinishReason: + unsafe extern "system" fn(response: *const flResponse) -> flFinishReason, + pub Response_GetUsage: unsafe extern "system" fn( + response: *const flResponse, + out_usage: *mut flUsage, + ) -> flStatusPtr, + + pub Session_Create: unsafe extern "system" fn( + model: *const flModel, + out_session: *mut *mut flSession, + ) -> flStatusPtr, + pub Session_Release: unsafe extern "system" fn(instance: *mut flSession), + pub Session_SetStreamingCallback: unsafe extern "system" fn( + session: *mut flSession, + callback: flStreamingCallback, + user_data: *mut c_void, + ) -> flStatusPtr, + pub Session_SetOptions: unsafe extern "system" fn( + session: *mut flSession, + options: *const flKeyValuePairs, + ) -> flStatusPtr, + pub Session_ProcessRequest: unsafe extern "system" fn( + session: *mut flSession, + request: *const flRequest, + response: *mut *mut flResponse, + ) -> flStatusPtr, + pub Session_AddToolDefinition: unsafe extern "system" fn( + session: *mut flSession, + tool_def: *const flToolDefinition, + ) -> flStatusPtr, + pub Session_RemoveToolDefinition: unsafe extern "system" fn( + session: *mut flSession, + tool_name: *const c_char, + out_removed: *mut bool, + ) -> flStatusPtr, + pub Session_GetTurnCount: unsafe extern "system" fn(session: *const flSession) -> usize, + pub Session_UndoTurns: + unsafe extern "system" fn(session: *mut flSession, count: usize) -> flStatusPtr, +} + +/// Configuration API table (`flConfigurationApi`). +#[repr(C)] +pub struct flConfigurationApiVtable { + pub Create: unsafe extern "system" fn( + app_name: *const c_char, + out_config: *mut *mut flConfiguration, + ) -> flStatusPtr, + pub Configuration_Release: unsafe extern "system" fn(instance: *mut flConfiguration), + pub SetDefaultLogLevel: + unsafe extern "system" fn(config: *mut flConfiguration, level: flLogLevel) -> flStatusPtr, + pub SetAppDataDir: + unsafe extern "system" fn(config: *mut flConfiguration, dir: *const c_char) -> flStatusPtr, + pub SetLogsDir: + unsafe extern "system" fn(config: *mut flConfiguration, dir: *const c_char) -> flStatusPtr, + pub SetModelCacheDir: + unsafe extern "system" fn(config: *mut flConfiguration, dir: *const c_char) -> flStatusPtr, + pub AddCatalogUrl: unsafe extern "system" fn( + config: *mut flConfiguration, + url: *const c_char, + filter_override: *const c_char, + ) -> flStatusPtr, + pub SetCatalogRegion: unsafe extern "system" fn( + config: *mut flConfiguration, + region: *const c_char, + ) -> flStatusPtr, + pub AddWebServiceEndpoint: + unsafe extern "system" fn(config: *mut flConfiguration, url: *const c_char) -> flStatusPtr, + pub SetExternalServiceUrl: + unsafe extern "system" fn(config: *mut flConfiguration, url: *const c_char) -> flStatusPtr, + pub SetAdditionalOptions: unsafe extern "system" fn( + config: *mut flConfiguration, + options: *const flKeyValuePairs, + ) -> flStatusPtr, +} + +/// Catalog API table (`flCatalogApi`). +#[repr(C)] +pub struct flCatalogApiVtable { + pub GetName: unsafe extern "system" fn( + catalog: *const flCatalog, + out_name: *mut *const c_char, + ) -> flStatusPtr, + pub GetModels: unsafe extern "system" fn( + catalog: *const flCatalog, + out_models: *mut *mut flModelList, + ) -> flStatusPtr, + pub GetModel: unsafe extern "system" fn( + catalog: *const flCatalog, + alias: *const c_char, + out_model: *mut *mut flModel, + ) -> flStatusPtr, + pub GetModelVariant: unsafe extern "system" fn( + catalog: *const flCatalog, + model_id: *const c_char, + out_model: *mut *mut flModel, + ) -> flStatusPtr, + pub GetLatestVersion: unsafe extern "system" fn( + catalog: *const flCatalog, + model: *const flModel, + out_model: *mut *mut flModel, + ) -> flStatusPtr, + pub GetCachedModels: unsafe extern "system" fn( + catalog: *const flCatalog, + out_models: *mut *mut flModelList, + ) -> flStatusPtr, + pub GetLoadedModels: unsafe extern "system" fn( + catalog: *const flCatalog, + out_models: *mut *mut flModelList, + ) -> flStatusPtr, +} + +/// Model API table (`flModelApi`). +#[repr(C)] +pub struct flModelApiVtable { + pub GetInfo: unsafe extern "system" fn( + model: *const flModel, + out_info: *mut *const flModelInfo, + ) -> flStatusPtr, + pub GetInputOutputInfo: unsafe extern "system" fn( + model: *const flModel, + out_inputs: *mut *const *const flItem, + out_num_inputs: *mut usize, + out_outputs: *mut *const *const flItem, + out_num_outputs: *mut usize, + ) -> flStatusPtr, + + pub IsCached: + unsafe extern "system" fn(model: *const flModel, out_cached: *mut c_int) -> flStatusPtr, + pub GetPath: unsafe extern "system" fn( + model: *const flModel, + out_path: *mut *const c_char, + ) -> flStatusPtr, + pub Download: unsafe extern "system" fn( + model: *mut flModel, + callback: flProgressCallback, + user_data: *mut c_void, + ) -> flStatusPtr, + + pub IsLoaded: + unsafe extern "system" fn(model: *const flModel, out_loaded: *mut c_int) -> flStatusPtr, + pub Load: unsafe extern "system" fn(model: *mut flModel) -> flStatusPtr, + pub Unload: unsafe extern "system" fn(model: *mut flModel) -> flStatusPtr, + pub RemoveFromCache: unsafe extern "system" fn(model: *mut flModel) -> flStatusPtr, + + pub GetVariants: unsafe extern "system" fn( + model: *const flModel, + out_variants: *mut *mut flModelList, + ) -> flStatusPtr, + pub SelectVariant: + unsafe extern "system" fn(model: *mut flModel, variant: *const flModel) -> flStatusPtr, + + pub Info_GetId: unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetName: unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetVersion: unsafe extern "system" fn(info: *const flModelInfo) -> c_int, + pub Info_GetAlias: unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetUri: unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetDeviceType: unsafe extern "system" fn(info: *const flModelInfo) -> flDeviceType, + pub Info_GetExecutionProvider: + unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetTask: unsafe extern "system" fn(info: *const flModelInfo) -> *const c_char, + pub Info_GetPromptTemplates: + unsafe extern "system" fn(info: *const flModelInfo) -> *const flKeyValuePairs, + pub Info_GetModelSettings: + unsafe extern "system" fn(info: *const flModelInfo) -> *const flKeyValuePairs, + pub Info_GetStringProperty: + unsafe extern "system" fn(info: *const flModelInfo, key: *const c_char) -> *const c_char, + pub Info_GetIntProperty: unsafe extern "system" fn( + info: *const flModelInfo, + key: *const c_char, + default_value: i64, + ) -> i64, +} diff --git a/sdk_v2/rust/src/detail/info.rs b/sdk_v2/rust/src/detail/info.rs new file mode 100644 index 000000000..3300a299e --- /dev/null +++ b/sdk_v2/rust/src/detail/info.rs @@ -0,0 +1,152 @@ +//! Builds the public [`ModelInfo`] from native `flModelInfo` accessors. + +use super::api::{cstr_to_string, read_kvps, Api}; +use super::ffi::*; +use super::native::NativeModel; +use crate::error::Result; +use crate::types::{DeviceType, ModelInfo, ModelSettings, Parameter, PromptTemplate, Runtime}; + +fn device_type(value: flDeviceType) -> DeviceType { + match value { + FOUNDRY_LOCAL_DEVICE_CPU => DeviceType::CPU, + FOUNDRY_LOCAL_DEVICE_GPU => DeviceType::GPU, + FOUNDRY_LOCAL_DEVICE_NPU => DeviceType::NPU, + _ => DeviceType::Invalid, + } +} + +/// Build a fully-populated [`ModelInfo`] for the given model handle. +pub(crate) fn build_model_info(api: &Api, native: &NativeModel) -> Result { + let info = native.info_ptr()?; + let m = api.model_api(); + + // SAFETY: `info` is a valid model-owned info pointer; all accessors below + // return pointers owned by the model and valid for the duration of use. + unsafe { + let id = cstr_to_string((m.Info_GetId)(info)).unwrap_or_default(); + let name = cstr_to_string((m.Info_GetName)(info)).unwrap_or_default(); + let version = (m.Info_GetVersion)(info).max(0) as u64; + let alias = cstr_to_string((m.Info_GetAlias)(info)).unwrap_or_default(); + let uri = cstr_to_string((m.Info_GetUri)(info)).unwrap_or_default(); + let task = cstr_to_string((m.Info_GetTask)(info)); + let execution_provider = cstr_to_string((m.Info_GetExecutionProvider)(info)); + let device = device_type((m.Info_GetDeviceType)(info)); + + let str_prop = |key: &str| -> Option { + let c = std::ffi::CString::new(key).ok()?; + cstr_to_string((m.Info_GetStringProperty)(info, c.as_ptr())) + }; + let int_prop = |key: &str, default_value: i64| -> i64 { + match std::ffi::CString::new(key) { + Ok(c) => (m.Info_GetIntProperty)(info, c.as_ptr(), default_value), + Err(_) => default_value, + } + }; + let opt_u64 = |v: i64| -> Option { + if v >= 0 { + Some(v as u64) + } else { + None + } + }; + let opt_bool = |v: i64| -> Option { + match v { + 0 => Some(false), + 1 => Some(true), + _ => None, + } + }; + + let runtime = execution_provider.clone().map(|ep| Runtime { + device_type: device.clone(), + execution_provider: ep, + }); + + let prompt_template = build_prompt_template(api, (m.Info_GetPromptTemplates)(info)); + let model_settings = build_model_settings(api, (m.Info_GetModelSettings)(info)); + + Ok(ModelInfo { + id, + name, + version, + alias, + display_name: str_prop(FOUNDRY_LOCAL_MODEL_PROP_DISPLAY_NAME_STR), + provider_type: str_prop(FOUNDRY_LOCAL_MODEL_PROP_MODEL_PROVIDER_STR) + .unwrap_or_default(), + uri, + model_type: str_prop(FOUNDRY_LOCAL_MODEL_PROP_MODEL_TYPE_STR).unwrap_or_default(), + prompt_template, + publisher: str_prop(FOUNDRY_LOCAL_MODEL_PROP_PUBLISHER_STR), + model_settings, + license: str_prop(FOUNDRY_LOCAL_MODEL_PROP_LICENSE_STR), + license_description: str_prop(FOUNDRY_LOCAL_MODEL_PROP_LICENSE_DESCRIPTION_STR), + cached: native.is_cached()?, + task, + runtime, + file_size_mb: opt_u64(int_prop(FOUNDRY_LOCAL_MODEL_PROP_FILESIZE_MB_INT, -1)), + supports_tool_calling: opt_bool(int_prop( + FOUNDRY_LOCAL_MODEL_PROP_SUPPORTS_TOOL_CALLING_INT, + -1, + )), + max_output_tokens: opt_u64(int_prop( + FOUNDRY_LOCAL_MODEL_PROP_MAX_OUTPUT_TOKENS_INT, + -1, + )), + min_fl_version: str_prop(FOUNDRY_LOCAL_MODEL_PROP_MIN_FL_VERSION_STR), + created_at_unix: int_prop(FOUNDRY_LOCAL_MODEL_PROP_CREATED_AT_UNIX_INT, 0).max(0) + as u64, + context_length: opt_u64(int_prop(FOUNDRY_LOCAL_MODEL_PROP_CONTEXT_LENGTH_INT, -1)), + input_modalities: str_prop(FOUNDRY_LOCAL_MODEL_PROP_INPUT_MODALITIES_STR), + output_modalities: str_prop(FOUNDRY_LOCAL_MODEL_PROP_OUTPUT_MODALITIES_STR), + capabilities: str_prop(FOUNDRY_LOCAL_MODEL_PROP_CAPABILITIES_STR), + }) + } +} + +unsafe fn build_prompt_template(api: &Api, kvps: *const flKeyValuePairs) -> Option { + if kvps.is_null() { + return None; + } + let pairs = read_kvps(api, kvps); + if pairs.is_empty() { + return None; + } + let get = |key: &str| -> Option { + pairs + .iter() + .find(|(k, _)| k == key) + .and_then(|(_, v)| v.clone()) + }; + let template = PromptTemplate { + system: get("system"), + user: get("user"), + assistant: get("assistant"), + prompt: get("prompt"), + }; + if template.system.is_none() + && template.user.is_none() + && template.assistant.is_none() + && template.prompt.is_none() + { + None + } else { + Some(template) + } +} + +unsafe fn build_model_settings(api: &Api, kvps: *const flKeyValuePairs) -> Option { + if kvps.is_null() { + return None; + } + let pairs = read_kvps(api, kvps); + if pairs.is_empty() { + return None; + } + let parameters = pairs + .into_iter() + .map(|(name, value)| Parameter { name, value }) + .collect::>(); + Some(ModelSettings { + parameters: Some(parameters), + }) +} diff --git a/sdk_v2/rust/src/detail/items.rs b/sdk_v2/rust/src/detail/items.rs new file mode 100644 index 000000000..dcec5b6a1 --- /dev/null +++ b/sdk_v2/rust/src/detail/items.rs @@ -0,0 +1,226 @@ +//! Helpers for constructing and reading native `flItem`s. +//! +//! The OpenAI facade and live-audio session build TEXT (`OPENAI_JSON`), AUDIO, +//! and BYTES items, and read TEXT items back out of responses / streamed queues. + +use std::os::raw::c_char; +use std::ptr; + +use super::api::{cstr_to_string, to_cstring, Api}; +use super::ffi::*; +use crate::error::Result; + +/// Create a TEXT item with the given subtype. The native layer copies the text. +pub(crate) fn make_text_item( + api: &Api, + text: &str, + item_type: flTextItemType, +) -> Result<*mut flItem> { + let mut item: *mut flItem = ptr::null_mut(); + api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_TEXT, &mut item) })?; + let c = to_cstring(text)?; + let data = flTextData { + version: FOUNDRY_LOCAL_API_VERSION, + text: c.as_ptr(), + r#type: item_type, + }; + // SAFETY: `item` is a valid TEXT item; the native call copies the string. + let status = unsafe { (api.item_api().SetText)(item, &data) }; + if let Err(e) = api.check(status) { + unsafe { (api.item_api().Item_Release)(item) }; + return Err(e); + } + Ok(item) +} + +/// Create a TEXT item carrying an opaque OpenAI REST JSON payload. +pub(crate) fn make_openai_json_item(api: &Api, json: &str) -> Result<*mut flItem> { + make_text_item(api, json, FOUNDRY_LOCAL_TEXT_ITEM_TYPE_OPENAI_JSON) +} + +/// Create a byte-backed AUDIO item describing a PCM format (used as the live +/// audio format descriptor) or carrying audio bytes. The native layer copies +/// the data, so the caller's buffer need not outlive the call. +pub(crate) fn make_audio_item( + api: &Api, + data: &[u8], + format: Option<&str>, + sample_rate: i32, + channels: i32, +) -> Result<*mut flItem> { + let mut item: *mut flItem = ptr::null_mut(); + api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_AUDIO, &mut item) })?; + + let format_c = match format { + Some(f) => Some(to_cstring(f)?), + None => None, + }; + + // Like SetBytes, SetAudio does not copy the sample buffer — it borrows the + // pointer (and frees it via the deleter when one is supplied). Transfer an + // owned heap allocation so the buffer outlives this call; the format string + // is copied natively, so a transient CString is fine. + let (data_ptr, len, deleter): (*mut u8, usize, flAudioDataDeleter) = if data.is_empty() { + (ptr::null_mut(), 0, None) + } else { + let boxed: Box<[u8]> = data.to_vec().into_boxed_slice(); + let len = boxed.len(); + ( + Box::into_raw(boxed) as *mut u8, + len, + Some(rust_audio_deleter), + ) + }; + + let audio = flAudioData { + version: FOUNDRY_LOCAL_API_VERSION, + data: data_ptr as *const std::ffi::c_void, + mutable_data: data_ptr as *mut std::ffi::c_void, + data_size: len, + format: format_c.as_ref().map_or(ptr::null(), |c| c.as_ptr()), + uri: ptr::null(), + sample_rate, + channels, + deleter, + deleter_user_data: ptr::null_mut(), + }; + // SAFETY: `item` is a valid AUDIO item. On success the item owns `data_ptr` + // (freed via the deleter); on failure we reclaim it here to avoid a leak. + let status = unsafe { (api.item_api().SetAudio)(item, &audio) }; + if let Err(e) = api.check(status) { + unsafe { + if !data_ptr.is_null() { + drop(Box::from_raw(ptr::slice_from_raw_parts_mut(data_ptr, len))); + } + (api.item_api().Item_Release)(item); + } + return Err(e); + } + Ok(item) +} + +/// Deleter that reclaims a Rust-allocated `Box<[u8]>` owned by an AUDIO item. +/// Mirrors [`rust_bytes_deleter`]; see its docs for the ownership contract. +unsafe extern "C" fn rust_audio_deleter( + data: *const flAudioData, + _user_data: *mut std::ffi::c_void, +) { + if data.is_null() { + return; + } + let d = &*data; + if !d.mutable_data.is_null() && d.data_size > 0 { + let slice = ptr::slice_from_raw_parts_mut(d.mutable_data as *mut u8, d.data_size); + drop(Box::from_raw(slice)); + } +} + +/// Deleter that reclaims a Rust-allocated `Box<[u8]>` owned by a BYTES item. +/// +/// The native item calls this on destruction. `mutable_data` is the pointer we +/// handed over via `Box::into_raw`, and `data_size` is its length; together they +/// reconstruct the boxed slice so it is dropped exactly once. +unsafe extern "C" fn rust_bytes_deleter( + data: *const flBytesData, + _user_data: *mut std::ffi::c_void, +) { + if data.is_null() { + return; + } + let d = &*data; + if !d.mutable_data.is_null() && d.data_size > 0 { + let slice = ptr::slice_from_raw_parts_mut(d.mutable_data as *mut u8, d.data_size); + drop(Box::from_raw(slice)); + } +} + +/// Create a BYTES item tagged with the given originating item type (e.g. AUDIO +/// for raw PCM chunks pushed into a live session). +/// +/// The native `SetBytes` does **not** copy — it stores the pointer and (when a +/// deleter is supplied) takes ownership of the buffer, freeing it via the +/// deleter when the item is destroyed. The item may be consumed asynchronously +/// (e.g. drained from an `ItemQueue` by a streaming worker long after this +/// returns), so the buffer must outlive this call. We therefore transfer an +/// owned heap allocation to the item rather than lending a caller buffer. +pub(crate) fn make_bytes_item( + api: &Api, + data: &[u8], + item_type: flItemType, +) -> Result<*mut flItem> { + let mut item: *mut flItem = ptr::null_mut(); + api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_BYTES, &mut item) })?; + + if data.is_empty() { + let bytes = flBytesData { + version: FOUNDRY_LOCAL_API_VERSION, + item_type, + data: ptr::null(), + mutable_data: ptr::null_mut(), + data_size: 0, + deleter: None, + deleter_user_data: ptr::null_mut(), + }; + let status = unsafe { (api.item_api().SetBytes)(item, &bytes) }; + if let Err(e) = api.check(status) { + unsafe { (api.item_api().Item_Release)(item) }; + return Err(e); + } + return Ok(item); + } + + // Transfer an owned copy to the item. `into_boxed_slice` guarantees + // capacity == len, so the deleter can reconstruct it from (ptr, len). + let boxed: Box<[u8]> = data.to_vec().into_boxed_slice(); + let len = boxed.len(); + let raw = Box::into_raw(boxed) as *mut u8; + + let bytes = flBytesData { + version: FOUNDRY_LOCAL_API_VERSION, + item_type, + data: raw as *const std::ffi::c_void, + mutable_data: raw as *mut std::ffi::c_void, + data_size: len, + deleter: Some(rust_bytes_deleter), + deleter_user_data: ptr::null_mut(), + }; + + // SAFETY: `item` is a valid BYTES item. On success the item owns `raw` and + // frees it via the deleter; on failure SetBytesData was not applied, so we + // reclaim and drop the box here to avoid leaking it. + let status = unsafe { (api.item_api().SetBytes)(item, &bytes) }; + if let Err(e) = api.check(status) { + unsafe { + let slice = ptr::slice_from_raw_parts_mut(raw, len); + drop(Box::from_raw(slice)); + (api.item_api().Item_Release)(item); + } + return Err(e); + } + Ok(item) +} + +/// Read the text of a TEXT item. Returns `None` for null/non-text items. +/// +/// # Safety +/// `item` must be null or a valid item pointer alive for the duration of this call. +pub(crate) unsafe fn read_text_item(api: &Api, item: *const flItem) -> Option { + if item.is_null() { + return None; + } + if (api.item_api().GetType)(item) != FOUNDRY_LOCAL_ITEM_TEXT { + return None; + } + let mut data = flTextData { + version: FOUNDRY_LOCAL_API_VERSION, + text: ptr::null::(), + r#type: FOUNDRY_LOCAL_TEXT_ITEM_TYPE_DEFAULT, + }; + if api + .check((api.item_api().GetText)(item, &mut data)) + .is_err() + { + return None; + } + cstr_to_string(data.text) +} diff --git a/sdk_v2/rust/src/detail/manager.rs b/sdk_v2/rust/src/detail/manager.rs new file mode 100644 index 000000000..6e1767f63 --- /dev/null +++ b/sdk_v2/rust/src/detail/manager.rs @@ -0,0 +1,229 @@ +//! Owning wrapper around a native `flManager` plus its discovery / web-service / +//! execution-provider operations. + +use std::os::raw::c_char; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use super::api::{cstr_to_string, Api}; +use super::ffi::*; +use crate::error::{FoundryLocalError, Result}; +use crate::types::EpInfo; + +/// Owns a native `flManager`. +/// +/// The manager is held by a process-lifetime singleton, so [`Drop`] effectively +/// never runs; the native handle is instead released by an `atexit` hook (see +/// [`teardown`](Self::teardown)) before the library's C++ static destructors run. +pub(crate) struct NativeManager { + api: Arc, + ptr: *mut flManager, + /// Set once the native manager has been released, so shutdown/release and + /// the `atexit` hook and `Drop` all coordinate to release exactly once. + released: AtomicBool, +} + +// SAFETY: the native manager is thread-safe; shutdown is documented as callable +// from any thread. +unsafe impl Send for NativeManager {} +unsafe impl Sync for NativeManager {} + +impl NativeManager { + /// Create a manager from a fully-built native configuration. + pub(crate) fn create(api: Arc, config: *const flConfiguration) -> Result { + let mut ptr: *mut flManager = std::ptr::null_mut(); + let status = unsafe { (api.root().Manager_Create)(config, &mut ptr) }; + api.check(status)?; + if ptr.is_null() { + return Err(FoundryLocalError::Internal { + reason: "Manager_Create returned a null manager".into(), + }); + } + Ok(Self { + api, + ptr, + released: AtomicBool::new(false), + }) + } + + /// The manager-owned catalog handle. + pub(crate) fn catalog_ptr(&self) -> Result<*mut flCatalog> { + let mut catalog: *mut flCatalog = std::ptr::null_mut(); + let status = unsafe { (self.api.root().Manager_GetCatalog)(self.ptr, &mut catalog) }; + self.api.check(status)?; + if catalog.is_null() { + return Err(FoundryLocalError::Internal { + reason: "Manager_GetCatalog returned a null catalog".into(), + }); + } + Ok(catalog) + } + + pub(crate) fn web_service_start(&self) -> Result<()> { + let status = unsafe { (self.api.root().Manager_WebServiceStart)(self.ptr) }; + self.api.check(status) + } + + pub(crate) fn web_service_stop(&self) -> Result<()> { + let status = unsafe { (self.api.root().Manager_WebServiceStop)(self.ptr) }; + self.api.check(status) + } + + pub(crate) fn web_service_urls(&self) -> Result> { + let mut urls: *const *const c_char = std::ptr::null(); + let mut count: usize = 0; + let status = + unsafe { (self.api.root().Manager_WebServiceUrls)(self.ptr, &mut urls, &mut count) }; + self.api.check(status)?; + if urls.is_null() || count == 0 { + return Ok(Vec::new()); + } + let mut out = Vec::with_capacity(count); + for i in 0..count { + // SAFETY: `urls` points to `count` valid C-string pointers. + if let Some(s) = unsafe { cstr_to_string(*urls.add(i)) } { + out.push(s); + } + } + Ok(out) + } + + pub(crate) fn discover_eps(&self) -> Result> { + let mut eps: *const flEpInfo = std::ptr::null(); + let mut count: usize = 0; + let status = + unsafe { (self.api.root().Manager_GetDiscoverableEps)(self.ptr, &mut eps, &mut count) }; + self.api.check(status)?; + if eps.is_null() || count == 0 { + return Ok(Vec::new()); + } + let mut out = Vec::with_capacity(count); + for i in 0..count { + // SAFETY: `eps` points to `count` valid flEpInfo structs. + let ep = unsafe { &*eps.add(i) }; + out.push(EpInfo { + name: unsafe { cstr_to_string(ep.name) }.unwrap_or_default(), + is_registered: ep.is_registered, + }); + } + Ok(out) + } + + /// Run an EP download/registration. Returns `None` on full success or + /// `Some(message)` describing a partial failure (a non-null native status). + pub(crate) fn download_and_register_eps( + &self, + names: Option<&[&str]>, + progress: Option, + cancel_flag: Option>, + ) -> Option { + // Build the optional C array of name pointers (kept alive across the call). + let name_cstrings: Option> = names.map(|ns| { + ns.iter() + .filter_map(|n| std::ffi::CString::new(*n).ok()) + .collect() + }); + let name_ptrs: Option> = name_cstrings + .as_ref() + .map(|cs| cs.iter().map(|c| c.as_ptr()).collect()); + + let (names_ptr, names_len) = match &name_ptrs { + Some(p) if !p.is_empty() => (p.as_ptr(), p.len()), + _ => (std::ptr::null(), 0usize), + }; + + let mut ctx = EpCtx { + progress, + cancel_flag, + cancelled: false, + }; + let user_data = &mut ctx as *mut EpCtx as *mut std::ffi::c_void; + let callback: flEpProgressCallback = Some(ep_trampoline); + + // SAFETY: name pointers and `ctx` outlive this blocking call. + let status = unsafe { + (self.api.root().Manager_DownloadAndRegisterEps)( + self.ptr, names_ptr, names_len, callback, user_data, + ) + }; + self.api.status_message(status) + } + + /// Begin graceful shutdown of the native manager (`Manager_Shutdown`). + /// + /// Idempotent and safe to call from any thread. Does **not** release the + /// native handle — that happens once at process exit via [`teardown`](Self::teardown). + pub(crate) fn shutdown(&self) -> Result<()> { + if self.released.load(Ordering::Acquire) { + return Ok(()); + } + // SAFETY: `ptr` is a live manager handle (not yet released — guarded above). + let status = unsafe { (self.api.root().Manager_Shutdown)(self.ptr) }; + self.api.check(status) + } + + /// Run the prescribed teardown exactly once: `Manager_Shutdown` then + /// `Manager_Release`. + /// + /// This is invoked from the process-exit hook (and `Drop`) so the manager's + /// C++ destructor runs *before* the library's static destructors — avoiding + /// the spdlog teardown abort (`mutex lock failed`) documented for the other + /// SDK bindings. Releasing is always attempted, even if shutdown errored. + pub(crate) fn teardown(&self) { + if self.released.swap(true, Ordering::AcqRel) { + return; + } + // SAFETY: `ptr` was created by Manager_Create and is released exactly + // once (guarded by the `released` swap above). + unsafe { + let status = (self.api.root().Manager_Shutdown)(self.ptr); + if !status.is_null() { + (self.api.root().Status_Release)(status); + } + (self.api.root().Manager_Release)(self.ptr); + } + } +} + +impl Drop for NativeManager { + fn drop(&mut self) { + self.teardown(); + } +} + +/// Boxed `(ep_name, percent)` progress callback. +pub(crate) type EpProgressCallback = Box; + +struct EpCtx { + progress: Option, + cancel_flag: Option>, + cancelled: bool, +} + +unsafe extern "C" fn ep_trampoline( + ep_name: *const c_char, + value: f32, + user_data: *mut std::ffi::c_void, +) -> std::os::raw::c_int { + if user_data.is_null() { + return 0; + } + let result = catch_unwind(AssertUnwindSafe(|| { + let ctx = &mut *(user_data as *mut EpCtx); + if ctx + .cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)) + { + ctx.cancelled = true; + return 1; + } + if let Some(cb) = ctx.progress.as_mut() { + let name = cstr_to_string(ep_name).unwrap_or_default(); + cb(&name, value as f64); + } + 0 + })); + result.unwrap_or(1) +} diff --git a/sdk_v2/rust/src/detail/mod.rs b/sdk_v2/rust/src/detail/mod.rs new file mode 100644 index 000000000..f8e299ea4 --- /dev/null +++ b/sdk_v2/rust/src/detail/mod.rs @@ -0,0 +1,13 @@ +//! Internal implementation detail of the Foundry Local SDK. +//! +//! Nothing in this module is part of the public API. + +pub(crate) mod api; +pub(crate) mod ffi; +pub(crate) mod info; +pub(crate) mod items; +pub(crate) mod manager; +pub(crate) mod model; +pub(crate) mod native; +pub(crate) mod session; +pub(crate) mod task; diff --git a/sdk_v2/rust/src/detail/model.rs b/sdk_v2/rust/src/detail/model.rs new file mode 100644 index 000000000..53712c49d --- /dev/null +++ b/sdk_v2/rust/src/detail/model.rs @@ -0,0 +1,405 @@ +//! Public [`Model`] type backed by catalog-owned native handles. +//! +//! Mirrors the legacy SDK: a `Model` is either a single variant or a group of +//! variants sharing an alias. Selection is tracked Rust-side (an index) and all +//! operations delegate to the selected variant's native handle, so +//! [`Model::info`] / [`Model::id`] return references that always reflect the +//! current selection. + +use std::fmt; +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; +use std::sync::Arc; + +use super::api::Api; +use super::info::build_model_info; +use super::native::NativeModel; +use super::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; +use crate::types::ModelInfo; + +/// One specific variant: its native handle plus the cached, immutable metadata. +#[derive(Clone)] +pub(crate) struct VariantData { + native: NativeModel, + info: ModelInfo, +} + +/// The public model type. +/// +/// A `Model` may represent either a group of variants (as returned by +/// [`Catalog::get_model`](crate::Catalog::get_model)) or a single variant (as +/// returned by [`Catalog::get_model_variant`](crate::Catalog::get_model_variant) +/// or [`Model::variants`]). +pub struct Model { + inner: ModelKind, +} + +type DownloadProgressCallback = Box; + +/// Builder for configuring and running a model download. +/// +/// Use this builder when combining optional settings like progress and cancellation. +pub struct DownloadBuilder<'a> { + model: &'a Model, + progress: Option, + cancel_flag: Option>, +} + +impl<'a> DownloadBuilder<'a> { + fn new(model: &'a Model) -> Self { + Self { + model, + progress: None, + cancel_flag: None, + } + } + + /// Report download progress as a percentage from 0.0 to 100.0. + pub fn progress(mut self, callback: F) -> Self + where + F: FnMut(f64) + Send + 'static, + { + self.progress = Some(Box::new(callback)); + self + } + + /// Cancel the download when `cancel_flag` is set to `true`. + pub fn cancel(mut self, cancel_flag: Arc) -> Self { + self.cancel_flag = Some(cancel_flag); + self + } + + /// Run the configured download. + pub async fn run(self) -> Result<()> { + let native = self.model.selected_variant().native.clone(); + let progress = self.progress; + let cancel_flag = self.cancel_flag; + spawn_blocking(move || native.download(progress, cancel_flag)).await + } +} +#[allow(clippy::large_enum_variant)] +enum ModelKind { + /// A single model variant (from `get_model_variant` or `variants()`). + Variant(VariantData), + /// A group of variants sharing the same alias (from `get_model`). + Group { + alias: String, + variants: Vec, + selected: AtomicUsize, + }, +} + +impl Clone for Model { + fn clone(&self) -> Self { + Self { + inner: match &self.inner { + ModelKind::Variant(v) => ModelKind::Variant(v.clone()), + ModelKind::Group { + alias, + variants, + selected, + } => ModelKind::Group { + alias: alias.clone(), + variants: variants.clone(), + selected: AtomicUsize::new(selected.load(Relaxed)), + }, + }, + } + } +} + +impl fmt::Debug for Model { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + ModelKind::Variant(v) => f + .debug_struct("Model::ModelVariant") + .field("id", &v.info.id) + .field("alias", &v.info.alias) + .finish(), + ModelKind::Group { + alias, + variants, + selected, + } => f + .debug_struct("Model::Model") + .field("alias", alias) + .field("id", &variants[selected.load(Relaxed)].info.id) + .field("variants_count", &variants.len()) + .field("selected_index", &selected.load(Relaxed)) + .finish(), + } + } +} + +// ── Construction (crate-internal) ──────────────────────────────────────────── + +impl Model { + /// Wrap a single leaf variant. + pub(crate) fn from_variant(api: &Arc, native: NativeModel) -> Result { + let info = build_model_info(api, &native)?; + Ok(Self { + inner: ModelKind::Variant(VariantData { native, info }), + }) + } + + /// Wrap an alias-group model, eagerly loading its variants. + pub(crate) fn from_group(api: &Arc, native: NativeModel) -> Result { + let group_info = build_model_info(api, &native)?; + let alias = group_info.alias.clone(); + + let mut variants = Vec::new(); + for variant_native in native.get_variants()? { + let info = build_model_info(api, &variant_native)?; + variants.push(VariantData { + native: variant_native, + info, + }); + } + + // A leaf masquerading as a group: fall back to a single-variant model. + if variants.is_empty() { + return Ok(Self { + inner: ModelKind::Variant(VariantData { + native, + info: group_info, + }), + }); + } + + // Prefer the first cached variant as the initial selection. + let selected = variants.iter().position(|v| v.info.cached).unwrap_or(0); + + Ok(Self { + inner: ModelKind::Group { + alias, + variants, + selected: AtomicUsize::new(selected), + }, + }) + } +} + +// ── Private helpers ────────────────────────────────────────────────────────── + +impl Model { + fn selected_variant(&self) -> &VariantData { + match &self.inner { + ModelKind::Variant(v) => v, + ModelKind::Group { + variants, selected, .. + } => &variants[selected.load(Relaxed)], + } + } + + pub(crate) fn selected_native(&self) -> &NativeModel { + &self.selected_variant().native + } +} + +// ── Public API ─────────────────────────────────────────────────────────────── + +impl Model { + /// Unique identifier of the (selected) variant. + pub fn id(&self) -> &str { + &self.selected_variant().info.id + } + + /// Alias shared by all variants of this model. + pub fn alias(&self) -> &str { + match &self.inner { + ModelKind::Variant(v) => &v.info.alias, + ModelKind::Group { alias, .. } => alias, + } + } + + /// Full catalog metadata for the (selected) variant. + pub fn info(&self) -> &ModelInfo { + &self.selected_variant().info + } + + /// Maximum context length (in tokens), or `None` if unknown. + pub fn context_length(&self) -> Option { + self.selected_variant().info.context_length + } + + /// Comma-separated input modalities (e.g. `"text,image"`), or `None`. + pub fn input_modalities(&self) -> Option<&str> { + self.selected_variant().info.input_modalities.as_deref() + } + + /// Comma-separated output modalities (e.g. `"text"`), or `None`. + pub fn output_modalities(&self) -> Option<&str> { + self.selected_variant().info.output_modalities.as_deref() + } + + /// Capability tags (e.g. `"reasoning"`), or `None`. + pub fn capabilities(&self) -> Option<&str> { + self.selected_variant().info.capabilities.as_deref() + } + + /// Whether the model supports tool/function calling, or `None`. + pub fn supports_tool_calling(&self) -> Option { + self.selected_variant().info.supports_tool_calling + } + + /// Whether the (selected) variant is cached on disk. + pub async fn is_cached(&self) -> Result { + let native = self.selected_native().clone(); + spawn_blocking(move || native.is_cached()).await + } + + /// Whether the (selected) variant is loaded into memory. + pub async fn is_loaded(&self) -> Result { + let native = self.selected_native().clone(); + spawn_blocking(move || native.is_loaded()).await + } + + /// Download the (selected) variant. If `progress` is provided it + /// receives download progress as a percentage (0.0–100.0). + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + let native = self.selected_native().clone(); + let progress: Option = + progress.map(|f| Box::new(f) as DownloadProgressCallback); + spawn_blocking(move || native.download(progress, None)).await + } + + /// Configure and run a model download with a builder. + /// + /// Use this for call sites that need progress, cancellation, or future + /// download options. + pub fn download_builder(&self) -> DownloadBuilder<'_> { + DownloadBuilder::new(self) + } + + /// Return the local file-system path of the (selected) variant. + pub async fn path(&self) -> Result { + let native = self.selected_native().clone(); + let id = self.id().to_owned(); + let path = spawn_blocking(move || native.path()).await?; + match path { + Some(p) => Ok(PathBuf::from(p)), + None => Err(FoundryLocalError::ModelOperation { + reason: format!("Error getting path for model {id}. Has it been downloaded?"), + }), + } + } + + /// Load the (selected) variant into memory. + pub async fn load(&self) -> Result<()> { + let native = self.selected_native().clone(); + spawn_blocking(move || native.load()).await + } + + /// Unload the (selected) variant from memory. + pub async fn unload(&self) -> Result { + let native = self.selected_native().clone(); + spawn_blocking(move || native.unload()).await?; + Ok(String::new()) + } + + /// Remove the (selected) variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + let native = self.selected_native().clone(); + spawn_blocking(move || native.remove_from_cache()).await?; + Ok(String::new()) + } + + /// Create a [`ChatClient`](crate::openai::ChatClient) bound to the (selected) variant. + pub fn create_chat_client(&self) -> crate::openai::ChatClient { + crate::openai::ChatClient::new(self.id(), self.selected_native().clone()) + } + + /// Create an [`AudioClient`](crate::openai::AudioClient) bound to the (selected) variant. + pub fn create_audio_client(&self) -> crate::openai::AudioClient { + crate::openai::AudioClient::new(self.id(), self.selected_native().clone()) + } + + /// Create an [`EmbeddingClient`](crate::openai::EmbeddingClient) bound to the (selected) variant. + pub fn create_embedding_client(&self) -> crate::openai::EmbeddingClient { + crate::openai::EmbeddingClient::new(self.id(), self.selected_native().clone()) + } + + /// Available variants of this model. + /// + /// For a single-variant model (e.g. from + /// [`Catalog::get_model_variant`](crate::Catalog::get_model_variant)), + /// this returns a single-element list containing itself. + pub fn variants(&self) -> Vec> { + match &self.inner { + ModelKind::Variant(v) => { + vec![Arc::new(Model { + inner: ModelKind::Variant(v.clone()), + })] + } + ModelKind::Group { variants, .. } => variants + .iter() + .map(|v| { + Arc::new(Model { + inner: ModelKind::Variant(v.clone()), + }) + }) + .collect(), + } + } + + /// Select a variant to use for subsequent operations. + /// + /// The `variant` must be one of the models returned by [`variants`](Model::variants). + /// + /// # Errors + /// + /// Returns an error if the variant does not belong to this model. + /// For single-variant models this always returns an error — use + /// [`Catalog::get_model`](crate::Catalog::get_model) to obtain a model + /// with all variants available. + pub fn select_variant(&self, variant: &Model) -> Result<()> { + self.select_variant_by_id(variant.id()) + } + + /// Select a variant by its unique id string. + /// + /// This is a convenience method for cases where you have a variant id + /// from an external source. Prefer [`select_variant`](Model::select_variant) + /// when you already have a [`Model`] reference from [`variants`](Model::variants). + /// + /// # Errors + /// + /// Returns an error if no variant with the given id exists. + /// For single-variant models this always returns an error — use + /// [`Catalog::get_model`](crate::Catalog::get_model) to obtain a model + /// with all variants available. + pub fn select_variant_by_id(&self, id: &str) -> Result<()> { + match &self.inner { + ModelKind::Variant(v) => Err(FoundryLocalError::ModelOperation { + reason: format!( + "select_variant is not supported on a single variant. \ + Call Catalog::get_model(\"{}\") to get a model with all variants available.", + v.info.alias + ), + }), + ModelKind::Group { + variants, + selected, + alias, + } => match variants.iter().position(|v| v.info.id == id) { + Some(pos) => { + selected.store(pos, Relaxed); + Ok(()) + } + None => { + let available: Vec<&str> = + variants.iter().map(|v| v.info.id.as_str()).collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Variant '{id}' not found for model '{alias}'. Available: {available:?}", + ), + }) + } + }, + } + } +} diff --git a/sdk_v2/rust/src/detail/native.rs b/sdk_v2/rust/src/detail/native.rs new file mode 100644 index 000000000..b2a362fa0 --- /dev/null +++ b/sdk_v2/rust/src/detail/native.rs @@ -0,0 +1,250 @@ +//! Non-owning wrapper around a native `flModel` handle. +//! +//! `flModel*` handles are owned by the catalog (which lives for the process +//! lifetime via the manager singleton) and are never released individually. A +//! `flModelList*`, by contrast, is owned by the caller: [`collect_models`] +//! eagerly extracts the contained handles and then releases the list. + +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use super::api::{cstr_to_string, Api}; +use super::ffi::*; +use crate::error::{FoundryLocalError, Result}; + +/// A borrowed handle to a catalog-owned `flModel`. +#[derive(Clone)] +pub(crate) struct NativeModel { + pub(crate) api: Arc, + pub(crate) ptr: *mut flModel, +} + +// SAFETY: model handles are owned by the catalog (process-lifetime) and the +// native implementation is thread-safe for independent operations. +unsafe impl Send for NativeModel {} +unsafe impl Sync for NativeModel {} + +impl NativeModel { + pub(crate) fn new(api: Arc, ptr: *mut flModel) -> Self { + Self { api, ptr } + } + + pub(crate) fn info_ptr(&self) -> Result<*const flModelInfo> { + let mut info: *const flModelInfo = std::ptr::null(); + // SAFETY: `ptr` is a valid catalog-owned model handle. + let status = unsafe { (self.api.model_api().GetInfo)(self.ptr, &mut info) }; + self.api.check(status)?; + Ok(info) + } + + pub(crate) fn is_cached(&self) -> Result { + let mut cached: std::os::raw::c_int = 0; + let status = unsafe { (self.api.model_api().IsCached)(self.ptr, &mut cached) }; + self.api.check(status)?; + Ok(cached != 0) + } + + pub(crate) fn is_loaded(&self) -> Result { + let mut loaded: std::os::raw::c_int = 0; + let status = unsafe { (self.api.model_api().IsLoaded)(self.ptr, &mut loaded) }; + self.api.check(status)?; + Ok(loaded != 0) + } + + pub(crate) fn path(&self) -> Result> { + let mut path: *const std::os::raw::c_char = std::ptr::null(); + let status = unsafe { (self.api.model_api().GetPath)(self.ptr, &mut path) }; + self.api.check(status)?; + // SAFETY: `path`, when non-null, points to model-owned storage valid now. + Ok(unsafe { cstr_to_string(path) }) + } + + pub(crate) fn load(&self) -> Result<()> { + let status = unsafe { (self.api.model_api().Load)(self.ptr) }; + self.api.check(status) + } + + pub(crate) fn unload(&self) -> Result<()> { + let status = unsafe { (self.api.model_api().Unload)(self.ptr) }; + self.api.check(status) + } + + pub(crate) fn remove_from_cache(&self) -> Result<()> { + let status = unsafe { (self.api.model_api().RemoveFromCache)(self.ptr) }; + self.api.check(status) + } + + pub(crate) fn get_variants(&self) -> Result> { + let mut list: *mut flModelList = std::ptr::null_mut(); + let status = unsafe { (self.api.model_api().GetVariants)(self.ptr, &mut list) }; + self.api.check(status)?; + Ok(collect_models(&self.api, list)) + } + + /// Download the model, optionally reporting progress (0.0–100.0) and + /// honouring a cancellation flag. Blocking. + pub(crate) fn download( + &self, + progress: Option>, + cancel_flag: Option>, + ) -> Result<()> { + let mut ctx = DownloadCtx { + progress, + cancel_flag, + cancelled: false, + }; + let callback: flProgressCallback = Some(download_trampoline); + let user_data = &mut ctx as *mut DownloadCtx as *mut std::ffi::c_void; + // SAFETY: `ctx` outlives the blocking native call; the trampoline only + // dereferences `user_data` for the duration of the call. + let status = unsafe { (self.api.model_api().Download)(self.ptr, callback, user_data) }; + if ctx.cancelled { + self.api.check(status).ok(); + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".into(), + }); + } + self.api.check(status) + } +} + +struct DownloadCtx { + progress: Option>, + cancel_flag: Option>, + cancelled: bool, +} + +unsafe extern "C" fn download_trampoline( + value: f32, + user_data: *mut std::ffi::c_void, +) -> std::os::raw::c_int { + if user_data.is_null() { + return 0; + } + let result = catch_unwind(AssertUnwindSafe(|| { + let ctx = &mut *(user_data as *mut DownloadCtx); + if ctx + .cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)) + { + ctx.cancelled = true; + return 1; + } + if let Some(cb) = ctx.progress.as_mut() { + cb(value as f64); + } + 0 + })); + result.unwrap_or(1) +} + +/// Eagerly extract all model handles from a `flModelList`, then release the list. +/// +/// The returned handles remain valid for the catalog's lifetime. +pub(crate) fn collect_models(api: &Arc, list: *mut flModelList) -> Vec { + if list.is_null() { + return Vec::new(); + } + let root = api.root(); + // SAFETY: `list` is a valid list handle we own until released below. + let size = unsafe { (root.ModelList_Size)(list) }; + let mut out = Vec::with_capacity(size); + for i in 0..size { + let m = unsafe { (root.ModelList_GetAt)(list, i) }; + if !m.is_null() { + out.push(NativeModel::new(Arc::clone(api), m)); + } + } + unsafe { (root.ModelList_Release)(list) }; + out +} + +/// A borrowed handle to a manager-owned `flCatalog`. +#[derive(Clone)] +pub(crate) struct NativeCatalog { + pub(crate) api: Arc, + pub(crate) ptr: *mut flCatalog, +} + +// SAFETY: the catalog handle is owned by the manager (process-lifetime) and the +// native implementation is thread-safe. +unsafe impl Send for NativeCatalog {} +unsafe impl Sync for NativeCatalog {} + +impl NativeCatalog { + pub(crate) fn new(api: Arc, ptr: *mut flCatalog) -> Self { + Self { api, ptr } + } + + pub(crate) fn name(&self) -> Result { + let mut name: *const std::os::raw::c_char = std::ptr::null(); + let status = unsafe { (self.api.catalog_api().GetName)(self.ptr, &mut name) }; + self.api.check(status)?; + Ok(unsafe { cstr_to_string(name) }.unwrap_or_default()) + } + + fn list( + &self, + f: unsafe extern "system" fn(*const flCatalog, *mut *mut flModelList) -> flStatusPtr, + ) -> Result> { + let mut list: *mut flModelList = std::ptr::null_mut(); + let status = unsafe { f(self.ptr, &mut list) }; + self.api.check(status)?; + Ok(collect_models(&self.api, list)) + } + + pub(crate) fn get_models(&self) -> Result> { + self.list(self.api.catalog_api().GetModels) + } + + pub(crate) fn get_cached_models(&self) -> Result> { + self.list(self.api.catalog_api().GetCachedModels) + } + + pub(crate) fn get_loaded_models(&self) -> Result> { + self.list(self.api.catalog_api().GetLoadedModels) + } + + fn lookup( + &self, + f: unsafe extern "system" fn( + *const flCatalog, + *const std::os::raw::c_char, + *mut *mut flModel, + ) -> flStatusPtr, + key: &str, + ) -> Result> { + let c = super::api::to_cstring(key)?; + let mut model: *mut flModel = std::ptr::null_mut(); + let status = unsafe { f(self.ptr, c.as_ptr(), &mut model) }; + self.api.check(status)?; + if model.is_null() { + Ok(None) + } else { + Ok(Some(NativeModel::new(Arc::clone(&self.api), model))) + } + } + + pub(crate) fn get_model(&self, alias: &str) -> Result> { + self.lookup(self.api.catalog_api().GetModel, alias) + } + + pub(crate) fn get_model_variant(&self, model_id: &str) -> Result> { + self.lookup(self.api.catalog_api().GetModelVariant, model_id) + } + + pub(crate) fn get_latest_version(&self, model: &NativeModel) -> Result { + let mut latest: *mut flModel = std::ptr::null_mut(); + let status = + unsafe { (self.api.catalog_api().GetLatestVersion)(self.ptr, model.ptr, &mut latest) }; + self.api.check(status)?; + if latest.is_null() { + return Err(FoundryLocalError::ModelOperation { + reason: "Catalog returned no latest version for the model.".into(), + }); + } + Ok(NativeModel::new(Arc::clone(&self.api), latest)) + } +} diff --git a/sdk_v2/rust/src/detail/session.rs b/sdk_v2/rust/src/detail/session.rs new file mode 100644 index 000000000..53be90965 --- /dev/null +++ b/sdk_v2/rust/src/detail/session.rs @@ -0,0 +1,311 @@ +//! Safe wrappers over native `flRequest`, `flResponse`, and `flSession`, plus the +//! OpenAI-JSON request/response and streaming bridges used by the OpenAI facade. + +use std::os::raw::c_int; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::ptr; +use std::sync::Arc; + +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; + +use super::api::Api; +use super::ffi::*; +use super::items::{make_bytes_item, make_openai_json_item, read_text_item}; +use super::native::NativeModel; +use crate::error::{FoundryLocalError, Result}; + +/// Per-item transform applied to streamed TEXT payloads before they are emitted. +pub(crate) type StreamTransform = Box Option + Send>; + +// ── Request ────────────────────────────────────────────────────────────────── + +pub(crate) struct NativeRequest { + api: Arc, + ptr: *mut flRequest, +} + +impl NativeRequest { + pub(crate) fn new(api: Arc) -> Result { + let mut ptr: *mut flRequest = ptr::null_mut(); + api.check(unsafe { (api.inference_api().Request_Create)(&mut ptr) })?; + Ok(Self { api, ptr }) + } + + /// Add an item, transferring ownership to the request. + pub(crate) fn add_item(&self, item: *mut flItem, take_ownership: bool) -> Result<()> { + let status = + unsafe { (self.api.inference_api().Request_AddItem)(self.ptr, item, take_ownership) }; + self.api.check(status) + } +} + +impl Drop for NativeRequest { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { (self.api.inference_api().Request_Release)(self.ptr) }; + self.ptr = ptr::null_mut(); + } + } +} + +// ── Response ───────────────────────────────────────────────────────────────── + +pub(crate) struct NativeResponse { + api: Arc, + ptr: *mut flResponse, +} + +impl NativeResponse { + pub(crate) fn item_count(&self) -> usize { + unsafe { (self.api.inference_api().Response_GetItemCount)(self.ptr) } + } + + /// Read the text payload of the response item at `idx` (if it is a TEXT item). + pub(crate) fn item_text(&self, idx: usize) -> Option { + let mut item: *const flItem = ptr::null(); + let status = + unsafe { (self.api.inference_api().Response_GetItem)(self.ptr, idx, &mut item) }; + if self.api.check(status).is_err() { + return None; + } + unsafe { read_text_item(&self.api, item) } + } +} + +impl Drop for NativeResponse { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { (self.api.inference_api().Response_Release)(self.ptr) }; + self.ptr = ptr::null_mut(); + } + } +} + +// ── ItemQueue ──────────────────────────────────────────────────────────────── + +/// Owning wrapper around a native input `flItemQueue`. +/// +/// In the C ABI an `ItemQueue` *is* an `Item` (same pointer, castable), so it can +/// be added to a request directly and released via `Item_Release`. +pub(crate) struct NativeItemQueue { + api: Arc, + ptr: *mut flItemQueue, +} + +// SAFETY: the native item queue is documented as thread-safe (multi-producer / +// multi-consumer); pushing from any thread is supported. +unsafe impl Send for NativeItemQueue {} +unsafe impl Sync for NativeItemQueue {} + +impl NativeItemQueue { + pub(crate) fn new(api: Arc) -> Result { + let mut ptr: *mut flItemQueue = ptr::null_mut(); + api.check(unsafe { (api.item_api().ItemQueue_Create)(&mut ptr) })?; + Ok(Self { api, ptr }) + } + + /// The queue as an `flItem*` (for adding to a request). + pub(crate) fn as_item_ptr(&self) -> *mut flItem { + self.ptr as *mut flItem + } + + /// Push an item, transferring ownership into the queue. + pub(crate) fn push_item(&self, item: *mut flItem) -> Result<()> { + self.api + .check(unsafe { (self.api.item_api().ItemQueue_Push)(self.ptr, item) }) + } + + /// Create a BYTES item from `data` and push it into the queue. + pub(crate) fn push_bytes(&self, data: &[u8], item_type: flItemType) -> Result<()> { + let item = make_bytes_item(&self.api, data, item_type)?; + if let Err(e) = self.push_item(item) { + // Push failed — we still own the item, so release it. + unsafe { (self.api.item_api().Item_Release)(item) }; + return Err(e); + } + Ok(()) + } + + /// Signal that no more items will be pushed. + pub(crate) fn mark_finished(&self) { + unsafe { (self.api.item_api().ItemQueue_MarkFinished)(self.ptr) }; + } +} + +impl Drop for NativeItemQueue { + fn drop(&mut self) { + if !self.ptr.is_null() { + // The queue is an Item; release via the polymorphic Item destructor. + unsafe { (self.api.item_api().Item_Release)(self.ptr as *mut flItem) }; + self.ptr = ptr::null_mut(); + } + } +} + +// ── Session ────────────────────────────────────────────────────────────────── + +pub(crate) struct NativeSession { + pub(crate) api: Arc, + ptr: *mut flSession, +} + +// SAFETY: a session is used from a single worker at a time; the native layer is +// thread-safe for the create/process/release lifecycle used here. +unsafe impl Send for NativeSession {} +unsafe impl Sync for NativeSession {} + +impl NativeSession { + /// Create a session bound to the given model variant. + pub(crate) fn create(model: &NativeModel) -> Result { + let api = Arc::clone(&model.api); + let mut ptr: *mut flSession = ptr::null_mut(); + api.check(unsafe { (api.inference_api().Session_Create)(model.ptr, &mut ptr) })?; + Ok(Self { api, ptr }) + } + + pub(crate) fn set_streaming_callback( + &self, + callback: flStreamingCallback, + user_data: *mut std::ffi::c_void, + ) -> Result<()> { + let status = unsafe { + (self.api.inference_api().Session_SetStreamingCallback)(self.ptr, callback, user_data) + }; + self.api.check(status) + } + + pub(crate) fn process_request(&self, request: &NativeRequest) -> Result { + let mut resp: *mut flResponse = ptr::null_mut(); + let status = unsafe { + (self.api.inference_api().Session_ProcessRequest)(self.ptr, request.ptr, &mut resp) + }; + self.api.check(status)?; + Ok(NativeResponse { + api: Arc::clone(&self.api), + ptr: resp, + }) + } + + /// Run a non-streaming OpenAI-JSON request and return the response payload. + /// + /// The request JSON is sent as a single `OPENAI_JSON` TEXT item; the response + /// payload is the text of the first response item. Blocking. + pub(crate) fn run_openai_json(&self, request_json: &str) -> Result { + let request = NativeRequest::new(Arc::clone(&self.api))?; + let item = make_openai_json_item(&self.api, request_json)?; + request.add_item(item, true)?; + let response = self.process_request(&request)?; + if response.item_count() == 0 { + return Err(FoundryLocalError::CommandExecution { + reason: "Native response contained no items".into(), + }); + } + response + .item_text(0) + .ok_or_else(|| FoundryLocalError::CommandExecution { + reason: "Native response item was not readable text".into(), + }) + } +} + +impl Drop for NativeSession { + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { (self.api.inference_api().Session_Release)(self.ptr) }; + self.ptr = ptr::null_mut(); + } + } +} + +// ── Streaming bridge ───────────────────────────────────────────────────────── + +struct StreamCtx { + api: Arc, + tx: UnboundedSender>, + transform: StreamTransform, +} + +unsafe extern "C" fn stream_trampoline( + data: flStreamingCallbackData, + user_data: *mut std::ffi::c_void, +) -> c_int { + if user_data.is_null() { + return 0; + } + let result = catch_unwind(AssertUnwindSafe(|| { + let ctx = &*(user_data as *const StreamCtx); + let queue = data.item_queue; + if queue.is_null() { + return 0; + } + let item_api = ctx.api.item_api(); + loop { + let mut item: *mut flItem = ptr::null_mut(); + let popped = (item_api.ItemQueue_TryPop)(queue, &mut item); + if !popped { + break; + } + if item.is_null() { + continue; + } + // Ownership of `item` transferred to us — read then release. + let text = read_text_item(&ctx.api, item); + (item_api.Item_Release)(item); + if let Some(text) = text { + if let Some(transformed) = (ctx.transform)(text) { + if ctx.tx.send(Ok(transformed)).is_err() { + return 1; // receiver dropped — cancel generation + } + } + } + } + 0 + })); + result.unwrap_or(1) +} + +/// Run a streaming OpenAI-JSON request, returning a channel of transformed +/// per-item TEXT payloads. +/// +/// `transform` is applied to each streamed item's text (return `None` to skip +/// an item). The session is created and processed on a blocking worker thread; +/// the channel closes when generation completes or errors. +pub(crate) fn run_openai_json_streaming( + session: NativeSession, + request_json: String, + transform: StreamTransform, +) -> UnboundedReceiver> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + + tokio::task::spawn_blocking(move || { + let ctx = Box::new(StreamCtx { + api: Arc::clone(&session.api), + tx: tx.clone(), + transform, + }); + let ctx_ptr = &*ctx as *const StreamCtx as *mut std::ffi::c_void; + + if let Err(e) = session.set_streaming_callback(Some(stream_trampoline), ctx_ptr) { + let _ = tx.send(Err(e)); + return; + } + + let run = (|| -> Result<()> { + let request = NativeRequest::new(Arc::clone(&session.api))?; + let item = make_openai_json_item(&session.api, &request_json)?; + request.add_item(item, true)?; + let _response = session.process_request(&request)?; + Ok(()) + })(); + if let Err(e) = run { + let _ = tx.send(Err(e)); + } + + // Uninstall the callback before the context/session are dropped. + let _ = session.set_streaming_callback(None, ptr::null_mut()); + drop(ctx); + drop(session); + }); + + rx +} diff --git a/sdk_v2/rust/src/detail/task.rs b/sdk_v2/rust/src/detail/task.rs new file mode 100644 index 000000000..36b5679e6 --- /dev/null +++ b/sdk_v2/rust/src/detail/task.rs @@ -0,0 +1,16 @@ +//! Shared async helper for running blocking native calls. + +use crate::error::{FoundryLocalError, Result}; + +/// Run a blocking native operation on the tokio blocking pool. +pub(crate) async fn spawn_blocking(f: F) -> Result +where + F: FnOnce() -> Result + Send + 'static, + T: Send + 'static, +{ + tokio::task::spawn_blocking(f) + .await + .map_err(|e| FoundryLocalError::Internal { + reason: format!("blocking task join error: {e}"), + })? +} diff --git a/sdk_v2/rust/src/error.rs b/sdk_v2/rust/src/error.rs new file mode 100644 index 000000000..c99dbfbf1 --- /dev/null +++ b/sdk_v2/rust/src/error.rs @@ -0,0 +1,36 @@ +use thiserror::Error; + +/// Errors that can occur when using the Foundry Local SDK. +#[derive(Debug, Error)] +pub enum FoundryLocalError { + /// The native core library could not be loaded. + #[error("library load error: {reason}")] + LibraryLoad { reason: String }, + /// A command executed against the native core returned an error. + #[error("command execution error: {reason}")] + CommandExecution { reason: String }, + /// The provided configuration is invalid. + #[error("invalid configuration: {reason}")] + InvalidConfiguration { reason: String }, + /// A model operation failed (load, unload, download, etc.). + #[error("model operation error: {reason}")] + ModelOperation { reason: String }, + /// An HTTP request to the external service failed. + #[error("HTTP request error: {0}")] + HttpRequest(#[from] reqwest::Error), + /// Serialization or deserialization of JSON data failed. + #[error("serialization error: {0}")] + Serialization(#[from] serde_json::Error), + /// A validation check on user-supplied input failed. + #[error("validation error: {reason}")] + Validation { reason: String }, + /// An I/O error occurred. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + /// An internal SDK error (e.g. poisoned lock). + #[error("internal error: {reason}")] + Internal { reason: String }, +} + +/// Convenience alias used throughout the SDK. +pub type Result = std::result::Result; diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs new file mode 100644 index 000000000..0ccd3b656 --- /dev/null +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -0,0 +1,351 @@ +//! Top-level entry point for the Foundry Local SDK. +//! +//! [`FoundryLocalManager`] is a singleton that initialises the native core +//! library, provides access to the model [`Catalog`], and can start / stop +//! the local web service. + +use std::sync::atomic::AtomicBool; +use std::sync::{Arc, Mutex, OnceLock}; + +use crate::catalog::Catalog; +use crate::configuration::{FoundryLocalConfig, Logger}; +use crate::detail::api::Api; +use crate::detail::manager::{EpProgressCallback, NativeManager}; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; +use crate::types::{EpDownloadResult, EpInfo}; + +/// Global singleton holder — only stores a successfully initialised manager. +static INSTANCE: OnceLock = OnceLock::new(); +/// Guard to ensure only one thread attempts initialisation at a time. +static INIT_GUARD: Mutex<()> = Mutex::new(()); + +/// Primary entry point for interacting with Foundry Local. +/// +/// Created once via [`FoundryLocalManager::create`]; subsequent calls return +/// the existing instance. +pub struct FoundryLocalManager { + native: Arc, + catalog: Catalog, + urls: Mutex>, + /// Application logger (stub — not yet wired into the native core). + _logger: Option>, +} + +type EpDownloadProgressCallback = Box; + +/// Builder for configuring and running execution provider downloads. +pub struct EpDownloadBuilder<'a> { + manager: &'a FoundryLocalManager, + names: Option>, + progress_callback: Option, + cancel_flag: Option>, +} + +impl<'a> EpDownloadBuilder<'a> { + fn new(manager: &'a FoundryLocalManager) -> Self { + Self { + manager, + names: None, + progress_callback: None, + cancel_flag: None, + } + } + + /// Download only the named execution providers. + pub fn names(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.names = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Report per-EP download progress as `(ep_name, percent)`. + pub fn progress(mut self, callback: F) -> Self + where + F: FnMut(&str, f64) + Send + 'static, + { + self.progress_callback = Some(Box::new(callback)); + self + } + + /// Cancel the download when `cancel_flag` is set to `true`. + pub fn cancel(mut self, cancel_flag: Arc) -> Self { + self.cancel_flag = Some(cancel_flag); + self + } + + /// Run the configured execution provider download. + pub async fn run(self) -> Result { + self.manager + .download_and_register_eps_impl(self.names, self.progress_callback, self.cancel_flag) + .await + } +} + +impl FoundryLocalManager { + /// Initialise the SDK. + /// + /// The first call creates the singleton, loads the native library, runs + /// the initialisation, and builds the model catalog. Subsequent calls + /// return a reference to the same instance (the provided config is + /// ignored after the first call). + pub fn create(config: FoundryLocalConfig) -> Result<&'static Self> { + // Fast path: singleton already initialised. + if let Some(manager) = INSTANCE.get() { + return Ok(manager); + } + + // Slow path: acquire init guard so only one thread attempts initialisation. + let _guard = INIT_GUARD.lock().map_err(|_| FoundryLocalError::Internal { + reason: "initialisation guard poisoned".into(), + })?; + + // Double-check after acquiring the lock. + if let Some(manager) = INSTANCE.get() { + return Ok(manager); + } + + let mut config = config; + let api = Arc::new(Api::load(config.library_path_ref())?); + let logger = config.take_logger(); + let native_config = config.build_native(&api)?; + + let native = Arc::new(NativeManager::create( + Arc::clone(&api), + native_config.as_ptr(), + )?); + // `native_config` is dropped here; Manager_Create has copied what it needs. + + let catalog_ptr = native.catalog_ptr()?; + let catalog = Catalog::new(Arc::clone(&api), catalog_ptr)?; + + let manager = FoundryLocalManager { + native, + catalog, + urls: Mutex::new(Vec::new()), + _logger: logger, + }; + + // Only cache on success — failures allow the next caller to retry. + match INSTANCE.set(manager) { + Ok(()) => { + // Register a process-exit hook to release the native manager + // before the library's C++ static destructors run. Without + // this, the native dtor chain (Manager -> logger -> spdlog + // flush) can fire after spdlog's global thread pool is gone, + // raising `mutex lock failed` and aborting the process. The + // hook mirrors the Python SDK's `atexit` teardown. + register_exit_teardown(); + Ok(INSTANCE.get().unwrap()) + } + // Another thread beat us — return their instance. + Err(_) => Ok(INSTANCE.get().unwrap()), + } + } + + /// Access the model catalog. + pub fn catalog(&self) -> &Catalog { + &self.catalog + } + + /// Begin a graceful shutdown of the local engine. + /// + /// Stops the web service, prevents new model loads, stops existing + /// sessions, and unloads models. Idempotent and safe to call from any + /// thread. + /// + /// Calling this is optional: the native manager is released automatically + /// at process exit. Use it when you want to deterministically wind the + /// engine down before exiting. After calling `shutdown`, the manager should + /// not be used for further inference. + pub fn shutdown(&self) -> Result<()> { + self.native.shutdown() + } + + /// URLs that the local web service is listening on. + /// + /// Empty until [`Self::start_web_service`] has been called. + pub fn urls(&self) -> Result> { + let lock = self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })?; + Ok(lock.clone()) + } + + /// Start the local web service. + /// + /// The listening URLs are stored internally and can be retrieved via + /// [`Self::urls`] after this method returns. + pub async fn start_web_service(&self) -> Result<()> { + let native = Arc::clone(&self.native); + let urls = spawn_blocking(move || { + native.web_service_start()?; + native.web_service_urls() + }) + .await?; + *self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })? = urls; + Ok(()) + } + + /// Stop the local web service. + pub async fn stop_web_service(&self) -> Result<()> { + let native = Arc::clone(&self.native); + spawn_blocking(move || native.web_service_stop()).await?; + self.urls + .lock() + .map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })? + .clear(); + Ok(()) + } + + /// Discover available execution providers and their registration status. + pub fn discover_eps(&self) -> Result> { + self.native.discover_eps() + } + + /// Download and register execution providers. + /// + /// If `names` is `None` or empty, all available EPs are downloaded. + /// Otherwise only the named EPs are downloaded and registered. + pub async fn download_and_register_eps( + &self, + names: Option<&[&str]>, + ) -> Result { + let names = names.map(|n| n.iter().map(|s| s.to_string()).collect::>()); + self.download_and_register_eps_impl(names, None, None).await + } + + /// Download and register execution providers, reporting per-EP progress. + /// + /// If `names` is `None` or empty, all available EPs are downloaded. + /// Otherwise only the named EPs are downloaded and registered. + /// + /// `progress_callback` receives `(ep_name, percent)` where `percent` + /// ranges from 0.0 to 100.0 as each EP downloads. + pub async fn download_and_register_eps_with_progress( + &self, + names: Option<&[&str]>, + progress_callback: F, + ) -> Result + where + F: FnMut(&str, f64) + Send + 'static, + { + let names = names.map(|n| n.iter().map(|s| s.to_string()).collect::>()); + self.download_and_register_eps_impl(names, Some(Box::new(progress_callback)), None) + .await + } + + /// Configure and run execution provider downloads with a builder. + /// + /// Use this for call sites that need names, progress, cancellation, or + /// future download options. + pub fn download_and_register_eps_builder(&self) -> EpDownloadBuilder<'_> { + EpDownloadBuilder::new(self) + } + + async fn download_and_register_eps_impl( + &self, + names: Option>, + progress_callback: Option, + cancel_flag: Option>, + ) -> Result { + let native = Arc::clone(&self.native); + + // Snapshot requested EP names (default: all discoverable). + let requested: Vec = match &names { + Some(n) if !n.is_empty() => n.clone(), + _ => native + .discover_eps() + .unwrap_or_default() + .into_iter() + .map(|e| e.name) + .collect(), + }; + + let (message, after) = spawn_blocking(move || { + let name_refs: Option> = names + .as_ref() + .map(|n| n.iter().map(String::as_str).collect()); + let progress: Option = + progress_callback.map(|cb| cb as EpProgressCallback); + let message = + native.download_and_register_eps(name_refs.as_deref(), progress, cancel_flag); + // Re-query registration state to synthesise the per-EP result. + let after = native.discover_eps().unwrap_or_default(); + Ok::<(Option, Vec), FoundryLocalError>((message, after)) + }) + .await?; + + let registered_eps: Vec = requested + .iter() + .filter(|name| after.iter().any(|e| &e.name == *name && e.is_registered)) + .cloned() + .collect(); + let failed_eps: Vec = requested + .iter() + .filter(|name| !registered_eps.contains(*name)) + .cloned() + .collect(); + + let success = message.is_none() && failed_eps.is_empty(); + let status = match &message { + None => "All requested execution providers were registered successfully.".to_string(), + Some(msg) if msg.is_empty() => { + "One or more execution providers failed to register.".to_string() + } + Some(msg) => msg.clone(), + }; + + let result = EpDownloadResult { + success, + status, + registered_eps, + failed_eps, + }; + + // Invalidate the catalog cache if any EP was newly registered so the next + // access re-fetches models with the updated set of available EPs. + if result.success || !result.registered_eps.is_empty() { + let _ = self.catalog.update_models().await; + } + + Ok(result) + } +} + +/// Register the process-exit teardown hook exactly once. +/// +/// Uses the C runtime's `atexit`, which runs registered handlers in LIFO order. +/// Because the `foundry_local` library is `dlopen`ed during `create()` (before +/// this registration), its static destructors are registered earlier and +/// therefore run *after* our hook — giving us the window to release the native +/// manager while the engine's globals (e.g. the spdlog thread pool) are still +/// alive. +fn register_exit_teardown() { + extern "C" { + fn atexit(cb: extern "C" fn()) -> std::os::raw::c_int; + } + // SAFETY: `exit_teardown` is a valid `extern "C"` function with no captured + // state; registering it with the C runtime is sound. + unsafe { + atexit(exit_teardown); + } +} + +/// `atexit` callback: release the singleton's native manager before the +/// library's C++ static destructors run. Panic-safe (a panic must never unwind +/// across the C runtime boundary). +extern "C" fn exit_teardown() { + let _ = std::panic::catch_unwind(|| { + if let Some(manager) = INSTANCE.get() { + manager.native.teardown(); + } + }); +} diff --git a/sdk_v2/rust/src/lib.rs b/sdk_v2/rust/src/lib.rs new file mode 100644 index 000000000..73c4180a0 --- /dev/null +++ b/sdk_v2/rust/src/lib.rs @@ -0,0 +1,45 @@ +//! Foundry Local Rust SDK +//! +//! Local AI model inference powered by the Foundry Local Core engine. + +mod catalog; +mod configuration; +mod error; +mod foundry_local_manager; +mod types; + +pub(crate) mod detail; +pub mod openai; + +pub use self::catalog::Catalog; +pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; +pub use self::detail::model::{DownloadBuilder, Model}; +pub use self::error::FoundryLocalError; +pub use self::foundry_local_manager::{EpDownloadBuilder, FoundryLocalManager}; +pub use self::types::{ + ChatResponseFormat, ChatToolChoice, DeviceType, EpDownloadResult, EpInfo, ModelInfo, + ModelSettings, Parameter, PromptTemplate, Runtime, +}; + +// Re-export OpenAI request types so callers can construct typed messages. +pub use async_openai::types::chat::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, + ChatCompletionToolChoiceOption, ChatCompletionTools, FunctionObject, +}; + +// Re-export OpenAI response types for convenience. +pub use crate::openai::{ + AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, ContentPart, + CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, TranscriptionSegment, + TranscriptionWord, +}; +pub use async_openai::types::chat::{ + ChatChoice, ChatChoiceStream, ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCalls, + ChatCompletionResponseMessage, ChatCompletionStreamResponseDelta, CompletionUsage, + CreateChatCompletionResponse, CreateChatCompletionStreamResponse, FinishReason, FunctionCall, + FunctionCallStream, +}; diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs new file mode 100644 index 000000000..ed631b7ed --- /dev/null +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -0,0 +1,210 @@ +//! OpenAI-compatible audio transcription client. + +use std::path::Path; + +use serde_json::{json, Value}; + +use crate::detail::native::NativeModel; +use crate::detail::session::{run_openai_json_streaming, NativeSession}; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; + +use super::json_stream::JsonStream; +use super::live_audio_session::LiveAudioTranscriptionSession; + +/// A segment of a transcription, as returned by the OpenAI-compatible API. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct TranscriptionSegment { + /// Segment index. + pub id: i32, + /// Seek offset of the segment. + pub seek: i32, + /// Start time of the segment in seconds. + pub start: f64, + /// End time of the segment in seconds. + pub end: f64, + /// Transcribed text of the segment. + pub text: String, + /// Token IDs corresponding to the text. + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens: Option>, + /// Temperature used for generation. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// Average log probability of the segment. + #[serde(skip_serializing_if = "Option::is_none")] + pub avg_logprob: Option, + /// Compression ratio of the segment. + #[serde(skip_serializing_if = "Option::is_none")] + pub compression_ratio: Option, + /// Probability of no speech in the segment. + #[serde(skip_serializing_if = "Option::is_none")] + pub no_speech_prob: Option, +} + +/// A word with timing information, as returned by the OpenAI-compatible API. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct TranscriptionWord { + /// The word text. + pub word: String, + /// Start time of the word in seconds. + pub start: f64, + /// End time of the word in seconds. + pub end: f64, +} + +/// OpenAI-compatible audio transcription response. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct AudioTranscriptionResponse { + /// The transcribed text. + pub text: String, + /// The language of the input audio (if detected). + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + /// Duration of the input audio in seconds (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// Segments of the transcription (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + /// Words with timestamps (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, +} + +/// Tuning knobs for audio transcription requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let client = model.create_audio_client() +/// .language("en") +/// .temperature(0.2); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct AudioClientSettings { + language: Option, + temperature: Option, +} + +impl AudioClientSettings { + fn serialize(&self, model_id: &str, file_name: &str) -> Value { + let mut map = serde_json::Map::new(); + + map.insert("model".into(), json!(model_id)); + map.insert("filename".into(), json!(file_name)); + + if let Some(ref lang) = self.language { + map.insert("language".into(), json!(lang)); + } + if let Some(temp) = self.temperature { + map.insert("temperature".into(), json!(temp)); + } + + Value::Object(map) + } +} + +/// A stream of [`AudioTranscriptionResponse`] chunks. +/// +/// Returned by [`AudioClient::transcribe_streaming`]. +pub type AudioTranscriptionStream = JsonStream; + +/// Client for OpenAI-compatible audio transcription backed by a local model. +pub struct AudioClient { + model_id: String, + model: NativeModel, + settings: AudioClientSettings, +} + +impl AudioClient { + pub(crate) fn new(model_id: &str, model: NativeModel) -> Self { + Self { + model_id: model_id.to_owned(), + model, + settings: AudioClientSettings::default(), + } + } + + /// Set the language hint for transcription. + pub fn language(mut self, lang: impl Into) -> Self { + self.settings.language = Some(lang.into()); + self + } + + /// Set the sampling temperature. + pub fn temperature(mut self, v: f64) -> Self { + self.settings.temperature = Some(v); + self + } + + /// Transcribe an audio file. + pub async fn transcribe( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; + Self::validate_path(path_str)?; + + let request = self.settings.serialize(&self.model_id, path_str); + let request_json = serde_json::to_string(&request)?; + let model = self.model.clone(); + + let raw = spawn_blocking(move || { + let session = NativeSession::create(&model)?; + session.run_openai_json(&request_json) + }) + .await?; + + let parsed: AudioTranscriptionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Transcribe an audio file with streaming results, returning an + /// [`AudioTranscriptionStream`]. + pub async fn transcribe_streaming( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; + Self::validate_path(path_str)?; + + let request = self.settings.serialize(&self.model_id, path_str); + let request_json = serde_json::to_string(&request)?; + let model = self.model.clone(); + + let session = spawn_blocking(move || NativeSession::create(&model)).await?; + let rx = run_openai_json_streaming(session, request_json, Box::new(Some)); + Ok(AudioTranscriptionStream::new(rx)) + } + + /// Create a [`LiveAudioTranscriptionSession`] for real-time audio + /// streaming transcription. + /// + /// Configure the session's [`settings`](LiveAudioTranscriptionSession::settings) + /// before calling [`start`](LiveAudioTranscriptionSession::start). + pub fn create_live_transcription_session(&self) -> LiveAudioTranscriptionSession { + LiveAudioTranscriptionSession::new(&self.model_id, self.model.clone()) + } + + fn validate_path(path: &str) -> Result<()> { + if path.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "audio_file_path must be a non-empty string".into(), + }); + } + Ok(()) + } +} diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs new file mode 100644 index 000000000..ca779e92c --- /dev/null +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -0,0 +1,317 @@ +//! OpenAI-compatible chat completions client. + +use std::collections::HashMap; + +use async_openai::types::chat::{ + ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +}; +use serde_json::{json, Value}; + +use crate::detail::native::NativeModel; +use crate::detail::session::{run_openai_json_streaming, NativeSession}; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; +use crate::types::{ChatResponseFormat, ChatToolChoice}; + +use super::json_stream::JsonStream; + +/// Tuning knobs for chat completion requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let client = model.create_chat_client() +/// .temperature(0.7) +/// .max_tokens(256); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ChatClientSettings { + frequency_penalty: Option, + max_tokens: Option, + n: Option, + temperature: Option, + presence_penalty: Option, + top_p: Option, + top_k: Option, + random_seed: Option, + response_format: Option, + tool_choice: Option, +} + +impl ChatClientSettings { + fn serialize(&self) -> Value { + let mut map = serde_json::Map::new(); + + if let Some(v) = self.frequency_penalty { + map.insert("frequency_penalty".into(), json!(v)); + } + if let Some(v) = self.max_tokens { + map.insert("max_tokens".into(), json!(v)); + } + if let Some(v) = self.n { + map.insert("n".into(), json!(v)); + } + if let Some(v) = self.presence_penalty { + map.insert("presence_penalty".into(), json!(v)); + } + if let Some(v) = self.temperature { + map.insert("temperature".into(), json!(v)); + } + if let Some(v) = self.top_p { + map.insert("top_p".into(), json!(v)); + } + + if let Some(ref rf) = self.response_format { + let mut rf_map = serde_json::Map::new(); + match rf { + ChatResponseFormat::Text => { + rf_map.insert("type".into(), json!("text")); + } + ChatResponseFormat::JsonObject => { + rf_map.insert("type".into(), json!("json_object")); + } + ChatResponseFormat::JsonSchema(schema) => { + rf_map.insert("type".into(), json!("json_schema")); + rf_map.insert("jsonSchema".into(), json!(schema)); + } + ChatResponseFormat::LarkGrammar(grammar) => { + rf_map.insert("type".into(), json!("lark_grammar")); + rf_map.insert("larkGrammar".into(), json!(grammar)); + } + } + map.insert("response_format".into(), Value::Object(rf_map)); + } + + if let Some(ref tc) = self.tool_choice { + let mut tc_map = serde_json::Map::new(); + match tc { + ChatToolChoice::None => { + tc_map.insert("type".into(), json!("none")); + } + ChatToolChoice::Auto => { + tc_map.insert("type".into(), json!("auto")); + } + ChatToolChoice::Required => { + tc_map.insert("type".into(), json!("required")); + } + ChatToolChoice::Function(name) => { + tc_map.insert("type".into(), json!("function")); + tc_map.insert("name".into(), json!(name)); + } + } + map.insert("tool_choice".into(), Value::Object(tc_map)); + } + + // Foundry-specific metadata for settings that don't map directly to + // the OpenAI spec. + let mut metadata: HashMap = HashMap::new(); + if let Some(k) = self.top_k { + metadata.insert("top_k".into(), k.to_string()); + } + if let Some(s) = self.random_seed { + metadata.insert("random_seed".into(), s.to_string()); + } + if !metadata.is_empty() { + map.insert("metadata".into(), json!(metadata)); + } + + Value::Object(map) + } +} + +/// A stream of [`CreateChatCompletionStreamResponse`] chunks. +/// +/// Returned by [`ChatClient::complete_streaming_chat`]. +pub type ChatCompletionStream = JsonStream; + +/// Client for OpenAI-compatible chat completions backed by a local model. +pub struct ChatClient { + model_id: String, + model: NativeModel, + settings: ChatClientSettings, +} + +impl ChatClient { + pub(crate) fn new(model_id: &str, model: NativeModel) -> Self { + Self { + model_id: model_id.to_owned(), + model, + settings: ChatClientSettings::default(), + } + } + + /// Set the frequency penalty. + pub fn frequency_penalty(mut self, v: f64) -> Self { + self.settings.frequency_penalty = Some(v); + self + } + + /// Set the maximum number of tokens to generate. + pub fn max_tokens(mut self, v: u32) -> Self { + self.settings.max_tokens = Some(v); + self + } + + /// Set the number of completions to generate. + pub fn n(mut self, v: u32) -> Self { + self.settings.n = Some(v); + self + } + + /// Set the sampling temperature. + pub fn temperature(mut self, v: f64) -> Self { + self.settings.temperature = Some(v); + self + } + + /// Set the presence penalty. + pub fn presence_penalty(mut self, v: f64) -> Self { + self.settings.presence_penalty = Some(v); + self + } + + /// Set the nucleus sampling probability. + pub fn top_p(mut self, v: f64) -> Self { + self.settings.top_p = Some(v); + self + } + + /// Set the top-k sampling parameter (Foundry extension). + pub fn top_k(mut self, v: u32) -> Self { + self.settings.top_k = Some(v); + self + } + + /// Set the random seed for reproducible results (Foundry extension). + pub fn random_seed(mut self, v: u64) -> Self { + self.settings.random_seed = Some(v); + self + } + + /// Set the desired response format. + pub fn response_format(mut self, v: ChatResponseFormat) -> Self { + self.settings.response_format = Some(v); + self + } + + /// Set the tool choice strategy. + pub fn tool_choice(mut self, v: ChatToolChoice) -> Self { + self.settings.tool_choice = Some(v); + self + } + + /// Perform a non-streaming chat completion. + pub async fn complete_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); + } + + let request = self.build_request(messages, tools, false)?; + let request_json = serde_json::to_string(&request)?; + let model = self.model.clone(); + + let raw = spawn_blocking(move || { + let session = NativeSession::create(&model)?; + session.run_openai_json(&request_json) + }) + .await?; + + let parsed: CreateChatCompletionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Perform a streaming chat completion, returning a [`ChatCompletionStream`]. + /// + /// Use the stream with `futures_core::StreamExt::next()` or + /// `tokio_stream::StreamExt::next()`. + pub async fn complete_streaming_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); + } + + let request = self.build_request(messages, tools, true)?; + let request_json = serde_json::to_string(&request)?; + let model = self.model.clone(); + + let session = spawn_blocking(move || NativeSession::create(&model)).await?; + let rx = run_openai_json_streaming(session, request_json, Box::new(normalize_chat_chunk)); + Ok(ChatCompletionStream::new(rx)) + } + + fn build_request( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + stream: bool, + ) -> Result { + let settings_value = self.settings.serialize(); + let mut map = match settings_value { + Value::Object(m) => m, + _ => serde_json::Map::new(), + }; + + map.insert("model".into(), json!(self.model_id)); + map.insert("messages".into(), serde_json::to_value(messages)?); + + if stream { + map.insert("stream".into(), json!(true)); + } + + if let Some(t) = tools { + map.insert("tools".into(), serde_json::to_value(t)?); + } + + Ok(Value::Object(map)) + } +} + +/// Normalize a streamed chat chunk so it parses as a +/// [`CreateChatCompletionStreamResponse`]. +/// +/// Foundry Local streams tool calls under `"message"` instead of the standard +/// `"delta"`; rewrite each such choice and ensure tool calls carry an `index`. +/// Chunks that are not valid JSON are passed through unchanged so the stream +/// surfaces the original parse error. +fn normalize_chat_chunk(text: String) -> Option { + let mut value: Value = match serde_json::from_str(&text) { + Ok(v) => v, + Err(_) => return Some(text), + }; + + if let Some(choices) = value.get_mut("choices").and_then(Value::as_array_mut) { + for choice in choices { + let Some(obj) = choice.as_object_mut() else { + continue; + }; + if obj.contains_key("message") && !obj.contains_key("delta") { + if let Some(mut message) = obj.remove("message") { + if let Some(tool_calls) = + message.get_mut("tool_calls").and_then(Value::as_array_mut) + { + for (i, tc) in tool_calls.iter_mut().enumerate() { + if let Some(tc_obj) = tc.as_object_mut() { + tc_obj.entry("index").or_insert_with(|| json!(i)); + } + } + } + obj.insert("delta".into(), message); + } + } + } + } + + serde_json::to_string(&value).ok().or(Some(text)) +} diff --git a/sdk_v2/rust/src/openai/embedding_client.rs b/sdk_v2/rust/src/openai/embedding_client.rs new file mode 100644 index 000000000..2d42c395d --- /dev/null +++ b/sdk_v2/rust/src/openai/embedding_client.rs @@ -0,0 +1,98 @@ +//! OpenAI-compatible embedding client. + +use async_openai::types::embeddings::CreateEmbeddingResponse; +use serde_json::{json, Value}; + +use crate::detail::native::NativeModel; +use crate::detail::session::NativeSession; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; + +/// Client for OpenAI-compatible embedding generation backed by a local model. +pub struct EmbeddingClient { + model_id: String, + model: NativeModel, +} + +impl EmbeddingClient { + pub(crate) fn new(model_id: &str, model: NativeModel) -> Self { + Self { + model_id: model_id.to_owned(), + model, + } + } + + /// Generate embeddings for a single input text. + pub async fn generate_embedding(&self, input: &str) -> Result { + Self::validate_input(input)?; + let request = self.build_request(json!(input)); + self.execute_request(request).await + } + + /// Generate embeddings for multiple input texts in a single request. + pub async fn generate_embeddings(&self, inputs: &[&str]) -> Result { + if inputs.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "inputs must be a non-empty array".into(), + }); + } + for input in inputs { + Self::validate_input(input)?; + } + let request = self.build_request(json!(inputs)); + self.execute_request(request).await + } + + async fn execute_request(&self, request: Value) -> Result { + let request_json = serde_json::to_string(&request)?; + let model = self.model.clone(); + + let raw = spawn_blocking(move || { + let session = NativeSession::create(&model)?; + session.run_openai_json(&request_json) + }) + .await?; + + // Patch the response to add fields required by async_openai types + // that the server doesn't return (object on each item, usage) + let mut response_value: Value = serde_json::from_str(&raw)?; + if let Some(data) = response_value + .get_mut("data") + .and_then(|d| d.as_array_mut()) + { + for item in data { + if item.get("object").is_none() { + item.as_object_mut() + .map(|m| m.insert("object".into(), json!("embedding"))); + } + } + } + if response_value.get("usage").is_none() { + response_value.as_object_mut().map(|m| { + m.insert( + "usage".into(), + json!({"prompt_tokens": 0, "total_tokens": 0}), + ) + }); + } + + let parsed: CreateEmbeddingResponse = serde_json::from_value(response_value)?; + Ok(parsed) + } + + fn build_request(&self, input: Value) -> Value { + json!({ + "model": self.model_id, + "input": input, + }) + } + + fn validate_input(input: &str) -> Result<()> { + if input.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "input must be a non-empty string".into(), + }); + } + Ok(()) + } +} diff --git a/sdk_v2/rust/src/openai/json_stream.rs b/sdk_v2/rust/src/openai/json_stream.rs new file mode 100644 index 000000000..b9dca189a --- /dev/null +++ b/sdk_v2/rust/src/openai/json_stream.rs @@ -0,0 +1,49 @@ +//! Generic JSON-deserializing stream over an unbounded channel of raw strings. + +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use serde::de::DeserializeOwned; + +use crate::error::{FoundryLocalError, Result}; + +/// A stream that deserializes each received string chunk into `T`. +/// +/// Empty chunks are silently skipped. +pub struct JsonStream { + rx: tokio::sync::mpsc::UnboundedReceiver>, + _marker: PhantomData, +} + +impl JsonStream { + pub(crate) fn new(rx: tokio::sync::mpsc::UnboundedReceiver>) -> Self { + Self { + rx, + _marker: PhantomData, + } + } +} + +impl Unpin for JsonStream {} + +impl futures_core::Stream for JsonStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(Ok(chunk))) => { + if chunk.is_empty() { + continue; + } + let parsed = serde_json::from_str::(&chunk).map_err(FoundryLocalError::from); + return Poll::Ready(Some(parsed)); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/sdk_v2/rust/src/openai/live_audio_session.rs b/sdk_v2/rust/src/openai/live_audio_session.rs new file mode 100644 index 000000000..0a1ad999e --- /dev/null +++ b/sdk_v2/rust/src/openai/live_audio_session.rs @@ -0,0 +1,471 @@ +//! Live audio transcription streaming session. +//! +//! Provides real-time audio streaming ASR (Automatic Speech Recognition). +//! Audio data from a microphone (or other source) is pushed in as PCM chunks +//! and transcription results are returned as an async [`Stream`](futures_core::Stream). +//! +//! # Example +//! +//! ```ignore +//! let audio_client = model.create_audio_client(); +//! let mut session = audio_client.create_live_transcription_session(); +//! session.settings.sample_rate = 16000; +//! session.settings.channels = 1; +//! session.settings.language = Some("en".into()); +//! +//! session.start(None).await?; +//! +//! // Push audio from microphone callback +//! session.append(&pcm_bytes, None).await?; +//! +//! // Read results as async stream +//! use tokio_stream::StreamExt; +//! let mut stream = session.get_stream().await?; +//! while let Some(result) = stream.next().await { +//! let result = result?; +//! print!("{}", result.content[0].text); +//! } +//! +//! session.stop(None).await?; +//! ``` + +use std::os::raw::c_int; +use std::panic::{catch_unwind, AssertUnwindSafe}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; +use tokio_util::sync::CancellationToken; + +use crate::detail::api::Api; +use crate::detail::ffi::{flItem, flStreamingCallbackData, FOUNDRY_LOCAL_ITEM_BYTES}; +use crate::detail::items::{make_audio_item, read_text_item}; +use crate::detail::native::NativeModel; +use crate::detail::session::{NativeItemQueue, NativeRequest, NativeSession}; +use crate::detail::task::spawn_blocking; +use crate::error::{FoundryLocalError, Result}; + +// ── Types ──────────────────────────────────────────────────────────────────── + +/// Audio format settings for a live transcription session. +/// +/// Must be configured before calling [`LiveAudioTranscriptionSession::start`]. +/// Settings are frozen once the session starts. +#[derive(Debug, Clone)] +pub struct LiveAudioTranscriptionOptions { + /// PCM sample rate in Hz. Default: 16000. + pub sample_rate: u32, + /// Number of audio channels. Default: 1 (mono). + pub channels: u32, + /// Number of bits per audio sample. Default: 16. + pub bits_per_sample: u32, + /// Optional BCP-47 language hint (e.g., `"en"`, `"zh"`). + pub language: Option, + /// Maximum number of audio chunks buffered in the internal push queue. + /// If the queue is full, [`LiveAudioTranscriptionSession::append`] will + /// wait asynchronously. + /// Default: 100 (~3 seconds of audio at typical chunk sizes). + pub push_queue_capacity: usize, +} + +impl Default for LiveAudioTranscriptionOptions { + fn default() -> Self { + Self { + sample_rate: 16000, + channels: 1, + bits_per_sample: 16, + language: None, + push_queue_capacity: 100, + } + } +} + +/// Internal raw deserialization target matching the native core's JSON format. +#[derive(Debug, Clone, serde::Deserialize)] +struct LiveAudioTranscriptionRaw { + #[serde(default)] + is_final: bool, + #[serde(default)] + text: String, + start_time: Option, + end_time: Option, + id: Option, +} + +/// A content part within a [`LiveAudioTranscriptionResponse`]. +/// +/// Mirrors the C# `ContentPart` shape from the OpenAI Realtime API so that +/// callers can access `result.content[0].text` or `result.content[0].transcript` +/// consistently across SDKs. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ContentPart { + /// The transcribed text. + pub text: String, + /// Same as `text` — provided for OpenAI Realtime API compatibility. + pub transcript: String, +} + +/// Transcription result from a live audio streaming session. +/// +/// Shaped to match the C# `LiveAudioTranscriptionResponse : ConversationItem` +/// so that callers access text via `result.content[0].text` or +/// `result.content[0].transcript`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct LiveAudioTranscriptionResponse { + /// Content parts — typically a single element. Access text via + /// `result.content[0].text` or `result.content[0].transcript`. + pub content: Vec, + /// Whether this is a final or partial (interim) result. + /// Nemotron models always return `true`; other models may return `false` + /// for interim hypotheses that will be replaced by a subsequent final result. + pub is_final: bool, + /// Start time offset of this segment in the audio stream (seconds). + pub start_time: Option, + /// End time offset of this segment in the audio stream (seconds). + pub end_time: Option, + /// Unique identifier for this result (if available). + pub id: Option, +} + +impl LiveAudioTranscriptionResponse { + /// Parse a transcription response from the native core's JSON format. + pub fn from_json(json: &str) -> Result { + serde_json::from_str::(json) + .map(Self::from_raw) + .map_err(FoundryLocalError::from) + } + + fn from_raw(raw: LiveAudioTranscriptionRaw) -> Self { + Self { + content: vec![ContentPart { + transcript: raw.text.clone(), + text: raw.text, + }], + is_final: raw.is_final, + start_time: raw.start_time, + end_time: raw.end_time, + id: raw.id, + } + } + + /// Build a response from a plain transcript string. + fn from_text(text: String, is_final: bool) -> Self { + Self { + content: vec![ContentPart { + transcript: text.clone(), + text, + }], + is_final, + start_time: None, + end_time: None, + id: None, + } + } +} + +/// Structured error response from the native core. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct CoreErrorResponse { + /// Error code (e.g. `"ASR_SESSION_NOT_FOUND"`). + pub code: String, + /// Human-readable error message. + pub message: String, + /// Whether this error is transient (retryable). + #[serde(rename = "isTransient", default)] + pub is_transient: bool, +} + +impl CoreErrorResponse { + /// Attempt to parse a native error string as structured JSON. + /// Returns `None` if the error is not valid JSON or doesn't match the schema. + pub fn try_parse(error_string: &str) -> Option { + serde_json::from_str(error_string).ok() + } +} + +// ── Stream type ────────────────────────────────────────────────────────────── + +/// An async stream of [`LiveAudioTranscriptionResponse`] items. +/// +/// Returned by [`LiveAudioTranscriptionSession::get_stream`]. +/// Implements [`futures_core::Stream`]. +pub struct LiveAudioTranscriptionStream { + rx: UnboundedReceiver>, +} + +impl futures_core::Stream for LiveAudioTranscriptionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +// ── Session state ──────────────────────────────────────────────────────────── + +#[derive(Default)] +struct SessionState { + started: bool, + stopped: bool, + queue: Option>, + output_rx: Option>>, + worker: Option>, +} + +// ── Session ────────────────────────────────────────────────────────────────── + +/// Session for real-time audio streaming ASR (Automatic Speech Recognition). +/// +/// Audio data from a microphone (or other source) is pushed in as PCM chunks +/// via [`append`](Self::append), and transcription results are returned as an +/// async [`Stream`](futures_core::Stream) via [`get_stream`](Self::get_stream). +/// +/// Created via [`AudioClient::create_live_transcription_session`](super::AudioClient::create_live_transcription_session). +/// +/// # Cancellation +/// +/// All lifecycle methods accept an optional [`CancellationToken`]. Pass `None` +/// to use the default (no cancellation). +pub struct LiveAudioTranscriptionSession { + model: NativeModel, + /// Audio format settings. Must be configured before calling [`start`](Self::start). + /// Settings are frozen once the session starts. + pub settings: LiveAudioTranscriptionOptions, + state: tokio::sync::Mutex, +} + +impl LiveAudioTranscriptionSession { + pub(crate) fn new(_model_id: &str, model: NativeModel) -> Self { + Self { + model, + settings: LiveAudioTranscriptionOptions::default(), + state: tokio::sync::Mutex::new(SessionState::default()), + } + } + + /// Start a real-time audio streaming session. + /// + /// Must be called before [`append`](Self::append) or + /// [`get_stream`](Self::get_stream). Settings are frozen after this call. + pub async fn start(&self, ct: Option) -> Result<()> { + let mut state = self.state.lock().await; + + if state.started { + return Err(FoundryLocalError::Validation { + reason: "Streaming session already started. Call stop() first.".into(), + }); + } + + if let Some(token) = &ct { + if token.is_cancelled() { + return Err(FoundryLocalError::CommandExecution { + reason: "Start cancelled".into(), + }); + } + } + + let settings = self.settings.clone(); + let model = self.model.clone(); + let api = Arc::clone(&model.api); + + // Create the shared input queue up front so `append` can push into it. + let queue = Arc::new(NativeItemQueue::new(api)?); + + let (output_tx, output_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + + let worker_queue = Arc::clone(&queue); + let worker = tokio::task::spawn_blocking(move || { + run_worker(model, settings, worker_queue, output_tx); + }); + + state.started = true; + state.stopped = false; + state.queue = Some(queue); + state.output_rx = Some(output_rx); + state.worker = Some(worker); + + Ok(()) + } + + /// Push a chunk of raw PCM audio data to the streaming session. + /// + /// The data is copied internally so the caller can reuse the buffer. + pub async fn append(&self, pcm_data: &[u8], ct: Option) -> Result<()> { + if let Some(token) = &ct { + if token.is_cancelled() { + return Err(FoundryLocalError::CommandExecution { + reason: "Append cancelled".into(), + }); + } + } + + let queue = { + let state = self.state.lock().await; + if !state.started || state.stopped { + return Err(FoundryLocalError::Validation { + reason: "No active streaming session. Call start() first.".into(), + }); + } + state + .queue + .clone() + .ok_or_else(|| FoundryLocalError::Internal { + reason: "Input queue not available — session may be in an invalid state".into(), + })? + }; + + let data = pcm_data.to_vec(); + spawn_blocking(move || queue.push_bytes(&data, FOUNDRY_LOCAL_ITEM_BYTES)).await + } + + /// Get the async stream of transcription results. + /// + /// Results arrive as the native ASR engine processes audio data. + /// Can only be called once per session (the receiver is moved out). + pub async fn get_stream(&self) -> Result { + let mut state = self.state.lock().await; + let rx = state + .output_rx + .take() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "No active streaming session, or stream already taken. \ + Call start() first and only call get_stream() once." + .into(), + })?; + Ok(LiveAudioTranscriptionStream { rx }) + } + + /// Signal end-of-audio and stop the streaming session. + /// + /// Any remaining buffered audio is drained to the native engine first; + /// final results are delivered through the transcription stream before it + /// completes. The native stop always completes to avoid session leaks, + /// even if the provided [`CancellationToken`] fires. + pub async fn stop(&self, _ct: Option) -> Result<()> { + let worker = { + let mut state = self.state.lock().await; + if !state.started || state.stopped { + return Ok(()); + } + state.stopped = true; + if let Some(queue) = &state.queue { + queue.mark_finished(); + } + state.worker.take() + }; + + if let Some(handle) = worker { + let _ = handle.await; + } + Ok(()) + } +} + +/// Streaming-callback context: forwards interim transcripts to the output channel. +struct LiveCtx { + api: Arc, + tx: UnboundedSender>, +} + +unsafe extern "C" fn live_trampoline( + data: flStreamingCallbackData, + user_data: *mut std::ffi::c_void, +) -> c_int { + if user_data.is_null() { + return 0; + } + let result = catch_unwind(AssertUnwindSafe(|| { + let ctx = &*(user_data as *const LiveCtx); + let queue = data.item_queue; + if queue.is_null() { + return 0; + } + let item_api = ctx.api.item_api(); + loop { + let mut item: *mut flItem = std::ptr::null_mut(); + if !(item_api.ItemQueue_TryPop)(queue, &mut item) { + break; + } + if item.is_null() { + continue; + } + let text = read_text_item(&ctx.api, item); + (item_api.Item_Release)(item); + if let Some(text) = text { + if !text.is_empty() + && ctx + .tx + .send(Ok(LiveAudioTranscriptionResponse::from_text(text, false))) + .is_err() + { + return 1; // receiver dropped — cancel + } + } + } + 0 + })); + result.unwrap_or(1) +} + +/// Blocking worker: builds the session/request, installs the streaming callback, +/// processes the audio queue to completion, then emits the final transcript. +fn run_worker( + model: NativeModel, + settings: LiveAudioTranscriptionOptions, + queue: Arc, + output_tx: UnboundedSender>, +) { + let api = Arc::clone(&model.api); + + let run = (|| -> Result<()> { + let session = NativeSession::create(&model)?; + + let mut ctx = Box::new(LiveCtx { + api: Arc::clone(&api), + tx: output_tx.clone(), + }); + let ctx_ptr = &mut *ctx as *mut LiveCtx as *mut std::ffi::c_void; + session.set_streaming_callback(Some(live_trampoline), ctx_ptr)?; + + let request = NativeRequest::new(Arc::clone(&api))?; + let format = make_audio_item( + &api, + &[], + Some("pcm"), + settings.sample_rate as i32, + settings.channels as i32, + )?; + request.add_item(format, true)?; + // The input queue stays owned by us (append pushes into it). + request.add_item(queue.as_item_ptr(), false)?; + + let response = session.process_request(&request)?; + + // Aggregate the terminal transcript from the final response items. + let mut final_text = String::new(); + for i in 0..response.item_count() { + if let Some(text) = response.item_text(i) { + final_text.push_str(&text); + } + } + + // Uninstall the callback before the context is dropped. + let _ = session.set_streaming_callback(None, std::ptr::null_mut()); + drop(ctx); + + if !final_text.is_empty() { + let _ = output_tx.send(Ok(LiveAudioTranscriptionResponse::from_text( + final_text, true, + ))); + } + Ok(()) + })(); + + if let Err(e) = run { + let _ = output_tx.send(Err(e)); + } + // `queue` Arc clone and `session`/`request` drop here. + drop(queue); +} diff --git a/sdk_v2/rust/src/openai/mod.rs b/sdk_v2/rust/src/openai/mod.rs new file mode 100644 index 000000000..9df1fd94e --- /dev/null +++ b/sdk_v2/rust/src/openai/mod.rs @@ -0,0 +1,17 @@ +mod audio_client; +mod chat_client; +mod embedding_client; +mod json_stream; +mod live_audio_session; + +pub use self::audio_client::{ + AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, + TranscriptionSegment, TranscriptionWord, +}; +pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; +pub use self::embedding_client::EmbeddingClient; +pub use self::json_stream::JsonStream; +pub use self::live_audio_session::{ + ContentPart, CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, +}; diff --git a/sdk_v2/rust/src/types.rs b/sdk_v2/rust/src/types.rs new file mode 100644 index 000000000..f39109adf --- /dev/null +++ b/sdk_v2/rust/src/types.rs @@ -0,0 +1,151 @@ +use serde::{Deserialize, Serialize}; + +/// Hardware device type for model execution. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum DeviceType { + Invalid, + #[default] + CPU, + GPU, + NPU, +} + +/// Prompt template describing how messages are formatted for the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptTemplate { + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub assistant: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, +} + +/// Runtime information for a model (device type and execution provider). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Runtime { + pub device_type: DeviceType, + pub execution_provider: String, +} + +/// A single parameter key-value pair used in model settings. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Parameter { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +/// Model-level settings containing a list of parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSettings { + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option>, +} + +/// Full metadata for a model variant as returned by the catalog. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub version: u64, + pub alias: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + pub provider_type: String, + pub uri: String, + pub model_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_template: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub publisher: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_settings: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license_description: Option, + pub cached: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_size_mb: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub supports_tool_calling: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_fl_version: Option, + #[serde(default)] + pub created_at_unix: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub context_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_modalities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_modalities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, +} + +/// Desired response format for chat completions. +/// +/// Extends the standard OpenAI formats with the Foundry-specific +/// `LarkGrammar` variant. +#[derive(Debug, Clone)] +pub enum ChatResponseFormat { + /// Plain text output (default). + Text, + /// JSON output (unstructured). + JsonObject, + /// JSON output constrained by a schema string. + JsonSchema(String), + /// Output constrained by a Lark grammar (Foundry extension). + LarkGrammar(String), +} + +/// Tool choice configuration for chat completions. +#[derive(Debug, Clone)] +pub enum ChatToolChoice { + /// Model will not call any tool. + None, + /// Model decides whether to call a tool. + Auto, + /// Model must call at least one tool. + Required, + /// Model must call the named function. + Function(String), +} + +/// Information about an available execution provider. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct EpInfo { + /// The name of the execution provider. + pub name: String, + /// Whether this EP is currently registered and ready for use. + pub is_registered: bool, +} + +/// Result of a download-and-register execution-provider operation. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub struct EpDownloadResult { + /// Whether all requested EPs were successfully registered. + pub success: bool, + /// Human-readable status message. + pub status: String, + /// Names of EPs that were successfully registered. + pub registered_eps: Vec, + /// Names of EPs that failed to register. + pub failed_eps: Vec, +} diff --git a/sdk_v2/rust/tests/integration/audio_client_test.rs b/sdk_v2/rust/tests/integration/audio_client_test.rs new file mode 100644 index 000000000..47cef9d01 --- /dev/null +++ b/sdk_v2/rust/tests/integration/audio_client_test.rs @@ -0,0 +1,130 @@ +use super::common; +use foundry_local_sdk::openai::AudioClient; +use std::sync::Arc; +use tokio_stream::StreamExt; + +async fn setup_audio_client() -> (AudioClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::WHISPER_MODEL_ALIAS) + .await + .expect("get_model(whisper-tiny) failed"); + model.load().await.expect("model.load() failed"); + (model.create_audio_client(), model) +} + +fn audio_file() -> String { + common::get_audio_file_path().to_string_lossy().into_owned() +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming() { + let (client, model) = setup_audio_client().await; + let client = client.language("en").temperature(0.0); + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming_with_temperature() { + let (client, model) = setup_audio_client().await; + let client = client.language("en").temperature(0.5); + + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe with temperature failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming() { + let (client, model) = setup_audio_client().await; + let client = client.language("en").temperature(0.0); + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming_with_temperature() { + let (client, model) = setup_audio_client().await; + let client = client.language("en").temperature(0.5); + + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming with temperature setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe("").await; + assert!(result.is_err(), "Expected error for empty audio file path"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe_streaming("").await; + assert!( + result.is_err(), + "Expected error for empty audio file path in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/catalog_test.rs b/sdk_v2/rust/tests/integration/catalog_test.rs new file mode 100644 index 000000000..d418c7a73 --- /dev/null +++ b/sdk_v2/rust/tests/integration/catalog_test.rs @@ -0,0 +1,106 @@ +use super::common; +use foundry_local_sdk::Catalog; + +fn catalog() -> &'static Catalog { + common::get_test_manager().catalog() +} + +#[test] +fn should_initialize_with_catalog_name() { + let cat = catalog(); + let name = cat.name(); + assert!(!name.is_empty(), "Catalog name must not be empty"); +} + +#[tokio::test] +async fn should_list_models() { + let cat = catalog(); + let models = cat.get_models().await.expect("get_models failed"); + + assert!( + !models.is_empty(), + "Expected at least one model in the catalog" + ); + + let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' not found in catalog", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_get_model_by_alias() { + let cat = catalog(); + let model = cat + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_empty_alias() { + let cat = catalog(); + let result = cat.get_model("").await; + assert!(result.is_err(), "Expected error for empty alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Model alias must be a non-empty string"), + "Unexpected error message: {err_msg}" + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_unknown_alias() { + let cat = catalog(); + let result = cat.get_model("unknown-nonexistent-model-alias").await; + assert!(result.is_err(), "Expected error for unknown alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Unknown model alias"), + "Error should mention unknown alias: {err_msg}" + ); + assert!( + err_msg.contains("Available"), + "Error should list available models: {err_msg}" + ); +} + +#[tokio::test] +async fn should_get_cached_models() { + let cat = catalog(); + let cached = cat + .get_cached_models() + .await + .expect("get_cached_models failed"); + + assert!(!cached.is_empty(), "Expected at least one cached model"); + + let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' should be in the cached models list", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_empty_id() { + let cat = catalog(); + let result = cat.get_model_variant("").await; + assert!(result.is_err(), "Expected error for empty variant ID"); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_unknown_id() { + let cat = catalog(); + let result = cat + .get_model_variant("unknown-nonexistent-variant-id") + .await; + assert!(result.is_err(), "Expected error for unknown variant ID"); +} diff --git a/sdk_v2/rust/tests/integration/chat_client_test.rs b/sdk_v2/rust/tests/integration/chat_client_test.rs new file mode 100644 index 000000000..b24f3804b --- /dev/null +++ b/sdk_v2/rust/tests/integration/chat_client_test.rs @@ -0,0 +1,334 @@ +use super::common; +use foundry_local_sdk::openai::ChatClient; +use foundry_local_sdk::{ + ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, ChatToolChoice, +}; +use serde_json::json; +use std::sync::Arc; +use tokio_stream::StreamExt; + +async fn setup_chat_client() -> (ChatClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let client = model.create_chat_client().max_tokens(500).temperature(0.0); + (client, model) +} + +fn user_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestUserMessage::from(content).into() +} + +fn system_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestSystemMessage::from(content).into() +} + +fn assistant_message(content: &str) -> ChatCompletionRequestMessage { + serde_json::from_value(json!({ "role": "assistant", "content": content })) + .expect("failed to construct assistant message") +} + +#[tokio::test] +async fn should_perform_chat_completion() { + let (client, model) = setup_chat_client().await; + let messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let response = client + .complete_chat(&messages, None) + .await + .expect("complete_chat failed"); + let content = response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + println!("Response: {content}"); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_streaming_chat_completion() { + let (client, model) = setup_chat_client().await; + let mut messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let mut first_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (first turn) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + first_result.push_str(content); + } + } + } + println!("First turn: {first_result}"); + + assert!( + first_result.contains("42"), + "First turn should contain '42', got: {first_result}" + ); + + messages.push(assistant_message(&first_result)); + messages.push(user_message("Now add 25 to that result.")); + + let mut second_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (follow-up) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + second_result.push_str(content); + } + } + } + println!("Follow-up: {second_result}"); + + assert!( + second_result.contains("67"), + "Follow-up should contain '67', got: {second_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_streaming_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!( + result.is_err(), + "Expected error for empty messages in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} + +// Note: The "invalid callback" test was removed because it was an exact +// duplicate of should_throw_when_completing_streaming_chat_with_empty_messages. + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_non_streaming() { + let (client, model) = setup_chat_client().await; + let client = client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("complete_chat with tools failed"); + + let choice = response + .choices + .first() + .expect("Expected at least one choice"); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .expect("Expected tool_calls"); + assert!( + !tool_calls.is_empty(), + "Expected at least one tool call in the response" + ); + + let tool_call = match &tool_calls[0] { + ChatCompletionMessageToolCalls::Function(tc) => tc, + _ => panic!("Expected a function tool call"), + }; + assert_eq!( + tool_call.function.name, "multiply", + "Expected tool call to 'multiply'" + ); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let tool_call_id = &tool_call.id; + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + let client = client.tool_choice(ChatToolChoice::Auto); + + let final_response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("follow-up complete_chat with tools failed"); + let content = final_response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + println!("Tool call result: {content}"); + + assert!( + content.contains("42"), + "Final answer should contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_streaming() { + let (client, model) = setup_chat_client().await; + let client = client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut tool_call_id = String::new(); + + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming tool call setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for call in tool_calls { + if let Some(ref func) = call.function { + if let Some(ref name) = func.name { + tool_call_name.push_str(name); + } + if let Some(ref args) = func.arguments { + tool_call_args.push_str(args); + } + } + if let Some(ref id) = call.id { + tool_call_id = id.clone(); + } + } + } + } + } + assert_eq!( + tool_call_name, "multiply", + "Expected streamed tool call to 'multiply'" + ); + + let args: serde_json::Value = + serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + let client = client.tool_choice(ChatToolChoice::Auto); + + let mut final_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming follow-up setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + final_result.push_str(content); + } + } + } + println!("Streamed tool call result: {final_result}"); + + assert!( + final_result.contains("42"), + "Streamed final answer should contain '42', got: {final_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/common/mod.rs b/sdk_v2/rust/tests/integration/common/mod.rs new file mode 100644 index 000000000..81897b7ba --- /dev/null +++ b/sdk_v2/rust/tests/integration/common/mod.rs @@ -0,0 +1,128 @@ +//! Shared test utilities and configuration for Foundry Local SDK integration tests. +//! +//! Mirrors `testUtils.ts` from the JavaScript SDK test suite. + +#![allow(dead_code)] + +use std::path::PathBuf; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager, LogLevel}; + +/// Default model alias used for chat-completion integration tests. +pub const TEST_MODEL_ALIAS: &str = "qwen2.5-0.5b"; + +/// Default model alias used for audio-transcription integration tests. +pub const WHISPER_MODEL_ALIAS: &str = "whisper-tiny"; + +/// Default model alias used for embedding integration tests. +pub const EMBEDDING_MODEL_ALIAS: &str = "qwen3-embedding-0.6b"; + +/// Expected transcription text fragment for the shared audio test file. +pub const EXPECTED_TRANSCRIPTION_TEXT: &str = + " And lots of times you need to give people more than one link at a time"; + +// ── Environment helpers ────────────────────────────────────────────────────── + +/// Returns `true` when the tests are running inside a CI environment +/// (Azure DevOps or GitHub Actions). +pub fn is_running_in_ci() -> bool { + let azure_devops = std::env::var("TF_BUILD").unwrap_or_else(|_| "false".into()); + let github_actions = std::env::var("GITHUB_ACTIONS").unwrap_or_else(|_| "false".into()); + azure_devops.eq_ignore_ascii_case("true") || github_actions.eq_ignore_ascii_case("true") +} + +/// Walk upward from `CARGO_MANIFEST_DIR` until a `.git` directory is found. +pub fn get_git_repo_root() -> PathBuf { + let mut current = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + loop { + if current.join(".git").exists() { + return current; + } + if !current.pop() { + panic!( + "Could not locate git repo root starting from {}", + env!("CARGO_MANIFEST_DIR") + ); + } + } +} + +/// Path to the shared test-data directory. +/// Uses FOUNDRY_TEST_DATA_DIR env var if set (CI), otherwise falls back +/// to looking for test-data-shared as a sibling of the repo root. +pub fn get_test_data_shared_path() -> PathBuf { + if let Ok(env_path) = std::env::var("FOUNDRY_TEST_DATA_DIR") { + let p = PathBuf::from(&env_path); + if p.is_dir() { + return p; + } + } + let repo_root = get_git_repo_root(); + repo_root + .parent() + .expect("repo root has no parent") + .join("test-data-shared") +} + +/// Path to the shared audio test file used by audio-client tests. +pub fn get_audio_file_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("testdata") + .join("Recording.mp3") +} + +// ── Test configuration ─────────────────────────────────────────────────────── + +/// Build a [`FoundryLocalConfig`] suitable for integration tests. +/// +/// * `modelCacheDir` → `/../test-data-shared` +/// * `logsDir` → `/sdk/rust/logs` +/// * `logLevel` → `Warn` +pub fn test_config() -> FoundryLocalConfig { + let repo_root = get_git_repo_root(); + let logs_dir = repo_root.join("sdk").join("rust").join("logs"); + + FoundryLocalConfig::new("FoundryLocalTest") + .model_cache_dir(get_test_data_shared_path().to_string_lossy().into_owned()) + .logs_dir(logs_dir.to_string_lossy().into_owned()) + .log_level(LogLevel::Warn) +} + +/// Create (or return the cached) [`FoundryLocalManager`] for tests. +/// +/// Panics if creation fails so that test set-up failures are immediately +/// visible. +pub fn get_test_manager() -> &'static FoundryLocalManager { + FoundryLocalManager::create(test_config()).expect("Failed to create FoundryLocalManager") +} + +// ── Tool definitions ───────────────────────────────────────────────────────── + +/// Returns a tool definition for a simple "multiply" function. +/// +/// Used by tool-calling chat-completion tests. +pub fn get_multiply_tool() -> foundry_local_sdk::ChatCompletionTools { + serde_json::from_value(serde_json::json!({ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number" + }, + "b": { + "type": "number", + "description": "The second number" + } + }, + "required": ["a", "b"] + } + } + })) + .expect("Failed to parse multiply tool definition") +} diff --git a/sdk_v2/rust/tests/integration/embedding_client_test.rs b/sdk_v2/rust/tests/integration/embedding_client_test.rs new file mode 100644 index 000000000..0f577329a --- /dev/null +++ b/sdk_v2/rust/tests/integration/embedding_client_test.rs @@ -0,0 +1,223 @@ +//! Integration tests for EmbeddingClient. + +use std::sync::Arc; + +use foundry_local_sdk::openai::EmbeddingClient; +use foundry_local_sdk::Model; + +use crate::common; + +async fn setup_embedding_client() -> (EmbeddingClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let model = catalog + .get_model(common::EMBEDDING_MODEL_ALIAS) + .await + .expect("embedding model should exist in catalog"); + + model.load().await.expect("model should load successfully"); + + let client = model.create_embedding_client(); + (client, model) +} + +#[tokio::test] +async fn should_generate_embedding() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embedding("The quick brown fox jumps over the lazy dog") + .await + .expect("embedding should succeed"); + + assert_eq!(response.data.len(), 1); + assert_eq!(response.data[0].index, 0); + assert_eq!(response.data[0].embedding.len(), 1024); + + println!("Embedding dimension: {}", response.data[0].embedding.len()); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_normalized_embedding() { + let (client, model) = setup_embedding_client().await; + + let inputs = [ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ]; + + for input in &inputs { + let response = client + .generate_embedding(input) + .await + .expect("embedding should succeed"); + + let embedding = &response.data[0].embedding; + assert_eq!(embedding.len(), 1024); + + // Verify L2 norm is approximately 1.0 + let norm: f32 = embedding.iter().map(|v| v * v).sum::().sqrt(); + assert!( + (0.99_f32..=1.01_f32).contains(&norm), + "L2 norm {norm} not approximately 1.0" + ); + + for val in embedding { + assert!( + (-1.0_f32..=1.0_f32).contains(val), + "value {val} outside [-1, 1]" + ); + } + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_produce_different_embeddings_for_different_inputs() { + let (client, model) = setup_embedding_client().await; + + let response1 = client + .generate_embedding("The quick brown fox") + .await + .expect("embedding should succeed"); + + let response2 = client + .generate_embedding("The capital of France is Paris") + .await + .expect("embedding should succeed"); + + let emb1 = &response1.data[0].embedding; + let emb2 = &response2.data[0].embedding; + + assert_eq!(emb1.len(), emb2.len()); + + // Cosine similarity should not be 1.0 + let dot: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum(); + let norm1: f32 = emb1.iter().map(|v| v * v).sum::().sqrt(); + let norm2: f32 = emb2.iter().map(|v| v * v).sum::().sqrt(); + let cosine_similarity = dot / (norm1 * norm2); + assert!( + cosine_similarity < 0.99_f32, + "cosine similarity {cosine_similarity} should be < 0.99" + ); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_produce_same_embedding_for_same_input() { + let (client, model) = setup_embedding_client().await; + + let response1 = client + .generate_embedding("Deterministic embedding test") + .await + .expect("embedding should succeed"); + + let response2 = client + .generate_embedding("Deterministic embedding test") + .await + .expect("embedding should succeed"); + + let emb1 = &response1.data[0].embedding; + let emb2 = &response2.data[0].embedding; + + for (i, (a, b)) in emb1.iter().zip(emb2.iter()).enumerate() { + assert_eq!(a, b, "mismatch at index {i}"); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_throw_for_empty_input() { + let (client, model) = setup_embedding_client().await; + + let result = client.generate_embedding("").await; + assert!(result.is_err(), "empty input should return an error"); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_throw_for_empty_batch() { + let (client, model) = setup_embedding_client().await; + + let result = client.generate_embeddings(&[]).await; + assert!(result.is_err(), "empty batch should return an error"); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_batch_embeddings() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embeddings(&[ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(response.data.len(), 3); + for (i, data) in response.data.iter().enumerate() { + assert_eq!(data.index, i as u32); + assert_eq!(data.embedding.len(), 1024); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_normalized_batch_embeddings() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embeddings(&["Hello world", "Goodbye world"]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(response.data.len(), 2); + for data in &response.data { + let norm: f32 = data.embedding.iter().map(|v| v * v).sum::().sqrt(); + assert!( + (0.99_f32..=1.01_f32).contains(&norm), + "L2 norm {norm} not approximately 1.0" + ); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_match_single_and_batch_results() { + let (client, model) = setup_embedding_client().await; + + let single = client + .generate_embedding("The capital of France is Paris") + .await + .expect("single embedding should succeed"); + + let batch = client + .generate_embeddings(&["The capital of France is Paris"]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(batch.data.len(), 1); + for (a, b) in single.data[0] + .embedding + .iter() + .zip(batch.data[0].embedding.iter()) + { + assert_eq!(a, b); + } + + model.unload().await.expect("unload should succeed"); +} diff --git a/sdk_v2/rust/tests/integration/live_audio_test.rs b/sdk_v2/rust/tests/integration/live_audio_test.rs new file mode 100644 index 000000000..d7f2f2da8 --- /dev/null +++ b/sdk_v2/rust/tests/integration/live_audio_test.rs @@ -0,0 +1,114 @@ +use super::common; +use std::sync::Arc; +use tokio_stream::StreamExt; + +/// Generate synthetic PCM audio (440Hz sine wave, 16kHz, 16-bit mono). +fn generate_sine_wave_pcm(sample_rate: i32, duration_seconds: i32, frequency: f64) -> Vec { + let total_samples = (sample_rate * duration_seconds) as usize; + let mut pcm_bytes = vec![0u8; total_samples * 2]; // 16-bit = 2 bytes per sample + + for i in 0..total_samples { + let t = i as f64 / sample_rate as f64; + let sample = + (i16::MAX as f64 * 0.5 * (2.0 * std::f64::consts::PI * frequency * t).sin()) as i16; + pcm_bytes[i * 2] = (sample & 0xFF) as u8; + pcm_bytes[i * 2 + 1] = ((sample >> 8) & 0xFF) as u8; + } + + pcm_bytes +} + +// --- E2E streaming test with synthetic PCM audio --- + +#[tokio::test] +async fn live_streaming_e2e_with_synthetic_pcm_returns_valid_response() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + // Try to get a nemotron or whisper model for audio streaming + let model = match catalog.get_model("nemotron").await { + Ok(m) => m, + Err(_) => match catalog.get_model(common::WHISPER_MODEL_ALIAS).await { + Ok(m) => m, + Err(_) => { + eprintln!("Skipping E2E test: no audio model available"); + return; + } + }, + }; + + if !model.is_cached().await.unwrap_or(false) { + eprintln!("Skipping E2E test: model not cached"); + return; + } + + model.load().await.expect("model.load() failed"); + + let audio_client = model.create_audio_client(); + let session = audio_client.create_live_transcription_session(); + + // Verify default settings + assert_eq!(session.settings.sample_rate, 16000); + assert_eq!(session.settings.channels, 1); + assert_eq!(session.settings.bits_per_sample, 16); + + if let Err(e) = session.start(None).await { + eprintln!("Skipping E2E test: could not start session: {e}"); + model.unload().await.ok(); + return; + } + + // Start collecting results in background (must start before pushing audio) + let mut stream = session.get_stream().await.expect("get_stream failed"); + + let results = Arc::new(tokio::sync::Mutex::new(Vec::new())); + let stream_error: Arc>> = + Arc::new(tokio::sync::Mutex::new(None)); + let results_clone = Arc::clone(&results); + let error_clone = Arc::clone(&stream_error); + let read_task = tokio::spawn(async move { + while let Some(result) = stream.next().await { + match result { + Ok(r) => results_clone.lock().await.push(r), + Err(e) => { + *error_clone.lock().await = Some(format!("{e}")); + break; + } + } + } + }); + + // Generate ~2 seconds of synthetic PCM audio (440Hz sine wave) + let pcm_bytes = generate_sine_wave_pcm(16000, 2, 440.0); + + // Push audio in chunks (100ms each, matching typical mic callback size) + let chunk_size = 16000 / 10 * 2; // 100ms of 16-bit audio = 3200 bytes + for offset in (0..pcm_bytes.len()).step_by(chunk_size) { + let end = std::cmp::min(offset + chunk_size, pcm_bytes.len()); + session + .append(&pcm_bytes[offset..end], None) + .await + .expect("append failed"); + } + + // Stop session to flush remaining audio and complete the stream + session.stop(None).await.expect("stop failed"); + read_task.await.expect("read task failed"); + + // Verify no stream errors occurred + assert!( + stream_error.lock().await.is_none(), + "Stream produced an error: {:?}", + stream_error.lock().await + ); + + // Verify response attributes — synthetic audio may or may not produce text, + // but the response objects should be properly structured (C#-compatible envelope) + let results = results.lock().await; + for result in results.iter() { + assert!(!result.content.is_empty(), "content must not be empty"); + assert_eq!(result.content[0].text, result.content[0].transcript); + } + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/main.rs b/sdk_v2/rust/tests/integration/main.rs new file mode 100644 index 000000000..055760003 --- /dev/null +++ b/sdk_v2/rust/tests/integration/main.rs @@ -0,0 +1,18 @@ +//! Single integration test binary for the Foundry Local Rust SDK. +//! +//! All test modules are compiled into one binary so the native core is only +//! initialised once (via the `OnceLock` singleton in `FoundryLocalManager`). +//! Running them as separate binaries causes "already initialized" errors +//! because the .NET native runtime retains state across process-level +//! library loads. + +mod common; + +mod audio_client_test; +mod catalog_test; +mod chat_client_test; +mod embedding_client_test; +mod live_audio_test; +mod manager_test; +mod model_test; +mod web_service_test; diff --git a/sdk_v2/rust/tests/integration/manager_test.rs b/sdk_v2/rust/tests/integration/manager_test.rs new file mode 100644 index 000000000..aa3e06148 --- /dev/null +++ b/sdk_v2/rust/tests/integration/manager_test.rs @@ -0,0 +1,21 @@ +use super::common; +use foundry_local_sdk::FoundryLocalManager; + +#[test] +fn should_initialize_successfully() { + let config = common::test_config(); + let manager = FoundryLocalManager::create(config); + assert!( + manager.is_ok(), + "Manager creation failed: {:?}", + manager.err() + ); +} + +#[test] +fn should_return_catalog_with_non_empty_name() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let name = catalog.name(); + assert!(!name.is_empty(), "Catalog name should not be empty"); +} diff --git a/sdk_v2/rust/tests/integration/model_test.rs b/sdk_v2/rust/tests/integration/model_test.rs new file mode 100644 index 000000000..c1ffa171e --- /dev/null +++ b/sdk_v2/rust/tests/integration/model_test.rs @@ -0,0 +1,323 @@ +use super::common; +use std::sync::Arc; + +// ── Cached model verification ──────────────────────────────────────────────── + +#[tokio::test] +async fn should_verify_cached_models_from_test_data_shared() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let cached = catalog + .get_cached_models() + .await + .expect("get_cached_models failed"); + + let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + has_qwen, + "'{}' should be present in cached models", + common::TEST_MODEL_ALIAS + ); + + let has_whisper = cached + .iter() + .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + assert!( + has_whisper, + "'{}' should be present in cached models", + common::WHISPER_MODEL_ALIAS + ); +} + +// ── Load / Unload ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_load_and_unload_model() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + model.load().await.expect("model.load() failed"); + assert!( + model.is_loaded().await.expect("is_loaded check failed"), + "Model should be loaded after load()" + ); + + model.unload().await.expect("model.unload() failed"); + assert!( + !model.is_loaded().await.expect("is_loaded check failed"), + "Model should not be loaded after unload()" + ); +} + +// ── Introspection ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_expose_alias() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_expose_non_empty_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + println!("Model id: {}", model.id()); + + assert!( + !model.id().is_empty(), + "Model id() should be a non-empty string" + ); +} + +#[tokio::test] +async fn should_have_at_least_one_variant() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let variants = model.variants(); + println!("Model has {} variant(s)", variants.len()); + + assert!( + !variants.is_empty(), + "Model should have at least one variant" + ); +} + +#[tokio::test] +async fn should_have_selected_variant_matching_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + // The model's id() should return the selected variant's id + // info() delegates to the selected variant, so id() and info().id must agree + assert_eq!( + model.id(), + model.info().id, + "model.id() should match model.info().id (the selected variant's metadata)" + ); +} + +#[tokio::test] +async fn should_report_cached_model_as_cached() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let cached = model.is_cached().await.expect("is_cached() should succeed"); + assert!( + cached, + "Test model '{}' should be cached (from test-data-shared)", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_return_non_empty_path_for_cached_model() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let path = model.path().await.expect("path() should succeed"); + println!("Model path: {}", path.display()); + + assert!( + !path.as_os_str().is_empty(), + "Cached model should have a non-empty path" + ); +} + +#[tokio::test] +async fn should_select_variant_by_model() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + // Remember the original selection so we can restore it afterward. + let original_id = model.id().to_string(); + + let first_variant = model.variants()[0].clone(); + let first_variant_id = first_variant.id().to_string(); + model + .select_variant(&first_variant) + .expect("select_variant should succeed"); + assert_eq!( + model.id(), + first_variant_id, + "After select_variant, id() should match the selected variant" + ); + + // Restore the original variant so other tests sharing this + // model via the catalog are not affected. + model + .select_variant_by_id(&original_id) + .expect("restoring original variant should succeed"); +} + +#[tokio::test] +async fn should_select_variant_by_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let original_id = model.id().to_string(); + + let first_variant_id = model.variants()[0].id().to_string(); + model + .select_variant_by_id(&first_variant_id) + .expect("select_variant_by_id should succeed"); + assert_eq!( + model.id(), + first_variant_id, + "After select_variant_by_id, id() should match the selected variant" + ); + + model + .select_variant_by_id(&original_id) + .expect("restoring original variant should succeed"); +} + +#[tokio::test] +async fn should_fail_to_select_unknown_variant() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let result = model.select_variant_by_id("nonexistent-variant-id"); + assert!( + result.is_err(), + "select_variant_by_id with unknown ID should fail" + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); +} + +// ── Load manager (core interop) ────────────────────────────────────────────── + +async fn get_test_model() -> Arc { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed") +} + +#[tokio::test] +async fn should_load_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_unload_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_list_loaded_models_using_core_interop() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let loaded = catalog + .get_loaded_models() + .await + .expect("catalog.get_loaded_models() failed"); + + let _ = loaded; +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_load_and_unload_model_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + let model = get_test_model().await; + + manager + .start_web_service() + .await + .expect("start_web_service failed"); + + model + .load() + .await + .expect("load via external service failed"); + + model + .unload() + .await + .expect("unload via external service failed"); +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_list_loaded_models_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + + manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let catalog = manager.catalog(); + let loaded = catalog + .get_loaded_models() + .await + .expect("get_loaded_models via external service failed"); + + let _ = loaded; +} diff --git a/sdk_v2/rust/tests/integration/web_service_test.rs b/sdk_v2/rust/tests/integration/web_service_test.rs new file mode 100644 index 000000000..9222f9d45 --- /dev/null +++ b/sdk_v2/rust/tests/integration/web_service_test.rs @@ -0,0 +1,161 @@ +use super::common; +use serde_json::json; + +/// Start the web service, make a non-streaming POST to v1/chat/completions, +/// verify we get a valid response, then stop the service. +#[tokio::test] +async fn should_complete_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + manager + .start_web_service() + .await + .expect("start_web_service failed"); + let urls = manager.urls().expect("urls() should succeed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": false + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + resp.status().is_success(), + "Expected 2xx, got {}", + resp.status() + ); + + let body: serde_json::Value = resp.json().await.expect("failed to parse response JSON"); + let content = body + .pointer("/choices/0/message/content") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// Start the web service, make a streaming POST to v1/chat/completions, +/// collect SSE chunks, verify we get a valid streamed response. +#[tokio::test] +async fn should_stream_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + manager + .start_web_service() + .await + .expect("start_web_service failed"); + let urls = manager.urls().expect("urls() should succeed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": true + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + response.status().is_success(), + "Expected 2xx, got {}", + response.status() + ); + + let mut full_text = String::new(); + while let Some(chunk) = response.chunk().await.expect("chunk read failed") { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + full_text.push_str(content); + } + } + } + } + } + + println!("REST streamed response: {full_text}"); + + assert!( + full_text.contains("42"), + "Expected streamed response to contain '42', got: {full_text}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// urls() should return the listening addresses after start_web_service. +#[tokio::test] +async fn should_expose_urls_after_start() { + let manager = common::get_test_manager(); + + manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let urls = manager.urls().expect("urls() should succeed"); + println!("Web service URLs: {urls:?}"); + assert!(!urls.is_empty(), "urls() should return URLs after start"); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); +} From 3dcb8337329e4c944ae5c39f548b5e7e5c897796 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 22 Jun 2026 17:52:29 +0100 Subject: [PATCH 2/7] bug fixes --- sdk_v2/rust/src/detail/items.rs | 10 ++++++---- sdk_v2/rust/src/detail/session.rs | 15 +++++++++------ sdk_v2/rust/src/openai/live_audio_session.rs | 10 +++++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/sdk_v2/rust/src/detail/items.rs b/sdk_v2/rust/src/detail/items.rs index dcec5b6a1..5f0de9f4f 100644 --- a/sdk_v2/rust/src/detail/items.rs +++ b/sdk_v2/rust/src/detail/items.rs @@ -16,9 +16,10 @@ pub(crate) fn make_text_item( text: &str, item_type: flTextItemType, ) -> Result<*mut flItem> { + // Convert before Create so a NUL-conversion error can't leak the item. + let c = to_cstring(text)?; let mut item: *mut flItem = ptr::null_mut(); api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_TEXT, &mut item) })?; - let c = to_cstring(text)?; let data = flTextData { version: FOUNDRY_LOCAL_API_VERSION, text: c.as_ptr(), @@ -48,14 +49,15 @@ pub(crate) fn make_audio_item( sample_rate: i32, channels: i32, ) -> Result<*mut flItem> { - let mut item: *mut flItem = ptr::null_mut(); - api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_AUDIO, &mut item) })?; - + // Convert before Create so a NUL-conversion error can't leak the item. let format_c = match format { Some(f) => Some(to_cstring(f)?), None => None, }; + let mut item: *mut flItem = ptr::null_mut(); + api.check(unsafe { (api.item_api().Create)(FOUNDRY_LOCAL_ITEM_AUDIO, &mut item) })?; + // Like SetBytes, SetAudio does not copy the sample buffer — it borrows the // pointer (and frees it via the deleter when one is supplied). Transfer an // owned heap allocation so the buffer outlives this call; the format string diff --git a/sdk_v2/rust/src/detail/session.rs b/sdk_v2/rust/src/detail/session.rs index 53be90965..e661a12d9 100644 --- a/sdk_v2/rust/src/detail/session.rs +++ b/sdk_v2/rust/src/detail/session.rs @@ -110,6 +110,12 @@ impl NativeItemQueue { } /// Push an item, transferring ownership into the queue. + /// + /// `ItemQueue_Push` takes ownership of `item` *unconditionally* for a + /// non-null queue and item: the native side moves the raw pointer into a + /// `unique_ptr` before enqueuing, so even if enqueuing fails the item is + /// already (or will be) freed. Callers must therefore **not** release `item` + /// on a returned error — doing so would double-free. pub(crate) fn push_item(&self, item: *mut flItem) -> Result<()> { self.api .check(unsafe { (self.api.item_api().ItemQueue_Push)(self.ptr, item) }) @@ -118,12 +124,9 @@ impl NativeItemQueue { /// Create a BYTES item from `data` and push it into the queue. pub(crate) fn push_bytes(&self, data: &[u8], item_type: flItemType) -> Result<()> { let item = make_bytes_item(&self.api, data, item_type)?; - if let Err(e) = self.push_item(item) { - // Push failed — we still own the item, so release it. - unsafe { (self.api.item_api().Item_Release)(item) }; - return Err(e); - } - Ok(()) + // `push_item` consumes `item` on every path (see its docs); do not + // release it here on error. + self.push_item(item) } /// Signal that no more items will be pushed. diff --git a/sdk_v2/rust/src/openai/live_audio_session.rs b/sdk_v2/rust/src/openai/live_audio_session.rs index 0a1ad999e..52302caa8 100644 --- a/sdk_v2/rust/src/openai/live_audio_session.rs +++ b/sdk_v2/rust/src/openai/live_audio_session.rs @@ -441,7 +441,13 @@ fn run_worker( // The input queue stays owned by us (append pushes into it). request.add_item(queue.as_item_ptr(), false)?; - let response = session.process_request(&request)?; + let response = session.process_request(&request); + + // Always uninstall the streaming callback before `ctx` can be dropped — + // on the error path too — so the native session never retains a dangling + // `user_data` pointer into the freed context. + let _ = session.set_streaming_callback(None, std::ptr::null_mut()); + let response = response?; // Aggregate the terminal transcript from the final response items. let mut final_text = String::new(); @@ -451,8 +457,6 @@ fn run_worker( } } - // Uninstall the callback before the context is dropped. - let _ = session.set_streaming_callback(None, std::ptr::null_mut()); drop(ctx); if !final_text.is_empty() { From b76c4bdf118a841d61fffc3325289f81309c104d Mon Sep 17 00:00:00 2001 From: samuel100 Date: Wed, 24 Jun 2026 10:53:11 +0100 Subject: [PATCH 3/7] WebGpu error fix --- sdk_v2/rust/src/foundry_local_manager.rs | 78 +++--------------------- 1 file changed, 9 insertions(+), 69 deletions(-) diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs index 0ccd3b656..e80eb9132 100644 --- a/sdk_v2/rust/src/foundry_local_manager.rs +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -5,7 +5,7 @@ //! the local web service. use std::sync::atomic::AtomicBool; -use std::sync::{Arc, Mutex, OnceLock}; +use std::sync::{Arc, Mutex}; use crate::catalog::Catalog; use crate::configuration::{FoundryLocalConfig, Logger}; @@ -15,10 +15,6 @@ use crate::detail::task::spawn_blocking; use crate::error::{FoundryLocalError, Result}; use crate::types::{EpDownloadResult, EpInfo}; -/// Global singleton holder — only stores a successfully initialised manager. -static INSTANCE: OnceLock = OnceLock::new(); -/// Guard to ensure only one thread attempts initialisation at a time. -static INIT_GUARD: Mutex<()> = Mutex::new(()); /// Primary entry point for interacting with Foundry Local. /// @@ -92,22 +88,7 @@ impl FoundryLocalManager { /// the initialisation, and builds the model catalog. Subsequent calls /// return a reference to the same instance (the provided config is /// ignored after the first call). - pub fn create(config: FoundryLocalConfig) -> Result<&'static Self> { - // Fast path: singleton already initialised. - if let Some(manager) = INSTANCE.get() { - return Ok(manager); - } - - // Slow path: acquire init guard so only one thread attempts initialisation. - let _guard = INIT_GUARD.lock().map_err(|_| FoundryLocalError::Internal { - reason: "initialisation guard poisoned".into(), - })?; - - // Double-check after acquiring the lock. - if let Some(manager) = INSTANCE.get() { - return Ok(manager); - } - + pub fn create(config: FoundryLocalConfig) -> Result { let mut config = config; let api = Arc::new(Api::load(config.library_path_ref())?); let logger = config.take_logger(); @@ -117,33 +98,21 @@ impl FoundryLocalManager { Arc::clone(&api), native_config.as_ptr(), )?); - // `native_config` is dropped here; Manager_Create has copied what it needs. let catalog_ptr = native.catalog_ptr()?; let catalog = Catalog::new(Arc::clone(&api), catalog_ptr)?; - let manager = FoundryLocalManager { + // Owned manager: teardown runs via RAII when this value is dropped + // (mid-program), not via a process-exit atexit hook. This releases the + // native manager (and unregisters EPs) while the engine/Metal runtime is + // still alive — matching the C++ SDK's local-`Manager` semantics, and + // avoiding the WebGPU `ReleaseEpFactory` teardown throw (ORT #29206). + Ok(FoundryLocalManager { native, catalog, urls: Mutex::new(Vec::new()), _logger: logger, - }; - - // Only cache on success — failures allow the next caller to retry. - match INSTANCE.set(manager) { - Ok(()) => { - // Register a process-exit hook to release the native manager - // before the library's C++ static destructors run. Without - // this, the native dtor chain (Manager -> logger -> spdlog - // flush) can fire after spdlog's global thread pool is gone, - // raising `mutex lock failed` and aborting the process. The - // hook mirrors the Python SDK's `atexit` teardown. - register_exit_teardown(); - Ok(INSTANCE.get().unwrap()) - } - // Another thread beat us — return their instance. - Err(_) => Ok(INSTANCE.get().unwrap()), - } + }) } /// Access the model catalog. @@ -320,32 +289,3 @@ impl FoundryLocalManager { } } -/// Register the process-exit teardown hook exactly once. -/// -/// Uses the C runtime's `atexit`, which runs registered handlers in LIFO order. -/// Because the `foundry_local` library is `dlopen`ed during `create()` (before -/// this registration), its static destructors are registered earlier and -/// therefore run *after* our hook — giving us the window to release the native -/// manager while the engine's globals (e.g. the spdlog thread pool) are still -/// alive. -fn register_exit_teardown() { - extern "C" { - fn atexit(cb: extern "C" fn()) -> std::os::raw::c_int; - } - // SAFETY: `exit_teardown` is a valid `extern "C"` function with no captured - // state; registering it with the C runtime is sound. - unsafe { - atexit(exit_teardown); - } -} - -/// `atexit` callback: release the singleton's native manager before the -/// library's C++ static destructors run. Panic-safe (a panic must never unwind -/// across the C runtime boundary). -extern "C" fn exit_teardown() { - let _ = std::panic::catch_unwind(|| { - if let Some(manager) = INSTANCE.get() { - manager.native.teardown(); - } - }); -} From d92a1e54604728ef2c27d5c0803d5dc4aa9f6682 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Wed, 24 Jun 2026 11:44:47 +0100 Subject: [PATCH 4/7] Make FoundryLocalManager a shared (Arc) resettable singleton Replace the owned-Self FoundryLocalManager::create with an Arc that shares one process-wide instance via a static OnceLock>>. While any handle is alive, callers share the same instance; when the last Arc drops, NativeManager::Drop runs teardown (Manager_Shutdown + Manager_Release) while the ORT runtime is still alive and before the library's C++ static destructors -- restoring singleton semantics without the atexit hook that caused the WebGPU ReleaseEpFactory double-unregister (ORT #29206). Use OnceLock to wrap the Mutex (const Weak::new() needs Rust 1.73, above the crate's 1.70 MSRV). Update stale atexit/singleton doc comments in manager.rs, foundry_local_manager.rs, and docs/api.md. Keep the test helper's &'static signature by holding a process-lifetime OnceLock>, so all existing call sites compile unchanged. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/docs/api.md | 4 +- sdk_v2/rust/src/detail/manager.rs | 19 ++--- sdk_v2/rust/src/foundry_local_manager.rs | 80 +++++++++++++++------ sdk_v2/rust/tests/integration/common/mod.rs | 10 ++- 4 files changed, 79 insertions(+), 34 deletions(-) diff --git a/sdk_v2/rust/docs/api.md b/sdk_v2/rust/docs/api.md index 8dcb0c292..4400c6c67 100644 --- a/sdk_v2/rust/docs/api.md +++ b/sdk_v2/rust/docs/api.md @@ -42,7 +42,7 @@ ### FoundryLocalManager -Primary entry point for interacting with Foundry Local. Singleton — created once via `create()`. +Primary entry point for interacting with Foundry Local. Shared instance — while any handle is alive, `create()` returns the same instance; it is torn down when the last handle is dropped. ```rust pub struct FoundryLocalManager { /* private fields */ } @@ -50,7 +50,7 @@ pub struct FoundryLocalManager { /* private fields */ } | Method | Signature | Description | |--------|-----------|-------------| -| `create` | `fn create(config: FoundryLocalConfig) -> Result<&'static Self, FoundryLocalError>` | Initialise the SDK. First call creates the singleton; subsequent calls return the existing instance (config is ignored after first call). | +| `create` | `fn create(config: FoundryLocalConfig) -> Result, FoundryLocalError>` | Initialise the SDK. Returns a shared handle: while any handle is alive, all calls return the same instance (config ignored after the first). Once the last handle is dropped the native manager is torn down via `Drop`, and a later call builds a fresh instance. | | `catalog` | `fn catalog(&self) -> &Catalog` | Access the model catalog. | | `urls` | `fn urls(&self) -> Result, FoundryLocalError>` | URLs the local web service is listening on. Empty until `start_web_service` is called. | | `start_web_service` | `async fn start_web_service(&self) -> Result<(), FoundryLocalError>` | Start the local web service. Retrieve listening URLs via `urls()`. | diff --git a/sdk_v2/rust/src/detail/manager.rs b/sdk_v2/rust/src/detail/manager.rs index 6e1767f63..d01982b00 100644 --- a/sdk_v2/rust/src/detail/manager.rs +++ b/sdk_v2/rust/src/detail/manager.rs @@ -13,14 +13,14 @@ use crate::types::EpInfo; /// Owns a native `flManager`. /// -/// The manager is held by a process-lifetime singleton, so [`Drop`] effectively -/// never runs; the native handle is instead released by an `atexit` hook (see -/// [`teardown`](Self::teardown)) before the library's C++ static destructors run. +/// The native handle is released exactly once by [`Drop`] (via +/// [`teardown`](Self::teardown)) when the last owner is dropped — while the ORT +/// runtime is still alive and before the library's C++ static destructors run. pub(crate) struct NativeManager { api: Arc, ptr: *mut flManager, - /// Set once the native manager has been released, so shutdown/release and - /// the `atexit` hook and `Drop` all coordinate to release exactly once. + /// Set once the native manager has been released, so `shutdown`, explicit + /// `teardown`, and `Drop` all coordinate to release exactly once. released: AtomicBool, } @@ -153,7 +153,8 @@ impl NativeManager { /// Begin graceful shutdown of the native manager (`Manager_Shutdown`). /// /// Idempotent and safe to call from any thread. Does **not** release the - /// native handle — that happens once at process exit via [`teardown`](Self::teardown). + /// native handle — that happens once the last owner is dropped, via + /// [`teardown`](Self::teardown). pub(crate) fn shutdown(&self) -> Result<()> { if self.released.load(Ordering::Acquire) { return Ok(()); @@ -166,10 +167,12 @@ impl NativeManager { /// Run the prescribed teardown exactly once: `Manager_Shutdown` then /// `Manager_Release`. /// - /// This is invoked from the process-exit hook (and `Drop`) so the manager's + /// Invoked from [`Drop`] when the last owner is released, so the manager's /// C++ destructor runs *before* the library's static destructors — avoiding /// the spdlog teardown abort (`mutex lock failed`) documented for the other - /// SDK bindings. Releasing is always attempted, even if shutdown errored. + /// SDK bindings, and the WebGPU `ReleaseEpFactory` throw that a process-exit + /// release would trigger. Releasing is always attempted, even if shutdown + /// errored. pub(crate) fn teardown(&self) { if self.released.swap(true, Ordering::AcqRel) { return; diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs index e80eb9132..c2cb9e71a 100644 --- a/sdk_v2/rust/src/foundry_local_manager.rs +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -1,11 +1,11 @@ //! Top-level entry point for the Foundry Local SDK. //! -//! [`FoundryLocalManager`] is a singleton that initialises the native core -//! library, provides access to the model [`Catalog`], and can start / stop -//! the local web service. +//! [`FoundryLocalManager`] initialises the native core library, provides access +//! to the model [`Catalog`], and can start / stop the local web service. While a +//! handle is alive it is shared process-wide (see [`FoundryLocalManager::create`]). use std::sync::atomic::AtomicBool; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock, Weak}; use crate::catalog::Catalog; use crate::configuration::{FoundryLocalConfig, Logger}; @@ -15,11 +15,27 @@ use crate::detail::task::spawn_blocking; use crate::error::{FoundryLocalError, Result}; use crate::types::{EpDownloadResult, EpInfo}; +/// Process-wide weak handle to the live manager. +/// +/// Holds a [`Weak`] so the global never keeps the manager alive past its last +/// strong reference: when the final [`Arc`] returned by +/// [`FoundryLocalManager::create`] is dropped, the native manager is torn down +/// deterministically via [`Drop`] — while the ORT runtime is still alive and +/// before the library's C++ static destructors run. +/// +/// Wrapped in a [`OnceLock`] (rather than a `const` `Mutex::new`) to keep the +/// crate compatible with its minimum supported Rust version. +static INSTANCE: OnceLock>> = OnceLock::new(); + +/// The lazily-initialised slot holding the shared-instance weak handle. +fn instance_slot() -> &'static Mutex> { + INSTANCE.get_or_init(|| Mutex::new(Weak::new())) +} /// Primary entry point for interacting with Foundry Local. /// -/// Created once via [`FoundryLocalManager::create`]; subsequent calls return -/// the existing instance. +/// Obtain a handle with [`FoundryLocalManager::create`]. While at least one +/// handle is alive, every caller shares the same instance. pub struct FoundryLocalManager { native: Arc, catalog: Catalog, @@ -82,13 +98,32 @@ impl<'a> EpDownloadBuilder<'a> { } impl FoundryLocalManager { - /// Initialise the SDK. + /// Initialise the SDK and return a shared handle to the manager. + /// + /// While at least one returned [`Arc`] is alive, every call returns the + /// **same** instance (a process-wide singleton) and the `config` passed to + /// later calls is ignored. Once the last handle is dropped the native + /// manager is torn down; a subsequent call then builds a fresh instance + /// from the new `config`. /// - /// The first call creates the singleton, loads the native library, runs - /// the initialisation, and builds the model catalog. Subsequent calls - /// return a reference to the same instance (the provided config is - /// ignored after the first call). - pub fn create(config: FoundryLocalConfig) -> Result { + /// Teardown runs via [`Drop`] when the final handle is released — not via a + /// process-exit hook — so the native manager (and its EP unregistration) + /// shuts down while the engine / ORT runtime is still alive. This matches + /// the C++ SDK's local-`Manager` semantics and avoids the WebGPU + /// `ReleaseEpFactory` teardown throw (ORT #29206). + pub fn create(config: FoundryLocalConfig) -> Result> { + // Hold the lock across initialisation so only one thread builds the + // instance; concurrent callers then observe and share it. A poisoned + // lock is recoverable: the guarded `Weak` is valid regardless of panics. + let mut slot = instance_slot() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + + // Reuse the live instance if one already exists. + if let Some(existing) = slot.upgrade() { + return Ok(existing); + } + let mut config = config; let api = Arc::new(Api::load(config.library_path_ref())?); let logger = config.take_logger(); @@ -102,17 +137,17 @@ impl FoundryLocalManager { let catalog_ptr = native.catalog_ptr()?; let catalog = Catalog::new(Arc::clone(&api), catalog_ptr)?; - // Owned manager: teardown runs via RAII when this value is dropped - // (mid-program), not via a process-exit atexit hook. This releases the - // native manager (and unregisters EPs) while the engine/Metal runtime is - // still alive — matching the C++ SDK's local-`Manager` semantics, and - // avoiding the WebGPU `ReleaseEpFactory` teardown throw (ORT #29206). - Ok(FoundryLocalManager { + let manager = Arc::new(FoundryLocalManager { native, catalog, urls: Mutex::new(Vec::new()), _logger: logger, - }) + }); + + // Record a weak reference so future calls share this instance without + // keeping it alive past the caller's last strong handle. + *slot = Arc::downgrade(&manager); + Ok(manager) } /// Access the model catalog. @@ -127,9 +162,9 @@ impl FoundryLocalManager { /// thread. /// /// Calling this is optional: the native manager is released automatically - /// at process exit. Use it when you want to deterministically wind the - /// engine down before exiting. After calling `shutdown`, the manager should - /// not be used for further inference. + /// when the last handle is dropped. Use it when you want to deterministically + /// wind the engine down before releasing the handle. After calling + /// `shutdown`, the manager should not be used for further inference. pub fn shutdown(&self) -> Result<()> { self.native.shutdown() } @@ -288,4 +323,3 @@ impl FoundryLocalManager { Ok(result) } } - diff --git a/sdk_v2/rust/tests/integration/common/mod.rs b/sdk_v2/rust/tests/integration/common/mod.rs index 81897b7ba..3ec39d1ca 100644 --- a/sdk_v2/rust/tests/integration/common/mod.rs +++ b/sdk_v2/rust/tests/integration/common/mod.rs @@ -5,6 +5,7 @@ #![allow(dead_code)] use std::path::PathBuf; +use std::sync::{Arc, OnceLock}; use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager, LogLevel}; @@ -91,10 +92,17 @@ pub fn test_config() -> FoundryLocalConfig { /// Create (or return the cached) [`FoundryLocalManager`] for tests. /// +/// Holds a process-lifetime strong handle so the shared instance survives for +/// the whole test binary (avoiding repeated native init), and hands out a +/// `'static` borrow into it. +/// /// Panics if creation fails so that test set-up failures are immediately /// visible. pub fn get_test_manager() -> &'static FoundryLocalManager { - FoundryLocalManager::create(test_config()).expect("Failed to create FoundryLocalManager") + static TEST_MANAGER: OnceLock> = OnceLock::new(); + TEST_MANAGER.get_or_init(|| { + FoundryLocalManager::create(test_config()).expect("Failed to create FoundryLocalManager") + }) } // ── Tool definitions ───────────────────────────────────────────────────────── From 8ef0679119b377e9d922fabd3bd09a4f0d26f587 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Wed, 24 Jun 2026 12:33:54 +0100 Subject: [PATCH 5/7] Tie native handle lifetimes to the manager; serialize native create/release Fix a use-after-free reachable from safe code: Model, the OpenAI clients, and sessions held only a raw flModel*/flSession* plus Arc (which keeps just the shared library loaded), so dropping the last Arc released the native catalog and model handles while those derived objects were still alive. A natural factory pattern (create manager, get a client, return it) would dangle. This was latent under the old leaked &'static singleton and became reachable once teardown moved onto last-Arc Drop. Thread a strong Arc keep-alive through NativeModel, NativeCatalog, and NativeSession so every derived handle keeps the native manager (and thus the catalog/model handles it owns) alive until the handle itself is dropped. The keep-alive targets NativeManager (a leaf that owns no Rust wrappers), so there is no reference cycle. Also add NATIVE_LIFECYCLE, a Mutex<()> held across both Manager_Create and Manager_Release, to close the create/teardown race: an Arc's strong count reaches zero slightly before Drop finishes Manager_Release, so a concurrent create() could otherwise observe no instance yet be rejected by the native single-instance guard. Validated: cargo fmt/check/clippy/doc clean (lib, tests, examples). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/src/catalog.rs | 9 ++- sdk_v2/rust/src/detail/manager.rs | 25 +++++++- sdk_v2/rust/src/detail/native.rs | 72 +++++++++++++++++------- sdk_v2/rust/src/detail/session.rs | 13 ++++- sdk_v2/rust/src/foundry_local_manager.rs | 10 ++-- 5 files changed, 100 insertions(+), 29 deletions(-) diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs index 95f35ce37..60694f894 100644 --- a/sdk_v2/rust/src/catalog.rs +++ b/sdk_v2/rust/src/catalog.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use crate::detail::api::Api; +use crate::detail::manager::NativeManager; use crate::detail::model::Model; use crate::detail::native::NativeCatalog; use crate::detail::task::spawn_blocking; @@ -19,8 +20,12 @@ pub struct Catalog { } impl Catalog { - pub(crate) fn new(api: Arc, ptr: *mut crate::detail::ffi::flCatalog) -> Result { - let native = NativeCatalog::new(api, ptr); + pub(crate) fn new( + api: Arc, + ptr: *mut crate::detail::ffi::flCatalog, + manager: Arc, + ) -> Result { + let native = NativeCatalog::new(api, ptr, manager); let name = native.name().unwrap_or_else(|_| "default".into()); Ok(Self { native, name }) } diff --git a/sdk_v2/rust/src/detail/manager.rs b/sdk_v2/rust/src/detail/manager.rs index d01982b00..908625095 100644 --- a/sdk_v2/rust/src/detail/manager.rs +++ b/sdk_v2/rust/src/detail/manager.rs @@ -4,13 +4,31 @@ use std::os::raw::c_char; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::api::{cstr_to_string, Api}; use super::ffi::*; use crate::error::{FoundryLocalError, Result}; use crate::types::EpInfo; +/// Serializes native manager creation against release. +/// +/// The native core enforces a single live `flManager` (`Manager_Create` fails +/// with `INVALID_USAGE` while one exists). Because an [`Arc`]'s +/// strong count reaches zero slightly before [`Drop`] finishes `Manager_Release`, +/// a concurrent `create` could otherwise observe "no instance" yet race the +/// in-flight release and be rejected. Holding this lock across both +/// `Manager_Create` and `Manager_Release` closes that window. +static NATIVE_LIFECYCLE: Mutex<()> = Mutex::new(()); + +/// Lock [`NATIVE_LIFECYCLE`], recovering from poisoning (the guarded `()` has no +/// invariants a panic could break). +fn lock_lifecycle() -> std::sync::MutexGuard<'static, ()> { + NATIVE_LIFECYCLE + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + /// Owns a native `flManager`. /// /// The native handle is released exactly once by [`Drop`] (via @@ -32,6 +50,8 @@ unsafe impl Sync for NativeManager {} impl NativeManager { /// Create a manager from a fully-built native configuration. pub(crate) fn create(api: Arc, config: *const flConfiguration) -> Result { + // Serialize against any in-flight native release (see `NATIVE_LIFECYCLE`). + let _lifecycle = lock_lifecycle(); let mut ptr: *mut flManager = std::ptr::null_mut(); let status = unsafe { (api.root().Manager_Create)(config, &mut ptr) }; api.check(status)?; @@ -177,6 +197,9 @@ impl NativeManager { if self.released.swap(true, Ordering::AcqRel) { return; } + // Serialize against a concurrent `create` so the native core never sees + // an overlapping create/release (see `NATIVE_LIFECYCLE`). + let _lifecycle = lock_lifecycle(); // SAFETY: `ptr` was created by Manager_Create and is released exactly // once (guarded by the `released` swap above). unsafe { diff --git a/sdk_v2/rust/src/detail/native.rs b/sdk_v2/rust/src/detail/native.rs index b2a362fa0..f13eda343 100644 --- a/sdk_v2/rust/src/detail/native.rs +++ b/sdk_v2/rust/src/detail/native.rs @@ -1,7 +1,10 @@ -//! Non-owning wrapper around a native `flModel` handle. +//! Wrappers around native `flModel` / `flCatalog` handles. //! -//! `flModel*` handles are owned by the catalog (which lives for the process -//! lifetime via the manager singleton) and are never released individually. A +//! `flModel*` and `flCatalog*` handles are owned by the native manager and are +//! never released individually; they stay valid only while that manager is +//! alive. Each wrapper therefore holds a strong [`Arc`] so a +//! handle can safely outlive the +//! [`FoundryLocalManager`](crate::FoundryLocalManager) that created it. A //! `flModelList*`, by contrast, is owned by the caller: [`collect_models`] //! eagerly extracts the contained handles and then releases the list. @@ -11,23 +14,34 @@ use std::sync::Arc; use super::api::{cstr_to_string, Api}; use super::ffi::*; +use super::manager::NativeManager; use crate::error::{FoundryLocalError, Result}; -/// A borrowed handle to a catalog-owned `flModel`. +/// A handle to a manager-owned `flModel`, plus a strong reference to the owning +/// native manager so the catalog (and therefore this handle) stays alive for as +/// long as the model is held. #[derive(Clone)] pub(crate) struct NativeModel { pub(crate) api: Arc, pub(crate) ptr: *mut flModel, + /// Keeps the native manager — which owns the catalog and this `flModel` — + /// alive. Propagated to derived models/sessions; never dereferenced here. + manager: Arc, } -// SAFETY: model handles are owned by the catalog (process-lifetime) and the -// native implementation is thread-safe for independent operations. +// SAFETY: the `flModel` is owned by the native manager kept alive via `manager`, +// and the native implementation is thread-safe for independent operations. unsafe impl Send for NativeModel {} unsafe impl Sync for NativeModel {} impl NativeModel { - pub(crate) fn new(api: Arc, ptr: *mut flModel) -> Self { - Self { api, ptr } + pub(crate) fn new(api: Arc, ptr: *mut flModel, manager: Arc) -> Self { + Self { api, ptr, manager } + } + + /// Clone the keep-alive handle to the owning native manager. + pub(crate) fn manager(&self) -> Arc { + Arc::clone(&self.manager) } pub(crate) fn info_ptr(&self) -> Result<*const flModelInfo> { @@ -79,7 +93,7 @@ impl NativeModel { let mut list: *mut flModelList = std::ptr::null_mut(); let status = unsafe { (self.api.model_api().GetVariants)(self.ptr, &mut list) }; self.api.check(status)?; - Ok(collect_models(&self.api, list)) + Ok(collect_models(&self.api, &self.manager, list)) } /// Download the model, optionally reporting progress (0.0–100.0) and @@ -142,8 +156,13 @@ unsafe extern "C" fn download_trampoline( /// Eagerly extract all model handles from a `flModelList`, then release the list. /// -/// The returned handles remain valid for the catalog's lifetime. -pub(crate) fn collect_models(api: &Arc, list: *mut flModelList) -> Vec { +/// Each extracted handle is tagged with `manager` so it keeps the owning native +/// manager alive for as long as the handle is held. +pub(crate) fn collect_models( + api: &Arc, + manager: &Arc, + list: *mut flModelList, +) -> Vec { if list.is_null() { return Vec::new(); } @@ -154,28 +173,33 @@ pub(crate) fn collect_models(api: &Arc, list: *mut flModelList) -> Vec, pub(crate) ptr: *mut flCatalog, + /// Keeps the native manager that owns this `flCatalog` alive; also tagged + /// onto every model produced from this catalog. + manager: Arc, } -// SAFETY: the catalog handle is owned by the manager (process-lifetime) and the -// native implementation is thread-safe. +// SAFETY: the `flCatalog` is owned by the native manager kept alive via +// `manager`, and the native implementation is thread-safe. unsafe impl Send for NativeCatalog {} unsafe impl Sync for NativeCatalog {} impl NativeCatalog { - pub(crate) fn new(api: Arc, ptr: *mut flCatalog) -> Self { - Self { api, ptr } + pub(crate) fn new(api: Arc, ptr: *mut flCatalog, manager: Arc) -> Self { + Self { api, ptr, manager } } pub(crate) fn name(&self) -> Result { @@ -192,7 +216,7 @@ impl NativeCatalog { let mut list: *mut flModelList = std::ptr::null_mut(); let status = unsafe { f(self.ptr, &mut list) }; self.api.check(status)?; - Ok(collect_models(&self.api, list)) + Ok(collect_models(&self.api, &self.manager, list)) } pub(crate) fn get_models(&self) -> Result> { @@ -223,7 +247,11 @@ impl NativeCatalog { if model.is_null() { Ok(None) } else { - Ok(Some(NativeModel::new(Arc::clone(&self.api), model))) + Ok(Some(NativeModel::new( + Arc::clone(&self.api), + model, + Arc::clone(&self.manager), + ))) } } @@ -245,6 +273,10 @@ impl NativeCatalog { reason: "Catalog returned no latest version for the model.".into(), }); } - Ok(NativeModel::new(Arc::clone(&self.api), latest)) + Ok(NativeModel::new( + Arc::clone(&self.api), + latest, + Arc::clone(&self.manager), + )) } } diff --git a/sdk_v2/rust/src/detail/session.rs b/sdk_v2/rust/src/detail/session.rs index e661a12d9..5421fb966 100644 --- a/sdk_v2/rust/src/detail/session.rs +++ b/sdk_v2/rust/src/detail/session.rs @@ -11,6 +11,7 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use super::api::Api; use super::ffi::*; use super::items::{make_bytes_item, make_openai_json_item, read_text_item}; +use super::manager::NativeManager; use super::native::NativeModel; use crate::error::{FoundryLocalError, Result}; @@ -150,10 +151,14 @@ impl Drop for NativeItemQueue { pub(crate) struct NativeSession { pub(crate) api: Arc, ptr: *mut flSession, + /// Keeps the native manager (which owns the model this session was created + /// from) alive for the session's lifetime; never dereferenced here. + _manager: Arc, } // SAFETY: a session is used from a single worker at a time; the native layer is -// thread-safe for the create/process/release lifecycle used here. +// thread-safe for the create/process/release lifecycle used here. The owning +// native manager is kept alive via `_manager`. unsafe impl Send for NativeSession {} unsafe impl Sync for NativeSession {} @@ -163,7 +168,11 @@ impl NativeSession { let api = Arc::clone(&model.api); let mut ptr: *mut flSession = ptr::null_mut(); api.check(unsafe { (api.inference_api().Session_Create)(model.ptr, &mut ptr) })?; - Ok(Self { api, ptr }) + Ok(Self { + api, + ptr, + _manager: model.manager(), + }) } pub(crate) fn set_streaming_callback( diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs index c2cb9e71a..095b834a9 100644 --- a/sdk_v2/rust/src/foundry_local_manager.rs +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -102,9 +102,11 @@ impl FoundryLocalManager { /// /// While at least one returned [`Arc`] is alive, every call returns the /// **same** instance (a process-wide singleton) and the `config` passed to - /// later calls is ignored. Once the last handle is dropped the native - /// manager is torn down; a subsequent call then builds a fresh instance - /// from the new `config`. + /// later calls is ignored. The native manager is torn down once every handle + /// derived from it is gone — this `Arc`, plus any [`Model`](crate::Model), + /// client, or session it produced, each of which keeps the native manager + /// alive so handles can safely outlive this `Arc`. A subsequent call then + /// builds a fresh instance from the new `config`. /// /// Teardown runs via [`Drop`] when the final handle is released — not via a /// process-exit hook — so the native manager (and its EP unregistration) @@ -135,7 +137,7 @@ impl FoundryLocalManager { )?); let catalog_ptr = native.catalog_ptr()?; - let catalog = Catalog::new(Arc::clone(&api), catalog_ptr)?; + let catalog = Catalog::new(Arc::clone(&api), catalog_ptr, Arc::clone(&native))?; let manager = Arc::new(FoundryLocalManager { native, From ee7082cf0e685b1e4f37d071a0969f1588bc78f5 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Wed, 24 Jun 2026 13:37:06 +0100 Subject: [PATCH 6/7] Realign Rust FFI item vtable with new speech result types in C ABI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit origin/main #746 (Add speech result types) inserted GetSpeechSegment and GetSpeechResult into flItemApi between GetToolResult and GetMetadata. Because ffi.rs mirrors the vtable positionally via #[repr(C)], the old layout left GetMetadata/GetMutableMetadata/GetQueue and all ItemQueue_* slots shifted by two pointers — so calls like ItemQueue_Push/TryPop (used by streaming) would misdispatch to the wrong native function against a core built from the new header. Add the two function-pointer slots in the matching position, plus the supporting flSpeechWord/flSpeechSegmentData/flSpeechResultData structs, flSpeechSegmentKind, the FOUNDRY_LOCAL_ITEM_SPEECH_SEGMENT/RESULT item-type constants, and the DURATION/CONFIDENCE_UNSET sentinels. API version is unchanged (still 1). Validated: cargo fmt/check/clippy clean (lib, tests, examples). Vtable order now matches foundry_local_c.h flItemApi field-for-field (31 entries). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/src/detail/ffi.rs | 62 +++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/sdk_v2/rust/src/detail/ffi.rs b/sdk_v2/rust/src/detail/ffi.rs index e87c350f3..4d64cf709 100644 --- a/sdk_v2/rust/src/detail/ffi.rs +++ b/sdk_v2/rust/src/detail/ffi.rs @@ -81,6 +81,8 @@ pub const FOUNDRY_LOCAL_ITEM_TEXT: flItemType = 20; pub const FOUNDRY_LOCAL_ITEM_MESSAGE: flItemType = 21; pub const FOUNDRY_LOCAL_ITEM_IMAGE: flItemType = 25; pub const FOUNDRY_LOCAL_ITEM_AUDIO: flItemType = 30; +pub const FOUNDRY_LOCAL_ITEM_SPEECH_SEGMENT: flItemType = 31; +pub const FOUNDRY_LOCAL_ITEM_SPEECH_RESULT: flItemType = 32; pub const FOUNDRY_LOCAL_ITEM_TOOL_CALL: flItemType = 100; pub const FOUNDRY_LOCAL_ITEM_TOOL_RESULT: flItemType = 101; pub const FOUNDRY_LOCAL_ITEM_QUEUE: flItemType = 200; @@ -246,6 +248,51 @@ pub struct flToolResultData { pub result: *const c_char, } +// ── Speech recognition output types (output-only; see foundry_local_c.h) ────── + +/// Sentinel for absent time fields in the speech structs (`INT64_MIN`). +pub const FOUNDRY_LOCAL_DURATION_UNSET: i64 = i64::MIN; +/// Sentinel for absent confidence in `flSpeechWord` (`-FLT_MAX`). +pub const FOUNDRY_LOCAL_CONFIDENCE_UNSET: f32 = f32::MIN; + +pub type flSpeechSegmentKind = c_int; +pub const FOUNDRY_LOCAL_SPEECH_SEGMENT_NONE: flSpeechSegmentKind = 0; +pub const FOUNDRY_LOCAL_SPEECH_SEGMENT_PARTIAL: flSpeechSegmentKind = 1; +pub const FOUNDRY_LOCAL_SPEECH_SEGMENT_FINAL: flSpeechSegmentKind = 2; + +#[repr(C)] +pub struct flSpeechWord { + pub version: u32, + pub text: *const c_char, + pub start_time_ms: i64, + pub end_time_ms: i64, + pub confidence: f32, + pub speaker_id: *const c_char, +} + +#[repr(C)] +pub struct flSpeechSegmentData { + pub version: u32, + pub kind: flSpeechSegmentKind, + pub text: *const c_char, + pub start_time_ms: i64, + pub end_time_ms: i64, + pub utterance_start: bool, + pub words: *const flSpeechWord, + pub words_count: usize, + pub language: *const c_char, +} + +#[repr(C)] +pub struct flSpeechResultData { + pub version: u32, + pub text: *const c_char, + pub language: *const c_char, + pub duration_ms: i64, + pub segments: *const *const flItem, + pub segments_count: usize, +} + #[repr(C)] pub struct flStreamingCallbackData { pub version: u32, @@ -285,8 +332,10 @@ pub const FOUNDRY_LOCAL_GET_VERSION_STRING_SYMBOL: &[u8] = b"FoundryLocalGetVers // ── Function tables ────────────────────────────────────────────────────────── // -// Field order and signatures MUST match foundry_local_c.h exactly. New entries -// are only ever appended at the end of each table. +// Field order and signatures MUST match foundry_local_c.h exactly — the tables +// are consumed positionally, so a new entry inserted mid-table upstream must be +// mirrored at the same position here (e.g. GetSpeechSegment/GetSpeechResult sit +// between GetToolResult and GetMetadata in flItemApi). /// Root API table (`flApi`). #[repr(C)] @@ -427,6 +476,15 @@ pub struct flItemApiVtable { out_tool_result: *mut flToolResultData, ) -> flStatusPtr, + pub GetSpeechSegment: unsafe extern "system" fn( + item: *const flItem, + out_segment: *mut flSpeechSegmentData, + ) -> flStatusPtr, + pub GetSpeechResult: unsafe extern "system" fn( + item: *const flItem, + out_result: *mut flSpeechResultData, + ) -> flStatusPtr, + pub GetMetadata: unsafe extern "system" fn( item: *const flItem, out_metadata: *mut *const flKeyValuePairs, From c6317177eb7b1746c0036f594bd6570baeb816a9 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Wed, 24 Jun 2026 13:43:45 +0100 Subject: [PATCH 7/7] Read speech result items in live audio transcription MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit origin/main #746 changed the native audio path to emit SPEECH_SEGMENT items (streaming) and a SPEECH_RESULT item (final), replacing the TextItem outputs. LiveAudioTranscriptionSession uses that native path but read results only via read_text_item / item_text (ITEM_TEXT only), so against the new core it silently produced empty streaming results and an empty final transcript. Add read_speech_segment / read_speech_result_text helpers (via the new GetSpeechSegment / GetSpeechResult accessors) and wire them into the streaming trampoline and final-transcript aggregation, with a TEXT fallback so the OpenAI-JSON path and older cores keep working. Segment timing (ms→s) and PARTIAL/FINAL state are mapped onto the existing response envelope. AudioClient::transcribe / transcribe_streaming are unaffected — they use the OpenAI-JSON path, which still returns OPENAI_JSON-tagged TEXT items. Validated: cargo fmt/check/clippy clean (lib, tests, examples). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/src/detail/items.rs | 80 ++++++++++++++++++++ sdk_v2/rust/src/detail/session.rs | 15 +++- sdk_v2/rust/src/openai/live_audio_session.rs | 46 ++++++++--- 3 files changed, 129 insertions(+), 12 deletions(-) diff --git a/sdk_v2/rust/src/detail/items.rs b/sdk_v2/rust/src/detail/items.rs index 5f0de9f4f..bbb3acb7b 100644 --- a/sdk_v2/rust/src/detail/items.rs +++ b/sdk_v2/rust/src/detail/items.rs @@ -226,3 +226,83 @@ pub(crate) unsafe fn read_text_item(api: &Api, item: *const flItem) -> Option, + pub end_time_s: Option, +} + +/// Convert a native millisecond field to seconds, mapping the UNSET sentinel to `None`. +fn duration_ms_to_seconds(ms: i64) -> Option { + if ms == FOUNDRY_LOCAL_DURATION_UNSET { + None + } else { + Some(ms as f64 / 1000.0) + } +} + +/// Read a SPEECH_SEGMENT item (output-only). Returns `None` for null/other items. +/// +/// # Safety +/// `item` must be null or a valid item pointer alive for the duration of this call. +pub(crate) unsafe fn read_speech_segment( + api: &Api, + item: *const flItem, +) -> Option { + if item.is_null() || (api.item_api().GetType)(item) != FOUNDRY_LOCAL_ITEM_SPEECH_SEGMENT { + return None; + } + let mut data = flSpeechSegmentData { + version: FOUNDRY_LOCAL_API_VERSION, + kind: FOUNDRY_LOCAL_SPEECH_SEGMENT_NONE, + text: ptr::null::(), + start_time_ms: FOUNDRY_LOCAL_DURATION_UNSET, + end_time_ms: FOUNDRY_LOCAL_DURATION_UNSET, + utterance_start: false, + words: ptr::null::(), + words_count: 0, + language: ptr::null::(), + }; + if api + .check((api.item_api().GetSpeechSegment)(item, &mut data)) + .is_err() + { + return None; + } + Some(SpeechSegmentText { + text: cstr_to_string(data.text).unwrap_or_default(), + // PARTIAL is an interim hypothesis; FINAL (and NONE entries) are stable. + is_final: data.kind == FOUNDRY_LOCAL_SPEECH_SEGMENT_FINAL, + start_time_s: duration_ms_to_seconds(data.start_time_ms), + end_time_s: duration_ms_to_seconds(data.end_time_ms), + }) +} + +/// Read the concatenated transcript of a SPEECH_RESULT item (output-only). +/// Returns `None` for null/other items. +/// +/// # Safety +/// `item` must be null or a valid item pointer alive for the duration of this call. +pub(crate) unsafe fn read_speech_result_text(api: &Api, item: *const flItem) -> Option { + if item.is_null() || (api.item_api().GetType)(item) != FOUNDRY_LOCAL_ITEM_SPEECH_RESULT { + return None; + } + let mut data = flSpeechResultData { + version: FOUNDRY_LOCAL_API_VERSION, + text: ptr::null::(), + language: ptr::null::(), + duration_ms: FOUNDRY_LOCAL_DURATION_UNSET, + segments: ptr::null::<*const flItem>(), + segments_count: 0, + }; + if api + .check((api.item_api().GetSpeechResult)(item, &mut data)) + .is_err() + { + return None; + } + cstr_to_string(data.text) +} diff --git a/sdk_v2/rust/src/detail/session.rs b/sdk_v2/rust/src/detail/session.rs index 5421fb966..f263b4020 100644 --- a/sdk_v2/rust/src/detail/session.rs +++ b/sdk_v2/rust/src/detail/session.rs @@ -10,7 +10,9 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use super::api::Api; use super::ffi::*; -use super::items::{make_bytes_item, make_openai_json_item, read_text_item}; +use super::items::{ + make_bytes_item, make_openai_json_item, read_speech_result_text, read_text_item, +}; use super::manager::NativeManager; use super::native::NativeModel; use crate::error::{FoundryLocalError, Result}; @@ -71,6 +73,17 @@ impl NativeResponse { } unsafe { read_text_item(&self.api, item) } } + + /// Read the transcript of the response item at `idx` (if it is a SPEECH_RESULT item). + pub(crate) fn item_speech_result_text(&self, idx: usize) -> Option { + let mut item: *const flItem = ptr::null(); + let status = + unsafe { (self.api.inference_api().Response_GetItem)(self.ptr, idx, &mut item) }; + if self.api.check(status).is_err() { + return None; + } + unsafe { read_speech_result_text(&self.api, item) } + } } impl Drop for NativeResponse { diff --git a/sdk_v2/rust/src/openai/live_audio_session.rs b/sdk_v2/rust/src/openai/live_audio_session.rs index 52302caa8..4aa3a24e6 100644 --- a/sdk_v2/rust/src/openai/live_audio_session.rs +++ b/sdk_v2/rust/src/openai/live_audio_session.rs @@ -40,7 +40,9 @@ use tokio_util::sync::CancellationToken; use crate::detail::api::Api; use crate::detail::ffi::{flItem, flStreamingCallbackData, FOUNDRY_LOCAL_ITEM_BYTES}; -use crate::detail::items::{make_audio_item, read_text_item}; +use crate::detail::items::{ + make_audio_item, read_speech_segment, read_text_item, SpeechSegmentText, +}; use crate::detail::native::NativeModel; use crate::detail::session::{NativeItemQueue, NativeRequest, NativeSession}; use crate::detail::task::spawn_blocking; @@ -162,6 +164,20 @@ impl LiveAudioTranscriptionResponse { id: None, } } + + /// Build a response from a SPEECH_SEGMENT item's content. + fn from_segment(seg: SpeechSegmentText) -> Self { + Self { + content: vec![ContentPart { + transcript: seg.text.clone(), + text: seg.text, + }], + is_final: seg.is_final, + start_time: seg.start_time_s, + end_time: seg.end_time_s, + id: None, + } + } } /// Structured error response from the native core. @@ -391,15 +407,18 @@ unsafe extern "C" fn live_trampoline( if item.is_null() { continue; } - let text = read_text_item(&ctx.api, item); + // The native audio path now streams SPEECH_SEGMENT items; older cores + // (and the OpenAI-JSON path) stream plain TEXT items. Handle both. + let response = if let Some(text) = read_text_item(&ctx.api, item) { + (!text.is_empty()).then(|| LiveAudioTranscriptionResponse::from_text(text, false)) + } else { + read_speech_segment(&ctx.api, item) + .filter(|seg| !seg.text.is_empty()) + .map(LiveAudioTranscriptionResponse::from_segment) + }; (item_api.Item_Release)(item); - if let Some(text) = text { - if !text.is_empty() - && ctx - .tx - .send(Ok(LiveAudioTranscriptionResponse::from_text(text, false))) - .is_err() - { + if let Some(response) = response { + if ctx.tx.send(Ok(response)).is_err() { return 1; // receiver dropped — cancel } } @@ -449,10 +468,15 @@ fn run_worker( let _ = session.set_streaming_callback(None, std::ptr::null_mut()); let response = response?; - // Aggregate the terminal transcript from the final response items. + // Aggregate the terminal transcript from the final response items. The + // native audio path returns a SPEECH_RESULT item; the OpenAI-JSON path + // (and older cores) return TEXT items. let mut final_text = String::new(); for i in 0..response.item_count() { - if let Some(text) = response.item_text(i) { + if let Some(text) = response + .item_text(i) + .or_else(|| response.item_speech_result_text(i)) + { final_text.push_str(&text); } }