diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index cbdfcc8..9a6f2db 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -8,14 +8,11 @@ permissions: env: ARTIFACT-LINUX-X86_64-EXTENSION: sqlite-vss-linux-x86_64 ARTIFACT-MACOS-X86_64-EXTENSION: sqlite-vss-macos-x86_64 - ARTIFACT-MACOS-AARCH64-EXTENSION: sqlite-vss-macos-aarch64 ARTIFACT-WINDOWS-X86_64-EXTENSION: sqlite-vss-windows-x86_64 - ARTIFACT-LINUX-X86_64-WHEELS: sqlite-vss-linux-x86_64-wheels - ARTIFACT-MACOS-X86_64-WHEELS: sqlite-vss-macos-x86_64-wheels - ARTIFACT-MACOS-AARCH64-WHEELS: sqlite-vss-macos-aarch64-wheels + ARTIFACT-MACOS-AARCH64-EXTENSION: sqlite-vss-macos-aarch64 jobs: - build-linux-x86_64-extension: - runs-on: ubuntu-20.04 + build-macos-aarch64-extension: + runs-on: flyci-macos-large-latest-m1 steps: - uses: actions/checkout@v3 with: @@ -30,33 +27,26 @@ jobs: - if: steps.cache-sqlite-build.outputs.cache-hit != 'true' working-directory: vendor/sqlite run: ./configure && make - - # TODO how cache this? - - run: sudo apt-get install -y cmake libgomp1 + - run: brew install llvm + - id: cache-cmake-build + uses: actions/cache@v3 + with: + path: build + key: ${{ runner.os }}-build + - run: make patch-openmp - run: make loadable-release static-release + env: + # `brew info libomp` gives the correct one, with .a file for static openmp builds + CC: /opt/homebrew/opt/llvm/bin/clang + CXX: /opt/homebrew/opt/llvm/bin/clang++ + LDFLAGS: "-L/usr/local/opt/libomp/lib/" + CPPFLAGS: "-I/usr/local/opt/libomp/include/" - uses: actions/upload-artifact@v3 with: - name: ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }} + name: ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }} path: dist/release/* - build-linux-x86_64-python: + build-linux-x86_64-extension: runs-on: ubuntu-20.04 - needs: [build-linux-x86_64-extension] - steps: - - uses: actions/checkout@v3 - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }} - path: dist/release/ - - uses: actions/setup-python@v3 - - run: pip install wheel - - run: make python-release - - run: make datasette-release - - uses: actions/upload-artifact@v3 - with: - name: ${{ env.ARTIFACT-LINUX-X86_64-WHEELS }} - path: dist/release/wheels/*.whl - build-macos-x86_64-extension: - runs-on: macos-latest steps: - uses: actions/checkout@v3 with: @@ -71,47 +61,20 @@ jobs: - if: steps.cache-sqlite-build.outputs.cache-hit != 'true' working-directory: vendor/sqlite run: ./configure && make - - run: brew install llvm - - id: cache-cmake-build - uses: actions/cache@v3 - with: - path: build - key: ${{ runner.os }}-build - - run: make patch-openmp + + # TODO how cache this? + - run: sudo apt-get install -y cmake libgomp1 - run: make loadable-release static-release - env: - # `brew info libomp` gives the correct one, with .a file for static openmp builds - CC: /usr/local/opt/llvm/bin/clang - CXX: /usr/local/opt/llvm/bin/clang++ - LDFLAGS: "-L/usr/local/opt/libomp/lib/" - CPPFLAGS: "-I/usr/local/opt/libomp/include/" - uses: actions/upload-artifact@v3 with: - name: ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }} + name: ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }} path: dist/release/* - build-macos-x86_64-python: + build-macos-x86_64-extension: runs-on: macos-latest - needs: [build-macos-x86_64-extension] - steps: - - uses: actions/checkout@v3 - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }} - path: dist/release/ - - uses: actions/setup-python@v3 - - run: pip install wheel - - run: make python-release - - run: make datasette-release - - uses: actions/upload-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-X86_64-WHEELS }} - path: dist/release/wheels/*.whl - build-macos-aarch64-extension: - runs-on: [self-hosted, mm1] steps: - uses: actions/checkout@v3 with: - submodules: "recursive" + submodules: recursive - id: cache-sqlite-build uses: actions/cache@v3 with: @@ -122,61 +85,24 @@ jobs: - if: steps.cache-sqlite-build.outputs.cache-hit != 'true' working-directory: vendor/sqlite run: ./configure && make + - run: brew install llvm + - id: cache-cmake-build + uses: actions/cache@v3 + with: + path: build + key: ${{ runner.os }}-build - run: make patch-openmp - run: make loadable-release static-release env: # `brew info libomp` gives the correct one, with .a file for static openmp builds - CC: /opt/homebrew/opt/llvm/bin/clang - CXX: /opt/homebrew/opt/llvm/bin/clang++ - LDFLAGS: "-L/opt/homebrew/opt/libomp/lib" - CPPFLAGS: "-I/opt/homebrew/opt/libomp/include" + CC: /usr/local/opt/llvm/bin/clang + CXX: /usr/local/opt/llvm/bin/clang++ + LDFLAGS: "-L/usr/local/opt/libomp/lib/" + CPPFLAGS: "-I/usr/local/opt/libomp/include/" - uses: actions/upload-artifact@v3 with: - name: ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }} + name: ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }} path: dist/release/* - build-macos-aarch64-python: - runs-on: [self-hosted, mm1] - needs: [build-macos-aarch64-extension] - steps: - - uses: actions/checkout@v3 - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }} - path: dist/release/ - - run: pip3 install wheel - - run: make python-release IS_MACOS_ARM=1 - - run: make datasette-release - - uses: actions/upload-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-AARCH64-WHEELS }} - path: dist/release/wheels/*.whl - upload-deno: - needs: - [ - build-macos-x86_64-extension, - build-macos-aarch64-extension, - build-linux-x86_64-extension, - ] - permissions: - contents: write - runs-on: ubuntu-latest - outputs: - deno-checksums: ${{ steps.deno-assets.outputs.result }} - steps: - - uses: actions/checkout@v3 - - uses: actions/download-artifact@v2 - - id: deno-assets - uses: actions/github-script@v6 - env: - ARTIFACT-LINUX-X86_64-EXTENSION: ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }} - ARTIFACT-MACOS-X86_64-EXTENSION: ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }} - ARTIFACT-MACOS-AARCH64-EXTENSION: ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }} - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - result-encoding: string - script: | - const script = require('.github/workflows/upload-deno-assets.js') - return await script({github, context}) upload-extensions: needs: [ @@ -201,164 +127,3 @@ jobs: macos-x86_64: ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }}/* macos-aarch64: ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }}/* linux-x86_64: ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }}/* - upload-checksums: - needs: [upload-extensions, upload-deno] - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - uses: actions/github-script@v6 - env: - CHECKSUMS: "${{ needs.upload-extensions.outputs.checksums }}\n${{ needs.upload-deno.outputs.deno-checksums }}" - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - const { owner, repo } = context.repo; - const release = await github.rest.repos.getReleaseByTag({ - owner, - repo, - tag: process.env.GITHUB_REF.replace("refs/tags/", ""), - }); - const release_id = release.data.id; - github.rest.repos.uploadReleaseAsset({ - owner, - repo, - release_id, - name: "checksums.txt", - data: process.env.CHECKSUMS, - }); - upload-hex: - runs-on: ubuntu-latest - needs: [upload-extensions] - steps: - - uses: actions/checkout@v2 - - uses: erlef/setup-beam@v1 - with: - otp-version: "24" - rebar3-version: "3.16.1" - elixir-version: "1.14" - - run: ./scripts/elixir_generate_checksum.sh "${{ needs.upload-extensions.outputs.checksums }}" - - run: mix deps.get - working-directory: ./bindings/elixir - - run: mix compile --docs - working-directory: ./bindings/elixir - - run: mix hex.publish --yes - working-directory: ./bindings/elixir - env: - HEX_API_KEY: ${{ secrets.HEX_API_KEY }} - upload-npm: - needs: - [ - build-macos-x86_64-extension, - build-macos-aarch64-extension, - build-linux-x86_64-extension, - ] - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: actions/download-artifact@v2 - - run: | - cp ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }}/*.so bindings/node/sqlite-vss-linux-x64/lib/ - cp ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }}/*.dylib bindings/node/sqlite-vss-darwin-x64/lib/ - cp ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }}/*.dylib bindings/node/sqlite-vss-darwin-arm64/lib/ - - uses: actions/setup-node@v3 - with: - node-version: "16" - registry-url: "https://registry.npmjs.org" - - name: Publish NPM sqlite-vss-linux-x64 - working-directory: bindings/node/sqlite-vss-linux-x64 - run: npm publish --access public - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - - name: Publish NPM sqlite-vss-darwin-x64 - working-directory: bindings/node/sqlite-vss-darwin-x64 - run: npm publish --access public - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - - name: Publish NPM sqlite-vss-darwin-arm64 - working-directory: bindings/node/sqlite-vss-darwin-arm64 - run: npm publish --access public - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - - name: Publish NPM sqlite-vss - working-directory: bindings/node/sqlite-vss - run: npm publish --access public - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - upload-pypi: - needs: - [ - build-linux-x86_64-python, - build-macos-x86_64-python, - build-macos-aarch64-python, - ] - runs-on: ubuntu-latest - steps: - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-LINUX-X86_64-WHEELS }} - path: dist - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-X86_64-WHEELS }} - path: dist - - uses: actions/download-artifact@v3 - with: - name: ${{ env.ARTIFACT-MACOS-AARCH64-WHEELS }} - path: dist - - uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - skip-existing: true - upload-gem: - needs: - [ - build-macos-x86_64-extension, - build-macos-aarch64-extension, - build-linux-x86_64-extension, - ] - permissions: - contents: write - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/download-artifact@v2 - - uses: ruby/setup-ruby@v1 - with: - ruby-version: 3.2 - - run: | - rm bindings/ruby/lib/*.{dylib,so,dll} || true - cp ${{ env.ARTIFACT-MACOS-X86_64-EXTENSION }}/*.dylib bindings/ruby/lib - gem -C bindings/ruby build -o x86_64-darwin.gem sqlite_vss.gemspec - env: - PLATFORM: x86_64-darwin - - run: | - rm bindings/ruby/lib/*.{dylib,so,dll} || true - cp ${{ env.ARTIFACT-MACOS-AARCH64-EXTENSION }}/*.dylib bindings/ruby/lib - gem -C bindings/ruby build -o arm64-darwin.gem sqlite_vss.gemspec - env: - PLATFORM: arm64-darwin - - run: | - rm bindings/ruby/lib/*.{dylib,so,dll} || true - cp ${{ env.ARTIFACT-LINUX-X86_64-EXTENSION }}/*.so bindings/ruby/lib - gem -C bindings/ruby build -o x86_64-linux.gem sqlite_vss.gemspec - env: - PLATFORM: x86_64-linux - - run: | - gem push bindings/ruby/x86_64-darwin.gem - gem push bindings/ruby/arm64-darwin.gem - gem push bindings/ruby/x86_64-linux.gem - env: - GEM_HOST_API_KEY: ${{ secrets.GEM_HOST_API_KEY }} - upload-crate: - runs-on: ubuntu-latest - needs: [upload-extensions] - steps: - - uses: actions/checkout@v2 - - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - - run: cargo publish --no-verify - working-directory: ./bindings/rust - env: - CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} \ No newline at end of file diff --git a/src/sqlite-vector.cpp b/src/sqlite-vector.cpp index a780fcf..818c4b1 100644 --- a/src/sqlite-vector.cpp +++ b/src/sqlite-vector.cpp @@ -14,398 +14,19 @@ SQLITE_EXTENSION_INIT1 using namespace std; - typedef unique_ptr> vec_ptr; -// https://github.com/sqlite/sqlite/blob/master/src/json.c#L88-L89 -#define JSON_SUBTYPE 74 /* Ascii for "J" */ - -#include -using json = nlohmann::json; - -char VECTOR_BLOB_HEADER_BYTE = 'v'; -char VECTOR_BLOB_HEADER_TYPE = 1; -const char *VECTOR_FLOAT_POINTER_NAME = "vectorf32v0"; - -#pragma endregion - -#pragma region Generic - -void delVectorFloat(void *p) { - - auto vx = static_cast(p); - sqlite3_free(vx->data); - delete vx; -} - -void resultVector(sqlite3_context *context, vector *vecIn) { - - auto vecRes = new VectorFloat(); - - vecRes->size = vecIn->size(); - vecRes->data = (float *)sqlite3_malloc(vecIn->size() * sizeof(float)); - - memcpy(vecRes->data, vecIn->data(), vecIn->size() * sizeof(float)); - - sqlite3_result_pointer(context, vecRes, VECTOR_FLOAT_POINTER_NAME, delVectorFloat); -} - -vec_ptr vectorFromBlobValue(sqlite3_value *value, const char **pzErrMsg) { - - int size = sqlite3_value_bytes(value); - char header; - char type; - - if (size < (2)) { - *pzErrMsg = "Vector blob size less than header length"; - return nullptr; - } - - const void *pBlob = sqlite3_value_blob(value); - memcpy(&header, ((char *)pBlob + 0), sizeof(char)); - memcpy(&type, ((char *)pBlob + 1), sizeof(char)); - - if (header != VECTOR_BLOB_HEADER_BYTE) { - *pzErrMsg = "Blob not well-formatted vector blob"; - return nullptr; - } - - if (type != VECTOR_BLOB_HEADER_TYPE) { - *pzErrMsg = "Blob type not right"; - return nullptr; - } - - int numElements = (size - 2) / sizeof(float); - float *vec = (float *)((char *)pBlob + 2); - return vec_ptr(new vector(vec, vec + numElements)); -} - -vec_ptr vectorFromRawBlobValue(sqlite3_value *value, const char **pzErrMsg) { - - int size = sqlite3_value_bytes(value); - - // Must be divisible by 4 - if (size % 4) { - *pzErrMsg = "Invalid raw blob length, blob must be divisible by 4"; - return nullptr; - } - const void *pBlob = sqlite3_value_blob(value); - - float *vec = (float *)((char *)pBlob); - return vec_ptr(new vector(vec, vec + (size / 4))); -} - -vec_ptr vectorFromTextValue(sqlite3_value *value) { - - try { - - json json = json::parse(sqlite3_value_text(value)); - vec_ptr pVec(new vector()); - json.get_to(*pVec); - return pVec; - - } catch (const json::exception &) { - return nullptr; - } - - return nullptr; -} - -static vec_ptr valueAsVector(sqlite3_value *value) { - - // Option 1: If the value is a "vectorf32v0" pointer, create vector from - // that - auto vec = (VectorFloat *)sqlite3_value_pointer(value, VECTOR_FLOAT_POINTER_NAME); - - if (vec != nullptr) - return vec_ptr(new vector(vec->data, vec->data + vec->size)); - - vec_ptr pVec; - - // Option 2: value is a blob in vector format - if (sqlite3_value_type(value) == SQLITE_BLOB) { - - const char *pzErrMsg = nullptr; - - if ((pVec = vectorFromBlobValue(value, &pzErrMsg)) != nullptr) - return pVec; - - if ((pVec = vectorFromRawBlobValue(value, &pzErrMsg)) != nullptr) - return pVec; - } - - // Option 3: if value is a JSON array coercible to float vector, use that - if (sqlite3_value_type(value) == SQLITE_TEXT) { - - if ((pVec = vectorFromTextValue(value)) != nullptr) - return pVec; - else - return nullptr; - } - - // Else, value isn't a vector - return nullptr; -} - -#pragma endregion - -#pragma region Meta - -static void vector_version(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); -} - -static void vector_debug(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vec_ptr pVec = valueAsVector(argv[0]); - - if (pVec == nullptr) { - - sqlite3_result_error(context, "Value not a vector", -1); - return; - } - - sqlite3_str *str = sqlite3_str_new(0); - sqlite3_str_appendf(str, "size: %lld [", pVec->size()); - - for (int i = 0; i < pVec->size(); i++) { - - if (i == 0) - sqlite3_str_appendf(str, "%f", pVec->at(i)); - else - sqlite3_str_appendf(str, ", %f", pVec->at(i)); - } - - sqlite3_str_appendchar(str, 1, ']'); - sqlite3_result_text(context, sqlite3_str_finish(str), -1, sqlite3_free); -} - -#pragma endregion - -#pragma region Vector generation - -// TODO should return fvec, ivec, or bvec depending on input. How do bvec, -// though? -static void vector_from(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vector vec; - vec.reserve(argc); - for (int i = 0; i < argc; i++) { - vec.push_back(sqlite3_value_double(argv[i])); - } - - resultVector(context, &vec); -} - -#pragma endregion - -#pragma region Vector general - -static void vector_value_at(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vec_ptr pVec = valueAsVector(argv[0]); - - if (pVec == nullptr) - return; - - int pos = sqlite3_value_int(argv[1]); - - try { - - float result = pVec->at(pos); - sqlite3_result_double(context, result); - - } catch (const out_of_range &oor) { - - char *errmsg = sqlite3_mprintf("%d out of range: %s", pos, oor.what()); - - if (errmsg != nullptr) { - sqlite3_result_error(context, errmsg, -1); - sqlite3_free(errmsg); - } else { - sqlite3_result_error_nomem(context); - } - } -} - -static void vector_length(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - auto pVec = (VectorFloat *)sqlite3_value_pointer(argv[0], VECTOR_FLOAT_POINTER_NAME); - if (pVec == nullptr) - return; - - sqlite3_result_int64(context, pVec->size); -} - -#pragma endregion - -#pragma region Json - -static void vector_to_json(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vec_ptr pVec = valueAsVector(argv[0]); - if (pVec == nullptr) - return; - - json j = json(*pVec); - - sqlite3_result_text(context, j.dump().c_str(), -1, SQLITE_TRANSIENT); - sqlite3_result_subtype(context, JSON_SUBTYPE); -} - -static void vector_from_json(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - const char *text = (const char *)sqlite3_value_text(argv[0]); - vec_ptr pVec = vectorFromTextValue(argv[0]); - - if (pVec == nullptr) { - sqlite3_result_error( - context, "input not valid json, or contains non-float data", -1); - } else { - resultVector(context, pVec.get()); - } -} - -#pragma endregion - -#pragma region Blob - -/* - -|Offset | Size | Description -|-|-|- -|a|a|A -*/ -static void vector_to_blob(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vec_ptr pVec = valueAsVector(argv[0]); - if (pVec == nullptr) - return; - - int size = pVec->size(); - int memSize = (sizeof(char)) + (sizeof(char)) + (size * 4); - void *pBlob = sqlite3_malloc(memSize); - memset(pBlob, 0, memSize); - - memcpy((void *)((char *)pBlob + 0), (void *)&VECTOR_BLOB_HEADER_BYTE, sizeof(char)); - memcpy((void *)((char *)pBlob + 1), (void *)&VECTOR_BLOB_HEADER_TYPE, sizeof(char)); - memcpy((void *)((char *)pBlob + 2), (void *)pVec->data(), size * 4); - - sqlite3_result_blob64(context, pBlob, memSize, sqlite3_free); -} - -static void vector_from_blob(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - const char *pzErrMsg; - - vec_ptr pVec = vectorFromBlobValue(argv[0], &pzErrMsg); - if (pVec == nullptr) - sqlite3_result_error(context, pzErrMsg, -1); - else - resultVector(context, pVec.get()); -} - -static void vector_to_raw(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - vec_ptr pVec = valueAsVector(argv[0]); - if (pVec == nullptr) - return; - - int size = pVec->size(); - int n = size * sizeof(float); - void *pBlob = sqlite3_malloc(n); - memset(pBlob, 0, n); - memcpy((void *)((char *)pBlob), (void *)pVec->data(), n); - sqlite3_result_blob64(context, pBlob, n, sqlite3_free); -} - -static void vector_from_raw(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - const char *pzErrMsg; // TODO: Shouldn't we have like error messages here? - - vec_ptr pVec = vectorFromRawBlobValue(argv[0], &pzErrMsg); - if (pVec == nullptr) - sqlite3_result_error(context, pzErrMsg, -1); - else - resultVector(context, pVec.get()); -} +#include "vec/functions.h" +#include "vec/fvecsEach_cursor.h" +#include "vec/fvecsEach_vtab.h" #pragma endregion #pragma region fvecs vtab -struct fvecsEach_vtab : public sqlite3_vtab { - - fvecsEach_vtab() { - - pModule = nullptr; - nRef = 0; - zErrMsg = nullptr; - } - - ~fvecsEach_vtab() { - - if (zErrMsg != nullptr) { - sqlite3_free(zErrMsg); - } - } -}; - -struct fvecsEach_cursor : public sqlite3_vtab_cursor { - - fvecsEach_cursor(sqlite3_vtab *pVtab) { - - this->pVtab = pVtab; - iRowid = 0; - pBlob = nullptr; - iBlobN = 0; - p = 0; - iCurrentD = 0; - } - - ~fvecsEach_cursor() { - if (pBlob != nullptr) - sqlite3_free(pBlob); - } - - sqlite3_int64 iRowid; - - // Copy of fvecs input blob - void *pBlob; - - // Total size of pBlob in bytes - sqlite3_int64 iBlobN; - sqlite3_int64 p; - - // Current dimensions - int iCurrentD; - - // Pointer to current vector being read in - vec_ptr pCurrentVector; -}; +#define FVECS_EACH_DIMENSIONS 0 +#define FVECS_EACH_VECTOR 1 +#define FVECS_EACH_INPUT 2 static int fvecsEachConnect(sqlite3 *db, void *pAux, @@ -414,18 +35,12 @@ static int fvecsEachConnect(sqlite3 *db, sqlite3_vtab **ppVtab, char **pzErr) { - int rc; - - rc = sqlite3_declare_vtab(db, "create table x(dimensions, vector, input hidden)"); - -#define FVECS_EACH_DIMENSIONS 0 -#define FVECS_EACH_VECTOR 1 -#define FVECS_EACH_INPUT 2 + auto rc = sqlite3_declare_vtab(db, "create table x(dimensions, vector, input hidden)"); if (rc == SQLITE_OK) { auto pNew = new fvecsEach_vtab(); - if (pNew == 0) + if (pNew == nullptr) return SQLITE_NOMEM; *ppVtab = pNew; @@ -435,25 +50,25 @@ static int fvecsEachConnect(sqlite3 *db, static int fvecsEachDisconnect(sqlite3_vtab *pVtab) { - auto pTable = static_cast(pVtab); - delete pTable; + auto table = static_cast(pVtab); + delete table; return SQLITE_OK; } static int fvecsEachOpen(sqlite3_vtab *p, sqlite3_vtab_cursor **ppCursor) { - auto pCur = new fvecsEach_cursor(p); - if (pCur == nullptr) + auto cursor = new fvecsEach_cursor(p); + if (cursor == nullptr) return SQLITE_NOMEM; - *ppCursor = pCur; + *ppCursor = cursor; return SQLITE_OK; } -static int fvecsEachClose(sqlite3_vtab_cursor *cur) { +static int fvecsEachClose(sqlite3_vtab_cursor *pCursor) { - auto pCur = static_cast(cur); - delete pCur; + auto cursor = static_cast(pCursor); + delete cursor; return SQLITE_OK; } @@ -480,83 +95,80 @@ static int fvecsEachBestIndex(sqlite3_vtab *tab, sqlite3_index_info *pIdxInfo) { return SQLITE_OK; } -static int fvecsEachFilter(sqlite3_vtab_cursor *pVtabCursor, +static int fvecsEachFilter(sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) { - auto pCur = static_cast(pVtabCursor); + auto cursor = static_cast(pCursor); int size = sqlite3_value_bytes(argv[0]); const void *blob = sqlite3_value_blob(argv[0]); - if (pCur->pBlob) - sqlite3_free(pCur->pBlob); + cursor->setBlob(sqlite3_malloc(size)); + cursor->iBlobN = size; + cursor->iRowid = 1; + memcpy(cursor->getBlob(), blob, size); - pCur->pBlob = sqlite3_malloc(size); - pCur->iBlobN = size; - pCur->iRowid = 1; - memcpy(pCur->pBlob, blob, size); - - memcpy(&pCur->iCurrentD, pCur->pBlob, sizeof(int)); - float *vecBegin = (float *)((char *)pCur->pBlob + sizeof(int)); + memcpy(&cursor->iCurrentD, cursor->getBlob(), sizeof(int)); + float *vecBegin = (float *)((char *)cursor->getBlob() + sizeof(int)); // TODO: Shouldn't this multiply by sizeof(float)? - pCur->pCurrentVector = vec_ptr(new vector(vecBegin, vecBegin + pCur->iCurrentD)); + cursor->pCurrentVector = vec_ptr(new vector(vecBegin, vecBegin + cursor->iCurrentD)); - pCur->p = sizeof(int) + (pCur->iCurrentD * sizeof(float)); + cursor->p = sizeof(int) + (cursor->iCurrentD * sizeof(float)); return SQLITE_OK; } -static int fvecsEachNext(sqlite3_vtab_cursor *cur) { +static int fvecsEachNext(sqlite3_vtab_cursor *pCursor) { - auto pCur = static_cast(cur); + auto cursor = static_cast(pCursor); // TODO: Shouldn't this multiply by sizeof(float)? - memcpy(&pCur->iCurrentD, ((char *)pCur->pBlob + pCur->p), sizeof(int)); - float *vecBegin = (float *)(((char *)pCur->pBlob + pCur->p) + sizeof(int)); + memcpy(&cursor->iCurrentD, ((char *)cursor->getBlob() + cursor->p), sizeof(int)); + float *vecBegin = (float *)(((char *)cursor->getBlob() + cursor->p) + sizeof(int)); - pCur->pCurrentVector->clear(); - pCur->pCurrentVector->shrink_to_fit(); - pCur->pCurrentVector->reserve(pCur->iCurrentD); - pCur->pCurrentVector->insert(pCur->pCurrentVector->begin(), + cursor->pCurrentVector->clear(); + cursor->pCurrentVector->shrink_to_fit(); + cursor->pCurrentVector->reserve(cursor->iCurrentD); + cursor->pCurrentVector->insert(cursor->pCurrentVector->begin(), vecBegin, - vecBegin + pCur->iCurrentD); + vecBegin + cursor->iCurrentD); - pCur->p += (sizeof(int) + (pCur->iCurrentD * sizeof(float))); - pCur->iRowid++; + cursor->p += (sizeof(int) + (cursor->iCurrentD * sizeof(float))); + cursor->iRowid++; return SQLITE_OK; } -static int fvecsEachEof(sqlite3_vtab_cursor *cur) { +static int fvecsEachEof(sqlite3_vtab_cursor *pCursor) { - auto pCur = (fvecsEach_cursor *)cur; - return pCur->p > pCur->iBlobN; + auto cursor = (fvecsEach_cursor *)pCursor; + return cursor->p > cursor->iBlobN; } -static int fvecsEachRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { +static int fvecsEachRowid(sqlite3_vtab_cursor *pCursor, sqlite_int64 *pRowid) { - fvecsEach_cursor *pCur = (fvecsEach_cursor *)cur; - *pRowid = pCur->iRowid; + fvecsEach_cursor *cursor = (fvecsEach_cursor *)pCursor; + *pRowid = cursor->iRowid; return SQLITE_OK; } -static int fvecsEachColumn(sqlite3_vtab_cursor *cur, +static int fvecsEachColumn(sqlite3_vtab_cursor *pCursor, sqlite3_context *context, int i) { - auto pCur = static_cast(cur); + auto cursor = static_cast(pCursor); switch (i) { case FVECS_EACH_DIMENSIONS: - sqlite3_result_int(context, pCur->iCurrentD); + sqlite3_result_int(context, cursor->iCurrentD); break; case FVECS_EACH_VECTOR: - resultVector(context, pCur->pCurrentVector.get()); + resultVector(context, cursor->pCurrentVector.get()); break; case FVECS_EACH_INPUT: @@ -584,17 +196,18 @@ static sqlite3_module fvecsEachModule = { /* xEof */ fvecsEachEof, /* xColumn */ fvecsEachColumn, /* xRowid */ fvecsEachRowid, - /* xUpdate */ 0, - /* xBegin */ 0, - /* xSync */ 0, - /* xCommit */ 0, - /* xRollback */ 0, - /* xFindMethod */ 0, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ 0}; + /* xUpdate */ nullptr, + /* xBegin */ nullptr, + /* xSync */ nullptr, + /* xCommit */ nullptr, + /* xRollback */ nullptr, + /* xFindMethod */ nullptr, + /* xRename */ nullptr, + /* xSavepoint */ nullptr, + /* xRelease */ nullptr, + /* xRollbackTo */ nullptr, + /* xShadowName */ nullptr +}; #pragma endregion @@ -675,9 +288,9 @@ __declspec(dllexport) aFunc[i].flags, aFunc[i].pAux, aFunc[i].xFunc, - 0, - 0, - 0); + nullptr, + nullptr, + nullptr); if (rc != SQLITE_OK) { diff --git a/src/sqlite-vss.cpp b/src/sqlite-vss.cpp index b9f8c6a..27b2343 100644 --- a/src/sqlite-vss.cpp +++ b/src/sqlite-vss.cpp @@ -1,224 +1,14 @@ + #include "sqlite-vss.h" -#include -#include - -#include "sqlite3ext.h" -SQLITE_EXTENSION_INIT1 - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "vss/inclusions.h" #include "sqlite-vector.h" - -using namespace std; - -typedef unique_ptr> vec_ptr; - -#pragma region Meta - -static void vss_version(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); -} - -static void vss_debug(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - auto resTxt = sqlite3_mprintf( - "version: %s\nfaiss version: %d.%d.%d\nfaiss compile options: %s", - SQLITE_VSS_VERSION, - FAISS_VERSION_MAJOR, - FAISS_VERSION_MINOR, - FAISS_VERSION_PATCH, - faiss::get_compile_options().c_str()); - - sqlite3_result_text(context, resTxt, -1, SQLITE_TRANSIENT); - sqlite3_free(resTxt); -} - -#pragma endregion - -#pragma region Distances - -static void vss_distance_l1(sqlite3_context *context, - int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", - -1); - return; - } - - sqlite3_result_double(context, faiss::fvec_L1(lhs->data(), rhs->data(), lhs->size())); -} - -static void vss_distance_l2(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", - -1); - return; - } - - sqlite3_result_double(context, faiss::fvec_L2sqr(lhs->data(), rhs->data(), lhs->size())); -} - -static void vss_distance_linf(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", - -1); - return; - } - - sqlite3_result_double(context, faiss::fvec_Linf(lhs->data(), rhs->data(), lhs->size())); -} - -static void vss_inner_product(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", - -1); - return; - } - - sqlite3_result_double(context, - faiss::fvec_inner_product(lhs->data(), rhs->data(), lhs->size())); -} - -static void vss_fvec_add(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", - -1); - return; - } - - auto size = lhs->size(); - vec_ptr c(new vector(size)); - faiss::fvec_add(size, lhs->data(), rhs->data(), c->data()); - - vector_api->xResultVector(context, c.get()); -} - -static void vss_fvec_sub(sqlite3_context *context, int argc, - sqlite3_value **argv) { - - auto vector_api = (vector0_api *)sqlite3_user_data(context); - - vec_ptr lhs = vector_api->xValueAsVector(argv[0]); - if (lhs == nullptr) { - sqlite3_result_error(context, "LHS is not a vector", -1); - return; - } - - vec_ptr rhs = vector_api->xValueAsVector(argv[1]); - if (rhs == nullptr) { - sqlite3_result_error(context, "RHS is not a vector", -1); - return; - } - - if (lhs->size() != rhs->size()) { - sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", -1); - return; - } - - int size = lhs->size(); - vec_ptr c = vec_ptr(new vector(size)); - faiss::fvec_sub(size, lhs->data(), rhs->data(), c->data()); - vector_api->xResultVector(context, c.get()); -} - -#pragma endregion +#include "vss/sql-statement.h" +#include "vss/meta-methods.h" +#include "vss/calculations.h" +#include "vss/vss-index.h" +#include "vss/vss-index-vtab.h" +#include "vss/vss-index-cursor.h" #pragma region Structs and cleanup functions @@ -228,18 +18,25 @@ struct VssSearchParams { sqlite3_int64 k; }; -void delVssSearchParams(void *p) { - - VssSearchParams *self = (VssSearchParams *)p; - delete self; -} - struct VssRangeSearchParams { vec_ptr vector; float distance; }; +struct VssIndexColumn { + + string name; + sqlite3_int64 dimensions; + string factory; +}; + +void delVssSearchParams(void *p) { + + VssSearchParams *self = (VssSearchParams *)p; + delete self; +} + void delVssRangeSearchParams(void *p) { auto self = (VssRangeSearchParams *)p; @@ -248,7 +45,7 @@ void delVssRangeSearchParams(void *p) { #pragma endregion -#pragma region Vtab +#pragma region Virtual table implementation static void vssSearchParamsFunc(sqlite3_context *context, int argc, @@ -269,7 +66,8 @@ static void vssSearchParamsFunc(sqlite3_context *context, sqlite3_result_pointer(context, params, "vss0_searchparams", delVssSearchParams); } -static void vssRangeSearchParamsFunc(sqlite3_context *context, int argc, +static void vssRangeSearchParamsFunc(sqlite3_context *context, + int argc, sqlite3_value **argv) { auto vector_api = (vector0_api *)sqlite3_user_data(context); @@ -288,147 +86,28 @@ static void vssRangeSearchParamsFunc(sqlite3_context *context, int argc, sqlite3_result_pointer(context, params, "vss0_rangesearchparams", delVssRangeSearchParams); } -static int write_index_insert(faiss::Index *index, - sqlite3 *db, +static int shadow_data_insert(sqlite3 *db, char *schema, char *name, - int rowId) { - - faiss::VectorIOWriter writer; - faiss::write_index(index, &writer); - sqlite3_int64 indexSize = writer.data.size(); - - // First try to insert into xyz_index. If that fails with a rowid constraint - // error, that means the index is already on disk, we just have to UPDATE - // instead. - - sqlite3_stmt *stmt; - char *sql = sqlite3_mprintf( - "insert into \"%w\".\"%w_index\"(rowid, idx) values (?, ?)", - schema, - name); - - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt == nullptr) { - sqlite3_free(sql); - return SQLITE_ERROR; - } - - rc = sqlite3_bind_int64(stmt, 1, rowId); - if (rc != SQLITE_OK) { - sqlite3_finalize(stmt); - sqlite3_free(sql); - return SQLITE_ERROR; - } - - rc = sqlite3_bind_blob64(stmt, 2, writer.data.data(), indexSize, SQLITE_TRANSIENT); - if (rc != SQLITE_OK) { - sqlite3_finalize(stmt); - sqlite3_free(sql); - return SQLITE_ERROR; - } - - int result = sqlite3_step(stmt); - sqlite3_finalize(stmt); - sqlite3_free(sql); - - if (result == SQLITE_DONE) { - - // INSERT was success, index wasn't written yet, all good to exit - return SQLITE_OK; + sqlite3_int64 rowid) { - } else if (sqlite3_extended_errcode(db) != SQLITE_CONSTRAINT_ROWID) { + SqlStatement insert(db, + sqlite3_mprintf("insert into \"%w\".\"%w_data\"(rowid, x) values (?, ?);", + schema, + name)); - // INSERT failed for another unknown reason, bad, return error + if (insert.prepare() != SQLITE_OK) return SQLITE_ERROR; - } - - // INSERT failed because index already is on disk, so we do an UPDATE instead - sql = sqlite3_mprintf( - "update \"%w\".\"%w_index\" set idx = ? where rowid = ?", schema, name); - - rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt == nullptr) { - sqlite3_free(sql); + if (insert.bind_int64(1, rowid) != SQLITE_OK) return SQLITE_ERROR; - } - rc = sqlite3_bind_blob64(stmt, 1, writer.data.data(), indexSize, SQLITE_TRANSIENT); - if (rc != SQLITE_OK) { - sqlite3_finalize(stmt); - sqlite3_free(sql); + if (insert.bind_null(2) != SQLITE_OK) return SQLITE_ERROR; - } - rc = sqlite3_bind_int64(stmt, 2, rowId); - if (rc != SQLITE_OK) { - sqlite3_finalize(stmt); - sqlite3_free(sql); + if (insert.step() != SQLITE_DONE) return SQLITE_ERROR; - } - - result = sqlite3_step(stmt); - sqlite3_finalize(stmt); - sqlite3_free(sql); - - if (result == SQLITE_DONE) { - return SQLITE_OK; - } - - return result; -} - -static int shadow_data_insert(sqlite3 *db, - char *schema, - char *name, - sqlite3_int64 *rowid, - sqlite3_int64 *retRowid) { - - sqlite3_stmt *stmt; - - if (rowid == nullptr) { - - auto sql = sqlite3_mprintf( - "insert into \"%w\".\"%w_data\"(x) values (?)", schema, name); - - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - sqlite3_free(sql); - - if (rc != SQLITE_OK || stmt == nullptr) { - return SQLITE_ERROR; - } - - sqlite3_bind_null(stmt, 1); - if (sqlite3_step(stmt) != SQLITE_DONE) { - sqlite3_finalize(stmt); - return SQLITE_ERROR; - } - - } else { - - auto sql = sqlite3_mprintf( - "insert into \"%w\".\"%w_data\"(rowid, x) values (?, ?);", schema, - name); - - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - sqlite3_free(sql); - if (rc != SQLITE_OK || stmt == nullptr) - return SQLITE_ERROR; - - sqlite3_bind_int64(stmt, 1, *rowid); - sqlite3_bind_null(stmt, 2); - if (sqlite3_step(stmt) != SQLITE_DONE) { - sqlite3_finalize(stmt); - return SQLITE_ERROR; - } - - if (retRowid != nullptr) - *retRowid = sqlite3_last_insert_rowid(db); - } - - sqlite3_finalize(stmt); return SQLITE_OK; } @@ -436,219 +115,45 @@ static int shadow_data_delete(sqlite3 *db, char *schema, char *name, sqlite3_int64 rowid) { - sqlite3_stmt *stmt; - - // TODO: We should strive to use only one concept and idea while creating - // SQL statements. - auto query = sqlite3_str_new(0); - sqlite3_str_appendf(query, "delete from \"%w\".\"%w_data\" where rowid = ?", - schema, name); + SqlStatement del(db, + sqlite3_mprintf("delete from \"%w\".\"%w_data\" where rowid = ?", + schema, + name)); - auto sql = sqlite3_str_finish(query); + if (del.prepare() != SQLITE_OK) + return SQLITE_ERROR; - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt == nullptr) + if (del.bind_int64(1, rowid) != SQLITE_OK) return SQLITE_ERROR; - sqlite3_bind_int64(stmt, 1, rowid); - if (sqlite3_step(stmt) != SQLITE_DONE) { - sqlite3_finalize(stmt); + if (del.step() != SQLITE_DONE) return SQLITE_ERROR; - } - sqlite3_free(sql); - sqlite3_finalize(stmt); return SQLITE_OK; } -static faiss::Index *read_index_select(sqlite3 *db, const char *name, int indexId) { - - sqlite3_stmt *stmt; - auto sql = sqlite3_mprintf("select idx from \"%w_index\" where rowid = ?", name); - - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); - if (rc != SQLITE_OK || stmt == nullptr) { - sqlite3_finalize(stmt); - sqlite3_free(sql); - return nullptr; - } - - sqlite3_bind_int64(stmt, 1, indexId); - if (sqlite3_step(stmt) != SQLITE_ROW) { - sqlite3_finalize(stmt); - sqlite3_free(sql); - return nullptr; - } - - auto index_data = sqlite3_column_blob(stmt, 0); - int64_t size = sqlite3_column_bytes(stmt, 0); - - faiss::VectorIOReader reader; - copy((const uint8_t *)index_data, - ((const uint8_t *)index_data) + size, - back_inserter(reader.data)); - - sqlite3_free(sql); - sqlite3_finalize(stmt); - - return faiss::read_index(&reader); -} - -static int create_shadow_tables(sqlite3 *db, - const char *schema, - const char *name, - int n) { - - auto sql = sqlite3_mprintf("create table \"%w\".\"%w_index\"(idx)", - schema, - name); - - auto rc = sqlite3_exec(db, sql, 0, 0, 0); - sqlite3_free(sql); - if (rc != SQLITE_OK) - return rc; - - sql = sqlite3_mprintf("create table \"%w\".\"%w_data\"(x);", - schema, - name); - - rc = sqlite3_exec(db, sql, nullptr, nullptr, nullptr); - sqlite3_free(sql); - return rc; -} - static int drop_shadow_tables(sqlite3 *db, char *name) { + // Dropping both x_index and x_data shadow tables. const char *drops[2] = {"drop table \"%w_index\";", "drop table \"%w_data\";"}; for (int i = 0; i < 2; i++) { - auto curSql = drops[i]; - - sqlite3_stmt *stmt; + SqlStatement cur(db, + sqlite3_mprintf(drops[i], + name)); - // TODO: Use of one construct to create SQL statements. - sqlite3_str *query = sqlite3_str_new(0); - sqlite3_str_appendf(query, curSql, name); - char *sql = sqlite3_str_finish(query); - - int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, 0); - if (rc != SQLITE_OK || stmt == nullptr) { - sqlite3_free(sql); + if (cur.prepare() != SQLITE_OK) return SQLITE_ERROR; - } - if (sqlite3_step(stmt) != SQLITE_DONE) { - sqlite3_free(sql); - sqlite3_finalize(stmt); + if (cur.step() != SQLITE_DONE) return SQLITE_ERROR; - } - - sqlite3_free(sql); - sqlite3_finalize(stmt); } return SQLITE_OK; } -#define VSS_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION -#define VSS_RANGE_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION + 1 - -// Wrapper around a single faiss index, with training data, insert records, and -// delete records. -struct vss_index { - - explicit vss_index(faiss::Index *index) : index(index) {} - - ~vss_index() { - if (index != nullptr) { - delete index; - } - } - - faiss::Index *index; - vector trainings; - vector insert_data; - vector insert_ids; - vector delete_ids; -}; - -struct vss_index_vtab : public sqlite3_vtab { - - vss_index_vtab(sqlite3 *db, vector0_api *vector_api, char *schema, char *name) - : db(db), - vector_api(vector_api), - schema(schema), - name(name) { } - - ~vss_index_vtab() { - - if (name) - sqlite3_free(name); - if (schema) - sqlite3_free(schema); - for (auto iter = indexes.begin(); iter != indexes.end(); ++iter) { - delete (*iter); - } - } - - sqlite3 *db; - vector0_api *vector_api; - - // Name of the virtual table. Must be freed during disconnect - char *name; - - // Name of the schema the virtual table exists in. Must be freed during - // disconnect - char *schema; - - // Vector holding all the faiss Indices the vtab uses, and their state, - // implying which items are to be deleted and inserted. - vector indexes; -}; - -enum QueryType { search, range_search, fullscan }; - -struct vss_index_cursor : public sqlite3_vtab_cursor { - - explicit vss_index_cursor(vss_index_vtab *table) - : table(table), - sqlite3_vtab_cursor({0}), - stmt(nullptr) { } - - ~vss_index_cursor() { - if (stmt != nullptr) - sqlite3_finalize(stmt); - } - - vss_index_vtab *table; - - sqlite3_int64 iCurrent; - sqlite3_int64 iRowid; - - QueryType query_type; - - // For query_type == QueryType::search - sqlite3_int64 limit; - vector search_ids; - vector search_distances; - - // For query_type == QueryType::range_search - unique_ptr range_search_result; - - // For query_type == QueryType::fullscan - sqlite3_stmt *stmt; - int step_result; -}; - -struct VssIndexColumn { - - string name; - sqlite3_int64 dimensions; - string factory; -}; - unique_ptr> parse_constructor(int argc, const char *const *argv) { @@ -696,6 +201,10 @@ unique_ptr> parse_constructor(int argc, return columns; } +#define VSS_INDEX_COLUMN_DISTANCE 0 +#define VSS_INDEX_COLUMN_OPERATION 1 +#define VSS_INDEX_COLUMN_VECTORS 2 + static int init(sqlite3 *db, void *pAux, int argc, @@ -705,31 +214,23 @@ static int init(sqlite3 *db, bool isCreate) { sqlite3_vtab_config(db, SQLITE_VTAB_CONSTRAINT_SUPPORT, 1); - int rc; - - sqlite3_str *str = sqlite3_str_new(nullptr); - sqlite3_str_appendall(str, - "create table x(distance hidden, operation hidden"); auto columns = parse_constructor(argc, argv); - if (columns == nullptr) { - *pzErr = sqlite3_mprintf("Error parsing constructor"); - return rc; + *pzErr = sqlite3_mprintf("Error parsing VSS index factory constructor"); + return SQLITE_ERROR; } - for (auto column = columns->begin(); column != columns->end(); ++column) { - sqlite3_str_appendf(str, ", \"%w\"", column->name.c_str()); + string sql = "create table x(distance hidden, operation hidden"; + for (auto colIter = columns->begin(); colIter != columns->end(); ++colIter) { + sql += ", \"" + colIter->name + "\""; } + sql += ")"; - sqlite3_str_appendall(str, ")"); - auto sql = sqlite3_str_finish(str); - rc = sqlite3_declare_vtab(db, sql); - sqlite3_free(sql); + SqlStatement create(db, + sqlite3_mprintf(sql.c_str())); -#define VSS_INDEX_COLUMN_DISTANCE 0 -#define VSS_INDEX_COLUMN_OPERATION 1 -#define VSS_INDEX_COLUMN_VECTORS 2 + auto rc = create.declare_vtab(); if (rc != SQLITE_OK) return rc; @@ -738,68 +239,43 @@ static int init(sqlite3 *db, (vector0_api *)pAux, sqlite3_mprintf("%s", argv[1]), sqlite3_mprintf("%s", argv[2])); - *ppVtab = pTable; - if (isCreate) { - - for (auto iter = columns->begin(); iter != columns->end(); ++iter) { + *ppVtab = pTable; - try { + try { - auto index = faiss::index_factory(iter->dimensions, iter->factory.c_str()); - pTable->indexes.push_back(new vss_index(index)); + if (isCreate) { - } catch (faiss::FaissException &e) { + auto idxNo = 0; + for (auto iter = columns->begin(); iter != columns->end(); ++iter, idxNo++) { - *pzErr = sqlite3_mprintf("Error building index factory for %s: %s", - iter->name.c_str(), - e.msg.c_str()); + pTable->getIndexes().push_back( + vss_index::factory(db, + argv[1], + argv[2], + idxNo, + iter->factory, + iter->dimensions)); - return SQLITE_ERROR; } - } - rc = create_shadow_tables(db, argv[1], argv[2], columns->size()); - if (rc != SQLITE_OK) - return rc; - - // Shadow tables were successully created. - // After shadow tables are created, write the initial index state to - // shadow _index. - auto i = 0; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { - - try { - - int rc = write_index_insert((*iter)->index, - pTable->db, - pTable->schema, - pTable->name, - i); - - if (rc != SQLITE_OK) - return rc; + } else { - } catch (faiss::FaissException &e) { + for (int idxNo = 0; idxNo < columns->size(); idxNo++) { - return SQLITE_ERROR; + pTable->getIndexes().push_back( + vss_index::factory(db, + argv[2], + idxNo)); } } - } else { - - for (int i = 0; i < columns->size(); i++) { + } catch (exception & e) { - auto index = read_index_select(db, argv[2], i); + *pzErr = sqlite3_mprintf("Error building index factory, exception was: %s", + e.what()); - // Index in shadow table should always be available, integrity check - // to avoid null pointer - if (index == nullptr) { - *pzErr = sqlite3_mprintf("Could not read index at position %d", i); - return SQLITE_ERROR; - } - pTable->indexes.push_back(new vss_index(index)); - } + return SQLITE_ERROR; } return SQLITE_OK; @@ -833,7 +309,8 @@ static int vssIndexDisconnect(sqlite3_vtab *pVtab) { static int vssIndexDestroy(sqlite3_vtab *pVtab) { auto pTable = static_cast(pVtab); - drop_shadow_tables(pTable->db, pTable->name); + drop_shadow_tables(pTable->getDb(), pTable->getName()); + vssIndexDisconnect(pVtab); return SQLITE_OK; } @@ -931,85 +408,83 @@ static int vssIndexFilter(sqlite3_vtab_cursor *pVtabCursor, if (strcmp(idxStr, "search") == 0) { - pCursor->query_type = QueryType::search; + pCursor->setQuery_type(QueryType::search); vec_ptr query_vector; + int nq = 1; + auto index = pCursor->getTable()->getIndexes().at(idxNum); + auto params = static_cast(sqlite3_value_pointer(argv[0], "vss0_searchparams")); if (params != nullptr) { - pCursor->limit = params->k; + pCursor->setLimit(params->k); query_vector = vec_ptr(new vector(*params->vector)); } else if (sqlite3_libversion_number() < 3041000) { // https://sqlite.org/forum/info/6b32f818ba1d97ef - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( - "vss_search() only support vss_search_params() as a " - "2nd parameter for SQLite versions below 3.41.0"); + auto ptrVtab = static_cast(pCursor->pVtab); + ptrVtab->setError( + sqlite3_mprintf( + "vss_search() only support vss_search_params() as a " + "2nd parameter for SQLite versions below 3.41.0")); return SQLITE_ERROR; - } else if ((query_vector = pCursor->table->vector_api->xValueAsVector( + } else if ((query_vector = pCursor->getTable()->getVector0_api()->xValueAsVector( argv[0])) != nullptr) { if (argc > 1) { - pCursor->limit = sqlite3_value_int(argv[1]); + + pCursor->setLimit(sqlite3_value_int(argv[1])); } else { - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = - sqlite3_mprintf("LIMIT required on vss_search() queries"); - return SQLITE_ERROR; + + auto ptrVtab = static_cast(pCursor->pVtab); + pCursor->setLimit(index->size()); } } else { - if (pVtabCursor->pVtab->zErrMsg != nullptr) - sqlite3_free(pVtabCursor->pVtab->zErrMsg); + auto ptrVtab = static_cast(pCursor->pVtab); + ptrVtab->setError(sqlite3_mprintf("2nd argument to vss_search() must be a vector")); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( - "2nd argument to vss_search() must be a vector"); return SQLITE_ERROR; } - int nq = 1; - auto index = pCursor->table->indexes.at(idxNum)->index; + if (!index->canQuery(query_vector)) { - if (query_vector->size() != index->d) { - - // TODO: To support index that transforms vectors - // (to conserve spage, eg?), we should probably - // have some logic in place that transforms the vectors here? - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( + auto ptrVtab = static_cast(pCursor->pVtab); + ptrVtab->setError(sqlite3_mprintf( "Input query size doesn't match index dimensions: %ld != %ld", query_vector->size(), - index->d); + index->dimensions())); + return SQLITE_ERROR; } - if (pCursor->limit <= 0) { + if (pCursor->getLimit() <= 0) { + + auto ptrVtab = static_cast(pCursor->pVtab); + ptrVtab->setError(sqlite3_mprintf( + "Limit must be greater than 0, got %ld", + pCursor->getLimit())); - sqlite3_free(pVtabCursor->pVtab->zErrMsg); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( - "Limit must be greater than 0, got %ld", pCursor->limit); return SQLITE_ERROR; } // To avoid trying to select more records than number of records in index. - auto searchMax = min(static_cast(pCursor->limit) * nq, index->ntotal * nq); + auto searchMax = min(static_cast(pCursor->getLimit()) * nq, index->size() * nq); - pCursor->search_distances = vector(searchMax, 0); - pCursor->search_ids = vector(searchMax, 0); + pCursor->resetSearch(searchMax); index->search(nq, - query_vector->data(), + query_vector, searchMax, - pCursor->search_distances.data(), - pCursor->search_ids.data()); + pCursor->getSearch_distances(), + pCursor->getSearch_ids()); } else if (strcmp(idxStr, "range_search") == 0) { - pCursor->query_type = QueryType::range_search; + pCursor->setQuery_type(QueryType::range_search); auto params = static_cast( sqlite3_value_pointer(argv[0], "vss0_rangesearchparams")); @@ -1017,41 +492,43 @@ static int vssIndexFilter(sqlite3_vtab_cursor *pVtabCursor, int nq = 1; vector nns(params->distance * nq); - pCursor->range_search_result = unique_ptr(new faiss::RangeSearchResult(nq, true)); + pCursor->getRange_search_result() = unique_ptr(new faiss::RangeSearchResult(nq, true)); - auto index = pCursor->table->indexes.at(idxNum)->index; + auto index = pCursor->getTable()->getIndexes().at(idxNum); index->range_search(nq, - params->vector->data(), + params->vector, params->distance, - pCursor->range_search_result.get()); + pCursor->getRange_search_result()); } else if (strcmp(idxStr, "fullscan") == 0) { - pCursor->query_type = QueryType::fullscan; - sqlite3_stmt *stmt; + pCursor->setQuery_type(QueryType::fullscan); + pCursor->setSql( + sqlite3_mprintf("select rowid from \"%w_data\"", + pCursor->getTable()->getName())); - int res = sqlite3_prepare_v2( - pCursor->table->db, - sqlite3_mprintf("select rowid from \"%w_data\"", pCursor->table->name), - -1, &pCursor->stmt, nullptr); + int res = sqlite3_prepare_v2(pCursor->getTable()->getDb(), + pCursor->getSql(), + -1, + &pCursor->stmt, + nullptr); if (res != SQLITE_OK) return res; - pCursor->step_result = sqlite3_step(pCursor->stmt); + pCursor->setStep_result(sqlite3_step(pCursor->getStmt())); } else { - if (pVtabCursor->pVtab->zErrMsg != 0) - sqlite3_free(pVtabCursor->pVtab->zErrMsg); + auto ptrVtab = static_cast(pCursor->pVtab); + ptrVtab->setError(sqlite3_mprintf( + "%s %s", "vssIndexFilter error: unhandled idxStr", idxStr)); - pVtabCursor->pVtab->zErrMsg = sqlite3_mprintf( - "%s %s", "vssIndexFilter error: unhandled idxStr", idxStr); return SQLITE_ERROR; } - pCursor->iCurrent = 0; + pCursor->setICurrent(0); return SQLITE_OK; } @@ -1059,15 +536,15 @@ static int vssIndexNext(sqlite3_vtab_cursor *cur) { auto pCursor = static_cast(cur); - switch (pCursor->query_type) { + switch (pCursor->getQuery_type()) { case QueryType::search: case QueryType::range_search: - pCursor->iCurrent++; + pCursor->incrementICurrent(); break; case QueryType::fullscan: - pCursor->step_result = sqlite3_step(pCursor->stmt); + pCursor->setStep_result(sqlite3_step(pCursor->getStmt())); } return SQLITE_OK; @@ -1077,18 +554,18 @@ static int vssIndexRowid(sqlite3_vtab_cursor *cur, sqlite_int64 *pRowid) { auto pCursor = static_cast(cur); - switch (pCursor->query_type) { + switch (pCursor->getQuery_type()) { case QueryType::search: - *pRowid = pCursor->search_ids.at(pCursor->iCurrent); + *pRowid = pCursor->getSearch_ids().at(pCursor->getICurrent()); break; case QueryType::range_search: - *pRowid = pCursor->range_search_result->labels[pCursor->iCurrent]; + *pRowid = pCursor->getRange_search_result()->labels[pCursor->getICurrent()]; break; case QueryType::fullscan: - *pRowid = sqlite3_column_int64(pCursor->stmt, 0); + *pRowid = sqlite3_column_int64(pCursor->getStmt(), 0); break; } return SQLITE_OK; @@ -1098,18 +575,18 @@ static int vssIndexEof(sqlite3_vtab_cursor *cur) { auto pCursor = static_cast(cur); - switch (pCursor->query_type) { + switch (pCursor->getQuery_type()) { case QueryType::search: - return pCursor->iCurrent >= pCursor->limit || - pCursor->iCurrent >= pCursor->search_ids.size() - || (pCursor->search_ids.at(pCursor->iCurrent) == -1); + return pCursor->getICurrent() >= pCursor->getLimit() || + pCursor->getICurrent() >= pCursor->getSearch_ids().size() + || (pCursor->getSearch_ids().at(pCursor->getICurrent()) == -1); case QueryType::range_search: - return pCursor->iCurrent >= pCursor->range_search_result->lims[1]; + return pCursor->getICurrent() >= pCursor->getRange_search_result()->lims[1]; case QueryType::fullscan: - return pCursor->step_result != SQLITE_ROW; + return pCursor->getStep_result() != SQLITE_ROW; } return 1; } @@ -1122,16 +599,16 @@ static int vssIndexColumn(sqlite3_vtab_cursor *cur, if (i == VSS_INDEX_COLUMN_DISTANCE) { - switch (pCursor->query_type) { + switch (pCursor->getQuery_type()) { case QueryType::search: sqlite3_result_double(ctx, - pCursor->search_distances.at(pCursor->iCurrent)); + pCursor->getSearch_distances().at(pCursor->getICurrent())); break; case QueryType::range_search: sqlite3_result_double(ctx, - pCursor->range_search_result->distances[pCursor->iCurrent]); + pCursor->getRange_search_result()->distances[pCursor->getICurrent()]); break; case QueryType::fullscan: @@ -1141,14 +618,15 @@ static int vssIndexColumn(sqlite3_vtab_cursor *cur, } else if (i >= VSS_INDEX_COLUMN_VECTORS) { auto index = - pCursor->table->indexes.at(i - VSS_INDEX_COLUMN_VECTORS)->index; + pCursor->getTable()->getIndexes().at(i - VSS_INDEX_COLUMN_VECTORS); - vector vec(index->d); + vector vec(index->dimensions()); sqlite3_int64 rowId; vssIndexRowid(cur, &rowId); try { - index->reconstruct(rowId, vec.data()); + + index->reconstruct(rowId, vec); } catch (faiss::FaissException &e) { @@ -1162,7 +640,7 @@ static int vssIndexColumn(sqlite3_vtab_cursor *cur, sqlite3_free(errmsg); return SQLITE_ERROR; } - pCursor->table->vector_api->xResultVector(ctx, &vec); + pCursor->getTable()->getVector0_api()->xResultVector(ctx, &vec); } return SQLITE_OK; } @@ -1178,71 +656,32 @@ static int vssIndexSync(sqlite3_vtab *pVTab) { try { - bool needsWriting = false; - - auto idxCol = 0; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, idxCol++) { - - // Checking if index needs training. - if (!(*iter)->trainings.empty()) { - - (*iter)->index->train( - (*iter)->trainings.size() / (*iter)->index->d, - (*iter)->trainings.data()); - - (*iter)->trainings.clear(); - (*iter)->trainings.shrink_to_fit(); - - needsWriting = true; - } - - // Checking if we're deleting records from the index. - if (!(*iter)->delete_ids.empty()) { - - faiss::IDSelectorBatch selector((*iter)->delete_ids.size(), - (*iter)->delete_ids.data()); - - (*iter)->index->remove_ids(selector); - (*iter)->delete_ids.clear(); - (*iter)->delete_ids.shrink_to_fit(); - - needsWriting = true; - } - - // Checking if we're inserting records to the index. - if (!(*iter)->insert_data.empty()) { - - (*iter)->index->add_with_ids( - (*iter)->insert_ids.size(), - (*iter)->insert_data.data(), - (faiss::idx_t *)(*iter)->insert_ids.data()); - - (*iter)->insert_ids.clear(); - (*iter)->insert_ids.shrink_to_fit(); + auto i = 0; + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter, i++) { - (*iter)->insert_data.clear(); - (*iter)->insert_data.shrink_to_fit(); + // Synchronizing index, implying deleting, training, and inserting records according to needs. + if ((*iter)->synchronize()) { - needsWriting = true; - } - } + /* + * If the above invocation returned true, we've got updates to currently iterated index, + * hence writing to db. + */ + int rc = (*iter)->write_index(pTable->getDb(), + pTable->getSchema(), + pTable->getName(), + i); - if (needsWriting) { - - int i = 0; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + if (rc != SQLITE_OK) { - int rc = write_index_insert((*iter)->index, - pTable->db, - pTable->schema, - pTable->name, - i); + pTable->setError(sqlite3_mprintf("Error saving index (%d): %s", + rc, + sqlite3_errmsg(pTable->getDb()))); - if (rc != SQLITE_OK) { + // Clearing all indexes to cleanup after ourselves. + for (auto iter2 = pTable->getIndexes().begin(); iter2 != pTable->getIndexes().end(); ++iter2) { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = sqlite3_mprintf("Error saving index (%d): %s", - rc, sqlite3_errmsg(pTable->db)); + (*iter2)->reset(); + } return rc; } } @@ -1252,24 +691,13 @@ static int vssIndexSync(sqlite3_vtab *pVTab) { } catch (faiss::FaissException &e) { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = - sqlite3_mprintf("Error during synchroning index. Full error: %s", - e.msg.c_str()); - - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { - - (*iter)->insert_ids.clear(); - (*iter)->insert_ids.shrink_to_fit(); - - (*iter)->insert_data.clear(); - (*iter)->insert_data.shrink_to_fit(); + pTable->setError(sqlite3_mprintf("Error during synchroning index. Full error: %s", + e.msg.c_str())); - (*iter)->delete_ids.clear(); - (*iter)->delete_ids.shrink_to_fit(); + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter) { - (*iter)->trainings.clear(); - (*iter)->trainings.shrink_to_fit(); + // Cleanups in case we've got hanging data. + (*iter)->reset(); } return SQLITE_ERROR; @@ -1282,19 +710,9 @@ static int vssIndexRollback(sqlite3_vtab *pVTab) { auto pTable = static_cast(pVTab); - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter) { - (*iter)->trainings.clear(); - (*iter)->trainings.shrink_to_fit(); - - (*iter)->insert_data.clear(); - (*iter)->insert_data.shrink_to_fit(); - - (*iter)->insert_ids.clear(); - (*iter)->insert_ids.shrink_to_fit(); - - (*iter)->delete_ids.clear(); - (*iter)->delete_ids.shrink_to_fit(); + (*iter)->reset(); } return SQLITE_OK; } @@ -1311,15 +729,15 @@ static int vssIndexUpdate(sqlite3_vtab *pVTab, // DELETE operation sqlite3_int64 rowid_to_delete = sqlite3_value_int64(argv[0]); - auto rc = shadow_data_delete(pTable->db, - pTable->schema, - pTable->name, + auto rc = shadow_data_delete(pTable->getDb(), + pTable->getSchema(), + pTable->getName(), rowid_to_delete); if (rc != SQLITE_OK) return rc; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter) { - (*iter)->delete_ids.push_back(rowid_to_delete); + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter) { + (*iter)->addDelete(rowid_to_delete); } } else if (argc > 1 && sqlite3_value_type(argv[0]) == SQLITE_NULL) { @@ -1334,45 +752,40 @@ static int vssIndexUpdate(sqlite3_vtab *pVTab, vec_ptr vec; sqlite3_int64 rowid = sqlite3_value_int64(argv[1]); + + // Needed to make sure we insert null record into x_data table. bool inserted_rowid = false; auto i = 0; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter, i++) { - if ((vec = pTable->vector_api->xValueAsVector( + if ((vec = pTable->getVector0_api()->xValueAsVector( argv[2 + VSS_INDEX_COLUMN_VECTORS + i])) != nullptr) { // Make sure the index is already trained, if it's needed - if (!(*iter)->index->is_trained) { + if (!(*iter)->isTrained()) { - sqlite3_free(pVTab->zErrMsg); - pVTab->zErrMsg = - sqlite3_mprintf("Index at i=%d requires training " - "before inserting data.", - i); + pTable->setError(sqlite3_mprintf("Index at i=%d requires training " + "before inserting data.", + i)); return SQLITE_ERROR; } if (!inserted_rowid) { - sqlite_int64 retrowid; - auto rc = shadow_data_insert(pTable->db, pTable->schema, pTable->name, - &rowid, &retrowid); + auto rc = shadow_data_insert(pTable->getDb(), + pTable->getSchema(), + pTable->getName(), + rowid); + if (rc != SQLITE_OK) return rc; inserted_rowid = true; } - (*iter)->insert_data.reserve((*iter)->insert_data.size() + vec->size()); - (*iter)->insert_data.insert( - (*iter)->insert_data.end(), - vec->begin(), - vec->end()); - - (*iter)->insert_ids.push_back(rowid); - + (*iter)->addInsertData(rowid, vec); *pRowid = rowid; } } @@ -1384,17 +797,11 @@ static int vssIndexUpdate(sqlite3_vtab *pVTab, if (operation.compare("training") == 0) { auto i = 0; - for (auto iter = pTable->indexes.begin(); iter != pTable->indexes.end(); ++iter, i++) { - - vec_ptr vec = pTable->vector_api->xValueAsVector(argv[2 + VSS_INDEX_COLUMN_VECTORS + i]); - if (vec != nullptr) { + for (auto iter = pTable->getIndexes().begin(); iter != pTable->getIndexes().end(); ++iter, i++) { - (*iter)->trainings.reserve((*iter)->trainings.size() + vec->size()); - (*iter)->trainings.insert( - (*iter)->trainings.end(), - vec->begin(), - vec->end()); - } + vec_ptr vec = pTable->getVector0_api()->xValueAsVector(argv[2 + VSS_INDEX_COLUMN_VECTORS + i]); + if (vec != nullptr) + (*iter)->addTrainings(vec); } } else { @@ -1406,11 +813,7 @@ static int vssIndexUpdate(sqlite3_vtab *pVTab, } else { // TODO: Implement - UPDATE operation - sqlite3_free(pVTab->zErrMsg); - - pVTab->zErrMsg = - sqlite3_mprintf("UPDATE statements on vss0 virtual tables not supported yet."); - + pTable->setError(sqlite3_mprintf("UPDATE statements on vss0 virtual tables not supported yet.")); return SQLITE_ERROR; } @@ -1465,7 +868,7 @@ static int vssIndexShadowName(const char *zName) { } static sqlite3_module vssIndexModule = { - /* iVersion */ 3, + /* iVersion */ 3, // TODO: Shouldn't this be the same as the version for sqlite-vector.cpp? /* xCreate */ vssIndexCreate, /* xConnect */ vssIndexConnect, /* xBestIndex */ vssIndexBestIndex, @@ -1484,11 +887,12 @@ static sqlite3_module vssIndexModule = { /* xCommit */ vssIndexCommit, /* xRollback */ vssIndexRollback, /* xFindMethod */ vssIndexFindFunction, - /* xRename */ 0, - /* xSavepoint */ 0, - /* xRelease */ 0, - /* xRollbackTo */ 0, - /* xShadowName */ vssIndexShadowName}; + /* xRename */ nullptr, + /* xSavepoint */ nullptr, + /* xRelease */ nullptr, + /* xRollbackTo */ nullptr, + /* xShadowName */ vssIndexShadowName +}; #pragma endregion @@ -1496,28 +900,17 @@ static sqlite3_module vssIndexModule = { vector0_api *vector0_api_from_db(sqlite3 *db) { - vector0_api *pRet = nullptr; - sqlite3_stmt *pStmt = nullptr; - - auto rc = sqlite3_prepare(db, "select vector0(?1)", -1, &pStmt, nullptr); - if (rc != SQLITE_OK) + SqlStatement select(db, sqlite3_mprintf("select vector0(?1)")); + if (select.prepare() != SQLITE_OK) return nullptr; - rc = sqlite3_bind_pointer(pStmt, 1, (void *)&pRet, "vector0_api_ptr", nullptr); - if (rc != SQLITE_OK) { - - sqlite3_finalize(pStmt); + vector0_api *pRet = nullptr; + if (select.bind_pointer(1, (void *)&pRet, "vector0_api_ptr") != SQLITE_OK) return nullptr; - } - rc = sqlite3_step(pStmt); - if (rc != SQLITE_ROW) { - - sqlite3_finalize(pStmt); + if (select.step() != SQLITE_ROW) return nullptr; - } - sqlite3_finalize(pStmt); return pRet; } @@ -1547,7 +940,7 @@ __declspec(dllexport) SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, 0, vss_version, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_debug", @@ -1555,7 +948,7 @@ __declspec(dllexport) SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, 0, vss_debug, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_distance_l1", @@ -1563,49 +956,49 @@ __declspec(dllexport) SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_distance_l1, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_distance_l2", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_distance_l2, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_distance_linf", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_distance_linf, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_inner_product", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_inner_product, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_fvec_add", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_fvec_add, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_fvec_sub", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vss_fvec_sub, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_search", 2, SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vssSearchFunc, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_search_params", @@ -1613,7 +1006,7 @@ __declspec(dllexport) 0, vector_api, vssSearchParamsFunc, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_range_search", @@ -1621,7 +1014,7 @@ __declspec(dllexport) SQLITE_UTF8 | SQLITE_DETERMINISTIC | SQLITE_INNOCUOUS, vector_api, vssRangeSearchFunc, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_range_search_params", @@ -1629,7 +1022,7 @@ __declspec(dllexport) 0, vector_api, vssRangeSearchParamsFunc, - 0, 0, 0); + nullptr, nullptr, nullptr); sqlite3_create_function_v2(db, "vss_memory_usage", @@ -1637,7 +1030,7 @@ __declspec(dllexport) 0, nullptr, faissMemoryUsageFunc, - 0, 0, 0); + nullptr, nullptr, nullptr); auto rc = sqlite3_create_module_v2(db, "vss0", &vssIndexModule, vector_api, nullptr); if (rc != SQLITE_OK) { diff --git a/src/vec/functions.h b/src/vec/functions.h new file mode 100644 index 0000000..a96c403 --- /dev/null +++ b/src/vec/functions.h @@ -0,0 +1,311 @@ + +#ifndef TRANSFORMERS_H +#define TRANSFORMERS_H + +char VECTOR_BLOB_HEADER_BYTE = 'v'; +char VECTOR_BLOB_HEADER_TYPE = 1; +const char *VECTOR_FLOAT_POINTER_NAME = "vectorf32v0"; + +// https://github.com/sqlite/sqlite/blob/master/src/json.c#L88-L89 +#define JSON_SUBTYPE 74 /* Ascii for "J" */ + +#include +using json = nlohmann::json; + +void delVectorFloat(void *p) { + + auto vx = static_cast(p); + sqlite3_free(vx->data); + delete vx; +} + +void resultVector(sqlite3_context *context, vector *vecIn) { + + auto vecRes = new VectorFloat(); + + vecRes->size = vecIn->size(); + vecRes->data = (float *)sqlite3_malloc(vecIn->size() * sizeof(float)); + + memcpy(vecRes->data, vecIn->data(), vecIn->size() * sizeof(float)); + + sqlite3_result_pointer(context, vecRes, VECTOR_FLOAT_POINTER_NAME, delVectorFloat); +} + +vec_ptr vectorFromBlobValue(sqlite3_value *value, const char **pzErrMsg) { + + int size = sqlite3_value_bytes(value); + char header; + char type; + + if (size < (2)) { + *pzErrMsg = "Vector blob size less than header length"; + return nullptr; + } + + const void *pBlob = sqlite3_value_blob(value); + memcpy(&header, ((char *)pBlob + 0), sizeof(char)); + memcpy(&type, ((char *)pBlob + 1), sizeof(char)); + + if (header != VECTOR_BLOB_HEADER_BYTE) { + *pzErrMsg = "Blob not well-formatted vector blob"; + return nullptr; + } + + if (type != VECTOR_BLOB_HEADER_TYPE) { + *pzErrMsg = "Blob type not right"; + return nullptr; + } + + int numElements = (size - 2) / sizeof(float); + float *vec = (float *)((char *)pBlob + 2); + return vec_ptr(new vector(vec, vec + numElements)); +} + +vec_ptr vectorFromRawBlobValue(sqlite3_value *value, const char **pzErrMsg) { + + int size = sqlite3_value_bytes(value); + + // Must be divisible by 4 + if (size % 4) { + *pzErrMsg = "Invalid raw blob length, blob must be divisible by 4"; + return nullptr; + } + const void *pBlob = sqlite3_value_blob(value); + + float *vec = (float *)((char *)pBlob); + return vec_ptr(new vector(vec, vec + (size / 4))); +} + +vec_ptr vectorFromTextValue(sqlite3_value *value) { + + try { + + json json = json::parse(sqlite3_value_text(value)); + vec_ptr pVec(new vector()); + json.get_to(*pVec); + return pVec; + + } catch (const json::exception &) { + return nullptr; + } + + return nullptr; +} + +static vec_ptr valueAsVector(sqlite3_value *value) { + + // Option 1: If the value is a "vectorf32v0" pointer, create vector from + // that + auto vec = (VectorFloat *)sqlite3_value_pointer(value, VECTOR_FLOAT_POINTER_NAME); + + if (vec != nullptr) + return vec_ptr(new vector(vec->data, vec->data + vec->size)); + + vec_ptr pVec; + + // Option 2: value is a blob in vector format + if (sqlite3_value_type(value) == SQLITE_BLOB) { + + const char *pzErrMsg = nullptr; + + if ((pVec = vectorFromBlobValue(value, &pzErrMsg)) != nullptr) + return pVec; + + if ((pVec = vectorFromRawBlobValue(value, &pzErrMsg)) != nullptr) + return pVec; + } + + // Option 3: if value is a JSON array coercible to float vector, use that + if (sqlite3_value_type(value) == SQLITE_TEXT) { + + if ((pVec = vectorFromTextValue(value)) != nullptr) + return pVec; + else + return nullptr; + } + + // Else, value isn't a vector + return nullptr; +} + +// TODO should return fvec, ivec, or bvec depending on input. How do bvec, +// though? +static void vector_from(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vector vec; + vec.reserve(argc); + for (int i = 0; i < argc; i++) { + vec.push_back(sqlite3_value_double(argv[i])); + } + + resultVector(context, &vec); +} + +static void vector_value_at(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + + if (pVec == nullptr) + return; + + int pos = sqlite3_value_int(argv[1]); + + try { + + float result = pVec->at(pos); + sqlite3_result_double(context, result); + + } catch (const out_of_range &oor) { + + char *errmsg = sqlite3_mprintf("%d out of range: %s", pos, oor.what()); + + if (errmsg != nullptr) { + sqlite3_result_error(context, errmsg, -1); + sqlite3_free(errmsg); + } else { + sqlite3_result_error_nomem(context); + } + } +} + +static void vector_length(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto pVec = (VectorFloat *)sqlite3_value_pointer(argv[0], VECTOR_FLOAT_POINTER_NAME); + if (pVec == nullptr) + return; + + sqlite3_result_int64(context, pVec->size); +} + +static void vector_to_json(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; + + json j = json(*pVec); + + sqlite3_result_text(context, j.dump().c_str(), -1, SQLITE_TRANSIENT); + sqlite3_result_subtype(context, JSON_SUBTYPE); +} + +static void vector_from_json(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + const char *text = (const char *)sqlite3_value_text(argv[0]); + vec_ptr pVec = vectorFromTextValue(argv[0]); + + if (pVec == nullptr) { + sqlite3_result_error( + context, "input not valid json, or contains non-float data", -1); + } else { + resultVector(context, pVec.get()); + } +} + +static void vector_to_blob(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; + + int size = pVec->size(); + int memSize = (sizeof(char)) + (sizeof(char)) + (size * 4); + void *pBlob = sqlite3_malloc(memSize); + memset(pBlob, 0, memSize); + + memcpy((void *)((char *)pBlob + 0), (void *)&VECTOR_BLOB_HEADER_BYTE, sizeof(char)); + memcpy((void *)((char *)pBlob + 1), (void *)&VECTOR_BLOB_HEADER_TYPE, sizeof(char)); + memcpy((void *)((char *)pBlob + 2), (void *)pVec->data(), size * 4); + + sqlite3_result_blob64(context, pBlob, memSize, sqlite3_free); +} + +static void vector_from_blob(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + const char *pzErrMsg; + + vec_ptr pVec = vectorFromBlobValue(argv[0], &pzErrMsg); + if (pVec == nullptr) + sqlite3_result_error(context, pzErrMsg, -1); + else + resultVector(context, pVec.get()); +} + +static void vector_to_raw(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + if (pVec == nullptr) + return; + + int size = pVec->size(); + int n = size * sizeof(float); + void *pBlob = sqlite3_malloc(n); + memset(pBlob, 0, n); + memcpy((void *)((char *)pBlob), (void *)pVec->data(), n); + sqlite3_result_blob64(context, pBlob, n, sqlite3_free); +} + +static void vector_from_raw(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + const char *pzErrMsg; // TODO: Shouldn't we have like error messages here? + + vec_ptr pVec = vectorFromRawBlobValue(argv[0], &pzErrMsg); + if (pVec == nullptr) + sqlite3_result_error(context, pzErrMsg, -1); + else + resultVector(context, pVec.get()); +} + +static void vector_version(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); +} + +static void vector_debug(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + vec_ptr pVec = valueAsVector(argv[0]); + + if (pVec == nullptr) { + + sqlite3_result_error(context, "Value not a vector", -1); + return; + } + + sqlite3_str *str = sqlite3_str_new(0); + sqlite3_str_appendf(str, "size: %lld [", pVec->size()); + + for (int i = 0; i < pVec->size(); i++) { + + if (i == 0) + sqlite3_str_appendf(str, "%f", pVec->at(i)); + else + sqlite3_str_appendf(str, ", %f", pVec->at(i)); + } + + sqlite3_str_appendchar(str, 1, ']'); + sqlite3_result_text(context, sqlite3_str_finish(str), -1, sqlite3_free); +} + +#endif // TRANSFORMERS_H diff --git a/src/vec/fvecsEach_cursor.h b/src/vec/fvecsEach_cursor.h new file mode 100644 index 0000000..13ae749 --- /dev/null +++ b/src/vec/fvecsEach_cursor.h @@ -0,0 +1,56 @@ + +#ifndef FVECSEACH_CURSOR_H +#define FVECSEACH_CURSOR_H + +class fvecsEach_cursor : public sqlite3_vtab_cursor { + +public: + + fvecsEach_cursor(sqlite3_vtab *pVtab) { + + this->pVtab = pVtab; + iRowid = 0; + pBlob = nullptr; + iBlobN = 0; + p = 0; + iCurrentD = 0; + } + + ~fvecsEach_cursor() { + + if (pBlob != nullptr) + sqlite3_free(pBlob); + } + + void * getBlob() { + + return pBlob; + } + + void setBlob(void * blob) { + + if (pBlob != nullptr) + sqlite3_free(pBlob); + + pBlob = blob; + } + + sqlite3_int64 iRowid; + + // Total size of pBlob in bytes + sqlite3_int64 iBlobN; + sqlite3_int64 p; + + // Current dimensions + int iCurrentD; + + // Pointer to current vector being read in + vec_ptr pCurrentVector; + +private: + + // Copy of fvecs input blob + void *pBlob; +}; + +#endif // FVECSEACH_CURSOR_H diff --git a/src/vec/fvecsEach_vtab.h b/src/vec/fvecsEach_vtab.h new file mode 100644 index 0000000..009273c --- /dev/null +++ b/src/vec/fvecsEach_vtab.h @@ -0,0 +1,22 @@ + +#ifndef FVECSEACH_VTAB_H +#define FVECSEACH_VTAB_H + +struct fvecsEach_vtab : public sqlite3_vtab { + + fvecsEach_vtab() { + + pModule = nullptr; + nRef = 0; + zErrMsg = nullptr; + } + + ~fvecsEach_vtab() { + + if (zErrMsg != nullptr) { + sqlite3_free(zErrMsg); + } + } +}; + +#endif // FVECSEACH_VTAB_H diff --git a/src/vss/calculations.h b/src/vss/calculations.h new file mode 100644 index 0000000..8d55e95 --- /dev/null +++ b/src/vss/calculations.h @@ -0,0 +1,171 @@ + +#ifndef VSS_CALCULATIONS_H +#define VSS_CALCULATIONS_H + +#include "inclusions.h" + +static void vss_distance_l1(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_L1(lhs->data(), rhs->data(), lhs->size())); +} + +static void vss_distance_l2(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_L2sqr(lhs->data(), rhs->data(), lhs->size())); +} + +static void vss_distance_linf(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, faiss::fvec_Linf(lhs->data(), rhs->data(), lhs->size())); +} + +static void vss_inner_product(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + sqlite3_result_double(context, + faiss::fvec_inner_product(lhs->data(), rhs->data(), lhs->size())); +} + +static void vss_fvec_add(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", + -1); + return; + } + + auto size = lhs->size(); + vec_ptr c(new vector(size)); + faiss::fvec_add(size, lhs->data(), rhs->data(), c->data()); + + vector_api->xResultVector(context, c.get()); +} + +static void vss_fvec_sub(sqlite3_context *context, int argc, + sqlite3_value **argv) { + + auto vector_api = (vector0_api *)sqlite3_user_data(context); + + vec_ptr lhs = vector_api->xValueAsVector(argv[0]); + if (lhs == nullptr) { + sqlite3_result_error(context, "LHS is not a vector", -1); + return; + } + + vec_ptr rhs = vector_api->xValueAsVector(argv[1]); + if (rhs == nullptr) { + sqlite3_result_error(context, "RHS is not a vector", -1); + return; + } + + if (lhs->size() != rhs->size()) { + sqlite3_result_error(context, "LHS and RHS are not vectors of the same size", -1); + return; + } + + int size = lhs->size(); + vec_ptr c = vec_ptr(new vector(size)); + faiss::fvec_sub(size, lhs->data(), rhs->data(), c->data()); + vector_api->xResultVector(context, c.get()); +} + +#endif // VSS_CALCULATIONS_H diff --git a/src/vss/inclusions.h b/src/vss/inclusions.h new file mode 100644 index 0000000..006d432 --- /dev/null +++ b/src/vss/inclusions.h @@ -0,0 +1,35 @@ + +#ifndef VSS_INCLUSIONS_H +#define VSS_INCLUSIONS_H + +#include +#include + +#include "sqlite3ext.h" +SQLITE_EXTENSION_INIT1 + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +typedef unique_ptr> vec_ptr; + +enum QueryType { search, range_search, fullscan }; + +#define VSS_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION +#define VSS_RANGE_SEARCH_FUNCTION SQLITE_INDEX_CONSTRAINT_FUNCTION + 1 + +#endif // VSS_INCLUSIONS_H diff --git a/src/vss/meta-methods.h b/src/vss/meta-methods.h new file mode 100644 index 0000000..05a70f7 --- /dev/null +++ b/src/vss/meta-methods.h @@ -0,0 +1,31 @@ + +#ifndef META_METHODS_H +#define META_METHODS_H + +#include "inclusions.h" + +static void vss_version(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + sqlite3_result_text(context, SQLITE_VSS_VERSION, -1, SQLITE_STATIC); +} + +static void vss_debug(sqlite3_context *context, + int argc, + sqlite3_value **argv) { + + auto resTxt = sqlite3_mprintf( + "version: %s\nfaiss version: %d.%d.%d\nfaiss compile options: %s", + SQLITE_VSS_VERSION, + FAISS_VERSION_MAJOR, + FAISS_VERSION_MINOR, + FAISS_VERSION_PATCH, + faiss::get_compile_options().c_str()); + + sqlite3_result_text(context, resTxt, -1, SQLITE_TRANSIENT); + sqlite3_free(resTxt); +} + + +#endif // META_METHODS_H diff --git a/src/vss/sql-statement.h b/src/vss/sql-statement.h new file mode 100644 index 0000000..cf4a076 --- /dev/null +++ b/src/vss/sql-statement.h @@ -0,0 +1,111 @@ + +#ifndef SQL_STATEMENT_H +#define SQL_STATEMENT_H + +#include "inclusions.h" + +/* + * Helper class encapsulating an SQL statement towards SQLite, with automatic and deterministic destruction + * and cleanup of any heap memory, etc. + */ +class SqlStatement { + +public: + + SqlStatement(sqlite3 *db, const char * sql) : db(db), sql(sql), stmt(nullptr) { + + this->sql = sql; + } + + ~SqlStatement() { + + if (stmt != nullptr) + sqlite3_finalize(stmt); + if (sql != nullptr) + sqlite3_free((void *)sql); + } + + int prepare() { + + auto res = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr); + if (res != SQLITE_OK || stmt == nullptr) { + + stmt = nullptr; + return SQLITE_ERROR; + } + return res; + } + + int bind_int64(int colNo, sqlite3_int64 value) { + + return sqlite3_bind_int64(stmt, colNo, value); + } + + int bind_blob64(int colNo, const void * data, int size) { + + return sqlite3_bind_blob64(stmt, colNo, data, size, SQLITE_TRANSIENT); + } + + int bind_null(int colNo) { + + return sqlite3_bind_null(stmt, colNo); + } + + int bind_pointer(int paramNo, void *ptr, const char * name) { + + return sqlite3_bind_pointer(stmt, paramNo, ptr, name, nullptr); + } + + int step() { + + return sqlite3_step(stmt); + } + + int exec() { + + return sqlite3_exec(db, sql, nullptr, nullptr, nullptr); + } + + int declare_vtab() { + + return sqlite3_declare_vtab(db, sql); + } + + const void * column_blob(int colNo) { + + return sqlite3_column_blob(stmt, colNo); + } + + int column_bytes(int colNo) { + + return sqlite3_column_bytes(stmt, colNo); + } + + int column_int64(int colNo) { + + return sqlite3_column_int64(stmt, colNo); + } + + int last_insert_rowid() { + + return sqlite3_last_insert_rowid(db); + } + + void finalize() { + + if (stmt != nullptr) + sqlite3_finalize(stmt); + stmt = nullptr; + if (sql != nullptr) + sqlite3_free((void *)sql); + sql = nullptr; + } + +private: + + sqlite3 *db; + sqlite3_stmt *stmt; + const char * sql; +}; + +#endif // SQL_STATEMENT_H diff --git a/src/vss/vss-index-cursor.h b/src/vss/vss-index-cursor.h new file mode 100644 index 0000000..b4f58fd --- /dev/null +++ b/src/vss/vss-index-cursor.h @@ -0,0 +1,144 @@ + +#ifndef VSS_INDEX_CURSOR_H +#define VSS_INDEX_CURSOR_H + +#include "inclusions.h" + +class vss_index_cursor : public sqlite3_vtab_cursor { + +public: + + explicit vss_index_cursor(vss_index_vtab *table) + : table(table), + sqlite3_vtab_cursor({0}), + stmt(nullptr), + sql(nullptr) { } + + ~vss_index_cursor() { + + if (stmt != nullptr) + sqlite3_finalize(stmt); + + if (sql != nullptr) + sqlite3_free(sql); + } + + vss_index_vtab * getTable() { + + return table; + } + + sqlite3_int64 getICurrent() { + + return iCurrent; + } + + sqlite3_int64 getIRowid() { + + return iRowid; + } + + QueryType getQuery_type() { + + return query_type; + } + + sqlite3_int64 getLimit() { + + return limit; + } + + vector & getSearch_ids() { + + return search_ids; + } + + vector & getSearch_distances() { + + return search_distances; + } + + unique_ptr & getRange_search_result() { + + return range_search_result; + } + + sqlite3_stmt *getStmt() { + + return stmt; + } + + int getStep_result() { + + return step_result; + } + + void setStep_result(int value) { + + step_result = value; + } + + void incrementICurrent() { + + iCurrent += 1; + } + + void setICurrent(sqlite3_int64 value) { + + iCurrent = value; + } + + void resetSearch(long noItems) { + + search_distances = vector(noItems, 0); + search_ids = vector(noItems, 0); + } + + void setQuery_type(QueryType value) { + + query_type = value; + } + + void setSql(char * value) { + + if (sql != nullptr) + sqlite3_free(sql); + sql = value; + } + + char * getSql() { + + return sql; + } + + void setLimit(sqlite3_int64 value) { + + limit = value; + } + + // TODO: Parts of our logic requires the address to the pointer such that we can assign what it's pointing at + sqlite3_stmt *stmt; + +private: + + vss_index_vtab *table; + + sqlite3_int64 iCurrent; + sqlite3_int64 iRowid; + + QueryType query_type; + + // For query_type == QueryType::search + sqlite3_int64 limit; + vector search_ids; + vector search_distances; + + // For query_type == QueryType::range_search + unique_ptr range_search_result; + + // For query_type == QueryType::fullscan + char *sql; + int step_result; +}; + +#endif // VSS_INDEX_CURSOR_H diff --git a/src/vss/vss-index-vtab.h b/src/vss/vss-index-vtab.h new file mode 100644 index 0000000..8d179c3 --- /dev/null +++ b/src/vss/vss-index-vtab.h @@ -0,0 +1,88 @@ + +#ifndef VSS_INDEX_VTAB_H +#define VSS_INDEX_VTAB_H + +#include "inclusions.h" + +class vss_index_vtab : public sqlite3_vtab { + +public: + + vss_index_vtab(sqlite3 *db, vector0_api *vector_api, char *schema, char *name) + : db(db), + vector_api(vector_api), + schema(schema), + name(name) { + + this->zErrMsg = nullptr; + } + + ~vss_index_vtab() { + + if (name) + sqlite3_free(name); + + if (schema) + sqlite3_free(schema); + + if (this->zErrMsg != nullptr) + delete this->zErrMsg; + + // Deleting all indexes associated with table. + for (auto iter = indexes.begin(); iter != indexes.end(); ++iter) { + delete (*iter); + } + } + + void setError(char *error) { + + if (this->zErrMsg != nullptr) { + delete this->zErrMsg; + } + + this->zErrMsg = error; + } + + sqlite3 * getDb() { + + return db; + } + + vector0_api * getVector0_api() { + + return vector_api; + } + + vector & getIndexes() { + + return indexes; + } + + char * getName() { + + return name; + } + + char * getSchema() { + + return schema; + } + +private: + + sqlite3 *db; + vector0_api *vector_api; + + // Name of the virtual table. Must be freed during disconnect + char *name; + + // Name of the schema the virtual table exists in. Must be freed during + // disconnect + char *schema; + + // Vector holding all the faiss Indices the vtab uses, and their state, + // implying which items are to be deleted and inserted. + vector indexes; +}; + +#endif // VSS_INDEX_VTAB_H diff --git a/src/vss/vss-index.h b/src/vss/vss-index.h new file mode 100644 index 0000000..d4636e4 --- /dev/null +++ b/src/vss/vss-index.h @@ -0,0 +1,372 @@ + +#ifndef VSS_INDEX_H +#define VSS_INDEX_H + +#include "inclusions.h" +/* + * Wrapper around a single faiss index, with training data, insert records, and + * delete records. + * + * An attempt at encapsulating everything related to faiss::Index instances, such as + * training, inserting, deleting, etc. + */ +class vss_index { + +public: + + ~vss_index() { + + if (index != nullptr) { + delete index; + } + } + + // Returns false if index requires training before inserting items to it. + bool isTrained() { + + return index->is_trained; + } + + // Reconstructs the original vector, requires IDMap2 string in index factory to work. + void reconstruct(sqlite3_int64 rowid, vector & vector) { + + index->reconstruct(rowid, vector.data()); + } + + // Returns true if specified vector is allowed to query index. + bool canQuery(vec_ptr & vec) { + + return vec->size() == index->d; + } + + // Queries the index for matches matching the specified vector + void search(int nq, + vec_ptr & vec, + faiss::idx_t max, + vector & distances, + vector & ids) { + + index->search(nq, vec->data(), max, distances.data(), ids.data()); + } + + // Queries the index for a range of items. + void range_search(int nq, vec_ptr & vec, float distance, unique_ptr & result) { + + index->range_search(nq, vec->data(), distance, result.get()); + } + + // Returns dimensions of index. + faiss::idx_t dimensions() { + + return index->d; + } + + // Returns the size of index. + faiss::idx_t size() { + + return index->ntotal; + } + + /* + * Adds the specified vector to the index' training material. + * + * Notice, needs to invoke synchronize() later to actually perform training of index. + */ + void addTrainings(vec_ptr & vec) { + + trainings.reserve(trainings.size() + vec->size()); + trainings.insert(trainings.end(), vec->begin(), vec->end()); + } + + /* + * Adds the specified vector to the index' temporary insert data. + * + * Notice, needs to invoke synchronize() later to actually add data to index. + */ + void addInsertData(faiss::idx_t rowId, vec_ptr & vec) { + + insert_data.reserve(insert_data.size() + vec->size()); + insert_data.insert(insert_data.end(), vec->begin(), vec->end()); + + insert_ids.push_back(rowId); + } + + /* + * Adds the specified rowid to the index' temporary delete data. + * + * Notice, needs to invoke synchronize() later to actually delete data from index. + */ + void addDelete(faiss::idx_t rowid) { + + delete_ids.push_back(rowid); + } + + /* + * Synchronizes index by updating index according to trainings, inserts and deletes. + */ + bool synchronize() { + + auto result = tryTrain(); + result = tryDelete() || result; + result = tryInsert() || result; + + return result; + } + + /* + * Resets all temporary training data to free memory. + */ + void reset() { + + trainings.clear(); + trainings.shrink_to_fit(); + + insert_data.clear(); + insert_data.shrink_to_fit(); + + insert_ids.clear(); + insert_ids.shrink_to_fit(); + + delete_ids.clear(); + delete_ids.shrink_to_fit(); + } + + int write_index(sqlite3 *db, + const char *schema, + const char *name, + int rowId) { + + // Writing our index + faiss::VectorIOWriter writer; + faiss::write_index(index, &writer); + + // First trying to insert index, if that fails with ROW constraing error, we try to update existing index. + if (write_index_insert(writer, db, schema, name, rowId) == SQLITE_OK) + return SQLITE_OK; + + if (sqlite3_extended_errcode(db) != SQLITE_CONSTRAINT_ROWID) + return SQLITE_ERROR; // Insert failed for unknown error + + // Insert failed because index already existed, updating existing index. + return write_index_update(writer, db, schema, name, rowId); + } + + /* + * Creates a new vss_index as a virtual table and stores + * its initial (empty) state. + */ + static vss_index * factory(sqlite3 *db, + const char *schema, + const char *name, + bool indexNo, + string & factoryArgs, + int dimensions) { + + // Creating a new index and storing in cache. + auto newIndex = new vss_index(faiss::index_factory(dimensions, factoryArgs.c_str())); + + // Checking if this is our first index for table, at which point we create our shadow tables. + if (indexNo == 0) { + + auto rc = create_shadow_tables(db, schema, name); + if (rc != SQLITE_OK) + throw domain_error("Couldn't create shadow tables"); + } + + // Writing its initial (empty) state. + int rc = newIndex->write_index(db, + schema, + name, + indexNo); + + // Returning index to caller. + return newIndex; + } + + /* + * Creates a new vss_index by reading existing data from db, + * or returns a cached index to caller. + */ + static vss_index * factory(sqlite3 *db, + const char *name, + int indexNo) { + + // Reading index from db. + auto newIndex = new vss_index(read_index_select(db, name, indexNo)); + + // Returning index to caller. + return newIndex; + } + +private: + + explicit vss_index(faiss::Index *index) : index(index) { } + + static int create_shadow_tables(sqlite3 *db, + const char *schema, + const char *name) { + + SqlStatement create1(db, + sqlite3_mprintf("create table \"%w\".\"%w_index\"(idx)", + schema, + name)); + + auto rc = create1.exec(); + if (rc != SQLITE_OK) + return rc; + + /* + * Notice, we'll need to explicitly finalize this object since we can only + * have one open statement at the same time to the same connetion. + */ + create1.finalize(); + + SqlStatement create2(db, + sqlite3_mprintf("create table \"%w\".\"%w_data\"(x);", + schema, + name)); + + rc = create2.exec(); + return rc; + } + + static faiss::Index * read_index_select(sqlite3 *db, + const char *name, + int indexId) { + + SqlStatement select(db, + sqlite3_mprintf("select idx from \"%w_index\" where rowid = ?", + name)); + + if (select.prepare() != SQLITE_OK) + return nullptr; + + if (select.bind_int64(1, indexId) != SQLITE_OK) + return nullptr; + + if (select.step() != SQLITE_ROW) + return nullptr; + + auto index_data = select.column_blob(0); + auto size = select.column_bytes(0); + + faiss::VectorIOReader reader; + copy((const uint8_t *)index_data, + ((const uint8_t *)index_data) + size, + back_inserter(reader.data)); + + return faiss::read_index(&reader); + } + + int write_index_insert(faiss::VectorIOWriter &writer, + sqlite3 *db, + const char *schema, + const char *name, + int rowId) { + + // If inserts fails it means index already exists. + SqlStatement insert(db, + sqlite3_mprintf("insert into \"%w\".\"%w_index\"(rowid, idx) values (?, ?)", + schema, + name)); + + if (insert.prepare() != SQLITE_OK) + return SQLITE_ERROR; + + if (insert.bind_int64(1, rowId) != SQLITE_OK) + return SQLITE_ERROR; + + if (insert.bind_blob64(2, writer.data.data(), writer.data.size()) != SQLITE_OK) + return SQLITE_ERROR; + + auto rc = insert.step(); + if (rc == SQLITE_DONE) + return SQLITE_OK; // Index did not exist, and we successfully inserted it. + + return rc; + } + + int write_index_update(faiss::VectorIOWriter &writer, + sqlite3 *db, + const char *schema, + const char *name, + int rowId) { + + // Updating existing index. + SqlStatement update(db, + sqlite3_mprintf("update \"%w\".\"%w_index\" set idx = ? where rowid = ?", + schema, + name)); + + if (update.prepare() != SQLITE_OK) + return SQLITE_ERROR; + + if (update.bind_blob64(1, writer.data.data(), writer.data.size()) != SQLITE_OK) + return SQLITE_ERROR; + + if (update.bind_int64(2, rowId) != SQLITE_OK) + return SQLITE_ERROR; + + auto rc = update.step(); + if (rc == SQLITE_DONE) + return SQLITE_OK; // We successfully updated existing index. + + return rc; + } + + bool tryTrain() { + + if (trainings.empty()) + return false; + + index->train(trainings.size() / index->d, trainings.data()); + + trainings.clear(); + trainings.shrink_to_fit(); + + return true; + } + + bool tryInsert() { + + if (insert_ids.empty()) + return false; + + index->add_with_ids( + insert_ids.size(), + insert_data.data(), + insert_ids.data()); + + insert_ids.clear(); + insert_ids.shrink_to_fit(); + + insert_data.clear(); + insert_data.shrink_to_fit(); + + return true; + } + + bool tryDelete() { + + if (delete_ids.empty()) + return false; + + faiss::IDSelectorBatch selector(delete_ids.size(), + delete_ids.data()); + + index->remove_ids(selector); + + delete_ids.clear(); + delete_ids.shrink_to_fit(); + + return true; + } + + faiss::Index * index; + vector trainings; + vector insert_data; + vector insert_ids; + vector delete_ids; +}; + +#endif // VSS_INDEX_H