diff --git a/cmd/serve.go b/cmd/serve.go index ed23057..e9755ad 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -9,15 +9,19 @@ import ( "os" "os/signal" "path/filepath" + "runtime" + "strings" "syscall" "time" "github.com/RandomCodeSpace/docsiq/internal/api" + "github.com/RandomCodeSpace/docsiq/internal/config" "github.com/RandomCodeSpace/docsiq/internal/embedder" "github.com/RandomCodeSpace/docsiq/internal/llm" "github.com/RandomCodeSpace/docsiq/internal/project" "github.com/RandomCodeSpace/docsiq/internal/sqlitevec" "github.com/RandomCodeSpace/docsiq/internal/vectorindex" + "github.com/RandomCodeSpace/docsiq/internal/workq" "github.com/spf13/cobra" ) @@ -140,11 +144,25 @@ var serveCmd = &cobra.Command{ } } + workers := cfg.Server.WorkqWorkers + if workers <= 0 { + workers = runtime.NumCPU() + } + depth := cfg.Server.WorkqDepth + if depth <= 0 { + depth = 64 + } + pool := workq.New(workq.Config{Workers: workers, QueueDepth: depth}) + router := api.NewRouter(prov, emb, cfg, registry, api.WithProjectStores(stores), api.WithVectorIndexes(vecIndexes), + api.WithWorkq(pool), ) + if err := validateServeSecurity(cfg); err != nil { + return err + } addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) ln, err := net.Listen("tcp", addr) if err != nil { @@ -177,6 +195,17 @@ var serveCmd = &cobra.Command{ slog.Error("❌ shutdown error", "err", err) return err } + + // Drain workq within its own 30s deadline. Server.Shutdown has already + // stopped accepting new HTTP requests, so no new jobs can be submitted; + // all that remains is letting in-flight pipelines finish or honour the + // cancelled ctx. + drainCtx, drainCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer drainCancel() + if err := pool.Close(drainCtx); err != nil { + slog.Warn("⚠️ workq drain timeout; some indexing jobs were cancelled mid-flight", "err", err) + } + slog.Info("✅ shutdown complete") return nil }, @@ -187,3 +216,33 @@ func init() { serveCmd.Flags().StringVar(&serveHost, "host", "", "Server host (overrides config)") serveCmd.Flags().IntVar(&servePort, "port", 0, "Server port (overrides config)") } + +// validateServeSecurity refuses to start the server when the API key is +// empty AND the bind host is not loopback. An unauthenticated service +// exposed on the network is almost never intentional; make it explicit. +// Loopback with empty key gets a prominent warning at boot instead. +func validateServeSecurity(cfg *config.Config) error { + if cfg.Server.APIKey != "" { + return nil + } + host := strings.ToLower(strings.TrimSpace(cfg.Server.Host)) + if host == "" { + return fmt.Errorf( + "server.api_key is empty and server.host is unset (binds all interfaces); refusing to start. " + + "Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev", + ) + } + loopback := host == "localhost" + if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil { + loopback = loopback || ip.IsLoopback() + } + if !loopback { + return fmt.Errorf( + "server.api_key is empty and server.host=%q is not loopback; refusing to start. "+ + "Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev", + cfg.Server.Host, + ) + } + slog.Warn("⚠️ auth disabled (empty server.api_key); only loopback bind allowed", "host", host) + return nil +} diff --git a/cmd/serve_test.go b/cmd/serve_test.go new file mode 100644 index 0000000..7480f90 --- /dev/null +++ b/cmd/serve_test.go @@ -0,0 +1,87 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +func TestValidateServeSecurity_RefusesNonLoopbackWithEmptyKey(t *testing.T) { + t.Parallel() + cases := []struct { + name string + host string + mustContain []string + }{ + { + name: "empty host binds all interfaces", + host: "", + mustContain: []string{"binds all interfaces", "DOCSIQ_SERVER_API_KEY"}, + }, + { + name: "0.0.0.0 is wildcard", + host: "0.0.0.0", + mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"}, + }, + { + name: "public IPv4", + host: "10.0.0.5", + mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"}, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.Host = tc.host + cfg.Server.Port = 8080 + cfg.Server.APIKey = "" + + err := validateServeSecurity(cfg) + if err == nil { + t.Fatalf("host=%q: expected error for empty api_key on non-loopback bind; got nil", tc.host) + } + for _, want := range tc.mustContain { + if !strings.Contains(err.Error(), want) { + t.Errorf("host=%q: error should contain %q; got %v", tc.host, want, err) + } + } + }) + } +} + +func TestValidateServeSecurity_AllowsLoopbackWithEmptyKey(t *testing.T) { + t.Parallel() + hosts := []string{ + "127.0.0.1", + "localhost", + "::1", + "[::1]", // bracketed IPv6 + "127.0.0.2", // other address in 127.0.0.0/8 + "::ffff:127.0.0.1", // IPv6-mapped IPv4 + } + for _, host := range hosts { + host := host + t.Run(host, func(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.Host = host + cfg.Server.APIKey = "" + if err := validateServeSecurity(cfg); err != nil { + t.Fatalf("host=%s: expected nil; got %v", host, err) + } + }) + } +} + +func TestValidateServeSecurity_AllowsNonLoopbackWithKey(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.Host = "0.0.0.0" + cfg.Server.APIKey = "s3cret" + if err := validateServeSecurity(cfg); err != nil { + t.Fatalf("expected nil; got %v", err) + } +} diff --git a/internal/api/auth.go b/internal/api/auth.go index c1e4a0f..185422f 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -50,17 +50,31 @@ func bearerAuthMiddleware(apiKey string, next http.Handler) http.Handler { return } - raw := strings.TrimSpace(r.Header.Get("Authorization")) - const prefix = "Bearer " - if !strings.HasPrefix(raw, prefix) { + // /api/session is the auth boundary itself — always public. + if path == "/api/session" { + next.ServeHTTP(w, r) + return + } + + // Defense-in-depth: reject immediately if the server has no key + // configured. This mirrors newSessionHandler's guard and keeps the + // middleware correct under future refactors (rather than relying on + // the no_token branch firing because keyBytes would also be empty). + if apiKey == "" { + slog.Warn("🔒 auth failure", "path", path, "remote_addr", r.RemoteAddr, "reason", "server_misconfigured") + writeJSON401(w) + return + } + + token := extractToken(r) + if token == "" { slog.Warn("🔒 auth failure", "path", path, "remote_addr", r.RemoteAddr, - "reason", "no_bearer_prefix") + "reason", "no_token") writeJSON401(w) return } - token := raw[len(prefix):] if subtle.ConstantTimeCompare([]byte(token), keyBytes) != 1 { slog.Warn("🔒 auth failure", "path", path, @@ -82,3 +96,20 @@ func writeJSON401(w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) _ = json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) } + +// extractToken returns the bearer token from either the Authorization +// header (preferred, for machine clients) or the session cookie (for +// browser clients after POST /api/session). Returns "" if neither. +func extractToken(r *http.Request) string { + raw := strings.TrimSpace(r.Header.Get("Authorization")) + const prefix = "Bearer " + if strings.HasPrefix(raw, prefix) { + return raw[len(prefix):] + } + if c, err := r.Cookie(sessionCookieName); err == nil { + if v := strings.TrimSpace(c.Value); v != "" { + return v + } + } + return "" +} diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index d447b42..6d4804f 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -515,6 +515,29 @@ func TestBearerAuthMiddleware(t *testing.T) { }) } +func TestAuth_AcceptsValidCookie(t *testing.T) { + t.Parallel() + h := buildAuthHandler("s3cret") + req := httptest.NewRequest(http.MethodGet, "/api/ping", nil) + req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "s3cret"}) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("want 200 with valid cookie, got %d", rr.Code) + } +} + +func TestAuth_RejectsMissingBothHeaderAndCookie(t *testing.T) { + t.Parallel() + h := buildAuthHandler("s3cret") + req := httptest.NewRequest(http.MethodGet, "/api/ping", nil) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusUnauthorized { + t.Fatalf("want 401, got %d", rr.Code) + } +} + // BenchmarkBearerAuth_WrongKey ensures the constant-time compare path runs // under the benchmark harness. It is NOT a hard timing assertion — if the // code ever switches to a non-constant-time compare (==, bytes.Equal), the diff --git a/internal/api/handlers.go b/internal/api/handlers.go index aa11bf9..1edcd83 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -23,6 +24,7 @@ import ( "github.com/RandomCodeSpace/docsiq/internal/search" "github.com/RandomCodeSpace/docsiq/internal/store" "github.com/RandomCodeSpace/docsiq/internal/vectorindex" + "github.com/RandomCodeSpace/docsiq/internal/workq" ) // handlers is the REST-side doc router state. Wave-2 drop: the @@ -38,6 +40,9 @@ type handlers struct { // return nil for a slug with no embeddings; LocalSearch falls back // to brute-force in that case. vecIndexes *VectorIndexes + // workq is the bounded worker pool for upload indexing jobs. When + // nil (dev/test path), upload() falls back to a detached goroutine. + workq *workq.Pool // Upload progress tracking uploadMu sync.Mutex @@ -403,8 +408,20 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) { return } slug := ProjectFromContext(r.Context()) - // TODO(docsiq): P2-1 wrap r.Body with http.MaxBytesReader before ParseMultipartForm - if err := r.ParseMultipartForm(128 << 20); err != nil { + if !enforceUploadLimit(w, r, h.cfg.Server.MaxUploadBytes) { + return + } + if err := r.ParseMultipartForm(32 << 20); err != nil { + // MaxBytesReader translates overflow into an error here; the + // response header is already 413 when that happens. For other + // malformed-form errors we emit a 400. + var mbe *http.MaxBytesError + if errors.As(err, &mbe) { + // http.MaxBytesReader has already called w.WriteHeader(413) + // internally; calling it again would emit "http: superfluous + // response.WriteHeader call". Just return. + return + } writeError(w, r, 400, "parse form: "+err.Error(), nil) return } @@ -484,25 +501,25 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) { h.setProgress(jobID, fmt.Sprintf("queued: %d files", len(paths))) - // Use a detached context so the background goroutine is not cancelled - // when the HTTP response is sent. - bgCtx := context.Background() - tmpDirOwned = true - - go func() { + job := func(ctx context.Context) { defer os.RemoveAll(tmpDir) pl := pipeline.New(st, h.provider, h.cfg) for _, p := range paths { + if ctx.Err() != nil { + slog.Warn("🛑 upload indexing cancelled on shutdown", "job_id", jobID, "file", filepath.Base(p)) + h.setProgress(jobID, "cancelled") + return + } slog.Info("📦 upload indexing file", "job_id", jobID, "file", filepath.Base(p)) h.setProgress(jobID, fmt.Sprintf("indexing: %s", filepath.Base(p))) - if err := pl.IndexPath(bgCtx, p, pipeline.IndexOptions{}); err != nil { + if err := pl.IndexPath(ctx, p, pipeline.IndexOptions{}); err != nil { slog.Error("❌ upload indexing failed", "job_id", jobID, "file", filepath.Base(p), "err", err) h.setProgress(jobID, fmt.Sprintf("error: %v", err)) return } } h.setProgress(jobID, "finalizing") - if err := pl.Finalize(bgCtx, false, true); err != nil { + if err := pl.Finalize(ctx, false, true); err != nil { slog.Warn("⚠️ upload finalization failed", "job_id", jobID, "err", err) } // Invalidate the vector index for this project so the next @@ -512,9 +529,27 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) { } slog.Info("✅ upload job complete", "job_id", jobID, "files", len(paths), "project", slug) h.setProgress(jobID, "done") - }() + } + + if h.workq == nil { + tmpDirOwned = true + go job(context.Background()) // dev/test fallback + } else { + if err := h.workq.Submit(job); err != nil { + if errors.Is(err, workq.ErrQueueFull) { + h.setProgress(jobID, "rejected: indexing queue full") + w.Header().Set("Retry-After", "30") + writeError(w, r, http.StatusServiceUnavailable, "indexing queue full; retry later", nil) + return + } + h.setProgress(jobID, "rejected: server unavailable") + writeError(w, r, http.StatusServiceUnavailable, "server shutting down", err) + return + } + tmpDirOwned = true + } - writeJSON(w, 202, map[string]string{"job_id": jobID, "status": "queued"}) + writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID, "status": "accepted"}) } func (h *handlers) setProgress(jobID, msg string) { @@ -606,3 +641,31 @@ func intQuery(s string, def int) int { } return n } + +// writeTooLarge emits a 413 JSON error describing the configured limit. +// Callers must ensure w.WriteHeader has not already been committed. +func writeTooLarge(w http.ResponseWriter, limit int64) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusRequestEntityTooLarge) + _, _ = fmt.Fprintf(w, `{"error":"request body exceeds maximum upload size of %d bytes"}`, limit) +} + +// enforceUploadLimit checks Content-Length against limit and, if the +// declared size is within bounds, wraps r.Body with http.MaxBytesReader +// so that any overflow during parsing is caught. Returns false and writes +// a 413 JSON response when the request is known to exceed the limit; +// the caller must return immediately in that case. +func enforceUploadLimit(w http.ResponseWriter, r *http.Request, limit int64) bool { + if limit <= 0 { + return true // unlimited (opt-in via 0 or negative) + } + // Fast path: Content-Length is declared and already exceeds the limit. + if r.ContentLength > limit { + slog.Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit) + writeTooLarge(w, limit) + return false + } + // Slow path: wrap body so overflow is caught during parsing. + r.Body = http.MaxBytesReader(w, r.Body, limit) + return true +} diff --git a/internal/api/router.go b/internal/api/router.go index b77802f..6f18993 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -1,9 +1,7 @@ package api import ( - "bytes" "context" - "html" "io/fs" "log/slog" "net/http" @@ -16,6 +14,7 @@ import ( "github.com/RandomCodeSpace/docsiq/internal/llm" "github.com/RandomCodeSpace/docsiq/internal/mcp" "github.com/RandomCodeSpace/docsiq/internal/project" + "github.com/RandomCodeSpace/docsiq/internal/workq" "github.com/RandomCodeSpace/docsiq/ui" ) @@ -26,6 +25,7 @@ type RouterOption func(*routerOptions) type routerOptions struct { vecIndexes *VectorIndexes stores *projectStores + workq *workq.Pool } // WithVectorIndexes wires a per-project HNSW index cache into the @@ -35,6 +35,13 @@ func WithVectorIndexes(vi *VectorIndexes) RouterOption { return func(o *routerOptions) { o.vecIndexes = vi } } +// WithWorkq injects a bounded worker pool for background indexing jobs. +// When nil (default), upload() falls back to a detached goroutine — the +// dev/test path. +func WithWorkq(p *workq.Pool) RouterOption { + return func(o *routerOptions) { o.workq = p } +} + // WithProjectStores lets callers inject a pre-built ProjectStores // cache so they can close it at shutdown. Nil (default) causes // NewRouter to allocate its own — fine for tests, but real servers @@ -70,6 +77,7 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re embedder: emb, cfg: cfg, vecIndexes: ro.vecIndexes, + workq: ro.workq, } nh := newNotesHandlersWithStores(stores, cfg, registry) ph := &projectsHandler{registry: registry} @@ -105,6 +113,12 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re }) } + // Session exchange — public (is the auth boundary). + // POST exchanges a bearer key for a docsiq_session httpOnly cookie. + // DELETE clears the cookie (logout). + mux.HandleFunc("POST /api/session", newSessionHandler(cfg.Server.APIKey)) + mux.HandleFunc("DELETE /api/session", newSessionDeleteHandler()) + // REST API — docs pipeline (Phase-0) mux.HandleFunc("GET /api/stats", h.getStats) mux.HandleFunc("GET /api/documents", h.listDocuments) @@ -156,7 +170,7 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re projectMiddleware(cfg, registry, mux)))) } -func spaHandler(assets fs.FS, cfg *config.Config) http.Handler { +func spaHandler(assets fs.FS, _ *config.Config) http.Handler { fileServer := http.FileServer(http.FS(assets)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -181,14 +195,6 @@ func spaHandler(assets fs.FS, cfg *config.Config) http.Handler { return } - if cfg.Server.APIKey != "" { - content = bytes.Replace( - content, - []byte(""), - []byte(``), - 1, - ) - } w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) _, _ = w.Write(content) diff --git a/internal/api/session.go b/internal/api/session.go new file mode 100644 index 0000000..9a0806a --- /dev/null +++ b/internal/api/session.go @@ -0,0 +1,72 @@ +package api + +import ( + "crypto/subtle" + "log/slog" + "net/http" + "strings" +) + +// sessionCookieName is the name of the httpOnly cookie that carries the +// bearer token after a successful POST /api/session exchange. The value +// is identical to cfg.Server.APIKey — we do not (yet) rotate or sign it; +// the cookie is a transport-hardening layer, not a session store. +const sessionCookieName = "docsiq_session" + +// newSessionHandler returns the POST /api/session handler. Accepts an +// Authorization: Bearer header and on match sets the session +// cookie. 401 on any other shape. +func newSessionHandler(apiKey string) http.HandlerFunc { + keyBytes := []byte(apiKey) + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Allow", "POST, DELETE") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + raw := strings.TrimSpace(r.Header.Get("Authorization")) + const prefix = "Bearer " + if !strings.HasPrefix(raw, prefix) { + writeJSON401(w) + return + } + token := raw[len(prefix):] + if apiKey == "" || subtle.ConstantTimeCompare([]byte(token), keyBytes) != 1 { + slog.Warn("🔒 session: auth failure", "remote_addr", r.RemoteAddr, "reason", "wrong_key") + writeJSON401(w) + return + } + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: apiKey, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + MaxAge: 86400 * 30, // 30 days + }) + w.WriteHeader(http.StatusNoContent) + } +} + +// newSessionDeleteHandler returns the DELETE /api/session handler, +// which clears the session cookie (client-initiated logout). +func newSessionDeleteHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + w.Header().Set("Allow", "POST, DELETE") + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: "", + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + MaxAge: -1, + }) + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/internal/api/session_test.go b/internal/api/session_test.go new file mode 100644 index 0000000..1f9495a --- /dev/null +++ b/internal/api/session_test.go @@ -0,0 +1,62 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestSession_PostExchangesBearerForCookie(t *testing.T) { + t.Parallel() + h := newSessionHandler("s3cret") + req := httptest.NewRequest(http.MethodPost, "/api/session", nil) + req.Header.Set("Authorization", "Bearer s3cret") + rr := httptest.NewRecorder() + h(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("want 204, got %d", rr.Code) + } + setCookie := rr.Header().Get("Set-Cookie") + if !strings.Contains(setCookie, sessionCookieName+"=") { + t.Fatalf("missing session cookie: %q", setCookie) + } + for _, attr := range []string{"HttpOnly", "Secure", "SameSite=Strict", "Path=/"} { + if !strings.Contains(setCookie, attr) { + t.Fatalf("cookie missing %s: %q", attr, setCookie) + } + } +} + +func TestSession_PostRejectsBadKey(t *testing.T) { + t.Parallel() + h := newSessionHandler("s3cret") + req := httptest.NewRequest(http.MethodPost, "/api/session", nil) + req.Header.Set("Authorization", "Bearer wrong") + rr := httptest.NewRecorder() + h(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("want 401, got %d", rr.Code) + } + if rr.Header().Get("Set-Cookie") != "" { + t.Fatal("cookie must not be set on failure") + } +} + +func TestSession_DeleteClearsCookie(t *testing.T) { + t.Parallel() + h := newSessionDeleteHandler() + req := httptest.NewRequest(http.MethodDelete, "/api/session", nil) + rr := httptest.NewRecorder() + h(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("want 204, got %d", rr.Code) + } + setCookie := rr.Header().Get("Set-Cookie") + if !strings.Contains(setCookie, "Max-Age=0") { + t.Fatalf("cookie should be cleared (Max-Age=0); got %q", setCookie) + } +} diff --git a/internal/api/spa_meta_test.go b/internal/api/spa_meta_test.go deleted file mode 100644 index 8e60f66..0000000 --- a/internal/api/spa_meta_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package api - -import ( - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/RandomCodeSpace/docsiq/internal/config" - "github.com/RandomCodeSpace/docsiq/ui" -) - -func TestSPA_InjectsMetaWhenAPIKeySet(t *testing.T) { - cfg := &config.Config{} - cfg.Server.APIKey = "secret-key-abc" - h := spaHandler(ui.Assets, cfg) - srv := httptest.NewServer(h) - defer srv.Close() - resp, err := http.Get(srv.URL + "/") - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if !strings.Contains(string(body), `name="docsiq-api-key"`) { - t.Fatalf("expected meta tag, body:\n%s", body) - } - if !strings.Contains(string(body), `content="secret-key-abc"`) { - t.Fatalf("expected API key in content attr, body:\n%s", body) - } -} - -func TestSPA_OmitsMetaWhenAPIKeyUnset(t *testing.T) { - cfg := &config.Config{} - cfg.Server.APIKey = "" - h := spaHandler(ui.Assets, cfg) - srv := httptest.NewServer(h) - defer srv.Close() - resp, err := http.Get(srv.URL + "/") - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - if strings.Contains(string(body), `name="docsiq-api-key"`) { - t.Fatalf("meta tag should not be present when APIKey empty") - } -} diff --git a/internal/api/spa_test.go b/internal/api/spa_test.go new file mode 100644 index 0000000..f5629d3 --- /dev/null +++ b/internal/api/spa_test.go @@ -0,0 +1,35 @@ +package api + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + "testing/fstest" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +func TestSpaHandler_DoesNotInjectAPIKey(t *testing.T) { + t.Parallel() + fsys := fstest.MapFS{ + "index.html": &fstest.MapFile{ + Data: []byte(``), + }, + } + cfg := &config.Config{} + cfg.Server.APIKey = "s3cret" + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + spaHandler(fsys, cfg).ServeHTTP(rr, req) + + body, _ := io.ReadAll(rr.Body) + if bytes.Contains(body, []byte("docsiq-api-key")) { + t.Fatalf("served HTML still contains api-key meta tag:\n%s", body) + } + if rr.Code != http.StatusOK { + t.Fatalf("want 200; got %d", rr.Code) + } +} diff --git a/internal/api/upload_limit_test.go b/internal/api/upload_limit_test.go new file mode 100644 index 0000000..188cd15 --- /dev/null +++ b/internal/api/upload_limit_test.go @@ -0,0 +1,128 @@ +package api + +import ( + "bytes" + "errors" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestUploadMaxBytes verifies that requests whose body exceeds +// cfg.Server.MaxUploadBytes are rejected with 413 before the handler +// tries to parse the multipart form. This exercises the fast-path +// (Content-Length declared and already over the limit). +func TestUploadMaxBytes(t *testing.T) { + t.Parallel() + const limit int64 = 1024 // 1 KiB for the test + + // Build a multipart body larger than the limit. + var body bytes.Buffer + mw := multipart.NewWriter(&body) + part, err := mw.CreateFormFile("files", "big.txt") + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := io.Copy(part, strings.NewReader(strings.Repeat("x", int(limit)*2))); err != nil { + t.Fatalf("copy: %v", err) + } + _ = mw.Close() + + req := httptest.NewRequest(http.MethodPost, "/api/upload", &body) + req.Header.Set("Content-Type", mw.FormDataContentType()) + rr := httptest.NewRecorder() + + // enforceUploadLimit is the unit-testable shim applied inside upload(). + // It wraps r.Body with http.MaxBytesReader and returns a 413 on overflow. + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !enforceUploadLimit(w, r, limit) { + return + } + if err := r.ParseMultipartForm(32 << 10); err != nil { + // MaxBytesReader converts overflow into a ParseMultipartForm error + // AFTER the header has been written by http.MaxBytesReader. We + // still exit here; the header is already 413 in that case. + return + } + w.WriteHeader(http.StatusOK) + }) + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d (body: %s)", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("Content-Type"); got != "application/json" { + t.Fatalf("expected Content-Type application/json, got %q", got) + } + if !strings.Contains(rr.Body.String(), "exceeds maximum upload size") { + t.Fatalf("expected JSON error to mention upload size, got: %s", rr.Body.String()) + } +} + +// TestUploadMaxBytes_UnknownContentLength covers the slow-path where the +// Content-Length is unknown (e.g. chunked transfer encoding). In that case +// the fast-path cannot reject up-front; enforcement happens when +// ParseMultipartForm reads through the MaxBytesReader wrapper and that +// returns a *http.MaxBytesError. +// +// Note: the real net/http server calls an internal requestTooLarge() hook +// on its own response writer that commits a 413, so production callers +// can "just return" on *MaxBytesError without a WriteHeader of their +// own. httptest.ResponseRecorder does not implement that hook, so this +// test verifies the downstream signal (the error type is what production +// code matches on) by asserting the *MaxBytesError surfaces and that +// writeTooLarge, given the limit, produces the expected 413 JSON. +func TestUploadMaxBytes_UnknownContentLength(t *testing.T) { + t.Parallel() + const limit int64 = 1024 // 1 KiB for the test + + var body bytes.Buffer + mw := multipart.NewWriter(&body) + part, err := mw.CreateFormFile("files", "big.txt") + if err != nil { + t.Fatalf("create form file: %v", err) + } + if _, err := io.Copy(part, strings.NewReader(strings.Repeat("x", int(limit)*2))); err != nil { + t.Fatalf("copy: %v", err) + } + _ = mw.Close() + + req := httptest.NewRequest(http.MethodPost, "/api/upload", &body) + req.Header.Set("Content-Type", mw.FormDataContentType()) + req.ContentLength = -1 // force slow path: unknown Content-Length + rr := httptest.NewRecorder() + + // This inner handler mirrors the prod upload() flow. In a real + // http.Server the MaxBytesReader's requestTooLarge() hook commits + // 413 automatically; in httptest we explicitly writeTooLarge() + // after matching *MaxBytesError, which exercises the same helper + // used by the fast-path. + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !enforceUploadLimit(w, r, limit) { + return + } + if err := r.ParseMultipartForm(32 << 10); err != nil { + var mbe *http.MaxBytesError + if errors.As(err, &mbe) { + writeTooLarge(w, mbe.Limit) + return + } + return + } + w.WriteHeader(http.StatusOK) + }) + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d (body: %s)", rr.Code, rr.Body.String()) + } + if got := rr.Header().Get("Content-Type"); got != "application/json" { + t.Fatalf("expected Content-Type application/json, got %q", got) + } + if !strings.Contains(rr.Body.String(), "exceeds maximum upload size") { + t.Fatalf("expected JSON error to mention upload size, got: %s", rr.Body.String()) + } +} diff --git a/internal/api/upload_workq_test.go b/internal/api/upload_workq_test.go new file mode 100644 index 0000000..b55ff72 --- /dev/null +++ b/internal/api/upload_workq_test.go @@ -0,0 +1,79 @@ +package api + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/RandomCodeSpace/docsiq/internal/workq" +) + +// TestUpload_ReturnsRetryOnFullQueue verifies that when the injected +// workq Pool is saturated, the upload-submission layer returns 503 with +// a Retry-After header and does not run the job. We exercise the HTTP +// bridge via the Pool contract without setting up the full upload +// handler (that's covered by other integration tests). +func TestUpload_ReturnsRetryOnFullQueue(t *testing.T) { + t.Parallel() + pool := workq.New(workq.Config{Workers: 1, QueueDepth: 1}) + // Saturate the pool: one worker busy, one queue slot full. + block := make(chan struct{}) + started := make(chan struct{}) + // LIFO: close(block) runs first (unblocks the stuck worker), then + // pool.Close drains cleanly. This order is safe even on panic. + defer pool.Close(context.Background()) //nolint:errcheck + defer close(block) + // Signal via started that the worker has actually begun executing + // (and is now blocked on <-block) before we fill the queue slot. + // Without this synchronisation the race detector's slower scheduling + // can drain the first job from the channel before Submit #2 lands, + // leaving a free slot for the test's submit and producing a false 202. + _ = pool.Submit(func(ctx context.Context) { close(started); <-block }) + <-started // worker is blocked on <-block; channel now empty + // Channel capacity = Workers+QueueDepth = 1+1 = 2. + // Fill both slots so the next submit returns ErrQueueFull. + _ = pool.Submit(func(ctx context.Context) {}) + _ = pool.Submit(func(ctx context.Context) {}) + + var called atomic.Bool + // Mimic what h.workq.Submit does in upload(): on ErrQueueFull, set + // Retry-After and write 503 via writeError equivalent. + handle := func(w http.ResponseWriter, _ *http.Request) { + err := pool.Submit(func(ctx context.Context) { + called.Store(true) + }) + if err == nil { + w.WriteHeader(http.StatusAccepted) + return + } + if errors.Is(err, workq.ErrQueueFull) { + w.Header().Set("Retry-After", "30") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"indexing queue full; retry later"}`)) + return + } + w.WriteHeader(http.StatusServiceUnavailable) + } + + req := httptest.NewRequest(http.MethodPost, "/api/upload", nil) + rr := httptest.NewRecorder() + handle(rr, req) + + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("want 503, got %d", rr.Code) + } + if rr.Header().Get("Retry-After") != "30" { + t.Fatalf("missing or wrong Retry-After: %q", rr.Header().Get("Retry-After")) + } + if got := rr.Body.String(); !strings.Contains(got, "queue full") { + t.Fatalf("body should mention queue full; got %s", got) + } + if called.Load() { + t.Fatal("job should not have run when queue is full") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index fbc6b45..f160f66 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -143,9 +143,12 @@ type CommunityConfig struct { } type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - APIKey string `mapstructure:"api_key"` + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + APIKey string `mapstructure:"api_key"` + MaxUploadBytes int64 `mapstructure:"max_upload_bytes"` // 0 or negative disables the cap + WorkqWorkers int `mapstructure:"workq_workers"` // 0 → runtime.NumCPU() + WorkqDepth int `mapstructure:"workq_depth"` // 0 → 64 } func Load(cfgFile string) (*Config, error) { @@ -208,6 +211,9 @@ func Load(cfgFile string) (*Config, error) { v.SetDefault("server.host", "127.0.0.1") v.SetDefault("server.port", 8080) v.SetDefault("server.api_key", "") + v.SetDefault("server.max_upload_bytes", int64(100*1024*1024)) // 100 MiB + v.SetDefault("server.workq_workers", 0) // 0 → runtime.NumCPU() + v.SetDefault("server.workq_depth", 64) // Config file search paths. Only ~/.docsiq and CWD are consulted. newCfgDir := filepath.Join(home, ".docsiq") @@ -230,6 +236,9 @@ func Load(cfgFile string) (*Config, error) { // either form populates server.api_key. BindEnv names are matched // verbatim (not prefixed), so we list the full env var name. _ = v.BindEnv("server.api_key", "DOCSIQ_SERVER_API_KEY", "DOCSIQ_API_KEY") + _ = v.BindEnv("server.max_upload_bytes", "DOCSIQ_SERVER_MAX_UPLOAD_BYTES") + _ = v.BindEnv("server.workq_workers", "DOCSIQ_SERVER_WORKQ_WORKERS") + _ = v.BindEnv("server.workq_depth", "DOCSIQ_SERVER_WORKQ_DEPTH") if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { diff --git a/internal/search/local.go b/internal/search/local.go index e2aaf61..3d4314c 100644 --- a/internal/search/local.go +++ b/internal/search/local.go @@ -87,7 +87,14 @@ func LocalSearch(ctx context.Context, st *store.Store, emb *embedder.Embedder, i docIDs[c.Chunk.DocID] = true } - entities, err := st.AllEntities(ctx) + // Scope entity fetch to the top-hit documents instead of a + // full-table scan. Entities with no relationships to any top-hit + // doc are out of local scope by definition. + docIDList := make([]string, 0, len(docIDs)) + for id := range docIDs { + docIDList = append(docIDList, id) + } + entities, err := st.EntitiesForDocs(ctx, docIDList) if err != nil { return nil, err } diff --git a/internal/store/entities_for_docs_test.go b/internal/store/entities_for_docs_test.go new file mode 100644 index 0000000..e239315 --- /dev/null +++ b/internal/store/entities_for_docs_test.go @@ -0,0 +1,59 @@ +package store + +import ( + "context" + "testing" +) + +func TestEntitiesForDocs_ScopesByRelationshipDocID(t *testing.T) { + t.Parallel() + st := newTestStore(t) + ctx := context.Background() + + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // Insert documents first (FK enforcement is on). + must(st.UpsertDocument(ctx, &Document{ID: "docA", Path: "/a", Title: "A", DocType: "txt", FileHash: "hashA"})) + must(st.UpsertDocument(ctx, &Document{ID: "docB", Path: "/b", Title: "B", DocType: "txt", FileHash: "hashB"})) + + // Three entities; two relationships each scoped to a doc. + must(st.UpsertEntity(ctx, &Entity{ID: "e1", Name: "Alpha"})) + must(st.UpsertEntity(ctx, &Entity{ID: "e2", Name: "Beta"})) + must(st.UpsertEntity(ctx, &Entity{ID: "e3", Name: "Gamma"})) + must(st.InsertRelationship(ctx, &Relationship{ID: "r1", SourceID: "e1", TargetID: "e2", Predicate: "rel", DocID: "docA"})) + must(st.InsertRelationship(ctx, &Relationship{ID: "r2", SourceID: "e3", TargetID: "e1", Predicate: "rel", DocID: "docB"})) + + got, err := st.EntitiesForDocs(ctx, []string{"docA"}) + if err != nil { + t.Fatal(err) + } + if len(got) != 2 { + t.Fatalf("docA: want 2 entities (e1, e2); got %d", len(got)) + } + + // Empty input → empty slice, no error. + empty, err := st.EntitiesForDocs(ctx, nil) + if err != nil || len(empty) != 0 { + t.Fatalf("empty input: want (0, nil); got (%d, %v)", len(empty), err) + } +} + +func TestEntitiesForDocs_HandlesLargeIDSets(t *testing.T) { + t.Parallel() + st := newTestStore(t) + ctx := context.Background() + + ids := make([]string, 1500) // > SQLite's 999 default + for i := range ids { + ids[i] = "doc-xyz" + } + _, err := st.EntitiesForDocs(ctx, ids) + if err != nil { + t.Fatalf("chunking at >999 should not error: %v", err) + } +} diff --git a/internal/store/store.go b/internal/store/store.go index 6b201fb..ea7f9e8 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -604,6 +604,62 @@ func (s *Store) AllEntities(ctx context.Context) ([]*Entity, error) { return entities, rows.Err() } +// EntitiesForDocs returns entities that participate (as source or target +// of any relationship) in at least one of the given documents. This is +// the "local" entity set for a scoped search — avoids the full-table +// scan AllEntities performs. +// +// The IN-list is chunked at 900 (below SQLite's default 999 variable +// limit) so caller-supplied doc sets of any size work transparently. +func (s *Store) EntitiesForDocs(ctx context.Context, docIDs []string) ([]*Entity, error) { + if len(docIDs) == 0 { + return nil, nil + } + const chunkSize = 900 + seen := make(map[string]struct{}, 128) + out := make([]*Entity, 0, 128) + + for start := 0; start < len(docIDs); start += chunkSize { + end := start + chunkSize + if end > len(docIDs) { + end = len(docIDs) + } + chunk := docIDs[start:end] + placeholders := strings.Repeat("?,", len(chunk)) + placeholders = placeholders[:len(placeholders)-1] + args := make([]any, len(chunk)) + for i, id := range chunk { + args[i] = id + } + q := `SELECT DISTINCT e.id, e.name, e.type, e.description, e.rank, e.community_id, e.vector + FROM entities e + JOIN relationships r ON (r.source_id = e.id OR r.target_id = e.id) + WHERE r.doc_id IN (` + placeholders + `)` + rows, err := s.db.QueryContext(ctx, q, args...) + if err != nil { + return nil, err + } + for rows.Next() { + e, err := scanEntityRow(rows) + if err != nil { + rows.Close() + return nil, err + } + if _, dup := seen[e.ID]; dup { + continue + } + seen[e.ID] = struct{}{} + out = append(out, e) + } + if err := rows.Err(); err != nil { + rows.Close() + return nil, err + } + rows.Close() + } + return out, nil +} + func (s *Store) UpdateEntityCommunity(ctx context.Context, entityID, communityID string) error { _, err := s.db.ExecContext(ctx, `UPDATE entities SET community_id=? WHERE id=?`, communityID, entityID) return err diff --git a/internal/workq/workq.go b/internal/workq/workq.go new file mode 100644 index 0000000..0204ae6 --- /dev/null +++ b/internal/workq/workq.go @@ -0,0 +1,126 @@ +// Package workq is a minimal bounded worker pool for fire-and-forget +// background work (e.g. post-upload indexing). Jobs carry a context +// derived from the pool's root context; Close() cancels that context +// and waits for workers to drain, honouring the caller's deadline. +package workq + +import ( + "context" + "errors" + "sync" +) + +// ErrQueueFull is returned by Submit when the job queue is saturated. +// Callers should surface this as 503 Service Unavailable with Retry-After. +var ErrQueueFull = errors.New("workq: queue full") + +// ErrClosed is returned by Submit after Close has been called. +var ErrClosed = errors.New("workq: closed") + +// Job is a unit of work. It receives the pool's context so it can +// abort on shutdown. +type Job func(ctx context.Context) + +// Config sizes the pool. Zero values use safe defaults (1 worker, +// 16-deep queue). Total in-flight + queued capacity is Workers + QueueDepth. +type Config struct { + Workers int + QueueDepth int +} + +// Pool is a fixed-size worker pool with a bounded submission queue. +type Pool struct { + jobs chan Job + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + // mu guards close(p.jobs) vs concurrent sends in Submit. RLock + // on the send path lets many Submits proceed in parallel; Close + // takes the write lock before closing the channel. + mu sync.RWMutex + closeOnce sync.Once + closed chan struct{} +} + +// New constructs and starts a Pool. +func New(cfg Config) *Pool { + if cfg.Workers < 1 { + cfg.Workers = 1 + } + if cfg.QueueDepth < 1 { + cfg.QueueDepth = 16 + } + ctx, cancel := context.WithCancel(context.Background()) + // Total buffered capacity = Workers + QueueDepth, so Submit succeeds + // whenever at least one worker is idle OR there is a free queue slot. + p := &Pool{ + jobs: make(chan Job, cfg.Workers+cfg.QueueDepth), + ctx: ctx, + cancel: cancel, + closed: make(chan struct{}), + } + for i := 0; i < cfg.Workers; i++ { + p.wg.Add(1) + go p.run() + } + return p +} + +// Submit enqueues job. Non-blocking: returns ErrQueueFull immediately +// if no queue slot is available, ErrClosed if the pool is shutting down. +func (p *Pool) Submit(job Job) error { + // RLock pairs with the write-lock in Close so the send on p.jobs + // cannot race with close(p.jobs). + p.mu.RLock() + defer p.mu.RUnlock() + select { + case <-p.closed: + return ErrClosed + default: + } + select { + case p.jobs <- job: + return nil + default: + return ErrQueueFull + } +} + +// Close stops accepting new work and waits for workers to drain. If +// the caller's ctx fires before drain completes, the pool context is +// cancelled so in-flight jobs honouring cancellation can abort, and +// ctx.Err() is returned. +func (p *Pool) Close(ctx context.Context) error { + p.closeOnce.Do(func() { + p.mu.Lock() + close(p.closed) + close(p.jobs) + p.mu.Unlock() + }) + done := make(chan struct{}) + go func() { + p.wg.Wait() + close(done) + }() + select { + case <-done: + return nil + case <-ctx.Done(): + p.cancel() + return ctx.Err() + } +} + +func (p *Pool) run() { + defer p.wg.Done() + for job := range p.jobs { + // Trap panics per-job so one bad job cannot kill a worker. + func() { + defer func() { + _ = recover() + }() + job(p.ctx) + }() + } +} diff --git a/internal/workq/workq_test.go b/internal/workq/workq_test.go new file mode 100644 index 0000000..84c4415 --- /dev/null +++ b/internal/workq/workq_test.go @@ -0,0 +1,107 @@ +package workq + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestPool_SubmitRunsJob(t *testing.T) { + t.Parallel() + p := New(Config{Workers: 2, QueueDepth: 4}) + defer p.Close(context.Background()) + + var ran atomic.Int32 + done := make(chan struct{}) + if err := p.Submit(func(ctx context.Context) { + ran.Add(1) + close(done) + }); err != nil { + t.Fatalf("submit: %v", err) + } + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("job did not run within 1s") + } + if got := ran.Load(); got != 1 { + t.Fatalf("want ran=1, got %d", got) + } +} + +func TestPool_SubmitReturnsErrQueueFull(t *testing.T) { + t.Parallel() + p := New(Config{Workers: 1, QueueDepth: 1}) + defer p.Close(context.Background()) + + block := make(chan struct{}) + // Occupy the single worker. + _ = p.Submit(func(ctx context.Context) { <-block }) + // Fill the single queue slot. + if err := p.Submit(func(ctx context.Context) {}); err != nil { + t.Fatalf("queue slot submit: %v", err) + } + // Third submit must fail fast. + err := p.Submit(func(ctx context.Context) {}) + if !errors.Is(err, ErrQueueFull) { + t.Fatalf("want ErrQueueFull, got %v", err) + } + close(block) +} + +func TestPool_CloseDrainsInflight(t *testing.T) { + t.Parallel() + p := New(Config{Workers: 2, QueueDepth: 4}) + var ran atomic.Int32 + for range 4 { + _ = p.Submit(func(ctx context.Context) { + time.Sleep(20 * time.Millisecond) + ran.Add(1) + }) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := p.Close(ctx); err != nil { + t.Fatalf("close: %v", err) + } + if got := ran.Load(); got != 4 { + t.Fatalf("want ran=4 after drain, got %d", got) + } +} + +func TestPool_CloseCancelsOnContextDeadline(t *testing.T) { + t.Parallel() + p := New(Config{Workers: 1, QueueDepth: 1}) + start := make(chan struct{}) + _ = p.Submit(func(ctx context.Context) { + close(start) + <-ctx.Done() // honour cancellation + }) + <-start + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + err := p.Close(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("want DeadlineExceeded, got %v", err) + } +} + +func TestPool_SubmitRaceDuringClose(t *testing.T) { + t.Parallel() + for range 50 { + p := New(Config{Workers: 4, QueueDepth: 8}) + var wg sync.WaitGroup + for range 32 { + wg.Add(1) + go func() { + defer wg.Done() + _ = p.Submit(func(ctx context.Context) {}) + }() + } + _ = p.Close(context.Background()) + wg.Wait() + } +} diff --git a/ui/src/hooks/api/useMCP.ts b/ui/src/hooks/api/useMCP.ts index 6482973..c35f37e 100644 --- a/ui/src/hooks/api/useMCP.ts +++ b/ui/src/hooks/api/useMCP.ts @@ -20,12 +20,6 @@ export interface MCPTool { }; } -function getBearer(): string | null { - if (typeof document === "undefined") return null; - const v = document.querySelector('meta[name="docsiq-api-key"]')?.getAttribute("content"); - return v && v.length ? v : null; -} - async function rpc( sessionId: string | null, body: unknown, @@ -34,11 +28,14 @@ async function rpc( "Content-Type": "application/json", "Accept": "application/json, text/event-stream", }; - const bearer = getBearer(); - if (bearer) headers["Authorization"] = `Bearer ${bearer}`; if (sessionId) headers["Mcp-Session-Id"] = sessionId; - const res = await fetch("/mcp", { method: "POST", headers, body: JSON.stringify(body) }); + const res = await fetch("/mcp", { + method: "POST", + credentials: "include", + headers, + body: JSON.stringify(body), + }); const newSession = res.headers.get("Mcp-Session-Id") ?? sessionId; const text = await res.text(); diff --git a/ui/src/lib/__tests__/api-client.test.ts b/ui/src/lib/__tests__/api-client.test.ts index c047495..6cc24c4 100644 --- a/ui/src/lib/__tests__/api-client.test.ts +++ b/ui/src/lib/__tests__/api-client.test.ts @@ -1,7 +1,7 @@ -import { describe, it, expect } from "vitest"; +import { describe, it, expect, vi } from "vitest"; import { http, HttpResponse } from "msw"; import { server } from "@/test/msw"; -import { apiFetch, ApiErrorResponse } from "../api-client"; +import { apiFetch, ApiErrorResponse, initAuth } from "../api-client"; describe("apiFetch", () => { it("returns parsed json on 200", async () => { @@ -32,4 +32,39 @@ describe("apiFetch", () => { const r = await apiFetch("/api/x", { method: "DELETE" }); expect(r).toBeUndefined(); }); + + it("sends credentials: 'include' on every fetch", async () => { + const spy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("{}", { status: 200, headers: { "content-type": "application/json" } }), + ); + await apiFetch("/api/stats"); + const init = (spy.mock.calls[0][1] ?? {}) as RequestInit; + expect(init.credentials).toBe("include"); + spy.mockRestore(); + }); + + it("does not set Authorization header on data-path fetch even when a key exists in a meta tag", async () => { + const meta = document.createElement("meta"); + meta.setAttribute("name", "docsiq-api-key"); + meta.setAttribute("content", "s3cret"); + document.head.appendChild(meta); + // Spy must be installed BEFORE initAuth() so the /api/session exchange + // is captured by the mock and not passed through to MSW (which has no + // handler for it and would throw). + const spy = vi.spyOn(globalThis, "fetch").mockResolvedValue( + new Response("{}", { status: 200, headers: { "content-type": "application/json" } }), + ); + try { + initAuth(); + await apiFetch("/api/stats"); + const statsCall = spy.mock.calls.find((c) => c[0] === "/api/stats"); + expect(statsCall).toBeDefined(); + const init = (statsCall![1] ?? {}) as RequestInit; + const hdrs = new Headers(init.headers); + expect(hdrs.has("Authorization")).toBe(false); + } finally { + spy.mockRestore(); + if (meta.parentElement) document.head.removeChild(meta); + } + }); }); diff --git a/ui/src/lib/api-client.ts b/ui/src/lib/api-client.ts index b462981..e3e551e 100644 --- a/ui/src/lib/api-client.ts +++ b/ui/src/lib/api-client.ts @@ -1,16 +1,38 @@ import type { ApiError } from "@/types/api"; -let bearer: string | null = null; +// Before cookies are set the first time, the UI may have been shipped a +// one-shot bearer via the meta tag (legacy). We exchange it for a cookie +// exactly once, then never read or send the key again. If no meta tag +// exists (production path), we rely entirely on cookies already set by +// the operator's OOB provisioning (e.g. `docsiq login`). +let sessionReady: Promise | null = null; -function readBearerFromMeta(): string | null { +function readOneShotBearerFromMeta(): string | null { if (typeof document === "undefined") return null; const m = document.querySelector('meta[name="docsiq-api-key"]'); const v = m?.getAttribute("content"); return v && v.length > 0 ? v : null; } -export function initAuth() { - bearer = readBearerFromMeta(); +async function establishSession(bearer: string): Promise { + const m = document.querySelector('meta[name="docsiq-api-key"]'); + m?.parentElement?.removeChild(m); + + const res = await fetch("/api/session", { + method: "POST", + credentials: "include", + headers: { Authorization: `Bearer ${bearer}` }, + }); + if (!res.ok) { + let body: ApiError = { error: `HTTP ${res.status}` }; + try { body = await res.json(); } catch { /* non-json */ } + throw new ApiErrorResponse(res.status, body); + } +} + +export function initAuth(): void { + const bearer = readOneShotBearerFromMeta(); + sessionReady = bearer ? establishSession(bearer) : Promise.resolve(); } export class ApiErrorResponse extends Error { @@ -27,15 +49,19 @@ export async function apiFetch( path: string, init: RequestInit = {}, ): Promise { + if (sessionReady) await sessionReady; const headers = new Headers(init.headers); - if (bearer) headers.set("Authorization", `Bearer ${bearer}`); if (init.body && !headers.has("Content-Type")) { headers.set("Content-Type", "application/json"); } - const res = await fetch(path, { ...init, headers }); + const res = await fetch(path, { ...init, headers, credentials: "include" }); if (!res.ok) { let body: ApiError = { error: `HTTP ${res.status}` }; - try { body = await res.json(); } catch { /* non-json */ } + try { + body = await res.json(); + } catch { + /* non-json */ + } throw new ApiErrorResponse(res.status, body); } if (res.status === 204) return undefined as T;