From a7badd95ef801a065988abe91b029858ed739d9e Mon Sep 17 00:00:00 2001 From: bishnubista Date: Thu, 14 May 2026 16:55:11 -0700 Subject: [PATCH 1/3] Validate upstream JSON-RPC responses in transparent proxy The transparent proxy forwarded malformed upstream MCP frames to clients with HTTP 200 even when the response violated JSON-RPC 2.0 structure. This adds a boundary check in NoOpResponseProcessor that rejects structurally invalid upstream frames and returns a synthetic 502 carrying a JSON-RPC error to the client, so the proxy stops being a silent amplifier for malformed (or adversarial) upstream servers. Validation runs only for streamable-http POST/200 responses that carry an MCP request signal (MCP-Protocol-Version or Mcp-Session-Id) and an application/json content type, with non-identity Content-Encoding traffic passed through untouched. Body reads are bounded to 100 MiB to match existing streamable-HTTP limits in pkg/vmcp. Rewritten error responses replace headers wholesale so upstream session/cookie/cache metadata is not smuggled into the proxy-generated error. SSE traffic is unaffected. Closes #5247 Signed-off-by: bishnubista --- .../proxy/transparent/response_processor.go | 200 ++++++++- .../transparent/response_processor_test.go | 391 ++++++++++++++++++ 2 files changed, 587 insertions(+), 4 deletions(-) create mode 100644 pkg/transport/proxy/transparent/response_processor_test.go diff --git a/pkg/transport/proxy/transparent/response_processor.go b/pkg/transport/proxy/transparent/response_processor.go index a6e0ca765f..e7fd14f844 100644 --- a/pkg/transport/proxy/transparent/response_processor.go +++ b/pkg/transport/proxy/transparent/response_processor.go @@ -6,11 +6,29 @@ package transparent import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math" + "mime" "net/http" + "strings" "github.com/stacklok/toolhive/pkg/transport/types" ) +// maxJSONRPCResponseBytes caps how much of an upstream JSON-RPC response the proxy +// will buffer for structural validation. Matches existing streamable-HTTP body +// limits elsewhere in the codebase (pkg/vmcp/client, pkg/vmcp/session/internal/backend). +const maxJSONRPCResponseBytes = 100 << 20 // 100 MiB + +// JSON-RPC error code returned to clients when the proxy rejects a malformed +// upstream response. -32000..-32099 is the implementation-defined server-error +// range in the JSON-RPC 2.0 spec; -32603 is reserved for internal JSON-RPC +// implementation errors and is not appropriate for a policy-level rejection. +const jsonRPCInvalidUpstreamCode = -32000 + // ResponseProcessor defines the interface for processing and modifying HTTP responses // based on transport-specific requirements. type ResponseProcessor interface { @@ -22,12 +40,38 @@ type ResponseProcessor interface { ShouldProcess(resp *http.Response) bool } -// NoOpResponseProcessor is a processor that does nothing. -// Used for transports that don't require response processing (e.g., streamable-http). +// NoOpResponseProcessor is the default processor for non-SSE transports. +// It validates JSON-RPC responses for streamable HTTP and otherwise leaves responses unchanged. type NoOpResponseProcessor struct{} -// ProcessResponse is a no-op implementation. -func (*NoOpResponseProcessor) ProcessResponse(_ *http.Response) error { +// ProcessResponse validates JSON-RPC responses when applicable. +func (*NoOpResponseProcessor) ProcessResponse(resp *http.Response) error { + if !shouldValidateJSONRPCResponse(resp) { + return nil + } + + // Read one byte past the cap so we can detect oversize without allocating beyond it. + body, err := io.ReadAll(io.LimitReader(resp.Body, maxJSONRPCResponseBytes+1)) + if err != nil { + return fmt.Errorf("failed to read upstream response body: %w", err) + } + _ = resp.Body.Close() + + if len(body) > maxJSONRPCResponseBytes { + writeInvalidUpstreamJSONRPCResponse(resp, fmt.Errorf( + "upstream JSON-RPC response exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes)) + return nil + } + + if err := validateJSONRPCResponse(body); err != nil { + writeInvalidUpstreamJSONRPCResponse(resp, err) + return nil + } + + // The reverse proxy still needs a readable body after validation. + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) return nil } @@ -36,6 +80,154 @@ func (*NoOpResponseProcessor) ShouldProcess(_ *http.Response) bool { return false } +func shouldValidateJSONRPCResponse(resp *http.Response) bool { + if resp == nil || resp.Body == nil || resp.Request == nil { + return false + } + if resp.Request.Method != http.MethodPost || resp.StatusCode != http.StatusOK { + return false + } + if !hasIdentityContentEncoding(resp.Header.Get("Content-Encoding")) { + // Content-Encoding semantics (RFC 9110): media-type rules apply after decoding. + // Validating a still-encoded body would mis-classify legitimate gzip JSON-RPC + // frames as invalid. Skip rather than introduce decompression here. + return false + } + if !requestLooksLikeMCP(resp.Request) { + // Narrow validation to traffic that carries an MCP streamable-HTTP signal, + // so non-MCP application/json POSTs flowing through the catch-all are not + // rewritten. Backward-compat clients omitting MCP-Protocol-Version on the + // initial initialize will pass through unchanged. + return false + } + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + return mediaType == "application/json" || mediaType == "application/json-rpc" +} + +func hasIdentityContentEncoding(value string) bool { + v := strings.TrimSpace(strings.ToLower(value)) + return v == "" || v == "identity" +} + +func requestLooksLikeMCP(req *http.Request) bool { + if req == nil { + return false + } + return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != "" +} + +func validateJSONRPCResponse(body []byte) error { + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("invalid JSON body: %w", err) + } + if dec.More() { + return fmt.Errorf("JSON-RPC response must contain a single JSON value") + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + return fmt.Errorf("JSON-RPC response must contain a single JSON value") + } + + switch value := payload.(type) { + case map[string]any: + return validateJSONRPCResponseObject(value) + case []any: + if len(value) == 0 { + return fmt.Errorf("JSON-RPC batch response must not be empty") + } + for i, item := range value { + obj, ok := item.(map[string]any) + if !ok { + return fmt.Errorf("JSON-RPC batch item %d must be an object", i) + } + if err := validateJSONRPCResponseObject(obj); err != nil { + return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err) + } + } + return nil + default: + return fmt.Errorf("JSON-RPC response must be an object or array") + } +} + +func validateJSONRPCResponseObject(obj map[string]any) error { + if obj["jsonrpc"] != "2.0" { + return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`) + } + + if _, ok := obj["id"]; !ok { + return fmt.Errorf("JSON-RPC response must include id") + } + if !isValidJSONRPCID(obj["id"]) { + return fmt.Errorf("JSON-RPC response id must be string, number, or null") + } + + _, hasResult := obj["result"] + _, hasError := obj["error"] + if hasResult == hasError { + return fmt.Errorf("JSON-RPC response must include exactly one of result or error") + } + if hasError { + if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) { + return fmt.Errorf("JSON-RPC error response must include error.code and error.message") + } + } + + return nil +} + +func isValidJSONRPCID(id any) bool { + switch id.(type) { + case nil, string, float64: + return true + default: + return false + } +} + +func isValidJSONRPCError(errObj map[string]any) bool { + code, codeOK := errObj["code"].(float64) + if !codeOK || math.Trunc(code) != code { + // JSON-RPC 2.0 requires error.code to be an integer. + return false + } + _, messageOK := errObj["message"].(string) + return messageOK +} + +func writeInvalidUpstreamJSONRPCResponse(resp *http.Response, validationErr error) { + body, err := json.Marshal(map[string]any{ + "jsonrpc": "2.0", + "error": map[string]any{ + "code": jsonRPCInvalidUpstreamCode, + "message": "Invalid upstream JSON-RPC response", + "data": validationErr.Error(), + }, + "id": nil, + }) + if err != nil { + body = []byte(`{"jsonrpc":"2.0","error":{"code":-32000,"message":"Invalid upstream JSON-RPC response"},"id":null}`) + } + + resp.StatusCode = http.StatusBadGateway + resp.Status = fmt.Sprintf("%d %s", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)) + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + + // Replace headers wholesale so upstream session/cookie/cache metadata is not + // smuggled into the proxy-generated error. Only carry the fields needed to + // describe this synthetic body. + resp.Header = http.Header{} + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body))) + resp.Trailer = nil +} + // createResponseProcessor is a factory function that creates the appropriate // response processor based on transport type. func createResponseProcessor( diff --git a/pkg/transport/proxy/transparent/response_processor_test.go b/pkg/transport/proxy/transparent/response_processor_test.go new file mode 100644 index 0000000000..fd53560fb7 --- /dev/null +++ b/pkg/transport/proxy/transparent/response_processor_test.go @@ -0,0 +1,391 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package transparent + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNoOpResponseProcessorValidatesJSONRPCResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + wantStatus int + wantBody string + }{ + { + name: "valid result response passes through", + body: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`, + }, + { + name: "valid error response passes through", + body: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`, + }, + { + name: "valid batch response passes through", + body: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":"two","result":{}}]`, + wantStatus: http.StatusOK, + wantBody: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":"two","result":{}}]`, + }, + { + name: "valid null result response passes through", + body: `{"jsonrpc":"2.0","id":1,"result":null}`, + wantStatus: http.StatusOK, + wantBody: `{"jsonrpc":"2.0","id":1,"result":null}`, + }, + { + name: "missing jsonrpc is rejected", + body: `{"id":1,"result":{"ok":true}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"Invalid upstream JSON-RPC response"`, + }, + { + name: "invalid id type is rejected", + body: `{"jsonrpc":"2.0","id":{"nested":true},"result":{}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response id must be string, number, or null"`, + }, + { + name: "non-object body is rejected", + body: `"not an object"`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must be an object or array"`, + }, + { + name: "result and error together are rejected", + body: `{"jsonrpc":"2.0","id":1,"result":{},"error":{"code":-32603,"message":"boom"}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must include exactly one of result or error"`, + }, + { + name: "trailing JSON value is rejected", + body: `{"jsonrpc":"2.0","id":1,"result":{}} {"jsonrpc":"2.0","id":2,"result":{}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must contain a single JSON value"`, + }, + { + name: "fractional error code is rejected", + body: `{"jsonrpc":"2.0","id":1,"error":{"code":1.5,"message":"nope"}}`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC error response must include error.code and error.message"`, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp := jsonResponse(tt.body) + if tt.wantStatus == http.StatusBadGateway { + // These sensitive headers must not survive a rewrite. Content-Encoding + // is covered separately by TestNoOpResponseProcessorSkipsCompressedResponses; + // setting it here would route through the pass-through gate instead. + resp.Header.Set("Mcp-Session-Id", "upstream-session-leak") + resp.Header.Set("Set-Cookie", "leak=1") + resp.Header.Set("Etag", "\"upstream-etag\"") + resp.Header.Set("Cache-Control", "private, max-age=60") + } + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tt.wantStatus, resp.StatusCode) + assert.Contains(t, string(gotBody), tt.wantBody) + assert.Equal(t, int64(len(gotBody)), resp.ContentLength) + assert.Equal(t, len(gotBody), int(resp.ContentLength)) + if tt.wantStatus == http.StatusBadGateway { + // Wholesale header replacement: only Content-Type and Content-Length remain. + assert.Empty(t, resp.Header.Get("Mcp-Session-Id")) + assert.Empty(t, resp.Header.Get("Set-Cookie")) + assert.Empty(t, resp.Header.Get("Etag")) + assert.Empty(t, resp.Header.Get("Cache-Control")) + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + assert.Nil(t, resp.Trailer) + } + }) + } +} + +func TestNoOpResponseProcessorAcceptsJSONContentTypeParameters(t *testing.T) { + t.Parallel() + + resp := jsonResponse(`{"jsonrpc":"2.0","id":1,"result":{}}`) + resp.Header.Set("Content-Type", "application/json; charset=utf-8") + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, `{"jsonrpc":"2.0","id":1,"result":{}}`, string(gotBody)) +} + +func TestNoOpResponseProcessorSkipsNonJSONRPCResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + method string + status int + contentType string + body string + }{ + { + name: "non-post response", + method: http.MethodGet, + status: http.StatusOK, + contentType: "application/json", + body: `{"resource":"https://example.com"}`, + }, + { + name: "non-200 response", + method: http.MethodPost, + status: http.StatusAccepted, + contentType: "application/json", + body: ``, + }, + { + name: "non-json response", + method: http.MethodPost, + status: http.StatusOK, + contentType: "text/plain", + body: `not json`, + }, + { + name: "post response with event stream", + method: http.MethodPost, + status: http.StatusOK, + contentType: "text/event-stream", + body: "event: message\ndata: {}\n\n", + }, + { + name: "content type containing application/json is not enough", + method: http.MethodPost, + status: http.StatusOK, + contentType: "application/jsonsomethingelse", + body: `not json`, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req := mcpRequest(tt.method) + resp := &http.Response{ + StatusCode: tt.status, + Status: http.StatusText(tt.status), + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(tt.body)), + ContentLength: int64(len(tt.body)), + Request: req, + } + resp.Header.Set("Content-Type", tt.contentType) + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, tt.status, resp.StatusCode) + assert.Equal(t, tt.body, string(gotBody)) + }) + } +} + +// TestNoOpResponseProcessorSkipsCompressedResponses verifies that responses +// carrying a non-identity Content-Encoding are passed through unchanged. +// Decoding here would either reject legitimate compressed JSON-RPC frames or +// open a decompression-bomb amplification path. +func TestNoOpResponseProcessorSkipsCompressedResponses(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + contentEncoding string + body string + }{ + { + name: "gzip valid json is left alone", + contentEncoding: "gzip", + body: gzipBytes(t, `{"jsonrpc":"2.0","id":1,"result":{}}`), + }, + { + name: "gzip malformed body is left alone (no false reject)", + contentEncoding: "gzip", + body: "not really gzip, but encoding header is set", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp := jsonResponse(tt.body) + resp.Header.Set("Content-Encoding", tt.contentEncoding) + + err := (&NoOpResponseProcessor{}).ProcessResponse(resp) + require.NoError(t, err) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, tt.body, string(gotBody)) + }) + } +} + +// TestNoOpResponseProcessorValidatesUnderIdentityEncoding proves that an +// explicit Content-Encoding: identity does not bypass validation: a malformed +// JSON-RPC body must still produce a 502 rewrite. +func TestNoOpResponseProcessorValidatesUnderIdentityEncoding(t *testing.T) { + t.Parallel() + + resp := jsonResponse(`{"id":1,"result":{"ok":true}}`) // missing jsonrpc → invalid + resp.Header.Set("Content-Encoding", "identity") + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), `"Invalid upstream JSON-RPC response"`) +} + +// TestNoOpResponseProcessorRequiresMCPSignal narrows validation to traffic that +// carries an MCP streamable-HTTP signal on the request. application/json POST +// 200 responses from non-MCP traffic flowing through the catch-all proxy must +// not be rewritten. +func TestNoOpResponseProcessorRequiresMCPSignal(t *testing.T) { + t.Parallel() + + body := `{"id":1,"result":{"ok":true}}` // missing jsonrpc — would be rejected if validated + + tests := []struct { + name string + headers map[string]string + validate bool + }{ + { + name: "no MCP headers — pass through", + headers: nil, + validate: false, + }, + { + name: "MCP-Protocol-Version header — validated", + headers: map[string]string{"MCP-Protocol-Version": "2025-06-18"}, + validate: true, + }, + { + name: "Mcp-Session-Id header — validated", + headers: map[string]string{"Mcp-Session-Id": "session-abc"}, + validate: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodPost, "http://example.com/mcp", nil) + require.NoError(t, err) + for k, v := range tt.headers { + req.Header.Set(k, v) + } + resp := &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + Request: req, + } + resp.Header.Set("Content-Type", "application/json") + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + if tt.validate { + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), `"Invalid upstream JSON-RPC response"`) + } else { + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, body, string(gotBody)) + } + }) + } +} + +// TestNoOpResponseProcessorRejectsOversizeResponse verifies the bounded read. +// The proxy is a security boundary; an unbounded io.ReadAll on attacker- +// controlled upstream bodies would amplify a malicious server into a memory +// DoS against the proxy. +func TestNoOpResponseProcessorRejectsOversizeResponse(t *testing.T) { + t.Parallel() + + // Produce a body strictly larger than the cap. Content does not need to be + // valid JSON-RPC — the size check fires before validation. + oversize := strings.Repeat("a", maxJSONRPCResponseBytes+1) + resp := jsonResponse(oversize) + + require.NoError(t, (&NoOpResponseProcessor{}).ProcessResponse(resp)) + + gotBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, http.StatusBadGateway, resp.StatusCode) + assert.Contains(t, string(gotBody), fmt.Sprintf("exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes)) +} + +func jsonResponse(body string) *http.Response { + resp := &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + Request: mcpRequest(http.MethodPost), + } + resp.Header.Set("Content-Type", "application/json") + return resp +} + +func mcpRequest(method string) *http.Request { + req, _ := http.NewRequest(method, "http://example.com/mcp", nil) + req.Header.Set("MCP-Protocol-Version", "2025-06-18") + return req +} + +func gzipBytes(t *testing.T, payload string) string { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + _, err := gw.Write([]byte(payload)) + require.NoError(t, err) + require.NoError(t, gw.Close()) + return buf.String() +} From f70c56c62b3298eac0c49b5dac3f6b5753d1a56c Mon Sep 17 00:00:00 2001 From: bishnubista Date: Tue, 26 May 2026 00:42:13 -0700 Subject: [PATCH 2/3] Tighten upstream JSON-RPC response validation Remove the top-level json.Decoder.More check from upstream JSON-RPC response validation. More is intended for array and object iteration, so using it as a top-level single-value guard can miss malformed trailing delimiter bytes. Keep the second Decode call as the exact-single-value check and compare its result with io.EOF via errors.Is. This rejects trailing JSON values and trailing syntax junk while accepting clean EOF after the first decoded response. Add a regression case for a valid JSON-RPC response followed by a stray closing delimiter so malformed trailing bytes cannot pass validation. --- pkg/transport/proxy/transparent/response_processor.go | 6 ++---- pkg/transport/proxy/transparent/response_processor_test.go | 6 ++++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/transport/proxy/transparent/response_processor.go b/pkg/transport/proxy/transparent/response_processor.go index e7fd14f844..319e427713 100644 --- a/pkg/transport/proxy/transparent/response_processor.go +++ b/pkg/transport/proxy/transparent/response_processor.go @@ -8,6 +8,7 @@ package transparent import ( "bytes" "encoding/json" + "errors" "fmt" "io" "math" @@ -126,10 +127,7 @@ func validateJSONRPCResponse(body []byte) error { if err := dec.Decode(&payload); err != nil { return fmt.Errorf("invalid JSON body: %w", err) } - if dec.More() { - return fmt.Errorf("JSON-RPC response must contain a single JSON value") - } - if err := dec.Decode(&struct{}{}); err != io.EOF { + if err := dec.Decode(&struct{}{}); !errors.Is(err, io.EOF) { return fmt.Errorf("JSON-RPC response must contain a single JSON value") } diff --git a/pkg/transport/proxy/transparent/response_processor_test.go b/pkg/transport/proxy/transparent/response_processor_test.go index fd53560fb7..9709969a47 100644 --- a/pkg/transport/proxy/transparent/response_processor_test.go +++ b/pkg/transport/proxy/transparent/response_processor_test.go @@ -79,6 +79,12 @@ func TestNoOpResponseProcessorValidatesJSONRPCResponses(t *testing.T) { wantStatus: http.StatusBadGateway, wantBody: `"JSON-RPC response must contain a single JSON value"`, }, + { + name: "trailing delimiter is rejected", + body: `{"jsonrpc":"2.0","id":1,"result":{}}]`, + wantStatus: http.StatusBadGateway, + wantBody: `"JSON-RPC response must contain a single JSON value"`, + }, { name: "fractional error code is rejected", body: `{"jsonrpc":"2.0","id":1,"error":{"code":1.5,"message":"nope"}}`, From c0a8b6c9c970caf7352b4d2d050ae1b3058c310c Mon Sep 17 00:00:00 2001 From: bishnubista Date: Tue, 26 May 2026 08:32:02 -0700 Subject: [PATCH 3/3] fix(proxy): reuse MCP response validation helper Move strict JSON-RPC response body validation into pkg/mcp so the transparent proxy reuses shared MCP parsing code instead of carrying a private copy. Use parsed MCP request context from the parser middleware as the primary signal for response validation, while preserving protocol/session header fallback for batch and compatibility cases that the request parser does not currently cover. Keep the transparent proxy's ModifyResponse path responsible for rejecting malformed upstream responses with the existing bounded read and 502 rewrite. --- pkg/mcp/response.go | 84 +++++++++++++++++++ pkg/mcp/response_test.go | 76 +++++++++++++++++ .../proxy/transparent/response_processor.go | 83 +----------------- .../transparent/response_processor_test.go | 37 ++++++-- 4 files changed, 196 insertions(+), 84 deletions(-) diff --git a/pkg/mcp/response.go b/pkg/mcp/response.go index 9a50b75161..a5c502c622 100644 --- a/pkg/mcp/response.go +++ b/pkg/mcp/response.go @@ -4,7 +4,12 @@ package mcp import ( + "bytes" "encoding/json" + "errors" + "fmt" + "io" + "math" ) // ParsedMCPResponse contains the result of inspecting a JSON-RPC response @@ -60,3 +65,82 @@ func ParseMCPResponse(body []byte) *ParsedMCPResponse { ErrorMessage: envelope.Error.Message, } } + +// ValidateJSONRPCResponseBody strictly validates a JSON-RPC 2.0 response body. +// Unlike ParseMCPResponse, this is an enforcement helper: malformed bodies return +// an error instead of being treated as "no application error". +func ValidateJSONRPCResponseBody(body []byte) error { + var payload any + dec := json.NewDecoder(bytes.NewReader(body)) + if err := dec.Decode(&payload); err != nil { + return fmt.Errorf("invalid JSON body: %w", err) + } + if err := dec.Decode(&struct{}{}); !errors.Is(err, io.EOF) { + return fmt.Errorf("JSON-RPC response must contain a single JSON value") + } + + switch value := payload.(type) { + case map[string]any: + return validateJSONRPCResponseObject(value) + case []any: + if len(value) == 0 { + return fmt.Errorf("JSON-RPC batch response must not be empty") + } + for i, item := range value { + obj, ok := item.(map[string]any) + if !ok { + return fmt.Errorf("JSON-RPC batch item %d must be an object", i) + } + if err := validateJSONRPCResponseObject(obj); err != nil { + return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err) + } + } + return nil + default: + return fmt.Errorf("JSON-RPC response must be an object or array") + } +} + +func validateJSONRPCResponseObject(obj map[string]any) error { + if obj["jsonrpc"] != "2.0" { + return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`) + } + + if _, ok := obj["id"]; !ok { + return fmt.Errorf("JSON-RPC response must include id") + } + if !isValidJSONRPCID(obj["id"]) { + return fmt.Errorf("JSON-RPC response id must be string, number, or null") + } + + _, hasResult := obj["result"] + _, hasError := obj["error"] + if hasResult == hasError { + return fmt.Errorf("JSON-RPC response must include exactly one of result or error") + } + if hasError { + if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) { + return fmt.Errorf("JSON-RPC error response must include error.code and error.message") + } + } + + return nil +} + +func isValidJSONRPCID(id any) bool { + switch id.(type) { + case nil, string, float64: + return true + default: + return false + } +} + +func isValidJSONRPCError(errObj map[string]any) bool { + code, codeOK := errObj["code"].(float64) + if !codeOK || math.Trunc(code) != code { + return false + } + _, messageOK := errObj["message"].(string) + return messageOK +} diff --git a/pkg/mcp/response_test.go b/pkg/mcp/response_test.go index aa963b9ed4..59b0416071 100644 --- a/pkg/mcp/response_test.go +++ b/pkg/mcp/response_test.go @@ -90,3 +90,79 @@ func TestParseMCPResponse(t *testing.T) { }) } } + +func TestValidateJSONRPCResponseBody(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + wantErr string + }{ + { + name: "valid result response", + body: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`, + }, + { + name: "valid error response", + body: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`, + }, + { + name: "valid batch response", + body: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":null,"result":null}]`, + }, + { + name: "invalid JSON", + body: `not json`, + wantErr: "invalid JSON body", + }, + { + name: "missing jsonrpc", + body: `{"id":1,"result":{"ok":true}}`, + wantErr: `JSON-RPC response must include "jsonrpc":"2.0"`, + }, + { + name: "invalid id type", + body: `{"jsonrpc":"2.0","id":{"nested":true},"result":{}}`, + wantErr: "JSON-RPC response id must be string, number, or null", + }, + { + name: "empty batch", + body: `[]`, + wantErr: "JSON-RPC batch response must not be empty", + }, + { + name: "result and error together", + body: `{"jsonrpc":"2.0","id":1,"result":{},"error":{"code":-32603,"message":"boom"}}`, + wantErr: "JSON-RPC response must include exactly one of result or error", + }, + { + name: "trailing JSON value", + body: `{"jsonrpc":"2.0","id":1,"result":{}} {"jsonrpc":"2.0","id":2,"result":{}}`, + wantErr: "JSON-RPC response must contain a single JSON value", + }, + { + name: "trailing delimiter", + body: `{"jsonrpc":"2.0","id":1,"result":{}}]`, + wantErr: "JSON-RPC response must contain a single JSON value", + }, + { + name: "fractional error code", + body: `{"jsonrpc":"2.0","id":1,"error":{"code":1.5,"message":"nope"}}`, + wantErr: "JSON-RPC error response must include error.code and error.message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := ValidateJSONRPCResponseBody([]byte(tt.body)) + if tt.wantErr == "" { + assert.NoError(t, err) + return + } + assert.ErrorContains(t, err, tt.wantErr) + }) + } +} diff --git a/pkg/transport/proxy/transparent/response_processor.go b/pkg/transport/proxy/transparent/response_processor.go index 319e427713..789943c868 100644 --- a/pkg/transport/proxy/transparent/response_processor.go +++ b/pkg/transport/proxy/transparent/response_processor.go @@ -8,14 +8,13 @@ package transparent import ( "bytes" "encoding/json" - "errors" "fmt" "io" - "math" "mime" "net/http" "strings" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/transport/types" ) @@ -64,7 +63,7 @@ func (*NoOpResponseProcessor) ProcessResponse(resp *http.Response) error { return nil } - if err := validateJSONRPCResponse(body); err != nil { + if err := mcpparser.ValidateJSONRPCResponseBody(body); err != nil { writeInvalidUpstreamJSONRPCResponse(resp, err) return nil } @@ -118,84 +117,10 @@ func requestLooksLikeMCP(req *http.Request) bool { if req == nil { return false } - return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != "" -} - -func validateJSONRPCResponse(body []byte) error { - var payload any - dec := json.NewDecoder(bytes.NewReader(body)) - if err := dec.Decode(&payload); err != nil { - return fmt.Errorf("invalid JSON body: %w", err) - } - if err := dec.Decode(&struct{}{}); !errors.Is(err, io.EOF) { - return fmt.Errorf("JSON-RPC response must contain a single JSON value") - } - - switch value := payload.(type) { - case map[string]any: - return validateJSONRPCResponseObject(value) - case []any: - if len(value) == 0 { - return fmt.Errorf("JSON-RPC batch response must not be empty") - } - for i, item := range value { - obj, ok := item.(map[string]any) - if !ok { - return fmt.Errorf("JSON-RPC batch item %d must be an object", i) - } - if err := validateJSONRPCResponseObject(obj); err != nil { - return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err) - } - } - return nil - default: - return fmt.Errorf("JSON-RPC response must be an object or array") - } -} - -func validateJSONRPCResponseObject(obj map[string]any) error { - if obj["jsonrpc"] != "2.0" { - return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`) - } - - if _, ok := obj["id"]; !ok { - return fmt.Errorf("JSON-RPC response must include id") - } - if !isValidJSONRPCID(obj["id"]) { - return fmt.Errorf("JSON-RPC response id must be string, number, or null") - } - - _, hasResult := obj["result"] - _, hasError := obj["error"] - if hasResult == hasError { - return fmt.Errorf("JSON-RPC response must include exactly one of result or error") - } - if hasError { - if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) { - return fmt.Errorf("JSON-RPC error response must include error.code and error.message") - } - } - - return nil -} - -func isValidJSONRPCID(id any) bool { - switch id.(type) { - case nil, string, float64: + if mcpparser.GetParsedMCPRequest(req.Context()) != nil { return true - default: - return false } -} - -func isValidJSONRPCError(errObj map[string]any) bool { - code, codeOK := errObj["code"].(float64) - if !codeOK || math.Trunc(code) != code { - // JSON-RPC 2.0 requires error.code to be an integer. - return false - } - _, messageOK := errObj["message"].(string) - return messageOK + return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != "" } func writeInvalidUpstreamJSONRPCResponse(resp *http.Response, validationErr error) { diff --git a/pkg/transport/proxy/transparent/response_processor_test.go b/pkg/transport/proxy/transparent/response_processor_test.go index 9709969a47..9f32fecb2d 100644 --- a/pkg/transport/proxy/transparent/response_processor_test.go +++ b/pkg/transport/proxy/transparent/response_processor_test.go @@ -6,6 +6,7 @@ package transparent import ( "bytes" "compress/gzip" + "context" "fmt" "io" "net/http" @@ -14,6 +15,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + mcpparser "github.com/stacklok/toolhive/pkg/mcp" ) func TestNoOpResponseProcessorValidatesJSONRPCResponses(t *testing.T) { @@ -292,15 +295,20 @@ func TestNoOpResponseProcessorRequiresMCPSignal(t *testing.T) { body := `{"id":1,"result":{"ok":true}}` // missing jsonrpc — would be rejected if validated tests := []struct { - name string - headers map[string]string - validate bool + name string + headers map[string]string + parsedContext bool + validate bool }{ { - name: "no MCP headers — pass through", - headers: nil, + name: "no parsed context or MCP headers — pass through", validate: false, }, + { + name: "parsed MCP context without headers — validated", + parsedContext: true, + validate: true, + }, { name: "MCP-Protocol-Version header — validated", headers: map[string]string{"MCP-Protocol-Version": "2025-06-18"}, @@ -311,6 +319,12 @@ func TestNoOpResponseProcessorRequiresMCPSignal(t *testing.T) { headers: map[string]string{"Mcp-Session-Id": "session-abc"}, validate: true, }, + { + name: "parsed MCP context and header — validated", + headers: map[string]string{"MCP-Protocol-Version": "2025-06-18"}, + parsedContext: true, + validate: true, + }, } for _, tt := range tests { @@ -323,6 +337,9 @@ func TestNoOpResponseProcessorRequiresMCPSignal(t *testing.T) { for k, v := range tt.headers { req.Header.Set(k, v) } + if tt.parsedContext { + req = withParsedMCPRequest(req) + } resp := &http.Response{ StatusCode: http.StatusOK, Status: "200 OK", @@ -386,6 +403,16 @@ func mcpRequest(method string) *http.Request { return req } +func withParsedMCPRequest(req *http.Request) *http.Request { + parsed := &mcpparser.ParsedMCPRequest{ + Method: "tools/list", + ID: 1, + IsRequest: true, + } + ctx := context.WithValue(req.Context(), mcpparser.MCPRequestContextKey, parsed) + return req.WithContext(ctx) +} + func gzipBytes(t *testing.T, payload string) string { t.Helper() var buf bytes.Buffer