Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions internal/mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"

Expand All @@ -32,6 +33,26 @@ type Client struct {
Tokens TokenSource

reqID atomic.Int64

// The Streamable HTTP transport may run in stateful mode: it returns an
// Mcp-Session-Id on the initialize response that must be echoed on every
// subsequent request (else the server 400s with "No valid session ID"). We
// capture it from any response and replay it; in stateless mode the header is
// simply absent and this stays empty (harmless).
sessionMu sync.Mutex
sessionID string
}

func (c *Client) getSession() string {
c.sessionMu.Lock()
defer c.sessionMu.Unlock()
return c.sessionID
}

func (c *Client) setSession(id string) {
c.sessionMu.Lock()
c.sessionID = id
c.sessionMu.Unlock()
}

type rpcRequest struct {
Expand Down Expand Up @@ -76,10 +97,18 @@ func (c *Client) Call(ctx context.Context, method string, params any) (json.RawM
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
// The Streamable HTTP transport REQUIRES the client to accept both JSON and
// SSE; sending only application/json gets a 406 ("Client must accept both
// application/json and text/event-stream"). The server then replies with
// whichever it prefers — often a one-message text/event-stream, which the
// response path below unwraps.
req.Header.Set("Accept", "application/json, text/event-stream")
if c.UserAgent != "" {
req.Header.Set("User-Agent", c.UserAgent)
}
if sid := c.getSession(); sid != "" {
req.Header.Set("Mcp-Session-Id", sid)
}
if c.Tokens != nil {
tok, terr := c.Tokens.AccessToken(ctx)
if terr != nil {
Expand All @@ -90,26 +119,30 @@ func (c *Client) Call(ctx context.Context, method string, params any) (json.RawM
return httpClient.Do(req)
}

// attempt issues the request once, fully drains the body, closes it, and
// returns the captured status + bytes. Splitting it out lets us retry on
// 401 without ever holding two open response bodies at once.
attempt := func() (int, []byte, error) {
// attempt issues the request once, captures the Mcp-Session-Id (if any) for
// later requests, fully drains the body, closes it, and returns the captured
// status + content-type + bytes. Splitting it out lets us retry on 401 without
// ever holding two open response bodies at once.
attempt := func() (int, string, []byte, error) {
resp, derr := do()
if derr != nil {
return 0, nil, derr
return 0, "", nil, derr
}
defer resp.Body.Close()
if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" {
c.setSession(sid)
}
body, _ := io.ReadAll(resp.Body)
return resp.StatusCode, body, nil
return resp.StatusCode, resp.Header.Get("Content-Type"), body, nil
}

status, raw, err := attempt()
status, contentType, raw, err := attempt()
if err != nil {
return nil, sterr.Wrap("transport", "mcp request failed", sterr.ExitTransport, err)
}
if status == http.StatusUnauthorized && c.Tokens != nil {
c.Tokens.ForceRefresh()
status, raw, err = attempt()
status, contentType, raw, err = attempt()
if err != nil {
return nil, sterr.Wrap("transport", "mcp request failed after refresh", sterr.ExitTransport, err)
}
Expand All @@ -130,6 +163,11 @@ func (c *Client) Call(ctx context.Context, method string, params any) (json.RawM
return nil, sterr.New("transport", fmt.Sprintf("%d: %s", status, strings.TrimSpace(string(raw))), sterr.ExitTransport)
}

// A text/event-stream response wraps the JSON-RPC message in SSE framing;
// unwrap it to the raw JSON before decoding. A plain application/json response
// passes through untouched.
raw = extractRPCPayload(contentType, raw)

var rr rpcResponse
if err := json.Unmarshal(raw, &rr); err != nil {
return nil, sterr.Wrap("transport", "decode rpc response", sterr.ExitTransport, err)
Expand All @@ -142,6 +180,38 @@ func (c *Client) Call(ctx context.Context, method string, params any) (json.RawM
return rr.Result, nil
}

// extractRPCPayload returns the JSON-RPC bytes from a /mcp response body. The
// Streamable HTTP transport may answer a POST with Content-Type
// text/event-stream, framing the single JSON-RPC reply as one SSE event:
//
// event: message
// data: {"jsonrpc":"2.0","id":1,"result":{...}}
//
// We pull the `data:` payload of the first complete event (joining multiple
// data: lines per the SSE spec). For a plain application/json response — or if
// no data line is found — the body is returned unchanged.
func extractRPCPayload(contentType string, body []byte) []byte {
if !strings.Contains(strings.ToLower(contentType), "text/event-stream") {
return body
}
var data []string
for _, line := range strings.Split(string(body), "\n") {
line = strings.TrimRight(line, "\r")
if strings.HasPrefix(line, "data:") {
data = append(data, strings.TrimSpace(line[len("data:"):]))
continue
}
// A blank line dispatches the event; stop at the first one that has data.
if line == "" && len(data) > 0 {
break
}
}
if len(data) == 0 {
return body
}
return []byte(strings.Join(data, "\n"))
}

// classifyRPCError maps a JSON-RPC error to a CodedError.
//
// Structured numeric codes win over English substring matching: a -32602
Expand Down
151 changes: 151 additions & 0 deletions internal/mcp/streamhttp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package mcp

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)

// TestExtractRPCPayload covers the SSE-unwrapping helper in isolation.
func TestExtractRPCPayload(t *testing.T) {
cases := []struct {
name string
contentType string
body string
want string
}{
{"plain json passthrough", "application/json", `{"jsonrpc":"2.0","id":1,"result":{}}`, `{"jsonrpc":"2.0","id":1,"result":{}}`},
{"sse single event", "text/event-stream", "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n\n", `{"jsonrpc":"2.0","id":1,"result":{}}`},
{"sse no event prefix", "text/event-stream", "data: {\"ok\":true}\n\n", `{"ok":true}`},
{"sse crlf line endings", "text/event-stream", "event: message\r\ndata: {\"ok\":true}\r\n\r\n", `{"ok":true}`},
{"sse multi data lines joined", "text/event-stream", "data: {\"a\":1,\ndata: \"b\":2}\n\n", "{\"a\":1,\n\"b\":2}"},
{"sse charset suffix", "text/event-stream; charset=utf-8", "data: {\"ok\":true}\n\n", `{"ok":true}`},
{"sse no data falls back to body", "text/event-stream", "event: ping\n\n", "event: ping\n\n"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := string(extractRPCPayload(tc.contentType, []byte(tc.body)))
if got != tc.want {
t.Errorf("extractRPCPayload() = %q, want %q", got, tc.want)
}
})
}
}

// streamHTTPHandler mimics the GitVelocity Streamable HTTP transport: it 406s
// when the Accept header omits text/event-stream, issues an Mcp-Session-Id on
// the first response, and replies with an SSE-framed JSON-RPC message.
type streamHTTPHandler struct {
mu sync.Mutex
sawSessions []string // Mcp-Session-Id sent by the client per request
acceptSeen []string
}

func (h *streamHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var rpc struct {
ID int64 `json:"id"`
Method string `json:"method"`
}
_ = json.NewDecoder(r.Body).Decode(&rpc)

h.mu.Lock()
h.acceptSeen = append(h.acceptSeen, r.Header.Get("Accept"))
h.sawSessions = append(h.sawSessions, r.Header.Get("Mcp-Session-Id"))
h.mu.Unlock()

accept := r.Header.Get("Accept")
if !strings.Contains(accept, "application/json") || !strings.Contains(accept, "text/event-stream") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotAcceptable)
_, _ = w.Write([]byte(`{"jsonrpc":"2.0","error":{"code":-32000,"message":"Not Acceptable: Client must accept both application/json and text/event-stream"},"id":null}`))
return
}

var result string
switch rpc.Method {
case "initialize":
w.Header().Set("Mcp-Session-Id", "sess-123") // hand out a session
result = `{"protocolVersion":"2024-11-05","serverInfo":{"name":"GitVelocity","version":"9.9.9"}}`
case "tools/list":
result = `{"tools":[{"name":"get_current_user","description":"who am i","inputSchema":{"type":"object","properties":{}}}]}`
default:
result = `{}`
}
// Reply as SSE, the way the real transport does.
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":" + itoa(rpc.ID) + ",\"result\":" + result + "}\n\n"))
}

