From 373e09f6049250dc5bb23ab9bd3a859e9fcd37d8 Mon Sep 17 00:00:00 2001 From: Maxence Maireaux Date: Thu, 4 Jun 2026 13:26:42 +0200 Subject: [PATCH] fix: reject parse errors before execution --- internal/mcp_impl/handlers.go | 70 +++++++++++++++--------------- internal/mcp_impl/handlers_test.go | 28 ++++++++++++ numscript.go | 4 ++ numscript_test.go | 9 ++++ 4 files changed, 77 insertions(+), 34 deletions(-) create mode 100644 internal/mcp_impl/handlers_test.go diff --git a/internal/mcp_impl/handlers.go b/internal/mcp_impl/handlers.go index 93139eb9..abe76b7b 100644 --- a/internal/mcp_impl/handlers.go +++ b/internal/mcp_impl/handlers.go @@ -95,45 +95,47 @@ func addEvalTool(s *server.MCPServer) { `), ), ) - s.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - script, err := request.RequireString("script") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } + s.AddTool(tool, handleEvalTool) +} - parsed := parser.Parse(script) - if len(parsed.Errors) != 0 { - out := make([]string, len(parsed.Errors)) - for index, err := range parsed.Errors { - out[index] = err.Msg - } - mcp.NewToolResultError(strings.Join(out, ", ")) - } +func handleEvalTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + script, err := request.RequireString("script") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - balances, mcpErr := parseBalancesJson(request.GetArguments()["balances"]) - if mcpErr != nil { - return mcpErr, nil + parsed := parser.Parse(script) + if len(parsed.Errors) != 0 { + out := make([]string, len(parsed.Errors)) + for index, err := range parsed.Errors { + out[index] = err.Msg } + return mcp.NewToolResultError(strings.Join(out, ", ")), nil + } - vars, mcpErr := parseVarsJson(request.GetArguments()["vars"]) - if mcpErr != nil { - return mcpErr, nil - } + balances, mcpErr := parseBalancesJson(request.GetArguments()["balances"]) + if mcpErr != nil { + return mcpErr, nil + } - out, iErr := interpreter.RunProgram( - ctx, - parsed.Value, - vars, - interpreter.StaticStore{ - Balances: balances, - }, - map[string]struct{}{}, - ) - if iErr != nil { - return mcp.NewToolResultError(iErr.Error()), nil - } - return mcp.NewToolResultJSON(*out) - }) + vars, mcpErr := parseVarsJson(request.GetArguments()["vars"]) + if mcpErr != nil { + return mcpErr, nil + } + + out, iErr := interpreter.RunProgram( + ctx, + parsed.Value, + vars, + interpreter.StaticStore{ + Balances: balances, + }, + map[string]struct{}{}, + ) + if iErr != nil { + return mcp.NewToolResultError(iErr.Error()), nil + } + return mcp.NewToolResultJSON(*out) } func addCheckTool(s *server.MCPServer) { diff --git a/internal/mcp_impl/handlers_test.go b/internal/mcp_impl/handlers_test.go new file mode 100644 index 00000000..7398d96e --- /dev/null +++ b/internal/mcp_impl/handlers_test.go @@ -0,0 +1,28 @@ +package mcp_impl + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +func TestHandleEvalToolRejectsParseErrors(t *testing.T) { + result, err := handleEvalTool(context.Background(), mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "script": "send [COIN 100] (", + "balances": map[string]any{}, + "vars": map[string]any{}, + }, + }, + }) + + require.NoError(t, err) + require.True(t, result.IsError) + require.NotEmpty(t, result.Content) + text, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok) + require.Contains(t, text.Text, "mismatched input") +} diff --git a/numscript.go b/numscript.go index e7fa0437..08cd5bdd 100644 --- a/numscript.go +++ b/numscript.go @@ -86,6 +86,10 @@ func (p ParseResult) RunWithFeatureFlags( store Store, featureFlags map[string]struct{}, ) (ExecutionResult, InterpreterError) { + if len(p.parseResult.Errors) != 0 { + return ExecutionResult{}, p.parseResult.Errors[0] + } + if featureFlags == nil { featureFlags = make(map[string]struct{}) } diff --git a/numscript_test.go b/numscript_test.go index b4da92d6..aef0c1da 100644 --- a/numscript_test.go +++ b/numscript_test.go @@ -58,6 +58,15 @@ func TestGetVarsNovars(t *testing.T) { ) } +func TestRunRejectsParseErrors(t *testing.T) { + parseResult := numscript.Parse(`send [COIN 100] (`) + require.NotEmpty(t, parseResult.GetParsingErrors()) + + _, err := parseResult.Run(context.Background(), nil, interpreter.StaticStore{}) + require.Error(t, err) + require.Equal(t, parseResult.GetParsingErrors()[0].Error(), err.Error()) +} + func TestDoNotGetWorldBalance(t *testing.T) { parseResult := numscript.Parse(`send [COIN 100] ( source = @world