diff --git a/README.md b/README.md index 19be087..8924401 100644 --- a/README.md +++ b/README.md @@ -111,32 +111,6 @@ rejected rather than silently returning wrong data. Each opened connection makes one small request to determine the file size, then fetches frames on demand. Frames are cached per connection, so repeated reads do not re-hit the network. -#### Coalescing reads for remote (S3/CDN) databases - -By default each frame is fetched in its own Range GET. For high-latency stores -like S3 a query can fire many small GETs. Enable an in-memory, page-aligned read -cache to coalesce the contiguous run of missing pages behind a read into a -single GET (default page size 64 KiB) and to serve adjacent frames from cache: - -```go -import sqlitezstd "github.com/jtarchie/sqlitezstd" - -// Register a cache-enabled VFS once (e.g. at startup). DSN query params are -// stripped before the VFS sees the path, so configuration lives on the named -// VFS, not the URL. -err := sqlitezstd.Register("zstdcache", - sqlitezstd.WithHTTPCacheSize(64<<20), // ~64 MiB of coalesced pages per open -) - -db, _ := sql.Open("sqlite3", "https://bucket.example.com/segment.sqlite.zst?vfs=zstdcache") -``` - -In practice this collapses a remote query's request count by an order of -magnitude — a full-table-scan test issues **125 Range GETs without the cache vs -9 with it (~14× fewer)**. The cache is per opened file and bounded by the -configured byte cap (LRU eviction), so memory stays bounded. Tune the page size -with `WithHTTPPageSize`. - For authenticated buckets, supply a signing transport with `WithRoundTripper`/`WithHTTPClient`; the library still wraps it with timeout, retry, and range-validation. @@ -153,8 +127,7 @@ go build -tags fts5 ./... ### Configuration Importing the package registers a `zstd` VFS with sensible defaults. To tune the -frame-cache size, HTTP timeout, retry count, HTTP read cache -(`WithHTTPCacheSize`/`WithHTTPPageSize`), transport +frame-cache size, HTTP timeout, retry count, transport (`WithRoundTripper`/`WithHTTPClient`), or logger, register your own named VFS and reference it via `?vfs=`: diff --git a/benchmark_test.go b/benchmark_test.go index 79d94f5..91f20a4 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,10 +18,9 @@ import ( _ "github.com/mattn/go-sqlite3" // ensure you import the SQLite3 driver ) -// minCacheVFS registers (once) a VFS whose frame cache holds a single frame, -// approximating the pre-cache behavior where the upstream reader kept only one -// decompressed frame. Benchmarking against it on the same fixture isolates the -// effect of the frame cache. +// minCacheVFS registers (once) a VFS whose frame cache holds a single frame. +// Benchmarking against it on the same fixture isolates the effect of the +// default frame-cache size. func minCacheVFS(b *testing.B) string { b.Helper() @@ -37,7 +36,7 @@ func minCacheVFS(b *testing.B) string { // BenchmarkReadCompressedSQLiteFTS5PorterMinCache mirrors // BenchmarkReadCompressedSQLiteFTS5Porter but with a single-frame cache, so the -// two together show the frame cache's impact (allocs/op and B/op in particular). +// two together show the frame cache's impact. func BenchmarkReadCompressedSQLiteFTS5PorterMinCache(b *testing.B) { _, zstPath := setupDB(b) @@ -340,55 +339,6 @@ func BenchmarkReadCompressedHTTPSQLite(b *testing.B) { }) } -// cacheHTTPVFS registers (once) a cache-enabled VFS for the HTTP benchmark. -func cacheHTTPVFS(b *testing.B) string { - b.Helper() - - const name = "zstd-httpcache-bench" - - if err := sqlitezstd.Register(name, sqlitezstd.WithHTTPCacheSize(64<<20)); err != nil && - !strings.Contains(err.Error(), "already") { - b.Fatalf("Failed to register http-cache vfs: %v", err) - } - - return name -} - -// BenchmarkReadCompressedHTTPSQLiteCached mirrors BenchmarkReadCompressedHTTPSQLite -// but through the coalescing HTTP cache. Over a local httptest server the latency -// win is small; the real benefit (far fewer Range GETs) is asserted by -// TestHTTPCacheCoalescesGETs. -func BenchmarkReadCompressedHTTPSQLiteCached(b *testing.B) { - _, zstPath := setupDB(b) - - zstDir := filepath.Dir(zstPath) - - server := httptest.NewServer(http.FileServer(http.Dir(zstDir))) - defer server.Close() - - vfs := cacheHTTPVFS(b) - - client, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=%s", server.URL, filepath.Base(zstPath), vfs)) - if err != nil { - b.Fatalf("Query failed: %v", err) - } - defer client.Close() //nolint: errcheck - - client.SetMaxOpenConns(max(4, runtime.NumCPU())) - - b.ResetTimer() - - b.RunParallel(func(pb *testing.PB) { - var count int - for pb.Next() { - err = client.QueryRow("SELECT MAX(value) FROM entries").Scan(&count) - if err != nil { - b.Fatalf("Query failed: %v", err) - } - } - }) -} - func BenchmarkReadCompressedRtreeSQLite(b *testing.B) { _, zstPath := setupDB(b) diff --git a/cache_internal_test.go b/cache_internal_test.go deleted file mode 100644 index 6926270..0000000 --- a/cache_internal_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package sqlitezstd - -import ( - "bytes" - "io" - "sync/atomic" - "testing" - - "github.com/klauspost/compress/zstd" -) - -// countingReaderAt records how many times the underlying source is read. -type countingReaderAt struct { - data []byte - reads atomic.Int64 -} - -func (c *countingReaderAt) ReadAt(p []byte, off int64) (int, error) { - c.reads.Add(1) - - if off >= int64(len(c.data)) { - return 0, io.EOF - } - - n := copy(p, c.data[off:]) - if n < len(p) { - return n, io.EOF - } - - return n, nil -} - -func TestFrameReaderCachesCompressedReads(t *testing.T) { - t.Parallel() - - data := make([]byte, 1024) - for i := range data { - data[i] = byte(i) - } - - src := &countingReaderAt{data: data} - - reader, err := newFrameReader(src, int64(len(data)), 8) - if err != nil { - t.Fatalf("newFrameReader: %v", err) - } - - p := make([]byte, 128) - - if _, err := reader.ReadAt(p, 256); err != nil { - t.Fatalf("first ReadAt: %v", err) - } - if got := src.reads.Load(); got != 1 { - t.Fatalf("want 1 source read after first ReadAt, got %d", got) - } - - // An identical read must be served from the cache, not the source. - if _, err := reader.ReadAt(p, 256); err != nil { - t.Fatalf("second ReadAt: %v", err) - } - if got := src.reads.Load(); got != 1 { - t.Fatalf("want still 1 source read after cached ReadAt, got %d", got) - } - - if !bytes.Equal(p, data[256:256+len(p)]) { - t.Fatal("cached ReadAt returned wrong bytes") - } -} - -// countingDecoder records how many times the underlying decoder is invoked. -type countingDecoder struct { - dec zstdDecoder - decodes atomic.Int64 -} - -func (c *countingDecoder) DecodeAll(input, dst []byte) ([]byte, error) { - c.decodes.Add(1) - - return c.dec.DecodeAll(input, dst) -} - -func TestCachingDecoderCachesDecompression(t *testing.T) { - t.Parallel() - - enc, err := zstd.NewWriter(nil) - if err != nil { - t.Fatalf("new encoder: %v", err) - } - - raw := bytes.Repeat([]byte("hello world "), 1000) - compressed := enc.EncodeAll(raw, nil) - _ = enc.Close() - - real, err := zstd.NewReader(nil) - if err != nil { - t.Fatalf("new decoder: %v", err) - } - defer real.Close() - - counter := &countingDecoder{dec: real} - - decoder, err := newCachingDecoder(counter, 8) - if err != nil { - t.Fatalf("newCachingDecoder: %v", err) - } - - out1, err := decoder.DecodeAll(compressed, nil) - if err != nil { - t.Fatalf("first DecodeAll: %v", err) - } - if !bytes.Equal(out1, raw) { - t.Fatal("first DecodeAll produced wrong output") - } - if got := counter.decodes.Load(); got != 1 { - t.Fatalf("want 1 underlying decode, got %d", got) - } - - // An identical input must be served from the cache. - out2, err := decoder.DecodeAll(compressed, nil) - if err != nil { - t.Fatalf("second DecodeAll: %v", err) - } - if !bytes.Equal(out2, raw) { - t.Fatal("second DecodeAll produced wrong output") - } - if got := counter.decodes.Load(); got != 1 { - t.Fatalf("want still 1 underlying decode after cache hit, got %d", got) - } -} diff --git a/decoder.go b/decoder.go deleted file mode 100644 index 12579d3..0000000 --- a/decoder.go +++ /dev/null @@ -1,64 +0,0 @@ -package sqlitezstd - -import ( - "sync" - - "github.com/cespare/xxhash/v2" - lru "github.com/hashicorp/golang-lru/v2" - "github.com/klauspost/compress/zstd" -) - -// sharedDecoder is a single, process-wide zstd decoder shared by every opened -// file. The seekable reader only ever calls DecodeAll, which is safe for -// concurrent use, so one decoder replaces the per-Open decoder pool that was -// allocated (and never closed) for each connection. It is intentionally never -// closed because it lives for the lifetime of the process. -// -// nolint: gochecknoglobals -var sharedDecoder = sync.OnceValues(func() (*zstd.Decoder, error) { - return zstd.NewReader(nil) -}) - -// zstdDecoder is the subset of *zstd.Decoder used here (matching the seekable -// ZSTDDecoder interface). It is an interface so the cache can be unit-tested. -type zstdDecoder interface { - DecodeAll(input, dst []byte) ([]byte, error) -} - -// cachingDecoder wraps a zstd decoder with an LRU of decompressed frames keyed -// by the hash of the compressed input. The upstream seekable reader keeps only -// a single decompressed frame, so SQLite's scattered page reads otherwise force -// the same frames to be decompressed (and freshly allocated) over and over — -// the dominant cost in the FTS5/trigram benchmarks. -type cachingDecoder struct { - dec zstdDecoder - cache *lru.Cache[uint64, []byte] -} - -func newCachingDecoder(dec zstdDecoder, size int) (*cachingDecoder, error) { - cache, err := lru.New[uint64, []byte](size) - if err != nil { - return nil, err - } - - return &cachingDecoder{dec: dec, cache: cache}, nil -} - -// DecodeAll implements the seekable ZSTDDecoder interface. The seekable reader -// only ever reads from (never mutates) the returned slice and always passes a -// nil dst, so returning a shared cached slice is safe for concurrent readers. -func (c *cachingDecoder) DecodeAll(input, dst []byte) ([]byte, error) { - key := xxhash.Sum64(input) - if cached, ok := c.cache.Get(key); ok { - return cached, nil - } - - out, err := c.dec.DecodeAll(input, dst) - if err != nil { - return nil, err - } - - _ = c.cache.Add(key, out) - - return out, nil -} diff --git a/file.go b/file.go index bafee8f..fa34894 100644 --- a/file.go +++ b/file.go @@ -54,8 +54,8 @@ func (z *ZstdFile) SectorSize() int64 { // A whole zstd frame must be decompressed to serve any byte within it, but // SQLite reads a read-only immutable database page-by-page regardless of the // reported sector size — the intra-frame locality win is captured by the - // frame cache (see decoder.go / readerat.go), not by this value. Reporting 0 - // keeps SQLite on its default behavior. + // frame cache, not by this value. Reporting 0 keeps SQLite on its default + // behavior. return 0 } diff --git a/go.mod b/go.mod index 29235a0..8ed1255 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,7 @@ go 1.25.0 require ( github.com/SaveTheRbtz/zstd-seekable-format-go/pkg v0.10.0 github.com/brianvoe/gofakeit/v7 v7.7.3 - github.com/cespare/xxhash/v2 v2.3.0 github.com/georgysavva/scany/v2 v2.1.4 - github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/klauspost/compress v1.18.6 github.com/mattn/go-sqlite3 v1.14.32 github.com/onsi/ginkgo/v2 v2.26.0 @@ -18,6 +16,7 @@ require ( require ( github.com/Masterminds/semver/v3 v3.4.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/google/go-cmp v0.7.0 // indirect diff --git a/go.sum b/go.sum index 143a4a2..d50bff8 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,6 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20251002213607-436353cc1ee6 h1:/WHh/1k4thM/w+PAZEIiZK9NwCMFahw5tUzKUCnUtds= github.com/google/pprof v0.0.0-20251002213607-436353cc1ee6/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= diff --git a/httpcache.go b/httpcache.go deleted file mode 100644 index 8de9a93..0000000 --- a/httpcache.go +++ /dev/null @@ -1,179 +0,0 @@ -package sqlitezstd - -import ( - "errors" - "io" - "sync" - - lru "github.com/hashicorp/golang-lru/v2" -) - -// DefaultHTTPPageSize is the coalescing/read-ahead page size used when caching -// is enabled without an explicit page size. A small frame read pulls the whole -// page, so adjacent frames become cache hits served without another GET. -const DefaultHTTPPageSize = 64 << 10 // 64 KiB - -// httpReadCache implements httpreadat.CacheHandler with an in-memory, -// page-aligned, bounded cache. It coalesces the contiguous run of missing pages -// covering a read into a single underlying fetch (one HTTP Range GET), which is -// what collapses a remote query's many one-frame GETs into far fewer requests. -// -// It is created per opened file, so under SQLite's per-connection serialization -// there is no real contention; the mutex is correctness insurance (and makes it -// safe under go test -race). -type httpReadCache struct { - pageSize int64 - fileSize int64 - - mu sync.Mutex - pages *lru.Cache[int64, []byte] - - hits int - misses int -} - -func newHTTPReadCache(fileSize int64, pageSize int, maxBytes int64) (*httpReadCache, error) { - if pageSize <= 0 { - pageSize = DefaultHTTPPageSize - } - - if maxBytes <= 0 { - return nil, errors.New("sqlitezstd: http cache size must be positive") - } - - maxPages := int(maxBytes / int64(pageSize)) - if maxPages < 1 { - maxPages = 1 - } - - pages, err := lru.New[int64, []byte](maxPages) - if err != nil { - return nil, err - } - - return &httpReadCache{ - pageSize: int64(pageSize), - fileSize: fileSize, - pages: pages, - }, nil -} - -// Get implements httpreadat.CacheHandler. -func (c *httpReadCache) Get(p []byte, off int64, fetcher io.ReaderAt) (int, error) { - if len(p) == 0 { - return 0, nil - } - - c.mu.Lock() - defer c.mu.Unlock() - - startPage := off / c.pageSize - endPage := (off + int64(len(p)) - 1) / c.pageSize - - // Collect the page slices covering the request, recording the contiguous run - // of missing pages. Hits are gathered as references so that adding the - // fetched pages below cannot disturb the assembly of this read. - pageData := make(map[int64][]byte, endPage-startPage+1) - - firstMissing, lastMissing := int64(-1), int64(-1) - - for i := startPage; i <= endPage; i++ { - if b, ok := c.pages.Get(i); ok { - pageData[i] = b - - continue - } - - if firstMissing < 0 { - firstMissing = i - } - - lastMissing = i - } - - if firstMissing >= 0 { - c.misses++ - - if err := c.fetchPages(firstMissing, lastMissing, fetcher, pageData); err != nil { - return 0, err - } - } else { - c.hits++ - } - - return c.assemble(p, off, pageData), nil -} - -// fetchPages reads [firstMissing, lastMissing] in a single ReadAt, slicing the -// result into per-page entries that are both returned (via pageData) and stored -// in the LRU for future reads. -func (c *httpReadCache) fetchPages(firstMissing, lastMissing int64, fetcher io.ReaderAt, pageData map[int64][]byte) error { - fetchStart := firstMissing * c.pageSize - - fetchEnd := (lastMissing + 1) * c.pageSize - if fetchEnd > c.fileSize { - fetchEnd = c.fileSize - } - - buf := make([]byte, fetchEnd-fetchStart) - - n, err := fetcher.ReadAt(buf, fetchStart) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - buf = buf[:n] - - for i := firstMissing; i <= lastMissing; i++ { - lo := (i - firstMissing) * c.pageSize - if lo >= int64(len(buf)) { - break - } - - hi := lo + c.pageSize - if hi > int64(len(buf)) { - hi = int64(len(buf)) - } - - page := make([]byte, hi-lo) - copy(page, buf[lo:hi]) - pageData[i] = page - _ = c.pages.Add(i, page) - } - - return nil -} - -// assemble copies the requested range out of the gathered page slices, stopping -// at the end of available data (a short read at EOF). -func (c *httpReadCache) assemble(p []byte, off int64, pageData map[int64][]byte) int { - copied := 0 - - for copied < len(p) { - cur := off + int64(copied) - pageIndex := cur / c.pageSize - - page, ok := pageData[pageIndex] - if !ok { - break - } - - within := int(cur - pageIndex*c.pageSize) - if within >= len(page) { - break - } - - copied += copy(p[copied:], page[within:]) - } - - return copied -} - -// stats reports cache hit/miss counts and the number of resident pages. Used by -// tests. -func (c *httpReadCache) stats() (hits, misses, pages int) { - c.mu.Lock() - defer c.mu.Unlock() - - return c.hits, c.misses, c.pages.Len() -} diff --git a/httpcache_internal_test.go b/httpcache_internal_test.go deleted file mode 100644 index 3750d4d..0000000 --- a/httpcache_internal_test.go +++ /dev/null @@ -1,164 +0,0 @@ -package sqlitezstd - -import ( - "bytes" - "io" - "sync/atomic" - "testing" -) - -// countingFetcher is an io.ReaderAt that records how many times it is called and -// how many bytes it serves. -type countingFetcher struct { - data []byte - calls atomic.Int64 - bytesRead atomic.Int64 -} - -func (f *countingFetcher) ReadAt(p []byte, off int64) (int, error) { - f.calls.Add(1) - - if off >= int64(len(f.data)) { - return 0, io.EOF - } - - n := copy(p, f.data[off:]) - f.bytesRead.Add(int64(n)) - - if n < len(p) { - return n, io.EOF - } - - return n, nil -} - -func makeData(size int) []byte { - data := make([]byte, size) - for i := range data { - data[i] = byte(i) - } - - return data -} - -// nolint: cyclop -func TestHTTPReadCacheCoalescesAndCaches(t *testing.T) { - t.Parallel() - - data := makeData(300 * 1024) - fetcher := &countingFetcher{data: data} - - cache, err := newHTTPReadCache(int64(len(data)), 64*1024, 8*1024*1024) - if err != nil { - t.Fatalf("newHTTPReadCache: %v", err) - } - - // A small 4 KiB read should pull the whole 64 KiB page in one fetch. - p := make([]byte, 4096) - - n, err := cache.Get(p, 1000, fetcher) - if err != nil || n != len(p) { - t.Fatalf("Get: n=%d err=%v", n, err) - } - if !bytes.Equal(p, data[1000:1000+len(p)]) { - t.Fatal("first Get returned wrong bytes") - } - if got := fetcher.calls.Load(); got != 1 { - t.Fatalf("want 1 fetch, got %d", got) - } - if got := fetcher.bytesRead.Load(); got != 64*1024 { - t.Fatalf("want a 64 KiB read-ahead fetch, got %d bytes", got) - } - - // Another read within the same page is served from cache: no new fetch. - n, err = cache.Get(p, 5000, fetcher) - if err != nil || n != len(p) { - t.Fatalf("second Get: n=%d err=%v", n, err) - } - if !bytes.Equal(p, data[5000:5000+len(p)]) { - t.Fatal("second Get returned wrong bytes") - } - if got := fetcher.calls.Load(); got != 1 { - t.Fatalf("want still 1 fetch after cache hit, got %d", got) - } - - hits, misses, _ := cache.stats() - if hits != 1 || misses != 1 { - t.Fatalf("stats: hits=%d misses=%d (want 1/1)", hits, misses) - } -} - -func TestHTTPReadCacheCrossPage(t *testing.T) { - t.Parallel() - - data := makeData(300 * 1024) - fetcher := &countingFetcher{data: data} - - cache, err := newHTTPReadCache(int64(len(data)), 64*1024, 8*1024*1024) - if err != nil { - t.Fatalf("newHTTPReadCache: %v", err) - } - - // A read spanning three pages must be assembled correctly. - p := make([]byte, 100*1024) - - n, err := cache.Get(p, 30*1024, fetcher) - if err != nil || n != len(p) { - t.Fatalf("Get: n=%d err=%v", n, err) - } - if !bytes.Equal(p, data[30*1024:30*1024+len(p)]) { - t.Fatal("cross-page Get returned wrong bytes") - } -} - -func TestHTTPReadCacheBounded(t *testing.T) { - t.Parallel() - - data := makeData(1024 * 1024) - fetcher := &countingFetcher{data: data} - - // Cap at 2 pages (128 KiB) with a 64 KiB page size. - cache, err := newHTTPReadCache(int64(len(data)), 64*1024, 128*1024) - if err != nil { - t.Fatalf("newHTTPReadCache: %v", err) - } - - p := make([]byte, 4096) - - for off := int64(0); off+int64(len(p)) <= int64(len(data)); off += 64 * 1024 { - n, err := cache.Get(p, off, fetcher) - if err != nil || n != len(p) { - t.Fatalf("Get @%d: n=%d err=%v", off, n, err) - } - if !bytes.Equal(p, data[off:off+int64(len(p))]) { - t.Fatalf("wrong bytes @%d", off) - } - } - - if _, _, pages := cache.stats(); pages > 2 { - t.Fatalf("cache exceeded its 2-page cap: %d pages resident", pages) - } -} - -func TestHTTPReadCacheReadToEOF(t *testing.T) { - t.Parallel() - - // File size not a multiple of the page size, to exercise the partial last page. - data := makeData(70 * 1024) - fetcher := &countingFetcher{data: data} - - cache, err := newHTTPReadCache(int64(len(data)), 64*1024, 8*1024*1024) - if err != nil { - t.Fatalf("newHTTPReadCache: %v", err) - } - - p := make([]byte, 4096) - - n, err := cache.Get(p, int64(len(data))-2048, fetcher) - if n != 2048 { - t.Fatalf("want 2048 bytes at EOF, got %d (err=%v)", n, err) - } - if !bytes.Equal(p[:n], data[len(data)-2048:]) { - t.Fatal("EOF read returned wrong bytes") - } -} diff --git a/httpcache_test.go b/httpcache_test.go deleted file mode 100644 index ff3f89b..0000000 --- a/httpcache_test.go +++ /dev/null @@ -1,190 +0,0 @@ -package sqlitezstd_test - -import ( - "database/sql" - "fmt" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" - "sync" - "sync/atomic" - "testing" - - sqlitezstd "github.com/jtarchie/sqlitezstd" - _ "github.com/mattn/go-sqlite3" -) - -// countingFileServer serves dir over HTTP (with Range support) and counts the -// number of Range GET requests it receives. -func countingFileServer(dir string) (*httptest.Server, *int64) { - var rangeGETs int64 - - fileServer := http.FileServer(http.Dir(dir)) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Range") != "" { - atomic.AddInt64(&rangeGETs, 1) - } - - fileServer.ServeHTTP(w, r) - })) - - return server, &rangeGETs -} - -// registerCacheVFS registers (once) a cache-enabled VFS under name. -func registerCacheVFS(t *testing.T, name string, opts ...sqlitezstd.Option) { - t.Helper() - - if err := sqlitezstd.Register(name, opts...); err != nil && !strings.Contains(err.Error(), "already") { - t.Fatalf("register %q: %v", name, err) - } -} - -func TestHTTPCacheCoalescesGETs(t *testing.T) { - t.Parallel() - - zstPath := buildCompressedDB(t, 100_000) - dir := filepath.Dir(zstPath) - base := filepath.Base(zstPath) - - const cacheVFS = "zstdcache-coalesce" - registerCacheVFS(t, cacheVFS, sqlitezstd.WithHTTPCacheSize(16<<20)) - - // A full table scan (value is not indexed) touches every data page, so it - // spans many frames — the case coalescing helps most. - const query = "SELECT COUNT(*) FROM entries WHERE value LIKE 'value-1%'" - - run := func(vfs string) (int64, int64) { - server, gets := countingFileServer(dir) - defer server.Close() - - db, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=%s", server.URL, base, vfs)) - if err != nil { - t.Fatalf("open (%s): %v", vfs, err) - } - defer db.Close() //nolint: errcheck - - db.SetMaxOpenConns(1) - - var count int64 - if err := db.QueryRow(query).Scan(&count); err != nil { - t.Fatalf("query (%s): %v", vfs, err) - } - - return count, atomic.LoadInt64(gets) - } - - wantCount, defaultGETs := run("zstd") - gotCount, cacheGETs := run(cacheVFS) - - t.Logf("range GETs: default(no cache)=%d cache=%d (%.1fx fewer)", - defaultGETs, cacheGETs, float64(defaultGETs)/float64(max64(cacheGETs, 1))) - - if gotCount != wantCount { - t.Fatalf("cache changed query result: got %d, want %d", gotCount, wantCount) - } - if cacheGETs == 0 { - t.Fatal("expected at least one range GET through the cache") - } - if cacheGETs*2 > defaultGETs { - t.Fatalf("cache did not substantially reduce GETs: default=%d cache=%d", defaultGETs, cacheGETs) - } -} - -func TestHTTPCachePerURLCorrectness(t *testing.T) { - t.Parallel() - - zstA := buildCompressedDB(t, 1_000) - zstB := buildCompressedDB(t, 7_777) - - const cacheVFS = "zstdcache-perurl" - registerCacheVFS(t, cacheVFS, sqlitezstd.WithHTTPCacheSize(8<<20)) - - count := func(zstPath string) int64 { - server, _ := countingFileServer(filepath.Dir(zstPath)) - defer server.Close() - - db, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=%s", server.URL, filepath.Base(zstPath), cacheVFS)) - if err != nil { - t.Fatalf("open: %v", err) - } - defer db.Close() //nolint: errcheck - - var n int64 - if err := db.QueryRow("SELECT COUNT(*) FROM entries").Scan(&n); err != nil { - t.Fatalf("query: %v", err) - } - - return n - } - - if got := count(zstA); got != 1_000 { - t.Fatalf("db A: got %d, want 1000", got) - } - if got := count(zstB); got != 7_777 { - t.Fatalf("db B: got %d, want 7777", got) - } -} - -func TestHTTPCacheConcurrentSameURL(t *testing.T) { - t.Parallel() - - zstPath := buildCompressedDB(t, 5_000) - - const cacheVFS = "zstdcache-concurrent" - registerCacheVFS(t, cacheVFS, sqlitezstd.WithHTTPCacheSize(8<<20)) - - server, _ := countingFileServer(filepath.Dir(zstPath)) - defer server.Close() - - db, err := sql.Open("sqlite3", fmt.Sprintf("%s/%s?vfs=%s", server.URL, filepath.Base(zstPath), cacheVFS)) - if err != nil { - t.Fatalf("open: %v", err) - } - defer db.Close() //nolint: errcheck - - db.SetMaxOpenConns(8) - - const ( - goroutines = 8 - iterations = 50 - ) - - var wg sync.WaitGroup - - errs := make(chan error, goroutines) - - for range goroutines { - wg.Add(1) - - go func() { - defer wg.Done() - - for i := range iterations { - var count int64 - if err := db.QueryRow("SELECT COUNT(*) FROM entries WHERE id > ?", i%100).Scan(&count); err != nil { - errs <- err - - return - } - } - }() - } - - wg.Wait() - close(errs) - - for err := range errs { - t.Fatalf("concurrent cached read failed: %v", err) - } -} - -func max64(a, b int64) int64 { - if a > b { - return a - } - - return b -} diff --git a/options.go b/options.go index 0d219f8..fcd032e 100644 --- a/options.go +++ b/options.go @@ -9,11 +9,8 @@ import ( // Default option values. These are applied by [Register] and by the default // "zstd" VFS registered in init(). const ( - // DefaultFrameCacheSize is the number of zstd frames cached per opened file - // (both the compressed bytes and the decompressed output). The upstream - // seekable reader only caches a single frame, so this cache is what keeps - // SQLite's scattered page reads from repeatedly re-fetching and - // re-decompressing the same frames. + // DefaultFrameCacheSize is the number of decoded zstd frames cached per + // opened file. DefaultFrameCacheSize = 64 // DefaultHTTPTimeout bounds dialing and waiting for response headers on the // HTTP(S) path so a hung server cannot block a query indefinitely. @@ -31,8 +28,6 @@ type Options struct { frameCacheSize int httpTimeout time.Duration httpMaxRetries int - httpCacheBytes int64 - httpPageSize int roundTripper http.RoundTripper logger *slog.Logger } @@ -41,7 +36,7 @@ type Options struct { // [WithHTTPRetries], and [WithLogger]. type Option func(*Options) -// WithFrameCacheSize sets the number of zstd frames cached per opened file. +// WithFrameCacheSize sets the number of decoded zstd frames cached per opened file. // Values <= 0 are ignored (the default is kept). func WithFrameCacheSize(frames int) Option { return func(o *Options) { @@ -71,29 +66,6 @@ func WithHTTPRetries(n int) Option { } } -// WithHTTPCacheSize enables an in-memory, page-aligned read cache for the -// HTTP(S) path, bounded to roughly maxBytes. It coalesces the contiguous run of -// missing pages covering a read into a single Range GET and serves adjacent -// frames from cache, drastically cutting request count for remote scans. A -// value <= 0 (the default) disables the cache, preserving the original -// one-GET-per-frame behavior. Has no effect on the local-file path. -func WithHTTPCacheSize(maxBytes int64) Option { - return func(o *Options) { - o.httpCacheBytes = maxBytes - } -} - -// WithHTTPPageSize sets the coalescing/read-ahead page size for the HTTP cache -// (see [WithHTTPCacheSize]). Values <= 0 keep the default of -// [DefaultHTTPPageSize]. -func WithHTTPPageSize(bytes int) Option { - return func(o *Options) { - if bytes > 0 { - o.httpPageSize = bytes - } - } -} - // WithRoundTripper sets the base http.RoundTripper used for the HTTP(S) path. // The library still wraps it with retry and Range-response validation, so a // caller can supply, for example, a request-signing transport for authenticated @@ -132,7 +104,6 @@ func defaultOptions() *Options { frameCacheSize: DefaultFrameCacheSize, httpTimeout: DefaultHTTPTimeout, httpMaxRetries: DefaultHTTPMaxRetries, - httpPageSize: DefaultHTTPPageSize, logger: slog.Default(), } } diff --git a/readerat.go b/readerat.go index 283b5fc..02ea787 100644 --- a/readerat.go +++ b/readerat.go @@ -17,8 +17,6 @@ import ( "errors" "io" "sync" - - lru "github.com/hashicorp/golang-lru/v2" ) var ( @@ -29,22 +27,18 @@ var ( // frameReader adapts a fixed-size io.ReaderAt (a local *os.File or the HTTP // range reader) into the io.ReadSeeker the seekable reader requires, while also -// exposing a cached io.ReaderAt. +// exposing io.ReaderAt directly. // -// Exposing ReadAt is important for two reasons: the seekable reader only takes -// its concurrency-safe, positional fast-path when the underlying reader -// implements io.ReaderAt (otherwise it falls back to a mutex-guarded Seek+Read); -// and the LRU here caches the compressed bytes of each frame keyed by offset, so -// frames the upstream single-frame cache has evicted are not re-fetched from -// disk or — far more importantly — re-fetched over the network. +// Exposing ReadAt is important because the seekable reader only takes its +// concurrency-safe, positional fast-path when the underlying reader implements +// io.ReaderAt; otherwise it falls back to a mutex-guarded Seek+Read. // // ReadAt is safe for concurrent use. The sequential Read/Seek methods (used only // for the seek-table footer at open time) are guarded by mu and must not be // called concurrently with each other. type frameReader struct { - src io.ReaderAt - size int64 - cache *lru.Cache[int64, []byte] + src io.ReaderAt + size int64 mu sync.Mutex offset int64 @@ -56,18 +50,7 @@ var ( _ io.Closer = (*frameReader)(nil) ) -func newFrameReader(src io.ReaderAt, size int64, cacheSize int) (*frameReader, error) { - cache, err := lru.New[int64, []byte](cacheSize) - if err != nil { - return nil, err - } - - return &frameReader{src: src, size: size, cache: cache}, nil -} - -// ReadAt implements io.ReaderAt and is safe for concurrent use. Compressed frame -// bytes are served from the LRU when present (the seekable reader always -// requests a given frame at the same offset and length). +// ReadAt implements io.ReaderAt and is safe for concurrent use. func (r *frameReader) ReadAt(p []byte, off int64) (int, error) { if r.size < 0 { return 0, errInvalidSize @@ -76,20 +59,7 @@ func (r *frameReader) ReadAt(p []byte, off int64) (int, error) { return 0, nil } - if cached, ok := r.cache.Get(off); ok && len(cached) == len(p) { - copy(p, cached) - - return len(p), nil - } - - n, err := r.src.ReadAt(p, off) - if (err == nil || errors.Is(err, io.EOF)) && n == len(p) { - buf := make([]byte, n) - copy(buf, p[:n]) - _ = r.cache.Add(off, buf) - } - - return n, err + return r.src.ReadAt(p, off) } // Read implements io.Reader. It is not safe for concurrent use. diff --git a/vfs.go b/vfs.go index 3c83dda..8cf3048 100644 --- a/vfs.go +++ b/vfs.go @@ -10,6 +10,8 @@ import ( "sync" seekable "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg" + "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg/framecache" + "github.com/klauspost/compress/zstd" _ "github.com/mattn/go-sqlite3" "github.com/psanford/httpreadat" "github.com/psanford/sqlite3vfs" @@ -22,6 +24,17 @@ type ZstdVFS struct { var _ sqlite3vfs.VFS = &ZstdVFS{} +// sharedDecoder is a single, process-wide zstd decoder shared by every opened +// file. The seekable reader only ever calls DecodeAll, which is safe for +// concurrent use, so one decoder replaces the per-Open decoder pool that was +// allocated (and never closed) for each connection. It is intentionally never +// closed because it lives for the lifetime of the process. +// +// nolint: gochecknoglobals +var sharedDecoder = sync.OnceValues(func() (*zstd.Decoder, error) { + return zstd.NewReader(nil) +}) + // Register registers a zstd VFS under the given name with the supplied options. // Open a database against it with the "?vfs=" query parameter. The default // "zstd" VFS (registered automatically on import) uses default options. @@ -84,26 +97,13 @@ func (z *ZstdVFS) resolveSource(name string) (io.ReaderAt, int64, error) { return nil, 0, fmt.Errorf("parse url: %w", err) } - rangerOpts := []httpreadat.Option{httpreadat.WithRoundTripper(newRangeRoundTripper(z.opts))} - - ranger := httpreadat.New(uri.String(), rangerOpts...) + ranger := httpreadat.New(uri.String(), httpreadat.WithRoundTripper(newRangeRoundTripper(z.opts))) size, err := ranger.Size() if err != nil { return nil, 0, fmt.Errorf("determine remote size: %w", err) } - if z.opts.httpCacheBytes > 0 { - cache, err := newHTTPReadCache(size, z.opts.httpPageSize, z.opts.httpCacheBytes) - if err != nil { - return nil, 0, fmt.Errorf("create http cache: %w", err) - } - - // Re-create the ranger with the coalescing cache handler installed, - // reusing the already-built transport. - ranger = httpreadat.New(uri.String(), append(rangerOpts, httpreadat.WithCacheHandler(cache))...) - } - return ranger, size, nil } @@ -131,13 +131,9 @@ func (z *ZstdVFS) open(name string) (_ *ZstdFile, err error) { return nil, err } - reader, err := newFrameReader(src, size, z.opts.frameCacheSize) - if err != nil { - if closer, ok := src.(io.Closer); ok { - _ = closer.Close() - } - - return nil, fmt.Errorf("create frame reader: %w", err) + reader := &frameReader{ + src: src, + size: size, } defer func() { @@ -151,12 +147,13 @@ func (z *ZstdVFS) open(name string) (_ *ZstdFile, err error) { return nil, fmt.Errorf("create zstd decoder: %w", err) } - cachingDec, err := newCachingDecoder(decoder, z.opts.frameCacheSize) - if err != nil { - return nil, fmt.Errorf("create caching decoder: %w", err) - } - - sr, err := seekable.NewReader(reader, cachingDec) + sr, err := seekable.NewReader( + reader, + decoder, + seekable.WithReaderFrameCache(framecache.NewSieve(framecache.Limits{ + MaxFrames: z.opts.frameCacheSize, + })), + ) if err != nil { return nil, fmt.Errorf("create seekable reader: %w", err) }