diff --git a/cmd/serve.go b/cmd/serve.go
index ed23057..e9755ad 100644
--- a/cmd/serve.go
+++ b/cmd/serve.go
@@ -9,15 +9,19 @@ import (
"os"
"os/signal"
"path/filepath"
+ "runtime"
+ "strings"
"syscall"
"time"
"github.com/RandomCodeSpace/docsiq/internal/api"
+ "github.com/RandomCodeSpace/docsiq/internal/config"
"github.com/RandomCodeSpace/docsiq/internal/embedder"
"github.com/RandomCodeSpace/docsiq/internal/llm"
"github.com/RandomCodeSpace/docsiq/internal/project"
"github.com/RandomCodeSpace/docsiq/internal/sqlitevec"
"github.com/RandomCodeSpace/docsiq/internal/vectorindex"
+ "github.com/RandomCodeSpace/docsiq/internal/workq"
"github.com/spf13/cobra"
)
@@ -140,11 +144,25 @@ var serveCmd = &cobra.Command{
}
}
+ workers := cfg.Server.WorkqWorkers
+ if workers <= 0 {
+ workers = runtime.NumCPU()
+ }
+ depth := cfg.Server.WorkqDepth
+ if depth <= 0 {
+ depth = 64
+ }
+ pool := workq.New(workq.Config{Workers: workers, QueueDepth: depth})
+
router := api.NewRouter(prov, emb, cfg, registry,
api.WithProjectStores(stores),
api.WithVectorIndexes(vecIndexes),
+ api.WithWorkq(pool),
)
+ if err := validateServeSecurity(cfg); err != nil {
+ return err
+ }
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
ln, err := net.Listen("tcp", addr)
if err != nil {
@@ -177,6 +195,17 @@ var serveCmd = &cobra.Command{
slog.Error("❌ shutdown error", "err", err)
return err
}
+
+ // Drain workq within its own 30s deadline. Server.Shutdown has already
+ // stopped accepting new HTTP requests, so no new jobs can be submitted;
+ // all that remains is letting in-flight pipelines finish or honour the
+ // cancelled ctx.
+ drainCtx, drainCancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer drainCancel()
+ if err := pool.Close(drainCtx); err != nil {
+ slog.Warn("⚠️ workq drain timeout; some indexing jobs were cancelled mid-flight", "err", err)
+ }
+
slog.Info("✅ shutdown complete")
return nil
},
@@ -187,3 +216,33 @@ func init() {
serveCmd.Flags().StringVar(&serveHost, "host", "", "Server host (overrides config)")
serveCmd.Flags().IntVar(&servePort, "port", 0, "Server port (overrides config)")
}
+
+// validateServeSecurity refuses to start the server when the API key is
+// empty AND the bind host is not loopback. An unauthenticated service
+// exposed on the network is almost never intentional; make it explicit.
+// Loopback with empty key gets a prominent warning at boot instead.
+func validateServeSecurity(cfg *config.Config) error {
+ if cfg.Server.APIKey != "" {
+ return nil
+ }
+ host := strings.ToLower(strings.TrimSpace(cfg.Server.Host))
+ if host == "" {
+ return fmt.Errorf(
+ "server.api_key is empty and server.host is unset (binds all interfaces); refusing to start. " +
+ "Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev",
+ )
+ }
+ loopback := host == "localhost"
+ if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil {
+ loopback = loopback || ip.IsLoopback()
+ }
+ if !loopback {
+ return fmt.Errorf(
+ "server.api_key is empty and server.host=%q is not loopback; refusing to start. "+
+ "Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev",
+ cfg.Server.Host,
+ )
+ }
+ slog.Warn("⚠️ auth disabled (empty server.api_key); only loopback bind allowed", "host", host)
+ return nil
+}
diff --git a/cmd/serve_test.go b/cmd/serve_test.go
new file mode 100644
index 0000000..7480f90
--- /dev/null
+++ b/cmd/serve_test.go
@@ -0,0 +1,87 @@
+package cmd
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/RandomCodeSpace/docsiq/internal/config"
+)
+
+func TestValidateServeSecurity_RefusesNonLoopbackWithEmptyKey(t *testing.T) {
+ t.Parallel()
+ cases := []struct {
+ name string
+ host string
+ mustContain []string
+ }{
+ {
+ name: "empty host binds all interfaces",
+ host: "",
+ mustContain: []string{"binds all interfaces", "DOCSIQ_SERVER_API_KEY"},
+ },
+ {
+ name: "0.0.0.0 is wildcard",
+ host: "0.0.0.0",
+ mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"},
+ },
+ {
+ name: "public IPv4",
+ host: "10.0.0.5",
+ mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"},
+ },
+ }
+ for _, tc := range cases {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Server.Host = tc.host
+ cfg.Server.Port = 8080
+ cfg.Server.APIKey = ""
+
+ err := validateServeSecurity(cfg)
+ if err == nil {
+ t.Fatalf("host=%q: expected error for empty api_key on non-loopback bind; got nil", tc.host)
+ }
+ for _, want := range tc.mustContain {
+ if !strings.Contains(err.Error(), want) {
+ t.Errorf("host=%q: error should contain %q; got %v", tc.host, want, err)
+ }
+ }
+ })
+ }
+}
+
+func TestValidateServeSecurity_AllowsLoopbackWithEmptyKey(t *testing.T) {
+ t.Parallel()
+ hosts := []string{
+ "127.0.0.1",
+ "localhost",
+ "::1",
+ "[::1]", // bracketed IPv6
+ "127.0.0.2", // other address in 127.0.0.0/8
+ "::ffff:127.0.0.1", // IPv6-mapped IPv4
+ }
+ for _, host := range hosts {
+ host := host
+ t.Run(host, func(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Server.Host = host
+ cfg.Server.APIKey = ""
+ if err := validateServeSecurity(cfg); err != nil {
+ t.Fatalf("host=%s: expected nil; got %v", host, err)
+ }
+ })
+ }
+}
+
+func TestValidateServeSecurity_AllowsNonLoopbackWithKey(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Server.Host = "0.0.0.0"
+ cfg.Server.APIKey = "s3cret"
+ if err := validateServeSecurity(cfg); err != nil {
+ t.Fatalf("expected nil; got %v", err)
+ }
+}
diff --git a/internal/api/auth.go b/internal/api/auth.go
index c1e4a0f..185422f 100644
--- a/internal/api/auth.go
+++ b/internal/api/auth.go
@@ -50,17 +50,31 @@ func bearerAuthMiddleware(apiKey string, next http.Handler) http.Handler {
return
}
- raw := strings.TrimSpace(r.Header.Get("Authorization"))
- const prefix = "Bearer "
- if !strings.HasPrefix(raw, prefix) {
+ // /api/session is the auth boundary itself — always public.
+ if path == "/api/session" {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ // Defense-in-depth: reject immediately if the server has no key
+ // configured. This mirrors newSessionHandler's guard and keeps the
+ // middleware correct under future refactors (rather than relying on
+ // the no_token branch firing because keyBytes would also be empty).
+ if apiKey == "" {
+ slog.Warn("🔒 auth failure", "path", path, "remote_addr", r.RemoteAddr, "reason", "server_misconfigured")
+ writeJSON401(w)
+ return
+ }
+
+ token := extractToken(r)
+ if token == "" {
slog.Warn("🔒 auth failure",
"path", path,
"remote_addr", r.RemoteAddr,
- "reason", "no_bearer_prefix")
+ "reason", "no_token")
writeJSON401(w)
return
}
- token := raw[len(prefix):]
if subtle.ConstantTimeCompare([]byte(token), keyBytes) != 1 {
slog.Warn("🔒 auth failure",
"path", path,
@@ -82,3 +96,20 @@ func writeJSON401(w http.ResponseWriter) {
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
}
+
+// extractToken returns the bearer token from either the Authorization
+// header (preferred, for machine clients) or the session cookie (for
+// browser clients after POST /api/session). Returns "" if neither.
+func extractToken(r *http.Request) string {
+ raw := strings.TrimSpace(r.Header.Get("Authorization"))
+ const prefix = "Bearer "
+ if strings.HasPrefix(raw, prefix) {
+ return raw[len(prefix):]
+ }
+ if c, err := r.Cookie(sessionCookieName); err == nil {
+ if v := strings.TrimSpace(c.Value); v != "" {
+ return v
+ }
+ }
+ return ""
+}
diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go
index d447b42..6d4804f 100644
--- a/internal/api/auth_test.go
+++ b/internal/api/auth_test.go
@@ -515,6 +515,29 @@ func TestBearerAuthMiddleware(t *testing.T) {
})
}
+func TestAuth_AcceptsValidCookie(t *testing.T) {
+ t.Parallel()
+ h := buildAuthHandler("s3cret")
+ req := httptest.NewRequest(http.MethodGet, "/api/ping", nil)
+ req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "s3cret"})
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, req)
+ if rr.Code != http.StatusOK {
+ t.Fatalf("want 200 with valid cookie, got %d", rr.Code)
+ }
+}
+
+func TestAuth_RejectsMissingBothHeaderAndCookie(t *testing.T) {
+ t.Parallel()
+ h := buildAuthHandler("s3cret")
+ req := httptest.NewRequest(http.MethodGet, "/api/ping", nil)
+ rr := httptest.NewRecorder()
+ h.ServeHTTP(rr, req)
+ if rr.Code != http.StatusUnauthorized {
+ t.Fatalf("want 401, got %d", rr.Code)
+ }
+}
+
// BenchmarkBearerAuth_WrongKey ensures the constant-time compare path runs
// under the benchmark harness. It is NOT a hard timing assertion — if the
// code ever switches to a non-constant-time compare (==, bytes.Equal), the
diff --git a/internal/api/handlers.go b/internal/api/handlers.go
index aa11bf9..1edcd83 100644
--- a/internal/api/handlers.go
+++ b/internal/api/handlers.go
@@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"log/slog"
@@ -23,6 +24,7 @@ import (
"github.com/RandomCodeSpace/docsiq/internal/search"
"github.com/RandomCodeSpace/docsiq/internal/store"
"github.com/RandomCodeSpace/docsiq/internal/vectorindex"
+ "github.com/RandomCodeSpace/docsiq/internal/workq"
)
// handlers is the REST-side doc router state. Wave-2 drop: the
@@ -38,6 +40,9 @@ type handlers struct {
// return nil for a slug with no embeddings; LocalSearch falls back
// to brute-force in that case.
vecIndexes *VectorIndexes
+ // workq is the bounded worker pool for upload indexing jobs. When
+ // nil (dev/test path), upload() falls back to a detached goroutine.
+ workq *workq.Pool
// Upload progress tracking
uploadMu sync.Mutex
@@ -403,8 +408,20 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) {
return
}
slug := ProjectFromContext(r.Context())
- // TODO(docsiq): P2-1 wrap r.Body with http.MaxBytesReader before ParseMultipartForm
- if err := r.ParseMultipartForm(128 << 20); err != nil {
+ if !enforceUploadLimit(w, r, h.cfg.Server.MaxUploadBytes) {
+ return
+ }
+ if err := r.ParseMultipartForm(32 << 20); err != nil {
+ // MaxBytesReader translates overflow into an error here; the
+ // response header is already 413 when that happens. For other
+ // malformed-form errors we emit a 400.
+ var mbe *http.MaxBytesError
+ if errors.As(err, &mbe) {
+ // http.MaxBytesReader has already called w.WriteHeader(413)
+ // internally; calling it again would emit "http: superfluous
+ // response.WriteHeader call". Just return.
+ return
+ }
writeError(w, r, 400, "parse form: "+err.Error(), nil)
return
}
@@ -484,25 +501,25 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) {
h.setProgress(jobID, fmt.Sprintf("queued: %d files", len(paths)))
- // Use a detached context so the background goroutine is not cancelled
- // when the HTTP response is sent.
- bgCtx := context.Background()
- tmpDirOwned = true
-
- go func() {
+ job := func(ctx context.Context) {
defer os.RemoveAll(tmpDir)
pl := pipeline.New(st, h.provider, h.cfg)
for _, p := range paths {
+ if ctx.Err() != nil {
+ slog.Warn("🛑 upload indexing cancelled on shutdown", "job_id", jobID, "file", filepath.Base(p))
+ h.setProgress(jobID, "cancelled")
+ return
+ }
slog.Info("📦 upload indexing file", "job_id", jobID, "file", filepath.Base(p))
h.setProgress(jobID, fmt.Sprintf("indexing: %s", filepath.Base(p)))
- if err := pl.IndexPath(bgCtx, p, pipeline.IndexOptions{}); err != nil {
+ if err := pl.IndexPath(ctx, p, pipeline.IndexOptions{}); err != nil {
slog.Error("❌ upload indexing failed", "job_id", jobID, "file", filepath.Base(p), "err", err)
h.setProgress(jobID, fmt.Sprintf("error: %v", err))
return
}
}
h.setProgress(jobID, "finalizing")
- if err := pl.Finalize(bgCtx, false, true); err != nil {
+ if err := pl.Finalize(ctx, false, true); err != nil {
slog.Warn("⚠️ upload finalization failed", "job_id", jobID, "err", err)
}
// Invalidate the vector index for this project so the next
@@ -512,9 +529,27 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) {
}
slog.Info("✅ upload job complete", "job_id", jobID, "files", len(paths), "project", slug)
h.setProgress(jobID, "done")
- }()
+ }
+
+ if h.workq == nil {
+ tmpDirOwned = true
+ go job(context.Background()) // dev/test fallback
+ } else {
+ if err := h.workq.Submit(job); err != nil {
+ if errors.Is(err, workq.ErrQueueFull) {
+ h.setProgress(jobID, "rejected: indexing queue full")
+ w.Header().Set("Retry-After", "30")
+ writeError(w, r, http.StatusServiceUnavailable, "indexing queue full; retry later", nil)
+ return
+ }
+ h.setProgress(jobID, "rejected: server unavailable")
+ writeError(w, r, http.StatusServiceUnavailable, "server shutting down", err)
+ return
+ }
+ tmpDirOwned = true
+ }
- writeJSON(w, 202, map[string]string{"job_id": jobID, "status": "queued"})
+ writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID, "status": "accepted"})
}
func (h *handlers) setProgress(jobID, msg string) {
@@ -606,3 +641,31 @@ func intQuery(s string, def int) int {
}
return n
}
+
+// writeTooLarge emits a 413 JSON error describing the configured limit.
+// Callers must ensure w.WriteHeader has not already been committed.
+func writeTooLarge(w http.ResponseWriter, limit int64) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusRequestEntityTooLarge)
+ _, _ = fmt.Fprintf(w, `{"error":"request body exceeds maximum upload size of %d bytes"}`, limit)
+}
+
+// enforceUploadLimit checks Content-Length against limit and, if the
+// declared size is within bounds, wraps r.Body with http.MaxBytesReader
+// so that any overflow during parsing is caught. Returns false and writes
+// a 413 JSON response when the request is known to exceed the limit;
+// the caller must return immediately in that case.
+func enforceUploadLimit(w http.ResponseWriter, r *http.Request, limit int64) bool {
+ if limit <= 0 {
+ return true // unlimited (opt-in via 0 or negative)
+ }
+ // Fast path: Content-Length is declared and already exceeds the limit.
+ if r.ContentLength > limit {
+ slog.Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit)
+ writeTooLarge(w, limit)
+ return false
+ }
+ // Slow path: wrap body so overflow is caught during parsing.
+ r.Body = http.MaxBytesReader(w, r.Body, limit)
+ return true
+}
diff --git a/internal/api/router.go b/internal/api/router.go
index b77802f..6f18993 100644
--- a/internal/api/router.go
+++ b/internal/api/router.go
@@ -1,9 +1,7 @@
package api
import (
- "bytes"
"context"
- "html"
"io/fs"
"log/slog"
"net/http"
@@ -16,6 +14,7 @@ import (
"github.com/RandomCodeSpace/docsiq/internal/llm"
"github.com/RandomCodeSpace/docsiq/internal/mcp"
"github.com/RandomCodeSpace/docsiq/internal/project"
+ "github.com/RandomCodeSpace/docsiq/internal/workq"
"github.com/RandomCodeSpace/docsiq/ui"
)
@@ -26,6 +25,7 @@ type RouterOption func(*routerOptions)
type routerOptions struct {
vecIndexes *VectorIndexes
stores *projectStores
+ workq *workq.Pool
}
// WithVectorIndexes wires a per-project HNSW index cache into the
@@ -35,6 +35,13 @@ func WithVectorIndexes(vi *VectorIndexes) RouterOption {
return func(o *routerOptions) { o.vecIndexes = vi }
}
+// WithWorkq injects a bounded worker pool for background indexing jobs.
+// When nil (default), upload() falls back to a detached goroutine — the
+// dev/test path.
+func WithWorkq(p *workq.Pool) RouterOption {
+ return func(o *routerOptions) { o.workq = p }
+}
+
// WithProjectStores lets callers inject a pre-built ProjectStores
// cache so they can close it at shutdown. Nil (default) causes
// NewRouter to allocate its own — fine for tests, but real servers
@@ -70,6 +77,7 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re
embedder: emb,
cfg: cfg,
vecIndexes: ro.vecIndexes,
+ workq: ro.workq,
}
nh := newNotesHandlersWithStores(stores, cfg, registry)
ph := &projectsHandler{registry: registry}
@@ -105,6 +113,12 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re
})
}
+ // Session exchange — public (is the auth boundary).
+ // POST exchanges a bearer key for a docsiq_session httpOnly cookie.
+ // DELETE clears the cookie (logout).
+ mux.HandleFunc("POST /api/session", newSessionHandler(cfg.Server.APIKey))
+ mux.HandleFunc("DELETE /api/session", newSessionDeleteHandler())
+
// REST API — docs pipeline (Phase-0)
mux.HandleFunc("GET /api/stats", h.getStats)
mux.HandleFunc("GET /api/documents", h.listDocuments)
@@ -156,7 +170,7 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re
projectMiddleware(cfg, registry, mux))))
}
-func spaHandler(assets fs.FS, cfg *config.Config) http.Handler {
+func spaHandler(assets fs.FS, _ *config.Config) http.Handler {
fileServer := http.FileServer(http.FS(assets))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -181,14 +195,6 @@ func spaHandler(assets fs.FS, cfg *config.Config) http.Handler {
return
}
- if cfg.Server.APIKey != "" {
- content = bytes.Replace(
- content,
- []byte(""),
- []byte(``),
- 1,
- )
- }
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(content)
diff --git a/internal/api/session.go b/internal/api/session.go
new file mode 100644
index 0000000..9a0806a
--- /dev/null
+++ b/internal/api/session.go
@@ -0,0 +1,72 @@
+package api
+
+import (
+ "crypto/subtle"
+ "log/slog"
+ "net/http"
+ "strings"
+)
+
+// sessionCookieName is the name of the httpOnly cookie that carries the
+// bearer token after a successful POST /api/session exchange. The value
+// is identical to cfg.Server.APIKey — we do not (yet) rotate or sign it;
+// the cookie is a transport-hardening layer, not a session store.
+const sessionCookieName = "docsiq_session"
+
+// newSessionHandler returns the POST /api/session handler. Accepts an
+// Authorization: Bearer header and on match sets the session
+// cookie. 401 on any other shape.
+func newSessionHandler(apiKey string) http.HandlerFunc {
+ keyBytes := []byte(apiKey)
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ w.Header().Set("Allow", "POST, DELETE")
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ raw := strings.TrimSpace(r.Header.Get("Authorization"))
+ const prefix = "Bearer "
+ if !strings.HasPrefix(raw, prefix) {
+ writeJSON401(w)
+ return
+ }
+ token := raw[len(prefix):]
+ if apiKey == "" || subtle.ConstantTimeCompare([]byte(token), keyBytes) != 1 {
+ slog.Warn("🔒 session: auth failure", "remote_addr", r.RemoteAddr, "reason", "wrong_key")
+ writeJSON401(w)
+ return
+ }
+ http.SetCookie(w, &http.Cookie{
+ Name: sessionCookieName,
+ Value: apiKey,
+ Path: "/",
+ HttpOnly: true,
+ Secure: true,
+ SameSite: http.SameSiteStrictMode,
+ MaxAge: 86400 * 30, // 30 days
+ })
+ w.WriteHeader(http.StatusNoContent)
+ }
+}
+
+// newSessionDeleteHandler returns the DELETE /api/session handler,
+// which clears the session cookie (client-initiated logout).
+func newSessionDeleteHandler() http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodDelete {
+ w.Header().Set("Allow", "POST, DELETE")
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ http.SetCookie(w, &http.Cookie{
+ Name: sessionCookieName,
+ Value: "",
+ Path: "/",
+ HttpOnly: true,
+ Secure: true,
+ SameSite: http.SameSiteStrictMode,
+ MaxAge: -1,
+ })
+ w.WriteHeader(http.StatusNoContent)
+ }
+}
diff --git a/internal/api/session_test.go b/internal/api/session_test.go
new file mode 100644
index 0000000..1f9495a
--- /dev/null
+++ b/internal/api/session_test.go
@@ -0,0 +1,62 @@
+package api
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestSession_PostExchangesBearerForCookie(t *testing.T) {
+ t.Parallel()
+ h := newSessionHandler("s3cret")
+ req := httptest.NewRequest(http.MethodPost, "/api/session", nil)
+ req.Header.Set("Authorization", "Bearer s3cret")
+ rr := httptest.NewRecorder()
+ h(rr, req)
+
+ if rr.Code != http.StatusNoContent {
+ t.Fatalf("want 204, got %d", rr.Code)
+ }
+ setCookie := rr.Header().Get("Set-Cookie")
+ if !strings.Contains(setCookie, sessionCookieName+"=") {
+ t.Fatalf("missing session cookie: %q", setCookie)
+ }
+ for _, attr := range []string{"HttpOnly", "Secure", "SameSite=Strict", "Path=/"} {
+ if !strings.Contains(setCookie, attr) {
+ t.Fatalf("cookie missing %s: %q", attr, setCookie)
+ }
+ }
+}
+
+func TestSession_PostRejectsBadKey(t *testing.T) {
+ t.Parallel()
+ h := newSessionHandler("s3cret")
+ req := httptest.NewRequest(http.MethodPost, "/api/session", nil)
+ req.Header.Set("Authorization", "Bearer wrong")
+ rr := httptest.NewRecorder()
+ h(rr, req)
+
+ if rr.Code != http.StatusUnauthorized {
+ t.Fatalf("want 401, got %d", rr.Code)
+ }
+ if rr.Header().Get("Set-Cookie") != "" {
+ t.Fatal("cookie must not be set on failure")
+ }
+}
+
+func TestSession_DeleteClearsCookie(t *testing.T) {
+ t.Parallel()
+ h := newSessionDeleteHandler()
+ req := httptest.NewRequest(http.MethodDelete, "/api/session", nil)
+ rr := httptest.NewRecorder()
+ h(rr, req)
+
+ if rr.Code != http.StatusNoContent {
+ t.Fatalf("want 204, got %d", rr.Code)
+ }
+ setCookie := rr.Header().Get("Set-Cookie")
+ if !strings.Contains(setCookie, "Max-Age=0") {
+ t.Fatalf("cookie should be cleared (Max-Age=0); got %q", setCookie)
+ }
+}
diff --git a/internal/api/spa_meta_test.go b/internal/api/spa_meta_test.go
deleted file mode 100644
index 8e60f66..0000000
--- a/internal/api/spa_meta_test.go
+++ /dev/null
@@ -1,49 +0,0 @@
-package api
-
-import (
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
-
- "github.com/RandomCodeSpace/docsiq/internal/config"
- "github.com/RandomCodeSpace/docsiq/ui"
-)
-
-func TestSPA_InjectsMetaWhenAPIKeySet(t *testing.T) {
- cfg := &config.Config{}
- cfg.Server.APIKey = "secret-key-abc"
- h := spaHandler(ui.Assets, cfg)
- srv := httptest.NewServer(h)
- defer srv.Close()
- resp, err := http.Get(srv.URL + "/")
- if err != nil {
- t.Fatal(err)
- }
- defer resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
- if !strings.Contains(string(body), `name="docsiq-api-key"`) {
- t.Fatalf("expected meta tag, body:\n%s", body)
- }
- if !strings.Contains(string(body), `content="secret-key-abc"`) {
- t.Fatalf("expected API key in content attr, body:\n%s", body)
- }
-}
-
-func TestSPA_OmitsMetaWhenAPIKeyUnset(t *testing.T) {
- cfg := &config.Config{}
- cfg.Server.APIKey = ""
- h := spaHandler(ui.Assets, cfg)
- srv := httptest.NewServer(h)
- defer srv.Close()
- resp, err := http.Get(srv.URL + "/")
- if err != nil {
- t.Fatal(err)
- }
- defer resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
- if strings.Contains(string(body), `name="docsiq-api-key"`) {
- t.Fatalf("meta tag should not be present when APIKey empty")
- }
-}
diff --git a/internal/api/spa_test.go b/internal/api/spa_test.go
new file mode 100644
index 0000000..f5629d3
--- /dev/null
+++ b/internal/api/spa_test.go
@@ -0,0 +1,35 @@
+package api
+
+import (
+ "bytes"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "testing/fstest"
+
+ "github.com/RandomCodeSpace/docsiq/internal/config"
+)
+
+func TestSpaHandler_DoesNotInjectAPIKey(t *testing.T) {
+ t.Parallel()
+ fsys := fstest.MapFS{
+ "index.html": &fstest.MapFile{
+ Data: []byte(``),
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Server.APIKey = "s3cret"
+
+ rr := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ spaHandler(fsys, cfg).ServeHTTP(rr, req)
+
+ body, _ := io.ReadAll(rr.Body)
+ if bytes.Contains(body, []byte("docsiq-api-key")) {
+ t.Fatalf("served HTML still contains api-key meta tag:\n%s", body)
+ }
+ if rr.Code != http.StatusOK {
+ t.Fatalf("want 200; got %d", rr.Code)
+ }
+}
diff --git a/internal/api/upload_limit_test.go b/internal/api/upload_limit_test.go
new file mode 100644
index 0000000..188cd15
--- /dev/null
+++ b/internal/api/upload_limit_test.go
@@ -0,0 +1,128 @@
+package api
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+// TestUploadMaxBytes verifies that requests whose body exceeds
+// cfg.Server.MaxUploadBytes are rejected with 413 before the handler
+// tries to parse the multipart form. This exercises the fast-path
+// (Content-Length declared and already over the limit).
+func TestUploadMaxBytes(t *testing.T) {
+ t.Parallel()
+ const limit int64 = 1024 // 1 KiB for the test
+
+ // Build a multipart body larger than the limit.
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+ part, err := mw.CreateFormFile("files", "big.txt")
+ if err != nil {
+ t.Fatalf("create form file: %v", err)
+ }
+ if _, err := io.Copy(part, strings.NewReader(strings.Repeat("x", int(limit)*2))); err != nil {
+ t.Fatalf("copy: %v", err)
+ }
+ _ = mw.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/api/upload", &body)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ rr := httptest.NewRecorder()
+
+ // enforceUploadLimit is the unit-testable shim applied inside upload().
+ // It wraps r.Body with http.MaxBytesReader and returns a 413 on overflow.
+ h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if !enforceUploadLimit(w, r, limit) {
+ return
+ }
+ if err := r.ParseMultipartForm(32 << 10); err != nil {
+ // MaxBytesReader converts overflow into a ParseMultipartForm error
+ // AFTER the header has been written by http.MaxBytesReader. We
+ // still exit here; the header is already 413 in that case.
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+ h.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusRequestEntityTooLarge {
+ t.Fatalf("expected 413, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+ if got := rr.Header().Get("Content-Type"); got != "application/json" {
+ t.Fatalf("expected Content-Type application/json, got %q", got)
+ }
+ if !strings.Contains(rr.Body.String(), "exceeds maximum upload size") {
+ t.Fatalf("expected JSON error to mention upload size, got: %s", rr.Body.String())
+ }
+}
+
+// TestUploadMaxBytes_UnknownContentLength covers the slow-path where the
+// Content-Length is unknown (e.g. chunked transfer encoding). In that case
+// the fast-path cannot reject up-front; enforcement happens when
+// ParseMultipartForm reads through the MaxBytesReader wrapper and that
+// returns a *http.MaxBytesError.
+//
+// Note: the real net/http server calls an internal requestTooLarge() hook
+// on its own response writer that commits a 413, so production callers
+// can "just return" on *MaxBytesError without a WriteHeader of their
+// own. httptest.ResponseRecorder does not implement that hook, so this
+// test verifies the downstream signal (the error type is what production
+// code matches on) by asserting the *MaxBytesError surfaces and that
+// writeTooLarge, given the limit, produces the expected 413 JSON.
+func TestUploadMaxBytes_UnknownContentLength(t *testing.T) {
+ t.Parallel()
+ const limit int64 = 1024 // 1 KiB for the test
+
+ var body bytes.Buffer
+ mw := multipart.NewWriter(&body)
+ part, err := mw.CreateFormFile("files", "big.txt")
+ if err != nil {
+ t.Fatalf("create form file: %v", err)
+ }
+ if _, err := io.Copy(part, strings.NewReader(strings.Repeat("x", int(limit)*2))); err != nil {
+ t.Fatalf("copy: %v", err)
+ }
+ _ = mw.Close()
+
+ req := httptest.NewRequest(http.MethodPost, "/api/upload", &body)
+ req.Header.Set("Content-Type", mw.FormDataContentType())
+ req.ContentLength = -1 // force slow path: unknown Content-Length
+ rr := httptest.NewRecorder()
+
+ // This inner handler mirrors the prod upload() flow. In a real
+ // http.Server the MaxBytesReader's requestTooLarge() hook commits
+ // 413 automatically; in httptest we explicitly writeTooLarge()
+ // after matching *MaxBytesError, which exercises the same helper
+ // used by the fast-path.
+ h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if !enforceUploadLimit(w, r, limit) {
+ return
+ }
+ if err := r.ParseMultipartForm(32 << 10); err != nil {
+ var mbe *http.MaxBytesError
+ if errors.As(err, &mbe) {
+ writeTooLarge(w, mbe.Limit)
+ return
+ }
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ })
+ h.ServeHTTP(rr, req)
+
+ if rr.Code != http.StatusRequestEntityTooLarge {
+ t.Fatalf("expected 413, got %d (body: %s)", rr.Code, rr.Body.String())
+ }
+ if got := rr.Header().Get("Content-Type"); got != "application/json" {
+ t.Fatalf("expected Content-Type application/json, got %q", got)
+ }
+ if !strings.Contains(rr.Body.String(), "exceeds maximum upload size") {
+ t.Fatalf("expected JSON error to mention upload size, got: %s", rr.Body.String())
+ }
+}
diff --git a/internal/api/upload_workq_test.go b/internal/api/upload_workq_test.go
new file mode 100644
index 0000000..b55ff72
--- /dev/null
+++ b/internal/api/upload_workq_test.go
@@ -0,0 +1,79 @@
+package api
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync/atomic"
+ "testing"
+
+ "github.com/RandomCodeSpace/docsiq/internal/workq"
+)
+
+// TestUpload_ReturnsRetryOnFullQueue verifies that when the injected
+// workq Pool is saturated, the upload-submission layer returns 503 with
+// a Retry-After header and does not run the job. We exercise the HTTP
+// bridge via the Pool contract without setting up the full upload
+// handler (that's covered by other integration tests).
+func TestUpload_ReturnsRetryOnFullQueue(t *testing.T) {
+ t.Parallel()
+ pool := workq.New(workq.Config{Workers: 1, QueueDepth: 1})
+ // Saturate the pool: one worker busy, one queue slot full.
+ block := make(chan struct{})
+ started := make(chan struct{})
+ // LIFO: close(block) runs first (unblocks the stuck worker), then
+ // pool.Close drains cleanly. This order is safe even on panic.
+ defer pool.Close(context.Background()) //nolint:errcheck
+ defer close(block)
+ // Signal via started that the worker has actually begun executing
+ // (and is now blocked on <-block) before we fill the queue slot.
+ // Without this synchronisation the race detector's slower scheduling
+ // can drain the first job from the channel before Submit #2 lands,
+ // leaving a free slot for the test's submit and producing a false 202.
+ _ = pool.Submit(func(ctx context.Context) { close(started); <-block })
+ <-started // worker is blocked on <-block; channel now empty
+ // Channel capacity = Workers+QueueDepth = 1+1 = 2.
+ // Fill both slots so the next submit returns ErrQueueFull.
+ _ = pool.Submit(func(ctx context.Context) {})
+ _ = pool.Submit(func(ctx context.Context) {})
+
+ var called atomic.Bool
+ // Mimic what h.workq.Submit does in upload(): on ErrQueueFull, set
+ // Retry-After and write 503 via writeError equivalent.
+ handle := func(w http.ResponseWriter, _ *http.Request) {
+ err := pool.Submit(func(ctx context.Context) {
+ called.Store(true)
+ })
+ if err == nil {
+ w.WriteHeader(http.StatusAccepted)
+ return
+ }
+ if errors.Is(err, workq.ErrQueueFull) {
+ w.Header().Set("Retry-After", "30")
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusServiceUnavailable)
+ _, _ = w.Write([]byte(`{"error":"indexing queue full; retry later"}`))
+ return
+ }
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+
+ req := httptest.NewRequest(http.MethodPost, "/api/upload", nil)
+ rr := httptest.NewRecorder()
+ handle(rr, req)
+
+ if rr.Code != http.StatusServiceUnavailable {
+ t.Fatalf("want 503, got %d", rr.Code)
+ }
+ if rr.Header().Get("Retry-After") != "30" {
+ t.Fatalf("missing or wrong Retry-After: %q", rr.Header().Get("Retry-After"))
+ }
+ if got := rr.Body.String(); !strings.Contains(got, "queue full") {
+ t.Fatalf("body should mention queue full; got %s", got)
+ }
+ if called.Load() {
+ t.Fatal("job should not have run when queue is full")
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index fbc6b45..f160f66 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -143,9 +143,12 @@ type CommunityConfig struct {
}
type ServerConfig struct {
- Host string `mapstructure:"host"`
- Port int `mapstructure:"port"`
- APIKey string `mapstructure:"api_key"`
+ Host string `mapstructure:"host"`
+ Port int `mapstructure:"port"`
+ APIKey string `mapstructure:"api_key"`
+ MaxUploadBytes int64 `mapstructure:"max_upload_bytes"` // 0 or negative disables the cap
+ WorkqWorkers int `mapstructure:"workq_workers"` // 0 → runtime.NumCPU()
+ WorkqDepth int `mapstructure:"workq_depth"` // 0 → 64
}
func Load(cfgFile string) (*Config, error) {
@@ -208,6 +211,9 @@ func Load(cfgFile string) (*Config, error) {
v.SetDefault("server.host", "127.0.0.1")
v.SetDefault("server.port", 8080)
v.SetDefault("server.api_key", "")
+ v.SetDefault("server.max_upload_bytes", int64(100*1024*1024)) // 100 MiB
+ v.SetDefault("server.workq_workers", 0) // 0 → runtime.NumCPU()
+ v.SetDefault("server.workq_depth", 64)
// Config file search paths. Only ~/.docsiq and CWD are consulted.
newCfgDir := filepath.Join(home, ".docsiq")
@@ -230,6 +236,9 @@ func Load(cfgFile string) (*Config, error) {
// either form populates server.api_key. BindEnv names are matched
// verbatim (not prefixed), so we list the full env var name.
_ = v.BindEnv("server.api_key", "DOCSIQ_SERVER_API_KEY", "DOCSIQ_API_KEY")
+ _ = v.BindEnv("server.max_upload_bytes", "DOCSIQ_SERVER_MAX_UPLOAD_BYTES")
+ _ = v.BindEnv("server.workq_workers", "DOCSIQ_SERVER_WORKQ_WORKERS")
+ _ = v.BindEnv("server.workq_depth", "DOCSIQ_SERVER_WORKQ_DEPTH")
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
diff --git a/internal/search/local.go b/internal/search/local.go
index e2aaf61..3d4314c 100644
--- a/internal/search/local.go
+++ b/internal/search/local.go
@@ -87,7 +87,14 @@ func LocalSearch(ctx context.Context, st *store.Store, emb *embedder.Embedder, i
docIDs[c.Chunk.DocID] = true
}
- entities, err := st.AllEntities(ctx)
+ // Scope entity fetch to the top-hit documents instead of a
+ // full-table scan. Entities with no relationships to any top-hit
+ // doc are out of local scope by definition.
+ docIDList := make([]string, 0, len(docIDs))
+ for id := range docIDs {
+ docIDList = append(docIDList, id)
+ }
+ entities, err := st.EntitiesForDocs(ctx, docIDList)
if err != nil {
return nil, err
}
diff --git a/internal/store/entities_for_docs_test.go b/internal/store/entities_for_docs_test.go
new file mode 100644
index 0000000..e239315
--- /dev/null
+++ b/internal/store/entities_for_docs_test.go
@@ -0,0 +1,59 @@
+package store
+
+import (
+ "context"
+ "testing"
+)
+
+func TestEntitiesForDocs_ScopesByRelationshipDocID(t *testing.T) {
+ t.Parallel()
+ st := newTestStore(t)
+ ctx := context.Background()
+
+ must := func(err error) {
+ t.Helper()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Insert documents first (FK enforcement is on).
+ must(st.UpsertDocument(ctx, &Document{ID: "docA", Path: "/a", Title: "A", DocType: "txt", FileHash: "hashA"}))
+ must(st.UpsertDocument(ctx, &Document{ID: "docB", Path: "/b", Title: "B", DocType: "txt", FileHash: "hashB"}))
+
+ // Three entities; two relationships each scoped to a doc.
+ must(st.UpsertEntity(ctx, &Entity{ID: "e1", Name: "Alpha"}))
+ must(st.UpsertEntity(ctx, &Entity{ID: "e2", Name: "Beta"}))
+ must(st.UpsertEntity(ctx, &Entity{ID: "e3", Name: "Gamma"}))
+ must(st.InsertRelationship(ctx, &Relationship{ID: "r1", SourceID: "e1", TargetID: "e2", Predicate: "rel", DocID: "docA"}))
+ must(st.InsertRelationship(ctx, &Relationship{ID: "r2", SourceID: "e3", TargetID: "e1", Predicate: "rel", DocID: "docB"}))
+
+ got, err := st.EntitiesForDocs(ctx, []string{"docA"})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(got) != 2 {
+ t.Fatalf("docA: want 2 entities (e1, e2); got %d", len(got))
+ }
+
+ // Empty input → empty slice, no error.
+ empty, err := st.EntitiesForDocs(ctx, nil)
+ if err != nil || len(empty) != 0 {
+ t.Fatalf("empty input: want (0, nil); got (%d, %v)", len(empty), err)
+ }
+}
+
+func TestEntitiesForDocs_HandlesLargeIDSets(t *testing.T) {
+ t.Parallel()
+ st := newTestStore(t)
+ ctx := context.Background()
+
+ ids := make([]string, 1500) // > SQLite's 999 default
+ for i := range ids {
+ ids[i] = "doc-xyz"
+ }
+ _, err := st.EntitiesForDocs(ctx, ids)
+ if err != nil {
+ t.Fatalf("chunking at >999 should not error: %v", err)
+ }
+}
diff --git a/internal/store/store.go b/internal/store/store.go
index 6b201fb..ea7f9e8 100644
--- a/internal/store/store.go
+++ b/internal/store/store.go
@@ -604,6 +604,62 @@ func (s *Store) AllEntities(ctx context.Context) ([]*Entity, error) {
return entities, rows.Err()
}
+// EntitiesForDocs returns entities that participate (as source or target
+// of any relationship) in at least one of the given documents. This is
+// the "local" entity set for a scoped search — avoids the full-table
+// scan AllEntities performs.
+//
+// The IN-list is chunked at 900 (below SQLite's default 999 variable
+// limit) so caller-supplied doc sets of any size work transparently.
+func (s *Store) EntitiesForDocs(ctx context.Context, docIDs []string) ([]*Entity, error) {
+ if len(docIDs) == 0 {
+ return nil, nil
+ }
+ const chunkSize = 900
+ seen := make(map[string]struct{}, 128)
+ out := make([]*Entity, 0, 128)
+
+ for start := 0; start < len(docIDs); start += chunkSize {
+ end := start + chunkSize
+ if end > len(docIDs) {
+ end = len(docIDs)
+ }
+ chunk := docIDs[start:end]
+ placeholders := strings.Repeat("?,", len(chunk))
+ placeholders = placeholders[:len(placeholders)-1]
+ args := make([]any, len(chunk))
+ for i, id := range chunk {
+ args[i] = id
+ }
+ q := `SELECT DISTINCT e.id, e.name, e.type, e.description, e.rank, e.community_id, e.vector
+ FROM entities e
+ JOIN relationships r ON (r.source_id = e.id OR r.target_id = e.id)
+ WHERE r.doc_id IN (` + placeholders + `)`
+ rows, err := s.db.QueryContext(ctx, q, args...)
+ if err != nil {
+ return nil, err
+ }
+ for rows.Next() {
+ e, err := scanEntityRow(rows)
+ if err != nil {
+ rows.Close()
+ return nil, err
+ }
+ if _, dup := seen[e.ID]; dup {
+ continue
+ }
+ seen[e.ID] = struct{}{}
+ out = append(out, e)
+ }
+ if err := rows.Err(); err != nil {
+ rows.Close()
+ return nil, err
+ }
+ rows.Close()
+ }
+ return out, nil
+}
+
func (s *Store) UpdateEntityCommunity(ctx context.Context, entityID, communityID string) error {
_, err := s.db.ExecContext(ctx, `UPDATE entities SET community_id=? WHERE id=?`, communityID, entityID)
return err
diff --git a/internal/workq/workq.go b/internal/workq/workq.go
new file mode 100644
index 0000000..0204ae6
--- /dev/null
+++ b/internal/workq/workq.go
@@ -0,0 +1,126 @@
+// Package workq is a minimal bounded worker pool for fire-and-forget
+// background work (e.g. post-upload indexing). Jobs carry a context
+// derived from the pool's root context; Close() cancels that context
+// and waits for workers to drain, honouring the caller's deadline.
+package workq
+
+import (
+ "context"
+ "errors"
+ "sync"
+)
+
+// ErrQueueFull is returned by Submit when the job queue is saturated.
+// Callers should surface this as 503 Service Unavailable with Retry-After.
+var ErrQueueFull = errors.New("workq: queue full")
+
+// ErrClosed is returned by Submit after Close has been called.
+var ErrClosed = errors.New("workq: closed")
+
+// Job is a unit of work. It receives the pool's context so it can
+// abort on shutdown.
+type Job func(ctx context.Context)
+
+// Config sizes the pool. Zero values use safe defaults (1 worker,
+// 16-deep queue). Total in-flight + queued capacity is Workers + QueueDepth.
+type Config struct {
+ Workers int
+ QueueDepth int
+}
+
+// Pool is a fixed-size worker pool with a bounded submission queue.
+type Pool struct {
+ jobs chan Job
+ ctx context.Context
+ cancel context.CancelFunc
+ wg sync.WaitGroup
+
+ // mu guards close(p.jobs) vs concurrent sends in Submit. RLock
+ // on the send path lets many Submits proceed in parallel; Close
+ // takes the write lock before closing the channel.
+ mu sync.RWMutex
+ closeOnce sync.Once
+ closed chan struct{}
+}
+
+// New constructs and starts a Pool.
+func New(cfg Config) *Pool {
+ if cfg.Workers < 1 {
+ cfg.Workers = 1
+ }
+ if cfg.QueueDepth < 1 {
+ cfg.QueueDepth = 16
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ // Total buffered capacity = Workers + QueueDepth, so Submit succeeds
+ // whenever at least one worker is idle OR there is a free queue slot.
+ p := &Pool{
+ jobs: make(chan Job, cfg.Workers+cfg.QueueDepth),
+ ctx: ctx,
+ cancel: cancel,
+ closed: make(chan struct{}),
+ }
+ for i := 0; i < cfg.Workers; i++ {
+ p.wg.Add(1)
+ go p.run()
+ }
+ return p
+}
+
+// Submit enqueues job. Non-blocking: returns ErrQueueFull immediately
+// if no queue slot is available, ErrClosed if the pool is shutting down.
+func (p *Pool) Submit(job Job) error {
+ // RLock pairs with the write-lock in Close so the send on p.jobs
+ // cannot race with close(p.jobs).
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+ select {
+ case <-p.closed:
+ return ErrClosed
+ default:
+ }
+ select {
+ case p.jobs <- job:
+ return nil
+ default:
+ return ErrQueueFull
+ }
+}
+
+// Close stops accepting new work and waits for workers to drain. If
+// the caller's ctx fires before drain completes, the pool context is
+// cancelled so in-flight jobs honouring cancellation can abort, and
+// ctx.Err() is returned.
+func (p *Pool) Close(ctx context.Context) error {
+ p.closeOnce.Do(func() {
+ p.mu.Lock()
+ close(p.closed)
+ close(p.jobs)
+ p.mu.Unlock()
+ })
+ done := make(chan struct{})
+ go func() {
+ p.wg.Wait()
+ close(done)
+ }()
+ select {
+ case <-done:
+ return nil
+ case <-ctx.Done():
+ p.cancel()
+ return ctx.Err()
+ }
+}
+
+func (p *Pool) run() {
+ defer p.wg.Done()
+ for job := range p.jobs {
+ // Trap panics per-job so one bad job cannot kill a worker.
+ func() {
+ defer func() {
+ _ = recover()
+ }()
+ job(p.ctx)
+ }()
+ }
+}
diff --git a/internal/workq/workq_test.go b/internal/workq/workq_test.go
new file mode 100644
index 0000000..84c4415
--- /dev/null
+++ b/internal/workq/workq_test.go
@@ -0,0 +1,107 @@
+package workq
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestPool_SubmitRunsJob(t *testing.T) {
+ t.Parallel()
+ p := New(Config{Workers: 2, QueueDepth: 4})
+ defer p.Close(context.Background())
+
+ var ran atomic.Int32
+ done := make(chan struct{})
+ if err := p.Submit(func(ctx context.Context) {
+ ran.Add(1)
+ close(done)
+ }); err != nil {
+ t.Fatalf("submit: %v", err)
+ }
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("job did not run within 1s")
+ }
+ if got := ran.Load(); got != 1 {
+ t.Fatalf("want ran=1, got %d", got)
+ }
+}
+
+func TestPool_SubmitReturnsErrQueueFull(t *testing.T) {
+ t.Parallel()
+ p := New(Config{Workers: 1, QueueDepth: 1})
+ defer p.Close(context.Background())
+
+ block := make(chan struct{})
+ // Occupy the single worker.
+ _ = p.Submit(func(ctx context.Context) { <-block })
+ // Fill the single queue slot.
+ if err := p.Submit(func(ctx context.Context) {}); err != nil {
+ t.Fatalf("queue slot submit: %v", err)
+ }
+ // Third submit must fail fast.
+ err := p.Submit(func(ctx context.Context) {})
+ if !errors.Is(err, ErrQueueFull) {
+ t.Fatalf("want ErrQueueFull, got %v", err)
+ }
+ close(block)
+}
+
+func TestPool_CloseDrainsInflight(t *testing.T) {
+ t.Parallel()
+ p := New(Config{Workers: 2, QueueDepth: 4})
+ var ran atomic.Int32
+ for range 4 {
+ _ = p.Submit(func(ctx context.Context) {
+ time.Sleep(20 * time.Millisecond)
+ ran.Add(1)
+ })
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ if err := p.Close(ctx); err != nil {
+ t.Fatalf("close: %v", err)
+ }
+ if got := ran.Load(); got != 4 {
+ t.Fatalf("want ran=4 after drain, got %d", got)
+ }
+}
+
+func TestPool_CloseCancelsOnContextDeadline(t *testing.T) {
+ t.Parallel()
+ p := New(Config{Workers: 1, QueueDepth: 1})
+ start := make(chan struct{})
+ _ = p.Submit(func(ctx context.Context) {
+ close(start)
+ <-ctx.Done() // honour cancellation
+ })
+ <-start
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+ err := p.Close(ctx)
+ if !errors.Is(err, context.DeadlineExceeded) {
+ t.Fatalf("want DeadlineExceeded, got %v", err)
+ }
+}
+
+func TestPool_SubmitRaceDuringClose(t *testing.T) {
+ t.Parallel()
+ for range 50 {
+ p := New(Config{Workers: 4, QueueDepth: 8})
+ var wg sync.WaitGroup
+ for range 32 {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ _ = p.Submit(func(ctx context.Context) {})
+ }()
+ }
+ _ = p.Close(context.Background())
+ wg.Wait()
+ }
+}
diff --git a/ui/src/hooks/api/useMCP.ts b/ui/src/hooks/api/useMCP.ts
index 6482973..c35f37e 100644
--- a/ui/src/hooks/api/useMCP.ts
+++ b/ui/src/hooks/api/useMCP.ts
@@ -20,12 +20,6 @@ export interface MCPTool {
};
}
-function getBearer(): string | null {
- if (typeof document === "undefined") return null;
- const v = document.querySelector('meta[name="docsiq-api-key"]')?.getAttribute("content");
- return v && v.length ? v : null;
-}
-
async function rpc(
sessionId: string | null,
body: unknown,
@@ -34,11 +28,14 @@ async function rpc(
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
};
- const bearer = getBearer();
- if (bearer) headers["Authorization"] = `Bearer ${bearer}`;
if (sessionId) headers["Mcp-Session-Id"] = sessionId;
- const res = await fetch("/mcp", { method: "POST", headers, body: JSON.stringify(body) });
+ const res = await fetch("/mcp", {
+ method: "POST",
+ credentials: "include",
+ headers,
+ body: JSON.stringify(body),
+ });
const newSession = res.headers.get("Mcp-Session-Id") ?? sessionId;
const text = await res.text();
diff --git a/ui/src/lib/__tests__/api-client.test.ts b/ui/src/lib/__tests__/api-client.test.ts
index c047495..6cc24c4 100644
--- a/ui/src/lib/__tests__/api-client.test.ts
+++ b/ui/src/lib/__tests__/api-client.test.ts
@@ -1,7 +1,7 @@
-import { describe, it, expect } from "vitest";
+import { describe, it, expect, vi } from "vitest";
import { http, HttpResponse } from "msw";
import { server } from "@/test/msw";
-import { apiFetch, ApiErrorResponse } from "../api-client";
+import { apiFetch, ApiErrorResponse, initAuth } from "../api-client";
describe("apiFetch", () => {
it("returns parsed json on 200", async () => {
@@ -32,4 +32,39 @@ describe("apiFetch", () => {
const r = await apiFetch("/api/x", { method: "DELETE" });
expect(r).toBeUndefined();
});
+
+ it("sends credentials: 'include' on every fetch", async () => {
+ const spy = vi.spyOn(globalThis, "fetch").mockResolvedValue(
+ new Response("{}", { status: 200, headers: { "content-type": "application/json" } }),
+ );
+ await apiFetch("/api/stats");
+ const init = (spy.mock.calls[0][1] ?? {}) as RequestInit;
+ expect(init.credentials).toBe("include");
+ spy.mockRestore();
+ });
+
+ it("does not set Authorization header on data-path fetch even when a key exists in a meta tag", async () => {
+ const meta = document.createElement("meta");
+ meta.setAttribute("name", "docsiq-api-key");
+ meta.setAttribute("content", "s3cret");
+ document.head.appendChild(meta);
+ // Spy must be installed BEFORE initAuth() so the /api/session exchange
+ // is captured by the mock and not passed through to MSW (which has no
+ // handler for it and would throw).
+ const spy = vi.spyOn(globalThis, "fetch").mockResolvedValue(
+ new Response("{}", { status: 200, headers: { "content-type": "application/json" } }),
+ );
+ try {
+ initAuth();
+ await apiFetch("/api/stats");
+ const statsCall = spy.mock.calls.find((c) => c[0] === "/api/stats");
+ expect(statsCall).toBeDefined();
+ const init = (statsCall![1] ?? {}) as RequestInit;
+ const hdrs = new Headers(init.headers);
+ expect(hdrs.has("Authorization")).toBe(false);
+ } finally {
+ spy.mockRestore();
+ if (meta.parentElement) document.head.removeChild(meta);
+ }
+ });
});
diff --git a/ui/src/lib/api-client.ts b/ui/src/lib/api-client.ts
index b462981..e3e551e 100644
--- a/ui/src/lib/api-client.ts
+++ b/ui/src/lib/api-client.ts
@@ -1,16 +1,38 @@
import type { ApiError } from "@/types/api";
-let bearer: string | null = null;
+// Before cookies are set the first time, the UI may have been shipped a
+// one-shot bearer via the meta tag (legacy). We exchange it for a cookie
+// exactly once, then never read or send the key again. If no meta tag
+// exists (production path), we rely entirely on cookies already set by
+// the operator's OOB provisioning (e.g. `docsiq login`).
+let sessionReady: Promise | null = null;
-function readBearerFromMeta(): string | null {
+function readOneShotBearerFromMeta(): string | null {
if (typeof document === "undefined") return null;
const m = document.querySelector('meta[name="docsiq-api-key"]');
const v = m?.getAttribute("content");
return v && v.length > 0 ? v : null;
}
-export function initAuth() {
- bearer = readBearerFromMeta();
+async function establishSession(bearer: string): Promise {
+ const m = document.querySelector('meta[name="docsiq-api-key"]');
+ m?.parentElement?.removeChild(m);
+
+ const res = await fetch("/api/session", {
+ method: "POST",
+ credentials: "include",
+ headers: { Authorization: `Bearer ${bearer}` },
+ });
+ if (!res.ok) {
+ let body: ApiError = { error: `HTTP ${res.status}` };
+ try { body = await res.json(); } catch { /* non-json */ }
+ throw new ApiErrorResponse(res.status, body);
+ }
+}
+
+export function initAuth(): void {
+ const bearer = readOneShotBearerFromMeta();
+ sessionReady = bearer ? establishSession(bearer) : Promise.resolve();
}
export class ApiErrorResponse extends Error {
@@ -27,15 +49,19 @@ export async function apiFetch(
path: string,
init: RequestInit = {},
): Promise {
+ if (sessionReady) await sessionReady;
const headers = new Headers(init.headers);
- if (bearer) headers.set("Authorization", `Bearer ${bearer}`);
if (init.body && !headers.has("Content-Type")) {
headers.set("Content-Type", "application/json");
}
- const res = await fetch(path, { ...init, headers });
+ const res = await fetch(path, { ...init, headers, credentials: "include" });
if (!res.ok) {
let body: ApiError = { error: `HTTP ${res.status}` };
- try { body = await res.json(); } catch { /* non-json */ }
+ try {
+ body = await res.json();
+ } catch {
+ /* non-json */
+ }
throw new ApiErrorResponse(res.status, body);
}
if (res.status === 204) return undefined as T;