diff --git a/mcp/server.go b/mcp/server.go index 28504376..16d06ca8 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -247,6 +247,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) } if s, ok := t.InputSchema.(*jsonschema.Schema); ok { + if s == nil { + panic(fmt.Errorf("AddTool %q: input schema is nil", t.Name)) + } if s.Type != "object" { panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) } @@ -261,6 +264,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { } if t.OutputSchema != nil { if s, ok := t.OutputSchema.(*jsonschema.Schema); ok { + if s == nil { + panic(fmt.Errorf("AddTool %q: output schema is nil", t.Name)) + } if s.Type != "object" { panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) } @@ -437,6 +443,9 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa if err != nil { return zero, err } + if internalSchema == nil { + return zero, fmt.Errorf("schema is nil for type %v", rt) + } *sfield = internalSchema resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) @@ -466,6 +475,10 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa } } + if internalSchema == nil { + return zero, fmt.Errorf("schema is nil for type %v", rt) + } + resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) if err != nil { return zero, err diff --git a/mcp/server_test.go b/mcp/server_test.go index 1312e1d9..2937ea2b 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "log" "log/slog" "slices" @@ -604,6 +605,52 @@ func TestAddToolNameValidation(t *testing.T) { } } +func TestAddToolNilSchema(t *testing.T) { + var nilSchema *jsonschema.Schema + + panicMsg := func(f func()) (msg string) { + defer func() { + if r := recover(); r != nil { + msg = fmt.Sprintf("%v", r) + } + }() + f() + return msg + } + + // Call s.AddTool directly to exercise the typed-nil checks added to that method. + tests := []struct { + name string + tool *Tool + wantContain string + }{ + { + name: "typed nil InputSchema", + tool: &Tool{Name: "T", InputSchema: nilSchema}, + wantContain: "input schema is nil", + }, + { + name: "typed nil OutputSchema", + tool: &Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}, OutputSchema: nilSchema}, + wantContain: "output schema is nil", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s := NewServer(testImpl, nil) + + msg := panicMsg(func() { + s.AddTool(tc.tool, nil) + }) + if msg == "" { + t.Error("expected panic") + } else if !strings.Contains(msg, tc.wantContain) { + t.Errorf("panic message %q does not contain %q", msg, tc.wantContain) + } + }) + } +} + type schema = jsonschema.Schema func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut any, wantErrContaining string) {