Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"

"github.com/RandomCodeSpace/docsiq/internal/api"
"github.com/RandomCodeSpace/docsiq/internal/config"
"github.com/RandomCodeSpace/docsiq/internal/embedder"
"github.com/RandomCodeSpace/docsiq/internal/llm"
"github.com/RandomCodeSpace/docsiq/internal/project"
"github.com/RandomCodeSpace/docsiq/internal/sqlitevec"
"github.com/RandomCodeSpace/docsiq/internal/vectorindex"
"github.com/RandomCodeSpace/docsiq/internal/workq"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -140,11 +144,25 @@ var serveCmd = &cobra.Command{
}
}

workers := cfg.Server.WorkqWorkers
if workers <= 0 {
workers = runtime.NumCPU()
}
depth := cfg.Server.WorkqDepth
if depth <= 0 {
depth = 64
}
pool := workq.New(workq.Config{Workers: workers, QueueDepth: depth})

router := api.NewRouter(prov, emb, cfg, registry,
api.WithProjectStores(stores),
api.WithVectorIndexes(vecIndexes),
api.WithWorkq(pool),
)

if err := validateServeSecurity(cfg); err != nil {
return err
}
addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
ln, err := net.Listen("tcp", addr)
if err != nil {
Expand Down Expand Up @@ -177,6 +195,17 @@ var serveCmd = &cobra.Command{
slog.Error("❌ shutdown error", "err", err)
return err
}

// Drain workq within its own 30s deadline. Server.Shutdown has already
// stopped accepting new HTTP requests, so no new jobs can be submitted;
// all that remains is letting in-flight pipelines finish or honour the
// cancelled ctx.
drainCtx, drainCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer drainCancel()
if err := pool.Close(drainCtx); err != nil {
slog.Warn("⚠️ workq drain timeout; some indexing jobs were cancelled mid-flight", "err", err)
}

slog.Info("✅ shutdown complete")
return nil
},
Expand All @@ -187,3 +216,33 @@ func init() {
serveCmd.Flags().StringVar(&serveHost, "host", "", "Server host (overrides config)")
serveCmd.Flags().IntVar(&servePort, "port", 0, "Server port (overrides config)")
}

// validateServeSecurity refuses to start the server when the API key is
// empty AND the bind host is not loopback. An unauthenticated service
// exposed on the network is almost never intentional; make it explicit.
// Loopback with empty key gets a prominent warning at boot instead.
func validateServeSecurity(cfg *config.Config) error {
if cfg.Server.APIKey != "" {
return nil
}
host := strings.ToLower(strings.TrimSpace(cfg.Server.Host))
if host == "" {
return fmt.Errorf(
"server.api_key is empty and server.host is unset (binds all interfaces); refusing to start. " +
"Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev",
)
}
loopback := host == "localhost"
if ip := net.ParseIP(strings.Trim(host, "[]")); ip != nil {
loopback = loopback || ip.IsLoopback()
}
if !loopback {
return fmt.Errorf(
"server.api_key is empty and server.host=%q is not loopback; refusing to start. "+
"Set DOCSIQ_SERVER_API_KEY or bind to 127.0.0.1/localhost for dev",
cfg.Server.Host,
)
}
slog.Warn("⚠️ auth disabled (empty server.api_key); only loopback bind allowed", "host", host)
return nil
}
87 changes: 87 additions & 0 deletions cmd/serve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package cmd

import (
"strings"
"testing"

"github.com/RandomCodeSpace/docsiq/internal/config"
)

func TestValidateServeSecurity_RefusesNonLoopbackWithEmptyKey(t *testing.T) {
t.Parallel()
cases := []struct {
name string
host string
mustContain []string
}{
{
name: "empty host binds all interfaces",
host: "",
mustContain: []string{"binds all interfaces", "DOCSIQ_SERVER_API_KEY"},
},
{
name: "0.0.0.0 is wildcard",
host: "0.0.0.0",
mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"},
},
{
name: "public IPv4",
host: "10.0.0.5",
mustContain: []string{"api_key", "DOCSIQ_SERVER_API_KEY"},
},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
cfg := &config.Config{}
cfg.Server.Host = tc.host
cfg.Server.Port = 8080
cfg.Server.APIKey = ""

err := validateServeSecurity(cfg)
if err == nil {
t.Fatalf("host=%q: expected error for empty api_key on non-loopback bind; got nil", tc.host)
}
for _, want := range tc.mustContain {
if !strings.Contains(err.Error(), want) {
t.Errorf("host=%q: error should contain %q; got %v", tc.host, want, err)
}
}
})
}
}

