Skip to content
Open
84 changes: 84 additions & 0 deletions pkg/mcp/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
package mcp

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math"
)

// ParsedMCPResponse contains the result of inspecting a JSON-RPC response
Expand Down Expand Up @@ -60,3 +65,82 @@ func ParseMCPResponse(body []byte) *ParsedMCPResponse {
ErrorMessage: envelope.Error.Message,
}
}

// ValidateJSONRPCResponseBody strictly validates a JSON-RPC 2.0 response body.
// Unlike ParseMCPResponse, this is an enforcement helper: malformed bodies return
// an error instead of being treated as "no application error".
func ValidateJSONRPCResponseBody(body []byte) error {
var payload any
dec := json.NewDecoder(bytes.NewReader(body))
if err := dec.Decode(&payload); err != nil {
return fmt.Errorf("invalid JSON body: %w", err)
}
if err := dec.Decode(&struct{}{}); !errors.Is(err, io.EOF) {
return fmt.Errorf("JSON-RPC response must contain a single JSON value")
}

switch value := payload.(type) {
case map[string]any:
return validateJSONRPCResponseObject(value)
case []any:
if len(value) == 0 {
return fmt.Errorf("JSON-RPC batch response must not be empty")
}
for i, item := range value {
obj, ok := item.(map[string]any)
if !ok {
return fmt.Errorf("JSON-RPC batch item %d must be an object", i)
}
if err := validateJSONRPCResponseObject(obj); err != nil {
return fmt.Errorf("JSON-RPC batch item %d is invalid: %w", i, err)
}
}
return nil
default:
return fmt.Errorf("JSON-RPC response must be an object or array")
}
}

func validateJSONRPCResponseObject(obj map[string]any) error {
if obj["jsonrpc"] != "2.0" {
return fmt.Errorf(`JSON-RPC response must include "jsonrpc":"2.0"`)
}

if _, ok := obj["id"]; !ok {
return fmt.Errorf("JSON-RPC response must include id")
}
if !isValidJSONRPCID(obj["id"]) {
return fmt.Errorf("JSON-RPC response id must be string, number, or null")
}

_, hasResult := obj["result"]
_, hasError := obj["error"]
if hasResult == hasError {
return fmt.Errorf("JSON-RPC response must include exactly one of result or error")
}
if hasError {
if errObj, ok := obj["error"].(map[string]any); !ok || !isValidJSONRPCError(errObj) {
return fmt.Errorf("JSON-RPC error response must include error.code and error.message")
}
}

return nil
}

func isValidJSONRPCID(id any) bool {
switch id.(type) {
case nil, string, float64:
return true
default:
return false
}
}

func isValidJSONRPCError(errObj map[string]any) bool {
code, codeOK := errObj["code"].(float64)
if !codeOK || math.Trunc(code) != code {
return false
}
_, messageOK := errObj["message"].(string)
return messageOK
}
76 changes: 76 additions & 0 deletions pkg/mcp/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,79 @@ func TestParseMCPResponse(t *testing.T) {
})
}
}

func TestValidateJSONRPCResponseBody(t *testing.T) {
t.Parallel()

tests := []struct {
name string
body string
wantErr string
}{
{
name: "valid result response",
body: `{"jsonrpc":"2.0","id":1,"result":{"ok":true}}`,
},
{
name: "valid error response",
body: `{"jsonrpc":"2.0","id":"abc","error":{"code":-32601,"message":"Method not found"}}`,
},
{
name: "valid batch response",
body: `[{"jsonrpc":"2.0","id":1,"result":{}},{"jsonrpc":"2.0","id":null,"result":null}]`,
},
{
name: "invalid JSON",
body: `not json`,
wantErr: "invalid JSON body",
},
{
name: "missing jsonrpc",
body: `{"id":1,"result":{"ok":true}}`,
wantErr: `JSON-RPC response must include "jsonrpc":"2.0"`,
},
{
name: "invalid id type",
body: `{"jsonrpc":"2.0","id":{"nested":true},"result":{}}`,
wantErr: "JSON-RPC response id must be string, number, or null",
},
{
name: "empty batch",
body: `[]`,
wantErr: "JSON-RPC batch response must not be empty",
},
{
name: "result and error together",
body: `{"jsonrpc":"2.0","id":1,"result":{},"error":{"code":-32603,"message":"boom"}}`,
wantErr: "JSON-RPC response must include exactly one of result or error",
},
{
name: "trailing JSON value",
body: `{"jsonrpc":"2.0","id":1,"result":{}} {"jsonrpc":"2.0","id":2,"result":{}}`,
wantErr: "JSON-RPC response must contain a single JSON value",
},
{
name: "trailing delimiter",
body: `{"jsonrpc":"2.0","id":1,"result":{}}]`,
wantErr: "JSON-RPC response must contain a single JSON value",
},
{
name: "fractional error code",
body: `{"jsonrpc":"2.0","id":1,"error":{"code":1.5,"message":"nope"}}`,
wantErr: "JSON-RPC error response must include error.code and error.message",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

err := ValidateJSONRPCResponseBody([]byte(tt.body))
if tt.wantErr == "" {
assert.NoError(t, err)
return
}
assert.ErrorContains(t, err, tt.wantErr)
})
}
}
123 changes: 119 additions & 4 deletions pkg/transport/proxy/transparent/response_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,29 @@
package transparent

import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime"
"net/http"
"strings"

mcpparser "github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/transport/types"
)

