diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e75e40a..929f095c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,6 +112,13 @@ jobs: - name: Install Rust uses: dtolnay/rust-toolchain@stable + with: + targets: wasm32-unknown-unknown + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true - name: Cache cargo uses: Swatinem/rust-cache@v2 @@ -122,13 +129,93 @@ jobs: run: cargo fmt --check - name: Run cargo clippy - run: cargo clippy -- -D warnings + run: cargo clippy --all-targets --all-features -- -D warnings - name: Run cargo test run: cargo test env: RUST_MIN_STACK: 16777216 + - name: Check package metadata + run: cargo test --test package_metadata + + - name: Check crate package dry run + run: cargo package --locked --no-verify + + - name: Check standalone feature surfaces + run: | + cargo check --no-default-features --all-targets + cargo check --features python --lib + cargo check --features python-adbc --lib + cargo check --features mcp-server --bin sidemantic-mcp --test mcp_protocol + cargo check --features mcp-adbc --bin sidemantic-mcp + cargo check --features runtime-server --bin sidemantic-server --test http_server + cargo check --features runtime-server-adbc --bin sidemantic-server + cargo check --features runtime-lsp --bin sidemantic-lsp --test lsp_protocol_smoke + cargo check --features workbench-tui --bin sidemantic --bin sidemantic-workbench + cargo check --features workbench-adbc --bin sidemantic --bin sidemantic-workbench + + - name: Install DuckDB ADBC driver + run: | + curl -L https://github.com/duckdb/duckdb/releases/download/v1.4.2/libduckdb-linux-amd64.zip -o /tmp/libduckdb.zip + unzip -q /tmp/libduckdb.zip -d /tmp/libduckdb + echo "SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER=/tmp/libduckdb/libduckdb.so" >> "$GITHUB_ENV" + + - name: Install SQLite ADBC driver + run: | + uv run --no-project --with adbc-driver-sqlite python - <<'PY' >> "$GITHUB_ENV" + import adbc_driver_sqlite + + print(f"SIDEMANTIC_TEST_ADBC_SQLITE_DRIVER={adbc_driver_sqlite._driver_path()}") + print("SIDEMANTIC_TEST_ADBC_SQLITE_URI=:memory:") + print("SIDEMANTIC_TEST_ADBC_REQUIRE=duckdb,sqlite") + PY + + - name: Run standalone protocol and UI tests + run: | + cargo test --features mcp-server --test mcp_protocol + cargo test --features runtime-server --test http_server + cargo test --features runtime-lsp --test lsp_protocol_smoke + cargo test --features workbench-tui --test cli_smoke workbench + cargo test --features workbench-adbc --test cli_smoke workbench + cargo test --features workbench-tui --test workbench_pty_smoke + cargo test --features workbench-adbc --test workbench_pty_smoke + cargo test --features adbc-exec --test adbc_driver_matrix + cargo test --features mcp-adbc,runtime-server-adbc --test adbc_duckdb_e2e + + - name: Smoke C ABI header and static library + run: | + cargo build --release --lib + cc -std=c11 -Wall -Wextra -I include tests/c_abi_smoke.c target/release/libsidemantic.a -lpthread -ldl -lm -lrt -o /tmp/sidemantic_c_abi_smoke + /tmp/sidemantic_c_abi_smoke + + - name: Check WASM build + run: cargo check --no-default-features --features wasm --target wasm32-unknown-unknown --lib + + - name: Install wasm-bindgen test runner + run: cargo install wasm-bindgen-cli --version 0.2.110 + + - name: Run WASM runtime tests + run: CARGO_TARGET_WASM32_UNKNOWN_UNKNOWN_RUNNER=wasm-bindgen-test-runner cargo test --target wasm32-unknown-unknown --features wasm --test wasm_bindgen_runtime + + - name: Check Python extension build + run: uvx maturin build --out dist + + - name: Smoke Python extension wheel + run: uv run --no-project --with dist/*.whl tests/python_wheel_smoke.py + + - name: Check lightweight Python extension build + run: uvx maturin build --no-default-features --features python --out dist-python + + - name: Smoke lightweight Python extension wheel + run: uv run --no-project --with dist-python/*.whl tests/python_wheel_python_smoke.py + + - name: Check Python extension ADBC build + run: uvx maturin build --no-default-features --features python-adbc --out dist-adbc + + - name: Smoke Python extension ADBC wheel + run: uv run --no-project --with dist-adbc/*.whl tests/python_wheel_adbc_smoke.py + duckdb-extension: name: DuckDB Extension needs: check-rust-changes @@ -160,10 +247,6 @@ jobs: - name: Install build dependencies run: sudo apt-get update && sudo apt-get install -y ninja-build - - name: Build Rust library - working-directory: sidemantic-rs - run: cargo build --release - - name: Build DuckDB extension working-directory: sidemantic-duckdb run: make diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index a7a5b1fa..e2c9d700 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -185,7 +185,6 @@ jobs: env: CLICKHOUSE_DB: default CLICKHOUSE_USER: default - CLICKHOUSE_PASSWORD: clickhouse CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT: 1 ports: - 8123:8123 @@ -221,6 +220,10 @@ jobs: - name: Install ADBC driver (best effort) run: | DB="${{ matrix.db }}" + if [ "$DB" = "clickhouse" ]; then + uvx dbc install --pre clickhouse + exit 0 + fi PKG_DB="$DB" if [ "$DB" = "postgres" ]; then PKG_DB="postgresql" @@ -237,6 +240,64 @@ jobs: BIGQUERY_DATASET: "test_dataset" CLICKHOUSE_HOST: "localhost" CLICKHOUSE_PORT: "8123" - CLICKHOUSE_PASSWORD: "clickhouse" SNOWFLAKE_TEST: "1" run: uv run pytest -m integration tests/db/test_adbc_ci_smoke.py -v + + - name: Install Rust for Rust ADBC probe + uses: dtolnay/rust-toolchain@stable + + - name: Export Postgres ADBC driver for Rust + if: matrix.db == 'postgres' + run: | + uv pip install adbc-driver-postgresql + uv run python - <<'PY' >> "$GITHUB_ENV" + import adbc_driver_postgresql + + print(f"SIDEMANTIC_TEST_ADBC_POSTGRES_DRIVER={adbc_driver_postgresql._driver_path()}") + PY + echo "SIDEMANTIC_TEST_ADBC_POSTGRES_URI=postgresql://test:test@localhost:5432/sidemantic_test" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_REQUIRE=postgres" >> "$GITHUB_ENV" + + - name: Export BigQuery ADBC driver for Rust + if: matrix.db == 'bigquery' + env: + BIGQUERY_ADBC_CREDENTIALS_JSON: ${{ secrets.BIGQUERY_ADBC_CREDENTIALS_JSON }} + BIGQUERY_ADBC_PROJECT: ${{ vars.BIGQUERY_ADBC_PROJECT }} + BIGQUERY_ADBC_DATASET: ${{ vars.BIGQUERY_ADBC_DATASET }} + run: | + if [ -z "$BIGQUERY_ADBC_CREDENTIALS_JSON" ] || [ -z "$BIGQUERY_ADBC_PROJECT" ] || [ -z "$BIGQUERY_ADBC_DATASET" ]; then + echo "Skipping Rust BigQuery ADBC probe; BIGQUERY_ADBC_CREDENTIALS_JSON secret plus BIGQUERY_ADBC_PROJECT/BIGQUERY_ADBC_DATASET vars are required." + exit 0 + fi + uvx dbc install bigquery + CREDENTIALS_FILE="$(mktemp)" + printf '%s' "$BIGQUERY_ADBC_CREDENTIALS_JSON" > "$CREDENTIALS_FILE" + echo "::add-mask::$BIGQUERY_ADBC_CREDENTIALS_JSON" + echo "SIDEMANTIC_TEST_ADBC_BIGQUERY_DRIVER=bigquery" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_BIGQUERY_DBOPTS=adbc.bigquery.sql.project_id=$BIGQUERY_ADBC_PROJECT,adbc.bigquery.sql.dataset_id=$BIGQUERY_ADBC_DATASET,adbc.bigquery.sql.auth_type=adbc.bigquery.sql.auth_type.json_credential_file,adbc.bigquery.sql.auth_credentials=$CREDENTIALS_FILE" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_REQUIRE=bigquery" >> "$GITHUB_ENV" + + - name: Export Snowflake ADBC driver for Rust + if: matrix.db == 'snowflake' + env: + SNOWFLAKE_ADBC_URI: ${{ secrets.SNOWFLAKE_ADBC_URI }} + run: | + if [ -z "$SNOWFLAKE_ADBC_URI" ]; then + echo "Skipping Rust Snowflake ADBC probe; SNOWFLAKE_ADBC_URI secret is required." + exit 0 + fi + uvx dbc install snowflake + echo "::add-mask::$SNOWFLAKE_ADBC_URI" + echo "SIDEMANTIC_TEST_ADBC_SNOWFLAKE_DRIVER=snowflake" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_SNOWFLAKE_URI=$SNOWFLAKE_ADBC_URI" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_REQUIRE=snowflake" >> "$GITHUB_ENV" + + - name: Export ClickHouse ADBC driver for Rust + if: matrix.db == 'clickhouse' + run: | + echo "SIDEMANTIC_TEST_ADBC_CLICKHOUSE_DRIVER=clickhouse" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_CLICKHOUSE_URI=http://localhost:8123/" >> "$GITHUB_ENV" + echo "SIDEMANTIC_TEST_ADBC_REQUIRE=clickhouse" >> "$GITHUB_ENV" + + - name: Run Rust ADBC probe + run: cargo test --manifest-path sidemantic-rs/Cargo.toml --features adbc-exec --test adbc_driver_matrix diff --git a/sidemantic-duckdb/CMakeLists.txt b/sidemantic-duckdb/CMakeLists.txt index 0afd6ef2..0309d647 100644 --- a/sidemantic-duckdb/CMakeLists.txt +++ b/sidemantic-duckdb/CMakeLists.txt @@ -11,9 +11,44 @@ include_directories(src/include) # Path to sidemantic-rs set(SIDEMANTIC_RS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../sidemantic-rs") -set(SIDEMANTIC_LIB "${SIDEMANTIC_RS_DIR}/target/release/libsidemantic.a") +set(SIDEMANTIC_CARGO_TARGET_DIR "${SIDEMANTIC_RS_DIR}/target" CACHE PATH "Cargo target directory for sidemantic-rs") +set(SIDEMANTIC_CARGO_PROFILE "release" CACHE STRING "Cargo profile used for the sidemantic-rs static library") set(SIDEMANTIC_INCLUDE "${SIDEMANTIC_RS_DIR}/include") +if(WIN32) + set(SIDEMANTIC_STATICLIB_NAME "sidemantic.lib") +else() + set(SIDEMANTIC_STATICLIB_NAME "libsidemantic.a") +endif() + +set(SIDEMANTIC_LIB "${SIDEMANTIC_CARGO_TARGET_DIR}/${SIDEMANTIC_CARGO_PROFILE}/${SIDEMANTIC_STATICLIB_NAME}") + +find_program(CARGO_EXECUTABLE cargo) +if(NOT CARGO_EXECUTABLE) + message(FATAL_ERROR "cargo is required to build the sidemantic DuckDB extension") +endif() + +set(SIDEMANTIC_CARGO_BUILD_ARGS build --manifest-path "${SIDEMANTIC_RS_DIR}/Cargo.toml" --lib) +if(SIDEMANTIC_CARGO_PROFILE STREQUAL "release") + list(APPEND SIDEMANTIC_CARGO_BUILD_ARGS --release) +elseif(NOT SIDEMANTIC_CARGO_PROFILE STREQUAL "debug") + list(APPEND SIDEMANTIC_CARGO_BUILD_ARGS --profile "${SIDEMANTIC_CARGO_PROFILE}") +endif() + +file(GLOB_RECURSE SIDEMANTIC_RUST_SOURCES + "${SIDEMANTIC_RS_DIR}/src/*.rs" + "${SIDEMANTIC_RS_DIR}/include/*.h") + +add_custom_command( + OUTPUT "${SIDEMANTIC_LIB}" + COMMAND ${CMAKE_COMMAND} -E env "CARGO_TARGET_DIR=${SIDEMANTIC_CARGO_TARGET_DIR}" "${CARGO_EXECUTABLE}" ${SIDEMANTIC_CARGO_BUILD_ARGS} + WORKING_DIRECTORY "${SIDEMANTIC_RS_DIR}" + DEPENDS "${SIDEMANTIC_RS_DIR}/Cargo.toml" "${SIDEMANTIC_RS_DIR}/Cargo.lock" ${SIDEMANTIC_RUST_SOURCES} + COMMENT "Building sidemantic-rs static library" + VERBATIM) + +add_custom_target(sidemantic_rust_staticlib DEPENDS "${SIDEMANTIC_LIB}") + # Include Rust library headers include_directories(${SIDEMANTIC_INCLUDE}) @@ -22,6 +57,9 @@ set(EXTENSION_SOURCES src/sidemantic_extension.cpp) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) +add_dependencies(${EXTENSION_NAME} sidemantic_rust_staticlib) +add_dependencies(${LOADABLE_EXTENSION_NAME} sidemantic_rust_staticlib) + # Link the Rust static library target_link_libraries(${EXTENSION_NAME} ${SIDEMANTIC_LIB}) target_link_libraries(${LOADABLE_EXTENSION_NAME} ${SIDEMANTIC_LIB}) diff --git a/sidemantic-duckdb/Makefile b/sidemantic-duckdb/Makefile index e91db43b..645c52db 100644 --- a/sidemantic-duckdb/Makefile +++ b/sidemantic-duckdb/Makefile @@ -1,8 +1,8 @@ PROJ_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) # Configuration of extension -EXT_NAME=quack +EXT_NAME=sidemantic EXT_CONFIG=${PROJ_DIR}extension_config.cmake # Include the Makefile from extension-ci-tools -include extension-ci-tools/makefiles/duckdb_extension.Makefile \ No newline at end of file +include extension-ci-tools/makefiles/duckdb_extension.Makefile diff --git a/sidemantic-duckdb/src/include/sidemantic_extension.hpp b/sidemantic-duckdb/src/include/sidemantic_extension.hpp index ca89bcdf..3d0e1220 100644 --- a/sidemantic-duckdb/src/include/sidemantic_extension.hpp +++ b/sidemantic-duckdb/src/include/sidemantic_extension.hpp @@ -24,6 +24,14 @@ ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, ParserExtensionPlanResult sidemantic_plan(ParserExtensionInfo *, ClientContext &, unique_ptr); +struct SidemanticParserInfo : ParserExtensionInfo { + SidemanticParserInfo(string db_path, string context_key) + : db_path(std::move(db_path)), context_key(std::move(context_key)) {} + + string db_path; + string context_key; +}; + // Operator extension: handles binding after parsing struct SidemanticOperatorExtension : public OperatorExtension { SidemanticOperatorExtension() : OperatorExtension() { Bind = sidemantic_bind; } @@ -36,9 +44,11 @@ struct SidemanticOperatorExtension : public OperatorExtension { // Parser extension: intercepts query strings struct SidemanticParserExtension : public ParserExtension { - SidemanticParserExtension() : ParserExtension() { + SidemanticParserExtension(string db_path, string context_key) : ParserExtension() { parse_function = sidemantic_parse; plan_function = sidemantic_plan; + parser_info = + make_shared_ptr(std::move(db_path), std::move(context_key)); } }; diff --git a/sidemantic-duckdb/src/sidemantic_extension.cpp b/sidemantic-duckdb/src/sidemantic_extension.cpp index 436d2e1d..600f068b 100644 --- a/sidemantic-duckdb/src/sidemantic_extension.cpp +++ b/sidemantic-duckdb/src/sidemantic_extension.cpp @@ -4,30 +4,36 @@ #include "duckdb/parser/parser.hpp" #include "duckdb/parser/statement/extension_statement.hpp" #include "duckdb/function/table_function.hpp" +#include "sidemantic.h" -// Rust FFI -extern "C" { - struct SidemanticRewriteResult { - char *sql; - char *error; - bool was_rewritten; - }; +#include +#include + +namespace duckdb { - char *sidemantic_load_yaml(const char *yaml); - char *sidemantic_load_file(const char *path); - void sidemantic_clear(void); - bool sidemantic_is_model(const char *table_name); - char *sidemantic_list_models(void); - SidemanticRewriteResult sidemantic_rewrite(const char *sql); - void sidemantic_free(char *ptr); - void sidemantic_free_result(SidemanticRewriteResult result); - char *sidemantic_define(const char *definition_sql, const char *db_path, bool replace); - char *sidemantic_autoload(const char *db_path); - char *sidemantic_add_definition(const char *definition_sql, const char *db_path, bool is_replace); - char *sidemantic_use(const char *model_name); +static std::string DatabasePath(DatabaseInstance &db) { + auto &db_config = db.config; + if (!db_config.options.database_path.empty()) { + return db_config.options.database_path; + } + return ""; } -namespace duckdb { +static std::string ContextKey(DatabaseInstance &db) { + auto path = DatabasePath(db); + if (!path.empty()) { + return "duckdb:" + path; + } + return "duckdb:memory:" + std::to_string(reinterpret_cast(&db)); +} + +static const char *ContextKeyPtr(const std::string &context_key) { + return context_key.empty() ? nullptr : context_key.c_str(); +} + +static const SidemanticParserInfo *ParserInfo(ParserExtensionInfo *info) { + return dynamic_cast(info); +} //============================================================================= // TABLE FUNCTION: sidemantic_load(yaml) @@ -59,7 +65,8 @@ static void SidemanticLoadFunction(ClientContext &context, TableFunctionInput &d } data.done = true; - char *error = sidemantic_load_yaml(data.yaml_content.c_str()); + auto context_key = ContextKey(*context.db); + char *error = sidemantic_load_yaml_for_context(ContextKeyPtr(context_key), data.yaml_content.c_str()); if (error) { string error_msg(error); sidemantic_free(error); @@ -100,7 +107,8 @@ static void SidemanticLoadFileFunction(ClientContext &context, TableFunctionInpu } data.done = true; - char *error = sidemantic_load_file(data.file_path.c_str()); + auto context_key = ContextKey(*context.db); + char *error = sidemantic_load_file_for_context(ContextKeyPtr(context_key), data.file_path.c_str()); if (error) { string error_msg(error); sidemantic_free(error); @@ -136,7 +144,8 @@ static void SidemanticModelsFunction(ClientContext &context, TableFunctionInput } data.done = true; - char *models_str = sidemantic_list_models(); + auto context_key = ContextKey(*context.db); + char *models_str = sidemantic_list_models_for_context(ContextKeyPtr(context_key)); if (!models_str) { output.SetCardinality(0); return; @@ -176,7 +185,12 @@ static void SidemanticRewriteSqlFunction(DataChunk &args, ExpressionState &state auto &sql_vector = args.data[0]; UnaryExecutor::Execute( sql_vector, result, args.size(), [&](string_t sql) { - SidemanticRewriteResult res = sidemantic_rewrite(sql.GetString().c_str()); + std::string context_key; + if (state.HasContext()) { + context_key = ContextKey(*state.GetContext().db); + } + SidemanticRewriteResult res = + sidemantic_rewrite_for_context(ContextKeyPtr(context_key), sql.GetString().c_str()); if (res.error) { string error_msg(res.error); @@ -429,11 +443,13 @@ static int IsCreateModelStatement(const std::string &query, std::string &definit return is_replace ? 2 : 1; } -// Global to store database path for parser extension (set during extension load) -static std::string g_db_path; - -ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, +ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *info, const std::string &query) { + auto parser_info = ParserInfo(info); + const char *context_key_ptr = parser_info ? ContextKeyPtr(parser_info->context_key) : nullptr; + const char *db_path_ptr = + parser_info && !parser_info->db_path.empty() ? parser_info->db_path.c_str() : nullptr; + // Check for SEMANTIC prefix std::string stripped_query; if (!StartsWithSemantic(query, stripped_query)) { @@ -448,9 +464,8 @@ ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, if (create_type > 0) { // This is a CREATE MODEL statement - handle specially bool replace = (create_type == 2); - const char *db_path_ptr = g_db_path.empty() ? nullptr : g_db_path.c_str(); - char *error = sidemantic_define(definition.c_str(), db_path_ptr, replace); + char *error = sidemantic_define_for_context(context_key_ptr, definition.c_str(), db_path_ptr, replace); if (error) { string error_msg(error); sidemantic_free(error); @@ -489,7 +504,7 @@ ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, } if (!model_name.empty()) { - char *error = sidemantic_use(model_name.c_str()); + char *error = sidemantic_use_for_context(context_key_ptr, model_name.c_str()); if (error) { string error_msg(error); sidemantic_free(error); @@ -512,9 +527,8 @@ ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, bool is_replace = false; std::string def_type = IsDefinitionStatement(stripped_query, definition, is_replace); if (!def_type.empty()) { - const char *db_path_ptr = g_db_path.empty() ? nullptr : g_db_path.c_str(); - - char *error = sidemantic_add_definition(definition.c_str(), db_path_ptr, is_replace); + char *error = + sidemantic_add_definition_for_context(context_key_ptr, definition.c_str(), db_path_ptr, is_replace); if (error) { string error_msg(error); sidemantic_free(error); @@ -533,7 +547,7 @@ ParserExtensionParseResult sidemantic_parse(ParserExtensionInfo *, } // Regular SEMANTIC SELECT query - try to rewrite using sidemantic - SidemanticRewriteResult result = sidemantic_rewrite(stripped_query.c_str()); + SidemanticRewriteResult result = sidemantic_rewrite_for_context(context_key_ptr, stripped_query.c_str()); // If there was an error, return it if (result.error) { @@ -606,25 +620,20 @@ static void LoadInternal(ExtensionLoader &loader) { auto &db = loader.GetDatabaseInstance(); auto &config = DBConfig::GetConfig(db); - // Capture database path for CREATE MODEL statements - auto &db_config = db.config; - if (!db_config.options.database_path.empty()) { - g_db_path = db_config.options.database_path; - } else { - g_db_path.clear(); - } + auto db_path = DatabasePath(db); + auto context_key = ContextKey(db); // Auto-load definitions from file if it exists - const char *db_path_ptr = g_db_path.empty() ? nullptr : g_db_path.c_str(); - char *error = sidemantic_autoload(db_path_ptr); + const char *db_path_ptr = db_path.empty() ? nullptr : db_path.c_str(); + char *error = sidemantic_autoload_for_context(ContextKeyPtr(context_key), db_path_ptr); if (error) { // Log warning but don't fail extension load - // fprintf(stderr, "Warning: failed to autoload sidemantic definitions: %s\n", error); + std::fprintf(stderr, "Warning: failed to autoload sidemantic definitions: %s\n", error); sidemantic_free(error); } // Register parser extension - SidemanticParserExtension parser; + SidemanticParserExtension parser(db_path, context_key); config.parser_extensions.push_back(parser); // Register operator extension diff --git a/sidemantic-duckdb/test/sql/sidemantic.test b/sidemantic-duckdb/test/sql/sidemantic.test index 3086239d..52d9e13e 100644 --- a/sidemantic-duckdb/test/sql/sidemantic.test +++ b/sidemantic-duckdb/test/sql/sidemantic.test @@ -100,3 +100,65 @@ SELECT * FROM sidemantic_models(); orders products test_model + +# Test repeated YAML load replaces same-name model instead of failing on duplicates +statement ok +SELECT * FROM sidemantic_load(' +models: + - name: reloadable + table: orders + primary_key: order_id + metrics: + - name: order_count + agg: count +'); + +statement ok +SELECT * FROM sidemantic_load(' +models: + - name: reloadable + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +'); + +query R +SEMANTIC SELECT reloadable.revenue FROM reloadable; +---- +225.00 + +# Test SEMANTIC MODEL active-model switching and repeated use/rewrite sequence +statement ok +SEMANTIC CREATE MODEL sql_orders (name sql_orders, table orders, primary_key order_id); + +statement ok +SEMANTIC CREATE METRIC sql_revenue AS SUM(amount); + +statement ok +SEMANTIC CREATE DIMENSION status AS status; + +query IR +SEMANTIC SELECT sql_orders.status, sql_orders.sql_revenue FROM sql_orders ORDER BY sql_orders.status; +---- +completed 150.00 +pending 75.00 + +statement ok +SEMANTIC CREATE OR REPLACE MODEL sql_orders (name sql_orders, table orders, primary_key order_id); + +statement ok +SEMANTIC MODEL sql_orders; + +statement ok +SEMANTIC CREATE OR REPLACE METRIC sql_order_count AS COUNT(*); + +query I +SEMANTIC SELECT sql_orders.sql_order_count FROM sql_orders; +---- +3 diff --git a/sidemantic-duckdb/test/sql/sidemantic_memory_and_invalid_persistence.test b/sidemantic-duckdb/test/sql/sidemantic_memory_and_invalid_persistence.test new file mode 100644 index 00000000..ed4bdf39 --- /dev/null +++ b/sidemantic-duckdb/test/sql/sidemantic_memory_and_invalid_persistence.test @@ -0,0 +1,64 @@ +# name: test/sql/sidemantic_memory_and_invalid_persistence.test +# description: Test in-memory session-local definitions and invalid persisted definitions +# group: [sidemantic] + +require sidemantic + +statement ok +SEMANTIC CREATE MODEL temp_orders (name temp_orders, table temp_orders, primary_key order_id); + +query I +SELECT count(*) FROM sidemantic_models() WHERE model_name = 'temp_orders'; +---- +1 + +restart + +require sidemantic + +query I +SELECT count(*) FROM sidemantic_models() WHERE model_name = 'temp_orders'; +---- +0 + +load __TEST_DIR__/sidemantic_invalid_persistence.duckdb + +statement ok +CREATE TABLE valid_events (event_id INT); + +statement ok +INSERT INTO valid_events VALUES (1), (2); + +statement ok +SEMANTIC CREATE MODEL valid_events (name valid_events, table valid_events, primary_key event_id); + +statement ok +SEMANTIC CREATE METRIC event_count AS COUNT(*); + +query I +SEMANTIC SELECT valid_events.event_count FROM valid_events; +---- +2 + +statement ok +COPY (SELECT 'MODEL (' AS line) TO '__TEST_DIR__/sidemantic_invalid_persistence.sidemantic.sql' (HEADER false); + +restart + +require sidemantic + +query I +SELECT count(*) FROM sidemantic_models() WHERE model_name = 'valid_events'; +---- +0 + +statement ok +SEMANTIC CREATE MODEL valid_events (name valid_events, table valid_events, primary_key event_id); + +statement ok +SEMANTIC CREATE METRIC event_count AS COUNT(*); + +query I +SEMANTIC SELECT valid_events.event_count FROM valid_events; +---- +2 diff --git a/sidemantic-duckdb/test/sql/sidemantic_persistence.test b/sidemantic-duckdb/test/sql/sidemantic_persistence.test new file mode 100644 index 00000000..358ab6ac --- /dev/null +++ b/sidemantic-duckdb/test/sql/sidemantic_persistence.test @@ -0,0 +1,61 @@ +# name: test/sql/sidemantic_persistence.test +# description: Test sidemantic extension persistence and autoload behavior +# group: [sidemantic] + +load __TEST_DIR__/sidemantic_persistence.duckdb + +require sidemantic + +statement ok +CREATE TABLE events (event_id INT, amount INT); + +statement ok +INSERT INTO events VALUES (1, 10), (2, 20), (3, 30); + +statement ok +SEMANTIC CREATE MODEL events (name events, table events, primary_key event_id); + +statement ok +SEMANTIC CREATE METRIC event_count AS COUNT(*); + +query I +SEMANTIC SELECT events.event_count FROM events; +---- +3 + +restart + +require sidemantic + +query I +SELECT * FROM sidemantic_models(); +---- +events + +query I +SEMANTIC SELECT events.event_count FROM events; +---- +3 + +statement ok +SEMANTIC CREATE OR REPLACE MODEL events (name events, table events, primary_key event_id); + +statement ok +SEMANTIC MODEL events; + +statement ok +SEMANTIC CREATE OR REPLACE METRIC total_amount AS SUM(amount); + +query I +SEMANTIC SELECT events.total_amount FROM events; +---- +60 + +restart + +require sidemantic + +query I +SEMANTIC SELECT events.total_amount FROM events; +---- +60 diff --git a/sidemantic-rs/.gitignore b/sidemantic-rs/.gitignore index b83d2226..a742227c 100644 --- a/sidemantic-rs/.gitignore +++ b/sidemantic-rs/.gitignore @@ -1 +1,2 @@ /target/ +/dist/ diff --git a/sidemantic-rs/Cargo.lock b/sidemantic-rs/Cargo.lock index a5aa5d80..7d527be4 100644 --- a/sidemantic-rs/Cargo.lock +++ b/sidemantic-rs/Cargo.lock @@ -3,246 +3,450 @@ version = 4 [[package]] -name = "aho-corasick" -version = "1.1.4" +name = "adbc_core" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +checksum = "e8dbe031527c9856a1e2df5e82aa8e568ffaab3be897f70d874477fb42a783bb" dependencies = [ - "memchr", + "arrow-array", + "arrow-schema", ] [[package]] -name = "equivalent" -version = "1.0.2" +name = "adbc_driver_manager" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +checksum = "5beaa87308b040adcf482fb760f64ced1cbcd9469fe86ddd5be221dabbbc544e" +dependencies = [ + "adbc_core", + "adbc_ffi", + "arrow-array", + "arrow-schema", + "libloading", + "toml", + "windows-registry", + "windows-sys 0.61.2", +] [[package]] -name = "hashbrown" -version = "0.16.1" +name = "adbc_ffi" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +checksum = "3600ae9aec2907516d088189e3b863029280f1953dd0eab903c7f4c862a0ce81" +dependencies = [ + "adbc_core", + "arrow-array", + "arrow-schema", +] [[package]] -name = "indexmap" -version = "2.12.1" +name = "ahash" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ - "equivalent", - "hashbrown", + "cfg-if", + "const-random", + "getrandom 0.3.4", + "once_cell", + "version_check", + "zerocopy", ] [[package]] -name = "itoa" -version = "1.0.15" +name = "aho-corasick" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] [[package]] -name = "lazy_static" -version = "1.5.0" +name = "allocator-api2" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] -name = "memchr" -version = "2.7.6" +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "arrow-array" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +checksum = "4c8955af33b25f3b175ee10af580577280b4bd01f7e823d94c7cdef7cf8c9aef" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "hashbrown 0.16.1", + "num-complex", + "num-integer", + "num-traits", +] [[package]] -name = "minimal-lexical" -version = "0.2.1" +name = "arrow-buffer" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +checksum = "c697ddca96183182f35b3a18e50b9110b11e916d7b7799cbfd4d34662f2c56c2" +dependencies = [ + "bytes", + "half", + "num-bigint", + "num-traits", +] [[package]] -name = "nom" -version = "7.1.3" +name = "arrow-data" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "1fdd994a9d28e6365aa78e15da3f3950c0fdcea6b963a12fa1c391afb637b304" dependencies = [ - "memchr", - "minimal-lexical", + "arrow-buffer", + "arrow-schema", + "half", + "num-integer", + "num-traits", ] [[package]] -name = "once_cell" -version = "1.21.3" +name = "arrow-ipc" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "abf7df950701ab528bf7c0cf7eeadc0445d03ef5d6ffc151eaae6b38a58feff1" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "flatbuffers", +] [[package]] -name = "polyglot-sql" -version = "0.1.5" +name = "arrow-schema" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50787ca43ba1cd7aa4092ac5e5669aef59ebb273c960755a6bd01cd8df0eaa20" +checksum = "8c872d36b7bf2a6a6a2b40de9156265f0242910791db366a2c17476ba8330d68" dependencies = [ - "serde", - "serde_json", - "thiserror 1.0.69", - "unicode-segmentation", + "bitflags 2.11.0", ] [[package]] -name = "proc-macro2" -version = "1.0.103" +name = "arrow-select" +version = "57.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +checksum = "68bf3e3efbd1278f770d67e5dc410257300b161b93baedb3aae836144edcaf4b" dependencies = [ - "unicode-ident", + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num-traits", ] [[package]] -name = "quote" -version = "1.0.42" +name = "async-trait" +version = "0.1.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", + "quote", + "syn", ] [[package]] -name = "regex" -version = "1.12.2" +name = "atomic-waker" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", + "proc-macro2", + "quote", + "syn", ] [[package]] -name = "regex-automata" -version = "0.4.13" +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "axum" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ - "aho-corasick", + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", "memchr", - "regex-syntax", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower 0.5.3", + "tower-layer", + "tower-service", + "tracing", ] [[package]] -name = "regex-syntax" -version = "0.8.8" +name = "axum-core" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] [[package]] -name = "ryu" -version = "1.0.20" +name = "base64" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] -name = "serde" -version = "1.0.228" +name = "bitflags" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cassowary" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df8670b8c7b9dae1793364eafadf7239c40d669904660c5960d74cfd80b46a53" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" dependencies = [ - "serde_core", - "serde_derive", + "rustversion", ] [[package]] -name = "serde_core" -version = "1.0.228" +name = "cc" +version = "1.2.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" dependencies = [ - "serde_derive", + "find-msvc-tools", + "shlex", ] [[package]] -name = "serde_derive" -version = "1.0.228" +name = "cfg-if" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" dependencies = [ - "proc-macro2", - "quote", - "syn", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", ] [[package]] -name = "serde_json" -version = "1.0.145" +name = "compact_str" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" dependencies = [ + "castaway", + "cfg-if", "itoa", - "memchr", + "rustversion", "ryu", - "serde", - "serde_core", + "static_assertions", ] [[package]] -name = "serde_yaml" -version = "0.9.34+deprecated" +name = "const-random" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" dependencies = [ - "indexmap", - "itoa", - "ryu", - "serde", - "unsafe-libyaml", + "const-random-macro", ] [[package]] -name = "sidemantic" -version = "0.1.0" +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "lazy_static", - "nom", + "getrandom 0.2.17", "once_cell", - "polyglot-sql", - "regex", - "serde", - "serde_json", - "serde_yaml", - "thiserror 2.0.17", + "tiny-keccak", ] [[package]] -name = "syn" -version = "2.0.111" +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crossterm" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" +dependencies = [ + "bitflags 2.11.0", + "crossterm_winapi", + "mio", + "parking_lot", + "rustix", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" dependencies = [ + "winapi", +] + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", "proc-macro2", "quote", - "unicode-ident", + "strsim", + "syn", ] [[package]] -name = "thiserror" -version = "1.0.69" +name = "darling_macro" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ - "thiserror-impl 1.0.69", + "darling_core", + "quote", + "syn", ] [[package]] -name = "thiserror" -version = "2.0.17" +name = "dashmap" +version = "5.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ - "thiserror-impl 2.0.17", + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", ] [[package]] -name = "thiserror-impl" -version = "1.0.69" +name = "displaydoc" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", @@ -250,10 +454,117 @@ dependencies = [ ] [[package]] -name = "thiserror-impl" -version = "2.0.17" +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flatbuffers" +version = "25.12.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35f6839d7b3b98adde531effaf34f0c2badc6f4735d26fe74709d8e513a96ef3" +dependencies = [ + "bitflags 2.11.0", + "rustc_version", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -261,19 +572,2068 @@ dependencies = [ ] [[package]] -name = "unicode-ident" -version = "1.0.22" +name = "futures-sink" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] -name = "unicode-segmentation" -version = "1.12.0" +name = "futures-task" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] -name = "unsafe-libyaml" -version = "0.2.11" +name = "futures-util" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", +] + +[[package]] +name = "indoc" +version = "2.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79cf5c93f93228cf8efb3ba362535fb11199ac548a09ce117c9b1adc3030d706" +dependencies = [ + "rustversion", +] + +[[package]] +name = "instability" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357b7205c6cd18dd2c86ed312d1e70add149aea98e7ef72b9fdf0270e555c11d" +dependencies = [ + "darling", + "indoc", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown 0.15.5", +] + +[[package]] +name = "lsp-types" +version = "0.94.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66bfd44a06ae10647fe3f8214762e9369fd4248df1350924b4ef9e770a85ea1" +dependencies = [ + "bitflags 1.3.2", + "serde", + "serde_json", + "serde_repr", + "url", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "minicov" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" +dependencies = [ + "cc", + "walkdir", +] + +[[package]] +name = "minijinja" +version = "2.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c54f3bcc034dd74496b5ca929fd0b710186672d5ff0b0f255a9ceb259042ece" +dependencies = [ + "serde", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pastey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "polyglot-sql" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bccd5cbca75edb3852dad81a34d9fa14c3abb2725d3cbc5dcb2d9e8250e4801" +dependencies = [ + "serde", + "serde_json", + "thiserror 1.0.69", + "unicode-segmentation", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.23.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "ratatui" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" +dependencies = [ + "bitflags 2.11.0", + "cassowary", + "compact_str", + "crossterm", + "indoc", + "instability", + "itertools", + "lru", + "paste", + "strum", + "unicode-segmentation", + "unicode-truncate", + "unicode-width 0.2.0", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags 2.11.0", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + +[[package]] +name = "rmcp" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc4c9c94680f75470ee8083a0667988b5d7b5beb70b9f998a8e51de7c682ce60" +dependencies = [ + "async-trait", + "base64", + "chrono", + "futures", + "pastey", + "pin-project-lite", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90c23c8f26cae4da838fbc3eadfaecf2d549d97c04b558e7bd90526a9c28b42a" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "chrono", + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_spanned" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +dependencies = [ + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_yaml" +version = "0.9.34+deprecated" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" +dependencies = [ + "indexmap", + "itoa", + "ryu", + "serde", + "unsafe-libyaml", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "sidemantic" +version = "0.1.0" +dependencies = [ + "adbc_core", + "adbc_driver_manager", + "arrow-array", + "arrow-ipc", + "arrow-schema", + "axum", + "chrono", + "crossterm", + "lazy_static", + "minijinja", + "nom", + "once_cell", + "polyglot-sql", + "pyo3", + "ratatui", + "regex", + "rmcp", + "serde", + "serde_json", + "serde_yaml", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tower-lsp", + "wasm-bindgen", + "wasm-bindgen-test", +] + +[[package]] +name = "signal-hook" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "toml" +version = "0.9.12+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf92845e79fc2e2def6a5d828f0801e29a2f8acc037becc5ab08595c7d5e9863" +dependencies = [ + "serde_spanned", + "toml_datetime", + "toml_parser", + "toml_writer", + "winnow", +] + +[[package]] +name = "toml_datetime" +version = "0.7.5+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92e1cfed4a3038bc5a127e35a2d360f145e1f4b971b551a2ba5fd7aedf7e1347" +dependencies = [ + "serde_core", +] + +[[package]] +name = "toml_parser" +version = "1.0.9+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +dependencies = [ + "winnow", +] + +[[package]] +name = "toml_writer" +version = "1.0.6+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-lsp" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4ba052b54a6627628d9b3c34c176e7eda8359b7da9acd497b9f20998d118508" +dependencies = [ + "async-trait", + "auto_impl", + "bytes", + "dashmap", + "futures", + "httparse", + "lsp-types", + "memchr", + "serde", + "serde_json", + "tokio", + "tokio-util", + "tower 0.4.13", + "tower-lsp-macros", + "tracing", +] + +[[package]] +name = "tower-lsp-macros" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84fd902d4e0b9a4b27f2f440108dc034e1758628a9b702f8ec61ad66355422fa" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-truncate" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3644627a5af5fa321c95b9b235a72fd24cd29c648c2c379431e6628655627bf" +dependencies = [ + "itertools", + "unicode-segmentation", + "unicode-width 0.1.14", +] + +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "unsafe-libyaml" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", + "serde_derive", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da" +dependencies = [ + "cfg-if", + "futures-util", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-bindgen-test" +version = "0.3.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d46bdcdbd11d994eb7d595fc5c6d1cab2bcd4374377e012f7e1691ac7f79e8ca" +dependencies = [ + "async-trait", + "cast", + "js-sys", + "libm", + "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e52641fa14e747360653ca52b5404268fec20badd1c097d3d2740ef121a09b09" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343a5d3439d1d59fd0fc5ae509e18c148db96777032deb880b59b2d72e9c55ad" + +[[package]] +name = "web-sys" +version = "0.3.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/sidemantic-rs/Cargo.toml b/sidemantic-rs/Cargo.toml index bd653ec2..327722c9 100644 --- a/sidemantic-rs/Cargo.toml +++ b/sidemantic-rs/Cargo.toml @@ -3,12 +3,29 @@ name = "sidemantic" version = "0.1.0" edition = "2021" description = "A SQL-first semantic layer in Rust" +license = "AGPL-3.0-only" +repository = "https://github.com/sidequery/sidemantic" +homepage = "https://sidemantic.com" [lib] -crate-type = ["rlib", "staticlib"] +crate-type = ["rlib", "staticlib", "cdylib"] + +[features] +default = [] +python = ["dep:pyo3"] +python-adbc = ["python", "adbc-exec"] +adbc-exec = ["dep:adbc_driver_manager", "dep:adbc_core", "dep:arrow-array", "dep:arrow-schema", "dep:arrow-ipc"] +wasm = ["dep:wasm-bindgen"] +mcp-server = ["dep:rmcp", "dep:tokio"] +mcp-adbc = ["mcp-server", "adbc-exec"] +runtime-server = ["dep:axum", "dep:tokio"] +runtime-server-adbc = ["runtime-server", "adbc-exec", "dep:tokio-stream"] +runtime-lsp = ["dep:tokio", "dep:tower-lsp"] +workbench-tui = ["dep:ratatui", "dep:crossterm"] +workbench-adbc = ["workbench-tui", "adbc-exec"] [dependencies] -polyglot-sql = "0.1" +polyglot-sql = "0.1.4" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" serde_yaml = "0.9" @@ -17,3 +34,46 @@ once_cell = "1.19" regex = "1.10" lazy_static = "1.5" nom = "7.1" +pyo3 = { version = "0.23", optional = true, features = ["abi3-py311", "extension-module"] } +adbc_driver_manager = { version = "0.22.0", optional = true } +adbc_core = { version = "0.22.0", optional = true } +arrow-array = { version = ">=53.1.0, <58", default-features = false, optional = true } +arrow-schema = { version = ">=53.1.0, <58", default-features = false, optional = true } +arrow-ipc = { version = ">=53.1.0, <58", default-features = false, optional = true } +chrono = { version = "0.4", default-features = false, features = ["clock"] } +minijinja = { version = "2.5", features = ["builtins", "serde"] } +wasm-bindgen = { version = "0.2", optional = true } +rmcp = { version = "0.16.0", optional = true, features = ["transport-io"] } +tokio = { version = "1", optional = true, features = ["rt-multi-thread", "macros", "io-std", "sync"] } +tokio-stream = { version = "0.1", optional = true } +axum = { version = "0.8.1", optional = true, features = ["json"] } +tower-lsp = { version = "0.20.0", optional = true } +ratatui = { version = "0.29.0", optional = true } +crossterm = { version = "0.28.1", optional = true } + +[dev-dependencies] +wasm-bindgen-test = "0.3" + +[[bin]] +name = "sidemantic-mcp" +path = "src/bin/sidemantic-mcp.rs" +required-features = ["mcp-server"] + +[[bin]] +name = "sidemantic-server" +path = "src/bin/sidemantic-server.rs" +required-features = ["runtime-server"] + +[[bin]] +name = "sidemantic-lsp" +path = "src/bin/sidemantic-lsp.rs" +required-features = ["runtime-lsp"] + +[[bin]] +name = "sidemantic" +path = "src/main.rs" + +[[bin]] +name = "sidemantic-workbench" +path = "src/bin/sidemantic-workbench.rs" +required-features = ["workbench-tui"] diff --git a/sidemantic-rs/include/sidemantic.h b/sidemantic-rs/include/sidemantic.h index 6dcd5cfa..0f996c0c 100644 --- a/sidemantic-rs/include/sidemantic.h +++ b/sidemantic-rs/include/sidemantic.h @@ -27,6 +27,7 @@ typedef struct { * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_load_yaml(const char *yaml); +char *sidemantic_load_yaml_for_context(const char *context, const char *yaml); /* * Load semantic models from a file or directory path. @@ -36,11 +37,13 @@ char *sidemantic_load_yaml(const char *yaml); * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_load_file(const char *path); +char *sidemantic_load_file_for_context(const char *context, const char *path); /* * Clear all loaded semantic models. */ void sidemantic_clear(void); +void sidemantic_clear_for_context(const char *context); /* * Define a semantic model from SQL definition format. @@ -48,14 +51,15 @@ void sidemantic_clear(void); * Parses the definition, saves to file, and loads into current session. * If `replace` is true, removes any existing model with the same name from the file. * - * db_path: Path to the database file (NULL for in-memory). + * db_path: Path to the database file (NULL for in-memory/session-local). * - If db_path is "foo.duckdb", definitions are saved to "foo.sidemantic.sql" - * - If db_path is NULL or ":memory:", definitions are saved to "./sidemantic_definitions.sql" + * - If db_path is NULL or ":memory:", definitions are not persisted * * Returns NULL on success, error message on failure. * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_define(const char *definition_sql, const char *db_path, bool replace); +char *sidemantic_define_for_context(const char *context, const char *definition_sql, const char *db_path, bool replace); /* * Auto-load definitions from file if it exists. @@ -67,6 +71,7 @@ char *sidemantic_define(const char *definition_sql, const char *db_path, bool re * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_autoload(const char *db_path); +char *sidemantic_autoload_for_context(const char *context, const char *db_path); /* * Add a metric/dimension/segment to a model. @@ -85,6 +90,7 @@ char *sidemantic_autoload(const char *db_path); * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_add_definition(const char *definition_sql, const char *db_path, bool is_replace); +char *sidemantic_add_definition_for_context(const char *context, const char *definition_sql, const char *db_path, bool is_replace); /* * Set the active model for subsequent METRIC/DIMENSION/SEGMENT additions. @@ -95,11 +101,13 @@ char *sidemantic_add_definition(const char *definition_sql, const char *db_path, * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_use(const char *model_name); +char *sidemantic_use_for_context(const char *context, const char *model_name); /* * Check if a table name is a registered semantic model. */ bool sidemantic_is_model(const char *table_name); +bool sidemantic_is_model_for_context(const char *context, const char *table_name); /* * Get list of registered model names (comma-separated). @@ -107,6 +115,7 @@ bool sidemantic_is_model(const char *table_name); * Caller must free the returned string with sidemantic_free(). */ char *sidemantic_list_models(void); +char *sidemantic_list_models_for_context(const char *context); /* * Rewrite a SQL query using semantic definitions. @@ -115,6 +124,7 @@ char *sidemantic_list_models(void); * Caller must free with sidemantic_free_result(). */ SidemanticRewriteResult sidemantic_rewrite(const char *sql); +SidemanticRewriteResult sidemantic_rewrite_for_context(const char *context, const char *sql); /* * Free a string returned by sidemantic functions. diff --git a/sidemantic-rs/pyproject.toml b/sidemantic-rs/pyproject.toml new file mode 100644 index 00000000..824482b9 --- /dev/null +++ b/sidemantic-rs/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["maturin>=1.7,<2.0"] +build-backend = "maturin" + +[project] +name = "sidemantic-rs" +version = "0.1.0" +description = "Standalone Rust runtime bindings for Sidemantic" +license = "AGPL-3.0-only" +requires-python = ">=3.11" + +[tool.maturin] +module-name = "sidemantic_rs" +features = ["python-adbc"] diff --git a/sidemantic-rs/src/bin/sidemantic-demo.rs b/sidemantic-rs/src/bin/sidemantic-demo.rs new file mode 100644 index 00000000..820a5e4a --- /dev/null +++ b/sidemantic-rs/src/bin/sidemantic-demo.rs @@ -0,0 +1,200 @@ +//! Example usage of sidemantic-rs + +use sidemantic::sql::{QueryRewriter, SemanticQuery, SqlGenerator}; +use sidemantic::{ + load_from_string, Dimension, Metric, Model, Relationship, Segment, SemanticGraph, +}; + +fn main() { + println!("=== Sidemantic-rs Demo ===\n"); + + // Demo 1: Programmatic API + demo_programmatic_api(); + + // Demo 2: YAML Loading (native format) + demo_yaml_loading(); + + // Demo 3: Cube.js Format + demo_cube_format(); + + // Demo 4: Segments + demo_segments(); + + // Demo 5: Query Rewriter + demo_query_rewriter(); +} + +fn demo_programmatic_api() { + println!("--- 1. Programmatic API ---\n"); + + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_dimension(Dimension::categorical("status")) + .with_dimension(Dimension::time("order_date").with_sql("created_at")) + .with_metric(Metric::sum("revenue", "amount")) + .with_metric(Metric::count("order_count")) + .with_metric(Metric::avg("avg_order_value", "amount")) + .with_relationship(Relationship::many_to_one("customers")); + + let customers = Model::new("customers", "id") + .with_table("customers") + .with_dimension(Dimension::categorical("name")) + .with_dimension(Dimension::categorical("country")); + + graph.add_model(orders).unwrap(); + graph.add_model(customers).unwrap(); + + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into(), "orders.order_count".into()]) + .with_dimensions(vec!["orders.status".into()]); + + println!("Query: revenue and order_count by status"); + println!("{}\n", generator.generate(&query).unwrap()); +} + +fn demo_yaml_loading() { + println!("--- 2. YAML Loading (Native Format) ---\n"); + + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: order_date + type: time + sql: created_at + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count +"#; + + let graph = load_from_string(yaml).unwrap(); + let generator = SqlGenerator::new(&graph); + + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into()]) + .with_dimensions(vec!["orders.status".into()]); + + println!("Loaded from YAML:"); + println!("{}\n", generator.generate(&query).unwrap()); +} + +fn demo_cube_format() { + println!("--- 3. Cube.js Format ---\n"); + + let yaml = r#" +cubes: + - name: orders + sql_table: public.orders + + dimensions: + - name: status + sql: "${CUBE}.status" + type: string + - name: created_at + sql: "${CUBE}.created_at" + type: time + + measures: + - name: revenue + sql: "${CUBE}.amount" + type: sum + - name: order_count + type: count + + segments: + - name: completed + sql: "${CUBE}.status = 'completed'" +"#; + + let graph = load_from_string(yaml).unwrap(); + let model = graph.get_model("orders").unwrap(); + + println!("Converted from Cube.js format:"); + println!(" Table: {:?}", model.table); + println!( + " Dimensions: {:?}", + model.dimensions.iter().map(|d| &d.name).collect::>() + ); + println!( + " Metrics: {:?}", + model.metrics.iter().map(|m| &m.name).collect::>() + ); + println!( + " Segments: {:?}\n", + model.segments.iter().map(|s| &s.name).collect::>() + ); + + let generator = SqlGenerator::new(&graph); + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into()]) + .with_dimensions(vec!["orders.status".into()]); + + println!("Generated SQL:"); + println!("{}\n", generator.generate(&query).unwrap()); +} + +fn demo_segments() { + println!("--- 4. Segments (Reusable Filters) ---\n"); + + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_dimension(Dimension::categorical("status")) + .with_metric(Metric::sum("revenue", "amount")) + .with_segment(Segment::new("completed", "{model}.status = 'completed'")) + .with_segment(Segment::new("high_value", "{model}.amount > 100")); + + graph.add_model(orders).unwrap(); + + let generator = SqlGenerator::new(&graph); + + // Query with segment + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into()]) + .with_segments(vec!["orders.completed".into()]); + + println!("Query with 'completed' segment:"); + println!("{}\n", generator.generate(&query).unwrap()); + + // Query with multiple segments + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".into()]) + .with_segments(vec!["orders.completed".into(), "orders.high_value".into()]); + + println!("Query with multiple segments:"); + println!("{}\n", generator.generate(&query).unwrap()); +} + +fn demo_query_rewriter() { + println!("--- 5. Query Rewriter ---\n"); + + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("public.orders") + .with_dimension(Dimension::categorical("status")) + .with_metric(Metric::sum("revenue", "amount")) + .with_metric(Metric::count("order_count")); + + graph.add_model(orders).unwrap(); + + let rewriter = QueryRewriter::new(&graph); + + let sql = "SELECT orders.revenue, orders.status FROM orders WHERE orders.status = 'pending'"; + println!("Original SQL:"); + println!(" {sql}\n"); + println!("Rewritten SQL:"); + println!("{}\n", rewriter.rewrite(sql).unwrap()); +} diff --git a/sidemantic-rs/src/bin/sidemantic-lsp.rs b/sidemantic-rs/src/bin/sidemantic-lsp.rs new file mode 100644 index 00000000..36676144 --- /dev/null +++ b/sidemantic-rs/src/bin/sidemantic-lsp.rs @@ -0,0 +1,874 @@ +//! Rust-native LSP server for Sidemantic SQL definitions. + +use std::collections::HashMap; +use std::sync::Arc; + +use serde_json::Value as JsonValue; +use sidemantic::runtime::parse_sql_statement_blocks_payload; +use tokio::sync::RwLock; +use tower_lsp::jsonrpc::Result; +use tower_lsp::lsp_types::{ + CodeAction, CodeActionKind, CodeActionOrCommand, CodeActionParams, + CodeActionProviderCapability, CodeActionResponse, CompletionItem, CompletionItemKind, + CompletionOptions, CompletionParams, CompletionResponse, Diagnostic, DiagnosticSeverity, + DidChangeTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams, + DocumentFormattingParams, DocumentSymbol, DocumentSymbolParams, DocumentSymbolResponse, + GotoDefinitionParams, GotoDefinitionResponse, Hover, HoverContents, HoverParams, + InitializeParams, InitializeResult, InitializedParams, Location, MarkupContent, MarkupKind, + MessageType, OneOf, ParameterInformation, ParameterLabel, Position, Range, ReferenceParams, + RenameParams, ServerCapabilities, SignatureHelp, SignatureHelpOptions, SignatureHelpParams, + SignatureInformation, SymbolKind, TextDocumentSyncCapability, TextDocumentSyncKind, TextEdit, + Url, WorkspaceEdit, +}; +use tower_lsp::{Client, LanguageServer, LspService, Server}; + +const KEYWORDS: &[&str] = &[ + "MODEL", + "DIMENSION", + "METRIC", + "RELATIONSHIP", + "SEGMENT", + "PARAMETER", + "PRE_AGGREGATION", +]; + +fn model_properties(def_type: &str) -> &'static [(&'static str, &'static str)] { + match def_type { + "MODEL" => &[ + ("name", "Unique model name."), + ("table", "Physical table for the model."), + ("sql", "Optional model SQL."), + ("primary_key", "Primary key column."), + ("default_time_dimension", "Default time dimension."), + ], + "DIMENSION" => &[ + ("name", "Unique dimension name."), + ( + "type", + "Dimension type (categorical, time, number, boolean).", + ), + ("sql", "Dimension SQL expression."), + ("description", "Human description."), + ("granularity", "Time granularity for time dimensions."), + ], + "METRIC" => &[ + ("name", "Unique metric name."), + ("agg", "Aggregation type."), + ("sql", "Metric SQL expression."), + ("type", "Metric semantic type."), + ("description", "Human description."), + ], + "RELATIONSHIP" => &[ + ("name", "Target model name."), + ("type", "Relationship type."), + ("foreign_key", "Foreign key column."), + ("primary_key", "Primary key column."), + ("through", "Junction model for many_to_many."), + ], + "SEGMENT" => &[ + ("name", "Unique segment name."), + ("sql", "Segment filter SQL."), + ("description", "Human description."), + ("public", "Visibility flag."), + ], + "PARAMETER" => &[ + ("name", "Unique parameter name."), + ("type", "Parameter type."), + ("default", "Default parameter value."), + ("description", "Human description."), + ], + "PRE_AGGREGATION" => &[ + ("name", "Unique pre-aggregation name."), + ("type", "Pre-aggregation type."), + ("measures", "Metric references included in the rollup."), + ("dimensions", "Dimension references included in the rollup."), + ("time_dimension", "Time dimension for granular rollups."), + ("granularity", "Time granularity."), + ], + _ => &[], + } +} + +fn keyword_doc(keyword: &str) -> Option<&'static str> { + match keyword { + "MODEL" => Some("Top-level model definition."), + "DIMENSION" => Some("Dimension definition inside or alongside a model."), + "METRIC" => Some("Metric definition."), + "RELATIONSHIP" => Some("Relationship definition between models."), + "SEGMENT" => Some("Reusable filter segment."), + "PARAMETER" => Some("Runtime query parameter definition."), + "PRE_AGGREGATION" => Some("Pre-aggregation definition for query routing."), + _ => None, + } +} + +fn symbol_kind_for_def(def_type: &str) -> SymbolKind { + match def_type { + "MODEL" => SymbolKind::CLASS, + "DIMENSION" => SymbolKind::FIELD, + "METRIC" => SymbolKind::FUNCTION, + "RELATIONSHIP" => SymbolKind::INTERFACE, + "SEGMENT" => SymbolKind::BOOLEAN, + "PARAMETER" => SymbolKind::VARIABLE, + "PRE_AGGREGATION" => SymbolKind::STRUCT, + _ => SymbolKind::OBJECT, + } +} + +fn get_completion_context(text: &str, line: u32, character: u32) -> String { + let lines: Vec<&str> = text.split('\n').collect(); + let mut paren_depth: i32 = 0; + let mut current_def: Option<&str> = None; + let max_line = line as usize; + + let mut i = max_line as i32; + while i >= 0 { + let idx = i as usize; + let check_line = if idx == max_line { + let ch = character as usize; + lines + .get(idx) + .map(|line_text| { + if ch <= line_text.len() { + &line_text[..ch] + } else { + *line_text + } + }) + .unwrap_or("") + } else { + lines.get(idx).copied().unwrap_or("") + }; + + for ch in check_line.chars().rev() { + if ch == ')' { + paren_depth += 1; + } else if ch == '(' { + paren_depth -= 1; + } + } + + let full_line_upper = lines + .get(idx) + .map(|line_text| line_text.to_ascii_uppercase()) + .unwrap_or_default(); + for keyword in KEYWORDS { + if full_line_upper.contains(keyword) && full_line_upper.contains('(') { + if paren_depth < 0 { + return format!("inside_{}", keyword.to_ascii_lowercase()); + } + current_def = Some(keyword); + } + } + + i -= 1; + } + + if paren_depth < 0 { + if let Some(keyword) = current_def { + return format!("inside_{}", keyword.to_ascii_lowercase()); + } + } + + "top_level".to_string() +} + +fn get_word_at_position(text: &str, line: u32, character: u32) -> Option { + let lines: Vec<&str> = text.split('\n').collect(); + let line_text = lines.get(line as usize)?; + if line_text.is_empty() { + return None; + } + let mut ch = character as usize; + if ch >= line_text.len() { + ch = line_text.len().saturating_sub(1); + } + + let bytes = line_text.as_bytes(); + if !(bytes[ch].is_ascii_alphanumeric() || bytes[ch] == b'_') { + return None; + } + + let mut start = ch; + let mut end = ch; + + while start > 0 && (bytes[start - 1].is_ascii_alphanumeric() || bytes[start - 1] == b'_') { + start -= 1; + } + while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { + end += 1; + } + + Some(line_text[start..end].to_string()) +} + +fn diagnostics_for_text(text: &str) -> Vec { + match parse_sql_statement_blocks_payload(text) { + Ok(_) => Vec::new(), + Err(e) => vec![Diagnostic { + range: Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 0, + character: 100, + }, + }, + severity: Some(DiagnosticSeverity::ERROR), + code: None, + code_description: None, + source: Some("sidemantic-rs".to_string()), + message: format!("Parse error: {e}"), + related_information: None, + tags: None, + data: None, + }], + } +} + +#[derive(Debug, Clone)] +struct DefinitionInfo { + def_type: String, + name: String, + range: Range, + selection_range: Range, +} + +fn offset_to_position(text: &str, offset: usize) -> Position { + let mut line = 0u32; + let mut character = 0u32; + for (idx, ch) in text.char_indices() { + if idx >= offset { + break; + } + if ch == '\n' { + line += 1; + character = 0; + } else { + character += 1; + } + } + Position { line, character } +} + +fn range_for_offsets(text: &str, start: usize, end: usize) -> Range { + Range { + start: offset_to_position(text, start), + end: offset_to_position(text, end), + } +} + +fn find_matching_definition_end(text: &str, open_offset: usize) -> usize { + let mut depth = 0i32; + for (offset, ch) in text[open_offset..].char_indices() { + match ch { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth == 0 { + let after_close = open_offset + offset + ch.len_utf8(); + let rest = &text[after_close..]; + let semicolon_len = rest.chars().take_while(|ch| ch.is_whitespace()).count() + + usize::from(rest.trim_start().starts_with(';')); + return after_close + semicolon_len; + } + } + _ => {} + } + } + text.len() +} + +fn find_definition_name(block: &str, base_offset: usize) -> Option<(String, usize, usize)> { + let bytes = block.as_bytes(); + let mut idx = 0usize; + while idx + 4 <= bytes.len() { + if !block[idx..].starts_with("name") { + idx += 1; + continue; + } + let before_ok = idx == 0 || !bytes[idx - 1].is_ascii_alphanumeric(); + let after = idx + 4; + let after_ok = after < bytes.len() && bytes[after].is_ascii_whitespace(); + if !before_ok || !after_ok { + idx += 1; + continue; + } + let mut start = after; + while start < bytes.len() && bytes[start].is_ascii_whitespace() { + start += 1; + } + if start >= bytes.len() { + return None; + } + + let (name_start, name_end, name) = if bytes[start] == b'"' || bytes[start] == b'\'' { + let quote = bytes[start]; + let name_start = start + 1; + let mut name_end = name_start; + while name_end < bytes.len() && bytes[name_end] != quote { + name_end += 1; + } + ( + name_start, + name_end, + block[name_start..name_end].to_string(), + ) + } else { + let name_start = start; + let mut name_end = name_start; + while name_end < bytes.len() + && (bytes[name_end].is_ascii_alphanumeric() || bytes[name_end] == b'_') + { + name_end += 1; + } + ( + name_start, + name_end, + block[name_start..name_end].to_string(), + ) + }; + if !name.is_empty() { + return Some((name, base_offset + name_start, base_offset + name_end)); + } + idx += 1; + } + None +} + +fn extract_definitions(text: &str) -> Vec { + let mut definitions = Vec::new(); + let upper = text.to_ascii_uppercase(); + let mut cursor = 0usize; + while cursor < text.len() { + let mut next_match: Option<(&str, usize)> = None; + for keyword in KEYWORDS { + if let Some(relative) = upper[cursor..].find(keyword) { + let offset = cursor + relative; + let before_ok = offset == 0 + || !upper.as_bytes()[offset.saturating_sub(1)].is_ascii_alphanumeric(); + let after_keyword = offset + keyword.len(); + let after_ok = upper[after_keyword..].trim_start().starts_with('('); + if before_ok + && after_ok + && next_match + .as_ref() + .is_none_or(|(_, current_offset)| offset < *current_offset) + { + next_match = Some((keyword, offset)); + } + } + } + + let Some((keyword, start_offset)) = next_match else { + break; + }; + let open_offset = text[start_offset..] + .find('(') + .map(|relative| start_offset + relative) + .unwrap_or(start_offset + keyword.len()); + let end_offset = find_matching_definition_end(text, open_offset); + let block = &text[start_offset..end_offset.min(text.len())]; + if let Some((name, name_start, name_end)) = find_definition_name(block, start_offset) { + definitions.push(DefinitionInfo { + def_type: keyword.to_string(), + name, + range: range_for_offsets(text, start_offset, end_offset), + selection_range: range_for_offsets(text, name_start, name_end), + }); + } + cursor = end_offset.max(start_offset + keyword.len()); + } + definitions +} + +fn format_sidemantic_document(text: &str) -> Option { + let payload = parse_sql_statement_blocks_payload(text).ok()?; + let blocks: JsonValue = serde_json::from_str(&payload).ok()?; + let blocks = blocks.as_array()?; + let mut formatted_blocks = Vec::new(); + for block in blocks { + let kind = block.get("kind")?.as_str()?; + let properties = block.get("properties")?.as_object()?; + let keyword = kind.to_ascii_uppercase(); + let keyword = if keyword == "PRE_AGGREGATION" { + "PRE_AGGREGATION".to_string() + } else { + keyword + }; + let mut lines = Vec::new(); + for (idx, (key, value)) in properties.iter().enumerate() { + let value_text = match value { + JsonValue::String(value) => value.clone(), + other => other.to_string(), + }; + let comma = if idx + 1 == properties.len() { "" } else { "," }; + lines.push(format!(" {key} {value_text}{comma}")); + } + formatted_blocks.push(format!("{keyword} (\n{}\n);", lines.join("\n"))); + } + Some(format!("{}\n", formatted_blocks.join("\n\n"))) +} + +fn signature_help_for_context(context: &str, word: Option<&str>) -> Option { + let keyword = context + .strip_prefix("inside_") + .map(|def_type| def_type.to_ascii_uppercase()) + .or_else(|| word.map(|word| word.to_ascii_uppercase()))?; + if !KEYWORDS.contains(&keyword.as_str()) { + return None; + } + let props = model_properties(&keyword) + .iter() + .map(|(name, _)| format!("{name}: value")) + .collect::>(); + let label = format!("{}({})", keyword, props.join(", ")); + let parameters = props + .iter() + .map(|label| ParameterInformation { + label: ParameterLabel::Simple(label.clone()), + documentation: None, + }) + .collect::>(); + Some(SignatureHelp { + signatures: vec![SignatureInformation { + label, + documentation: None, + parameters: Some(parameters), + active_parameter: None, + }], + active_signature: Some(0), + active_parameter: Some(0), + }) +} + +fn reference_locations( + uri: &Url, + text: &str, + word: &str, + include_declaration: bool, +) -> Vec { + if KEYWORDS.contains(&word.to_ascii_uppercase().as_str()) { + return Vec::new(); + } + let definitions = extract_definitions(text); + let declaration_ranges = definitions + .iter() + .filter(|definition| definition.name.eq_ignore_ascii_case(word)) + .map(|definition| definition.selection_range) + .collect::>(); + let mut locations = Vec::new(); + let mut cursor = 0usize; + while let Some(relative) = text[cursor..] + .to_ascii_lowercase() + .find(&word.to_ascii_lowercase()) + { + let start = cursor + relative; + let end = start + word.len(); + let before_ok = start == 0 || !text.as_bytes()[start - 1].is_ascii_alphanumeric(); + let after_ok = end == text.len() || !text.as_bytes()[end].is_ascii_alphanumeric(); + if before_ok && after_ok { + let range = range_for_offsets(text, start, end); + let is_declaration = declaration_ranges.contains(&range); + if include_declaration || !is_declaration { + locations.push(Location { + uri: uri.clone(), + range, + }); + } + } + cursor = end; + } + locations +} + +struct Backend { + client: Client, + documents: Arc>>, +} + +impl Backend { + async fn publish_diagnostics_for_uri(&self, uri: &Url) { + let docs = self.documents.read().await; + if let Some(text) = docs.get(uri) { + let diagnostics = diagnostics_for_text(text); + self.client + .publish_diagnostics(uri.clone(), diagnostics, None) + .await; + } + } +} + +#[tower_lsp::async_trait] +impl LanguageServer for Backend { + async fn initialize(&self, _: InitializeParams) -> Result { + Ok(InitializeResult { + server_info: None, + capabilities: ServerCapabilities { + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + completion_provider: Some(CompletionOptions::default()), + hover_provider: Some(tower_lsp::lsp_types::HoverProviderCapability::Simple(true)), + document_formatting_provider: Some(OneOf::Left(true)), + document_symbol_provider: Some(OneOf::Left(true)), + definition_provider: Some(OneOf::Left(true)), + references_provider: Some(OneOf::Left(true)), + rename_provider: Some(OneOf::Left(true)), + signature_help_provider: Some(SignatureHelpOptions::default()), + code_action_provider: Some(CodeActionProviderCapability::Simple(true)), + ..ServerCapabilities::default() + }, + }) + } + + async fn initialized(&self, _: InitializedParams) { + self.client + .log_message(MessageType::INFO, "sidemantic-rs LSP initialized") + .await; + } + + async fn shutdown(&self) -> Result<()> { + Ok(()) + } + + async fn did_open(&self, params: DidOpenTextDocumentParams) { + { + let mut docs = self.documents.write().await; + docs.insert( + params.text_document.uri.clone(), + params.text_document.text.clone(), + ); + } + self.publish_diagnostics_for_uri(¶ms.text_document.uri) + .await; + } + + async fn did_change(&self, params: DidChangeTextDocumentParams) { + if let Some(change) = params.content_changes.first() { + let mut docs = self.documents.write().await; + docs.insert(params.text_document.uri.clone(), change.text.clone()); + } + self.publish_diagnostics_for_uri(¶ms.text_document.uri) + .await; + } + + async fn did_save(&self, params: DidSaveTextDocumentParams) { + self.publish_diagnostics_for_uri(¶ms.text_document.uri) + .await; + } + + async fn completion(&self, params: CompletionParams) -> Result> { + let docs = self.documents.read().await; + let Some(text) = docs.get(¶ms.text_document_position.text_document.uri) else { + return Ok(Some(CompletionResponse::Array(Vec::new()))); + }; + + let context = get_completion_context( + text, + params.text_document_position.position.line, + params.text_document_position.position.character, + ); + + let mut items = Vec::new(); + if context == "top_level" { + for keyword in KEYWORDS { + items.push(CompletionItem { + label: (*keyword).to_string(), + kind: Some(CompletionItemKind::KEYWORD), + detail: Some("Sidemantic definition".to_string()), + insert_text: Some(format!("{keyword} (\n name $1,\n $0\n);")), + insert_text_format: Some(tower_lsp::lsp_types::InsertTextFormat::SNIPPET), + ..CompletionItem::default() + }); + } + } else if let Some(def_type) = context.strip_prefix("inside_") { + let upper = def_type.to_ascii_uppercase(); + for (prop, description) in model_properties(&upper) { + items.push(CompletionItem { + label: (*prop).to_string(), + kind: Some(CompletionItemKind::PROPERTY), + detail: Some((*description).to_string()), + insert_text: Some(format!("{prop} $0,")), + insert_text_format: Some(tower_lsp::lsp_types::InsertTextFormat::SNIPPET), + ..CompletionItem::default() + }); + } + } + + Ok(Some(CompletionResponse::Array(items))) + } + + async fn hover(&self, params: HoverParams) -> Result> { + let docs = self.documents.read().await; + let Some(text) = docs.get(¶ms.text_document_position_params.text_document.uri) else { + return Ok(None); + }; + let pos = params.text_document_position_params.position; + let Some(word) = get_word_at_position(text, pos.line, pos.character) else { + return Ok(None); + }; + + let word_upper = word.to_ascii_uppercase(); + if let Some(doc) = keyword_doc(&word_upper) { + return Ok(Some(Hover { + contents: HoverContents::Markup(MarkupContent { + kind: MarkupKind::Markdown, + value: format!("**{word_upper}**\n\n{doc}"), + }), + range: None, + })); + } + + let context = get_completion_context(text, pos.line, pos.character); + if let Some(def_type) = context.strip_prefix("inside_") { + let upper = def_type.to_ascii_uppercase(); + for (prop, description) in model_properties(&upper) { + if prop.eq_ignore_ascii_case(&word) { + return Ok(Some(Hover { + contents: HoverContents::Markup(MarkupContent { + kind: MarkupKind::Markdown, + value: format!("**{prop}**\n\n{description}"), + }), + range: None, + })); + } + } + } + + Ok(None) + } + + async fn formatting(&self, params: DocumentFormattingParams) -> Result>> { + let docs = self.documents.read().await; + let Some(text) = docs.get(¶ms.text_document.uri) else { + return Ok(None); + }; + let Some(formatted) = format_sidemantic_document(text) else { + return Ok(None); + }; + let end = offset_to_position(text, text.len()); + Ok(Some(vec![TextEdit { + range: Range { + start: Position { + line: 0, + character: 0, + }, + end, + }, + new_text: formatted, + }])) + } + + async fn document_symbol( + &self, + params: DocumentSymbolParams, + ) -> Result> { + let docs = self.documents.read().await; + let Some(text) = docs.get(¶ms.text_document.uri) else { + return Ok(Some(DocumentSymbolResponse::Nested(Vec::new()))); + }; + let symbols = extract_definitions(text) + .into_iter() + .map(|definition| { + #[allow(deprecated)] + DocumentSymbol { + name: definition.name, + detail: Some(definition.def_type.clone()), + kind: symbol_kind_for_def(&definition.def_type), + tags: None, + deprecated: None, + range: definition.range, + selection_range: definition.selection_range, + children: None, + } + }) + .collect::>(); + Ok(Some(DocumentSymbolResponse::Nested(symbols))) + } + + async fn signature_help(&self, params: SignatureHelpParams) -> Result> { + let docs = self.documents.read().await; + let Some(text) = docs.get(¶ms.text_document_position_params.text_document.uri) else { + return Ok(None); + }; + let pos = params.text_document_position_params.position; + let context = get_completion_context(text, pos.line, pos.character); + let word = get_word_at_position(text, pos.line, pos.character); + Ok(signature_help_for_context(&context, word.as_deref())) + } + + async fn goto_definition( + &self, + params: GotoDefinitionParams, + ) -> Result> { + let docs = self.documents.read().await; + let uri = params.text_document_position_params.text_document.uri; + let Some(text) = docs.get(&uri) else { + return Ok(None); + }; + let pos = params.text_document_position_params.position; + let Some(word) = get_word_at_position(text, pos.line, pos.character) else { + return Ok(None); + }; + let Some(definition) = extract_definitions(text) + .into_iter() + .find(|definition| definition.name.eq_ignore_ascii_case(&word)) + else { + return Ok(None); + }; + Ok(Some(GotoDefinitionResponse::Scalar(Location { + uri, + range: definition.selection_range, + }))) + } + + async fn references(&self, params: ReferenceParams) -> Result>> { + let docs = self.documents.read().await; + let uri = params.text_document_position.text_document.uri; + let Some(text) = docs.get(&uri) else { + return Ok(Some(Vec::new())); + }; + let pos = params.text_document_position.position; + let Some(word) = get_word_at_position(text, pos.line, pos.character) else { + return Ok(Some(Vec::new())); + }; + Ok(Some(reference_locations( + &uri, + text, + &word, + params.context.include_declaration, + ))) + } + + async fn rename(&self, params: RenameParams) -> Result> { + let docs = self.documents.read().await; + let uri = params.text_document_position.text_document.uri; + let Some(text) = docs.get(&uri) else { + return Ok(None); + }; + let pos = params.text_document_position.position; + let Some(word) = get_word_at_position(text, pos.line, pos.character) else { + return Ok(None); + }; + if KEYWORDS.contains(&word.to_ascii_uppercase().as_str()) { + return Ok(None); + } + let edits = reference_locations(&uri, text, &word, true) + .into_iter() + .map(|location| TextEdit { + range: location.range, + new_text: params.new_name.clone(), + }) + .collect::>(); + if edits.is_empty() { + return Ok(None); + } + let mut changes = HashMap::new(); + changes.insert(uri, edits); + Ok(Some(WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + })) + } + + async fn code_action(&self, params: CodeActionParams) -> Result> { + if !params.context.diagnostics.iter().any(|diagnostic| { + diagnostic.message.contains("name") || diagnostic.message.contains("Parse error") + }) { + return Ok(Some(Vec::new())); + } + let docs = self.documents.read().await; + let uri = params.text_document.uri; + let Some(text) = docs.get(&uri) else { + return Ok(Some(Vec::new())); + }; + let insertion_line = text + .lines() + .position(|line| line.contains('(')) + .map(|line| line as u32 + 1) + .unwrap_or(1); + let edit = TextEdit { + range: Range { + start: Position { + line: insertion_line, + character: 0, + }, + end: Position { + line: insertion_line, + character: 0, + }, + }, + new_text: " name model_name,\n".to_string(), + }; + let mut changes = HashMap::new(); + changes.insert(uri, vec![edit]); + Ok(Some(vec![CodeActionOrCommand::CodeAction(CodeAction { + title: "Add missing name property".to_string(), + kind: Some(CodeActionKind::QUICKFIX), + diagnostics: Some(params.context.diagnostics), + edit: Some(WorkspaceEdit { + changes: Some(changes), + document_changes: None, + change_annotations: None, + }), + command: None, + is_preferred: Some(true), + disabled: None, + data: None, + })])) + } +} + +#[tokio::main] +async fn main() { + let stdin = tokio::io::stdin(); + let stdout = tokio::io::stdout(); + + let (service, socket) = LspService::new(|client| Backend { + client, + documents: Arc::new(RwLock::new(HashMap::new())), + }); + + Server::new(stdin, stdout, socket).serve(service).await; +} + +#[cfg(test)] +mod tests { + use super::{get_completion_context, get_word_at_position}; + + #[test] + fn completion_context_top_level() { + let text = "\nMODEL (\n name orders\n);\n\n"; + assert_eq!(get_completion_context(text, 5, 0), "top_level"); + } + + #[test] + fn completion_context_inside_model() { + let text = "MODEL (\n name orders,\n\n);"; + assert_eq!(get_completion_context(text, 2, 4), "inside_model"); + } + + #[test] + fn completion_context_inside_metric() { + let text = "MODEL (name orders);\n\nMETRIC (\n name revenue,\n\n);"; + assert_eq!(get_completion_context(text, 4, 4), "inside_metric"); + } + + #[test] + fn word_at_position() { + let text = "MODEL (\n name orders,\n);"; + assert_eq!(get_word_at_position(text, 0, 2).as_deref(), Some("MODEL")); + assert_eq!(get_word_at_position(text, 1, 6).as_deref(), Some("name")); + assert_eq!(get_word_at_position(text, 1, 12).as_deref(), Some("orders")); + } + + #[test] + fn word_at_position_none_for_whitespace() { + let text = "MODEL ( )"; + assert_eq!(get_word_at_position(text, 0, 8), None); + } +} diff --git a/sidemantic-rs/src/bin/sidemantic-mcp.rs b/sidemantic-rs/src/bin/sidemantic-mcp.rs new file mode 100644 index 00000000..27f90b57 --- /dev/null +++ b/sidemantic-rs/src/bin/sidemantic-mcp.rs @@ -0,0 +1,1714 @@ +//! Sidemantic MCP server implemented with the rmcp Rust SDK. +//! +//! This binary provides Rust-native MCP tools for semantic model introspection +//! and query compilation/execution. + +use std::collections::HashMap; +use std::env; +use std::path::PathBuf; +use std::sync::Arc; + +#[cfg(feature = "mcp-adbc")] +use adbc_core::options::{OptionConnection, OptionDatabase, OptionValue}; +use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::{ + Annotated, CallToolResult, ListResourcesResult, PaginatedRequestParams, RawResource, + ReadResourceRequestParams, ReadResourceResult, ResourceContents, ServerCapabilities, + ServerInfo, + }, + schemars, + schemars::JsonSchema, + service::{RequestContext, RoleServer}, + tool, tool_handler, tool_router, + transport::stdio, + ErrorData as McpError, ServerHandler, ServiceExt, +}; +use serde::Deserialize; +use serde_json::{json, Map as JsonMap, Value as JsonValue}; +use sidemantic::runtime::interpolate_query_filters; +#[cfg(feature = "mcp-adbc")] +use sidemantic::{ + chart_auto_detect_columns, chart_encoding_type, chart_format_label, chart_select_type, +}; +#[cfg(feature = "mcp-adbc")] +use sidemantic::{execute_with_adbc, AdbcExecutionRequest, AdbcExecutionResult, AdbcValue}; +use sidemantic::{Metric, Model, Relationship, RelationshipType, SemanticQuery, SidemanticRuntime}; + +const CATALOG_RESOURCE_URI: &str = "semantic://catalog"; +#[cfg(feature = "mcp-adbc")] +const CHART_COLORS: [&str; 8] = [ + "#2E5EAA", "#E8702A", "#4C9A2A", "#9B59B6", "#1ABC9C", "#E74C3C", "#F39C12", "#34495E", +]; + +#[cfg(feature = "mcp-adbc")] +type DatabaseOption = (OptionDatabase, OptionValue); +#[cfg(not(feature = "mcp-adbc"))] +type DatabaseOption = (String, String); +#[cfg(feature = "mcp-adbc")] +type ConnectionOption = (OptionConnection, OptionValue); +#[cfg(not(feature = "mcp-adbc"))] +type ConnectionOption = (String, String); + +#[cfg_attr(not(feature = "mcp-adbc"), allow(dead_code))] +#[derive(Debug, Clone)] +struct SidemanticMcpServer { + runtime: Arc, + adbc_driver: Option, + adbc_uri: Option, + adbc_entrypoint: Option, + database_options: Vec, + connection_options: Vec, + tool_router: ToolRouter, +} + +#[derive(Debug, Default)] +struct ServerConfig { + models_path: String, + adbc_driver: Option, + adbc_uri: Option, + adbc_entrypoint: Option, + database_options: Vec, + connection_options: Vec, +} + +#[derive(Debug, Clone, Deserialize, JsonSchema)] +struct GetModelsRequest { + #[serde(default)] + model_names: Vec, +} + +#[derive(Debug, Clone, Default, Deserialize, JsonSchema)] +struct QueryRequest { + #[serde(default)] + dimensions: Vec, + #[serde(default)] + metrics: Vec, + #[serde(default, rename = "where")] + where_clause: Option, + #[serde(default)] + filters: Vec, + #[serde(default)] + segments: Vec, + #[serde(default)] + order_by: Vec, + #[serde(default)] + limit: Option, + #[serde(default)] + offset: Option, + #[serde(default)] + ungrouped: bool, + #[serde(default)] + dry_run: bool, + #[serde(default)] + use_preaggregations: bool, + #[serde(default)] + parameters: JsonMap, +} + +#[derive(Debug, Clone, Default, Deserialize, JsonSchema)] +struct ValidateQueryRequest { + #[serde(default)] + dimensions: Vec, + #[serde(default)] + metrics: Vec, +} + +#[derive(Debug, Clone, Deserialize, JsonSchema)] +struct SQLRequest { + query: String, +} + +#[derive(Debug, Clone, Deserialize, JsonSchema)] +struct ChartRequest { + #[serde(default)] + dimensions: Vec, + #[serde(default)] + metrics: Vec, + #[serde(default, rename = "where")] + where_clause: Option, + #[serde(default)] + filters: Vec, + #[serde(default)] + segments: Vec, + #[serde(default)] + order_by: Vec, + #[serde(default)] + limit: Option, + #[serde(default)] + offset: Option, + #[serde(default)] + parameters: JsonMap, + #[serde(default = "default_chart_type")] + chart_type: String, + #[serde(default)] + title: String, + #[serde(default = "default_chart_width")] + width: usize, + #[serde(default = "default_chart_height")] + height: usize, +} + +fn default_chart_type() -> String { + "auto".to_string() +} + +fn default_chart_width() -> usize { + 600 +} + +fn default_chart_height() -> usize { + 400 +} + +impl SidemanticMcpServer { + fn new(runtime: SidemanticRuntime, config: ServerConfig) -> Self { + Self { + runtime: Arc::new(runtime), + adbc_driver: config.adbc_driver, + adbc_uri: config.adbc_uri, + adbc_entrypoint: config.adbc_entrypoint, + database_options: config.database_options, + connection_options: config.connection_options, + tool_router: Self::tool_router(), + } + } + + fn compile_request(&self, request: &QueryRequest) -> Result { + let mut filters = request.filters.clone(); + if let Some(where_clause) = &request.where_clause { + if !where_clause.trim().is_empty() { + filters.push(where_clause.clone()); + } + } + let parameter_values = request + .parameters + .iter() + .map(|(key, value)| { + serde_yaml::to_value(value) + .map(|value| (key.clone(), value)) + .map_err(|e| { + McpError::invalid_params( + format!("failed to parse query parameter '{key}': {e}"), + None, + ) + }) + }) + .collect::, _>>()?; + let filters = interpolate_query_filters(self.runtime.graph(), filters, ¶meter_values) + .map_err(|e| { + McpError::invalid_params(format!("failed to interpolate query parameters: {e}"), None) + })?; + + let mut query = SemanticQuery::new() + .with_dimensions(request.dimensions.clone()) + .with_metrics(request.metrics.clone()) + .with_filters(filters) + .with_segments(request.segments.clone()) + .with_ungrouped(request.ungrouped) + .with_use_preaggregations(request.use_preaggregations); + + if !request.order_by.is_empty() { + query = query.with_order_by(request.order_by.clone()); + } + if let Some(limit) = request.limit { + query = query.with_limit(limit); + } + if let Some(offset) = request.offset { + query = query.with_offset(offset); + } + + self.runtime + .compile(&query) + .map_err(|e| McpError::invalid_params(format!("failed to compile query: {e}"), None)) + } + + fn execute_sql_tool_response( + &self, + sql: String, + original_sql: Option, + tool_name: &str, + ) -> Result { + let _ = tool_name; + #[cfg(not(feature = "mcp-adbc"))] + { + let _ = (sql, original_sql); + return Err(McpError::invalid_params( + format!( + "ADBC execution support is not enabled. Rebuild with feature 'mcp-adbc' to use {tool_name}." + ), + None, + )); + } + + #[cfg(feature = "mcp-adbc")] + { + let result = self.execute_sql_with_adbc(&sql)?; + let rows = adbc_rows_to_json_rows(&result.columns, &result.rows); + let mut response = JsonMap::new(); + response.insert("sql".to_string(), json!(sql)); + if let Some(original_sql) = original_sql { + response.insert("original_sql".to_string(), json!(original_sql)); + } + response.insert("rows".to_string(), JsonValue::Array(rows.clone())); + response.insert("row_count".to_string(), json!(rows.len())); + Ok(CallToolResult::structured(JsonValue::Object(response))) + } + } + + #[cfg(feature = "mcp-adbc")] + fn execute_sql_with_adbc(&self, sql: &str) -> Result { + let Some(driver) = self.adbc_driver.as_ref() else { + return Err(McpError::invalid_params( + "ADBC driver is not configured. Set SIDEMANTIC_MCP_ADBC_DRIVER or pass --driver." + .to_string(), + None, + )); + }; + + execute_with_adbc(AdbcExecutionRequest { + driver: driver.clone(), + sql: sql.to_string(), + uri: self.adbc_uri.clone(), + entrypoint: self.adbc_entrypoint.clone(), + database_options: self.database_options.clone(), + connection_options: self.connection_options.clone(), + }) + .map_err(|e| { + McpError::internal_error(format!("failed to execute query via ADBC: {e}"), None) + }) + } +} + +#[tool_router] +impl SidemanticMcpServer { + #[tool( + name = "list_models", + description = "List all available semantic models with dimensions and metrics." + )] + async fn list_models(&self) -> Result { + let payload = self.runtime.loaded_graph_payload(); + let models = payload + .models + .iter() + .map(model_summary_json) + .collect::>(); + Ok(CallToolResult::structured(json!({ "models": models }))) + } + + #[tool( + name = "get_models", + description = "Get detailed metadata for one or more semantic models." + )] + async fn get_models( + &self, + Parameters(request): Parameters, + ) -> Result { + let payload = self.runtime.loaded_graph_payload(); + let model_map: HashMap<&str, &Model> = payload + .models + .iter() + .map(|model| (model.name.as_str(), model)) + .collect(); + + let mut details = Vec::new(); + for model_name in &request.model_names { + let Some(model) = model_map.get(model_name.as_str()) else { + continue; + }; + + details.push(model_detail_json(model, &model_map)); + } + + Ok(CallToolResult::structured(json!({ "models": details }))) + } + + #[tool( + name = "get_semantic_graph", + description = "Discover models, relationships, graph-level metrics, and joinable model pairs." + )] + async fn get_semantic_graph(&self) -> Result { + Ok(CallToolResult::structured(graph_payload(&self.runtime))) + } + + #[tool( + name = "compile_query", + description = "Compile semantic dimensions/metrics into SQL without execution." + )] + async fn compile_query( + &self, + Parameters(request): Parameters, + ) -> Result { + let sql = self.compile_request(&request)?; + Ok(CallToolResult::structured(json!({ "sql": sql }))) + } + + #[tool( + name = "validate_query", + description = "Validate semantic dimensions and metrics without compiling or executing." + )] + async fn validate_query( + &self, + Parameters(request): Parameters, + ) -> Result { + let errors = self + .runtime + .validate_query_references(&request.metrics, &request.dimensions); + Ok(CallToolResult::structured(json!({ + "valid": errors.is_empty(), + "errors": errors + }))) + } + + #[tool( + name = "run_query", + description = "Compile and execute a semantic query using ADBC driver manager." + )] + async fn run_query( + &self, + Parameters(request): Parameters, + ) -> Result { + let sql = self.compile_request(&request)?; + if request.dry_run { + return Ok(CallToolResult::structured(json!({ "sql": sql }))); + } + + #[cfg(not(feature = "mcp-adbc"))] + { + let _ = sql; + return Err(McpError::invalid_params( + "ADBC execution support is not enabled. Rebuild with feature 'mcp-adbc' to use run_query." + .to_string(), + None, + )); + } + + #[cfg(feature = "mcp-adbc")] + { + let result = self.execute_sql_with_adbc(&sql)?; + let rows = adbc_rows_to_json_rows(&result.columns, &result.rows); + + Ok(CallToolResult::structured(json!({ + "sql": sql, + "rows": rows, + "row_count": rows.len() + }))) + } + } + + #[tool( + name = "run_sql", + description = "Rewrite semantic SQL and execute it using ADBC driver manager." + )] + async fn run_sql( + &self, + Parameters(request): Parameters, + ) -> Result { + let original_sql = normalize_sql(&request.query) + .map_err(|message| McpError::invalid_params(message, None))?; + let sql = self + .runtime + .rewrite(&original_sql) + .map_err(|e| McpError::invalid_params(format!("failed to rewrite SQL: {e}"), None))?; + self.execute_sql_tool_response(sql, Some(original_sql), "run_sql") + } + + #[tool( + name = "create_chart", + description = "Execute a semantic query and return a Vega-Lite chart spec plus a PNG preview." + )] + async fn create_chart( + &self, + Parameters(request): Parameters, + ) -> Result { + let query_request = QueryRequest { + dimensions: request.dimensions.clone(), + metrics: request.metrics.clone(), + where_clause: request.where_clause.clone(), + filters: request.filters.clone(), + segments: request.segments.clone(), + order_by: request.order_by.clone(), + limit: request.limit, + offset: request.offset, + parameters: request.parameters.clone(), + ungrouped: false, + dry_run: false, + use_preaggregations: false, + }; + let sql = self.compile_request(&query_request)?; + + #[cfg(not(feature = "mcp-adbc"))] + { + let _ = ( + &request.chart_type, + &request.title, + request.width, + request.height, + ); + let _ = sql; + return Err(McpError::invalid_params( + "ADBC execution support is not enabled. Rebuild with feature 'mcp-adbc' to use create_chart." + .to_string(), + None, + )); + } + + #[cfg(feature = "mcp-adbc")] + { + if request.width == 0 || request.height == 0 { + return Err(McpError::invalid_params( + "chart width and height must be greater than zero".to_string(), + None, + )); + } + + let result = self.execute_sql_with_adbc(&sql)?; + let rows = adbc_rows_to_json_rows(&result.columns, &result.rows); + if rows.is_empty() { + return Err(McpError::invalid_params( + "Query returned no data. Check filters or use run_query to inspect results first." + .to_string(), + None, + )); + } + + let title = if request.title.trim().is_empty() { + generate_chart_title(&request.dimensions, &request.metrics) + } else { + request.title.clone() + }; + let chart = build_chart_payload( + rows, + &result.columns, + &request.chart_type, + &title, + request.width, + request.height, + )?; + + Ok(CallToolResult::structured(json!({ + "sql": sql, + "vega_spec": chart.vega_spec, + "png_base64": chart.png_base64, + "row_count": chart.row_count + }))) + } + } +} + +#[tool_handler] +impl ServerHandler for SidemanticMcpServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some( + "Rust-native Sidemantic MCP server. Tools: list_models, get_models, get_semantic_graph, validate_query, compile_query, run_query, run_sql, create_chart. Resource: semantic://catalog." + .to_string(), + ), + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + ..Default::default() + } + } + + fn list_resources( + &self, + _request: Option, + _context: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + std::future::ready(Ok(ListResourcesResult::with_all_items(vec![ + Annotated::new( + RawResource { + uri: CATALOG_RESOURCE_URI.to_string(), + name: "catalog".to_string(), + title: Some("Sidemantic Catalog Metadata".to_string()), + description: Some( + "Postgres-compatible catalog metadata for the semantic layer.".to_string(), + ), + mime_type: Some("application/json".to_string()), + size: None, + icons: None, + meta: None, + }, + None, + ), + ]))) + } + + fn read_resource( + &self, + request: ReadResourceRequestParams, + _context: RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let result = if request.uri == CATALOG_RESOURCE_URI { + self.runtime + .generate_catalog_metadata("public") + .map_err(|e| { + McpError::internal_error( + format!("failed to generate catalog metadata: {e}"), + None, + ) + }) + .map(|catalog| ReadResourceResult { + contents: vec![ResourceContents::TextResourceContents { + uri: CATALOG_RESOURCE_URI.to_string(), + mime_type: Some("application/json".to_string()), + text: catalog, + meta: None, + }], + }) + } else { + Err(McpError::resource_not_found( + format!("resource not found: {}", request.uri), + None, + )) + }; + std::future::ready(result) + } +} + +fn model_summary_json(model: &Model) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(model.name)); + entry.insert("table".to_string(), json!(model.table)); + entry.insert( + "dimensions".to_string(), + json!(model + .dimensions + .iter() + .map(|dimension| dimension.name.clone()) + .collect::>()), + ); + entry.insert( + "metrics".to_string(), + json!(model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect::>()), + ); + entry.insert( + "relationships".to_string(), + json!(model.relationships.len()), + ); + JsonValue::Object(entry) +} + +fn graph_model_summary_json(model: &Model) -> JsonValue { + let mut entry = match model_summary_json(model) { + JsonValue::Object(entry) => entry, + _ => JsonMap::new(), + }; + entry.insert( + "relationships".to_string(), + JsonValue::Array( + model + .relationships + .iter() + .map(|relationship| { + json!({ + "name": relationship.name, + "type": enum_json_name(&relationship.r#type) + .unwrap_or_else(|| "many_to_one".to_string()) + }) + }) + .collect(), + ), + ); + if !model.segments.is_empty() { + entry.insert( + "segments".to_string(), + json!(model + .segments + .iter() + .map(|segment| segment.name.clone()) + .collect::>()), + ); + } + if let Some(description) = &model.description { + entry.insert("description".to_string(), json!(description)); + } + if !model.primary_key.is_empty() { + entry.insert("primary_key".to_string(), json!(model.primary_key)); + } + if let Some(default_time_dimension) = &model.default_time_dimension { + entry.insert( + "default_time_dimension".to_string(), + json!(default_time_dimension), + ); + } + JsonValue::Object(entry) +} + +fn dimension_json(dimension: &sidemantic::Dimension) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(dimension.name)); + entry.insert( + "type".to_string(), + json!(enum_json_name(&dimension.r#type).unwrap_or_else(|| "categorical".to_string())), + ); + entry.insert("sql".to_string(), json!(dimension.sql)); + if let Some(description) = &dimension.description { + entry.insert("description".to_string(), json!(description)); + } + if let Some(label) = &dimension.label { + entry.insert("label".to_string(), json!(label)); + } + if let Some(granularity) = &dimension.granularity { + entry.insert("granularity".to_string(), json!(granularity)); + } + if let Some(supported_granularities) = &dimension.supported_granularities { + entry.insert( + "supported_granularities".to_string(), + json!(supported_granularities), + ); + } + if let Some(format) = &dimension.format { + entry.insert("format".to_string(), json!(format)); + } + if let Some(value_format_name) = &dimension.value_format_name { + entry.insert("value_format_name".to_string(), json!(value_format_name)); + } + if let Some(parent) = &dimension.parent { + entry.insert("parent".to_string(), json!(parent)); + } + JsonValue::Object(entry) +} + +fn segment_json(segment: &sidemantic::Segment) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(segment.name)); + entry.insert("sql".to_string(), json!(segment.sql)); + if let Some(description) = &segment.description { + entry.insert("description".to_string(), json!(description)); + } + if !segment.public { + entry.insert("public".to_string(), json!(false)); + } + JsonValue::Object(entry) +} + +fn model_detail_json(model: &Model, model_map: &HashMap<&str, &Model>) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(model.name)); + entry.insert("table".to_string(), json!(model.table)); + entry.insert("primary_key".to_string(), json!(model.primary_key)); + entry.insert( + "dimensions".to_string(), + JsonValue::Array(model.dimensions.iter().map(dimension_json).collect()), + ); + entry.insert( + "metrics".to_string(), + JsonValue::Array(model.metrics.iter().map(metric_json).collect()), + ); + entry.insert( + "relationships".to_string(), + JsonValue::Array( + model + .relationships + .iter() + .map(|relationship| relationship_json(model, relationship, model_map)) + .collect(), + ), + ); + if !model.segments.is_empty() { + entry.insert( + "segments".to_string(), + JsonValue::Array(model.segments.iter().map(segment_json).collect()), + ); + } + if let Some(description) = &model.description { + entry.insert("description".to_string(), json!(description)); + } + if let Some(sql) = &model.sql { + entry.insert("sql".to_string(), json!(sql)); + } + if let Some(default_time_dimension) = &model.default_time_dimension { + entry.insert( + "default_time_dimension".to_string(), + json!(default_time_dimension), + ); + } + if let Some(default_grain) = &model.default_grain { + entry.insert("default_grain".to_string(), json!(default_grain)); + } + JsonValue::Object(entry) +} + +fn graph_payload(runtime: &SidemanticRuntime) -> JsonValue { + let payload = runtime.loaded_graph_payload(); + let models = payload + .models + .iter() + .map(graph_model_summary_json) + .collect::>(); + let graph_metrics = payload + .top_level_metrics + .iter() + .map(metric_json) + .collect::>(); + let model_names = payload + .models + .iter() + .map(|model| model.name.clone()) + .collect::>(); + let mut joinable_pairs = Vec::new(); + for (idx, left_name) in model_names.iter().enumerate() { + for right_name in model_names.iter().skip(idx + 1) { + if let Ok(path) = runtime.find_join_path(left_name, right_name) { + joinable_pairs.push(json!({ + "from": left_name, + "to": right_name, + "hops": path.steps.len() + })); + } + } + } + + let mut result = JsonMap::new(); + result.insert("models".to_string(), JsonValue::Array(models)); + result.insert( + "joinable_pairs".to_string(), + JsonValue::Array(joinable_pairs), + ); + if !graph_metrics.is_empty() { + result.insert("graph_metrics".to_string(), JsonValue::Array(graph_metrics)); + } + JsonValue::Object(result) +} + +fn enum_json_name(value: &T) -> Option { + serde_json::to_value(value) + .ok() + .and_then(|value| value.as_str().map(ToString::to_string)) +} + +fn metric_json(metric: &Metric) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(metric.name)); + entry.insert("sql".to_string(), json!(metric.sql)); + if let Some(agg) = &metric.agg { + if let Some(name) = enum_json_name(agg) { + entry.insert("agg".to_string(), json!(name)); + } + } + if let Some(name) = enum_json_name(&metric.r#type) { + entry.insert("type".to_string(), json!(name)); + } + if let Some(description) = &metric.description { + entry.insert("description".to_string(), json!(description)); + } + if !metric.filters.is_empty() { + entry.insert("filters".to_string(), json!(metric.filters)); + } + JsonValue::Object(entry) +} + +fn relationship_json( + model: &Model, + relationship: &Relationship, + model_map: &HashMap<&str, &Model>, +) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(relationship.name)); + entry.insert( + "type".to_string(), + json!(enum_json_name(&relationship.r#type).unwrap_or_else(|| "many_to_one".to_string())), + ); + if let Some(foreign_key) = &relationship.foreign_key { + entry.insert("foreign_key".to_string(), json!(foreign_key)); + } + if let Some(primary_key) = &relationship.primary_key { + entry.insert("primary_key".to_string(), json!(primary_key)); + } + if let Some(through) = &relationship.through { + entry.insert("through".to_string(), json!(through)); + } + if let Some(through_fk) = &relationship.through_foreign_key { + entry.insert("through_foreign_key".to_string(), json!(through_fk)); + } + if let Some(related_fk) = &relationship.related_foreign_key { + entry.insert("related_foreign_key".to_string(), json!(related_fk)); + } + + if let Some(join_condition) = format_join_condition(model, relationship, model_map) { + entry.insert("join_condition".to_string(), json!(join_condition)); + } + + JsonValue::Object(entry) +} + +fn format_join_condition( + model: &Model, + relationship: &Relationship, + model_map: &HashMap<&str, &Model>, +) -> Option { + let related_model = model_map.get(relationship.name.as_str())?; + let related_name = relationship.name.as_str(); + let model_name = model.name.as_str(); + + match relationship.r#type { + RelationshipType::ManyToOne => { + let fk = relationship + .foreign_key + .clone() + .unwrap_or_else(|| format!("{related_name}_id")); + let pk = relationship + .primary_key + .clone() + .unwrap_or_else(|| related_model.primary_key.clone()); + Some(format!("{model_name}.{fk} = {related_name}.{pk}")) + } + RelationshipType::OneToMany | RelationshipType::OneToOne => { + let fk = relationship.foreign_key.clone()?; + let pk = model.primary_key.clone(); + Some(format!("{related_name}.{fk} = {model_name}.{pk}")) + } + RelationshipType::ManyToMany => { + if let Some(through) = &relationship.through { + let _junction_model = model_map.get(through.as_str())?; + let (junction_self_fk, junction_related_fk) = relationship.junction_keys(); + let junction_self_fk = junction_self_fk?; + let junction_related_fk = junction_related_fk?; + let base_pk = model.primary_key.clone(); + let related_pk = relationship + .primary_key + .clone() + .unwrap_or_else(|| related_model.primary_key.clone()); + return Some(format!( + "{model_name}.{base_pk} = {through}.{junction_self_fk} AND {through}.{junction_related_fk} = {related_name}.{related_pk}" + )); + } + + relationship.foreign_key.clone().map(|foreign_key| { + format!( + "{model_name}.{} = {related_name}.{foreign_key}", + model.primary_key + ) + }) + } + } +} + +#[cfg(feature = "mcp-adbc")] +fn adbc_value_to_json(value: &AdbcValue) -> JsonValue { + match value { + AdbcValue::Null => JsonValue::Null, + AdbcValue::Bool(v) => json!(v), + AdbcValue::I64(v) => json!(v), + AdbcValue::U64(v) => json!(v), + AdbcValue::F64(v) => json!(v), + AdbcValue::String(v) => json!(v), + AdbcValue::Bytes(v) => json!(v), + } +} + +#[cfg(feature = "mcp-adbc")] +fn adbc_rows_to_json_rows(columns: &[String], rows: &[Vec]) -> Vec { + rows.iter() + .map(|row| { + let mut row_map = JsonMap::new(); + for (idx, column) in columns.iter().enumerate() { + let value = row + .get(idx) + .map(adbc_value_to_json) + .unwrap_or(JsonValue::Null); + row_map.insert(column.clone(), value); + } + JsonValue::Object(row_map) + }) + .collect() +} + +fn normalize_sql(sql: &str) -> Result { + let mut normalized = sql.trim().to_string(); + if normalized.is_empty() { + return Err("SQL query cannot be empty".to_string()); + } + while normalized.ends_with(';') { + normalized.pop(); + normalized = normalized.trim_end().to_string(); + } + if has_unquoted_semicolon(&normalized) { + return Err("Only one SQL statement is allowed".to_string()); + } + Ok(normalized) +} + +#[cfg(feature = "mcp-adbc")] +struct ChartPayload { + vega_spec: JsonValue, + png_base64: String, + row_count: usize, +} + +#[cfg(feature = "mcp-adbc")] +fn build_chart_payload( + rows: Vec, + columns: &[String], + requested_type: &str, + title: &str, + width: usize, + height: usize, +) -> Result { + if columns.is_empty() { + return Err(McpError::invalid_params( + "chart query returned no columns".to_string(), + None, + )); + } + + let numeric_flags = columns + .iter() + .skip(1) + .map(|column| { + rows.iter() + .filter_map(JsonValue::as_object) + .filter_map(|row| row.get(column)) + .any(is_json_number) + }) + .collect::>(); + let (x, y_cols) = chart_auto_detect_columns(columns, &numeric_flags) + .map_err(|e| McpError::invalid_params(format!("failed to build chart: {e}"), None))?; + let mut chart_rows = rows.clone(); + if x == "index" { + for (idx, row) in chart_rows.iter_mut().enumerate() { + if let Some(row) = row.as_object_mut() { + row.insert("index".to_string(), json!(idx)); + } + } + } + + let x_value_kind = chart_rows + .first() + .and_then(JsonValue::as_object) + .and_then(|row| row.get(&x)) + .map(json_value_kind) + .unwrap_or("other"); + let chart_type = if requested_type == "auto" { + chart_select_type(&x, x_value_kind, y_cols.len()) + } else { + requested_type.to_string() + }; + let allowed = ["bar", "line", "area", "scatter", "point"]; + if !allowed.contains(&chart_type.as_str()) { + return Err(McpError::invalid_params( + format!("Unsupported chart type: {chart_type}"), + None, + )); + } + + let vega_spec = build_vega_spec(&chart_rows, &x, &y_cols, &chart_type, title, width, height); + let png_base64 = + render_chart_png_data_url(&chart_rows, &x, &y_cols, &chart_type, width, height); + Ok(ChartPayload { + vega_spec, + png_base64, + row_count: chart_rows.len(), + }) +} + +#[cfg(feature = "mcp-adbc")] +fn build_vega_spec( + rows: &[JsonValue], + x: &str, + y_cols: &[String], + chart_type: &str, + title: &str, + width: usize, + height: usize, +) -> JsonValue { + let x_label = chart_format_label(x); + let data = json!({ "values": rows }); + let config = json!({ + "font": "Inter, system-ui, -apple-system, sans-serif", + "title": { + "fontSize": 18, + "fontWeight": 600, + "anchor": "start", + "color": "#1a1a1a", + "offset": 20 + }, + "axis": { + "labelFontSize": 12, + "titleFontSize": 13, + "titleFontWeight": 500, + "titleColor": "#4a4a4a", + "labelColor": "#6a6a6a", + "gridColor": "#e8e8e8", + "gridOpacity": 0.6, + "domainColor": "#cccccc", + "tickColor": "#cccccc", + "titlePadding": 12, + "labelPadding": 8 + }, + "legend": { + "titleFontSize": 13, + "titleFontWeight": 500, + "labelFontSize": 12, + "titleColor": "#4a4a4a", + "labelColor": "#6a6a6a", + "symbolSize": 100, + "orient": "right", + "offset": 10 + }, + "view": { "strokeWidth": 0 }, + "bar": { "cornerRadiusEnd": 2 }, + "line": { "strokeCap": "round" }, + "point": { "filled": true } + }); + + if chart_type == "scatter" { + let y = y_cols.first().cloned().unwrap_or_else(|| x.to_string()); + return json!({ + "$schema": "https://vega.github.io/schema/vega-lite/v5.json", + "data": data, + "mark": { "type": "circle", "size": 80, "opacity": 0.7 }, + "encoding": { + "x": { "field": x, "type": "quantitative", "title": x_label }, + "y": { "field": y, "type": "quantitative", "title": chart_format_label(&y) }, + "color": { "value": CHART_COLORS[0] }, + "tooltip": [ + { "field": x, "type": "quantitative", "title": x_label, "format": ",.2f" }, + { "field": y, "type": "quantitative", "title": chart_format_label(&y), "format": ",.2f" } + ] + }, + "config": config, + "width": width, + "height": height, + "title": title + }); + } + + let x_type = chart_encoding_type(x); + if y_cols.len() > 1 { + let mark = match chart_type { + "line" => json!({ "type": "line", "point": true, "strokeWidth": 2.5 }), + "area" => json!({ "type": "area", "opacity": 0.6, "line": true }), + "point" => json!({ "type": "point", "size": 80, "filled": true }), + _ => json!({ "type": "bar" }), + }; + return json!({ + "$schema": "https://vega.github.io/schema/vega-lite/v5.json", + "data": data, + "transform": [{ "fold": y_cols, "as": ["metric", "value"] }], + "mark": mark, + "encoding": { + "x": { "field": x, "type": x_type, "title": x_label }, + "y": { "field": "value", "type": "quantitative", "title": "Value" }, + "color": { + "field": "metric", + "type": "nominal", + "title": "Metric", + "scale": { "range": CHART_COLORS } + }, + "tooltip": [ + { "field": x, "type": x_type, "title": x_label }, + { "field": "metric", "type": "nominal", "title": "Metric" }, + { "field": "value", "type": "quantitative", "title": "Value", "format": ",.2f" } + ] + }, + "config": config, + "width": width, + "height": height, + "title": title + }); + } + + let y = y_cols.first().cloned().unwrap_or_else(|| x.to_string()); + let y_label = chart_format_label(&y); + let mark = match chart_type { + "line" => json!({ "type": "line", "point": true, "strokeWidth": 3 }), + "area" => json!({ "type": "area", "opacity": 0.7, "line": true }), + "point" => json!({ "type": "point", "size": 100, "filled": true }), + _ => json!({ "type": "bar" }), + }; + let mut encoding = JsonMap::new(); + encoding.insert( + "x".to_string(), + json!({ "field": x, "type": x_type, "title": x_label }), + ); + encoding.insert( + "y".to_string(), + json!({ "field": y, "type": "quantitative", "title": y_label }), + ); + encoding.insert( + "tooltip".to_string(), + json!([ + { "field": x, "type": x_type, "title": x_label }, + { "field": y, "type": "quantitative", "title": y_label, "format": ",.2f" } + ]), + ); + if chart_type == "bar" { + encoding.insert( + "color".to_string(), + json!({ + "field": x, + "type": x_type, + "legend": null, + "scale": { "range": CHART_COLORS } + }), + ); + } + + json!({ + "$schema": "https://vega.github.io/schema/vega-lite/v5.json", + "data": data, + "mark": mark, + "encoding": JsonValue::Object(encoding), + "config": config, + "width": width, + "height": height, + "title": title + }) +} + +#[cfg(feature = "mcp-adbc")] +fn generate_chart_title(dimensions: &[String], metrics: &[String]) -> String { + if metrics.is_empty() { + return "Data Visualization".to_string(); + } + + let metric_names = metrics + .iter() + .map(|metric| format_chart_field_name(metric)) + .collect::>(); + let mut title = if metric_names.len() == 1 { + metric_names[0].clone() + } else if metric_names.len() == 2 { + format!("{} & {}", metric_names[0], metric_names[1]) + } else { + format!("{} & {} more", metric_names[0], metric_names.len() - 1) + }; + + if let Some(dimension) = dimensions.first() { + title = format!("{} by {}", title, format_chart_field_name(dimension)); + } + title +} + +#[cfg(feature = "mcp-adbc")] +fn format_chart_field_name(field: &str) -> String { + let field = field.rsplit('.').next().unwrap_or(field); + if let Some((base, granularity)) = field.rsplit_once("__") { + return format!( + "{} ({})", + chart_format_label(base), + chart_format_label(granularity) + ); + } + chart_format_label(field) +} + +#[cfg(feature = "mcp-adbc")] +fn is_json_number(value: &JsonValue) -> bool { + matches!(value, JsonValue::Number(_)) +} + +#[cfg(feature = "mcp-adbc")] +fn json_value_kind(value: &JsonValue) -> &'static str { + match value { + JsonValue::Number(_) => "number", + JsonValue::String(_) => "string", + _ => "other", + } +} + +#[cfg(feature = "mcp-adbc")] +fn json_number(value: &JsonValue) -> Option { + value + .as_f64() + .or_else(|| value.as_i64().map(|value| value as f64)) + .or_else(|| value.as_u64().map(|value| value as f64)) +} + +#[cfg(feature = "mcp-adbc")] +fn render_chart_png_data_url( + rows: &[JsonValue], + x: &str, + y_cols: &[String], + chart_type: &str, + width: usize, + height: usize, +) -> String { + let width = width.clamp(1, 1200); + let height = height.clamp(1, 900); + let mut image = vec![255u8; width * height * 3]; + let left = 40usize.min(width.saturating_sub(1)); + let right = 10usize.min(width.saturating_sub(1)); + let top = 12usize.min(height.saturating_sub(1)); + let bottom = 28usize.min(height.saturating_sub(1)); + let plot_left = left; + let plot_right = width.saturating_sub(right + 1).max(plot_left); + let plot_top = top; + let plot_bottom = height.saturating_sub(bottom + 1).max(plot_top); + draw_line( + &mut image, + width, + height, + (plot_left as isize, plot_bottom as isize), + (plot_right as isize, plot_bottom as isize), + [204, 204, 204], + ); + draw_line( + &mut image, + width, + height, + (plot_left as isize, plot_top as isize), + (plot_left as isize, plot_bottom as isize), + [204, 204, 204], + ); + + let y = y_cols.first().map(String::as_str).unwrap_or(x); + let values = rows + .iter() + .filter_map(JsonValue::as_object) + .filter_map(|row| row.get(y).and_then(json_number)) + .collect::>(); + if values.is_empty() { + let png = encode_rgb_png(width as u32, height as u32, &image); + return format!("data:image/png;base64,{}", base64_encode(&png)); + } + + let min_value = values.iter().copied().fold(0.0_f64, f64::min); + let mut max_value = values.iter().copied().fold(0.0_f64, f64::max); + if (max_value - min_value).abs() < f64::EPSILON { + max_value = min_value + 1.0; + } + let scale_y = |value: f64| -> isize { + let span = max_value - min_value; + let normalized = (value - min_value) / span; + (plot_bottom as f64 - normalized * (plot_bottom.saturating_sub(plot_top) as f64)).round() + as isize + }; + let color = [46, 94, 170]; + + if chart_type == "bar" { + let slot = + (plot_right.saturating_sub(plot_left).max(1) as f64 / values.len() as f64).max(1.0); + let bar_width = (slot * 0.7).max(1.0) as usize; + for (idx, value) in values.iter().enumerate() { + let center = plot_left as f64 + slot * (idx as f64 + 0.5); + let x0 = center.round() as isize - (bar_width as isize / 2); + let x1 = x0 + bar_width as isize; + let y0 = scale_y(*value); + fill_rect( + &mut image, + width, + height, + (x0, y0.min(plot_bottom as isize)), + (x1, plot_bottom as isize), + color, + ); + } + } else { + let denom = values.len().saturating_sub(1).max(1) as f64; + let points = values + .iter() + .enumerate() + .map(|(idx, value)| { + let x = plot_left as f64 + + (plot_right.saturating_sub(plot_left) as f64 * idx as f64 / denom); + (x.round() as isize, scale_y(*value)) + }) + .collect::>(); + + for window in points.windows(2) { + draw_line(&mut image, width, height, window[0], window[1], color); + } + for (x, y) in points { + fill_circle(&mut image, width, height, x, y, 3, color); + } + } + + let png = encode_rgb_png(width as u32, height as u32, &image); + format!("data:image/png;base64,{}", base64_encode(&png)) +} + +#[cfg(feature = "mcp-adbc")] +fn set_pixel(image: &mut [u8], width: usize, height: usize, x: isize, y: isize, color: [u8; 3]) { + if x < 0 || y < 0 || x >= width as isize || y >= height as isize { + return; + } + let idx = ((y as usize * width) + x as usize) * 3; + image[idx..idx + 3].copy_from_slice(&color); +} + +#[cfg(feature = "mcp-adbc")] +fn draw_line( + image: &mut [u8], + width: usize, + height: usize, + start: (isize, isize), + end: (isize, isize), + color: [u8; 3], +) { + let (mut x0, mut y0) = start; + let (x1, y1) = end; + let dx = (x1 - x0).abs(); + let sx = if x0 < x1 { 1 } else { -1 }; + let dy = -(y1 - y0).abs(); + let sy = if y0 < y1 { 1 } else { -1 }; + let mut err = dx + dy; + loop { + set_pixel(image, width, height, x0, y0, color); + if x0 == x1 && y0 == y1 { + break; + } + let e2 = 2 * err; + if e2 >= dy { + err += dy; + x0 += sx; + } + if e2 <= dx { + err += dx; + y0 += sy; + } + } +} + +#[cfg(feature = "mcp-adbc")] +fn fill_rect( + image: &mut [u8], + width: usize, + height: usize, + top_left: (isize, isize), + bottom_right: (isize, isize), + color: [u8; 3], +) { + let (x0, y0) = top_left; + let (x1, y1) = bottom_right; + for y in y0.max(0)..=y1.min(height as isize - 1) { + for x in x0.max(0)..=x1.min(width as isize - 1) { + set_pixel(image, width, height, x, y, color); + } + } +} + +#[cfg(feature = "mcp-adbc")] +fn fill_circle( + image: &mut [u8], + width: usize, + height: usize, + cx: isize, + cy: isize, + radius: isize, + color: [u8; 3], +) { + for y in -radius..=radius { + for x in -radius..=radius { + if x * x + y * y <= radius * radius { + set_pixel(image, width, height, cx + x, cy + y, color); + } + } + } +} + +#[cfg(feature = "mcp-adbc")] +fn encode_rgb_png(width: u32, height: u32, rgb: &[u8]) -> Vec { + let mut scanlines = Vec::with_capacity((width as usize * 3 + 1) * height as usize); + for row in 0..height as usize { + scanlines.push(0); + let start = row * width as usize * 3; + let end = start + width as usize * 3; + scanlines.extend_from_slice(&rgb[start..end]); + } + + let mut out = Vec::new(); + out.extend_from_slice(&[137, 80, 78, 71, 13, 10, 26, 10]); + let mut ihdr = Vec::new(); + ihdr.extend_from_slice(&width.to_be_bytes()); + ihdr.extend_from_slice(&height.to_be_bytes()); + ihdr.extend_from_slice(&[8, 2, 0, 0, 0]); + write_png_chunk(&mut out, b"IHDR", &ihdr); + write_png_chunk(&mut out, b"IDAT", &zlib_store(&scanlines)); + write_png_chunk(&mut out, b"IEND", &[]); + out +} + +#[cfg(feature = "mcp-adbc")] +fn write_png_chunk(out: &mut Vec, kind: &[u8; 4], data: &[u8]) { + out.extend_from_slice(&(data.len() as u32).to_be_bytes()); + out.extend_from_slice(kind); + out.extend_from_slice(data); + let mut crc_input = Vec::with_capacity(kind.len() + data.len()); + crc_input.extend_from_slice(kind); + crc_input.extend_from_slice(data); + out.extend_from_slice(&crc32(&crc_input).to_be_bytes()); +} + +#[cfg(feature = "mcp-adbc")] +fn zlib_store(data: &[u8]) -> Vec { + let mut out = vec![0x78, 0x01]; + let mut offset = 0; + while offset < data.len() { + let remaining = data.len() - offset; + let chunk_len = remaining.min(65_535); + let is_final = offset + chunk_len >= data.len(); + out.push(if is_final { 1 } else { 0 }); + let len = chunk_len as u16; + out.extend_from_slice(&len.to_le_bytes()); + out.extend_from_slice(&(!len).to_le_bytes()); + out.extend_from_slice(&data[offset..offset + chunk_len]); + offset += chunk_len; + } + out.extend_from_slice(&adler32(data).to_be_bytes()); + out +} + +#[cfg(feature = "mcp-adbc")] +fn adler32(data: &[u8]) -> u32 { + const MOD_ADLER: u32 = 65_521; + let mut a = 1u32; + let mut b = 0u32; + for byte in data { + a = (a + *byte as u32) % MOD_ADLER; + b = (b + a) % MOD_ADLER; + } + (b << 16) | a +} + +#[cfg(feature = "mcp-adbc")] +fn crc32(data: &[u8]) -> u32 { + let mut crc = 0xffff_ffffu32; + for byte in data { + crc ^= *byte as u32; + for _ in 0..8 { + let mask = if crc & 1 == 1 { 0xedb8_8320 } else { 0 }; + crc = (crc >> 1) ^ mask; + } + } + !crc +} + +#[cfg(feature = "mcp-adbc")] +fn base64_encode(data: &[u8]) -> String { + const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + let mut out = String::with_capacity(data.len().div_ceil(3) * 4); + for chunk in data.chunks(3) { + let b0 = chunk[0]; + let b1 = *chunk.get(1).unwrap_or(&0); + let b2 = *chunk.get(2).unwrap_or(&0); + out.push(TABLE[(b0 >> 2) as usize] as char); + out.push(TABLE[(((b0 & 0b0000_0011) << 4) | (b1 >> 4)) as usize] as char); + if chunk.len() > 1 { + out.push(TABLE[(((b1 & 0b0000_1111) << 2) | (b2 >> 6)) as usize] as char); + } else { + out.push('='); + } + if chunk.len() > 2 { + out.push(TABLE[(b2 & 0b0011_1111) as usize] as char); + } else { + out.push('='); + } + } + out +} + +fn has_unquoted_semicolon(sql: &str) -> bool { + let mut in_single = false; + let mut in_double = false; + let mut prev = '\0'; + for ch in sql.chars() { + match ch { + '\'' if !in_double && prev != '\\' => in_single = !in_single, + '"' if !in_single && prev != '\\' => in_double = !in_double, + ';' if !in_single && !in_double => return true, + _ => {} + } + prev = ch; + } + false +} + +fn parse_config() -> Result { + let mut models_path: Option = None; + let mut adbc_driver: Option = env::var("SIDEMANTIC_MCP_ADBC_DRIVER").ok(); + let mut adbc_uri: Option = env::var("SIDEMANTIC_MCP_ADBC_URI").ok(); + let mut adbc_entrypoint: Option = env::var("SIDEMANTIC_MCP_ADBC_ENTRYPOINT").ok(); + let mut database_options: Vec = Vec::new(); + let mut connection_options: Vec = Vec::new(); + if let Ok(env_dbopts) = + env::var("SIDEMANTIC_MCP_ADBC_DBOPTS").or_else(|_| env::var("SIDEMANTIC_ADBC_DBOPTS")) + { + database_options.extend(parse_database_options(&env_dbopts)?); + } + if let Ok(env_connopts) = + env::var("SIDEMANTIC_MCP_ADBC_CONNOPTS").or_else(|_| env::var("SIDEMANTIC_ADBC_CONNOPTS")) + { + connection_options.extend(parse_connection_options(&env_connopts)?); + } + + let mut args = env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--models" => { + let value = args + .next() + .ok_or_else(|| "--models requires a path value".to_string())?; + models_path = Some(value); + } + "--driver" => { + let value = args + .next() + .ok_or_else(|| "--driver requires a value".to_string())?; + adbc_driver = Some(value); + } + "--uri" => { + let value = args + .next() + .ok_or_else(|| "--uri requires a value".to_string())?; + adbc_uri = Some(value); + } + "--entrypoint" => { + let value = args + .next() + .ok_or_else(|| "--entrypoint requires a value".to_string())?; + adbc_entrypoint = Some(value); + } + "--dbopt" => { + let value = args + .next() + .ok_or_else(|| "--dbopt requires a value".to_string())?; + database_options.extend(parse_database_options(&value)?); + } + "--connopt" => { + let value = args + .next() + .ok_or_else(|| "--connopt requires a value".to_string())?; + connection_options.extend(parse_connection_options(&value)?); + } + "--help" | "-h" => { + return Err( + "Usage: sidemantic-mcp [] [--models ] [--driver ] [--uri ] [--entrypoint ] [--dbopt ] [--connopt ]".to_string() + ); + } + value if !value.starts_with('-') && models_path.is_none() => { + models_path = Some(value.to_string()); + } + unknown => { + return Err(format!( + "unknown argument: {unknown}. Use --help for usage." + )); + } + } + } + + let models_path = models_path + .or_else(|| env::var("SIDEMANTIC_MCP_MODELS").ok()) + .unwrap_or_else(|| ".".to_string()); + + Ok(ServerConfig { + models_path, + adbc_driver, + adbc_uri, + adbc_entrypoint, + database_options, + connection_options, + }) +} + +fn parse_kv_pairs(input: &str, option_name: &str) -> Result, String> { + let mut pairs = Vec::new(); + for fragment in input + .split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + { + let (key, value) = fragment + .split_once('=') + .ok_or_else(|| format!("{option_name} expects key=value, got '{fragment}'"))?; + if key.trim().is_empty() { + return Err(format!("{option_name} key cannot be empty: '{fragment}'")); + } + pairs.push((key.trim().to_string(), value.to_string())); + } + if pairs.is_empty() { + return Err(format!("{option_name} expects key=value pairs")); + } + Ok(pairs) +} + +#[cfg(feature = "mcp-adbc")] +fn parse_option_value(value: &str) -> OptionValue { + if let Some(rest) = value.strip_prefix("int:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Int(parsed); + } + } + if let Some(rest) = value.strip_prefix("float:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Double(parsed); + } + } + if let Some(rest) = value.strip_prefix("str:") { + return OptionValue::String(rest.to_string()); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Int(parsed); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Double(parsed); + } + OptionValue::String(value.to_string()) +} + +#[cfg(not(feature = "mcp-adbc"))] +fn parse_option_value(value: &str) -> String { + value.to_string() +} + +fn parse_database_options(input: &str) -> Result, String> { + let mut parsed = Vec::new(); + for (key, raw_value) in parse_kv_pairs(input, "--dbopt")? { + #[cfg(feature = "mcp-adbc")] + parsed.push(( + OptionDatabase::from(key.as_str()), + parse_option_value(&raw_value), + )); + #[cfg(not(feature = "mcp-adbc"))] + parsed.push((key, parse_option_value(&raw_value))); + } + Ok(parsed) +} + +fn parse_connection_options(input: &str) -> Result, String> { + let mut parsed = Vec::new(); + for (key, raw_value) in parse_kv_pairs(input, "--connopt")? { + #[cfg(feature = "mcp-adbc")] + parsed.push(( + OptionConnection::from(key.as_str()), + parse_option_value(&raw_value), + )); + #[cfg(not(feature = "mcp-adbc"))] + parsed.push((key, parse_option_value(&raw_value))); + } + Ok(parsed) +} + +fn load_runtime(models_path: &str) -> Result { + let path = PathBuf::from(models_path); + if path.is_dir() { + return SidemanticRuntime::from_directory(path) + .map_err(|e| format!("failed to load models from directory '{models_path}': {e}")); + } + if path.is_file() { + return SidemanticRuntime::from_file(path) + .map_err(|e| format!("failed to load models from file '{models_path}': {e}")); + } + Err(format!( + "models path '{models_path}' is not a readable file or directory" + )) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = parse_config()?; + let runtime = load_runtime(&config.models_path)?; + let server = SidemanticMcpServer::new(runtime, config); + + let service = server.serve(stdio()).await?; + service.waiting().await?; + Ok(()) +} diff --git a/sidemantic-rs/src/bin/sidemantic-server.rs b/sidemantic-rs/src/bin/sidemantic-server.rs new file mode 100644 index 00000000..88b19992 --- /dev/null +++ b/sidemantic-rs/src/bin/sidemantic-server.rs @@ -0,0 +1,1492 @@ +//! Rust-native HTTP runtime server for Sidemantic. +//! +//! This server exposes semantic compile/execute operations over HTTP using axum. + +use std::collections::HashMap; +use std::env; +#[cfg(feature = "runtime-server-adbc")] +use std::io::{self, Write}; +use std::path::PathBuf; +use std::sync::Arc; + +#[cfg(feature = "runtime-server-adbc")] +use adbc_core::options::{OptionConnection, OptionDatabase, OptionValue}; +#[cfg(feature = "runtime-server-adbc")] +use axum::body::Bytes; +use axum::body::{to_bytes, Body}; +use axum::extract::{Path, Query, Request, State}; +use axum::http::{header, HeaderMap, HeaderValue, Method, StatusCode}; +use axum::middleware::{self, Next}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Map as JsonMap, Value as JsonValue}; +use sidemantic::runtime::interpolate_query_filters; +#[cfg(feature = "runtime-server-adbc")] +use sidemantic::{ + execute_with_adbc, execute_with_adbc_arrow_ipc, write_adbc_arrow_ipc, AdbcExecutionRequest, + AdbcValue, +}; +use sidemantic::{Metric, Model, Relationship, RelationshipType, SemanticQuery, SidemanticRuntime}; +#[cfg(feature = "runtime-server-adbc")] +use tokio_stream::wrappers::ReceiverStream; + +#[cfg(feature = "runtime-server-adbc")] +type DatabaseOption = (OptionDatabase, OptionValue); +#[cfg(not(feature = "runtime-server-adbc"))] +type DatabaseOption = (String, String); +#[cfg(feature = "runtime-server-adbc")] +type ConnectionOption = (OptionConnection, OptionValue); +#[cfg(not(feature = "runtime-server-adbc"))] +type ConnectionOption = (String, String); + +const ARROW_STREAM_MEDIA_TYPE: &str = "application/vnd.apache.arrow.stream"; +#[cfg(feature = "runtime-server-adbc")] +const ARROW_STREAM_CHUNK_BYTES: usize = 64 * 1024; + +#[cfg_attr(not(feature = "runtime-server-adbc"), allow(dead_code))] +#[derive(Debug, Clone)] +struct AppState { + runtime: Arc, + adbc_driver: Option, + adbc_uri: Option, + adbc_entrypoint: Option, + database_options: Vec, + connection_options: Vec, +} + +#[derive(Debug, Clone)] +struct HttpControls { + auth_token: Option, + cors_origins: Vec, + max_request_body_bytes: usize, +} + +#[derive(Debug, Default)] +struct ServerConfig { + models_path: String, + bind: String, + auth_token: Option, + cors_origins: Vec, + max_request_body_bytes: usize, + adbc_driver: Option, + adbc_uri: Option, + adbc_entrypoint: Option, + database_options: Vec, + connection_options: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct QueryRequest { + #[serde(default)] + dimensions: Vec, + #[serde(default)] + metrics: Vec, + #[serde(default, rename = "where")] + where_clause: Option, + #[serde(default)] + filters: Vec, + #[serde(default)] + segments: Vec, + #[serde(default)] + order_by: Vec, + #[serde(default)] + limit: Option, + #[serde(default)] + offset: Option, + #[serde(default)] + ungrouped: bool, + #[serde(default)] + use_preaggregations: bool, + #[serde(default)] + parameters: JsonMap, +} + +#[derive(Debug, Clone, Deserialize)] +struct SQLRequest { + query: String, +} + +#[derive(Debug, Clone, Default, Deserialize)] +struct ResponseFormatParams { + format: Option, + transport: Option, + stream: Option, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum ArrowTransport { + Buffered, + Chunked, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum ResponseFormat { + Json, + Arrow(ArrowTransport), +} + +#[derive(Debug, Clone, Deserialize)] +struct GetModelsRequest { + #[serde(default)] + model_names: Vec, +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +fn json_error(status: StatusCode, message: impl Into) -> (StatusCode, Json) { + ( + status, + Json(ErrorResponse { + error: message.into(), + }), + ) +} + +fn json_error_response(status: StatusCode, message: impl Into) -> Response { + ( + status, + Json(ErrorResponse { + error: message.into(), + }), + ) + .into_response() +} + +fn resolve_response_format( + params: &ResponseFormatParams, + headers: &HeaderMap, +) -> Result)> { + let response_format = if let Some(format) = params.format.as_deref() { + match format.to_ascii_lowercase().as_str() { + "json" => ResponseFormat::Json, + "arrow" => ResponseFormat::Arrow(resolve_arrow_transport(params)?), + other => { + return Err(json_error( + StatusCode::BAD_REQUEST, + format!("unsupported response format '{other}'"), + )); + } + } + } else { + let wants_arrow = headers + .get(header::ACCEPT) + .and_then(|value| value.to_str().ok()) + .map(|value| value.contains(ARROW_STREAM_MEDIA_TYPE)) + .unwrap_or(false); + if wants_arrow { + ResponseFormat::Arrow(resolve_arrow_transport(params)?) + } else { + ResponseFormat::Json + } + }; + + if response_format == ResponseFormat::Json + && (params.transport.is_some() || params.stream.is_some()) + { + return Err(json_error( + StatusCode::BAD_REQUEST, + "chunked transport streaming is only supported for Arrow responses", + )); + } + + Ok(response_format) +} + +fn resolve_arrow_transport( + params: &ResponseFormatParams, +) -> Result)> { + let mut transport = ArrowTransport::Buffered; + + if let Some(value) = params.transport.as_deref() { + transport = match value.to_ascii_lowercase().as_str() { + "buffered" | "buffer" => ArrowTransport::Buffered, + "chunked" | "stream" | "streaming" => ArrowTransport::Chunked, + other => { + return Err(json_error( + StatusCode::BAD_REQUEST, + format!("unsupported Arrow transport '{other}'"), + )); + } + }; + } + + if let Some(value) = params.stream.as_deref() { + let stream_transport = match value.to_ascii_lowercase().as_str() { + "true" | "1" | "yes" | "chunked" | "stream" | "streaming" => ArrowTransport::Chunked, + "false" | "0" | "no" | "buffered" | "buffer" => ArrowTransport::Buffered, + other => { + return Err(json_error( + StatusCode::BAD_REQUEST, + format!("unsupported Arrow stream option '{other}'"), + )); + } + }; + if params.transport.is_some() && transport != stream_transport { + return Err(json_error( + StatusCode::BAD_REQUEST, + "conflicting Arrow transport and stream options", + )); + } + transport = stream_transport; + } + + Ok(transport) +} + +#[cfg(feature = "runtime-server-adbc")] +fn arrow_response(bytes: Vec, row_count: usize) -> Response { + let mut response = Body::from(bytes).into_response(); + let headers = response.headers_mut(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static(ARROW_STREAM_MEDIA_TYPE), + ); + headers.insert( + "x-sidemantic-row-count", + HeaderValue::from_str(&row_count.to_string()).expect("row count header should be valid"), + ); + headers.insert("x-sidemantic-dialect", HeaderValue::from_static("generic")); + headers.insert( + "x-sidemantic-arrow-transport", + HeaderValue::from_static("buffered"), + ); + response +} + +#[cfg(feature = "runtime-server-adbc")] +fn arrow_chunked_response( + receiver: tokio::sync::mpsc::Receiver>, +) -> Response { + let stream = ReceiverStream::new(receiver); + let mut response = Body::from_stream(stream).into_response(); + let headers = response.headers_mut(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static(ARROW_STREAM_MEDIA_TYPE), + ); + headers.insert("x-sidemantic-dialect", HeaderValue::from_static("generic")); + headers.insert( + "x-sidemantic-arrow-transport", + HeaderValue::from_static("chunked"), + ); + response +} + +#[cfg(feature = "runtime-server-adbc")] +struct ArrowChunkWriter { + sender: tokio::sync::mpsc::Sender>, + buffer: Vec, + chunk_size: usize, +} + +#[cfg(feature = "runtime-server-adbc")] +impl ArrowChunkWriter { + fn new(sender: tokio::sync::mpsc::Sender>) -> Self { + Self { + sender, + buffer: Vec::with_capacity(ARROW_STREAM_CHUNK_BYTES), + chunk_size: ARROW_STREAM_CHUNK_BYTES, + } + } + + fn flush_buffer(&mut self) -> io::Result<()> { + if self.buffer.is_empty() { + return Ok(()); + } + + let mut chunk = Vec::with_capacity(self.chunk_size); + std::mem::swap(&mut self.buffer, &mut chunk); + self.sender + .blocking_send(Ok(Bytes::from(chunk))) + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "HTTP client disconnected")) + } +} + +#[cfg(feature = "runtime-server-adbc")] +impl Write for ArrowChunkWriter { + fn write(&mut self, mut buf: &[u8]) -> io::Result { + let bytes_written = buf.len(); + while !buf.is_empty() { + let remaining = self.chunk_size - self.buffer.len(); + let take = remaining.min(buf.len()); + self.buffer.extend_from_slice(&buf[..take]); + buf = &buf[take..]; + + if self.buffer.len() >= self.chunk_size { + self.flush_buffer()?; + } + } + Ok(bytes_written) + } + + fn flush(&mut self) -> io::Result<()> { + self.flush_buffer() + } +} + +fn cors_allowed_origin(controls: &HttpControls, request: &Request) -> Option { + if controls.cors_origins.is_empty() { + return None; + } + let origin = request.headers().get(header::ORIGIN)?.to_str().ok()?; + if controls.cors_origins.iter().any(|allowed| allowed == "*") { + return HeaderValue::from_str(origin).ok(); + } + controls + .cors_origins + .iter() + .any(|allowed| allowed == origin) + .then(|| HeaderValue::from_str(origin).ok()) + .flatten() +} + +fn apply_cors(mut response: Response, origin: Option) -> Response { + if let Some(origin) = origin { + let headers = response.headers_mut(); + headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin); + headers.insert( + header::ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static("GET,POST,OPTIONS"), + ); + headers.insert( + header::ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static("authorization,content-type"), + ); + headers.insert( + header::ACCESS_CONTROL_MAX_AGE, + HeaderValue::from_static("600"), + ); + } + response +} + +async fn http_controls_middleware( + State(controls): State>, + mut request: Request, + next: Next, +) -> Response { + let origin = cors_allowed_origin(&controls, &request); + if request.method() == Method::OPTIONS { + return apply_cors(StatusCode::NO_CONTENT.into_response(), origin); + } + + if request.uri().path() != "/readyz" { + if let Some(expected) = &controls.auth_token { + let authorized = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value == format!("Bearer {expected}")); + if !authorized { + let mut response = json_error_response(StatusCode::UNAUTHORIZED, "Unauthorized"); + response + .headers_mut() + .insert(header::WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")); + return apply_cors(response, origin); + } + } + } + + if matches!( + request.method(), + &Method::POST | &Method::PUT | &Method::PATCH + ) { + if let Some(content_length) = request + .headers() + .get(header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.parse::().ok()) + { + if content_length > controls.max_request_body_bytes { + return apply_cors( + json_error_response( + StatusCode::PAYLOAD_TOO_LARGE, + format!( + "Request body exceeds {} bytes", + controls.max_request_body_bytes + ), + ), + origin, + ); + } + } + + let (parts, body) = request.into_parts(); + let bytes = match to_bytes(body, controls.max_request_body_bytes + 1).await { + Ok(bytes) if bytes.len() <= controls.max_request_body_bytes => bytes, + _ => { + return apply_cors( + json_error_response( + StatusCode::PAYLOAD_TOO_LARGE, + format!( + "Request body exceeds {} bytes", + controls.max_request_body_bytes + ), + ), + origin, + ); + } + }; + request = Request::from_parts(parts, Body::from(bytes)); + } + + apply_cors(next.run(request).await, origin) +} + +fn compile_request(runtime: &SidemanticRuntime, request: &QueryRequest) -> Result { + let mut filters = request.filters.clone(); + if let Some(where_clause) = &request.where_clause { + if !where_clause.trim().is_empty() { + filters.push(where_clause.clone()); + } + } + let parameter_values = request + .parameters + .iter() + .map(|(key, value)| { + serde_yaml::to_value(value) + .map(|value| (key.clone(), value)) + .map_err(|e| format!("failed to parse query parameter '{key}': {e}")) + }) + .collect::, _>>()?; + let filters = interpolate_query_filters(runtime.graph(), filters, ¶meter_values) + .map_err(|e| format!("failed to interpolate query parameters: {e}"))?; + + let mut query = SemanticQuery::new() + .with_dimensions(request.dimensions.clone()) + .with_metrics(request.metrics.clone()) + .with_filters(filters) + .with_segments(request.segments.clone()) + .with_ungrouped(request.ungrouped) + .with_use_preaggregations(request.use_preaggregations); + + if !request.order_by.is_empty() { + query = query.with_order_by(request.order_by.clone()); + } + if let Some(limit) = request.limit { + query = query.with_limit(limit); + } + if let Some(offset) = request.offset { + query = query.with_offset(offset); + } + + runtime + .compile(&query) + .map_err(|e| format!("failed to compile query: {e}")) +} + +async fn readyz() -> Json { + Json(json!({ "status": "ok" })) +} + +async fn health(State(state): State>) -> Json { + let payload = state.runtime.loaded_graph_payload(); + Json(json!({ + "status": "ok", + "version": env!("CARGO_PKG_VERSION"), + "dialect": "generic", + "model_count": payload.models.len() + })) +} + +fn model_summary_json(model: &Model) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(model.name)); + entry.insert("table".to_string(), json!(model.table)); + entry.insert( + "dimensions".to_string(), + json!(model + .dimensions + .iter() + .map(|dimension| dimension.name.clone()) + .collect::>()), + ); + entry.insert( + "metrics".to_string(), + json!(model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect::>()), + ); + entry.insert( + "relationships".to_string(), + json!(model.relationships.len()), + ); + JsonValue::Object(entry) +} + +fn graph_model_summary_json(model: &Model) -> JsonValue { + let mut entry = match model_summary_json(model) { + JsonValue::Object(entry) => entry, + _ => JsonMap::new(), + }; + entry.insert( + "relationships".to_string(), + JsonValue::Array( + model + .relationships + .iter() + .map(|relationship| { + json!({ + "name": relationship.name, + "type": enum_json_name(&relationship.r#type) + .unwrap_or_else(|| "many_to_one".to_string()) + }) + }) + .collect(), + ), + ); + if !model.segments.is_empty() { + entry.insert( + "segments".to_string(), + json!(model + .segments + .iter() + .map(|segment| segment.name.clone()) + .collect::>()), + ); + } + if let Some(description) = &model.description { + entry.insert("description".to_string(), json!(description)); + } + if !model.primary_key.is_empty() { + entry.insert("primary_key".to_string(), json!(model.primary_key)); + } + if let Some(default_time_dimension) = &model.default_time_dimension { + entry.insert( + "default_time_dimension".to_string(), + json!(default_time_dimension), + ); + } + JsonValue::Object(entry) +} + +fn dimension_json(dimension: &sidemantic::Dimension) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(dimension.name)); + entry.insert( + "type".to_string(), + json!(enum_json_name(&dimension.r#type).unwrap_or_else(|| "categorical".to_string())), + ); + entry.insert("sql".to_string(), json!(dimension.sql)); + if let Some(description) = &dimension.description { + entry.insert("description".to_string(), json!(description)); + } + if let Some(label) = &dimension.label { + entry.insert("label".to_string(), json!(label)); + } + if let Some(granularity) = &dimension.granularity { + entry.insert("granularity".to_string(), json!(granularity)); + } + if let Some(supported_granularities) = &dimension.supported_granularities { + entry.insert( + "supported_granularities".to_string(), + json!(supported_granularities), + ); + } + if let Some(format) = &dimension.format { + entry.insert("format".to_string(), json!(format)); + } + if let Some(value_format_name) = &dimension.value_format_name { + entry.insert("value_format_name".to_string(), json!(value_format_name)); + } + if let Some(parent) = &dimension.parent { + entry.insert("parent".to_string(), json!(parent)); + } + JsonValue::Object(entry) +} + +fn segment_json(segment: &sidemantic::Segment) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(segment.name)); + entry.insert("sql".to_string(), json!(segment.sql)); + if let Some(description) = &segment.description { + entry.insert("description".to_string(), json!(description)); + } + if !segment.public { + entry.insert("public".to_string(), json!(false)); + } + JsonValue::Object(entry) +} + +fn model_detail_json(model: &Model, model_map: &HashMap<&str, &Model>) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(model.name)); + entry.insert("table".to_string(), json!(model.table)); + entry.insert("primary_key".to_string(), json!(model.primary_key)); + entry.insert( + "dimensions".to_string(), + JsonValue::Array(model.dimensions.iter().map(dimension_json).collect()), + ); + entry.insert( + "metrics".to_string(), + JsonValue::Array(model.metrics.iter().map(metric_json).collect()), + ); + entry.insert( + "relationships".to_string(), + JsonValue::Array( + model + .relationships + .iter() + .map(|relationship| relationship_json(model, relationship, model_map)) + .collect(), + ), + ); + if !model.segments.is_empty() { + entry.insert( + "segments".to_string(), + JsonValue::Array(model.segments.iter().map(segment_json).collect()), + ); + } + if let Some(description) = &model.description { + entry.insert("description".to_string(), json!(description)); + } + if let Some(sql) = &model.sql { + entry.insert("sql".to_string(), json!(sql)); + } + if let Some(default_time_dimension) = &model.default_time_dimension { + entry.insert( + "default_time_dimension".to_string(), + json!(default_time_dimension), + ); + } + if let Some(default_grain) = &model.default_grain { + entry.insert("default_grain".to_string(), json!(default_grain)); + } + JsonValue::Object(entry) +} + +async fn list_models(State(state): State>) -> Json { + let payload = state.runtime.loaded_graph_payload(); + let models = payload.models.iter().map(model_summary_json).collect(); + Json(JsonValue::Array(models)) +} + +async fn get_model( + State(state): State>, + Path(model_name): Path, +) -> Result, (StatusCode, Json)> { + let payload = state.runtime.loaded_graph_payload(); + let model_map: HashMap<&str, &Model> = payload + .models + .iter() + .map(|model| (model.name.as_str(), model)) + .collect(); + let Some(model) = model_map.get(model_name.as_str()) else { + return Err(json_error( + StatusCode::NOT_FOUND, + format!("model not found: {model_name}"), + )); + }; + + Ok(Json(model_detail_json(model, &model_map))) +} + +async fn get_models( + State(state): State>, + Json(request): Json, +) -> Json { + let payload = state.runtime.loaded_graph_payload(); + let model_map: HashMap<&str, &Model> = payload + .models + .iter() + .map(|model| (model.name.as_str(), model)) + .collect(); + + let mut details = Vec::new(); + for model_name in &request.model_names { + let Some(model) = model_map.get(model_name.as_str()) else { + continue; + }; + + details.push(model_detail_json(model, &model_map)); + } + + Json(JsonValue::Array(details)) +} + +fn graph_payload(runtime: &SidemanticRuntime) -> JsonValue { + let payload = runtime.loaded_graph_payload(); + let models = payload + .models + .iter() + .map(graph_model_summary_json) + .collect::>(); + let graph_metrics = payload + .top_level_metrics + .iter() + .map(metric_json) + .collect::>(); + + let model_names = payload + .models + .iter() + .map(|model| model.name.clone()) + .collect::>(); + let mut joinable_pairs = Vec::new(); + for (idx, left_name) in model_names.iter().enumerate() { + for right_name in model_names.iter().skip(idx + 1) { + if let Ok(path) = runtime.find_join_path(left_name, right_name) { + joinable_pairs.push(json!({ + "from": left_name, + "to": right_name, + "hops": path.steps.len() + })); + } + } + } + + let mut result = JsonMap::new(); + result.insert("models".to_string(), JsonValue::Array(models)); + result.insert( + "joinable_pairs".to_string(), + JsonValue::Array(joinable_pairs), + ); + if !graph_metrics.is_empty() { + result.insert("graph_metrics".to_string(), JsonValue::Array(graph_metrics)); + } + JsonValue::Object(result) +} + +async fn graph(State(state): State>) -> Json { + Json(graph_payload(&state.runtime)) +} + +async fn compile_query( + State(state): State>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let sql = compile_request(&state.runtime, &request) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + Ok(Json(json!({ "sql": sql }))) +} + +async fn run_query( + State(state): State>, + Query(format_params): Query, + headers: HeaderMap, + Json(request): Json, +) -> Result)> { + let response_format = resolve_response_format(&format_params, &headers)?; + let sql = compile_request(&state.runtime, &request) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + execute_sql_response(&state, sql, None, "/query/run", response_format) +} + +async fn compile_sql( + State(state): State>, + Json(request): Json, +) -> Result, (StatusCode, Json)> { + let query = normalize_sql(&request.query) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + let sql = state.runtime.rewrite(&query).map_err(|e| { + json_error( + StatusCode::BAD_REQUEST, + format!("failed to rewrite SQL: {e}"), + ) + })?; + Ok(Json(json!({ "sql": sql }))) +} + +async fn run_sql( + State(state): State>, + Query(format_params): Query, + headers: HeaderMap, + Json(request): Json, +) -> Result)> { + let response_format = resolve_response_format(&format_params, &headers)?; + let original_sql = normalize_sql(&request.query) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + let sql = state.runtime.rewrite(&original_sql).map_err(|e| { + json_error( + StatusCode::BAD_REQUEST, + format!("failed to rewrite SQL: {e}"), + ) + })?; + execute_sql_response(&state, sql, Some(original_sql), "/sql", response_format) +} + +async fn run_raw_sql( + State(state): State>, + Query(format_params): Query, + headers: HeaderMap, + Json(request): Json, +) -> Result)> { + let response_format = resolve_response_format(&format_params, &headers)?; + let sql = normalize_sql(&request.query) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + require_select_only_sql(&sql) + .map_err(|message| json_error(StatusCode::BAD_REQUEST, message))?; + execute_sql_response(&state, sql, None, "/raw", response_format) +} + +fn execute_sql_response( + state: &AppState, + sql: String, + original_sql: Option, + route_name: &str, + response_format: ResponseFormat, +) -> Result)> { + match response_format { + ResponseFormat::Json => { + execute_sql_json(state, sql, original_sql, route_name).map(IntoResponse::into_response) + } + ResponseFormat::Arrow(ArrowTransport::Buffered) => { + execute_sql_arrow(state, sql, route_name) + } + ResponseFormat::Arrow(ArrowTransport::Chunked) => { + execute_sql_arrow_chunked(state, sql, route_name) + } + } +} + +fn execute_sql_arrow( + state: &AppState, + sql: String, + route_name: &str, +) -> Result)> { + let _ = route_name; + #[cfg(not(feature = "runtime-server-adbc"))] + { + let _ = (state, sql); + return Err(json_error( + StatusCode::BAD_REQUEST, + format!( + "ADBC execution support is not enabled. Rebuild with feature 'runtime-server-adbc' to use {route_name}." + ), + )); + } + + #[cfg(feature = "runtime-server-adbc")] + { + let Some(driver) = state.adbc_driver.clone() else { + return Err(json_error( + StatusCode::BAD_REQUEST, + "ADBC driver is not configured. Set SIDEMANTIC_SERVER_ADBC_DRIVER or pass --driver.", + )); + }; + + let result = execute_with_adbc_arrow_ipc(AdbcExecutionRequest { + driver, + sql, + uri: state.adbc_uri.clone(), + entrypoint: state.adbc_entrypoint.clone(), + database_options: state.database_options.clone(), + connection_options: state.connection_options.clone(), + }) + .map_err(|e| { + json_error( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to execute query via ADBC: {e}"), + ) + })?; + + Ok(arrow_response(result.bytes, result.row_count)) + } +} + +fn execute_sql_arrow_chunked( + state: &AppState, + sql: String, + route_name: &str, +) -> Result)> { + let _ = route_name; + #[cfg(not(feature = "runtime-server-adbc"))] + { + let _ = (state, sql); + return Err(json_error( + StatusCode::BAD_REQUEST, + format!( + "ADBC execution support is not enabled. Rebuild with feature 'runtime-server-adbc' to use {route_name}." + ), + )); + } + + #[cfg(feature = "runtime-server-adbc")] + { + let Some(driver) = state.adbc_driver.clone() else { + return Err(json_error( + StatusCode::BAD_REQUEST, + "ADBC driver is not configured. Set SIDEMANTIC_SERVER_ADBC_DRIVER or pass --driver.", + )); + }; + + let request = AdbcExecutionRequest { + driver, + sql, + uri: state.adbc_uri.clone(), + entrypoint: state.adbc_entrypoint.clone(), + database_options: state.database_options.clone(), + connection_options: state.connection_options.clone(), + }; + let (sender, receiver) = tokio::sync::mpsc::channel(8); + tokio::task::spawn_blocking(move || { + let writer = ArrowChunkWriter::new(sender.clone()); + if let Err(err) = write_adbc_arrow_ipc(request, writer) { + let _ = sender.blocking_send(Err(io::Error::other(format!( + "failed to execute query via ADBC: {err}" + )))); + } + }); + + Ok(arrow_chunked_response(receiver)) + } +} + +fn execute_sql_json( + state: &AppState, + sql: String, + original_sql: Option, + route_name: &str, +) -> Result, (StatusCode, Json)> { + let _ = route_name; + #[cfg(not(feature = "runtime-server-adbc"))] + { + let _ = (state, sql, original_sql); + return Err(json_error( + StatusCode::BAD_REQUEST, + format!( + "ADBC execution support is not enabled. Rebuild with feature 'runtime-server-adbc' to use {route_name}." + ), + )); + } + + #[cfg(feature = "runtime-server-adbc")] + { + let Some(driver) = state.adbc_driver.clone() else { + return Err(json_error( + StatusCode::BAD_REQUEST, + "ADBC driver is not configured. Set SIDEMANTIC_SERVER_ADBC_DRIVER or pass --driver.", + )); + }; + + let result = execute_with_adbc(AdbcExecutionRequest { + driver, + sql: sql.clone(), + uri: state.adbc_uri.clone(), + entrypoint: state.adbc_entrypoint.clone(), + database_options: state.database_options.clone(), + connection_options: state.connection_options.clone(), + }) + .map_err(|e| { + json_error( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to execute query via ADBC: {e}"), + ) + })?; + + let rows = adbc_rows_to_json_rows(&result.columns, &result.rows); + let mut response = JsonMap::new(); + response.insert("sql".to_string(), json!(sql)); + if let Some(original_sql) = original_sql { + response.insert("original_sql".to_string(), json!(original_sql)); + } + response.insert("row_count".to_string(), json!(rows.len())); + response.insert("rows".to_string(), JsonValue::Array(rows)); + Ok(Json(JsonValue::Object(response))) + } +} + +#[cfg(feature = "runtime-server-adbc")] +fn adbc_rows_to_json_rows(columns: &[String], rows: &[Vec]) -> Vec { + rows.iter() + .map(|row| { + let mut row_map = JsonMap::new(); + for (idx, column) in columns.iter().enumerate() { + let value = row + .get(idx) + .map(adbc_value_to_json) + .unwrap_or(JsonValue::Null); + row_map.insert(column.clone(), value); + } + JsonValue::Object(row_map) + }) + .collect() +} + +fn normalize_sql(sql: &str) -> Result { + let mut normalized = sql.trim().to_string(); + if normalized.is_empty() { + return Err("SQL query cannot be empty".to_string()); + } + while normalized.ends_with(';') { + normalized.pop(); + normalized = normalized.trim_end().to_string(); + } + if has_unquoted_semicolon(&normalized) { + return Err("Only one SQL statement is allowed".to_string()); + } + Ok(normalized) +} + +fn has_unquoted_semicolon(sql: &str) -> bool { + let mut in_single = false; + let mut in_double = false; + let mut prev = '\0'; + for ch in sql.chars() { + match ch { + '\'' if !in_double && prev != '\\' => in_single = !in_single, + '"' if !in_single && prev != '\\' => in_double = !in_double, + ';' if !in_single && !in_double => return true, + _ => {} + } + prev = ch; + } + false +} + +fn require_select_only_sql(sql: &str) -> Result<(), String> { + let scrubbed = scrub_quoted_sql(sql); + let lower = scrubbed.to_ascii_lowercase(); + let first_word = lower + .split_whitespace() + .next() + .ok_or_else(|| "SQL query cannot be empty".to_string())?; + if first_word != "select" && first_word != "with" { + return Err("Raw SQL execution only supports SELECT statements".to_string()); + } + for banned in [ + "insert", "update", "delete", "drop", "create", "alter", "truncate", "merge", "copy", + "call", "grant", "revoke", + ] { + if lower + .split(|ch: char| !ch.is_ascii_alphanumeric() && ch != '_') + .any(|token| token == banned) + { + return Err(format!( + "Raw SQL execution only supports SELECT statements; found {banned}" + )); + } + } + Ok(()) +} + +fn scrub_quoted_sql(sql: &str) -> String { + let mut scrubbed = String::with_capacity(sql.len()); + let mut in_single = false; + let mut in_double = false; + let mut prev = '\0'; + for ch in sql.chars() { + match ch { + '\'' if !in_double && prev != '\\' => { + in_single = !in_single; + scrubbed.push(' '); + } + '"' if !in_single && prev != '\\' => { + in_double = !in_double; + scrubbed.push(' '); + } + _ if in_single || in_double => scrubbed.push(' '), + _ => scrubbed.push(ch), + } + prev = ch; + } + scrubbed +} + +fn parse_config() -> Result { + let mut models_path: Option = None; + let mut bind = env::var("SIDEMANTIC_SERVER_BIND") + .ok() + .unwrap_or_else(|| "127.0.0.1:4543".to_string()); + let mut auth_token = env::var("SIDEMANTIC_SERVER_AUTH_TOKEN").ok(); + let mut cors_origins = env::var("SIDEMANTIC_SERVER_CORS_ORIGINS") + .ok() + .map(|value| { + value + .split(',') + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToString::to_string) + .collect::>() + }) + .unwrap_or_default(); + let mut max_request_body_bytes = env::var("SIDEMANTIC_SERVER_MAX_REQUEST_BODY_BYTES") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(1024 * 1024); + let mut adbc_driver = env::var("SIDEMANTIC_SERVER_ADBC_DRIVER") + .ok() + .or_else(|| env::var("SIDEMANTIC_MCP_ADBC_DRIVER").ok()); + let mut adbc_uri = env::var("SIDEMANTIC_SERVER_ADBC_URI") + .ok() + .or_else(|| env::var("SIDEMANTIC_MCP_ADBC_URI").ok()); + let mut adbc_entrypoint = env::var("SIDEMANTIC_SERVER_ADBC_ENTRYPOINT") + .ok() + .or_else(|| env::var("SIDEMANTIC_MCP_ADBC_ENTRYPOINT").ok()); + let mut database_options: Vec = Vec::new(); + let mut connection_options: Vec = Vec::new(); + if let Ok(env_dbopts) = + env::var("SIDEMANTIC_SERVER_ADBC_DBOPTS").or_else(|_| env::var("SIDEMANTIC_ADBC_DBOPTS")) + { + database_options.extend(parse_database_options(&env_dbopts)?); + } + if let Ok(env_connopts) = env::var("SIDEMANTIC_SERVER_ADBC_CONNOPTS") + .or_else(|_| env::var("SIDEMANTIC_ADBC_CONNOPTS")) + { + connection_options.extend(parse_connection_options(&env_connopts)?); + } + + let mut args = env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--models" => { + let value = args + .next() + .ok_or_else(|| "--models requires a path value".to_string())?; + models_path = Some(value); + } + "--bind" => { + bind = args + .next() + .ok_or_else(|| "--bind requires a value".to_string())?; + } + "--auth-token" => { + auth_token = Some( + args.next() + .ok_or_else(|| "--auth-token requires a value".to_string())?, + ); + } + "--cors-origin" => { + cors_origins.push( + args.next() + .ok_or_else(|| "--cors-origin requires a value".to_string())?, + ); + } + "--max-request-body-bytes" => { + let value = args + .next() + .ok_or_else(|| "--max-request-body-bytes requires a value".to_string())?; + max_request_body_bytes = value + .parse::() + .map_err(|_| "--max-request-body-bytes must be an integer".to_string())?; + } + "--driver" => { + adbc_driver = Some( + args.next() + .ok_or_else(|| "--driver requires a value".to_string())?, + ); + } + "--uri" => { + adbc_uri = Some( + args.next() + .ok_or_else(|| "--uri requires a value".to_string())?, + ); + } + "--entrypoint" => { + adbc_entrypoint = Some( + args.next() + .ok_or_else(|| "--entrypoint requires a value".to_string())?, + ); + } + "--dbopt" => { + let value = args + .next() + .ok_or_else(|| "--dbopt requires a value".to_string())?; + database_options.extend(parse_database_options(&value)?); + } + "--connopt" => { + let value = args + .next() + .ok_or_else(|| "--connopt requires a value".to_string())?; + connection_options.extend(parse_connection_options(&value)?); + } + "--help" | "-h" => { + return Err( + "Usage: sidemantic-server [--models ] [--bind ] [--auth-token ] [--cors-origin ] [--max-request-body-bytes ] [--driver ] [--uri ] [--entrypoint ] [--dbopt ] [--connopt ]".to_string() + ); + } + unknown => { + return Err(format!( + "unknown argument: {unknown}. Use --help for usage." + )); + } + } + } + + let models_path = models_path + .or_else(|| env::var("SIDEMANTIC_SERVER_MODELS").ok()) + .unwrap_or_else(|| ".".to_string()); + + Ok(ServerConfig { + models_path, + bind, + auth_token, + cors_origins, + max_request_body_bytes, + adbc_driver, + adbc_uri, + adbc_entrypoint, + database_options, + connection_options, + }) +} + +fn parse_kv_pairs(input: &str, option_name: &str) -> Result, String> { + let mut pairs = Vec::new(); + for fragment in input + .split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + { + let (key, value) = fragment + .split_once('=') + .ok_or_else(|| format!("{option_name} expects key=value, got '{fragment}'"))?; + if key.trim().is_empty() { + return Err(format!("{option_name} key cannot be empty: '{fragment}'")); + } + pairs.push((key.trim().to_string(), value.to_string())); + } + if pairs.is_empty() { + return Err(format!("{option_name} expects key=value pairs")); + } + Ok(pairs) +} + +#[cfg(feature = "runtime-server-adbc")] +fn parse_option_value(value: &str) -> OptionValue { + if let Some(rest) = value.strip_prefix("int:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Int(parsed); + } + } + if let Some(rest) = value.strip_prefix("float:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Double(parsed); + } + } + if let Some(rest) = value.strip_prefix("str:") { + return OptionValue::String(rest.to_string()); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Int(parsed); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Double(parsed); + } + OptionValue::String(value.to_string()) +} + +#[cfg(not(feature = "runtime-server-adbc"))] +fn parse_option_value(value: &str) -> String { + value.to_string() +} + +fn parse_database_options(input: &str) -> Result, String> { + let mut parsed = Vec::new(); + for (key, raw_value) in parse_kv_pairs(input, "--dbopt")? { + #[cfg(feature = "runtime-server-adbc")] + parsed.push(( + OptionDatabase::from(key.as_str()), + parse_option_value(&raw_value), + )); + #[cfg(not(feature = "runtime-server-adbc"))] + parsed.push((key, parse_option_value(&raw_value))); + } + Ok(parsed) +} + +fn parse_connection_options(input: &str) -> Result, String> { + let mut parsed = Vec::new(); + for (key, raw_value) in parse_kv_pairs(input, "--connopt")? { + #[cfg(feature = "runtime-server-adbc")] + parsed.push(( + OptionConnection::from(key.as_str()), + parse_option_value(&raw_value), + )); + #[cfg(not(feature = "runtime-server-adbc"))] + parsed.push((key, parse_option_value(&raw_value))); + } + Ok(parsed) +} + +fn load_runtime(models_path: &str) -> Result { + let path = PathBuf::from(models_path); + if path.is_dir() { + return SidemanticRuntime::from_directory(path) + .map_err(|e| format!("failed to load models from directory '{models_path}': {e}")); + } + if path.is_file() { + return SidemanticRuntime::from_file(path) + .map_err(|e| format!("failed to load models from file '{models_path}': {e}")); + } + Err(format!( + "models path '{models_path}' is not a readable file or directory" + )) +} + +fn enum_json_name(value: &T) -> Option { + serde_json::to_value(value) + .ok() + .and_then(|value| value.as_str().map(ToString::to_string)) +} + +fn metric_json(metric: &Metric) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(metric.name)); + entry.insert("sql".to_string(), json!(metric.sql)); + if let Some(agg) = &metric.agg { + if let Some(name) = enum_json_name(agg) { + entry.insert("agg".to_string(), json!(name)); + } + } + if let Some(name) = enum_json_name(&metric.r#type) { + entry.insert("type".to_string(), json!(name)); + } + if let Some(description) = &metric.description { + entry.insert("description".to_string(), json!(description)); + } + if !metric.filters.is_empty() { + entry.insert("filters".to_string(), json!(metric.filters)); + } + JsonValue::Object(entry) +} + +fn relationship_json( + model: &Model, + relationship: &Relationship, + model_map: &HashMap<&str, &Model>, +) -> JsonValue { + let mut entry = JsonMap::new(); + entry.insert("name".to_string(), json!(relationship.name)); + entry.insert( + "type".to_string(), + json!(enum_json_name(&relationship.r#type).unwrap_or_else(|| "many_to_one".to_string())), + ); + if let Some(foreign_key) = &relationship.foreign_key { + entry.insert("foreign_key".to_string(), json!(foreign_key)); + } + if let Some(primary_key) = &relationship.primary_key { + entry.insert("primary_key".to_string(), json!(primary_key)); + } + if let Some(through) = &relationship.through { + entry.insert("through".to_string(), json!(through)); + } + if let Some(through_fk) = &relationship.through_foreign_key { + entry.insert("through_foreign_key".to_string(), json!(through_fk)); + } + if let Some(related_fk) = &relationship.related_foreign_key { + entry.insert("related_foreign_key".to_string(), json!(related_fk)); + } + + if let Some(join_condition) = format_join_condition(model, relationship, model_map) { + entry.insert("join_condition".to_string(), json!(join_condition)); + } + + JsonValue::Object(entry) +} + +fn format_join_condition( + model: &Model, + relationship: &Relationship, + model_map: &HashMap<&str, &Model>, +) -> Option { + let related_model = model_map.get(relationship.name.as_str())?; + let related_name = relationship.name.as_str(); + let model_name = model.name.as_str(); + + match relationship.r#type { + RelationshipType::ManyToOne => { + let fk = relationship + .foreign_key + .clone() + .unwrap_or_else(|| format!("{related_name}_id")); + let pk = relationship + .primary_key + .clone() + .unwrap_or_else(|| related_model.primary_key.clone()); + Some(format!("{model_name}.{fk} = {related_name}.{pk}")) + } + RelationshipType::OneToMany | RelationshipType::OneToOne => { + let fk = relationship.foreign_key.clone()?; + let pk = model.primary_key.clone(); + Some(format!("{related_name}.{fk} = {model_name}.{pk}")) + } + RelationshipType::ManyToMany => { + if let Some(through) = &relationship.through { + let _junction_model = model_map.get(through.as_str())?; + let (junction_self_fk, junction_related_fk) = relationship.junction_keys(); + let junction_self_fk = junction_self_fk?; + let junction_related_fk = junction_related_fk?; + let base_pk = model.primary_key.clone(); + let related_pk = relationship + .primary_key + .clone() + .unwrap_or_else(|| related_model.primary_key.clone()); + return Some(format!( + "{model_name}.{base_pk} = {through}.{junction_self_fk} AND {through}.{junction_related_fk} = {related_name}.{related_pk}" + )); + } + + relationship.foreign_key.clone().map(|foreign_key| { + format!( + "{model_name}.{} = {related_name}.{foreign_key}", + model.primary_key + ) + }) + } + } +} + +#[cfg(feature = "runtime-server-adbc")] +fn adbc_value_to_json(value: &AdbcValue) -> JsonValue { + match value { + AdbcValue::Null => JsonValue::Null, + AdbcValue::Bool(v) => json!(v), + AdbcValue::I64(v) => json!(v), + AdbcValue::U64(v) => json!(v), + AdbcValue::F64(v) => json!(v), + AdbcValue::String(v) => json!(v), + AdbcValue::Bytes(v) => json!(v), + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = parse_config()?; + let runtime = load_runtime(&config.models_path)?; + + let state = Arc::new(AppState { + runtime: Arc::new(runtime), + adbc_driver: config.adbc_driver, + adbc_uri: config.adbc_uri, + adbc_entrypoint: config.adbc_entrypoint, + database_options: config.database_options, + connection_options: config.connection_options, + }); + let controls = Arc::new(HttpControls { + auth_token: config.auth_token, + cors_origins: config.cors_origins, + max_request_body_bytes: config.max_request_body_bytes, + }); + + let app = Router::new() + .route("/readyz", get(readyz)) + .route("/health", get(health)) + .route("/graph", get(graph)) + .route("/models", get(list_models).post(get_models)) + .route("/models/{model}", get(get_model)) + .route("/compile", post(compile_query)) + .route("/query", post(run_query)) + .route("/query/compile", post(compile_query)) + .route("/query/run", post(run_query)) + .route("/sql/compile", post(compile_sql)) + .route("/sql", post(run_sql)) + .route("/raw", post(run_raw_sql)) + .with_state(state) + .layer(middleware::from_fn_with_state( + controls, + http_controls_middleware, + )); + + eprintln!("sidemantic-server listening on {}", config.bind); + let listener = tokio::net::TcpListener::bind(&config.bind).await?; + axum::serve(listener, app).await?; + Ok(()) +} diff --git a/sidemantic-rs/src/bin/sidemantic-workbench.rs b/sidemantic-rs/src/bin/sidemantic-workbench.rs new file mode 100644 index 00000000..327a3314 --- /dev/null +++ b/sidemantic-rs/src/bin/sidemantic-workbench.rs @@ -0,0 +1,21 @@ +//! Dedicated entrypoint for the Rust workbench. + +use std::env; +use std::process; + +type CliResult = std::result::Result; + +#[allow(dead_code)] +#[path = "../main.rs"] +mod cli; + +#[cfg(feature = "adbc-exec")] +pub(crate) use cli::parse_connection_url_to_adbc; + +fn main() { + let args = env::args().skip(1).collect::>(); + if let Err(err) = cli::workbench_command(&args) { + eprintln!("error: {err}"); + process::exit(1); + } +} diff --git a/sidemantic-rs/src/config/loader.rs b/sidemantic-rs/src/config/loader.rs index 905fb4c5..ad294942 100644 --- a/sidemantic-rs/src/config/loader.rs +++ b/sidemantic-rs/src/config/loader.rs @@ -1,15 +1,45 @@ -//! Config loader: loads semantic layer definitions from YAML files +//! Config loader: loads semantic layer definitions from YAML/SQL files -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fs; use std::path::Path; +use regex::Regex; +use serde::{Deserialize, Serialize}; + use crate::core::{ - resolve_model_inheritance, Model, Relationship, RelationshipType, SemanticGraph, + extract_dependencies, resolve_model_inheritance, Metric, Model, Parameter, Relationship, + RelationshipType, SemanticGraph, }; use crate::error::{Result, SidemanticError}; use super::schema::{CubeConfig, SidemanticConfig}; +use super::sql_parser::{ + parse_sql_definitions, parse_sql_graph_definitions_extended, parse_sql_model, +}; + +#[derive(Debug)] +struct ParsedConfig { + models: Vec, + extends_map: HashMap, + top_level_metrics: Vec, + top_level_parameters: Vec, +} + +#[derive(Debug)] +pub struct LoadedGraphMetadata { + pub graph: SemanticGraph, + pub model_order: Vec, + pub top_level_metrics: Vec, + pub original_model_metrics: HashMap>, + pub model_sources: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadedModelSource { + pub source_format: String, + pub source_file: Option, +} /// Detected config format #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -22,40 +52,286 @@ pub enum ConfigFormat { /// Load a semantic graph from a single YAML file pub fn load_from_file(path: impl AsRef) -> Result { + Ok(load_from_file_with_metadata(path)?.graph) +} + +/// Load a semantic graph from a single file with metadata. +/// +/// Supported file extensions: +/// - `.yml` / `.yaml` (native/cube YAML) +/// - `.sql` (MODEL statements or SQL + YAML frontmatter) +pub fn load_from_file_with_metadata(path: impl AsRef) -> Result { let path = path.as_ref(); let content = fs::read_to_string(path) .map_err(|e| SidemanticError::Validation(format!("Failed to read file: {e}")))?; - load_from_string(&content) + let ext = path + .extension() + .and_then(|value| value.to_str()) + .map(str::to_ascii_lowercase); + match ext.as_deref() { + Some("sql") => load_from_sql_string_with_metadata(&content), + _ => load_from_string_with_metadata(&content), + } } /// Load a semantic graph from a YAML string pub fn load_from_string(content: &str) -> Result { + Ok(load_from_string_with_metadata(content)?.graph) +} + +/// Load a semantic graph from YAML with parsing metadata used by Python bridge. +pub fn load_from_string_with_metadata(content: &str) -> Result { let format = detect_format(content); - let (models, extends_map) = parse_content_with_extends(content, format)?; + let parsed = parse_content_with_extends(content, format)?; + let ParsedConfig { + models, + extends_map, + top_level_metrics, + top_level_parameters, + } = parsed; + let model_order: Vec = models.iter().map(|model| model.name.clone()).collect(); + let original_model_metrics: HashMap> = models + .iter() + .map(|model| { + ( + model.name.clone(), + model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect::>(), + ) + }) + .collect(); // Resolve inheritance - let models_map: HashMap = - models.into_iter().map(|m| (m.name.clone(), m)).collect(); - let resolved_models = resolve_model_inheritance(models_map, &extends_map)?; + let models_map = collect_unique_models(models)?; + let mut resolved_models = resolve_model_inheritance(models_map, &extends_map)?; + if !resolved_models.is_empty() && !top_level_metrics.is_empty() { + assign_top_level_metrics(&mut resolved_models, top_level_metrics.clone())?; + } + + let mut graph = SemanticGraph::new(); + for model in resolved_models.into_values() { + graph.add_model(model)?; + } + for parameter in top_level_parameters { + graph.add_parameter(parameter)?; + } + + let model_sources = model_order + .iter() + .map(|model_name| { + ( + model_name.clone(), + LoadedModelSource { + source_format: match format { + ConfigFormat::Sidemantic => "Sidemantic".to_string(), + ConfigFormat::Cube => "Cube".to_string(), + }, + source_file: None, + }, + ) + }) + .collect(); + + Ok(LoadedGraphMetadata { + graph, + model_order, + top_level_metrics, + original_model_metrics, + model_sources, + }) +} + +fn parse_sql_frontmatter_and_body(content: &str) -> Result<(Option, String)> { + if !content.trim().starts_with("---") { + return Ok((None, content.to_string())); + } + + let parts: Vec<&str> = content.splitn(3, "---").collect(); + if parts.len() < 3 { + return Ok((None, content.to_string())); + } + + let frontmatter_text = parts[1].trim(); + let sql_body = parts[2].trim().to_string(); + if frontmatter_text.is_empty() { + return Ok((None, sql_body)); + } + + let frontmatter_value: serde_yaml::Value = + serde_yaml::from_str(frontmatter_text).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL frontmatter: {e}")) + })?; + + match frontmatter_value { + serde_yaml::Value::Null => Ok((None, sql_body)), + serde_yaml::Value::Mapping(mapping) => { + if mapping.is_empty() { + Ok((None, sql_body)) + } else { + Ok((Some(mapping), sql_body)) + } + } + _ => Err(SidemanticError::Validation( + "failed to parse SQL frontmatter: frontmatter must be a YAML mapping".to_string(), + )), + } +} + +fn model_from_sql_frontmatter(frontmatter: serde_yaml::Mapping) -> Result { + let mut wrapper = serde_yaml::Mapping::new(); + wrapper.insert( + serde_yaml::Value::String("models".to_string()), + serde_yaml::Value::Sequence(vec![serde_yaml::Value::Mapping(frontmatter)]), + ); + + let config: SidemanticConfig = serde_yaml::from_value(serde_yaml::Value::Mapping(wrapper)) + .map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL frontmatter model: {e}")) + })?; + let (models, _, _) = config.into_parts(); + models.into_iter().next().ok_or_else(|| { + SidemanticError::Validation( + "failed to parse SQL frontmatter: missing model definition".to_string(), + ) + }) +} + +fn parse_sql_content(content: &str) -> Result { + let has_model_statement = { + let upper = content.to_ascii_uppercase(); + upper.contains("MODEL") && upper.contains("MODEL (") + }; + + let mut models: Vec = Vec::new(); + let mut top_level_metrics: Vec = Vec::new(); + let mut top_level_parameters: Vec = Vec::new(); + + if has_model_statement { + let model = parse_sql_model(content).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL model statement: {e}")) + })?; + let model_metric_names: HashSet = model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect(); + models.push(model); + + let (sql_metrics, _, sql_parameters, _) = parse_sql_graph_definitions_extended(content) + .map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL graph definitions: {e}")) + })?; + for metric in sql_metrics { + if !model_metric_names.contains(&metric.name) { + top_level_metrics.push(metric); + } + } + top_level_parameters.extend(sql_parameters); + } else { + let (frontmatter, sql_body) = parse_sql_frontmatter_and_body(content)?; + let (sql_metrics, sql_segments, sql_parameters, sql_preaggs) = + parse_sql_graph_definitions_extended(&sql_body).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL graph definitions: {e}")) + })?; + top_level_parameters.extend(sql_parameters); + + if let Some(frontmatter) = frontmatter { + let mut model = model_from_sql_frontmatter(frontmatter)?; + model.metrics.extend(sql_metrics); + model.segments.extend(sql_segments); + model.pre_aggregations.extend(sql_preaggs); + models.push(model); + } else { + top_level_metrics.extend(sql_metrics); + } + } + + Ok(ParsedConfig { + models, + extends_map: HashMap::new(), + top_level_metrics, + top_level_parameters, + }) +} + +/// Load a semantic graph from SQL content with metadata. +pub fn load_from_sql_string_with_metadata(content: &str) -> Result { + let parsed = parse_sql_content(content)?; + let ParsedConfig { + models, + extends_map, + top_level_metrics, + top_level_parameters, + } = parsed; + let model_order: Vec = models.iter().map(|model| model.name.clone()).collect(); + let original_model_metrics: HashMap> = models + .iter() + .map(|model| { + ( + model.name.clone(), + model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect::>(), + ) + }) + .collect(); + + let models_map = collect_unique_models(models)?; + let mut resolved_models = resolve_model_inheritance(models_map, &extends_map)?; + if !resolved_models.is_empty() && !top_level_metrics.is_empty() { + assign_top_level_metrics(&mut resolved_models, top_level_metrics.clone())?; + } let mut graph = SemanticGraph::new(); for model in resolved_models.into_values() { graph.add_model(model)?; } + for parameter in top_level_parameters { + graph.add_parameter(parameter)?; + } - Ok(graph) + let model_sources = model_order + .iter() + .map(|model_name| { + ( + model_name.clone(), + LoadedModelSource { + source_format: "Sidemantic".to_string(), + source_file: None, + }, + ) + }) + .collect(); + + Ok(LoadedGraphMetadata { + graph, + model_order, + top_level_metrics, + original_model_metrics, + model_sources, + }) } -/// Load all YAML files from a directory into a semantic graph +/// Load all semantic model files from a directory into a semantic graph. /// /// This function: -/// 1. Recursively finds all .yml/.yaml files +/// 1. Recursively finds all `.yml`/`.yaml`/`.sql` files /// 2. Auto-detects format (Sidemantic vs Cube.js) /// 3. Parses and collects all models /// 4. Infers relationships from FK naming conventions /// 5. Returns a unified SemanticGraph pub fn load_from_directory(dir: impl AsRef) -> Result { + Ok(load_from_directory_with_metadata(dir)?.graph) +} + +/// Load all YAML files from a directory into a semantic graph with metadata. +pub fn load_from_directory_with_metadata(dir: impl AsRef) -> Result { let dir = dir.as_ref(); if !dir.is_dir() { @@ -66,36 +342,126 @@ pub fn load_from_directory(dir: impl AsRef) -> Result { } let mut all_models: HashMap = HashMap::new(); + let mut all_top_level_metrics: Vec = Vec::new(); + let mut all_top_level_parameters: Vec = Vec::new(); + let mut model_order: Vec = Vec::new(); + let mut model_sources: HashMap = HashMap::new(); - // Recursively find and parse all YAML files + // Recursively find and parse model files. for entry in walkdir(dir)? { let path = entry; - let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); + let ext = path + .extension() + .and_then(|e| e.to_str()) + .map(str::to_ascii_lowercase); - if ext == "yml" || ext == "yaml" { - let content = fs::read_to_string(&path).map_err(|e| { - SidemanticError::Validation(format!("Failed to read {}: {}", path.display(), e)) - })?; + match ext.as_deref() { + Some("yml") | Some("yaml") => { + let content = fs::read_to_string(&path).map_err(|e| { + SidemanticError::Validation(format!("Failed to read {}: {}", path.display(), e)) + })?; - let format = detect_format(&content); - let models = parse_content(&content, format)?; + let format = detect_format(&content); + let parsed = parse_content(&content, format)?; + let source_format = match format { + ConfigFormat::Sidemantic => "Sidemantic", + ConfigFormat::Cube => "Cube", + }; + let source_file = path + .strip_prefix(dir) + .ok() + .map(|value| value.to_string_lossy().to_string()); - for model in models { - all_models.insert(model.name.clone(), model); + for model in parsed.models { + if all_models.contains_key(&model.name) { + return Err(SidemanticError::Validation(format!( + "Duplicate model '{}' found while loading directory", + model.name + ))); + } + model_order.push(model.name.clone()); + model_sources.insert( + model.name.clone(), + LoadedModelSource { + source_format: source_format.to_string(), + source_file: source_file.clone(), + }, + ); + all_models.insert(model.name.clone(), model); + } + all_top_level_metrics.extend(parsed.top_level_metrics); + all_top_level_parameters.extend(parsed.top_level_parameters); + } + Some("sql") => { + let content = fs::read_to_string(&path).map_err(|e| { + SidemanticError::Validation(format!("Failed to read {}: {}", path.display(), e)) + })?; + let parsed = parse_sql_content(&content)?; + let source_file = path + .strip_prefix(dir) + .ok() + .map(|value| value.to_string_lossy().to_string()); + + for model in parsed.models { + if all_models.contains_key(&model.name) { + return Err(SidemanticError::Validation(format!( + "Duplicate model '{}' found while loading directory", + model.name + ))); + } + model_order.push(model.name.clone()); + model_sources.insert( + model.name.clone(), + LoadedModelSource { + source_format: "Sidemantic".to_string(), + source_file: source_file.clone(), + }, + ); + all_models.insert(model.name.clone(), model); + } + all_top_level_metrics.extend(parsed.top_level_metrics); + all_top_level_parameters.extend(parsed.top_level_parameters); } + _ => {} } } + let original_model_metrics: HashMap> = all_models + .iter() + .map(|(model_name, model)| { + ( + model_name.clone(), + model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect::>(), + ) + }) + .collect(); + // Infer relationships from FK naming conventions infer_relationships(&mut all_models); + if !all_models.is_empty() && !all_top_level_metrics.is_empty() { + assign_top_level_metrics(&mut all_models, all_top_level_metrics.clone())?; + } // Build the graph let mut graph = SemanticGraph::new(); for (_, model) in all_models { graph.add_model(model)?; } + for parameter in all_top_level_parameters { + graph.add_parameter(parameter)?; + } - Ok(graph) + Ok(LoadedGraphMetadata { + graph, + model_order, + top_level_metrics: all_top_level_metrics, + original_model_metrics, + model_sources, + }) } /// Detect the config format from content @@ -110,39 +476,397 @@ fn detect_format(content: &str) -> ConfigFormat { } /// Parse content based on detected format -fn parse_content(content: &str, format: ConfigFormat) -> Result> { - let (models, _) = parse_content_with_extends(content, format)?; - Ok(models) +fn parse_content(content: &str, format: ConfigFormat) -> Result { + parse_content_with_extends(content, format) } -/// Parse content and return extends map for inheritance resolution -fn parse_content_with_extends( +fn yaml_mapping_get_str<'a>(mapping: &'a serde_yaml::Mapping, key: &str) -> Option<&'a str> { + mapping + .get(serde_yaml::Value::String(key.to_string())) + .and_then(serde_yaml::Value::as_str) +} + +fn apply_embedded_sql_definitions( content: &str, - format: ConfigFormat, -) -> Result<(Vec, HashMap)> { + models: &mut [Model], + top_level_metrics: &mut Vec, +) -> Result<()> { + let root: serde_yaml::Value = serde_yaml::from_str(content) + .map_err(|e| SidemanticError::Validation(format!("YAML parse error: {e}")))?; + let Some(root_mapping) = root.as_mapping() else { + return Ok(()); + }; + + if let Some(sql_metrics) = yaml_mapping_get_str(root_mapping, "sql_metrics") { + let (metrics, _) = parse_sql_definitions(sql_metrics).map_err(|e| { + SidemanticError::Validation(format!( + "Failed to parse top-level sql_metrics definitions: {e}" + )) + })?; + top_level_metrics.extend(metrics); + } + + if let Some(sql_segments) = yaml_mapping_get_str(root_mapping, "sql_segments") { + parse_sql_definitions(sql_segments).map_err(|e| { + SidemanticError::Validation(format!( + "Failed to parse top-level sql_segments definitions: {e}" + )) + })?; + } + + let model_lookup: HashMap = models + .iter() + .enumerate() + .map(|(index, model)| (model.name.clone(), index)) + .collect(); + + let Some(model_entries) = root_mapping + .get(serde_yaml::Value::String("models".to_string())) + .and_then(serde_yaml::Value::as_sequence) + else { + return Ok(()); + }; + + for model_entry in model_entries { + let Some(model_mapping) = model_entry.as_mapping() else { + continue; + }; + let Some(model_name) = yaml_mapping_get_str(model_mapping, "name") else { + continue; + }; + let Some(model_index) = model_lookup.get(model_name) else { + continue; + }; + + if let Some(sql_metrics) = yaml_mapping_get_str(model_mapping, "sql_metrics") { + let (metrics, _) = parse_sql_definitions(sql_metrics).map_err(|e| { + SidemanticError::Validation(format!( + "Failed to parse sql_metrics definitions for model '{model_name}': {e}" + )) + })?; + models[*model_index].metrics.extend(metrics); + } + + if let Some(sql_segments) = yaml_mapping_get_str(model_mapping, "sql_segments") { + let (_, segments) = parse_sql_definitions(sql_segments).map_err(|e| { + SidemanticError::Validation(format!( + "Failed to parse sql_segments definitions for model '{model_name}': {e}" + )) + })?; + models[*model_index].segments.extend(segments); + } + } + + Ok(()) +} + +fn substitute_env_vars(content: &str) -> String { + let brace_pattern = Regex::new(r"\$\{([^}]+)\}").expect("valid regex"); + let substituted = brace_pattern.replace_all(content, |caps: ®ex::Captures<'_>| { + let var_expr = &caps[1]; + if let Some((var_name, default_value)) = var_expr.split_once(":-") { + std::env::var(var_name).unwrap_or_else(|_| default_value.to_string()) + } else { + std::env::var(var_expr).unwrap_or_else(|_| caps[0].to_string()) + } + }); + + let simple_pattern = Regex::new(r"\$([A-Z_][A-Z0-9_]*)").expect("valid regex"); + simple_pattern + .replace_all(&substituted, |caps: ®ex::Captures<'_>| { + let var_name = &caps[1]; + std::env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) + }) + .into_owned() +} + +/// Parse content and return extends map for inheritance resolution +fn parse_content_with_extends(content: &str, format: ConfigFormat) -> Result { + let content = substitute_env_vars(content); + match format { ConfigFormat::Sidemantic => { - let config: SidemanticConfig = serde_yaml::from_str(content) + let config: SidemanticConfig = serde_yaml::from_str(&content) .map_err(|e| SidemanticError::Validation(format!("YAML parse error: {e}")))?; - - // Extract extends map before converting to models let extends_map: HashMap = config .models .iter() .filter_map(|m| m.extends.as_ref().map(|e| (m.name.clone(), e.clone()))) .collect(); + let (mut models, mut top_level_metrics, top_level_parameters) = config.into_parts(); + apply_embedded_sql_definitions(&content, &mut models, &mut top_level_metrics)?; - Ok((config.into_models(), extends_map)) + Ok(ParsedConfig { + models, + extends_map, + top_level_metrics, + top_level_parameters, + }) } ConfigFormat::Cube => { - let config: CubeConfig = serde_yaml::from_str(content) + let config: CubeConfig = serde_yaml::from_str(&content) .map_err(|e| SidemanticError::Validation(format!("YAML parse error: {e}")))?; // Cube.js doesn't support extends in the same way - Ok((config.into_models(), HashMap::new())) + Ok(ParsedConfig { + models: config.into_models(), + extends_map: HashMap::new(), + top_level_metrics: Vec::new(), + top_level_parameters: Vec::new(), + }) } } } +fn collect_unique_models(models: Vec) -> Result> { + let mut map = HashMap::new(); + for model in models { + if map.contains_key(&model.name) { + return Err(SidemanticError::Validation(format!( + "Duplicate model '{}' in config", + model.name + ))); + } + map.insert(model.name.clone(), model); + } + Ok(map) +} + +fn owners_from_dotted_reference( + reference: &str, + models: &HashMap, +) -> Option { + let (model_name, _) = reference.split_once('.')?; + models + .contains_key(model_name) + .then(|| model_name.to_string()) +} + +fn owners_from_sql_fragment( + fragment: &str, + models: &HashMap, +) -> Result> { + let model_ref_regex = Regex::new(r"\b([A-Za-z_][A-Za-z0-9_]*)\.([A-Za-z_][A-Za-z0-9_]*)\b") + .map_err(|e| SidemanticError::Validation(format!("Invalid ownership regex: {e}")))?; + Ok(model_ref_regex + .captures_iter(fragment) + .filter_map(|captures| captures.get(1).map(|model_name| model_name.as_str())) + .filter(|model_name| models.contains_key(*model_name)) + .map(ToString::to_string) + .collect()) +} + +fn model_owners_for_entity( + entity: Option<&str>, + models: &HashMap, +) -> HashSet { + let Some(entity) = entity else { + return HashSet::new(); + }; + if let Some(owner) = owners_from_dotted_reference(entity, models) { + return HashSet::from([owner]); + } + models + .iter() + .filter(|(_, model)| { + model + .dimensions + .iter() + .any(|dimension| dimension.name == entity) + }) + .map(|(model_name, _)| model_name.clone()) + .collect() +} + +fn metric_reference_strings(metric: &Metric) -> Vec<&str> { + let mut references = vec![ + metric.sql.as_deref(), + metric.base_metric.as_deref(), + metric.numerator.as_deref(), + metric.denominator.as_deref(), + metric.entity.as_deref(), + metric.base_event.as_deref(), + metric.conversion_event.as_deref(), + metric.cohort_event.as_deref(), + metric.activity_event.as_deref(), + metric.having.as_deref(), + ] + .into_iter() + .flatten() + .collect::>(); + + if let Some(steps) = metric.steps.as_ref() { + references.extend(steps.iter().map(String::as_str)); + } + if let Some(inner_metrics) = metric.inner_metrics.as_ref() { + references.extend( + inner_metrics + .iter() + .filter_map(|inner| inner.sql.as_deref()), + ); + } + if let Some(entity_dimensions) = metric.entity_dimensions.as_ref() { + references.extend(entity_dimensions.iter().map(String::as_str)); + } + + references +} + +fn resolve_metric_owners( + metric_name: &str, + top_level_metrics: &HashMap, + models: &HashMap, + cache: &mut HashMap>, + visiting: &mut HashSet, +) -> Result> { + if let Some(cached) = cache.get(metric_name) { + return Ok(cached.clone()); + } + if !visiting.insert(metric_name.to_string()) { + return Ok(HashSet::new()); + } + + let metric = top_level_metrics.get(metric_name).ok_or_else(|| { + SidemanticError::Validation(format!( + "Top-level metric '{}' not found while resolving ownership", + metric_name + )) + })?; + + let existing_model_owners: HashSet = models + .iter() + .filter_map(|(model_name, model)| { + if model.get_metric(metric_name).is_some() { + Some(model_name.clone()) + } else { + None + } + }) + .collect(); + if existing_model_owners.len() == 1 { + visiting.remove(metric_name); + cache.insert(metric_name.to_string(), existing_model_owners.clone()); + return Ok(existing_model_owners); + } + + let deps = extract_dependencies(metric, None); + let mut owners = HashSet::new(); + for dep in deps { + if let Some(owner) = owners_from_dotted_reference(&dep, models) { + owners.insert(owner); + continue; + } + + if top_level_metrics.contains_key(&dep) { + owners.extend(resolve_metric_owners( + &dep, + top_level_metrics, + models, + cache, + visiting, + )?); + continue; + } + + for (model_name, model) in models { + if model.get_metric(&dep).is_some() { + owners.insert(model_name.clone()); + } + } + } + + if owners.is_empty() { + for reference in metric_reference_strings(metric) { + if let Some(owner) = owners_from_dotted_reference(reference, models) { + owners.insert(owner); + } + owners.extend(owners_from_sql_fragment(reference, models)?); + } + } + + if owners.is_empty() { + owners.extend(model_owners_for_entity(metric.entity.as_deref(), models)); + } + + if owners.is_empty() && models.len() == 1 { + if let Some(single_model) = models.keys().next() { + owners.insert(single_model.clone()); + } + } + + visiting.remove(metric_name); + cache.insert(metric_name.to_string(), owners.clone()); + Ok(owners) +} + +fn assign_top_level_metrics( + models: &mut HashMap, + top_level_metrics: Vec, +) -> Result<()> { + if top_level_metrics.is_empty() { + return Ok(()); + } + + let mut metric_by_name = HashMap::new(); + for metric in &top_level_metrics { + if metric_by_name + .insert(metric.name.clone(), metric.clone()) + .is_some() + { + return Err(SidemanticError::Validation(format!( + "Duplicate top-level metric '{}'", + metric.name + ))); + } + } + + let mut owner_cache: HashMap> = HashMap::new(); + let mut ownership: HashMap = HashMap::new(); + for metric in &top_level_metrics { + let owners = resolve_metric_owners( + &metric.name, + &metric_by_name, + models, + &mut owner_cache, + &mut HashSet::new(), + )?; + + if owners.len() != 1 { + let mut owner_list: Vec = owners.into_iter().collect(); + owner_list.sort(); + return Err(SidemanticError::Validation(format!( + "Cannot determine single owning model for top-level metric '{}'; owners={:?}", + metric.name, owner_list + ))); + } + + let owner = owners + .into_iter() + .next() + .expect("owner set length checked to be exactly one"); + ownership.insert(metric.name.clone(), owner); + } + + for metric in top_level_metrics { + let owner = ownership.get(&metric.name).ok_or_else(|| { + SidemanticError::Validation(format!( + "Missing owner assignment for top-level metric '{}'", + metric.name + )) + })?; + let owner_model = models.get_mut(owner).ok_or_else(|| { + SidemanticError::Validation(format!( + "Owner model '{}' not found for top-level metric '{}'", + owner, metric.name + )) + })?; + if owner_model.get_metric(&metric.name).is_none() { + owner_model.metrics.push(metric); + } + } + + Ok(()) +} + /// Infer relationships between models based on FK naming conventions /// /// Looks for columns ending with `_id` and tries to match them to existing models. @@ -200,7 +924,12 @@ fn infer_relationships(models: &mut HashMap) { name: actual_target.clone(), r#type: RelationshipType::ManyToOne, foreign_key: Some(dim.name.clone()), + foreign_key_columns: Some(vec![dim.name.clone()]), primary_key: Some("id".to_string()), + primary_key_columns: Some(vec!["id".to_string()]), + through: None, + through_foreign_key: None, + related_foreign_key: None, sql: None, }, )); @@ -212,7 +941,12 @@ fn infer_relationships(models: &mut HashMap) { name: model_name.clone(), r#type: RelationshipType::OneToMany, foreign_key: Some(dim.name.clone()), + foreign_key_columns: Some(vec![dim.name.clone()]), primary_key: Some("id".to_string()), + primary_key_columns: Some(vec!["id".to_string()]), + through: None, + through_foreign_key: None, + related_foreign_key: None, sql: None, }, )); @@ -272,6 +1006,23 @@ mod tests { assert_eq!(detect_format(content), ConfigFormat::Cube); } + #[test] + fn test_substitute_env_vars_with_default() { + let content = "table: ${SIDEMANTIC_RS_MISSING_FOR_TEST:-orders_table}"; + let substituted = substitute_env_vars(content); + assert_eq!(substituted, "table: orders_table"); + } + + #[test] + fn test_substitute_env_vars_for_existing_var() { + let Some(home) = std::env::var("HOME").ok() else { + return; + }; + let content = "root: $HOME"; + let substituted = substitute_env_vars(content); + assert_eq!(substituted, format!("root: {home}")); + } + #[test] fn test_load_from_string_sidemantic() { let yaml = r#" @@ -315,6 +1066,226 @@ cubes: assert_eq!(model.dimensions[0].sql, Some("status".to_string())); } + #[test] + fn test_load_from_string_assigns_top_level_metrics() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count + +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count + - name: double_revenue_per_order + type: derived + sql: revenue_per_order * 2 +"#; + + let graph = load_from_string(yaml).unwrap(); + let orders = graph.get_model("orders").unwrap(); + assert!(orders.get_metric("revenue_per_order").is_some()); + assert!(orders.get_metric("double_revenue_per_order").is_some()); + } + + #[test] + fn test_load_from_string_rejects_ambiguous_top_level_metrics() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount + + - name: customers + table: customers + primary_key: customer_id + metrics: + - name: revenue + agg: sum + sql: lifetime_value + +metrics: + - name: blended_revenue + type: derived + sql: revenue +"#; + + let err = load_from_string(yaml).unwrap_err(); + assert!(err.to_string().contains( + "Cannot determine single owning model for top-level metric 'blended_revenue'" + )); + } + + #[test] + fn test_load_from_string_assigns_top_level_metric_by_existing_model_metric_name() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount + - name: revenue_yoy + type: time_comparison + base_metric: revenue + comparison_type: yoy + + - name: customers + table: customers + primary_key: customer_id + metrics: + - name: revenue + agg: sum + sql: lifetime_value + +metrics: + - name: revenue_yoy + type: time_comparison + base_metric: revenue + comparison_type: yoy +"#; + + let graph = load_from_string(yaml).unwrap(); + let orders = graph.get_model("orders").unwrap(); + assert_eq!( + orders + .metrics + .iter() + .filter(|metric| metric.name == "revenue_yoy") + .count(), + 1 + ); + + let customers = graph.get_model("customers").unwrap(); + assert!(customers.get_metric("revenue_yoy").is_none()); + } + + #[test] + fn test_load_from_string_assigns_complex_top_level_metrics_by_entity_dimension() { + let yaml = r#" +models: + - name: events + table: events + primary_key: event_id + dimensions: + - name: user_id + type: categorical + - name: platform + type: categorical + - name: event_type + type: categorical + + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_id + type: categorical + +metrics: + - name: signup_conversion + type: conversion + entity: user_id + base_event: event_type = 'signup' + conversion_event: event_type = 'purchase' + conversion_window: 7 days + - name: signup_retention + type: retention + entity: user_id + cohort_event: event_type = 'signup' + - name: multi_platform_users + type: cohort + entity: user_id + inner_metrics: + - name: platform_count + agg: count_distinct + sql: platform + having: platform_count >= 2 + agg: count +"#; + + let graph = load_from_string(yaml).unwrap(); + let events = graph.get_model("events").unwrap(); + assert!(events.get_metric("signup_conversion").is_some()); + assert!(events.get_metric("signup_retention").is_some()); + assert!(events.get_metric("multi_platform_users").is_some()); + + let orders = graph.get_model("orders").unwrap(); + assert!(orders.get_metric("signup_conversion").is_none()); + assert!(orders.get_metric("signup_retention").is_none()); + assert!(orders.get_metric("multi_platform_users").is_none()); + } + + #[test] + fn test_load_from_string_parses_model_embedded_sql_definitions() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: yaml_metric + agg: sum + sql: amount + sql_metrics: | + METRIC ( + name sql_metric, + agg count + ); + sql_segments: | + SEGMENT ( + name completed, + sql status = 'completed' + ); +"#; + + let graph = load_from_string(yaml).unwrap(); + let orders = graph.get_model("orders").unwrap(); + assert!(orders.get_metric("yaml_metric").is_some()); + assert!(orders.get_metric("sql_metric").is_some()); + assert!(orders.get_segment("completed").is_some()); + } + + #[test] + fn test_load_from_string_with_metadata_parses_graph_level_sql_metrics() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount +sql_metrics: | + METRIC ( + name total_revenue, + type derived, + sql orders.revenue + ); +"#; + + let loaded = load_from_string_with_metadata(yaml).unwrap(); + assert_eq!(loaded.top_level_metrics.len(), 1); + assert_eq!(loaded.top_level_metrics[0].name, "total_revenue"); + let orders = loaded.graph.get_model("orders").unwrap(); + assert!(orders.get_metric("total_revenue").is_some()); + } + #[test] fn test_infer_relationships() { let mut models = HashMap::new(); @@ -372,4 +1343,107 @@ models: assert!(us_orders.get_metric("revenue").is_some()); // inherited assert!(us_orders.get_metric("order_count").is_some()); // own } + + #[test] + fn test_load_from_string_parses_top_level_parameters() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id +parameters: + - name: status + type: string + default_value: pending + - name: start_date + type: date + default_to_today: true +"#; + + let graph = load_from_string(yaml).unwrap(); + assert!(graph.get_parameter("status").is_some()); + assert!(graph.get_parameter("start_date").is_some()); + } + + #[test] + fn test_load_from_string_rejects_duplicate_parameters() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id +parameters: + - name: status + type: string + - name: status + type: string +"#; + + let err = load_from_string(yaml).unwrap_err(); + assert!(err + .to_string() + .contains("Parameter 'status' already exists")); + } + + #[test] + fn test_load_from_string_rejects_duplicate_models() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + - name: orders + table: orders_v2 + primary_key: id +"#; + + let err = load_from_string(yaml).unwrap_err(); + assert!(err + .to_string() + .contains("Duplicate model 'orders' in config")); + } + + #[test] + fn test_load_from_string_rejects_invalid_parameter_type() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id +parameters: + - name: status + type: enum +"#; + + let err = load_from_string(yaml).unwrap_err(); + assert!(err.to_string().contains("YAML parse error")); + } + + #[test] + fn test_load_from_string_parses_many_to_many_through_relationship() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + relationships: + - name: products + type: many_to_many + through: order_items + through_foreign_key: order_id + related_foreign_key: product_id + primary_key: product_id + - name: order_items + table: order_items + - name: products + table: products + primary_key: product_id +"#; + + let graph = load_from_string(yaml).unwrap(); + let path = graph.find_join_path("orders", "products").unwrap(); + assert_eq!(path.steps.len(), 2); + assert_eq!(path.steps[0].to_model, "order_items"); + assert_eq!(path.steps[1].to_model, "products"); + } } diff --git a/sidemantic-rs/src/config/mod.rs b/sidemantic-rs/src/config/mod.rs index 9bf8635b..be596577 100644 --- a/sidemantic-rs/src/config/mod.rs +++ b/sidemantic-rs/src/config/mod.rs @@ -7,6 +7,13 @@ mod loader; mod schema; mod sql_parser; -pub use loader::{load_from_directory, load_from_file, load_from_string, ConfigFormat}; -pub use schema::{CubeConfig, SidemanticConfig}; -pub use sql_parser::{parse_sql_definitions, parse_sql_model}; +pub use loader::{ + load_from_directory, load_from_directory_with_metadata, load_from_file, + load_from_file_with_metadata, load_from_sql_string_with_metadata, load_from_string, + load_from_string_with_metadata, ConfigFormat, LoadedGraphMetadata, LoadedModelSource, +}; +pub use schema::{CubeConfig, ModelConfig, SidemanticConfig}; +pub use sql_parser::{ + parse_sql_definitions, parse_sql_graph_definitions, parse_sql_graph_definitions_extended, + parse_sql_model, parse_sql_statement_blocks, +}; diff --git a/sidemantic-rs/src/config/schema.rs b/sidemantic-rs/src/config/schema.rs index 3ff35945..009044c7 100644 --- a/sidemantic-rs/src/config/schema.rs +++ b/sidemantic-rs/src/config/schema.rs @@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize}; use crate::core::{ - Aggregation, Dimension, DimensionType, Metric, MetricType, Model, Relationship, - RelationshipType, Segment, + Aggregation, CohortInnerMetric, ComparisonCalculation, ComparisonType, Dimension, + DimensionType, Metric, MetricType, Model, Parameter, ParameterType, PreAggregation, + PreAggregationType, RefreshKey, Relationship, RelationshipType, Segment, TimeGrain, }; // ============================================================================= @@ -21,6 +22,8 @@ pub struct SidemanticConfig { /// Graph-level metrics (can reference model metrics) #[serde(default)] pub metrics: Vec, + #[serde(default)] + pub parameters: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -30,8 +33,13 @@ pub struct ModelConfig { pub extends: Option, pub table: Option, pub sql: Option, - #[serde(default = "default_primary_key")] - pub primary_key: String, + pub source_uri: Option, + #[serde(default = "default_primary_key_config")] + pub primary_key: KeyConfig, + #[serde(default)] + pub primary_key_columns: Option>, + #[serde(default)] + pub unique_keys: Option>>, pub description: Option, #[serde(default)] pub dimensions: Vec, @@ -41,12 +49,36 @@ pub struct ModelConfig { pub relationships: Vec, #[serde(default)] pub segments: Vec, + #[serde(default)] + pub pre_aggregations: Vec, + pub default_time_dimension: Option, + pub default_grain: Option, } fn default_primary_key() -> String { "id".to_string() } +fn default_primary_key_config() -> KeyConfig { + KeyConfig::Single(default_primary_key()) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum KeyConfig { + Single(String), + Multiple(Vec), +} + +impl KeyConfig { + fn into_columns(self) -> Vec { + match self { + Self::Single(value) => vec![value], + Self::Multiple(values) => values, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DimensionConfig { pub name: String, @@ -54,8 +86,13 @@ pub struct DimensionConfig { pub dim_type: Option, pub sql: Option, pub granularity: Option, + pub supported_granularities: Option>, pub description: Option, pub label: Option, + pub format: Option, + pub value_format_name: Option, + pub parent: Option, + pub window: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -67,19 +104,61 @@ pub struct MetricConfig { pub sql: Option, pub numerator: Option, pub denominator: Option, + pub offset_window: Option, + pub window: Option, + pub grain_to_date: Option, + pub window_expression: Option, + pub window_frame: Option, + pub window_order: Option, + pub base_metric: Option, + pub comparison_type: Option, + pub time_offset: Option, + pub calculation: Option, + pub entity: Option, + pub base_event: Option, + pub conversion_event: Option, + pub conversion_window: Option, + pub steps: Option>, + pub cohort_event: Option, + pub activity_event: Option, + pub periods: Option, + pub retention_granularity: Option, + pub granularity: Option, + pub inner_metrics: Option>, + pub entity_dimensions: Option>, + pub having: Option, + pub fill_nulls_with: Option, + pub format: Option, + pub value_format_name: Option, + pub drill_fields: Option>, + pub non_additive_dimension: Option, #[serde(default)] pub filters: Vec, pub description: Option, pub label: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CohortInnerMetricConfig { + pub name: String, + pub agg: Option, + pub sql: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RelationshipConfig { pub name: String, #[serde(default, rename = "type")] pub rel_type: Option, - pub foreign_key: Option, - pub primary_key: Option, + pub foreign_key: Option, + #[serde(default)] + pub foreign_key_columns: Option>, + pub primary_key: Option, + #[serde(default)] + pub primary_key_columns: Option>, + pub through: Option, + pub through_foreign_key: Option, + pub related_foreign_key: Option, /// Custom SQL join condition using {from} and {to} placeholders pub sql: Option, } @@ -91,6 +170,75 @@ pub struct SegmentConfig { pub description: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PreAggregationConfig { + pub name: String, + #[serde(default, rename = "type")] + pub preagg_type: Option, + #[serde(default)] + pub measures: Option>, + #[serde(default)] + pub dimensions: Option>, + #[serde(default)] + pub time_dimension: Option, + #[serde(default)] + pub granularity: Option, + #[serde(default)] + pub partition_granularity: Option, + #[serde(default)] + pub build_range_start: Option, + #[serde(default)] + pub build_range_end: Option, + #[serde(default = "default_scheduled_refresh")] + pub scheduled_refresh: bool, + #[serde(default)] + pub refresh_key: Option, + #[serde(default)] + pub indexes: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RefreshKeyConfig { + #[serde(default)] + pub every: Option, + #[serde(default)] + pub sql: Option, + #[serde(default)] + pub incremental: bool, + #[serde(default)] + pub update_window: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IndexConfig { + pub name: String, + #[serde(default)] + pub columns: Vec, + #[serde(default = "default_index_type", rename = "type")] + pub index_type: String, +} + +fn default_index_type() -> String { + "regular".to_string() +} + +fn default_scheduled_refresh() -> bool { + true +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParameterConfig { + pub name: String, + #[serde(rename = "type")] + pub parameter_type: ParameterType, + pub description: Option, + pub label: Option, + pub default_value: Option, + pub allowed_values: Option>, + #[serde(default)] + pub default_to_today: bool, +} + // ============================================================================= // Cube.js Format // ============================================================================= @@ -155,20 +303,45 @@ pub struct CubeSegment { // ============================================================================= impl SidemanticConfig { + /// Convert to core models, top-level metrics, and top-level parameters. + pub fn into_parts(self) -> (Vec, Vec, Vec) { + let models = self.models.into_iter().map(|m| m.into_model()).collect(); + let metrics = self.metrics.into_iter().map(|m| m.into_metric()).collect(); + let parameters = self + .parameters + .into_iter() + .map(|p| p.into_parameter()) + .collect(); + (models, metrics, parameters) + } + /// Convert to list of core Model types pub fn into_models(self) -> Vec { - self.models.into_iter().map(|m| m.into_model()).collect() + self.into_parts().0 } } impl ModelConfig { /// Convert to core Model type pub fn into_model(self) -> Model { + let primary_key_columns = self + .primary_key_columns + .filter(|columns| !columns.is_empty()) + .unwrap_or_else(|| self.primary_key.into_columns()); + let primary_key = primary_key_columns + .first() + .cloned() + .unwrap_or_else(default_primary_key); + Model { name: self.name, table: self.table, sql: self.sql, - primary_key: self.primary_key, + source_uri: self.source_uri, + extends: self.extends, + primary_key, + primary_key_columns, + unique_keys: self.unique_keys, dimensions: self .dimensions .into_iter() @@ -185,6 +358,13 @@ impl ModelConfig { .into_iter() .map(|s| s.into_segment()) .collect(), + pre_aggregations: self + .pre_aggregations + .into_iter() + .map(|p| p.into_pre_aggregation()) + .collect(), + default_time_dimension: self.default_time_dimension, + default_grain: self.default_grain, label: None, description: self.description, } @@ -205,8 +385,13 @@ impl DimensionConfig { r#type: dim_type, sql: self.sql, granularity: self.granularity, + supported_granularities: self.supported_granularities, label: self.label, description: self.description, + format: self.format, + value_format_name: self.value_format_name, + parent: self.parent, + window: self.window, } } } @@ -216,10 +401,40 @@ impl MetricConfig { let metric_type = match self.metric_type.as_deref() { Some("derived") => MetricType::Derived, Some("ratio") => MetricType::Ratio, - _ => MetricType::Simple, + Some("cumulative") => MetricType::Cumulative, + Some("time_comparison") => MetricType::TimeComparison, + Some("conversion") => MetricType::Conversion, + Some("retention") => MetricType::Retention, + Some("cohort") => MetricType::Cohort, + _ => { + if self.agg.is_none() && self.sql.is_some() { + MetricType::Derived + } else { + MetricType::Simple + } + } }; let agg = self.agg.as_deref().map(parse_aggregation); + let grain_to_date = self.grain_to_date.as_deref().and_then(parse_time_grain); + let comparison_type = self + .comparison_type + .as_deref() + .and_then(parse_comparison_type); + let calculation = self + .calculation + .as_deref() + .and_then(parse_comparison_calculation); + let inner_metrics = self.inner_metrics.map(|items| { + items + .into_iter() + .map(|item| CohortInnerMetric { + name: item.name, + agg: item.agg.as_deref().map(parse_aggregation), + sql: item.sql, + }) + .collect() + }); Metric { name: self.name, @@ -228,17 +443,36 @@ impl MetricConfig { sql: self.sql, numerator: self.numerator, denominator: self.denominator, + offset_window: self.offset_window, filters: self.filters, label: self.label, description: self.description, - window: None, - grain_to_date: None, - base_metric: None, - comparison_type: None, - time_offset: None, - calculation: None, - fill_nulls_with: None, - format: None, + window: self.window, + grain_to_date, + window_expression: self.window_expression, + window_frame: self.window_frame, + window_order: self.window_order, + base_metric: self.base_metric, + comparison_type, + time_offset: self.time_offset, + calculation, + entity: self.entity, + base_event: self.base_event, + conversion_event: self.conversion_event, + conversion_window: self.conversion_window, + steps: self.steps, + cohort_event: self.cohort_event, + activity_event: self.activity_event, + periods: self.periods, + retention_granularity: self.retention_granularity.or(self.granularity), + inner_metrics, + entity_dimensions: self.entity_dimensions, + having: self.having, + fill_nulls_with: self.fill_nulls_with, + format: self.format, + value_format_name: self.value_format_name, + drill_fields: self.drill_fields, + non_additive_dimension: self.non_additive_dimension, } } } @@ -252,11 +486,29 @@ impl RelationshipConfig { _ => RelationshipType::ManyToOne, }; + let foreign_key_columns = self + .foreign_key_columns + .filter(|columns| !columns.is_empty()) + .or_else(|| self.foreign_key.clone().map(KeyConfig::into_columns)); + let primary_key_columns = self + .primary_key_columns + .filter(|columns| !columns.is_empty()) + .or_else(|| self.primary_key.clone().map(KeyConfig::into_columns)); + Relationship { name: self.name, r#type: rel_type, - foreign_key: self.foreign_key, - primary_key: self.primary_key, + foreign_key: foreign_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + foreign_key_columns, + primary_key: primary_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + primary_key_columns, + through: self.through, + through_foreign_key: self.through_foreign_key, + related_foreign_key: self.related_foreign_key, sql: self.sql, } } @@ -273,6 +525,60 @@ impl SegmentConfig { } } +impl PreAggregationConfig { + fn into_pre_aggregation(self) -> PreAggregation { + let preagg_type = match self.preagg_type.as_deref() { + Some("original_sql") => PreAggregationType::OriginalSql, + Some("rollup_join") => PreAggregationType::RollupJoin, + Some("lambda") => PreAggregationType::Lambda, + _ => PreAggregationType::Rollup, + }; + + PreAggregation { + name: self.name, + preagg_type, + measures: self.measures, + dimensions: self.dimensions, + time_dimension: self.time_dimension, + granularity: self.granularity, + partition_granularity: self.partition_granularity, + build_range_start: self.build_range_start, + build_range_end: self.build_range_end, + scheduled_refresh: self.scheduled_refresh, + refresh_key: self.refresh_key.map(|r| RefreshKey { + every: r.every, + sql: r.sql, + incremental: r.incremental, + update_window: r.update_window, + }), + indexes: self.indexes.map(|indexes| { + indexes + .into_iter() + .map(|idx| crate::core::Index { + name: idx.name, + columns: idx.columns, + index_type: idx.index_type, + }) + .collect() + }), + } + } +} + +impl ParameterConfig { + fn into_parameter(self) -> Parameter { + Parameter { + name: self.name, + parameter_type: self.parameter_type, + description: self.description, + label: self.label, + default_value: self.default_value, + allowed_values: self.allowed_values, + default_to_today: self.default_to_today, + } + } +} + // Cube.js conversions impl CubeConfig { @@ -291,7 +597,11 @@ impl CubeDefinition { name: self.name, table: self.sql_table, sql: self.sql, - primary_key, + source_uri: None, + extends: None, + primary_key: primary_key.clone(), + primary_key_columns: vec![primary_key], + unique_keys: None, dimensions: self .dimensions .into_iter() @@ -304,6 +614,9 @@ impl CubeDefinition { .into_iter() .map(|s| s.into_segment()) .collect(), + pre_aggregations: Vec::new(), + default_time_dimension: None, + default_grain: None, label: None, description: self.description, } @@ -327,8 +640,13 @@ impl CubeDimension { r#type: dim_type, sql, granularity: None, + supported_granularities: None, label: self.title, description: self.description, + format: None, + value_format_name: None, + parent: None, + window: None, } } } @@ -366,17 +684,36 @@ impl CubeMeasure { sql, numerator: None, denominator: None, + offset_window: None, filters, label: self.title, description: self.description, window: None, grain_to_date: None, + window_expression: None, + window_frame: None, + window_order: None, base_metric: None, comparison_type: None, time_offset: None, calculation: None, + entity: None, + base_event: None, + conversion_event: None, + conversion_window: None, + steps: None, + cohort_event: None, + activity_event: None, + periods: None, + retention_granularity: None, + inner_metrics: None, + entity_dimensions: None, + having: None, fill_nulls_with: None, format: None, + value_format_name: None, + drill_fields: None, + non_additive_dimension: None, } } } @@ -412,6 +749,38 @@ fn parse_aggregation(s: &str) -> Aggregation { } } +fn parse_time_grain(s: &str) -> Option { + match s.to_lowercase().as_str() { + "day" => Some(TimeGrain::Day), + "week" => Some(TimeGrain::Week), + "month" => Some(TimeGrain::Month), + "quarter" => Some(TimeGrain::Quarter), + "year" => Some(TimeGrain::Year), + _ => None, + } +} + +fn parse_comparison_type(s: &str) -> Option { + match s.to_lowercase().as_str() { + "yoy" => Some(ComparisonType::Yoy), + "mom" => Some(ComparisonType::Mom), + "wow" => Some(ComparisonType::Wow), + "dod" => Some(ComparisonType::Dod), + "qoq" => Some(ComparisonType::Qoq), + "prior_period" => Some(ComparisonType::PriorPeriod), + _ => None, + } +} + +fn parse_comparison_calculation(s: &str) -> Option { + match s.to_lowercase().as_str() { + "difference" => Some(ComparisonCalculation::Difference), + "percent_change" => Some(ComparisonCalculation::PercentChange), + "ratio" => Some(ComparisonCalculation::Ratio), + _ => None, + } +} + /// Strip ${CUBE}. prefix from SQL expressions fn strip_cube_placeholder(sql: &str) -> String { sql.replace("${CUBE}.", "").replace("${CUBE}", "") @@ -441,17 +810,24 @@ models: segments: - name: completed sql: "{model}.status = 'completed'" +parameters: + - name: status + type: string + default_value: pending "#; let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); assert_eq!(config.models.len(), 1); + assert_eq!(config.parameters.len(), 1); - let models = config.into_models(); + let (models, _, parameters) = config.into_parts(); let orders = &models[0]; assert_eq!(orders.name, "orders"); assert_eq!(orders.dimensions.len(), 2); assert_eq!(orders.metrics.len(), 1); assert_eq!(orders.segments.len(), 1); + assert_eq!(parameters.len(), 1); + assert_eq!(parameters[0].name, "status"); } #[test] @@ -498,6 +874,84 @@ cubes: assert_eq!(orders.segments[0].sql, "{model}.status = 'completed'"); } + #[test] + fn test_parse_many_to_many_relationship_fields() { + let yaml = r#" +models: + - name: orders + table: orders + relationships: + - name: products + type: many_to_many + through: order_items + through_foreign_key: order_id + related_foreign_key: product_id + primary_key: product_id + - name: order_items + table: order_items + - name: products + table: products + primary_key: product_id +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (models, _, _) = config.into_parts(); + + let orders = models.iter().find(|m| m.name == "orders").unwrap(); + let rel = orders + .relationships + .iter() + .find(|r| r.name == "products") + .unwrap(); + assert_eq!(rel.r#type, RelationshipType::ManyToMany); + assert_eq!(rel.through.as_deref(), Some("order_items")); + assert_eq!(rel.through_foreign_key.as_deref(), Some("order_id")); + assert_eq!(rel.related_foreign_key.as_deref(), Some("product_id")); + assert_eq!(rel.primary_key.as_deref(), Some("product_id")); + } + + #[test] + fn test_parse_native_yaml_composite_keys() { + let yaml = r#" +models: + - name: order_items + table: order_items + primary_key: [order_id, item_id] + - name: shipments + table: shipments + primary_key: shipment_id + relationships: + - name: order_items + type: many_to_one + foreign_key: [order_id, item_id] + primary_key: [order_id, item_id] +"#; + + let config: SidemanticConfig = serde_yaml::from_str(yaml).unwrap(); + let (models, _, _) = config.into_parts(); + + let order_items = models.iter().find(|m| m.name == "order_items").unwrap(); + assert_eq!( + order_items.primary_key_columns, + vec!["order_id".to_string(), "item_id".to_string()] + ); + + let shipments = models.iter().find(|m| m.name == "shipments").unwrap(); + let rel = shipments + .relationships + .iter() + .find(|r| r.name == "order_items") + .unwrap(); + assert_eq!( + rel.foreign_key_columns.as_ref().unwrap(), + &vec!["order_id".to_string(), "item_id".to_string()] + ); + assert_eq!( + rel.primary_key_columns.as_ref().unwrap(), + &vec!["order_id".to_string(), "item_id".to_string()] + ); + } + #[test] fn test_strip_cube_placeholder() { assert_eq!(strip_cube_placeholder("${CUBE}.status"), "status"); diff --git a/sidemantic-rs/src/config/sql_parser.rs b/sidemantic-rs/src/config/sql_parser.rs index 3fb99ef9..d5feed3a 100644 --- a/sidemantic-rs/src/config/sql_parser.rs +++ b/sidemantic-rs/src/config/sql_parser.rs @@ -21,19 +21,30 @@ use nom::{ bytes::complete::{tag, tag_no_case, take_until, take_while, take_while1}, character::complete::{char, multispace0, multispace1}, combinator::{map, opt, recognize}, + error::{Error as NomError, ErrorKind}, multi::separated_list1, sequence::{delimited, pair, tuple}, IResult, }; -use polyglot_sql::{DialectType, Expression}; +use polyglot_sql::Expression; +#[cfg(not(target_arch = "wasm32"))] +use polyglot_sql::{parse as polyglot_parse, DialectType}; use crate::core::{ - Aggregation, Dimension, DimensionType, Metric, MetricType, Model, Relationship, - RelationshipType, Segment, + Aggregation, CohortInnerMetric, ComparisonCalculation, ComparisonType, Dimension, + DimensionType, Index, Metric, MetricType, Model, Parameter, ParameterType, PreAggregation, + PreAggregationType, RefreshKey, Relationship, RelationshipType, Segment, TimeGrain, }; use crate::error::{Result, SidemanticError}; +type SqlGraphDefinitionParts = ( + Vec, + Vec, + Vec, + Vec, +); + /// Property name aliases (SQL syntax -> Rust field name) fn resolve_alias(name: &str) -> &str { match name.to_lowercase().as_str() { @@ -52,6 +63,313 @@ enum Statement { Metric(HashMap), Segment(HashMap), Relationship(HashMap), + Parameter(HashMap), + PreAggregation(HashMap), +} + +/// Serializable SQL statement block payload for Python bridge consumers. +#[derive(Debug, Clone, serde::Serialize)] +pub struct SqlStatementBlock { + pub kind: String, + pub properties: HashMap, +} + +fn split_top_level(text: &str, delimiter: char) -> Vec { + let mut items = Vec::new(); + let mut depth: i32 = 0; + let mut in_quote: Option = None; + let mut escape = false; + let mut buf = String::new(); + + for c in text.chars() { + if let Some(q) = in_quote { + buf.push(c); + if escape { + escape = false; + continue; + } + if c == '\\' { + escape = true; + continue; + } + if c == q { + in_quote = None; + } + continue; + } + + if c == '\'' || c == '"' { + in_quote = Some(c); + buf.push(c); + continue; + } + + if c == '[' || c == '{' { + depth += 1; + } else if c == ']' || c == '}' { + depth = (depth - 1).max(0); + } + + if c == delimiter && depth == 0 { + let item = buf.trim().to_string(); + if !item.is_empty() { + items.push(item); + } + buf.clear(); + continue; + } + + buf.push(c); + } + + let trailing = buf.trim().to_string(); + if !trailing.is_empty() { + items.push(trailing); + } + + items +} + +fn split_key_value(text: &str) -> (String, String) { + let mut depth: i32 = 0; + let mut in_quote: Option = None; + + for (idx, c) in text.char_indices() { + if let Some(q) = in_quote { + if c == q { + in_quote = None; + } + continue; + } + + if c == '\'' || c == '"' { + in_quote = Some(c); + continue; + } + + if depth == 0 && (c == '[' || c == '{') && idx > 0 { + return ( + text[..idx].trim().to_string(), + text[idx..].trim().to_string(), + ); + } + + if c == '[' || c == '{' { + depth += 1; + continue; + } + if c == ']' || c == '}' { + depth = (depth - 1).max(0); + continue; + } + + if depth == 0 && (c == ':' || c == '=') { + return ( + text[..idx].trim().to_string(), + text[idx + 1..].trim().to_string(), + ); + } + + if depth == 0 && c.is_whitespace() { + return ( + text[..idx].trim().to_string(), + text[idx..].trim().to_string(), + ); + } + } + + (text.trim().to_string(), String::new()) +} + +fn parse_scalar_literal(value: &str) -> serde_json::Value { + if value.is_empty() { + return serde_json::Value::String(String::new()); + } + + if value.len() >= 2 { + let first = value.chars().next().unwrap_or_default(); + let last = value.chars().next_back().unwrap_or_default(); + if (first == '\'' || first == '"') && first == last { + let mut inner = value[1..value.len() - 1].to_string(); + if first == '\'' { + inner = inner.replace("''", "'"); + } + return serde_json::Value::String(inner); + } + } + + let lowered = value.to_lowercase(); + if lowered == "true" { + return serde_json::Value::Bool(true); + } + if lowered == "false" { + return serde_json::Value::Bool(false); + } + if lowered == "null" || lowered == "none" { + return serde_json::Value::Null; + } + + if let Ok(v) = value.parse::() { + return serde_json::json!(v); + } + if let Ok(v) = value.parse::() { + if let Some(num) = serde_json::Number::from_f64(v) { + return serde_json::Value::Number(num); + } + } + + serde_json::Value::String(value.to_string()) +} + +fn parse_literal(value: &str) -> serde_json::Value { + let raw = value.trim(); + if raw.is_empty() { + return serde_json::Value::String(String::new()); + } + + if raw.starts_with('[') && raw.ends_with(']') { + let inner = raw[1..raw.len() - 1].trim(); + let items = split_top_level(inner, ',') + .into_iter() + .map(|item| parse_literal(&item)) + .collect::>(); + return serde_json::Value::Array(items); + } + + if raw.starts_with('{') && raw.ends_with('}') { + let inner = raw[1..raw.len() - 1].trim(); + let mut object = serde_json::Map::new(); + for pair in split_top_level(inner, ',') { + if pair.is_empty() { + continue; + } + let (key, raw_value) = split_key_value(&pair); + if key.is_empty() { + continue; + } + let parsed_key = parse_scalar_literal(&key); + let key_str = match parsed_key { + serde_json::Value::String(s) => s, + _ => key, + }; + if raw_value.is_empty() { + object.insert(key_str, serde_json::Value::Bool(true)); + } else { + object.insert(key_str, parse_literal(&raw_value)); + } + } + return serde_json::Value::Object(object); + } + + parse_scalar_literal(raw) +} + +fn json_value_to_string(value: &serde_json::Value) -> Option { + match value { + serde_json::Value::Null => None, + serde_json::Value::String(v) => Some(v.clone()), + serde_json::Value::Bool(v) => Some(if *v { "true".into() } else { "false".into() }), + serde_json::Value::Number(v) => Some(v.to_string()), + _ => Some(value.to_string()), + } +} + +fn json_value_to_string_list(value: serde_json::Value) -> Vec { + match value { + serde_json::Value::Array(items) => items + .into_iter() + .filter_map(|item| json_value_to_string(&item)) + .collect(), + other => json_value_to_string(&other).into_iter().collect(), + } +} + +fn parse_key_columns(props: &HashMap, key: &str) -> Option> { + props.get(key).and_then(|value| { + let parsed = parse_literal(value); + let columns = json_value_to_string_list(parsed); + if columns.is_empty() { + None + } else { + Some(columns) + } + }) +} + +fn parse_metric_type(value: Option<&String>) -> MetricType { + match value.map(|s| s.to_lowercase()) { + Some(metric_type) if metric_type == "derived" => MetricType::Derived, + Some(metric_type) if metric_type == "ratio" => MetricType::Ratio, + Some(metric_type) if metric_type == "cumulative" => MetricType::Cumulative, + Some(metric_type) + if metric_type == "time_comparison" || metric_type == "timecomparison" => + { + MetricType::TimeComparison + } + Some(metric_type) if metric_type == "conversion" => MetricType::Conversion, + Some(metric_type) if metric_type == "retention" => MetricType::Retention, + Some(metric_type) if metric_type == "cohort" => MetricType::Cohort, + _ => MetricType::Simple, + } +} + +fn parse_metric_aggregation(value: Option<&String>) -> Option { + value.and_then(|agg_str| match agg_str.to_lowercase().as_str() { + "sum" => Some(Aggregation::Sum), + "count" => Some(Aggregation::Count), + "count_distinct" | "countdistinct" => Some(Aggregation::CountDistinct), + "avg" | "average" => Some(Aggregation::Avg), + "min" => Some(Aggregation::Min), + "max" => Some(Aggregation::Max), + "median" => Some(Aggregation::Median), + "expression" => Some(Aggregation::Expression), + _ => None, + }) +} + +fn parse_time_grain(value: Option<&String>) -> Option { + value.and_then(|grain| match grain.to_lowercase().as_str() { + "day" => Some(TimeGrain::Day), + "week" => Some(TimeGrain::Week), + "month" => Some(TimeGrain::Month), + "quarter" => Some(TimeGrain::Quarter), + "year" => Some(TimeGrain::Year), + _ => None, + }) +} + +fn parse_comparison_type(value: Option<&String>) -> Option { + value.and_then(|comparison| match comparison.to_lowercase().as_str() { + "yoy" => Some(ComparisonType::Yoy), + "mom" => Some(ComparisonType::Mom), + "wow" => Some(ComparisonType::Wow), + "dod" => Some(ComparisonType::Dod), + "qoq" => Some(ComparisonType::Qoq), + "prior_period" => Some(ComparisonType::PriorPeriod), + _ => None, + }) +} + +fn parse_comparison_calculation(value: Option<&String>) -> Option { + value.and_then(|calc| match calc.to_lowercase().as_str() { + "difference" => Some(ComparisonCalculation::Difference), + "percent_change" => Some(ComparisonCalculation::PercentChange), + "ratio" => Some(ComparisonCalculation::Ratio), + _ => None, + }) +} + +fn parse_parameter_type(value: Option<&String>) -> Option { + value.and_then( + |parameter_type| match parameter_type.to_lowercase().as_str() { + "string" => Some(ParameterType::String), + "number" => Some(ParameterType::Number), + "date" => Some(ParameterType::Date), + "unquoted" => Some(ParameterType::Unquoted), + "yesno" => Some(ParameterType::Yesno), + _ => None, + }, + ) } // ============================================================================ @@ -84,81 +402,56 @@ fn quoted_string(input: &str) -> IResult<&str, String> { ))(input) } -/// Parse expression with balanced parentheses: SUM(amount), CASE WHEN ... END -fn expression_with_parens(input: &str) -> IResult<&str, String> { - let (input, name) = identifier(input)?; - let (input, _) = multispace0(input)?; - let (input, _) = char('(')(input)?; - let (input, content) = parse_balanced_parens(input)?; - let (input, _) = char(')')(input)?; - Ok((input, format!("{name}({content})"))) -} - -/// Parse content with balanced parentheses -fn parse_balanced_parens(input: &str) -> IResult<&str, String> { - let mut result = String::new(); - let mut depth = 0; - let mut chars = input.char_indices().peekable(); - - while let Some((i, c)) = chars.next() { - match c { - '(' => { - depth += 1; - result.push(c); +/// Parse a raw value until the next top-level comma or ')' delimiter. +fn simple_value(input: &str) -> IResult<&str, String> { + let mut depth_paren: i32 = 0; + let mut depth_bracket: i32 = 0; + let mut depth_brace: i32 = 0; + let mut in_quote: Option = None; + let mut escape = false; + + for (idx, c) in input.char_indices() { + if let Some(q) = in_quote { + if escape { + escape = false; + continue; } - ')' => { - if depth == 0 { - return Ok((&input[i..], result)); - } - depth -= 1; - result.push(c); - } - '\'' => { - // Handle quoted string inside expression - result.push(c); - for (_, inner_c) in chars.by_ref() { - result.push(inner_c); - if inner_c == '\'' { - break; - } - } + if c == '\\' { + escape = true; + continue; + } + if c == q { + in_quote = None; } - _ => result.push(c), + continue; } - } - Ok(("", result)) -} - -/// Parse a simple value (identifier without following paren) -fn simple_value(input: &str) -> IResult<&str, String> { - // Take chars until we hit comma, closing paren, or whitespace followed by comma/paren - let mut result = String::new(); - let mut chars = input.char_indices().peekable(); + if c == '\'' || c == '"' { + in_quote = Some(c); + continue; + } - while let Some((i, c)) = chars.peek() { match c { - ',' | ')' => { - return Ok((&input[*i..], result.trim().to_string())); - } - ' ' | '\t' | '\n' | '\r' => { - // Look ahead to see if next non-ws is comma or paren - let rest = &input[*i..]; - let trimmed = rest.trim_start(); - if trimmed.starts_with(',') || trimmed.starts_with(')') { - return Ok((trimmed, result.trim().to_string())); + '(' => depth_paren += 1, + ')' => { + if depth_paren > 0 { + depth_paren -= 1; + } else if depth_bracket == 0 && depth_brace == 0 { + return Ok((&input[idx..], input[..idx].trim().to_string())); } - result.push(*c); - chars.next(); } - _ => { - result.push(*c); - chars.next(); + '[' => depth_bracket += 1, + ']' => depth_bracket = (depth_bracket - 1).max(0), + '{' => depth_brace += 1, + '}' => depth_brace = (depth_brace - 1).max(0), + ',' if depth_paren == 0 && depth_bracket == 0 && depth_brace == 0 => { + return Ok((&input[idx..], input[..idx].trim().to_string())); } + _ => {} } } - Ok(("", result.trim().to_string())) + Ok(("", input.trim().to_string())) } /// Parse a property value (quoted string, expression with parens, or simple value) @@ -170,12 +463,7 @@ fn property_value(input: &str) -> IResult<&str, String> { return Ok((rest, s)); } - // Try expression with parentheses - if let Ok((rest, expr)) = expression_with_parens(input) { - return Ok((rest, expr)); - } - - // Fall back to simple value + // Fall back to raw value parsing (supports nested (), [], and {}) simple_value(input) } @@ -237,7 +525,7 @@ fn simple_metric(input: &str) -> IResult<&str, Statement> { let (input, expr) = take_while(|c| c != ';')(input)?; let (input, _) = opt(char(';'))(input)?; - // Parse the expression using polyglot-sql to extract aggregation + // Parse the expression using sqlparser to extract aggregation let props = parse_metric_expression(name.trim(), expr.trim()); Ok((input, Statement::Metric(props))) } @@ -273,26 +561,12 @@ fn parse_metric_expression(name: &str, expr: &str) -> HashMap { let mut props = HashMap::new(); props.insert("name".to_string(), name.to_string()); - // Try to parse as SQL and extract aggregation - let sql = format!("SELECT {expr}"); - - if let Ok(expressions) = polyglot_sql::parse(&sql, DialectType::Generic) { - if let Some(Expression::Select(select)) = expressions.into_iter().next() { - if let Some(parsed_expr) = select.expressions.into_iter().next() { - // Unwrap alias if present - let parsed_expr = match parsed_expr { - Expression::Alias(a) => a.this, - other => other, - }; - if let Some((agg, inner_expr)) = extract_aggregation(&parsed_expr) { - props.insert("agg".to_string(), agg); - if !inner_expr.is_empty() { - props.insert("sql".to_string(), inner_expr); - } - return props; - } - } + if let Some((agg, inner_expr)) = extract_aggregation_with_polyglot(expr) { + props.insert("agg".to_string(), agg); + if !inner_expr.is_empty() { + props.insert("sql".to_string(), inner_expr); } + return props; } // Fall back to storing the whole expression as sql with "expression" type @@ -302,70 +576,116 @@ fn parse_metric_expression(name: &str, expr: &str) -> HashMap { props } -/// Generate a SQL string from a polyglot-sql expression -fn generate_expr_str(expr: &Expression) -> String { - polyglot_sql::generate(expr, DialectType::Generic).unwrap_or_default() +fn extract_aggregation_with_polyglot(expr: &str) -> Option<(String, String)> { + let sql = format!("SELECT {expr}"); + let statements = parse_polyglot_with_large_stack(sql)?; + let statement = statements.first()?; + let select = statement.as_select()?; + let parsed_expr = select.expressions.first()?; + extract_aggregation_polyglot_expr(parsed_expr) +} + +fn parse_polyglot_with_large_stack(sql: String) -> Option> { + #[cfg(target_arch = "wasm32")] + { + let _ = sql; + return None; + } + + #[cfg(not(target_arch = "wasm32"))] + { + std::thread::Builder::new() + .stack_size(16 * 1024 * 1024) + .spawn(move || polyglot_parse(&sql, DialectType::Generic).ok()) // stack-heavy parser path + .ok()? + .join() + .ok()? + } } -/// Extract aggregation function and inner expression from polyglot-sql AST -fn extract_aggregation(expr: &Expression) -> Option<(String, String)> { +fn extract_aggregation_polyglot_expr(expr: &Expression) -> Option<(String, String)> { match expr { - // Typed aggregate variants - Expression::Sum(f) => Some(("sum".to_string(), generate_expr_str(&f.this))), - Expression::Avg(f) => Some(("avg".to_string(), generate_expr_str(&f.this))), - Expression::Min(f) => Some(("min".to_string(), generate_expr_str(&f.this))), - Expression::Max(f) => Some(("max".to_string(), generate_expr_str(&f.this))), - Expression::Count(f) => { - if f.star { - Some(("count".to_string(), String::new())) - } else if let Some(ref inner) = f.this { - Some(("count".to_string(), generate_expr_str(inner))) - } else { - Some(("count".to_string(), String::new())) - } - } - // Generic aggregate function - Expression::AggregateFunction(f) => { - let func_name = f.name.to_lowercase(); - let agg = match func_name.as_str() { - "sum" => "sum", - "count" => "count", - "avg" | "average" => "avg", - "min" => "min", - "max" => "max", - "count_distinct" => "count_distinct", - _ => return None, - }; - let inner = if f.args.is_empty() { - String::new() - } else { - generate_expr_str(&f.args[0]) - }; - Some((agg.to_string(), inner)) - } - // Generic function (parser might use this for aggregate functions) - Expression::Function(f) => { - let func_name = f.name.to_lowercase(); - let agg = match func_name.as_str() { - "sum" => "sum", - "count" => "count", - "avg" | "average" => "avg", - "min" => "min", - "max" => "max", - "count_distinct" => "count_distinct", - _ => return None, - }; - let inner = if f.args.is_empty() || matches!(f.args[0], Expression::Star(_)) { - String::new() + Expression::Alias(alias) => extract_aggregation_polyglot_expr(&alias.this), + Expression::Sum(agg) => Some(( + "sum".to_string(), + extract_inner_expression_polyglot(&agg.this), + )), + Expression::Count(count) => { + if count.star { + return Some(("count".to_string(), String::new())); + } + + let agg_name = if count.distinct { + "count_distinct" } else { - generate_expr_str(&f.args[0]) + "count" }; - Some((agg.to_string(), inner)) + let inner = count + .this + .as_ref() + .map(extract_inner_expression_polyglot) + .unwrap_or_default(); + Some((agg_name.to_string(), inner)) + } + Expression::Avg(agg) => Some(( + "avg".to_string(), + extract_inner_expression_polyglot(&agg.this), + )), + Expression::Min(agg) => Some(( + "min".to_string(), + extract_inner_expression_polyglot(&agg.this), + )), + Expression::Max(agg) => Some(( + "max".to_string(), + extract_inner_expression_polyglot(&agg.this), + )), + Expression::AggregateFunction(func) => { + extract_aggregation_from_function_name_polyglot(&func.name, func.args.first()) + } + Expression::Function(func) => { + extract_aggregation_from_function_name_polyglot(&func.name, func.args.first()) } _ => None, } } +fn extract_aggregation_from_function_name_polyglot( + function_name: &str, + first_arg: Option<&Expression>, +) -> Option<(String, String)> { + let agg = match function_name.to_lowercase().as_str() { + "sum" => "sum", + "count" => "count", + "avg" | "average" => "avg", + "min" => "min", + "max" => "max", + "count_distinct" => "count_distinct", + _ => return None, + }; + + let inner = match first_arg { + Some(Expression::Star(_)) | None => String::new(), + Some(arg) => extract_inner_expression_polyglot(arg), + }; + + Some((agg.to_string(), inner)) +} + +fn extract_inner_expression_polyglot(expr: &Expression) -> String { + match expr { + Expression::Column(column) => { + if let Some(table) = &column.table { + format!("{}.{}", table.name, column.name.name) + } else { + column.name.name.clone() + } + } + Expression::Identifier(ident) => ident.name.clone(), + Expression::Star(_) => String::new(), + _ => String::new(), + } +} + /// Infer dimension type from expression fn infer_dimension_type(expr: &str) -> String { let expr_lower = expr.to_lowercase(); @@ -439,6 +759,8 @@ fn statement(input: &str) -> IResult<&str, Statement> { map(definition("METRIC"), Statement::Metric), map(definition("SEGMENT"), Statement::Segment), map(definition("RELATIONSHIP"), Statement::Relationship), + map(definition("PARAMETER"), Statement::Parameter), + map(definition("PRE_AGGREGATION"), Statement::PreAggregation), ))(input) } @@ -479,6 +801,9 @@ fn parse_file(input: &str) -> IResult<&str, Vec> { Err(_) => { // Skip unknown content until next statement or end if let Some(pos) = remaining.find(|c: char| c.is_alphabetic()) { + if pos == 0 { + return Err(nom::Err::Error(NomError::new(remaining, ErrorKind::Tag))); + } remaining = &remaining[pos..]; } else { break; @@ -504,6 +829,7 @@ pub fn parse_sql_model(sql: &str) -> Result { let mut metrics = Vec::new(); let mut segments = Vec::new(); let mut relationships = Vec::new(); + let mut pre_aggregations = Vec::new(); for stmt in statements { match stmt { @@ -530,6 +856,12 @@ pub fn parse_sql_model(sql: &str) -> Result { relationships.push(rel); } } + Statement::Parameter(_) => {} + Statement::PreAggregation(props) => { + if let Some(preagg) = build_pre_aggregation(&props) { + pre_aggregations.push(preagg); + } + } } } @@ -541,17 +873,58 @@ pub fn parse_sql_model(sql: &str) -> Result { model.metrics.extend(metrics); model.segments.extend(segments); model.relationships.extend(relationships); + model.pre_aggregations.extend(pre_aggregations); Ok(model) } +/// Parse SQL into statement blocks preserving high-level statement kinds/properties. +pub fn parse_sql_statement_blocks(sql: &str) -> Result> { + let (_, statements) = + parse_file(sql).map_err(|e| SidemanticError::Validation(format!("Parse error: {e}")))?; + + let mut blocks = Vec::with_capacity(statements.len()); + + for stmt in statements { + let (kind, properties) = match stmt { + Statement::Model(props) => ("model".to_string(), props), + Statement::Dimension(props) => ("dimension".to_string(), props), + Statement::Metric(props) => ("metric".to_string(), props), + Statement::Segment(props) => ("segment".to_string(), props), + Statement::Relationship(props) => ("relationship".to_string(), props), + Statement::Parameter(props) => ("parameter".to_string(), props), + Statement::PreAggregation(props) => ("pre_aggregation".to_string(), props), + }; + + blocks.push(SqlStatementBlock { kind, properties }); + } + + Ok(blocks) +} + /// Parse SQL definitions for metrics and segments only pub fn parse_sql_definitions(sql: &str) -> Result<(Vec, Vec)> { + let (metrics, segments, _) = parse_sql_graph_definitions(sql)?; + Ok((metrics, segments)) +} + +/// Parse SQL definitions for graph-level definitions (metrics, segments, parameters) +pub fn parse_sql_graph_definitions( + sql: &str, +) -> Result<(Vec, Vec, Vec)> { + let (metrics, segments, parameters, _) = parse_sql_graph_definitions_extended(sql)?; + Ok((metrics, segments, parameters)) +} + +/// Parse SQL definitions for graph-level definitions including pre-aggregations. +pub fn parse_sql_graph_definitions_extended(sql: &str) -> Result { let (_, statements) = parse_file(sql).map_err(|e| SidemanticError::Validation(format!("Parse error: {e}")))?; let mut metrics = Vec::new(); let mut segments = Vec::new(); + let mut parameters = Vec::new(); + let mut pre_aggregations = Vec::new(); for stmt in statements { match stmt { @@ -565,11 +938,21 @@ pub fn parse_sql_definitions(sql: &str) -> Result<(Vec, Vec)> { segments.push(seg); } } + Statement::Parameter(props) => { + if let Some(parameter) = build_parameter(&props) { + parameters.push(parameter); + } + } + Statement::PreAggregation(props) => { + if let Some(pre_aggregation) = build_pre_aggregation(&props) { + pre_aggregations.push(pre_aggregation); + } + } _ => {} } } - Ok((metrics, segments)) + Ok((metrics, segments, parameters, pre_aggregations)) } // ============================================================================ @@ -581,10 +964,16 @@ fn build_model(props: &HashMap) -> Result { .get("name") .ok_or_else(|| SidemanticError::Validation("MODEL requires 'name' property".into()))?; + let primary_key_columns = + parse_key_columns(props, "primary_key").unwrap_or_else(|| vec!["id".to_string()]); let mut model = Model::new( name, - props.get("primary_key").map(|s| s.as_str()).unwrap_or("id"), - ); + primary_key_columns + .first() + .map(|s| s.as_str()) + .unwrap_or("id"), + ) + .with_primary_key_columns(primary_key_columns); if let Some(table) = props.get("table") { model.table = Some(table.clone()); @@ -592,12 +981,44 @@ fn build_model(props: &HashMap) -> Result { if let Some(sql) = props.get("sql") { model.sql = Some(sql.clone()); } + if let Some(source_uri) = props.get("source_uri") { + model.source_uri = Some(source_uri.clone()); + } + if let Some(extends) = props.get("extends") { + model.extends = Some(extends.clone()); + } if let Some(desc) = props.get("description") { model.description = Some(desc.clone()); } if let Some(label) = props.get("label") { model.label = Some(label.clone()); } + if let Some(default_time_dimension) = props.get("default_time_dimension") { + model.default_time_dimension = Some(default_time_dimension.clone()); + } + if let Some(default_grain) = props.get("default_grain") { + model.default_grain = Some(default_grain.clone()); + } + if let Some(unique_keys) = props.get("unique_keys") { + let parsed = parse_literal(unique_keys); + if let serde_json::Value::Array(groups) = parsed { + let normalized = groups + .into_iter() + .filter_map(|group| match group { + serde_json::Value::Array(columns) => Some( + columns + .into_iter() + .filter_map(|column| json_value_to_string(&column)) + .collect::>(), + ), + _ => None, + }) + .collect::>(); + if !normalized.is_empty() { + model.unique_keys = Some(normalized); + } + } + } Ok(model) } @@ -622,12 +1043,31 @@ fn build_dimension(props: &HashMap) -> Option { if let Some(sql) = props.get("sql") { dim.sql = Some(sql.clone()); } + if let Some(granularity) = props.get("granularity") { + dim.granularity = Some(granularity.clone()); + } + if let Some(supported_granularities) = props.get("supported_granularities") { + let parsed = parse_literal(supported_granularities); + let values = json_value_to_string_list(parsed); + if !values.is_empty() { + dim.supported_granularities = Some(values); + } + } if let Some(desc) = props.get("description") { dim.description = Some(desc.clone()); } if let Some(label) = props.get("label") { dim.label = Some(label.clone()); } + if let Some(format) = props.get("format") { + dim.format = Some(format.clone()); + } + if let Some(value_format_name) = props.get("value_format_name") { + dim.value_format_name = Some(value_format_name.clone()); + } + if let Some(parent) = props.get("parent") { + dim.parent = Some(parent.clone()); + } Some(dim) } @@ -636,37 +1076,127 @@ fn build_metric(props: &HashMap) -> Option { let name = props.get("name")?; let mut metric = Metric::new(name); + metric.agg = None; + metric.r#type = parse_metric_type(props.get("type")); metric.sql = props.get("sql").cloned(); metric.numerator = props.get("numerator").cloned(); metric.denominator = props.get("denominator").cloned(); + metric.offset_window = props.get("offset_window").cloned(); + metric.window = props.get("window").cloned(); + metric.grain_to_date = parse_time_grain(props.get("grain_to_date")); + metric.window_expression = props.get("window_expression").cloned(); + metric.window_frame = props.get("window_frame").cloned(); + metric.window_order = props.get("window_order").cloned(); + metric.base_metric = props.get("base_metric").cloned(); + metric.comparison_type = parse_comparison_type(props.get("comparison_type")); + metric.time_offset = props.get("time_offset").cloned(); + metric.calculation = parse_comparison_calculation(props.get("calculation")); + metric.entity = props.get("entity").cloned(); + metric.base_event = props.get("base_event").cloned(); + metric.conversion_event = props.get("conversion_event").cloned(); + metric.conversion_window = props.get("conversion_window").cloned(); + if let Some(steps) = props.get("steps") { + let parsed = json_value_to_string_list(parse_literal(steps)); + if !parsed.is_empty() { + metric.steps = Some(parsed); + } + } + metric.cohort_event = props.get("cohort_event").cloned(); + metric.activity_event = props.get("activity_event").cloned(); + metric.periods = props.get("periods").and_then(|value| value.parse().ok()); + metric.retention_granularity = props + .get("retention_granularity") + .or_else(|| props.get("granularity")) + .cloned(); + if let Some(entity_dimensions) = props.get("entity_dimensions") { + let parsed = json_value_to_string_list(parse_literal(entity_dimensions)); + if !parsed.is_empty() { + metric.entity_dimensions = Some(parsed); + } + } + if let Some(inner_metrics) = props.get("inner_metrics") { + let parsed = parse_literal(inner_metrics); + if let serde_json::Value::Array(items) = parsed { + let inner = items + .into_iter() + .filter_map(|item| { + let serde_json::Value::Object(obj) = item else { + return None; + }; + let name = obj.get("name").and_then(json_value_to_string)?; + let agg = obj + .get("agg") + .and_then(json_value_to_string) + .and_then(|value| parse_metric_aggregation(Some(&value))); + let sql = obj.get("sql").and_then(json_value_to_string); + Some(CohortInnerMetric { name, agg, sql }) + }) + .collect::>(); + if !inner.is_empty() { + metric.inner_metrics = Some(inner); + } + } + } + metric.having = props.get("having").cloned(); metric.description = props.get("description").cloned(); metric.label = props.get("label").cloned(); metric.format = props.get("format").cloned(); + metric.value_format_name = props.get("value_format_name").cloned(); + metric.non_additive_dimension = props.get("non_additive_dimension").cloned(); - if let Some(agg_str) = props.get("agg") { - metric.agg = match agg_str.to_lowercase().as_str() { - "sum" => Some(Aggregation::Sum), - "count" => Some(Aggregation::Count), - "count_distinct" | "countdistinct" => Some(Aggregation::CountDistinct), - "avg" | "average" => Some(Aggregation::Avg), - "min" => Some(Aggregation::Min), - "max" => Some(Aggregation::Max), - "expression" => Some(Aggregation::Expression), - _ => None, - }; + if let Some(fill_nulls_with) = props.get("fill_nulls_with") { + let parsed = parse_literal(fill_nulls_with); + if !parsed.is_null() { + metric.fill_nulls_with = Some(parsed); + } + } + + if let Some(filters) = props.get("filters") { + metric.filters = json_value_to_string_list(parse_literal(filters)); + } + if let Some(drill_fields) = props.get("drill_fields") { + metric.drill_fields = Some(json_value_to_string_list(parse_literal(drill_fields))); } - if metric.numerator.is_some() && metric.denominator.is_some() { + metric.agg = parse_metric_aggregation(props.get("agg")); + + if metric.agg.is_none() + && matches!( + metric.r#type, + MetricType::Simple | MetricType::Cumulative | MetricType::Derived + ) + { + if let Some(sql) = metric.sql.as_deref() { + if let Some((agg, inner_expr)) = extract_aggregation_with_polyglot(sql) { + metric.agg = parse_metric_aggregation(Some(&agg)); + metric.sql = if inner_expr.is_empty() { + Some("*".to_string()) + } else { + Some(inner_expr) + }; + } + } + } + + if matches!(metric.r#type, MetricType::Simple) + && metric.numerator.is_some() + && metric.denominator.is_some() + { metric.r#type = MetricType::Ratio; - } else if metric - .sql - .as_ref() - .map(|s| s.contains('{')) - .unwrap_or(false) + } else if matches!(metric.r#type, MetricType::Simple) + && metric + .sql + .as_ref() + .map(|s| s.contains('{')) + .unwrap_or(false) { metric.r#type = MetricType::Derived; } + if !matches!(metric.r#type, MetricType::Simple | MetricType::Cohort) { + metric.agg = None; + } + Some(metric) } @@ -699,15 +1229,153 @@ fn build_relationship(props: &HashMap) -> Option { _ => RelationshipType::ManyToOne, }; + let foreign_key_columns = parse_key_columns(props, "foreign_key"); + let primary_key_columns = parse_key_columns(props, "primary_key"); + Some(Relationship { name: name.clone(), r#type: rtype, - foreign_key: props.get("foreign_key").cloned(), - primary_key: props.get("primary_key").cloned(), + foreign_key: foreign_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + foreign_key_columns, + primary_key: primary_key_columns + .as_ref() + .and_then(|columns| columns.first().cloned()), + primary_key_columns, + through: props.get("through").cloned(), + through_foreign_key: props.get("through_foreign_key").cloned(), + related_foreign_key: props.get("related_foreign_key").cloned(), sql: props.get("sql").cloned(), }) } +fn build_pre_aggregation(props: &HashMap) -> Option { + let name = props.get("name")?; + + let preagg_type = match props.get("type").map(|t| t.to_lowercase()) { + Some(kind) if kind == "original_sql" => PreAggregationType::OriginalSql, + Some(kind) if kind == "rollup_join" => PreAggregationType::RollupJoin, + Some(kind) if kind == "lambda" => PreAggregationType::Lambda, + _ => PreAggregationType::Rollup, + }; + + let measures = props + .get("measures") + .map(|value| json_value_to_string_list(parse_literal(value))); + let dimensions = props + .get("dimensions") + .map(|value| json_value_to_string_list(parse_literal(value))); + + let scheduled_refresh = props + .get("scheduled_refresh") + .map(|value| match parse_literal(value) { + serde_json::Value::Bool(v) => v, + serde_json::Value::String(v) => v.eq_ignore_ascii_case("true"), + _ => true, + }) + .unwrap_or(true); + + let refresh_key = props + .get("refresh_key") + .and_then(|value| match parse_literal(value) { + serde_json::Value::Object(map) => Some(RefreshKey { + every: map.get("every").and_then(json_value_to_string), + sql: map.get("sql").and_then(json_value_to_string), + incremental: map + .get("incremental") + .and_then(|v| match v { + serde_json::Value::Bool(b) => Some(*b), + serde_json::Value::String(s) => Some(s.eq_ignore_ascii_case("true")), + _ => None, + }) + .unwrap_or(false), + update_window: map.get("update_window").and_then(json_value_to_string), + }), + _ => None, + }); + + let indexes = props + .get("indexes") + .and_then(|value| match parse_literal(value) { + serde_json::Value::Array(items) => { + let parsed = items + .into_iter() + .filter_map(|item| match item { + serde_json::Value::Object(map) => { + let name = map.get("name").and_then(json_value_to_string)?; + let columns = map + .get("columns") + .cloned() + .map(json_value_to_string_list) + .unwrap_or_default(); + let index_type = map + .get("type") + .and_then(json_value_to_string) + .unwrap_or_else(|| "regular".to_string()); + Some(Index { + name, + columns, + index_type, + }) + } + _ => None, + }) + .collect::>(); + Some(parsed) + } + _ => None, + }); + + Some(PreAggregation { + name: name.clone(), + preagg_type, + measures, + dimensions, + time_dimension: props.get("time_dimension").cloned(), + granularity: props.get("granularity").cloned(), + partition_granularity: props.get("partition_granularity").cloned(), + build_range_start: props.get("build_range_start").cloned(), + build_range_end: props.get("build_range_end").cloned(), + scheduled_refresh, + refresh_key, + indexes, + }) +} + +fn build_parameter(props: &HashMap) -> Option { + let name = props.get("name")?; + let parameter_type = parse_parameter_type(props.get("type"))?; + + let default_value = props.get("default_value").map(|value| parse_literal(value)); + + let allowed_values = props + .get("allowed_values") + .map(|value| match parse_literal(value) { + serde_json::Value::Array(items) => items, + other => vec![other], + }); + + let default_to_today = props + .get("default_to_today") + .map(|value| match parse_literal(value) { + serde_json::Value::Bool(v) => v, + serde_json::Value::String(v) => v.eq_ignore_ascii_case("true"), + _ => false, + }) + .unwrap_or(false); + + Some(Parameter { + name: name.clone(), + parameter_type, + description: props.get("description").cloned(), + label: props.get("label").cloned(), + default_value, + allowed_values, + default_to_today, + }) +} + #[cfg(test)] mod tests { use super::*; @@ -767,7 +1435,30 @@ mod tests { let model = parse_sql_model(sql).unwrap(); let revenue = model.get_metric("revenue").unwrap(); - assert_eq!(revenue.sql, Some("SUM(amount)".to_string())); + assert_eq!(revenue.agg, Some(Aggregation::Sum)); + assert_eq!(revenue.sql, Some("amount".to_string())); + } + + #[test] + fn test_parse_cohort_metric_preserves_outer_aggregation() { + let sql = r#" + MODEL (name events, table events); + METRIC ( + name scored_users, + type cohort, + entity user_id, + inner_metrics [{ name total_score, agg sum, sql score }], + having total_score > 10, + agg avg, + sql cohort_sub.total_score + ); + "#; + + let model = parse_sql_model(sql).unwrap(); + let metric = model.get_metric("scored_users").unwrap(); + assert_eq!(metric.r#type, MetricType::Cohort); + assert_eq!(metric.agg, Some(Aggregation::Avg)); + assert_eq!(metric.sql, Some("cohort_sub.total_score".to_string())); } #[test] @@ -881,4 +1572,158 @@ mod tests { assert_eq!(model.metrics.len(), 2); assert_eq!(model.dimensions.len(), 2); } + + #[test] + fn test_parse_sql_graph_definitions_with_parameter() { + let sql = r#" + METRIC (name revenue, agg sum, sql amount); + SEGMENT (name completed, sql status = 'completed'); + PARAMETER ( + name region, + type string, + allowed_values [us, eu], + default_value 'us', + default_to_today false + ); + "#; + + let (metrics, segments, parameters) = parse_sql_graph_definitions(sql).unwrap(); + assert_eq!(metrics.len(), 1); + assert_eq!(segments.len(), 1); + assert_eq!(parameters.len(), 1); + + let parameter = ¶meters[0]; + assert_eq!(parameter.name, "region"); + assert_eq!(parameter.parameter_type, ParameterType::String); + assert_eq!( + parameter.allowed_values, + Some(vec![ + serde_json::Value::String("us".to_string()), + serde_json::Value::String("eu".to_string()) + ]) + ); + assert_eq!( + parameter.default_value, + Some(serde_json::Value::String("us".to_string())) + ); + } + + #[test] + fn test_parse_model_with_extended_metadata_fields() { + let sql = r#" + MODEL ( + name orders, + table orders, + source_uri s3://warehouse/orders, + extends base_orders, + default_time_dimension order_date, + default_grain day + ); + DIMENSION ( + name order_date, + type time, + sql order_date, + supported_granularities [day, week, month], + format yyyy-mm-dd, + value_format_name iso_date + ); + "#; + + let model = parse_sql_model(sql).unwrap(); + assert_eq!(model.source_uri.as_deref(), Some("s3://warehouse/orders")); + assert_eq!(model.extends.as_deref(), Some("base_orders")); + assert_eq!(model.default_time_dimension.as_deref(), Some("order_date")); + assert_eq!(model.default_grain.as_deref(), Some("day")); + + let order_date = model.get_dimension("order_date").unwrap(); + assert_eq!( + order_date.supported_granularities, + Some(vec![ + "day".to_string(), + "week".to_string(), + "month".to_string() + ]) + ); + assert_eq!(order_date.format.as_deref(), Some("yyyy-mm-dd")); + assert_eq!(order_date.value_format_name.as_deref(), Some("iso_date")); + } + + #[test] + fn test_parse_model_with_composite_primary_key() { + let sql = r#" + MODEL ( + name order_items, + table order_items, + primary_key [order_id, item_id] + ); + "#; + + let model = parse_sql_model(sql).unwrap(); + assert_eq!(model.primary_key, "order_id"); + assert_eq!( + model.primary_key_columns, + vec!["order_id".to_string(), "item_id".to_string()] + ); + } + + #[test] + fn test_parse_relationship_with_composite_keys() { + let sql = r#" + MODEL (name shipments, table shipments); + RELATIONSHIP ( + name order_items, + type many_to_one, + foreign_key [order_id, item_id], + primary_key [order_id, item_id] + ); + "#; + + let model = parse_sql_model(sql).unwrap(); + let rel = model.get_relationship("order_items").unwrap(); + assert_eq!(rel.foreign_key.as_deref(), Some("order_id")); + assert_eq!( + rel.foreign_key_columns.as_ref().unwrap(), + &vec!["order_id".to_string(), "item_id".to_string()] + ); + assert_eq!(rel.primary_key.as_deref(), Some("order_id")); + assert_eq!( + rel.primary_key_columns.as_ref().unwrap(), + &vec!["order_id".to_string(), "item_id".to_string()] + ); + } + + #[test] + fn test_parse_sql_statement_blocks() { + let sql = r#" + MODEL (name orders, table orders); + METRIC (name revenue, expression SUM(amount)); + SEGMENT (name completed, sql status = 'completed'); + "#; + + let blocks = parse_sql_statement_blocks(sql).unwrap(); + assert_eq!(blocks.len(), 3); + + assert_eq!(blocks[0].kind, "model"); + assert_eq!( + blocks[0].properties.get("name"), + Some(&"orders".to_string()) + ); + + assert_eq!(blocks[1].kind, "metric"); + assert_eq!( + blocks[1].properties.get("name"), + Some(&"revenue".to_string()) + ); + // Alias keys are resolved in rust parser payloads. + assert_eq!( + blocks[1].properties.get("sql"), + Some(&"SUM(amount)".to_string()) + ); + + assert_eq!(blocks[2].kind, "segment"); + assert_eq!( + blocks[2].properties.get("sql"), + Some(&"status = 'completed'".to_string()) + ); + } } diff --git a/sidemantic-rs/src/core/dependency.rs b/sidemantic-rs/src/core/dependency.rs index ee757ad8..5cd12910 100644 --- a/sidemantic-rs/src/core/dependency.rs +++ b/sidemantic-rs/src/core/dependency.rs @@ -4,7 +4,7 @@ use std::collections::HashSet; -use polyglot_sql::{DialectType, Expression, ExpressionWalk}; +use polyglot_sql::{parse, traversal, DialectType, Expression}; use super::model::{Metric, MetricType}; use super::SemanticGraph; @@ -15,6 +15,15 @@ use super::SemanticGraph; /// For qualified references (model.metric), returns the full reference. /// For unqualified references, attempts to resolve using the graph. pub fn extract_dependencies(metric: &Metric, graph: Option<&SemanticGraph>) -> HashSet { + extract_dependencies_with_context(metric, graph, None) +} + +/// Extract dependencies with optional model context for unqualified reference resolution. +pub fn extract_dependencies_with_context( + metric: &Metric, + graph: Option<&SemanticGraph>, + model_context: Option<&str>, +) -> HashSet { let mut deps = HashSet::new(); match metric.r#type { @@ -40,7 +49,7 @@ pub fn extract_dependencies(metric: &Metric, graph: Option<&SemanticGraph>) -> H // Resolve references using graph if available if let Some(g) = graph { for ref_name in refs { - let resolved = resolve_reference(&ref_name, g); + let resolved = resolve_reference(&ref_name, g, model_context); deps.insert(resolved); } } else { @@ -56,6 +65,8 @@ pub fn extract_dependencies(metric: &Metric, graph: Option<&SemanticGraph>) -> H // Cumulative metrics depend on the base metric in sql field if let Some(ref sql) = metric.sql { deps.insert(sql.clone()); + } else if let Some(ref base_metric) = metric.base_metric { + deps.insert(base_metric.clone()); } } MetricType::TimeComparison => { @@ -64,6 +75,9 @@ pub fn extract_dependencies(metric: &Metric, graph: Option<&SemanticGraph>) -> H deps.insert(base.clone()); } } + MetricType::Conversion | MetricType::Retention | MetricType::Cohort => { + // Complex event/cohort metrics are modeled via event filters, not metric dependencies. + } } deps @@ -87,43 +101,55 @@ fn has_operators(s: &str) -> bool { /// Uses polyglot-sql to parse the expression and find all column identifiers. fn extract_column_references(sql: &str) -> HashSet { let mut refs = HashSet::new(); + let normalized_sql = sql.replace("${CUBE}.", "").replace("${CUBE}", ""); + + // polyglot-sql traversal can recurse indefinitely on some PostgreSQL cast + // forms (expr::type). Fall back to the tokenizer path for these expressions. + if normalized_sql.contains("::") { + return extract_simple_references(&normalized_sql); + } // Wrap in SELECT to make it valid SQL - let wrapped = format!("SELECT {sql}"); + let wrapped = format!("SELECT {normalized_sql}"); - let Ok(expressions) = polyglot_sql::parse(&wrapped, DialectType::Generic) else { + let Ok(statements) = parse(&wrapped, DialectType::Generic) else { // If parsing fails, try simple extraction - return extract_simple_references(sql); + return extract_simple_references(&normalized_sql); }; - for expr in expressions { - if let Expression::Select(select) = expr { - for item in &select.expressions { - extract_refs_from_expr(item, &mut refs); + for statement in &statements { + if let Expression::Select(select) = statement { + for projection in &select.expressions { + for column_ref in traversal::get_columns(projection) { + if let Expression::Column(column) = column_ref { + let candidate = if let Some(table) = &column.table { + if table.name.is_empty() { + column.name.name.clone() + } else { + format!("{}.{}", table.name, column.name.name) + } + } else { + column.name.name.clone() + }; + if let Some(cleaned) = sanitize_reference(&candidate) { + refs.insert(cleaned); + } + } + } } } } + if refs.is_empty() { + return extract_simple_references(&normalized_sql); + } + refs } -/// Recursively extract column references from an expression using DFS -fn extract_refs_from_expr(expr: &Expression, refs: &mut HashSet) { - for node in expr.dfs() { - match node { - Expression::Identifier(ident) => { - refs.insert(ident.name.clone()); - } - Expression::Column(col) => { - if let Some(table) = &col.table { - refs.insert(format!("{}.{}", table.name, col.name.name)); - } else { - refs.insert(col.name.name.clone()); - } - } - _ => {} - } - } +/// Public wrapper used by language bindings for dependency analysis helpers. +pub fn extract_column_references_from_expr(sql: &str) -> HashSet { + extract_column_references(sql) } /// Simple fallback extraction for when parsing fails @@ -144,8 +170,11 @@ fn extract_simple_references(sql: &str) -> HashSet { if c.is_alphanumeric() || c == '_' || c == '.' { current.push(c); } else { - if !current.is_empty() && !is_keyword(¤t) && !is_number(¤t) { - refs.insert(current.clone()); + let is_function_call = c == '('; + if !is_function_call { + if let Some(cleaned) = sanitize_reference(¤t) { + refs.insert(cleaned); + } } current.clear(); } @@ -154,18 +183,35 @@ fn extract_simple_references(sql: &str) -> HashSet { prev_char = c; } - if !current.is_empty() && !is_keyword(¤t) && !is_number(¤t) { - refs.insert(current); + if let Some(cleaned) = sanitize_reference(¤t) { + refs.insert(cleaned); } refs } +fn sanitize_reference(raw: &str) -> Option { + let mut candidate = raw.trim(); + while let Some(stripped) = candidate.strip_prefix('.') { + candidate = stripped; + } + if candidate.is_empty() { + return None; + } + if is_keyword(candidate) || is_number(candidate) || is_cast_type(candidate) { + return None; + } + if candidate.eq_ignore_ascii_case("cube") { + return None; + } + Some(candidate.to_string()) +} + /// Check if string is a SQL keyword fn is_keyword(s: &str) -> bool { let keywords = [ - "SELECT", "FROM", "WHERE", "AND", "OR", "NOT", "NULL", "NULLIF", "CASE", "WHEN", "THEN", - "ELSE", "END", "AS", "SUM", "COUNT", "AVG", "MIN", "MAX", "DISTINCT", + "SELECT", "FROM", "WHERE", "AND", "OR", "NOT", "NULL", "CASE", "WHEN", "THEN", "ELSE", + "END", "AS", "DISTINCT", ]; keywords.iter().any(|k| k.eq_ignore_ascii_case(s)) } @@ -175,16 +221,46 @@ fn is_number(s: &str) -> bool { s.parse::().is_ok() } +fn is_cast_type(s: &str) -> bool { + let cast_types = [ + "float", + "double", + "decimal", + "numeric", + "integer", + "int", + "bigint", + "smallint", + "real", + "boolean", + "bool", + "date", + "time", + "timestamp", + "varchar", + "text", + ]; + cast_types.iter().any(|ty| ty.eq_ignore_ascii_case(s)) +} + /// Resolve a reference using the semantic graph /// /// If the reference is already qualified (model.metric), returns as-is. /// Otherwise, searches all models for a matching metric. -fn resolve_reference(ref_name: &str, graph: &SemanticGraph) -> String { +fn resolve_reference(ref_name: &str, graph: &SemanticGraph, model_context: Option<&str>) -> String { // Already qualified if ref_name.contains('.') { return ref_name.to_string(); } + if let Some(context_model_name) = model_context { + if let Some(model) = graph.get_model(context_model_name) { + if model.get_metric(ref_name).is_some() { + return format!("{context_model_name}.{ref_name}"); + } + } + } + // Search models for matching metric for model in graph.models() { if model.get_metric(ref_name).is_some() { @@ -297,4 +373,14 @@ mod tests { assert!(refs.contains("revenue")); assert!(refs.contains("cost")); } + + #[test] + fn test_extract_column_references_ignores_cube_placeholder_and_cast_type() { + let refs = extract_column_references( + "COUNT(CASE WHEN ${CUBE}.status = 'approved' THEN 1 END)::float / NULLIF(COUNT(*), 0)", + ); + assert!(refs.contains("status")); + assert!(!refs.contains("CUBE")); + assert!(!refs.contains("float")); + } } diff --git a/sidemantic-rs/src/core/graph.rs b/sidemantic-rs/src/core/graph.rs index 1d1d3e66..5c3cb344 100644 --- a/sidemantic-rs/src/core/graph.rs +++ b/sidemantic-rs/src/core/graph.rs @@ -3,6 +3,7 @@ use std::collections::{HashMap, HashSet, VecDeque}; use crate::core::model::{Model, RelationshipType}; +use crate::core::Parameter; use crate::error::{Result, SidemanticError}; /// A step in a join path @@ -12,6 +13,8 @@ pub struct JoinStep { pub to_model: String, pub from_key: String, pub to_key: String, + pub from_keys: Vec, + pub to_keys: Vec, pub relationship_type: RelationshipType, /// Custom SQL join condition (overrides FK/PK join) pub custom_condition: Option, @@ -63,13 +66,20 @@ impl JoinPath { } } -/// Edge in the adjacency list: (target_model, fk, pk, relationship_type, custom_sql) -type AdjacencyEdge = (String, String, String, RelationshipType, Option); +/// Edge in the adjacency list: (target_model, from_keys, to_keys, relationship_type, custom_sql) +type AdjacencyEdge = ( + String, + Vec, + Vec, + RelationshipType, + Option, +); /// The semantic graph holds all models and their relationships #[derive(Debug, Default)] pub struct SemanticGraph { models: HashMap, + parameters: HashMap, /// Adjacency list: model -> edges adjacency: HashMap>, } @@ -79,14 +89,25 @@ impl SemanticGraph { Self::default() } + fn validate_model(model: &Model) -> Result<()> { + if model.table.is_none() && model.sql.is_none() { + return Err(SidemanticError::Validation(format!( + "Model '{}' must have either 'table' or 'sql' defined", + model.name + ))); + } + + Ok(()) + } + /// Add a model to the graph pub fn add_model(&mut self, model: Model) -> Result<()> { let name = model.name.clone(); - // Validate model - if model.table.is_none() && model.sql.is_none() { + Self::validate_model(&model)?; + if self.models.contains_key(&name) { return Err(SidemanticError::Validation(format!( - "Model '{name}' must have either 'table' or 'sql' defined" + "Model '{name}' already exists" ))); } @@ -95,6 +116,16 @@ impl SemanticGraph { Ok(()) } + /// Add or replace a model in the graph. + pub fn replace_model(&mut self, model: Model) -> Result<()> { + let name = model.name.clone(); + + Self::validate_model(&model)?; + self.models.insert(name, model); + self.rebuild_adjacency(); + Ok(()) + } + /// Get a model by name pub fn get_model(&self, name: &str) -> Option<&Model> { self.models.get(name) @@ -105,18 +136,135 @@ impl SemanticGraph { self.models.values() } + /// Add a parameter to the graph + pub fn add_parameter(&mut self, parameter: Parameter) -> Result<()> { + if self.parameters.contains_key(¶meter.name) { + return Err(SidemanticError::Validation(format!( + "Parameter '{}' already exists", + parameter.name + ))); + } + self.parameters.insert(parameter.name.clone(), parameter); + Ok(()) + } + + /// Get a parameter by name + pub fn get_parameter(&self, name: &str) -> Option<&Parameter> { + self.parameters.get(name) + } + + /// Get all parameters + pub fn parameters(&self) -> impl Iterator { + self.parameters.values() + } + /// Rebuild the adjacency list from model relationships fn rebuild_adjacency(&mut self) { self.adjacency.clear(); for model in self.models.values() { - let edges = self.adjacency.entry(model.name.clone()).or_default(); + self.adjacency.entry(model.name.clone()).or_default(); for rel in &model.relationships { - edges.push(( + if rel.r#type == RelationshipType::ManyToMany { + if let Some(through_name) = &rel.through { + let through_model_exists = self.models.contains_key(through_name); + let target_model_exists = self.models.contains_key(&rel.name); + if !through_model_exists || !target_model_exists { + continue; + } + + let (source_fk_opt, target_fk_opt) = rel.junction_keys(); + let (Some(source_fk), Some(target_fk)) = (source_fk_opt, target_fk_opt) + else { + continue; + }; + + let source_pk = model.primary_keys(); + let target_pk = + if rel.primary_key.is_some() || rel.primary_key_columns.is_some() { + rel.primary_key_columns() + } else { + self.models + .get(&rel.name) + .map(|target_model| target_model.primary_keys()) + .unwrap_or_else(|| vec!["id".to_string()]) + }; + let source_pk_first = source_pk + .first() + .cloned() + .unwrap_or_else(|| "id".to_string()); + let target_pk_first = target_pk + .first() + .cloned() + .unwrap_or_else(|| "id".to_string()); + + // source -> through (one_to_many) + self.adjacency.entry(model.name.clone()).or_default().push(( + through_name.clone(), + vec![source_pk_first.clone()], + vec![source_fk.clone()], + RelationshipType::OneToMany, + None, + )); + // through -> source (many_to_one) + self.adjacency + .entry(through_name.clone()) + .or_default() + .push(( + model.name.clone(), + vec![source_fk], + vec![source_pk_first], + RelationshipType::ManyToOne, + None, + )); + + // through -> target (many_to_one) + self.adjacency + .entry(through_name.clone()) + .or_default() + .push(( + rel.name.clone(), + vec![target_fk.clone()], + vec![target_pk_first.clone()], + RelationshipType::ManyToOne, + None, + )); + // target -> through (one_to_many) + self.adjacency.entry(rel.name.clone()).or_default().push(( + through_name.clone(), + vec![target_pk_first], + vec![target_fk], + RelationshipType::OneToMany, + None, + )); + continue; + } + } + + let fk_keys = rel.foreign_key_columns(); + let pk_keys = if rel.primary_key.is_some() || rel.primary_key_columns.is_some() { + rel.primary_key_columns() + } else { + self.models + .get(&rel.name) + .map(|target_model| target_model.primary_keys()) + .unwrap_or_else(|| vec!["id".to_string()]) + }; + + let (from_keys, to_keys) = match rel.r#type { + RelationshipType::ManyToOne | RelationshipType::OneToOne => { + (fk_keys.clone(), pk_keys.clone()) + } + RelationshipType::OneToMany | RelationshipType::ManyToMany => { + (pk_keys.clone(), fk_keys.clone()) + } + }; + + self.adjacency.entry(model.name.clone()).or_default().push(( rel.name.clone(), - rel.fk(), - rel.pk(), + from_keys.clone(), + to_keys.clone(), rel.r#type.clone(), rel.sql.clone(), )); @@ -124,6 +272,22 @@ impl SemanticGraph { // Add reverse edges for relationships for rel in &model.relationships { + if rel.r#type == RelationshipType::ManyToMany && rel.through.is_some() { + continue; + } + + // If the target model already declares an explicit reverse relationship, + // don't synthesize another reverse edge. This avoids conflicting + // FK/PK directions when both sides are configured. + if self + .models + .get(&rel.name) + .and_then(|target| target.get_relationship(&model.name)) + .is_some() + { + continue; + } + let reverse_type = match rel.r#type { RelationshipType::ManyToOne => RelationshipType::OneToMany, RelationshipType::OneToMany => RelationshipType::ManyToOne, @@ -138,10 +302,29 @@ impl SemanticGraph { .replace("__TEMP__", "{to}") }); + let fk_keys = rel.foreign_key_columns(); + let pk_keys = if rel.primary_key.is_some() || rel.primary_key_columns.is_some() { + rel.primary_key_columns() + } else { + self.models + .get(&rel.name) + .map(|target_model| target_model.primary_keys()) + .unwrap_or_else(|| vec!["id".to_string()]) + }; + + let (reverse_from_keys, reverse_to_keys) = match rel.r#type { + RelationshipType::ManyToOne | RelationshipType::OneToOne => { + (pk_keys.clone(), fk_keys.clone()) + } + RelationshipType::OneToMany | RelationshipType::ManyToMany => { + (fk_keys.clone(), pk_keys.clone()) + } + }; + self.adjacency.entry(rel.name.clone()).or_default().push(( model.name.clone(), - rel.pk(), - rel.fk(), + reverse_from_keys, + reverse_to_keys, reverse_type, reverse_sql, )); @@ -173,14 +356,18 @@ impl SemanticGraph { while let Some((current, path)) = queue.pop_front() { if let Some(edges) = self.adjacency.get(¤t) { - for (target, fk, pk, rel_type, custom_sql) in edges { + for (target, from_keys, to_keys, rel_type, custom_sql) in edges { if !visited.contains(target) { let mut new_path = path.clone(); + let from_key = from_keys.first().cloned().unwrap_or_default(); + let to_key = to_keys.first().cloned().unwrap_or_default(); new_path.push(JoinStep { from_model: current.clone(), to_model: target.clone(), - from_key: fk.clone(), - to_key: pk.clone(), + from_key, + to_key, + from_keys: from_keys.clone(), + to_keys: to_keys.clone(), relationship_type: rel_type.clone(), custom_condition: custom_sql.clone(), }); @@ -236,6 +423,7 @@ impl SemanticGraph { mod tests { use super::*; use crate::core::model::{Dimension, Metric, Relationship}; + use crate::core::parameter::{Parameter, ParameterType}; fn create_test_graph() -> SemanticGraph { let mut graph = SemanticGraph::new(); @@ -266,6 +454,32 @@ mod tests { assert!(graph.get_model("nonexistent").is_none()); } + #[test] + fn test_replace_model_overwrites_existing_model() { + let mut graph = SemanticGraph::new(); + + graph + .add_model( + Model::new("orders", "order_id") + .with_table("orders") + .with_dimension(Dimension::categorical("status")), + ) + .unwrap(); + + graph + .replace_model( + Model::new("orders", "id") + .with_table("orders_v2") + .with_metric(Metric::count("order_count")), + ) + .unwrap(); + + let model = graph.get_model("orders").unwrap(); + assert_eq!(model.table.as_deref(), Some("orders_v2")); + assert!(model.get_dimension("status").is_none()); + assert!(model.get_metric("order_count").is_some()); + } + #[test] fn test_find_join_path() { let graph = create_test_graph(); @@ -281,6 +495,8 @@ mod tests { assert_eq!(path.steps[0].to_model, "customers"); assert_eq!(path.steps[0].from_key, "customers_id"); assert_eq!(path.steps[0].to_key, "id"); + assert_eq!(path.steps[0].from_keys, vec!["customers_id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["id".to_string()]); // Reverse relationship let path = graph.find_join_path("customers", "orders").unwrap(); @@ -346,4 +562,158 @@ mod tests { .unwrap() .contains("{from}.customer_id = {to}.id")); } + + #[test] + fn test_default_relationship_uses_target_primary_key() { + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_relationship(Relationship { + name: "customers".to_string(), + r#type: RelationshipType::ManyToOne, + foreign_key: Some("customer_id".to_string()), + foreign_key_columns: None, + primary_key: None, + primary_key_columns: None, + through: None, + through_foreign_key: None, + related_foreign_key: None, + sql: None, + }); + + let customers = Model::new("customers", "customer_id").with_table("customers"); + + graph.add_model(orders).unwrap(); + graph.add_model(customers).unwrap(); + + let path = graph.find_join_path("orders", "customers").unwrap(); + assert_eq!(path.steps.len(), 1); + assert_eq!(path.steps[0].from_key, "customer_id"); + assert_eq!(path.steps[0].to_key, "customer_id"); + assert_eq!(path.steps[0].from_keys, vec!["customer_id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["customer_id".to_string()]); + } + + #[test] + fn test_many_to_many_with_through_builds_two_hop_path() { + let mut graph = SemanticGraph::new(); + + let orders = Model::new("orders", "order_id") + .with_table("orders") + .with_relationship(Relationship { + name: "products".to_string(), + r#type: RelationshipType::ManyToMany, + foreign_key: None, + foreign_key_columns: None, + primary_key: Some("product_id".to_string()), + primary_key_columns: None, + through: Some("order_items".to_string()), + through_foreign_key: Some("order_id".to_string()), + related_foreign_key: Some("product_id".to_string()), + sql: None, + }); + let order_items = Model::new("order_items", "id").with_table("order_items"); + let products = Model::new("products", "product_id").with_table("products"); + + graph.add_model(orders).unwrap(); + graph.add_model(order_items).unwrap(); + graph.add_model(products).unwrap(); + + let path = graph.find_join_path("orders", "products").unwrap(); + assert_eq!(path.steps.len(), 2); + + // orders -> order_items + assert_eq!(path.steps[0].from_model, "orders"); + assert_eq!(path.steps[0].to_model, "order_items"); + assert_eq!(path.steps[0].from_key, "order_id"); + assert_eq!(path.steps[0].to_key, "order_id"); + assert_eq!(path.steps[0].from_keys, vec!["order_id".to_string()]); + assert_eq!(path.steps[0].to_keys, vec!["order_id".to_string()]); + assert_eq!(path.steps[0].relationship_type, RelationshipType::OneToMany); + + // order_items -> products + assert_eq!(path.steps[1].from_model, "order_items"); + assert_eq!(path.steps[1].to_model, "products"); + assert_eq!(path.steps[1].from_key, "product_id"); + assert_eq!(path.steps[1].to_key, "product_id"); + assert_eq!(path.steps[1].from_keys, vec!["product_id".to_string()]); + assert_eq!(path.steps[1].to_keys, vec!["product_id".to_string()]); + assert_eq!(path.steps[1].relationship_type, RelationshipType::ManyToOne); + } + + #[test] + fn test_find_join_path_with_composite_keys() { + let mut graph = SemanticGraph::new(); + + let order_items = Model::new("order_items", "order_id") + .with_primary_key_columns(vec!["order_id".to_string(), "item_id".to_string()]) + .with_table("order_items"); + let shipments = Model::new("shipments", "shipment_id") + .with_table("shipments") + .with_relationship(Relationship::many_to_one("order_items").with_key_columns( + vec!["order_id".to_string(), "item_id".to_string()], + vec!["order_id".to_string(), "item_id".to_string()], + )); + + graph.add_model(order_items).unwrap(); + graph.add_model(shipments).unwrap(); + + let path = graph.find_join_path("shipments", "order_items").unwrap(); + assert_eq!(path.steps.len(), 1); + assert_eq!(path.steps[0].from_key, "order_id"); + assert_eq!(path.steps[0].to_key, "order_id"); + assert_eq!( + path.steps[0].from_keys, + vec!["order_id".to_string(), "item_id".to_string()] + ); + assert_eq!( + path.steps[0].to_keys, + vec!["order_id".to_string(), "item_id".to_string()] + ); + } + + #[test] + fn test_add_model_duplicate_name() { + let mut graph = SemanticGraph::new(); + let orders_one = Model::new("orders", "order_id").with_table("orders"); + let orders_two = Model::new("orders", "id").with_table("orders_v2"); + + graph.add_model(orders_one).unwrap(); + let err = graph.add_model(orders_two).unwrap_err(); + assert!(err.to_string().contains("Model 'orders' already exists")); + } + + #[test] + fn test_add_parameter() { + let mut graph = create_test_graph(); + let parameter = Parameter { + name: "status".to_string(), + parameter_type: ParameterType::String, + description: None, + label: None, + default_value: Some(serde_json::Value::String("pending".to_string())), + allowed_values: None, + default_to_today: false, + }; + graph.add_parameter(parameter).unwrap(); + assert!(graph.get_parameter("status").is_some()); + } + + #[test] + fn test_add_parameter_duplicate() { + let mut graph = create_test_graph(); + let parameter = Parameter { + name: "status".to_string(), + parameter_type: ParameterType::String, + description: None, + label: None, + default_value: None, + allowed_values: None, + default_to_today: false, + }; + graph.add_parameter(parameter.clone()).unwrap(); + let err = graph.add_parameter(parameter).unwrap_err(); + assert!(err.to_string().contains("already exists")); + } } diff --git a/sidemantic-rs/src/core/inheritance.rs b/sidemantic-rs/src/core/inheritance.rs index f2a03f41..76ecf2c5 100644 --- a/sidemantic-rs/src/core/inheritance.rs +++ b/sidemantic-rs/src/core/inheritance.rs @@ -1,7 +1,7 @@ //! Model inheritance support //! //! Allows models to extend other models, inheriting dimensions, metrics, -//! relationships, and segments. Child values override parent values. +//! relationships, segments, and pre-aggregations. Child values override parent values. use std::collections::{HashMap, HashSet}; @@ -11,22 +11,58 @@ use crate::error::{Result, SidemanticError}; /// Merge a child model with its parent. /// /// Child inherits all fields from parent, with child values taking precedence. -/// List fields (dimensions, metrics, relationships, segments) are merged by name, +/// List fields (dimensions, metrics, relationships, segments, pre-aggregations) are merged by name, /// with child items overriding parent items of the same name. pub fn merge_model(child: &Model, parent: &Model) -> Model { // Start with parent's table/sql, override with child if set let table = child.table.clone().or_else(|| parent.table.clone()); let sql = child.sql.clone().or_else(|| parent.sql.clone()); - let primary_key = if child.primary_key != "id" { - child.primary_key.clone() + let source_uri = child + .source_uri + .clone() + .or_else(|| parent.source_uri.clone()); + let extends = child.extends.clone(); + let child_primary_key_columns = if child.primary_key_columns.is_empty() { + vec![child.primary_key.clone()] + } else { + child.primary_key_columns.clone() + }; + let parent_primary_key_columns = if parent.primary_key_columns.is_empty() { + vec![parent.primary_key.clone()] } else { - parent.primary_key.clone() + parent.primary_key_columns.clone() }; + let child_overrides_primary_key = child_primary_key_columns.len() > 1 + || child_primary_key_columns + .first() + .map(|value| value.as_str()) + != Some("id"); + let primary_key_columns = if child_overrides_primary_key { + child_primary_key_columns + } else { + parent_primary_key_columns + }; + let primary_key = primary_key_columns + .first() + .cloned() + .unwrap_or_else(|| "id".to_string()); + let unique_keys = child + .unique_keys + .clone() + .or_else(|| parent.unique_keys.clone()); let description = child .description .clone() .or_else(|| parent.description.clone()); let label = child.label.clone().or_else(|| parent.label.clone()); + let default_time_dimension = child + .default_time_dimension + .clone() + .or_else(|| parent.default_time_dimension.clone()); + let default_grain = child + .default_grain + .clone() + .or_else(|| parent.default_grain.clone()); // Merge dimensions by name (child overrides parent) let mut dimensions_map: HashMap = parent @@ -72,15 +108,33 @@ pub fn merge_model(child: &Model, parent: &Model) -> Model { } let segments: Vec<_> = segments_map.into_values().collect(); + // Merge pre-aggregations by name + let mut pre_aggs_map: HashMap = parent + .pre_aggregations + .iter() + .map(|p| (p.name.clone(), p.clone())) + .collect(); + for pre_agg in &child.pre_aggregations { + pre_aggs_map.insert(pre_agg.name.clone(), pre_agg.clone()); + } + let pre_aggregations: Vec<_> = pre_aggs_map.into_values().collect(); + Model { name: child.name.clone(), table, sql, + source_uri, + extends, primary_key, + primary_key_columns, + unique_keys, dimensions, metrics, relationships, segments, + pre_aggregations, + default_time_dimension, + default_grain, label, description, } diff --git a/sidemantic-rs/src/core/mod.rs b/sidemantic-rs/src/core/mod.rs index ebe6a327..6538db30 100644 --- a/sidemantic-rs/src/core/mod.rs +++ b/sidemantic-rs/src/core/mod.rs @@ -4,19 +4,28 @@ mod dependency; mod graph; mod inheritance; mod model; +mod parameter; mod relative_date; mod segment; pub mod symmetric_agg; mod table_calc; -pub use dependency::{check_circular_dependencies, extract_dependencies}; +pub use dependency::{ + check_circular_dependencies, extract_column_references_from_expr, extract_dependencies, + extract_dependencies_with_context, +}; pub use graph::{JoinPath, JoinStep, SemanticGraph}; pub use inheritance::{merge_model, resolve_model_inheritance}; pub use model::{ - Aggregation, Dimension, DimensionType, Metric, MetricType, Model, Relationship, - RelationshipType, + Aggregation, CohortInnerMetric, ComparisonCalculation, ComparisonType, Dimension, + DimensionType, Index, Metric, MetricType, Model, PreAggregation, PreAggregationType, + RefreshKey, Relationship, RelationshipType, TimeGrain, }; +pub use parameter::{Parameter, ParameterType}; pub use relative_date::RelativeDate; pub use segment::Segment; -pub use symmetric_agg::{build_symmetric_aggregate_sql, SqlDialect, SymmetricAggType}; +pub use symmetric_agg::{ + build_symmetric_aggregate_sql, build_symmetric_aggregate_sql_with_key_expr, SqlDialect, + SymmetricAggType, +}; pub use table_calc::{TableCalcType, TableCalculation}; diff --git a/sidemantic-rs/src/core/model.rs b/sidemantic-rs/src/core/model.rs index 5ef62e0c..92015e72 100644 --- a/sidemantic-rs/src/core/model.rs +++ b/sidemantic-rs/src/core/model.rs @@ -25,10 +25,25 @@ pub struct Dimension { pub sql: Option, /// Time granularity (for time dimensions) pub granularity: Option, + /// Supported granularities for time dimensions + #[serde(default)] + pub supported_granularities: Option>, /// Human-readable label pub label: Option, /// Description pub description: Option, + /// Display format string + #[serde(default)] + pub format: Option, + /// Named display format + #[serde(default)] + pub value_format_name: Option, + /// Parent dimension name for hierarchies + #[serde(default)] + pub parent: Option, + /// Window expression projected in model CTEs. + #[serde(default)] + pub window: Option, } impl Dimension { @@ -38,8 +53,13 @@ impl Dimension { r#type: DimensionType::Categorical, sql: None, granularity: None, + supported_granularities: None, label: None, description: None, + format: None, + value_format_name: None, + parent: None, + window: None, } } @@ -69,6 +89,11 @@ impl Dimension { self.sql.as_deref().unwrap_or(&self.name) } + /// Returns the window expression when configured, otherwise the row-level SQL expression. + pub fn window_sql_expr(&self) -> &str { + self.window.as_deref().unwrap_or_else(|| self.sql_expr()) + } + /// Returns SQL with time granularity applied (DATE_TRUNC) pub fn sql_with_granularity(&self, granularity: Option<&str>) -> String { let base_sql = self.sql_expr(); @@ -112,14 +137,18 @@ impl Aggregation { /// Metric type #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "snake_case")] pub enum MetricType { #[default] Simple, Derived, Ratio, Cumulative, + #[serde(alias = "timecomparison")] TimeComparison, + Conversion, + Retention, + Cohort, } /// Time comparison type @@ -155,6 +184,16 @@ pub enum TimeGrain { Year, } +/// Inner per-entity aggregate for cohort metrics. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CohortInnerMetric { + pub name: String, + #[serde(default)] + pub agg: Option, + #[serde(default)] + pub sql: Option, +} + /// A metric represents a business measure (aggregation) #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Metric { @@ -169,6 +208,9 @@ pub struct Metric { pub numerator: Option, /// Denominator metric (for ratio metrics) pub denominator: Option, + /// Time offset for ratio denominator (e.g. "1 month") + #[serde(default)] + pub offset_window: Option, /// Filters to apply #[serde(default)] pub filters: Vec, @@ -184,6 +226,15 @@ pub struct Metric { /// Grain for period-to-date (e.g., month for MTD) #[serde(default)] pub grain_to_date: Option, + /// Raw SQL expression used inside a window function + #[serde(default)] + pub window_expression: Option, + /// Explicit window frame clause + #[serde(default)] + pub window_frame: Option, + /// ORDER BY column override for window metrics + #[serde(default)] + pub window_order: Option, // Time comparison fields /// Base metric for time comparison @@ -199,6 +250,48 @@ pub struct Metric { #[serde(default)] pub calculation: Option, + // Conversion metric fields + /// Entity identifier expression (e.g. user_id) + #[serde(default)] + pub entity: Option, + /// Base event predicate/value + #[serde(default)] + pub base_event: Option, + /// Conversion event predicate/value + #[serde(default)] + pub conversion_event: Option, + /// Conversion window (e.g. "7 days") + #[serde(default)] + pub conversion_window: Option, + /// N-step funnel filter expressions. + #[serde(default)] + pub steps: Option>, + + // Retention metric fields + /// Cohort-defining event predicate. + #[serde(default)] + pub cohort_event: Option, + /// Activity event predicate. + #[serde(default)] + pub activity_event: Option, + /// Number of retention periods to include. + #[serde(default)] + pub periods: Option, + /// Retention time grain. + #[serde(default)] + pub retention_granularity: Option, + + // Cohort metric fields + /// Per-entity inner aggregations. + #[serde(default)] + pub inner_metrics: Option>, + /// Dimensions carried through from inner to outer aggregation. + #[serde(default)] + pub entity_dimensions: Option>, + /// HAVING predicate applied to the inner aggregation. + #[serde(default)] + pub having: Option, + // Display formatting /// Default value when result is NULL #[serde(default)] @@ -206,6 +299,15 @@ pub struct Metric { /// Display format string (e.g., "$#,##0.00", "0.00%") #[serde(default)] pub format: Option, + /// Named format (e.g., "usd", "percent", "decimal_2") + #[serde(default)] + pub value_format_name: Option, + /// Fields to include in drill-down results + #[serde(default)] + pub drill_fields: Option>, + /// Dimension across which this metric is non-additive + #[serde(default)] + pub non_additive_dimension: Option, } impl Metric { @@ -217,17 +319,36 @@ impl Metric { sql: None, numerator: None, denominator: None, + offset_window: None, filters: Vec::new(), label: None, description: None, window: None, grain_to_date: None, + window_expression: None, + window_frame: None, + window_order: None, base_metric: None, comparison_type: None, time_offset: None, calculation: None, + entity: None, + base_event: None, + conversion_event: None, + conversion_window: None, + steps: None, + cohort_event: None, + activity_event: None, + periods: None, + retention_granularity: None, + inner_metrics: None, + entity_dimensions: None, + having: None, fill_nulls_with: None, format: None, + value_format_name: None, + drill_fields: None, + non_additive_dimension: None, } } @@ -426,6 +547,9 @@ impl Metric { } } } + MetricType::Conversion | MetricType::Retention | MetricType::Cohort => { + "NULL /* complex metric */".to_string() + } } } @@ -455,8 +579,23 @@ pub struct Relationship { pub r#type: RelationshipType, /// Foreign key column (defaults to {name}_id) pub foreign_key: Option, + /// Foreign key columns for composite relationships + #[serde(default)] + pub foreign_key_columns: Option>, /// Primary key in related model (defaults to "id") pub primary_key: Option, + /// Primary key columns in related model for composite relationships + #[serde(default)] + pub primary_key_columns: Option>, + /// Junction model for many-to-many relationships + #[serde(default)] + pub through: Option, + /// Foreign key in junction model pointing to this model + #[serde(default)] + pub through_foreign_key: Option, + /// Foreign key in junction model pointing to related model + #[serde(default)] + pub related_foreign_key: Option, /// Custom SQL join condition (overrides FK/PK) /// Use {from} and {to} placeholders for table aliases #[serde(default)] @@ -469,7 +608,12 @@ impl Relationship { name: target.into(), r#type: RelationshipType::ManyToOne, foreign_key: None, + foreign_key_columns: None, primary_key: None, + primary_key_columns: None, + through: None, + through_foreign_key: None, + related_foreign_key: None, sql: None, } } @@ -490,8 +634,24 @@ impl Relationship { foreign_key: impl Into, primary_key: impl Into, ) -> Self { - self.foreign_key = Some(foreign_key.into()); - self.primary_key = Some(primary_key.into()); + let foreign_key = foreign_key.into(); + let primary_key = primary_key.into(); + self.foreign_key = Some(foreign_key.clone()); + self.foreign_key_columns = Some(vec![foreign_key]); + self.primary_key = Some(primary_key.clone()); + self.primary_key_columns = Some(vec![primary_key]); + self + } + + pub fn with_key_columns( + mut self, + foreign_keys: Vec, + primary_keys: Vec, + ) -> Self { + self.foreign_key = foreign_keys.first().cloned(); + self.foreign_key_columns = Some(foreign_keys); + self.primary_key = primary_keys.first().cloned(); + self.primary_key_columns = Some(primary_keys); self } @@ -505,20 +665,146 @@ impl Relationship { /// Returns the foreign key column name pub fn fk(&self) -> String { - self.foreign_key - .clone() + self.foreign_key_columns() + .into_iter() + .next() .unwrap_or_else(|| format!("{}_id", self.name)) } /// Returns the primary key column name in the related model pub fn pk(&self) -> String { - self.primary_key.clone().unwrap_or_else(|| "id".to_string()) + self.primary_key_columns() + .into_iter() + .next() + .unwrap_or_else(|| "id".to_string()) + } + + pub fn foreign_key_columns(&self) -> Vec { + self.foreign_key_columns + .clone() + .filter(|columns| !columns.is_empty()) + .or_else(|| self.foreign_key.clone().map(|key| vec![key])) + .unwrap_or_else(|| vec![format!("{}_id", self.name)]) + } + + pub fn primary_key_columns(&self) -> Vec { + self.primary_key_columns + .clone() + .filter(|columns| !columns.is_empty()) + .or_else(|| self.primary_key.clone().map(|key| vec![key])) + .unwrap_or_else(|| vec!["id".to_string()]) } /// Returns the custom SQL condition if set pub fn custom_condition(&self) -> Option<&str> { self.sql.as_deref() } + + /// Get junction keys for many-to-many relationships. + /// Returns (source_fk_in_through, target_fk_in_through). + pub fn junction_keys(&self) -> (Option, Option) { + if self.r#type != RelationshipType::ManyToMany { + return (None, None); + } + ( + self.through_foreign_key + .clone() + .or_else(|| self.foreign_key.clone()), + self.related_foreign_key.clone(), + ) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum PreAggregationType { + #[default] + Rollup, + OriginalSql, + RollupJoin, + Lambda, +} + +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +fn default_index_type() -> String { + "regular".to_string() +} + +/// Refresh strategy configuration for pre-aggregations. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RefreshKey { + #[serde(default)] + pub every: Option, + #[serde(default)] + pub sql: Option, + #[serde(default = "default_false")] + pub incremental: bool, + #[serde(default)] + pub update_window: Option, +} + +/// Index definition for pre-aggregation performance. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Index { + pub name: String, + #[serde(default)] + pub columns: Vec, + #[serde(default = "default_index_type", rename = "type")] + pub index_type: String, +} + +/// Pre-aggregation definition for query routing. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PreAggregation { + pub name: String, + #[serde(default, rename = "type")] + pub preagg_type: PreAggregationType, + #[serde(default)] + pub measures: Option>, + #[serde(default)] + pub dimensions: Option>, + #[serde(default)] + pub time_dimension: Option, + #[serde(default)] + pub granularity: Option, + #[serde(default)] + pub partition_granularity: Option, + #[serde(default)] + pub build_range_start: Option, + #[serde(default)] + pub build_range_end: Option, + #[serde(default = "default_true")] + pub scheduled_refresh: bool, + #[serde(default)] + pub refresh_key: Option, + #[serde(default)] + pub indexes: Option>, +} + +impl PreAggregation { + /// Returns pre-aggregation table name in [database.][schema.]model_preagg_name form. + pub fn table_name( + &self, + model_name: &str, + database: Option<&str>, + schema: Option<&str>, + ) -> String { + let mut table_name = format!("{model_name}_preagg_{}", self.name); + if let Some(schema) = schema { + table_name = format!("{schema}.{table_name}"); + } + if let Some(database) = database { + table_name = format!("{database}.{table_name}"); + } + table_name + } } /// A model represents a table or view with semantic definitions @@ -529,8 +815,20 @@ pub struct Model { pub table: Option, /// SQL expression for derived tables pub sql: Option, + /// Remote source URI + #[serde(default)] + pub source_uri: Option, + /// Parent model name for inheritance + #[serde(default)] + pub extends: Option, /// Primary key column pub primary_key: String, + /// Primary key columns for composite keys + #[serde(default)] + pub primary_key_columns: Vec, + /// Unique key constraints + #[serde(default)] + pub unique_keys: Option>>, /// Dimensions (grouping attributes) #[serde(default)] pub dimensions: Vec, @@ -543,6 +841,15 @@ pub struct Model { /// Segments (reusable filters) #[serde(default)] pub segments: Vec, + /// Pre-aggregations for query routing + #[serde(default)] + pub pre_aggregations: Vec, + /// Default time dimension auto-included when querying this model's metrics + #[serde(default)] + pub default_time_dimension: Option, + /// Default grain used with default_time_dimension + #[serde(default)] + pub default_grain: Option, /// Human-readable label pub label: Option, /// Description @@ -551,20 +858,36 @@ pub struct Model { impl Model { pub fn new(name: impl Into, primary_key: impl Into) -> Self { + let primary_key = primary_key.into(); Self { name: name.into(), table: None, sql: None, - primary_key: primary_key.into(), + source_uri: None, + extends: None, + primary_key: primary_key.clone(), + primary_key_columns: vec![primary_key], + unique_keys: None, dimensions: Vec::new(), metrics: Vec::new(), relationships: Vec::new(), segments: Vec::new(), + pre_aggregations: Vec::new(), + default_time_dimension: None, + default_grain: None, label: None, description: None, } } + pub fn with_primary_key_columns(mut self, primary_key_columns: Vec) -> Self { + if let Some(primary_key) = primary_key_columns.first() { + self.primary_key = primary_key.clone(); + } + self.primary_key_columns = primary_key_columns; + self + } + pub fn with_table(mut self, table: impl Into) -> Self { self.table = Some(table.into()); self @@ -595,11 +918,24 @@ impl Model { self } + pub fn with_pre_aggregation(mut self, pre_aggregation: PreAggregation) -> Self { + self.pre_aggregations.push(pre_aggregation); + self + } + /// Returns the table name or model name as fallback pub fn table_name(&self) -> &str { self.table.as_deref().unwrap_or(&self.name) } + pub fn primary_keys(&self) -> Vec { + if self.primary_key_columns.is_empty() { + vec![self.primary_key.clone()] + } else { + self.primary_key_columns.clone() + } + } + /// Returns the table source (table name or SQL subquery) pub fn table_source(&self) -> String { if let Some(sql) = &self.sql { @@ -628,6 +964,11 @@ impl Model { pub fn get_segment(&self, name: &str) -> Option<&Segment> { self.segments.iter().find(|s| s.name == name) } + + /// Find a pre-aggregation by name + pub fn get_pre_aggregation(&self, name: &str) -> Option<&PreAggregation> { + self.pre_aggregations.iter().find(|p| p.name == name) + } } #[cfg(test)] diff --git a/sidemantic-rs/src/core/parameter.rs b/sidemantic-rs/src/core/parameter.rs new file mode 100644 index 00000000..e1d19bfb --- /dev/null +++ b/sidemantic-rs/src/core/parameter.rs @@ -0,0 +1,30 @@ +//! Parameter definitions for dynamic query input. + +use serde::{Deserialize, Serialize}; + +/// Parameter type for query-time substitution. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ParameterType { + String, + Number, + Date, + Unquoted, + Yesno, +} + +/// Parameter metadata defined at graph level. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Parameter { + pub name: String, + #[serde(rename = "type")] + pub parameter_type: ParameterType, + pub description: Option, + pub label: Option, + #[serde(default)] + pub default_value: Option, + #[serde(default)] + pub allowed_values: Option>, + #[serde(default)] + pub default_to_today: bool, +} diff --git a/sidemantic-rs/src/core/symmetric_agg.rs b/sidemantic-rs/src/core/symmetric_agg.rs index 9ed09ccb..94d425ce 100644 --- a/sidemantic-rs/src/core/symmetric_agg.rs +++ b/sidemantic-rs/src/core/symmetric_agg.rs @@ -48,6 +48,8 @@ pub enum SymmetricAggType { Avg, Count, CountDistinct, + Min, + Max, } /// Build SQL for symmetric aggregate to prevent double-counting in fan-out joins. @@ -73,11 +75,28 @@ pub fn build_symmetric_aggregate_sql( model_alias: Option<&str>, dialect: SqlDialect, ) -> String { - // Add table prefix if provided - let pk_col = match model_alias { + let primary_key_expr = match model_alias { Some(alias) => format!("{alias}.{primary_key}"), None => primary_key.to_string(), }; + build_symmetric_aggregate_sql_with_key_expr( + measure_expr, + &primary_key_expr, + agg_type, + model_alias, + dialect, + ) +} + +/// Build SQL for symmetric aggregate when the deduplication key is already a SQL expression. +pub fn build_symmetric_aggregate_sql_with_key_expr( + measure_expr: &str, + primary_key_expr: &str, + agg_type: SymmetricAggType, + model_alias: Option<&str>, + dialect: SqlDialect, +) -> String { + let pk_col = primary_key_expr.to_string(); let measure_col = match model_alias { Some(alias) => format!("{alias}.{measure_expr}"), None => measure_expr.to_string(), @@ -87,43 +106,54 @@ pub fn build_symmetric_aggregate_sql( let (hash_expr, multiplier) = match dialect { SqlDialect::BigQuery => ( format!("FARM_FINGERPRINT(CAST({pk_col} AS STRING))"), - "1048576".to_string(), // 2^20 + "1000000000000".to_string(), ), SqlDialect::Postgres => ( - format!("hashtext({pk_col}::text)::bigint"), - "1024".to_string(), // 2^10 (smaller to avoid overflow) + format!("hashtext({pk_col}::text)::numeric"), + "1000000000000".to_string(), ), SqlDialect::Snowflake => ( - format!("(HASH({pk_col}) % 1000000000)"), // Modulo to constrain range - "100".to_string(), // Small multiplier + format!("HASH({pk_col})::NUMBER(38, 0)"), + "1000000000000".to_string(), ), SqlDialect::ClickHouse => ( format!("halfMD5(CAST({pk_col} AS String))"), - "1048576".to_string(), + "1000000000000".to_string(), ), SqlDialect::Databricks | SqlDialect::Spark => ( format!("xxhash64(CAST({pk_col} AS STRING))"), - "1048576".to_string(), + "1000000000000".to_string(), ), SqlDialect::DuckDB => ( format!("HASH({pk_col})::HUGEINT"), - "(1::HUGEINT << 20)".to_string(), + "(1::HUGEINT << 40)".to_string(), ), }; + let symmetric_base_expr = match dialect { + SqlDialect::DuckDB => { + format!("CAST(({hash_expr} * {multiplier}) AS DECIMAL(38, 6))") + } + _ => format!("({hash_expr} * {multiplier})"), + }; + let symmetric_measure_expr = match dialect { + SqlDialect::DuckDB => format!("CAST({measure_col} AS DECIMAL(38, 6))"), + _ => measure_col.clone(), + }; + match agg_type { SymmetricAggType::Sum => { // SUM(DISTINCT HASH(pk) * multiplier + value) - SUM(DISTINCT HASH(pk) * multiplier) format!( - "(SUM(DISTINCT ({hash_expr} * {multiplier}) + {measure_col}) - \ - SUM(DISTINCT ({hash_expr} * {multiplier})))" + "(SUM(DISTINCT {symmetric_base_expr} + {symmetric_measure_expr}) - \ + SUM(DISTINCT {symmetric_base_expr}))" ) } SymmetricAggType::Avg => { // Sum divided by distinct count let sum_expr = format!( - "(SUM(DISTINCT ({hash_expr} * {multiplier}) + {measure_col}) - \ - SUM(DISTINCT ({hash_expr} * {multiplier})))" + "(SUM(DISTINCT {symmetric_base_expr} + {symmetric_measure_expr}) - \ + SUM(DISTINCT {symmetric_base_expr}))" ); format!("{sum_expr} / NULLIF(COUNT(DISTINCT {pk_col}), 0)") } @@ -135,6 +165,8 @@ pub fn build_symmetric_aggregate_sql( // Count distinct on the measure itself - no symmetric aggregate needed format!("COUNT(DISTINCT {measure_col})") } + SymmetricAggType::Min => format!("MIN({measure_col})"), + SymmetricAggType::Max => format!("MAX({measure_col})"), } } @@ -163,7 +195,8 @@ mod tests { ); assert!(sql.contains("SUM(DISTINCT")); assert!(sql.contains("HASH(order_id)::HUGEINT")); - assert!(sql.contains("+ amount")); + assert!(sql.contains("+ CAST(amount AS DECIMAL(38, 6))")); + assert!(sql.contains("CAST((HASH(order_id)::HUGEINT * (1::HUGEINT << 40))")); } #[test] @@ -215,4 +248,41 @@ mod tests { ); assert!(sql.contains("hashtext")); } + + #[test] + fn test_min_passthrough() { + let sql = build_symmetric_aggregate_sql( + "amount", + "order_id", + SymmetricAggType::Min, + None, + SqlDialect::DuckDB, + ); + assert_eq!(sql, "MIN(amount)"); + } + + #[test] + fn test_max_passthrough_with_alias() { + let sql = build_symmetric_aggregate_sql( + "amount", + "order_id", + SymmetricAggType::Max, + Some("orders_cte"), + SqlDialect::DuckDB, + ); + assert_eq!(sql, "MAX(orders_cte.amount)"); + } + + #[test] + fn test_symmetric_sum_with_key_expression() { + let sql = build_symmetric_aggregate_sql_with_key_expr( + "amount", + "CONCAT(CAST(o.order_id AS VARCHAR), '|', CAST(o.item_id AS VARCHAR))", + SymmetricAggType::Sum, + Some("o"), + SqlDialect::DuckDB, + ); + assert!(sql.contains("HASH(CONCAT(")); + assert!(sql.contains("CAST(o.amount AS DECIMAL(38, 6))")); + } } diff --git a/sidemantic-rs/src/db/adbc.rs b/sidemantic-rs/src/db/adbc.rs new file mode 100644 index 00000000..d2865e04 --- /dev/null +++ b/sidemantic-rs/src/db/adbc.rs @@ -0,0 +1,594 @@ +use adbc_core::{ + options::{AdbcVersion, OptionConnection, OptionDatabase, OptionValue}, + Connection, Database, Driver, Statement, LOAD_FLAG_DEFAULT, +}; +use adbc_driver_manager::ManagedDriver; +use arrow_array::{ + Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, + LargeStringArray, RecordBatchReader, StringArray, Time32MillisecondArray, Time32SecondArray, + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, + UInt32Array, UInt64Array, UInt8Array, +}; +use arrow_ipc::writer::StreamWriter; +use arrow_schema::{DataType, TimeUnit}; +use std::io::Write; + +use crate::error::{Result, SidemanticError}; + +#[derive(Debug, Clone, PartialEq)] +pub enum AdbcValue { + Null, + Bool(bool), + I64(i64), + U64(u64), + F64(f64), + String(String), + Bytes(Vec), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AdbcExecutionResult { + pub columns: Vec, + pub rows: Vec>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AdbcArrowIpcResult { + pub bytes: Vec, + pub row_count: usize, +} + +#[derive(Debug, Clone)] +pub struct AdbcExecutionRequest { + pub driver: String, + pub sql: String, + pub uri: Option, + pub entrypoint: Option, + pub database_options: Vec<(OptionDatabase, OptionValue)>, + pub connection_options: Vec<(OptionConnection, OptionValue)>, +} + +fn adbc_error(context: &str, err: impl std::fmt::Display) -> SidemanticError { + SidemanticError::InvalidConfig(format!("{context}: {err}")) +} + +fn is_duckdb_driver(driver: &str, entrypoint: Option<&str>) -> bool { + driver.to_ascii_lowercase().contains("duckdb") + || entrypoint + .map(|entrypoint| entrypoint.to_ascii_lowercase().contains("duckdb")) + .unwrap_or(false) +} + +fn has_database_option(options: &[(OptionDatabase, OptionValue)], key: &str) -> bool { + options.iter().any(|(option, _)| option.as_ref() == key) +} + +fn database_options_with_uri( + driver: &str, + entrypoint: Option<&str>, + uri: Option, + mut database_options: Vec<(OptionDatabase, OptionValue)>, +) -> Vec<(OptionDatabase, OptionValue)> { + let Some(uri) = uri else { + return database_options; + }; + + if is_duckdb_driver(driver, entrypoint) { + if !has_database_option(&database_options, "path") { + database_options.push(( + OptionDatabase::Other("path".to_string()), + OptionValue::String(uri), + )); + } + } else if !has_database_option(&database_options, OptionDatabase::Uri.as_ref()) { + database_options.push((OptionDatabase::Uri, OptionValue::String(uri))); + } + + database_options +} + +pub fn execute_with_adbc(request: AdbcExecutionRequest) -> Result { + let AdbcExecutionRequest { + driver, + sql, + uri, + entrypoint, + database_options, + connection_options, + } = request; + + let database_options = + database_options_with_uri(&driver, entrypoint.as_deref(), uri, database_options); + let entrypoint_bytes = entrypoint.as_deref().map(str::as_bytes); + let mut managed_driver = ManagedDriver::load_from_name( + &driver, + entrypoint_bytes, + AdbcVersion::V110, + LOAD_FLAG_DEFAULT, + None, + ) + .or_else(|_| { + ManagedDriver::load_from_name( + &driver, + entrypoint_bytes, + AdbcVersion::V100, + LOAD_FLAG_DEFAULT, + None, + ) + }) + .map_err(|e| adbc_error("failed to load ADBC driver", e))?; + + let database = if database_options.is_empty() { + managed_driver.new_database() + } else { + managed_driver.new_database_with_opts(database_options) + } + .map_err(|e| adbc_error("failed to create ADBC database", e))?; + + let mut connection = if connection_options.is_empty() { + database.new_connection() + } else { + database.new_connection_with_opts(connection_options) + } + .map_err(|e| adbc_error("failed to create ADBC connection", e))?; + + let mut statement = connection + .new_statement() + .map_err(|e| adbc_error("failed to create ADBC statement", e))?; + statement + .set_sql_query(&sql) + .map_err(|e| adbc_error("failed to set SQL query", e))?; + let mut reader = statement + .execute() + .map_err(|e| adbc_error("failed to execute SQL query", e))?; + + let fields = reader.schema().fields().clone(); + let columns = fields + .iter() + .map(|field| field.name().to_string()) + .collect(); + + let mut rows: Vec> = Vec::new(); + for batch in &mut reader { + let batch = batch.map_err(|e| adbc_error("failed reading Arrow batch", e))?; + for row_index in 0..batch.num_rows() { + let mut values: Vec = Vec::with_capacity(batch.num_columns()); + for col_index in 0..batch.num_columns() { + let field = &fields[col_index]; + let array = batch.column(col_index); + values.push(array_cell_to_value( + array.as_ref(), + field.data_type(), + row_index, + )?); + } + rows.push(values); + } + } + + Ok(AdbcExecutionResult { columns, rows }) +} + +pub fn execute_with_adbc_arrow_ipc(request: AdbcExecutionRequest) -> Result { + let mut bytes = Vec::new(); + let row_count = write_adbc_arrow_ipc(request, &mut bytes)?; + Ok(AdbcArrowIpcResult { bytes, row_count }) +} + +pub fn write_adbc_arrow_ipc(request: AdbcExecutionRequest, writer: W) -> Result { + let AdbcExecutionRequest { + driver, + sql, + uri, + entrypoint, + database_options, + connection_options, + } = request; + + let database_options = + database_options_with_uri(&driver, entrypoint.as_deref(), uri, database_options); + let entrypoint_bytes = entrypoint.as_deref().map(str::as_bytes); + let mut managed_driver = ManagedDriver::load_from_name( + &driver, + entrypoint_bytes, + AdbcVersion::V110, + LOAD_FLAG_DEFAULT, + None, + ) + .or_else(|_| { + ManagedDriver::load_from_name( + &driver, + entrypoint_bytes, + AdbcVersion::V100, + LOAD_FLAG_DEFAULT, + None, + ) + }) + .map_err(|e| adbc_error("failed to load ADBC driver", e))?; + + let database = if database_options.is_empty() { + managed_driver.new_database() + } else { + managed_driver.new_database_with_opts(database_options) + } + .map_err(|e| adbc_error("failed to create ADBC database", e))?; + + let mut connection = if connection_options.is_empty() { + database.new_connection() + } else { + database.new_connection_with_opts(connection_options) + } + .map_err(|e| adbc_error("failed to create ADBC connection", e))?; + + let mut statement = connection + .new_statement() + .map_err(|e| adbc_error("failed to create ADBC statement", e))?; + statement + .set_sql_query(&sql) + .map_err(|e| adbc_error("failed to set SQL query", e))?; + let mut reader = statement + .execute() + .map_err(|e| adbc_error("failed to execute SQL query", e))?; + + let schema = reader.schema(); + let mut writer = StreamWriter::try_new(writer, &schema) + .map_err(|e| adbc_error("failed to create Arrow IPC writer", e))?; + let mut row_count = 0; + + for batch in &mut reader { + let batch = batch.map_err(|e| adbc_error("failed reading Arrow batch", e))?; + row_count += batch.num_rows(); + writer + .write(&batch) + .map_err(|e| adbc_error("failed writing Arrow IPC batch", e))?; + } + writer + .finish() + .map_err(|e| adbc_error("failed finishing Arrow IPC stream", e))?; + let _ = writer + .into_inner() + .map_err(|e| adbc_error("failed finishing Arrow IPC stream", e))?; + + Ok(row_count) +} + +fn decimal128_to_string(value: i128, scale: i8) -> String { + if scale <= 0 { + let multiplier = 10_i128.pow((-scale) as u32); + return (value * multiplier).to_string(); + } + + let negative = value < 0; + let digits = value.abs().to_string(); + let scale_usize = scale as usize; + + let rendered = if digits.len() <= scale_usize { + format!("0.{}{}", "0".repeat(scale_usize - digits.len()), digits) + } else { + let split = digits.len() - scale_usize; + format!("{}.{}", &digits[..split], &digits[split..]) + }; + + if negative { + format!("-{rendered}") + } else { + rendered + } +} + +fn downcast_error(ty: &str) -> SidemanticError { + SidemanticError::InvalidConfig(format!("failed to read {ty} column")) +} + +fn array_cell_to_value( + array: &dyn Array, + data_type: &DataType, + row_index: usize, +) -> Result { + if array.is_null(row_index) { + return Ok(AdbcValue::Null); + } + + match data_type { + DataType::Null => Ok(AdbcValue::Null), + DataType::Boolean => Ok(AdbcValue::Bool( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Boolean"))? + .value(row_index), + )), + DataType::Int8 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Int8"))? + .value(row_index) as i64, + )), + DataType::Int16 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Int16"))? + .value(row_index) as i64, + )), + DataType::Int32 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Int32"))? + .value(row_index) as i64, + )), + DataType::Int64 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Int64"))? + .value(row_index), + )), + DataType::UInt8 => Ok(AdbcValue::U64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("UInt8"))? + .value(row_index) as u64, + )), + DataType::UInt16 => Ok(AdbcValue::U64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("UInt16"))? + .value(row_index) as u64, + )), + DataType::UInt32 => Ok(AdbcValue::U64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("UInt32"))? + .value(row_index) as u64, + )), + DataType::UInt64 => Ok(AdbcValue::U64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("UInt64"))? + .value(row_index), + )), + DataType::Float16 => Ok(AdbcValue::F64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Float16"))? + .value(row_index) + .to_f32() as f64, + )), + DataType::Float32 => Ok(AdbcValue::F64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Float32"))? + .value(row_index) as f64, + )), + DataType::Float64 => Ok(AdbcValue::F64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Float64"))? + .value(row_index), + )), + DataType::Utf8 => Ok(AdbcValue::String( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Utf8"))? + .value(row_index) + .to_string(), + )), + DataType::LargeUtf8 => Ok(AdbcValue::String( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("LargeUtf8"))? + .value(row_index) + .to_string(), + )), + DataType::Binary => Ok(AdbcValue::Bytes( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Binary"))? + .value(row_index) + .to_vec(), + )), + DataType::LargeBinary => Ok(AdbcValue::Bytes( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("LargeBinary"))? + .value(row_index) + .to_vec(), + )), + DataType::Decimal128(_, scale) => { + let value = array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Decimal128"))? + .value(row_index); + Ok(AdbcValue::String(decimal128_to_string(value, *scale))) + } + DataType::Date32 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Date32"))? + .value(row_index) as i64, + )), + DataType::Date64 => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Date64"))? + .value(row_index), + )), + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Timestamp(second)"))? + .value(row_index), + )), + TimeUnit::Millisecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Timestamp(millisecond)"))? + .value(row_index), + )), + TimeUnit::Microsecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Timestamp(microsecond)"))? + .value(row_index), + )), + TimeUnit::Nanosecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Timestamp(nanosecond)"))? + .value(row_index), + )), + }, + DataType::Time32(unit) => match unit { + TimeUnit::Second => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Time32(second)"))? + .value(row_index) as i64, + )), + TimeUnit::Millisecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Time32(millisecond)"))? + .value(row_index) as i64, + )), + _ => Err(SidemanticError::InvalidConfig( + "unsupported Time32 unit in Rust ADBC executor".to_string(), + )), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Time64(microsecond)"))? + .value(row_index), + )), + TimeUnit::Nanosecond => Ok(AdbcValue::I64( + array + .as_any() + .downcast_ref::() + .ok_or_else(|| downcast_error("Time64(nanosecond)"))? + .value(row_index), + )), + _ => Err(SidemanticError::InvalidConfig( + "unsupported Time64 unit in Rust ADBC executor".to_string(), + )), + }, + _ => Err(SidemanticError::InvalidConfig(format!( + "unsupported Arrow datatype in Rust ADBC executor: {data_type:?}" + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::DataType; + + fn assert_single_string_database_option( + options: &[(OptionDatabase, OptionValue)], + key: &str, + expected_value: &str, + ) { + assert_eq!(options.len(), 1); + assert_eq!(options[0].0.as_ref(), key); + let OptionValue::String(actual_value) = &options[0].1 else { + panic!("expected string option value, got {:?}", options[0].1); + }; + assert_eq!(actual_value, expected_value); + } + + #[test] + fn test_duckdb_uri_maps_to_path_database_option() { + let options = database_options_with_uri( + "/tmp/libduckdb.so", + Some("duckdb_adbc_init"), + Some("/tmp/warehouse.duckdb".to_string()), + Vec::new(), + ); + + assert_single_string_database_option(&options, "path", "/tmp/warehouse.duckdb"); + } + + #[test] + fn test_duckdb_uri_preserves_explicit_path_option() { + let options = database_options_with_uri( + "adbc_driver_duckdb", + None, + Some("/tmp/ignored.duckdb".to_string()), + vec![( + OptionDatabase::Other("path".to_string()), + OptionValue::String("/tmp/explicit.duckdb".to_string()), + )], + ); + + assert_single_string_database_option(&options, "path", "/tmp/explicit.duckdb"); + } + + #[test] + fn test_non_duckdb_uri_uses_canonical_uri_option() { + let options = database_options_with_uri( + "adbc_driver_sqlite", + None, + Some(":memory:".to_string()), + Vec::new(), + ); + + assert_single_string_database_option(&options, OptionDatabase::Uri.as_ref(), ":memory:"); + } + + #[test] + fn test_decimal128_to_string() { + assert_eq!(decimal128_to_string(12345, 2), "123.45"); + assert_eq!(decimal128_to_string(-12345, 2), "-123.45"); + assert_eq!(decimal128_to_string(15, 4), "0.0015"); + } + + #[test] + fn test_array_cell_to_value_int32_and_null() { + let array = Int32Array::from(vec![Some(7), None]); + assert_eq!( + array_cell_to_value(&array, &DataType::Int32, 0).unwrap(), + AdbcValue::I64(7) + ); + assert_eq!( + array_cell_to_value(&array, &DataType::Int32, 1).unwrap(), + AdbcValue::Null + ); + } + + #[test] + fn test_array_cell_to_value_binary() { + let array = BinaryArray::from(vec![Some(b"abc".as_slice())]); + assert_eq!( + array_cell_to_value(&array, &DataType::Binary, 0).unwrap(), + AdbcValue::Bytes(b"abc".to_vec()) + ); + } +} diff --git a/sidemantic-rs/src/db/mod.rs b/sidemantic-rs/src/db/mod.rs new file mode 100644 index 00000000..96904828 --- /dev/null +++ b/sidemantic-rs/src/db/mod.rs @@ -0,0 +1,8 @@ +#[cfg(feature = "adbc-exec")] +mod adbc; + +#[cfg(feature = "adbc-exec")] +pub use adbc::{ + execute_with_adbc, execute_with_adbc_arrow_ipc, write_adbc_arrow_ipc, AdbcArrowIpcResult, + AdbcExecutionRequest, AdbcExecutionResult, AdbcValue, +}; diff --git a/sidemantic-rs/src/ffi.rs b/sidemantic-rs/src/ffi.rs index fbef8fa7..68713137 100644 --- a/sidemantic-rs/src/ffi.rs +++ b/sidemantic-rs/src/ffi.rs @@ -6,6 +6,7 @@ //! Callers must ensure pointers are valid. Documented in header. #![allow(clippy::not_unsafe_ptr_arg_deref)] +use std::collections::{HashMap, HashSet}; use std::ffi::{CStr, CString}; use std::fs::{self, OpenOptions}; use std::io::Write; @@ -20,11 +21,17 @@ use crate::config::{load_from_directory, load_from_file, load_from_string, parse use crate::core::SemanticGraph; use crate::sql::QueryRewriter; -/// Global semantic graph state (thread-safe) -static SEMANTIC_GRAPH: Lazy> = Lazy::new(|| Mutex::new(SemanticGraph::new())); +const DEFAULT_CONTEXT_KEY: &str = "__sidemantic_default_context__"; -/// Active model for METRIC/DIMENSION/SEGMENT additions (set by CREATE MODEL or USE) -static ACTIVE_MODEL: Lazy>> = Lazy::new(|| Mutex::new(None)); +#[derive(Default)] +struct FfiState { + graph: SemanticGraph, + active_model: Option, +} + +/// Semantic graph state keyed by DuckDB database/session context. +static FFI_STATES: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); /// Result from rewrite operation #[repr(C)] @@ -37,29 +44,68 @@ pub struct SidemanticRewriteResult { pub was_rewritten: bool, } +fn c_string_arg(ptr: *const c_char, name: &str) -> std::result::Result { + if ptr.is_null() { + return Err(to_c_string(&format!("Error: null {name} pointer"))); + } + + unsafe { + CStr::from_ptr(ptr) + .to_str() + .map(str::to_string) + .map_err(|e| to_c_string(&format!("Error: invalid UTF-8: {e}"))) + } +} + +fn context_key(context: *const c_char) -> std::result::Result { + if context.is_null() { + return Ok(DEFAULT_CONTEXT_KEY.to_string()); + } + + let raw = unsafe { + CStr::from_ptr(context) + .to_str() + .map_err(|e| to_c_string(&format!("Error: invalid UTF-8: {e}")))? + }; + let trimmed = raw.trim(); + if trimmed.is_empty() { + Ok(DEFAULT_CONTEXT_KEY.to_string()) + } else { + Ok(trimmed.to_string()) + } +} + /// Load semantic models from YAML string /// /// Returns null on success, error message on failure. /// Caller must free the returned string with `sidemantic_free`. #[no_mangle] pub extern "C" fn sidemantic_load_yaml(yaml: *const c_char) -> *mut c_char { - if yaml.is_null() { - return to_c_string("Error: null yaml pointer"); - } + sidemantic_load_yaml_for_context(ptr::null(), yaml) +} - let yaml_str = unsafe { - match CStr::from_ptr(yaml).to_str() { - Ok(s) => s, - Err(e) => return to_c_string(&format!("Error: invalid UTF-8: {e}")), - } +/// Load semantic models from YAML string into a context-keyed graph. +#[no_mangle] +pub extern "C" fn sidemantic_load_yaml_for_context( + context: *const c_char, + yaml: *const c_char, +) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let yaml_str = match c_string_arg(yaml, "yaml") { + Ok(value) => value, + Err(error) => return error, }; - match load_from_string(yaml_str) { + match load_from_string(&yaml_str) { Ok(new_graph) => { - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); - // Merge new models into existing graph + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); + // Merge new models into existing graph, replacing same-name definitions. for model in new_graph.models() { - if let Err(e) = graph.add_model(model.clone()) { + if let Err(e) = state.graph.replace_model(model.clone()) { return to_c_string(&format!("Error adding model: {e}")); } } @@ -75,18 +121,25 @@ pub extern "C" fn sidemantic_load_yaml(yaml: *const c_char) -> *mut c_char { /// Caller must free the returned string with `sidemantic_free`. #[no_mangle] pub extern "C" fn sidemantic_load_file(path: *const c_char) -> *mut c_char { - if path.is_null() { - return to_c_string("Error: null path pointer"); - } + sidemantic_load_file_for_context(ptr::null(), path) +} - let path_str = unsafe { - match CStr::from_ptr(path).to_str() { - Ok(s) => s, - Err(e) => return to_c_string(&format!("Error: invalid UTF-8: {e}")), - } +/// Load semantic models from a file or directory into a context-keyed graph. +#[no_mangle] +pub extern "C" fn sidemantic_load_file_for_context( + context: *const c_char, + path: *const c_char, +) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let path_str = match c_string_arg(path, "path") { + Ok(value) => value, + Err(error) => return error, }; - let path = Path::new(path_str); + let path = Path::new(&path_str); // Check if path exists if !path.exists() { @@ -101,10 +154,11 @@ pub extern "C" fn sidemantic_load_file(path: *const c_char) -> *mut c_char { match result { Ok(new_graph) => { - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); - // Merge new models into existing graph + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); + // Merge new models into existing graph, replacing same-name definitions. for model in new_graph.models() { - if let Err(e) = graph.add_model(model.clone()) { + if let Err(e) = state.graph.replace_model(model.clone()) { return to_c_string(&format!("Error adding model: {e}")); } } @@ -117,8 +171,17 @@ pub extern "C" fn sidemantic_load_file(path: *const c_char) -> *mut c_char { /// Clear all loaded semantic models #[no_mangle] pub extern "C" fn sidemantic_clear() { - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); - *graph = SemanticGraph::new(); + sidemantic_clear_for_context(ptr::null()); +} + +/// Clear all loaded semantic models for one context. +#[no_mangle] +pub extern "C" fn sidemantic_clear_for_context(context: *const c_char) { + let Ok(key) = context_key(context) else { + return; + }; + let mut states = FFI_STATES.lock().unwrap(); + states.insert(key, FfiState::default()); } /// Define a semantic model from SQL definition format @@ -134,19 +197,28 @@ pub extern "C" fn sidemantic_define( db_path: *const c_char, replace: bool, ) -> *mut c_char { - if definition_sql.is_null() { - return to_c_string("Error: null definition_sql pointer"); - } + sidemantic_define_for_context(ptr::null(), definition_sql, db_path, replace) +} - let sql_str = unsafe { - match CStr::from_ptr(definition_sql).to_str() { - Ok(s) => s, - Err(e) => return to_c_string(&format!("Error: invalid UTF-8: {e}")), - } +/// Define a semantic model in a context-keyed graph. +#[no_mangle] +pub extern "C" fn sidemantic_define_for_context( + context: *const c_char, + definition_sql: *const c_char, + db_path: *const c_char, + replace: bool, +) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let sql_str = match c_string_arg(definition_sql, "definition_sql") { + Ok(value) => value, + Err(error) => return error, }; // Parse the definition to validate and get model name - let model = match parse_sql_model(sql_str) { + let model = match parse_sql_model(&sql_str) { Ok(m) => m, Err(e) => return to_c_string(&format!("Error parsing definition: {e}")), }; @@ -158,51 +230,60 @@ pub extern "C" fn sidemantic_define( // Handle OR REPLACE: read existing file, remove model if exists if replace { - if let Err(e) = remove_model_from_file(&definitions_path, &model_name) { - return to_c_string(&format!("Error removing existing model: {e}")); + if let Some(definitions_path) = definitions_path.as_ref() { + if let Err(e) = remove_model_from_file(definitions_path, &model_name) { + return to_c_string(&format!("Error removing existing model: {e}")); + } } } // Append definition to file - if let Err(e) = append_definition_to_file(&definitions_path, sql_str) { - return to_c_string(&format!("Error writing to definitions file: {e}")); + if let Some(definitions_path) = definitions_path.as_ref() { + if let Err(e) = append_definition_to_file(definitions_path, &sql_str) { + return to_c_string(&format!("Error writing to definitions file: {e}")); + } } // Load model into current session - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); - if let Err(e) = graph.add_model(model) { + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); + let result = if replace { + state.graph.replace_model(model) + } else { + state.graph.add_model(model) + }; + if let Err(e) = result { return to_c_string(&format!("Error adding model to session: {e}")); } // Set this model as the active model for subsequent METRIC/DIMENSION additions - *ACTIVE_MODEL.lock().unwrap() = Some(model_name); + state.active_model = Some(model_name); ptr::null_mut() // Success } /// Get the definitions file path based on database path -fn get_definitions_path(db_path: *const c_char) -> PathBuf { +fn get_definitions_path(db_path: *const c_char) -> Option { if db_path.is_null() { - // In-memory database: use current directory - return PathBuf::from("./sidemantic_definitions.sql"); + return None; } let path_str = unsafe { match CStr::from_ptr(db_path).to_str() { - Ok(s) => s, - Err(_) => return PathBuf::from("./sidemantic_definitions.sql"), + Ok(s) => s.trim(), + Err(_) => return None, } }; if path_str.is_empty() || path_str == ":memory:" { - return PathBuf::from("./sidemantic_definitions.sql"); + return None; } // Replace .duckdb extension with .sidemantic.sql let db_path = Path::new(path_str); let stem = db_path.file_stem().unwrap_or_default(); let parent = db_path.parent().unwrap_or(Path::new(".")); - parent.join(format!("{}.sidemantic.sql", stem.to_string_lossy())) + Some(parent.join(format!("{}.sidemantic.sql", stem.to_string_lossy()))) } /// Remove a model definition from the file by name @@ -213,53 +294,287 @@ fn remove_model_from_file(path: &Path, model_name: &str) -> std::io::Result<()> let content = fs::read_to_string(path)?; let mut result = String::new(); - let mut skip_until_next_model = false; - let model_pattern = "MODEL".to_string(); - let name_pattern = format!("name {model_name}"); - let name_pattern_comma = format!("name {model_name},"); - - for line in content.lines() { - let line_trimmed = line.trim().to_uppercase(); - - // Check if this is a MODEL statement - if line_trimmed.starts_with(&model_pattern) { - // Check if this model has the name we're looking for - let line_lower = line.to_lowercase(); - if line_lower.contains(&name_pattern.to_lowercase()) - || line_lower.contains(&name_pattern_comma.to_lowercase()) - { - skip_until_next_model = true; - continue; + + let mut cursor = 0; + for (start, end) in model_definition_ranges(&content) { + result.push_str(&content[cursor..start]); + + let block = &content[start..end]; + let should_remove = parse_sql_model(block) + .map(|model| model.name == model_name) + .unwrap_or(false); + + if !should_remove { + result.push_str(block); + } + + cursor = end; + } + result.push_str(&content[cursor..]); + + fs::write(path, result.trim_end())?; + Ok(()) +} + +fn model_definition_ranges(content: &str) -> Vec<(usize, usize)> { + let mut starts = Vec::new(); + let content_upper = content.to_uppercase(); + let mut search_start = 0; + + while let Some(pos) = content_upper[search_start..].find("MODEL") { + let actual_pos = search_start + pos; + let is_start = + actual_pos == 0 || !content.as_bytes()[actual_pos - 1].is_ascii_alphanumeric(); + let is_followed_by_boundary = actual_pos + 5 >= content.len() + || matches!( + content.as_bytes()[actual_pos + 5], + b' ' | b'(' | b'\t' | b'\n' + ); + + if is_start && is_followed_by_boundary { + starts.push(actual_pos); + } + + search_start = actual_pos + 1; + } + + starts + .iter() + .enumerate() + .map(|(index, start)| { + let end = starts.get(index + 1).copied().unwrap_or(content.len()); + (*start, end) + }) + .collect() +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum DefinitionKind { + Metric, + Dimension, + Segment, +} + +fn definition_kind(sql: &str) -> Option { + if starts_with_definition_keyword(sql, "METRIC") { + Some(DefinitionKind::Metric) + } else if starts_with_definition_keyword(sql, "DIMENSION") { + Some(DefinitionKind::Dimension) + } else if starts_with_definition_keyword(sql, "SEGMENT") { + Some(DefinitionKind::Segment) + } else { + None + } +} + +fn starts_with_definition_keyword(sql: &str, keyword: &str) -> bool { + let trimmed = sql.trim_start(); + if trimmed.len() < keyword.len() { + return false; + } + + let prefix = &trimmed[..keyword.len()]; + if !prefix.eq_ignore_ascii_case(keyword) { + return false; + } + + trimmed[keyword.len()..] + .chars() + .next() + .map(|ch| ch.is_whitespace()) + .unwrap_or(true) +} + +fn statement_ranges(block: &str) -> Vec<(usize, usize)> { + let mut ranges = Vec::new(); + let mut start = None; + let mut in_single_quote = false; + let mut in_double_quote = false; + + for (idx, ch) in block.char_indices() { + if start.is_none() && !ch.is_whitespace() { + start = Some(idx); + } + + if in_single_quote { + if ch == '\'' { + in_single_quote = false; } - skip_until_next_model = false; + continue; + } + if in_double_quote { + if ch == '"' { + in_double_quote = false; + } + continue; } - // If we encounter another statement type, stop skipping - if skip_until_next_model - && (line_trimmed.starts_with("MODEL") - || line_trimmed.starts_with("--") - || line_trimmed.is_empty()) - { - if line_trimmed.starts_with("MODEL") - && !line.to_lowercase().contains(&name_pattern.to_lowercase()) - { - skip_until_next_model = false; - } else if line_trimmed.is_empty() || line_trimmed.starts_with("--") { - // Skip empty lines and comments between removed statements - continue; + match ch { + '\'' => in_single_quote = true, + '"' => in_double_quote = true, + ';' => { + if let Some(statement_start) = start.take() { + ranges.push((statement_start, idx + ch.len_utf8())); + } } + _ => {} + } + } + + if let Some(statement_start) = start { + if !block[statement_start..].trim().is_empty() { + ranges.push((statement_start, block.len())); } + } + + ranges +} + +fn persist_model_item_definition_to_file( + path: &Path, + model_name: &str, + kind: DefinitionKind, + item_names: &[String], + definition: &str, + is_replace: bool, +) -> std::io::Result<()> { + if !path.exists() { + return append_definition_to_file(path, definition); + } + + let content = fs::read_to_string(path)?; + let item_names: HashSet<&str> = item_names.iter().map(String::as_str).collect(); + let (_, adjusted_definition) = extract_model_prefix(definition.trim()); + let mut result = String::new(); + let mut cursor = 0; + let mut inserted = false; + + for (start, end) in model_definition_ranges(&content) { + result.push_str(&content[cursor..start]); + + let block = &content[start..end]; + let block_model_name = parse_sql_model(block).ok().map(|model| model.name); + let cleaned = if is_replace { + remove_item_definitions_from_block( + block, + block_model_name.as_deref(), + model_name, + kind, + &item_names, + ) + } else { + block.to_string() + }; - if !skip_until_next_model { - result.push_str(line); - result.push('\n'); + if block_model_name.as_deref() == Some(model_name) { + result.push_str(&insert_definition_at_block_end( + &cleaned, + &adjusted_definition, + )); + inserted = true; + } else { + result.push_str(&cleaned); } + + cursor = end; } + result.push_str(&content[cursor..]); fs::write(path, result.trim_end())?; + + if !inserted { + append_definition_to_file(path, definition)?; + } + Ok(()) } +fn remove_item_definitions_from_block( + block: &str, + block_model_name: Option<&str>, + target_model_name: &str, + kind: DefinitionKind, + item_names: &HashSet<&str>, +) -> String { + if item_names.is_empty() { + return block.to_string(); + } + + let mut result = String::new(); + let mut cursor = 0; + + for (start, end) in statement_ranges(block) { + result.push_str(&block[cursor..start]); + let statement = &block[start..end]; + if !should_remove_item_statement( + statement, + block_model_name, + target_model_name, + kind, + item_names, + ) { + result.push_str(statement); + } + cursor = end; + } + result.push_str(&block[cursor..]); + + result +} + +fn should_remove_item_statement( + statement: &str, + block_model_name: Option<&str>, + target_model_name: &str, + kind: DefinitionKind, + item_names: &HashSet<&str>, +) -> bool { + if definition_kind(statement) != Some(kind) { + return false; + } + + let (explicit_model, adjusted_statement) = extract_model_prefix(statement.trim()); + let belongs_to_target = explicit_model + .as_deref() + .map(|model| model == target_model_name) + .unwrap_or(block_model_name == Some(target_model_name)); + if !belongs_to_target { + return false; + } + + let dummy_sql = format!("MODEL (name {target_model_name}, table dummy);\n{adjusted_statement}"); + let Ok(parsed) = parse_sql_model(&dummy_sql) else { + return false; + }; + + match kind { + DefinitionKind::Metric => parsed + .metrics + .iter() + .any(|metric| item_names.contains(metric.name.as_str())), + DefinitionKind::Dimension => parsed + .dimensions + .iter() + .any(|dimension| item_names.contains(dimension.name.as_str())), + DefinitionKind::Segment => parsed + .segments + .iter() + .any(|segment| item_names.contains(segment.name.as_str())), + } +} + +fn insert_definition_at_block_end(block: &str, definition: &str) -> String { + let trimmed_len = block.trim_end().len(); + let (body, trailing) = block.split_at(trimmed_len); + let trimmed_definition = definition.trim(); + + if body.is_empty() { + return format!("{trimmed_definition}{trailing}"); + } + + format!("{body}\n\n{trimmed_definition}{trailing}") +} + /// Append a definition to the file fn append_definition_to_file(path: &Path, definition: &str) -> std::io::Result<()> { let mut file = OpenOptions::new().create(true).append(true).open(path)?; @@ -280,7 +595,22 @@ fn append_definition_to_file(path: &Path, definition: &str) -> std::io::Result<( /// Caller must free the returned string with `sidemantic_free`. #[no_mangle] pub extern "C" fn sidemantic_autoload(db_path: *const c_char) -> *mut c_char { - let definitions_path = get_definitions_path(db_path); + sidemantic_autoload_for_context(ptr::null(), db_path) +} + +/// Load persisted definitions into a context-keyed graph if they exist. +#[no_mangle] +pub extern "C" fn sidemantic_autoload_for_context( + context: *const c_char, + db_path: *const c_char, +) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let Some(definitions_path) = get_definitions_path(db_path) else { + return ptr::null_mut(); + }; if !definitions_path.exists() { return ptr::null_mut(); // No file to load, success @@ -298,15 +628,18 @@ pub extern "C" fn sidemantic_autoload(db_path: *const c_char) -> *mut c_char { // Parse each model definition in the file // Split on MODEL keyword to handle multiple definitions - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); + let mut last_model_name = None; for block in split_definitions(&content) { if block.trim().is_empty() { continue; } match parse_sql_model(block) { Ok(model) => { - if let Err(e) = graph.add_model(model) { + last_model_name = Some(model.name.clone()); + if let Err(e) = state.graph.replace_model(model) { return to_c_string(&format!("Error loading model: {e}")); } } @@ -317,13 +650,15 @@ pub extern "C" fn sidemantic_autoload(db_path: *const c_char) -> *mut c_char { } } + state.active_model = last_model_name; + ptr::null_mut() // Success } /// Split content into individual model definitions fn split_definitions(content: &str) -> Vec<&str> { let mut definitions = Vec::new(); - let mut start = 0; + let mut start = None; // Find each MODEL keyword and split there let content_upper = content.to_uppercase(); @@ -342,17 +677,17 @@ fn split_definitions(content: &str) -> Vec<&str> { || content.as_bytes()[actual_pos + 5] == b'\n'); if is_start && is_followed_by_space { - if start < actual_pos && start > 0 { - definitions.push(&content[start..actual_pos]); + if let Some(previous_start) = start { + definitions.push(&content[previous_start..actual_pos]); } - start = actual_pos; + start = Some(actual_pos); } search_start = actual_pos + 1; } // Don't forget the last definition - if start < content.len() { + if let Some(start) = start { definitions.push(&content[start..]); } @@ -378,24 +713,34 @@ pub extern "C" fn sidemantic_add_definition( db_path: *const c_char, is_replace: bool, ) -> *mut c_char { - use crate::config::parse_sql_model; + sidemantic_add_definition_for_context(ptr::null(), definition_sql, db_path, is_replace) +} - if definition_sql.is_null() { - return to_c_string("Error: null definition_sql pointer"); - } +/// Add a metric/dimension/segment in a context-keyed graph. +#[no_mangle] +pub extern "C" fn sidemantic_add_definition_for_context( + context: *const c_char, + definition_sql: *const c_char, + db_path: *const c_char, + is_replace: bool, +) -> *mut c_char { + use crate::config::parse_sql_model; - let sql_str = unsafe { - match CStr::from_ptr(definition_sql).to_str() { - Ok(s) => s, - Err(e) => return to_c_string(&format!("Error: invalid UTF-8: {e}")), - } + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let sql_str = match c_string_arg(definition_sql, "definition_sql") { + Ok(value) => value, + Err(error) => return error, }; // Parse to determine what type it is and extract properties let sql_trimmed = sql_str.trim(); let sql_upper = sql_trimmed.to_uppercase(); - let mut graph = SEMANTIC_GRAPH.lock().unwrap(); + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); // Check for model.name syntax: "METRIC model.name (...)" or "DIMENSION model.name (...)" // Extract model name if present, otherwise use ACTIVE_MODEL @@ -403,18 +748,17 @@ pub extern "C" fn sidemantic_add_definition( let model_name = if let Some(explicit_model) = target_model_name { // Verify the model exists - if graph.get_model(&explicit_model).is_none() { + if state.graph.get_model(&explicit_model).is_none() { return to_c_string(&format!("Error: model '{explicit_model}' not found")); } explicit_model } else { // Use ACTIVE_MODEL or fall back to last model - let active = ACTIVE_MODEL.lock().unwrap(); - if let Some(ref name) = *active { + if let Some(ref name) = state.active_model { name.clone() } else { // Fall back to last model - let model_names: Vec = graph.models().map(|m| m.name.clone()).collect(); + let model_names: Vec = state.graph.models().map(|m| m.name.clone()).collect(); if model_names.is_empty() { return to_c_string("Error: no model defined yet. Create a model first with SEMANTIC CREATE MODEL, or use SEMANTIC USE ."); } @@ -423,7 +767,7 @@ pub extern "C" fn sidemantic_add_definition( }; // Get the model to modify - let model = match graph.get_model(&model_name) { + let model = match state.graph.get_model(&model_name) { Some(m) => m.clone(), None => return to_c_string(&format!("Error: could not find model '{model_name}'")), }; @@ -438,41 +782,63 @@ pub extern "C" fn sidemantic_add_definition( // Extract what was added and update the model let mut updated_model = model.clone(); + let mut persisted_kind = None; + let mut persisted_item_names = Vec::new(); + if sql_upper.starts_with("METRIC") { + persisted_kind = Some(DefinitionKind::Metric); for metric in parsed.metrics { if is_replace { // Remove existing metric with same name updated_model.metrics.retain(|m| m.name != metric.name); } + persisted_item_names.push(metric.name.clone()); updated_model.metrics.push(metric); } } else if sql_upper.starts_with("DIMENSION") { + persisted_kind = Some(DefinitionKind::Dimension); for dim in parsed.dimensions { if is_replace { // Remove existing dimension with same name updated_model.dimensions.retain(|d| d.name != dim.name); } + persisted_item_names.push(dim.name.clone()); updated_model.dimensions.push(dim); } } else if sql_upper.starts_with("SEGMENT") { + persisted_kind = Some(DefinitionKind::Segment); for seg in parsed.segments { if is_replace { // Remove existing segment with same name updated_model.segments.retain(|s| s.name != seg.name); } + persisted_item_names.push(seg.name.clone()); updated_model.segments.push(seg); } } - // add_model will overwrite since it uses HashMap::insert - if let Err(e) = graph.add_model(updated_model) { + if let Err(e) = state.graph.replace_model(updated_model) { return to_c_string(&format!("Error updating model: {e}")); } - // Append to definitions file - let definitions_path = get_definitions_path(db_path); - if let Err(e) = append_definition_to_file(&definitions_path, sql_str) { - return to_c_string(&format!("Error writing to definitions file: {e}")); + // Persist the definition with the owning model so autoload sees the same graph. + if let Some(definitions_path) = get_definitions_path(db_path) { + let result = if let Some(kind) = persisted_kind { + persist_model_item_definition_to_file( + &definitions_path, + &model_name, + kind, + &persisted_item_names, + &sql_str, + is_replace, + ) + } else { + append_definition_to_file(&definitions_path, &sql_str) + }; + + if let Err(e) = result { + return to_c_string(&format!("Error writing to definitions file: {e}")); + } } ptr::null_mut() // Success @@ -536,15 +902,22 @@ fn extract_model_prefix(sql: &str) -> (Option, String) { /// Returns null on success, error message on failure. #[no_mangle] pub extern "C" fn sidemantic_use(model_name: *const c_char) -> *mut c_char { - if model_name.is_null() { - return to_c_string("Error: null model_name pointer"); - } + sidemantic_use_for_context(ptr::null(), model_name) +} - let name_str = unsafe { - match CStr::from_ptr(model_name).to_str() { - Ok(s) => s, - Err(e) => return to_c_string(&format!("Error: invalid UTF-8: {e}")), - } +/// Set the active model for one context. +#[no_mangle] +pub extern "C" fn sidemantic_use_for_context( + context: *const c_char, + model_name: *const c_char, +) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let name_str = match c_string_arg(model_name, "model_name") { + Ok(value) => value, + Err(error) => return error, }; let name = name_str.trim(); @@ -553,9 +926,10 @@ pub extern "C" fn sidemantic_use(model_name: *const c_char) -> *mut c_char { } // Verify the model exists - let graph = SEMANTIC_GRAPH.lock().unwrap(); - if graph.get_model(name).is_none() { - let available: Vec<&str> = graph.models().map(|m| m.name.as_str()).collect(); + let mut states = FFI_STATES.lock().unwrap(); + let state = states.entry(key).or_default(); + if state.graph.get_model(name).is_none() { + let available: Vec<&str> = state.graph.models().map(|m| m.name.as_str()).collect(); return to_c_string(&format!( "Error: model '{}' not found. Available models: {}", name, @@ -566,10 +940,9 @@ pub extern "C" fn sidemantic_use(model_name: *const c_char) -> *mut c_char { } )); } - drop(graph); // Release lock before acquiring ACTIVE_MODEL lock // Set active model - *ACTIVE_MODEL.lock().unwrap() = Some(name.to_string()); + state.active_model = Some(name.to_string()); ptr::null_mut() // Success } @@ -577,19 +950,27 @@ pub extern "C" fn sidemantic_use(model_name: *const c_char) -> *mut c_char { /// Check if a table name is a registered semantic model #[no_mangle] pub extern "C" fn sidemantic_is_model(table_name: *const c_char) -> bool { - if table_name.is_null() { - return false; - } + sidemantic_is_model_for_context(ptr::null(), table_name) +} - let name = unsafe { - match CStr::from_ptr(table_name).to_str() { - Ok(s) => s, - Err(_) => return false, - } +/// Check if a table name is a registered semantic model in one context. +#[no_mangle] +pub extern "C" fn sidemantic_is_model_for_context( + context: *const c_char, + table_name: *const c_char, +) -> bool { + let Ok(key) = context_key(context) else { + return false; + }; + let Ok(name) = c_string_arg(table_name, "table_name") else { + return false; }; - let graph = SEMANTIC_GRAPH.lock().unwrap(); - graph.get_model(name).is_some() + let states = FFI_STATES.lock().unwrap(); + states + .get(&key) + .map(|state| state.graph.get_model(&name).is_some()) + .unwrap_or(false) } /// Get list of registered model names (comma-separated) @@ -597,8 +978,21 @@ pub extern "C" fn sidemantic_is_model(table_name: *const c_char) -> bool { /// Caller must free the returned string with `sidemantic_free`. #[no_mangle] pub extern "C" fn sidemantic_list_models() -> *mut c_char { - let graph = SEMANTIC_GRAPH.lock().unwrap(); - let names: Vec<&str> = graph.models().map(|m| m.name.as_str()).collect(); + sidemantic_list_models_for_context(ptr::null()) +} + +/// Get list of registered model names for one context. +#[no_mangle] +pub extern "C" fn sidemantic_list_models_for_context(context: *const c_char) -> *mut c_char { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => return error, + }; + let states = FFI_STATES.lock().unwrap(); + let names: Vec<&str> = states + .get(&key) + .map(|state| state.graph.models().map(|m| m.name.as_str()).collect()) + .unwrap_or_default(); to_c_string(&names.join(",")) } @@ -607,42 +1001,59 @@ pub extern "C" fn sidemantic_list_models() -> *mut c_char { /// Returns a SidemanticRewriteResult struct. Caller must free with `sidemantic_free_result`. #[no_mangle] pub extern "C" fn sidemantic_rewrite(sql: *const c_char) -> SidemanticRewriteResult { - if sql.is_null() { - return SidemanticRewriteResult { - sql: ptr::null_mut(), - error: to_c_string("Error: null sql pointer"), - was_rewritten: false, - }; - } + sidemantic_rewrite_for_context(ptr::null(), sql) +} - let sql_str = unsafe { - match CStr::from_ptr(sql).to_str() { - Ok(s) => s, - Err(e) => { - return SidemanticRewriteResult { - sql: ptr::null_mut(), - error: to_c_string(&format!("Error: invalid UTF-8: {e}")), - was_rewritten: false, - } +/// Rewrite a SQL query using semantic definitions from one context. +#[no_mangle] +pub extern "C" fn sidemantic_rewrite_for_context( + context: *const c_char, + sql: *const c_char, +) -> SidemanticRewriteResult { + let key = match context_key(context) { + Ok(key) => key, + Err(error) => { + return SidemanticRewriteResult { + sql: ptr::null_mut(), + error, + was_rewritten: false, + } + } + }; + + let sql_str = match c_string_arg(sql, "sql") { + Ok(value) => value, + Err(error) => { + return SidemanticRewriteResult { + sql: ptr::null_mut(), + error, + was_rewritten: false, } } }; - let graph = SEMANTIC_GRAPH.lock().unwrap(); + let states = FFI_STATES.lock().unwrap(); + let Some(state) = states.get(&key) else { + return SidemanticRewriteResult { + sql: to_c_string(&sql_str), + error: ptr::null_mut(), + was_rewritten: false, + }; + }; // Check if query references any semantic models - if !query_references_models(sql_str, &graph) { + if !query_references_models(&sql_str, &state.graph) { // Passthrough - not a semantic query return SidemanticRewriteResult { - sql: to_c_string(sql_str), + sql: to_c_string(&sql_str), error: ptr::null_mut(), was_rewritten: false, }; } // Rewrite the query - let rewriter = QueryRewriter::new(&graph); - match rewriter.rewrite(sql_str) { + let rewriter = QueryRewriter::new(&state.graph); + match rewriter.rewrite(&sql_str) { Ok(rewritten) => SidemanticRewriteResult { sql: to_c_string(&rewritten), error: ptr::null_mut(), @@ -707,9 +1118,56 @@ fn query_references_models(sql: &str, graph: &SemanticGraph) -> bool { mod tests { use super::*; use std::ffi::CString; + use std::sync::{Mutex, MutexGuard}; + use std::time::{SystemTime, UNIX_EPOCH}; + + static TEST_MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); + + fn test_lock() -> MutexGuard<'static, ()> { + TEST_MUTEX.lock().unwrap() + } + + fn assert_success(result: *mut c_char) { + if result.is_null() { + return; + } + + let message = unsafe { CStr::from_ptr(result).to_string_lossy().into_owned() }; + sidemantic_free(result); + panic!("{message}"); + } + + fn take_error(result: *mut c_char) -> String { + assert!(!result.is_null()); + let message = unsafe { CStr::from_ptr(result).to_string_lossy().into_owned() }; + sidemantic_free(result); + message + } + + fn take_rewrite_sql(result: SidemanticRewriteResult) -> String { + assert!(result.error.is_null()); + let sql = unsafe { CStr::from_ptr(result.sql).to_string_lossy().into_owned() }; + sidemantic_free_result(result); + sql + } + + fn unique_db_path(name: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + std::env::temp_dir().join(format!("sidemantic_{name}_{nanos}.duckdb")) + } + + fn remove_definitions_file(db_path: &CString) { + if let Some(definitions_path) = get_definitions_path(db_path.as_ptr()) { + let _ = fs::remove_file(definitions_path); + } + } #[test] fn test_load_and_rewrite() { + let _guard = test_lock(); // Clear any existing state sidemantic_clear(); @@ -753,6 +1211,7 @@ models: #[test] fn test_passthrough() { + let _guard = test_lock(); sidemantic_clear(); // Query without semantic models should pass through @@ -764,4 +1223,544 @@ models: sidemantic_free_result(result); } + + #[test] + fn test_define_replace_updates_in_memory_graph() { + let _guard = test_lock(); + sidemantic_clear(); + + let db_path = unique_db_path("define_replace"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path); + + let first = + CString::new("MODEL (name orders, table orders, primary_key order_id);").unwrap(); + assert_success(sidemantic_define(first.as_ptr(), db_path.as_ptr(), false)); + + let replacement = CString::new( + "MODEL (name orders, table orders_v2, primary_key order_id);\nMETRIC (name order_count, agg count);", + ) + .unwrap(); + assert_success(sidemantic_define( + replacement.as_ptr(), + db_path.as_ptr(), + true, + )); + + let result = sidemantic_rewrite( + CString::new("SELECT orders.order_count FROM orders") + .unwrap() + .as_ptr(), + ); + assert!(result.error.is_null()); + assert!(result.was_rewritten); + + let rewritten = unsafe { CStr::from_ptr(result.sql).to_string_lossy().into_owned() }; + assert!(rewritten.contains("orders_v2"), "{rewritten}"); + sidemantic_free_result(result); + + remove_definitions_file(&db_path); + } + + #[test] + fn test_add_definition_updates_existing_model() { + let _guard = test_lock(); + sidemantic_clear(); + + let db_path = unique_db_path("add_definition"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path); + + let model = + CString::new("MODEL (name orders, table orders, primary_key order_id);").unwrap(); + assert_success(sidemantic_define(model.as_ptr(), db_path.as_ptr(), false)); + + let metric = CString::new("METRIC (name revenue, agg sum, sql amount);").unwrap(); + assert_success(sidemantic_add_definition( + metric.as_ptr(), + db_path.as_ptr(), + false, + )); + + let result = sidemantic_rewrite( + CString::new("SELECT orders.revenue FROM orders") + .unwrap() + .as_ptr(), + ); + assert!(result.error.is_null()); + assert!(result.was_rewritten); + + let rewritten = unsafe { CStr::from_ptr(result.sql).to_string_lossy().into_owned() }; + assert!(rewritten.contains("SUM"), "{rewritten}"); + sidemantic_free_result(result); + + remove_definitions_file(&db_path); + } + + #[test] + fn test_replace_metric_dimension_and_segment_updates_persisted_definitions() { + let _guard = test_lock(); + sidemantic_clear(); + + let db_path = unique_db_path("replace_item_persistence"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + + let model = + CString::new("MODEL (name orders, table orders, primary_key order_id);").unwrap(); + assert_success(sidemantic_define(model.as_ptr(), db_path.as_ptr(), false)); + + let old_metric = CString::new("METRIC revenue AS SUM(gross_amount);").unwrap(); + assert_success(sidemantic_add_definition( + old_metric.as_ptr(), + db_path.as_ptr(), + false, + )); + let new_metric = CString::new("METRIC revenue AS SUM(net_amount);").unwrap(); + assert_success(sidemantic_add_definition( + new_metric.as_ptr(), + db_path.as_ptr(), + true, + )); + + let old_dimension = CString::new("DIMENSION status AS raw_status;").unwrap(); + assert_success(sidemantic_add_definition( + old_dimension.as_ptr(), + db_path.as_ptr(), + false, + )); + let new_dimension = CString::new("DIMENSION status AS clean_status;").unwrap(); + assert_success(sidemantic_add_definition( + new_dimension.as_ptr(), + db_path.as_ptr(), + true, + )); + + let old_segment = + CString::new("SEGMENT (name target_segment, sql old_flag = true);").unwrap(); + assert_success(sidemantic_add_definition( + old_segment.as_ptr(), + db_path.as_ptr(), + false, + )); + let new_segment = + CString::new("SEGMENT (name target_segment, sql new_flag = true);").unwrap(); + assert_success(sidemantic_add_definition( + new_segment.as_ptr(), + db_path.as_ptr(), + true, + )); + + let content = fs::read_to_string(&definitions_path).unwrap(); + assert!(!content.contains("gross_amount"), "{content}"); + assert!(!content.contains("raw_status"), "{content}"); + assert!(!content.contains("old_flag"), "{content}"); + assert!(content.contains("net_amount"), "{content}"); + assert!(content.contains("clean_status"), "{content}"); + assert!(content.contains("new_flag"), "{content}"); + assert_eq!(content.matches("METRIC revenue").count(), 1, "{content}"); + assert_eq!(content.matches("DIMENSION status").count(), 1, "{content}"); + assert_eq!( + content.matches("SEGMENT (name target_segment").count(), + 1, + "{content}" + ); + + let persisted_model = parse_sql_model(&content).unwrap(); + assert_eq!( + persisted_model.metrics[0].sql.as_deref(), + Some("net_amount") + ); + assert_eq!( + persisted_model.dimensions[0].sql.as_deref(), + Some("clean_status") + ); + assert_eq!(persisted_model.segments[0].sql, "new_flag = true"); + + sidemantic_clear(); + assert_success(sidemantic_autoload(db_path.as_ptr())); + + let metric_sql = take_rewrite_sql(sidemantic_rewrite( + CString::new("SELECT orders.revenue FROM orders") + .unwrap() + .as_ptr(), + )); + assert!(metric_sql.contains("net_amount"), "{metric_sql}"); + assert!(!metric_sql.contains("gross_amount"), "{metric_sql}"); + + let dimension_sql = take_rewrite_sql(sidemantic_rewrite( + CString::new("SELECT orders.status FROM orders") + .unwrap() + .as_ptr(), + )); + assert!(dimension_sql.contains("clean_status"), "{dimension_sql}"); + assert!(!dimension_sql.contains("raw_status"), "{dimension_sql}"); + + remove_definitions_file(&db_path); + } + + #[test] + fn test_prefixed_definition_persists_under_target_model_block() { + let _guard = test_lock(); + sidemantic_clear(); + + let db_path = unique_db_path("prefixed_item_persistence"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + + let orders = + CString::new("MODEL (name orders, table orders, primary_key order_id);").unwrap(); + let customers = + CString::new("MODEL (name customers, table customers, primary_key customer_id);") + .unwrap(); + assert_success(sidemantic_define(orders.as_ptr(), db_path.as_ptr(), false)); + assert_success(sidemantic_define( + customers.as_ptr(), + db_path.as_ptr(), + false, + )); + + let metric = CString::new("METRIC orders.revenue AS SUM(amount);").unwrap(); + assert_success(sidemantic_add_definition( + metric.as_ptr(), + db_path.as_ptr(), + false, + )); + + let content = fs::read_to_string(&definitions_path).unwrap(); + let blocks = split_definitions(&content); + let orders_model = blocks + .iter() + .find_map(|block| { + let model = parse_sql_model(block).ok()?; + (model.name == "orders").then_some(model) + }) + .unwrap(); + let customers_model = blocks + .iter() + .find_map(|block| { + let model = parse_sql_model(block).ok()?; + (model.name == "customers").then_some(model) + }) + .unwrap(); + + assert_eq!(orders_model.metrics.len(), 1); + assert_eq!(orders_model.metrics[0].name, "revenue"); + assert!(customers_model.metrics.is_empty()); + assert!(!content.contains("orders.revenue"), "{content}"); + + sidemantic_clear(); + assert_success(sidemantic_autoload(db_path.as_ptr())); + + let rewritten = take_rewrite_sql(sidemantic_rewrite( + CString::new("SELECT orders.revenue FROM orders") + .unwrap() + .as_ptr(), + )); + assert!(rewritten.contains("SUM"), "{rewritten}"); + assert!(rewritten.contains("amount"), "{rewritten}"); + + remove_definitions_file(&db_path); + } + + #[test] + fn test_clear_resets_active_model() { + let _guard = test_lock(); + sidemantic_clear(); + + let db_path = unique_db_path("clear_active"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path); + + let model = + CString::new("MODEL (name orders, table orders, primary_key order_id);").unwrap(); + assert_success(sidemantic_define(model.as_ptr(), db_path.as_ptr(), false)); + + sidemantic_clear(); + + let metric = CString::new("METRIC (name revenue, agg sum, sql amount);").unwrap(); + let error = take_error(sidemantic_add_definition( + metric.as_ptr(), + db_path.as_ptr(), + false, + )); + + assert!(error.contains("no model defined yet"), "{error}"); + + remove_definitions_file(&db_path); + } + + #[test] + fn test_context_keyed_state_isolates_models_and_active_model() { + let _guard = test_lock(); + + let context_a = CString::new("duckdb:a").unwrap(); + let context_b = CString::new("duckdb:b").unwrap(); + sidemantic_clear_for_context(context_a.as_ptr()); + sidemantic_clear_for_context(context_b.as_ptr()); + + let db_path_a = unique_db_path("context_a"); + let db_path_a = CString::new(db_path_a.to_string_lossy().to_string()).unwrap(); + let db_path_b = unique_db_path("context_b"); + let db_path_b = CString::new(db_path_b.to_string_lossy().to_string()).unwrap(); + remove_definitions_file(&db_path_a); + remove_definitions_file(&db_path_b); + + let model_a = + CString::new("MODEL (name orders, table orders_a, primary_key order_id);").unwrap(); + let model_b = + CString::new("MODEL (name orders, table orders_b, primary_key order_id);").unwrap(); + + assert_success(sidemantic_define_for_context( + context_a.as_ptr(), + model_a.as_ptr(), + db_path_a.as_ptr(), + false, + )); + assert_success(sidemantic_define_for_context( + context_b.as_ptr(), + model_b.as_ptr(), + db_path_b.as_ptr(), + false, + )); + + let metric_a = CString::new("METRIC (name revenue, agg sum, sql amount);").unwrap(); + let metric_b = CString::new("METRIC (name order_count, agg count);").unwrap(); + assert_success(sidemantic_add_definition_for_context( + context_a.as_ptr(), + metric_a.as_ptr(), + db_path_a.as_ptr(), + false, + )); + assert_success(sidemantic_add_definition_for_context( + context_b.as_ptr(), + metric_b.as_ptr(), + db_path_b.as_ptr(), + false, + )); + + let sql_a = CString::new("SELECT orders.revenue FROM orders").unwrap(); + let rewritten_a = take_rewrite_sql(sidemantic_rewrite_for_context( + context_a.as_ptr(), + sql_a.as_ptr(), + )); + assert!(rewritten_a.contains("orders_a"), "{rewritten_a}"); + assert!(rewritten_a.contains("SUM"), "{rewritten_a}"); + + let sql_b = CString::new("SELECT orders.order_count FROM orders").unwrap(); + let rewritten_b = take_rewrite_sql(sidemantic_rewrite_for_context( + context_b.as_ptr(), + sql_b.as_ptr(), + )); + assert!(rewritten_b.contains("orders_b"), "{rewritten_b}"); + assert!(rewritten_b.contains("COUNT"), "{rewritten_b}"); + + sidemantic_clear_for_context(context_a.as_ptr()); + + let passthrough = sidemantic_rewrite_for_context(context_a.as_ptr(), sql_a.as_ptr()); + assert!(passthrough.error.is_null()); + assert!(!passthrough.was_rewritten); + sidemantic_free_result(passthrough); + + let still_rewritten = sidemantic_rewrite_for_context(context_b.as_ptr(), sql_b.as_ptr()); + assert!(still_rewritten.error.is_null()); + assert!(still_rewritten.was_rewritten); + sidemantic_free_result(still_rewritten); + + remove_definitions_file(&db_path_a); + remove_definitions_file(&db_path_b); + } + + #[test] + fn test_load_yaml_replaces_same_name_model_in_context() { + let _guard = test_lock(); + + let context = CString::new("duckdb:replace-load").unwrap(); + sidemantic_clear_for_context(context.as_ptr()); + + let first = CString::new( + r#" +models: + - name: orders + table: orders_v1 + primary_key: order_id + metrics: + - name: order_count + agg: count +"#, + ) + .unwrap(); + assert_success(sidemantic_load_yaml_for_context( + context.as_ptr(), + first.as_ptr(), + )); + + let second = CString::new( + r#" +models: + - name: orders + table: orders_v2 + primary_key: order_id + metrics: + - name: order_count + agg: count +"#, + ) + .unwrap(); + assert_success(sidemantic_load_yaml_for_context( + context.as_ptr(), + second.as_ptr(), + )); + + let sql = CString::new("SELECT orders.order_count FROM orders").unwrap(); + let rewritten = take_rewrite_sql(sidemantic_rewrite_for_context( + context.as_ptr(), + sql.as_ptr(), + )); + assert!(rewritten.contains("orders_v2"), "{rewritten}"); + assert!(rewritten.contains("COUNT"), "{rewritten}"); + + sidemantic_clear_for_context(context.as_ptr()); + } + + #[test] + fn test_autoload_sets_active_model_to_last_loaded_model() { + let _guard = test_lock(); + + let context = CString::new("duckdb:autoload-active").unwrap(); + sidemantic_clear_for_context(context.as_ptr()); + + let db_path = unique_db_path("autoload_active"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + fs::write( + &definitions_path, + "MODEL (name orders, table orders, primary_key order_id);", + ) + .unwrap(); + + assert_success(sidemantic_autoload_for_context( + context.as_ptr(), + db_path.as_ptr(), + )); + + let metric = CString::new("METRIC (name revenue, agg sum, sql amount);").unwrap(); + assert_success(sidemantic_add_definition_for_context( + context.as_ptr(), + metric.as_ptr(), + db_path.as_ptr(), + false, + )); + + let sql = CString::new("SELECT orders.revenue FROM orders").unwrap(); + let rewritten = take_rewrite_sql(sidemantic_rewrite_for_context( + context.as_ptr(), + sql.as_ptr(), + )); + assert!(rewritten.contains("SUM"), "{rewritten}"); + + sidemantic_clear_for_context(context.as_ptr()); + let _ = fs::remove_file(definitions_path); + } + + #[test] + fn test_autoload_loads_all_persisted_models() { + let _guard = test_lock(); + + let context = CString::new("duckdb:autoload-all-models").unwrap(); + sidemantic_clear_for_context(context.as_ptr()); + + let db_path = unique_db_path("autoload_all_models"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + fs::write( + &definitions_path, + "MODEL (name orders, table orders, primary_key order_id);\nMETRIC revenue AS SUM(amount);\n\nMODEL (name customers, table customers, primary_key customer_id);\nMETRIC customer_count AS COUNT(*);", + ) + .unwrap(); + + assert_success(sidemantic_autoload_for_context( + context.as_ptr(), + db_path.as_ptr(), + )); + + let orders_sql = CString::new("SELECT orders.revenue FROM orders").unwrap(); + let orders_rewrite = take_rewrite_sql(sidemantic_rewrite_for_context( + context.as_ptr(), + orders_sql.as_ptr(), + )); + assert!(orders_rewrite.contains("amount"), "{orders_rewrite}"); + + let customers_sql = CString::new("SELECT customers.customer_count FROM customers").unwrap(); + let customers_rewrite = take_rewrite_sql(sidemantic_rewrite_for_context( + context.as_ptr(), + customers_sql.as_ptr(), + )); + assert!(customers_rewrite.contains("COUNT"), "{customers_rewrite}"); + + sidemantic_clear_for_context(context.as_ptr()); + let _ = fs::remove_file(definitions_path); + } + + #[test] + fn test_autoload_invalid_definition_is_best_effort() { + let _guard = test_lock(); + + let context = CString::new("duckdb:autoload-invalid-best-effort").unwrap(); + sidemantic_clear_for_context(context.as_ptr()); + + let db_path = unique_db_path("autoload_invalid_best_effort"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + fs::write( + &definitions_path, + "MODEL (name events, table events, primary_key event_id);\nMETRIC event_count AS COUNT(*);\n\nMODEL (", + ) + .unwrap(); + + assert_success(sidemantic_autoload_for_context( + context.as_ptr(), + db_path.as_ptr(), + )); + + let sql = CString::new("SELECT events.event_count FROM events").unwrap(); + let rewritten = take_rewrite_sql(sidemantic_rewrite_for_context( + context.as_ptr(), + sql.as_ptr(), + )); + assert!(rewritten.contains("COUNT"), "{rewritten}"); + + sidemantic_clear_for_context(context.as_ptr()); + let _ = fs::remove_file(definitions_path); + } + + #[test] + fn test_remove_model_from_file_handles_multiline_models() { + let _guard = test_lock(); + let db_path = unique_db_path("remove_multiline"); + let db_path = CString::new(db_path.to_string_lossy().to_string()).unwrap(); + let definitions_path = get_definitions_path(db_path.as_ptr()).unwrap(); + let content = r#" +MODEL ( + name orders, + table orders, + primary_key order_id +); +METRIC (name revenue, agg sum, sql amount); + +MODEL (name customers, table customers, primary_key customer_id); +"#; + fs::write(&definitions_path, content).unwrap(); + + remove_model_from_file(&definitions_path, "orders").unwrap(); + + let updated = fs::read_to_string(&definitions_path).unwrap(); + assert!(!updated.contains("name orders"), "{updated}"); + assert!(!updated.contains("name revenue"), "{updated}"); + assert!(updated.contains("name customers"), "{updated}"); + + let _ = fs::remove_file(definitions_path); + } } diff --git a/sidemantic-rs/src/lib.rs b/sidemantic-rs/src/lib.rs index 67fd168b..25b99c1d 100644 --- a/sidemantic-rs/src/lib.rs +++ b/sidemantic-rs/src/lib.rs @@ -36,17 +36,99 @@ pub mod config; pub mod core; +pub mod db; pub mod error; pub mod ffi; +#[cfg(feature = "python")] +mod python; +pub mod runtime; pub mod sql; +#[cfg(feature = "wasm")] +pub mod wasm; // Re-export commonly used types -pub use config::{load_from_directory, load_from_file, load_from_string}; +pub use config::{ + load_from_directory, load_from_directory_with_metadata, load_from_file, load_from_string, +}; pub use core::{ - build_symmetric_aggregate_sql, merge_model, resolve_model_inheritance, Aggregation, Dimension, - DimensionType, JoinPath, JoinStep, Metric, MetricType, Model, Relationship, RelationshipType, - RelativeDate, Segment, SemanticGraph, SqlDialect, SymmetricAggType, TableCalcType, - TableCalculation, + build_symmetric_aggregate_sql, merge_model, resolve_model_inheritance, Aggregation, + CohortInnerMetric, Dimension, DimensionType, JoinPath, JoinStep, Metric, MetricType, Model, + Parameter, ParameterType, Relationship, RelationshipType, RelativeDate, Segment, SemanticGraph, + SqlDialect, SymmetricAggType, TableCalcType, TableCalculation, }; pub use error::{Result, SidemanticError}; +pub use runtime::{ + analyze_migrator_query, build_preaggregation_refresh_statements, + calculate_preaggregation_benefit_score, chart_auto_detect_columns, chart_encoding_type, + chart_format_label, chart_select_type, compile_with_yaml_query, detect_adapter_kind, + dimension_sql_expr_with_yaml, dimension_with_granularity_with_yaml, + evaluate_table_calculation_expression, extract_column_references, + extract_metric_dependencies_from_yaml, extract_preaggregation_patterns, find_models_for_query, + find_relationship_path_with_yaml, format_parameter_value_with_yaml, + generate_catalog_metadata_with_yaml, generate_preaggregation_definition, + generate_preaggregation_materialization_sql_with_yaml, generate_preaggregation_name, + generate_time_comparison_sql, interpolate_sql_with_parameters_with_yaml, is_relative_date, + is_sql_template, load_graph_from_directory, load_graph_with_yaml, metric_is_simple_aggregation, + metric_sql_expr, metric_to_sql, model_find_dimension_index_with_yaml, + model_find_metric_index_with_yaml, model_find_pre_aggregation_index_with_yaml, + model_find_segment_index_with_yaml, model_get_drill_down_with_yaml, + model_get_drill_up_with_yaml, model_get_hierarchy_path_with_yaml, parse_reference_with_yaml, + parse_relative_date, parse_simple_metric_aggregation, parse_sql_definitions_payload, + parse_sql_graph_definitions_payload, parse_sql_model_payload, + plan_preaggregation_refresh_execution, recommend_preaggregation_patterns, + relationship_foreign_key_columns_with_yaml, relationship_primary_key_columns_with_yaml, + relationship_related_key_with_yaml, relationship_sql_expr_with_yaml, relative_date_to_range, + render_sql_template, resolve_metric_inheritance, resolve_model_inheritance_with_yaml, + resolve_preaggregation_refresh_mode, rewrite_with_yaml, segment_get_sql_with_yaml, + shape_preaggregation_refresh_result, summarize_preaggregation_patterns, + time_comparison_offset_interval, time_comparison_sql_offset, trailing_period_sql_interval, + validate_engine_refresh_sql_compatibility, validate_metric_payload, validate_model_payload, + validate_models_yaml, validate_parameter_payload, validate_preaggregation_refresh_request, + validate_query_references, validate_query_with_yaml, validate_table_calculation_payload, + validate_table_formula_expression, LoadedGraphPayload, PreaggregationRefreshExecutionPlan, + PreaggregationRefreshResultShape, QueryValidationContext, RelationshipPathError, + RelationshipPathStep, SidemanticRuntime, +}; pub use sql::{QueryRewriter, SemanticQuery, SqlGenerator}; +#[cfg(feature = "wasm")] +pub use wasm::{ + wasm_analyze_migrator_query, wasm_build_preaggregation_refresh_statements, + wasm_build_symmetric_aggregate_sql, wasm_calculate_preaggregation_benefit_score, + wasm_chart_auto_detect_columns, wasm_chart_encoding_type, wasm_chart_format_label, + wasm_chart_select_type, wasm_compile_with_yaml_query, wasm_detect_adapter_kind, + wasm_dimension_sql_expr_with_yaml, wasm_dimension_with_granularity_with_yaml, + wasm_evaluate_table_calculation_expression, wasm_extract_column_references, + wasm_extract_metric_dependencies_from_yaml, wasm_extract_preaggregation_patterns, + wasm_find_models_for_query, wasm_find_relationship_path_with_yaml, + wasm_format_parameter_value_with_yaml, wasm_generate_catalog_metadata_with_yaml, + wasm_generate_preaggregation_definition, + wasm_generate_preaggregation_materialization_sql_with_yaml, wasm_generate_preaggregation_name, + wasm_generate_time_comparison_sql, wasm_interpolate_sql_with_parameters_with_yaml, + wasm_is_relative_date, wasm_is_sql_template, wasm_load_graph_with_sql, + wasm_load_graph_with_yaml, wasm_metric_is_simple_aggregation, wasm_metric_sql_expr, + wasm_metric_to_sql, wasm_model_find_dimension_index_with_yaml, + wasm_model_find_metric_index_with_yaml, wasm_model_find_pre_aggregation_index_with_yaml, + wasm_model_find_segment_index_with_yaml, wasm_model_get_drill_down_with_yaml, + wasm_model_get_drill_up_with_yaml, wasm_model_get_hierarchy_path_with_yaml, + wasm_needs_symmetric_aggregate, wasm_parse_reference_with_yaml, wasm_parse_relative_date, + wasm_parse_simple_metric_aggregation, wasm_parse_sql_definitions_payload, + wasm_parse_sql_graph_definitions_payload, wasm_parse_sql_model_payload, + wasm_parse_sql_statement_blocks_payload, wasm_recommend_preaggregation_patterns, + wasm_relationship_foreign_key_columns_with_yaml, + wasm_relationship_primary_key_columns_with_yaml, wasm_relationship_related_key_with_yaml, + wasm_relationship_sql_expr_with_yaml, wasm_relative_date_to_range, wasm_render_sql_template, + wasm_resolve_metric_inheritance, wasm_resolve_model_inheritance_with_yaml, + wasm_rewrite_with_yaml, wasm_segment_get_sql_with_yaml, wasm_summarize_preaggregation_patterns, + wasm_time_comparison_offset_interval, wasm_time_comparison_sql_offset, + wasm_trailing_period_sql_interval, wasm_validate_engine_refresh_sql_compatibility, + wasm_validate_metric_payload, wasm_validate_model_payload, wasm_validate_models_yaml, + wasm_validate_parameter_payload, wasm_validate_query_references, + wasm_validate_query_references_with_yaml, wasm_validate_query_with_yaml, + wasm_validate_table_calculation_payload, wasm_validate_table_formula_expression, +}; + +#[cfg(feature = "adbc-exec")] +pub use db::{ + execute_with_adbc, execute_with_adbc_arrow_ipc, write_adbc_arrow_ipc, AdbcArrowIpcResult, + AdbcExecutionRequest, AdbcExecutionResult, AdbcValue, +}; diff --git a/sidemantic-rs/src/main.rs b/sidemantic-rs/src/main.rs index 820a5e4a..95e1accf 100644 --- a/sidemantic-rs/src/main.rs +++ b/sidemantic-rs/src/main.rs @@ -1,200 +1,2838 @@ -//! Example usage of sidemantic-rs +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::env; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process; +#[cfg(any( + feature = "mcp-server", + feature = "runtime-server", + feature = "runtime-lsp" +))] +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; -use sidemantic::sql::{QueryRewriter, SemanticQuery, SqlGenerator}; +#[cfg(feature = "adbc-exec")] +use adbc_core::options::{OptionConnection, OptionDatabase, OptionValue}; +use regex::Regex; +use serde::Deserialize; +#[cfg(feature = "adbc-exec")] +use serde::Serialize; +#[cfg(feature = "adbc-exec")] +use serde_json::Map as JsonMap; +#[cfg(feature = "adbc-exec")] +use serde_json::Value as JsonValue; use sidemantic::{ - load_from_string, Dimension, Metric, Model, Relationship, Segment, SemanticGraph, + build_preaggregation_refresh_statements, extract_preaggregation_patterns, + generate_preaggregation_definition, recommend_preaggregation_patterns, + summarize_preaggregation_patterns, SemanticQuery, SidemanticError, SidemanticRuntime, }; +#[cfg(feature = "adbc-exec")] +use sidemantic::{execute_with_adbc, AdbcExecutionRequest, AdbcValue}; + +#[cfg(feature = "workbench-tui")] +mod workbench; + +type CliResult = std::result::Result; +type ParsedOptions = HashMap>; +#[cfg(feature = "adbc-exec")] +pub(crate) type AdbcConnectionUrlParts = + (String, Option, Vec<(OptionDatabase, OptionValue)>); + +#[cfg(feature = "adbc-exec")] +#[derive(Debug, Default, Serialize)] +struct RunOutput { + sql: String, + columns: Vec, + rows: Vec, + row_count: usize, +} + +#[derive(Debug, Clone)] +struct RefreshPlan { + model_name: String, + preagg_name: String, + table_name: String, + mode: String, + statements: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct CliPreaggPattern { + model: String, + metrics: Vec, + dimensions: Vec, + granularities: Vec, + count: usize, +} + +#[derive(Debug, Clone, Deserialize)] +struct CliPreaggRecommendation { + pattern: CliPreaggPattern, + suggested_name: String, + query_count: usize, + estimated_benefit_score: f64, +} + +#[derive(Debug, Clone, Deserialize, Default)] +struct CliMigratorAnalysisPayload { + #[serde(default)] + column_references: Vec, + #[serde(default)] + group_by_columns: Vec<(String, String)>, +} + +#[derive(Debug, Clone)] +struct CliMigratorMetric { + name: String, + agg: String, + sql: String, +} + +#[derive(Debug, Clone, Default)] +struct CliGeneratedModel { + dimensions: BTreeSet, + metrics: Vec, +} + +#[cfg(feature = "adbc-exec")] +#[derive(Debug, Clone)] +struct AdbcCliConfig { + driver: String, + uri: Option, + entrypoint: Option, + database_options: Vec<(OptionDatabase, OptionValue)>, + connection_options: Vec<(OptionConnection, OptionValue)>, +} fn main() { - println!("=== Sidemantic-rs Demo ===\n"); + if let Err(err) = run() { + eprintln!("error: {err}"); + process::exit(1); + } +} - // Demo 1: Programmatic API - demo_programmatic_api(); - - // Demo 2: YAML Loading (native format) - demo_yaml_loading(); - - // Demo 3: Cube.js Format - demo_cube_format(); +fn run() -> CliResult<()> { + let args: Vec = env::args().skip(1).collect(); + if args.is_empty() { + print_help(); + return Ok(()); + } - // Demo 4: Segments - demo_segments(); + let command = args[0].as_str(); + let rest = &args[1..]; + match command { + "-h" | "--help" | "help" => { + print_help(); + Ok(()) + } + "compile" => compile_command(rest), + "rewrite" => rewrite_command(rest), + "validate" => validate_command(rest), + "migrator" => migrator_command(rest), + "info" => info_command(rest), + "query" => query_command(rest), + "run" => run_command(rest), + "preagg" => preagg_command(rest), + "workbench" => workbench_command(rest), + "tree" => tree_command(rest), + "mcp" => mcp_command(rest), + "mcp-serve" => mcp_command(rest), + "server" => server_command(rest), + "serve" => server_command(rest), + "lsp" => lsp_command(rest), + unknown => Err(format!( + "unknown command '{unknown}'. Use 'sidemantic --help' for usage." + )), + } +} - // Demo 5: Query Rewriter - demo_query_rewriter(); -} - -fn demo_programmatic_api() { - println!("--- 1. Programmatic API ---\n"); - - let mut graph = SemanticGraph::new(); - - let orders = Model::new("orders", "order_id") - .with_table("orders") - .with_dimension(Dimension::categorical("status")) - .with_dimension(Dimension::time("order_date").with_sql("created_at")) - .with_metric(Metric::sum("revenue", "amount")) - .with_metric(Metric::count("order_count")) - .with_metric(Metric::avg("avg_order_value", "amount")) - .with_relationship(Relationship::many_to_one("customers")); - - let customers = Model::new("customers", "id") - .with_table("customers") - .with_dimension(Dimension::categorical("name")) - .with_dimension(Dimension::categorical("country")); - - graph.add_model(orders).unwrap(); - graph.add_model(customers).unwrap(); - - let generator = SqlGenerator::new(&graph); - - let query = SemanticQuery::new() - .with_metrics(vec!["orders.revenue".into(), "orders.order_count".into()]) - .with_dimensions(vec!["orders.status".into()]); - - println!("Query: revenue and order_count by status"); - println!("{}\n", generator.generate(&query).unwrap()); -} - -fn demo_yaml_loading() { - println!("--- 2. YAML Loading (Native Format) ---\n"); - - let yaml = r#" -models: - - name: orders - table: orders - primary_key: order_id - dimensions: - - name: status - type: categorical - - name: order_date - type: time - sql: created_at - metrics: - - name: revenue - agg: sum - sql: amount - - name: order_count - agg: count -"#; - - let graph = load_from_string(yaml).unwrap(); - let generator = SqlGenerator::new(&graph); - - let query = SemanticQuery::new() - .with_metrics(vec!["orders.revenue".into()]) - .with_dimensions(vec!["orders.status".into()]); - - println!("Loaded from YAML:"); - println!("{}\n", generator.generate(&query).unwrap()); -} - -fn demo_cube_format() { - println!("--- 3. Cube.js Format ---\n"); - - let yaml = r#" -cubes: - - name: orders - sql_table: public.orders - - dimensions: - - name: status - sql: "${CUBE}.status" - type: string - - name: created_at - sql: "${CUBE}.created_at" - type: time - - measures: - - name: revenue - sql: "${CUBE}.amount" - type: sum - - name: order_count - type: count - - segments: - - name: completed - sql: "${CUBE}.status = 'completed'" -"#; - - let graph = load_from_string(yaml).unwrap(); - let model = graph.get_model("orders").unwrap(); - - println!("Converted from Cube.js format:"); - println!(" Table: {:?}", model.table); +fn print_help() { println!( - " Dimensions: {:?}", - model.dimensions.iter().map(|d| &d.name).collect::>() + "sidemantic (Rust CLI)\n\ + \n\ + Commands:\n\ + compile Compile semantic query to SQL\n\ + rewrite Rewrite SQL using semantic graph\n\ + validate Validate model/query references\n\ + migrator Analyze SQL coverage and bootstrap model files\n\ + info Show semantic layer model summary\n\ + query Rewrite SQL and optionally execute via ADBC\n\ + run Compile and execute query via ADBC\n\ + preagg Pre-aggregation helpers (materialize/recommend/refresh)\n\ + workbench Launch interactive workbench (ratatui)\n\ + mcp-serve Launch sidemantic-mcp passthrough\n\ + serve Launch sidemantic-server passthrough\n\ + lsp Launch sidemantic-lsp passthrough\n\ + \n\ + Examples:\n\ + sidemantic compile --models ./models --metric orders.revenue --dimension orders.status\n\ + sidemantic rewrite --models ./models --sql \"select orders.revenue from orders\"\n\ + sidemantic migrator --queries ./queries --generate-models ./out\n\ + sidemantic info --models ./models\n\ + sidemantic query --models ./models --sql \"select orders.revenue from orders\" --dry-run\n\ + sidemantic run --models ./models --metric orders.revenue --driver adbc_driver_duckdb --uri :memory:\n\ + sidemantic preagg refresh --models ./models --model orders --name daily_revenue --mode full\n\ + sidemantic serve --models ./models --bind 127.0.0.1:5544\n\ + \n\ + Use ' --help' for command-specific usage." ); - println!( - " Metrics: {:?}", - model.metrics.iter().map(|m| &m.name).collect::>() +} + +fn parse_options(args: &[String]) -> CliResult<(ParsedOptions, Vec)> { + let mut options: ParsedOptions = HashMap::new(); + let mut positionals = Vec::new(); + let mut index = 0usize; + + while index < args.len() { + let arg = &args[index]; + if arg.starts_with("--") { + if let Some((key, value)) = arg.split_once('=') { + options + .entry(key.to_string()) + .or_default() + .push(value.to_string()); + index += 1; + continue; + } + + let key = arg.to_string(); + if index + 1 < args.len() && !args[index + 1].starts_with("--") { + options + .entry(key) + .or_default() + .push(args[index + 1].clone()); + index += 2; + } else { + options.entry(key).or_default().push("true".to_string()); + index += 1; + } + } else { + positionals.push(arg.clone()); + index += 1; + } + } + + Ok((options, positionals)) +} + +fn option_values(options: &ParsedOptions, key: &str) -> Vec { + options.get(key).cloned().unwrap_or_default() +} + +fn option_value(options: &ParsedOptions, key: &str) -> Option { + options.get(key).and_then(|values| values.last().cloned()) +} + +fn option_value_any(options: &ParsedOptions, keys: &[&str]) -> Option { + for key in keys { + if let Some(value) = option_value(options, key) { + return Some(value); + } + } + None +} + +fn require_option(options: &ParsedOptions, key: &str) -> CliResult { + option_value(options, key).ok_or_else(|| format!("missing required option '{key}'")) +} + +fn option_flag(options: &ParsedOptions, key: &str) -> bool { + options.contains_key(key) +} + +fn option_usize(options: &ParsedOptions, key: &str) -> CliResult> { + match option_value(options, key) { + Some(value) => value + .parse::() + .map(Some) + .map_err(|_| format!("invalid usize value for {key}: {value}")), + None => Ok(None), + } +} + +fn option_f64(options: &ParsedOptions, key: &str) -> CliResult> { + match option_value(options, key) { + Some(value) => value + .parse::() + .map(Some) + .map_err(|_| format!("invalid f64 value for {key}: {value}")), + None => Ok(None), + } +} + +fn expect_no_positionals(positionals: &[String], context: &str) -> CliResult<()> { + if positionals.is_empty() { + Ok(()) + } else { + Err(format!( + "{context}: unexpected positional arguments: {}", + positionals.join(" ") + )) + } +} + +fn refresh_plan_error_message(err: SidemanticError) -> String { + match err { + SidemanticError::Validation(message) => message, + other => other.to_string(), + } +} + +#[cfg(feature = "adbc-exec")] +fn env_non_empty(key: &str) -> Option { + match env::var(key) { + Ok(value) if !value.trim().is_empty() => Some(value), + _ => None, + } +} + +fn load_runtime(models_path: &str) -> CliResult { + let path = PathBuf::from(models_path); + if path.is_dir() { + return SidemanticRuntime::from_directory(path) + .map_err(|e| format!("failed to load models directory '{models_path}': {e}")); + } + if path.is_file() { + return SidemanticRuntime::from_file(path) + .map_err(|e| format!("failed to load models file '{models_path}': {e}")); + } + Err(format!( + "models path '{models_path}' is not a readable file or directory" + )) +} + +fn build_query_from_options(options: &ParsedOptions) -> CliResult { + let metrics = option_values(options, "--metric"); + let dimensions = option_values(options, "--dimension"); + if metrics.is_empty() && dimensions.is_empty() { + return Err("query requires at least one --metric or --dimension".to_string()); + } + + let filters = option_values(options, "--filter"); + let segments = option_values(options, "--segment"); + let order_by = option_values(options, "--order-by"); + let limit = option_usize(options, "--limit")?; + let ungrouped = option_flag(options, "--ungrouped"); + let use_preaggregations = option_flag(options, "--use-preaggregations"); + let skip_default_time_dimensions = option_flag(options, "--skip-default-time-dimensions"); + let preagg_database = option_value(options, "--preagg-database"); + let preagg_schema = option_value(options, "--preagg-schema"); + + let mut query = SemanticQuery::new() + .with_metrics(metrics) + .with_dimensions(dimensions) + .with_filters(filters) + .with_segments(segments) + .with_order_by(order_by) + .with_ungrouped(ungrouped) + .with_use_preaggregations(use_preaggregations) + .with_skip_default_time_dimensions(skip_default_time_dimensions) + .with_preaggregation_qualifiers(preagg_database, preagg_schema); + + if let Some(limit) = limit { + query = query.with_limit(limit); + } + Ok(query) +} + +fn compile_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic compile --models [--metric ...] [--dimension ...] [--filter ...] [--segment ...] [--order-by ...] [--limit ] [--use-preaggregations] [--preagg-database ] [--preagg-schema ] [--skip-default-time-dimensions]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "compile")?; + let models = require_option(&options, "--models")?; + let runtime = load_runtime(&models)?; + let query = build_query_from_options(&options)?; + let sql = runtime + .compile(&query) + .map_err(|e| format!("failed to compile query: {e}"))?; + println!("{sql}"); + Ok(()) +} + +fn rewrite_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!("Usage: sidemantic rewrite --models (--sql | --sql-file )"); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "rewrite")?; + let models = require_option(&options, "--models")?; + let runtime = load_runtime(&models)?; + + let sql = if let Some(sql) = option_value(&options, "--sql") { + sql + } else if let Some(path) = option_value(&options, "--sql-file") { + fs::read_to_string(&path).map_err(|e| format!("failed to read SQL file '{path}': {e}"))? + } else { + return Err("rewrite requires --sql or --sql-file".to_string()); + }; + + let rewritten = runtime + .rewrite(&sql) + .map_err(|e| format!("failed to rewrite SQL: {e}"))?; + println!("{rewritten}"); + Ok(()) +} + +fn validate_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic validate [models_path] [--models ] [--verbose] [--metric ...] [--dimension ...]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + if positionals.len() > 1 { + return Err(format!( + "validate: unexpected positional arguments: {}", + positionals[1..].join(" ") + )); + } + + let models = option_value(&options, "--models") + .or_else(|| positionals.first().cloned()) + .unwrap_or_else(|| ".".to_string()); + let runtime = load_runtime(&models)?; + let metrics = option_values(&options, "--metric"); + let dimensions = option_values(&options, "--dimension"); + let verbose = option_flag(&options, "--verbose"); + + if !metrics.is_empty() || !dimensions.is_empty() { + let errors = runtime.validate_query_references(&metrics, &dimensions); + if errors.is_empty() { + println!("ok"); + return Ok(()); + } + + for err in errors { + eprintln!("{err}"); + } + return Err("query reference validation failed".to_string()); + } + + let models_list = runtime.graph().models().collect::>(); + if models_list.is_empty() { + return Err("No models found".to_string()); + } + + println!("Validation passed"); + if verbose { + println!("Models: {}", models_list.len()); + for model in &models_list { + println!( + "- {}: {} dimensions, {} metrics, {} relationships", + model.name, + model.dimensions.len(), + model.metrics.len(), + model.relationships.len() + ); + } + } else { + println!("ok"); + return Ok(()); + } + Ok(()) +} + +fn parse_queries_from_path(path: &str) -> CliResult> { + let source = PathBuf::from(path); + if source.is_file() { + let content = fs::read_to_string(&source) + .map_err(|e| format!("failed to read queries file '{}': {e}", source.display()))?; + let queries = split_queries(&content); + if queries.is_empty() { + return Err(format!("no SQL queries found in '{}'", source.display())); + } + return Ok(queries); + } + + if !source.is_dir() { + return Err(format!( + "queries path '{}' does not exist", + source.display() + )); + } + + let mut sql_files = Vec::new(); + let mut stack = vec![source.clone()]; + while let Some(dir) = stack.pop() { + let entries = fs::read_dir(&dir) + .map_err(|e| format!("failed to read queries directory '{}': {e}", dir.display()))?; + for entry in entries { + let entry = entry.map_err(|e| format!("failed to read query directory entry: {e}"))?; + let path = entry.path(); + if path.is_dir() { + stack.push(path); + continue; + } + if path + .extension() + .and_then(|value| value.to_str()) + .is_some_and(|ext| ext.eq_ignore_ascii_case("sql")) + { + sql_files.push(path); + } + } + } + sql_files.sort(); + + let mut queries = Vec::new(); + for file in sql_files { + let content = fs::read_to_string(&file) + .map_err(|e| format!("failed to read queries file '{}': {e}", file.display()))?; + queries.extend(split_queries(&content)); + } + if queries.is_empty() { + return Err(format!("no SQL queries found under '{}'", source.display())); + } + Ok(queries) +} + +fn split_column_reference(reference: &str) -> (String, String) { + let cleaned = reference + .trim() + .trim_matches('`') + .trim_matches('"') + .trim_matches('[') + .trim_matches(']'); + if let Some((table, column)) = cleaned.rsplit_once('.') { + return (table.trim().to_string(), column.trim().to_string()); + } + (String::new(), cleaned.to_string()) +} + +fn infer_query_models(payload: &CliMigratorAnalysisPayload) -> BTreeSet { + let mut models = BTreeSet::new(); + for reference in &payload.column_references { + let (table, _column) = split_column_reference(reference); + if !table.is_empty() { + models.insert(table); + } + } + for (table, _column) in &payload.group_by_columns { + if !table.is_empty() { + models.insert(table.clone()); + } + } + models +} + +fn is_time_dimension_name(name: &str) -> bool { + let lowered = name.to_ascii_lowercase(); + lowered.contains("date") + || lowered.contains("time") + || lowered.contains("timestamp") + || lowered == "created_at" + || lowered == "updated_at" +} + +fn normalize_migrator_agg_name(raw_agg: &str, raw_arg: &str) -> (String, String) { + let mut agg = raw_agg.to_ascii_lowercase(); + let mut arg = raw_arg.trim().to_string(); + if agg == "count" && arg.to_ascii_lowercase().starts_with("distinct ") { + agg = "count_distinct".to_string(); + arg = arg[8..].trim().to_string(); + } + if agg == "count" && arg.is_empty() { + arg = "*".to_string(); + } + (agg, arg) +} + +fn extract_aggregations_from_query(query: &str) -> Vec<(String, String, String)> { + let Ok(re) = Regex::new(r"(?i)\b(sum|avg|count|min|max|median)\s*\(([^)]*)\)") else { + return Vec::new(); + }; + let mut values = Vec::new(); + for capture in re.captures_iter(query) { + let Some(agg) = capture.get(1).map(|value| value.as_str()) else { + continue; + }; + let Some(arg) = capture.get(2).map(|value| value.as_str()) else { + continue; + }; + let (normalized_agg, normalized_arg) = normalize_migrator_agg_name(agg, arg); + let (table, column) = split_column_reference(&normalized_arg); + let effective_column = if column.is_empty() { + normalized_arg + } else { + column + }; + values.push((normalized_agg, effective_column, table)); + } + values +} + +fn build_metric_name(agg: &str, column: &str) -> String { + if agg == "count" && column == "*" { + "count".to_string() + } else if agg == "count" || agg == "count_distinct" { + format!("{}_count", column) + } else { + format!("{agg}_{column}") + } +} + +fn build_rewritten_query( + payload: &CliMigratorAnalysisPayload, + aggregations: &[(String, String, String)], + default_model: Option<&str>, +) -> Option { + let default = default_model?; + let mut dimensions = BTreeSet::new(); + for (table, column) in &payload.group_by_columns { + let target = if table.is_empty() { default } else { table }; + if target == default && !column.is_empty() { + dimensions.insert(format!("{default}.{column}")); + } + } + + let mut metrics = BTreeSet::new(); + for (agg, column, table) in aggregations { + let target = if table.is_empty() { default } else { table }; + if target == default { + metrics.insert(format!("{default}.{}", build_metric_name(agg, column))); + } + } + + if dimensions.is_empty() && metrics.is_empty() { + return None; + } + + let mut selections = Vec::new(); + selections.extend(dimensions); + selections.extend(metrics); + let body = selections + .iter() + .map(|value| format!(" {value}")) + .collect::>() + .join(",\n"); + Some(format!("SELECT\n{body}\nFROM {default}")) +} + +fn render_generated_model_yaml(model_name: &str, model: &CliGeneratedModel) -> CliResult { + let mut root = serde_yaml::Mapping::new(); + + let mut model_meta = serde_yaml::Mapping::new(); + model_meta.insert( + serde_yaml::Value::String("name".to_string()), + serde_yaml::Value::String(model_name.to_string()), + ); + model_meta.insert( + serde_yaml::Value::String("table".to_string()), + serde_yaml::Value::String(model_name.to_string()), + ); + model_meta.insert( + serde_yaml::Value::String("description".to_string()), + serde_yaml::Value::String("Auto-generated from query analysis".to_string()), + ); + root.insert( + serde_yaml::Value::String("model".to_string()), + serde_yaml::Value::Mapping(model_meta), + ); + + if !model.dimensions.is_empty() { + let mut dimensions = Vec::new(); + for dimension in &model.dimensions { + let mut item = serde_yaml::Mapping::new(); + item.insert( + serde_yaml::Value::String("name".to_string()), + serde_yaml::Value::String(dimension.to_string()), + ); + item.insert( + serde_yaml::Value::String("sql".to_string()), + serde_yaml::Value::String(dimension.to_string()), + ); + item.insert( + serde_yaml::Value::String("type".to_string()), + serde_yaml::Value::String(if is_time_dimension_name(dimension) { + "time".to_string() + } else { + "categorical".to_string() + }), + ); + dimensions.push(serde_yaml::Value::Mapping(item)); + } + root.insert( + serde_yaml::Value::String("dimensions".to_string()), + serde_yaml::Value::Sequence(dimensions), + ); + } + + if !model.metrics.is_empty() { + let mut metrics = model.metrics.clone(); + metrics.sort_by(|left, right| left.name.cmp(&right.name)); + let mut metric_values = Vec::new(); + for metric in metrics { + let mut item = serde_yaml::Mapping::new(); + item.insert( + serde_yaml::Value::String("name".to_string()), + serde_yaml::Value::String(metric.name), + ); + item.insert( + serde_yaml::Value::String("agg".to_string()), + serde_yaml::Value::String(metric.agg), + ); + item.insert( + serde_yaml::Value::String("sql".to_string()), + serde_yaml::Value::String(metric.sql), + ); + metric_values.push(serde_yaml::Value::Mapping(item)); + } + root.insert( + serde_yaml::Value::String("metrics".to_string()), + serde_yaml::Value::Sequence(metric_values), + ); + } + + serde_yaml::to_string(&serde_yaml::Value::Mapping(root)) + .map_err(|e| format!("failed to serialize generated model '{model_name}': {e}")) +} + +fn migrator_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic migrator [models_dir] --queries [--verbose] [--generate-models ] [--models ]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + if positionals.len() > 1 { + return Err(format!( + "migrator: unexpected positional arguments: {}", + positionals[1..].join(" ") + )); + } + + let queries_path = option_value_any(&options, &["--queries", "--queries-file"]) + .ok_or_else(|| "migrator requires --queries ".to_string())?; + let queries = parse_queries_from_path(&queries_path)?; + let verbose = option_flag(&options, "--verbose"); + let generate_models_dir = option_value(&options, "--generate-models"); + + if let Some(output_root) = generate_models_dir { + let models_output = PathBuf::from(&output_root).join("models"); + let rewritten_output = PathBuf::from(&output_root).join("rewritten_queries"); + fs::create_dir_all(&models_output).map_err(|e| { + format!( + "failed to create generated models directory '{}': {e}", + models_output.display() + ) + })?; + fs::create_dir_all(&rewritten_output).map_err(|e| { + format!( + "failed to create rewritten queries directory '{}': {e}", + rewritten_output.display() + ) + })?; + + let mut models: BTreeMap = BTreeMap::new(); + let mut rewritten_queries: Vec = Vec::new(); + let mut parseable_queries = 0usize; + + for query in &queries { + let analysis_json = sidemantic::analyze_migrator_query(query) + .map_err(|e| format!("failed to analyze query for migrator bootstrap: {e}"))?; + let payload: CliMigratorAnalysisPayload = serde_json::from_str(&analysis_json) + .map_err(|e| format!("failed to decode migrator analysis payload: {e}"))?; + parseable_queries += 1; + + let query_models = infer_query_models(&payload); + let default_model = query_models.iter().next().map(String::as_str); + let aggregations = extract_aggregations_from_query(query); + + for (table, column) in &payload.group_by_columns { + let resolved_model = if table.is_empty() { + default_model.map(str::to_string) + } else { + Some(table.to_string()) + }; + let Some(model_name) = resolved_model else { + continue; + }; + if column.is_empty() { + continue; + } + models + .entry(model_name) + .or_default() + .dimensions + .insert(column.to_string()); + } + + for (agg, column, table) in &aggregations { + let resolved_model = if table.is_empty() { + default_model.map(str::to_string) + } else { + Some(table.to_string()) + }; + let Some(model_name) = resolved_model else { + continue; + }; + let metric = CliMigratorMetric { + name: build_metric_name(agg, column), + agg: agg.to_string(), + sql: column.to_string(), + }; + let model_entry = models.entry(model_name).or_default(); + if !model_entry + .metrics + .iter() + .any(|existing| existing.name == metric.name) + { + model_entry.metrics.push(metric); + } + } + + rewritten_queries.push( + build_rewritten_query(&payload, &aggregations, default_model) + .unwrap_or_else(|| query.clone()), + ); + } + + for (model_name, model) in &models { + let file_path = models_output.join(format!("{model_name}.yml")); + let rendered = render_generated_model_yaml(model_name, model)?; + fs::write(&file_path, rendered).map_err(|e| { + format!( + "failed to write generated model file '{}': {e}", + file_path.display() + ) + })?; + } + for (index, rewritten_sql) in rewritten_queries.iter().enumerate() { + let file_path = rewritten_output.join(format!("query_{}.sql", index + 1)); + fs::write(&file_path, format!("{rewritten_sql}\n")).map_err(|e| { + format!( + "failed to write rewritten query file '{}': {e}", + file_path.display() + ) + })?; + } + + println!( + "Generated {} models and {} rewritten queries in {}", + models.len(), + rewritten_queries.len(), + output_root + ); + println!("Analyzed {parseable_queries} queries"); + return Ok(()); + } + + let models_path = option_value(&options, "--models") + .or_else(|| positionals.first().cloned()) + .unwrap_or_else(|| ".".to_string()); + let runtime = load_runtime(&models_path)?; + let available_models = runtime + .graph() + .models() + .map(|model| model.name.clone()) + .collect::>(); + + let mut parseable_queries = 0usize; + let mut rewritable_queries = 0usize; + let mut missing_models = BTreeSet::new(); + + for (index, query) in queries.iter().enumerate() { + let analysis_json = sidemantic::analyze_migrator_query(query) + .map_err(|e| format!("failed to analyze query in migrator coverage mode: {e}"))?; + let payload: CliMigratorAnalysisPayload = serde_json::from_str(&analysis_json) + .map_err(|e| format!("failed to decode migrator analysis payload: {e}"))?; + parseable_queries += 1; + + let query_models = infer_query_models(&payload); + let missing_for_query = query_models + .iter() + .filter(|model| !available_models.contains(*model)) + .cloned() + .collect::>(); + + if missing_for_query.is_empty() { + rewritable_queries += 1; + } else { + missing_models.extend(missing_for_query.iter().cloned()); + } + + if verbose { + println!("Query #{}:", index + 1); + if query_models.is_empty() { + println!(" models: (none inferred)"); + } else { + println!( + " models: {}", + query_models.iter().cloned().collect::>().join(", ") + ); + } + if missing_for_query.is_empty() { + println!(" rewritable: yes"); + } else { + println!( + " missing models: {}", + missing_for_query + .iter() + .cloned() + .collect::>() + .join(", ") + ); + } + } + } + + let coverage = if queries.is_empty() { + 0.0 + } else { + (rewritable_queries as f64 / queries.len() as f64) * 100.0 + }; + + println!("Total Queries: {}", queries.len()); + println!("Parseable: {}", parseable_queries); + println!("Rewritable: {}", rewritable_queries); + println!("Coverage: {:.1}%", coverage); + if !missing_models.is_empty() { + println!( + "Missing Models: {}", + missing_models.into_iter().collect::>().join(", ") + ); + } + + Ok(()) +} + +fn info_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!("Usage: sidemantic info [--models ]"); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "info")?; + let models = option_value(&options, "--models").unwrap_or_else(|| ".".to_string()); + let runtime = load_runtime(&models)?; + + let mut sorted_models = runtime.graph().models().collect::>(); + sorted_models.sort_by(|left, right| left.name.cmp(&right.name)); + + if sorted_models.is_empty() { + println!("No models found"); + return Ok(()); + } + + println!("\nSemantic Layer: {models}\n"); + for model in sorted_models { + println!("- {}", model.name); + let table = match model.table.as_deref() { + Some(value) if !value.is_empty() => value, + _ => "N/A", + }; + println!(" Table: {table}"); + println!(" Dimensions: {}", model.dimensions.len()); + println!(" Metrics: {}", model.metrics.len()); + println!(" Relationships: {}", model.relationships.len()); + if !model.relationships.is_empty() { + let connected = model + .relationships + .iter() + .map(|relationship| relationship.name.as_str()) + .collect::>() + .join(", "); + println!(" Connected to: {connected}"); + } + println!(); + } + + Ok(()) +} + +fn query_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic query [--models ] [--sql | --sql-file | ] [--dry-run] [--output ] [--connection | --db ] [--driver ] [--uri ] [--entrypoint ] [--dbopt ] [--connopt ] [--username ] [--password ]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + let models = option_value(&options, "--models").unwrap_or_else(|| ".".to_string()); + let runtime = load_runtime(&models)?; + let sql = if let Some(sql) = option_value(&options, "--sql") { + sql + } else if let Some(path) = option_value(&options, "--sql-file") { + fs::read_to_string(&path).map_err(|e| format!("failed to read SQL file '{path}': {e}"))? + } else if positionals.is_empty() { + return Err("query requires --sql, --sql-file, or a positional SQL query".to_string()); + } else { + positionals.join(" ") + }; + + let rewritten = runtime + .rewrite(&sql) + .map_err(|e| format!("failed to rewrite SQL: {e}"))?; + if option_flag(&options, "--dry-run") { + println!("{rewritten}"); + return Ok(()); + } + + #[cfg(not(feature = "adbc-exec"))] + { + let _ = rewritten; + Err("query execution requires the crate feature 'adbc-exec'".to_string()) + } + + #[cfg(feature = "adbc-exec")] + { + let adbc = parse_query_adbc_cli_config(&options, "query")?; + let result = execute_with_adbc(AdbcExecutionRequest { + driver: adbc.driver, + sql: rewritten, + uri: adbc.uri, + entrypoint: adbc.entrypoint, + database_options: adbc.database_options, + connection_options: adbc.connection_options, + }) + .map_err(|e| format!("failed to execute query via ADBC: {e}"))?; + write_csv_rows( + &result.columns, + &result.rows, + option_value(&options, "--output").as_deref(), + ) + } +} + +fn run_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic run --models [--driver ] [--uri ] [--entrypoint ] [--username ] [--password ] [--dbopt ...] [--connopt ...] [--catalog ] [--schema ] [--autocommit ] [--read-only ] [--isolation-level ] [query flags]" + ); + return Ok(()); + } + + #[cfg(not(feature = "adbc-exec"))] + { + let _ = args; + Err("run requires the crate feature 'adbc-exec'".to_string()) + } + + #[cfg(feature = "adbc-exec")] + { + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "run")?; + + let models = require_option(&options, "--models")?; + let runtime = load_runtime(&models)?; + let query = build_query_from_options(&options)?; + let sql = runtime + .compile(&query) + .map_err(|e| format!("failed to compile query: {e}"))?; + + let adbc = parse_adbc_cli_config(&options, "run")?; + let result = execute_with_adbc(AdbcExecutionRequest { + driver: adbc.driver, + sql: sql.clone(), + uri: adbc.uri, + entrypoint: adbc.entrypoint, + database_options: adbc.database_options, + connection_options: adbc.connection_options, + }) + .map_err(|e| format!("failed to execute query via ADBC: {e}"))?; + + let rows = result + .rows + .iter() + .map(|row| { + let mut row_map = JsonMap::new(); + for (idx, column) in result.columns.iter().enumerate() { + let value = row + .get(idx) + .map(adbc_value_to_json) + .unwrap_or(JsonValue::Null); + row_map.insert(column.clone(), value); + } + JsonValue::Object(row_map) + }) + .collect::>(); + let payload = RunOutput { + sql, + columns: result.columns, + row_count: rows.len(), + rows, + }; + println!( + "{}", + serde_json::to_string_pretty(&payload) + .map_err(|e| format!("failed to serialize run output: {e}"))? + ); + Ok(()) + } +} + +fn preagg_command(args: &[String]) -> CliResult<()> { + if args.is_empty() || args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic preagg [options]\n\ + \n\ + materialize: --models --model --name [--execute ...adbc opts]\n\ + recommend: (--queries-file | [--connection |--db ] [--days ] [--limit ] [adbc opts]) [--min-query-count ] [--min-benefit-score ] [--top-n ] [--json]\n\ + apply: --models (--queries-file | [--connection |--db ] [--days ] [--limit ] [adbc opts]) [--min-query-count ] [--min-benefit-score ] [--top-n ] [--dry-run]\n\ + refresh: --models [--model ] [--name |--preagg ] [--mode ] [--dialect ] [--refresh-every ] [--from-watermark ] [--lookback ] [--watermark-column ] [--execute ...adbc opts]\n\ + \n\ + ADBC opts: [--driver ] [--uri ] [--entrypoint ] [--username ] [--password ] [--dbopt ] [--connopt ]" + ); + return Ok(()); + } + + match args[0].as_str() { + "materialize" => preagg_materialize_command(&args[1..]), + "recommend" => preagg_recommend_command(&args[1..]), + "apply" => preagg_apply_command(&args[1..]), + "refresh" => preagg_refresh_command(&args[1..]), + other => Err(format!( + "unknown preagg subcommand '{other}'. Use 'preagg --help' for usage." + )), + } +} + +fn preagg_materialize_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic preagg materialize --models --model --name [--preagg-database ] [--preagg-schema ] [--execute ...adbc opts]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "preagg materialize")?; + let models = require_option(&options, "--models")?; + let model_name = require_option(&options, "--model")?; + let preagg_name = require_option(&options, "--name")?; + + let runtime = load_runtime(&models)?; + let sql = runtime + .generate_preaggregation_materialization_sql(&model_name, &preagg_name) + .map_err(|e| format!("failed to generate pre-aggregation SQL: {e}"))?; + println!("{sql}"); + + if !option_flag(&options, "--execute") { + return Ok(()); + } + + #[cfg(not(feature = "adbc-exec"))] + { + Err("preagg materialize --execute requires feature 'adbc-exec'".to_string()) + } + + #[cfg(feature = "adbc-exec")] + { + let adbc = parse_adbc_cli_config(&options, "preagg materialize")?; + execute_sql_statements(&adbc, &[sql])?; + println!("materialized"); + Ok(()) + } +} + +fn preagg_refresh_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic preagg refresh --models [--model ] [--name |--preagg ] [--mode ] [--dialect ] [--refresh-every ] [--from-watermark ] [--lookback ] [--watermark-column ] [--preagg-database ] [--preagg-schema ] [--execute ...adbc opts]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "preagg refresh")?; + + let models = require_option(&options, "--models")?; + let model_filter = option_value(&options, "--model"); + let preagg_filter = option_value_any(&options, &["--name", "--preagg"]); + let mode = option_value(&options, "--mode"); + if let Some(mode_value) = mode.as_deref() { + let mode_is_valid = matches!(mode_value, "full" | "incremental" | "merge" | "engine"); + if !mode_is_valid { + return Err(format!( + "invalid preagg refresh mode '{mode_value}'. Supported modes: full, incremental, merge, engine" + )); + } + } + let dialect = option_value(&options, "--dialect"); + let refresh_every = option_value(&options, "--refresh-every"); + if mode.as_deref() == Some("engine") && dialect.is_none() { + return Err("engine refresh mode requires --dialect".to_string()); + } + + let preagg_database = option_value(&options, "--preagg-database"); + let preagg_schema = option_value(&options, "--preagg-schema"); + let from_watermark = option_value(&options, "--from-watermark"); + let lookback = option_value(&options, "--lookback"); + let forced_watermark_column = option_value(&options, "--watermark-column"); + let execute = option_flag(&options, "--execute"); + + let runtime = load_runtime(&models)?; + let mut plans: Vec = Vec::new(); + + for model in runtime.graph().models() { + if let Some(target_model) = model_filter.as_deref() { + if model.name != target_model { + continue; + } + } + + for preagg in &model.pre_aggregations { + if let Some(target_preagg) = preagg_filter.as_deref() { + if preagg.name != target_preagg { + continue; + } + } + + let table_name = preagg.table_name( + &model.name, + preagg_database.as_deref(), + preagg_schema.as_deref(), + ); + let resolved_mode = mode.clone().unwrap_or_else(|| { + if preagg + .refresh_key + .as_ref() + .is_some_and(|refresh_key| refresh_key.incremental) + { + "incremental".to_string() + } else { + "full".to_string() + } + }); + let source_sql = runtime + .generate_preaggregation_materialization_sql(&model.name, &preagg.name) + .map_err(|e| { + format!( + "failed to generate materialization SQL for {}.{}: {e}", + model.name, preagg.name + ) + })?; + + let default_watermark = preagg + .time_dimension + .as_ref() + .zip(preagg.granularity.as_ref()) + .map(|(time_dimension, granularity)| format!("{time_dimension}_{granularity}")); + let watermark_column = forced_watermark_column + .clone() + .or(default_watermark) + .unwrap_or_default(); + let effective_refresh_every = refresh_every.clone().or_else(|| { + preagg + .refresh_key + .as_ref() + .and_then(|key| key.every.clone()) + }); + + let statements = build_preaggregation_refresh_statements( + &resolved_mode, + &table_name, + &source_sql, + if watermark_column.is_empty() { + None + } else { + Some(watermark_column.as_str()) + }, + from_watermark.as_deref(), + lookback.as_deref(), + dialect.as_deref(), + effective_refresh_every.as_deref(), + ) + .map_err(refresh_plan_error_message)?; + + plans.push(RefreshPlan { + model_name: model.name.clone(), + preagg_name: preagg.name.clone(), + table_name, + mode: resolved_mode, + statements, + }); + } + } + + if plans.is_empty() { + let mut scope_parts = Vec::new(); + if let Some(value) = model_filter { + scope_parts.push(format!("model={value}")); + } + if let Some(value) = preagg_filter { + scope_parts.push(format!("preagg={value}")); + } + let scope = if scope_parts.is_empty() { + "requested scope".to_string() + } else { + scope_parts.join(", ") + }; + return Err(format!("no pre-aggregations found for {scope}")); + } + + for plan in &plans { + println!( + "-- refresh {}.{} mode={} table={}", + plan.model_name, plan.preagg_name, plan.mode, plan.table_name + ); + for statement in &plan.statements { + println!("{statement};"); + } + println!(); + } + + if !execute { + println!("dry-run: generated refresh SQL only (add --execute to run statements)"); + return Ok(()); + } + + #[cfg(not(feature = "adbc-exec"))] + { + Err("preagg refresh --execute requires feature 'adbc-exec'".to_string()) + } + + #[cfg(feature = "adbc-exec")] + { + let adbc = parse_adbc_cli_config(&options, "preagg refresh")?; + let mut statement_count = 0usize; + for plan in &plans { + execute_sql_statements(&adbc, &plan.statements)?; + statement_count += plan.statements.len(); + } + println!( + "refreshed {} pre-aggregation(s) with {} SQL statement(s)", + plans.len(), + statement_count + ); + Ok(()) + } +} + +fn preagg_recommend_command(args: &[String]) -> CliResult<()> { + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "preagg recommend")?; + let queries = load_preagg_queries(&options, "preagg recommend")?; + let patterns_json = extract_preaggregation_patterns(queries) + .map_err(|e| format!("failed to extract pre-aggregation patterns: {e}"))?; + + let min_query_count = option_usize(&options, "--min-query-count")?.unwrap_or(10); + let min_benefit_score = option_f64(&options, "--min-benefit-score")?.unwrap_or(0.3); + let top_n = option_usize(&options, "--top-n")?; + let recommendations_json = recommend_preaggregation_patterns( + &patterns_json, + min_query_count, + min_benefit_score, + top_n, + ) + .map_err(|e| format!("failed to build recommendations: {e}"))?; + + if option_flag(&options, "--json") { + println!("{recommendations_json}"); + return Ok(()); + } + + let summary_json = summarize_preaggregation_patterns(&patterns_json, min_query_count) + .map_err(|e| format!("failed to summarize recommendations: {e}"))?; + let summary: serde_json::Value = serde_json::from_str(&summary_json) + .map_err(|e| format!("failed to parse recommendation summary payload: {e}"))?; + let recommendations: Vec = serde_json::from_str(&recommendations_json) + .map_err(|e| format!("failed to parse recommendations payload: {e}"))?; + + let total_queries = summary + .get("total_queries") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let unique_patterns = summary + .get("unique_patterns") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let patterns_above_threshold = summary + .get("patterns_above_threshold") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + + eprintln!("\n\u{2713} Analyzed {total_queries} queries"); + eprintln!( + " Found {unique_patterns} unique pattern{}", + if unique_patterns == 1 { "" } else { "s" } ); + eprintln!(" {patterns_above_threshold} patterns above threshold"); + + if let Some(models) = summary.get("models").and_then(serde_json::Value::as_object) { + if !models.is_empty() { + eprintln!("\n Models:"); + let mut model_counts: Vec<_> = models.iter().collect(); + model_counts.sort_by(|left, right| left.0.cmp(right.0)); + for (model_name, count) in model_counts { + let count = count.as_u64().unwrap_or(0); + eprintln!(" {model_name}: {count} queries"); + } + } + } + + if recommendations.is_empty() { + eprintln!("\nNo recommendations found above thresholds"); + eprintln!( + "Try lowering --min-count (currently {min_query_count}) or --min-score (currently {min_benefit_score})" + ); + return Ok(()); + } + + println!("\n{}", "=".repeat(80)); println!( - " Segments: {:?}\n", - model.segments.iter().map(|s| &s.name).collect::>() + "Pre-Aggregation Recommendations (found {})", + recommendations.len() + ); + println!("{}\n", "=".repeat(80)); + + for (index, recommendation) in recommendations.iter().enumerate() { + println!("{}. {}", index + 1, recommendation.suggested_name); + println!(" Model: {}", recommendation.pattern.model); + println!(" Query Count: {}", recommendation.query_count); + println!( + " Benefit Score: {:.2}", + recommendation.estimated_benefit_score + ); + println!(" Metrics: {}", recommendation.pattern.metrics.join(", ")); + if recommendation.pattern.dimensions.is_empty() { + println!(" Dimensions: (none)"); + } else { + println!( + " Dimensions: {}", + recommendation.pattern.dimensions.join(", ") + ); + } + if !recommendation.pattern.granularities.is_empty() { + println!( + " Granularities: {}", + recommendation.pattern.granularities.join(", ") + ); + } + println!(); + } + + eprintln!("Run 'sidemantic preagg apply' to add these to your models"); + Ok(()) +} + +fn collect_yaml_files(models_path: &str) -> CliResult> { + let path = PathBuf::from(models_path); + if path.is_file() { + let extension = path + .extension() + .and_then(|value| value.to_str()) + .map(|value| value.to_ascii_lowercase()) + .unwrap_or_default(); + if extension == "yml" || extension == "yaml" { + return Ok(vec![path]); + } + return Err(format!("models path '{models_path}' is not a YAML file")); + } + + if !path.is_dir() { + return Err(format!("models path '{models_path}' does not exist")); + } + + let mut files: Vec = Vec::new(); + let mut stack: Vec = vec![path]; + while let Some(dir) = stack.pop() { + let entries = fs::read_dir(&dir) + .map_err(|e| format!("failed to read models directory '{}': {e}", dir.display()))?; + for entry in entries { + let entry = entry.map_err(|e| format!("failed to read models directory entry: {e}"))?; + let entry_path = entry.path(); + if entry_path.is_dir() { + stack.push(entry_path); + continue; + } + if let Some(extension) = entry_path.extension().and_then(|value| value.to_str()) { + let extension = extension.to_ascii_lowercase(); + if extension == "yml" || extension == "yaml" { + files.push(entry_path); + } + } + } + } + files.sort(); + Ok(files) +} + +fn find_model_in_yaml(content: &str, model_name: &str) -> CliResult { + let yaml: serde_yaml::Value = + serde_yaml::from_str(content).map_err(|e| format!("failed to parse YAML: {e}"))?; + let Some(root) = yaml.as_mapping() else { + return Ok(false); + }; + + let models_key = serde_yaml::Value::String("models".to_string()); + let Some(models_value) = root.get(&models_key) else { + return Ok(false); + }; + let Some(models) = models_value.as_sequence() else { + return Ok(false); + }; + + for model in models { + let Some(model_mapping) = model.as_mapping() else { + continue; + }; + let name_key = serde_yaml::Value::String("name".to_string()); + let Some(name_value) = model_mapping.get(&name_key) else { + continue; + }; + if name_value.as_str() == Some(model_name) { + return Ok(true); + } + } + + Ok(false) +} + +fn preagg_value_name(preagg_value: &serde_yaml::Value) -> Option { + let mapping = preagg_value.as_mapping()?; + let key = serde_yaml::Value::String("name".to_string()); + mapping + .get(&key) + .and_then(serde_yaml::Value::as_str) + .map(str::to_string) +} + +fn parse_string_array( + value: &serde_json::Value, + key: &str, +) -> CliResult>> { + let Some(raw) = value.get(key) else { + return Ok(None); + }; + let Some(items) = raw.as_array() else { + return Err(format!( + "pre-aggregation definition field '{key}' must be an array" + )); + }; + if items.is_empty() { + return Ok(None); + } + let mut output = Vec::with_capacity(items.len()); + for item in items { + let Some(item_str) = item.as_str() else { + return Err(format!( + "pre-aggregation definition field '{key}' must contain strings" + )); + }; + output.push(serde_yaml::Value::String(item_str.to_string())); + } + Ok(Some(output)) +} + +fn build_preagg_yaml_value(definition_json: &str) -> CliResult { + let definition: serde_json::Value = serde_json::from_str(definition_json) + .map_err(|e| format!("failed to parse generated pre-aggregation definition JSON: {e}"))?; + let Some(name) = definition + .get("name") + .and_then(serde_json::Value::as_str) + .map(str::to_string) + else { + return Err("generated pre-aggregation definition missing name".to_string()); + }; + let Some(measures) = parse_string_array(&definition, "measures")? else { + return Err("generated pre-aggregation definition missing measures".to_string()); + }; + + let mut mapping = serde_yaml::Mapping::new(); + mapping.insert( + serde_yaml::Value::String("name".to_string()), + serde_yaml::Value::String(name), + ); + mapping.insert( + serde_yaml::Value::String("measures".to_string()), + serde_yaml::Value::Sequence(measures), + ); + + if let Some(dimensions) = parse_string_array(&definition, "dimensions")? { + mapping.insert( + serde_yaml::Value::String("dimensions".to_string()), + serde_yaml::Value::Sequence(dimensions), + ); + } + if let Some(time_dimension) = definition + .get("time_dimension") + .and_then(serde_json::Value::as_str) + .map(str::to_string) + { + mapping.insert( + serde_yaml::Value::String("time_dimension".to_string()), + serde_yaml::Value::String(time_dimension), + ); + } + if let Some(granularity) = definition + .get("granularity") + .and_then(serde_json::Value::as_str) + .map(str::to_string) + { + mapping.insert( + serde_yaml::Value::String("granularity".to_string()), + serde_yaml::Value::String(granularity), + ); + } + + Ok(serde_yaml::Value::Mapping(mapping)) +} + +fn preagg_apply_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic preagg apply --models (--queries-file | [--connection |--db ] [--days ] [--limit ] [adbc opts]) [--min-query-count ] [--min-benefit-score ] [--top-n ] [--dry-run]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + expect_no_positionals(&positionals, "preagg apply")?; + let models = require_option(&options, "--models")?; + let queries = load_preagg_queries(&options, "preagg apply")?; + let patterns_json = extract_preaggregation_patterns(queries) + .map_err(|e| format!("failed to extract pre-aggregation patterns: {e}"))?; + + let min_query_count = option_usize(&options, "--min-query-count")?.unwrap_or(10); + let min_benefit_score = option_f64(&options, "--min-benefit-score")?.unwrap_or(0.3); + let top_n = option_usize(&options, "--top-n")?; + let dry_run = option_flag(&options, "--dry-run"); + + let recommendations_json = recommend_preaggregation_patterns( + &patterns_json, + min_query_count, + min_benefit_score, + top_n, + ) + .map_err(|e| format!("failed to build recommendations: {e}"))?; + let recommendations: Vec = serde_json::from_str(&recommendations_json) + .map_err(|e| format!("failed to parse recommendations payload: {e}"))?; + if recommendations.is_empty() { + eprintln!("No recommendations found above thresholds"); + return Ok(()); + } + + let yaml_files = collect_yaml_files(&models)?; + if yaml_files.is_empty() { + return Err(format!("no YAML model files found under '{models}'")); + } + + eprintln!( + "\nFound {} recommendations to apply\n", + recommendations.len() ); + let mut by_model: BTreeMap> = BTreeMap::new(); + for recommendation in recommendations { + by_model + .entry(recommendation.pattern.model.clone()) + .or_default() + .push(recommendation); + } + + let mut updated_count = 0usize; + for (model_name, model_recommendations) in by_model { + let mut model_file: Option = None; + for yaml_file in &yaml_files { + let file_content = fs::read_to_string(yaml_file) + .map_err(|e| format!("failed to read model file '{}': {e}", yaml_file.display()))?; + if find_model_in_yaml(&file_content, &model_name)? { + model_file = Some(yaml_file.clone()); + break; + } + } + + let Some(model_file_path) = model_file else { + eprintln!("warning: Could not find YAML file for model '{model_name}'"); + continue; + }; + + let mut yaml_data: serde_yaml::Value = + serde_yaml::from_str(&fs::read_to_string(&model_file_path).map_err(|e| { + format!( + "failed to read model file '{}': {e}", + model_file_path.display() + ) + })?) + .map_err(|e| { + format!( + "failed to parse YAML file '{}': {e}", + model_file_path.display() + ) + })?; + let Some(root_mapping) = yaml_data.as_mapping_mut() else { + return Err(format!( + "YAML file '{}' must contain a mapping root", + model_file_path.display() + )); + }; + + let models_key = serde_yaml::Value::String("models".to_string()); + let Some(models_value) = root_mapping.get_mut(&models_key) else { + continue; + }; + let Some(models_seq) = models_value.as_sequence_mut() else { + return Err(format!( + "YAML file '{}' has non-sequence 'models' entry", + model_file_path.display() + )); + }; + + let mut file_modified = false; + for model in models_seq { + let Some(model_mapping) = model.as_mapping_mut() else { + continue; + }; + + let name_key = serde_yaml::Value::String("name".to_string()); + if model_mapping + .get(&name_key) + .and_then(serde_yaml::Value::as_str) + != Some(model_name.as_str()) + { + continue; + } + + let preaggs_key = serde_yaml::Value::String("pre_aggregations".to_string()); + if !model_mapping.contains_key(&preaggs_key) { + model_mapping.insert(preaggs_key.clone(), serde_yaml::Value::Sequence(Vec::new())); + } + let Some(preaggs_seq) = model_mapping + .get_mut(&preaggs_key) + .and_then(serde_yaml::Value::as_sequence_mut) + else { + return Err(format!( + "model '{model_name}' in '{}' has non-sequence pre_aggregations", + model_file_path.display() + )); + }; + + for recommendation in &model_recommendations { + let recommendation_payload = serde_json::json!({ + "pattern": { + "model": recommendation.pattern.model, + "metrics": recommendation.pattern.metrics, + "dimensions": recommendation.pattern.dimensions, + "granularities": recommendation.pattern.granularities, + "count": recommendation.pattern.count, + }, + "suggested_name": recommendation.suggested_name, + "query_count": recommendation.query_count, + "estimated_benefit_score": recommendation.estimated_benefit_score, + }); + let recommendation_json = serde_json::to_string(&recommendation_payload) + .map_err(|e| format!("failed to serialize recommendation payload: {e}"))?; + let definition_json = generate_preaggregation_definition(&recommendation_json) + .map_err(|e| format!("failed to generate pre-aggregation definition: {e}"))?; + let preagg_value = build_preagg_yaml_value(&definition_json)?; + let Some(preagg_name) = preagg_value_name(&preagg_value) else { + return Err("generated pre-aggregation definition missing name".to_string()); + }; + + if preaggs_seq.iter().any(|existing| { + preagg_value_name(existing).as_deref() == Some(preagg_name.as_str()) + }) { + continue; + } + + preaggs_seq.push(preagg_value); + file_modified = true; + updated_count += 1; + eprintln!( + " + {model_name}.{preagg_name} ({} queries)", + recommendation.query_count + ); + } + break; + } + + if file_modified && !dry_run { + let rendered = serde_yaml::to_string(&yaml_data).map_err(|e| { + format!( + "failed to serialize updated YAML for '{}': {e}", + model_file_path.display() + ) + })?; + fs::write(&model_file_path, rendered).map_err(|e| { + format!( + "failed to write updated YAML file '{}': {e}", + model_file_path.display() + ) + })?; + } + } + + if dry_run { + eprintln!("Dry run: Would add {updated_count} pre-aggregations"); + } else { + eprintln!("\u{2713} Added {updated_count} pre-aggregations to model files"); + } + + Ok(()) +} + +pub(crate) fn workbench_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!( + "Usage: sidemantic workbench [models_dir] [--demo] [--connection ] [--db ]" + ); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + if positionals.len() > 1 { + return Err(format!( + "workbench: unexpected positional arguments: {}", + positionals[1..].join(" ") + )); + } + + let demo_mode = option_flag(&options, "--demo"); + let directory = if demo_mode { + let candidates = [ + PathBuf::from("examples").join("multi_format_demo"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("examples") + .join("multi_format_demo"), + ]; + candidates + .iter() + .find(|path| path.exists()) + .cloned() + .ok_or_else(|| "Error: Demo models not found".to_string())? + } else { + PathBuf::from( + positionals + .first() + .cloned() + .unwrap_or_else(|| ".".to_string()), + ) + }; + + if !directory.exists() { + return Err(format!( + "Error: Directory {} does not exist", + directory.display() + )); + } + + let connection = if let Some(value) = option_value(&options, "--connection") { + Some(value) + } else if let Some(db_path) = option_value(&options, "--db") { + let db_path = PathBuf::from(db_path); + let absolute = if db_path.is_absolute() { + db_path + } else { + env::current_dir() + .map_err(|e| format!("failed to inspect current directory: {e}"))? + .join(db_path) + }; + Some(render_duckdb_connection_url(&absolute)) + } else if demo_mode { + prepare_workbench_demo_connection()? + } else { + discover_duckdb_connection_from_data_dir(&directory)? + }; + + let mut details = vec![format!("models={}", directory.display())]; + if demo_mode { + details.push("demo=true".to_string()); + } + if let Some(value) = connection.as_ref() { + details.push(format!("connection={value}")); + } + + launch_workbench_tui(directory.to_string_lossy().as_ref(), connection).map_err(|err| { + format!( + "{err} (resolved: {})", + details + .iter() + .map(String::as_str) + .collect::>() + .join(", ") + ) + }) +} + +fn discover_duckdb_connection_from_data_dir(directory: &Path) -> CliResult> { + let data_dir = directory.join("data"); + if !data_dir.is_dir() { + return Ok(None); + } + + let entries = fs::read_dir(&data_dir).map_err(|e| { + format!( + "failed to read workbench data directory '{}': {e}", + data_dir.display() + ) + })?; + + let mut db_files = Vec::new(); + for entry in entries { + let entry = entry.map_err(|e| { + format!( + "failed to read workbench data directory entry in '{}': {e}", + data_dir.display() + ) + })?; + let path = entry.path(); + if path.is_file() + && path + .extension() + .and_then(|value| value.to_str()) + .is_some_and(|ext| ext.eq_ignore_ascii_case("db")) + { + db_files.push(path); + } + } + db_files.sort(); + + let Some(selected) = db_files.into_iter().next() else { + return Ok(None); + }; + + let absolute = if selected.is_absolute() { + selected + } else { + env::current_dir() + .map_err(|e| format!("failed to inspect current directory: {e}"))? + .join(selected) + }; + Ok(Some(render_duckdb_connection_url(&absolute))) +} + +fn unique_demo_db_path() -> CliResult { + let suffix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| format!("failed to compute demo database timestamp: {e}"))? + .as_nanos(); + let mut path = env::temp_dir(); + path.push(format!("sidemantic_workbench_demo_{suffix}.db")); + Ok(path) +} + +fn render_duckdb_connection_url(path: &Path) -> String { + let rendered = path.to_string_lossy(); + if path.is_absolute() { + format!("duckdb://{rendered}") + } else { + format!("duckdb:///{rendered}") + } +} + +#[cfg(feature = "adbc-exec")] +fn prepare_workbench_demo_connection() -> CliResult> { + let db_path = unique_demo_db_path()?; + let db_uri = db_path.to_string_lossy().to_string(); + let connection_url = render_duckdb_connection_url(&db_path); + + let adbc = AdbcCliConfig { + driver: normalize_adbc_driver_name("duckdb"), + uri: Some(db_uri), + entrypoint: None, + database_options: Vec::new(), + connection_options: Vec::new(), + }; + if let Err(err) = execute_sql_statements(&adbc, &workbench_demo_seed_sql()) { + eprintln!( + "warning: failed to seed demo database via ADBC ({err}); continuing with demo connection" + ); + } + + Ok(Some(connection_url)) +} - let generator = SqlGenerator::new(&graph); - let query = SemanticQuery::new() - .with_metrics(vec!["orders.revenue".into()]) - .with_dimensions(vec!["orders.status".into()]); +#[cfg(not(feature = "adbc-exec"))] +fn prepare_workbench_demo_connection() -> CliResult> { + let db_path = unique_demo_db_path()?; + Ok(Some(render_duckdb_connection_url(&db_path))) +} + +#[cfg(feature = "adbc-exec")] +fn workbench_demo_seed_sql() -> Vec { + vec![ + r#" +create table customers ( + id integer primary key, + name varchar, + email varchar, + region varchar, + signup_date date +) +"# + .trim() + .to_string(), + r#" +insert into customers values + (1, 'Alice Johnson', 'alice@example.com', 'North', '2023-01-15'), + (2, 'Bob Smith', 'bob@example.com', 'South', '2023-02-20'), + (3, 'Carol Davis', 'carol@example.com', 'East', '2023-03-10'), + (4, 'David Wilson', 'david@example.com', 'West', '2023-04-05'), + (5, 'Eve Martinez', 'eve@example.com', 'North', '2023-05-18') +"# + .trim() + .to_string(), + r#" +create table products ( + id integer primary key, + name varchar, + category varchar, + price decimal(10,2), + cost decimal(10,2) +) +"# + .trim() + .to_string(), + r#" +insert into products values + (1, 'Laptop Pro', 'Electronics', 1299.99, 800.00), + (2, 'Wireless Mouse', 'Electronics', 29.99, 15.00), + (3, 'Desk Chair', 'Furniture', 249.99, 120.00), + (4, 'Standing Desk', 'Furniture', 599.99, 350.00), + (5, 'Notebook Set', 'Office Supplies', 12.99, 5.00) +"# + .trim() + .to_string(), + r#" +create table orders ( + id integer primary key, + customer_id integer, + product_id integer, + quantity integer, + amount decimal(10,2), + status varchar, + created_at timestamp +) +"# + .trim() + .to_string(), + r#" +insert into orders values + (1, 1, 1, 1, 1299.99, 'completed', '2025-01-10 10:30:00'), + (2, 2, 2, 2, 59.98, 'completed', '2025-01-11 11:45:00'), + (3, 3, 3, 1, 249.99, 'pending', '2025-01-12 09:15:00'), + (4, 4, 4, 1, 599.99, 'completed', '2025-01-13 14:20:00'), + (5, 5, 5, 3, 38.97, 'cancelled', '2025-01-14 16:05:00'), + (6, 1, 2, 1, 29.99, 'completed', '2025-01-15 13:10:00'), + (7, 2, 3, 2, 499.98, 'completed', '2025-01-16 15:40:00') +"# + .trim() + .to_string(), + ] +} + +fn tree_command(args: &[String]) -> CliResult<()> { + if args.iter().any(|arg| arg == "--help" || arg == "-h") { + println!("Usage: sidemantic tree "); + return Ok(()); + } + + let (options, positionals) = parse_options(args)?; + if !options.is_empty() { + return Err( + "tree does not accept options; use 'workbench' for --demo/--connection/--db" + .to_string(), + ); + } + if positionals.len() != 1 { + return Err("tree requires exactly one positional models directory".to_string()); + } + eprintln!("tree is deprecated; use 'workbench'."); + workbench_command(positionals.as_slice()) +} + +#[cfg(feature = "workbench-tui")] +fn launch_workbench_tui(models_path: &str, connection: Option) -> CliResult<()> { + workbench::launch(models_path, connection) +} + +#[cfg(not(feature = "workbench-tui"))] +fn launch_workbench_tui(_models_path: &str, _connection: Option) -> CliResult<()> { + Err("workbench requires the crate feature 'workbench-tui'".to_string()) +} + +fn mcp_command(args: &[String]) -> CliResult<()> { + #[cfg(not(feature = "mcp-server"))] + { + let _ = args; + Err("mcp requires the crate feature 'mcp-server'".to_string()) + } + + #[cfg(feature = "mcp-server")] + { + run_sibling_binary("sidemantic-mcp", args) + } +} + +fn server_command(args: &[String]) -> CliResult<()> { + #[cfg(not(feature = "runtime-server"))] + { + let _ = args; + Err("server requires the crate feature 'runtime-server'".to_string()) + } + + #[cfg(feature = "runtime-server")] + { + run_sibling_binary("sidemantic-server", args) + } +} + +fn lsp_command(args: &[String]) -> CliResult<()> { + #[cfg(not(feature = "runtime-lsp"))] + { + let _ = args; + Err("lsp requires the crate feature 'runtime-lsp'".to_string()) + } + + #[cfg(feature = "runtime-lsp")] + { + run_sibling_binary("sidemantic-lsp", args) + } +} + +#[cfg(any( + feature = "mcp-server", + feature = "runtime-server", + feature = "runtime-lsp" +))] +fn run_sibling_binary(binary_name: &str, args: &[String]) -> CliResult<()> { + let current = env::current_exe().map_err(|e| format!("failed to inspect executable: {e}"))?; + let mut candidate = current.clone(); + candidate.set_file_name(format!("{binary_name}{}", env::consts::EXE_SUFFIX)); + if !candidate.exists() { + return Err(format!( + "{binary_name} binary not found next to {}. Build/install sibling runtime binaries (e.g. cargo build --manifest-path sidemantic-rs/Cargo.toml --features mcp-server,runtime-server,runtime-lsp --bins).", + current.display() + )); + } + + let status = Command::new(&candidate) + .args(args) + .status() + .map_err(|e| format!("failed to launch {}: {e}", candidate.display()))?; + if status.success() { + Ok(()) + } else { + Err(format!("{binary_name} exited with status {status}")) + } +} + +fn split_queries(content: &str) -> Vec { + let chunks: Vec = content + .split(';') + .map(str::trim) + .filter(|chunk| !chunk.is_empty()) + .map(str::to_string) + .collect(); + if !chunks.is_empty() { + return chunks; + } + content + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .map(str::to_string) + .collect() +} + +fn load_preagg_queries(options: &ParsedOptions, context: &str) -> CliResult> { + let queries_file = option_value(options, "--queries-file"); + let has_connection_mode = option_value(options, "--connection").is_some() + || option_value(options, "--db").is_some() + || option_value(options, "--driver").is_some() + || option_value(options, "--uri").is_some(); + if queries_file.is_some() && has_connection_mode { + return Err(format!( + "{context}: provide either --queries-file or connection options (--connection/--db/--driver/--uri), not both" + )); + } + + if let Some(queries_file) = queries_file { + let content = fs::read_to_string(&queries_file) + .map_err(|e| format!("failed to read queries file '{queries_file}': {e}"))?; + return Ok(split_queries(&content)); + } + + if !has_connection_mode { + return Err(format!( + "{context}: missing query source; use --queries-file or --connection/--db" + )); + } + + let days_back = option_usize(options, "--days")?.unwrap_or(7); + let limit = option_usize(options, "--limit")?.unwrap_or(1000); + if days_back == 0 { + return Err(format!("{context}: --days must be greater than 0")); + } + if limit == 0 { + return Err(format!("{context}: --limit must be greater than 0")); + } + + #[cfg(not(feature = "adbc-exec"))] + { + let _ = days_back; + let _ = limit; + Err(format!( + "{context}: query-history mode requires the crate feature 'adbc-exec'" + )) + } + + #[cfg(feature = "adbc-exec")] + { + let adbc = parse_query_adbc_cli_config(options, context)?; + fetch_preagg_query_history(&adbc, days_back, limit, context) + } +} + +#[cfg(feature = "adbc-exec")] +fn parse_kv_pairs(input: &str, option_name: &str) -> CliResult> { + let mut pairs = Vec::new(); + for fragment in input + .split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + { + let (key, value) = fragment + .split_once('=') + .ok_or_else(|| format!("{option_name} expects key=value, got '{fragment}'"))?; + if key.trim().is_empty() { + return Err(format!("{option_name} key cannot be empty: '{fragment}'")); + } + pairs.push((key.trim().to_string(), value.to_string())); + } + if pairs.is_empty() { + return Err(format!("{option_name} expects key=value pairs")); + } + Ok(pairs) +} + +#[cfg(feature = "adbc-exec")] +fn infer_preagg_history_dialect(driver: &str, uri: Option<&str>) -> Option<&'static str> { + let driver_family = driver + .strip_prefix("adbc_driver_") + .unwrap_or(driver) + .to_ascii_lowercase(); + match driver_family.as_str() { + "bigquery" => Some("bigquery"), + "snowflake" => Some("snowflake"), + "clickhouse" => Some("clickhouse"), + "databricks" => Some("databricks"), + "spark" => { + if let Some(uri_value) = uri { + if uri_value.to_ascii_lowercase().starts_with("databricks://") { + return Some("databricks"); + } + } + None + } + _ => None, + } +} + +#[cfg(feature = "adbc-exec")] +fn is_safe_sql_identifier_fragment(value: &str) -> bool { + !value.is_empty() + && value + .chars() + .all(|ch| ch.is_ascii_alphanumeric() || ch == '_' || ch == '-') +} + +#[cfg(feature = "adbc-exec")] +fn parse_bigquery_history_target(uri: Option<&str>, context: &str) -> CliResult<(String, String)> { + let raw = uri.ok_or_else(|| { + format!( + "{context}: BigQuery query-history mode requires a --connection URL or --uri with project information" + ) + })?; + let normalized = raw.strip_prefix("bigquery://").unwrap_or(raw); + let (path_part, query_part) = normalized + .split_once('?') + .map_or((normalized, None), |(path, query)| (path, Some(query))); + let project = path_part + .split('/') + .next() + .map(str::trim) + .unwrap_or_default(); + if !is_safe_sql_identifier_fragment(project) { + return Err(format!( + "{context}: BigQuery project id must contain only letters, numbers, '_' or '-'" + )); + } + + let mut location = "us".to_string(); + if let Some(query) = query_part { + for (key, value) in parse_query_pairs(query) { + if (key == "location" || key == "region") && !value.trim().is_empty() { + location = value.trim().to_ascii_lowercase(); + } + } + } + if !is_safe_sql_identifier_fragment(&location) { + return Err(format!( + "{context}: BigQuery location must contain only letters, numbers, '_' or '-'" + )); + } + + Ok((project.to_string(), location)) +} - println!("Generated SQL:"); - println!("{}\n", generator.generate(&query).unwrap()); +#[cfg(feature = "adbc-exec")] +fn build_preagg_query_history_sql( + dialect: &str, + uri: Option<&str>, + days_back: usize, + limit: usize, + context: &str, +) -> CliResult { + match dialect { + "snowflake" => Ok(format!( + "SELECT query_text \ + FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(END_TIME_RANGE_START => DATEADD('day', -{days_back}, CURRENT_TIMESTAMP()))) \ + WHERE query_text LIKE '%-- sidemantic:%' \ + AND execution_status = 'SUCCESS' \ + ORDER BY start_time DESC \ + LIMIT {limit}" + )), + "bigquery" => { + let (project, location) = parse_bigquery_history_target(uri, context)?; + Ok(format!( + "SELECT query \ + FROM `{project}.region-{location}.INFORMATION_SCHEMA.JOBS_BY_PROJECT` \ + WHERE creation_time >= TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL {days_back} DAY) \ + AND job_type = 'QUERY' \ + AND state = 'DONE' \ + AND query LIKE '%-- sidemantic:%' \ + ORDER BY creation_time DESC \ + LIMIT {limit}" + )) + } + "databricks" => Ok(format!( + "SELECT statement_text \ + FROM system.query.history \ + WHERE start_time >= CURRENT_TIMESTAMP() - INTERVAL {days_back} DAYS \ + AND statement_text LIKE '%-- sidemantic:%' \ + AND status = 'FINISHED' \ + ORDER BY start_time DESC \ + LIMIT {limit}" + )), + "clickhouse" => Ok(format!( + "SELECT query \ + FROM system.query_log \ + WHERE event_time >= now() - INTERVAL {days_back} DAY \ + AND query LIKE '%-- sidemantic:%' \ + AND type = 'QueryFinish' \ + AND exception = '' \ + ORDER BY event_time DESC \ + LIMIT {limit}" + )), + _ => Err(format!( + "{context}: unsupported query-history dialect '{dialect}'" + )), + } } -fn demo_segments() { - println!("--- 4. Segments (Reusable Filters) ---\n"); +#[cfg(feature = "adbc-exec")] +fn adbc_value_text(value: &AdbcValue) -> Option { + match value { + AdbcValue::Null => None, + AdbcValue::Bool(v) => Some(v.to_string()), + AdbcValue::I64(v) => Some(v.to_string()), + AdbcValue::U64(v) => Some(v.to_string()), + AdbcValue::F64(v) => Some(v.to_string()), + AdbcValue::String(v) => Some(v.clone()), + AdbcValue::Bytes(v) => Some(String::from_utf8_lossy(v).to_string()), + } +} + +#[cfg(feature = "adbc-exec")] +fn fetch_preagg_query_history( + adbc: &AdbcCliConfig, + days_back: usize, + limit: usize, + context: &str, +) -> CliResult> { + let dialect = infer_preagg_history_dialect(&adbc.driver, adbc.uri.as_deref()).ok_or_else(|| { + format!( + "{context}: adapter does not support get_query_history(). Supported adapters: BigQueryAdapter, SnowflakeAdapter, DatabricksAdapter, ClickHouseAdapter" + ) + })?; + let history_sql = + build_preagg_query_history_sql(dialect, adbc.uri.as_deref(), days_back, limit, context)?; + let result = execute_with_adbc(AdbcExecutionRequest { + driver: adbc.driver.clone(), + sql: history_sql, + uri: adbc.uri.clone(), + entrypoint: adbc.entrypoint.clone(), + database_options: adbc.database_options.clone(), + connection_options: adbc.connection_options.clone(), + }) + .map_err(|e| format!("{context}: failed to fetch query history via ADBC: {e}"))?; + + let mut queries = Vec::new(); + for row in result.rows { + if let Some(first_col) = row.first() { + if let Some(text) = adbc_value_text(first_col) { + let trimmed = text.trim(); + if !trimmed.is_empty() { + queries.push(trimmed.to_string()); + } + } + } + } + Ok(queries) +} + +#[cfg(feature = "adbc-exec")] +fn parse_option_value(value: &str) -> OptionValue { + if let Some(rest) = value.strip_prefix("int:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Int(parsed); + } + } + if let Some(rest) = value.strip_prefix("float:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Double(parsed); + } + } + if let Some(rest) = value.strip_prefix("str:") { + return OptionValue::String(rest.to_string()); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Int(parsed); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Double(parsed); + } + OptionValue::String(value.to_string()) +} - let mut graph = SemanticGraph::new(); +#[cfg(feature = "adbc-exec")] +fn parse_bool_option(raw: &str, option_name: &str) -> CliResult { + match raw.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Ok(true), + "0" | "false" | "no" | "off" => Ok(false), + _ => Err(format!("{option_name} expects true/false, got '{raw}'")), + } +} + +#[cfg(feature = "adbc-exec")] +fn parse_database_options(values: &[String]) -> CliResult> { + let mut parsed = Vec::new(); + for value in values { + for (key, raw_value) in parse_kv_pairs(value, "--dbopt")? { + parsed.push(( + OptionDatabase::from(key.as_str()), + parse_option_value(&raw_value), + )); + } + } + Ok(parsed) +} + +#[cfg(feature = "adbc-exec")] +fn parse_connection_options(values: &[String]) -> CliResult> { + let mut parsed = Vec::new(); + for value in values { + for (key, raw_value) in parse_kv_pairs(value, "--connopt")? { + parsed.push(( + OptionConnection::from(key.as_str()), + parse_option_value(&raw_value), + )); + } + } + Ok(parsed) +} + +#[cfg(feature = "adbc-exec")] +fn normalize_adbc_driver_name(driver: &str) -> String { + if driver.starts_with("adbc_driver_") + || driver.contains('/') + || driver.contains('\\') + || driver.ends_with(".so") + || driver.ends_with(".dylib") + || driver.ends_with(".dll") + { + driver.to_string() + } else { + format!("adbc_driver_{driver}") + } +} + +#[cfg(feature = "adbc-exec")] +fn parse_query_pairs(query: &str) -> Vec<(String, String)> { + query + .split('&') + .filter(|fragment| !fragment.trim().is_empty()) + .filter_map(|fragment| { + let (key, value) = fragment.split_once('=').unwrap_or((fragment, "")); + if key.trim().is_empty() { + return None; + } + Some((key.trim().to_string(), value.to_string())) + }) + .collect() +} + +#[cfg(feature = "adbc-exec")] +pub(crate) fn parse_connection_url_to_adbc(connection: &str) -> CliResult { + let (scheme, remainder) = connection + .split_once("://") + .ok_or_else(|| format!("invalid connection URL '{connection}': expected scheme://..."))?; + let scheme = scheme.to_ascii_lowercase(); + let (path_part, query_part) = remainder + .split_once('?') + .map_or((remainder, None), |(path, query)| (path, Some(query))); + let query_pairs = query_part.map(parse_query_pairs).unwrap_or_default(); + + match scheme.as_str() { + "adbc" => { + let (driver_raw, path_uri) = path_part + .split_once('/') + .map_or((path_part, None), |(driver, path)| (driver, Some(path))); + if driver_raw.trim().is_empty() { + return Err( + "adbc:// URL must include a driver, e.g. adbc://duckdb?uri=:memory:" + .to_string(), + ); + } + + let mut uri = path_uri + .map(str::to_string) + .filter(|value| !value.trim().is_empty()); + let mut db_options: Vec<(OptionDatabase, OptionValue)> = Vec::new(); + for (key, raw_value) in query_pairs { + if key == "uri" { + if !raw_value.is_empty() { + uri = Some(raw_value); + } + continue; + } + db_options.push(( + OptionDatabase::from(key.as_str()), + parse_option_value(&raw_value), + )); + } + if driver_raw.eq_ignore_ascii_case("sqlite") && uri.is_none() { + uri = Some(":memory:".to_string()); + } + Ok((normalize_adbc_driver_name(driver_raw), uri, db_options)) + } + "duckdb" => { + let mut uri = if matches!(path_part, "" | "/" | ":memory:" | "/:memory:") { + ":memory:".to_string() + } else if path_part.starts_with("md:") || path_part.starts_with('/') { + path_part.to_string() + } else { + format!("//{path_part}") + }; + + if uri == "//" { + uri = ":memory:".to_string(); + } + + let db_options = query_pairs + .into_iter() + .map(|(key, raw_value)| { + ( + OptionDatabase::from(key.as_str()), + parse_option_value(&raw_value), + ) + }) + .collect::>(); + Ok((normalize_adbc_driver_name("duckdb"), Some(uri), db_options)) + } + "sqlite" => { + let trimmed = path_part.trim_start_matches('/'); + let uri = if trimmed.is_empty() { + ":memory:".to_string() + } else { + trimmed.to_string() + }; + Ok((normalize_adbc_driver_name("sqlite"), Some(uri), Vec::new())) + } + "databricks" | "spark" | "postgresql" | "postgres" | "mysql" | "snowflake" | "bigquery" + | "clickhouse" | "mssql" | "trino" | "redshift" => { + let driver = match scheme.as_str() { + "postgres" => "postgresql", + other => other, + }; + Ok(( + normalize_adbc_driver_name(driver), + Some(connection.to_string()), + Vec::new(), + )) + } + _ => Err(format!( + "unsupported connection URL scheme '{scheme}'. Use --driver/--uri for custom drivers." + )), + } +} - let orders = Model::new("orders", "order_id") - .with_table("orders") - .with_dimension(Dimension::categorical("status")) - .with_metric(Metric::sum("revenue", "amount")) - .with_segment(Segment::new("completed", "{model}.status = 'completed'")) - .with_segment(Segment::new("high_value", "{model}.amount > 100")); +#[cfg(feature = "adbc-exec")] +fn parse_adbc_cli_config_internal( + options: &ParsedOptions, + context: &str, + default_driver: Option, + default_uri: Option, + mut default_database_options: Vec<(OptionDatabase, OptionValue)>, +) -> CliResult { + let driver = option_value_any(options, &["--driver"]) + .or_else(|| env_non_empty("SIDEMANTIC_ADBC_DRIVER")) + .or_else(|| env_non_empty("SIDEMANTIC_MCP_ADBC_DRIVER")) + .or(default_driver) + .ok_or_else(|| { + format!("{context}: missing ADBC driver. Set --driver or SIDEMANTIC_ADBC_DRIVER.") + })?; + let uri = option_value_any(options, &["--uri", "--db-uri"]) + .or_else(|| env_non_empty("SIDEMANTIC_ADBC_URI")) + .or_else(|| env_non_empty("SIDEMANTIC_MCP_ADBC_URI")) + .or(default_uri); + let entrypoint = option_value(options, "--entrypoint") + .or_else(|| env_non_empty("SIDEMANTIC_ADBC_ENTRYPOINT")) + .or_else(|| env_non_empty("SIDEMANTIC_MCP_ADBC_ENTRYPOINT")); - graph.add_model(orders).unwrap(); + let mut database_options = std::mem::take(&mut default_database_options); + database_options.extend(parse_database_options(&option_values(options, "--dbopt"))?); + let mut connection_options = parse_connection_options(&option_values(options, "--connopt"))?; - let generator = SqlGenerator::new(&graph); + if let Some(env_opts) = env_non_empty("SIDEMANTIC_ADBC_DBOPTS") { + database_options.extend(parse_database_options(&[env_opts])?); + } + if let Some(env_opts) = env_non_empty("SIDEMANTIC_ADBC_CONNOPTS") { + connection_options.extend(parse_connection_options(&[env_opts])?); + } - // Query with segment - let query = SemanticQuery::new() - .with_metrics(vec!["orders.revenue".into()]) - .with_segments(vec!["orders.completed".into()]); + let username = option_value_any(options, &["--username", "--db-username"]) + .or_else(|| env_non_empty("SIDEMANTIC_ADBC_USERNAME")); + let password = option_value_any(options, &["--password", "--db-password"]) + .or_else(|| env_non_empty("SIDEMANTIC_ADBC_PASSWORD")); - println!("Query with 'completed' segment:"); - println!("{}\n", generator.generate(&query).unwrap()); + if let Some(username) = username { + database_options.push((OptionDatabase::Username, OptionValue::String(username))); + } + if let Some(password) = password { + database_options.push((OptionDatabase::Password, OptionValue::String(password))); + } - // Query with multiple segments - let query = SemanticQuery::new() - .with_metrics(vec!["orders.revenue".into()]) - .with_segments(vec!["orders.completed".into(), "orders.high_value".into()]); + if let Some(catalog) = option_value(options, "--catalog") { + connection_options.push(( + OptionConnection::CurrentCatalog, + OptionValue::String(catalog), + )); + } + if let Some(schema) = option_value(options, "--schema") { + connection_options.push((OptionConnection::CurrentSchema, OptionValue::String(schema))); + } + if let Some(raw) = option_value(options, "--autocommit") { + let value = parse_bool_option(&raw, "--autocommit")?; + connection_options.push(( + OptionConnection::AutoCommit, + OptionValue::String(value.to_string()), + )); + } + if let Some(raw) = option_value(options, "--read-only") { + let value = parse_bool_option(&raw, "--read-only")?; + connection_options.push(( + OptionConnection::ReadOnly, + OptionValue::String(value.to_string()), + )); + } + if let Some(level) = option_value(options, "--isolation-level") { + connection_options.push((OptionConnection::IsolationLevel, OptionValue::String(level))); + } - println!("Query with multiple segments:"); - println!("{}\n", generator.generate(&query).unwrap()); + Ok(AdbcCliConfig { + driver: normalize_adbc_driver_name(&driver), + uri, + entrypoint, + database_options, + connection_options, + }) } -fn demo_query_rewriter() { - println!("--- 5. Query Rewriter ---\n"); +#[cfg(feature = "adbc-exec")] +fn parse_adbc_cli_config(options: &ParsedOptions, context: &str) -> CliResult { + parse_adbc_cli_config_internal(options, context, None, None, Vec::new()) +} + +#[cfg(feature = "adbc-exec")] +fn parse_query_adbc_cli_config(options: &ParsedOptions, context: &str) -> CliResult { + let mut default_driver = None; + let mut default_uri = None; + let mut default_database_options = Vec::new(); + + if let Some(connection_url) = option_value(options, "--connection") { + let (driver, uri, db_options) = parse_connection_url_to_adbc(&connection_url)?; + default_driver = Some(driver); + default_uri = uri; + default_database_options = db_options; + } else if let Some(db_path) = option_value(options, "--db") { + default_driver = Some(normalize_adbc_driver_name("duckdb")); + default_uri = Some(db_path); + } + + parse_adbc_cli_config_internal( + options, + context, + default_driver, + default_uri, + default_database_options, + ) +} + +#[cfg(feature = "adbc-exec")] +fn execute_sql_statements(adbc: &AdbcCliConfig, statements: &[String]) -> CliResult<()> { + for statement in statements { + let _ = execute_with_adbc(AdbcExecutionRequest { + driver: adbc.driver.clone(), + sql: statement.clone(), + uri: adbc.uri.clone(), + entrypoint: adbc.entrypoint.clone(), + database_options: adbc.database_options.clone(), + connection_options: adbc.connection_options.clone(), + }) + .map_err(|e| format!("failed to execute statement via ADBC: {e}"))?; + } + Ok(()) +} + +#[cfg(feature = "adbc-exec")] +fn adbc_value_to_json(value: &AdbcValue) -> JsonValue { + match value { + AdbcValue::Null => JsonValue::Null, + AdbcValue::Bool(v) => JsonValue::Bool(*v), + AdbcValue::I64(v) => JsonValue::Number((*v).into()), + AdbcValue::U64(v) => JsonValue::Number(serde_json::Number::from(*v)), + AdbcValue::F64(v) => serde_json::Number::from_f64(*v) + .map(JsonValue::Number) + .unwrap_or(JsonValue::Null), + AdbcValue::String(v) => JsonValue::String(v.clone()), + AdbcValue::Bytes(v) => JsonValue::Array( + v.iter() + .map(|byte| JsonValue::Number(serde_json::Number::from(*byte))) + .collect(), + ), + } +} + +#[cfg(feature = "adbc-exec")] +fn csv_escape(value: &str) -> String { + if value.contains(',') || value.contains('"') || value.contains('\n') || value.contains('\r') { + format!("\"{}\"", value.replace('"', "\"\"")) + } else { + value.to_string() + } +} + +#[cfg(feature = "adbc-exec")] +fn adbc_value_to_csv_field(value: &AdbcValue) -> String { + match value { + AdbcValue::Null => String::new(), + AdbcValue::Bool(v) => v.to_string(), + AdbcValue::I64(v) => v.to_string(), + AdbcValue::U64(v) => v.to_string(), + AdbcValue::F64(v) => v.to_string(), + AdbcValue::String(v) => v.clone(), + AdbcValue::Bytes(v) => v.iter().map(|byte| format!("{byte:02x}")).collect(), + } +} + +#[cfg(feature = "adbc-exec")] +fn write_csv_rows( + columns: &[String], + rows: &[Vec], + output_path: Option<&str>, +) -> CliResult<()> { + let mut csv = String::new(); + csv.push_str( + &columns + .iter() + .map(|column| csv_escape(column)) + .collect::>() + .join(","), + ); + csv.push('\n'); + + for row in rows { + let mut serialized_row = Vec::with_capacity(columns.len()); + for (index, _) in columns.iter().enumerate() { + let raw = row + .get(index) + .map(adbc_value_to_csv_field) + .unwrap_or_else(String::new); + serialized_row.push(csv_escape(&raw)); + } + csv.push_str(&serialized_row.join(",")); + csv.push('\n'); + } + + if let Some(path) = output_path { + fs::write(path, csv).map_err(|e| format!("failed to write CSV output to '{path}': {e}"))?; + } else { + print!("{csv}"); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::split_queries; + #[cfg(feature = "adbc-exec")] + use super::{ + normalize_adbc_driver_name, parse_connection_url_to_adbc, render_duckdb_connection_url, + }; + #[cfg(all(unix, feature = "adbc-exec"))] + use std::path::Path; + + #[test] + fn test_split_queries_counts_semicolon_separated_instrumented_queries() { + let content = "\ +SELECT revenue FROM orders +-- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day; +SELECT revenue FROM orders +-- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day; +"; - let mut graph = SemanticGraph::new(); + let queries = split_queries(content); - let orders = Model::new("orders", "order_id") - .with_table("public.orders") - .with_dimension(Dimension::categorical("status")) - .with_metric(Metric::sum("revenue", "amount")) - .with_metric(Metric::count("order_count")); + assert_eq!(queries.len(), 2); + assert!(queries.iter().all(|query| query.contains("-- sidemantic:"))); + } - graph.add_model(orders).unwrap(); + #[cfg(feature = "adbc-exec")] + #[test] + fn normalize_adbc_driver_name_keeps_explicit_library_paths() { + assert_eq!( + normalize_adbc_driver_name("duckdb"), + "adbc_driver_duckdb".to_string() + ); + assert_eq!( + normalize_adbc_driver_name("adbc_driver_postgresql"), + "adbc_driver_postgresql".to_string() + ); + assert_eq!( + normalize_adbc_driver_name("/tmp/_duckdb.so"), + "/tmp/_duckdb.so".to_string() + ); + assert_eq!( + normalize_adbc_driver_name(".\\drivers\\duckdb.dll"), + ".\\drivers\\duckdb.dll".to_string() + ); + } - let rewriter = QueryRewriter::new(&graph); + #[cfg(all(unix, feature = "adbc-exec"))] + #[test] + fn duckdb_connection_url_preserves_absolute_path() { + let path = Path::new("/tmp/sidemantic-workbench.duckdb"); + let url = render_duckdb_connection_url(path); + assert_eq!(url, "duckdb:///tmp/sidemantic-workbench.duckdb"); - let sql = "SELECT orders.revenue, orders.status FROM orders WHERE orders.status = 'pending'"; - println!("Original SQL:"); - println!(" {sql}\n"); - println!("Rewritten SQL:"); - println!("{}\n", rewriter.rewrite(sql).unwrap()); + let (_, uri, _) = parse_connection_url_to_adbc(&url).unwrap(); + assert_eq!(uri.as_deref(), Some("/tmp/sidemantic-workbench.duckdb")); + } } diff --git a/sidemantic-rs/src/python.rs b/sidemantic-rs/src/python.rs new file mode 100644 index 00000000..e9e9d654 --- /dev/null +++ b/sidemantic-rs/src/python.rs @@ -0,0 +1,1395 @@ +//! Python bindings for sidemantic-rs via PyO3. + +#[cfg(feature = "python-adbc")] +use crate::db::{execute_with_adbc as execute_with_adbc_native, AdbcExecutionRequest, AdbcValue}; +use crate::error::SidemanticError; +use crate::runtime::{ + analyze_migrator_query as analyze_migrator_query_native, + build_preaggregation_refresh_statements as build_preaggregation_refresh_statements_native, + build_symmetric_aggregate_sql as build_symmetric_aggregate_sql_native, + calculate_preaggregation_benefit_score as calculate_preaggregation_benefit_score_native, + chart_auto_detect_columns as chart_auto_detect_columns_native, + chart_encoding_type as chart_encoding_type_native, + chart_format_label as chart_format_label_native, chart_select_type as chart_select_type_native, + compile_with_yaml_query as compile_with_yaml_query_native, + detect_adapter_kind as detect_adapter_kind_native, + dimension_sql_expr_with_yaml as dimension_sql_expr_with_yaml_native, + dimension_with_granularity_with_yaml as dimension_with_granularity_with_yaml_native, + evaluate_table_calculation_expression as evaluate_table_calculation_expression_native, + extract_column_references as extract_column_references_native, + extract_metric_dependencies_from_yaml as extract_metric_dependencies_from_yaml_native, + extract_preaggregation_patterns as extract_preaggregation_patterns_native, + find_models_for_query as find_models_for_query_native, + find_models_for_query_with_yaml as find_models_for_query_with_yaml_native, + find_relationship_path_with_yaml as find_relationship_path_with_yaml_native, + format_parameter_value_with_yaml as format_parameter_value_with_yaml_native, + generate_catalog_metadata_with_yaml as generate_catalog_metadata_with_yaml_native, + generate_preaggregation_definition as generate_preaggregation_definition_native, + generate_preaggregation_materialization_sql_with_yaml as generate_preaggregation_materialization_sql_with_yaml_native, + generate_preaggregation_name as generate_preaggregation_name_native, + generate_time_comparison_sql as generate_time_comparison_sql_native, + interpolate_sql_with_parameters_with_yaml as interpolate_sql_with_parameters_with_yaml_native, + is_relative_date as is_relative_date_native, is_sql_template as is_sql_template_native, + load_graph_from_directory as load_graph_from_directory_native, + load_graph_with_sql as load_graph_with_sql_native, + load_graph_with_yaml as load_graph_with_yaml_native, + metric_is_simple_aggregation as metric_is_simple_aggregation_native, + metric_sql_expr as metric_sql_expr_native, metric_to_sql as metric_to_sql_native, + model_find_dimension_index_with_yaml as model_find_dimension_index_with_yaml_native, + model_find_metric_index_with_yaml as model_find_metric_index_with_yaml_native, + model_find_pre_aggregation_index_with_yaml as model_find_pre_aggregation_index_with_yaml_native, + model_find_segment_index_with_yaml as model_find_segment_index_with_yaml_native, + model_get_drill_down_with_yaml as model_get_drill_down_with_yaml_native, + model_get_drill_up_with_yaml as model_get_drill_up_with_yaml_native, + model_get_hierarchy_path_with_yaml as model_get_hierarchy_path_with_yaml_native, + needs_symmetric_aggregate as needs_symmetric_aggregate_native, + parse_reference_with_yaml as parse_reference_with_yaml_native, + parse_relative_date as parse_relative_date_native, + parse_simple_metric_aggregation as parse_simple_metric_aggregation_native, + parse_sql_definitions_payload as parse_sql_definitions_payload_native, + parse_sql_graph_definitions_payload as parse_sql_graph_definitions_payload_native, + parse_sql_model_payload as parse_sql_model_payload_native, + parse_sql_statement_blocks_payload as parse_sql_statement_blocks_payload_native, + plan_preaggregation_refresh_execution as plan_preaggregation_refresh_execution_native, + recommend_preaggregation_patterns as recommend_preaggregation_patterns_native, + relationship_foreign_key_columns_with_yaml as relationship_foreign_key_columns_with_yaml_native, + relationship_primary_key_columns_with_yaml as relationship_primary_key_columns_with_yaml_native, + relationship_related_key_with_yaml as relationship_related_key_with_yaml_native, + relationship_sql_expr_with_yaml as relationship_sql_expr_with_yaml_native, + relative_date_to_range as relative_date_to_range_native, + render_sql_template as render_sql_template_native, + resolve_metric_inheritance as resolve_metric_inheritance_native, + resolve_model_inheritance_with_yaml as resolve_model_inheritance_with_yaml_native, + resolve_preaggregation_refresh_mode as resolve_preaggregation_refresh_mode_native, + rewrite_with_yaml as rewrite_with_yaml_native, + segment_get_sql_with_yaml as segment_get_sql_with_yaml_native, + shape_preaggregation_refresh_result as shape_preaggregation_refresh_result_native, + summarize_preaggregation_patterns as summarize_preaggregation_patterns_native, + time_comparison_offset_interval as time_comparison_offset_interval_native, + time_comparison_sql_offset as time_comparison_sql_offset_native, + trailing_period_sql_interval as trailing_period_sql_interval_native, + validate_engine_refresh_sql_compatibility as validate_engine_refresh_sql_compatibility_native, + validate_metric_payload as validate_metric_payload_native, + validate_model_payload as validate_model_payload_native, + validate_models_yaml as validate_models_yaml_native, + validate_parameter_payload as validate_parameter_payload_native, + validate_preaggregation_refresh_request as validate_preaggregation_refresh_request_native, + validate_query_references_with_yaml as validate_query_references_with_yaml_native, + validate_query_with_yaml as validate_query_with_yaml_native, + validate_table_calculation_payload as validate_table_calculation_payload_native, + validate_table_formula_expression as validate_table_formula_expression_native, + RelationshipPathError, +}; +#[cfg(feature = "python-adbc")] +use adbc_core::{ + constants, + options::{OptionConnection, OptionDatabase, OptionValue}, +}; +use pyo3::exceptions::PyKeyError; +use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; +#[cfg(feature = "python-adbc")] +use pyo3::types::{PyBool, PyBytes, PyString}; +use pyo3::types::{PyDict, PyList, PyTuple}; + +type PyRelationshipPath = Vec<(String, String, Vec, Vec, String)>; + +static REGISTRY_CONTEXTVAR: GILOnceCell> = GILOnceCell::new(); + +fn registry_contextvar(py: Python<'_>) -> PyResult<&Py> { + REGISTRY_CONTEXTVAR.get_or_try_init(py, || { + let contextvars = py.import("contextvars")?; + let contextvar_type = contextvars.getattr("ContextVar")?; + let kwargs = PyDict::new(py); + kwargs.set_item("default", py.None())?; + let contextvar = contextvar_type.call(("sidemantic_rs_current_layer",), Some(&kwargs))?; + Ok(contextvar.unbind()) + }) +} + +/// Rewrite SQL using semantic models provided as YAML. +/// +/// This is stateless on purpose so Python tests can call it safely across runs. +#[pyfunction] +fn rewrite_with_yaml(yaml: &str, sql: &str) -> PyResult { + rewrite_with_yaml_native(yaml, sql).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +#[pyfunction] +fn is_sql_template(sql: &str) -> bool { + is_sql_template_native(sql) +} + +#[pyfunction] +fn render_sql_template(template_str: &str, context_yaml: &str) -> PyResult { + render_sql_template_native(template_str, context_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn format_parameter_value_with_yaml(parameter_yaml: &str, value_yaml: &str) -> PyResult { + format_parameter_value_with_yaml_native(parameter_yaml, value_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn interpolate_sql_with_parameters( + sql: &str, + parameters_yaml: &str, + values_yaml: &str, +) -> PyResult { + interpolate_sql_with_parameters_with_yaml_native(sql, parameters_yaml, values_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn evaluate_table_calculation_expression(expr: &str) -> PyResult { + evaluate_table_calculation_expression_native(expr) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn validate_table_formula_expression(expression: &str) -> PyResult { + validate_table_formula_expression_native(expression) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +#[pyo3(signature = (measure_expr, primary_key, agg_type, model_alias = None, dialect = "duckdb"))] +fn build_symmetric_aggregate_sql( + measure_expr: &str, + primary_key: &str, + agg_type: &str, + model_alias: Option<&str>, + dialect: &str, +) -> PyResult { + build_symmetric_aggregate_sql_native(measure_expr, primary_key, agg_type, model_alias, dialect) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn needs_symmetric_aggregate(relationship: &str, is_base_model: bool) -> bool { + needs_symmetric_aggregate_native(relationship, is_base_model) +} + +#[pyfunction] +#[pyo3(signature = (expr, dialect = "duckdb"))] +fn parse_relative_date(expr: &str, dialect: &str) -> Option { + parse_relative_date_native(expr, dialect) +} + +#[pyfunction] +#[pyo3(signature = (expr, column = "date_col", dialect = "duckdb"))] +fn relative_date_to_range(expr: &str, column: &str, dialect: &str) -> Option { + relative_date_to_range_native(expr, column, dialect) +} + +#[pyfunction] +fn is_relative_date(expr: &str) -> bool { + is_relative_date_native(expr) +} + +#[pyfunction] +#[pyo3(signature = (comparison_type, offset = None, offset_unit = None))] +fn time_comparison_offset_interval( + comparison_type: &str, + offset: Option, + offset_unit: Option<&str>, +) -> PyResult<(i64, String)> { + time_comparison_offset_interval_native(comparison_type, offset, offset_unit) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +#[pyo3(signature = (comparison_type, offset = None, offset_unit = None))] +fn time_comparison_sql_offset( + comparison_type: &str, + offset: Option, + offset_unit: Option<&str>, +) -> PyResult { + time_comparison_sql_offset_native(comparison_type, offset, offset_unit) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +fn trailing_period_sql_interval(amount: i64, unit: &str) -> PyResult { + trailing_period_sql_interval_native(amount, unit) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction] +#[pyo3(signature = (comparison_type, calculation, current_metric_sql, time_dimension, offset = None, offset_unit = None))] +fn generate_time_comparison_sql( + comparison_type: &str, + calculation: &str, + current_metric_sql: &str, + time_dimension: &str, + offset: Option, + offset_unit: Option<&str>, +) -> PyResult { + generate_time_comparison_sql_native( + comparison_type, + calculation, + current_metric_sql, + time_dimension, + offset, + offset_unit, + ) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Compile a semantic query using sidemantic-rs SQL generator. +/// +/// `query_yaml` payload supports: +/// - metrics: [str] +/// - dimensions: [str] +/// - filters: [str] +/// - segments: [str] +/// - order_by: [str] +/// - limit: int | null +#[pyfunction] +fn compile_with_yaml(yaml: &str, query_yaml: &str) -> PyResult { + compile_with_yaml_query_native(yaml, query_yaml).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse native YAML definitions and return a serialized graph payload. +#[pyfunction] +fn load_graph_with_yaml(yaml: &str) -> PyResult { + load_graph_with_yaml_native(yaml).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse SQL file content definitions and return a serialized graph payload. +#[pyfunction] +fn load_graph_with_sql(sql_content: &str) -> PyResult { + load_graph_with_sql_native(sql_content).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse supported graph definitions from a directory and return a serialized graph payload. +#[pyfunction] +fn load_graph_from_directory(path: &str) -> PyResult { + load_graph_from_directory_native(path).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse SQL metric/segment definitions and return serialized payload. +#[pyfunction] +fn parse_sql_definitions_payload(sql: &str) -> PyResult { + parse_sql_definitions_payload_native(sql).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse SQL graph definitions and return serialized payload. +#[pyfunction] +fn parse_sql_graph_definitions_payload(sql: &str) -> PyResult { + parse_sql_graph_definitions_payload_native(sql).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse SQL model definition and return serialized model payload. +#[pyfunction] +fn parse_sql_model_payload(sql: &str) -> PyResult { + parse_sql_model_payload_native(sql).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Parse raw SQL statement blocks and return serialized payload. +#[pyfunction] +fn parse_sql_statement_blocks_payload(sql: &str) -> PyResult { + parse_sql_statement_blocks_payload_native(sql).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Set current layer in a Rust-owned ContextVar. +#[pyfunction] +#[pyo3(signature = (layer = None))] +fn registry_set_current_layer(py: Python<'_>, layer: Option>) -> PyResult<()> { + let contextvar = registry_contextvar(py)?.bind(py); + match layer { + Some(value) => { + contextvar.call_method1("set", (value.bind(py),))?; + } + None => { + contextvar.call_method1("set", (py.None(),))?; + } + } + Ok(()) +} + +/// Get current layer from a Rust-owned ContextVar. +#[pyfunction] +fn registry_get_current_layer(py: Python<'_>) -> PyResult>> { + let contextvar = registry_contextvar(py)?.bind(py); + let value = contextvar.call_method0("get")?; + if value.is_none() { + Ok(None) + } else { + Ok(Some(value.unbind())) + } +} + +/// Validate a semantic query using sidemantic-rs graph semantics. +/// +/// Returns a Python-compatible list of validation error strings. +#[pyfunction] +fn validate_query_with_yaml(yaml: &str, query_yaml: &str) -> PyResult> { + validate_query_with_yaml_native(yaml, query_yaml).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate query metric/dimension references against model YAML. +#[pyfunction] +fn validate_query_references( + yaml: &str, + metrics: Vec, + dimensions: Vec, +) -> PyResult> { + validate_query_references_with_yaml_native(yaml, &metrics, &dimensions).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Generate materialization SQL for a model pre-aggregation using sidemantic-rs schema. +#[pyfunction] +fn generate_preaggregation_materialization_sql( + yaml: &str, + model_name: &str, + preagg_name: &str, +) -> PyResult { + generate_preaggregation_materialization_sql_with_yaml_native(yaml, model_name, preagg_name) + .map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate engine refresh SQL compatibility with materialized-view restrictions. +#[pyfunction] +fn validate_engine_refresh_sql_compatibility( + source_sql: &str, + dialect: &str, +) -> (bool, Option) { + validate_engine_refresh_sql_compatibility_native(source_sql, dialect) +} + +/// Build SQL statements for a pre-aggregation refresh operation. +#[pyfunction] +#[pyo3(signature = ( + mode, + table_name, + source_sql, + watermark_column = None, + from_watermark = None, + lookback = None, + dialect = None, + refresh_every = None +))] +#[allow(clippy::too_many_arguments)] +fn build_preaggregation_refresh_statements( + mode: &str, + table_name: &str, + source_sql: &str, + watermark_column: Option<&str>, + from_watermark: Option<&str>, + lookback: Option<&str>, + dialect: Option<&str>, + refresh_every: Option<&str>, +) -> PyResult> { + build_preaggregation_refresh_statements_native( + mode, + table_name, + source_sql, + watermark_column, + from_watermark, + lookback, + dialect, + refresh_every, + ) + .map_err(map_refresh_planner_error) +} + +/// Resolve refresh mode from explicit value or incremental default flag. +#[pyfunction] +#[pyo3(signature = (mode = None, refresh_incremental = false))] +fn resolve_preaggregation_refresh_mode( + mode: Option<&str>, + refresh_incremental: bool, +) -> PyResult { + resolve_preaggregation_refresh_mode_native(mode, refresh_incremental) + .map_err(map_refresh_planner_error) +} + +/// Validate pre-aggregation refresh mode requirements. +#[pyfunction] +#[pyo3(signature = (mode, watermark_column = None, dialect = None))] +fn validate_preaggregation_refresh_request( + mode: &str, + watermark_column: Option<&str>, + dialect: Option<&str>, +) -> PyResult { + validate_preaggregation_refresh_request_native(mode, watermark_column, dialect) + .map_err(map_refresh_planner_error)?; + Ok(true) +} + +/// Plan pre-aggregation refresh execution mode and branch requirements. +#[pyfunction] +#[pyo3(signature = (mode = None, refresh_incremental = false, watermark_column = None, dialect = None))] +fn plan_preaggregation_refresh_execution( + py: Python<'_>, + mode: Option<&str>, + refresh_incremental: bool, + watermark_column: Option<&str>, + dialect: Option<&str>, +) -> PyResult> { + let plan = plan_preaggregation_refresh_execution_native( + mode, + refresh_incremental, + watermark_column, + dialect, + ) + .map_err(map_refresh_planner_error)?; + + let payload = PyDict::new(py); + payload.set_item("mode", &plan.mode)?; + payload.set_item("requires_prior_watermark", plan.requires_prior_watermark)?; + payload.set_item( + "requires_merge_table_existence_check", + plan.requires_merge_table_existence_check, + )?; + payload.set_item("include_new_watermark", plan.include_new_watermark)?; + Ok(payload.into_any().unbind()) +} + +/// Compute recommender benefit score for a serialized pattern payload. +#[pyfunction] +#[pyo3(signature = (pattern_json, count))] +fn calculate_preaggregation_benefit_score(pattern_json: &str, count: usize) -> PyResult { + calculate_preaggregation_benefit_score_native(pattern_json, count) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Generate recommender name for a serialized pattern payload. +#[pyfunction] +fn generate_preaggregation_name(pattern_json: &str) -> PyResult { + generate_preaggregation_name_native(pattern_json) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Parse instrumented query comments into pre-aggregation pattern counts. +#[pyfunction] +fn extract_preaggregation_patterns(queries: Vec) -> PyResult { + extract_preaggregation_patterns_native(queries) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Build pre-aggregation recommendations from known pattern counts. +#[pyfunction] +#[pyo3(signature = (patterns_json, min_query_count, min_benefit_score, top_n = None))] +fn recommend_preaggregation_patterns( + patterns_json: &str, + min_query_count: usize, + min_benefit_score: f64, + top_n: Option, +) -> PyResult { + recommend_preaggregation_patterns_native( + patterns_json, + min_query_count, + min_benefit_score, + top_n, + ) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Summarize pattern counts for recommendation reporting. +#[pyfunction] +fn summarize_preaggregation_patterns( + patterns_json: &str, + min_query_count: usize, +) -> PyResult { + summarize_preaggregation_patterns_native(patterns_json, min_query_count) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Convert a recommendation payload into a pre-aggregation definition payload. +#[pyfunction] +fn generate_preaggregation_definition(recommendation_json: &str) -> PyResult { + generate_preaggregation_definition_native(recommendation_json) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +fn execute_sql<'py>(connection: &Bound<'py, PyAny>, sql: &str) -> PyResult> { + connection.call_method1("execute", (sql,)) +} + +fn extract_first_row_value(row: &Bound<'_, PyAny>) -> PyResult>> { + if row.is_none() { + return Ok(None); + } + + if let Ok(tuple) = row.downcast::() { + if tuple.is_empty() { + return Ok(None); + } + let item = tuple.get_item(0)?; + if item.is_none() { + Ok(None) + } else { + Ok(Some(item.unbind())) + } + } else if let Ok(list) = row.downcast::() { + if list.is_empty() { + return Ok(None); + } + let item = list.get_item(0)?; + if item.is_none() { + Ok(None) + } else { + Ok(Some(item.unbind())) + } + } else { + let item = row.get_item(0)?; + if item.is_none() { + Ok(None) + } else { + Ok(Some(item.unbind())) + } + } +} + +fn get_current_watermark( + connection: &Bound<'_, PyAny>, + table_name: &str, + watermark_column: &str, +) -> Option> { + let sql = format!("SELECT MAX({watermark_column}) as max_watermark FROM {table_name}"); + let cursor = execute_sql(connection, &sql).ok()?; + let row = cursor.call_method0("fetchone").ok()?; + extract_first_row_value(&row).ok()? +} + +fn py_value_to_i64(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(v) = value.extract::() { + return Ok(v); + } + if let Ok(v) = value.extract::() { + return i64::try_from(v) + .map_err(|_| PyRuntimeError::new_err("count value exceeded i64 range")); + } + if let Ok(v) = value.extract::() { + return Ok(v as i64); + } + Err(PyRuntimeError::new_err( + "failed to parse numeric result from database cursor", + )) +} + +fn watermark_to_refresh_value(value: &Bound<'_, PyAny>) -> PyResult { + let rendered = value.str()?.to_str()?.to_string(); + if rendered.len() >= 2 && rendered.starts_with('\'') && rendered.ends_with('\'') { + return Ok(rendered[1..rendered.len() - 1].to_string()); + } + Ok(rendered) +} + +fn build_refresh_result_dict( + py: Python<'_>, + mode: &str, + rows_inserted: i64, + rows_updated: i64, + new_watermark: Option>, +) -> PyResult> { + let payload = PyDict::new(py); + payload.set_item("mode", mode)?; + payload.set_item("rows_inserted", rows_inserted)?; + payload.set_item("rows_updated", rows_updated)?; + if let Some(value) = new_watermark { + payload.set_item("new_watermark", value.bind(py))?; + } else { + payload.set_item("new_watermark", py.None())?; + } + Ok(payload.into_any().unbind()) +} + +fn map_refresh_planner_error(err: SidemanticError) -> PyErr { + match err { + SidemanticError::Validation(message) => { + if let Some(rest) = + message.strip_prefix("unsupported dialect for engine refresh mode: ") + { + let dialect = rest.split('.').next().unwrap_or(rest).trim(); + return PyValueError::new_err(format!( + "Unsupported dialect for engine mode: {dialect}" + )); + } + PyValueError::new_err(message) + } + other => PyRuntimeError::new_err(other.to_string()), + } +} + +fn table_exists(connection: &Bound<'_, PyAny>, table_name: &str) -> bool { + execute_sql(connection, &format!("SELECT 1 FROM {table_name} LIMIT 1")).is_ok() +} + +fn execute_refresh_statements( + connection: &Bound<'_, PyAny>, + statements: &[String], +) -> PyResult<()> { + for statement in statements { + execute_sql(connection, statement)?; + } + Ok(()) +} + +/// Execute a pre-aggregation refresh strategy using a Python DB connection. +#[pyfunction] +#[pyo3(signature = ( + connection, + source_sql, + table_name, + mode = None, + watermark_column = None, + lookback = None, + from_watermark = None, + to_watermark = None, + dialect = None, + refresh_incremental = false, + refresh_every = None +))] +#[allow(clippy::too_many_arguments)] +fn refresh_preaggregation( + py: Python<'_>, + connection: &Bound<'_, PyAny>, + source_sql: &str, + table_name: &str, + mode: Option<&str>, + watermark_column: Option<&str>, + lookback: Option<&str>, + from_watermark: Option>, + to_watermark: Option>, + dialect: Option<&str>, + refresh_incremental: bool, + refresh_every: Option<&str>, +) -> PyResult> { + let _ = to_watermark; + let refresh_plan = plan_preaggregation_refresh_execution_native( + mode, + refresh_incremental, + watermark_column, + dialect, + ) + .map_err(map_refresh_planner_error)?; + + let mut normalized_watermark: Option = None; + let mut merge_target_table_existed = false; + if refresh_plan.requires_prior_watermark { + let watermark_column = watermark_column.unwrap_or_default(); + let mut current_watermark = from_watermark; + if current_watermark.is_none() { + current_watermark = get_current_watermark(connection, table_name, watermark_column); + } + if let Some(value) = current_watermark { + normalized_watermark = Some(watermark_to_refresh_value(value.bind(py))?); + } + if refresh_plan.requires_merge_table_existence_check { + merge_target_table_existed = table_exists(connection, table_name); + } + } + + let statements = build_preaggregation_refresh_statements_native( + &refresh_plan.mode, + table_name, + source_sql, + watermark_column, + normalized_watermark.as_deref(), + lookback, + dialect, + refresh_every, + ) + .map_err(map_refresh_planner_error)?; + execute_refresh_statements(connection, &statements)?; + + let full_rows_inserted = if refresh_plan.mode == "full" { + let count_cursor = execute_sql(connection, &format!("SELECT COUNT(*) FROM {table_name}"))?; + let count_row = count_cursor.call_method0("fetchone")?; + if let Some(value) = extract_first_row_value(&count_row)? { + py_value_to_i64(value.bind(py))? + } else { + 0 + } + } else { + 0 + }; + + let result_shape = shape_preaggregation_refresh_result_native( + &refresh_plan.mode, + merge_target_table_existed, + full_rows_inserted, + ) + .map_err(map_refresh_planner_error)?; + let new_watermark = if refresh_plan.include_new_watermark { + let watermark_column = watermark_column.unwrap_or_default(); + get_current_watermark(connection, table_name, watermark_column) + } else { + None + }; + build_refresh_result_dict( + py, + &result_shape.mode, + result_shape.rows_inserted, + result_shape.rows_updated, + new_watermark, + ) +} + +/// Validate models YAML by loading it into a Rust semantic graph. +#[pyfunction] +fn validate_models_yaml(yaml: &str) -> PyResult { + validate_models_yaml_native(yaml).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate model payload shape in Rust. +#[pyfunction] +fn validate_model_payload(model_yaml: &str) -> PyResult { + validate_model_payload_native(model_yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Resolve model inheritance in Rust and return resolved models as YAML. +#[pyfunction] +fn resolve_model_inheritance(yaml: &str) -> PyResult { + resolve_model_inheritance_with_yaml_native(yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Resolve metric inheritance in Rust and return resolved metrics as YAML. +#[pyfunction] +fn resolve_metric_inheritance(metrics_yaml: &str) -> PyResult { + resolve_metric_inheritance_native(metrics_yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + SidemanticError::SqlGeneration(msg) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate metric payload shape in Rust. +#[pyfunction] +fn validate_metric_payload(metric_yaml: &str) -> PyResult { + validate_metric_payload_native(metric_yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate parameter payload shape in Rust. +#[pyfunction] +fn validate_parameter_payload(parameter_yaml: &str) -> PyResult { + validate_parameter_payload_native(parameter_yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Validate table-calculation payload shape in Rust. +#[pyfunction] +fn validate_table_calculation_payload(calculation_yaml: &str) -> PyResult { + validate_table_calculation_payload_native(calculation_yaml).map_err(|e| match e { + SidemanticError::Validation(msg) => PyValueError::new_err(msg), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Export semantic graph catalog metadata in Postgres-compatible format. +#[pyfunction] +#[pyo3(signature = (yaml, schema = "public"))] +fn generate_catalog_metadata(yaml: &str, schema: &str) -> PyResult { + generate_catalog_metadata_with_yaml_native(yaml, schema).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +#[cfg(feature = "python-adbc")] +fn adbc_error(context: &str, err: impl std::fmt::Display) -> PyErr { + PyRuntimeError::new_err(format!("{context}: {err}")) +} + +#[cfg(feature = "python-adbc")] +fn py_value_to_option_value(value: &Bound<'_, PyAny>) -> PyResult { + if value.is_instance_of::() { + let flag = value.extract::()?; + let text = if flag { "true" } else { "false" }; + return Ok(OptionValue::String(text.into())); + } + + if value.is_instance_of::() { + return Ok(OptionValue::String(value.extract::()?)); + } + + if value.is_instance_of::() { + return Ok(OptionValue::Bytes(value.extract::>()?)); + } + + if let Ok(number) = value.extract::() { + return Ok(OptionValue::Int(number)); + } + + if let Ok(number) = value.extract::() { + return Ok(OptionValue::Double(number)); + } + + let text = value.str()?.to_str()?.to_owned(); + Ok(OptionValue::String(text)) +} + +#[cfg(feature = "python-adbc")] +fn merge_database_options( + uri: Option<&str>, + db_kwargs: Option<&Bound<'_, PyDict>>, +) -> PyResult> { + let mut options: Vec<(OptionDatabase, OptionValue)> = Vec::new(); + + if let Some(kwargs) = db_kwargs { + for (key, value) in kwargs.iter() { + if value.is_none() { + continue; + } + let key = key.extract::()?; + if uri.is_some() && key == constants::ADBC_OPTION_URI { + continue; + } + options.push(( + OptionDatabase::from(key.as_str()), + py_value_to_option_value(&value)?, + )); + } + } + + Ok(options) +} + +#[cfg(feature = "python-adbc")] +fn merge_connection_options( + conn_kwargs: Option<&Bound<'_, PyDict>>, + autocommit: bool, +) -> PyResult> { + let mut options: Vec<(OptionConnection, OptionValue)> = Vec::new(); + + if let Some(kwargs) = conn_kwargs { + for (key, value) in kwargs.iter() { + if value.is_none() { + continue; + } + let key = key.extract::()?; + options.push(( + OptionConnection::from(key.as_str()), + py_value_to_option_value(&value)?, + )); + } + } + + options.push(( + OptionConnection::AutoCommit, + OptionValue::String(if autocommit { "true" } else { "false" }.to_owned()), + )); + Ok(options) +} + +#[cfg(feature = "python-adbc")] +fn adbc_value_to_py(py: Python<'_>, value: &AdbcValue) -> PyResult> { + match value { + AdbcValue::Null => Ok(py.None()), + AdbcValue::Bool(v) => Ok(PyBool::new(py, *v).to_owned().into_any().unbind()), + AdbcValue::I64(v) => Ok(v.into_pyobject(py)?.into_any().unbind()), + AdbcValue::U64(v) => Ok(v.into_pyobject(py)?.into_any().unbind()), + AdbcValue::F64(v) => Ok(v.into_pyobject(py)?.into_any().unbind()), + AdbcValue::String(v) => Ok(v.into_pyobject(py)?.into_any().unbind()), + AdbcValue::Bytes(v) => Ok(PyBytes::new(py, v).into_any().unbind()), + } +} + +/// Execute SQL via the Rust ADBC driver manager and return rows/columns. +#[cfg(feature = "python-adbc")] +#[pyfunction] +#[pyo3(signature = (driver, sql, uri=None, entrypoint=None, db_kwargs=None, conn_kwargs=None, autocommit=true))] +#[allow(clippy::too_many_arguments)] +fn execute_with_adbc( + py: Python<'_>, + driver: &str, + sql: &str, + uri: Option<&str>, + entrypoint: Option<&str>, + db_kwargs: Option<&Bound<'_, PyDict>>, + conn_kwargs: Option<&Bound<'_, PyDict>>, + autocommit: bool, +) -> PyResult> { + let database_options = merge_database_options(uri, db_kwargs)?; + let connection_options = merge_connection_options(conn_kwargs, autocommit)?; + let payload = execute_with_adbc_native(AdbcExecutionRequest { + driver: driver.to_string(), + sql: sql.to_string(), + uri: uri.map(str::to_owned), + entrypoint: entrypoint.map(str::to_owned), + database_options, + connection_options, + }) + .map_err(|e| adbc_error("rust ADBC execution failed", e))?; + + let mut rows: Vec> = Vec::with_capacity(payload.rows.len()); + for row in &payload.rows { + let mut values: Vec> = Vec::with_capacity(row.len()); + for value in row { + values.push(adbc_value_to_py(py, value)?); + } + let tuple = PyTuple::new(py, values)?; + rows.push(tuple.into_any().unbind()); + } + + let result = PyDict::new(py); + result.set_item("columns", payload.columns)?; + result.set_item("rows", rows)?; + Ok(result.into_any().unbind()) +} + +/// Report disabled ADBC execution in lightweight Python builds. +#[cfg(not(feature = "python-adbc"))] +#[pyfunction] +#[pyo3(signature = (driver, sql, uri=None, entrypoint=None, db_kwargs=None, conn_kwargs=None, autocommit=true))] +#[allow(unused_variables, clippy::too_many_arguments)] +fn execute_with_adbc( + driver: &str, + sql: &str, + uri: Option<&str>, + entrypoint: Option<&str>, + db_kwargs: Option<&Bound<'_, PyDict>>, + conn_kwargs: Option<&Bound<'_, PyDict>>, + autocommit: bool, +) -> PyResult<()> { + Err(PyRuntimeError::new_err( + "ADBC execution support is not enabled. Rebuild sidemantic-rs with feature 'python-adbc' to use execute_with_adbc.", + )) +} + +/// Detect adapter kind from file path and content. +#[pyfunction] +fn detect_adapter_kind(path: &str, content: &str) -> Option { + detect_adapter_kind_native(path, content) +} + +/// Extract column references from a SQL expression. +#[pyfunction] +fn extract_column_references(sql_expr: &str) -> Vec { + extract_column_references_native(sql_expr) +} + +/// Analyze query components for migrator helper extraction. +#[pyfunction] +fn analyze_migrator_query(sql_query: &str) -> PyResult { + analyze_migrator_query_native(sql_query).map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Auto-detect chart x/y columns from ordered columns + numeric flags. +#[pyfunction] +fn chart_auto_detect_columns( + columns: Vec, + numeric_flags: Vec, +) -> PyResult<(String, Vec)> { + chart_auto_detect_columns_native(&columns, &numeric_flags) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Select chart type from x-column semantics and y-count. +#[pyfunction] +fn chart_select_type(x: &str, x_value_kind: &str, y_count: usize) -> String { + chart_select_type_native(x, x_value_kind, y_count) +} + +/// Format chart labels using Python-compatible naming rules. +#[pyfunction] +fn chart_format_label(column: &str) -> String { + chart_format_label_native(column) +} + +/// Determine chart encoding type from column name. +#[pyfunction] +fn chart_encoding_type(column: &str) -> String { + chart_encoding_type_native(column) +} + +/// Extract metric dependencies from a metric payload, with optional graph/context resolution. +#[pyfunction] +#[pyo3(signature = (metric_yaml, models_yaml = None, model_context = None))] +fn extract_metric_dependencies( + metric_yaml: &str, + models_yaml: Option<&str>, + model_context: Option<&str>, +) -> PyResult> { + extract_metric_dependencies_from_yaml_native(metric_yaml, models_yaml, model_context) + .map_err(|e| PyValueError::new_err(format!("failed to extract metric dependencies: {e}"))) +} + +/// Parse a top-level simple metric aggregation expression. +/// +/// Returns `(agg, inner_sql)` for simple aggregations like: +/// - `SUM(amount)` -> `("sum", Some("amount"))` +/// - `COUNT(*)` -> `("count", None)` +/// - `COUNT(DISTINCT user_id)` -> `("count_distinct", Some("user_id"))` +/// +/// Returns `None` for non-simple/complex expressions. +#[pyfunction] +fn parse_simple_metric_aggregation(sql_expr: &str) -> Option<(String, Option)> { + parse_simple_metric_aggregation_native(sql_expr) +} + +/// Convert a metric payload to SQL aggregation expression. +#[pyfunction] +fn metric_to_sql(metric_yaml: &str) -> PyResult { + metric_to_sql_native(metric_yaml).map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve metric SQL expression with Python-compatible count fallback. +#[pyfunction] +fn metric_sql_expr(metric_yaml: &str) -> PyResult { + metric_sql_expr_native(metric_yaml).map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Determine whether metric payload represents a simple aggregation. +#[pyfunction] +fn metric_is_simple_aggregation(metric_yaml: &str) -> PyResult { + metric_is_simple_aggregation_native(metric_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve dimension SQL expression. +#[pyfunction] +fn dimension_sql_expr(dimension_yaml: &str) -> PyResult { + dimension_sql_expr_with_yaml_native(dimension_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Apply time granularity to a dimension SQL expression. +#[pyfunction] +fn dimension_with_granularity(dimension_yaml: &str, granularity: &str) -> PyResult { + dimension_with_granularity_with_yaml_native(dimension_yaml, granularity) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get full hierarchy path from root to a given dimension. +#[pyfunction] +fn model_get_hierarchy_path(model_yaml: &str, dimension_name: &str) -> PyResult> { + model_get_hierarchy_path_with_yaml_native(model_yaml, dimension_name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get first child dimension in hierarchy order for a given dimension. +#[pyfunction] +fn model_get_drill_down(model_yaml: &str, dimension_name: &str) -> PyResult> { + model_get_drill_down_with_yaml_native(model_yaml, dimension_name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get parent dimension for a given dimension. +#[pyfunction] +fn model_get_drill_up(model_yaml: &str, dimension_name: &str) -> PyResult> { + model_get_drill_up_with_yaml_native(model_yaml, dimension_name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get dimension index by name. +#[pyfunction] +fn model_find_dimension_index(model_yaml: &str, name: &str) -> PyResult> { + model_find_dimension_index_with_yaml_native(model_yaml, name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get metric index by name. +#[pyfunction] +fn model_find_metric_index(model_yaml: &str, name: &str) -> PyResult> { + model_find_metric_index_with_yaml_native(model_yaml, name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get segment index by name. +#[pyfunction] +fn model_find_segment_index(model_yaml: &str, name: &str) -> PyResult> { + model_find_segment_index_with_yaml_native(model_yaml, name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Get pre-aggregation index by name. +#[pyfunction] +fn model_find_pre_aggregation_index(model_yaml: &str, name: &str) -> PyResult> { + model_find_pre_aggregation_index_with_yaml_native(model_yaml, name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve Relationship.sql_expr via sidemantic-rs. +#[pyfunction] +fn relationship_sql_expr(relationship_yaml: &str) -> PyResult { + relationship_sql_expr_with_yaml_native(relationship_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve Relationship.related_key via sidemantic-rs. +#[pyfunction] +fn relationship_related_key(relationship_yaml: &str) -> PyResult { + relationship_related_key_with_yaml_native(relationship_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve Relationship.foreign_key_columns via sidemantic-rs. +#[pyfunction] +fn relationship_foreign_key_columns(relationship_yaml: &str) -> PyResult> { + relationship_foreign_key_columns_with_yaml_native(relationship_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve Relationship.primary_key_columns via sidemantic-rs. +#[pyfunction] +fn relationship_primary_key_columns(relationship_yaml: &str) -> PyResult> { + relationship_primary_key_columns_with_yaml_native(relationship_yaml) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Resolve Segment.get_sql via sidemantic-rs. +#[pyfunction] +fn segment_get_sql(segment_yaml: &str, model_alias: &str) -> PyResult { + segment_get_sql_with_yaml_native(segment_yaml, model_alias) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Find join path between two models using Python-compatible graph semantics. +#[pyfunction] +fn find_relationship_path_with_yaml( + graph_yaml: &str, + from_model: &str, + to_model: &str, +) -> PyResult { + find_relationship_path_with_yaml_native(graph_yaml, from_model, to_model).map_err(|e| match e { + RelationshipPathError::ModelNotFound(model_name) => { + PyKeyError::new_err(format!("Model {model_name} not found")) + } + RelationshipPathError::NoJoinPath { + from_model, + to_model, + } => PyValueError::new_err(format!( + "No join path found between {from_model} and {to_model}" + )), + RelationshipPathError::InvalidPayload(err) => { + PyValueError::new_err(format!("failed to parse graph payload: {err}")) + } + }) +} + +/// Parse a qualified semantic reference using Rust graph semantics. +#[pyfunction] +fn parse_reference_with_yaml( + yaml: &str, + reference: &str, +) -> PyResult<(String, String, Option)> { + parse_reference_with_yaml_native(yaml, reference).map_err(|e| match e { + SidemanticError::Validation(_) + | SidemanticError::YamlParse(_) + | SidemanticError::InvalidConfig(_) + | SidemanticError::FileNotFound(_) + | SidemanticError::Io(_) => PyValueError::new_err(e.to_string()), + _ => PyRuntimeError::new_err(e.to_string()), + }) +} + +/// Discover model names referenced by dimensions/measures. +#[pyfunction] +fn find_models_for_query(dimensions: Vec, measures: Vec) -> Vec { + find_models_for_query_native(&dimensions, &measures) + .into_iter() + .collect() +} + +/// Discover model names referenced by dimensions/measures using graph payload context. +#[pyfunction] +fn find_models_for_query_with_yaml( + yaml: &str, + dimensions: Vec, + measures: Vec, +) -> PyResult> { + find_models_for_query_with_yaml_native(yaml, &dimensions, &measures) + .map(|models| models.into_iter().collect()) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +/// Python module entrypoint. +#[pymodule] +fn sidemantic_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(rewrite_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(compile_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(load_graph_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(load_graph_with_sql, m)?)?; + m.add_function(wrap_pyfunction!(load_graph_from_directory, m)?)?; + m.add_function(wrap_pyfunction!(parse_sql_definitions_payload, m)?)?; + m.add_function(wrap_pyfunction!(parse_sql_graph_definitions_payload, m)?)?; + m.add_function(wrap_pyfunction!(parse_sql_model_payload, m)?)?; + m.add_function(wrap_pyfunction!(parse_sql_statement_blocks_payload, m)?)?; + m.add_function(wrap_pyfunction!(registry_set_current_layer, m)?)?; + m.add_function(wrap_pyfunction!(registry_get_current_layer, m)?)?; + m.add_function(wrap_pyfunction!(validate_query_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(validate_query_references, m)?)?; + m.add_function(wrap_pyfunction!( + generate_preaggregation_materialization_sql, + m + )?)?; + m.add_function(wrap_pyfunction!( + validate_engine_refresh_sql_compatibility, + m + )?)?; + m.add_function(wrap_pyfunction!( + build_preaggregation_refresh_statements, + m + )?)?; + m.add_function(wrap_pyfunction!(resolve_preaggregation_refresh_mode, m)?)?; + m.add_function(wrap_pyfunction!( + validate_preaggregation_refresh_request, + m + )?)?; + m.add_function(wrap_pyfunction!(plan_preaggregation_refresh_execution, m)?)?; + m.add_function(wrap_pyfunction!(refresh_preaggregation, m)?)?; + m.add_function(wrap_pyfunction!(validate_models_yaml, m)?)?; + m.add_function(wrap_pyfunction!(validate_model_payload, m)?)?; + m.add_function(wrap_pyfunction!(resolve_model_inheritance, m)?)?; + m.add_function(wrap_pyfunction!(resolve_metric_inheritance, m)?)?; + m.add_function(wrap_pyfunction!(validate_metric_payload, m)?)?; + m.add_function(wrap_pyfunction!(validate_parameter_payload, m)?)?; + m.add_function(wrap_pyfunction!(validate_table_calculation_payload, m)?)?; + m.add_function(wrap_pyfunction!(extract_preaggregation_patterns, m)?)?; + m.add_function(wrap_pyfunction!(recommend_preaggregation_patterns, m)?)?; + m.add_function(wrap_pyfunction!(summarize_preaggregation_patterns, m)?)?; + m.add_function(wrap_pyfunction!(calculate_preaggregation_benefit_score, m)?)?; + m.add_function(wrap_pyfunction!(generate_preaggregation_name, m)?)?; + m.add_function(wrap_pyfunction!(generate_preaggregation_definition, m)?)?; + m.add_function(wrap_pyfunction!(is_sql_template, m)?)?; + m.add_function(wrap_pyfunction!(render_sql_template, m)?)?; + m.add_function(wrap_pyfunction!(format_parameter_value_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(interpolate_sql_with_parameters, m)?)?; + m.add_function(wrap_pyfunction!(evaluate_table_calculation_expression, m)?)?; + m.add_function(wrap_pyfunction!(validate_table_formula_expression, m)?)?; + m.add_function(wrap_pyfunction!(build_symmetric_aggregate_sql, m)?)?; + m.add_function(wrap_pyfunction!(needs_symmetric_aggregate, m)?)?; + m.add_function(wrap_pyfunction!(parse_relative_date, m)?)?; + m.add_function(wrap_pyfunction!(relative_date_to_range, m)?)?; + m.add_function(wrap_pyfunction!(is_relative_date, m)?)?; + m.add_function(wrap_pyfunction!(time_comparison_offset_interval, m)?)?; + m.add_function(wrap_pyfunction!(time_comparison_sql_offset, m)?)?; + m.add_function(wrap_pyfunction!(trailing_period_sql_interval, m)?)?; + m.add_function(wrap_pyfunction!(generate_time_comparison_sql, m)?)?; + m.add_function(wrap_pyfunction!(execute_with_adbc, m)?)?; + m.add_function(wrap_pyfunction!(detect_adapter_kind, m)?)?; + m.add_function(wrap_pyfunction!(extract_column_references, m)?)?; + m.add_function(wrap_pyfunction!(analyze_migrator_query, m)?)?; + m.add_function(wrap_pyfunction!(chart_auto_detect_columns, m)?)?; + m.add_function(wrap_pyfunction!(chart_select_type, m)?)?; + m.add_function(wrap_pyfunction!(chart_format_label, m)?)?; + m.add_function(wrap_pyfunction!(chart_encoding_type, m)?)?; + m.add_function(wrap_pyfunction!(extract_metric_dependencies, m)?)?; + m.add_function(wrap_pyfunction!(parse_simple_metric_aggregation, m)?)?; + m.add_function(wrap_pyfunction!(metric_to_sql, m)?)?; + m.add_function(wrap_pyfunction!(metric_sql_expr, m)?)?; + m.add_function(wrap_pyfunction!(metric_is_simple_aggregation, m)?)?; + m.add_function(wrap_pyfunction!(dimension_sql_expr, m)?)?; + m.add_function(wrap_pyfunction!(dimension_with_granularity, m)?)?; + m.add_function(wrap_pyfunction!(model_get_hierarchy_path, m)?)?; + m.add_function(wrap_pyfunction!(model_get_drill_down, m)?)?; + m.add_function(wrap_pyfunction!(model_get_drill_up, m)?)?; + m.add_function(wrap_pyfunction!(model_find_dimension_index, m)?)?; + m.add_function(wrap_pyfunction!(model_find_metric_index, m)?)?; + m.add_function(wrap_pyfunction!(model_find_segment_index, m)?)?; + m.add_function(wrap_pyfunction!(model_find_pre_aggregation_index, m)?)?; + m.add_function(wrap_pyfunction!(relationship_sql_expr, m)?)?; + m.add_function(wrap_pyfunction!(relationship_related_key, m)?)?; + m.add_function(wrap_pyfunction!(relationship_foreign_key_columns, m)?)?; + m.add_function(wrap_pyfunction!(relationship_primary_key_columns, m)?)?; + m.add_function(wrap_pyfunction!(segment_get_sql, m)?)?; + m.add_function(wrap_pyfunction!(find_relationship_path_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(parse_reference_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(find_models_for_query, m)?)?; + m.add_function(wrap_pyfunction!(find_models_for_query_with_yaml, m)?)?; + m.add_function(wrap_pyfunction!(generate_catalog_metadata, m)?)?; + Ok(()) +} diff --git a/sidemantic-rs/src/runtime.rs b/sidemantic-rs/src/runtime.rs new file mode 100644 index 00000000..ff437379 --- /dev/null +++ b/sidemantic-rs/src/runtime.rs @@ -0,0 +1,8252 @@ +//! Runtime orchestration helpers for pure Rust consumers. +//! +//! This module exposes a high-level API for loading models, compiling queries, +//! rewriting SQL, and validating query references without requiring the Python +//! bridge. + +use std::borrow::Cow; +use std::collections::{BTreeSet, HashMap, HashSet}; +use std::path::Path; + +use chrono::Utc; +use minijinja::Environment; +use polyglot_sql::{parse_one as polyglot_parse_one, DialectType}; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::config::{ + load_from_directory_with_metadata, load_from_file_with_metadata, + load_from_sql_string_with_metadata, load_from_string_with_metadata, + parse_sql_definitions as parse_sql_definitions_config, + parse_sql_graph_definitions_extended as parse_sql_graph_definitions_extended_config, + parse_sql_model as parse_sql_model_config, + parse_sql_statement_blocks as parse_sql_statement_blocks_config, LoadedModelSource, +}; +use crate::core::symmetric_agg::needs_symmetric_aggregate as needs_symmetric_aggregate_core; +use crate::core::{ + build_symmetric_aggregate_sql as build_symmetric_aggregate_sql_core, + extract_column_references_from_expr, extract_dependencies_with_context, + resolve_model_inheritance as resolve_models_inheritance, Aggregation, Dimension, DimensionType, + JoinPath, Metric, MetricType, Model, Parameter, ParameterType, Relationship, RelationshipType, + SemanticGraph, SqlDialect, SymmetricAggType, +}; +#[cfg(any(target_arch = "wasm32", test))] +use crate::core::{TableCalcType, TableCalculation}; +use crate::error::{Result, SidemanticError}; +use crate::sql::{QueryRewriter, SemanticQuery, SqlGenerator}; + +/// Query-validation context for unqualified metric semantics. +#[derive(Debug, Clone, Default)] +pub struct QueryValidationContext { + /// Top-level metric names defined outside `models`. + pub top_level_metric_names: HashSet, + /// Optional `sql` references for top-level metrics. + pub top_level_metric_sql_refs: HashMap, +} + +impl QueryValidationContext { + pub fn new( + top_level_metric_names: HashSet, + top_level_metric_sql_refs: HashMap, + ) -> Self { + Self { + top_level_metric_names, + top_level_metric_sql_refs, + } + } + + pub fn from_top_level_metrics(metrics: &[Metric]) -> Self { + let mut names = HashSet::new(); + let mut sql_refs = HashMap::new(); + for metric in metrics { + names.insert(metric.name.clone()); + if let Some(sql) = metric.sql.as_ref() { + if !sql.is_empty() { + sql_refs.insert(metric.name.clone(), sql.clone()); + } + } + } + Self::new(names, sql_refs) + } +} + +/// High-level runtime wrapper for pure Rust orchestration. +#[derive(Debug)] +pub struct SidemanticRuntime { + graph: SemanticGraph, + query_validation: QueryValidationContext, + top_level_metrics: Vec, + model_order: Vec, + original_model_metrics: HashMap>, + model_sources: HashMap, +} + +/// Serialized graph payload compatible with Python bridge metadata consumers. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LoadedGraphPayload { + pub models: Vec, + pub parameters: Vec, + pub top_level_metrics: Vec, + pub model_order: Vec, + pub original_model_metrics: HashMap>, + pub model_sources: HashMap, +} + +/// Tuple shape returned to Python bridge for graph path steps. +pub type RelationshipPathStep = (String, String, Vec, Vec, String); + +/// Relationship path discovery errors that preserve Python-compatible exception semantics. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum RelationshipPathError { + #[error("Model {0} not found")] + ModelNotFound(String), + #[error("No join path found between {from_model} and {to_model}")] + NoJoinPath { + from_model: String, + to_model: String, + }, + #[error("{0}")] + InvalidPayload(String), +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum GraphPathKeyPayload { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Deserialize)] +struct GraphPathPayload { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct GraphPathModelPayload { + name: String, + #[serde(default)] + primary_key_columns: Vec, + #[serde(default)] + primary_key: Option, + #[serde(default)] + relationships: Vec, +} + +#[derive(Debug, Deserialize)] +struct GraphPathRelationshipPayload { + name: String, + #[serde(default, rename = "type")] + relationship_type: Option, + #[serde(default)] + foreign_key: Option, + #[serde(default)] + primary_key: Option, + #[serde(default)] + foreign_key_columns: Vec, + #[serde(default)] + primary_key_columns: Vec, + #[serde(default)] + has_foreign_key: bool, + #[serde(default)] + has_primary_key: bool, + through: Option, + through_foreign_key: Option, + related_foreign_key: Option, +} + +#[derive(Debug, Deserialize)] +struct MetricDependencyPayload { + name: String, + #[serde(default, rename = "type")] + metric_type: Option, + #[serde(default)] + agg: Option, + #[serde(default)] + sql: Option, + #[serde(default)] + numerator: Option, + #[serde(default)] + denominator: Option, + #[serde(default)] + base_metric: Option, +} + +#[derive(Debug, Deserialize)] +struct RuntimeQueryPayload { + #[serde(default)] + metrics: Vec, + #[serde(default)] + dimensions: Vec, + #[serde(default)] + filters: Vec, + #[serde(default)] + segments: Vec, + #[serde(default)] + order_by: Vec, + limit: Option, + offset: Option, + #[serde(default)] + ungrouped: bool, + #[serde(default)] + use_preaggregations: bool, + #[serde(default)] + skip_default_time_dimensions: bool, + preagg_database: Option, + preagg_schema: Option, + #[serde(default)] + parameter_values: HashMap, +} + +#[derive(Debug, Deserialize)] +struct RuntimeQueryValidationPayload { + #[serde(default)] + metrics: Vec, + #[serde(default)] + dimensions: Vec, +} + +#[derive(Debug, Serialize)] +struct ParsedSqlDefinitionsPayload { + metrics: Vec, + segments: Vec, +} + +#[derive(Debug, Serialize)] +struct ParsedSqlGraphDefinitionsPayload { + metrics: Vec, + segments: Vec, + parameters: Vec, + pre_aggregations: Vec, +} + +#[derive(Debug, Deserialize)] +struct DimensionHelperPayload { + name: String, + #[serde(default, rename = "type")] + dimension_type: Option, + #[serde(default)] + sql: Option, + #[serde(default)] + supported_granularities: Option>, +} + +#[derive(Debug, Deserialize)] +struct ModelHierarchyPayload { + #[serde(default)] + dimensions: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelHierarchyDimensionPayload { + name: String, + #[serde(default)] + parent: Option, +} + +#[derive(Debug, Deserialize)] +struct ModelLookupPayload { + #[serde(default)] + dimensions: Vec, + #[serde(default)] + metrics: Vec, + #[serde(default)] + segments: Vec, + #[serde(default)] + pre_aggregations: Vec, +} + +#[derive(Debug, Deserialize)] +struct ModelLookupItemPayload { + name: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum RelationshipKeyPayload { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Deserialize)] +struct RelationshipHelperPayload { + name: String, + #[serde(default, rename = "type")] + relationship_type: Option, + #[serde(default)] + foreign_key: Option, + #[serde(default)] + primary_key: Option, +} + +#[derive(Debug, Deserialize)] +struct SegmentHelperPayload { + sql: String, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +struct PreaggPatternRecord { + model: String, + metrics: Vec, + dimensions: Vec, + granularities: Vec, + count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PreaggRecommendationRecord { + pattern: PreaggPatternRecord, + suggested_name: String, + query_count: usize, + estimated_benefit_score: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +struct MigratorAnalysisPayload { + #[serde(default)] + column_references: Vec, + #[serde(default)] + group_by_columns: Vec<(String, String)>, + #[serde(default)] + derived_metrics: Vec, + #[serde(default)] + cumulative_metrics: Vec, + #[serde(default)] + aggregations_in_derived: Vec<(String, String, String)>, + #[serde(default)] + aggregations_in_cumulative: Vec<(String, String, String)>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MigratorDerivedMetricRecord { + name: String, + sql_expression: String, + table: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MigratorCumulativeMetricRecord { + name: String, + base_metric: String, + table: String, + #[serde(skip_serializing_if = "Option::is_none")] + window: Option, + #[serde(skip_serializing_if = "Option::is_none")] + grain_to_date: Option, + agg_type: String, + agg_column: String, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +struct PatternKey { + model: String, + metrics: Vec, + dimensions: Vec, + granularities: Vec, +} + +fn parse_metric_helper_payload(metric_yaml: &str) -> Result { + serde_yaml::from_str(metric_yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to parse metric payload: {e}"))) +} + +fn normalize_set(values: Vec) -> Vec { + let set: BTreeSet = values + .into_iter() + .filter(|value| !value.is_empty()) + .collect::>(); + set.into_iter().collect() +} + +fn extract_pattern_key(query: &str, instrumentation_re: &Regex) -> Option { + let captures = instrumentation_re.captures(query)?; + let metadata = captures.get(1)?.as_str(); + + let mut parts: HashMap = HashMap::new(); + for part in metadata.split_whitespace() { + if let Some((key, value)) = part.split_once('=') { + parts.insert(key.to_string(), value.to_string()); + } + } + + let models = normalize_set( + parts + .get("models") + .map(|value| { + value + .split(',') + .map(str::trim) + .map(str::to_string) + .collect::>() + }) + .unwrap_or_default(), + ); + let metrics = normalize_set( + parts + .get("metrics") + .map(|value| { + value + .split(',') + .map(str::trim) + .map(str::to_string) + .collect::>() + }) + .unwrap_or_default(), + ); + let dimensions = normalize_set( + parts + .get("dimensions") + .map(|value| { + value + .split(',') + .map(str::trim) + .map(str::to_string) + .collect::>() + }) + .unwrap_or_default(), + ); + let granularities = normalize_set( + parts + .get("granularities") + .map(|value| { + value + .split(',') + .map(str::trim) + .map(str::to_string) + .collect::>() + }) + .unwrap_or_default(), + ); + + if models.len() != 1 || metrics.is_empty() { + return None; + } + + Some(PatternKey { + model: models[0].clone(), + metrics, + dimensions, + granularities, + }) +} + +fn benefit_score(pattern: &PreaggPatternRecord, count: usize) -> f64 { + let query_score = (count as f64 + 1.0).log10() / 6.0; + let dim_count = pattern.dimensions.len() as f64; + let dim_score = (1.0 - (dim_count * 0.1)).max(0.0); + let metric_count = pattern.metrics.len() as f64; + let metric_score = (0.25 + (metric_count * 0.25)).min(1.0); + ((query_score * 0.5) + (dim_score * 0.25) + (metric_score * 0.25)).min(1.0) +} + +fn granularity_rank(granularity: &str) -> usize { + match granularity { + "second" => 0, + "minute" => 1, + "hour" => 2, + "day" => 3, + "week" => 4, + "month" => 5, + "quarter" => 6, + "year" => 7, + _ => 99, + } +} + +fn pattern_name(pattern: &PreaggPatternRecord) -> String { + let mut parts: Vec = Vec::new(); + + if !pattern.granularities.is_empty() { + let mut granularities = pattern.granularities.clone(); + granularities.sort_by_key(|value| granularity_rank(value)); + if let Some(first) = granularities.first() { + parts.push(first.clone()); + } + } + + if !pattern.dimensions.is_empty() { + let mut dimensions = pattern.dimensions.clone(); + dimensions.sort(); + if dimensions.len() <= 2 { + parts.extend( + dimensions + .into_iter() + .map(|value| value.rsplit('.').next().unwrap_or(&value).to_string()), + ); + } else { + parts.push(format!("{}dims", dimensions.len())); + } + } + + if pattern.metrics.len() == 1 { + if let Some(metric) = pattern.metrics.first() { + parts.push(metric.rsplit('.').next().unwrap_or(metric).to_string()); + } + } else if !pattern.metrics.is_empty() { + parts.push(format!("{}metrics", pattern.metrics.len())); + } + + if parts.is_empty() { + "rollup".to_string() + } else { + parts.join("_") + } +} + +fn parse_pattern_payload(pattern_json: &str, context: &str) -> Result { + serde_json::from_str(pattern_json).map_err(|e| { + SidemanticError::Validation(format!( + "failed to parse pattern payload for {context}: {e}" + )) + }) +} + +fn parse_patterns_payload(patterns_json: &str, context: &str) -> Result> { + serde_json::from_str(patterns_json).map_err(|e| { + SidemanticError::Validation(format!( + "failed to parse pattern payload for {context}: {e}" + )) + }) +} + +fn serialize_json_payload(value: &serde_json::Value, context: &str) -> Result { + serde_json::to_string(value).map_err(|e| { + SidemanticError::SqlGeneration(format!("failed to serialize {context} payload: {e}")) + }) +} + +fn strip_model_prefix(value: &str) -> String { + value.rsplit('.').next().unwrap_or(value).to_string() +} + +fn graph_path_key_to_columns(key: &GraphPathKeyPayload) -> Vec { + match key { + GraphPathKeyPayload::Single(value) => vec![value.clone()], + GraphPathKeyPayload::Multiple(values) => values.clone(), + } +} + +fn graph_model_primary_keys(model: &GraphPathModelPayload) -> Vec { + if !model.primary_key_columns.is_empty() { + model.primary_key_columns.clone() + } else if let Some(primary_key) = model.primary_key.as_ref() { + graph_path_key_to_columns(primary_key) + } else { + vec!["id".to_string()] + } +} + +fn relationship_type_name(relationship: &GraphPathRelationshipPayload) -> &str { + relationship + .relationship_type + .as_deref() + .unwrap_or("many_to_one") +} + +fn parse_relationship_type_label(relationship_type: &str) -> RelationshipType { + match relationship_type { + "one_to_one" => RelationshipType::OneToOne, + "one_to_many" => RelationshipType::OneToMany, + "many_to_many" => RelationshipType::ManyToMany, + _ => RelationshipType::ManyToOne, + } +} + +fn relationship_foreign_keys(relationship: &GraphPathRelationshipPayload) -> Vec { + if !relationship.foreign_key_columns.is_empty() { + return relationship.foreign_key_columns.clone(); + } + if let Some(foreign_key) = relationship.foreign_key.as_ref() { + return graph_path_key_to_columns(foreign_key); + } + + if relationship_type_name(relationship) == "many_to_one" { + return vec![format!("{}_id", relationship.name)]; + } + + vec!["id".to_string()] +} + +fn relationship_primary_keys(relationship: &GraphPathRelationshipPayload) -> Vec { + if !relationship.primary_key_columns.is_empty() { + return relationship.primary_key_columns.clone(); + } + if let Some(primary_key) = relationship.primary_key.as_ref() { + return graph_path_key_to_columns(primary_key); + } + vec!["id".to_string()] +} + +fn relationship_has_foreign_key(relationship: &GraphPathRelationshipPayload) -> bool { + relationship.has_foreign_key + || !relationship.foreign_key_columns.is_empty() + || relationship.foreign_key.is_some() +} + +fn relationship_has_primary_key(relationship: &GraphPathRelationshipPayload) -> bool { + relationship.has_primary_key + || !relationship.primary_key_columns.is_empty() + || relationship.primary_key.is_some() +} + +fn relationship_first_foreign_key(relationship: &GraphPathRelationshipPayload) -> Option { + relationship_foreign_keys(relationship).first().cloned() +} + +fn parse_metric_type_for_dependencies(payload: &MetricDependencyPayload) -> MetricType { + match payload.metric_type.as_deref() { + Some("derived") => MetricType::Derived, + Some("ratio") => MetricType::Ratio, + Some("cumulative") => MetricType::Cumulative, + Some("time_comparison" | "timecomparison") => MetricType::TimeComparison, + Some("conversion") => MetricType::Conversion, + Some("retention") => MetricType::Retention, + Some("cohort") => MetricType::Cohort, + _ => { + if payload.agg.is_none() && payload.sql.is_some() { + MetricType::Derived + } else { + MetricType::Simple + } + } + } +} + +fn parse_metric_agg_for_dependencies(agg: Option<&str>) -> Option { + match agg { + Some("sum") => Some(Aggregation::Sum), + Some("count") => Some(Aggregation::Count), + Some("count_distinct") => Some(Aggregation::CountDistinct), + Some("avg") => Some(Aggregation::Avg), + Some("min") => Some(Aggregation::Min), + Some("max") => Some(Aggregation::Max), + Some("median") => Some(Aggregation::Median), + Some("expression") => Some(Aggregation::Expression), + _ => None, + } +} + +fn format_python_string_list(values: &[String]) -> String { + let rendered = values + .iter() + .map(|value| format!("'{}'", value.replace('\'', "\\'"))) + .collect::>() + .join(", "); + format!("[{rendered}]") +} + +fn has_inline_aggregation(sql: &str) -> bool { + Regex::new(r"(?i)\b(sum|avg|count|min|max|median)\s*\(") + .ok() + .is_some_and(|re| re.is_match(sql)) +} + +fn yaml_value_to_python_str(value: &serde_yaml::Value) -> String { + match value { + serde_yaml::Value::Null => "None".to_string(), + serde_yaml::Value::Bool(v) => { + if *v { + "True".to_string() + } else { + "False".to_string() + } + } + serde_yaml::Value::Number(v) => v.to_string(), + serde_yaml::Value::String(v) => v.clone(), + serde_yaml::Value::Sequence(v) => { + serde_json::to_string(v).unwrap_or_else(|_| "[]".to_string()) + } + serde_yaml::Value::Mapping(v) => { + serde_json::to_string(v).unwrap_or_else(|_| "{}".to_string()) + } + serde_yaml::Value::Tagged(v) => yaml_value_to_python_str(&v.value), + } +} + +fn yaml_value_type_name(value: &serde_yaml::Value) -> &'static str { + match value { + serde_yaml::Value::Null => "NoneType", + serde_yaml::Value::Bool(_) => "bool", + serde_yaml::Value::Number(_) => "number", + serde_yaml::Value::String(_) => "str", + serde_yaml::Value::Sequence(_) => "list", + serde_yaml::Value::Mapping(_) => "dict", + serde_yaml::Value::Tagged(v) => yaml_value_type_name(&v.value), + } +} + +fn value_is_truthy(value: &serde_yaml::Value) -> bool { + match value { + serde_yaml::Value::Null => false, + serde_yaml::Value::Bool(v) => *v, + serde_yaml::Value::Number(v) => { + if let Some(i) = v.as_i64() { + return i != 0; + } + if let Some(u) = v.as_u64() { + return u != 0; + } + if let Some(f) = v.as_f64() { + return f != 0.0; + } + true + } + serde_yaml::Value::String(v) => !v.is_empty(), + serde_yaml::Value::Sequence(v) => !v.is_empty(), + serde_yaml::Value::Mapping(v) => !v.is_empty(), + serde_yaml::Value::Tagged(v) => value_is_truthy(&v.value), + } +} + +fn parameter_runtime_value( + parameter: &Parameter, + parameter_values: &HashMap, +) -> serde_yaml::Value { + if let Some(value) = parameter_values.get(¶meter.name) { + return value.clone(); + } + + if parameter.default_to_today && parameter.parameter_type == ParameterType::Date { + return serde_yaml::Value::String(Utc::now().date_naive().to_string()); + } + + if let Some(default_value) = ¶meter.default_value { + return serde_yaml::to_value(default_value).unwrap_or(serde_yaml::Value::Null); + } + + serde_yaml::Value::Null +} + +fn format_float_like_python(value: f64) -> String { + if value.fract() == 0.0 { + format!("{value:.1}") + } else { + value.to_string() + } +} + +fn format_parameter_value( + parameter: &Parameter, + value: &serde_yaml::Value, +) -> std::result::Result { + match parameter.parameter_type { + ParameterType::String => { + let escaped = yaml_value_to_python_str(value).replace('\'', "''"); + Ok(format!("'{escaped}'")) + } + ParameterType::Date => Ok(format!("'{}'", yaml_value_to_python_str(value))), + ParameterType::Number => match value { + serde_yaml::Value::Number(v) => Ok(v.to_string()), + serde_yaml::Value::String(v) => { + let parsed = v + .parse::() + .map_err(|_| format!("Invalid numeric parameter value: {v}"))?; + if !parsed.is_finite() { + return Err(format!("Invalid numeric parameter value: {v}")); + } + Ok(format_float_like_python(parsed)) + } + serde_yaml::Value::Tagged(v) => format_parameter_value(parameter, &v.value), + other => Err(format!( + "Numeric parameter must be int, float, or numeric string, got {}", + yaml_value_type_name(other) + )), + }, + ParameterType::Unquoted => { + let rendered = yaml_value_to_python_str(value); + let is_safe = rendered + .chars() + .filter(|c| *c != '_' && *c != '.') + .all(char::is_alphanumeric); + if !is_safe { + return Err(format!( + "Unquoted parameter must be alphanumeric with underscores/dots only: {}", + yaml_value_to_python_str(value) + )); + } + Ok(rendered) + } + ParameterType::Yesno => { + if value_is_truthy(value) { + Ok("TRUE".to_string()) + } else { + Ok("FALSE".to_string()) + } + } + } +} + +pub fn is_sql_template(sql: &str) -> bool { + sql.contains("{{") || sql.contains("{%") || sql.contains("{#") +} + +fn has_jinja_control_markers(sql: &str) -> bool { + sql.contains("{%") || sql.contains("{#") +} + +fn yaml_to_json_value(value: &serde_yaml::Value) -> serde_json::Value { + match value { + serde_yaml::Value::Null => serde_json::Value::Null, + serde_yaml::Value::Bool(v) => serde_json::Value::Bool(*v), + serde_yaml::Value::Number(v) => { + if let Some(i) = v.as_i64() { + return serde_json::json!(i); + } + if let Some(u) = v.as_u64() { + return serde_json::json!(u); + } + if let Some(f) = v.as_f64() { + return serde_json::json!(f); + } + serde_json::Value::Null + } + serde_yaml::Value::String(v) => serde_json::Value::String(v.clone()), + serde_yaml::Value::Sequence(values) => { + serde_json::Value::Array(values.iter().map(yaml_to_json_value).collect()) + } + serde_yaml::Value::Mapping(values) => { + let mut object = serde_json::Map::new(); + for (key, value) in values { + let key = match key { + serde_yaml::Value::String(v) => v.clone(), + other => yaml_value_to_python_str(other), + }; + object.insert(key, yaml_to_json_value(value)); + } + serde_json::Value::Object(object) + } + serde_yaml::Value::Tagged(v) => yaml_to_json_value(&v.value), + } +} + +fn parse_string_keyed_yaml_mapping( + payload_yaml: &str, + error_context: &str, +) -> std::result::Result, String> { + let parsed: serde_yaml::Value = serde_yaml::from_str(payload_yaml) + .map_err(|e| format!("failed to parse {error_context} payload: {e}"))?; + + let mut result = HashMap::new(); + let Some(mapping) = parsed.as_mapping() else { + if parsed.is_null() { + return Ok(result); + } + return Err(format!("{error_context} payload must be a YAML mapping")); + }; + + for (key, value) in mapping { + let key = match key { + serde_yaml::Value::String(v) => v.clone(), + other => yaml_value_to_python_str(other), + }; + result.insert(key, value.clone()); + } + + Ok(result) +} + +fn build_runtime_context( + parameters_by_name: &HashMap, + parameter_values: &HashMap, +) -> HashMap { + let mut context = HashMap::new(); + for (name, parameter) in parameters_by_name { + context.insert( + name.clone(), + parameter_runtime_value(parameter, parameter_values), + ); + } + context +} + +fn render_template_with_context( + template_str: &str, + context: &HashMap, +) -> std::result::Result { + let env = Environment::new(); + let render_context = context + .iter() + .map(|(key, value)| (key.clone(), yaml_to_json_value(value))) + .collect::>(); + let render_once = |candidate: &str| -> std::result::Result { + let template = env + .template_from_str(candidate) + .map_err(|e| format!("Template syntax error: {e}"))?; + template + .render(render_context.clone()) + .map_err(|e| format!("Template rendering error: {e}")) + }; + + match render_once(template_str) { + Ok(rendered) => Ok(rendered), + Err(err) if err.contains("method named items") => { + let re = Regex::new(r"([A-Za-z_][A-Za-z0-9_\.]*)\.items\(\)") + .expect("valid template items() compatibility regex"); + let rewritten = re.replace_all(template_str, "$1|items").into_owned(); + if rewritten == template_str { + return Err(err); + } + render_once(&rewritten) + } + Err(err) => Err(err), + } +} + +fn interpolate_simple_filter( + filter: &str, + parameters_by_name: &HashMap, + parameter_values: &HashMap, +) -> std::result::Result { + let pattern = Regex::new(r"\{\{\s*(\w+)\s*\}\}").expect("valid parameter regex"); + let mut interpolation_error: Option = None; + + let rendered = pattern + .replace_all(filter, |captures: ®ex::Captures<'_>| { + let Some(param_name_match) = captures.get(1) else { + return Cow::Owned( + captures + .get(0) + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(), + ); + }; + let param_name = param_name_match.as_str(); + let Some(parameter) = parameters_by_name.get(param_name) else { + return Cow::Owned( + captures + .get(0) + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(), + ); + }; + + let value = parameter_runtime_value(parameter, parameter_values); + match format_parameter_value(parameter, &value) { + Ok(formatted) => Cow::Owned(formatted), + Err(err) => { + interpolation_error = Some(err); + Cow::Owned( + captures + .get(0) + .map(|m| m.as_str()) + .unwrap_or("") + .to_string(), + ) + } + } + }) + .into_owned(); + + if let Some(err) = interpolation_error { + return Err(err); + } + + Ok(rendered) +} + +fn interpolate_sql_with_parameters_impl( + sql: &str, + parameters_by_name: &HashMap, + parameter_values: &HashMap, +) -> std::result::Result { + if is_sql_template(sql) && has_jinja_control_markers(sql) { + let context = build_runtime_context(parameters_by_name, parameter_values); + return render_template_with_context(sql, &context); + } + interpolate_simple_filter(sql, parameters_by_name, parameter_values) +} + +pub fn interpolate_query_filters( + graph: &SemanticGraph, + filters: Vec, + parameter_values: &HashMap, +) -> std::result::Result, String> { + let parameters_by_name: HashMap = graph + .parameters() + .map(|parameter| (parameter.name.clone(), parameter)) + .collect(); + + filters + .into_iter() + .map(|filter| { + interpolate_sql_with_parameters_impl(&filter, ¶meters_by_name, parameter_values) + }) + .collect() +} + +/// Render SQL template using YAML context payload. +pub fn render_sql_template(template_str: &str, context_yaml: &str) -> Result { + let context = parse_string_keyed_yaml_mapping(context_yaml, "template context") + .map_err(SidemanticError::Validation)?; + render_template_with_context(template_str, &context).map_err(SidemanticError::Validation) +} + +/// Format a parameter value from YAML payloads. +pub fn format_parameter_value_with_yaml(parameter_yaml: &str, value_yaml: &str) -> Result { + let parameter: Parameter = serde_yaml::from_str(parameter_yaml).map_err(|e| { + SidemanticError::Validation(format!("failed to parse parameter payload: {e}")) + })?; + let value: serde_yaml::Value = serde_yaml::from_str(value_yaml).map_err(|e| { + SidemanticError::Validation(format!("failed to parse parameter value: {e}")) + })?; + format_parameter_value(¶meter, &value).map_err(SidemanticError::Validation) +} + +/// Interpolate SQL with parameter definitions and values from YAML payloads. +pub fn interpolate_sql_with_parameters_with_yaml( + sql: &str, + parameters_yaml: &str, + values_yaml: &str, +) -> Result { + let parameters: Vec = serde_yaml::from_str(parameters_yaml).map_err(|e| { + SidemanticError::Validation(format!("failed to parse parameter definitions: {e}")) + })?; + let values = parse_string_keyed_yaml_mapping(values_yaml, "parameter values") + .map_err(SidemanticError::Validation)?; + + let parameters_by_name: HashMap = parameters + .iter() + .map(|parameter| (parameter.name.clone(), parameter)) + .collect(); + + interpolate_sql_with_parameters_impl(sql, ¶meters_by_name, &values) + .map_err(SidemanticError::Validation) +} + +/// Compile a semantic query by loading graph YAML and parsing query YAML payload. +pub fn compile_with_yaml_query(yaml: &str, query_yaml: &str) -> Result { + let runtime = SidemanticRuntime::from_yaml(yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to load YAML models: {e}")))?; + let payload: RuntimeQueryPayload = serde_yaml::from_str(query_yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to parse query payload: {e}")))?; + let filters = + interpolate_query_filters(runtime.graph(), payload.filters, &payload.parameter_values) + .map_err(|e| { + SidemanticError::Validation(format!("failed to interpolate query parameters: {e}")) + })?; + + let mut query = SemanticQuery::new() + .with_metrics(payload.metrics) + .with_dimensions(payload.dimensions) + .with_filters(filters) + .with_segments(payload.segments) + .with_order_by(payload.order_by) + .with_ungrouped(payload.ungrouped) + .with_use_preaggregations(payload.use_preaggregations) + .with_skip_default_time_dimensions(payload.skip_default_time_dimensions) + .with_preaggregation_qualifiers(payload.preagg_database, payload.preagg_schema); + + if let Some(limit) = payload.limit { + query = query.with_limit(limit); + } + if let Some(offset) = payload.offset { + query = query.with_offset(offset); + } + + runtime + .compile(&query) + .map_err(|e| SidemanticError::SqlGeneration(format!("failed to compile SQL: {e}"))) +} + +/// Validate query references using graph and query YAML payloads. +pub fn validate_query_references_with_yaml( + yaml: &str, + metrics: &[String], + dimensions: &[String], +) -> Result> { + let runtime = SidemanticRuntime::from_yaml(yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to load YAML models: {e}")))?; + Ok(runtime.validate_query_references(metrics, dimensions)) +} + +/// Validate query references using graph and query YAML payloads. +pub fn validate_query_with_yaml(yaml: &str, query_yaml: &str) -> Result> { + let payload: RuntimeQueryValidationPayload = serde_yaml::from_str(query_yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to parse query payload: {e}")))?; + validate_query_references_with_yaml(yaml, &payload.metrics, &payload.dimensions) +} + +/// Load graph YAML and serialize runtime payload. +pub fn load_graph_with_yaml(yaml: &str) -> Result { + let runtime = SidemanticRuntime::from_yaml(yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to load YAML models: {e}")))?; + let payload = runtime.loaded_graph_payload(); + serde_json::to_string(&payload) + .map_err(|e| SidemanticError::Validation(format!("failed to serialize graph payload: {e}"))) +} + +/// Load graph definitions from a directory and serialize runtime payload. +pub fn load_graph_from_directory(path: &str) -> Result { + let runtime = SidemanticRuntime::from_directory(path).map_err(|e| { + SidemanticError::Validation(format!("failed to load directory models: {e}")) + })?; + serde_json::to_string(&runtime.loaded_graph_payload()) + .map_err(|e| SidemanticError::Validation(format!("failed to serialize graph payload: {e}"))) +} + +fn build_runtime_with_metadata( + graph: SemanticGraph, + top_level_metrics: Vec, + model_order: Vec, + original_model_metrics: HashMap>, + model_sources: HashMap, +) -> SidemanticRuntime { + let query_validation = QueryValidationContext::from_top_level_metrics(&top_level_metrics); + SidemanticRuntime { + graph, + query_validation, + top_level_metrics, + model_order, + original_model_metrics, + model_sources, + } +} + +/// Load graph SQL content (.sql definitions with optional YAML frontmatter) and serialize runtime payload. +pub fn load_graph_with_sql(sql_content: &str) -> Result { + let loaded = load_from_sql_string_with_metadata(sql_content)?; + let runtime = build_runtime_with_metadata( + loaded.graph, + loaded.top_level_metrics, + loaded.model_order, + loaded.original_model_metrics, + loaded.model_sources, + ); + serde_json::to_string(&runtime.loaded_graph_payload()) + .map_err(|e| SidemanticError::Validation(format!("failed to serialize graph payload: {e}"))) +} + +/// Parse SQL metric/segment definitions and return serialized payload. +pub fn parse_sql_definitions_payload(sql: &str) -> Result { + let (metrics, segments) = parse_sql_definitions_config(sql).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL definitions: {e}")) + })?; + + let payload = ParsedSqlDefinitionsPayload { metrics, segments }; + serde_json::to_string(&payload).map_err(|e| { + SidemanticError::SqlGeneration(format!("failed to serialize SQL definitions payload: {e}")) + }) +} + +/// Parse SQL graph definitions and return serialized payload. +pub fn parse_sql_graph_definitions_payload(sql: &str) -> Result { + let (metrics, segments, parameters, pre_aggregations) = + parse_sql_graph_definitions_extended_config(sql).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL graph definitions: {e}")) + })?; + + let payload = ParsedSqlGraphDefinitionsPayload { + metrics, + segments, + parameters, + pre_aggregations, + }; + serde_json::to_string(&payload).map_err(|e| { + SidemanticError::SqlGeneration(format!( + "failed to serialize SQL graph definitions payload: {e}" + )) + }) +} + +/// Parse SQL model definition and return serialized model payload. +pub fn parse_sql_model_payload(sql: &str) -> Result { + let model = parse_sql_model_config(sql) + .map_err(|e| SidemanticError::Validation(format!("failed to parse SQL model: {e}")))?; + serde_json::to_string(&model).map_err(|e| { + SidemanticError::SqlGeneration(format!("failed to serialize SQL model payload: {e}")) + }) +} + +/// Parse raw SQL statement blocks and return serialized payload. +pub fn parse_sql_statement_blocks_payload(sql: &str) -> Result { + let blocks = parse_sql_statement_blocks_config(sql).map_err(|e| { + SidemanticError::Validation(format!("failed to parse SQL statement blocks: {e}")) + })?; + serde_json::to_string(&blocks).map_err(|e| { + SidemanticError::SqlGeneration(format!( + "failed to serialize SQL statement blocks payload: {e}" + )) + }) +} + +/// Validate model graph payload parses and composes into a runtime graph. +pub fn validate_models_yaml(yaml: &str) -> Result { + SidemanticRuntime::from_yaml(yaml) + .map(|_| true) + .map_err(|e| SidemanticError::Validation(format!("failed to load YAML models: {e}"))) +} + +/// Parse reference from graph and reference payloads. +pub fn parse_reference_with_yaml( + yaml: &str, + reference: &str, +) -> Result<(String, String, Option)> { + let runtime = SidemanticRuntime::from_yaml(yaml) + .map_err(|e| SidemanticError::Validation(format!("failed to load YAML models: {e}")))?; + runtime + .parse_reference(reference) + .map_err(|e| SidemanticError::Validation(format!("failed to parse reference: {e}"))) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn split_top_level_csv(input: &str) -> Vec { + let mut parts = Vec::new(); + let mut current = String::new(); + let mut depth: i32 = 0; + let mut in_quote: Option = None; + let mut escape = false; + + for ch in input.chars() { + if let Some(quote) = in_quote { + current.push(ch); + if escape { + escape = false; + continue; + } + if ch == '\\' { + escape = true; + continue; + } + if ch == quote { + in_quote = None; + } + continue; + } + + match ch { + '\'' | '"' => { + in_quote = Some(ch); + current.push(ch); + } + '(' => { + depth += 1; + current.push(ch); + } + ')' => { + depth = (depth - 1).max(0); + current.push(ch); + } + ',' if depth == 0 => { + let trimmed = current.trim(); + if !trimmed.is_empty() { + parts.push(trimmed.to_string()); + } + current.clear(); + } + _ => current.push(ch), + } + } + + let trimmed = current.trim(); + if !trimmed.is_empty() { + parts.push(trimmed.to_string()); + } + parts +} + +#[cfg(any(target_arch = "wasm32", test))] +fn normalize_ident(ident: &str) -> String { + ident + .trim() + .trim_matches('"') + .trim_matches('`') + .trim_matches('[') + .trim_matches(']') + .to_string() +} + +#[cfg(any(target_arch = "wasm32", test))] +fn split_alias_suffix(expr: &str) -> (String, Option) { + let trimmed = expr.trim(); + let with_as = Regex::new(r"(?is)^(?P.+?)\s+as\s+(?P[A-Za-z_][A-Za-z0-9_]*)\s*$") + .expect("valid alias regex"); + if let Some(captures) = with_as.captures(trimmed) { + let body = captures + .name("body") + .map(|m| m.as_str().trim().to_string()) + .unwrap_or_else(|| trimmed.to_string()); + let alias = captures.name("alias").map(|m| normalize_ident(m.as_str())); + return (body, alias); + } + let without_as = Regex::new(r"(?is)^(?P.+?)\s+(?P[A-Za-z_][A-Za-z0-9_]*)\s*$") + .expect("valid bare alias regex"); + if let Some(captures) = without_as.captures(trimmed) { + let body = captures + .name("body") + .map(|m| m.as_str().trim().to_string()) + .unwrap_or_else(|| trimmed.to_string()); + let alias = captures.name("alias").map(|m| normalize_ident(m.as_str())); + return (body, alias); + } + (trimmed.to_string(), None) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn normalize_alias_key(value: &str) -> String { + normalize_ident(value).to_ascii_lowercase() +} + +#[cfg(any(target_arch = "wasm32", test))] +fn split_order_suffix(order_item: &str) -> (String, String) { + let trimmed = order_item.trim(); + let order_re = Regex::new( + r"(?is)^(?P.+?)\s+(?Pasc|desc)(?:\s+nulls\s+(?Pfirst|last))?\s*$", + ) + .expect("valid order suffix regex"); + if let Some(captures) = order_re.captures(trimmed) { + let expr = captures + .name("expr") + .map(|m| m.as_str().trim().to_string()) + .unwrap_or_else(|| trimmed.to_string()); + let mut suffix = captures + .name("dir") + .map(|m| m.as_str().to_ascii_uppercase()) + .unwrap_or_default(); + if let Some(nulls) = captures.name("nulls") { + if !suffix.is_empty() { + suffix.push(' '); + } + suffix.push_str("NULLS "); + suffix.push_str(&nulls.as_str().to_ascii_uppercase()); + } + return (expr, suffix); + } + (trimmed.to_string(), String::new()) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn resolve_wasm_semantic_ref( + runtime: &SidemanticRuntime, + aliases: &HashMap, + default_model: &str, + reference: &str, +) -> Result<(String, bool)> { + let (model_ref, field_ref) = if let Some((model_part, field_part)) = reference.split_once('.') { + (normalize_ident(model_part), normalize_ident(field_part)) + } else { + (default_model.to_string(), normalize_ident(reference)) + }; + + let resolved_model = aliases + .get(&model_ref) + .cloned() + .unwrap_or(model_ref.clone()); + let model = runtime.graph.get_model(&resolved_model).ok_or_else(|| { + SidemanticError::Validation(format!( + "wasm rewrite fallback model '{resolved_model}' not found" + )) + })?; + + if model.get_metric(&field_ref).is_some() { + Ok((format!("{resolved_model}.{field_ref}"), true)) + } else if model.get_dimension(&field_ref).is_some() { + Ok((format!("{resolved_model}.{field_ref}"), false)) + } else { + Err(SidemanticError::Validation(format!( + "wasm rewrite fallback field '{resolved_model}.{field_ref}' not found" + ))) + } +} + +#[cfg(any(target_arch = "wasm32", test))] +fn normalize_wasm_match_key(value: &str) -> String { + value + .trim() + .trim_matches('"') + .trim_matches('`') + .trim_matches('[') + .trim_matches(']') + .to_ascii_lowercase() +} + +#[cfg(any(target_arch = "wasm32", test))] +fn normalize_wasm_projection_expression(expr: &str, aliases: &HashMap) -> String { + let mut normalized = expr.to_string(); + for (alias, model) in aliases { + if alias == model { + continue; + } + let pattern = format!(r"(?i)\b{}\s*\.\s*", regex::escape(alias)); + if let Ok(re) = Regex::new(&pattern) { + normalized = re.replace_all(&normalized, format!("{model}.")).to_string(); + } + } + normalized = normalized.replace('`', ""); + normalized = normalized.replace('"', ""); + normalized = normalized.replace('[', ""); + normalized = normalized.replace(']', ""); + if let Ok(re) = Regex::new(r"\s+") { + normalized = re.replace_all(&normalized, "").to_string(); + } + normalized.to_ascii_lowercase() +} + +#[cfg(any(target_arch = "wasm32", test))] +fn strip_wasm_model_qualifier(expr: &str, model_name: &str) -> String { + let pattern = format!(r"(?i)\b{}\s*\.\s*", regex::escape(model_name)); + if let Ok(re) = Regex::new(&pattern) { + re.replace_all(expr, "").to_string() + } else { + expr.to_string() + } +} + +#[cfg(any(target_arch = "wasm32", test))] +fn resolve_wasm_expression_projection_metric( + runtime: &SidemanticRuntime, + aliases: &HashMap, + projection: &str, +) -> Result { + let projection_norm = normalize_wasm_projection_expression(projection, aliases); + if projection_norm.is_empty() { + return Err(SidemanticError::Validation( + "wasm rewrite fallback received empty projection expression".into(), + )); + } + + let mut matches: Vec<(usize, String)> = Vec::new(); + + for model in runtime.graph.models() { + let model_name = model.name.clone(); + let projection_without_model = strip_wasm_model_qualifier(&projection_norm, &model_name); + + for metric in &model.metrics { + if metric.sql.is_none() { + continue; + } + + let metric_sql_norm = normalize_wasm_projection_expression(metric.sql_expr(), aliases); + let metric_sql_qualified_norm = normalize_wasm_projection_expression( + &format!("{}.{}", model.name, metric.sql_expr()), + aliases, + ); + let metric_sql_without_model = + strip_wasm_model_qualifier(&metric_sql_norm, &model_name); + + let rank = if projection_norm == metric_sql_qualified_norm { + Some(0usize) + } else if projection_norm == metric_sql_norm { + Some(1usize) + } else if projection_without_model == metric_sql_without_model { + Some(2usize) + } else { + None + }; + + if let Some(rank) = rank { + matches.push((rank, format!("{}.{}", model.name, metric.name))); + } + } + } + + if matches.is_empty() { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback projection '{projection}' requires a matching metric expression" + ))); + } + + let best_rank = matches + .iter() + .map(|(rank, _)| *rank) + .min() + .unwrap_or(usize::MAX); + let mut best_matches: Vec = matches + .into_iter() + .filter_map(|(rank, reference)| (rank == best_rank).then_some(reference)) + .collect(); + best_matches.sort(); + best_matches.dedup(); + + if best_matches.len() == 1 { + return Ok(best_matches[0].clone()); + } + + Err(SidemanticError::Validation(format!( + "wasm rewrite fallback projection '{projection}' is ambiguous; matches: {}", + best_matches.join(", ") + ))) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn build_wasm_formula_projection_from_aggregates( + runtime: &SidemanticRuntime, + aliases: &HashMap, + default_model: &str, + projection: &str, +) -> Result<(Vec, String)> { + let aggregate_pattern = Regex::new( + r"(?is)(?:count\s*\(\s*distinct\s+[^()]+\s*\)|(?:sum|avg|min|max|median|count)\s*\(\s*(?:[^()]+|\*)\s*\))", + ) + .expect("valid aggregate token regex"); + + let mut replacements: Vec<(usize, usize, String)> = Vec::new(); + let mut metric_refs: Vec = Vec::new(); + + for capture in aggregate_pattern.find_iter(projection) { + let aggregate_sql = capture.as_str(); + let semantic_ref = + resolve_wasm_aggregate_projection(runtime, aliases, default_model, aggregate_sql)?; + let metric_alias = semantic_ref + .split('.') + .next_back() + .map(str::to_string) + .unwrap_or_else(|| semantic_ref.clone()); + metric_refs.push(semantic_ref); + replacements.push((capture.start(), capture.end(), metric_alias)); + } + + if replacements.is_empty() { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback projection '{projection}' requires a matching metric expression" + ))); + } + + metric_refs.sort(); + metric_refs.dedup(); + + let mut rewritten = projection.to_string(); + for (start, end, replacement) in replacements.into_iter().rev() { + rewritten.replace_range(start..end, &replacement); + } + + Ok((metric_refs, rewritten)) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn parse_wasm_field_reference( + aliases: &HashMap, + default_model: &str, + reference: &str, +) -> Result<(String, String)> { + let trimmed = reference.trim(); + if trimmed.is_empty() { + return Err(SidemanticError::Validation( + "wasm rewrite fallback received empty field reference".into(), + )); + } + + let identifier_re = + Regex::new(r"(?i)^[A-Za-z_][A-Za-z0-9_]*$").expect("valid wasm identifier regex"); + + if let Some((model_part, field_part)) = trimmed.split_once('.') { + let model_key = normalize_wasm_match_key(model_part); + let field_key = normalize_wasm_match_key(field_part); + if !identifier_re.is_match(&field_key) { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback only supports simple field references in aggregation inputs: {trimmed}" + ))); + } + let resolved_model = aliases + .get(&normalize_ident(model_part)) + .cloned() + .unwrap_or(model_key); + return Ok((resolved_model, field_key)); + } + + let field_key = normalize_wasm_match_key(trimmed); + if !identifier_re.is_match(&field_key) { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback only supports simple field references in aggregation inputs: {trimmed}" + ))); + } + Ok((default_model.to_string(), field_key)) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn parse_wasm_count_input( + aliases: &HashMap, + default_model: &str, + inner_sql: &str, +) -> Result<(String, bool, Option)> { + let trimmed = inner_sql.trim(); + if trimmed == "*" || trimmed == "1" { + return Ok((default_model.to_string(), true, None)); + } + + if let Some((model_part, field_part)) = trimmed.split_once('.') { + let model_key = normalize_ident(model_part); + let field_key = normalize_wasm_match_key(field_part); + if field_key == "*" || field_key == "1" { + let resolved_model = aliases + .get(&model_key) + .cloned() + .unwrap_or_else(|| normalize_wasm_match_key(model_part)); + return Ok((resolved_model, true, None)); + } + } + + let (model_name, field_name) = parse_wasm_field_reference(aliases, default_model, trimmed)?; + Ok((model_name, false, Some(field_name))) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn resolve_wasm_aggregate_projection( + runtime: &SidemanticRuntime, + aliases: &HashMap, + default_model: &str, + projection: &str, +) -> Result { + let Some((agg_name, inner_expr)) = parse_simple_metric_aggregation(projection) else { + return resolve_wasm_expression_projection_metric(runtime, aliases, projection); + }; + + let target_agg = match agg_name.as_str() { + "sum" => Aggregation::Sum, + "avg" => Aggregation::Avg, + "min" => Aggregation::Min, + "max" => Aggregation::Max, + "median" => Aggregation::Median, + "count" => Aggregation::Count, + "count_distinct" => Aggregation::CountDistinct, + _ => { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback does not support function/aggregate projection '{projection}'" + ))); + } + }; + + let inner_expr = inner_expr.unwrap_or_default(); + let (model_name, is_count_star, field_name) = if target_agg == Aggregation::Count { + parse_wasm_count_input(aliases, default_model, &inner_expr)? + } else { + let (model_name, field_name) = + parse_wasm_field_reference(aliases, default_model, &inner_expr)?; + (model_name, false, Some(field_name)) + }; + + let model = runtime.graph.get_model(&model_name).ok_or_else(|| { + SidemanticError::Validation(format!( + "wasm rewrite fallback model '{model_name}' not found" + )) + })?; + + for metric in &model.metrics { + if metric.r#type != MetricType::Simple { + continue; + } + + let metric_agg = metric.agg.as_ref().unwrap_or(&Aggregation::Sum); + if metric_agg != &target_agg { + continue; + } + + if is_count_star { + let metric_sql = normalize_wasm_match_key(metric.sql_expr()); + if metric_sql == "*" + || metric_sql.is_empty() + || metric.name.eq_ignore_ascii_case("count") + { + return Ok(format!("{}.{}", model_name, metric.name)); + } + continue; + } + + let Some(field_name) = field_name.as_ref() else { + continue; + }; + let metric_sql = normalize_wasm_match_key(metric.sql_expr()); + let metric_name = normalize_wasm_match_key(&metric.name); + let qualified_field = format!("{}.{}", model_name, field_name); + + if metric_sql == *field_name || metric_sql == qualified_field || metric_name == *field_name + { + return Ok(format!("{}.{}", model_name, metric.name)); + } + } + + if is_count_star && target_agg == Aggregation::Count { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback count(*) requires '{model_name}.count' metric" + ))); + } + + Err(SidemanticError::Validation(format!( + "wasm rewrite fallback projection '{projection}' requires a matching metric on model '{model_name}'" + ))) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn parse_simple_from_aliases(from_clause: &str) -> Result<(String, HashMap)> { + let lower = from_clause.to_ascii_lowercase(); + for unsupported in [ + " join ", " left ", " right ", " inner ", " outer ", " cross ", " full ", " union ", + " group ", " order ", " having ", " limit ", " offset ", " with ", + ] { + if lower.contains(unsupported) { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback only supports SELECT ... FROM ... [WHERE ...] [ORDER BY ...] [LIMIT ...]; unsupported clause in FROM: {from_clause}" + ))); + } + } + if from_clause.contains(',') || from_clause.contains('(') || from_clause.contains(')') { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback only supports a single table in FROM: {from_clause}" + ))); + } + + let tokens: Vec<&str> = from_clause.split_whitespace().collect(); + let (model, alias) = match tokens.as_slice() { + [model] => (normalize_ident(model), None), + [model, alias] => (normalize_ident(model), Some(normalize_ident(alias))), + [model, as_kw, alias] if as_kw.eq_ignore_ascii_case("as") => { + (normalize_ident(model), Some(normalize_ident(alias))) + } + _ => { + return Err(SidemanticError::Validation(format!( + "wasm rewrite fallback could not parse FROM clause: {from_clause}" + ))); + } + }; + + let mut aliases = HashMap::new(); + aliases.insert(model.clone(), model.clone()); + if let Some(alias) = alias { + aliases.insert(alias, model.clone()); + } + Ok((model, aliases)) +} + +#[cfg(any(target_arch = "wasm32", test))] +fn rewrite_where_aliases(where_sql: &str, aliases: &HashMap) -> String { + let mut output = where_sql.to_string(); + for (alias, model) in aliases { + if alias == model { + continue; + } + let pattern = format!(r"\b{}\s*\.", regex::escape(alias)); + if let Ok(re) = Regex::new(&pattern) { + output = re.replace_all(&output, format!("{model}.")).to_string(); + } + } + output +} + +#[cfg(any(target_arch = "wasm32", test))] +fn rewrite_with_yaml_wasm_fallback(runtime: &SidemanticRuntime, sql: &str) -> Result { + let select_re = Regex::new( + r"(?is)^\s*select\s+(?P { + fn rewrite_select(&self, mut select: Select) -> Result { + if model_refs.len() != 1 + || select.with.is_some() + || !select.joins.is_empty() + || select.having.is_some() + || select.order_by.is_some() + || select.limit.is_some() + || select.offset.is_some() + || select.distinct + || select.group_by.is_none() + { + return Ok(select); + } - // Add GROUP BY if we have aggregations and dimensions - let has_aggregations = self.has_aggregations(&expressions); - let has_dimensions = self.has_non_aggregated_columns(&expressions); + let (model_name, model_alias) = (&model_refs[0].0, &model_refs[0].1); + let model = self.graph.get_model(model_name).ok_or_else(|| { + SidemanticError::Validation(format!("Model '{model_name}' not found")) + })?; - let group_by = if has_aggregations && has_dimensions { - Some(self.build_group_by(&expressions)) - } else { - select.group_by - }; + let mut dimensions: Vec<(Expression, String)> = Vec::new(); + let mut metrics: Vec<(Expression, String)> = Vec::new(); + + for projection in &select.expressions { + let Expression::Alias(alias) = projection else { + return Ok(select); + }; + + if self.is_aggregation(&alias.this) { + metrics.push((alias.this.clone(), alias.alias.name.clone())); + } else { + dimensions.push((alias.this.clone(), alias.alias.name.clone())); + } + } - Ok(Select { - expressions, - from, - joins, - where_clause, - group_by, - ..select - }) + if dimensions.is_empty() || metrics.is_empty() { + return Ok(select); + } + + let mut cte_select = Select::new(); + cte_select.from = select.from.clone(); + cte_select.where_clause = select.where_clause.clone(); + for primary_key in model.primary_keys() { + cte_select.expressions.push( + Expression::qualified_column(model_alias.clone(), primary_key.clone()) + .alias(primary_key), + ); + } + + for (dimension_expr, alias_name) in &dimensions { + cte_select + .expressions + .push(dimension_expr.clone().alias(alias_name.clone())); + } + + for (metric_expr, alias_name) in &metrics { + let Some(raw_expr) = extract_aggregate_input(metric_expr) else { + return Ok(select); + }; + cte_select + .expressions + .push(raw_expr.alias(format!("{alias_name}_raw"))); + } + + let cte_name = format!("{model_name}_cte"); + let cte_alias = cte_name.clone(); + + let mut outer_select = Select::new(); + outer_select.with = Some(With { + ctes: vec![Cte { + alias: Identifier::new(cte_name), + this: Expression::Select(Box::new(cte_select)), + columns: vec![], + materialized: None, + key_expressions: vec![], + alias_first: false, + }], + recursive: false, + leading_comments: vec![], + search: None, + }); + outer_select.from = Some(From { + expressions: vec![Expression::Table(table_ref_for(&cte_alias, None))], + }); + + for (_, alias_name) in &dimensions { + outer_select.expressions.push( + Expression::qualified_column(cte_alias.clone(), alias_name.clone()) + .alias(alias_name.clone()), + ); + } + + for (metric_expr, alias_name) in &metrics { + let raw_col = + Expression::qualified_column(cte_alias.clone(), format!("{alias_name}_raw")); + let Some(outer_metric_expr) = rebuild_aggregate_with_input(metric_expr, raw_col) else { + return Ok(select); + }; + outer_select + .expressions + .push(outer_metric_expr.alias(alias_name.clone())); + } + + outer_select.group_by = Some(GroupBy { + expressions: (1..=dimensions.len()) + .map(|i| Expression::number(i as i64)) + .collect(), + all: None, + totals: false, + }); + + Ok(outer_select) + } + + fn rewrite_nested_query_expr(&self, expr: Expression) -> Result { + match expr { + Expression::Select(select) => { + let rewritten = self.rewrite_select(*select)?; + Ok(Expression::Select(Box::new(rewritten))) + } + Expression::Subquery(mut subquery) => { + subquery.this = self.rewrite_nested_query_expr(subquery.this)?; + Ok(Expression::Subquery(subquery)) + } + Expression::Alias(mut alias) => { + alias.this = self.rewrite_nested_query_expr(alias.this)?; + Ok(Expression::Alias(alias)) + } + Expression::Paren(mut paren) => { + paren.this = self.rewrite_nested_query_expr(paren.this)?; + Ok(Expression::Paren(paren)) + } + Expression::JoinedTable(mut joined) => { + joined.left = self.rewrite_nested_query_expr(joined.left)?; + for join in &mut joined.joins { + join.this = self.rewrite_nested_query_expr(join.this.clone())?; + } + Ok(Expression::JoinedTable(joined)) + } + other => Ok(other), + } } /// Find semantic model references in FROM clause - fn find_model_references(&self, from: &Option) -> Vec<(String, String)> { + fn find_model_references( + &self, + from: Option<&polyglot_sql::expressions::From>, + ) -> Vec<(String, String)> { let mut refs = Vec::new(); - let Some(from) = from else { - return refs; - }; - - for expr in &from.expressions { - if let Expression::Table(table_ref) = expr { - let table_name = &table_ref.name.name; - - if self.graph.get_model(table_name).is_some() { - let alias_name = table_ref - .alias - .as_ref() - .map(|a| a.name.clone()) - .unwrap_or_else(|| table_name.clone()); - refs.push((table_name.clone(), alias_name)); + if let Some(from_clause) = from { + for source in &from_clause.expressions { + if let Some((table_name, alias)) = table_name_and_alias(source) { + if self.graph.get_model(&table_name).is_some() { + refs.push((table_name.clone(), alias.unwrap_or(table_name))); + } } } } @@ -153,26 +390,17 @@ impl<'a> QueryRewriter<'a> { let mut models = HashSet::new(); for item in projection { - match item { - Expression::Alias(alias) => { - self.collect_model_refs_from_expr(&alias.this, &mut models); - } - _ => { - self.collect_model_refs_from_expr(item, &mut models); - } - } + self.collect_model_refs_from_expr(item, &mut models); } models } - /// Recursively collect model references from an expression + /// Collect model references from an expression fn collect_model_refs_from_expr(&self, expr: &Expression, models: &mut HashSet) { - use polyglot_sql::ExpressionWalk; - - for node in expr.dfs() { - if let Expression::Column(col) = node { - if let Some(table) = &col.table { + for column_ref in traversal::get_columns(expr) { + if let Expression::Column(column) = column_ref { + if let Some(table) = &column.table { if self.graph.get_model(&table.name).is_some() { models.insert(table.name.clone()); } @@ -181,6 +409,66 @@ impl<'a> QueryRewriter<'a> { } } + fn collect_order_by_model_refs( + &self, + order_by: Option<&polyglot_sql::expressions::OrderBy>, + models: &mut HashSet, + ) { + if let Some(order_by) = order_by { + for ordered in &order_by.expressions { + self.collect_model_refs_from_expr(&ordered.this, models); + } + } + } + + fn projection_alias_lookup( + &self, + projection: &[Expression], + model_refs: &[(String, String)], + ) -> HashMap { + let mut aliases = HashMap::new(); + + for item in projection { + match item { + Expression::Star(_) if model_refs.len() == 1 => { + let (model_name, _) = (&model_refs[0].0, &model_refs[0].1); + if let Some(model) = self.graph.get_model(model_name) { + for dimension in &model.dimensions { + aliases.insert( + semantic_field_key(model_name, &dimension.name, None), + dimension.name.clone(), + ); + } + for metric in &model.metrics { + aliases.insert( + semantic_field_key(model_name, &metric.name, None), + metric.name.clone(), + ); + } + } + } + Expression::Alias(alias) => { + if let Some(key) = semantic_field_key_for_expr(&alias.this, model_refs) { + aliases.insert(key, alias.alias.name.clone()); + } + } + Expression::Column(column) => { + if let Some((model_name, _, base_field, granularity)) = + resolve_model_field(column, model_refs) + { + aliases.insert( + semantic_field_key(model_name, base_field, granularity), + column.name.name.clone(), + ); + } + } + _ => {} + } + } + + aliases + } + /// Rewrite SELECT projection items fn rewrite_projection( &self, @@ -191,20 +479,68 @@ impl<'a> QueryRewriter<'a> { for item in projection { match item { + Expression::Star(_) => { + if model_refs.len() != 1 { + return Err(SidemanticError::Validation( + "SELECT * requires a FROM clause with a single table".into(), + )); + } + + let (model_name, alias) = (&model_refs[0].0, &model_refs[0].1); + let model = self.graph.get_model(model_name).ok_or_else(|| { + SidemanticError::Validation(format!("Model '{model_name}' not found")) + })?; + + for dimension in &model.dimensions { + result.push( + Expression::qualified_column( + alias.clone(), + dimension.sql_expr().to_string(), + ) + .alias(dimension.name.clone()), + ); + } + for metric in &model.metrics { + result.push( + self.metric_to_expr(metric, alias) + .alias(metric.name.clone()), + ); + } + } Expression::Alias(alias) => { - let rewritten_inner = - self.rewrite_select_expr(alias.this.clone(), model_refs)?; - result.push(Expression::Alias(Box::new(Alias { - this: rewritten_inner, - alias: alias.alias.clone(), - column_aliases: alias.column_aliases.clone(), - pre_alias_comments: alias.pre_alias_comments.clone(), - trailing_comments: alias.trailing_comments.clone(), - }))); + let mut new_alias = alias.as_ref().clone(); + new_alias.this = self.rewrite_select_expr(new_alias.this, model_refs)?; + result.push(Expression::Alias(Box::new(new_alias))); } - Expression::Star(_) => result.push(item.clone()), - other => { - let rewritten = self.rewrite_select_expr(other.clone(), model_refs)?; + Expression::Column(column) => { + let rewritten = self.rewrite_select_expr(item.clone(), model_refs)?; + let alias_name = column_alias_name(column, model_refs) + .unwrap_or_else(|| column.name.name.clone()); + result.push(rewritten.alias(alias_name)); + } + Expression::Count(_) + | Expression::Sum(_) + | Expression::Avg(_) + | Expression::Min(_) + | Expression::Max(_) + | Expression::Median(_) + | Expression::AggregateFunction(_) => { + return Err(SidemanticError::Validation( + "Aggregate functions must be defined as a metric".into(), + )); + } + Expression::Function(_) => { + return Err(SidemanticError::Validation( + "Aggregate functions must be defined as a metric".into(), + )); + } + _ => { + let rewritten = self.rewrite_select_expr(item.clone(), model_refs)?; + if matches!(rewritten, Expression::Identifier(_)) { + return Err(SidemanticError::Validation( + "Query must select at least one metric or dimension".into(), + )); + } result.push(rewritten); } } @@ -219,38 +555,50 @@ impl<'a> QueryRewriter<'a> { expr: Expression, model_refs: &[(String, String)], ) -> Result { - match &expr { - Expression::Column(col) if col.table.is_some() => { - let table_ident = col.table.as_ref().unwrap(); - let model_name = &table_ident.name; - let field_name = &col.name.name; - - // Find the model - if let Some((actual_model, alias)) = model_refs - .iter() - .find(|(m, a)| m == model_name || a == model_name) - { - let model = self.graph.get_model(actual_model).unwrap(); - - // Check if it's a metric - if let Some(metric) = model.get_metric(field_name) { - return Ok(self.metric_to_expr(metric, alias)); - } + if let Expression::Column(column) = &expr { + if let Some((model_name, alias_name, base_field, granularity)) = + resolve_model_field(column, model_refs) + { + let model = self.graph.get_model(model_name).ok_or_else(|| { + SidemanticError::Validation(format!("Model '{model_name}' not found")) + })?; + + if let Some(metric) = model.get_metric(base_field) { + return Ok(self.metric_to_expr(metric, alias_name)); + } - // Check if it's a dimension - if let Some(dimension) = model.get_dimension(field_name) { - return Ok(Expression::qualified_column( - alias.as_str(), - dimension.sql_expr(), - )); - } + if let Some(dimension) = model.get_dimension(base_field) { + return Ok(dimension_to_expr(alias_name, dimension, granularity)); } - // Not a semantic reference, return as-is - Ok(expr) + return Err(SidemanticError::Validation(format!( + "Field '{model_name}.{base_field}' not found" + ))); } - _ => self.rewrite_expr(expr, model_refs), + + return Err(SidemanticError::Validation(format!( + "Cannot resolve column: {}", + column.name.name + ))); } + + if matches!( + expr, + Expression::Count(_) + | Expression::Sum(_) + | Expression::Avg(_) + | Expression::Min(_) + | Expression::Max(_) + | Expression::Median(_) + | Expression::AggregateFunction(_) + | Expression::Function(_) + ) { + return Err(SidemanticError::Validation( + "Aggregate functions must be defined as a metric".into(), + )); + } + + self.rewrite_expr(expr, model_refs) } /// Convert a metric to an expression @@ -259,217 +607,134 @@ impl<'a> QueryRewriter<'a> { MetricType::Simple => { // Handle Expression type: sql field contains the full expression if let Some(crate::core::Aggregation::Expression) = &metric.agg { - return self.parse_sql_fragment(metric.sql_expr()); + if let Some(expr) = parse_select_expr(metric.sql_expr()) { + return expr; + } + // Fallback: return as identifier + return Expression::identifier(metric.name.clone()); } - let agg = metric.agg.as_ref().unwrap(); - // COUNT without explicit sql defaults to COUNT(*) - let use_wildcard = metric.sql.as_deref() == Some("*") - || (*agg == crate::core::Aggregation::Count && metric.sql.is_none()); - - if use_wildcard { - return Expression::Count(Box::new(CountFunc { - this: None, - star: true, - distinct: false, - filter: None, - ignore_nulls: None, - original_name: None, - })); + if let Some(expr) = parse_select_expr(&metric.to_sql(Some(alias))) { + return expr; } - let col_expr = Expression::qualified_column(alias, metric.sql_expr().to_string()); - - let make_agg = |this: Expression| AggFunc { - this, - distinct: false, - filter: None, - order_by: vec![], - name: None, - ignore_nulls: None, - having_max: None, - limit: None, - }; - - match agg { - crate::core::Aggregation::Sum => Expression::Sum(Box::new(make_agg(col_expr))), - crate::core::Aggregation::Count => Expression::Count(Box::new(CountFunc { - this: Some(col_expr), - star: false, - distinct: false, - filter: None, - ignore_nulls: None, - original_name: None, - })), - crate::core::Aggregation::CountDistinct => { - Expression::Count(Box::new(CountFunc { - this: Some(col_expr), - star: false, - distinct: true, - filter: None, - ignore_nulls: None, - original_name: None, - })) - } - crate::core::Aggregation::Avg => Expression::Avg(Box::new(make_agg(col_expr))), - crate::core::Aggregation::Min => Expression::Min(Box::new(make_agg(col_expr))), - crate::core::Aggregation::Max => Expression::Max(Box::new(make_agg(col_expr))), - crate::core::Aggregation::Median => { - Expression::Median(Box::new(make_agg(col_expr))) - } - crate::core::Aggregation::Expression => unreachable!(), - } - } - MetricType::Derived | MetricType::Ratio => { - // For derived/ratio metrics, parse the SQL expression - self.parse_sql_fragment(metric.sql_expr()) + // Fallback: return as identifier + Expression::identifier(metric.name.clone()) } - MetricType::Cumulative | MetricType::TimeComparison => { - // Complex metric types require special handling with window functions - self.parse_sql_fragment(&metric.to_sql(Some(alias))) - } - } - } - - /// Parse a SQL expression fragment by wrapping in SELECT - fn parse_sql_fragment(&self, expr_sql: &str) -> Expression { - let sql = format!("SELECT {expr_sql}"); - if let Ok(expressions) = polyglot_sql::parse(&sql, DIALECT) { - if let Some(Expression::Select(select)) = expressions.into_iter().next() { - if let Some(expr) = select.expressions.into_iter().next() { - // Unwrap Alias if present - if let Expression::Alias(alias) = expr { - return alias.this; - } + MetricType::Derived + | MetricType::Ratio + | MetricType::Cumulative + | MetricType::TimeComparison + | MetricType::Conversion + | MetricType::Retention + | MetricType::Cohort => { + if let Some(expr) = parse_select_expr(&metric.to_sql(Some(alias))) { return expr; } + // Fallback: return as identifier + Expression::identifier(metric.name.clone()) } } - // Fallback: return as identifier - Expression::identifier(expr_sql) } /// Rewrite FROM clause with JOINs for cross-model references fn rewrite_from_with_joins( &self, - from: &Option, - existing_joins: &[Join], + select: &mut Select, model_refs: &[(String, String)], base_model: Option<&str>, models_to_join: &[String], - ) -> Result<(Option, Vec)> { - let Some(from) = from else { - return Ok((None, existing_joins.to_vec())); - }; - - let mut new_from_exprs = Vec::new(); - let mut new_joins = existing_joins.to_vec(); - - for expr in &from.expressions { - if let Expression::Table(table_ref) = expr { - let table_name = &table_ref.name.name; - - if let Some(model) = self.graph.get_model(table_name) { - // Build JOINs for models referenced but not in FROM - if Some(table_name.as_str()) == base_model { - for target_model_name in models_to_join { - if let Ok(join_path) = - self.graph.find_join_path(table_name, target_model_name) - { - for step in &join_path.steps { - let target_model = - self.graph.get_model(&step.to_model).unwrap(); - - // Find the alias for this model - let to_alias = model_refs - .iter() - .find(|(m, _)| m == &step.to_model) - .map(|(_, a)| a.clone()) - .unwrap_or_else(|| step.to_model.clone()); - - let from_alias = model_refs - .iter() - .find(|(m, _)| m == &step.from_model) - .map(|(_, a)| a.clone()) - .unwrap_or_else(|| step.from_model.clone()); - - // Build JOIN condition - let join_condition = - if let Some(custom) = &step.custom_condition { - let condition_sql = custom - .replace("{from}", &from_alias) - .replace("{to}", &to_alias); - self.parse_where_fragment(&condition_sql) - .unwrap_or_else(|| { - self.build_default_join_condition( - &from_alias, - &step.from_key, - &to_alias, - &step.to_key, - ) - }) - } else { - self.build_default_join_condition( - &from_alias, - &step.from_key, - &to_alias, - &step.to_key, - ) - }; - - let join_table = make_table_ref_with_alias( - target_model.table_name(), - &to_alias, - ); - - new_joins.push(Join { - this: Expression::Table(join_table), - on: Some(join_condition), - using: vec![], - kind: JoinKind::Left, - use_inner_keyword: false, - use_outer_keyword: true, - deferred_condition: false, - join_hint: None, - match_condition: None, - pivots: vec![], - comments: vec![], - nesting_group: 0, - directed: false, - }); - } - } - } - } + ) -> Result<()> { + // Rewrite base FROM semantic model table names + if let Some(from_clause) = &mut select.from { + for source in &mut from_clause.expressions { + self.rewrite_from_source(source); + } + } - let mut new_table = make_table_ref(model.table_name()); - new_table.alias = table_ref.alias.clone(); - new_table.alias_explicit_as = table_ref.alias_explicit_as; - new_from_exprs.push(Expression::Table(new_table)); - } else { - new_from_exprs.push(expr.clone()); + // Add auto-joins for referenced models + if let Some(base) = base_model { + for target_model_name in models_to_join { + if let Ok(join_path) = self.graph.find_join_path(base, target_model_name) { + for step in &join_path.steps { + let target_model = self.graph.get_model(&step.to_model).unwrap(); + + // Find aliases for this join step + let to_alias = model_refs + .iter() + .find(|(m, _)| m == &step.to_model) + .map(|(_, a)| a.clone()) + .unwrap_or_else(|| step.to_model.clone()); + + let from_alias = model_refs + .iter() + .find(|(m, _)| m == &step.from_model) + .map(|(_, a)| a.clone()) + .unwrap_or_else(|| step.from_model.clone()); + + let join_condition = if let Some(custom) = &step.custom_condition { + let condition_sql = custom + .replace("{from}", &from_alias) + .replace("{to}", &to_alias); + parse_where_expr(&condition_sql).unwrap_or_else(|| { + self.build_default_join_condition( + &from_alias, + &step.from_keys, + &to_alias, + &step.to_keys, + ) + .expect("join path keys already validated") + }) + } else { + self.build_default_join_condition( + &from_alias, + &step.from_keys, + &to_alias, + &step.to_keys, + )? + }; + + select.joins.push(Join { + this: Expression::Table(table_ref_for( + target_model.table_name(), + Some(&to_alias), + )), + on: Some(join_condition), + using: vec![], + kind: JoinKind::Left, + use_inner_keyword: false, + use_outer_keyword: false, + deferred_condition: false, + join_hint: None, + match_condition: None, + pivots: vec![], + }); + } } - } else { - new_from_exprs.push(expr.clone()); } } - Ok(( - Some(From { - expressions: new_from_exprs, - }), - new_joins, - )) + Ok(()) } - /// Parse a WHERE condition fragment - fn parse_where_fragment(&self, condition_sql: &str) -> Option { - let sql = format!("SELECT 1 WHERE {condition_sql}"); - let exprs = polyglot_sql::parse(&sql, DIALECT).ok()?; - if let Some(Expression::Select(s)) = exprs.into_iter().next() { - s.where_clause.map(|w| w.this) - } else { - None + fn rewrite_from_source(&self, source: &mut Expression) { + match source { + Expression::Table(table) => { + let model_name = table.name.name.clone(); + if let Some(model) = self.graph.get_model(&model_name) { + rewrite_table_ref_name(table, model.table_name()); + if table.alias.is_none() && model.table_name() != model_name { + table.alias = Some(Identifier::new(model_name)); + } + } + } + Expression::Alias(alias) => { + self.rewrite_from_source(&mut alias.this); + } + Expression::Paren(paren) => { + self.rewrite_from_source(&mut paren.this); + } + _ => {} } } @@ -477,106 +742,233 @@ impl<'a> QueryRewriter<'a> { fn build_default_join_condition( &self, from_alias: &str, - from_key: &str, + from_keys: &[String], to_alias: &str, - to_key: &str, - ) -> Expression { - Expression::Eq(Box::new(BinaryOp { - left: Expression::qualified_column(from_alias, from_key), - right: Expression::qualified_column(to_alias, to_key), - left_comments: vec![], - operator_comments: vec![], - trailing_comments: vec![], + to_keys: &[String], + ) -> Result { + if from_keys.is_empty() || to_keys.is_empty() { + return Err(SidemanticError::Validation( + "Join path is missing join key columns".to_string(), + )); + } + if from_keys.len() != to_keys.len() { + return Err(SidemanticError::Validation(format!( + "Join key column count mismatch: {} vs {}", + from_keys.len(), + to_keys.len() + ))); + } + + let mut conditions = from_keys + .iter() + .zip(to_keys.iter()) + .map(|(from_key, to_key)| { + Expression::Eq(Box::new(BinaryOp::new( + Expression::qualified_column(from_alias.to_string(), from_key.to_string()), + Expression::qualified_column(to_alias.to_string(), to_key.to_string()), + ))) + }); + + let first = conditions.next().expect("checked non-empty key lists"); + Ok(conditions.fold(first, |left, right| { + Expression::And(Box::new(BinaryOp::new(left, right))) })) } - /// Rewrite general expressions (WHERE clause, etc.) + /// Rewrite general expressions fn rewrite_expr( &self, expr: Expression, model_refs: &[(String, String)], ) -> Result { - let model_refs_vec = model_refs.to_vec(); - let graph = self.graph; - - polyglot_sql::transform_map(expr, &|node| { - if let Expression::Column(ref col) = node { - if let Some(ref table_ident) = col.table { - let table_name = &table_ident.name; - let field_name = &col.name.name; - - if let Some((actual_model, alias)) = model_refs_vec - .iter() - .find(|(m, a)| m.as_str() == table_name || a.as_str() == table_name) - { - if let Some(model) = graph.get_model(actual_model) { - if let Some(dimension) = model.get_dimension(field_name) { - return Ok(Expression::qualified_column( - alias.as_str(), - dimension.sql_expr(), - )); - } - } + match expr { + Expression::Column(column) => { + if let Some((model_name, alias_name, base_field, granularity)) = + resolve_model_field(&column, model_refs) + { + let model = self.graph.get_model(model_name).ok_or_else(|| { + SidemanticError::Validation(format!("Model '{model_name}' not found")) + })?; + + if let Some(metric) = model.get_metric(base_field) { + return Ok(self.metric_to_expr(metric, alias_name)); + } + + if let Some(dimension) = model.get_dimension(base_field) { + return Ok(dimension_to_expr(alias_name, dimension, granularity)); } } + + Ok(Expression::Column(column)) + } + Expression::Alias(mut alias) => { + alias.this = self.rewrite_expr(alias.this, model_refs)?; + Ok(Expression::Alias(alias)) + } + Expression::Paren(mut paren) => { + paren.this = self.rewrite_expr(paren.this, model_refs)?; + Ok(Expression::Paren(paren)) + } + + Expression::And(binary) => Ok(Expression::And(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Or(binary) => Ok(Expression::Or(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Add(binary) => Ok(Expression::Add(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Sub(binary) => Ok(Expression::Sub(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Mul(binary) => Ok(Expression::Mul(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Div(binary) => Ok(Expression::Div(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Mod(binary) => Ok(Expression::Mod(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Eq(binary) => Ok(Expression::Eq(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Neq(binary) => Ok(Expression::Neq(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Lt(binary) => Ok(Expression::Lt(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Lte(binary) => Ok(Expression::Lte(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Gt(binary) => Ok(Expression::Gt(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Gte(binary) => Ok(Expression::Gte(Box::new( + self.rewrite_binary_op(*binary, model_refs)?, + ))), + Expression::Like(mut like_op) => { + like_op.left = self.rewrite_expr(like_op.left, model_refs)?; + like_op.right = self.rewrite_expr(like_op.right, model_refs)?; + if let Some(escape) = like_op.escape.take() { + like_op.escape = Some(self.rewrite_expr(escape, model_refs)?); + } + Ok(Expression::Like(like_op)) + } + Expression::ILike(mut like_op) => { + like_op.left = self.rewrite_expr(like_op.left, model_refs)?; + like_op.right = self.rewrite_expr(like_op.right, model_refs)?; + if let Some(escape) = like_op.escape.take() { + like_op.escape = Some(self.rewrite_expr(escape, model_refs)?); + } + Ok(Expression::ILike(like_op)) + } + Expression::Not(mut unary) => { + unary.this = self.rewrite_expr(unary.this, model_refs)?; + Ok(Expression::Not(unary)) + } + Expression::Neg(mut unary) => { + unary.this = self.rewrite_expr(unary.this, model_refs)?; + Ok(Expression::Neg(unary)) + } + Expression::BitwiseNot(mut unary) => { + unary.this = self.rewrite_expr(unary.this, model_refs)?; + Ok(Expression::BitwiseNot(unary)) } - Ok(node) - }) - .map_err(|e| SidemanticError::SqlParse(e.to_string())) + Expression::In(mut in_expr) => { + in_expr.this = self.rewrite_expr(in_expr.this, model_refs)?; + in_expr.expressions = in_expr + .expressions + .into_iter() + .map(|e| self.rewrite_expr(e, model_refs)) + .collect::>>()?; + if let Some(query_expr) = in_expr.query.take() { + in_expr.query = Some(self.rewrite_expr(query_expr, model_refs)?); + } + if let Some(unnest_expr) = in_expr.unnest.take() { + in_expr.unnest = Some(Box::new(self.rewrite_expr(*unnest_expr, model_refs)?)); + } + Ok(Expression::In(in_expr)) + } + Expression::Between(mut between) => { + between.this = self.rewrite_expr(between.this, model_refs)?; + between.low = self.rewrite_expr(between.low, model_refs)?; + between.high = self.rewrite_expr(between.high, model_refs)?; + Ok(Expression::Between(between)) + } + Expression::IsNull(mut is_null) => { + is_null.this = self.rewrite_expr(is_null.this, model_refs)?; + Ok(Expression::IsNull(is_null)) + } + other => Ok(other), + } } - /// Check if projection has any aggregation functions - fn has_aggregations(&self, projection: &[Expression]) -> bool { - projection.iter().any(|item| { - let expr = match item { - Expression::Alias(a) => &a.this, - other => other, - }; - self.is_aggregation(expr) - }) + fn rewrite_binary_op( + &self, + mut binary: BinaryOp, + model_refs: &[(String, String)], + ) -> Result { + binary.left = self.rewrite_expr(binary.left, model_refs)?; + binary.right = self.rewrite_expr(binary.right, model_refs)?; + Ok(binary) } - /// Check if expression contains an aggregation function anywhere in its tree. - /// Uses polyglot-sql's own classification via DFS traversal rather than - /// maintaining a hardcoded list of function names. + fn rewrite_order_by_expr( + &self, + expr: Expression, + model_refs: &[(String, String)], + projection_aliases: &HashMap, + output_aliases: &HashSet, + ) -> Result { + if let Expression::Column(column) = &expr { + if column.table.is_none() && output_aliases.contains(&column.name.name) { + return Ok(expr); + } + + if let Some((model_name, _, base_field, granularity)) = + resolve_model_field(column, model_refs) + { + let key = semantic_field_key(model_name, base_field, granularity); + if let Some(alias) = projection_aliases.get(&key) { + return Ok(Expression::identifier(alias.clone())); + } + } + } + + self.rewrite_expr(expr, model_refs) + } + + /// Check if expression is an aggregation function fn is_aggregation(&self, expr: &Expression) -> bool { - expr.dfs().any(|node| { - matches!( - node, - Expression::Sum(_) - | Expression::Count(_) - | Expression::Avg(_) - | Expression::Min(_) - | Expression::Max(_) - | Expression::Median(_) - | Expression::AggregateFunction(_) - ) - }) + match expr { + Expression::Alias(alias) => self.is_aggregation(&alias.this), + Expression::Count(_) + | Expression::Sum(_) + | Expression::Avg(_) + | Expression::Min(_) + | Expression::Max(_) + | Expression::Median(_) + | Expression::AggregateFunction(_) => true, + _ => false, + } } /// Check if projection has non-aggregated columns fn has_non_aggregated_columns(&self, projection: &[Expression]) -> bool { - projection.iter().any(|item| { - let expr = match item { - Expression::Alias(a) => &a.this, - other => other, - }; - !self.is_aggregation(expr) - }) + projection.iter().any(|expr| !self.is_aggregation(expr)) } /// Build GROUP BY clause from non-aggregated columns fn build_group_by(&self, projection: &[Expression]) -> GroupBy { let mut group_by_exprs = Vec::new(); - for (i, item) in projection.iter().enumerate() { - let expr = match item { - Expression::Alias(a) => &a.this, - other => other, - }; + for (i, expr) in projection.iter().enumerate() { if !self.is_aggregation(expr) { // Use positional reference - group_by_exprs.push(Expression::Literal(Literal::Number((i + 1).to_string()))); + group_by_exprs.push(Expression::number((i + 1) as i64)); } } @@ -584,27 +976,311 @@ impl<'a> QueryRewriter<'a> { expressions: group_by_exprs, all: None, totals: false, - comments: vec![], } } } -/// Build a TableRef from a possibly-qualified table name -fn make_table_ref(full_name: &str) -> TableRef { - let parts: Vec<&str> = full_name.split('.').collect(); - match parts.len() { - 2 => TableRef::new_with_schema(parts[1], parts[0]), - 3 => TableRef::new_with_catalog(parts[2], parts[1], parts[0]), - _ => TableRef::new(full_name), +fn parse_sql_with_large_stack(sql: &str) -> Result> { + #[cfg(target_arch = "wasm32")] + { + let _ = sql; + return Err(SidemanticError::SqlParse( + "operation not supported on this platform".to_string(), + )); + } + + #[cfg(not(target_arch = "wasm32"))] + { + let sql_owned = sql.to_string(); + let handle = std::thread::Builder::new() + .stack_size(16 * 1024 * 1024) + .spawn(move || { + polyglot_parse(&sql_owned, DialectType::Generic).map_err(|e| e.to_string()) + }) + .map_err(|e| SidemanticError::SqlParse(e.to_string()))?; + + let parse_result = handle + .join() + .map_err(|_| SidemanticError::SqlParse("Polyglot parser thread panicked".into()))?; + + parse_result.map_err(SidemanticError::SqlParse) } } -/// Build a TableRef with an alias -fn make_table_ref_with_alias(full_name: &str, alias: &str) -> TableRef { - let mut table = make_table_ref(full_name); - table.alias = Some(Identifier::new(alias)); - table.alias_explicit_as = true; - table +fn parse_select_expr(expr_sql: &str) -> Option { + let sql = format!("SELECT {expr_sql}"); + let statements = parse_sql_with_large_stack(&sql).ok()?; + let statement = statements.first()?; + let select = statement.as_select()?; + select.expressions.first().cloned() +} + +fn parse_where_expr(condition_sql: &str) -> Option { + let sql = format!("SELECT 1 WHERE {condition_sql}"); + let statements = parse_sql_with_large_stack(&sql).ok()?; + let statement = statements.first()?; + let select = statement.as_select()?; + select.where_clause.as_ref().map(|w| w.this.clone()) +} + +fn expr_to_sql(expr: &Expression) -> Result { + polyglot_generate(expr, DialectType::Generic) + .map_err(|e| SidemanticError::SqlGeneration(e.to_string())) +} + +fn resolve_model_ref<'a>( + table_or_alias: &str, + model_refs: &'a [(String, String)], +) -> Option<(&'a str, &'a str)> { + model_refs + .iter() + .find(|(model, alias)| model == table_or_alias || alias == table_or_alias) + .map(|(model, alias)| (model.as_str(), alias.as_str())) +} + +fn resolve_model_field<'a>( + column: &'a polyglot_sql::expressions::Column, + model_refs: &'a [(String, String)], +) -> Option<(&'a str, &'a str, &'a str, Option<&'a str>)> { + let (base_field, granularity) = split_granularity(&column.name.name); + + if let Some(table) = &column.table { + if let Some((model, alias)) = resolve_model_ref(&table.name, model_refs) { + return Some((model, alias, base_field, granularity)); + } + return None; + } + + if model_refs.len() == 1 { + let (model, alias) = (&model_refs[0].0, &model_refs[0].1); + return Some((model.as_str(), alias.as_str(), base_field, granularity)); + } + + None +} + +type SemanticFieldKey = (String, String, Option); + +fn semantic_field_key( + model_name: &str, + base_field: &str, + granularity: Option<&str>, +) -> SemanticFieldKey { + ( + model_name.to_string(), + base_field.to_string(), + granularity.map(ToString::to_string), + ) +} + +fn semantic_field_key_for_expr( + expr: &Expression, + model_refs: &[(String, String)], +) -> Option { + let Expression::Column(column) = expr else { + return None; + }; + let (model_name, _, base_field, granularity) = resolve_model_field(column, model_refs)?; + Some(semantic_field_key(model_name, base_field, granularity)) +} + +fn split_granularity(field: &str) -> (&str, Option<&str>) { + const VALID_GRANULARITIES: [&str; 8] = [ + "year", "quarter", "month", "week", "day", "hour", "minute", "second", + ]; + + if let Some((base, gran)) = field.rsplit_once("__") { + if VALID_GRANULARITIES.contains(&gran) { + return (base, Some(gran)); + } + } + + (field, None) +} + +fn column_alias_name( + column: &polyglot_sql::expressions::Column, + model_refs: &[(String, String)], +) -> Option { + resolve_model_field(column, model_refs).map(|(_, _, _, _)| column.name.name.clone()) +} + +fn has_star_projection(projection: &[Expression]) -> bool { + projection + .iter() + .any(|expr| matches!(expr, Expression::Star(_))) +} + +fn extract_aggregate_input(expr: &Expression) -> Option { + match expr { + Expression::Sum(agg) + | Expression::Avg(agg) + | Expression::Min(agg) + | Expression::Max(agg) + | Expression::Median(agg) => Some(agg.this.clone()), + Expression::AggregateFunction(func) => { + let func_name = func.name.to_uppercase(); + match func_name.as_str() { + "SUM" | "AVG" | "MIN" | "MAX" | "MEDIAN" => func.args.first().cloned(), + "COUNT" => { + if let Some(arg) = func.args.first() { + if matches!(arg, Expression::Star(_)) { + Some(Expression::number(1)) + } else { + Some(arg.clone()) + } + } else { + Some(Expression::number(1)) + } + } + _ => None, + } + } + Expression::Count(count) => { + if count.star || count.this.is_none() { + Some(Expression::number(1)) + } else { + count.this.clone() + } + } + _ => None, + } +} + +fn rebuild_aggregate_with_input(expr: &Expression, input: Expression) -> Option { + match expr.clone() { + Expression::Sum(mut agg) => { + agg.this = input; + Some(Expression::Sum(agg)) + } + Expression::Avg(mut agg) => { + agg.this = input; + Some(Expression::Avg(agg)) + } + Expression::Min(mut agg) => { + agg.this = input; + Some(Expression::Min(agg)) + } + Expression::Max(mut agg) => { + agg.this = input; + Some(Expression::Max(agg)) + } + Expression::Median(mut agg) => { + agg.this = input; + Some(Expression::Median(agg)) + } + Expression::Count(mut count) => { + count.this = Some(input); + count.star = false; + Some(Expression::Count(count)) + } + Expression::AggregateFunction(mut func) => { + let func_name = func.name.to_uppercase(); + match func_name.as_str() { + "SUM" | "AVG" | "MIN" | "MAX" | "MEDIAN" | "COUNT" => { + func.args = vec![input]; + Some(Expression::AggregateFunction(func)) + } + _ => None, + } + } + _ => None, + } +} + +fn is_from_metrics(from: Option<&From>) -> bool { + let Some(from_clause) = from else { + return false; + }; + if from_clause.expressions.len() != 1 { + return false; + } + matches!( + &from_clause.expressions[0], + Expression::Table(table) if table.name.name == "metrics" + ) +} + +fn unique_alias(model_name: &str, existing: &HashSet) -> String { + let base = model_name.chars().next().unwrap_or('t').to_string(); + if !existing.contains(&base) { + return base; + } + + let mut i = 2; + loop { + let candidate = format!("{base}{i}"); + if !existing.contains(&candidate) { + return candidate; + } + i += 1; + } +} + +fn dimension_to_expr( + alias_name: &str, + dimension: &crate::core::Dimension, + granularity: Option<&str>, +) -> Expression { + if let Some(gran) = granularity.filter(|_| dimension.r#type == DimensionType::Time) { + let sql = format!( + "DATE_TRUNC('{gran}', {}.{})", + alias_name, + dimension.sql_expr() + ); + if let Some(expr) = parse_select_expr(&sql) { + return expr; + } + } + + Expression::qualified_column(alias_name.to_string(), dimension.sql_expr().to_string()) +} + +fn table_name_and_alias(source: &Expression) -> Option<(String, Option)> { + match source { + Expression::Table(table) => Some(( + table.name.name.clone(), + table.alias.as_ref().map(|a| a.name.clone()), + )), + Expression::Alias(alias) => { + if let Expression::Table(table) = &alias.this { + Some((table.name.name.clone(), Some(alias.alias.name.clone()))) + } else { + None + } + } + Expression::Paren(paren) => table_name_and_alias(&paren.this), + _ => None, + } +} + +fn table_ref_for(table_name: &str, alias: Option<&str>) -> TableRef { + let mut table_ref = TableRef::new(""); + rewrite_table_ref_name(&mut table_ref, table_name); + table_ref.alias = alias.map(Identifier::new); + table_ref +} + +fn rewrite_table_ref_name(table_ref: &mut TableRef, table_name: &str) { + let parts: Vec<&str> = table_name.split('.').collect(); + match parts.len() { + 0 => {} + 1 => { + table_ref.catalog = None; + table_ref.schema = None; + table_ref.name = Identifier::new(parts[0]); + } + 2 => { + table_ref.catalog = None; + table_ref.schema = Some(Identifier::new(parts[0])); + table_ref.name = Identifier::new(parts[1]); + } + _ => { + table_ref.catalog = Some(Identifier::new(parts[parts.len() - 3])); + table_ref.schema = Some(Identifier::new(parts[parts.len() - 2])); + table_ref.name = Identifier::new(parts[parts.len() - 1]); + } + } } #[cfg(test)] @@ -642,9 +1318,9 @@ mod tests { let sql = "SELECT orders.revenue, orders.status FROM orders"; let rewritten = rewriter.rewrite(sql).unwrap(); - assert!(rewritten.contains("public") || rewritten.contains("orders")); - assert!(rewritten.to_uppercase().contains("SUM(")); - assert!(rewritten.to_uppercase().contains("GROUP BY")); + assert!(rewritten.contains("public.orders")); + assert!(rewritten.contains("SUM(")); + assert!(rewritten.contains("GROUP BY")); } #[test] @@ -655,7 +1331,7 @@ mod tests { let sql = "SELECT o.revenue, o.status FROM orders AS o"; let rewritten = rewriter.rewrite(sql).unwrap(); - assert!(rewritten.contains("public") || rewritten.contains("orders")); + assert!(rewritten.contains("public.orders")); } #[test] @@ -666,7 +1342,7 @@ mod tests { let sql = "SELECT orders.revenue FROM orders WHERE orders.status = 'completed'"; let rewritten = rewriter.rewrite(sql).unwrap(); - assert!(rewritten.to_uppercase().contains("WHERE")); + assert!(rewritten.contains("WHERE")); assert!(rewritten.contains("status")); } @@ -690,6 +1366,34 @@ mod tests { ); } + #[test] + fn test_cross_model_composite_join() { + let mut graph = SemanticGraph::new(); + + let shipments = Model::new("shipments", "shipment_id") + .with_table("public.shipments") + .with_metric(Metric::count("shipment_count")) + .with_relationship(Relationship::many_to_one("order_items").with_key_columns( + vec!["order_id".to_string(), "item_id".to_string()], + vec!["order_id".to_string(), "item_id".to_string()], + )); + let order_items = Model::new("order_items", "order_id") + .with_primary_key_columns(vec!["order_id".to_string(), "item_id".to_string()]) + .with_table("public.order_items") + .with_dimension(Dimension::categorical("sku")); + + graph.add_model(shipments).unwrap(); + graph.add_model(order_items).unwrap(); + + let rewriter = QueryRewriter::new(&graph); + let sql = "SELECT shipments.shipment_count, order_items.sku FROM shipments"; + let rewritten = rewriter.rewrite(sql).unwrap(); + + assert!(rewritten.contains("shipments.order_id = o.order_id")); + assert!(rewritten.contains("shipments.item_id = o.item_id")); + assert!(rewritten.to_uppercase().contains(" AND ")); + } + #[test] fn test_cross_model_join_in_where() { let graph = create_test_graph(); @@ -733,12 +1437,83 @@ mod tests { // Should be COUNT(*) not COUNT(order_count) assert!( - rewritten.to_uppercase().contains("COUNT(*)"), + rewritten.contains("COUNT(*)"), "Expected COUNT(*) but got: {rewritten}" ); assert!( - !rewritten.contains("order_count"), - "Should not contain order_count in COUNT: {rewritten}" + !rewritten.contains("COUNT(order_count)"), + "Should not count order_count column directly: {rewritten}" + ); + } + + #[test] + fn test_wrap_simple_select_with_cte_preserves_composite_primary_keys() { + let mut graph = SemanticGraph::new(); + + let order_items = Model::new("order_items", "order_id") + .with_primary_key_columns(vec!["order_id".to_string(), "item_id".to_string()]) + .with_table("public.order_items") + .with_dimension(Dimension::categorical("sku")) + .with_metric(Metric::sum("item_revenue", "amount")); + + graph.add_model(order_items).unwrap(); + let rewriter = QueryRewriter::new(&graph); + + let sql = "SELECT order_items.sku, order_items.item_revenue FROM order_items"; + let rewritten = rewriter.rewrite(sql).unwrap(); + + assert!(rewritten.contains("order_items.order_id AS order_id")); + assert!(rewritten.contains("order_items.item_id AS item_id")); + assert!(rewritten.contains("GROUP BY 1")); + } + + #[test] + fn test_order_by_projected_semantic_refs_rewrite_to_output_aliases() { + let graph = create_test_graph(); + let rewriter = QueryRewriter::new(&graph); + + let sql = "SELECT orders.revenue AS total, orders.status FROM orders ORDER BY orders.revenue DESC, orders.status ASC"; + let rewritten = rewriter.rewrite(sql).unwrap(); + + assert!( + rewritten.contains("ORDER BY total DESC, status ASC"), + "expected ORDER BY aliases, got: {rewritten}" + ); + assert!( + !rewritten.contains("ORDER BY orders.revenue"), + "semantic metric reference leaked into ORDER BY: {rewritten}" + ); + } + + #[test] + fn test_order_by_projected_alias_is_preserved() { + let graph = create_test_graph(); + let rewriter = QueryRewriter::new(&graph); + + let sql = "SELECT orders.revenue AS total, orders.status FROM orders ORDER BY total DESC"; + let rewritten = rewriter.rewrite(sql).unwrap(); + + assert!( + rewritten.contains("ORDER BY total DESC"), + "expected projected alias ORDER BY, got: {rewritten}" + ); + } + + #[test] + fn test_order_by_order_only_semantic_metric_rewrites_to_aggregate() { + let graph = create_test_graph(); + let rewriter = QueryRewriter::new(&graph); + + let sql = "SELECT orders.status FROM orders ORDER BY orders.revenue DESC"; + let rewritten = rewriter.rewrite(sql).unwrap(); + + assert!( + rewritten.contains("ORDER BY SUM(orders.amount) DESC"), + "expected aggregate ORDER BY, got: {rewritten}" + ); + assert!( + !rewritten.contains("ORDER BY orders.revenue"), + "semantic metric reference leaked into ORDER BY: {rewritten}" ); } } diff --git a/sidemantic-rs/src/wasm.rs b/sidemantic-rs/src/wasm.rs new file mode 100644 index 00000000..a5f6e397 --- /dev/null +++ b/sidemantic-rs/src/wasm.rs @@ -0,0 +1,597 @@ +//! wasm-bindgen exports for pure Rust runtime orchestration. + +use wasm_bindgen::prelude::*; + +use crate::runtime::{ + analyze_migrator_query, build_preaggregation_refresh_statements, + build_symmetric_aggregate_sql as build_symmetric_aggregate_sql_runtime, + calculate_preaggregation_benefit_score, chart_auto_detect_columns, chart_encoding_type, + chart_format_label, chart_select_type, compile_with_yaml_query, detect_adapter_kind, + dimension_sql_expr_with_yaml, dimension_with_granularity_with_yaml, + evaluate_table_calculation_expression, extract_column_references, + extract_metric_dependencies_from_yaml, extract_preaggregation_patterns, find_models_for_query, + find_relationship_path_with_yaml, format_parameter_value_with_yaml, + generate_catalog_metadata_with_yaml, generate_preaggregation_definition, + generate_preaggregation_materialization_sql_with_yaml, generate_preaggregation_name, + generate_time_comparison_sql, interpolate_sql_with_parameters_with_yaml, is_relative_date, + is_sql_template, load_graph_with_sql, load_graph_with_yaml, metric_is_simple_aggregation, + metric_sql_expr, metric_to_sql, model_find_dimension_index_with_yaml, + model_find_metric_index_with_yaml, model_find_pre_aggregation_index_with_yaml, + model_find_segment_index_with_yaml, model_get_drill_down_with_yaml, + model_get_drill_up_with_yaml, model_get_hierarchy_path_with_yaml, + needs_symmetric_aggregate as needs_symmetric_aggregate_runtime, parse_reference_with_yaml, + parse_relative_date, parse_simple_metric_aggregation, parse_sql_definitions_payload, + parse_sql_graph_definitions_payload, parse_sql_model_payload, + parse_sql_statement_blocks_payload, recommend_preaggregation_patterns, + relationship_foreign_key_columns_with_yaml, relationship_primary_key_columns_with_yaml, + relationship_related_key_with_yaml, relationship_sql_expr_with_yaml, relative_date_to_range, + render_sql_template, resolve_metric_inheritance, resolve_model_inheritance_with_yaml, + rewrite_with_yaml, segment_get_sql_with_yaml, summarize_preaggregation_patterns, + time_comparison_offset_interval, time_comparison_sql_offset, trailing_period_sql_interval, + validate_engine_refresh_sql_compatibility, validate_metric_payload, validate_model_payload, + validate_models_yaml, validate_parameter_payload, validate_query_references_with_yaml, + validate_query_with_yaml, validate_table_calculation_payload, + validate_table_formula_expression, +}; + +fn wasm_error(err: impl std::fmt::Display) -> JsValue { + JsValue::from_str(&err.to_string()) +} + +#[wasm_bindgen] +pub fn wasm_compile_with_yaml_query(yaml: &str, query_yaml: &str) -> Result { + compile_with_yaml_query(yaml, query_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_build_symmetric_aggregate_sql( + measure_expr: &str, + primary_key: &str, + agg_type: &str, + model_alias: Option, + dialect: &str, +) -> Result { + build_symmetric_aggregate_sql_runtime( + measure_expr, + primary_key, + agg_type, + model_alias.as_deref(), + dialect, + ) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_needs_symmetric_aggregate(relationship_type: &str, is_base_model: bool) -> bool { + needs_symmetric_aggregate_runtime(relationship_type, is_base_model) +} + +#[wasm_bindgen] +pub fn wasm_rewrite_with_yaml(yaml: &str, sql: &str) -> Result { + rewrite_with_yaml(yaml, sql).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_query_with_yaml(yaml: &str, query_yaml: &str) -> Result { + let errors = validate_query_with_yaml(yaml, query_yaml).map_err(wasm_error)?; + serde_json::to_string(&errors).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_load_graph_with_yaml(yaml: &str) -> Result { + load_graph_with_yaml(yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_load_graph_with_sql(sql_content: &str) -> Result { + load_graph_with_sql(sql_content).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_query_references_with_yaml( + yaml: &str, + metrics_json: &str, + dimensions_json: &str, +) -> Result { + let metrics: Vec = serde_json::from_str(metrics_json).map_err(wasm_error)?; + let dimensions: Vec = serde_json::from_str(dimensions_json).map_err(wasm_error)?; + let errors = + validate_query_references_with_yaml(yaml, &metrics, &dimensions).map_err(wasm_error)?; + serde_json::to_string(&errors).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_query_references( + yaml: &str, + metrics_json: &str, + dimensions_json: &str, +) -> Result { + wasm_validate_query_references_with_yaml(yaml, metrics_json, dimensions_json) +} + +#[wasm_bindgen] +pub fn wasm_generate_catalog_metadata_with_yaml( + yaml: &str, + schema: &str, +) -> Result { + generate_catalog_metadata_with_yaml(yaml, schema).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_generate_preaggregation_materialization_sql_with_yaml( + yaml: &str, + model_name: &str, + preagg_name: &str, +) -> Result { + generate_preaggregation_materialization_sql_with_yaml(yaml, model_name, preagg_name) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_extract_column_references(sql_expr: &str) -> String { + serde_json::to_string(&extract_column_references(sql_expr)).unwrap_or_else(|_| "[]".to_string()) +} + +#[wasm_bindgen] +pub fn wasm_find_models_for_query( + dimensions_json: &str, + metrics_json: &str, +) -> Result { + let dimensions: Vec = serde_json::from_str(dimensions_json).map_err(wasm_error)?; + let metrics: Vec = serde_json::from_str(metrics_json).map_err(wasm_error)?; + let models = find_models_for_query(&dimensions, &metrics); + serde_json::to_string(&models).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_find_relationship_path_with_yaml( + yaml: &str, + from_model: &str, + to_model: &str, +) -> Result { + let path = find_relationship_path_with_yaml(yaml, from_model, to_model).map_err(wasm_error)?; + serde_json::to_string(&path).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_render_sql_template(template: &str, context_yaml: &str) -> Result { + render_sql_template(template, context_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_sql_definitions_payload(sql: &str) -> Result { + parse_sql_definitions_payload(sql).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_sql_graph_definitions_payload(sql: &str) -> Result { + parse_sql_graph_definitions_payload(sql).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_sql_model_payload(sql: &str) -> Result { + parse_sql_model_payload(sql).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_sql_statement_blocks_payload(sql: &str) -> Result { + parse_sql_statement_blocks_payload(sql).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_analyze_migrator_query(sql_query: &str) -> Result { + analyze_migrator_query(sql_query).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_chart_auto_detect_columns( + columns_json: &str, + numeric_flags_json: &str, +) -> Result { + let columns: Vec = serde_json::from_str(columns_json).map_err(wasm_error)?; + let numeric_flags: Vec = serde_json::from_str(numeric_flags_json).map_err(wasm_error)?; + let (x, y) = chart_auto_detect_columns(&columns, &numeric_flags).map_err(wasm_error)?; + serde_json::to_string(&serde_json::json!({ + "x": x, + "y": y, + })) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_chart_select_type(x: &str, x_value_kind: &str, y_count: usize) -> String { + chart_select_type(x, x_value_kind, y_count) +} + +#[wasm_bindgen] +pub fn wasm_chart_encoding_type(column: &str) -> String { + chart_encoding_type(column) +} + +#[wasm_bindgen] +pub fn wasm_chart_format_label(column: &str) -> String { + chart_format_label(column) +} + +#[wasm_bindgen] +pub fn wasm_validate_model_payload(model_yaml: &str) -> Result { + validate_model_payload(model_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_metric_payload(metric_yaml: &str) -> Result { + validate_metric_payload(metric_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_parameter_payload(parameter_yaml: &str) -> Result { + validate_parameter_payload(parameter_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_table_calculation_payload(calculation_yaml: &str) -> Result { + validate_table_calculation_payload(calculation_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_relative_date(expr: &str, dialect: &str) -> Result { + serde_json::to_string(&parse_relative_date(expr, dialect)).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_relative_date_to_range( + expr: &str, + column: &str, + dialect: &str, +) -> Result { + serde_json::to_string(&relative_date_to_range(expr, column, dialect)).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_is_relative_date(expr: &str) -> bool { + is_relative_date(expr) +} + +#[wasm_bindgen] +pub fn wasm_time_comparison_offset_interval( + comparison_type: &str, + offset: Option, + offset_unit: Option, +) -> Result { + let (amount, unit) = + time_comparison_offset_interval(comparison_type, offset, offset_unit.as_deref()) + .map_err(wasm_error)?; + serde_json::to_string(&serde_json::json!({ + "amount": amount, + "unit": unit, + })) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_time_comparison_sql_offset( + comparison_type: &str, + offset: Option, + offset_unit: Option, +) -> Result { + time_comparison_sql_offset(comparison_type, offset, offset_unit.as_deref()).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_trailing_period_sql_interval(amount: i64, unit: &str) -> Result { + trailing_period_sql_interval(amount, unit).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_generate_time_comparison_sql( + comparison_type: &str, + calculation: &str, + current_metric_sql: &str, + time_dimension: &str, + offset: Option, + offset_unit: Option, +) -> Result { + generate_time_comparison_sql( + comparison_type, + calculation, + current_metric_sql, + time_dimension, + offset, + offset_unit.as_deref(), + ) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_is_sql_template(sql: &str) -> bool { + is_sql_template(sql) +} + +#[wasm_bindgen] +pub fn wasm_format_parameter_value_with_yaml( + parameter_yaml: &str, + value_yaml: &str, +) -> Result { + format_parameter_value_with_yaml(parameter_yaml, value_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_interpolate_sql_with_parameters_with_yaml( + sql_template: &str, + parameters_yaml: &str, + values_yaml: &str, +) -> Result { + interpolate_sql_with_parameters_with_yaml(sql_template, parameters_yaml, values_yaml) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_detect_adapter_kind(path: &str, content: &str) -> Result { + serde_json::to_string(&detect_adapter_kind(path, content)).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_simple_metric_aggregation(sql_expr: &str) -> Result { + serde_json::to_string(&parse_simple_metric_aggregation(sql_expr)).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_metric_to_sql(metric_yaml: &str) -> Result { + metric_to_sql(metric_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_metric_sql_expr(metric_yaml: &str) -> Result { + metric_sql_expr(metric_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_metric_is_simple_aggregation(metric_yaml: &str) -> Result { + metric_is_simple_aggregation(metric_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_extract_preaggregation_patterns(queries_json: &str) -> Result { + let queries: Vec = serde_json::from_str(queries_json).map_err(wasm_error)?; + extract_preaggregation_patterns(queries).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_recommend_preaggregation_patterns( + patterns_json: &str, + min_count: usize, + min_benefit_score: f64, + max_recommendations: Option, +) -> Result { + recommend_preaggregation_patterns( + patterns_json, + min_count, + min_benefit_score, + max_recommendations, + ) + .map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_calculate_preaggregation_benefit_score( + pattern_json: &str, + count: usize, +) -> Result { + calculate_preaggregation_benefit_score(pattern_json, count).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_generate_preaggregation_name(pattern_json: &str) -> Result { + generate_preaggregation_name(pattern_json).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_generate_preaggregation_definition( + recommendation_json: &str, +) -> Result { + generate_preaggregation_definition(recommendation_json).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_summarize_preaggregation_patterns( + patterns_json: &str, + min_count: usize, +) -> Result { + summarize_preaggregation_patterns(patterns_json, min_count).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_models_yaml(yaml: &str) -> Result { + validate_models_yaml(yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_parse_reference_with_yaml(yaml: &str, reference: &str) -> Result { + let parsed = parse_reference_with_yaml(yaml, reference).map_err(wasm_error)?; + serde_json::to_string(&parsed).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_resolve_model_inheritance_with_yaml(yaml: &str) -> Result { + resolve_model_inheritance_with_yaml(yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_resolve_metric_inheritance(metrics_yaml: &str) -> Result { + resolve_metric_inheritance(metrics_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_dimension_sql_expr_with_yaml(dimension_yaml: &str) -> Result { + dimension_sql_expr_with_yaml(dimension_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_dimension_with_granularity_with_yaml( + dimension_yaml: &str, + granularity: &str, +) -> Result { + dimension_with_granularity_with_yaml(dimension_yaml, granularity).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_get_hierarchy_path_with_yaml( + model_yaml: &str, + dimension_name: &str, +) -> Result { + let path = + model_get_hierarchy_path_with_yaml(model_yaml, dimension_name).map_err(wasm_error)?; + serde_json::to_string(&path).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_get_drill_down_with_yaml( + model_yaml: &str, + dimension_name: &str, +) -> Result { + let drill = model_get_drill_down_with_yaml(model_yaml, dimension_name).map_err(wasm_error)?; + serde_json::to_string(&drill).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_get_drill_up_with_yaml( + model_yaml: &str, + dimension_name: &str, +) -> Result { + let drill = model_get_drill_up_with_yaml(model_yaml, dimension_name).map_err(wasm_error)?; + serde_json::to_string(&drill).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_find_dimension_index_with_yaml( + model_yaml: &str, + name: &str, +) -> Result { + let index = model_find_dimension_index_with_yaml(model_yaml, name).map_err(wasm_error)?; + serde_json::to_string(&index).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_find_metric_index_with_yaml( + model_yaml: &str, + name: &str, +) -> Result { + let index = model_find_metric_index_with_yaml(model_yaml, name).map_err(wasm_error)?; + serde_json::to_string(&index).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_find_segment_index_with_yaml( + model_yaml: &str, + name: &str, +) -> Result { + let index = model_find_segment_index_with_yaml(model_yaml, name).map_err(wasm_error)?; + serde_json::to_string(&index).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_model_find_pre_aggregation_index_with_yaml( + model_yaml: &str, + name: &str, +) -> Result { + let index = model_find_pre_aggregation_index_with_yaml(model_yaml, name).map_err(wasm_error)?; + serde_json::to_string(&index).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_relationship_sql_expr_with_yaml(relationship_yaml: &str) -> Result { + relationship_sql_expr_with_yaml(relationship_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_relationship_related_key_with_yaml(relationship_yaml: &str) -> Result { + relationship_related_key_with_yaml(relationship_yaml).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_relationship_foreign_key_columns_with_yaml( + relationship_yaml: &str, +) -> Result { + let cols = relationship_foreign_key_columns_with_yaml(relationship_yaml).map_err(wasm_error)?; + serde_json::to_string(&cols).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_relationship_primary_key_columns_with_yaml( + relationship_yaml: &str, +) -> Result { + let cols = relationship_primary_key_columns_with_yaml(relationship_yaml).map_err(wasm_error)?; + serde_json::to_string(&cols).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_segment_get_sql_with_yaml( + segment_yaml: &str, + model_alias: &str, +) -> Result { + segment_get_sql_with_yaml(segment_yaml, model_alias).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_extract_metric_dependencies_from_yaml( + metric_yaml: &str, + models_yaml: Option, + model_context: Option, +) -> Result { + let deps = extract_metric_dependencies_from_yaml( + metric_yaml, + models_yaml.as_deref(), + model_context.as_deref(), + ) + .map_err(wasm_error)?; + serde_json::to_string(&deps).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_evaluate_table_calculation_expression(expr: &str) -> Result { + evaluate_table_calculation_expression(expr).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_table_formula_expression(expression: &str) -> Result { + validate_table_formula_expression(expression).map_err(wasm_error) +} + +#[wasm_bindgen] +pub fn wasm_validate_engine_refresh_sql_compatibility( + source_sql: &str, + dialect: &str, +) -> Result { + let (is_valid, error) = validate_engine_refresh_sql_compatibility(source_sql, dialect); + serde_json::to_string(&serde_json::json!({ + "is_valid": is_valid, + "error": error, + })) + .map_err(wasm_error) +} + +#[wasm_bindgen] +#[allow(clippy::too_many_arguments)] +pub fn wasm_build_preaggregation_refresh_statements( + mode: &str, + table_name: &str, + source_sql: &str, + watermark_column: Option, + from_watermark: Option, + lookback: Option, + dialect: Option, + refresh_every: Option, +) -> Result { + let statements = build_preaggregation_refresh_statements( + mode, + table_name, + source_sql, + watermark_column.as_deref(), + from_watermark.as_deref(), + lookback.as_deref(), + dialect.as_deref(), + refresh_every.as_deref(), + ) + .map_err(wasm_error)?; + serde_json::to_string(&statements).map_err(wasm_error) +} diff --git a/sidemantic-rs/src/workbench.rs b/sidemantic-rs/src/workbench.rs new file mode 100644 index 00000000..7cb4550f --- /dev/null +++ b/sidemantic-rs/src/workbench.rs @@ -0,0 +1,1257 @@ +use std::io::{self, IsTerminal}; +use std::time::Duration; + +use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; +use crossterm::execute; +use crossterm::terminal::{ + disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen, +}; +use ratatui::backend::CrosstermBackend; +use ratatui::layout::{Constraint, Direction, Layout}; +use ratatui::style::{Color, Modifier, Style}; +use ratatui::text::Line; +use ratatui::widgets::{Block, Borders, Cell, List, ListItem, Paragraph, Row, Table, Wrap}; +use ratatui::Terminal; +use sidemantic::SidemanticRuntime; +#[cfg(feature = "workbench-adbc")] +use sidemantic::{execute_with_adbc, AdbcExecutionRequest, AdbcValue}; + +use crate::CliResult; + +const PREVIEW_ROW_LIMIT: usize = 25; +const PREVIEW_CELL_WIDTH: usize = 48; + +#[derive(Debug, Clone)] +struct ModelSummary { + name: String, + table: String, + dimensions: usize, + metrics: usize, + relationships: usize, + dimension_names: Vec, + metric_names: Vec, + relationship_names: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FocusPanel { + Models, + Sql, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OutputView { + Sql, + Table, + Chart, +} + +impl OutputView { + fn next(self) -> Self { + match self { + OutputView::Sql => OutputView::Table, + OutputView::Table => OutputView::Chart, + OutputView::Chart => OutputView::Sql, + } + } + + fn label(self) -> &'static str { + match self { + OutputView::Sql => "SQL", + OutputView::Table => "TABLE", + OutputView::Chart => "CHART", + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ChartRenderMode { + Bar, + Dot, +} + +impl ChartRenderMode { + fn next(self) -> Self { + match self { + ChartRenderMode::Bar => ChartRenderMode::Dot, + ChartRenderMode::Dot => ChartRenderMode::Bar, + } + } + + fn label(self) -> &'static str { + match self { + ChartRenderMode::Bar => "BAR", + ChartRenderMode::Dot => "DOT", + } + } +} + +#[derive(Debug, Clone)] +struct ExecutionPreview { + rewritten_sql: String, + columns: Vec, + rows: Vec>, +} + +#[cfg_attr(not(feature = "workbench-adbc"), allow(dead_code))] +#[derive(Debug, Clone, PartialEq)] +enum WorkbenchValue { + Null, + Bool(bool), + I64(i64), + U64(u64), + F64(f64), + String(String), + Bytes(Vec), +} + +#[cfg(feature = "workbench-adbc")] +impl From for WorkbenchValue { + fn from(value: AdbcValue) -> Self { + match value { + AdbcValue::Null => Self::Null, + AdbcValue::Bool(value) => Self::Bool(value), + AdbcValue::I64(value) => Self::I64(value), + AdbcValue::U64(value) => Self::U64(value), + AdbcValue::F64(value) => Self::F64(value), + AdbcValue::String(value) => Self::String(value), + AdbcValue::Bytes(value) => Self::Bytes(value), + } + } +} + +#[derive(Debug)] +struct WorkbenchApp { + runtime: SidemanticRuntime, + models: Vec, + selected_model_index: usize, + sql_input: String, + output: String, + output_view: OutputView, + execution_preview: Option, + status: String, + focus: FocusPanel, + should_quit: bool, + connection: Option, + chart_mode: ChartRenderMode, + chart_value_column: Option, + chart_label_column: Option, +} + +impl WorkbenchApp { + fn new(runtime: SidemanticRuntime, connection: Option) -> Self { + let mut models = runtime + .graph() + .models() + .map(|model| ModelSummary { + name: model.name.clone(), + table: model.table_name().to_string(), + dimensions: model.dimensions.len(), + metrics: model.metrics.len(), + relationships: model.relationships.len(), + dimension_names: model + .dimensions + .iter() + .map(|dimension| dimension.name.clone()) + .collect(), + metric_names: model + .metrics + .iter() + .map(|metric| metric.name.clone()) + .collect(), + relationship_names: model + .relationships + .iter() + .map(|relationship| relationship.name.clone()) + .collect(), + }) + .collect::>(); + models.sort_by(|left, right| left.name.cmp(&right.name)); + + let sql_input = models.first().map_or_else( + || "select 1".to_string(), + |model| format!("select * from {}", model.name), + ); + let status = if connection.is_some() { + "Ready. F5 rewrite, F6 execute, F7 view cycle, Ctrl+M/V/L chart controls, Tab switch focus, Esc quit." + .to_string() + } else { + "Ready. F5 rewrite. F6 execute requires --connection/--db. F7 view cycle, Ctrl+M/V/L chart controls, Tab switch focus, Esc quit." + .to_string() + }; + + let mut app = Self { + runtime, + models, + selected_model_index: 0, + sql_input, + output: String::new(), + output_view: OutputView::Sql, + execution_preview: None, + status, + focus: FocusPanel::Sql, + should_quit: false, + connection, + chart_mode: ChartRenderMode::Bar, + chart_value_column: None, + chart_label_column: None, + }; + app.run_rewrite(); + app + } + + fn run_rewrite(&mut self) { + match self.runtime.rewrite(&self.sql_input) { + Ok(sql) => { + self.output = sql; + self.status = "Rewrite ok".to_string(); + } + Err(err) => { + self.output = err.to_string(); + self.status = "Rewrite failed".to_string(); + } + } + } + + fn selected_model_name(&self) -> Option<&str> { + self.models + .get(self.selected_model_index) + .map(|model| model.name.as_str()) + } + + fn apply_model_template(&mut self) { + if let Some(model_name) = self.selected_model_name().map(str::to_string) { + self.sql_input = format!("select * from {model_name}"); + self.status = format!("Loaded template for model '{model_name}'"); + } + } + + fn run_execute(&mut self) { + let rewritten = match self.runtime.rewrite(&self.sql_input) { + Ok(sql) => sql, + Err(err) => { + self.execution_preview = None; + self.output = err.to_string(); + self.status = "Execute failed (rewrite)".to_string(); + return; + } + }; + + let Some(connection) = self.connection.as_deref() else { + self.execution_preview = None; + self.output = rewritten; + self.status = "Execute skipped: no connection configured".to_string(); + return; + }; + + #[cfg(not(feature = "workbench-adbc"))] + { + let _ = connection; + self.execution_preview = None; + self.output = rewritten; + self.status = "Execute unavailable: build without workbench-adbc".to_string(); + self.output.push_str( + "\n\nADBC execution support is not enabled. Rebuild with feature 'workbench-adbc' to run database-backed queries.", + ); + return; + } + + #[cfg(feature = "workbench-adbc")] + { + let (driver, uri, database_options) = + match crate::parse_connection_url_to_adbc(connection) { + Ok(payload) => payload, + Err(err) => { + self.execution_preview = None; + self.output = rewritten; + self.status = "Execute failed (connection)".to_string(); + self.output.push_str("\n\nConnection parsing failed:\n"); + self.output.push_str(&err); + return; + } + }; + + match execute_with_adbc(AdbcExecutionRequest { + driver, + sql: rewritten.clone(), + uri, + entrypoint: None, + database_options, + connection_options: Vec::new(), + }) { + Ok(result) => { + let rows = result + .rows + .into_iter() + .map(|row| { + row.into_iter() + .map(WorkbenchValue::from) + .collect::>() + }) + .collect::>(); + let row_count = rows.len(); + self.execution_preview = Some(ExecutionPreview { + rewritten_sql: rewritten.clone(), + columns: result.columns.clone(), + rows: rows.clone(), + }); + self.chart_value_column = None; + self.chart_label_column = None; + self.output = format_execution_output(&rewritten, &result.columns, &rows); + self.status = format!( + "Execute ok ({} row{})", + row_count, + if row_count == 1 { "" } else { "s" } + ); + } + Err(err) => { + self.execution_preview = None; + self.output = rewritten; + self.status = "Execute failed".to_string(); + self.output.push_str("\n\nExecution failed:\n"); + self.output.push_str(&err.to_string()); + } + } + } + } + + fn cycle_output_view(&mut self) { + self.output_view = self.output_view.next(); + self.status = format!("Output view: {}", self.output_view.label()); + } + + fn set_output_view(&mut self, view: OutputView) { + self.output_view = view; + self.status = format!("Output view: {}", self.output_view.label()); + } + + fn cycle_chart_mode(&mut self) { + self.chart_mode = self.chart_mode.next(); + self.output_view = OutputView::Chart; + self.status = format!("Chart mode: {}", self.chart_mode.label()); + } + + fn cycle_chart_value_column(&mut self) { + let Some(preview) = self.execution_preview.as_ref() else { + self.status = "Chart value column unavailable: run execute first".to_string(); + return; + }; + + let numeric_candidates = chart_numeric_column_indices(preview); + if numeric_candidates.is_empty() { + self.status = "Chart value column unavailable: no numeric columns".to_string(); + return; + } + + let current_position = self.chart_value_column.and_then(|current| { + numeric_candidates + .iter() + .position(|index| *index == current) + }); + let next_position = match current_position { + Some(position) => (position + 1) % numeric_candidates.len(), + None => 0, + }; + let next_index = numeric_candidates[next_position]; + self.chart_value_column = Some(next_index); + if self.chart_label_column == Some(next_index) { + self.chart_label_column = None; + } + self.output_view = OutputView::Chart; + self.status = format!("Chart value column: {}", preview.columns[next_index]); + } + + fn cycle_chart_label_column(&mut self) { + let Some(preview) = self.execution_preview.as_ref() else { + self.status = "Chart label column unavailable: run execute first".to_string(); + return; + }; + if preview.columns.is_empty() { + self.status = "Chart label column unavailable: no columns".to_string(); + return; + } + + let value_index = self + .chart_value_column + .and_then(|index| preview.columns.get(index).map(|_| index)) + .or_else(|| first_numeric_column_index(preview)) + .unwrap_or(0); + + let mut label_candidates = (0..preview.columns.len()).collect::>(); + if label_candidates.len() > 1 { + label_candidates.retain(|index| *index != value_index); + } + if label_candidates.is_empty() { + self.status = "Chart label column unavailable: no label candidates".to_string(); + return; + } + + let default_label = default_chart_label_index(preview, value_index); + let current_label = self + .chart_label_column + .and_then(|index| preview.columns.get(index).map(|_| index)) + .unwrap_or(default_label); + let current_position = label_candidates + .iter() + .position(|index| *index == current_label) + .unwrap_or(0); + let next_position = (current_position + 1) % label_candidates.len(); + let next_index = label_candidates[next_position]; + self.chart_label_column = Some(next_index); + self.output_view = OutputView::Chart; + self.status = format!("Chart label column: {}", preview.columns[next_index]); + } + + fn next_focus(&mut self) { + self.focus = match self.focus { + FocusPanel::Models => FocusPanel::Sql, + FocusPanel::Sql => FocusPanel::Models, + }; + } + + fn handle_models_key(&mut self, key: KeyEvent) { + match key.code { + KeyCode::Down if !self.models.is_empty() => { + self.selected_model_index = + (self.selected_model_index + 1).min(self.models.len() - 1); + } + KeyCode::Up if self.selected_model_index > 0 => { + self.selected_model_index -= 1; + } + KeyCode::Enter => self.apply_model_template(), + _ => {} + } + } + + fn handle_sql_key(&mut self, key: KeyEvent) { + match key.code { + KeyCode::Char(ch) => { + if key.modifiers.contains(KeyModifiers::CONTROL) { + return; + } + self.sql_input.push(ch); + } + KeyCode::Backspace => { + self.sql_input.pop(); + } + KeyCode::Enter => { + self.sql_input.push('\n'); + } + KeyCode::Tab => self.next_focus(), + _ => {} + } + } + + fn handle_key(&mut self, key: KeyEvent) { + if key.kind != KeyEventKind::Press { + return; + } + + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('c') { + self.should_quit = true; + return; + } + + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('r') { + self.run_rewrite(); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('e') { + self.run_execute(); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('1') { + self.set_output_view(OutputView::Sql); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('2') { + self.set_output_view(OutputView::Table); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('3') { + self.set_output_view(OutputView::Chart); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('m') { + self.cycle_chart_mode(); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('v') { + self.cycle_chart_value_column(); + return; + } + if key.modifiers.contains(KeyModifiers::CONTROL) && key.code == KeyCode::Char('l') { + self.cycle_chart_label_column(); + return; + } + + match key.code { + KeyCode::Esc => self.should_quit = true, + KeyCode::F(5) => self.run_rewrite(), + KeyCode::F(6) => self.run_execute(), + KeyCode::F(7) => self.cycle_output_view(), + KeyCode::Tab => self.next_focus(), + _ => match self.focus { + FocusPanel::Models => self.handle_models_key(key), + FocusPanel::Sql => self.handle_sql_key(key), + }, + } + } +} + +fn selected_model_details(model: Option<&ModelSummary>) -> String { + let Some(model) = model else { + return "No model selected".to_string(); + }; + let mut details = String::new(); + details.push_str(&format!("table: {}\n", model.table)); + details.push_str(&format!( + "dimensions ({}): {}\n", + model.dimensions, + summarize_name_list(&model.dimension_names, 8) + )); + details.push_str(&format!( + "metrics ({}): {}\n", + model.metrics, + summarize_name_list(&model.metric_names, 8) + )); + details.push_str(&format!( + "relationships ({}): {}\n", + model.relationships, + summarize_name_list(&model.relationship_names, 6) + )); + details +} + +fn summarize_name_list(names: &[String], limit: usize) -> String { + if names.is_empty() { + return "-".to_string(); + } + let shown = names.iter().take(limit).cloned().collect::>(); + if names.len() > limit { + format!("{}, +{} more", shown.join(", "), names.len() - limit) + } else { + shown.join(", ") + } +} + +fn workbench_value_as_f64(value: &WorkbenchValue) -> Option { + match value { + WorkbenchValue::Bool(value) => Some(if *value { 1.0 } else { 0.0 }), + WorkbenchValue::I64(value) => Some(*value as f64), + WorkbenchValue::U64(value) => Some(*value as f64), + WorkbenchValue::F64(value) => Some(*value), + WorkbenchValue::Null | WorkbenchValue::String(_) | WorkbenchValue::Bytes(_) => None, + } +} + +fn chart_numeric_column_indices(preview: &ExecutionPreview) -> Vec { + preview + .columns + .iter() + .enumerate() + .filter_map(|(index, _)| { + preview + .rows + .iter() + .take(PREVIEW_ROW_LIMIT) + .find_map(|row| row.get(index)) + .and_then(workbench_value_as_f64) + .map(|_| index) + }) + .collect() +} + +fn first_numeric_column_index(preview: &ExecutionPreview) -> Option { + chart_numeric_column_indices(preview).into_iter().next() +} + +fn default_chart_label_index(preview: &ExecutionPreview, value_index: usize) -> usize { + if value_index == 0 && preview.columns.len() > 1 { + 1 + } else { + 0 + } +} + +fn build_chart_lines( + preview: Option<&ExecutionPreview>, + chart_mode: ChartRenderMode, + value_column_override: Option, + label_column_override: Option, +) -> Vec { + let Some(preview) = preview else { + return vec!["No execution results yet. Press F6 to run query.".to_string()]; + }; + if preview.rows.is_empty() || preview.columns.is_empty() { + return vec!["Execution returned no rows.".to_string()]; + } + + let numeric_candidates = chart_numeric_column_indices(preview); + let Some(default_value_index) = numeric_candidates.first().copied() else { + return vec!["Chart unavailable: no numeric columns in execution result.".to_string()]; + }; + let value_index = value_column_override + .filter(|index| numeric_candidates.contains(index)) + .unwrap_or(default_value_index); + let default_label_index = default_chart_label_index(preview, value_index); + let label_index = label_column_override + .filter(|index| preview.columns.get(*index).is_some()) + .unwrap_or(default_label_index); + + let mut points = Vec::new(); + for (row_index, row) in preview.rows.iter().take(12).enumerate() { + let Some(value) = row.get(value_index).and_then(workbench_value_as_f64) else { + continue; + }; + let label = row + .get(label_index) + .map(format_workbench_value) + .filter(|value| !value.trim().is_empty() && value != "NULL") + .unwrap_or_else(|| format!("row {}", row_index + 1)); + points.push((label, value)); + } + if points.is_empty() { + return vec!["Chart unavailable: no numeric rows in preview subset.".to_string()]; + } + + let max_abs = points + .iter() + .map(|(_, value)| value.abs()) + .fold(0.0_f64, f64::max) + .max(1.0); + let min_value = points + .iter() + .map(|(_, value)| *value) + .fold(f64::INFINITY, f64::min); + let max_value = points + .iter() + .map(|(_, value)| *value) + .fold(f64::NEG_INFINITY, f64::max); + let value_span = (max_value - min_value).abs().max(1e-9); + + let mut lines = vec![ + format!( + "chart source: value={} label={} mode={}", + preview.columns[value_index], + preview.columns[label_index], + chart_mode.label() + ), + "chart controls: Ctrl+M mode, Ctrl+V value column, Ctrl+L label column".to_string(), + ]; + for (label, value) in points { + match chart_mode { + ChartRenderMode::Bar => { + let bar_units = ((value.abs() / max_abs) * 24.0).round().max(1.0) as usize; + let bar_symbol = if value < 0.0 { '-' } else { '#' }; + let bar = std::iter::repeat_n(bar_symbol, bar_units).collect::(); + lines.push(format!("{label:>20} | {bar:<24} {value:.3}")); + } + ChartRenderMode::Dot => { + let marker_position = (((value - min_value) / value_span) * 23.0).round() as usize; + let mut markers = [' '; 24]; + markers[marker_position.min(23)] = 'o'; + let track = markers.iter().collect::(); + lines.push(format!("{label:>20} | {track} {value:.3}")); + } + } + } + lines +} + +#[cfg_attr(not(feature = "workbench-adbc"), allow(dead_code))] +fn format_execution_output( + rewritten: &str, + columns: &[String], + rows: &[Vec], +) -> String { + let shown = rows.len().min(PREVIEW_ROW_LIMIT); + let mut output = String::new(); + output.push_str("Rendered SQL:\n"); + output.push_str(rewritten); + output.push_str("\n\nResult preview:\n"); + output.push_str(&format!( + "rows={} shown={} limit={}\n", + rows.len(), + shown, + PREVIEW_ROW_LIMIT + )); + + if columns.is_empty() { + output.push_str("(no columns)\n"); + return output; + } + + output.push_str(&columns.join(" | ")); + output.push('\n'); + output.push_str( + &columns + .iter() + .map(|_| "--------") + .collect::>() + .join("-+-"), + ); + output.push('\n'); + + for row in rows.iter().take(shown) { + let rendered_row = columns + .iter() + .enumerate() + .map(|(index, _)| { + let value = row.get(index).unwrap_or(&WorkbenchValue::Null); + format_workbench_value(value) + }) + .collect::>(); + output.push_str(&rendered_row.join(" | ")); + output.push('\n'); + } + + if rows.len() > shown { + output.push('\n'); + output.push_str(&format!("... {} more rows", rows.len() - shown)); + } + + output +} + +fn format_workbench_value(value: &WorkbenchValue) -> String { + let raw = match value { + WorkbenchValue::Null => "NULL".to_string(), + WorkbenchValue::Bool(value) => value.to_string(), + WorkbenchValue::I64(value) => value.to_string(), + WorkbenchValue::U64(value) => value.to_string(), + WorkbenchValue::F64(value) => value.to_string(), + WorkbenchValue::String(value) => value.clone(), + WorkbenchValue::Bytes(value) => { + let hex = value + .iter() + .take(16) + .map(|byte| format!("{byte:02x}")) + .collect::(); + if value.len() > 16 { + format!("0x{hex}...") + } else { + format!("0x{hex}") + } + } + }; + let normalized = raw.replace('\n', " "); + if normalized.chars().count() <= PREVIEW_CELL_WIDTH { + normalized + } else { + let prefix = normalized + .chars() + .take(PREVIEW_CELL_WIDTH.saturating_sub(3)) + .collect::(); + format!("{prefix}...") + } +} + +fn draw_app(frame: &mut ratatui::Frame<'_>, app: &WorkbenchApp) { + let layout = Layout::default() + .direction(Direction::Vertical) + .constraints([ + Constraint::Length(3), + Constraint::Min(10), + Constraint::Length(3), + ]) + .split(frame.area()); + + let connection_text = app + .connection + .as_deref() + .map_or("connection=none".to_string(), |value| { + format!("connection={value}") + }); + let selected_text = app + .selected_model_name() + .map_or("selected=none".to_string(), |name| { + format!("selected={name}") + }); + let header = Paragraph::new(format!( + "Sidemantic Workbench (ratatui) | {selected_text} | {connection_text}" + )) + .block(Block::default().borders(Borders::ALL)); + frame.render_widget(header, layout[0]); + + let body = Layout::default() + .direction(Direction::Horizontal) + .constraints([Constraint::Percentage(35), Constraint::Percentage(65)]) + .split(layout[1]); + + let model_items = app + .models + .iter() + .map(|model| { + ListItem::new(Line::from(format!( + "{} [table={} dims={} metrics={} rels={}]", + model.name, model.table, model.dimensions, model.metrics, model.relationships + ))) + }) + .collect::>(); + let left = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Percentage(52), Constraint::Percentage(48)]) + .split(body[0]); + let model_title = if app.focus == FocusPanel::Models { + "Models [focus]" + } else { + "Models" + }; + let model_list = + List::new(model_items).block(Block::default().title(model_title).borders(Borders::ALL)); + let mut model_state = ratatui::widgets::ListState::default(); + if !app.models.is_empty() { + model_state.select(Some(app.selected_model_index)); + } + frame.render_stateful_widget(model_list, left[0], &mut model_state); + + let details = selected_model_details(app.models.get(app.selected_model_index)); + let details_panel = Paragraph::new(details) + .block( + Block::default() + .title("Model Details (Enter loads template)") + .borders(Borders::ALL), + ) + .wrap(Wrap { trim: false }); + frame.render_widget(details_panel, left[1]); + + let right = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Percentage(45), Constraint::Percentage(55)]) + .split(body[1]); + + let sql_style = if app.focus == FocusPanel::Sql { + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD) + } else { + Style::default() + }; + let sql_editor = Paragraph::new(app.sql_input.as_str()) + .block( + Block::default() + .title( + "SQL Input [Tab switch, F5 rewrite, F6 execute, F7 view cycle, Ctrl+M/V/L chart config]", + ) + .borders(Borders::ALL) + .border_style(sql_style), + ) + .wrap(Wrap { trim: false }); + frame.render_widget(sql_editor, right[0]); + + let output_title = format!( + "Output [{}] (Ctrl+1 SQL / Ctrl+2 TABLE / Ctrl+3 CHART)", + app.output_view.label() + ); + match app.output_view { + OutputView::Sql => { + let output = Paragraph::new(app.output.as_str()) + .block(Block::default().title(output_title).borders(Borders::ALL)) + .wrap(Wrap { trim: false }); + frame.render_widget(output, right[1]); + } + OutputView::Table => { + if let Some(preview) = app.execution_preview.as_ref() { + let row_count = preview.rows.len(); + let shown_count = row_count.min(PREVIEW_ROW_LIMIT); + if preview.columns.is_empty() { + let output = Paragraph::new("Execution returned no columns.") + .block(Block::default().title(output_title).borders(Borders::ALL)) + .wrap(Wrap { trim: false }); + frame.render_widget(output, right[1]); + } else { + let header = Row::new(preview.columns.iter().map(|column| { + Cell::from(column.clone()).style( + Style::default() + .fg(Color::Yellow) + .add_modifier(Modifier::BOLD), + ) + })); + let rows = preview + .rows + .iter() + .take(PREVIEW_ROW_LIMIT) + .map(|row| { + Row::new(preview.columns.iter().enumerate().map(|(index, _)| { + let value = row.get(index).unwrap_or(&WorkbenchValue::Null); + Cell::from(format_workbench_value(value)) + })) + }) + .collect::>(); + let widths = preview + .columns + .iter() + .map(|_| Constraint::Min(12)) + .collect::>(); + let table = Table::new(rows, widths) + .header(header) + .column_spacing(1) + .block( + Block::default() + .title(format!( + "{output_title} rows={row_count} shown={shown_count} sql={}", + preview.rewritten_sql.lines().next().unwrap_or("") + )) + .borders(Borders::ALL), + ); + frame.render_widget(table, right[1]); + } + } else { + let output = Paragraph::new("No execution results. Press F6 to execute query.") + .block(Block::default().title(output_title).borders(Borders::ALL)) + .wrap(Wrap { trim: false }); + frame.render_widget(output, right[1]); + } + } + OutputView::Chart => { + let chart_lines = build_chart_lines( + app.execution_preview.as_ref(), + app.chart_mode, + app.chart_value_column, + app.chart_label_column, + ); + let output = Paragraph::new(chart_lines.join("\n")) + .block(Block::default().title(output_title).borders(Borders::ALL)) + .wrap(Wrap { trim: false }); + frame.render_widget(output, right[1]); + } + } + + let footer = Paragraph::new(app.status.as_str()).block(Block::default().borders(Borders::ALL)); + frame.render_widget(footer, layout[2]); +} + +pub fn launch(models_path: &str, connection: Option) -> CliResult<()> { + if !io::stdin().is_terminal() || !io::stdout().is_terminal() { + return Err("workbench requires an interactive terminal (TTY)".to_string()); + } + + let runtime = super::load_runtime(models_path)?; + let mut app = WorkbenchApp::new(runtime, connection); + + enable_raw_mode().map_err(|e| format!("failed to enable raw mode: {e}"))?; + let mut stdout = io::stdout(); + execute!(stdout, EnterAlternateScreen) + .map_err(|e| format!("failed to enter alternate screen: {e}"))?; + + let backend = CrosstermBackend::new(stdout); + let mut terminal = + Terminal::new(backend).map_err(|e| format!("failed to initialize terminal: {e}"))?; + + let mut loop_result: CliResult<()> = Ok(()); + while !app.should_quit { + if let Err(err) = terminal.draw(|frame| draw_app(frame, &app)) { + loop_result = Err(format!("failed to draw workbench UI: {err}")); + break; + } + + match event::poll(Duration::from_millis(100)) { + Ok(true) => match event::read() { + Ok(Event::Key(key)) => app.handle_key(key), + Ok(_) => {} + Err(err) => { + loop_result = Err(format!("failed to read terminal event: {err}")); + break; + } + }, + Ok(false) => {} + Err(err) => { + loop_result = Err(format!("failed to poll terminal events: {err}")); + break; + } + } + } + + let mut restore_error = None; + if let Err(err) = disable_raw_mode() { + restore_error = Some(format!("failed to disable raw mode: {err}")); + } + if let Err(err) = execute!(terminal.backend_mut(), LeaveAlternateScreen) { + restore_error = Some(format!("failed to leave alternate screen: {err}")); + } + if let Err(err) = terminal.show_cursor() { + restore_error = Some(format!("failed to restore cursor: {err}")); + } + + if let Some(err) = restore_error { + return Err(err); + } + + loop_result +} + +#[cfg(test)] +mod tests { + use super::*; + use ratatui::backend::TestBackend; + + #[test] + fn format_workbench_value_truncates_and_normalizes_newlines() { + let value = WorkbenchValue::String( + "line1\nline2 this is a very long trailing value that should exceed the preview cell width" + .to_string(), + ); + let rendered = format_workbench_value(&value); + assert!(!rendered.contains('\n')); + assert!(rendered.contains("line1 line2")); + assert!(rendered.ends_with("...")); + } + + #[test] + fn format_execution_output_includes_row_counts_and_preview() { + let columns = vec!["id".to_string(), "name".to_string()]; + let rows = vec![ + vec![ + WorkbenchValue::I64(1), + WorkbenchValue::String("alice".to_string()), + ], + vec![ + WorkbenchValue::I64(2), + WorkbenchValue::String("bob".to_string()), + ], + ]; + let rendered = format_execution_output("select * from users", &columns, &rows); + assert!(rendered.contains("Rendered SQL:")); + assert!(rendered.contains("rows=2 shown=2")); + assert!(rendered.contains("id | name")); + assert!(rendered.contains("alice")); + assert!(rendered.contains("bob")); + } + + #[test] + fn output_view_cycles_in_expected_order() { + assert_eq!(OutputView::Sql.next(), OutputView::Table); + assert_eq!(OutputView::Table.next(), OutputView::Chart); + assert_eq!(OutputView::Chart.next(), OutputView::Sql); + } + + #[test] + fn build_chart_lines_uses_numeric_column_when_available() { + let preview = ExecutionPreview { + rewritten_sql: "select name, amount from t".to_string(), + columns: vec!["name".to_string(), "amount".to_string()], + rows: vec![ + vec![ + WorkbenchValue::String("alpha".to_string()), + WorkbenchValue::F64(10.0), + ], + vec![ + WorkbenchValue::String("beta".to_string()), + WorkbenchValue::F64(20.0), + ], + ], + }; + let lines = build_chart_lines(Some(&preview), ChartRenderMode::Bar, None, None); + assert!(lines + .first() + .is_some_and(|line| line.contains("chart source"))); + assert!(lines.iter().any(|line| line.contains("alpha"))); + assert!(lines.iter().any(|line| line.contains("beta"))); + } + + #[test] + fn build_chart_lines_reports_missing_numeric_data() { + let preview = ExecutionPreview { + rewritten_sql: "select name from t".to_string(), + columns: vec!["name".to_string()], + rows: vec![vec![WorkbenchValue::String("alpha".to_string())]], + }; + let lines = build_chart_lines(Some(&preview), ChartRenderMode::Bar, None, None); + assert_eq!( + lines, + vec!["Chart unavailable: no numeric columns in execution result.".to_string()] + ); + } + + #[test] + fn build_chart_lines_honors_column_overrides() { + let preview = ExecutionPreview { + rewritten_sql: "select label, amount_a, amount_b from t".to_string(), + columns: vec![ + "label".to_string(), + "amount_a".to_string(), + "amount_b".to_string(), + ], + rows: vec![ + vec![ + WorkbenchValue::String("first".to_string()), + WorkbenchValue::F64(10.0), + WorkbenchValue::F64(3.0), + ], + vec![ + WorkbenchValue::String("second".to_string()), + WorkbenchValue::F64(20.0), + WorkbenchValue::F64(7.0), + ], + ], + }; + + let lines = build_chart_lines(Some(&preview), ChartRenderMode::Bar, Some(2), Some(0)); + assert!(lines + .first() + .is_some_and(|line| line.contains("value=amount_b label=label mode=BAR"))); + } + + #[test] + fn build_chart_lines_dot_mode_renders_dot_track() { + let preview = ExecutionPreview { + rewritten_sql: "select name, amount from t".to_string(), + columns: vec!["name".to_string(), "amount".to_string()], + rows: vec![ + vec![ + WorkbenchValue::String("alpha".to_string()), + WorkbenchValue::F64(10.0), + ], + vec![ + WorkbenchValue::String("beta".to_string()), + WorkbenchValue::F64(20.0), + ], + ], + }; + + let lines = build_chart_lines(Some(&preview), ChartRenderMode::Dot, None, None); + assert!(lines.first().is_some_and(|line| line.contains("mode=DOT"))); + assert!(lines.iter().any(|line| line.contains("o"))); + } + + fn fixture_runtime() -> SidemanticRuntime { + SidemanticRuntime::from_yaml( + r#" +models: + - name: z_orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +"#, + ) + .expect("fixture runtime should load") + } + + fn key(code: KeyCode) -> KeyEvent { + KeyEvent::new(code, KeyModifiers::NONE) + } + + fn ctrl(ch: char) -> KeyEvent { + KeyEvent::new(KeyCode::Char(ch), KeyModifiers::CONTROL) + } + + #[test] + fn workbench_app_startup_sorts_models_and_rewrites_initial_query() { + let app = WorkbenchApp::new(fixture_runtime(), None); + assert_eq!(app.models[0].name, "customers"); + assert_eq!(app.models[1].name, "z_orders"); + assert_eq!(app.sql_input, "select * from customers"); + assert_eq!(app.output_view, OutputView::Sql); + assert_eq!(app.focus, FocusPanel::Sql); + assert_eq!(app.status, "Rewrite ok"); + assert!(!app.output.trim().is_empty()); + + let with_connection = + WorkbenchApp::new(fixture_runtime(), Some("duckdb:///tmp/demo.db".to_string())); + assert_eq!(with_connection.status, "Rewrite ok"); + } + + #[test] + fn workbench_key_handling_updates_state_deterministically() { + let mut app = WorkbenchApp::new(fixture_runtime(), None); + + app.handle_key(key(KeyCode::Char('x'))); + assert!(app.sql_input.ends_with('x')); + app.handle_key(key(KeyCode::Backspace)); + assert!(!app.sql_input.ends_with('x')); + app.handle_key(key(KeyCode::Enter)); + assert!(app.sql_input.ends_with('\n')); + app.handle_key(ctrl('z')); + assert!(!app.sql_input.ends_with('z')); + + app.handle_key(key(KeyCode::Tab)); + assert_eq!(app.focus, FocusPanel::Models); + app.handle_key(key(KeyCode::Down)); + assert_eq!(app.selected_model_index, 1); + app.handle_key(key(KeyCode::Down)); + assert_eq!(app.selected_model_index, 1); + app.handle_key(key(KeyCode::Up)); + assert_eq!(app.selected_model_index, 0); + app.handle_key(key(KeyCode::Enter)); + assert_eq!(app.sql_input, "select * from customers"); + + app.handle_key(key(KeyCode::F(5))); + assert_eq!(app.status, "Rewrite ok"); + app.handle_key(key(KeyCode::F(6))); + assert_eq!(app.status, "Execute skipped: no connection configured"); + app.handle_key(key(KeyCode::F(7))); + assert_eq!(app.output_view, OutputView::Table); + app.handle_key(ctrl('3')); + assert_eq!(app.output_view, OutputView::Chart); + app.handle_key(ctrl('m')); + assert_eq!(app.chart_mode, ChartRenderMode::Dot); + app.handle_key(key(KeyCode::Esc)); + assert!(app.should_quit); + } + + #[test] + fn draw_app_renders_main_panels_and_small_viewports() { + let mut app = WorkbenchApp::new(fixture_runtime(), None); + let backend = TestBackend::new(100, 32); + let mut terminal = Terminal::new(backend).expect("test terminal should initialize"); + terminal + .draw(|frame| draw_app(frame, &app)) + .expect("workbench should draw"); + let rendered = terminal.backend().to_string(); + for expected in [ + "Sidemantic Workbench", + "Models", + "Model Details", + "SQL Input", + "Output [SQL]", + "connection=none", + ] { + assert!( + rendered.contains(expected), + "missing {expected}\n{rendered}" + ); + } + + app.execution_preview = Some(ExecutionPreview { + rewritten_sql: "select country, count(*) from customers group by 1".to_string(), + columns: vec!["country".to_string(), "count".to_string()], + rows: vec![vec![ + WorkbenchValue::String("US".to_string()), + WorkbenchValue::I64(3), + ]], + }); + app.output_view = OutputView::Table; + terminal + .draw(|frame| draw_app(frame, &app)) + .expect("table view should draw"); + let rendered = terminal.backend().to_string(); + assert!(rendered.contains("country")); + assert!(rendered.contains("US")); + + app.output_view = OutputView::Chart; + terminal + .draw(|frame| draw_app(frame, &app)) + .expect("chart view should draw"); + let rendered = terminal.backend().to_string(); + assert!(rendered.contains("chart source")); + + let backend = TestBackend::new(60, 20); + let mut small_terminal = Terminal::new(backend).expect("small terminal should initialize"); + small_terminal + .draw(|frame| draw_app(frame, &app)) + .expect("small viewport should draw"); + } +} diff --git a/sidemantic-rs/tests/adbc_driver_matrix.rs b/sidemantic-rs/tests/adbc_driver_matrix.rs new file mode 100644 index 00000000..b5fac556 --- /dev/null +++ b/sidemantic-rs/tests/adbc_driver_matrix.rs @@ -0,0 +1,198 @@ +#![cfg(feature = "adbc-exec")] + +use adbc_core::options::{OptionConnection, OptionDatabase, OptionValue}; +use sidemantic::{execute_with_adbc, AdbcExecutionRequest}; + +struct DriverProbe { + name: &'static str, + env_prefix: &'static str, + default_entrypoint: Option<&'static str>, + default_dbopts: &'static [(&'static str, &'static str)], +} + +const DRIVER_PROBES: &[DriverProbe] = &[ + DriverProbe { + name: "duckdb", + env_prefix: "SIDEMANTIC_TEST_ADBC_DUCKDB", + default_entrypoint: Some("duckdb_adbc_init"), + default_dbopts: &[("path", ":memory:")], + }, + DriverProbe { + name: "sqlite", + env_prefix: "SIDEMANTIC_TEST_ADBC_SQLITE", + default_entrypoint: None, + default_dbopts: &[], + }, + DriverProbe { + name: "postgres", + env_prefix: "SIDEMANTIC_TEST_ADBC_POSTGRES", + default_entrypoint: None, + default_dbopts: &[], + }, + DriverProbe { + name: "bigquery", + env_prefix: "SIDEMANTIC_TEST_ADBC_BIGQUERY", + default_entrypoint: None, + default_dbopts: &[], + }, + DriverProbe { + name: "snowflake", + env_prefix: "SIDEMANTIC_TEST_ADBC_SNOWFLAKE", + default_entrypoint: None, + default_dbopts: &[], + }, + DriverProbe { + name: "clickhouse", + env_prefix: "SIDEMANTIC_TEST_ADBC_CLICKHOUSE", + default_entrypoint: None, + default_dbopts: &[], + }, +]; + +fn env_value(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) +} + +fn parse_option_value(value: &str) -> OptionValue { + if let Some(rest) = value.strip_prefix("int:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Int(parsed); + } + } + if let Some(rest) = value.strip_prefix("float:") { + if let Ok(parsed) = rest.parse::() { + return OptionValue::Double(parsed); + } + } + if let Some(rest) = value.strip_prefix("str:") { + return OptionValue::String(rest.to_string()); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Int(parsed); + } + if let Ok(parsed) = value.parse::() { + return OptionValue::Double(parsed); + } + OptionValue::String(value.to_string()) +} + +fn parse_database_options(prefix: &str) -> Vec<(OptionDatabase, OptionValue)> { + let mut options = Vec::new(); + if let Some(raw) = env_value(&format!("{prefix}_DBOPTS")) { + for pair in raw + .split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + { + let Some((key, value)) = pair.split_once('=') else { + panic!("{prefix}_DBOPTS expects comma-separated key=value pairs, got {pair}"); + }; + options.push((OptionDatabase::from(key.trim()), parse_option_value(value))); + } + } + options +} + +fn parse_connection_options(prefix: &str) -> Vec<(OptionConnection, OptionValue)> { + let mut options = Vec::new(); + if let Some(raw) = env_value(&format!("{prefix}_CONNOPTS")) { + for pair in raw + .split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + { + let Some((key, value)) = pair.split_once('=') else { + panic!("{prefix}_CONNOPTS expects comma-separated key=value pairs, got {pair}"); + }; + options.push(( + OptionConnection::from(key.trim()), + parse_option_value(value), + )); + } + } + options +} + +fn default_database_options(probe: &DriverProbe) -> Vec<(OptionDatabase, OptionValue)> { + probe + .default_dbopts + .iter() + .map(|(key, value)| (OptionDatabase::from(*key), parse_option_value(value))) + .collect() +} + +fn required_drivers() -> Vec { + env_value("SIDEMANTIC_TEST_ADBC_REQUIRE") + .map(|raw| { + raw.split(',') + .map(str::trim) + .filter(|item| !item.is_empty()) + .map(str::to_string) + .collect() + }) + .unwrap_or_default() +} + +#[test] +fn rust_adbc_driver_matrix_executes_configured_drivers() { + let required = required_drivers(); + let mut executed = Vec::new(); + + for probe in DRIVER_PROBES { + let driver_var = format!("{}_DRIVER", probe.env_prefix); + let Some(driver) = env_value(&driver_var) else { + if required.iter().any(|name| name == probe.name) { + panic!( + "{} ADBC probe is required but {} is not set", + probe.name, driver_var + ); + } + eprintln!( + "skipping {} ADBC probe; {} is not set", + probe.name, driver_var + ); + continue; + }; + + let query = env_value(&format!("{}_QUERY", probe.env_prefix)) + .unwrap_or_else(|| "select 1 as sidemantic_adbc_probe".to_string()); + let uri = env_value(&format!("{}_URI", probe.env_prefix)); + let entrypoint = env_value(&format!("{}_ENTRYPOINT", probe.env_prefix)) + .or_else(|| probe.default_entrypoint.map(ToString::to_string)); + let mut database_options = default_database_options(probe); + database_options.extend(parse_database_options(probe.env_prefix)); + let connection_options = parse_connection_options(probe.env_prefix); + + let result = execute_with_adbc(AdbcExecutionRequest { + driver, + sql: query, + uri, + entrypoint, + database_options, + connection_options, + }) + .unwrap_or_else(|err| panic!("{} ADBC probe failed: {err}", probe.name)); + + assert!( + !result.columns.is_empty(), + "{} ADBC probe should return at least one column", + probe.name + ); + assert!( + !result.rows.is_empty(), + "{} ADBC probe should return at least one row", + probe.name + ); + executed.push(probe.name); + } + + for required_driver in required { + assert!( + executed.iter().any(|name| *name == required_driver), + "required ADBC probe {required_driver} did not execute" + ); + } +} diff --git a/sidemantic-rs/tests/adbc_duckdb_e2e.rs b/sidemantic-rs/tests/adbc_duckdb_e2e.rs new file mode 100644 index 00000000..8e9d07af --- /dev/null +++ b/sidemantic-rs/tests/adbc_duckdb_e2e.rs @@ -0,0 +1,605 @@ +#![cfg(all(feature = "mcp-adbc", feature = "runtime-server-adbc"))] + +mod common; + +use std::fs; +use std::io::{BufRead, BufReader, Cursor, Read, Write}; +use std::net::TcpStream; +use std::path::{Path, PathBuf}; +use std::process::{ChildStdin, Command, Stdio}; +use std::sync::mpsc::{self, Receiver}; +use std::thread; +use std::time::Duration; + +use adbc_core::options::{OptionDatabase, OptionValue}; +use arrow_ipc::reader::StreamReader; +use common::{command_with_clean_sidemantic_env, free_loopback_addr, unique_temp_dir, ChildGuard}; +use serde_json::{json, Value}; +use sidemantic::{execute_with_adbc, AdbcExecutionRequest}; + +const DUCKDB_ENTRYPOINT: &str = "duckdb_adbc_init"; +const ARROW_STREAM_MEDIA_TYPE: &str = "application/vnd.apache.arrow.stream"; + +fn duckdb_driver_path() -> Option { + std::env::var("SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER") + .ok() + .filter(|value| !value.trim().is_empty()) +} + +fn duckdb_database_options(db_path: &Path) -> Vec<(OptionDatabase, OptionValue)> { + vec![( + OptionDatabase::from("path"), + OptionValue::String(db_path.to_string_lossy().to_string()), + )] +} + +fn execute_duckdb_sql(driver: &str, db_path: &Path, sql: &str) { + execute_with_adbc(AdbcExecutionRequest { + driver: driver.to_string(), + sql: sql.to_string(), + uri: None, + entrypoint: Some(DUCKDB_ENTRYPOINT.to_string()), + database_options: duckdb_database_options(db_path), + connection_options: Vec::new(), + }) + .unwrap_or_else(|err| panic!("DuckDB ADBC SQL failed: {sql}\n{err}")); +} + +fn seed_duckdb(driver: &str, db_path: &Path) { + execute_duckdb_sql(driver, db_path, "drop table if exists orders"); + execute_duckdb_sql( + driver, + db_path, + "create table orders(order_id integer, status varchar, customer_id integer, amount double)", + ); + execute_duckdb_sql( + driver, + db_path, + "insert into orders values (1, 'complete', 10, 10.5), (2, 'complete', 11, 20.0), (3, 'cancelled', 10, 7.0)", + ); +} + +fn assert_revenue_rows(rows: &[Value]) { + assert_eq!(rows.len(), 2, "{rows:?}"); + let revenue_for = |status: &str| -> f64 { + rows.iter() + .find(|row| row["status"] == status) + .and_then(|row| row["revenue"].as_f64()) + .unwrap_or_else(|| panic!("missing revenue row for {status}: {rows:?}")) + }; + assert!((revenue_for("complete") - 30.5).abs() < f64::EPSILON); + assert!((revenue_for("cancelled") - 7.0).abs() < f64::EPSILON); +} + +fn run_cli(driver: &str, models_path: &Path, db_path: &Path) { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("run") + .arg("--models") + .arg(models_path) + .arg("--dimension") + .arg("orders.status") + .arg("--metric") + .arg("orders.revenue") + .arg("--driver") + .arg(driver) + .arg("--entrypoint") + .arg(DUCKDB_ENTRYPOINT) + .arg("--dbopt") + .arg(format!("path={}", db_path.to_string_lossy())) + .output() + .expect("sidemantic run should execute"); + + assert!( + output.status.success(), + "sidemantic run failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let payload: Value = + serde_json::from_slice(&output.stdout).expect("CLI run output should be JSON"); + assert_eq!(payload["row_count"], 2); + assert_revenue_rows(payload["rows"].as_array().expect("rows should be an array")); +} + +fn http_request(addr: &str, method: &str, path: &str, body: Option) -> (u16, Value) { + let (status, _head, body) = http_request_raw(addr, method, path, body, "application/json"); + let json_body = serde_json::from_slice(&body).unwrap_or_else(|err| { + panic!( + "response body should be JSON: {err}\nraw response:\n{}", + String::from_utf8_lossy(&body) + ) + }); + (status, json_body) +} + +fn http_arrow_request( + addr: &str, + method: &str, + path: &str, + body: Option, +) -> (u16, String, Vec) { + http_request_raw(addr, method, path, body, ARROW_STREAM_MEDIA_TYPE) +} + +fn http_request_raw( + addr: &str, + method: &str, + path: &str, + body: Option, + accept: &str, +) -> (u16, String, Vec) { + let body = body.map(|value| value.to_string()); + let mut request = format!( + "{method} {path} HTTP/1.1\r\nHost: {addr}\r\nConnection: close\r\nAccept: {accept}\r\n" + ); + if let Some(body) = body.as_ref() { + request.push_str("Content-Type: application/json\r\n"); + request.push_str(&format!("Content-Length: {}\r\n", body.len())); + } + request.push_str("\r\n"); + if let Some(body) = body.as_ref() { + request.push_str(body); + } + + let mut stream = TcpStream::connect(addr).expect("server should accept connections"); + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .expect("read timeout should be set"); + stream + .write_all(request.as_bytes()) + .expect("request should be written"); + + let mut response = Vec::new(); + stream + .read_to_end(&mut response) + .expect("response should be read"); + let separator = response + .windows(4) + .position(|window| window == b"\r\n\r\n") + .expect("response should contain headers and body"); + let head = String::from_utf8_lossy(&response[..separator]).to_string(); + let body = response[separator + 4..].to_vec(); + let status = head + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .and_then(|code| code.parse::().ok()) + .expect("status code should be parsed"); + (status, head, body) +} + +fn assert_arrow_response( + status: u16, + head: &str, + body: &[u8], + expected_row_count: usize, + expected_fields: &[&str], +) { + assert_eq!(status, 200, "{}", String::from_utf8_lossy(body)); + let normalized_head = head.to_ascii_lowercase(); + assert!( + normalized_head.contains("content-type: application/vnd.apache.arrow.stream"), + "{head}" + ); + assert!( + normalized_head.contains(&format!("x-sidemantic-row-count: {expected_row_count}")), + "{head}" + ); + assert!( + normalized_head.contains("x-sidemantic-arrow-transport: buffered"), + "{head}" + ); + + assert_arrow_payload(body, expected_row_count, expected_fields); +} + +fn assert_chunked_arrow_response( + status: u16, + head: &str, + body: &[u8], + expected_row_count: usize, + expected_fields: &[&str], +) { + assert_eq!(status, 200, "{}", String::from_utf8_lossy(body)); + let normalized_head = head.to_ascii_lowercase(); + assert!( + normalized_head.contains("content-type: application/vnd.apache.arrow.stream"), + "{head}" + ); + assert!( + normalized_head.contains("transfer-encoding: chunked"), + "{head}" + ); + assert!( + normalized_head.contains("x-sidemantic-arrow-transport: chunked"), + "{head}" + ); + assert!( + !normalized_head.contains("x-sidemantic-row-count"), + "streaming responses should not buffer to compute row-count headers: {head}" + ); + + let decoded_body = decode_chunked_body(body); + assert_arrow_payload(&decoded_body, expected_row_count, expected_fields); +} + +fn assert_arrow_payload(body: &[u8], expected_row_count: usize, expected_fields: &[&str]) { + let mut reader = + StreamReader::try_new(Cursor::new(body), None).expect("Arrow IPC stream should decode"); + let schema = reader.schema(); + let field_names = schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + assert_eq!(field_names, expected_fields); + + let mut row_count = 0; + for batch in &mut reader { + let batch = batch.expect("Arrow record batch should decode"); + row_count += batch.num_rows(); + } + assert_eq!(row_count, expected_row_count); +} + +fn decode_chunked_body(body: &[u8]) -> Vec { + let mut decoded = Vec::new(); + let mut offset = 0; + loop { + let line_end = body[offset..] + .windows(2) + .position(|window| window == b"\r\n") + .map(|position| offset + position) + .expect("chunk size line should end with CRLF"); + let size_line = + std::str::from_utf8(&body[offset..line_end]).expect("chunk size should be UTF-8"); + let size_hex = size_line.split(';').next().unwrap_or(size_line).trim(); + let size = usize::from_str_radix(size_hex, 16).unwrap_or_else(|err| { + panic!("chunk size should be hex, got '{size_hex}': {err}"); + }); + offset = line_end + 2; + if size == 0 { + break; + } + + let data_end = offset + size; + assert!( + data_end + 2 <= body.len(), + "chunk declares more bytes than response body contains" + ); + decoded.extend_from_slice(&body[offset..data_end]); + assert_eq!(&body[data_end..data_end + 2], b"\r\n"); + offset = data_end + 2; + } + decoded +} + +fn wait_for_server(addr: &str) { + for _ in 0..100 { + if TcpStream::connect(addr).is_ok() { + let (status, body) = http_request(addr, "GET", "/readyz", None); + if status == 200 && body == json!({ "status": "ok" }) { + return; + } + } + thread::sleep(Duration::from_millis(50)); + } + panic!("server did not become ready at {addr}"); +} + +fn run_http(driver: &str, models_path: &Path, db_path: &Path) { + let bind = free_loopback_addr(); + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-server")); + command + .arg("--models") + .arg(models_path) + .arg("--bind") + .arg(&bind) + .arg("--driver") + .arg(driver) + .arg("--entrypoint") + .arg(DUCKDB_ENTRYPOINT) + .arg("--dbopt") + .arg(format!("path={}", db_path.to_string_lossy())); + let mut child = ChildGuard::new( + command + .spawn() + .expect("sidemantic-server child should spawn"), + ); + wait_for_server(&bind); + + let (status, body) = http_request( + &bind, + "POST", + "/query", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + })), + ); + assert_eq!(status, 200, "{body}"); + assert_eq!(body["row_count"], 2); + assert_revenue_rows(body["rows"].as_array().expect("rows should be an array")); + + let (status, body) = http_request( + &bind, + "POST", + "/sql", + Some(json!({ + "query": "select orders.status, orders.revenue from orders" + })), + ); + assert_eq!(status, 200, "{body}"); + assert_eq!(body["row_count"], 2); + assert_revenue_rows(body["rows"].as_array().expect("rows should be an array")); + assert_eq!( + body["original_sql"], + "select orders.status, orders.revenue from orders" + ); + + let (status, body) = http_request( + &bind, + "POST", + "/raw", + Some(json!({ + "query": "select status, amount from orders order by order_id limit 1" + })), + ); + assert_eq!(status, 200, "{body}"); + assert_eq!(body["row_count"], 1); + assert_eq!(body["rows"][0]["status"], "complete"); + assert_eq!(body["rows"][0]["amount"], 10.5); + + let (status, head, body) = http_arrow_request( + &bind, + "POST", + "/query?format=arrow", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + })), + ); + assert_arrow_response(status, &head, &body, 2, &["status", "revenue"]); + + let (status, head, body) = http_arrow_request( + &bind, + "POST", + "/query?format=arrow&transport=chunked", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + })), + ); + assert_chunked_arrow_response(status, &head, &body, 2, &["status", "revenue"]); + + let (status, head, body) = http_arrow_request( + &bind, + "POST", + "/sql", + Some(json!({ + "query": "select orders.status, orders.revenue from orders" + })), + ); + assert_arrow_response(status, &head, &body, 2, &["status", "revenue"]); + + let (status, head, body) = http_arrow_request( + &bind, + "POST", + "/raw?format=arrow", + Some(json!({ + "query": "select status, amount from orders order by order_id limit 1" + })), + ); + assert_arrow_response(status, &head, &body, 1, &["status", "amount"]); + + let (status, head, body) = http_arrow_request( + &bind, + "POST", + "/raw?format=arrow&stream=true", + Some(json!({ + "query": "select status, amount from orders order by order_id limit 1" + })), + ); + assert_chunked_arrow_response(status, &head, &body, 1, &["status", "amount"]); + + child.kill_and_wait(); +} + +struct McpClient { + child: ChildGuard, + stdin: ChildStdin, + responses: Receiver, +} + +impl McpClient { + fn spawn(models_path: &Path, driver: &str, db_path: &Path) -> Self { + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-mcp")); + command + .arg("--models") + .arg(models_path) + .arg("--driver") + .arg(driver) + .arg("--entrypoint") + .arg(DUCKDB_ENTRYPOINT) + .arg("--dbopt") + .arg(format!("path={}", db_path.to_string_lossy())) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()); + let mut child = command.spawn().expect("sidemantic-mcp should spawn"); + let stdin = child.stdin.take().expect("child stdin should be piped"); + let stdout = child.stdout.take().expect("child stdout should be piped"); + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let reader = BufReader::new(stdout); + for line in reader.lines() { + let line = match line { + Ok(line) => line, + Err(_) => break, + }; + if line.trim().is_empty() { + continue; + } + if let Ok(value) = serde_json::from_str::(&line) { + let _ = tx.send(value); + } + } + }); + + Self { + child: ChildGuard::new(child), + stdin, + responses: rx, + } + } + + fn send(&mut self, message: Value) { + writeln!(self.stdin, "{message}").expect("mcp message should be written"); + self.stdin.flush().expect("mcp stdin should flush"); + } + + fn request(&mut self, id: u64, method: &str, params: Value) -> Value { + self.send(json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params + })); + self.read_response(id) + } + + fn notify(&mut self, method: &str, params: Value) { + self.send(json!({ + "jsonrpc": "2.0", + "method": method, + "params": params + })); + } + + fn read_response(&mut self, id: u64) -> Value { + for _ in 0..20 { + let value = self + .responses + .recv_timeout(Duration::from_secs(5)) + .expect("mcp response should arrive"); + if value["id"] == json!(id) { + return value; + } + } + panic!("mcp response id {id} was not observed"); + } + + fn shutdown(mut self) { + self.child.kill_and_wait(); + } +} + +fn structured_content(response: &Value) -> Value { + if let Some(value) = response["result"]["structuredContent"].as_object() { + return Value::Object(value.clone()); + } + let text = response["result"]["content"][0]["text"] + .as_str() + .expect("tool response should include text content"); + serde_json::from_str(text).expect("tool text content should be JSON") +} + +fn run_mcp(driver: &str, models_path: &Path, db_path: &Path) { + let mut client = McpClient::spawn(models_path, driver, db_path); + + let init = client.request( + 1, + "initialize", + json!({ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "sidemantic-adbc-test", "version": "0.0.0" } + }), + ); + assert_eq!(init["jsonrpc"], "2.0"); + client.notify("notifications/initialized", json!({})); + + let response = client.request( + 2, + "tools/call", + json!({ + "name": "run_query", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + } + }), + ); + let payload = structured_content(&response); + assert_eq!(payload["row_count"], 2); + assert_revenue_rows(payload["rows"].as_array().expect("rows should be an array")); + + let sql_response = client.request( + 3, + "tools/call", + json!({ + "name": "run_sql", + "arguments": { + "query": "select orders.status, orders.revenue from orders" + } + }), + ); + let sql_payload = structured_content(&sql_response); + assert_eq!(sql_payload["row_count"], 2); + assert_revenue_rows( + sql_payload["rows"] + .as_array() + .expect("SQL rows should be an array"), + ); + assert_eq!( + sql_payload["original_sql"], + "select orders.status, orders.revenue from orders" + ); + + let chart_response = client.request( + 4, + "tools/call", + json!({ + "name": "create_chart", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "chart_type": "bar", + "width": 80, + "height": 60 + } + }), + ); + let chart_payload = structured_content(&chart_response); + assert_eq!(chart_payload["row_count"], 2); + assert_eq!(chart_payload["vega_spec"]["title"], "Revenue by Status"); + assert!(chart_payload["vega_spec"]["data"]["values"] + .as_array() + .expect("chart should embed data") + .iter() + .any(|row| row["status"] == "complete")); + assert!(chart_payload["png_base64"] + .as_str() + .expect("chart should include PNG data URL") + .starts_with("data:image/png;base64,")); + + client.shutdown(); +} + +#[test] +fn duckdb_adbc_executes_cli_http_and_mcp() { + let Some(driver) = duckdb_driver_path() else { + eprintln!("skipping DuckDB ADBC E2E; SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER is not set"); + return; + }; + + let dir = unique_temp_dir("sidemantic_adbc_duckdb_e2e"); + let models_path = common::write_retail_fixture(&dir); + let db_path: PathBuf = dir.join("warehouse.duckdb"); + seed_duckdb(&driver, &db_path); + + run_cli(&driver, &models_path, &db_path); + run_http(&driver, &models_path, &db_path); + run_mcp(&driver, &models_path, &db_path); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} diff --git a/sidemantic-rs/tests/c_abi_smoke.c b/sidemantic-rs/tests/c_abi_smoke.c new file mode 100644 index 00000000..b38fe81a --- /dev/null +++ b/sidemantic-rs/tests/c_abi_smoke.c @@ -0,0 +1,152 @@ +#include "sidemantic.h" + +#include +#include +#include +#include + +static void fail(const char *message) { + fprintf(stderr, "%s\n", message); + exit(1); +} + +static void expect_success(char *error) { + if (error != NULL) { + fprintf(stderr, "unexpected sidemantic error: %s\n", error); + sidemantic_free(error); + exit(1); + } +} + +static void expect_error(char *error, const char *fragment) { + if (error == NULL) { + fail("expected sidemantic error, got success"); + } + if (strstr(error, fragment) == NULL) { + fprintf(stderr, "expected error containing '%s', got '%s'\n", fragment, error); + sidemantic_free(error); + exit(1); + } + sidemantic_free(error); +} + +static void expect_rewrite_contains(SidemanticRewriteResult result, const char *fragment) { + if (result.error != NULL) { + fprintf(stderr, "unexpected rewrite error: %s\n", result.error); + sidemantic_free_result(result); + exit(1); + } + if (!result.was_rewritten) { + sidemantic_free_result(result); + fail("expected query to be rewritten"); + } + if (result.sql == NULL || strstr(result.sql, fragment) == NULL) { + fprintf(stderr, "expected rewritten SQL containing '%s', got '%s'\n", fragment, + result.sql == NULL ? "" : result.sql); + sidemantic_free_result(result); + exit(1); + } + sidemantic_free_result(result); +} + +static void expect_passthrough(SidemanticRewriteResult result, const char *sql) { + if (result.error != NULL) { + fprintf(stderr, "unexpected rewrite error: %s\n", result.error); + sidemantic_free_result(result); + exit(1); + } + if (result.was_rewritten) { + sidemantic_free_result(result); + fail("expected passthrough rewrite result"); + } + if (result.sql == NULL || strcmp(result.sql, sql) != 0) { + fprintf(stderr, "expected passthrough SQL '%s', got '%s'\n", sql, + result.sql == NULL ? "" : result.sql); + sidemantic_free_result(result); + exit(1); + } + sidemantic_free_result(result); +} + +int main(void) { + const char *context_a = "c-abi:a"; + const char *context_b = "c-abi:b"; + const char *yaml_a = + "models:\n" + " - name: orders\n" + " table: orders_a\n" + " primary_key: order_id\n" + " metrics:\n" + " - name: revenue\n" + " agg: sum\n" + " sql: amount\n"; + const char *yaml_b = + "models:\n" + " - name: orders\n" + " table: orders_b\n" + " primary_key: order_id\n" + " metrics:\n" + " - name: order_count\n" + " agg: count\n"; + + sidemantic_clear_for_context(context_a); + sidemantic_clear_for_context(context_b); + + expect_success(sidemantic_load_yaml_for_context(context_a, yaml_a)); + expect_success(sidemantic_load_yaml_for_context(context_b, yaml_b)); + + char *models = sidemantic_list_models_for_context(context_a); + if (models == NULL || strstr(models, "orders") == NULL) { + sidemantic_free(models); + fail("expected context A model list to include orders"); + } + sidemantic_free(models); + + if (!sidemantic_is_model_for_context(context_a, "orders")) { + fail("expected context A orders model"); + } + + expect_rewrite_contains( + sidemantic_rewrite_for_context(context_a, "SELECT orders.revenue FROM orders"), + "orders_a"); + expect_rewrite_contains( + sidemantic_rewrite_for_context(context_b, "SELECT orders.order_count FROM orders"), + "orders_b"); + + sidemantic_clear_for_context(context_a); + expect_passthrough( + sidemantic_rewrite_for_context(context_a, "SELECT orders.revenue FROM orders"), + "SELECT orders.revenue FROM orders"); + expect_rewrite_contains( + sidemantic_rewrite_for_context(context_b, "SELECT orders.order_count FROM orders"), + "orders_b"); + + sidemantic_clear(); + expect_success(sidemantic_load_yaml(yaml_a)); + expect_rewrite_contains(sidemantic_rewrite("SELECT orders.revenue FROM orders"), "orders_a"); + + sidemantic_clear_for_context("c-abi:memory"); + expect_success(sidemantic_define_for_context( + "c-abi:memory", "MODEL (name events, table events, primary_key event_id);", NULL, + false)); + expect_success(sidemantic_add_definition_for_context( + "c-abi:memory", "METRIC event_count AS COUNT(*)", NULL, false)); + expect_rewrite_contains( + sidemantic_rewrite_for_context("c-abi:memory", "SELECT events.event_count FROM events"), + "COUNT"); + + expect_error(sidemantic_load_yaml_for_context(context_a, NULL), "null yaml pointer"); + + SidemanticRewriteResult null_sql = sidemantic_rewrite_for_context(context_a, NULL); + if (null_sql.error == NULL || strstr(null_sql.error, "null sql pointer") == NULL) { + sidemantic_free_result(null_sql); + fail("expected null SQL rewrite error"); + } + sidemantic_free_result(null_sql); + + sidemantic_clear_for_context(context_b); + sidemantic_clear_for_context("c-abi:memory"); + sidemantic_clear(); + + return 0; +} diff --git a/sidemantic-rs/tests/cli_smoke.rs b/sidemantic-rs/tests/cli_smoke.rs new file mode 100644 index 00000000..e7bc6218 --- /dev/null +++ b/sidemantic-rs/tests/cli_smoke.rs @@ -0,0 +1,1163 @@ +use std::fs; +use std::path::PathBuf; +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; + +fn unique_temp_dir(prefix: &str) -> PathBuf { + let mut dir = std::env::temp_dir(); + let suffix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be valid") + .as_nanos(); + dir.push(format!("{prefix}_{suffix}")); + fs::create_dir_all(&dir).expect("temp dir should be created"); + dir +} + +fn is_expected_workbench_unavailable_error(stderr: &str) -> bool { + stderr.contains("workbench requires the crate feature 'workbench-tui'") + || stderr.contains("workbench requires an interactive terminal (TTY)") +} + +fn write_retail_fixture(dir: &std::path::Path) -> PathBuf { + let models_path = dir.join("retail.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: customer_id + type: numeric + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count +"#, + ) + .expect("retail fixture should be written"); + models_path +} + +#[test] +fn cli_help_lists_core_commands() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("--help") + .output() + .expect("sidemantic binary should run"); + + assert!( + output.status.success(), + "help command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("compile")); + assert!(stdout.contains("rewrite")); + assert!(stdout.contains("validate")); + assert!(stdout.contains("run")); + assert!(stdout.contains("preagg")); + assert!(stdout.contains("workbench")); + assert!(stdout.contains("serve")); + assert!(stdout.contains("mcp-serve")); + assert!(stdout.contains("lsp")); +} + +#[test] +fn cli_workbench_reports_expected_unavailable_status() { + let dir = unique_temp_dir("sidemantic_cli_workbench_unavailable"); + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("workbench") + .arg(&dir) + .output() + .expect("workbench command should run"); + + assert!( + !output.status.success(), + "workbench command should fail in non-interactive smoke test mode" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + is_expected_workbench_unavailable_error(&stderr), + "unexpected stderr: {stderr}" + ); + assert!( + stderr.contains(&format!("models={}", dir.display())), + "unexpected stderr: {stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_workbench_autodiscovers_data_db_connection() { + let dir = unique_temp_dir("sidemantic_cli_workbench_auto_db"); + let data_dir = dir.join("data"); + fs::create_dir_all(&data_dir).expect("data dir should be created"); + let db_path = data_dir.join("warehouse.db"); + fs::write(&db_path, []).expect("placeholder db file should be created"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("workbench") + .arg(&dir) + .output() + .expect("workbench command should run"); + + assert!( + !output.status.success(), + "workbench command should fail in non-interactive smoke test mode" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + is_expected_workbench_unavailable_error(&stderr), + "unexpected stderr: {stderr}" + ); + let expected_connection = format!("connection=duckdb://{}", db_path.to_string_lossy()); + assert!( + stderr.contains(&expected_connection), + "unexpected stderr: {stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_workbench_demo_resolves_connection() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("workbench") + .arg("--demo") + .output() + .expect("workbench demo command should run"); + + assert!( + !output.status.success(), + "workbench command should fail in non-interactive smoke test mode" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + is_expected_workbench_unavailable_error(&stderr), + "unexpected stderr: {stderr}" + ); + assert!(stderr.contains("demo=true"), "unexpected stderr: {stderr}"); + assert!( + stderr.contains("connection=duckdb:///"), + "unexpected stderr: {stderr}" + ); +} + +#[test] +fn cli_tree_alias_requires_directory_positional() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("tree") + .output() + .expect("tree command should run"); + + assert!( + !output.status.success(), + "tree command should fail without required directory" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("tree requires exactly one positional models directory"), + "unexpected stderr: {stderr}" + ); +} + +#[test] +fn cli_tree_alias_forwards_to_workbench_unavailable_status() { + let dir = unique_temp_dir("sidemantic_cli_tree_alias_unavailable"); + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("tree") + .arg(&dir) + .output() + .expect("tree command should run"); + + assert!( + !output.status.success(), + "tree command should fail in non-interactive smoke test mode" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("tree is deprecated; use 'workbench'"), + "unexpected stderr: {stderr}" + ); + assert!( + is_expected_workbench_unavailable_error(&stderr), + "unexpected stderr: {stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_serve_alias_is_recognized() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("serve") + .output() + .expect("serve alias should run"); + + assert!( + !output.status.success(), + "serve should fail without runtime-server feature in default test build" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("server requires the crate feature 'runtime-server'"), + "unexpected stderr: {stderr}" + ); +} + +#[test] +fn cli_mcp_serve_alias_is_recognized() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("mcp-serve") + .output() + .expect("mcp-serve alias should run"); + + assert!( + !output.status.success(), + "mcp-serve should fail without mcp-server feature in default test build" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("mcp requires the crate feature 'mcp-server'"), + "unexpected stderr: {stderr}" + ); +} + +#[test] +fn cli_compile_accepts_yaml_model_file() { + let dir = unique_temp_dir("sidemantic_cli_smoke"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("compile") + .arg("--models") + .arg(&models_path) + .arg("--metric") + .arg("orders.revenue") + .arg("--dimension") + .arg("orders.status") + .output() + .expect("compile command should run"); + + assert!( + output.status.success(), + "compile command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("SELECT"), "unexpected SQL output: {stdout}"); + assert!(stdout.contains("SUM("), "unexpected SQL output: {stdout}"); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_info_lists_model_summary() { + let dir = unique_temp_dir("sidemantic_cli_info"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("info") + .arg("--models") + .arg(&models_path) + .output() + .expect("info command should run"); + + assert!( + output.status.success(), + "info command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Semantic Layer:")); + assert!(stdout.contains("- orders")); + assert!(stdout.contains("Dimensions: 1")); + assert!(stdout.contains("Metrics: 1")); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_query_dry_run_rewrites_sql() { + let dir = unique_temp_dir("sidemantic_cli_query_dry_run"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("query") + .arg("--models") + .arg(&models_path) + .arg("--sql") + .arg("SELECT orders.revenue, orders.status FROM orders") + .arg("--dry-run") + .output() + .expect("query command should run"); + + assert!( + output.status.success(), + "query command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("SELECT"), "unexpected SQL output: {stdout}"); + assert!(stdout.contains("SUM("), "unexpected SQL output: {stdout}"); + assert!( + stdout.contains("GROUP BY"), + "unexpected SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_rewrite_covers_representative_relationship_query() { + let dir = unique_temp_dir("sidemantic_cli_rewrite_retail"); + let models_path = write_retail_fixture(&dir); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("rewrite") + .arg("--models") + .arg(&models_path) + .arg("--sql") + .arg( + "SELECT orders.revenue, customers.country FROM orders \ + WHERE customers.country = 'US' \ + ORDER BY orders.revenue DESC LIMIT 5", + ) + .output() + .expect("rewrite command should run"); + + assert!( + output.status.success(), + "rewrite command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("SUM("), "unexpected SQL output: {stdout}"); + assert!(stdout.contains("JOIN"), "unexpected SQL output: {stdout}"); + assert!( + stdout.contains("country"), + "unexpected SQL output: {stdout}" + ); + assert!( + stdout.contains("ORDER BY"), + "unexpected SQL output: {stdout}" + ); + assert!( + stdout.contains("LIMIT 5"), + "unexpected SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_compile_covers_top_level_ratio_metric_fixture() { + let dir = unique_temp_dir("sidemantic_cli_compile_ratio"); + let models_path = write_retail_fixture(&dir); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("compile") + .arg("--models") + .arg(&models_path) + .arg("--metric") + .arg("revenue_per_order") + .arg("--dimension") + .arg("customers.country") + .arg("--order-by") + .arg("revenue_per_order DESC") + .arg("--limit") + .arg("10") + .output() + .expect("compile command should run"); + + assert!( + output.status.success(), + "compile command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("SUM("), "unexpected SQL output: {stdout}"); + assert!(stdout.contains("COUNT("), "unexpected SQL output: {stdout}"); + assert!(stdout.contains("JOIN"), "unexpected SQL output: {stdout}"); + assert!( + stdout.contains("ORDER BY"), + "unexpected SQL output: {stdout}" + ); + assert!( + stdout.contains("LIMIT 10"), + "unexpected SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_apply_writes_recommendations_to_model_files() { + let dir = unique_temp_dir("sidemantic_cli_preagg_apply"); + let models_dir = dir.join("models"); + fs::create_dir_all(&models_dir).expect("models dir should be created"); + let models_path = models_dir.join("orders.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: created_at + type: time + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + let queries_path = dir.join("queries.sql"); + fs::write( + &queries_path, + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.created_at,orders.status granularities=day\n\ +select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.created_at,orders.status granularities=day\n", + ) + .expect("queries file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("apply") + .arg("--models") + .arg(&models_dir) + .arg("--queries-file") + .arg(&queries_path) + .arg("--min-query-count") + .arg("1") + .arg("--min-benefit-score") + .arg("0") + .output() + .expect("preagg apply command should run"); + + assert!( + output.status.success(), + "preagg apply command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let updated = fs::read_to_string(&models_path).expect("updated model file should be readable"); + assert!( + updated.contains("pre_aggregations"), + "missing pre_aggregations in updated model file: {updated}" + ); + assert!( + updated.contains("day_created_at_status_revenue"), + "missing expected generated pre-aggregation name: {updated}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_apply_dry_run_does_not_modify_model_files() { + let dir = unique_temp_dir("sidemantic_cli_preagg_apply_dry_run"); + let models_dir = dir.join("models"); + fs::create_dir_all(&models_dir).expect("models dir should be created"); + let models_path = models_dir.join("orders.yml"); + let original = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: created_at + type: time + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#; + fs::write(&models_path, original).expect("models file should be written"); + let queries_path = dir.join("queries.sql"); + fs::write( + &queries_path, + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.created_at,orders.status granularities=day\n\ +select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.created_at,orders.status granularities=day\n", + ) + .expect("queries file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("apply") + .arg("--models") + .arg(&models_dir) + .arg("--queries-file") + .arg(&queries_path) + .arg("--min-query-count") + .arg("1") + .arg("--min-benefit-score") + .arg("0") + .arg("--dry-run") + .output() + .expect("preagg apply dry-run command should run"); + + assert!( + output.status.success(), + "preagg apply dry-run command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let updated = fs::read_to_string(&models_path).expect("model file should be readable"); + assert_eq!(updated, original); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_recommend_connection_mode_requires_adbc_feature() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("recommend") + .arg("--connection") + .arg("snowflake://account/db/schema") + .output() + .expect("preagg recommend command should run"); + + assert!( + !output.status.success(), + "preagg recommend connection mode should fail in smoke environment" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + if cfg!(feature = "adbc-exec") { + assert!( + stderr.contains("preagg recommend: failed to fetch query history via ADBC"), + "unexpected stderr: {stderr}" + ); + } else { + assert!( + stderr.contains( + "preagg recommend: query-history mode requires the crate feature 'adbc-exec'" + ), + "unexpected stderr: {stderr}" + ); + } +} + +#[test] +fn cli_preagg_apply_connection_mode_requires_adbc_feature() { + let dir = unique_temp_dir("sidemantic_cli_preagg_apply_connection_mode"); + let models_dir = dir.join("models"); + fs::create_dir_all(&models_dir).expect("models dir should be created"); + let models_path = models_dir.join("orders.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("apply") + .arg("--models") + .arg(&models_dir) + .arg("--connection") + .arg("snowflake://account/db/schema") + .output() + .expect("preagg apply command should run"); + + assert!( + !output.status.success(), + "preagg apply connection mode should fail in smoke environment" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + if cfg!(feature = "adbc-exec") { + assert!( + stderr.contains("preagg apply: failed to fetch query history via ADBC"), + "unexpected stderr: {stderr}" + ); + } else { + assert!( + stderr.contains( + "preagg apply: query-history mode requires the crate feature 'adbc-exec'" + ), + "unexpected stderr: {stderr}" + ); + } + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_migrator_generate_models_bootstrap_writes_outputs() { + let dir = unique_temp_dir("sidemantic_cli_migrator_bootstrap"); + let queries_path = dir.join("queries.sql"); + fs::write( + &queries_path, + "select orders.status, sum(orders.amount) as revenue from orders group by orders.status\n", + ) + .expect("queries file should be written"); + let output_dir = dir.join("generated"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("migrator") + .arg("--queries") + .arg(&queries_path) + .arg("--generate-models") + .arg(&output_dir) + .output() + .expect("migrator bootstrap command should run"); + + assert!( + output.status.success(), + "migrator bootstrap command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + + let model_file = output_dir.join("models").join("orders.yml"); + assert!(model_file.exists(), "expected generated model file"); + let model_content = fs::read_to_string(&model_file).expect("model file should be readable"); + assert!( + model_content.contains("model:"), + "unexpected model YAML: {model_content}" + ); + assert!( + model_content.contains("name: orders"), + "unexpected model YAML: {model_content}" + ); + assert!( + model_content.contains("sum_amount"), + "missing expected inferred metric: {model_content}" + ); + + let rewritten_file = output_dir.join("rewritten_queries").join("query_1.sql"); + assert!(rewritten_file.exists(), "expected rewritten query file"); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_migrator_coverage_mode_reports_totals() { + let dir = unique_temp_dir("sidemantic_cli_migrator_coverage"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + let queries_path = dir.join("queries.sql"); + fs::write( + &queries_path, + "select orders.status, sum(orders.amount) as revenue from orders group by orders.status\n", + ) + .expect("queries file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("migrator") + .arg("--models") + .arg(&models_path) + .arg("--queries") + .arg(&queries_path) + .output() + .expect("migrator coverage command should run"); + + assert!( + output.status.success(), + "migrator coverage command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("Total Queries: 1"), + "unexpected output: {stdout}" + ); + assert!( + stdout.contains("Rewritable: 1"), + "unexpected output: {stdout}" + ); + assert!( + stdout.contains("Coverage: 100.0%"), + "unexpected output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_validate_definition_mode_with_verbose_uses_positional_models_path() { + let dir = unique_temp_dir("sidemantic_cli_validate_verbose"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("validate") + .arg(&models_path) + .arg("--verbose") + .output() + .expect("validate command should run"); + + assert!( + output.status.success(), + "validate command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("Validation passed"), + "unexpected output: {stdout}" + ); + assert!(stdout.contains("Models: 1"), "unexpected output: {stdout}"); + assert!(stdout.contains("- orders"), "unexpected output: {stdout}"); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_validate_query_reference_mode_remains_supported() { + let dir = unique_temp_dir("sidemantic_cli_validate_refs"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("validate") + .arg("--models") + .arg(&models_path) + .arg("--metric") + .arg("orders.revenue") + .arg("--dimension") + .arg("orders.status") + .output() + .expect("validate command should run"); + + assert!( + output.status.success(), + "validate reference mode failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("ok"), "unexpected output: {stdout}"); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_refresh_generates_sql_plan() { + let dir = unique_temp_dir("sidemantic_cli_preagg_refresh"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("refresh") + .arg("--models") + .arg(&models_path) + .arg("--model") + .arg("orders") + .arg("--name") + .arg("daily_revenue") + .arg("--mode") + .arg("incremental") + .arg("--from-watermark") + .arg("2026-01-01") + .output() + .expect("preagg refresh command should run"); + + assert!( + output.status.success(), + "preagg refresh command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("INSERT INTO orders_preagg_daily_revenue"), + "unexpected refresh SQL output: {stdout}" + ); + assert!( + stdout.contains("order_date_day >= '2026-01-01'"), + "unexpected refresh SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_refresh_engine_mode_generates_bigquery_ddl() { + let dir = unique_temp_dir("sidemantic_cli_preagg_engine"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("refresh") + .arg("--models") + .arg(&models_path) + .arg("--model") + .arg("orders") + .arg("--name") + .arg("daily_revenue") + .arg("--mode") + .arg("engine") + .arg("--dialect") + .arg("bigquery") + .arg("--refresh-every") + .arg("2 hours") + .output() + .expect("preagg refresh engine command should run"); + + assert!( + output.status.success(), + "preagg refresh engine command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("CREATE MATERIALIZED VIEW IF NOT EXISTS orders_preagg_daily_revenue"), + "unexpected engine refresh SQL output: {stdout}" + ); + assert!( + stdout.contains("refresh_interval_minutes = 120"), + "unexpected engine refresh SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_refresh_engine_mode_requires_dialect() { + let dir = unique_temp_dir("sidemantic_cli_preagg_engine_missing_dialect"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("refresh") + .arg("--models") + .arg(&models_path) + .arg("--model") + .arg("orders") + .arg("--name") + .arg("daily_revenue") + .arg("--mode") + .arg("engine") + .output() + .expect("preagg refresh engine command should run"); + + assert!( + !output.status.success(), + "preagg refresh engine command should fail without dialect" + ); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("engine refresh mode requires --dialect"), + "unexpected stderr: {stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_refresh_defaults_to_incremental_when_refresh_key_is_incremental() { + let dir = unique_temp_dir("sidemantic_cli_preagg_default_incremental"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] + refresh_key: + incremental: true +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("refresh") + .arg("--models") + .arg(&models_path) + .arg("--model") + .arg("orders") + .arg("--name") + .arg("daily_revenue") + .output() + .expect("preagg refresh command should run"); + + assert!( + output.status.success(), + "preagg refresh command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("mode=incremental"), + "unexpected refresh mode output: {stdout}" + ); + assert!( + stdout.contains("INSERT INTO orders_preagg_daily_revenue"), + "unexpected refresh SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn cli_preagg_refresh_defaults_to_full_without_incremental_refresh_key() { + let dir = unique_temp_dir("sidemantic_cli_preagg_default_full"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#, + ) + .expect("models file should be written"); + + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic")) + .arg("preagg") + .arg("refresh") + .arg("--models") + .arg(&models_path) + .arg("--model") + .arg("orders") + .arg("--name") + .arg("daily_revenue") + .output() + .expect("preagg refresh command should run"); + + assert!( + output.status.success(), + "preagg refresh command failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!( + stdout.contains("mode=full"), + "unexpected refresh mode output: {stdout}" + ); + assert!( + stdout.contains("DELETE FROM orders_preagg_daily_revenue"), + "unexpected refresh SQL output: {stdout}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} diff --git a/sidemantic-rs/tests/common/mod.rs b/sidemantic-rs/tests/common/mod.rs new file mode 100644 index 00000000..bc4417f8 --- /dev/null +++ b/sidemantic-rs/tests/common/mod.rs @@ -0,0 +1,121 @@ +#![allow(dead_code)] + +use std::fs; +use std::net::TcpListener; +use std::path::{Path, PathBuf}; +use std::process::{Child, Command, Stdio}; +use std::time::{SystemTime, UNIX_EPOCH}; + +pub fn unique_temp_dir(prefix: &str) -> PathBuf { + let mut dir = std::env::temp_dir(); + let suffix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be valid") + .as_nanos(); + dir.push(format!("{prefix}_{suffix}")); + fs::create_dir_all(&dir).expect("temp dir should be created"); + dir +} + +pub fn write_retail_fixture(dir: &Path) -> PathBuf { + let models_path = dir.join("retail.yml"); + fs::write( + &models_path, + r#" +parameters: + - name: status_param + type: string + - name: min_customer_id + type: number +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: customer_id + type: numeric + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count + segments: + - name: completed + sql: "{model}.status = 'complete'" + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count +"#, + ) + .expect("retail fixture should be written"); + models_path +} + +#[allow(dead_code)] +pub fn free_loopback_addr() -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("free port should be allocated"); + listener + .local_addr() + .expect("local addr should be available") + .to_string() +} + +pub struct ChildGuard { + child: Child, +} + +impl ChildGuard { + pub fn new(child: Child) -> Self { + Self { child } + } + + pub fn kill_and_wait(&mut self) { + let _ = self.child.kill(); + let _ = self.child.wait(); + } +} + +impl Drop for ChildGuard { + fn drop(&mut self) { + self.kill_and_wait(); + } +} + +pub fn command_with_clean_sidemantic_env(binary: &str) -> Command { + let mut command = Command::new(binary); + command + .env_remove("SIDEMANTIC_ADBC_DBOPTS") + .env_remove("SIDEMANTIC_ADBC_CONNOPTS") + .env_remove("SIDEMANTIC_MCP_ADBC_DRIVER") + .env_remove("SIDEMANTIC_MCP_ADBC_URI") + .env_remove("SIDEMANTIC_MCP_ADBC_ENTRYPOINT") + .env_remove("SIDEMANTIC_MCP_ADBC_DBOPTS") + .env_remove("SIDEMANTIC_MCP_ADBC_CONNOPTS") + .env_remove("SIDEMANTIC_SERVER_ADBC_DRIVER") + .env_remove("SIDEMANTIC_SERVER_ADBC_URI") + .env_remove("SIDEMANTIC_SERVER_ADBC_ENTRYPOINT") + .env_remove("SIDEMANTIC_SERVER_ADBC_DBOPTS") + .env_remove("SIDEMANTIC_SERVER_ADBC_CONNOPTS") + .env_remove("SIDEMANTIC_SERVER_AUTH_TOKEN") + .env_remove("SIDEMANTIC_SERVER_CORS_ORIGINS") + .env_remove("SIDEMANTIC_SERVER_MAX_REQUEST_BODY_BYTES") + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::piped()); + command +} diff --git a/sidemantic-rs/tests/http_server.rs b/sidemantic-rs/tests/http_server.rs new file mode 100644 index 00000000..ac5759f3 --- /dev/null +++ b/sidemantic-rs/tests/http_server.rs @@ -0,0 +1,452 @@ +#![cfg(feature = "runtime-server")] + +mod common; + +use std::fs; +use std::io::{Read, Write}; +use std::net::TcpStream; +use std::thread; +use std::time::Duration; + +use common::{command_with_clean_sidemantic_env, free_loopback_addr, unique_temp_dir}; +use serde_json::{json, Value}; + +fn http_request(addr: &str, method: &str, path: &str, body: Option) -> (u16, Value, String) { + http_request_with_headers(addr, method, path, body, &[]) +} + +fn http_request_with_headers( + addr: &str, + method: &str, + path: &str, + body: Option, + headers: &[(&str, &str)], +) -> (u16, Value, String) { + let body = body.map(|value| value.to_string()); + let mut request = format!( + "{method} {path} HTTP/1.1\r\nHost: {addr}\r\nConnection: close\r\nAccept: application/json\r\n" + ); + for (name, value) in headers { + request.push_str(&format!("{name}: {value}\r\n")); + } + if let Some(body) = body.as_ref() { + request.push_str("Content-Type: application/json\r\n"); + request.push_str(&format!("Content-Length: {}\r\n", body.len())); + } + request.push_str("\r\n"); + if let Some(body) = body.as_ref() { + request.push_str(body); + } + + let mut stream = TcpStream::connect(addr).expect("server should accept connections"); + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .expect("read timeout should be set"); + stream + .write_all(request.as_bytes()) + .expect("request should be written"); + + let mut response = String::new(); + stream + .read_to_string(&mut response) + .expect("response should be read"); + let (head, body) = response + .split_once("\r\n\r\n") + .expect("response should contain headers and body"); + let status = head + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .and_then(|code| code.parse::().ok()) + .expect("status code should be parsed"); + let json_body = serde_json::from_str(body.trim()).unwrap_or_else(|err| { + panic!("response body should be JSON: {err}\nraw response:\n{response}") + }); + (status, json_body, response) +} + +fn wait_for_server(addr: &str) { + for _ in 0..100 { + if TcpStream::connect(addr).is_ok() { + let (status, body, _) = http_request(addr, "GET", "/readyz", None); + if status == 200 && body == json!({ "status": "ok" }) { + return; + } + } + thread::sleep(Duration::from_millis(50)); + } + panic!("server did not become ready at {addr}"); +} + +#[test] +fn http_server_fails_fast_for_missing_models_path() { + let dir = unique_temp_dir("sidemantic_http_server_bad_models"); + let missing = dir.join("missing.yml"); + let bind = free_loopback_addr(); + + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-server")); + let output = command + .arg("--models") + .arg(&missing) + .arg("--bind") + .arg(&bind) + .output() + .expect("sidemantic-server should run to a startup error"); + + assert!(!output.status.success(), "missing models path should fail"); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("models path") && stderr.contains("is not a readable file or directory"), + "{stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn http_server_exercises_real_endpoints_and_errors() { + let dir = unique_temp_dir("sidemantic_http_server"); + let models_path = common::write_retail_fixture(&dir); + let bind = free_loopback_addr(); + + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-server")); + command + .arg("--models") + .arg(&models_path) + .arg("--bind") + .arg(&bind); + let mut child = common::ChildGuard::new( + command + .spawn() + .expect("sidemantic-server child should spawn"), + ); + wait_for_server(&bind); + + let (status, body, _) = http_request(&bind, "GET", "/health", None); + assert_eq!(status, 200); + assert_eq!(body["status"], "ok"); + assert_eq!(body["model_count"], 2); + + let (status, body, _) = http_request(&bind, "GET", "/readyz", None); + assert_eq!(status, 200); + assert_eq!(body, json!({ "status": "ok" })); + + let (status, body, _) = http_request(&bind, "GET", "/models", None); + assert_eq!(status, 200); + assert!(body + .as_array() + .expect("models response should be an array") + .iter() + .any(|model| model["name"] == "orders")); + + let (status, body, _) = http_request(&bind, "GET", "/graph", None); + assert_eq!(status, 200); + assert!(body["models"] + .as_array() + .expect("graph models should be an array") + .iter() + .any(|model| model["name"] == "orders")); + assert!(body["joinable_pairs"] + .as_array() + .expect("joinable pairs should be an array") + .iter() + .any(|pair| pair["from"] == "orders" && pair["to"] == "customers")); + + let (status, body, _) = http_request(&bind, "GET", "/models/orders", None); + assert_eq!(status, 200); + assert_eq!(body["name"], "orders"); + assert!(body["relationships"].to_string().contains("customers")); + + let (status, body, _) = http_request(&bind, "GET", "/models/missing", None); + assert_eq!(status, 404); + assert_eq!(body["error"], "model not found: missing"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/models", + Some(json!({ "model_names": ["orders"] })), + ); + assert_eq!(status, 200); + assert_eq!( + body.as_array() + .expect("filtered models should be array") + .len(), + 1 + ); + assert_eq!(body[0]["name"], "orders"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/query/compile", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "limit": 5 + })), + ); + assert_eq!(status, 200); + let sql = body["sql"] + .as_str() + .expect("compile response should include sql"); + assert!(sql.contains("SUM"), "{sql}"); + assert!(sql.contains("GROUP BY"), "{sql}"); + assert!(sql.contains("LIMIT 5"), "{sql}"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/compile", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "filters": ["orders.status = 'complete'"], + "where": "orders.customer_id > 0", + "order_by": ["orders.revenue desc"], + "limit": 5, + "offset": 2, + "ungrouped": false + })), + ); + assert_eq!(status, 200); + let sql = body["sql"] + .as_str() + .expect("compile alias response should include sql"); + assert!(sql.contains("WHERE"), "{sql}"); + assert!( + sql.to_ascii_uppercase().contains("ORDER BY REVENUE DESC"), + "{sql}" + ); + assert!(sql.contains("LIMIT 5"), "{sql}"); + assert!(sql.contains("OFFSET 2"), "{sql}"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/compile", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "segments": ["orders.completed"], + "filters": ["orders.status = {{ status_param }}"], + "parameters": { "status_param": "complete" }, + "order_by": ["orders.status asc"], + "use_preaggregations": true, + "limit": 3 + })), + ); + assert_eq!(status, 200); + let sql = body["sql"] + .as_str() + .expect("compile response should include segment/parameter sql"); + assert!(sql.contains("status = 'complete'"), "{sql}"); + assert!( + sql.to_ascii_uppercase().contains("ORDER BY STATUS ASC"), + "{sql}" + ); + assert!(sql.contains("LIMIT 3"), "{sql}"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/compile", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "ungrouped": true + })), + ); + assert_eq!(status, 200); + let sql = body["sql"] + .as_str() + .expect("ungrouped compile response should include sql"); + assert!(!sql.to_ascii_uppercase().contains("GROUP BY"), "{sql}"); + + let (status, body, _) = http_request( + &bind, + "POST", + "/query/compile", + Some(json!({ "metrics": ["orders.unknown_metric"] })), + ); + assert_eq!(status, 400); + assert!(body["error"] + .as_str() + .unwrap_or("") + .contains("failed to compile query")); + + let (status, body, _) = http_request( + &bind, + "POST", + "/query", + Some(json!({ "metrics": ["orders.revenue"] })), + ); + assert_eq!(status, 400); + assert!(body["error"] + .as_str() + .unwrap_or("") + .contains("runtime-server-adbc")); + + let (status, body, _) = http_request( + &bind, + "POST", + "/query?format=xml", + Some(json!({ "metrics": ["orders.revenue"] })), + ); + assert_eq!(status, 400); + assert!(body["error"] + .as_str() + .unwrap_or("") + .contains("unsupported response format 'xml'")); + + let (status, body, _) = http_request( + &bind, + "POST", + "/sql/compile", + Some(json!({ "query": "select orders.status, orders.revenue from orders" })), + ); + assert_eq!(status, 200); + assert!(body["sql"].as_str().unwrap_or("").contains("orders_cte")); + + let (status, body, _) = http_request( + &bind, + "POST", + "/raw", + Some(json!({ "query": "delete from orders" })), + ); + assert_eq!(status, 400); + assert!(body["error"].as_str().unwrap_or("").contains("SELECT")); + + child.kill_and_wait(); + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn http_server_enforces_auth_cors_and_body_limit() { + let dir = unique_temp_dir("sidemantic_http_server_controls"); + let models_path = common::write_retail_fixture(&dir); + let bind = free_loopback_addr(); + + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-server")); + command + .arg("--models") + .arg(&models_path) + .arg("--bind") + .arg(&bind) + .arg("--auth-token") + .arg("secret") + .arg("--cors-origin") + .arg("https://app.example.com") + .arg("--max-request-body-bytes") + .arg("32"); + let mut child = common::ChildGuard::new( + command + .spawn() + .expect("sidemantic-server child should spawn"), + ); + wait_for_server(&bind); + + let (status, body, _) = http_request(&bind, "GET", "/readyz", None); + assert_eq!(status, 200); + assert_eq!(body, json!({ "status": "ok" })); + + let (status, body, response) = http_request(&bind, "GET", "/health", None); + assert_eq!(status, 401); + assert_eq!(body["error"], "Unauthorized"); + assert!(response.to_ascii_lowercase().contains("www-authenticate")); + + let (status, body, response) = http_request_with_headers( + &bind, + "GET", + "/health", + None, + &[ + ("Authorization", "Bearer secret"), + ("Origin", "https://app.example.com"), + ], + ); + assert_eq!(status, 200); + assert_eq!(body["status"], "ok"); + assert!(response + .to_ascii_lowercase() + .contains("access-control-allow-origin: https://app.example.com")); + + let (status, body, _) = http_request_with_headers( + &bind, + "POST", + "/compile", + Some(json!({ + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "filters": ["orders.status = 'complete'"] + })), + &[("Authorization", "Bearer secret")], + ); + assert_eq!(status, 413); + assert!(body["error"] + .as_str() + .unwrap_or("") + .contains("Request body exceeds 32 bytes")); + + child.kill_and_wait(); + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn http_server_interpolates_query_parameters() { + let dir = unique_temp_dir("sidemantic_http_server_parameters"); + let models_path = dir.join("models.yml"); + fs::write( + &models_path, + r#" +parameters: + - name: status + type: string +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +"#, + ) + .expect("parameter fixture should be written"); + let bind = free_loopback_addr(); + + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-server")); + command + .arg("--models") + .arg(&models_path) + .arg("--bind") + .arg(&bind); + let mut child = common::ChildGuard::new( + command + .spawn() + .expect("sidemantic-server child should spawn"), + ); + wait_for_server(&bind); + + let (status, body, _) = http_request( + &bind, + "POST", + "/compile", + Some(json!({ + "metrics": ["orders.revenue"], + "filters": ["orders.status = {{ status }}"], + "parameters": { "status": "complete" } + })), + ); + assert_eq!(status, 200, "{body}"); + let sql = body["sql"] + .as_str() + .expect("compile response should include sql"); + assert!(sql.contains("status = 'complete'"), "{sql}"); + + child.kill_and_wait(); + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} diff --git a/sidemantic-rs/tests/lsp_protocol_smoke.py b/sidemantic-rs/tests/lsp_protocol_smoke.py new file mode 100644 index 00000000..7e023156 --- /dev/null +++ b/sidemantic-rs/tests/lsp_protocol_smoke.py @@ -0,0 +1,355 @@ +import json +import queue +import subprocess +import sys +import threading +import time + + +def send_frame(proc, message): + body = json.dumps(message, separators=(",", ":")).encode() + proc.stdin.write(f"Content-Length: {len(body)}\r\n\r\n".encode()) + proc.stdin.write(body) + proc.stdin.flush() + + +def read_frames(proc, messages): + stream = proc.stdout + while True: + content_length = None + while True: + line = stream.readline() + if not line: + return + if line in (b"\r\n", b"\n"): + break + header, _, value = line.decode().partition(":") + if header.lower() == "content-length": + content_length = int(value.strip()) + if content_length is None: + continue + body = stream.read(content_length) + if not body: + return + messages.put(json.loads(body)) + + +class LspClient: + def __init__(self, binary): + self.proc = subprocess.Popen( + [binary], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + self.messages = queue.Queue() + self.reader = threading.Thread(target=read_frames, args=(self.proc, self.messages), daemon=True) + self.reader.start() + + def close(self): + if self.proc.poll() is None: + self.proc.kill() + self.proc.wait(timeout=5) + + def request(self, request_id, method, params): + message = {"jsonrpc": "2.0", "id": request_id, "method": method} + if params is not None: + message["params"] = params + send_frame(self.proc, message) + return self.wait_for(lambda msg: msg.get("id") == request_id, f"id {request_id}") + + def notify(self, method, params): + message = {"jsonrpc": "2.0", "method": method} + if params is not None: + message["params"] = params + send_frame(self.proc, message) + + def wait_for_notification(self, method): + return self.wait_for(lambda msg: msg.get("method") == method, method) + + def wait_for(self, predicate, label): + deadline = time.time() + 5 + seen = [] + while time.time() < deadline: + try: + msg = self.messages.get(timeout=0.25) + except queue.Empty: + if self.proc.poll() is not None: + stderr = self.proc.stderr.read().decode(errors="replace") + raise AssertionError( + f"LSP process exited while waiting for {label}; stderr={stderr!r}; seen={seen!r}" + ) + continue + if predicate(msg): + return msg + seen.append(msg) + raise AssertionError(f"timed out waiting for {label}; seen={seen!r}") + + +def completion_labels(response): + return [item["label"] for item in response["result"]] + + +def hover_value(response): + contents = response["result"]["contents"] + assert contents["kind"] == "markdown", response + return contents["value"] + + +def main(): + if len(sys.argv) != 2: + raise SystemExit("usage: lsp_protocol_smoke.py ") + client = LspClient(sys.argv[1]) + uri = "file:///tmp/sidemantic-lsp-smoke.semantic.sql" + try: + initialized = client.request( + 1, + "initialize", + { + "processId": None, + "rootUri": None, + "capabilities": {"textDocument": {"publishDiagnostics": {}}}, + }, + ) + capabilities = initialized["result"]["capabilities"] + assert capabilities["textDocumentSync"] == 1, capabilities + assert isinstance(capabilities["completionProvider"], dict), capabilities + assert capabilities["hoverProvider"] is True, capabilities + assert capabilities["documentFormattingProvider"] is True, capabilities + assert capabilities["documentSymbolProvider"] is True, capabilities + assert capabilities["definitionProvider"] is True, capabilities + assert capabilities["referencesProvider"] is True, capabilities + assert capabilities["renameProvider"] is True, capabilities + assert isinstance(capabilities["signatureHelpProvider"], dict), capabilities + assert capabilities["codeActionProvider"] is True, capabilities + client.notify("initialized", {}) + client.wait_for_notification("window/logMessage") + + client.notify( + "textDocument/didOpen", + { + "textDocument": { + "uri": uri, + "languageId": "sidemantic", + "version": 1, + "text": "MODEL (name orders, table orders, primary_key order_id);", + } + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + assert diagnostics["params"]["uri"] == uri, diagnostics + assert diagnostics["params"]["diagnostics"] == [], diagnostics + + keyword_hover = client.request( + 2, + "textDocument/hover", + {"textDocument": {"uri": uri}, "position": {"line": 0, "character": 2}}, + ) + assert "Top-level model definition" in hover_value(keyword_hover), keyword_hover + + property_hover = client.request( + 3, + "textDocument/hover", + {"textDocument": {"uri": uri}, "position": {"line": 0, "character": 9}}, + ) + assert "Unique model name" in hover_value(property_hover), property_hover + + symbols = client.request( + 9, + "textDocument/documentSymbol", + {"textDocument": {"uri": uri}}, + ) + assert [symbol["name"] for symbol in symbols["result"]] == ["orders"], symbols + + signature = client.request( + 10, + "textDocument/signatureHelp", + {"textDocument": {"uri": uri}, "position": {"line": 0, "character": 2}}, + ) + assert signature["result"]["signatures"][0]["label"].startswith("MODEL("), signature + + formatting = client.request( + 11, + "textDocument/formatting", + {"textDocument": {"uri": uri}, "options": {"tabSize": 4, "insertSpaces": True}}, + ) + assert formatting["result"], formatting + assert "MODEL (\n name orders," in formatting["result"][0]["newText"], formatting + + client.notify( + "textDocument/didChange", + { + "textDocument": {"uri": uri, "version": 2}, + "contentChanges": [{"text": "MODEL ("}], + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + diagnostic = diagnostics["params"]["diagnostics"][0] + assert diagnostic["severity"] == 1, diagnostic + assert diagnostic["source"] == "sidemantic-rs", diagnostic + assert diagnostic["message"].startswith("Parse error:"), diagnostic + + client.notify("textDocument/didSave", {"textDocument": {"uri": uri}}) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + diagnostic = diagnostics["params"]["diagnostics"][0] + assert diagnostic["message"].startswith("Parse error:"), diagnostic + + client.notify( + "textDocument/didChange", + { + "textDocument": {"uri": uri, "version": 3}, + "contentChanges": [{"text": "MODEL (name orders, table orders, primary_key order_id);\n\n"}], + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + assert diagnostics["params"]["diagnostics"] == [], diagnostics + + rich_text = ( + "MODEL (name orders, table order_items, primary_key order_id);\n\n" + "METRIC (name revenue, model orders, sql amount, agg sum);\n" + ) + client.notify( + "textDocument/didChange", + { + "textDocument": {"uri": uri, "version": 31}, + "contentChanges": [{"text": rich_text}], + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + assert diagnostics["params"]["diagnostics"] == [], diagnostics + + definition = client.request( + 12, + "textDocument/definition", + {"textDocument": {"uri": uri}, "position": {"line": 2, "character": 28}}, + ) + assert definition["result"]["range"]["start"]["line"] == 0, definition + + refs = client.request( + 13, + "textDocument/references", + { + "textDocument": {"uri": uri}, + "position": {"line": 0, "character": 12}, + "context": {"includeDeclaration": False}, + }, + ) + assert len(refs["result"]) == 1, refs + assert refs["result"][0]["range"]["start"]["line"] == 2, refs + + rename = client.request( + 14, + "textDocument/rename", + { + "textDocument": {"uri": uri}, + "position": {"line": 0, "character": 12}, + "newName": "sales_orders", + }, + ) + edits = rename["result"]["changes"][uri] + assert len(edits) == 2, rename + assert all(edit["newText"] == "sales_orders" for edit in edits), rename + + client.notify( + "textDocument/didChange", + { + "textDocument": {"uri": uri, "version": 32}, + "contentChanges": [{"text": "MODEL (name orders, table orders, primary_key order_id);\n\n"}], + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + assert diagnostics["params"]["diagnostics"] == [], diagnostics + + top_level = client.request( + 4, + "textDocument/completion", + {"textDocument": {"uri": uri}, "position": {"line": 2, "character": 0}}, + ) + labels = completion_labels(top_level) + for expected in ["MODEL", "DIMENSION", "METRIC", "RELATIONSHIP", "SEGMENT"]: + assert expected in labels, labels + + client.notify( + "textDocument/didChange", + { + "textDocument": {"uri": uri, "version": 4}, + "contentChanges": [{"text": "MODEL (\n \n);"}], + }, + ) + client.wait_for_notification("textDocument/publishDiagnostics") + inside_model = client.request( + 5, + "textDocument/completion", + {"textDocument": {"uri": uri}, "position": {"line": 1, "character": 4}}, + ) + labels = completion_labels(inside_model) + for expected in ["name", "table", "primary_key", "default_time_dimension", "sql"]: + assert expected in labels, labels + + code_actions = client.request( + 15, + "textDocument/codeAction", + { + "textDocument": {"uri": uri}, + "range": {"start": {"line": 0, "character": 0}, "end": {"line": 1, "character": 0}}, + "context": { + "diagnostics": [ + { + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 10}, + }, + "message": "name field required", + "severity": 1, + "source": "sidemantic-rs", + } + ] + }, + }, + ) + assert code_actions["result"][0]["title"] == "Add missing name property", code_actions + + unopened = client.request( + 6, + "textDocument/completion", + { + "textDocument": {"uri": "file:///tmp/unopened.semantic.sql"}, + "position": {"line": 0, "character": 0}, + }, + ) + assert unopened["result"] == [], unopened + + yaml_uri = "file:///tmp/sidemantic-lsp-smoke.yml" + client.notify( + "textDocument/didOpen", + { + "textDocument": { + "uri": yaml_uri, + "languageId": "yaml", + "version": 1, + "text": "models:\n - name: orders\n", + } + }, + ) + diagnostics = client.wait_for_notification("textDocument/publishDiagnostics") + assert diagnostics["params"]["uri"] == yaml_uri, diagnostics + diagnostic = diagnostics["params"]["diagnostics"][0] + assert diagnostic["message"].startswith("Parse error:"), diagnostic + + invalid_method = client.request(7, "sidemantic/unsupported", {}) + assert invalid_method["error"]["code"] == -32601, invalid_method + + shutdown = client.request(8, "shutdown", None) + assert "result" in shutdown, shutdown + client.notify("exit", None) + client.proc.stdin.close() + deadline = time.time() + 5 + while time.time() < deadline and client.proc.poll() is None: + time.sleep(0.05) + assert client.proc.poll() is not None, "LSP process did not exit after shutdown" + finally: + client.close() + + +if __name__ == "__main__": + main() diff --git a/sidemantic-rs/tests/lsp_protocol_smoke.rs b/sidemantic-rs/tests/lsp_protocol_smoke.rs new file mode 100644 index 00000000..c284a666 --- /dev/null +++ b/sidemantic-rs/tests/lsp_protocol_smoke.rs @@ -0,0 +1,24 @@ +#![cfg(feature = "runtime-lsp")] + +use std::path::PathBuf; +use std::process::Command; + +#[test] +fn lsp_stdio_protocol_exercises_diagnostics_completion_and_shutdown() { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let script = manifest_dir.join("tests/lsp_protocol_smoke.py"); + let output = Command::new("uv") + .arg("run") + .arg("--no-project") + .arg(&script) + .arg(env!("CARGO_BIN_EXE_sidemantic-lsp")) + .output() + .expect("uv should run the LSP protocol smoke harness"); + + assert!( + output.status.success(), + "LSP protocol smoke failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); +} diff --git a/sidemantic-rs/tests/mcp_protocol.rs b/sidemantic-rs/tests/mcp_protocol.rs new file mode 100644 index 00000000..af4e3b85 --- /dev/null +++ b/sidemantic-rs/tests/mcp_protocol.rs @@ -0,0 +1,473 @@ +#![cfg(feature = "mcp-server")] + +mod common; + +use std::fs; +use std::io::{BufRead, BufReader, Write}; +use std::process::{ChildStdin, Command, Stdio}; +use std::sync::mpsc::{self, Receiver}; +use std::thread; +use std::time::Duration; + +use common::{command_with_clean_sidemantic_env, unique_temp_dir, ChildGuard}; +use serde_json::{json, Value}; + +struct McpClient { + child: ChildGuard, + stdin: ChildStdin, + responses: Receiver, +} + +impl McpClient { + fn spawn(models_path: &std::path::Path) -> Self { + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-mcp")); + command + .arg("--models") + .arg(models_path) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()); + Self::spawn_command(command) + } + + fn spawn_positional(models_path: &std::path::Path) -> Self { + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-mcp")); + command + .arg(models_path) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()); + Self::spawn_command(command) + } + + fn spawn_command(mut command: Command) -> Self { + let mut child = command.spawn().expect("sidemantic-mcp should spawn"); + let stdin = child.stdin.take().expect("child stdin should be piped"); + let stdout = child.stdout.take().expect("child stdout should be piped"); + let (tx, rx) = mpsc::channel(); + thread::spawn(move || { + let reader = BufReader::new(stdout); + for line in reader.lines() { + let line = match line { + Ok(line) => line, + Err(_) => break, + }; + if line.trim().is_empty() { + continue; + } + if let Ok(value) = serde_json::from_str::(&line) { + let _ = tx.send(value); + } + } + }); + + Self { + child: ChildGuard::new(child), + stdin, + responses: rx, + } + } + + fn send(&mut self, message: Value) { + writeln!(self.stdin, "{message}").expect("mcp message should be written"); + self.stdin.flush().expect("mcp stdin should flush"); + } + + fn request(&mut self, id: u64, method: &str, params: Value) -> Value { + self.send(json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params + })); + self.read_response(id) + } + + fn notify(&mut self, method: &str, params: Value) { + self.send(json!({ + "jsonrpc": "2.0", + "method": method, + "params": params + })); + } + + fn read_response(&mut self, id: u64) -> Value { + for _ in 0..20 { + let value = self + .responses + .recv_timeout(Duration::from_secs(5)) + .expect("mcp response should arrive"); + if value["id"] == json!(id) { + return value; + } + } + panic!("mcp response id {id} was not observed"); + } + + fn shutdown(mut self) { + self.child.kill_and_wait(); + } +} + +fn structured_content(response: &Value) -> Value { + if let Some(value) = response["result"]["structuredContent"].as_object() { + return Value::Object(value.clone()); + } + let text = response["result"]["content"][0]["text"] + .as_str() + .expect("tool response should include text content"); + serde_json::from_str(text).expect("tool text content should be JSON") +} + +#[test] +fn mcp_server_accepts_positional_models_path() { + let dir = unique_temp_dir("sidemantic_mcp_positional_models"); + let models_path = common::write_retail_fixture(&dir); + let mut client = McpClient::spawn_positional(&models_path); + + let init = client.request( + 1, + "initialize", + json!({ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "sidemantic-test", "version": "0.0.0" } + }), + ); + assert_eq!(init["jsonrpc"], "2.0"); + + client.shutdown(); + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn mcp_server_fails_fast_for_missing_models_path() { + let dir = unique_temp_dir("sidemantic_mcp_bad_models"); + let missing = dir.join("missing.yml"); + + let mut command = command_with_clean_sidemantic_env(env!("CARGO_BIN_EXE_sidemantic-mcp")); + let output = command + .arg("--models") + .arg(&missing) + .output() + .expect("sidemantic-mcp should run to a startup error"); + + assert!(!output.status.success(), "missing models path should fail"); + let stderr = String::from_utf8_lossy(&output.stderr); + assert!( + stderr.contains("models path") && stderr.contains("is not a readable file or directory"), + "{stderr}" + ); + + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} + +#[test] +fn mcp_server_exercises_tool_protocol_and_errors() { + let dir = unique_temp_dir("sidemantic_mcp_protocol"); + let models_path = common::write_retail_fixture(&dir); + let mut client = McpClient::spawn(&models_path); + + let init = client.request( + 1, + "initialize", + json!({ + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "sidemantic-test", "version": "0.0.0" } + }), + ); + assert_eq!(init["jsonrpc"], "2.0"); + assert!(init["result"]["capabilities"].to_string().contains("tools")); + assert!(init["result"]["capabilities"] + .to_string() + .contains("resources")); + client.notify("notifications/initialized", json!({})); + + let tools = client.request(2, "tools/list", json!({})); + let tool_names = tools["result"]["tools"] + .as_array() + .expect("tools/list should return tools") + .iter() + .map(|tool| tool["name"].as_str().unwrap_or_default()) + .collect::>(); + for expected in [ + "list_models", + "get_models", + "get_semantic_graph", + "validate_query", + "compile_query", + "run_query", + "run_sql", + "create_chart", + ] { + assert!( + tool_names.contains(&expected), + "missing tool {expected}: {tools}" + ); + } + + let list = client.request( + 3, + "tools/call", + json!({ "name": "list_models", "arguments": {} }), + ); + let list_payload = structured_content(&list); + assert!(list_payload["models"] + .as_array() + .expect("list_models should return model array") + .iter() + .any(|model| model["name"] == "orders")); + + let details = client.request( + 4, + "tools/call", + json!({ "name": "get_models", "arguments": { "model_names": ["orders"] } }), + ); + let details_payload = structured_content(&details); + assert_eq!(details_payload["models"][0]["name"], "orders"); + assert!(details_payload["models"][0]["relationships"] + .to_string() + .contains("customers")); + + let graph = client.request( + 9, + "tools/call", + json!({ "name": "get_semantic_graph", "arguments": {} }), + ); + let graph_payload = structured_content(&graph); + assert!(graph_payload["models"] + .as_array() + .expect("graph should include model array") + .iter() + .any(|model| model["name"] == "orders")); + assert!(graph_payload["joinable_pairs"] + .as_array() + .expect("graph should include joinable pairs") + .iter() + .any(|pair| pair["from"] == "orders" && pair["to"] == "customers")); + + let compiled = client.request( + 5, + "tools/call", + json!({ + "name": "compile_query", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "filters": ["orders.status = {{ status_param }}"], + "parameters": { "status_param": "complete" }, + "segments": ["orders.completed"], + "order_by": ["orders.revenue desc"], + "use_preaggregations": true, + "limit": 5 + } + }), + ); + let compiled_payload = structured_content(&compiled); + let sql = compiled_payload["sql"] + .as_str() + .expect("compile tool should return sql"); + assert!(sql.contains("SUM"), "{sql}"); + assert!(sql.contains("status = 'complete'"), "{sql}"); + assert!(sql.contains("GROUP BY"), "{sql}"); + assert!( + sql.to_ascii_uppercase().contains("ORDER BY REVENUE DESC"), + "{sql}" + ); + assert!(sql.contains("LIMIT 5"), "{sql}"); + + let dry_run = client.request( + 10, + "tools/call", + json!({ + "name": "run_query", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "where": "orders.customer_id > 0", + "filters": ["orders.status = {{ status_param }}"], + "parameters": { "status_param": "complete" }, + "offset": 2, + "dry_run": true + } + }), + ); + let dry_run_payload = structured_content(&dry_run); + let sql = dry_run_payload["sql"] + .as_str() + .expect("dry run should return sql"); + assert!(sql.contains("WHERE"), "{sql}"); + assert!(sql.contains("status = 'complete'"), "{sql}"); + assert!(sql.contains("OFFSET 2"), "{sql}"); + + let ungrouped = client.request( + 18, + "tools/call", + json!({ + "name": "compile_query", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "ungrouped": true + } + }), + ); + let ungrouped_payload = structured_content(&ungrouped); + let sql = ungrouped_payload["sql"] + .as_str() + .expect("ungrouped compile tool should return sql"); + assert!(!sql.to_ascii_uppercase().contains("GROUP BY"), "{sql}"); + + let valid = client.request( + 11, + "tools/call", + json!({ + "name": "validate_query", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + } + }), + ); + let valid_payload = structured_content(&valid); + assert_eq!(valid_payload["valid"], true); + assert_eq!( + valid_payload["errors"] + .as_array() + .expect("valid errors should be an array") + .len(), + 0 + ); + + let invalid = client.request( + 12, + "tools/call", + json!({ + "name": "validate_query", + "arguments": { "metrics": ["orders.unknown_metric"] } + }), + ); + let invalid_payload = structured_content(&invalid); + assert_eq!(invalid_payload["valid"], false); + assert!(invalid_payload["errors"] + .to_string() + .contains("unknown_metric")); + + let invalid_metric = client.request( + 6, + "tools/call", + json!({ + "name": "compile_query", + "arguments": { "metrics": ["orders.unknown_metric"] } + }), + ); + assert_eq!(invalid_metric["error"]["code"], -32602); + assert!(invalid_metric["error"]["message"] + .as_str() + .unwrap_or("") + .contains("failed to compile query")); + + let missing_adbc = client.request( + 7, + "tools/call", + json!({ + "name": "run_query", + "arguments": { "metrics": ["orders.revenue"] } + }), + ); + assert_eq!(missing_adbc["error"]["code"], -32602); + assert!(missing_adbc["error"]["message"] + .as_str() + .unwrap_or("") + .contains("mcp-adbc")); + + let missing_adbc_sql = client.request( + 13, + "tools/call", + json!({ + "name": "run_sql", + "arguments": { "query": "select orders.status, orders.revenue from orders" } + }), + ); + assert_eq!(missing_adbc_sql["error"]["code"], -32602); + assert!(missing_adbc_sql["error"]["message"] + .as_str() + .unwrap_or("") + .contains("mcp-adbc")); + + let missing_adbc_chart = client.request( + 14, + "tools/call", + json!({ + "name": "create_chart", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"] + } + }), + ); + assert_eq!(missing_adbc_chart["error"]["code"], -32602); + assert!(missing_adbc_chart["error"]["message"] + .as_str() + .unwrap_or("") + .contains("mcp-adbc")); + + let chart_parameter_error = client.request( + 19, + "tools/call", + json!({ + "name": "create_chart", + "arguments": { + "dimensions": ["orders.status"], + "metrics": ["orders.revenue"], + "filters": ["orders.customer_id > {{ min_customer_id }}"], + "parameters": { "min_customer_id": "not-a-number" } + } + }), + ); + assert_eq!(chart_parameter_error["error"]["code"], -32602); + assert!(chart_parameter_error["error"]["message"] + .as_str() + .unwrap_or("") + .contains("failed to interpolate query parameters")); + + let resources = client.request(15, "resources/list", json!({})); + let resource_items = resources["result"]["resources"] + .as_array() + .expect("resources/list should return resources"); + assert!(resource_items + .iter() + .any(|resource| resource["uri"] == "semantic://catalog")); + + let catalog = client.request(16, "resources/read", json!({ "uri": "semantic://catalog" })); + let catalog_text = catalog["result"]["contents"][0]["text"] + .as_str() + .expect("catalog resource should return text"); + let catalog_payload: Value = + serde_json::from_str(catalog_text).expect("catalog text should be JSON"); + assert!(catalog_payload + .get("tables") + .and_then(Value::as_array) + .expect("catalog should include table metadata") + .iter() + .any(|table| table["table_name"] == "orders")); + assert!(catalog_payload + .get("columns") + .and_then(Value::as_array) + .expect("catalog should include column metadata") + .iter() + .any(|column| column["table_name"] == "orders" && column["column_name"] == "revenue")); + + let missing_resource = + client.request(17, "resources/read", json!({ "uri": "semantic://missing" })); + assert_eq!(missing_resource["error"]["code"], -32002); + + let unknown_tool = client.request( + 8, + "tools/call", + json!({ "name": "does_not_exist", "arguments": {} }), + ); + assert!(unknown_tool.get("error").is_some(), "{unknown_tool}"); + + client.shutdown(); + fs::remove_dir_all(&dir).expect("temp dir should be removed"); +} diff --git a/sidemantic-rs/tests/package_metadata.rs b/sidemantic-rs/tests/package_metadata.rs new file mode 100644 index 00000000..634970f4 --- /dev/null +++ b/sidemantic-rs/tests/package_metadata.rs @@ -0,0 +1,50 @@ +fn toml_string_value(contents: &str, key: &str) -> Option { + let prefix = format!("{key} = "); + contents.lines().find_map(|line| { + let value = line.trim().strip_prefix(&prefix)?; + Some(value.trim().trim_matches('"').to_string()) + }) +} + +#[test] +fn rust_crate_and_python_extension_versions_match() { + let pyproject = include_str!("../pyproject.toml"); + let pyproject_version = + toml_string_value(pyproject, "version").expect("pyproject.toml project.version"); + + assert_eq!(env!("CARGO_PKG_VERSION"), pyproject_version); +} + +#[test] +fn python_extension_metadata_targets_the_expected_module_and_feature() { + let pyproject = include_str!("../pyproject.toml"); + + assert!(pyproject.contains("name = \"sidemantic-rs\"")); + assert!(pyproject.contains("module-name = \"sidemantic_rs\"")); + assert!(pyproject.contains("features = [\"python-adbc\"]")); + assert!(pyproject.contains("license = \"AGPL-3.0-only\"")); +} + +#[test] +fn library_crate_types_cover_rust_c_abi_and_python_wasm_artifacts() { + let cargo_toml = include_str!("../Cargo.toml"); + + assert!(cargo_toml.contains("crate-type = [\"rlib\", \"staticlib\", \"cdylib\"]")); +} + +#[test] +fn cargo_metadata_and_feature_split_are_explicit() { + let cargo_toml = include_str!("../Cargo.toml"); + + for expected in [ + "license = \"AGPL-3.0-only\"", + "repository = \"https://github.com/sidequery/sidemantic\"", + "homepage = \"https://sidemantic.com\"", + "python-adbc = [\"python\", \"adbc-exec\"]", + "mcp-adbc = [\"mcp-server\", \"adbc-exec\"]", + "runtime-server-adbc = [\"runtime-server\", \"adbc-exec\", \"dep:tokio-stream\"]", + "workbench-adbc = [\"workbench-tui\", \"adbc-exec\"]", + ] { + assert!(cargo_toml.contains(expected), "missing {expected}"); + } +} diff --git a/sidemantic-rs/tests/python_wheel_adbc_smoke.py b/sidemantic-rs/tests/python_wheel_adbc_smoke.py new file mode 100644 index 00000000..fc0b02c8 --- /dev/null +++ b/sidemantic-rs/tests/python_wheel_adbc_smoke.py @@ -0,0 +1,40 @@ +import os + +import sidemantic_rs + +execute_with_adbc = getattr(sidemantic_rs, "execute_with_adbc", None) +if not callable(execute_with_adbc): + raise AssertionError("python-adbc wheel should expose execute_with_adbc") + +try: + execute_with_adbc("adbc_driver_missing_for_sidemantic_smoke", "select 1") +except RuntimeError as exc: + message = str(exc) + if "rust ADBC execution failed" not in message: + raise AssertionError(f"unexpected ADBC error: {message}") from exc +else: + raise AssertionError("missing test driver should fail through the Rust ADBC execution path") + +duckdb_driver = os.environ.get("SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER") +if duckdb_driver: + result = execute_with_adbc( + duckdb_driver, + "select 42 as answer", + entrypoint="duckdb_adbc_init", + db_kwargs={"path": ":memory:"}, + ) + if result.get("columns") != ["answer"]: + raise AssertionError(f"unexpected DuckDB columns: {result!r}") + if result.get("rows") != [(42,)]: + raise AssertionError(f"unexpected DuckDB rows: {result!r}") + + uri_result = execute_with_adbc( + duckdb_driver, + "select 7 as answer", + uri=":memory:", + entrypoint="duckdb_adbc_init", + ) + if uri_result.get("columns") != ["answer"]: + raise AssertionError(f"unexpected DuckDB URI columns: {uri_result!r}") + if uri_result.get("rows") != [(7,)]: + raise AssertionError(f"unexpected DuckDB URI rows: {uri_result!r}") diff --git a/sidemantic-rs/tests/python_wheel_python_smoke.py b/sidemantic-rs/tests/python_wheel_python_smoke.py new file mode 100644 index 00000000..817ca03f --- /dev/null +++ b/sidemantic-rs/tests/python_wheel_python_smoke.py @@ -0,0 +1,51 @@ +# /// script +# requires-python = ">=3.11" +# /// + +import importlib.metadata +import importlib.util + +import sidemantic_rs + +root_python_package = importlib.util.find_spec("sidemantic") +if root_python_package is not None: + raise AssertionError("isolated sidemantic_rs python-feature wheel unexpectedly found root sidemantic package") + +if importlib.metadata.version("sidemantic-rs") != "0.1.0": + raise AssertionError("unexpected sidemantic-rs wheel version") + +models_yaml = """ +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +""" + +query_yaml = """ +metrics: [orders.revenue] +dimensions: [orders.status] +""" + +compiled = sidemantic_rs.compile_with_yaml(models_yaml, query_yaml) +if "SUM(" not in compiled or "GROUP BY" not in compiled: + raise AssertionError(f"unexpected compiled SQL: {compiled}") + +execute_with_adbc = getattr(sidemantic_rs, "execute_with_adbc", None) +if not callable(execute_with_adbc): + raise AssertionError("python-feature wheel should expose execute_with_adbc disabled stub") + +try: + execute_with_adbc("adbc_driver_duckdb", "select 1") +except RuntimeError as exc: + message = str(exc) + if "python-adbc" not in message or "not enabled" not in message: + raise AssertionError(f"unexpected disabled ADBC error: {message}") from exc +else: + raise AssertionError("python-feature execute_with_adbc should fail with a feature guidance error") diff --git a/sidemantic-rs/tests/python_wheel_smoke.py b/sidemantic-rs/tests/python_wheel_smoke.py new file mode 100644 index 00000000..74eeeda4 --- /dev/null +++ b/sidemantic-rs/tests/python_wheel_smoke.py @@ -0,0 +1,189 @@ +# /// script +# requires-python = ">=3.11" +# /// + +import importlib.metadata +import importlib.util +import json + +import sidemantic_rs + + +def assert_contains(text: str, needle: str) -> None: + if needle not in text: + raise AssertionError(f"expected {needle!r} in {text!r}") + + +def expect_raises(exc_type: type[BaseException], func, *args) -> str: + try: + func(*args) + except exc_type as exc: + return str(exc) + raise AssertionError(f"expected {exc_type.__name__} from {func.__name__}") + + +root_python_package = importlib.util.find_spec("sidemantic") +if root_python_package is not None: + raise AssertionError("isolated sidemantic_rs wheel smoke unexpectedly found root sidemantic package") + +if importlib.metadata.version("sidemantic-rs") != "0.1.0": + raise AssertionError("unexpected sidemantic-rs wheel version") + +models_yaml = """ +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: customer_id + type: numeric + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count +""" + +query_yaml = """ +metrics: [revenue_per_order] +dimensions: [orders.status] +order_by: [revenue_per_order DESC] +limit: 5 +""" + +compiled = sidemantic_rs.compile_with_yaml(models_yaml, query_yaml) +assert_contains(compiled, "SUM(") +assert_contains(compiled, "COUNT(") +assert_contains(compiled, "ORDER BY") +assert_contains(compiled, "LIMIT 5") + +rewritten = sidemantic_rs.rewrite_with_yaml( + models_yaml, + "SELECT orders.revenue, orders.status FROM orders ORDER BY orders.revenue DESC LIMIT 3", +) +assert_contains(rewritten, "SUM(") +assert_contains(rewritten, "ORDER BY") +assert_contains(rewritten, "LIMIT 3") + +payload = json.loads(sidemantic_rs.load_graph_with_yaml(models_yaml)) +assert "orders" in json.dumps(payload) +assert "revenue_per_order" in json.dumps(payload) + +errors = sidemantic_rs.validate_query_with_yaml(models_yaml, query_yaml) +if errors: + raise AssertionError(f"unexpected validation errors: {errors!r}") + +reference_error = sidemantic_rs.validate_query_references( + models_yaml, + ["orders.missing_metric"], + [], +) +if not reference_error: + raise AssertionError("missing metric reference should produce validation errors") + +reference_errors = sidemantic_rs.validate_query_references( + models_yaml, + ["revenue_per_order"], + ["orders.status"], +) +if reference_errors: + raise AssertionError(f"unexpected reference errors: {reference_errors!r}") + +sql_payload = json.loads(sidemantic_rs.load_graph_with_sql("MODEL (name events, table events, primary_key event_id);")) +assert "events" in json.dumps(sql_payload) + +statement_blocks = json.loads( + sidemantic_rs.parse_sql_statement_blocks_payload("MODEL (name events, table events, primary_key event_id);") +) +if statement_blocks[0]["kind"] != "model": + raise AssertionError(f"unexpected statement blocks: {statement_blocks!r}") + +expect_raises(ValueError, sidemantic_rs.parse_sql_statement_blocks_payload, "MODEL (") +expect_raises(ValueError, sidemantic_rs.load_graph_with_yaml, "models: [") +expect_raises(ValueError, sidemantic_rs.compile_with_yaml, models_yaml, "metrics: [") + +catalog = json.loads(sidemantic_rs.generate_catalog_metadata(models_yaml, "semantic")) +assert "orders" in json.dumps(catalog) + +if sidemantic_rs.detect_adapter_kind("cube.yml", "cubes: []") != "cube": + raise AssertionError("adapter detection failed for Cube YAML") + +chart_x, chart_y = sidemantic_rs.chart_auto_detect_columns(["status", "revenue"], [True]) +if chart_x != "status" or chart_y != ["revenue"]: + raise AssertionError(f"unexpected chart columns: {(chart_x, chart_y)!r}") + +expect_raises( + ValueError, + sidemantic_rs.chart_auto_detect_columns, + ["status", "revenue"], + [False, True], +) + +if not sidemantic_rs.is_relative_date("last 7 days"): + raise AssertionError("relative date detection failed") +assert_contains(sidemantic_rs.relative_date_to_range("last 7 days", "created_at"), "created_at") + +models_for_query = sidemantic_rs.find_models_for_query(["orders.status"], ["orders.revenue"]) +if models_for_query != ["orders"]: + raise AssertionError(f"unexpected query models: {models_for_query!r}") + +parsed_ref = sidemantic_rs.parse_reference_with_yaml(models_yaml, "orders.revenue") +if parsed_ref[:2] != ("orders", "revenue"): + raise AssertionError(f"unexpected parsed reference: {parsed_ref!r}") + +path = sidemantic_rs.find_relationship_path_with_yaml(models_yaml, "orders", "customers") +if not path or path[0][0] != "orders": + raise AssertionError(f"unexpected relationship path: {path!r}") + +expect_raises( + KeyError, + sidemantic_rs.find_relationship_path_with_yaml, + models_yaml, + "missing", + "orders", +) + +relationship_yaml = """ +name: customers +type: many_to_one +foreign_key: customer_id +""" +if sidemantic_rs.relationship_foreign_key_columns(relationship_yaml) != ["customer_id"]: + raise AssertionError("relationship foreign-key helper failed") + +refresh_statements = sidemantic_rs.build_preaggregation_refresh_statements("full", "agg_orders", "SELECT 1 AS n") +if not refresh_statements or "agg_orders" not in "\n".join(refresh_statements): + raise AssertionError(f"unexpected refresh statements: {refresh_statements!r}") + +expect_raises( + ValueError, + sidemantic_rs.build_preaggregation_refresh_statements, + "incremental", + "agg_orders", + "SELECT 1 AS n", +) + +sidemantic_rs.registry_set_current_layer({"name": "wheel-smoke"}) +if sidemantic_rs.registry_get_current_layer() != {"name": "wheel-smoke"}: + raise AssertionError("registry ContextVar roundtrip failed") + +if not callable(getattr(sidemantic_rs, "execute_with_adbc", None)): + raise AssertionError("default python wheel should expose execute_with_adbc") diff --git a/sidemantic-rs/tests/wasm_bindgen_runtime.rs b/sidemantic-rs/tests/wasm_bindgen_runtime.rs new file mode 100644 index 00000000..9682f7e6 --- /dev/null +++ b/sidemantic-rs/tests/wasm_bindgen_runtime.rs @@ -0,0 +1,723 @@ +#![cfg(all(target_arch = "wasm32", feature = "wasm"))] + +use sidemantic::{ + wasm_analyze_migrator_query, wasm_build_preaggregation_refresh_statements, + wasm_build_symmetric_aggregate_sql, wasm_calculate_preaggregation_benefit_score, + wasm_chart_auto_detect_columns, wasm_chart_encoding_type, wasm_chart_format_label, + wasm_chart_select_type, wasm_compile_with_yaml_query, wasm_detect_adapter_kind, + wasm_dimension_sql_expr_with_yaml, wasm_dimension_with_granularity_with_yaml, + wasm_evaluate_table_calculation_expression, wasm_extract_column_references, + wasm_extract_metric_dependencies_from_yaml, wasm_extract_preaggregation_patterns, + wasm_find_models_for_query, wasm_find_relationship_path_with_yaml, + wasm_format_parameter_value_with_yaml, wasm_generate_catalog_metadata_with_yaml, + wasm_generate_preaggregation_definition, + wasm_generate_preaggregation_materialization_sql_with_yaml, wasm_generate_preaggregation_name, + wasm_generate_time_comparison_sql, wasm_interpolate_sql_with_parameters_with_yaml, + wasm_is_relative_date, wasm_is_sql_template, wasm_load_graph_with_sql, + wasm_load_graph_with_yaml, wasm_metric_is_simple_aggregation, wasm_metric_sql_expr, + wasm_metric_to_sql, wasm_model_find_dimension_index_with_yaml, + wasm_model_find_metric_index_with_yaml, wasm_model_find_pre_aggregation_index_with_yaml, + wasm_model_find_segment_index_with_yaml, wasm_model_get_drill_down_with_yaml, + wasm_model_get_drill_up_with_yaml, wasm_model_get_hierarchy_path_with_yaml, + wasm_needs_symmetric_aggregate, wasm_parse_reference_with_yaml, wasm_parse_relative_date, + wasm_parse_simple_metric_aggregation, wasm_parse_sql_definitions_payload, + wasm_parse_sql_graph_definitions_payload, wasm_parse_sql_model_payload, + wasm_parse_sql_statement_blocks_payload, wasm_recommend_preaggregation_patterns, + wasm_relationship_foreign_key_columns_with_yaml, + wasm_relationship_primary_key_columns_with_yaml, wasm_relationship_related_key_with_yaml, + wasm_relationship_sql_expr_with_yaml, wasm_relative_date_to_range, wasm_render_sql_template, + wasm_resolve_metric_inheritance, wasm_resolve_model_inheritance_with_yaml, + wasm_rewrite_with_yaml, wasm_segment_get_sql_with_yaml, wasm_summarize_preaggregation_patterns, + wasm_time_comparison_offset_interval, wasm_time_comparison_sql_offset, + wasm_trailing_period_sql_interval, wasm_validate_engine_refresh_sql_compatibility, + wasm_validate_metric_payload, wasm_validate_model_payload, wasm_validate_models_yaml, + wasm_validate_parameter_payload, wasm_validate_query_references, + wasm_validate_query_references_with_yaml, wasm_validate_query_with_yaml, + wasm_validate_table_calculation_payload, wasm_validate_table_formula_expression, +}; +use wasm_bindgen::JsValue; +use wasm_bindgen_test::*; + +const SIMPLE_MODELS_YAML: &str = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + - name: count + agg: count +"#; + +const PREAGG_MODELS_YAML: &str = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#; + +const TOP_LEVEL_METRIC_YAML: &str = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count +"#; + +fn js_err_text(err: &JsValue) -> String { + err.as_string().unwrap_or_else(|| format!("{err:?}")) +} + +fn js_err_contains(err: &JsValue, needle: &str) -> bool { + js_err_text(err) + .to_ascii_lowercase() + .contains(&needle.to_ascii_lowercase()) +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_compile_and_rewrite() { + let query_yaml = r#" +metrics: [orders.revenue] +dimensions: [orders.status] +"#; + + let compiled = wasm_compile_with_yaml_query(SIMPLE_MODELS_YAML, query_yaml).unwrap(); + assert!(compiled.contains("SUM(")); + + let rewritten = + wasm_rewrite_with_yaml(SIMPLE_MODELS_YAML, "SELECT orders.revenue FROM orders").unwrap(); + assert!(rewritten.contains("SUM(")); + assert!(rewritten.contains("revenue")); + + let rewritten_with_order = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT orders.revenue AS rev, orders.status FROM orders ORDER BY rev DESC LIMIT 3", + ) + .unwrap(); + assert!(rewritten_with_order.contains("ORDER BY")); + assert!(rewritten_with_order.contains("LIMIT 3")); + + let rewritten_with_positional_order = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT orders.revenue, orders.status FROM orders ORDER BY 1 DESC LIMIT 2", + ) + .unwrap(); + assert!(rewritten_with_positional_order.contains("ORDER BY")); + assert!(rewritten_with_positional_order.contains("LIMIT 2")); + + let rewritten_with_aggregate = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT SUM(orders.amount) AS total_revenue, orders.status FROM orders ORDER BY total_revenue DESC LIMIT 2", + ) + .unwrap(); + assert!(rewritten_with_aggregate.contains("SUM(")); + assert!(rewritten_with_aggregate.contains("ORDER BY")); + assert!(rewritten_with_aggregate.contains("LIMIT 2")); + + let rewritten_with_expression = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT SUM(amount) / COUNT(*) AS aov, status FROM orders ORDER BY aov DESC LIMIT 1", + ) + .unwrap(); + assert!(rewritten_with_expression + .to_ascii_uppercase() + .contains("COUNT(")); + assert!(rewritten_with_expression.contains("revenue / count AS aov")); + assert!(rewritten_with_expression.contains("ORDER BY")); + assert!(rewritten_with_expression.contains("LIMIT 1")); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_validate_and_load() { + let query_yaml = r#" +metrics: [orders.revenue] +dimensions: [orders.status] +"#; + let errors_json = wasm_validate_query_with_yaml(SIMPLE_MODELS_YAML, query_yaml).unwrap(); + assert_eq!(errors_json, "[]"); + + let payload = wasm_load_graph_with_yaml(SIMPLE_MODELS_YAML).unwrap(); + assert!(payload.contains("\"models\"")); + assert!(payload.contains("\"orders\"")); + let sql_payload = wasm_load_graph_with_sql( + "MODEL (name orders, table orders, primary_key order_id);\nMETRIC (name order_count, agg count);\n", + ) + .unwrap(); + assert!(sql_payload.contains("\"models\"")); + assert!(sql_payload.contains("\"order_count\"")); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_dependency_helpers() { + let refs = wasm_extract_column_references("(revenue - cost) / revenue"); + assert_eq!(refs, "[\"cost\",\"revenue\"]"); + + let models = wasm_find_models_for_query("[\"orders.status\"]", "[\"orders.revenue\"]").unwrap(); + assert_eq!(models, "[\"orders\"]"); + let ref_errors = wasm_validate_query_references_with_yaml( + SIMPLE_MODELS_YAML, + "[\"orders.revenue\"]", + "[\"orders.status\"]", + ) + .unwrap(); + assert_eq!(ref_errors, "[]"); + let ref_errors_alias = wasm_validate_query_references( + SIMPLE_MODELS_YAML, + "[\"orders.revenue\"]", + "[\"orders.status\"]", + ) + .unwrap(); + assert_eq!(ref_errors_alias, "[]"); + + let relationship_yaml = r#" +models: + - name: customers + primary_key_columns: [id] + relationships: [] + - name: orders + primary_key_columns: [order_id] + relationships: + - name: customers + type: many_to_one + foreign_key_columns: [customer_id] + has_foreign_key: true +"#; + let path_json = + wasm_find_relationship_path_with_yaml(relationship_yaml, "orders", "customers").unwrap(); + assert!(path_json.contains("orders")); + assert!(path_json.contains("customers")); + + let rendered = wasm_render_sql_template( + "select {{ col }} from {{ table }}", + "col: amount\ntable: orders\n", + ) + .unwrap(); + assert_eq!(rendered, "select amount from orders"); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_catalog_and_preaggregation_helpers() { + let catalog_json = + wasm_generate_catalog_metadata_with_yaml(PREAGG_MODELS_YAML, "analytics").unwrap(); + assert!(catalog_json.contains("\"table_name\":\"orders\"")); + + let preagg_sql = wasm_generate_preaggregation_materialization_sql_with_yaml( + PREAGG_MODELS_YAML, + "orders", + "daily_revenue", + ) + .unwrap(); + assert!(preagg_sql.contains("DATE_TRUNC('day', order_date)")); + assert!(preagg_sql.contains("SUM(amount) as revenue_raw")); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_top_level_metric_validation_and_load() { + let query_yaml = r#" +metrics: [revenue_per_order] +dimensions: [orders.status] +"#; + let errors_json = wasm_validate_query_with_yaml(TOP_LEVEL_METRIC_YAML, query_yaml).unwrap(); + assert_eq!(errors_json, "[]"); + + let payload = wasm_load_graph_with_yaml(TOP_LEVEL_METRIC_YAML).unwrap(); + assert!(payload.contains("\"top_level_metrics\"")); + assert!(payload.contains("\"revenue_per_order\"")); + + assert!(wasm_validate_model_payload("name: orders\ntable: orders\nprimary_key: id\n").unwrap()); + assert!(wasm_validate_metric_payload("name: revenue\ntype: derived\nsql: amount\n").unwrap()); + assert!(wasm_validate_parameter_payload("name: region\ntype: string\n").unwrap()); + assert!( + wasm_validate_table_calculation_payload("name: pct\ntype: percent_of_total\n").unwrap() + ); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_sql_definition_parsers() { + let definitions_payload = wasm_parse_sql_definitions_payload( + "METRIC (name revenue, agg sum, sql amount);\nSEGMENT (name completed, sql status = 'completed');\n", + ) + .unwrap(); + assert!(definitions_payload.contains("\"metrics\"")); + assert!(definitions_payload.contains("\"segments\"")); + assert!(definitions_payload.contains("\"revenue\"")); + + let graph_payload = wasm_parse_sql_graph_definitions_payload( + "PARAMETER (name region, type string);\nPRE_AGGREGATION (name daily_rollup, measures [revenue], dimensions [status]);\n", + ) + .unwrap(); + assert!(graph_payload.contains("\"parameters\"")); + assert!(graph_payload.contains("\"pre_aggregations\"")); + + let model_payload = wasm_parse_sql_model_payload( + "MODEL (name orders, table orders, primary_key order_id);\nDIMENSION (name status, type categorical);\n", + ) + .unwrap(); + assert!(model_payload.contains("\"name\":\"orders\"")); + assert!(model_payload.contains("\"dimensions\"")); + + let statement_blocks_payload = wasm_parse_sql_statement_blocks_payload( + "MODEL (name orders, table orders);\nMETRIC (name revenue, expression SUM(amount));\n", + ) + .unwrap(); + assert!(statement_blocks_payload.contains("\"kind\":\"model\"")); + assert!(statement_blocks_payload.contains("\"kind\":\"metric\"")); + assert!(statement_blocks_payload.contains("\"sql\":\"SUM(amount)\"")); + + let migrator_payload = wasm_analyze_migrator_query( + "\nSELECT\n status,\n SUM(amount) / COUNT(*) AS avg_order_value\nFROM orders\nGROUP BY status\n", + ) + .unwrap(); + assert!(migrator_payload.contains("\"group_by_columns\"")); + assert!(migrator_payload.contains("\"avg_order_value\"")); + let chart_columns: serde_json::Value = serde_json::from_str( + &wasm_chart_auto_detect_columns("[\"created_at\",\"revenue\",\"region\"]", "[true,false]") + .unwrap(), + ) + .unwrap(); + assert_eq!(chart_columns["x"], "created_at"); + assert_eq!(chart_columns["y"], serde_json::json!(["revenue"])); + assert_eq!(wasm_chart_select_type("created_at", "string", 1), "area"); + assert_eq!(wasm_chart_encoding_type("order_date"), "temporal"); + assert_eq!( + wasm_chart_format_label("created_at__month"), + "Created At (Month)" + ); + + assert!(wasm_is_sql_template("select {{ col }} from orders")); + assert!(!wasm_is_sql_template("select col from orders")); + + let formatted_param = + wasm_format_parameter_value_with_yaml("name: status\ntype: string\n", "\"complete\"\n") + .unwrap(); + assert_eq!(formatted_param, "'complete'"); + + let interpolated = wasm_interpolate_sql_with_parameters_with_yaml( + "status = {{ status }} and amount >= {{ min_amount }}", + "- name: status\n type: string\n- name: min_amount\n type: number\n", + "status: complete\nmin_amount: 100\n", + ) + .unwrap(); + assert!(interpolated.contains("status = 'complete'")); + assert!(interpolated.contains("amount >= 100")); + + let parsed_today: Option = + serde_json::from_str(&wasm_parse_relative_date("today", "duckdb").unwrap()).unwrap(); + assert_eq!(parsed_today, Some("CURRENT_DATE".to_string())); + let range_today: Option = serde_json::from_str( + &wasm_relative_date_to_range("today", "event_date", "duckdb").unwrap(), + ) + .unwrap(); + assert_eq!(range_today, Some("event_date = CURRENT_DATE".to_string())); + assert!(wasm_is_relative_date("last 7 days")); + assert!(!wasm_is_relative_date("2024-01-01")); + + let offset_json = wasm_time_comparison_offset_interval("yoy", None, None).unwrap(); + let offset_value: serde_json::Value = serde_json::from_str(&offset_json).unwrap(); + assert_eq!(offset_value["amount"], 1); + assert_eq!(offset_value["unit"], "year"); + assert_eq!( + wasm_time_comparison_sql_offset("mom", None, None).unwrap(), + "INTERVAL '1 month'" + ); + assert_eq!( + wasm_trailing_period_sql_interval(3, "month").unwrap(), + "INTERVAL '3 month'" + ); + let comparison_sql = wasm_generate_time_comparison_sql( + "mom", + "percent_change", + "SUM(amount)", + "order_date", + None, + None, + ) + .unwrap(); + assert!(comparison_sql.contains("LAG(")); + assert!(comparison_sql.contains("ORDER BY order_date")); + assert!(comparison_sql.contains("/ NULLIF")); + + let adapter_kind: Option = + serde_json::from_str(&wasm_detect_adapter_kind("orders.lkml", "").unwrap()).unwrap(); + assert_eq!(adapter_kind, Some("lookml".to_string())); + + let parsed_simple_agg: serde_json::Value = serde_json::from_str( + &wasm_parse_simple_metric_aggregation("COUNT(DISTINCT customer_id)").unwrap(), + ) + .unwrap(); + assert_eq!(parsed_simple_agg[0], "count_distinct"); + assert_eq!(parsed_simple_agg[1], "customer_id"); + + let metric_yaml = "name: revenue\nagg: sum\nsql: amount\n"; + assert_eq!(wasm_metric_to_sql(metric_yaml).unwrap(), "SUM(amount)"); + assert_eq!(wasm_metric_sql_expr(metric_yaml).unwrap(), "amount"); + assert!(wasm_metric_is_simple_aggregation(metric_yaml).unwrap()); + + let queries_json = serde_json::to_string(&vec![ + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.count dimensions=orders.region".to_string(), + ]) + .unwrap(); + let patterns_json = wasm_extract_preaggregation_patterns(&queries_json).unwrap(); + let summary_json = wasm_summarize_preaggregation_patterns(&patterns_json, 2).unwrap(); + assert!(summary_json.contains("\"total_queries\":3")); + assert!(summary_json.contains("\"patterns_above_threshold\":1")); + let recommendations_json = + wasm_recommend_preaggregation_patterns(&patterns_json, 1, 0.0, Some(1)).unwrap(); + assert!(recommendations_json.contains("\"query_count\":2")); + let score = wasm_calculate_preaggregation_benefit_score( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + 2, + ) + .unwrap(); + assert!(score > 0.0); + let name = wasm_generate_preaggregation_name( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + ) + .unwrap(); + assert_eq!(name, "day_status_revenue"); + let definition_json = wasm_generate_preaggregation_definition( + r#"{"pattern":{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.created_at","orders.status"],"granularities":["day"],"count":2},"suggested_name":"day_created_at_status_revenue","query_count":2,"estimated_benefit_score":0.5}"#, + ) + .unwrap(); + assert!(definition_json.contains("\"name\":\"day_created_at_status_revenue\"")); + assert!(definition_json.contains("\"time_dimension\":\"created_at\"")); + assert!(definition_json.contains("\"granularity\":\"day\"")); + + assert!(wasm_validate_models_yaml(SIMPLE_MODELS_YAML).unwrap()); + let parsed_reference: (String, String, Option) = serde_json::from_str( + &wasm_parse_reference_with_yaml(SIMPLE_MODELS_YAML, "orders.revenue").unwrap(), + ) + .unwrap(); + assert_eq!( + parsed_reference, + ("orders".to_string(), "revenue".to_string(), None) + ); + + let model_inheritance_yaml = r#" +models: + - name: base + table: orders + primary_key: id + - name: child + extends: base + table: child_orders + primary_key: id +"#; + let resolved_models = wasm_resolve_model_inheritance_with_yaml(model_inheritance_yaml).unwrap(); + assert!(resolved_models.contains("name: base")); + assert!(resolved_models.contains("name: child")); + + let metric_inheritance_yaml = r#" +- name: base + agg: sum + sql: amount +- name: child + extends: base +"#; + let resolved_metrics = wasm_resolve_metric_inheritance(metric_inheritance_yaml).unwrap(); + assert!(resolved_metrics.contains("name: base")); + assert!(resolved_metrics.contains("name: child")); + + let dimension_yaml = r#" +name: created_at +type: time +sql: created_at +supported_granularities: [day, month] +"#; + assert_eq!( + wasm_dimension_sql_expr_with_yaml(dimension_yaml).unwrap(), + "created_at" + ); + assert_eq!( + wasm_dimension_with_granularity_with_yaml(dimension_yaml, "month").unwrap(), + "DATE_TRUNC('month', created_at)" + ); + + let model_hierarchy_yaml = r#" +dimensions: + - name: country + - name: state + parent: country + - name: city + parent: state +"#; + let hierarchy: Vec = serde_json::from_str( + &wasm_model_get_hierarchy_path_with_yaml(model_hierarchy_yaml, "city").unwrap(), + ) + .unwrap(); + assert_eq!(hierarchy, vec!["country", "state", "city"]); + + let drill_down: Option = serde_json::from_str( + &wasm_model_get_drill_down_with_yaml(model_hierarchy_yaml, "country").unwrap(), + ) + .unwrap(); + assert_eq!(drill_down, Some("state".to_string())); + let drill_up: Option = serde_json::from_str( + &wasm_model_get_drill_up_with_yaml(model_hierarchy_yaml, "city").unwrap(), + ) + .unwrap(); + assert_eq!(drill_up, Some("state".to_string())); + + let lookup_yaml = r#" +dimensions: + - name: status + - name: region +metrics: + - name: revenue + - name: count +segments: + - name: active + - name: priority +pre_aggregations: + - name: daily + - name: monthly +"#; + let dim_idx: Option = serde_json::from_str( + &wasm_model_find_dimension_index_with_yaml(lookup_yaml, "region").unwrap(), + ) + .unwrap(); + assert_eq!(dim_idx, Some(1)); + let metric_idx: Option = serde_json::from_str( + &wasm_model_find_metric_index_with_yaml(lookup_yaml, "count").unwrap(), + ) + .unwrap(); + assert_eq!(metric_idx, Some(1)); + let segment_idx: Option = serde_json::from_str( + &wasm_model_find_segment_index_with_yaml(lookup_yaml, "priority").unwrap(), + ) + .unwrap(); + assert_eq!(segment_idx, Some(1)); + let preagg_idx: Option = serde_json::from_str( + &wasm_model_find_pre_aggregation_index_with_yaml(lookup_yaml, "monthly").unwrap(), + ) + .unwrap(); + assert_eq!(preagg_idx, Some(1)); + + let relationship_yaml = r#" +name: customers +type: many_to_one +foreign_key: customer_id +"#; + assert_eq!( + wasm_relationship_sql_expr_with_yaml(relationship_yaml).unwrap(), + "customer_id" + ); + assert_eq!( + wasm_relationship_related_key_with_yaml(relationship_yaml).unwrap(), + "id" + ); + let fk_cols: Vec = serde_json::from_str( + &wasm_relationship_foreign_key_columns_with_yaml(relationship_yaml).unwrap(), + ) + .unwrap(); + assert_eq!(fk_cols, vec!["customer_id"]); + let pk_cols: Vec = serde_json::from_str( + &wasm_relationship_primary_key_columns_with_yaml(relationship_yaml).unwrap(), + ) + .unwrap(); + assert_eq!(pk_cols, vec!["id"]); + + assert_eq!( + wasm_segment_get_sql_with_yaml("sql: \"{model}.status = 'completed'\"\n", "orders_cte") + .unwrap(), + "orders_cte.status = 'completed'" + ); + + let metric_dependency_models_yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount + - name: cost + agg: sum + sql: cost +"#; + let metric_dependency_yaml = r#" +name: margin +type: derived +sql: revenue / cost +"#; + let deps: Vec = serde_json::from_str( + &wasm_extract_metric_dependencies_from_yaml( + metric_dependency_yaml, + Some(metric_dependency_models_yaml.to_string()), + Some("orders".to_string()), + ) + .unwrap(), + ) + .unwrap(); + assert_eq!(deps, vec!["orders.cost", "orders.revenue"]); + + assert_eq!( + wasm_evaluate_table_calculation_expression("1 + 2 * 3").unwrap(), + 7.0 + ); + assert!(wasm_validate_table_formula_expression("${a} + ${b}").unwrap()); + let refresh_valid: serde_json::Value = serde_json::from_str( + &wasm_validate_engine_refresh_sql_compatibility("SELECT 1", "snowflake").unwrap(), + ) + .unwrap(); + assert_eq!(refresh_valid["is_valid"], true); + assert_eq!(refresh_valid["error"], serde_json::Value::Null); + let refresh_statements: Vec = serde_json::from_str( + &wasm_build_preaggregation_refresh_statements( + "incremental", + "orders_preagg_daily_revenue", + "SELECT order_date, SUM(revenue) AS total_revenue FROM orders GROUP BY order_date", + Some("order_date".to_string()), + Some("2026-01-01".to_string()), + None, + None, + None, + ) + .unwrap(), + ) + .unwrap(); + assert_eq!(refresh_statements.len(), 2); + assert!(refresh_statements[1].contains("INSERT INTO orders_preagg_daily_revenue")); + let symmetric_sql = wasm_build_symmetric_aggregate_sql( + "amount", + "order_id", + "sum", + Some("orders_cte".to_string()), + "duckdb", + ) + .unwrap(); + assert!(symmetric_sql.contains("SUM(DISTINCT")); + assert!(symmetric_sql.contains("orders_cte.order_id")); + assert!(wasm_needs_symmetric_aggregate("one_to_many", true)); + assert!(!wasm_needs_symmetric_aggregate("many_to_one", true)); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_error_paths_cross_wasm_boundary() { + let invalid_yaml = wasm_load_graph_with_yaml("models: [").unwrap_err(); + assert!(js_err_contains(&invalid_yaml, "yaml")); + + let invalid_query_yaml = + wasm_compile_with_yaml_query(SIMPLE_MODELS_YAML, "metrics: [").unwrap_err(); + let invalid_query_text = js_err_text(&invalid_query_yaml); + assert!( + invalid_query_text.to_ascii_lowercase().contains("parse") + || invalid_query_text.to_ascii_lowercase().contains("eof") + || invalid_query_text.to_ascii_lowercase().contains("expected"), + "{invalid_query_text}" + ); + + let invalid_sql_definition = wasm_parse_sql_statement_blocks_payload("MODEL (").unwrap_err(); + assert!( + js_err_text(&invalid_sql_definition).contains("Validation") + || js_err_text(&invalid_sql_definition).contains("parse") + ); + + let invalid_json = wasm_find_models_for_query("not-json", "[]").unwrap_err(); + assert!(js_err_text(&invalid_json).contains("expected")); + + let invalid_chart_json = + wasm_chart_auto_detect_columns("[\"status\",\"revenue\"]", "not-json").unwrap_err(); + assert!(js_err_text(&invalid_chart_json).contains("expected")); + + let mismatched_chart_flags = + wasm_chart_auto_detect_columns("[\"status\",\"revenue\"]", "[false,true]").unwrap_err(); + assert!(js_err_text(&mismatched_chart_flags).contains("numeric flag count mismatch")); + + let missing_ref_errors_json = wasm_validate_query_references_with_yaml( + SIMPLE_MODELS_YAML, + "[\"orders.missing_metric\"]", + "[]", + ) + .unwrap(); + let missing_ref_errors: Vec = serde_json::from_str(&missing_ref_errors_json).unwrap(); + assert_eq!(missing_ref_errors.len(), 1); + assert!(missing_ref_errors[0].contains("missing_metric")); + + let missing_join_model = + wasm_find_relationship_path_with_yaml(SIMPLE_MODELS_YAML, "orders", "missing").unwrap_err(); + assert!(js_err_text(&missing_join_model).contains("missing")); + + let incremental_without_watermark = wasm_build_preaggregation_refresh_statements( + "incremental", + "orders_preagg", + "SELECT 1", + None, + None, + None, + None, + None, + ) + .unwrap_err(); + assert!(js_err_text(&incremental_without_watermark).contains("watermark")); +} + +#[wasm_bindgen_test] +fn wasm_bindgen_runtime_rewrite_fallback_rejects_unsupported_sql_shapes() { + let select_star = + wasm_rewrite_with_yaml(SIMPLE_MODELS_YAML, "SELECT * FROM orders").unwrap_err(); + assert!(js_err_text(&select_star).contains("SELECT *")); + + let explicit_join = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT orders.revenue FROM orders JOIN customers ON orders.customer_id = customers.id", + ) + .unwrap_err(); + assert!(js_err_text(&explicit_join).contains("unsupported clause")); + + let cte_query = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "WITH base AS (SELECT * FROM orders) SELECT * FROM base", + ) + .unwrap_err(); + assert!(js_err_text(&cte_query).contains("only supports SELECT")); + + let grouped_query = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT orders.revenue FROM orders GROUP BY orders.status", + ) + .unwrap_err(); + assert!(js_err_text(&grouped_query).contains("unsupported clause")); + + let subquery_from = wasm_rewrite_with_yaml( + SIMPLE_MODELS_YAML, + "SELECT orders.revenue FROM (SELECT * FROM orders) orders", + ) + .unwrap_err(); + assert!(js_err_text(&subquery_from).contains("single table")); +} diff --git a/sidemantic-rs/tests/wasm_parity_subset.rs b/sidemantic-rs/tests/wasm_parity_subset.rs new file mode 100644 index 00000000..34e09111 --- /dev/null +++ b/sidemantic-rs/tests/wasm_parity_subset.rs @@ -0,0 +1,1000 @@ +use sidemantic::runtime::{ + calculate_preaggregation_benefit_score, compile_with_yaml_query, detect_adapter_kind, + dimension_with_granularity_with_yaml, extract_column_references, + extract_preaggregation_patterns, find_models_for_query, find_relationship_path_with_yaml, + generate_preaggregation_definition, generate_preaggregation_name, generate_time_comparison_sql, + interpolate_sql_with_parameters_with_yaml, is_relative_date, metric_is_simple_aggregation, + metric_sql_expr, metric_to_sql, model_get_hierarchy_path_with_yaml, + parse_simple_metric_aggregation, parse_sql_definitions_payload, + parse_sql_graph_definitions_payload, parse_sql_model_payload, + recommend_preaggregation_patterns, relationship_sql_expr_with_yaml, render_sql_template, + resolve_metric_inheritance, resolve_model_inheritance_with_yaml, segment_get_sql_with_yaml, + summarize_preaggregation_patterns, trailing_period_sql_interval, validate_metric_payload, + validate_model_payload, validate_parameter_payload, validate_table_calculation_payload, + validate_table_formula_expression, SidemanticRuntime, +}; +use sidemantic::sql::SemanticQuery; +#[cfg(feature = "wasm")] +use sidemantic::{ + wasm_analyze_migrator_query, wasm_build_preaggregation_refresh_statements, + wasm_build_symmetric_aggregate_sql, wasm_calculate_preaggregation_benefit_score, + wasm_chart_auto_detect_columns, wasm_chart_encoding_type, wasm_chart_format_label, + wasm_chart_select_type, wasm_compile_with_yaml_query, wasm_detect_adapter_kind, + wasm_dimension_sql_expr_with_yaml, wasm_dimension_with_granularity_with_yaml, + wasm_evaluate_table_calculation_expression, wasm_extract_column_references, + wasm_extract_metric_dependencies_from_yaml, wasm_extract_preaggregation_patterns, + wasm_find_models_for_query, wasm_find_relationship_path_with_yaml, + wasm_format_parameter_value_with_yaml, wasm_generate_catalog_metadata_with_yaml, + wasm_generate_preaggregation_definition, + wasm_generate_preaggregation_materialization_sql_with_yaml, wasm_generate_preaggregation_name, + wasm_generate_time_comparison_sql, wasm_interpolate_sql_with_parameters_with_yaml, + wasm_is_relative_date, wasm_is_sql_template, wasm_load_graph_with_sql, + wasm_load_graph_with_yaml, wasm_metric_is_simple_aggregation, wasm_metric_sql_expr, + wasm_metric_to_sql, wasm_model_find_dimension_index_with_yaml, + wasm_model_find_metric_index_with_yaml, wasm_model_find_pre_aggregation_index_with_yaml, + wasm_model_find_segment_index_with_yaml, wasm_model_get_drill_down_with_yaml, + wasm_model_get_drill_up_with_yaml, wasm_model_get_hierarchy_path_with_yaml, + wasm_needs_symmetric_aggregate, wasm_parse_reference_with_yaml, wasm_parse_relative_date, + wasm_parse_simple_metric_aggregation, wasm_parse_sql_definitions_payload, + wasm_parse_sql_graph_definitions_payload, wasm_parse_sql_model_payload, + wasm_parse_sql_statement_blocks_payload, wasm_recommend_preaggregation_patterns, + wasm_relationship_foreign_key_columns_with_yaml, + wasm_relationship_primary_key_columns_with_yaml, wasm_relationship_related_key_with_yaml, + wasm_relationship_sql_expr_with_yaml, wasm_relative_date_to_range, wasm_render_sql_template, + wasm_resolve_metric_inheritance, wasm_resolve_model_inheritance_with_yaml, + wasm_rewrite_with_yaml, wasm_segment_get_sql_with_yaml, wasm_summarize_preaggregation_patterns, + wasm_time_comparison_offset_interval, wasm_time_comparison_sql_offset, + wasm_trailing_period_sql_interval, wasm_validate_engine_refresh_sql_compatibility, + wasm_validate_metric_payload, wasm_validate_model_payload, wasm_validate_models_yaml, + wasm_validate_parameter_payload, wasm_validate_query_references, + wasm_validate_query_references_with_yaml, wasm_validate_query_with_yaml, + wasm_validate_table_calculation_payload, wasm_validate_table_formula_expression, +}; + +#[test] +fn wasm_parity_subset_runtime_compile_and_rewrite() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + - name: customer_id + type: numeric + metrics: + - name: revenue + agg: sum + sql: amount + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +"#; + + let runtime = SidemanticRuntime::from_yaml(yaml).unwrap(); + let query = SemanticQuery::new() + .with_metrics(vec!["orders.revenue".to_string()]) + .with_dimensions(vec!["customers.country".to_string()]) + .with_filters(vec!["customers.country = 'US'".to_string()]) + .with_order_by(vec!["orders.revenue DESC".to_string()]) + .with_limit(10); + + let compiled = runtime.compile(&query).unwrap(); + assert!(compiled.contains("SUM(")); + assert!(compiled.contains("JOIN")); + assert!(compiled.contains("country")); + + let rewritten = runtime + .rewrite("SELECT orders.revenue, customers.country FROM orders") + .unwrap(); + assert!(rewritten.contains("SUM(")); + assert!(rewritten.contains("country")); +} + +#[test] +fn wasm_parity_subset_runtime_validation_and_join_path() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: customer_id + type: numeric + - name: status + type: categorical + metrics: + - name: order_count + agg: count + relationships: + - name: customers + type: many_to_one + foreign_key: customer_id + - name: customers + table: customers + primary_key: id + dimensions: + - name: country + type: categorical +metrics: + - name: scoped_order_count + type: cumulative + base_metric: order_count +"#; + + let runtime = SidemanticRuntime::from_yaml(yaml).unwrap(); + let errors = runtime.validate_query_references( + &["scoped_order_count".to_string()], + &["customers.country".to_string()], + ); + assert!( + errors.is_empty(), + "unexpected validation errors: {errors:?}" + ); + + let join_path = runtime.find_join_path("orders", "customers").unwrap(); + assert_eq!(join_path.steps.len(), 1); + assert_eq!(join_path.steps[0].from_model, "orders"); + assert_eq!(join_path.steps[0].to_model, "customers"); +} + +#[test] +fn wasm_parity_subset_runtime_helpers() { + let models = find_models_for_query( + &["orders.status".to_string(), "customers.country".to_string()], + &["orders.revenue".to_string()], + ); + assert_eq!( + models.into_iter().collect::>(), + vec!["customers".to_string(), "orders".to_string()] + ); + + let graph_yaml = r#" +models: + - name: customers + primary_key_columns: [id] + relationships: [] + - name: orders + primary_key_columns: [order_id] + relationships: + - name: customers + type: many_to_one + foreign_key_columns: [customer_id] + has_foreign_key: true +"#; + let path = find_relationship_path_with_yaml(graph_yaml, "orders", "customers").unwrap(); + assert_eq!(path.len(), 1); + assert_eq!(path[0].0, "orders"); + assert_eq!(path[0].1, "customers"); + assert_eq!(path[0].2, vec!["customer_id".to_string()]); + assert_eq!(path[0].3, vec!["id".to_string()]); + assert_eq!(path[0].4, "many_to_one"); +} + +#[test] +fn wasm_parity_subset_compile_with_yaml_query() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +parameters: + - name: status + type: string + default_value: pending +"#; + let query_yaml = r#" +metrics: [orders.revenue] +dimensions: [orders.status] +filters: + - "orders.status = {{ status }}" +parameter_values: + status: complete +"#; + + let sql = compile_with_yaml_query(yaml, query_yaml).unwrap(); + assert!(sql.contains("SUM(")); + assert!(sql.contains("'complete'")); +} + +#[test] +fn wasm_parity_subset_template_parameter_helpers() { + let rendered = render_sql_template( + "select {{ col }} from {{ table }}", + "col: amount\ntable: orders\n", + ) + .unwrap(); + assert_eq!(rendered, "select amount from orders"); + + let interpolated = interpolate_sql_with_parameters_with_yaml( + "status = {{ status }}", + "- name: status\n type: string\n", + "status: complete\n", + ) + .unwrap(); + assert_eq!(interpolated, "status = 'complete'"); +} + +#[test] +fn wasm_parity_subset_core_helper_entrypoints() { + let dimension_yaml = r#" +name: created_at +type: time +sql: created_at +supported_granularities: [day, month] +"#; + let with_granularity = dimension_with_granularity_with_yaml(dimension_yaml, "month").unwrap(); + assert_eq!(with_granularity, "DATE_TRUNC('month', created_at)"); + + let model_yaml = r#" +dimensions: + - name: country + - name: state + parent: country + - name: city + parent: state +"#; + let hierarchy = model_get_hierarchy_path_with_yaml(model_yaml, "city").unwrap(); + assert_eq!( + hierarchy, + vec![ + "country".to_string(), + "state".to_string(), + "city".to_string() + ] + ); + + let relationship_yaml = r#" +name: customers +type: many_to_one +foreign_key: customer_id +"#; + assert_eq!( + relationship_sql_expr_with_yaml(relationship_yaml).unwrap(), + "customer_id" + ); + + let segment_yaml = "sql: \"{model}.status = 'completed'\"\n"; + assert_eq!( + segment_get_sql_with_yaml(segment_yaml, "orders_cte").unwrap(), + "orders_cte.status = 'completed'" + ); +} + +#[test] +fn wasm_parity_subset_extract_column_references_entrypoint() { + let refs = extract_column_references("(revenue - cost) / revenue"); + assert_eq!(refs, vec!["cost".to_string(), "revenue".to_string()]); +} + +#[test] +fn wasm_parity_subset_time_and_date_helper_entrypoints() { + assert!(validate_table_formula_expression("${a} + ${b}").unwrap()); + assert_eq!( + trailing_period_sql_interval(3, "month").unwrap(), + "INTERVAL '3 month'" + ); + assert!(is_relative_date("last 7 days")); + + let comparison_sql = + generate_time_comparison_sql("mom", "difference", "SUM(amount)", "order_date", None, None) + .unwrap(); + assert!(comparison_sql.contains("LAG(")); + assert!(comparison_sql.contains("ORDER BY order_date")); +} + +#[test] +fn wasm_parity_subset_metric_helper_entrypoints() { + assert_eq!( + parse_simple_metric_aggregation("COUNT(DISTINCT customer_id)"), + Some(( + "count_distinct".to_string(), + Some("customer_id".to_string()) + )) + ); + + let metric_yaml = r#" +name: revenue +agg: sum +sql: amount +"#; + assert_eq!(metric_to_sql(metric_yaml).unwrap(), "SUM(amount)"); + assert_eq!(metric_sql_expr(metric_yaml).unwrap(), "amount"); + assert!(metric_is_simple_aggregation(metric_yaml).unwrap()); +} + +#[test] +fn wasm_parity_subset_preaggregation_recommender_entrypoints() { + let queries = vec![ + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.count dimensions=orders.region".to_string(), + ]; + let patterns_json = extract_preaggregation_patterns(queries).unwrap(); + let summary_json = summarize_preaggregation_patterns(&patterns_json, 2).unwrap(); + assert!(summary_json.contains("\"total_queries\":3")); + assert!(summary_json.contains("\"patterns_above_threshold\":1")); + + let recommendations_json = + recommend_preaggregation_patterns(&patterns_json, 1, 0.0, Some(1)).unwrap(); + assert!(recommendations_json.contains("\"query_count\":2")); + + let score = calculate_preaggregation_benefit_score( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + 2, + ) + .unwrap(); + assert!(score > 0.0); + + let name = generate_preaggregation_name( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + ) + .unwrap(); + assert_eq!(name, "day_status_revenue"); + + let definition_json = generate_preaggregation_definition( + r#"{"pattern":{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.created_at","orders.status"],"granularities":["day"],"count":2},"suggested_name":"day_created_at_status_revenue","query_count":2,"estimated_benefit_score":0.5}"#, + ) + .unwrap(); + assert!(definition_json.contains("\"name\":\"day_created_at_status_revenue\"")); + assert!(definition_json.contains("\"time_dimension\":\"created_at\"")); + assert!(definition_json.contains("\"granularity\":\"day\"")); +} + +#[test] +fn wasm_parity_subset_metric_inheritance_entrypoint() { + let metrics_yaml = r#" +- name: base + agg: sum + sql: amount +- name: child + extends: base +"#; + let resolved_yaml = resolve_metric_inheritance(metrics_yaml).unwrap(); + assert!(resolved_yaml.contains("name: base")); + assert!(resolved_yaml.contains("name: child")); + assert!(resolved_yaml.contains("agg: sum")); + assert!(resolved_yaml.contains("sql: amount")); +} + +#[test] +fn wasm_parity_subset_model_inheritance_entrypoint() { + let models_yaml = r#" +models: + - name: base + table: orders + primary_key: id + - name: child + extends: base + table: child_orders + primary_key: id +"#; + let resolved_yaml = resolve_model_inheritance_with_yaml(models_yaml).unwrap(); + assert!(resolved_yaml.contains("name: base")); + assert!(resolved_yaml.contains("name: child")); +} + +#[test] +fn wasm_parity_subset_payload_validation_entrypoints() { + assert!(validate_model_payload("name: orders\ntable: orders\nprimary_key: id\n").unwrap()); + assert!(validate_metric_payload("name: revenue\ntype: derived\nsql: amount\n").unwrap()); + assert!(validate_parameter_payload("name: region\ntype: string\n").unwrap()); + assert!(validate_table_calculation_payload("name: pct\ntype: percent_of_total\n").unwrap()); +} + +#[test] +fn wasm_parity_subset_adapter_autodetection_entrypoint() { + assert_eq!( + detect_adapter_kind("orders.lkml", ""), + Some("lookml".to_string()) + ); + assert_eq!( + detect_adapter_kind("orders.yml", "models:\n - name: orders\n"), + Some("sidemantic".to_string()) + ); + assert_eq!( + detect_adapter_kind("orders.yml", "dimensions:\n status: _.status\n"), + Some("bsl".to_string()) + ); + assert_eq!(detect_adapter_kind("orders.yml", "foo: bar\n"), None,); +} + +#[test] +fn wasm_parity_subset_sql_definitions_parser_entrypoints() { + let definitions_payload = parse_sql_definitions_payload( + "METRIC (name revenue, agg sum, sql amount);\nSEGMENT (name completed, sql status = 'completed');\n", + ) + .unwrap(); + assert!(definitions_payload.contains("\"metrics\"")); + assert!(definitions_payload.contains("\"segments\"")); + assert!(definitions_payload.contains("\"revenue\"")); + + let graph_payload = parse_sql_graph_definitions_payload( + "METRIC (name revenue, agg sum, sql amount);\nPARAMETER (name region, type string);\nPRE_AGGREGATION (name daily_rollup, measures [revenue], dimensions [status]);\n", + ) + .unwrap(); + assert!(graph_payload.contains("\"parameters\"")); + assert!(graph_payload.contains("\"pre_aggregations\"")); + assert!(graph_payload.contains("\"daily_rollup\"")); + + let model_payload = parse_sql_model_payload( + "MODEL (name orders, table orders, primary_key order_id);\nDIMENSION (name status, type categorical);\n", + ) + .unwrap(); + assert!(model_payload.contains("\"name\":\"orders\"")); + assert!(model_payload.contains("\"dimensions\"")); +} + +#[cfg(feature = "wasm")] +#[test] +fn wasm_parity_subset_wasm_api_entrypoints() { + let yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + - name: count + agg: count +"#; + + let query_yaml = r#" +metrics: [orders.revenue] +dimensions: [orders.status] +"#; + + let compiled = wasm_compile_with_yaml_query(yaml, query_yaml).unwrap(); + assert!(compiled.contains("SUM(")); + + let rewritten = wasm_rewrite_with_yaml(yaml, "SELECT orders.revenue FROM orders").unwrap(); + assert!(rewritten.contains("SUM(")); + + let rewritten_with_order = wasm_rewrite_with_yaml( + yaml, + "SELECT orders.revenue AS rev, orders.status FROM orders ORDER BY rev DESC LIMIT 5", + ) + .unwrap(); + assert!(rewritten_with_order.contains("ORDER BY")); + assert!(rewritten_with_order.contains("LIMIT 5")); + + let rewritten_with_positional_order = wasm_rewrite_with_yaml( + yaml, + "SELECT orders.revenue, orders.status FROM orders ORDER BY 1 DESC LIMIT 2", + ) + .unwrap(); + assert!(rewritten_with_positional_order.contains("ORDER BY")); + assert!(rewritten_with_positional_order.contains("LIMIT 2")); + + #[cfg(target_arch = "wasm32")] + { + let rewritten_with_aggregate = wasm_rewrite_with_yaml( + yaml, + "SELECT SUM(orders.amount) AS total_revenue, orders.status FROM orders ORDER BY total_revenue DESC LIMIT 4", + ) + .unwrap(); + assert!(rewritten_with_aggregate.contains("SUM(")); + assert!(rewritten_with_aggregate.contains("ORDER BY")); + assert!(rewritten_with_aggregate.contains("LIMIT 4")); + } + + #[cfg(target_arch = "wasm32")] + { + let rewritten_with_expression = wasm_rewrite_with_yaml( + yaml, + "SELECT SUM(amount) / COUNT(*) AS aov, status FROM orders ORDER BY aov DESC LIMIT 1", + ) + .unwrap(); + assert!(rewritten_with_expression + .to_ascii_uppercase() + .contains("COUNT(")); + assert!(rewritten_with_expression.contains("revenue / count AS aov")); + assert!(rewritten_with_expression.contains("ORDER BY")); + assert!(rewritten_with_expression.contains("LIMIT 1")); + } + + let errors_json = wasm_validate_query_with_yaml(yaml, query_yaml).unwrap(); + assert_eq!(errors_json, "[]"); + + let payload = wasm_load_graph_with_yaml(yaml).unwrap(); + assert!(payload.contains("\"models\"")); + assert!(payload.contains("\"orders\"")); + let sql_payload = wasm_load_graph_with_sql( + "MODEL (name orders, table orders, primary_key order_id);\nMETRIC (name order_count, agg count);\n", + ) + .unwrap(); + assert!(sql_payload.contains("\"models\"")); + assert!(sql_payload.contains("\"order_count\"")); + + let refs = wasm_extract_column_references("(revenue - cost) / revenue"); + assert_eq!(refs, "[\"cost\",\"revenue\"]"); + + let models = wasm_find_models_for_query("[\"orders.status\"]", "[\"orders.revenue\"]").unwrap(); + assert_eq!(models, "[\"orders\"]"); + let ref_errors = wasm_validate_query_references_with_yaml( + yaml, + "[\"orders.revenue\"]", + "[\"orders.status\"]", + ) + .unwrap(); + assert_eq!(ref_errors, "[]"); + let ref_errors_alias = + wasm_validate_query_references(yaml, "[\"orders.revenue\"]", "[\"orders.status\"]") + .unwrap(); + assert_eq!(ref_errors_alias, "[]"); + + let relationship_yaml = r#" +models: + - name: customers + primary_key_columns: [id] + relationships: [] + - name: orders + primary_key_columns: [order_id] + relationships: + - name: customers + type: many_to_one + foreign_key_columns: [customer_id] + has_foreign_key: true +"#; + let path_json = + wasm_find_relationship_path_with_yaml(relationship_yaml, "orders", "customers").unwrap(); + assert!(path_json.contains("orders")); + assert!(path_json.contains("customers")); + assert!(path_json.contains("customer_id")); + + let rendered = wasm_render_sql_template( + "select {{ col }} from {{ table }}", + "col: amount\ntable: orders\n", + ) + .unwrap(); + assert_eq!(rendered, "select amount from orders"); + + let preagg_yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: order_date + type: time + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + pre_aggregations: + - name: daily_revenue + time_dimension: order_date + granularity: day + measures: [revenue] +"#; + let catalog_json = wasm_generate_catalog_metadata_with_yaml(preagg_yaml, "analytics").unwrap(); + assert!(catalog_json.contains("\"table_name\":\"orders\"")); + + let preagg_sql = wasm_generate_preaggregation_materialization_sql_with_yaml( + preagg_yaml, + "orders", + "daily_revenue", + ) + .unwrap(); + assert!(preagg_sql.contains("DATE_TRUNC('day', order_date)")); + assert!(preagg_sql.contains("SUM(amount) as revenue_raw")); + + let top_level_metric_yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount + - name: order_count + agg: count +metrics: + - name: revenue_per_order + type: ratio + numerator: revenue + denominator: order_count +"#; + let top_level_query_yaml = r#" +metrics: [revenue_per_order] +dimensions: [orders.status] +"#; + let top_level_errors = + wasm_validate_query_with_yaml(top_level_metric_yaml, top_level_query_yaml).unwrap(); + assert_eq!(top_level_errors, "[]"); + let top_level_payload = wasm_load_graph_with_yaml(top_level_metric_yaml).unwrap(); + assert!(top_level_payload.contains("\"top_level_metrics\"")); + assert!(top_level_payload.contains("\"revenue_per_order\"")); + assert!(wasm_validate_model_payload("name: orders\ntable: orders\nprimary_key: id\n").unwrap()); + assert!(wasm_validate_metric_payload("name: revenue\ntype: derived\nsql: amount\n").unwrap()); + assert!(wasm_validate_parameter_payload("name: region\ntype: string\n").unwrap()); + assert!( + wasm_validate_table_calculation_payload("name: pct\ntype: percent_of_total\n").unwrap() + ); + + let definitions_payload = wasm_parse_sql_definitions_payload( + "METRIC (name revenue, agg sum, sql amount);\nSEGMENT (name completed, sql status = 'completed');\n", + ) + .unwrap(); + assert!(definitions_payload.contains("\"metrics\"")); + assert!(definitions_payload.contains("\"segments\"")); + assert!(definitions_payload.contains("\"revenue\"")); + + let graph_payload = wasm_parse_sql_graph_definitions_payload( + "PARAMETER (name region, type string);\nPRE_AGGREGATION (name daily_rollup, measures [revenue], dimensions [status]);\n", + ) + .unwrap(); + assert!(graph_payload.contains("\"parameters\"")); + assert!(graph_payload.contains("\"pre_aggregations\"")); + + let model_payload = wasm_parse_sql_model_payload( + "MODEL (name orders, table orders, primary_key order_id);\nDIMENSION (name status, type categorical);\n", + ) + .unwrap(); + assert!(model_payload.contains("\"name\":\"orders\"")); + assert!(model_payload.contains("\"dimensions\"")); + + let statement_blocks_payload = wasm_parse_sql_statement_blocks_payload( + "MODEL (name orders, table orders);\nMETRIC (name revenue, expression SUM(amount));\n", + ) + .unwrap(); + assert!(statement_blocks_payload.contains("\"kind\":\"model\"")); + assert!(statement_blocks_payload.contains("\"kind\":\"metric\"")); + assert!(statement_blocks_payload.contains("\"sql\":\"SUM(amount)\"")); + + let migrator_payload = wasm_analyze_migrator_query( + "\nSELECT\n status,\n SUM(amount) / COUNT(*) AS avg_order_value\nFROM orders\nGROUP BY status\n", + ) + .unwrap(); + assert!(migrator_payload.contains("\"group_by_columns\"")); + assert!(migrator_payload.contains("\"avg_order_value\"")); + let chart_columns: serde_json::Value = serde_json::from_str( + &wasm_chart_auto_detect_columns("[\"created_at\",\"revenue\",\"region\"]", "[true,false]") + .unwrap(), + ) + .unwrap(); + assert_eq!(chart_columns["x"], "created_at"); + assert_eq!(chart_columns["y"], serde_json::json!(["revenue"])); + assert_eq!(wasm_chart_select_type("created_at", "string", 1), "area"); + assert_eq!(wasm_chart_encoding_type("order_date"), "temporal"); + assert_eq!( + wasm_chart_format_label("created_at__month"), + "Created At (Month)" + ); + + assert!(wasm_is_sql_template("select {{ col }} from orders")); + assert!(!wasm_is_sql_template("select col from orders")); + + let formatted_param = + wasm_format_parameter_value_with_yaml("name: status\ntype: string\n", "\"complete\"\n") + .unwrap(); + assert_eq!(formatted_param, "'complete'"); + + let interpolated = wasm_interpolate_sql_with_parameters_with_yaml( + "status = {{ status }} and amount >= {{ min_amount }}", + "- name: status\n type: string\n- name: min_amount\n type: number\n", + "status: complete\nmin_amount: 100\n", + ) + .unwrap(); + assert!(interpolated.contains("status = 'complete'")); + assert!(interpolated.contains("amount >= 100")); + + let parsed_today: Option = + serde_json::from_str(&wasm_parse_relative_date("today", "duckdb").unwrap()).unwrap(); + assert_eq!(parsed_today, Some("CURRENT_DATE".to_string())); + let range_today: Option = serde_json::from_str( + &wasm_relative_date_to_range("today", "event_date", "duckdb").unwrap(), + ) + .unwrap(); + assert_eq!(range_today, Some("event_date = CURRENT_DATE".to_string())); + assert!(wasm_is_relative_date("last 7 days")); + assert!(!wasm_is_relative_date("2024-01-01")); + + let offset_json = wasm_time_comparison_offset_interval("yoy", None, None).unwrap(); + let offset_value: serde_json::Value = serde_json::from_str(&offset_json).unwrap(); + assert_eq!(offset_value["amount"], 1); + assert_eq!(offset_value["unit"], "year"); + assert_eq!( + wasm_time_comparison_sql_offset("mom", None, None).unwrap(), + "INTERVAL '1 month'" + ); + assert_eq!( + wasm_trailing_period_sql_interval(3, "month").unwrap(), + "INTERVAL '3 month'" + ); + let comparison_sql = wasm_generate_time_comparison_sql( + "mom", + "percent_change", + "SUM(amount)", + "order_date", + None, + None, + ) + .unwrap(); + assert!(comparison_sql.contains("LAG(")); + assert!(comparison_sql.contains("ORDER BY order_date")); + assert!(comparison_sql.contains("/ NULLIF")); + + let adapter_kind: Option = + serde_json::from_str(&wasm_detect_adapter_kind("orders.lkml", "").unwrap()).unwrap(); + assert_eq!(adapter_kind, Some("lookml".to_string())); + + let parsed_simple_agg: serde_json::Value = serde_json::from_str( + &wasm_parse_simple_metric_aggregation("COUNT(DISTINCT customer_id)").unwrap(), + ) + .unwrap(); + assert_eq!(parsed_simple_agg[0], "count_distinct"); + assert_eq!(parsed_simple_agg[1], "customer_id"); + + let metric_yaml = "name: revenue\nagg: sum\nsql: amount\n"; + assert_eq!(wasm_metric_to_sql(metric_yaml).unwrap(), "SUM(amount)"); + assert_eq!(wasm_metric_sql_expr(metric_yaml).unwrap(), "amount"); + assert!(wasm_metric_is_simple_aggregation(metric_yaml).unwrap()); + + let queries_json = serde_json::to_string(&vec![ + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.revenue dimensions=orders.status granularities=day".to_string(), + "select * from orders -- sidemantic: models=orders metrics=orders.count dimensions=orders.region".to_string(), + ]) + .unwrap(); + let patterns_json = wasm_extract_preaggregation_patterns(&queries_json).unwrap(); + let summary_json = wasm_summarize_preaggregation_patterns(&patterns_json, 2).unwrap(); + assert!(summary_json.contains("\"total_queries\":3")); + assert!(summary_json.contains("\"patterns_above_threshold\":1")); + let recommendations_json = + wasm_recommend_preaggregation_patterns(&patterns_json, 1, 0.0, Some(1)).unwrap(); + assert!(recommendations_json.contains("\"query_count\":2")); + let score = wasm_calculate_preaggregation_benefit_score( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + 2, + ) + .unwrap(); + assert!(score > 0.0); + let name = wasm_generate_preaggregation_name( + r#"{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.status"],"granularities":["day"],"count":2}"#, + ) + .unwrap(); + assert_eq!(name, "day_status_revenue"); + let definition_json = wasm_generate_preaggregation_definition( + r#"{"pattern":{"model":"orders","metrics":["orders.revenue"],"dimensions":["orders.created_at","orders.status"],"granularities":["day"],"count":2},"suggested_name":"day_created_at_status_revenue","query_count":2,"estimated_benefit_score":0.5}"#, + ) + .unwrap(); + assert!(definition_json.contains("\"name\":\"day_created_at_status_revenue\"")); + assert!(definition_json.contains("\"time_dimension\":\"created_at\"")); + assert!(definition_json.contains("\"granularity\":\"day\"")); + + assert!(wasm_validate_models_yaml(yaml).unwrap()); + + let parsed_reference: (String, String, Option) = + serde_json::from_str(&wasm_parse_reference_with_yaml(yaml, "orders.revenue").unwrap()) + .unwrap(); + assert_eq!( + parsed_reference, + ("orders".to_string(), "revenue".to_string(), None) + ); + + let model_inheritance_yaml = r#" +models: + - name: base + table: orders + primary_key: id + - name: child + extends: base + table: child_orders + primary_key: id +"#; + let resolved_models = wasm_resolve_model_inheritance_with_yaml(model_inheritance_yaml).unwrap(); + assert!(resolved_models.contains("name: base")); + assert!(resolved_models.contains("name: child")); + + let metric_inheritance_yaml = r#" +- name: base + agg: sum + sql: amount +- name: child + extends: base +"#; + let resolved_metrics = wasm_resolve_metric_inheritance(metric_inheritance_yaml).unwrap(); + assert!(resolved_metrics.contains("name: base")); + assert!(resolved_metrics.contains("name: child")); + + let dimension_yaml = r#" +name: created_at +type: time +sql: created_at +supported_granularities: [day, month] +"#; + assert_eq!( + wasm_dimension_sql_expr_with_yaml(dimension_yaml).unwrap(), + "created_at" + ); + assert_eq!( + wasm_dimension_with_granularity_with_yaml(dimension_yaml, "month").unwrap(), + "DATE_TRUNC('month', created_at)" + ); + + let model_hierarchy_yaml = r#" +dimensions: + - name: country + - name: state + parent: country + - name: city + parent: state +"#; + let hierarchy: Vec = serde_json::from_str( + &wasm_model_get_hierarchy_path_with_yaml(model_hierarchy_yaml, "city").unwrap(), + ) + .unwrap(); + assert_eq!(hierarchy, vec!["country", "state", "city"]); + + let drill_down: Option = serde_json::from_str( + &wasm_model_get_drill_down_with_yaml(model_hierarchy_yaml, "country").unwrap(), + ) + .unwrap(); + assert_eq!(drill_down, Some("state".to_string())); + let drill_up: Option = serde_json::from_str( + &wasm_model_get_drill_up_with_yaml(model_hierarchy_yaml, "city").unwrap(), + ) + .unwrap(); + assert_eq!(drill_up, Some("state".to_string())); + + let lookup_yaml = r#" +dimensions: + - name: status + - name: region +metrics: + - name: revenue + - name: count +segments: + - name: active + - name: priority +pre_aggregations: + - name: daily + - name: monthly +"#; + let dim_idx: Option = serde_json::from_str( + &wasm_model_find_dimension_index_with_yaml(lookup_yaml, "region").unwrap(), + ) + .unwrap(); + assert_eq!(dim_idx, Some(1)); + let metric_idx: Option = serde_json::from_str( + &wasm_model_find_metric_index_with_yaml(lookup_yaml, "count").unwrap(), + ) + .unwrap(); + assert_eq!(metric_idx, Some(1)); + let segment_idx: Option = serde_json::from_str( + &wasm_model_find_segment_index_with_yaml(lookup_yaml, "priority").unwrap(), + ) + .unwrap(); + assert_eq!(segment_idx, Some(1)); + let preagg_idx: Option = serde_json::from_str( + &wasm_model_find_pre_aggregation_index_with_yaml(lookup_yaml, "monthly").unwrap(), + ) + .unwrap(); + assert_eq!(preagg_idx, Some(1)); + + let relationship_yaml = r#" +name: customers +type: many_to_one +foreign_key: customer_id +"#; + assert_eq!( + wasm_relationship_sql_expr_with_yaml(relationship_yaml).unwrap(), + "customer_id" + ); + assert_eq!( + wasm_relationship_related_key_with_yaml(relationship_yaml).unwrap(), + "id" + ); + let fk_cols: Vec = serde_json::from_str( + &wasm_relationship_foreign_key_columns_with_yaml(relationship_yaml).unwrap(), + ) + .unwrap(); + assert_eq!(fk_cols, vec!["customer_id"]); + let pk_cols: Vec = serde_json::from_str( + &wasm_relationship_primary_key_columns_with_yaml(relationship_yaml).unwrap(), + ) + .unwrap(); + assert_eq!(pk_cols, vec!["id"]); + + assert_eq!( + wasm_segment_get_sql_with_yaml("sql: \"{model}.status = 'completed'\"\n", "orders_cte") + .unwrap(), + "orders_cte.status = 'completed'" + ); + + let metric_dependency_models_yaml = r#" +models: + - name: orders + table: orders + primary_key: order_id + metrics: + - name: revenue + agg: sum + sql: amount + - name: cost + agg: sum + sql: cost +"#; + let metric_dependency_yaml = r#" +name: margin +type: derived +sql: revenue / cost +"#; + let deps: Vec = serde_json::from_str( + &wasm_extract_metric_dependencies_from_yaml( + metric_dependency_yaml, + Some(metric_dependency_models_yaml.to_string()), + Some("orders".to_string()), + ) + .unwrap(), + ) + .unwrap(); + assert_eq!(deps, vec!["orders.cost", "orders.revenue"]); + + assert_eq!( + wasm_evaluate_table_calculation_expression("1 + 2 * 3").unwrap(), + 7.0 + ); + assert!(wasm_validate_table_formula_expression("${a} + ${b}").unwrap()); + let refresh_valid: serde_json::Value = serde_json::from_str( + &wasm_validate_engine_refresh_sql_compatibility("SELECT 1", "snowflake").unwrap(), + ) + .unwrap(); + assert_eq!(refresh_valid["is_valid"], true); + assert_eq!(refresh_valid["error"], serde_json::Value::Null); + let refresh_statements: Vec = serde_json::from_str( + &wasm_build_preaggregation_refresh_statements( + "incremental", + "orders_preagg_daily_revenue", + "SELECT order_date, SUM(revenue) AS total_revenue FROM orders GROUP BY order_date", + Some("order_date".to_string()), + Some("2026-01-01".to_string()), + None, + None, + None, + ) + .unwrap(), + ) + .unwrap(); + assert_eq!(refresh_statements.len(), 2); + assert!(refresh_statements[1].contains("INSERT INTO orders_preagg_daily_revenue")); + let symmetric_sql = wasm_build_symmetric_aggregate_sql( + "amount", + "order_id", + "sum", + Some("orders_cte".to_string()), + "duckdb", + ) + .unwrap(); + assert!(symmetric_sql.contains("SUM(DISTINCT")); + assert!(symmetric_sql.contains("orders_cte.order_id")); + assert!(wasm_needs_symmetric_aggregate("one_to_many", true)); + assert!(!wasm_needs_symmetric_aggregate("many_to_one", true)); +} diff --git a/sidemantic-rs/tests/workbench_pty_smoke.py b/sidemantic-rs/tests/workbench_pty_smoke.py new file mode 100644 index 00000000..32d7d498 --- /dev/null +++ b/sidemantic-rs/tests/workbench_pty_smoke.py @@ -0,0 +1,222 @@ +import fcntl +import os +import pty +import select +import struct +import subprocess +import sys +import tempfile +import termios +import time +from pathlib import Path + +MODEL_YAML = """ +models: + - name: orders + table: orders + primary_key: order_id + dimensions: + - name: status + type: categorical + metrics: + - name: revenue + agg: sum + sql: amount +""" + + +def write_models(root: Path) -> Path: + path = root / "models.yml" + path.write_text(MODEL_YAML) + return path + + +def spawn_pty(args, env=None): + master_fd, slave_fd = pty.openpty() + winsize = struct.pack("HHHH", 32, 120, 0, 0) + fcntl.ioctl(slave_fd, termios.TIOCSWINSZ, winsize) + proc = subprocess.Popen( + args, + stdin=slave_fd, + stdout=slave_fd, + stderr=subprocess.PIPE, + env=env, + close_fds=True, + ) + os.close(slave_fd) + os.set_blocking(master_fd, False) + return proc, master_fd + + +def read_available(fd): + chunks = [] + while True: + ready, _, _ = select.select([fd], [], [], 0) + if not ready: + break + try: + chunk = os.read(fd, 65536) + except BlockingIOError: + break + except OSError: + break + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks).decode(errors="replace") + + +def wait_for(fd, needle, timeout=5): + deadline = time.time() + timeout + output = "" + while time.time() < deadline: + output += read_available(fd) + if needle in output: + return output + time.sleep(0.05) + raise AssertionError(f"timed out waiting for {needle!r}; output={output!r}") + + +def send_key(fd, data): + os.write(fd, data) + time.sleep(0.1) + + +def close_pty_process(proc, fd): + try: + if proc.poll() is None: + send_key(fd, b"\x1b") + try: + proc.wait(timeout=2) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait(timeout=5) + finally: + try: + os.close(fd) + except OSError: + pass + + +def run_bad_model_path(binary: str, root: Path): + missing = root / "missing-models" + proc, fd = spawn_pty([binary, "workbench", str(missing)]) + try: + proc.wait(timeout=5) + output = read_available(fd) + stderr = proc.stderr.read().decode(errors="replace") + assert proc.returncode != 0, "bad model path should fail" + assert "does not exist" in stderr or "does not exist" in output, (output, stderr) + finally: + close_pty_process(proc, fd) + + +def run_no_db_workbench(binary: str, models_path: Path): + proc, fd = spawn_pty([binary, "workbench", str(models_path)]) + try: + first = wait_for(fd, "Workbench") + assert "connection=none" in first, first + assert "SQL Input" in first, first + + send_key(fd, b"\x05") + wait_for(fd, "configured") + + send_key(fd, b"\x1b[18~") + wait_for(fd, "TABLE") + + send_key(fd, b"\x1b") + proc.wait(timeout=5) + assert proc.returncode == 0, f"workbench quit failed: {proc.returncode}" + finally: + close_pty_process(proc, fd) + + +def seed_duckdb(binary: str, models_path: Path, db_path: Path, driver: str): + for sql in [ + "drop table if exists orders", + "create table orders(order_id integer, status varchar, amount double)", + "insert into orders values (1, 'complete', 10.5), (2, 'complete', 20.0), (3, 'cancelled', 7.0)", + ]: + result = subprocess.run( + [ + binary, + "query", + "--models", + str(models_path), + "--sql", + sql, + "--driver", + driver, + "--entrypoint", + "duckdb_adbc_init", + "--uri", + str(db_path), + ], + check=False, + capture_output=True, + text=True, + ) + assert result.returncode == 0, result.stderr + + +def write_driver_manifest(root: Path, driver: str) -> Path: + driver_dir = root / "adbc-drivers" + driver_dir.mkdir() + (driver_dir / "adbc_driver_duckdb.toml").write_text( + f""" +manifest_version = 1 +name = "DuckDB" +version = "1.0.0" + +[Driver] +entrypoint = "duckdb_adbc_init" +shared = "{driver}" +""" + ) + return driver_dir + + +def run_db_workbench(binary: str, models_path: Path, root: Path): + driver = os.environ.get("SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER") + if not driver: + raise AssertionError("SIDEMANTIC_TEST_ADBC_DUCKDB_DRIVER is required for workbench-adbc PTY test") + + db_path = root / "warehouse.duckdb" + seed_duckdb(binary, models_path, db_path, driver) + driver_dir = write_driver_manifest(root, driver) + env = os.environ.copy() + existing = env.get("ADBC_DRIVER_PATH") + env["ADBC_DRIVER_PATH"] = str(driver_dir) if not existing else f"{driver_dir}{os.pathsep}{existing}" + + proc, fd = spawn_pty([binary, "workbench", str(models_path), "--db", str(db_path)], env=env) + try: + first = wait_for(fd, "Workbench") + assert "connection=duckdb:///" in first, first + + send_key(fd, b"\x05") + output = wait_for(fd, "complete") + assert "complete" in output or "cancelled" in output, output + + send_key(fd, b"\x1b") + proc.wait(timeout=5) + assert proc.returncode == 0, f"workbench DB quit failed: {proc.returncode}" + finally: + close_pty_process(proc, fd) + + +def main(): + if len(sys.argv) != 2: + raise SystemExit("usage: workbench_pty_smoke.py ") + + binary = sys.argv[1] + with tempfile.TemporaryDirectory(prefix="sidemantic-workbench-pty-") as tmp: + root = Path(tmp) + models_path = write_models(root) + run_bad_model_path(binary, root) + run_no_db_workbench(binary, models_path) + if os.environ.get("SIDEMANTIC_WORKBENCH_PTY_EXPECT_ADBC") == "1": + run_db_workbench(binary, models_path, root) + + +if __name__ == "__main__": + main() diff --git a/sidemantic-rs/tests/workbench_pty_smoke.rs b/sidemantic-rs/tests/workbench_pty_smoke.rs new file mode 100644 index 00000000..5e68e5e8 --- /dev/null +++ b/sidemantic-rs/tests/workbench_pty_smoke.rs @@ -0,0 +1,53 @@ +#![cfg(feature = "workbench-tui")] + +use std::path::PathBuf; +use std::process::Command; + +#[test] +fn dedicated_workbench_binary_handles_help_directly() { + let output = Command::new(env!("CARGO_BIN_EXE_sidemantic-workbench")) + .arg("--help") + .output() + .expect("sidemantic-workbench should run directly"); + + assert!( + output.status.success(), + "sidemantic-workbench --help failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + assert!( + String::from_utf8_lossy(&output.stdout).contains("Usage: sidemantic workbench"), + "unexpected stdout:\n{}", + String::from_utf8_lossy(&output.stdout) + ); +} + +#[test] +fn workbench_pty_exercises_launch_keys_quit_and_execution_states() { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let script = manifest_dir.join("tests/workbench_pty_smoke.py"); + let mut command = Command::new("uv"); + command + .arg("run") + .arg("--no-project") + .arg(&script) + .arg(env!("CARGO_BIN_EXE_sidemantic")); + + if cfg!(feature = "workbench-adbc") { + command.env("SIDEMANTIC_WORKBENCH_PTY_EXPECT_ADBC", "1"); + } else { + command.env_remove("SIDEMANTIC_WORKBENCH_PTY_EXPECT_ADBC"); + } + + let output = command + .output() + .expect("uv should run the workbench PTY smoke harness"); + + assert!( + output.status.success(), + "workbench PTY smoke failed\nstdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); +} diff --git a/sidemantic/core/semantic_graph.py b/sidemantic/core/semantic_graph.py index 6e91f336..1b4b15b8 100644 --- a/sidemantic/core/semantic_graph.py +++ b/sidemantic/core/semantic_graph.py @@ -43,10 +43,16 @@ def __init__(self): self.metrics: dict[str, Metric] = {} self.table_calculations: dict[str, TableCalculation] = {} self.parameters: dict[str, Parameter] = {} + self._version = 0 + self._adjacency_dirty = True self._adjacency: dict[ str, list[tuple[str, list[str], list[str], str]] ] = {} # model -> [(to_model, from_keys, to_keys, rel_type)] + def _mark_dirty(self) -> None: + self._version += 1 + self._adjacency_dirty = True + def add_model(self, model: Model) -> None: """Add a model to the graph. @@ -68,7 +74,7 @@ def add_model(self, model: Model) -> None: if metric.name not in self.metrics: self.metrics[metric.name] = metric - self._adjacency_dirty = True + self._mark_dirty() def add_metric(self, measure: Metric) -> None: """Add a measure to the graph. @@ -80,6 +86,7 @@ def add_metric(self, measure: Metric) -> None: raise ValueError(f"Measure {measure.name} already exists") self.metrics[measure.name] = measure + self._mark_dirty() def add_table_calculation(self, calc: TableCalculation) -> None: """Add a table calculation to the graph. @@ -91,6 +98,7 @@ def add_table_calculation(self, calc: TableCalculation) -> None: raise ValueError(f"Table calculation {calc.name} already exists") self.table_calculations[calc.name] = calc + self._mark_dirty() def get_table_calculation(self, name: str) -> TableCalculation: """Get a table calculation by name. @@ -122,6 +130,7 @@ def add_parameter(self, param: Parameter) -> None: raise ValueError(f"Parameter {param.name} already exists") self.parameters[param.name] = param + self._mark_dirty() def get_parameter(self, name: str) -> Parameter: """Get a parameter by name. diff --git a/sidemantic/core/semantic_layer.py b/sidemantic/core/semantic_layer.py index e27a69f2..9a821dc1 100644 --- a/sidemantic/core/semantic_layer.py +++ b/sidemantic/core/semantic_layer.py @@ -2,13 +2,21 @@ from __future__ import annotations +import os +from collections.abc import Callable from pathlib import Path +import yaml + from sidemantic.core.metric import Metric from sidemantic.core.model import Model from sidemantic.core.semantic_graph import SemanticGraph +from sidemantic.rust_bridge import get_rust_module, graph_to_rust_yaml +from sidemantic.rust_parity import is_strict_for from sidemantic.sql.generator import SQLGenerator +_RUST_SQL_OUTPUT_DIALECT = "duckdb" + class SemanticLayer: """Main semantic layer interface. @@ -52,9 +60,31 @@ def __init__( from sidemantic.db.base import BaseDatabaseAdapter self.graph = SemanticGraph() + self._sql_rewrite_cache: dict[tuple[object, ...], str] = {} + self._sql_rewrite_cache_limit = 256 self.use_preaggregations = use_preaggregations self.preagg_database = preagg_database self.preagg_schema = preagg_schema + self._strict_rust_sql_generator_entrypoint = is_strict_for("sql_generator_entrypoint") + self._strict_rust_query_validation = is_strict_for("semantic_core_query_validation") + self._use_rust_sql_generator = ( + os.getenv("SIDEMANTIC_RS_SQL_GENERATOR", "0") == "1" or self._strict_rust_sql_generator_entrypoint + ) + self._use_rust_query_validation = ( + os.getenv("SIDEMANTIC_RS_QUERY_VALIDATION", "0") == "1" or self._strict_rust_query_validation + ) + self._rust_sql_verify = ( + os.getenv("SIDEMANTIC_RS_SQL_GENERATOR_VERIFY", "1") == "1" + and not self._strict_rust_sql_generator_entrypoint + ) + self._rust_no_fallback = os.getenv("SIDEMANTIC_RS_NO_FALLBACK", "0") == "1" + self._rust_module = None + if self._use_rust_sql_generator: + try: + self._rust_module = get_rust_module() + except Exception: + if self._rust_no_fallback or self._strict_rust_sql_generator_entrypoint: + raise # Initialize adapter from connection string or use provided adapter if isinstance(connection, BaseDatabaseAdapter): @@ -207,6 +237,7 @@ def add_model(self, model: Model) -> None: ) self.graph.add_model(model) + self._sql_rewrite_cache.clear() def _normalize_model_table(self, model: Model) -> None: """Normalize model.table for the active dialect when needed.""" @@ -423,6 +454,7 @@ def add_metric(self, measure: Metric) -> None: ) self.graph.add_metric(measure) + self._sql_rewrite_cache.clear() def query( self, @@ -515,13 +547,99 @@ def compile( dimensions = dimensions or [] # Validate query - errors = validate_query(metrics, dimensions, self.graph) + errors = self._validate_query(metrics, dimensions, validate_query) if errors: raise QueryValidationError("Query validation failed:\n" + "\n".join(f" - {e}" for e in errors)) # Determine if pre-aggregations should be used use_preaggs = use_preaggregations if use_preaggregations is not None else self.use_preaggregations + inner_sql = None + if self._use_rust_sql_generator: + inner_sql = self._compile_with_rust( + metrics=metrics, + dimensions=dimensions, + filters=filters, + segments=segments, + order_by=order_by, + limit=limit, + offset=offset, + dialect=dialect, + ungrouped=ungrouped, + parameters=parameters, + use_preaggregations=use_preaggs, + ) + if inner_sql is None and self._strict_rust_sql_generator_entrypoint: + raise ValueError("Rust SQL generator returned no SQL in strict mode") + if inner_sql is not None and self._rust_sql_verify: + python_sql = self._compile_with_python( + metrics=metrics, + dimensions=dimensions, + filters=filters, + segments=segments, + order_by=order_by, + limit=limit, + offset=offset, + dialect=dialect, + ungrouped=ungrouped, + parameters=parameters, + use_preaggregations=use_preaggs, + ) + if inner_sql.strip() != python_sql.strip(): + if self._rust_no_fallback or self._strict_rust_sql_generator_entrypoint: + raise ValueError("Rust SQL generator output mismatch with Python SQL generator") + inner_sql = python_sql + + if inner_sql is None: + inner_sql = self._compile_with_python( + metrics=metrics, + dimensions=dimensions, + filters=filters, + segments=segments, + order_by=order_by, + limit=limit, + offset=offset, + dialect=dialect, + ungrouped=ungrouped, + parameters=parameters, + use_preaggregations=use_preaggs, + ) + + return self._apply_post_process(inner_sql, post_process) + + def _validate_query( + self, + metrics: list[str], + dimensions: list[str], + python_validate_query: Callable[[list[str], list[str], SemanticGraph], list[str]], + ) -> list[str]: + from sidemantic.validation import QueryValidationError + + if self._use_rust_query_validation: + try: + from sidemantic.rust_bridge import validate_query_with_rust + + return validate_query_with_rust(self.graph, metrics, dimensions) + except Exception as e: + if self._strict_rust_query_validation or self._rust_no_fallback: + raise QueryValidationError(f"Rust query validation failed: {e}") from e + + return python_validate_query(metrics, dimensions, self.graph) + + def _compile_with_python( + self, + metrics: list[str] | None, + dimensions: list[str] | None, + filters: list[str] | None, + segments: list[str] | None, + order_by: list[str] | None, + limit: int | None, + offset: int | None, + dialect: str | None, + ungrouped: bool, + parameters: dict[str, any] | None, + use_preaggregations: bool, + ) -> str: generator = SQLGenerator( self.graph, dialect=dialect or self.dialect, @@ -529,7 +647,7 @@ def compile( preagg_schema=self.preagg_schema, ) - inner_sql = generator.generate( + return generator.generate( metrics=metrics, dimensions=dimensions, filters=filters, @@ -539,9 +657,84 @@ def compile( offset=offset, ungrouped=ungrouped, parameters=parameters, - use_preaggregations=use_preaggs, + use_preaggregations=use_preaggregations, ) + def _compile_with_rust( + self, + metrics: list[str] | None, + dimensions: list[str] | None, + filters: list[str] | None, + segments: list[str] | None, + order_by: list[str] | None, + limit: int | None, + offset: int | None, + dialect: str | None, + ungrouped: bool, + parameters: dict[str, any] | None, + use_preaggregations: bool, + ) -> str | None: + if not self._rust_module: + if self._rust_no_fallback or self._strict_rust_sql_generator_entrypoint: + raise ValueError("Rust SQL generator backend is not initialized") + return None + + payload = { + "metrics": metrics or [], + "dimensions": dimensions or [], + "filters": list(filters or []), + "parameter_values": parameters or {}, + "segments": segments or [], + "order_by": order_by or [], + "limit": limit, + "offset": offset, + "ungrouped": ungrouped, + "use_preaggregations": bool(use_preaggregations), + "preagg_database": self.preagg_database, + "preagg_schema": self.preagg_schema, + } + + try: + models_yaml = graph_to_rust_yaml(self.graph) + query_yaml = yaml.safe_dump(payload, sort_keys=False) + sql = self._rust_module.compile_with_yaml(models_yaml, query_yaml) + + target_dialect = dialect or self.dialect + if target_dialect != _RUST_SQL_OUTPUT_DIALECT: + import sqlglot + + sql = sqlglot.transpile(sql, read=_RUST_SQL_OUTPUT_DIALECT, write=target_dialect)[0] + if target_dialect == "bigquery": + sql = sql.replace("TIMESTAMP_TRUNC(", "DATE_TRUNC(") + + if "-- sidemantic:" not in sql: + generator = SQLGenerator( + self.graph, + dialect=dialect or self.dialect, + preagg_database=self.preagg_database, + preagg_schema=self.preagg_schema, + ) + segment_filters = generator._resolve_segments(segments or []) + all_filters = list(filters or []) + segment_filters + model_names = generator._find_required_models(metrics or [], dimensions or [], all_filters) + sql = ( + sql + + "\n" + + generator._generate_instrumentation_comment( + model_names, + metrics or [], + dimensions or [], + used_preagg=False, + ) + ) + + return sql + except Exception as e: + if self._rust_no_fallback or self._strict_rust_sql_generator_entrypoint: + raise ValueError(f"Rust SQL generator failed: {e}") from e + return None + + def _apply_post_process(self, inner_sql: str, post_process: str | None) -> str: if post_process is not None: if "{inner}" not in post_process: raise ValueError("post_process must contain a {inner} placeholder") @@ -1012,8 +1205,20 @@ def sql(self, query: str): """ from sidemantic.sql.query_rewriter import QueryRewriter - rewriter = QueryRewriter(self.graph, dialect=self.dialect) - rewritten_sql = rewriter.rewrite(query) + cache_key = ( + getattr(self.graph, "_version", 0), + self.dialect, + os.getenv("SIDEMANTIC_RS_REWRITER", "0"), + os.getenv("SIDEMANTIC_RS_NO_FALLBACK", "0"), + query, + ) + rewritten_sql = self._sql_rewrite_cache.get(cache_key) + if rewritten_sql is None: + rewriter = QueryRewriter(self.graph, dialect=self.dialect) + rewritten_sql = rewriter.rewrite(query) + if len(self._sql_rewrite_cache) >= self._sql_rewrite_cache_limit: + self._sql_rewrite_cache.pop(next(iter(self._sql_rewrite_cache))) + self._sql_rewrite_cache[cache_key] = rewritten_sql return self.adapter.execute(rewritten_sql) diff --git a/sidemantic/rust_bridge.py b/sidemantic/rust_bridge.py new file mode 100644 index 00000000..26b84d27 --- /dev/null +++ b/sidemantic/rust_bridge.py @@ -0,0 +1,1540 @@ +"""Helpers for calling sidemantic-rs from Python.""" + +from __future__ import annotations + +import json +import re +from pathlib import Path + +import yaml + +from sidemantic.core.semantic_graph import SemanticGraph + + +def get_rust_module() -> object: + """Import and return sidemantic_rs extension module.""" + try: + import sidemantic_rs + except ImportError as e: + raise ValueError( + "Rust backend requires the sidemantic_rs Python extension. " + "Build it with: uv run --with maturin maturin develop " + "--manifest-path sidemantic-rs/Cargo.toml --features python-adbc" + ) from e + return sidemantic_rs + + +def graph_to_rust_yaml(graph: SemanticGraph) -> str: + """Serialize semantic graph to sidemantic-rs YAML schema.""" + extra_metrics_by_model, remaining_top_level_metrics = _assign_top_level_metrics_for_rust(graph) + return models_to_rust_yaml( + list(graph.models.values()), + extra_metrics_by_model=extra_metrics_by_model, + top_level_metrics=remaining_top_level_metrics, + top_level_parameters=list(graph.parameters.values()), + ) + + +def _assign_top_level_metrics_for_rust(graph: SemanticGraph) -> tuple[dict[str, list], list]: + """Assign Python graph-level metrics to model payloads for sidemantic-rs. + + Python can keep derived/ratio metrics at graph scope even when dependencies span + models. The Rust YAML loader requires a single owner before it builds the graph, + so choose one deterministically in the bridge and leave only unresolvable metrics + at top level. + """ + top_level_metrics = list(graph.metrics.values()) + top_level_by_name = {metric.name: metric for metric in top_level_metrics} + cache: dict[str, str | None] = {} + + def model_metric_owners(metric_name: str) -> set[str]: + return { + model.name + for model in graph.models.values() + if any(model_metric.name == metric_name for model_metric in model.metrics) + } + + def owner_from_dotted_reference(reference: str) -> str | None: + if "." not in reference: + return None + model_name = reference.split(".", 1)[0] + return model_name if model_name in graph.models else None + + def owners_from_sql_fragment(fragment: str) -> list[str]: + owners = [] + for model_name, _field_name in re.findall(r"\b([A-Za-z_][A-Za-z0-9_]*)\.([A-Za-z_][A-Za-z0-9_]*)\b", fragment): + if model_name in graph.models and model_name not in owners: + owners.append(model_name) + return owners + + def model_owners_for_entity(entity: str | None) -> set[str]: + if not isinstance(entity, str): + return set() + dotted_owner = owner_from_dotted_reference(entity) + if dotted_owner: + return {dotted_owner} + return { + model.name + for model in graph.models.values() + if any(dimension.name == entity for dimension in model.dimensions) + } + + def metric_reference_strings(metric) -> list[str]: + references = [] + for attr in ( + "sql", + "base_metric", + "numerator", + "denominator", + "entity", + "base_event", + "conversion_event", + "cohort_event", + "activity_event", + "having", + ): + value = getattr(metric, attr, None) + if isinstance(value, str): + references.append(value) + steps = getattr(metric, "steps", None) + if isinstance(steps, list): + references.extend(step for step in steps if isinstance(step, str)) + for inner in getattr(metric, "inner_metrics", None) or []: + if isinstance(inner, dict): + value = inner.get("sql") + if isinstance(value, str): + references.append(value) + entity_dimensions = getattr(metric, "entity_dimensions", None) + if isinstance(entity_dimensions, list): + references.extend(value for value in entity_dimensions if isinstance(value, str)) + return references + + def preferred_ratio_owner(metric) -> str | None: + if isinstance(metric.denominator, str): + owner = owner_from_dotted_reference(metric.denominator) + if owner: + return owner + if isinstance(metric.sql, str) and "/" in metric.sql: + denominator_sql = metric.sql.split("/", 1)[1] + owners = owners_from_sql_fragment(denominator_sql) + if owners: + return owners[0] + return None + + def resolve_owner(metric) -> str | None: + cached = cache.get(metric.name) + if metric.name in cache: + return cached + + cache[metric.name] = None + owners = model_metric_owners(metric.name) + + try: + dependencies = metric.get_dependencies(graph) + except Exception: + dependencies = set() + + for dep in sorted(dependencies): + dotted_owner = owner_from_dotted_reference(dep) + if dotted_owner: + owners.add(dotted_owner) + continue + if dep in top_level_by_name: + dep_owner = resolve_owner(top_level_by_name[dep]) + if dep_owner: + owners.add(dep_owner) + continue + owners.update(model_metric_owners(dep)) + + for reference in metric_reference_strings(metric): + dotted_owner = owner_from_dotted_reference(reference) + if dotted_owner: + owners.add(dotted_owner) + owners.update(owners_from_sql_fragment(reference)) + + if not owners: + owners.update(model_owners_for_entity(getattr(metric, "entity", None))) + + if not owners and len(graph.models) == 1: + owners.add(next(iter(graph.models))) + + preferred_owner = preferred_ratio_owner(metric) + owner = preferred_owner if preferred_owner in owners else (sorted(owners)[0] if owners else None) + cache[metric.name] = owner + return owner + + assigned: dict[str, list] = {} + remaining = [] + for metric in top_level_metrics: + owner = resolve_owner(metric) + if owner: + assigned.setdefault(owner, []).append(metric) + else: + remaining.append(metric) + + return assigned, remaining + + +def _normalize_metric_type(metric_payload: dict, *, empty_filters_to_none: bool = False) -> dict: + normalized = dict(metric_payload) + metric_type = normalized.get("type") + if metric_type == "simple": + normalized["type"] = None + elif metric_type == "timecomparison": + normalized["type"] = "time_comparison" + if empty_filters_to_none and normalized.get("filters") == []: + normalized["filters"] = None + return normalized + + +def load_graph_from_yaml_with_rust(yaml_content: str) -> SemanticGraph: + """Parse native YAML definitions via sidemantic-rs and build a Python SemanticGraph.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.load_graph_with_yaml(yaml_content)) + return _graph_from_loaded_payload(payload) + + +def load_graph_from_sql_with_rust(sql_content: str) -> SemanticGraph: + """Parse SQL file content definitions via sidemantic-rs and build a Python SemanticGraph.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.load_graph_with_sql(sql_content)) + return _graph_from_loaded_payload(payload) + + +def load_graph_from_directory_with_rust(directory: str | Path) -> SemanticGraph: + """Parse supported directory definitions via sidemantic-rs and build a Python SemanticGraph.""" + rust_module = get_rust_module() + if not hasattr(rust_module, "load_graph_from_directory"): + raise ValueError( + "Rust backend requires a sidemantic_rs build with load_graph_from_directory. " + "Rebuild it with: uv run --with maturin maturin develop " + "--manifest-path sidemantic-rs/Cargo.toml --features python-adbc" + ) + payload = json.loads(rust_module.load_graph_from_directory(str(directory))) + return _graph_from_loaded_payload(payload) + + +def _graph_from_loaded_payload(payload: dict) -> SemanticGraph: + """Build a Python graph from sidemantic-rs loaded graph payload JSON.""" + + from sidemantic.core.metric import Metric + from sidemantic.core.model import Model + from sidemantic.core.parameter import Parameter + + graph = SemanticGraph() + top_level_metric_names = {metric["name"] for metric in payload.get("top_level_metrics") or []} + original_model_metrics = payload.get("original_model_metrics") or {} + model_sources = payload.get("model_sources") or {} + + for model_data in payload.get("models") or []: + normalized_model = dict(model_data) + original_metric_names = set(original_model_metrics.get(normalized_model.get("name"), [])) + normalized_metrics = [] + for metric_data in normalized_model.get("metrics") or []: + normalized_metric = _normalize_metric_type(metric_data) + metric_name = normalized_metric.get("name") + if metric_name in top_level_metric_names and metric_name not in original_metric_names: + continue + normalized_metrics.append(normalized_metric) + normalized_model["metrics"] = normalized_metrics + model = Model(**normalized_model) + source_metadata = model_sources.get(model.name) or {} + source_format = source_metadata.get("source_format") + source_file = source_metadata.get("source_file") + if source_format and not hasattr(model, "_source_format"): + model._source_format = source_format + if source_file and not hasattr(model, "_source_file"): + model._source_file = source_file + graph.add_model(model) + + for metric_data in payload.get("top_level_metrics") or []: + metric = Metric(**_normalize_metric_type(metric_data)) + graph.add_metric(metric) + + for parameter_data in payload.get("parameters") or []: + parameter = Parameter(**parameter_data) + graph.add_parameter(parameter) + + return graph + + +def parse_sql_definitions_with_rust(sql: str) -> tuple[list, list]: + """Parse SQL metric/segment definitions via sidemantic-rs.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.parse_sql_definitions_payload(sql)) + + from sidemantic.core.metric import Metric + from sidemantic.core.segment import Segment + + metrics = [ + Metric(**_normalize_metric_type(metric_payload, empty_filters_to_none=True)) + for metric_payload in payload.get("metrics") or [] + ] + segments = [Segment(**segment_payload) for segment_payload in payload.get("segments") or []] + return metrics, segments + + +def parse_sql_graph_definitions_with_rust(sql: str) -> tuple[list, list, list]: + """Parse SQL graph definitions (metrics, segments, parameters) via sidemantic-rs.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.parse_sql_graph_definitions_payload(sql)) + + from sidemantic.core.metric import Metric + from sidemantic.core.parameter import Parameter + from sidemantic.core.segment import Segment + + metrics = [ + Metric(**_normalize_metric_type(metric_payload, empty_filters_to_none=True)) + for metric_payload in payload.get("metrics") or [] + ] + segments = [Segment(**segment_payload) for segment_payload in payload.get("segments") or []] + parameters = [Parameter(**parameter_payload) for parameter_payload in payload.get("parameters") or []] + return metrics, segments, parameters + + +def parse_sql_graph_definitions_extended_with_rust(sql: str) -> tuple[list, list, list, list]: + """Parse SQL graph definitions including pre-aggregations via sidemantic-rs.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.parse_sql_graph_definitions_payload(sql)) + + from sidemantic.core.metric import Metric + from sidemantic.core.parameter import Parameter + from sidemantic.core.pre_aggregation import PreAggregation + from sidemantic.core.segment import Segment + + metrics = [ + Metric(**_normalize_metric_type(metric_payload, empty_filters_to_none=True)) + for metric_payload in payload.get("metrics") or [] + ] + segments = [Segment(**segment_payload) for segment_payload in payload.get("segments") or []] + parameters = [Parameter(**parameter_payload) for parameter_payload in payload.get("parameters") or []] + pre_aggregations = [PreAggregation(**preagg_payload) for preagg_payload in payload.get("pre_aggregations") or []] + return metrics, segments, parameters, pre_aggregations + + +def parse_sql_model_with_rust(sql: str): + """Parse SQL MODEL/DIMENSION/METRIC/SEGMENT/RELATIONSHIP definitions via sidemantic-rs.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.parse_sql_model_payload(sql)) + + from sidemantic.core.model import Model + + normalized_model = dict(payload) + normalized_model["metrics"] = [ + _normalize_metric_type(metric_payload, empty_filters_to_none=True) + for metric_payload in normalized_model.get("metrics") or [] + ] + return Model(**normalized_model) + + +def parse_sql_statement_blocks_with_rust(sql: str) -> list[dict]: + """Parse raw SQL statement blocks via sidemantic-rs.""" + rust_module = get_rust_module() + payload = json.loads(rust_module.parse_sql_statement_blocks_payload(sql)) + if isinstance(payload, list): + return payload + return [] + + +def models_to_rust_yaml( + models: list, + *, + extra_metrics_by_model: dict[str, list] | None = None, + top_level_metrics: list | None = None, + top_level_parameters: list | None = None, + include_extends: bool = False, +) -> str: + """Serialize model list to sidemantic-rs YAML schema.""" + serialized_models = [] + extra_metrics_by_model = extra_metrics_by_model or {} + top_level_metrics = top_level_metrics or [] + top_level_parameters = top_level_parameters or [] + models_by_name = {m.name: m for m in models} + + for model in models: + primary_key_columns = model.primary_key if isinstance(model.primary_key, list) else [model.primary_key] + model_data = { + "name": model.name, + "extends": model.extends if include_extends else None, + "table": model.table or (model.name if not model.sql else None), + "sql": model.sql, + "source_uri": model.source_uri, + "primary_key": primary_key_columns[0] if primary_key_columns else "id", + "primary_key_columns": primary_key_columns, + "unique_keys": model.unique_keys, + "description": model.description, + "label": None, + "default_time_dimension": model.default_time_dimension, + "default_grain": model.default_grain, + "dimensions": [], + "metrics": [], + "relationships": [], + "segments": [], + "pre_aggregations": [], + } + + for dimension in model.dimensions: + model_data["dimensions"].append( + { + "name": dimension.name, + "type": dimension.type, + "sql": dimension.sql, + "granularity": dimension.granularity, + "supported_granularities": dimension.supported_granularities, + "description": dimension.description, + "label": dimension.label, + "format": dimension.format, + "value_format_name": dimension.value_format_name, + "parent": dimension.parent, + "window": dimension.window, + } + ) + + serialized_metric_names = set() + for metric in [*model.metrics, *extra_metrics_by_model.get(model.name, [])]: + if metric.name in serialized_metric_names: + continue + serialized_metric_names.add(metric.name) + model_data["metrics"].append(_serialize_metric(metric, primary_key_columns=primary_key_columns)) + + for relationship in model.relationships: + rel_payload = _serialize_relationship( + relationship, + source_model=model, + target_model=models_by_name.get(relationship.name), + ) + if rel_payload: + model_data["relationships"].append(rel_payload) + + for segment in model.segments: + model_data["segments"].append( + { + "name": segment.name, + "sql": segment.sql, + "description": segment.description, + "public": segment.public, + } + ) + + for pre_aggregation in model.pre_aggregations: + model_data["pre_aggregations"].append(_serialize_pre_aggregation(pre_aggregation)) + + serialized_models.append(model_data) + + payload = {"models": serialized_models} + if top_level_metrics: + payload["metrics"] = [_serialize_metric(metric, primary_key_columns=None) for metric in top_level_metrics] + if top_level_parameters: + payload["parameters"] = [_serialize_parameter(parameter) for parameter in top_level_parameters] + + return yaml.safe_dump(payload, sort_keys=False) + + +def validate_query_with_rust(graph: SemanticGraph, metrics: list[str], dimensions: list[str]) -> list[str]: + """Validate query references and join reachability via sidemantic-rs.""" + rust_module = get_rust_module() + graph_yaml = graph_to_rust_yaml(graph) + metric_refs = list(metrics or []) + dimension_refs = list(dimensions or []) + + def _validate_with_legacy_payload() -> list[str]: + payload = { + "metrics": metric_refs, + "dimensions": dimension_refs, + } + return rust_module.validate_query_with_yaml( + graph_yaml, + yaml.safe_dump(payload, sort_keys=False), + ) + + if hasattr(rust_module, "validate_query_references"): + try: + errors = rust_module.validate_query_references( + graph_yaml, + metric_refs, + dimension_refs, + ) + except TypeError: + # Compatibility fallback for older sidemantic-rs extension builds. + errors = _validate_with_legacy_payload() + else: + # Compatibility fallback for older sidemantic-rs extension builds. + errors = _validate_with_legacy_payload() + return [str(error) for error in errors] + + +def validate_models_payload_with_rust( + models: list, + *, + top_level_metrics: list | None = None, + top_level_parameters: list | None = None, + include_extends: bool = True, +) -> bool: + """Validate model payload set via sidemantic-rs loader/graph semantics.""" + rust_module = get_rust_module() + models_yaml = models_to_rust_yaml( + models, + top_level_metrics=top_level_metrics or [], + top_level_parameters=top_level_parameters or [], + include_extends=include_extends, + ) + return bool(rust_module.validate_models_yaml(models_yaml)) + + +def validate_model_payload_with_rust(model_obj) -> bool: + """Validate model payload shape via sidemantic-rs.""" + rust_module = get_rust_module() + payload = yaml.safe_load(models_to_rust_yaml([model_obj], include_extends=False)) or {} + model_payload = (payload.get("models") or [{}])[0] + model_yaml = yaml.safe_dump(model_payload, sort_keys=False) + return bool(rust_module.validate_model_payload(model_yaml)) + + +def validate_metric_payload_with_rust(metric_obj) -> bool: + """Validate metric payload shape via sidemantic-rs.""" + rust_module = get_rust_module() + metric_yaml = yaml.safe_dump(metric_obj.model_dump(exclude_none=True), sort_keys=False) + return bool(rust_module.validate_metric_payload(metric_yaml)) + + +def validate_parameter_payload_with_rust(parameter_obj) -> bool: + """Validate parameter payload shape via sidemantic-rs.""" + rust_module = get_rust_module() + parameter_yaml = yaml.safe_dump(_serialize_parameter(parameter_obj), sort_keys=False) + return bool(rust_module.validate_parameter_payload(parameter_yaml)) + + +def validate_table_calculation_payload_with_rust(calculation_obj) -> bool: + """Validate table-calculation payload shape via sidemantic-rs.""" + rust_module = get_rust_module() + calculation_yaml = yaml.safe_dump(calculation_obj.model_dump(exclude_none=True), sort_keys=False) + return bool(rust_module.validate_table_calculation_payload(calculation_yaml)) + + +def resolve_model_inheritance_with_rust(models: dict[str, object]) -> dict[str, object]: + """Resolve model inheritance via sidemantic-rs.""" + rust_module = get_rust_module() + from sidemantic.core.model import Model + + passthrough_fields = ("source_uri", "unique_keys", "meta", "extends") + passthrough_cache: dict[str, dict] = {} + + def inherited_passthrough(model_name: str) -> dict: + cached = passthrough_cache.get(model_name) + if cached is not None: + return cached + + model = models.get(model_name) + if model is None: + values = {field: None for field in passthrough_fields} + passthrough_cache[model_name] = values + return values + + parent_name = getattr(model, "extends", None) + if parent_name and parent_name in models: + values = inherited_passthrough(parent_name).copy() + else: + values = {field: getattr(model, field, None) for field in passthrough_fields} + passthrough_cache[model_name] = values + return values + + models_yaml = models_to_rust_yaml(list(models.values()), include_extends=True) + has_extends = any(getattr(model, "extends", None) for model in models.values()) + if not has_extends: + # Keep exact Python model payloads for non-inheritance cases. + # This avoids lossy schema conversion while still exercising the Rust path. + rust_module.resolve_model_inheritance(models_yaml) + return dict(models) + + resolved_yaml = rust_module.resolve_model_inheritance(models_yaml) + resolved_payload = yaml.safe_load(resolved_yaml) or [] + resolved_models = {} + for model_data in resolved_payload: + normalized_data = dict(model_data) + normalized_metrics = [] + for metric_data in normalized_data.get("metrics") or []: + metric_payload = dict(metric_data) + metric_type = metric_payload.get("type") + if metric_type == "simple": + metric_payload["type"] = None + elif metric_type == "timecomparison": + metric_payload["type"] = "time_comparison" + normalized_metrics.append(metric_payload) + normalized_data["metrics"] = normalized_metrics + + model_name = normalized_data.get("name") + if isinstance(model_name, str): + for field, value in inherited_passthrough(model_name).items(): + normalized_data.setdefault(field, value) + + model = Model(**normalized_data) + resolved_models[model.name] = model + return resolved_models + + +def resolve_metric_inheritance_with_rust(metrics: dict[str, object]) -> dict[str, object]: + """Resolve metric inheritance via sidemantic-rs.""" + rust_module = get_rust_module() + from sidemantic.core.metric import Metric + + has_extends = any(getattr(metric, "extends", None) for metric in metrics.values()) + serialized_metrics = [] + for metric in metrics.values(): + child_fields = metric.model_fields_set - {"extends"} + metric_data = metric.model_dump(include=child_fields) + metric_data["name"] = metric.name + if metric.extends is not None: + metric_data["extends"] = metric.extends + serialized_metrics.append(metric_data) + + metrics_yaml = yaml.safe_dump(serialized_metrics, sort_keys=False) + if not has_extends: + # Keep exact Python metric payloads for non-inheritance cases. + # This avoids lossy schema conversion while still exercising the Rust path. + rust_module.resolve_metric_inheritance(metrics_yaml) + return dict(metrics) + + resolved_yaml = rust_module.resolve_metric_inheritance(metrics_yaml) + resolved_payload = yaml.safe_load(resolved_yaml) or [] + resolved_metrics = {} + for metric_data in resolved_payload: + normalized_data = dict(metric_data) + metric_type = normalized_data.get("type") + if metric_type == "simple": + normalized_data["type"] = None + elif metric_type == "timecomparison": + normalized_data["type"] = "time_comparison" + + metric = Metric(**normalized_data) + resolved_metrics[metric.name] = metric + return resolved_metrics + + +def parse_reference_with_rust(graph: SemanticGraph, reference: str) -> tuple[str, str, str | None]: + """Parse a qualified semantic reference via sidemantic-rs.""" + rust_module = get_rust_module() + parsed = rust_module.parse_reference_with_yaml( + graph_to_rust_yaml(graph), + reference, + ) + return _normalize_parsed_reference(parsed) + + +def find_relationship_path_with_rust(graph: SemanticGraph, from_model: str, to_model: str) -> list: + """Find join path between models via sidemantic-rs.""" + rust_module = get_rust_module() + rust_steps = rust_module.find_relationship_path_with_yaml( + graph_to_rust_yaml(graph), + from_model, + to_model, + ) + rust_steps = _deserialize_json_payload(rust_steps) + + from sidemantic.core.semantic_graph import JoinPath + + path = [] + for step in rust_steps: + from_name, to_name, from_columns, to_columns, relationship = _normalize_relationship_path_step(step) + path.append( + JoinPath( + from_model=str(from_name), + to_model=str(to_name), + from_columns=[str(column) for column in from_columns], + to_columns=[str(column) for column in to_columns], + relationship=str(relationship), + ) + ) + return path + + +def _normalize_parsed_reference(parsed: object) -> tuple[str, str, str | None]: + parsed = _deserialize_json_payload(parsed) + if isinstance(parsed, dict): + model_name = parsed.get("model_name", parsed.get("model")) + field_name = parsed.get("field_name", parsed.get("field")) + granularity = parsed.get("granularity") + else: + try: + model_name, field_name, granularity = parsed + except (TypeError, ValueError) as exc: + raise TypeError("unexpected parse_reference_with_yaml result shape") from exc + + if model_name is None or field_name is None: + raise TypeError("unexpected parse_reference_with_yaml result payload") + + return str(model_name), str(field_name), (str(granularity) if granularity is not None else None) + + +def _normalize_relationship_path_step(step: object) -> tuple[str, str, list[str], list[str], str]: + if isinstance(step, dict): + from_name = step.get("from_model") + to_name = step.get("to_model") + relationship = step.get("relationship") + from_columns = step.get("from_columns") + to_columns = step.get("to_columns") + + # Older payloads may expose only single-column aliases. + if from_columns is None: + from_entity = step.get("from_entity") + from_columns = [from_entity] if from_entity is not None else None + if to_columns is None: + to_entity = step.get("to_entity") + to_columns = [to_entity] if to_entity is not None else None + else: + try: + from_name, to_name, from_columns, to_columns, relationship = step + except (TypeError, ValueError) as exc: + raise TypeError("unexpected find_relationship_path_with_yaml step shape") from exc + + if from_name is None or to_name is None or relationship is None: + raise TypeError("unexpected find_relationship_path_with_yaml step payload") + + normalized_from_columns = _normalize_relationship_columns(from_columns) + normalized_to_columns = _normalize_relationship_columns(to_columns) + return str(from_name), str(to_name), normalized_from_columns, normalized_to_columns, str(relationship) + + +def _normalize_relationship_columns(columns: object) -> list[str]: + columns = _deserialize_json_payload(columns) + if columns is None: + return [] + if isinstance(columns, str): + return [columns] + try: + return [str(column) for column in columns] + except TypeError as exc: + raise TypeError("unexpected relationship column payload") from exc + + +def _deserialize_json_payload(payload: object) -> object: + if not isinstance(payload, str): + return payload + + text = payload.strip() + if not text or text[0] not in "[{": + return payload + + try: + return json.loads(text) + except json.JSONDecodeError: + return payload + + +def _base_dimension_reference(reference: str) -> str: + if "__" not in reference: + return reference + base_ref, granularity = reference.rsplit("__", 1) + if granularity in {"hour", "day", "week", "month", "quarter", "year"}: + return base_ref + return reference + + +def _infer_models_for_unqualified_references( + graph: SemanticGraph, + dimensions: list[str], + measures: list[str], +) -> set[str]: + inferred_models: set[str] = set() + + for dimension_ref in dimensions: + base_ref = _base_dimension_reference(dimension_ref) + if "." in base_ref: + continue + for model in graph.models.values(): + if any(dimension.name == base_ref for dimension in model.dimensions): + inferred_models.add(model.name) + + for metric_ref in measures: + if "." in metric_ref: + continue + + metric = graph.metrics.get(metric_ref) + sql_ref = getattr(metric, "sql", None) if metric is not None else None + if isinstance(sql_ref, str) and "." in sql_ref: + model_name = sql_ref.split(".", 1)[0].strip() + if model_name and model_name in graph.models: + inferred_models.add(model_name) + + for model in graph.models.values(): + if any(model_metric.name == metric_ref for model_metric in model.metrics): + inferred_models.add(model.name) + + return inferred_models + + +def find_models_for_query_with_rust( + graph: SemanticGraph, + dimensions: list[str], + measures: list[str], +) -> set[str]: + """Discover model names referenced by dimensions/measures via sidemantic-rs.""" + rust_module = get_rust_module() + dimension_refs = list(dimensions or []) + measure_refs = list(measures or []) + + if hasattr(rust_module, "find_models_for_query_with_yaml"): + try: + models = rust_module.find_models_for_query_with_yaml( + graph_to_rust_yaml(graph), + dimension_refs, + measure_refs, + ) + except TypeError: + # Compatibility fallback for older sidemantic-rs extension builds. + models = rust_module.find_models_for_query(dimension_refs, measure_refs) + else: + # Compatibility fallback for older sidemantic-rs extension builds. + models = rust_module.find_models_for_query(dimension_refs, measure_refs) + + resolved_models = {str(model_name) for model_name in models} + resolved_models.update(_infer_models_for_unqualified_references(graph, dimension_refs, measure_refs)) + return resolved_models + + +def _as_list(value: str | list[str] | None) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _serialize_metric(metric, *, primary_key_columns: list[str] | None) -> dict: + metric_sql = metric.sql + if metric.agg == "count_distinct" and not metric_sql and primary_key_columns: + if len(primary_key_columns) == 1: + metric_sql = primary_key_columns[0] + else: + casts = ", '|', ".join(f"CAST({col} AS VARCHAR)" for col in primary_key_columns) + metric_sql = f"CONCAT({casts})" + + return { + "name": metric.name, + "extends": metric.extends, + "type": metric.type, + "agg": metric.agg, + "sql": metric_sql, + "numerator": metric.numerator, + "denominator": metric.denominator, + "offset_window": metric.offset_window, + "window": metric.window, + "grain_to_date": metric.grain_to_date, + "window_expression": metric.window_expression, + "window_frame": metric.window_frame, + "window_order": metric.window_order, + "base_metric": metric.base_metric, + "comparison_type": metric.comparison_type, + "time_offset": metric.time_offset, + "calculation": metric.calculation, + "entity": metric.entity, + "base_event": metric.base_event, + "conversion_event": metric.conversion_event, + "conversion_window": metric.conversion_window, + "steps": metric.steps, + "cohort_event": metric.cohort_event, + "activity_event": metric.activity_event, + "periods": metric.periods, + "retention_granularity": metric.retention_granularity, + "inner_metrics": metric.inner_metrics, + "entity_dimensions": metric.entity_dimensions, + "having": metric.having, + "fill_nulls_with": metric.fill_nulls_with, + "format": metric.format, + "value_format_name": metric.value_format_name, + "drill_fields": metric.drill_fields, + "non_additive_dimension": metric.non_additive_dimension, + "filters": metric.filters or [], + "description": metric.description, + "label": metric.label, + } + + +def _serialize_parameter(parameter) -> dict: + return { + "name": parameter.name, + "type": parameter.type, + "description": parameter.description, + "label": parameter.label, + "default_value": parameter.default_value, + "allowed_values": parameter.allowed_values, + "default_to_today": parameter.default_to_today, + } + + +def is_sql_template_with_rust(sql: str) -> bool: + """Check template marker presence via sidemantic-rs.""" + rust_module = get_rust_module() + return bool(rust_module.is_sql_template(sql)) + + +def render_sql_template_with_rust(template_str: str, context: dict) -> str: + """Render a SQL template via sidemantic-rs.""" + rust_module = get_rust_module() + context_yaml = yaml.safe_dump(context or {}, sort_keys=False) + return rust_module.render_sql_template(template_str, context_yaml) + + +def format_parameter_value_with_rust(parameter, value) -> str: + """Format a parameter value via sidemantic-rs.""" + rust_module = get_rust_module() + parameter_yaml = yaml.safe_dump(_serialize_parameter(parameter), sort_keys=False) + value_yaml = yaml.safe_dump(value, sort_keys=False) + return rust_module.format_parameter_value_with_yaml(parameter_yaml, value_yaml) + + +def interpolate_sql_with_parameters_with_rust( + sql: str, + parameters: dict, + values: dict | None = None, +) -> str: + """Interpolate SQL placeholders/templates via sidemantic-rs.""" + rust_module = get_rust_module() + payload = [_serialize_parameter(parameter) for parameter in parameters.values()] + parameters_yaml = yaml.safe_dump(payload, sort_keys=False) + values_yaml = yaml.safe_dump(values or {}, sort_keys=False) + return rust_module.interpolate_sql_with_parameters(sql, parameters_yaml, values_yaml) + + +def evaluate_table_calculation_expression_with_rust(expr: str) -> float: + """Evaluate a table-calculation arithmetic expression via sidemantic-rs.""" + rust_module = get_rust_module() + return float(rust_module.evaluate_table_calculation_expression(expr)) + + +def chart_auto_detect_columns_with_rust(columns: list[str], numeric_flags: list[bool]) -> tuple[str, list[str]]: + """Auto-detect chart x/y columns via sidemantic-rs.""" + rust_module = get_rust_module() + x_col, y_cols = rust_module.chart_auto_detect_columns(list(columns), list(numeric_flags)) + return str(x_col), [str(col) for col in y_cols] + + +def chart_select_type_with_rust(x: str, x_value_kind: str, y_count: int) -> str: + """Select chart type via sidemantic-rs.""" + rust_module = get_rust_module() + return str(rust_module.chart_select_type(x, x_value_kind, int(y_count))) + + +def chart_format_label_with_rust(column: str) -> str: + """Format chart label via sidemantic-rs.""" + rust_module = get_rust_module() + return str(rust_module.chart_format_label(column)) + + +def chart_encoding_type_with_rust(column: str) -> str: + """Determine chart encoding type via sidemantic-rs.""" + rust_module = get_rust_module() + return str(rust_module.chart_encoding_type(column)) + + +def extract_column_references_with_rust(sql_expr: str) -> set[str]: + """Extract column references from a SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + refs = rust_module.extract_column_references(sql_expr) + return {str(ref) for ref in refs} + + +def analyze_migrator_query_with_rust(sql_query: str) -> dict: + """Analyze migrator query extraction payload via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.analyze_migrator_query(sql_query) + if isinstance(payload, str): + return json.loads(payload) + return dict(payload) + + +def extract_metric_dependencies_with_rust(metric_obj, graph=None, model_context: str | None = None) -> set[str]: + """Extract metric dependencies via sidemantic-rs.""" + rust_module = get_rust_module() + metric_yaml = yaml.safe_dump(metric_obj.model_dump(exclude_none=True), sort_keys=False) + models_yaml = graph_to_rust_yaml(graph) if graph is not None else None + refs = rust_module.extract_metric_dependencies(metric_yaml, models_yaml, model_context) + return {str(ref) for ref in refs} + + +def parse_simple_metric_aggregation_with_rust(sql_expr: str) -> tuple[str, str | None] | None: + """Parse a top-level simple metric aggregation via sidemantic-rs.""" + rust_module = get_rust_module() + parsed = rust_module.parse_simple_metric_aggregation(sql_expr) + if parsed is None: + return None + agg, inner = parsed + return str(agg), (str(inner) if inner is not None else None) + + +def metric_to_sql_with_rust(metric_obj) -> str: + """Render metric SQL aggregation via sidemantic-rs.""" + rust_module = get_rust_module() + metric_yaml = yaml.safe_dump(metric_obj.model_dump(exclude_none=True), sort_keys=False) + return rust_module.metric_to_sql(metric_yaml) + + +def metric_sql_expr_with_rust(metric_obj) -> str: + """Resolve metric SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + metric_yaml = yaml.safe_dump(metric_obj.model_dump(exclude_none=True), sort_keys=False) + return rust_module.metric_sql_expr(metric_yaml) + + +def metric_is_simple_aggregation_with_rust(metric_obj) -> bool: + """Check metric simple-aggregation status via sidemantic-rs.""" + rust_module = get_rust_module() + metric_yaml = yaml.safe_dump(metric_obj.model_dump(exclude_none=True), sort_keys=False) + return bool(rust_module.metric_is_simple_aggregation(metric_yaml)) + + +def dimension_sql_expr_with_rust(dimension_obj) -> str: + """Resolve dimension SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + dimension_yaml = yaml.safe_dump(dimension_obj.model_dump(exclude_none=True), sort_keys=False) + return rust_module.dimension_sql_expr(dimension_yaml) + + +def dimension_with_granularity_with_rust(dimension_obj, granularity: str) -> str: + """Apply time granularity to a dimension SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + dimension_yaml = yaml.safe_dump(dimension_obj.model_dump(exclude_none=True), sort_keys=False) + return rust_module.dimension_with_granularity(dimension_yaml, granularity) + + +def model_get_hierarchy_path_with_rust(model_obj, dimension_name: str) -> list[str]: + """Get model hierarchy path via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + return [str(item) for item in rust_module.model_get_hierarchy_path(model_yaml, dimension_name)] + + +def model_get_drill_down_with_rust(model_obj, dimension_name: str) -> str | None: + """Get model drill-down target via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_get_drill_down(model_yaml, dimension_name) + return str(result) if result is not None else None + + +def model_get_drill_up_with_rust(model_obj, dimension_name: str) -> str | None: + """Get model drill-up target via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_get_drill_up(model_yaml, dimension_name) + return str(result) if result is not None else None + + +def model_find_dimension_index_with_rust(model_obj, name: str) -> int | None: + """Find model dimension index by name via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_find_dimension_index(model_yaml, name) + return int(result) if result is not None else None + + +def model_find_metric_index_with_rust(model_obj, name: str) -> int | None: + """Find model metric index by name via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_find_metric_index(model_yaml, name) + return int(result) if result is not None else None + + +def model_find_segment_index_with_rust(model_obj, name: str) -> int | None: + """Find model segment index by name via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_find_segment_index(model_yaml, name) + return int(result) if result is not None else None + + +def model_find_pre_aggregation_index_with_rust(model_obj, name: str) -> int | None: + """Find model pre-aggregation index by name via sidemantic-rs.""" + rust_module = get_rust_module() + model_yaml = yaml.safe_dump(model_obj.model_dump(exclude_none=True), sort_keys=False) + result = rust_module.model_find_pre_aggregation_index(model_yaml, name) + return int(result) if result is not None else None + + +def relationship_sql_expr_with_rust(relationship_obj) -> str: + """Resolve relationship SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + relationship_yaml = yaml.safe_dump(relationship_obj.model_dump(), sort_keys=False) + return rust_module.relationship_sql_expr(relationship_yaml) + + +def relationship_related_key_with_rust(relationship_obj) -> str: + """Resolve relationship related key via sidemantic-rs.""" + rust_module = get_rust_module() + relationship_yaml = yaml.safe_dump(relationship_obj.model_dump(), sort_keys=False) + return rust_module.relationship_related_key(relationship_yaml) + + +def relationship_foreign_key_columns_with_rust(relationship_obj) -> list[str]: + """Resolve relationship foreign-key columns via sidemantic-rs.""" + rust_module = get_rust_module() + relationship_yaml = yaml.safe_dump(relationship_obj.model_dump(), sort_keys=False) + return [str(column) for column in rust_module.relationship_foreign_key_columns(relationship_yaml)] + + +def relationship_primary_key_columns_with_rust(relationship_obj) -> list[str]: + """Resolve relationship primary-key columns via sidemantic-rs.""" + rust_module = get_rust_module() + relationship_yaml = yaml.safe_dump(relationship_obj.model_dump(), sort_keys=False) + return [str(column) for column in rust_module.relationship_primary_key_columns(relationship_yaml)] + + +def segment_get_sql_with_rust(segment_obj, model_alias: str = "model") -> str: + """Resolve segment SQL placeholder interpolation via sidemantic-rs.""" + rust_module = get_rust_module() + segment_yaml = yaml.safe_dump(segment_obj.model_dump(), sort_keys=False) + return rust_module.segment_get_sql(segment_yaml, model_alias) + + +def validate_table_formula_expression_with_rust(expression: str) -> bool: + """Validate table-calculation formula syntax via sidemantic-rs.""" + rust_module = get_rust_module() + return bool(rust_module.validate_table_formula_expression(expression)) + + +def build_symmetric_aggregate_sql_with_rust( + measure_expr: str, + primary_key: str, + agg_type: str, + model_alias: str | None = None, + dialect: str = "duckdb", +) -> str: + """Build symmetric aggregate SQL via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.build_symmetric_aggregate_sql(measure_expr, primary_key, agg_type, model_alias, dialect) + + +def needs_symmetric_aggregate_with_rust(relationship: str, is_base_model: bool) -> bool: + """Evaluate symmetric aggregate need via sidemantic-rs.""" + rust_module = get_rust_module() + return bool(rust_module.needs_symmetric_aggregate(relationship, is_base_model)) + + +def parse_relative_date_with_rust(expr: str, dialect: str = "duckdb") -> str | None: + """Parse a relative date expression into SQL via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.parse_relative_date(expr, dialect) + + +def relative_date_to_range_with_rust( + expr: str, + column: str = "date_col", + dialect: str = "duckdb", +) -> str | None: + """Convert a relative date expression to a SQL range filter via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.relative_date_to_range(expr, column, dialect) + + +def is_relative_date_with_rust(expr: str) -> bool: + """Check relative-date expression recognition via sidemantic-rs.""" + rust_module = get_rust_module() + return bool(rust_module.is_relative_date(expr)) + + +def time_comparison_offset_interval_with_rust( + comparison_type: str, + offset: int | None = None, + offset_unit: str | None = None, +) -> tuple[int, str]: + """Resolve time comparison offset interval via sidemantic-rs.""" + rust_module = get_rust_module() + amount, unit = rust_module.time_comparison_offset_interval(comparison_type, offset, offset_unit) + return int(amount), str(unit) + + +def time_comparison_sql_offset_with_rust( + comparison_type: str, + offset: int | None = None, + offset_unit: str | None = None, +) -> str: + """Render SQL INTERVAL text for a time comparison via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.time_comparison_sql_offset(comparison_type, offset, offset_unit) + + +def trailing_period_sql_interval_with_rust(amount: int, unit: str) -> str: + """Render SQL INTERVAL text for a trailing period via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.trailing_period_sql_interval(amount, unit) + + +def generate_time_comparison_sql_with_rust( + *, + comparison_type: str, + calculation: str, + current_metric_sql: str, + time_dimension: str, + offset: int | None = None, + offset_unit: str | None = None, +) -> str: + """Generate time-comparison SQL expression via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.generate_time_comparison_sql( + comparison_type, + calculation, + current_metric_sql, + time_dimension, + offset, + offset_unit, + ) + + +def generate_catalog_metadata_with_rust(graph: SemanticGraph, schema: str = "public") -> dict: + """Generate Postgres-compatible catalog metadata via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.generate_catalog_metadata(graph_to_rust_yaml(graph), schema) + return json.loads(payload) + + +def _serialize_pre_aggregation(pre_aggregation) -> dict: + refresh_key = getattr(pre_aggregation, "refresh_key", None) + refresh_key_payload = None + if refresh_key is not None: + refresh_key_payload = { + "every": refresh_key.every, + "sql": refresh_key.sql, + "incremental": refresh_key.incremental, + "update_window": refresh_key.update_window, + } + + indexes = getattr(pre_aggregation, "indexes", None) + index_payload = None + if indexes is not None: + index_payload = [ + { + "name": idx.name, + "columns": list(idx.columns), + "type": idx.type, + } + for idx in indexes + ] + + return { + "name": pre_aggregation.name, + "type": pre_aggregation.type, + "measures": pre_aggregation.measures, + "dimensions": pre_aggregation.dimensions, + "time_dimension": pre_aggregation.time_dimension, + "granularity": pre_aggregation.granularity, + "partition_granularity": pre_aggregation.partition_granularity, + "build_range_start": pre_aggregation.build_range_start, + "build_range_end": pre_aggregation.build_range_end, + "scheduled_refresh": pre_aggregation.scheduled_refresh, + "refresh_key": refresh_key_payload, + "indexes": index_payload, + } + + +def match_preaggregation_with_rust( + model, + *, + metrics: list[str] | None = None, + dimensions: list[str] | None = None, + time_granularity: str | None = None, + filters: list[str] | None = None, +) -> str | None: + """Return matching pre-aggregation name using sidemantic-rs compile routing. + + This is a thin compatibility bridge for Python PreAggregationMatcher: + - compile a single-model query with `use_preaggregations=true` + - detect routed pre-aggregation from generated SQL table reference + """ + rust_module = get_rust_module() + metrics = metrics or [] + dimensions = list(dimensions or []) + filters = filters or [] + + model_for_match = model.model_copy(deep=True) if hasattr(model, "model_copy") else model.copy(deep=True) + existing_dim_names = {dimension.name for dimension in model_for_match.dimensions} + requested_dimension_names = [] + for dimension in dimensions: + bare = dimension.split(".", 1)[1] if "." in dimension else dimension + bare = bare.split("__", 1)[0] + requested_dimension_names.append(bare) + + for dim_name in requested_dimension_names: + if dim_name not in existing_dim_names: + from sidemantic.core.dimension import Dimension + + model_for_match.dimensions.append( + Dimension(name=dim_name, type="categorical", sql=dim_name), + ) + existing_dim_names.add(dim_name) + + qualified_metrics = [m if "." in m else f"{model_for_match.name}.{m}" for m in metrics] + qualified_dimensions = [d if "." in d else f"{model_for_match.name}.{d}" for d in dimensions] + + if time_granularity and not any("__" in d for d in qualified_dimensions): + time_dim = None + preagg_time_dims = { + p.time_dimension for p in model_for_match.pre_aggregations if getattr(p, "time_dimension", None) + } + if len(preagg_time_dims) == 1: + time_dim = next(iter(preagg_time_dims)) + if not time_dim: + time_dims = [d.name for d in model_for_match.dimensions if d.type == "time"] + if len(time_dims) == 1: + time_dim = time_dims[0] + if time_dim: + if time_dim not in existing_dim_names: + from sidemantic.core.dimension import Dimension + + model_for_match.dimensions.append( + Dimension(name=time_dim, type="time", sql=time_dim), + ) + qualified_dimensions.append(f"{model_for_match.name}.{time_dim}__{time_granularity}") + + models_yaml = models_to_rust_yaml([model_for_match]) + payload = { + "metrics": qualified_metrics, + "dimensions": qualified_dimensions, + "filters": filters, + "segments": [], + "order_by": [], + "limit": None, + "ungrouped": False, + "use_preaggregations": True, + "preagg_database": None, + "preagg_schema": None, + } + query_yaml = yaml.safe_dump(payload, sort_keys=False) + sql = rust_module.compile_with_yaml(models_yaml, query_yaml) + + for pre_aggregation in sorted(model.pre_aggregations, key=lambda p: len(p.name), reverse=True): + table_name = pre_aggregation.get_table_name(model_for_match.name) + if table_name in sql: + return pre_aggregation.name + return None + + +def generate_preaggregation_materialization_sql_with_rust(model, pre_aggregation) -> str: + """Generate pre-aggregation materialization SQL using sidemantic-rs.""" + rust_module = get_rust_module() + model_for_materialization = model.model_copy(deep=True) if hasattr(model, "model_copy") else model.copy(deep=True) + preagg_name = pre_aggregation.name + retained = [p for p in model_for_materialization.pre_aggregations if p.name != preagg_name] + retained.append(pre_aggregation) + model_for_materialization.pre_aggregations = retained + + models_yaml = models_to_rust_yaml([model_for_materialization]) + return rust_module.generate_preaggregation_materialization_sql( + models_yaml, + model_for_materialization.name, + preagg_name, + ) + + +def validate_engine_refresh_sql_compatibility_with_rust(source_sql: str, dialect: str) -> tuple[bool, str | None]: + """Validate engine refresh SQL compatibility via sidemantic-rs.""" + rust_module = get_rust_module() + is_valid, error_msg = rust_module.validate_engine_refresh_sql_compatibility(source_sql, dialect) + return bool(is_valid), (str(error_msg) if error_msg is not None else None) + + +def build_preaggregation_refresh_statements_with_rust( + *, + mode: str, + table_name: str, + source_sql: str, + watermark_column: str | None = None, + from_watermark: str | None = None, + lookback: str | None = None, + dialect: str | None = None, + refresh_every: str | None = None, +) -> list[str]: + """Build pre-aggregation refresh SQL statements via sidemantic-rs planner.""" + rust_module = get_rust_module() + statements = rust_module.build_preaggregation_refresh_statements( + mode=mode, + table_name=table_name, + source_sql=source_sql, + watermark_column=watermark_column, + from_watermark=from_watermark, + lookback=lookback, + dialect=dialect, + refresh_every=refresh_every, + ) + return [str(statement) for statement in statements] + + +def refresh_preaggregation_with_rust( + *, + pre_aggregation, + connection, + source_sql: str, + table_name: str, + mode: str | None, + watermark_column: str | None, + lookback: str | None, + from_watermark, + to_watermark, + dialect: str | None, +) -> dict: + """Execute pre-aggregation refresh via sidemantic-rs.""" + rust_module = get_rust_module() + + refresh_key = getattr(pre_aggregation, "refresh_key", None) + refresh_incremental = bool(refresh_key and refresh_key.incremental) + refresh_every = refresh_key.every if refresh_key else None + + def _refresh_with_mode(resolved_mode: str | None) -> dict: + return rust_module.refresh_preaggregation( + connection=connection, + source_sql=source_sql, + table_name=table_name, + mode=resolved_mode, + watermark_column=watermark_column, + lookback=lookback, + from_watermark=from_watermark, + to_watermark=to_watermark, + dialect=dialect, + refresh_incremental=refresh_incremental, + refresh_every=refresh_every, + ) + + try: + return _refresh_with_mode(mode) + except TypeError: + if mode is not None: + raise + + resolved_mode: str | None = None + if hasattr(rust_module, "plan_preaggregation_refresh_execution"): + try: + refresh_plan = rust_module.plan_preaggregation_refresh_execution( + mode, + refresh_incremental, + watermark_column, + dialect, + ) + planner_mode = refresh_plan.get("mode") if hasattr(refresh_plan, "get") else None + if planner_mode is not None: + resolved_mode = str(planner_mode) + except (TypeError, KeyError, AttributeError): + resolved_mode = None + + if resolved_mode is None: + if hasattr(rust_module, "resolve_preaggregation_refresh_mode"): + resolved_mode = rust_module.resolve_preaggregation_refresh_mode(mode, refresh_incremental) + else: + resolved_mode = "incremental" if refresh_incremental else "full" + + return _refresh_with_mode(resolved_mode) + + +def extract_preaggregation_patterns_with_rust(queries: list[str]) -> list[dict]: + """Extract grouped query patterns from instrumented SQL via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.extract_preaggregation_patterns(queries) + return json.loads(payload) + + +def recommend_preaggregation_patterns_with_rust( + patterns: list[dict], + *, + min_query_count: int, + min_benefit_score: float, + top_n: int | None = None, +) -> list[dict]: + """Build pre-aggregation recommendations from pattern counts via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.recommend_preaggregation_patterns( + json.dumps(patterns), + min_query_count, + min_benefit_score, + top_n, + ) + return json.loads(payload) + + +def summarize_preaggregation_patterns_with_rust( + patterns: list[dict], + *, + min_query_count: int, +) -> dict: + """Summarize pattern counts via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.summarize_preaggregation_patterns( + json.dumps(patterns), + min_query_count, + ) + return json.loads(payload) + + +def calculate_preaggregation_benefit_score_with_rust( + pattern: dict, + *, + count: int, +) -> float: + """Calculate recommender benefit score via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.calculate_preaggregation_benefit_score( + json.dumps(pattern), + count, + ) + + +def generate_preaggregation_name_with_rust(pattern: dict) -> str: + """Generate recommender name via sidemantic-rs.""" + rust_module = get_rust_module() + return rust_module.generate_preaggregation_name( + json.dumps(pattern), + ) + + +def generate_preaggregation_definition_with_rust(recommendation: dict) -> dict: + """Generate a pre-aggregation definition payload via sidemantic-rs.""" + rust_module = get_rust_module() + payload = rust_module.generate_preaggregation_definition( + json.dumps(recommendation), + ) + return json.loads(payload) + + +def _normalize_filter_sql(filter_sql: str) -> str: + sql = filter_sql.replace("{model}.", "") + sql = sql.replace("{model}", "") + sql = re.sub(r"\b\w+_cte\.", "", sql) + return sql + + +def _serialize_relationship(relationship, source_model, target_model) -> dict | None: + foreign_keys = _as_list(relationship.foreign_key) + if not foreign_keys: + foreign_keys = [f"{relationship.name}_id"] if relationship.type == "many_to_one" else ["id"] + + if relationship.primary_key is not None: + primary_keys = _as_list(relationship.primary_key) + elif target_model: + primary_keys = target_model.primary_key_columns + else: + primary_keys = ["id"] + + sql = None + if len(foreign_keys) > 1 and len(foreign_keys) == len(primary_keys): + sql = " AND ".join(f"{{from}}.{fk} = {{to}}.{pk}" for fk, pk in zip(foreign_keys, primary_keys, strict=False)) + if getattr(relationship, "sql", None): + sql = relationship.sql + + through_foreign_key = getattr(relationship, "through_foreign_key", None) + related_foreign_key = getattr(relationship, "related_foreign_key", None) + if relationship.type == "many_to_many": + junction_keys_fn = getattr(relationship, "junction_keys", None) + if callable(junction_keys_fn): + junction_self_fk, junction_related_fk = junction_keys_fn() + through_foreign_key = through_foreign_key or junction_self_fk + related_foreign_key = related_foreign_key or junction_related_fk + + return { + "name": relationship.name, + "type": relationship.type, + "foreign_key": foreign_keys[0] if foreign_keys else None, + "primary_key": primary_keys[0] if primary_keys else None, + "foreign_key_columns": foreign_keys, + "primary_key_columns": primary_keys, + "through": getattr(relationship, "through", None), + "through_foreign_key": through_foreign_key, + "related_foreign_key": related_foreign_key, + "sql": sql, + } diff --git a/sidemantic/rust_parity.py b/sidemantic/rust_parity.py new file mode 100644 index 00000000..e878ef83 --- /dev/null +++ b/sidemantic/rust_parity.py @@ -0,0 +1,69 @@ +"""Rust parity gating utilities. + +Used to enforce strict no-Python-fallback behavior during migration. +""" + +from __future__ import annotations + +import json +import os +from functools import lru_cache +from pathlib import Path + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent.parent + + +@lru_cache(maxsize=1) +def _load_parity_matrix() -> dict: + matrix_path = _repo_root() / "docs" / "rust-parity-matrix.json" + if not matrix_path.exists(): + return {"subsystems": {}} + try: + return json.loads(matrix_path.read_text()) + except Exception: + return {"subsystems": {}} + + +@lru_cache(maxsize=1) +def strict_targets() -> set[str]: + raw = os.getenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "").strip() + if not raw: + return set() + return {part.strip() for part in raw.split(",") if part.strip()} + + +def is_strict_mode() -> bool: + return bool(strict_targets()) + + +def is_strict_for(subsystem: str) -> bool: + targets = strict_targets() + if not targets: + return False + if "all" in targets: + return True + return subsystem in targets + + +def subsystem_status(subsystem: str) -> str: + matrix = _load_parity_matrix() + subsystems = matrix.get("subsystems", {}) + record = subsystems.get(subsystem, {}) + return record.get("status", "python_only") + + +def require_rust_subsystem(subsystem: str, feature: str) -> None: + """Raise when strict mode requires a subsystem that is not rust-backed.""" + if not is_strict_for(subsystem): + return + + status = subsystem_status(subsystem) + if status == "rust_backed": + return + + raise RuntimeError( + f"[rust-strict:{subsystem}] Feature '{feature}' is not rust-backed (status={status}). " + "Set SIDEMANTIC_RS_STRICT_SUBSYSTEMS to a narrower scope, or implement this subsystem in sidemantic-rs." + ) diff --git a/sidemantic/sql/generator.py b/sidemantic/sql/generator.py index 73845412..049a51cd 100644 --- a/sidemantic/sql/generator.py +++ b/sidemantic/sql/generator.py @@ -33,6 +33,65 @@ def __init__( self.dialect = dialect self.preagg_database = preagg_database self.preagg_schema = preagg_schema + self._generate_cache: dict[tuple[object, ...], str] = {} + self._generate_cache_limit = 256 + + @staticmethod + def _freeze_cache_value(value): + if isinstance(value, dict): + items = ( + (SQLGenerator._freeze_cache_value(k), SQLGenerator._freeze_cache_value(v)) for k, v in value.items() + ) + return tuple(sorted(items, key=repr)) + if isinstance(value, (list, tuple)): + return tuple(SQLGenerator._freeze_cache_value(item) for item in value) + if isinstance(value, set): + return tuple(sorted((SQLGenerator._freeze_cache_value(item) for item in value), key=repr)) + try: + hash(value) + return value + except TypeError: + return repr(value) + + def _generate_cache_key( + self, + metrics, + dimensions, + filters, + segments, + order_by, + limit, + offset, + parameters, + ungrouped, + use_preaggregations, + aliases, + skip_default_time_dimensions, + ) -> tuple[object, ...]: + return ( + getattr(self.graph, "_version", 0), + self.dialect, + self.preagg_database, + self.preagg_schema, + self._freeze_cache_value(metrics), + self._freeze_cache_value(dimensions), + self._freeze_cache_value(filters), + self._freeze_cache_value(segments), + self._freeze_cache_value(order_by), + limit, + offset, + self._freeze_cache_value(parameters), + ungrouped, + use_preaggregations, + self._freeze_cache_value(aliases), + skip_default_time_dimensions, + ) + + def _cache_generate_result(self, cache_key: tuple[object, ...], sql: str) -> str: + if len(self._generate_cache) >= self._generate_cache_limit: + self._generate_cache.pop(next(iter(self._generate_cache))) + self._generate_cache[cache_key] = sql + return sql def _date_trunc(self, granularity: str, column_expr: str) -> str: """Generate dialect-specific DATE_TRUNC expression. @@ -332,6 +391,23 @@ def generate( segments = segments or [] parameters = parameters or {} aliases = aliases or {} + cache_key = self._generate_cache_key( + metrics, + dimensions, + filters, + segments, + order_by, + limit, + offset, + parameters, + ungrouped, + use_preaggregations, + aliases, + skip_default_time_dimensions, + ) + cached = self._generate_cache.get(cache_key) + if cached is not None: + return cached # Auto-include default_time_dimension from metrics if not already present if not skip_default_time_dimensions: @@ -438,7 +514,10 @@ def metric_needs_window(m): needs_window_functions = any(metric_needs_window(m) for m in metrics) if needs_window_functions: - return self._generate_with_window_functions(metrics, dimensions, filters, order_by, limit, offset, aliases) + return self._cache_generate_result( + cache_key, + self._generate_with_window_functions(metrics, dimensions, filters, order_by, limit, offset, aliases), + ) # Parse dimension references and extract granularities parsed_dims = self._parse_dimension_refs(dimensions) @@ -449,15 +528,18 @@ def metric_needs_window(m): # Check if we need symmetric aggregation (pre-aggregation approach) # This is needed when metrics come from different models at different join levels if self._needs_preaggregation_for_fanout(metrics, dimensions): - return self._generate_with_preaggregation( - metrics=metrics, - dimensions=dimensions, - filters=filters, - segments=None, # Already resolved into filters above - order_by=order_by, - limit=limit, - offset=offset, - aliases=aliases, + return self._cache_generate_result( + cache_key, + self._generate_with_preaggregation( + metrics=metrics, + dimensions=dimensions, + filters=filters, + segments=None, # Already resolved into filters above + order_by=order_by, + limit=limit, + offset=offset, + aliases=aliases, + ), ) # Try to use pre-aggregation if enabled (single model queries only) @@ -476,7 +558,7 @@ def metric_needs_window(m): instrumentation = self._generate_instrumentation_comment( models=[model_names[0]], metrics=metrics, dimensions=dimensions, used_preagg=True ) - return preagg_sql + "\n" + instrumentation + return self._cache_generate_result(cache_key, preagg_sql + "\n" + instrumentation) if not model_names: raise ValueError("No models found for query") @@ -591,7 +673,7 @@ def metric_needs_window(m): ) full_sql = full_sql + "\n" + instrumentation - return full_sql + return self._cache_generate_result(cache_key, full_sql) def _parse_dimension_refs(self, dimensions: list[str]) -> list[tuple[str, str | None]]: """Parse dimension references to extract granularities. diff --git a/sidemantic/sql/query_rewriter.py b/sidemantic/sql/query_rewriter.py index f791e522..7f78c458 100644 --- a/sidemantic/sql/query_rewriter.py +++ b/sidemantic/sql/query_rewriter.py @@ -3,6 +3,7 @@ Parses user SQL and rewrites it to use the semantic layer. """ +import os from dataclasses import dataclass import sqlglot @@ -10,6 +11,7 @@ from sqlglot.tokens import TokenType from sidemantic.core.semantic_graph import SemanticGraph +from sidemantic.rust_bridge import get_rust_module, graph_to_rust_yaml from sidemantic.sql.aggregation_detection import sql_has_aggregate from sidemantic.sql.generator import SQLGenerator @@ -35,6 +37,18 @@ def __init__(self, graph: SemanticGraph, dialect: str = "duckdb"): self.graph = graph self.dialect = dialect self.generator = SQLGenerator(graph, dialect=dialect) + self._use_rust_rewriter = os.getenv("SIDEMANTIC_RS_REWRITER", "0") == "1" + self._rust_no_fallback = os.getenv("SIDEMANTIC_RS_NO_FALLBACK", "0") == "1" + self._rust_module = None + self._rust_models_yaml: str | None = None + + if self._use_rust_rewriter: + try: + self._rust_module = get_rust_module() + self._rust_models_yaml = graph_to_rust_yaml(self.graph) + except Exception: + if self._rust_no_fallback: + raise def rewrite(self, sql: str, strict: bool = True) -> str: """Rewrite user SQL to use semantic layer. @@ -115,6 +129,18 @@ def rewrite(self, sql: str, strict: bool = True) -> str: return sql return sql + references_semantic_model = self._select_tree_references_semantic_model(parsed) + if not references_semantic_model: + return sql + + self._raise_on_user_cte_name_collision(parsed) + + if self._use_rust_rewriter: + rust_sql = self._prepare_sql_for_rust(parsed, sql) + rust_rewritten = self._rewrite_with_rust(rust_sql, strict=strict) + if rust_rewritten is not None: + return rust_rewritten + # Check if this is a CTE-based query or has subqueries has_ctes = parsed.args.get("with") is not None has_subquery_in_from = self._has_subquery_in_from(parsed) @@ -124,9 +150,6 @@ def rewrite(self, sql: str, strict: bool = True) -> str: # Handle CTEs and subqueries return self._rewrite_with_ctes_or_subqueries(parsed) - if not self._references_semantic_model(parsed): - return sql - # Otherwise, treat as simple semantic layer query return self._rewrite_simple_query(parsed) @@ -1898,6 +1921,108 @@ def _rewrite_with_ctes_or_subqueries(self, parsed: exp.Select) -> str: return parsed.sql(dialect=self.dialect) + def _select_tree_references_semantic_model(self, select: exp.Select) -> bool: + if self._references_semantic_model(select): + return True + + for nested_select in select.find_all(exp.Select): + if nested_select is select: + continue + if self._references_semantic_model(nested_select): + return True + + return False + + def _raise_on_user_cte_name_collision(self, select: exp.Select) -> None: + with_clause = select.args.get("with") + if not with_clause: + return + + reserved_names = self._generated_cte_names_for_select_tree(select) + for cte in with_clause.expressions: + if cte.alias in reserved_names: + raise ValueError( + f"CTE name '{cte.alias}' conflicts with an internally " + f"generated name. Please choose a different CTE name." + ) + + def _generated_cte_names_for_select_tree(self, select: exp.Select) -> set[str]: + reserved_names: set[str] = set() + for nested_select in select.find_all(exp.Select): + if not self._references_semantic_model(nested_select): + continue + reserved_names.update(self._generated_cte_names_for_semantic_select(nested_select)) + return reserved_names + + def _generated_cte_names_for_semantic_select(self, select: exp.Select) -> set[str]: + had_inferred_table = hasattr(self, "inferred_table") + previous_inferred_table = getattr(self, "inferred_table", None) + try: + self.inferred_table = self._extract_from_table(select) + metrics, dimensions, _aliases = self._extract_metrics_and_dimensions(select) + filters = self._extract_filters(select) + model_names = self.generator._find_required_models(metrics, dimensions, filters) + return { + self.generator._cte_name(model_name) for model_name in model_names if model_name in self.graph.models + } + except Exception: + return set() + finally: + if had_inferred_table: + self.inferred_table = previous_inferred_table + elif hasattr(self, "inferred_table"): + del self.inferred_table + + def _rewrite_with_rust(self, sql: str, strict: bool = True) -> str | None: + """Rewrite using sidemantic-rs bindings, returning None to allow Python fallback.""" + if not self._rust_module: + if strict and self._rust_no_fallback: + raise ValueError("Rust rewriter backend is not initialized") + return None + + try: + models_yaml = self._rust_models_yaml + if models_yaml is None: + models_yaml = graph_to_rust_yaml(self.graph) + self._rust_models_yaml = models_yaml + return self._rust_module.rewrite_with_yaml(models_yaml, sql) + except Exception as e: + if self._rust_no_fallback: + raise ValueError(f"Rust rewriter failed: {e}") from e + return None + + def _prepare_sql_for_rust(self, parsed: exp.Select, original_sql: str) -> str: + """Normalize Python-only graph metric shorthand to SQL sidemantic-rs can rewrite.""" + if self._extract_from_table(parsed) != "metrics": + return original_sql + + changed = False + rewritten_projections: list[exp.Expression] = [] + + for projection in parsed.expressions: + alias_name = projection.alias_or_name if isinstance(projection, exp.Alias) else None + node = projection.this if isinstance(projection, exp.Alias) else projection + + if isinstance(node, exp.Column) and not node.table and node.name in self.graph.metrics: + graph_metric = self.graph.metrics[node.name] + if graph_metric.sql: + try: + metric_expr = sqlglot.parse_one(graph_metric.sql, dialect=self.dialect) + except Exception: + return original_sql + rewritten_projections.append(exp.alias_(metric_expr, alias_name or node.name, copy=False)) + changed = True + continue + + rewritten_projections.append(projection) + + if not changed: + return original_sql + + rewritten = parsed.copy() + rewritten.set("expressions", rewritten_projections) + return rewritten.sql(dialect=self.dialect) + def _rewrite_select_tree(self, select: exp.Select): """Recursively rewrite semantic subqueries and CTEs (bottom-up). diff --git a/tests/core/test_rust_bridge_yaml_serialization.py b/tests/core/test_rust_bridge_yaml_serialization.py new file mode 100644 index 00000000..0039c293 --- /dev/null +++ b/tests/core/test_rust_bridge_yaml_serialization.py @@ -0,0 +1,140 @@ +"""Regression coverage for Python->Rust YAML bridge serialization fidelity.""" + +import yaml + +from sidemantic.core.dimension import Dimension +from sidemantic.core.metric import Metric +from sidemantic.core.model import Model +from sidemantic.core.pre_aggregation import Index, PreAggregation, RefreshKey +from sidemantic.core.semantic_graph import SemanticGraph +from sidemantic.rust_bridge import graph_to_rust_yaml, models_to_rust_yaml + + +def test_models_to_rust_yaml_preserves_extended_core_metadata(): + model = Model( + name="orders", + table="orders", + primary_key=["order_id", "tenant_id"], + source_uri="s3://warehouse/orders", + extends="base_orders", + unique_keys=[["order_id", "tenant_id"]], + default_time_dimension="order_date", + default_grain="day", + dimensions=[ + Dimension( + name="order_date", + type="time", + sql="order_date", + granularity="day", + supported_granularities=["day", "week", "month"], + format="yyyy-mm-dd", + value_format_name="iso_date", + parent="order_month", + ) + ], + metrics=[ + Metric( + name="revenue", + agg="sum", + sql="amount", + value_format_name="usd", + drill_fields=["order_id"], + non_additive_dimension="order_date", + ) + ], + pre_aggregations=[ + PreAggregation( + name="daily_rollup", + measures=["revenue"], + dimensions=["status"], + time_dimension="order_date", + granularity="day", + refresh_key=RefreshKey(every="1 hour", incremental=True, update_window="7 day"), + indexes=[Index(name="idx_status", columns=["status"], type="regular")], + ) + ], + ) + + payload = yaml.safe_load(models_to_rust_yaml([model], include_extends=True)) + model_payload = payload["models"][0] + dimension_payload = model_payload["dimensions"][0] + metric_payload = model_payload["metrics"][0] + preagg_payload = model_payload["pre_aggregations"][0] + + assert model_payload["source_uri"] == "s3://warehouse/orders" + assert model_payload["extends"] == "base_orders" + assert model_payload["unique_keys"] == [["order_id", "tenant_id"]] + + assert dimension_payload["supported_granularities"] == ["day", "week", "month"] + assert dimension_payload["format"] == "yyyy-mm-dd" + assert dimension_payload["value_format_name"] == "iso_date" + assert dimension_payload["parent"] == "order_month" + + assert metric_payload["value_format_name"] == "usd" + assert metric_payload["drill_fields"] == ["order_id"] + assert metric_payload["non_additive_dimension"] == "order_date" + + assert preagg_payload["refresh_key"]["every"] == "1 hour" + assert preagg_payload["refresh_key"]["incremental"] is True + assert preagg_payload["indexes"] == [{"name": "idx_status", "columns": ["status"], "type": "regular"}] + + +def test_graph_to_rust_yaml_assigns_complex_metrics_by_entity_dimension(): + graph = SemanticGraph() + graph.add_model( + Model( + name="events", + table="events", + primary_key="event_id", + dimensions=[ + Dimension(name="user_id", type="categorical"), + Dimension(name="event_type", type="categorical"), + Dimension(name="platform", type="categorical"), + ], + ) + ) + graph.add_model( + Model( + name="orders", + table="orders", + primary_key="order_id", + dimensions=[Dimension(name="order_id", type="categorical")], + ) + ) + graph.add_metric( + Metric( + name="signup_conversion", + type="conversion", + entity="user_id", + base_event="event_type = 'signup'", + conversion_event="event_type = 'purchase'", + conversion_window="7 days", + ) + ) + graph.add_metric( + Metric( + name="signup_retention", + type="retention", + entity="user_id", + cohort_event="event_type = 'signup'", + ) + ) + graph.add_metric( + Metric( + name="multi_platform_users", + type="cohort", + entity="user_id", + inner_metrics=[{"name": "platform_count", "agg": "count_distinct", "sql": "platform"}], + having="platform_count >= 2", + agg="count", + ) + ) + + payload = yaml.safe_load(graph_to_rust_yaml(graph)) + models = {model["name"]: model for model in payload["models"]} + event_metric_names = {metric["name"] for metric in models["events"]["metrics"]} + order_metric_names = {metric["name"] for metric in models["orders"].get("metrics", [])} + + assert {"signup_conversion", "signup_retention", "multi_platform_users"} <= event_metric_names + assert not {"signup_conversion", "signup_retention", "multi_platform_users"} & order_metric_names + assert payload.get("metrics") in (None, []) diff --git a/tests/core/test_rust_parity.py b/tests/core/test_rust_parity.py new file mode 100644 index 00000000..c561fdde --- /dev/null +++ b/tests/core/test_rust_parity.py @@ -0,0 +1,113 @@ +"""Tests for Rust strict-mode parity matrix gating.""" + +import json + +import pytest + +import sidemantic.rust_parity as rust_parity + + +@pytest.fixture(autouse=True) +def _reset_rust_parity_caches(): + rust_parity._load_parity_matrix.cache_clear() + rust_parity.strict_targets.cache_clear() + yield + rust_parity._load_parity_matrix.cache_clear() + rust_parity.strict_targets.cache_clear() + + +def _write_matrix(tmp_path, payload: dict) -> None: + docs = tmp_path / "docs" + docs.mkdir() + (docs / "rust-parity-matrix.json").write_text(json.dumps(payload)) + + +def _use_tmp_matrix(monkeypatch, tmp_path) -> None: + monkeypatch.setattr(rust_parity, "_repo_root", lambda: tmp_path) + rust_parity._load_parity_matrix.cache_clear() + + +def test_require_rust_subsystem_passes_for_rust_backed_target(monkeypatch, tmp_path): + _write_matrix( + tmp_path, + {"subsystems": {"sql_generator_entrypoint": {"status": "rust_backed"}}}, + ) + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "sql_generator_entrypoint") + rust_parity.strict_targets.cache_clear() + + rust_parity.require_rust_subsystem("sql_generator_entrypoint", "compile") + + +def test_require_rust_subsystem_ignores_non_strict_targets(monkeypatch, tmp_path): + _write_matrix( + tmp_path, + {"subsystems": {"semantic_sql_rewriter": {"status": "rust_backed_opt_in"}}}, + ) + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "sql_generator_entrypoint") + rust_parity.strict_targets.cache_clear() + + rust_parity.require_rust_subsystem("semantic_sql_rewriter", "rewrite") + + +def test_require_rust_subsystem_fails_for_non_rust_backed_strict_target(monkeypatch, tmp_path): + _write_matrix( + tmp_path, + {"subsystems": {"semantic_sql_rewriter": {"status": "rust_backed_opt_in"}}}, + ) + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "semantic_sql_rewriter") + rust_parity.strict_targets.cache_clear() + + with pytest.raises(RuntimeError, match=r"\[rust-strict:semantic_sql_rewriter\].*rust_backed_opt_in"): + rust_parity.require_rust_subsystem("semantic_sql_rewriter", "rewrite") + + +def test_require_rust_subsystem_all_targets_every_subsystem(monkeypatch, tmp_path): + _write_matrix( + tmp_path, + {"subsystems": {"sql_generator_entrypoint": {"status": "rust_backed"}}}, + ) + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "all") + rust_parity.strict_targets.cache_clear() + + with pytest.raises(RuntimeError, match=r"\[rust-strict:semantic_sql_rewriter\].*python_only"): + rust_parity.require_rust_subsystem("semantic_sql_rewriter", "rewrite") + + +def test_strict_targets_parses_comma_separated_values(monkeypatch): + monkeypatch.setenv( + "SIDEMANTIC_RS_STRICT_SUBSYSTEMS", + " sql_generator_entrypoint, semantic_core_query_validation ,,", + ) + rust_parity.strict_targets.cache_clear() + + assert rust_parity.strict_targets() == { + "sql_generator_entrypoint", + "semantic_core_query_validation", + } + assert rust_parity.is_strict_for("semantic_core_query_validation") is True + assert rust_parity.is_strict_for("semantic_sql_rewriter") is False + + +def test_missing_matrix_falls_back_to_python_only(monkeypatch, tmp_path): + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "sql_generator_entrypoint") + rust_parity.strict_targets.cache_clear() + + with pytest.raises(RuntimeError, match="python_only"): + rust_parity.require_rust_subsystem("sql_generator_entrypoint", "compile") + + +def test_invalid_matrix_falls_back_to_python_only(monkeypatch, tmp_path): + docs = tmp_path / "docs" + docs.mkdir() + (docs / "rust-parity-matrix.json").write_text("{") + _use_tmp_matrix(monkeypatch, tmp_path) + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "sql_generator_entrypoint") + rust_parity.strict_targets.cache_clear() + + with pytest.raises(RuntimeError, match="python_only"): + rust_parity.require_rust_subsystem("sql_generator_entrypoint", "compile") diff --git a/tests/core/test_rust_query_validation.py b/tests/core/test_rust_query_validation.py new file mode 100644 index 00000000..61ffc14c --- /dev/null +++ b/tests/core/test_rust_query_validation.py @@ -0,0 +1,207 @@ +"""Strict and env-gated tests for Rust-backed query validation.""" + +import pytest + +import sidemantic.rust_parity as rust_parity +import sidemantic.validation as validation_module +from sidemantic.core.dimension import Dimension +from sidemantic.core.metric import Metric +from sidemantic.core.model import Model +from sidemantic.core.semantic_layer import SemanticLayer + + +@pytest.fixture(autouse=True) +def _reset_strict_targets_cache(): + rust_parity.strict_targets.cache_clear() + yield + rust_parity.strict_targets.cache_clear() + + +def _clear_strict_cache() -> None: + rust_parity.strict_targets.cache_clear() + + +def _build_layer() -> SemanticLayer: + layer = SemanticLayer(auto_register=False) + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="order_id", + dimensions=[ + Dimension(name="status", type="categorical"), + Dimension(name="created_at", type="time", granularity="second"), + ], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + return layer + + +def test_query_validation_routes_to_rust(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + monkeypatch.setenv("SIDEMANTIC_RS_QUERY_VALIDATION", "1") + monkeypatch.delenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", raising=False) + monkeypatch.delenv("SIDEMANTIC_RS_NO_FALLBACK", raising=False) + _clear_strict_cache() + + layer = _build_layer() + monkeypatch.setattr(rust_bridge, "validate_query_with_rust", lambda *_args, **_kwargs: []) + monkeypatch.setattr( + validation_module, + "validate_query", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("python validation should not run")), + ) + + sql = layer.compile(metrics=["orders.revenue"], dimensions=["orders.status"]) + assert "SELECT" in sql + + +def test_query_validation_strict_raises_without_fallback(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "semantic_core_query_validation") + monkeypatch.setenv("SIDEMANTIC_RS_NO_FALLBACK", "1") + monkeypatch.delenv("SIDEMANTIC_RS_QUERY_VALIDATION", raising=False) + _clear_strict_cache() + + layer = _build_layer() + monkeypatch.setattr( + rust_bridge, + "validate_query_with_rust", + lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("rust validation failure")), + ) + + with pytest.raises( + validation_module.QueryValidationError, match="Rust query validation failed: rust validation failure" + ): + layer.compile(metrics=["orders.revenue"], dimensions=["orders.status"]) + + +def test_query_validation_with_rust_matches_python_error_text(monkeypatch): + monkeypatch.setenv("SIDEMANTIC_RS_QUERY_VALIDATION", "1") + monkeypatch.delenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", raising=False) + monkeypatch.delenv("SIDEMANTIC_RS_NO_FALLBACK", raising=False) + _clear_strict_cache() + + layer = _build_layer() + + with pytest.raises(validation_module.QueryValidationError) as exc_info: + layer.compile(metrics=["missing_metric"], dimensions=["orders.status"]) + + assert "Metric 'missing_metric' not found" in str(exc_info.value) + + +def test_query_validation_with_rust_accepts_subhour_time_granularities(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + monkeypatch.setenv("SIDEMANTIC_RS_QUERY_VALIDATION", "1") + monkeypatch.delenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", raising=False) + monkeypatch.delenv("SIDEMANTIC_RS_NO_FALLBACK", raising=False) + _clear_strict_cache() + monkeypatch.setattr( + rust_bridge, + "validate_query_with_rust", + lambda _graph, metrics, dimensions: [] + if metrics == ["orders.revenue"] and dimensions == ["orders.created_at__minute"] + else ["unexpected query"], + ) + + layer = _build_layer() + sql = layer.compile(metrics=["orders.revenue"], dimensions=["orders.created_at__minute"]) + + assert "DATE_TRUNC('MINUTE'" in sql + + +def test_validate_query_with_rust_prefers_reference_entrypoint(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + calls = {"references": 0} + + class _FakeRustModule: + def validate_query_references(self, models_yaml, metrics, dimensions): + calls["references"] += 1 + assert models_yaml == "models: []" + assert metrics == ["orders.revenue"] + assert dimensions == ["orders.status"] + return [] + + def validate_query_with_yaml(self, _models_yaml, _query_yaml): + raise AssertionError("legacy validation entrypoint should not be used") + + monkeypatch.setattr(rust_bridge, "get_rust_module", lambda: _FakeRustModule()) + monkeypatch.setattr(rust_bridge, "graph_to_rust_yaml", lambda _graph: "models: []") + + layer = _build_layer() + errors = rust_bridge.validate_query_with_rust(layer.graph, ["orders.revenue"], ["orders.status"]) + assert errors == [] + assert calls["references"] == 1 + + +def test_validate_query_with_rust_falls_back_to_legacy_payload_entrypoint(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + calls = {"legacy": 0} + + class _LegacyRustModule: + def validate_query_with_yaml(self, models_yaml, query_yaml): + calls["legacy"] += 1 + assert models_yaml == "models: []" + assert "metrics:" in query_yaml + assert "dimensions:" in query_yaml + return ["legacy error"] + + monkeypatch.setattr(rust_bridge, "get_rust_module", lambda: _LegacyRustModule()) + monkeypatch.setattr(rust_bridge, "graph_to_rust_yaml", lambda _graph: "models: []") + + layer = _build_layer() + errors = rust_bridge.validate_query_with_rust(layer.graph, ["orders.revenue"], ["orders.status"]) + assert errors == ["legacy error"] + assert calls["legacy"] == 1 + + +def test_validate_query_with_rust_falls_back_on_reference_signature_incompatibility(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + calls = {"legacy": 0} + + class _IncompatibleReferenceRustModule: + def validate_query_references(self, _models_yaml, _query_yaml): + return [] + + def validate_query_with_yaml(self, models_yaml, query_yaml): + calls["legacy"] += 1 + assert models_yaml == "models: []" + assert "metrics:" in query_yaml + assert "dimensions:" in query_yaml + return ["legacy signature fallback"] + + monkeypatch.setattr(rust_bridge, "get_rust_module", lambda: _IncompatibleReferenceRustModule()) + monkeypatch.setattr(rust_bridge, "graph_to_rust_yaml", lambda _graph: "models: []") + + layer = _build_layer() + errors = rust_bridge.validate_query_with_rust(layer.graph, ["orders.revenue"], ["orders.status"]) + assert errors == ["legacy signature fallback"] + assert calls["legacy"] == 1 + + +def test_validate_query_with_rust_propagates_reference_validation_error(monkeypatch): + import sidemantic.rust_bridge as rust_bridge + + class _ReferenceValidationErrorRustModule: + def validate_query_references(self, models_yaml, metrics, dimensions): + assert models_yaml == "models: []" + assert metrics == ["orders.revenue"] + assert dimensions == ["orders.status"] + raise ValueError("reference validation failure") + + def validate_query_with_yaml(self, _models_yaml, _query_yaml): + raise AssertionError("legacy entrypoint should not be used for validation errors") + + monkeypatch.setattr(rust_bridge, "get_rust_module", lambda: _ReferenceValidationErrorRustModule()) + monkeypatch.setattr(rust_bridge, "graph_to_rust_yaml", lambda _graph: "models: []") + + layer = _build_layer() + with pytest.raises(ValueError, match="reference validation failure"): + rust_bridge.validate_query_with_rust(layer.graph, ["orders.revenue"], ["orders.status"]) diff --git a/tests/core/test_rust_strict_sql_generator.py b/tests/core/test_rust_strict_sql_generator.py new file mode 100644 index 00000000..82f6e952 --- /dev/null +++ b/tests/core/test_rust_strict_sql_generator.py @@ -0,0 +1,178 @@ +"""Strict-mode behavior tests for Rust SQL generator entrypoint.""" + +import pytest +import yaml + +import sidemantic.core.semantic_layer as semantic_layer_module +import sidemantic.rust_parity as rust_parity +from sidemantic.core.metric import Metric +from sidemantic.core.model import Model +from sidemantic.core.semantic_layer import SemanticLayer + + +@pytest.fixture(autouse=True) +def _reset_strict_targets_cache(): + rust_parity.strict_targets.cache_clear() + yield + rust_parity.strict_targets.cache_clear() + + +def _configure_strict_sql_entrypoint(monkeypatch) -> None: + monkeypatch.setenv("SIDEMANTIC_RS_STRICT_SUBSYSTEMS", "sql_generator_entrypoint") + monkeypatch.delenv("SIDEMANTIC_RS_SQL_GENERATOR", raising=False) + monkeypatch.delenv("SIDEMANTIC_RS_SQL_GENERATOR_VERIFY", raising=False) + rust_parity.strict_targets.cache_clear() + + +def _build_layer(monkeypatch) -> SemanticLayer: + monkeypatch.setattr(semantic_layer_module, "get_rust_module", lambda: object()) + layer = SemanticLayer(auto_register=False) + layer.add_model( + Model( + name="orders", + table="orders", + primary_key="order_id", + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + return layer + + +def test_strict_sql_entrypoint_forces_rust_compile(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + layer = _build_layer(monkeypatch) + + calls = {"rust": 0, "python": 0} + + def fake_rust_compile(**_kwargs): + calls["rust"] += 1 + return "SELECT 1" + + def fake_python_compile(**_kwargs): + calls["python"] += 1 + return "SELECT 2" + + monkeypatch.setattr(layer, "_compile_with_rust", fake_rust_compile) + monkeypatch.setattr(layer, "_compile_with_python", fake_python_compile) + + sql = layer.compile(metrics=["orders.revenue"]) + assert sql.startswith("SELECT 1") + assert calls["rust"] == 1 + assert calls["python"] == 0 + assert layer._use_rust_sql_generator is True + + +def test_strict_sql_entrypoint_rejects_python_fallback(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + layer = _build_layer(monkeypatch) + + monkeypatch.setattr(layer, "_compile_with_rust", lambda **_kwargs: None) + monkeypatch.setattr( + layer, + "_compile_with_python", + lambda **_kwargs: (_ for _ in ()).throw(AssertionError("python fallback should not run")), + ) + + with pytest.raises(ValueError, match="returned no SQL in strict mode"): + layer.compile(metrics=["orders.revenue"]) + + +def test_strict_sql_entrypoint_disables_python_verify(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + layer = _build_layer(monkeypatch) + assert layer._rust_sql_verify is False + + +def test_rust_compile_payload_includes_preaggregation_flags(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + layer = _build_layer(monkeypatch) + captured = {} + + class FakeRustModule: + def compile_with_yaml(self, _models_yaml, query_yaml): + captured.update(yaml.safe_load(query_yaml)) + return "SELECT 1" + + layer._rust_module = FakeRustModule() + sql = layer.compile(metrics=["orders.revenue"], offset=10, use_preaggregations=True) + + assert sql.startswith("SELECT 1") + assert "OFFSET" not in sql + assert captured["offset"] == 10 + assert captured["use_preaggregations"] is True + assert "preagg_database" in captured + assert "preagg_schema" in captured + + +def test_rust_compile_transpiles_from_rust_output_dialect(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + layer = _build_layer(monkeypatch) + layer.dialect = "bigquery" + + class FakeRustModule: + def compile_with_yaml(self, _models_yaml, _query_yaml): + return "SELECT DATE_TRUNC('month', order_date) AS order_month FROM orders_cte" + + layer._rust_module = FakeRustModule() + sql = layer.compile(metrics=["orders.revenue"], dialect=None) + + assert "DATE_TRUNC(order_date, MONTH)" in sql + assert "DATE_TRUNC('month', order_date)" not in sql + + +def test_rust_compile_payload_includes_complex_metric_fields(monkeypatch): + _configure_strict_sql_entrypoint(monkeypatch) + monkeypatch.setattr(semantic_layer_module, "get_rust_module", lambda: object()) + layer = SemanticLayer(auto_register=False) + layer.add_model( + Model( + name="events", + table="events", + primary_key="event_id", + metrics=[ + Metric( + name="signup_funnel", + type="conversion", + entity="user_id", + steps=["event_type = 'signup'", "event_type = 'purchase'"], + ), + Metric( + name="signup_retention", + type="retention", + entity="user_id", + cohort_event="event_type = 'signup'", + activity_event="event_type = 'active'", + periods=7, + retention_granularity="day", + ), + Metric( + name="multi_platform_users", + type="cohort", + entity="user_id", + inner_metrics=[{"name": "platform_count", "agg": "count_distinct", "sql": "platform"}], + having="platform_count >= 2", + agg="count", + ), + ], + ) + ) + captured = {} + + class FakeRustModule: + def compile_with_yaml(self, models_yaml, _query_yaml): + captured.update(yaml.safe_load(models_yaml)) + return "SELECT 1" + + layer._rust_module = FakeRustModule() + sql = layer.compile(metrics=["events.signup_funnel"]) + + assert sql.startswith("SELECT 1") + metrics = {metric["name"]: metric for metric in captured["models"][0]["metrics"]} + assert metrics["signup_funnel"]["steps"] == ["event_type = 'signup'", "event_type = 'purchase'"] + assert metrics["signup_retention"]["cohort_event"] == "event_type = 'signup'" + assert metrics["signup_retention"]["periods"] == 7 + assert metrics["signup_retention"]["retention_granularity"] == "day" + assert metrics["multi_platform_users"]["inner_metrics"] == [ + {"name": "platform_count", "agg": "count_distinct", "sql": "platform"} + ] + assert metrics["multi_platform_users"]["having"] == "platform_count >= 2" diff --git a/tests/db/test_adbc_ci_smoke.py b/tests/db/test_adbc_ci_smoke.py index 8878e3c4..662b92fc 100644 --- a/tests/db/test_adbc_ci_smoke.py +++ b/tests/db/test_adbc_ci_smoke.py @@ -60,8 +60,7 @@ def _target_uri(db: str) -> str: if db == "clickhouse": host = os.getenv("CLICKHOUSE_HOST", "localhost") port = os.getenv("CLICKHOUSE_PORT", "8123") - password = os.getenv("CLICKHOUSE_PASSWORD", "clickhouse") - return f"clickhouse://default:{password}@{host}:{port}/default" + return f"http://{host}:{port}/" pytest.skip(f"Unsupported ADBC_DB={db!r}") diff --git a/tests/queries/test_rust_query_rewriter_route.py b/tests/queries/test_rust_query_rewriter_route.py new file mode 100644 index 00000000..01b70546 --- /dev/null +++ b/tests/queries/test_rust_query_rewriter_route.py @@ -0,0 +1,89 @@ +import pytest + +from sidemantic.core.dimension import Dimension +from sidemantic.core.metric import Metric +from sidemantic.core.model import Model +from sidemantic.core.semantic_graph import SemanticGraph +from sidemantic.sql.query_rewriter import QueryRewriter + + +def _graph() -> SemanticGraph: + graph = SemanticGraph() + graph.add_model( + Model( + name="orders", + table="orders", + primary_key="id", + dimensions=[Dimension(name="status", type="categorical", sql="status")], + metrics=[Metric(name="revenue", agg="sum", sql="amount")], + ) + ) + return graph + + +def test_query_rewriter_routes_to_rust_when_enabled(monkeypatch): + class FakeRustModule: + def __init__(self): + self.calls = [] + + def rewrite_with_yaml(self, yaml_text: str, sql_text: str) -> str: + self.calls.append((yaml_text, sql_text)) + return "SELECT 1 AS from_rust" + + fake = FakeRustModule() + monkeypatch.setenv("SIDEMANTIC_RS_REWRITER", "1") + monkeypatch.delenv("SIDEMANTIC_RS_NO_FALLBACK", raising=False) + monkeypatch.setattr("sidemantic.sql.query_rewriter.get_rust_module", lambda: fake) + monkeypatch.setattr("sidemantic.sql.query_rewriter.graph_to_rust_yaml", lambda _graph: "models: []") + + rewritten = QueryRewriter(_graph()).rewrite("SELECT orders.revenue FROM orders") + + assert rewritten == "SELECT 1 AS from_rust" + assert fake.calls == [("models: []", "SELECT orders.revenue FROM orders")] + + +def test_query_rewriter_no_fallback_raises_when_rust_fails(monkeypatch): + class FailingRustModule: + def rewrite_with_yaml(self, _yaml_text: str, _sql_text: str) -> str: + raise RuntimeError("boom") + + monkeypatch.setenv("SIDEMANTIC_RS_REWRITER", "1") + monkeypatch.setenv("SIDEMANTIC_RS_NO_FALLBACK", "1") + monkeypatch.setattr("sidemantic.sql.query_rewriter.get_rust_module", lambda: FailingRustModule()) + monkeypatch.setattr("sidemantic.sql.query_rewriter.graph_to_rust_yaml", lambda _graph: "models: []") + + with pytest.raises(ValueError, match="Rust rewriter failed: boom"): + QueryRewriter(_graph()).rewrite("SELECT orders.revenue FROM orders") + + +def test_rust_rewriter_falls_back_when_graph_metric_sql_cannot_be_prepared(monkeypatch): + class FailingRustModule: + def __init__(self): + self.calls = [] + + def rewrite_with_yaml(self, yaml_text: str, sql_text: str) -> str: + self.calls.append((yaml_text, sql_text)) + raise RuntimeError("boom") + + graph = _graph() + graph.add_metric( + Metric( + name="placeholder_metric", + type="derived", + sql="${TABLE}.amount", + ) + ) + + fake = FailingRustModule() + monkeypatch.setenv("SIDEMANTIC_RS_REWRITER", "1") + monkeypatch.delenv("SIDEMANTIC_RS_NO_FALLBACK", raising=False) + monkeypatch.setattr("sidemantic.sql.query_rewriter.get_rust_module", lambda: fake) + monkeypatch.setattr("sidemantic.sql.query_rewriter.graph_to_rust_yaml", lambda _graph: "models: []") + + rewriter = QueryRewriter(graph) + monkeypatch.setattr(rewriter, "_rewrite_simple_query", lambda _parsed: "SELECT 42 AS fallback") + + rewritten = rewriter.rewrite("SELECT placeholder_metric FROM metrics") + + assert rewritten == "SELECT 42 AS fallback" + assert fake.calls == [("models: []", "SELECT placeholder_metric FROM metrics")] diff --git a/tests/queries/test_sql_rewriter.py b/tests/queries/test_sql_rewriter.py index 121fd3fb..5574e803 100644 --- a/tests/queries/test_sql_rewriter.py +++ b/tests/queries/test_sql_rewriter.py @@ -1397,6 +1397,23 @@ def test_semantic_root_with_user_cte_preserved(semantic_layer): assert rows[0]["revenue"] == 250.00 +def test_semantic_root_allows_unrelated_generated_cte_name(semantic_layer): + """User CTE names are rejected only when this query actually generates the same CTE.""" + sql = """ + WITH customers_cte AS ( + SELECT 'completed' AS status + ) + SELECT orders.revenue + FROM orders + WHERE orders.status IN (SELECT status FROM customers_cte) + """ + result = semantic_layer.sql(sql) + rows = _rows(result) + + assert len(rows) == 1 + assert rows[0]["revenue"] == 250.00 + + def test_semantic_root_with_recursive_cte_preserved(semantic_layer): """WITH RECURSIVE flag is preserved when merging user CTEs.""" sql = """