diff --git a/README.md b/README.md index f3b2da1..e1d9ff2 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A DataFusion extension that integrates [USearch](https://github.com/unum-cloud/u Queries matching the `ORDER BY distance_fn(col, query) LIMIT k` pattern are **transparently rewritten** by an optimizer rule into a native USearch index call — no query rewrite needed from the caller. `WHERE` clause filters are handled adaptively: high-selectivity filters use USearch's in-graph predicate API; low-selectivity filters bypass HNSW entirely and scan the data directly. -**DataFusion:** 52.2   **USearch:** 2.24 +**DataFusion:** 53   **USearch:** 2.24 --- @@ -230,20 +230,33 @@ tests/ ### Optimizer rewrite -The rule (`rule.rs`) matches two logical plan shapes: +The rule (`rule.rs`) matches the `Sort(fetch=k)` over a `TableScan`, with an +optional `Projection` between them and an optional `Filter` directly above the +scan: ``` Sort(fetch=k, ORDER BY dist ASC) - Projection([..., distance_fn(col, lit) AS dist, ...]) - TableScan(name) - -Sort(fetch=k, ORDER BY dist ASC) - Projection([..., distance_fn(col, lit) AS dist, ...]) - Filter(predicate) + [ Projection([..., distance_fn(col, lit) AS dist, ...]) ] ← optional + [ Filter(predicate) ] ← optional TableScan(name) ``` -Preconditions: sort is `ASC`, distance UDF matches index metric, table is registered, query vector is a literal. When the rule fires, it replaces the inner nodes with a `USearchNode` leaf carrying: table name, vector column, query vector, k, distance type, and absorbed filter predicates. The `Sort` node is preserved above for final ordering. +DataFusion omits the `Projection` for `SELECT *` (and for any SELECT whose +columns come straight from the scan), so the `Sort` can sit directly on the +`TableScan`. + +Preconditions: sort is `ASC`, distance UDF matches index metric, table is +registered, query vector is a literal. When the rule fires, it replaces the inner +nodes with a `USearchNode` leaf carrying: table name, vector column, query +vector, k, distance type, and absorbed filter predicates. The `Sort` node is +preserved above for final ordering. + +**Schema preservation:** an optimizer rule must not change the plan's output +schema. The `USearchNode` produces only what the `lookup_provider` can fetch by +key (addressing key + non-vector columns) plus `_distance` — it cannot produce +the indexed vector column. If the matched `Sort`'s output would include the +vector column (e.g. `SELECT *`), the rule declines and the query falls back to +exact execution rather than emitting a schema-incompatible plan. Physical planning (`planner.rs`) translates `USearchNode` into `USearchExec`, a physical plan node that executes the actual search. @@ -305,6 +318,7 @@ Tests cover optimizer rule matching/rejection, end-to-end execution through both | Limitation | Notes | |---|---| +| Projecting the indexed vector column | A k-NN query whose output includes the vector column itself (e.g. `SELECT *`, or `SELECT id, vector`) falls back to exact execution. The `lookup_provider` does not store the vector column (see [registration](#register-providers-and-set-up-the-sessioncontext)), so the rewrite cannot reproduce it. Project the metadata columns and the distance instead. | | Stacked `Filter` nodes | Only one `Filter -> TableScan` layer is absorbed. `Filter -> Filter -> TableScan` falls back to exact execution. DataFusion typically combines multiple WHERE conditions into a single Filter, so this rarely occurs. | | Runtime query vectors | The query vector must be a compile-time literal (`ARRAY[0.1, ...]`). Column references or subquery results are not rewritten. Use `vector_search_vector` for explicit ANN queries. | | `ef_search` per-query | `expansion_search` is global to the index instance. Per-query adjustment is not supported. | diff --git a/src/rule.rs b/src/rule.rs index 7896516..1c502fd 100644 --- a/src/rule.rs +++ b/src/rule.rs @@ -48,21 +48,66 @@ impl USearchRule { } fn try_match(&self, plan: &LogicalPlan) -> Option { + match plan { + // Anchor on the Sort itself. The projection (if any) sits *below* the + // Sort and supplies its output columns; SELECT * omits it entirely. + LogicalPlan::Sort(sort) => { + let (proj_exprs_slice, after_sort): (&[Expr], &LogicalPlan) = + match sort.input.as_ref() { + LogicalPlan::Projection(p) => (p.expr.as_slice(), p.input.as_ref()), + other => (&[], other), + }; + self.build_rewrite(sort, proj_exprs_slice, after_sort) + } + + // Output-aware passthrough. When a Projection sits directly over a + // k-NN Sort that rests on the scan (no projection between them), drive + // the rewrite with the OUTER projection's columns — i.e. the query's + // real output. The rewrite can only produce the index node's columns + // (addressing key + non-vector columns + _distance), never the indexed + // vector itself. Routing the output columns through `build_rewrite` + // lets it fire when they're all producible (e.g. `SELECT id … ORDER BY + // l2_distance(emb, …)`) and decline — falling back to exact search — + // when the output needs the vector (`SELECT *`, `SELECT id, emb`), + // rather than emitting a schema the consumer can't satisfy (issue #508). + // + // ALTERNATIVE (not taken): teach USearchExec to reconstruct the vector + // column for the result keys via `index.get(key)`, so even + // vector-returning queries stay on the index. Rejected to keep a single + // source of truth for returned vectors — the index would otherwise be a + // second source that must byte-match the parquet (breaks under F16 + // quantization, and relies on USearch never transforming stored vectors). + // See the README "Limitations" entry and runtimedb issue #508. + LogicalPlan::Projection(outer) => { + let LogicalPlan::Sort(sort) = outer.input.as_ref() else { + return None; + }; + // Only the passthrough shape; the remap shape (projection *below* + // the Sort) is handled when we visit the Sort above. + if !matches!( + sort.input.as_ref(), + LogicalPlan::TableScan(_) | LogicalPlan::Filter(_) + ) { + return None; + } + self.build_rewrite(sort, &outer.expr, sort.input.as_ref()) + } + + _ => None, + } + } + + fn build_rewrite( + &self, + sort: &datafusion::logical_expr::logical_plan::Sort, + proj_exprs_slice: &[Expr], + after_sort: &LogicalPlan, + ) -> Option { use datafusion::logical_expr::logical_plan::TableScan; // Require Sort with embedded fetch limit. - let sort = match plan { - LogicalPlan::Sort(s) => s, - _ => return None, - }; let k = sort.fetch?; - // Projection is optional — DataFusion 51 omits it for SELECT * queries. - let (proj_exprs_slice, after_sort): (&[Expr], &LogicalPlan) = match sort.input.as_ref() { - LogicalPlan::Projection(p) => (p.expr.as_slice(), p.input.as_ref()), - other => (&[], other), - }; - // Accept TableScan directly, or Filter(TableScan) for WHERE clauses. // Deeper nesting (Filter→Filter→…) is not absorbed — the rule does // not fire and DataFusion falls back to exact execution. @@ -155,7 +200,14 @@ impl USearchRule { // Build the final user-visible projection over USearchNode output. let dist_alias_str = dist_alias.as_deref().unwrap_or("_distance"); let final_proj_exprs = if proj_exprs_slice.is_empty() { - passthrough_projection(&vsn_df_schema, &table_ref) + // No explicit Projection node (e.g. SELECT *, or a SELECT whose + // columns come straight from the scan, so the Sort sits directly on + // the TableScan). The rewrite must reproduce the original output + // columns; if any isn't producible from the node — the indexed + // vector column is never stored in the fetch path — bail so the + // query falls back to exact brute-force search, like the other + // unsupported shapes (DESC, metric mismatch, stacked filters). + passthrough_projection(after_sort.schema().as_ref(), &vsn_df_schema, &table_ref)? } else { remap_projections(proj_exprs_slice, dist_alias_str, &table_ref) }; @@ -375,21 +427,39 @@ fn build_outer_projection(exprs: &[Expr]) -> Vec { .collect() } -/// Build a passthrough Projection for SELECT * queries (no original Projection node). -/// Projects only the original table columns (not `_distance`) so the output schema -/// matches the original Sort schema. The Sort re-evaluates the distance UDF expression -/// on the k result rows returned by USearchExec (O(k × dim), negligible for small k). -fn passthrough_projection(schema: &DFSchema, table_ref: &TableReference) -> Vec { - schema +/// Build a passthrough Projection for queries with no explicit Projection node +/// (e.g. `SELECT *`, or a SELECT whose columns come straight from the scan so the +/// Sort sits directly on the TableScan). +/// +/// The projection must reproduce the *original* output columns (`original_schema`, +/// the Sort's input). The `USearchNode` can only produce the columns in +/// `node_schema` — the fetch path's addressing key + non-vector columns + +/// `_distance`; the indexed vector column is never stored there (see +/// `PointLookupProvider`). If the original output needs a column the node can't +/// produce (the vector column), return `None` so the rule declines to rewrite and +/// the query falls back to exact brute-force search. The Sort re-evaluates the +/// distance UDF on the k result rows returned by USearchExec (O(k × dim)). +fn passthrough_projection( + original_schema: &DFSchema, + node_schema: &DFSchema, + table_ref: &TableReference, +) -> Option> { + original_schema .inner() .fields() .iter() - .filter(|f| f.name() != "_distance") .map(|f| { - Expr::Column(datafusion::common::Column::new( - Some(table_ref.clone()), - f.name().as_str(), - )) + let producible = node_schema + .inner() + .fields() + .iter() + .any(|nf| nf.name() == f.name()); + producible.then(|| { + Expr::Column(datafusion::common::Column::new( + Some(table_ref.clone()), + f.name().as_str(), + )) + }) }) .collect() } diff --git a/tests/vector_col_projection.rs b/tests/vector_col_projection.rs new file mode 100644 index 0000000..d5da67e --- /dev/null +++ b/tests/vector_col_projection.rs @@ -0,0 +1,204 @@ +// tests/vector_col_projection.rs — Regression tests for the case where a k-NN +// query projects the indexed vector column itself (or SELECT *). +// +// Unlike tests/optimizer_rule.rs, the lookup provider's schema here DELIBERATELY +// EXCLUDES the vector column — faithfully modelling production, where the SQLite +// sidecar stores only the addressing key + non-vector columns (the vector itself +// is never stored). The registry derives meta.schema from the lookup provider, so +// meta.schema lacks the vector column. +// +// In this configuration the rewrite cannot reproduce the vector column in its +// output. The rule must therefore decline to fire (fall back to brute-force exact +// search) rather than produce a plan whose output schema differs from the +// original — which trips DataFusion's post-optimizer invariant check. + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::LogicalPlan; +use datafusion::prelude::SessionContext; +use usearch::{Index, IndexOptions, MetricKind, ScalarKind}; + +use datafusion_vector_search_ext::{HashKeyProvider, USearchNode, USearchRegistry, register_all}; + +/// The user-visible table: addressing key absent, vector column present. +fn table_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + Field::new( + "embedding", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4), + false, + ), + ])) +} + +/// The sidecar/lookup schema: synthetic addressing key + non-vector columns. +/// The vector column is excluded — exactly as the SQLite sidecar stores it. +fn lookup_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("_key", DataType::UInt64, false), + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + ])) +} + +fn make_index() -> Arc { + let options = IndexOptions { + dimensions: 4, + metric: MetricKind::L2sq, + quantization: ScalarKind::F32, + ..Default::default() + }; + Arc::new(Index::new(&options).expect("usearch Index::new failed")) +} + +async fn make_ctx() -> SessionContext { + // Registry's scan provider mirrors the parquet provider: full schema incl. + // the synthetic `_key` and the vector column. Keyed on `_key`. + let scan_provider = Arc::new( + HashKeyProvider::try_new( + Arc::new(Schema::new(vec![ + Field::new("_key", DataType::UInt64, false), + Field::new("id", DataType::Int32, false), + Field::new("label", DataType::Utf8, true), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float32, true)), + 4, + ), + false, + ), + ])), + vec![], + "_key", + ) + .expect("scan HashKeyProvider::try_new failed"), + ); + + // Registry's lookup provider mirrors the SQLite sidecar: NO vector column. + let lookup_provider = Arc::new( + HashKeyProvider::try_new(lookup_schema(), vec![], "_key") + .expect("lookup HashKeyProvider::try_new failed"), + ); + + let reg = USearchRegistry::new(); + reg.add( + "items::embedding", + make_index(), + scan_provider, + lookup_provider, + "_key", + MetricKind::L2sq, + ScalarKind::F32, + ) + .expect("USearchRegistry::add failed"); + let registry = reg.into_arc(); + + let ctx = SessionContext::default(); + register_all(&ctx, registry).expect("register_all failed"); + + // The SQL-visible table carries the real columns (no `_key`). + let table = Arc::new( + HashKeyProvider::try_new(table_schema(), vec![], "id") + .expect("table HashKeyProvider::try_new failed"), + ); + ctx.register_table("items", table) + .expect("register_table failed"); + ctx +} + +fn contains_usearch_node(plan: &LogicalPlan) -> bool { + if let LogicalPlan::Extension(ext) = plan + && ext.node.as_any().downcast_ref::().is_some() + { + return true; + } + plan.inputs().iter().any(|c| contains_usearch_node(c)) +} + +const Q: &str = "ARRAY[0.1, 0.2, 0.3, 0.4]"; + +/// SELECT * over an indexed table: the output includes the vector column, which +/// the rewrite cannot produce. The rule must NOT fire, and optimization must +/// succeed (falling back to exact search) rather than erroring on a schema +/// mismatch. +#[tokio::test] +async fn test_select_star_with_vector_index_does_not_crash() { + let ctx = make_ctx().await; + let sql = format!("SELECT * FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error when the vector column is in the output"); + assert!( + !contains_usearch_node(&plan), + "vector column in output → rule must fall back to exact search, not rewrite\nPlan: {plan:?}" + ); +} + +/// Explicitly projecting the indexed vector column has the same requirement. +#[tokio::test] +async fn test_select_vector_column_does_not_crash() { + let ctx = make_ctx().await; + let sql = + format!("SELECT id, embedding FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error when the vector column is in the output"); + assert!( + !contains_usearch_node(&plan), + "vector column in output → rule must fall back to exact search, not rewrite\nPlan: {plan:?}" + ); +} + +/// The canonical vector-search query (distance aliased in the SELECT, ORDER BY +/// the alias) keeps rewriting: its output columns (id, the distance) are all +/// producible from the sidecar, and the aliased distance forces a Projection +/// below the Sort so the rewrite reproduces the schema exactly. +#[tokio::test] +async fn test_aliased_distance_still_rewrites() { + let ctx = make_ctx().await; + let sql = format!( + "SELECT id, l2_distance(embedding, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 2" + ); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization failed"); + assert!( + contains_usearch_node(&plan), + "no vector column in output → rule must still fire\nPlan: {plan:?}" + ); +} + +/// A bare projection whose ORDER BY computes the distance inline (distance not in +/// the SELECT, no Projection below the Sort) — the common "give me the nearest +/// rows' ids" query. The indexed vector is only a sort input, not an output +/// column, so the rule must still use the index. This is the regression guard for +/// the over-eager fallback: the output-aware passthrough drives the rewrite from +/// the OUTER projection (`[id]`, all producible), not the Sort's schema. +#[tokio::test] +async fn test_bare_select_inline_distance_still_rewrites() { + let ctx = make_ctx().await; + let sql = format!("SELECT id FROM items ORDER BY l2_distance(embedding, {Q}) ASC LIMIT 2"); + let plan = ctx + .sql(&sql) + .await + .expect("SQL analysis failed") + .into_optimized_plan() + .expect("optimization must not error"); + assert!( + contains_usearch_node(&plan), + "vector not in output → rule must still use the index, not fall back\nPlan: {plan:?}" + ); +}