diff --git a/pkg/ratelimit/middleware_test.go b/pkg/ratelimit/middleware_test.go index ed76e72e0c..b2adf34624 100644 --- a/pkg/ratelimit/middleware_test.go +++ b/pkg/ratelimit/middleware_test.go @@ -13,11 +13,17 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) // dummyLimiter is a test double for the Limiter interface. @@ -208,3 +214,44 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) { assert.Equal(t, "echo", recorder.toolName) assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID") } + +func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) { + t.Parallel() + + expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}}) + mw := &rateLimitMiddleware{handler: expected} + + assert.NotNil(t, mw.Handler()) +} + +func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{ + Namespace: "default", + ServerName: "server", + RedisAddr: mr.Addr(), + Config: &v1beta1.RateLimitConfig{ + Shared: &v1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + runner := transportmocks.NewMockMiddlewareRunner(ctrl) + var registered transporttypes.Middleware + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})). + Do(func(_ string, middleware transporttypes.Middleware) { + registered = middleware + }) + + require.NoError(t, CreateMiddleware(cfg, runner)) + require.NotNil(t, registered) + require.NotNil(t, registered.Handler()) + require.NoError(t, registered.Close()) +} diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index a962f52f27..cd84370a8f 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -36,13 +36,14 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" - "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" + authfactory "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + ratelimitfactory "github.com/stacklok/toolhive/pkg/vmcp/ratelimit/factory" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -367,13 +368,31 @@ func Serve(ctx context.Context, cfg ServeConfig) error { } authMiddleware, authzMiddleware, authInfoHandler, err := - factory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider) + authfactory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } slog.Info(fmt.Sprintf("Incoming authentication configured: %s", vmcpCfg.IncomingAuth.Type)) + namespace := vmcpNamespace() + rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{ + Namespace: namespace, + ServerName: vmcpCfg.Name, + RateLimiting: vmcpCfg.RateLimiting, + SessionStorage: vmcpCfg.SessionStorage, + }) + if err != nil { + return fmt.Errorf("failed to create rate limit middleware: %w", err) + } + if rateLimitCleanup != nil { + defer func() { + if closeErr := rateLimitCleanup(context.Background()); closeErr != nil { + slog.Error(fmt.Sprintf("failed to close rate limit middleware: %v", closeErr)) + } + }() + } + serverCfg := &vmcpserver.Config{ Name: vmcpCfg.Name, Version: versions.Version, @@ -384,6 +403,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error { AuthMiddleware: authMiddleware, AuthzMiddleware: authzMiddleware, AuthInfoHandler: authInfoHandler, + RateLimitMiddleware: rateLimitMiddleware, AuthServer: embeddedAuthServer, TelemetryProvider: telemetryProvider, AuditConfig: vmcpCfg.Audit, @@ -529,6 +549,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) { return cfg, nil } +func vmcpNamespace() string { + namespace := os.Getenv("VMCP_NAMESPACE") + if namespace == "" { + return "local" + } + return namespace +} + // loadAuthServerConfig loads the auth server RunConfig from a sibling file // alongside the main config. The operator serializes authserver.RunConfig as a // separate ConfigMap key (authserver-config.yaml). @@ -560,7 +588,7 @@ func discoverBackends( ) ([]vmcp.Backend, vmcp.BackendClient, vmcpauth.OutgoingAuthRegistry, error) { slog.Info("initializing outgoing authentication") envReader := &env.OSReader{} - outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, envReader) + outgoingRegistry, err := authfactory.NewOutgoingAuthRegistry(ctx, envReader) if err != nil { return nil, nil, nil, fmt.Errorf("failed to create outgoing authentication registry: %w", err) } diff --git a/pkg/vmcp/cli/serve_test.go b/pkg/vmcp/cli/serve_test.go index 667b285779..22b0260295 100644 --- a/pkg/vmcp/cli/serve_test.go +++ b/pkg/vmcp/cli/serve_test.go @@ -337,6 +337,20 @@ func TestValidateQuickModeHost(t *testing.T) { } } +func TestVMCPNamespace(t *testing.T) { + t.Run("defaults to local", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "") + + assert.Equal(t, "local", vmcpNamespace()) + }) + + t.Run("uses environment value", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "toolhive-system") + + assert.Equal(t, "toolhive-system", vmcpNamespace()) + }) +} + // TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the // discoverer succeeds but returns no backends. The function must return a // non-error, an empty (non-nil) backend slice, and pass through the client and diff --git a/pkg/vmcp/ratelimit/factory/middleware.go b/pkg/vmcp/ratelimit/factory/middleware.go new file mode 100644 index 0000000000..9280011428 --- /dev/null +++ b/pkg/vmcp/ratelimit/factory/middleware.go @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package factory builds vMCP-specific rate-limit middleware. +package factory + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math" + "net/http" + "os" + "time" + + "github.com/redis/go-redis/v9" + + "github.com/stacklok/toolhive/pkg/auth" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/ratelimit" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +const redisPingTimeout = 5 * time.Second + +// Config contains the vMCP rate-limit middleware inputs. +type Config struct { + Namespace string + ServerName string + RateLimiting *ratelimittypes.RateLimitConfig + SessionStorage *vmcpconfig.SessionStorageConfig +} + +// NewMiddleware creates Redis-backed rate-limit middleware for vMCP. +func NewMiddleware( + ctx context.Context, + cfg Config, +) (func(http.Handler) http.Handler, func(context.Context) error, error) { + if cfg.RateLimiting == nil { + return nil, nil, nil + } + if cfg.SessionStorage == nil || cfg.SessionStorage.Provider != "redis" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage") + } + if cfg.SessionStorage.Address == "" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address") + } + + client := redis.NewClient(&redis.Options{ + Addr: cfg.SessionStorage.Address, + DB: int(cfg.SessionStorage.DB), + Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar), + }) + + pingCtx, cancel := context.WithTimeout(ctx, redisPingTimeout) + defer cancel() + if err := client.Ping(pingCtx).Err(); err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w", + cfg.SessionStorage.Address, err) + } + + limiter, err := ratelimit.NewLimiter(client, cfg.Namespace, cfg.ServerName, cfg.RateLimiting) + if err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err) + } + + cleanup := func(context.Context) error { + return client.Close() + } + return rateLimitHandler(limiter), cleanup, nil +} + +func rateLimitHandler(limiter ratelimit.Limiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + parsed := mcpparser.GetParsedMCPRequest(r.Context()) + if parsed == nil || parsed.Method != "tools/call" { + next.ServeHTTP(w, r) + return + } + + var userID string + if identity, ok := auth.IdentityFromContext(r.Context()); ok { + userID = identity.Subject + } + decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID) + if err != nil { + slog.Warn("rate limit check failed, allowing request", "error", err) + next.ServeHTTP(w, r) + return + } + if !decision.Allowed { + writeRateLimited(w, parsed.ID, decision.RetryAfter) + return + } + next.ServeHTTP(w, r) + }) + } +} + +func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) { + retrySeconds := int(math.Ceil(retryAfter.Seconds())) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", fmt.Sprintf("%d", retrySeconds)) + w.WriteHeader(http.StatusTooManyRequests) + //nolint:gosec // G104: writing a static JSON error response to an HTTP client + _, _ = w.Write(rateLimitedBody(requestID, retryAfter)) +} + +func rateLimitedBody(requestID any, retryAfter time.Duration) []byte { + retrySeconds := math.Ceil(retryAfter.Seconds()) + resp := map[string]any{ + "jsonrpc": "2.0", + "error": map[string]any{ + "code": ratelimit.CodeRateLimited, + "message": ratelimit.MessageRateLimited, + "data": map[string]any{ + "retryAfterSeconds": retrySeconds, + }, + }, + "id": requestID, + } + data, err := json.Marshal(resp) + if err != nil { + return []byte(fmt.Sprintf( + `{"jsonrpc":"2.0","error":{"code":-32029,"message":"Rate limit exceeded","data":{"retryAfterSeconds":%.0f}},"id":null}`, + retrySeconds, + )) + } + return data +} diff --git a/pkg/vmcp/ratelimit/factory/middleware_test.go b/pkg/vmcp/ratelimit/factory/middleware_test.go new file mode 100644 index 0000000000..0b4d64b93b --- /dev/null +++ b/pkg/vmcp/ratelimit/factory/middleware_test.go @@ -0,0 +1,264 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package factory + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/stacklok/toolhive/pkg/auth" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/ratelimit" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +type recordingLimiter struct { + toolName string + userID string +} + +func (r *recordingLimiter) Allow(_ context.Context, toolName, userID string) (*ratelimit.Decision, error) { + r.toolName = toolName + r.userID = userID + return &ratelimit.Decision{Allowed: true}, nil +} + +func TestNewMiddlewareDisabledWithoutConfig(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + }) + + require.NoError(t, err) + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRequiresRedisSessionStorage(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRequiresRedisAddress(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage address") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRedisPingFailure(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + middleware, cleanup, err := NewMiddleware(ctx, Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: "127.0.0.1:1", + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to connect to Redis") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareInvalidRateLimitConfig(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 0, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create rate limiter") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestRateLimitMiddlewarePerUserSharedAcrossTools(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &ratelimittypes.RateLimitConfig{ + PerUser: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "alice") + assert.Equal(t, http.StatusOK, first.Code) + + second := serveToolCall(t, handler, "backend_b_echo", "alice") + assert.Equal(t, http.StatusTooManyRequests, second.Code) + assertRateLimitedBody(t, second) +} + +func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: "backend_a_echo", + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "") + assert.Equal(t, http.StatusOK, first.Code) + + otherTool := serveToolCall(t, handler, "backend_b_echo", "") + assert.Equal(t, http.StatusOK, otherTool.Code) + + secondMatchingTool := serveToolCall(t, handler, "backend_a_echo", "") + assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code) +} + +func TestRateLimitHandlerPassesParsedResourceIDAndUserID(t *testing.T) { + t.Parallel() + + recorder := &recordingLimiter{} + handler := rateLimitHandler(recorder)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", "backend_a_echo", 1) + req = withIdentity(req, "alice@example.com") + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "backend_a_echo", recorder.toolName) + assert.Equal(t, "alice@example.com", recorder.userID) +} + +func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) http.Handler { + t.Helper() + + mr := miniredis.RunT(t) + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: cfg, + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + }, + }) + require.NoError(t, err) + require.NotNil(t, middleware) + require.NotNil(t, cleanup) + t.Cleanup(func() { + require.NoError(t, cleanup(context.Background())) + }) + + return middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) +} + +func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string) *httptest.ResponseRecorder { + t.Helper() + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", toolName, 1) + if userID != "" { + req = withIdentity(req, userID) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + return w +} + +func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request { + parsed := &mcpparser.ParsedMCPRequest{ + Method: method, + ResourceID: resourceID, + ID: id, + IsRequest: true, + } + ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed) + return r.WithContext(ctx) +} + +func withIdentity(r *http.Request, subject string) *http.Request { + identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: subject}} + ctx := auth.WithIdentity(r.Context(), identity) + return r.WithContext(ctx) +} + +func sharedRateLimitConfig(maxTokens int32) *ratelimittypes.RateLimitConfig { + return &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: maxTokens, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + } +} + +func assertRateLimitedBody(t *testing.T, recorder *httptest.ResponseRecorder) { + t.Helper() + + var resp map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + errObj := resp["error"].(map[string]any) + assert.Equal(t, float64(ratelimit.CodeRateLimited), errObj["code"]) + assert.Equal(t, ratelimit.MessageRateLimited, errObj["message"]) +} diff --git a/pkg/vmcp/server/middleware_test.go b/pkg/vmcp/server/middleware_test.go new file mode 100644 index 0000000000..b2044cd31f --- /dev/null +++ b/pkg/vmcp/server/middleware_test.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestApplyRateLimitingWrapsConfiguredMiddleware(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{ + RateLimitMiddleware: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Rate-Limit-Test", "wrapped") + next.ServeHTTP(w, r) + }) + }, + }} + handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, "wrapped", rec.Header().Get("X-Rate-Limit-Test")) +} + +func TestApplyRateLimitingPassesThroughWhenDisabled(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{}} + handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index d070166600..dcdc0f856b 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -128,6 +128,10 @@ type Config struct { // Exposes OIDC discovery information about the protected resource. AuthInfoHandler http.Handler + // RateLimitMiddleware is the optional rate-limit middleware to apply after + // authentication and MCP request parsing. + RateLimitMiddleware func(http.Handler) http.Handler + // AuthServer is the optional embedded authorization server. // When non-nil, the routes returned by Routes() are registered on the mux // alongside the protected resource metadata endpoint. @@ -572,9 +576,9 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { } // MCP endpoint - apply middleware chain (wrapping order, execution happens in reverse): - // Code wraps: auth+parser → audit → discovery → annotation-enrichment → + // Code wraps: auth+parser → rate-limit → audit → discovery → annotation-enrichment → // authz → backend-enrichment → MCP-parsing → telemetry - // Execution order: recovery → header-val → auth+parser → audit → + // Execution order: recovery → header-val → auth+parser → rate-limit → audit → // discovery → annotation-enrichment → authz → backend-enrichment → // MCP-parsing → telemetry → handler @@ -652,6 +656,8 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Info("audit middleware enabled for MCP endpoints") } + mcpHandler = s.applyRateLimiting(mcpHandler) + // Apply authentication middleware if configured (runs first in chain) if s.config.AuthMiddleware != nil { mcpHandler = s.config.AuthMiddleware(mcpHandler) @@ -677,6 +683,14 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { return mux, nil } +func (s *Server) applyRateLimiting(next http.Handler) http.Handler { + if s.config.RateLimitMiddleware == nil { + return next + } + slog.Info("rate limit middleware enabled for MCP endpoints") + return s.config.RateLimitMiddleware(next) +} + // Start starts the Virtual MCP Server and begins serving requests. // //nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go new file mode 100644 index 0000000000..3a4f89f85b --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go @@ -0,0 +1,271 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package virtualmcp + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "os/exec" + "strings" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + mcpv1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +var _ = ginkgo.Describe("VirtualMCPServer Rate Limiting", ginkgo.Ordered, func() { + const ( + timeout = 5 * time.Minute + pollInterval = 2 * time.Second + oidcAudience = "vmcp-audience" + ) + + var ( + mcpGroupName string + backendName string + vmcpName string + redisName string + oidcName string + vmcpLocalPort int + oidcLocalPort int + vmcpPortForwardCleanup func() + oidcPortForwardCleanup func() + oidcCleanup func() + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpGroupName = fmt.Sprintf("e2e-rl-group-%d", ts) + backendName = fmt.Sprintf("e2e-rl-backend-%d", ts) + vmcpName = fmt.Sprintf("e2e-rl-vmcp-%d", ts) + redisName = fmt.Sprintf("e2e-rl-redis-%d", ts) + oidcName = fmt.Sprintf("e2e-rl-oidc-%d", ts) + + ginkgo.By("Deploying Redis") + deployRedis(redisName) + + ginkgo.By("Deploying parameterized OIDC server") + oidcIssuer, _, cleanup := DeployParameterizedOIDCServer( + ctx, k8sClient, oidcName, defaultNamespace, timeout, pollInterval, + ) + oidcCleanup = cleanup + var err error + oidcLocalPort, oidcPortForwardCleanup, err = startRateLimitServicePortForward(oidcName, 80) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + ginkgo.By("Creating MCPOIDCConfig") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPOIDCConfigSpec{ + Type: mcpv1beta1.MCPOIDCConfigTypeInline, + Inline: &mcpv1beta1.InlineOIDCSharedConfig{ + Issuer: oidcIssuer, + InsecureAllowHTTP: true, + JWKSAllowPrivateIP: true, + ProtectedResourceAllowPrivateIP: true, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Creating MCPGroup") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, defaultNamespace, + "E2E vMCP rate limiting group", timeout, pollInterval) + + ginkgo.By("Creating backend MCPServer") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for backend MCPServer to be ready") + gomega.Eventually(func() error { + server := &mcpv1beta1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: defaultNamespace, + }, server); err != nil { + return err + } + if server.Status.Phase != mcpv1beta1.MCPServerPhaseReady { + return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + ginkgo.By("Creating VirtualMCPServer with per-user rate limiting") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.VirtualMCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + RateLimiting: &mcpv1beta1.RateLimitConfig{ + PerUser: &mcpv1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + IncomingAuth: &mcpv1beta1.IncomingAuthConfig{ + Type: "oidc", + OIDCConfigRef: &mcpv1beta1.MCPOIDCConfigReference{ + Name: oidcName, + Audience: oidcAudience, + }, + }, + SessionStorage: &mcpv1beta1.SessionStorageConfig{ + Provider: mcpv1beta1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Port-forwarding VirtualMCPServer service") + vmcpLocalPort, vmcpPortForwardCleanup, err = startRateLimitServicePortForward(VMCPServiceName(vmcpName), 4483) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }) + + ginkgo.AfterAll(func() { + if vmcpPortForwardCleanup != nil { + vmcpPortForwardCleanup() + } + if oidcPortForwardCleanup != nil { + oidcPortForwardCleanup() + } + if oidcCleanup != nil { + oidcCleanup() + } + _ = k8sClient.Delete(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{Name: mcpGroupName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + }) + cleanupRedis(redisName) + }) + + ginkgo.It("rejects tools/call after the per-user limit is exceeded", func() { + token := fetchRateLimitOIDCToken(oidcLocalPort, "alice") + mcpClient := newRateLimitMCPClient(vmcpLocalPort, token) + defer mcpClient.Close() + + tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + toolName := firstEchoToolName(tools.Tools) + gomega.Expect(toolName).ToNot(gomega.BeEmpty()) + + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = map[string]any{"input": "ratelimittest"} + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).To(gomega.HaveOccurred()) + gomega.Expect(err.Error()).To(gomega.Or( + gomega.ContainSubstring("429"), + gomega.ContainSubstring("-32029"), + gomega.ContainSubstring("Rate limit exceeded"), + )) + }) +}) + +func fetchRateLimitOIDCToken(oidcPort int, subject string) string { + url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcPort, subject) + resp, err := http.Post(url, "application/x-www-form-urlencoded", nil) //nolint:noctx + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed()) + gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty()) + return tokenResp.AccessToken +} + +func newRateLimitMCPClient(vmcpPort int, token string) *mcpclient.Client { + httpClient := &http.Client{ + Transport: &authRoundTripper{token: token, transport: http.DefaultTransport}, + Timeout: 30 * time.Second, + } + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpPort) + return InitializeMCPClientWithRetries(serverURL, 2*time.Minute, transport.WithHTTPBasicClient(httpClient)) +} + +func startRateLimitServicePortForward(serviceName string, servicePort int32) (int, func(), error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, nil, fmt.Errorf("failed to find free local port: %w", err) + } + localPort := listener.Addr().(*net.TCPAddr).Port + _ = listener.Close() + + kubeconfigArg := fmt.Sprintf("--kubeconfig=%s", kubeconfig) + //nolint:gosec // kubeconfig, serviceName, and ports are test-controlled values. + cmd := exec.Command("kubectl", kubeconfigArg, + "-n", defaultNamespace, "port-forward", + fmt.Sprintf("svc/%s", serviceName), + fmt.Sprintf("%d:%d", localPort, servicePort)) + if err := cmd.Start(); err != nil { + return 0, nil, fmt.Errorf("failed to start port-forward to service %s: %w", serviceName, err) + } + + cleanup := func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + } + + for range 30 { + conn, dialErr := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", localPort), 500*time.Millisecond) + if dialErr == nil { + _ = conn.Close() + return localPort, cleanup, nil + } + time.Sleep(500 * time.Millisecond) + } + + cleanup() + return 0, nil, fmt.Errorf("port-forward to service %s never became ready on localhost:%d", serviceName, localPort) +} + +func firstEchoToolName(tools []mcp.Tool) string { + for _, tool := range tools { + if tool.Name == "echo" || strings.HasSuffix(tool.Name, "_echo") { + return tool.Name + } + } + return "" +}