From 3fabdc8fae30c88399b0cb310ce5e80d6ead4891 Mon Sep 17 00:00:00 2001 From: Anoop Narang Date: Tue, 2 Jun 2026 13:34:55 +0530 Subject: [PATCH] fix(rule): keep index for k-NN that returns metadata, fall back when vector is projected MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A k-NN query that projects the indexed vector column (SELECT *, or SELECT id, embedding) crashed with a post-optimizer schema-mismatch when an index was present: the passthrough branch built its output from the index node's columns (addressing key + non-vector columns), which can't include the vector, so the rewritten plan's schema differed from the original and DataFusion's invariant check aborted the query. The fix is output-aware. The rule now also anchors on a Projection sitting directly over a passthrough k-NN Sort and drives the rewrite from that outer projection's columns — the query's real output: - vector NOT in output (e.g. SELECT id ... ORDER BY l2_distance(emb, ...), the common "nearest ids" query) -> every output column is producible from the node, so the index is still used. - vector IN output (SELECT *, SELECT id, embedding) -> the rewrite can't produce it, so the rule declines and the query falls back to exact brute-force search (correct, like the existing DESC / metric-mismatch fallbacks) instead of crashing. This keeps the metadata-only k-NN path on the index (no regression) while fixing the crash. A code comment records the rejected alternative (have USearchExec reconstruct the vector via index.get) and why: it would make the index a second source of returned vectors that must byte-match the source, which breaks under F16 quantization. Regression tests model production (lookup schema excludes the vector column, which the existing tests' provider included, masking the bug). README documents the fallback. Fixes #508 --- README.md | 32 ++++-- src/rule.rs | 114 ++++++++++++++---- tests/vector_col_projection.rs | 204 +++++++++++++++++++++++++++++++++ 3 files changed, 319 insertions(+), 31 deletions(-) create mode 100644 tests/vector_col_projection.rs diff --git a/README.md b/README.md index f3b2da1..e1d9ff2 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A DataFusion extension that integrates [USearch](https://github.com/unum-cloud/u Queries matching the `ORDER BY distance_fn(col, query) LIMIT k` pattern are **transparently rewritten** by an optimizer rule into a native USearch index call — no query rewrite needed from the caller. `WHERE` clause filters are handled adaptively: high-selectivity filters use USearch's in-graph predicate API; low-selectivity filters bypass HNSW entirely and scan the data directly. -**DataFusion:** 52.2   **USearch:** 2.24 +**DataFusion:** 53   **USearch:** 2.24 --- @@ -230,20 +230,33 @@ tests/ ### Optimizer rewrite -The rule (`rule.rs`) matches two logical plan shapes: +The rule (`rule.rs`) matches the `Sort(fetch=k)` over a `TableScan`, with an +optional `Projection` between them and an optional `Filter` directly above the +scan: ``` Sort(fetch=k, ORDER BY dist ASC) - Projection([..., distance_fn(col, lit) AS dist, ...]) - TableScan(name) - -Sort(fetch=k, ORDER BY dist ASC) - Projection([..., distance_fn(col, lit) AS dist, ...]) - Filter(predicate) + [ Projection([..., distance_fn(col, lit) AS dist, ...]) ] ← optional + [ Filter(predicate) ] ← optional TableScan(name) ``` -Preconditions: sort is `ASC`, distance UDF matches index metric, table is registered, query vector is a literal. When the rule fires, it replaces the inner nodes with a `USearchNode` leaf carrying: table name, vector column, query vector, k, distance type, and absorbed filter predicates. The `Sort` node is preserved above for final ordering. +DataFusion omits the `Projection` for `SELECT *` (and for any SELECT whose +columns come straight from the scan), so the `Sort` can sit directly on the +`TableScan`. + +Preconditions: sort is `ASC`, distance UDF matches index metric, table is +registered, query vector is a literal. When the rule fires, it replaces the inner +nodes with a `USearchNode` leaf carrying: table name, vector column, query +vector, k, distance type, and absorbed filter predicates. The `Sort` node is +preserved above for final ordering. + +**Schema preservation:** an optimizer rule must not change the plan's output +schema. The `USearchNode` produces only what the `lookup_provider` can fetch by +key (addressing key + non-vector columns) plus `_distance` — it cannot produce +the indexed vector column. If the matched `Sort`'s output would include the +vector column (e.g. `SELECT *`), the rule declines and the query falls back to +exact execution rather than emitting a schema-incompatible plan. Physical planning (`planner.rs`) translates `USearchNode` into `USearchExec`, a physical plan node that executes the actual search. @@ -305,6 +318,7 @@ Tests cover optimizer rule matching/rejection, end-to-end execution through both | Limitation | Notes | |---|---| +| Projecting the indexed vector column | A k-NN query whose output includes the vector column itself (e.g. `SELECT *`, or `SELECT id, vector`) falls back to exact execution. The `lookup_provider` does not store the vector column (see [registration](#register-providers-and-set-up-the-sessioncontext)), so the rewrite cannot reproduce it. Project the metadata columns and the distance instead. | | Stacked `Filter` nodes | Only one `Filter -> TableScan` layer is absorbed. `Filter -> Filter -> TableScan` falls back to exact execution. DataFusion typically combines multiple WHERE conditions into a single Filter, so this rarely occurs. | | Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use `vector_search_vector` for explicit ANN queries. | | `ef_search` per-query | `expansion_search` is global to the index instance. Per-query adjustment is not supported. | diff --git a/src/rule.rs b/src/rule.rs index 7896516..1c502fd 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -48,21 +48,66 @@ impl USearchRule { } fn try_match(&self, plan: &LogicalPlan) -> Option { + match plan { + // Anchor on the Sort itself. The projection (if any) sits *below* the + // Sort and supplies its output columns; SELECT * omits it entirely. + LogicalPlan::Sort(sort) => { + let (proj_exprs_slice, after_sort): (&[Expr], &LogicalPlan) = + match sort.input.as_ref() { + LogicalPlan::Projection(p) => (p.expr.as_slice(), p.input.as_ref()), + other => (&[], other), + }; + self.build_rewrite(sort, proj_exprs_slice, after_sort) + } + + // Output-aware passthrough. When a Projection sits directly over a + // k-NN Sort that rests on the scan (no projection between them), drive + // the rewrite with the OUTER projection's columns — i.e. the query's + // real output. The rewrite can only produce the index node's columns + // (addressing key + non-vector columns + _distance), never the indexed + // vector itself. Routing the output columns through `build_rewrite` + // lets it fire when they're all producible (e.g. `SELECT id … ORDER BY + // l2_distance(emb, …)`) and decline — falling back to exact search — + // when the output needs the vector (`SELECT *`, `SELECT id, emb`), + // rather than emitting a schema the consumer can't satisfy (issue #508). + // + // ALTERNATIVE (not taken): teach USearchExec to reconstruct the vector + // column for the result keys via `index.get(key)`, so even + // vector-returning queries stay on the index. Rejected to keep a single + // source of truth for returned vectors — the index would otherwise be a + // second source that must byte-match the parquet (breaks under F16 + // quantization, and relies on USearch never transforming stored vectors). + // See the README "Limitations" entry and runtimedb issue #508. + LogicalPlan::Projection(outer) => { + let LogicalPlan::Sort(sort) = outer.input.as_ref() else { + return None; + }; + // Only the passthrough shape; the remap shape (projection *below* + // the Sort) is handled when we visit the Sort above. + if !matches!( + sort.input.as_ref(), + LogicalPlan::TableScan(_) | LogicalPlan::Filter(_) + ) { + return None; + } + self.build_rewrite(sort, &outer.expr, sort.input.as_ref()) + } + + _ => None, + } + } + + fn build_rewrite( + &self, + sort: &datafusion::logical_expr::logical_plan::Sort, + proj_exprs_slice: &[Expr], + after_sort: &LogicalPlan, + ) -> Option { use datafusion::logical_expr::logical_plan::TableScan; // Require Sort with embedded fetch limit. - let sort = match plan { - LogicalPlan::Sort(s) => s, - _ => return None, - }; let k = sort.fetch?; - // Projection is optional — DataFusion 51 omits it for SELECT * queries. - let (proj_exprs_slice, after_sort): (&[Expr], &LogicalPlan) = match sort.input.as_ref() { - LogicalPlan::Projection(p) => (p.expr.as_slice(), p.input.as_ref()), - other => (&[], other), - }; - // Accept TableScan directly, or Filter(TableScan) for WHERE clauses. // Deeper nesting (Filter→Filter→…) is not absorbed — the rule does // not fire and DataFusion falls back to exact execution. @@ -155,7 +200,14 @@ impl USearchRule { // Build the final user-visible projection over USearchNode output. let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance"); let final_proj_exprs = if proj_exprs_slice.is_empty() { - passthrough_projection(&vsn_df_schema, &table_ref) + // No explicit Projection node (e.g. SELECT *, or a SELECT whose + // columns come straight from the scan, so the Sort sits directly on + // the TableScan). The rewrite must reproduce the original output + // columns; if any isn't producible from the node — the indexed + // vector column is never stored in the fetch path — bail so the + // query falls back to exact brute-force search, like the other + // unsupported shapes (DESC, metric mismatch, stacked filters). + passthrough_projection(after_sort.schema().as_ref(), &vsn_df_schema, &table_ref)? } else { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; @@ -375,21 +427,39 @@ fn build_outer_projection(exprs: &[Expr]) -> Vec { .collect() } -/// Build a passthrough Projection for SELECT * queries (no original Projection node). -/// Projects only the original table columns (not `_distance`) so the output schema -/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression -/// on the k result rows returned by USearchExec (O(k × dim), negligible for small k). -fn passthrough_projection(schema: &DFSchema, table_ref: &TableReference) -> Vec { - schema +/// Build a passthrough Projection for queries with no explicit Projection node +/// (e.g. `SELECT *`, or a SELECT whose columns come straight from the scan so the +/// Sort sits directly on the TableScan). +/// +/// The projection must reproduce the *original* output columns (`original_schema`, +/// the Sort's input). The `USearchNode` can only produce the columns in +/// `node_schema` — the fetch path's addressing key + non-vector columns + +/// `_distance`; the indexed vector column is never stored there (see +/// `PointLookupProvider`). If the original output needs a column the node can't +/// produce (the vector column), return `None` so the rule declines to rewrite and +/// the query falls back to exact brute-force search. The Sort re-evaluates the +/// distance UDF on the k result rows returned by USearchExec (O(k × dim)). +fn passthrough_projection( + original_schema: &DFSchema, + node_schema: &DFSchema, + table_ref: &TableReference, +) -> Option> { + original_schema .inner() .fields() .iter() - .filter(|f| f.name() != "_distance") .map(|f| { - Expr::Column(datafusion::common::Column::new( - Some(table_ref.clone()), - f.name().as_str(), - )) + let producible = node_schema + .inner() + .fields() + .iter() + .any(|nf| nf.name() == f.name()); + producible.then(|| { + Expr::Column(datafusion::common::Column::new( + Some(table_ref.clone()), + f.name().as_str(), + )) + }) }) .collect() } diff --git a/tests/vector_col_projection.rs b/tests/vector_col_projection.rs new file mode 100644 index 0000000..d5da67e --- /dev/null +++ b/tests/vector_col_projection.rs @@ -0,0 +1,204 @@ +// tests/vector_col_projection.rs — Regression tests for the case where a k-NN +// query projects the indexed vector column itself (or SELECT *). +// +// Unlike tests/optimizer_rule.rs, the lookup provider's schema here DELIBERATELY +// EXCLUDES the vector column — faithfully modelling production, where the SQLite +// sidecar stores only the addressing key + non-vector columns (the vector itself +// is never stored). The registry derives meta.schema from the lookup provider, so +// meta.schema lacks the vector column. +// +// In this configuration the rewrite cannot reproduce the vector column in its +// output. The rule must therefore decline to fire (fall back to brute-force exact +// search) rather than produce a plan whose output schema differs from the +// original — which trips DataFusion's post-optimizer invariant check. + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use usearch::{Index, IndexOptions, MetricKind, ScalarKind}; + +use datafusion_vector_search_ext::{HashKeyProvider, USearchNode, USearchRegistry, register_all}; + +/// The user-visible table: addressing key absent, vector column present. +fn table_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + Field::new( + "embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])) +} + +/// The sidecar/lookup schema: synthetic addressing key + non-vector columns. +/// The vector column is excluded — exactly as the SQLite sidecar stores it. +fn lookup_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("_key", DataType::UInt64, false), + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + ])) +} + +fn make_index() -> Arc { + let options = IndexOptions { + dimensions: 4, + metric: MetricKind::L2sq, + quantization: ScalarKind::F32, + ..Default::default() + }; + Arc::new(Index::new(&options).expect("usearch Index::new failed")) +} + +async fn make_ctx() -> SessionContext { + // Registry's scan provider mirrors the parquet provider: full schema incl. + // the synthetic `_key` and the vector column. Keyed on `_key`. + let scan_provider = Arc::new( + HashKeyProvider::try_new( + Arc::new(Schema::new(vec![ + Field::new("_key", DataType::UInt64, false), + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])), + vec![], + "_key", + ) + .expect("scan HashKeyProvider::try_new failed"), + ); + + // Registry's lookup provider mirrors the SQLite sidecar: NO vector column. + let lookup_provider = Arc::new( + HashKeyProvider::try_new(lookup_schema(), vec![], "_key") + .expect("lookup HashKeyProvider::try_new failed"), + ); + + let reg = USearchRegistry::new(); + reg.add( + "items::embedding", + make_index(), + scan_provider, + lookup_provider, + "_key", + MetricKind::L2sq, + ScalarKind::F32, + ) + .expect("USearchRegistry::add failed"); + let registry = reg.into_arc(); + + let ctx = SessionContext::default(); + register_all(&ctx, registry).expect("register_all failed"); + + // The SQL-visible table carries the real columns (no `_key`). + let table = Arc::new( + HashKeyProvider::try_new(table_schema(), vec![], "id") + .expect("table HashKeyProvider::try_new failed"), + ); + ctx.register_table("items", table) + .expect("register_table failed"); + ctx +} + +fn contains_usearch_node(plan: &LogicalPlan) -> bool { + if let LogicalPlan::Extension(ext) = plan + && ext.node.as_any().downcast_ref::().is_some() + { + return true; + } + plan.inputs().iter().any(|c| contains_usearch_node(c)) +} + +const Q: &str = "ARRAY[0.1, 0.2, 0.3, 0.4]"; + +/// SELECT * over an indexed table: the output includes the vector column, which +/// the rewrite cannot produce. The rule must NOT fire, and optimization must +/// succeed (falling back to exact search) rather than erroring on a schema +/// mismatch. +#[tokio::test] +async fn test_select_star_with_vector_index_does_not_crash() { + let ctx = make_ctx().await; + let sql = format!("SELECT * FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error when the vector column is in the output"); + assert!( + !contains_usearch_node(&plan), + "vector column in output → rule must fall back to exact search, not rewrite\nPlan: {plan:?}" + ); +} + +/// Explicitly projecting the indexed vector column has the same requirement. +#[tokio::test] +async fn test_select_vector_column_does_not_crash() { + let ctx = make_ctx().await; + let sql = + format!("SELECT id, embedding FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error when the vector column is in the output"); + assert!( + !contains_usearch_node(&plan), + "vector column in output → rule must fall back to exact search, not rewrite\nPlan: {plan:?}" + ); +} + +/// The canonical vector-search query (distance aliased in the SELECT, ORDER BY +/// the alias) keeps rewriting: its output columns (id, the distance) are all +/// producible from the sidecar, and the aliased distance forces a Projection +/// below the Sort so the rewrite reproduces the schema exactly. +#[tokio::test] +async fn test_aliased_distance_still_rewrites() { + let ctx = make_ctx().await; + let sql = format!( + "SELECT id, l2_distance(embedding, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 2" + ); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization failed"); + assert!( + contains_usearch_node(&plan), + "no vector column in output → rule must still fire\nPlan: {plan:?}" + ); +} + +/// A bare projection whose ORDER BY computes the distance inline (distance not in +/// the SELECT, no Projection below the Sort) — the common "give me the nearest +/// rows' ids" query. The indexed vector is only a sort input, not an output +/// column, so the rule must still use the index. This is the regression guard for +/// the over-eager fallback: the output-aware passthrough drives the rewrite from +/// the OUTER projection (`[id]`, all producible), not the Sort's schema. +#[tokio::test] +async fn test_bare_select_inline_distance_still_rewrites() { + let ctx = make_ctx().await; + let sql = format!("SELECT id FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error"); + assert!( + contains_usearch_node(&plan), + "vector not in output → rule must still use the index, not fall back\nPlan: {plan:?}" + ); +}