func itoa(n int64) string {
if n == 0 {
return "0"
}
neg := n < 0
if neg {
n = -n
}
var b []byte
for n > 0 {
b = append([]byte{byte('0' + n%10)}, b...)
n /= 10
}
if neg {
b = append([]byte{'-'}, b...)
}
return string(b)
}

// TestStreamableHTTP_AcceptAndSSE is the regression test for the 406 the real
// server returned: the client must send a dual Accept header and unwrap the SSE
// reply. It also pins session propagation: the Mcp-Session-Id handed out on
// initialize must ride along on the subsequent tools/list.
func TestStreamableHTTP_AcceptAndSSE(t *testing.T) {
h := &streamHTTPHandler{}
srv := httptest.NewServer(h)
defer srv.Close()

c := &Client{ServerURL: srv.URL, HTTPClient: srv.Client()}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

init, err := c.Initialize(ctx)
if err != nil {
t.Fatalf("Initialize: %v", err)
}
if init.ServerInfo.Version != "9.9.9" {
t.Errorf("server version = %q, want 9.9.9 (SSE body not unwrapped?)", init.ServerInfo.Version)
}

tools, err := c.ListTools(ctx)
if err != nil {
t.Fatalf("ListTools: %v", err)
}
if len(tools.Tools) != 1 || tools.Tools[0].Name != "get_current_user" {
t.Errorf("tools = %+v, want one get_current_user", tools.Tools)
}

h.mu.Lock()
defer h.mu.Unlock()
// Every request must have advertised the dual Accept.
for i, a := range h.acceptSeen {
if !strings.Contains(a, "application/json") || !strings.Contains(a, "text/event-stream") {
t.Errorf("request %d Accept = %q, want both json and event-stream", i, a)
}
}
// First request (initialize) carries no session; the second must echo sess-123.
if len(h.sawSessions) != 2 {
t.Fatalf("expected 2 requests, got %d", len(h.sawSessions))
}
if h.sawSessions[0] != "" {
t.Errorf("initialize should carry no session, got %q", h.sawSessions[0])
}
if h.sawSessions[1] != "sess-123" {
t.Errorf("tools/list session = %q, want sess-123 (not propagated?)", h.sawSessions[1])
}
}
Loading