Skip to content

Commit d7e23d4

Browse files
mesutoezdilEItanya
andauthored
fix(go-adk): forward allowedHeaders from incoming request to MCP tool calls (#1733)
Closes #1679 The allowedHeaders field on McpServer was being populated by the translator and stored in HttpMcpServerConfig/SseMcpServerConfig, but the Go runtime never used it. Headers from the incoming A2A request were silently dropped even when the operator explicitly listed them in allowedHeaders. The Python runtime handled this correctly already. What changed: - allowedRequestHeaders helper added to registry.go. It reads the a2asrv.CallContext already present in the Go context and returns only the header values whose names appear in the configured allowedHeaders list (case-insensitive). No intermediate context key, no extra copy. - headerRoundTripper gets an allowedHeaders field. On each outbound MCP HTTP request the transport calls allowedRequestHeaders and sets the matching headers. Static headers from the server spec are applied last so they always take precedence, mirroring the Python runtime behaviour. - CreateToolsets now passes AllowedHeaders from each MCP server config into the transport params. - executor.go no longer needs to touch headers at all. The A2A server already stores the CallContext in the Go context before Execute is called, so the round tripper picks it up automatically. The change is fully backwards compatible. If allowedHeaders is empty, no extra headers are forwarded. Tests use a2asrv.NewRequestMeta and a2asrv.WithCallContext to build the test context, matching what the real A2A server does. Coverage includes: normal forwarding, static-header-wins override, no A2A context, non-allowed header filter, and empty allowed list. Signed-off-by: mesutoezdil <mesudozdil@gmail.com> --------- Signed-off-by: mesutoezdil <mesudozdil@gmail.com> Signed-off-by: Eitan Yarmush <eitan.yarmush@solo.io> Co-authored-by: Eitan Yarmush <eitan.yarmush@solo.io>
1 parent 318b19b commit d7e23d4

2 files changed

Lines changed: 318 additions & 7 deletions

File tree

go/adk/pkg/mcp/registry.go

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"time"
1111

12+
"github.com/a2aproject/a2a-go/a2asrv"
1213
"github.com/go-logr/logr"
1314
"github.com/kagent-dev/kagent/go/api/adk"
1415
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
@@ -21,12 +22,47 @@ const (
2122
defaultTimeout = 30 * time.Minute
2223
)
2324

25+
// allowedRequestHeaders reads the incoming A2A request metadata from ctx and
26+
// returns only the header key/value pairs whose names appear in allowed.
27+
// It reads directly from the A2A CallContext that is already present in the Go
28+
// context, avoiding a redundant copy.
29+
//
30+
// Lookup relies on RequestMeta.Get which already does a case-insensitive O(1)
31+
// lookup (NewRequestMeta lowercases keys at construction). Keys in the result
32+
// preserve the casing from the allowed list so the MCP server sees the header
33+
// names the operator configured. When a header has multiple values only the
34+
// first one is forwarded; additional values are intentionally dropped.
35+
func allowedRequestHeaders(ctx context.Context, allowed []string) map[string]string {
36+
if len(allowed) == 0 {
37+
return nil
38+
}
39+
callCtx, ok := a2asrv.CallContextFrom(ctx)
40+
if !ok {
41+
return nil
42+
}
43+
meta := callCtx.RequestMeta()
44+
if meta == nil {
45+
return nil
46+
}
47+
result := make(map[string]string)
48+
for _, name := range allowed {
49+
if vals, ok := meta.Get(name); ok && len(vals) > 0 && vals[0] != "" {
50+
result[name] = vals[0]
51+
}
52+
}
53+
if len(result) == 0 {
54+
return nil
55+
}
56+
return result
57+
}
58+
2459
// mcpServerParams groups connection parameters for an MCP server,
2560
// reducing parameter sprawl across createTransport / initializeToolSet.
2661
type mcpServerParams struct {
2762
URL string
2863
Headers map[string]string
29-
ServerType string // "http" or "sse"
64+
AllowedHeaders []string // header names to forward from incoming request
65+
ServerType string // "http" or "sse"
3066
Timeout *float64
3167
SseReadTimeout *float64
3268
TLSInsecureSkipVerify *bool
@@ -46,6 +82,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
4682
params := mcpServerParams{
4783
URL: httpTool.Params.Url,
4884
Headers: httpTool.Params.Headers,
85+
AllowedHeaders: httpTool.AllowedHeaders,
4986
ServerType: "http",
5087
Timeout: httpTool.Params.Timeout,
5188
SseReadTimeout: httpTool.Params.SseReadTimeout,
@@ -65,6 +102,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
65102
params := mcpServerParams{
66103
URL: sseTool.Params.Url,
67104
Headers: sseTool.Params.Headers,
105+
AllowedHeaders: sseTool.AllowedHeaders,
68106
ServerType: "sse",
69107
Timeout: sseTool.Params.Timeout,
70108
SseReadTimeout: sseTool.Params.SseReadTimeout,
@@ -162,10 +200,11 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
162200
}
163201

164202
var httpTransport http.RoundTripper = baseTransport
165-
if len(params.Headers) > 0 {
203+
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 {
166204
httpTransport = &headerRoundTripper{
167-
base: baseTransport,
168-
headers: params.Headers,
205+
base: baseTransport,
206+
headers: params.Headers,
207+
allowedHeaders: params.AllowedHeaders,
169208
}
170209
}
171210

@@ -190,17 +229,35 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
190229
return mcpTransport, nil
191230
}
192231

193-
// headerRoundTripper wraps an http.RoundTripper to add custom headers to all requests.
232+
// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
233+
// requests. It supports two sources of headers:
234+
// - headers: static key/value pairs configured on the MCP server spec
235+
// - allowedHeaders: header names to forward from the incoming A2A request;
236+
// values are read on each call via allowedRequestHeaders directly from the
237+
// A2A CallContext that is already present in the Go context.
238+
//
239+
// Static headers take precedence: if an allowed header has the same name as a
240+
// static header, the static value wins.
194241
type headerRoundTripper struct {
195-
base http.RoundTripper
196-
headers map[string]string
242+
base http.RoundTripper
243+
headers map[string]string
244+
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
197245
}
198246

199247
func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
200248
req = req.Clone(req.Context())
249+
250+
// Forward allowed headers from the incoming A2A request first so that
251+
// static headers can override them if there is a name collision.
252+
for k, v := range allowedRequestHeaders(req.Context(), rt.allowedHeaders) {
253+
req.Header.Set(k, v)
254+
}
255+
256+
// Apply static headers (override any dynamic ones with the same name).
201257
for key, value := range rt.headers {
202258
req.Header.Set(key, value)
203259
}
260+
204261
return rt.base.RoundTrip(req)
205262
}
206263

go/adk/pkg/mcp/registry_test.go

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package mcp
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/a2aproject/a2a-go/a2asrv"
10+
)
11+
12+
// a2aCtx builds a context that carries an A2A CallContext with the given headers.
13+
// Keys are stored case-insensitively by NewRequestMeta, matching the behaviour
14+
// of a real A2A server.
15+
func a2aCtx(headers map[string][]string) context.Context {
16+
meta := a2asrv.NewRequestMeta(headers)
17+
ctx, _ := a2asrv.WithCallContext(context.Background(), meta)
18+
return ctx
19+
}
20+
21+
// TestAllowedRequestHeaders_ForwardsMatchingHeaders verifies that headers listed
22+
// in allowedHeaders are forwarded when they are present in the A2A CallContext.
23+
func TestAllowedRequestHeaders_ForwardsMatchingHeaders(t *testing.T) {
24+
t.Parallel()
25+
var capturedAuth, capturedCustom, capturedStatic string
26+
27+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
28+
capturedAuth = r.Header.Get("Authorization")
29+
capturedCustom = r.Header.Get("X-Custom")
30+
capturedStatic = r.Header.Get("X-Static")
31+
w.WriteHeader(http.StatusOK)
32+
}))
33+
defer srv.Close()
34+
35+
ctx := a2aCtx(map[string][]string{
36+
"Authorization": {"Bearer token123"},
37+
"X-Custom": {"custom-value"},
38+
"X-Ignored": {"should-not-appear"},
39+
})
40+
41+
rt := &headerRoundTripper{
42+
base: http.DefaultTransport,
43+
headers: map[string]string{"X-Static": "static-value"},
44+
allowedHeaders: []string{"Authorization", "X-Custom"},
45+
}
46+
47+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
48+
resp, err := rt.RoundTrip(req)
49+
if err != nil {
50+
t.Fatalf("RoundTrip failed: %v", err)
51+
}
52+
resp.Body.Close()
53+
54+
if capturedAuth != "Bearer token123" {
55+
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer token123")
56+
}
57+
if capturedCustom != "custom-value" {
58+
t.Errorf("X-Custom: got %q, want %q", capturedCustom, "custom-value")
59+
}
60+
if capturedStatic != "static-value" {
61+
t.Errorf("X-Static: got %q, want %q", capturedStatic, "static-value")
62+
}
63+
}
64+
65+
// TestAllowedRequestHeaders_StaticOverridesDynamic verifies that a statically
66+
// configured header wins over the same header forwarded from the A2A request.
67+
func TestAllowedRequestHeaders_StaticOverridesDynamic(t *testing.T) {
68+
t.Parallel()
69+
var capturedAuth string
70+
71+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
72+
capturedAuth = r.Header.Get("Authorization")
73+
w.WriteHeader(http.StatusOK)
74+
}))
75+
defer srv.Close()
76+
77+
ctx := a2aCtx(map[string][]string{
78+
"Authorization": {"Bearer incoming"},
79+
})
80+
81+
rt := &headerRoundTripper{
82+
base: http.DefaultTransport,
83+
headers: map[string]string{"Authorization": "Bearer static"},
84+
allowedHeaders: []string{"Authorization"},
85+
}
86+
87+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
88+
resp, err := rt.RoundTrip(req)
89+
if err != nil {
90+
t.Fatalf("RoundTrip failed: %v", err)
91+
}
92+
resp.Body.Close()
93+
94+
if capturedAuth != "Bearer static" {
95+
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer static")
96+
}
97+
}
98+
99+
// TestAllowedRequestHeaders_NoA2AContext verifies that no headers are forwarded
100+
// when the context does not carry an A2A CallContext.
101+
func TestAllowedRequestHeaders_NoA2AContext(t *testing.T) {
102+
t.Parallel()
103+
var capturedAuth string
104+
105+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
106+
capturedAuth = r.Header.Get("Authorization")
107+
w.WriteHeader(http.StatusOK)
108+
}))
109+
defer srv.Close()
110+
111+
rt := &headerRoundTripper{
112+
base: http.DefaultTransport,
113+
allowedHeaders: []string{"Authorization"},
114+
}
115+
116+
req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
117+
resp, err := rt.RoundTrip(req)
118+
if err != nil {
119+
t.Fatalf("RoundTrip failed: %v", err)
120+
}
121+
resp.Body.Close()
122+
123+
if capturedAuth != "" {
124+
t.Errorf("Authorization should be empty without A2A context, got %q", capturedAuth)
125+
}
126+
}
127+
128+
// TestAllowedRequestHeaders_IgnoresNonAllowed verifies that headers not listed
129+
// in allowedHeaders are not forwarded even if they appear in the A2A request.
130+
func TestAllowedRequestHeaders_IgnoresNonAllowed(t *testing.T) {
131+
t.Parallel()
132+
var capturedIgnored string
133+
134+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
135+
capturedIgnored = r.Header.Get("X-Ignored")
136+
w.WriteHeader(http.StatusOK)
137+
}))
138+
defer srv.Close()
139+
140+
ctx := a2aCtx(map[string][]string{
141+
"X-Ignored": {"should-not-appear"},
142+
})
143+
144+
rt := &headerRoundTripper{
145+
base: http.DefaultTransport,
146+
allowedHeaders: []string{"Authorization"},
147+
}
148+
149+
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
150+
resp, err := rt.RoundTrip(req)
151+
if err != nil {
152+
t.Fatalf("RoundTrip failed: %v", err)
153+
}
154+
resp.Body.Close()
155+
156+
if capturedIgnored != "" {
157+
t.Errorf("X-Ignored should not be forwarded, got %q", capturedIgnored)
158+
}
159+
}
160+
161+
// TestAllowedRequestHeaders_EmptyAllowedList verifies that allowedRequestHeaders
162+
// returns nil immediately when the allowed list is empty.
163+
func TestAllowedRequestHeaders_EmptyAllowedList(t *testing.T) {
164+
t.Parallel()
165+
ctx := a2aCtx(map[string][]string{
166+
"Authorization": {"Bearer token"},
167+
})
168+
169+
got := allowedRequestHeaders(ctx, nil)
170+
if got != nil {
171+
t.Errorf("expected nil for empty allowed list, got %v", got)
172+
}
173+
174+
got = allowedRequestHeaders(ctx, []string{})
175+
if got != nil {
176+
t.Errorf("expected nil for empty allowed list, got %v", got)
177+
}
178+
}
179+
180+
// TestAllowedRequestHeaders_CaseInsensitiveLookup verifies that matching between
181+
// the configured allowedHeaders and the incoming request headers is case-insensitive
182+
// regardless of which side is lowercased or uppercased.
183+
func TestAllowedRequestHeaders_CaseInsensitiveLookup(t *testing.T) {
184+
t.Parallel()
185+
186+
cases := []struct {
187+
name string
188+
incoming map[string][]string
189+
allowed []string
190+
wantKey string
191+
wantVal string
192+
}{
193+
{
194+
name: "allowed lowercase, incoming capitalized",
195+
incoming: map[string][]string{"Authorization": {"Bearer x"}},
196+
allowed: []string{"authorization"},
197+
wantKey: "authorization",
198+
wantVal: "Bearer x",
199+
},
200+
{
201+
name: "allowed capitalized, incoming lowercase",
202+
incoming: map[string][]string{"authorization": {"Bearer y"}},
203+
allowed: []string{"Authorization"},
204+
wantKey: "Authorization",
205+
wantVal: "Bearer y",
206+
},
207+
{
208+
name: "mixed case both sides",
209+
incoming: map[string][]string{"X-Trace-Id": {"abc"}},
210+
allowed: []string{"x-trace-id"},
211+
wantKey: "x-trace-id",
212+
wantVal: "abc",
213+
},
214+
}
215+
216+
for _, tc := range cases {
217+
t.Run(tc.name, func(t *testing.T) {
218+
t.Parallel()
219+
ctx := a2aCtx(tc.incoming)
220+
got := allowedRequestHeaders(ctx, tc.allowed)
221+
if got[tc.wantKey] != tc.wantVal {
222+
t.Errorf("got[%q] = %q, want %q (full map: %v)", tc.wantKey, got[tc.wantKey], tc.wantVal, got)
223+
}
224+
})
225+
}
226+
}
227+
228+
// TestAllowedRequestHeaders_MultiValueFirstWins documents the behaviour for headers
229+
// that arrive with multiple values: only the first one is forwarded. If a use case
230+
// ever needs all values, the helper signature will have to change.
231+
func TestAllowedRequestHeaders_MultiValueFirstWins(t *testing.T) {
232+
t.Parallel()
233+
ctx := a2aCtx(map[string][]string{
234+
"X-Forwarded-For": {"1.2.3.4", "5.6.7.8", "9.10.11.12"},
235+
})
236+
got := allowedRequestHeaders(ctx, []string{"X-Forwarded-For"})
237+
if got["X-Forwarded-For"] != "1.2.3.4" {
238+
t.Errorf("expected first value 1.2.3.4, got %q", got["X-Forwarded-For"])
239+
}
240+
}
241+
242+
// TestAllowedRequestHeaders_ReturnsNilWhenNoMatches verifies that the helper returns
243+
// nil rather than an empty map when the allowed list has entries but none of them
244+
// appear in the request metadata.
245+
func TestAllowedRequestHeaders_ReturnsNilWhenNoMatches(t *testing.T) {
246+
t.Parallel()
247+
ctx := a2aCtx(map[string][]string{
248+
"X-Something-Else": {"value"},
249+
})
250+
got := allowedRequestHeaders(ctx, []string{"Authorization", "X-Trace-Id"})
251+
if got != nil {
252+
t.Errorf("expected nil when no allowed headers are present, got %v", got)
253+
}
254+
}

0 commit comments

Comments
 (0)