Skip to content
13 changes: 13 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
panic(fmt.Errorf("AddTool %q: missing input schema", t.Name))
}
if s, ok := t.InputSchema.(*jsonschema.Schema); ok {
if s == nil {
panic(fmt.Errorf("AddTool %q: input schema is nil", t.Name))
}
if s.Type != "object" {
panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name))
}
Expand All @@ -261,6 +264,9 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
}
if t.OutputSchema != nil {
if s, ok := t.OutputSchema.(*jsonschema.Schema); ok {
if s == nil {
panic(fmt.Errorf("AddTool %q: output schema is nil", t.Name))
}
if s.Type != "object" {
panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name))
}
Expand Down Expand Up @@ -437,6 +443,9 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa
if err != nil {
return zero, err
}
if internalSchema == nil {
return zero, fmt.Errorf("schema is nil for type %v", rt)
}
*sfield = internalSchema

resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
Expand Down Expand Up @@ -466,6 +475,10 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, cache *SchemaCa
}
}

if internalSchema == nil {
return zero, fmt.Errorf("schema is nil for type %v", rt)
}

resolved, err := internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
if err != nil {
return zero, err
Expand Down
47 changes: 47 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"log/slog"
"slices"
Expand Down Expand Up @@ -604,6 +605,52 @@ func TestAddToolNameValidation(t *testing.T) {
}
}

func TestAddToolNilSchema(t *testing.T) {
var nilSchema *jsonschema.Schema

panicMsg := func(f func()) (msg string) {
defer func() {
if r := recover(); r != nil {
msg = fmt.Sprintf("%v", r)
}
}()
f()
return msg
}

// Call s.AddTool directly to exercise the typed-nil checks added to that method.
tests := []struct {
name string
tool *Tool
wantContain string
}{
{
name: "typed nil InputSchema",
tool: &Tool{Name: "T", InputSchema: nilSchema},
wantContain: "input schema is nil",
},
{
name: "typed nil OutputSchema",
tool: &Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}, OutputSchema: nilSchema},
wantContain: "output schema is nil",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
s := NewServer(testImpl, nil)

msg := panicMsg(func() {
s.AddTool(tc.tool, nil)
})
if msg == "" {
t.Error("expected panic")
} else if !strings.Contains(msg, tc.wantContain) {
t.Errorf("panic message %q does not contain %q", msg, tc.wantContain)
}
})
}
}

type schema = jsonschema.Schema

func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut any, wantErrContaining string) {
Expand Down