diff --git a/command_run.go b/command_run.go index 87fa7e7fa5..c509b0172b 100644 --- a/command_run.go +++ b/command_run.go @@ -158,7 +158,12 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context tracef("using post-parse arguments %[1]q (cmd=%[2]q)", args, cmd.Name) - if checkCompletions(ctx, cmd) { + if shouldRunCompletion(cmd) { + var beforeErr error + if ctx, beforeErr = runBefore(ctx, commandChain(cmd)); beforeErr != nil { + return ctx, beforeErr + } + runCompletion(ctx, cmd) return ctx, nil } @@ -297,23 +302,12 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context // perform the command action. // // First, resolve the chain of nested commands up to the parent. - var cmdChain []*Command - for p := cmd; p != nil; p = p.parent { - cmdChain = append(cmdChain, p) - } - slices.Reverse(cmdChain) + cmdChain := commandChain(cmd) // Run Before actions in order. - for _, cmd := range cmdChain { - if cmd.Before == nil { - continue - } - if bctx, err := cmd.Before(ctx, cmd); err != nil { - deferErr = cmd.handleExitCoder(ctx, err) - return ctx, deferErr - } else if bctx != nil { - ctx = bctx - } + if ctx, err = runBefore(ctx, cmdChain); err != nil { + deferErr = err + return ctx, deferErr } // Run flag actions in order. @@ -363,3 +357,26 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context tracef("returning deferErr (cmd=%[1]q) %[2]q", cmd.Name, deferErr) return ctx, deferErr } + +func commandChain(cmd *Command) []*Command { + var cmdChain []*Command + for p := cmd; p != nil; p = p.parent { + cmdChain = append(cmdChain, p) + } + slices.Reverse(cmdChain) + return cmdChain +} + +func runBefore(ctx context.Context, cmdChain []*Command) (context.Context, error) { + for _, cmd := range cmdChain { + if cmd.Before == nil { + continue + } + if bctx, err := cmd.Before(ctx, cmd); err != nil { + return ctx, cmd.handleExitCoder(ctx, err) + } else if bctx != nil { + ctx = bctx + } + } + return ctx, nil +} diff --git a/command_test.go b/command_test.go index 24ce5e4739..338565e7f6 100644 --- a/command_test.go +++ b/command_test.go @@ -2098,8 +2098,9 @@ func TestCommand_OrderOfOperations(t *testing.T) { r := require.New(t) _ = cmd.Run(buildTestContext(t), []string{"command", completionFlag}) - r.Equal(1, counts.ShellComplete) - r.Equal(1, counts.Total) + r.Equal(1, counts.Before) + r.Equal(2, counts.ShellComplete) + r.Equal(2, counts.Total) }) t.Run("nil on usage error", func(t *testing.T) { diff --git a/completion_test.go b/completion_test.go index 2da7f622e2..26b7cf995f 100644 --- a/completion_test.go +++ b/completion_test.go @@ -3,7 +3,9 @@ package cli import ( "bytes" "context" + "errors" "fmt" + "io" "strings" "testing" @@ -393,6 +395,58 @@ func TestCompletionSubcommandCustomShellComplete(t *testing.T) { r.Equal("custom-index\n", out.String()) } +func TestCompletionRunsBeforeChain(t *testing.T) { + type contextKey struct{} + + out := &bytes.Buffer{} + cmd := &Command{ + EnableShellCompletion: true, + Writer: out, + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return context.WithValue(ctx, contextKey{}, "ready"), nil + }, + Commands: []*Command{ + { + Name: "index", + Commands: []*Command{ + { + Name: "show", + ShellComplete: func(ctx context.Context, cmd *Command) { + fmt.Fprintln(cmd.Root().Writer, ctx.Value(contextKey{})) + }, + Action: func(ctx context.Context, cmd *Command) error { return nil }, + }, + }, + }, + }, + } + + r := require.New(t) + r.NoError(cmd.Run(buildTestContext(t), []string{"foo", "index", "show", completionFlag})) + r.Equal("ready\n", out.String()) +} + +func TestCompletionReturnsBeforeError(t *testing.T) { + beforeErr := errors.New("load config") + completed := false + + cmd := &Command{ + EnableShellCompletion: true, + Writer: io.Discard, + Before: func(ctx context.Context, cmd *Command) (context.Context, error) { + return nil, beforeErr + }, + ShellComplete: func(ctx context.Context, cmd *Command) { + completed = true + }, + } + + err := cmd.Run(buildTestContext(t), []string{"foo", completionFlag}) + + require.ErrorIs(t, err, beforeErr) + assert.False(t, completed) +} + func TestCompletionInvalidShell(t *testing.T) { cmd := &Command{ EnableShellCompletion: true, diff --git a/help.go b/help.go index 1fba6edf0c..d901ba770a 100644 --- a/help.go +++ b/help.go @@ -492,7 +492,7 @@ func checkShellCompleteFlag(c *Command, arguments []string) (bool, []string) { return true, arguments[:pos] } -func checkCompletions(ctx context.Context, cmd *Command) bool { +func shouldRunCompletion(cmd *Command) bool { tracef("checking completions on command %[1]q", cmd.Name) if !cmd.Root().shellCompletion { @@ -509,13 +509,14 @@ func checkCompletions(ctx context.Context, cmd *Command) bool { } tracef("no subcommand found for completion %[1]q", cmd.Name) + return true +} +func runCompletion(ctx context.Context, cmd *Command) { if cmd.ShellComplete != nil { tracef("running shell completion func for command %[1]q", cmd.Name) cmd.ShellComplete(ctx, cmd) } - - return true } func subtract(a, b int) int {