func TestValidateServeSecurity_AllowsLoopbackWithEmptyKey(t *testing.T) {
t.Parallel()
hosts := []string{
"127.0.0.1",
"localhost",
"::1",
"[::1]", // bracketed IPv6
"127.0.0.2", // other address in 127.0.0.0/8
"::ffff:127.0.0.1", // IPv6-mapped IPv4
}
for _, host := range hosts {
host := host
t.Run(host, func(t *testing.T) {
t.Parallel()
cfg := &config.Config{}
cfg.Server.Host = host
cfg.Server.APIKey = ""
if err := validateServeSecurity(cfg); err != nil {
t.Fatalf("host=%s: expected nil; got %v", host, err)
}
})
}
}

func TestValidateServeSecurity_AllowsNonLoopbackWithKey(t *testing.T) {
t.Parallel()
cfg := &config.Config{}
cfg.Server.Host = "0.0.0.0"
cfg.Server.APIKey = "s3cret"
if err := validateServeSecurity(cfg); err != nil {
t.Fatalf("expected nil; got %v", err)
}
}
41 changes: 36 additions & 5 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,31 @@ func bearerAuthMiddleware(apiKey string, next http.Handler) http.Handler {
return
}

raw := strings.TrimSpace(r.Header.Get("Authorization"))
const prefix = "Bearer "
if !strings.HasPrefix(raw, prefix) {
// /api/session is the auth boundary itself — always public.
if path == "/api/session" {
next.ServeHTTP(w, r)
return
}

// Defense-in-depth: reject immediately if the server has no key
// configured. This mirrors newSessionHandler's guard and keeps the
// middleware correct under future refactors (rather than relying on
// the no_token branch firing because keyBytes would also be empty).
if apiKey == "" {
slog.Warn("🔒 auth failure", "path", path, "remote_addr", r.RemoteAddr, "reason", "server_misconfigured")
writeJSON401(w)
return
}

token := extractToken(r)
if token == "" {
slog.Warn("🔒 auth failure",
"path", path,
"remote_addr", r.RemoteAddr,
"reason", "no_bearer_prefix")
"reason", "no_token")
writeJSON401(w)
return
}
token := raw[len(prefix):]
if subtle.ConstantTimeCompare([]byte(token), keyBytes) != 1 {
slog.Warn("🔒 auth failure",
"path", path,
Expand All @@ -82,3 +96,20 @@ func writeJSON401(w http.ResponseWriter) {
w.WriteHeader(http.StatusUnauthorized)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
}

// extractToken returns the bearer token from either the Authorization
// header (preferred, for machine clients) or the session cookie (for
// browser clients after POST /api/session). Returns "" if neither.
func extractToken(r *http.Request) string {
raw := strings.TrimSpace(r.Header.Get("Authorization"))
const prefix = "Bearer "
if strings.HasPrefix(raw, prefix) {
return raw[len(prefix):]
}
if c, err := r.Cookie(sessionCookieName); err == nil {
if v := strings.TrimSpace(c.Value); v != "" {
return v
}
}
return ""
}
23 changes: 23 additions & 0 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,29 @@ func TestBearerAuthMiddleware(t *testing.T) {
})
}

func TestAuth_AcceptsValidCookie(t *testing.T) {
t.Parallel()
h := buildAuthHandler("s3cret")
req := httptest.NewRequest(http.MethodGet, "/api/ping", nil)
req.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "s3cret"})
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("want 200 with valid cookie, got %d", rr.Code)
}
}

func TestAuth_RejectsMissingBothHeaderAndCookie(t *testing.T) {
t.Parallel()
h := buildAuthHandler("s3cret")
req := httptest.NewRequest(http.MethodGet, "/api/ping", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Fatalf("want 401, got %d", rr.Code)
}
}

// BenchmarkBearerAuth_WrongKey ensures the constant-time compare path runs
// under the benchmark harness. It is NOT a hard timing assertion — if the
// code ever switches to a non-constant-time compare (==, bytes.Equal), the
Expand Down
Loading
Loading