diff --git a/cmd/thv/app/proxy.go b/cmd/thv/app/proxy.go index 80ca864a23..cc2a03c0ae 100644 --- a/cmd/thv/app/proxy.go +++ b/cmd/thv/app/proxy.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/oauthproto/tokenexchange" "github.com/stacklok/toolhive/pkg/transport" "github.com/stacklok/toolhive/pkg/transport/middleware" + "github.com/stacklok/toolhive/pkg/transport/middleware/origin" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -110,9 +111,10 @@ Dynamic client registration (automatic OAuth client setup): } var ( - proxyHost string - proxyPort int - proxyTargetURI string + proxyHost string + proxyPort int + proxyTargetURI string + proxyAllowedOrigins []string resourceURL string // Explicit resource URL for OAuth discovery endpoint (RFC 9728) @@ -133,6 +135,10 @@ const ( func init() { proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)") + proxyCmd.Flags().StringArrayVar(&proxyAllowedOrigins, "allowed-origins", nil, + "Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+ + "loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+ + "no value is supplied. Example: https://my-mcp.example.com") proxyCmd.Flags().StringVar( &proxyTargetURI, "target-uri", @@ -226,6 +232,22 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error { // Create middlewares slice for incoming request authentication var middlewares []types.NamedMiddleware + // Origin-header validation (DNS-rebinding protection per MCP 2025-11-25 + // §"Security Warning"). Added first so disallowed Origins are rejected + // before authentication or any outbound token acquisition runs. + if allowed := origin.ResolveAllowedOrigins(proxyHost, port, proxyAllowedOrigins); len(allowed) > 0 { + middlewares = append(middlewares, types.NamedMiddleware{ + Name: origin.MiddlewareType, + Function: origin.NewHandler(allowed), + }) + } else { + slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind", + "host", proxyHost, + "port", port, + "hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection", + ) + } + // Get OIDC configuration if enabled (for protecting the proxy endpoint) oidcConfig := getProxyOIDCConfig(cmd) diff --git a/cmd/thv/app/run_flags.go b/cmd/thv/app/run_flags.go index 4a4ae37504..46aa4b9753 100644 --- a/cmd/thv/app/run_flags.go +++ b/cmd/thv/app/run_flags.go @@ -141,6 +141,12 @@ type RunFlags struct { RemoteForwardHeaders []string RemoteForwardHeadersSecret []string + // AllowedOrigins is the HTTP Origin-header allowlist for DNS-rebinding protection + // (MCP 2025-11-25 §"Security Warning"). Empty with a loopback host auto-derives + // loopback-only defaults; empty with a non-loopback host disables the check + // (operator must supply explicit origins for public bind). + AllowedOrigins []string + // Runtime configuration RuntimeImage string RuntimeAddPackages []string @@ -160,6 +166,10 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) { cmd.Flags().StringVar(&config.Name, "name", "", "Name of the MCP server (default to auto-generated from image)") cmd.Flags().StringVar(&config.Group, "group", "default", "Name of the group this workload should belong to") cmd.Flags().StringVar(&config.Host, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)") + cmd.Flags().StringArrayVar(&config.AllowedOrigins, "allowed-origins", nil, + "Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+ + "loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+ + "no value is supplied. Example: https://my-mcp.example.com") cmd.Flags().IntVar(&config.ProxyPort, "proxy-port", 0, "Port for the HTTP proxy to listen on (host port)") cmd.Flags().IntVar(&config.TargetPort, "target-port", 0, "Port for the container to expose (only applicable to SSE or Streamable HTTP transport)") @@ -685,6 +695,7 @@ func buildRunnerConfig( PrintOverlays: runFlags.PrintOverlays, }), runner.WithPublish(runFlags.Publish), + runner.WithAllowedOrigins(runFlags.AllowedOrigins), } opts = append(opts, extraOpts...) diff --git a/docs/cli/thv_proxy.md b/docs/cli/thv_proxy.md index be2e8d92d2..6cbc09c22e 100644 --- a/docs/cli/thv_proxy.md +++ b/docs/cli/thv_proxy.md @@ -97,6 +97,7 @@ thv proxy [flags] SERVER_NAME ### Options ``` + --allowed-origins stringArray Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; loopback binds derive a default allowlist automatically, non-loopback binds log a warning when no value is supplied. Example: https://my-mcp.example.com -h, --help help for proxy --host string Host for the HTTP proxy to listen on (IP or hostname) (default "127.0.0.1") --oidc-audience string Expected audience for the token diff --git a/docs/cli/thv_run.md b/docs/cli/thv_run.md index b69d7b6cd1..03be31ffa9 100644 --- a/docs/cli/thv_run.md +++ b/docs/cli/thv_run.md @@ -112,6 +112,7 @@ thv run [flags] SERVER_OR_IMAGE_OR_PROTOCOL [-- ARGS...] ``` --allow-docker-gateway Allow outbound connections to Docker gateway addresses (host.docker.internal, gateway.docker.internal, 172.17.0.1). Only applies when --isolate-network is set. These are blocked by default even when insecure_allow_all is enabled. + --allowed-origins stringArray Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; loopback binds derive a default allowlist automatically, non-loopback binds log a warning when no value is supplied. Example: https://my-mcp.example.com --audit-config string Path to the audit configuration file --authz-config string Path to the authorization configuration file --ca-cert string Path to a custom CA certificate file to use for container builds diff --git a/docs/server/docs.go b/docs/server/docs.go index 9ca280736f..06e0ab5631 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -1237,6 +1237,14 @@ const docTemplate = `{ "description": "AllowDockerGateway permits outbound connections to Docker gateway addresses\n(host.docker.internal, gateway.docker.internal, 172.17.0.1). These are\nblocked by default in the egress proxy even when InsecureAllowAll is set.\nOnly applicable to Docker deployments with network isolation enabled.", "type": "boolean" }, + "allowed_origins": { + "description": "AllowedOrigins is the allowlist of values accepted on the HTTP Origin header,\nused for DNS-rebinding protection per MCP 2025-11-25 §\"Security Warning\".\nWhen empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default\nloopback-only allowlist is derived at middleware-wiring time.\nWhen empty and Host is non-loopback, the middleware is disabled — operators\nexposing the proxy publicly must configure an explicit allowlist.", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "audit_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config" }, diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 8127a52c45..62c0ca2f79 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -1230,6 +1230,14 @@ "description": "AllowDockerGateway permits outbound connections to Docker gateway addresses\n(host.docker.internal, gateway.docker.internal, 172.17.0.1). These are\nblocked by default in the egress proxy even when InsecureAllowAll is set.\nOnly applicable to Docker deployments with network isolation enabled.", "type": "boolean" }, + "allowed_origins": { + "description": "AllowedOrigins is the allowlist of values accepted on the HTTP Origin header,\nused for DNS-rebinding protection per MCP 2025-11-25 §\"Security Warning\".\nWhen empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default\nloopback-only allowlist is derived at middleware-wiring time.\nWhen empty and Host is non-loopback, the middleware is disabled — operators\nexposing the proxy publicly must configure an explicit allowlist.", + "items": { + "type": "string" + }, + "type": "array", + "uniqueItems": false + }, "audit_config": { "$ref": "#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config" }, diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6bbded4bce..f09ab6e5ff 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -1254,6 +1254,18 @@ components: blocked by default in the egress proxy even when InsecureAllowAll is set. Only applicable to Docker deployments with network isolation enabled. type: boolean + allowed_origins: + description: |- + AllowedOrigins is the allowlist of values accepted on the HTTP Origin header, + used for DNS-rebinding protection per MCP 2025-11-25 §"Security Warning". + When empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default + loopback-only allowlist is derived at middleware-wiring time. + When empty and Host is non-loopback, the middleware is disabled — operators + exposing the proxy publicly must configure an explicit allowlist. + items: + type: string + type: array + uniqueItems: false audit_config: $ref: '#/components/schemas/github_com_stacklok_toolhive_pkg_audit.Config' audit_config_path: diff --git a/pkg/runner/config.go b/pkg/runner/config.go index 399cff3729..50183f8ba0 100644 --- a/pkg/runner/config.go +++ b/pkg/runner/config.go @@ -103,6 +103,14 @@ type RunConfig struct { // TargetHost is the host to forward traffic to (only applicable to SSE transport) TargetHost string `json:"target_host,omitempty" yaml:"target_host,omitempty"` + // AllowedOrigins is the allowlist of values accepted on the HTTP Origin header, + // used for DNS-rebinding protection per MCP 2025-11-25 §"Security Warning". + // When empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default + // loopback-only allowlist is derived at middleware-wiring time. + // When empty and Host is non-loopback, the middleware is disabled — operators + // exposing the proxy publicly must configure an explicit allowlist. + AllowedOrigins []string `json:"allowed_origins,omitempty" yaml:"allowed_origins,omitempty"` + // Publish lists ports to publish to the host in format "hostPort:containerPort" Publish []string `json:"publish,omitempty" yaml:"publish,omitempty"` diff --git a/pkg/runner/config_builder.go b/pkg/runner/config_builder.go index 97d9ebd5b5..b5df34ba78 100644 --- a/pkg/runner/config_builder.go +++ b/pkg/runner/config_builder.go @@ -331,6 +331,18 @@ func WithAllowDockerGateway(allow bool) RunConfigBuilderOption { } } +// WithAllowedOrigins sets the HTTP Origin-header allowlist used for +// DNS-rebinding protection (MCP 2025-11-25 §"Security Warning"). +// An empty slice defers the choice to middleware wiring, which derives a +// loopback-only default when the bind host is loopback and otherwise leaves +// the middleware disabled. +func WithAllowedOrigins(origins []string) RunConfigBuilderOption { + return func(b *runConfigBuilder) error { + b.config.AllowedOrigins = origins + return nil + } +} + // WithTrustProxyHeaders sets whether to trust X-Forwarded-* headers from reverse proxies func WithTrustProxyHeaders(trust bool) RunConfigBuilderOption { return func(b *runConfigBuilder) error { diff --git a/pkg/runner/middleware.go b/pkg/runner/middleware.go index 068d35a455..4164024461 100644 --- a/pkg/runner/middleware.go +++ b/pkg/runner/middleware.go @@ -5,6 +5,7 @@ package runner import ( "fmt" + "log/slog" "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" @@ -21,6 +22,7 @@ import ( "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware" + "github.com/stacklok/toolhive/pkg/transport/middleware/origin" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/usagemetrics" "github.com/stacklok/toolhive/pkg/webhook/mutating" @@ -45,6 +47,7 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory { audit.MiddlewareType: audit.CreateMiddleware, recovery.MiddlewareType: recovery.CreateMiddleware, headerfwd.HeaderForwardMiddlewareName: headerfwd.CreateMiddleware, + origin.MiddlewareType: origin.CreateMiddleware, validating.MiddlewareType: validating.CreateMiddleware, mutating.MiddlewareType: mutating.CreateMiddleware, } @@ -57,14 +60,20 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory { func PopulateMiddlewareConfigs(config *RunConfig) error { var middlewareConfigs []types.MiddlewareConfig // TODO: Consider extracting other middleware setup into helper functions like addUsageMetricsMiddleware + // + // NOTE: Origin-validation middleware is intentionally NOT added here. It is + // wired centrally in runner.Run (via prependOriginMiddleware) for both the + // operator/proxyrunner path (this function) and the CLI path + // (WithMiddlewareFromFlags), because that is the only place where the + // effective Host/Port/AllowedOrigins are fully resolved. // Authentication middleware (always present) authParams := auth.MiddlewareParams{ OIDCConfig: config.OIDCConfig, } - authConfig, err := types.NewMiddlewareConfig(auth.MiddlewareType, authParams) - if err != nil { - return fmt.Errorf("failed to create auth middleware config: %w", err) + authConfig, authErr := types.NewMiddlewareConfig(auth.MiddlewareType, authParams) + if authErr != nil { + return fmt.Errorf("failed to create auth middleware config: %w", authErr) } middlewareConfigs = append(middlewareConfigs, *authConfig) @@ -72,7 +81,7 @@ func PopulateMiddlewareConfigs(config *RunConfig) error { // This exchanges ToolHive JWTs for upstream IdP tokens when embedded auth server is used. // IMPORTANT: Must run BEFORE token exchange middleware so it can read the `tsid` claim // from the original ToolHive JWT before any token modification occurs. - middlewareConfigs, err = addUpstreamSwapMiddleware(middlewareConfigs, config) + middlewareConfigs, err := addUpstreamSwapMiddleware(middlewareConfigs, config) if err != nil { return err } @@ -421,6 +430,45 @@ func addAWSStsMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig return append(middlewares, *awsStsMwConfig), nil } +// prependOriginMiddleware prepends Origin-header validation middleware for +// DNS-rebind protection per MCP 2025-11-25 §"Security Warning". It is placed at +// the front of the chain so disallowed Origin values are rejected before +// authentication or any business logic runs. Default-derivation logic lives in +// origin.ResolveAllowedOrigins so the standalone `thv proxy` command and the +// runner path agree on behavior. +// +// This is called from runner.Run after both middleware-population paths +// (PopulateMiddlewareConfigs and WithMiddlewareFromFlags) have run, because +// that is the only point where the effective Host/Port/AllowedOrigins are +// fully resolved — the CLI builder defers port resolution to validateConfig. +// +// When the effective allowlist is empty — which happens when the operator +// binds to a non-loopback host without supplying --allowed-origins — the +// middleware is skipped entirely and a WARN is logged so the security-disabled +// state is visible in operator logs. A follow-up PR hardens the non-loopback +// path by requiring an explicit opt-in flag (see audit row 22). +func prependOriginMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) { + allowed := origin.ResolveAllowedOrigins(config.Host, config.Port, config.AllowedOrigins) + if len(allowed) == 0 { + slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind", + "host", config.Host, + "port", config.Port, + "hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection", + ) + return middlewares, nil + } + + params := origin.MiddlewareParams{AllowedOrigins: allowed} + mwCfg, err := types.NewMiddlewareConfig(origin.MiddlewareType, params) + if err != nil { + return nil, fmt.Errorf("failed to create origin middleware config: %w", err) + } + // Prepend so Origin validation is the outermost wrapper (runs first at + // request time). Build a new slice to avoid mutating the caller's backing + // array. + return append([]types.MiddlewareConfig{*mwCfg}, middlewares...), nil +} + // addRateLimitMiddleware adds rate limit middleware if configured. func addRateLimitMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) { if config.RateLimitConfig == nil { diff --git a/pkg/runner/middleware_test.go b/pkg/runner/middleware_test.go index 19ec497c6a..37b2265a6d 100644 --- a/pkg/runner/middleware_test.go +++ b/pkg/runner/middleware_test.go @@ -29,6 +29,7 @@ import ( "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware" + "github.com/stacklok/toolhive/pkg/transport/middleware/origin" "github.com/stacklok/toolhive/pkg/transport/types" "github.com/stacklok/toolhive/pkg/webhook" "github.com/stacklok/toolhive/pkg/webhook/mutating" @@ -123,6 +124,65 @@ func TestAddHeaderForwardMiddleware(t *testing.T) { } } +func TestPrependOriginMiddleware(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *RunConfig + wantPrepended bool + wantAllowedCount int + }{ + { + name: "non-loopback bind without explicit allowlist skips middleware", + config: &RunConfig{Host: "0.0.0.0", Port: 8080}, + wantPrepended: false, + }, + { + name: "zero port skips middleware", + config: &RunConfig{Host: "127.0.0.1", Port: 0}, + wantPrepended: false, + }, + { + name: "loopback bind derives default allowlist and prepends", + config: &RunConfig{Host: "127.0.0.1", Port: 8080}, + wantPrepended: true, + wantAllowedCount: 3, // localhost + 127.0.0.1 + [::1] + }, + { + name: "explicit allowlist on non-loopback bind prepends", + config: &RunConfig{Host: "0.0.0.0", Port: 8080, AllowedOrigins: []string{"https://app.example.com"}}, + wantPrepended: true, + wantAllowedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Seed with an existing entry so we can prove origin is prepended, + // not appended — the security intent requires it to run first. + initial := []types.MiddlewareConfig{{Type: auth.MiddlewareType}} + got, err := prependOriginMiddleware(initial, tt.config) + require.NoError(t, err) + + if !tt.wantPrepended { + assert.Equal(t, initial, got, "middleware slice should be unchanged") + return + } + + require.Len(t, got, len(initial)+1) + assert.Equal(t, origin.MiddlewareType, got[0].Type, "origin middleware must be first in the chain") + assert.Equal(t, auth.MiddlewareType, got[1].Type, "pre-existing middleware must follow origin") + + var params origin.MiddlewareParams + require.NoError(t, json.Unmarshal(got[0].Parameters, ¶ms)) + assert.Len(t, params.AllowedOrigins, tt.wantAllowedCount) + }) + } +} + func TestPopulateMiddlewareConfigs_HeaderForward(t *testing.T) { t.Parallel() diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 079a7d88db..c6ace9507d 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -265,6 +265,20 @@ func (r *Runner) Run(ctx context.Context) error { } } + // Origin-header validation (DNS-rebinding protection per MCP 2025-11-25 + // §"Security Warning") is wired here, after both middleware-population + // paths, because it is the single place where Host/Port/AllowedOrigins are + // fully resolved: the CLI builder (WithMiddlewareFromFlags) defers port + // resolution to validateConfig, so the effective port is not known at + // builder time. Prepending keeps Origin validation at the front of the + // chain so disallowed Origins are rejected before authentication or any + // business logic runs. + var err error + r.Config.MiddlewareConfigs, err = prependOriginMiddleware(r.Config.MiddlewareConfigs, r.Config) + if err != nil { + return fmt.Errorf("failed to add origin middleware: %w", err) + } + // Initialize embedded auth server if configured. // This must happen before middleware creation so that the upstream token // service is available to middleware factories (e.g., upstreamswap). diff --git a/pkg/transport/middleware/origin/origin.go b/pkg/transport/middleware/origin/origin.go new file mode 100644 index 0000000000..302264751d --- /dev/null +++ b/pkg/transport/middleware/origin/origin.go @@ -0,0 +1,277 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package origin provides HTTP middleware that enforces MCP Origin header +// validation (DNS-rebinding protection) per MCP 2025-11-25 §"Security Warning" +// (https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#security-warning). +// +// When the Origin header is present on an inbound request, it MUST exactly +// match one of the configured allowed origins. Otherwise the middleware +// responds with HTTP 403 and a JSON-RPC error body. Requests without an +// Origin header (typical for non-browser clients) are permitted through. +package origin + +import ( + "encoding/json" + "fmt" + "log/slog" + "maps" + "net" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/stacklok/toolhive/pkg/transport/types" +) + +const ( + // MiddlewareType is the type identifier registered in the middleware factory map. + MiddlewareType = "origin" + + // jsonRPCCodeInvalidRequest is the JSON-RPC 2.0 error code for an invalid + // request. We reuse it for rejected Origin values because the request is + // not well-formed from the server's security policy perspective. + jsonRPCCodeInvalidRequest int64 = -32600 + + // forbiddenBodyFallback is returned if JSON marshalling of the error body + // fails (should never happen with simple map types). + forbiddenBodyFallback = `{"jsonrpc":"2.0","error":{"code":-32600,"message":"Origin not allowed"},"id":null}` +) + +// MiddlewareParams holds the parameters for the origin middleware factory. +type MiddlewareParams struct { + // AllowedOrigins is the exact-match allowlist of acceptable Origin values. + // An empty list disables the middleware (requests pass through unchanged). + AllowedOrigins []string `json:"allowed_origins"` +} + +// FactoryMiddleware wraps origin-validation as a factory-pattern middleware. +type FactoryMiddleware struct { + handler types.MiddlewareFunction +} + +// Handler returns the middleware function used by the proxy. +func (m *FactoryMiddleware) Handler() types.MiddlewareFunction { + return m.handler +} + +// Close releases any resources held by the middleware. +func (*FactoryMiddleware) Close() error { + return nil +} + +// CreateMiddleware is the factory function registered in +// runner.GetSupportedMiddlewareFactories. +// +// If params.AllowedOrigins is empty the factory still registers a pass-through +// handler so the middleware slot is occupied, but logs at Warn level to make +// the security-disabled state visible in operator logs. Callers that want to +// avoid registration entirely should skip calling this factory (see +// pkg/runner.prependOriginMiddleware). +func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRunner) error { + var params MiddlewareParams + if err := json.Unmarshal(config.Parameters, ¶ms); err != nil { + return fmt.Errorf("failed to unmarshal origin middleware parameters: %w", err) + } + + if len(params.AllowedOrigins) == 0 { + slog.Warn("origin middleware registered with empty allowlist; Origin validation disabled") + } + + handler := NewHandler(params.AllowedOrigins) + runner.AddMiddleware(MiddlewareType, &FactoryMiddleware{handler: handler}) + return nil +} + +// NewHandler returns a middleware function that enforces Origin header +// validation against the provided allowlist. It is the single entry point used +// by both the factory path (CreateMiddleware) and callers that build their +// middleware chain directly (e.g. `thv proxy`). +// +// What this solves: DNS-rebinding protection per MCP 2025-11-25 §"Security +// Warning" — requests whose Origin header is present and not in allowedOrigins +// receive HTTP 403 with a JSON-RPC error body. +// +// What this does NOT solve: CORS, CSRF token validation, authentication, or +// Origin-header injection via trusted reverse proxies (the caller's reverse +// proxy must deduplicate Origin headers upstream). +// +// An empty allowedOrigins slice produces a pass-through handler — the caller +// is responsible for deciding whether that is acceptable (e.g. when bind is +// loopback-only and the caller derived an allowlist via ResolveAllowedOrigins). +// +// Matching rules: exact match on byte representation except that the scheme +// and host portions of the Origin value are lowercased (RFC 6454 §4: scheme +// and host are ASCII-case-insensitive). Configured allowlist entries are +// canonicalized once at construction time. +func NewHandler(allowedOrigins []string) types.MiddlewareFunction { + if len(allowedOrigins) == 0 { + return func(next http.Handler) http.Handler { return next } + } + + // Build a set for O(1) lookups. Entries are canonicalized so that + // case-variant Origin values (RFC 6454 §4 makes scheme + host case- + // insensitive) match predictably. Preserve the sorted list for logging. + allowedSet := make(map[string]struct{}, len(allowedOrigins)) + for _, o := range allowedOrigins { + allowedSet[canonicalizeOrigin(o)] = struct{}{} + } + slog.Debug("origin middleware configured", + "allowed_origin_count", len(allowedSet), + "allowed_origins", slices.Sorted(maps.Keys(allowedSet)), + ) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Reject requests with multiple Origin headers outright — the + // Fetch spec defines Origin as a single-value header and browsers + // never legitimately send more than one. Splitting / merging at an + // upstream proxy is the only way this fires. + if values := r.Header.Values("Origin"); len(values) > 1 { + slog.Warn("rejecting request with multiple Origin headers", + "count", len(values), + "method", r.Method, + "path", r.URL.Path, + "remote", r.RemoteAddr, + ) + writeForbidden(w) + return + } + + origin := r.Header.Get("Origin") + if origin == "" { + // MCP spec §"Security Warning" only mandates validation when + // the header is present. Non-browser clients (stdio bridges, + // SDK clients) typically omit Origin entirely. + next.ServeHTTP(w, r) + return + } + if _, ok := allowedSet[canonicalizeOrigin(origin)]; !ok { + slog.Warn("rejecting request with disallowed Origin", + "origin", origin, + "method", r.Method, + "path", r.URL.Path, + "remote", r.RemoteAddr, + ) + writeForbidden(w) + return + } + next.ServeHTTP(w, r) + }) + } +} + +// canonicalizeOrigin normalizes an Origin value for exact-match comparison. +// It parses the value with net/url.Parse and rebuilds it as +// "scheme://host[:port]" with the scheme and host lowercased (RFC 6454 §4 +// makes both ASCII-case-insensitive) and the port preserved verbatim. Using +// the standard parser handles IPv6 bracket literals, percent-encoding, and the +// host/port split correctly instead of hand-rolling them. +// +// Per RFC 6454 §6, a serialized origin has no userinfo, path, query, or +// fragment. Any input that carries userinfo (e.g. "https://user:pass@host"), +// or that net/url cannot parse, or that lacks a scheme or host, is returned +// with a "\x00invalid:" sentinel prefix so it can never collide with a +// legitimate allowlist entry. This makes such values fail closed: a malformed +// configured entry will not match anything, and a malformed request Origin +// will not match the allowlist. +func canonicalizeOrigin(raw string) string { + if raw == "" { + return raw + } + const invalid = "\x00invalid:" // sentinel that no real Origin can produce + u, err := url.Parse(raw) + if err != nil { + return invalid + raw + } + // A serialized origin (RFC 6454 §6) carries no userinfo and must have both + // a scheme and a host. Reject anything else to avoid ambiguous matches. + if u.Scheme == "" || u.Host == "" || u.User != nil { + return invalid + raw + } + scheme := strings.ToLower(u.Scheme) + // u.Hostname() strips IPv6 brackets and the port; u.Port() returns the port + // (possibly empty). Re-add brackets for IPv6 literals (those containing ":"). + host := strings.ToLower(u.Hostname()) + if strings.Contains(host, ":") { + host = "[" + host + "]" + } + if port := u.Port(); port != "" { + return scheme + "://" + host + ":" + port + } + return scheme + "://" + host +} + +// ResolveAllowedOrigins picks the effective Origin allowlist for a proxy +// listener. Resolution order: +// 1. If explicit is non-empty, use it verbatim. +// 2. Otherwise, if host is a loopback IP or the string "localhost", and port +// is valid, return loopback-only defaults +// (http://localhost:PORT, http://127.0.0.1:PORT, http://[::1]:PORT). +// 3. Otherwise, return nil — operators exposing the proxy publicly must +// configure an explicit allowlist. +// +// Shared by the runner middleware-config helper (pkg/runner) and the +// standalone `thv proxy` command to keep the default-derivation logic in one +// place; exported because the `thv proxy` call site is outside the runner +// package and cannot reach an internal helper. +// +// What this does NOT solve: it does not validate that `explicit` entries are +// well-formed Origin values. Callers that pass operator-supplied slices must +// rely on the middleware's canonical matching to either accept or reject +// malformed entries at request time (they will simply fail to match). +func ResolveAllowedOrigins(host string, port int, explicit []string) []string { + if len(explicit) > 0 { + return explicit + } + if port <= 0 { + return nil + } + if !isLoopbackHost(host) { + return nil + } + return []string{ + fmt.Sprintf("http://localhost:%d", port), + fmt.Sprintf("http://127.0.0.1:%d", port), + fmt.Sprintf("http://[::1]:%d", port), + } +} + +// isLoopbackHost reports whether host refers to a loopback address. Accepts +// the literal string "localhost" plus any IP literal that net.ParseIP +// classifies as loopback (e.g. 127.0.0.0/8, ::1). IPv6 is currently rejected +// by cmd/thv/app/run.go:ValidateAndNormaliseHostFlag; this helper nevertheless +// handles it so future IPv6 support does not silently lose default Origin +// protection. +func isLoopbackHost(host string) bool { + if host == "localhost" { + return true + } + // Strip bracket form for IPv6 literals: "[::1]" → "::1". + trimmed := strings.TrimSuffix(strings.TrimPrefix(host, "["), "]") + if ip := net.ParseIP(trimmed); ip != nil { + return ip.IsLoopback() + } + return false +} + +// writeForbidden emits a 403 response with a JSON-RPC error body (id: null). +func writeForbidden(w http.ResponseWriter) { + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "error": map[string]any{ + "code": jsonRPCCodeInvalidRequest, + "message": "Origin not allowed", + }, + "id": nil, + }) + if err != nil { + // Marshal of a static map should never fail; fall back to a literal. + body = []byte(forbiddenBodyFallback) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + //nolint:gosec // G104: writing a static JSON error response to an HTTP client + _, _ = w.Write(body) +} diff --git a/pkg/transport/middleware/origin/origin_test.go b/pkg/transport/middleware/origin/origin_test.go new file mode 100644 index 0000000000..805ddb1ae7 --- /dev/null +++ b/pkg/transport/middleware/origin/origin_test.go @@ -0,0 +1,379 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package origin + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/transport/types" + typesmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks" +) + +// runMiddleware applies the middleware to a stub handler, issues a request +// with the given Origin header (skipped when empty), and returns the response. +func runMiddleware( + t *testing.T, + allowedOrigins []string, + origin string, +) (*httptest.ResponseRecorder, bool) { + t.Helper() + var nextCalled bool + mw := NewHandler(allowedOrigins) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if origin != "" { + req.Header.Set("Origin", origin) + } + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + return rec, nextCalled +} + +func TestOriginMiddleware_RequestPermitted(t *testing.T) { + t.Parallel() + tests := []struct { + name string + allowedOrigins []string + origin string + }{ + { + name: "empty allowlist disables middleware", + allowedOrigins: nil, + origin: "http://evil.example", + }, + { + name: "missing Origin header passes", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "", + }, + { + name: "exact match passes", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://localhost:8080", + }, + { + name: "match against second entry", + allowedOrigins: []string{"http://localhost:8080", "http://127.0.0.1:8080"}, + origin: "http://127.0.0.1:8080", + }, + { + name: "case-insensitive scheme match (RFC 6454)", + allowedOrigins: []string{"http://app.example.com"}, + origin: "HTTP://app.example.com", + }, + { + name: "case-insensitive host match (RFC 6454)", + allowedOrigins: []string{"https://App.Example.com"}, + origin: "https://app.example.com", + }, + { + name: "mixed-case allowlist entry matches lowercase Origin", + allowedOrigins: []string{"HTTPS://App.Example.com:443"}, + origin: "https://app.example.com:443", + }, + { + name: "IPv6 bracket literal matches", + allowedOrigins: []string{"http://[::1]:8080"}, + origin: "http://[::1]:8080", + }, + { + name: "IPv6 uppercase hex folds to match (RFC 6454 host case-insensitive)", + allowedOrigins: []string{"http://[fe80::1]:8080"}, + origin: "http://[FE80::1]:8080", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + rec, nextCalled := runMiddleware(t, tc.allowedOrigins, tc.origin) + assert.True(t, nextCalled, "next handler must be invoked") + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +func TestOriginMiddleware_RequestRejected(t *testing.T) { + t.Parallel() + tests := []struct { + name string + allowedOrigins []string + origin string + }{ + { + name: "different host rejected", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://evil.example", + }, + { + name: "different port rejected (exact match required)", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://localhost:9090", + }, + { + name: "different scheme rejected", + allowedOrigins: []string{"https://app.example.com"}, + origin: "http://app.example.com", + }, + { + // RFC 6454 §6: a serialized origin carries no userinfo. An attacker + // must not be able to smuggle a trusted host into the userinfo + // component (https://localhost:8080@evil.example). + name: "Origin with userinfo rejected even if host looks allowed", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://localhost:8080@evil.example", + }, + { + name: "allowlist entry with userinfo never matches", + allowedOrigins: []string{"http://user:pass@localhost:8080"}, + origin: "http://localhost:8080", + }, + { + name: "malformed Origin (control char) rejected", + allowedOrigins: []string{"http://localhost:8080"}, + origin: "http://local\x7fhost:8080", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + rec, nextCalled := runMiddleware(t, tc.allowedOrigins, tc.origin) + assertForbiddenJSONRPC(t, rec, nextCalled) + }) + } +} + +func TestOriginMiddleware_MultipleOriginHeadersRejected(t *testing.T) { + t.Parallel() + + var nextCalled bool + mw := NewHandler([]string{"http://localhost:8080"}) + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Add("Origin", "http://localhost:8080") + req.Header.Add("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assertForbiddenJSONRPC(t, rec, nextCalled) +} + +// assertForbiddenJSONRPC validates that rec carries a 403 with a canonical +// JSON-RPC error body and that the inner handler was never invoked. +func assertForbiddenJSONRPC(t *testing.T, rec *httptest.ResponseRecorder, nextCalled bool) { + t.Helper() + assert.False(t, nextCalled, "next handler must NOT be invoked") + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + var parsed struct { + JSONRPC string `json:"jsonrpc"` + Error struct { + Code int64 `json:"code"` + Message string `json:"message"` + } `json:"error"` + ID any `json:"id"` + } + require.NoError(t, json.Unmarshal(body, &parsed)) + assert.Equal(t, "2.0", parsed.JSONRPC) + assert.Equal(t, jsonRPCCodeInvalidRequest, parsed.Error.Code) + assert.Equal(t, "Origin not allowed", parsed.Error.Message) + assert.Nil(t, parsed.ID) +} + +func TestNewHandler_RejectsDisallowedOrigin(t *testing.T) { + t.Parallel() + mw := NewHandler([]string{"http://localhost:8080"}) + require.NotNil(t, mw) + + handler := mw(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req.Header.Set("Origin", "http://evil.example") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestCreateMiddleware_Factory(t *testing.T) { + t.Parallel() + + t.Run("valid parameters register middleware", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + params := MiddlewareParams{AllowedOrigins: []string{"http://localhost:8080"}} + cfg, err := types.NewMiddlewareConfig(MiddlewareType, params) + require.NoError(t, err) + + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&FactoryMiddleware{})). + Times(1) + + require.NoError(t, CreateMiddleware(cfg, runner)) + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + cfg := &types.MiddlewareConfig{ + Type: MiddlewareType, + Parameters: json.RawMessage(`{not json}`), + } + + err := CreateMiddleware(cfg, runner) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal origin middleware parameters") + }) + + t.Run("empty allowlist still registers pass-through", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + runner := typesmocks.NewMockMiddlewareRunner(ctrl) + + cfg, err := types.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{AllowedOrigins: nil}) + require.NoError(t, err) + + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&FactoryMiddleware{})). + Times(1) + + require.NoError(t, CreateMiddleware(cfg, runner)) + }) +} + +func TestFactoryMiddleware_Lifecycle(t *testing.T) { + t.Parallel() + + mw := &FactoryMiddleware{handler: NewHandler([]string{"http://localhost:8080"})} + require.NotNil(t, mw.Handler()) + require.NoError(t, mw.Close()) +} + +func TestResolveAllowedOrigins(t *testing.T) { + t.Parallel() + tests := []struct { + name string + host string + port int + explicit []string + want []string + }{ + { + name: "explicit list wins over loopback derivation", + host: "127.0.0.1", + port: 8080, + explicit: []string{"https://app.example.com"}, + want: []string{"https://app.example.com"}, + }, + { + name: "loopback IPv4 auto-derives localhost defaults", + host: "127.0.0.1", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "non-standard loopback IPv4 auto-derives defaults", + host: "127.0.0.2", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "localhost string auto-derives defaults", + host: "localhost", + port: 8080, + want: []string{ + "http://localhost:8080", + "http://127.0.0.1:8080", + "http://[::1]:8080", + }, + }, + { + name: "IPv6 loopback ::1 auto-derives defaults", + host: "::1", + port: 9090, + want: []string{ + "http://localhost:9090", + "http://127.0.0.1:9090", + "http://[::1]:9090", + }, + }, + { + name: "IPv6 loopback in bracket form auto-derives defaults", + host: "[::1]", + port: 9090, + want: []string{ + "http://localhost:9090", + "http://127.0.0.1:9090", + "http://[::1]:9090", + }, + }, + { + name: "non-loopback host with empty explicit returns nil", + host: "0.0.0.0", + port: 8080, + want: nil, + }, + { + name: "public host with empty explicit returns nil", + host: "192.168.1.10", + port: 8080, + want: nil, + }, + { + name: "garbage host returns nil", + host: "not-a-host", + port: 8080, + want: nil, + }, + { + name: "zero port disables derivation", + host: "127.0.0.1", + port: 0, + want: nil, + }, + { + name: "negative port disables derivation", + host: "127.0.0.1", + port: -1, + want: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := ResolveAllowedOrigins(tc.host, tc.port, tc.explicit) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 3eb6943975..87ddc06869 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -373,7 +373,10 @@ func (p *HTTPSSEProxy) handleSSEConnection(w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") + // CORS headers deliberately omitted: the origin middleware + // (pkg/transport/middleware/origin) enforces Origin validation per + // MCP 2025-11-25 §"Security Warning". Reflecting Origin or emitting + // `*` here would bypass that protection. // Create a unique client ID clientID := uuid.New().String()