// maxJSONRPCResponseBytes caps how much of an upstream JSON-RPC response the proxy
// will buffer for structural validation. Matches existing streamable-HTTP body
// limits elsewhere in the codebase (pkg/vmcp/client, pkg/vmcp/session/internal/backend).
const maxJSONRPCResponseBytes = 100 << 20 // 100 MiB

// JSON-RPC error code returned to clients when the proxy rejects a malformed
// upstream response. -32000..-32099 is the implementation-defined server-error
// range in the JSON-RPC 2.0 spec; -32603 is reserved for internal JSON-RPC
// implementation errors and is not appropriate for a policy-level rejection.
const jsonRPCInvalidUpstreamCode = -32000

// ResponseProcessor defines the interface for processing and modifying HTTP responses
// based on transport-specific requirements.
type ResponseProcessor interface {
Expand All @@ -22,12 +40,38 @@ type ResponseProcessor interface {
ShouldProcess(resp *http.Response) bool
}

// NoOpResponseProcessor is a processor that does nothing.
// Used for transports that don't require response processing (e.g., streamable-http).
// NoOpResponseProcessor is the default processor for non-SSE transports.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scoping to non-SSE makes sense for now — SSE validation would need a fundamentally different approach. By the time ProcessResponse runs on an SSE response, the 200 OK and Content-Type: text/event-stream headers are already committed downstream, so there is no way to rewrite to 502 for a malformed mid-stream event. A proper fix would need a per-event streaming interceptor inside the SSE processor that can either synthesize an error event or close the stream — a meaningfully larger scope than this PR.

Worth opening a follow-up issue to track it explicitly, so it is clear the gap is deferred rather than forgotten.

// It validates JSON-RPC responses for streamable HTTP and otherwise leaves responses unchanged.
type NoOpResponseProcessor struct{}

// ProcessResponse is a no-op implementation.
func (*NoOpResponseProcessor) ProcessResponse(_ *http.Response) error {
// ProcessResponse validates JSON-RPC responses when applicable.
func (*NoOpResponseProcessor) ProcessResponse(resp *http.Response) error {
if !shouldValidateJSONRPCResponse(resp) {
return nil
}

// Read one byte past the cap so we can detect oversize without allocating beyond it.
body, err := io.ReadAll(io.LimitReader(resp.Body, maxJSONRPCResponseBytes+1))
if err != nil {
return fmt.Errorf("failed to read upstream response body: %w", err)
}
_ = resp.Body.Close()

if len(body) > maxJSONRPCResponseBytes {
writeInvalidUpstreamJSONRPCResponse(resp, fmt.Errorf(
"upstream JSON-RPC response exceeds maximum allowed size of %d bytes", maxJSONRPCResponseBytes))
return nil
}

if err := mcpparser.ValidateJSONRPCResponseBody(body); err != nil {
writeInvalidUpstreamJSONRPCResponse(resp, err)
return nil
}

// The reverse proxy still needs a readable body after validation.
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
return nil
}

Expand All @@ -36,6 +80,77 @@ func (*NoOpResponseProcessor) ShouldProcess(_ *http.Response) bool {
return false
}

func shouldValidateJSONRPCResponse(resp *http.Response) bool {
if resp == nil || resp.Body == nil || resp.Request == nil {
return false
}
if resp.Request.Method != http.MethodPost || resp.StatusCode != http.StatusOK {
return false
}
if !hasIdentityContentEncoding(resp.Header.Get("Content-Encoding")) {
// Content-Encoding semantics (RFC 9110): media-type rules apply after decoding.
// Validating a still-encoded body would mis-classify legitimate gzip JSON-RPC
// frames as invalid. Skip rather than introduce decompression here.
return false
}
if !requestLooksLikeMCP(resp.Request) {
// Narrow validation to traffic that carries an MCP streamable-HTTP signal,
// so non-MCP application/json POSTs flowing through the catch-all are not
// rewritten. Backward-compat clients omitting MCP-Protocol-Version on the
// initial initialize will pass through unchanged.
return false
}
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return false
}
return mediaType == "application/json" || mediaType == "application/json-rpc"
}

func hasIdentityContentEncoding(value string) bool {
v := strings.TrimSpace(strings.ToLower(value))
return v == "" || v == "identity"
}

func requestLooksLikeMCP(req *http.Request) bool {
if req == nil {
return false
}
if mcpparser.GetParsedMCPRequest(req.Context()) != nil {
return true
}
return req.Header.Get("MCP-Protocol-Version") != "" || req.Header.Get("Mcp-Session-Id") != ""
}

func writeInvalidUpstreamJSONRPCResponse(resp *http.Response, validationErr error) {
body, err := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"error": map[string]any{
"code": jsonRPCInvalidUpstreamCode,
"message": "Invalid upstream JSON-RPC response",
"data": validationErr.Error(),
},
"id": nil,
})
if err != nil {
body = []byte(`{"jsonrpc":"2.0","error":{"code":-32000,"message":"Invalid upstream JSON-RPC response"},"id":null}`)
}

resp.StatusCode = http.StatusBadGateway
resp.Status = fmt.Sprintf("%d %s", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))
resp.Body = io.NopCloser(bytes.NewReader(body))
resp.ContentLength = int64(len(body))

// Replace headers wholesale so upstream session/cookie/cache metadata is not
// smuggled into the proxy-generated error. Only carry the fields needed to
// describe this synthetic body.
resp.Header = http.Header{}
resp.Header.Set("Content-Type", "application/json")
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(body)))
resp.Trailer = nil
}

// createResponseProcessor is a factory function that creates the appropriate
// response processor based on transport type.
func createResponseProcessor(
Expand Down
Loading