Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"

"github.com/Kong/volcano-cli/internal/api"
"github.com/Kong/volcano-cli/internal/apiclient"
"github.com/Kong/volcano-cli/internal/config"
"github.com/Kong/volcano-cli/internal/localmode"
cliruntime "github.com/Kong/volcano-cli/internal/runtime"
clisession "github.com/Kong/volcano-cli/internal/session"
)
Expand Down Expand Up @@ -60,7 +57,7 @@ func (s Service) LoginWithToken(ctx context.Context, cfg *config.Config, token s
// LoginWithBrowser runs the OAuth device flow and returns credentials to persist.
func (s Service) LoginWithBrowser(ctx context.Context, cfg *config.Config, w io.Writer) (Credentials, error) {
apiURL := s.apiURL(cfg)
clientID, err := resolveDeviceClientID(apiURL)
clientID, err := cfg.DeviceClientID()
if err != nil {
return Credentials{}, err
}
Expand All @@ -78,36 +75,6 @@ func (s Service) LoginWithBrowser(ctx context.Context, cfg *config.Config, w io.
return s.completeBrowserLogin(ctx, client, clientID, deviceAuth, w)
}

// resolveDeviceClientID returns the device OAuth client id for the login flow.
// When the CLI is pointed at a loopback address the local server only knows the
// deterministic local device client, so issue it directly; otherwise defer to
// the configured id.
func resolveDeviceClientID(apiURL string) (string, error) {
if isLocalAPIURL(apiURL) {
return localmode.DeviceClientID, nil
}
return config.FirstPartyDeviceClientID()
}

// isLocalAPIURL reports whether apiURL points at a loopback address.
func isLocalAPIURL(apiURL string) bool {
u, err := url.Parse(strings.TrimSpace(apiURL))
if err != nil {
return false
}
host := u.Hostname()
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil && ip.IsLoopback() {
return true
}
return false
}

// Logout deletes local authentication state.
func (s Service) Logout() error {
return config.Delete()
Expand Down
103 changes: 51 additions & 52 deletions internal/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
"github.com/stretchr/testify/require"

"github.com/Kong/volcano-cli/internal/config"
"github.com/Kong/volcano-cli/internal/localmode"
cliruntime "github.com/Kong/volcano-cli/internal/runtime"
)

const authAlphaProjectID = "11111111-1111-4111-8111-111111111111"

Check failure on line 21 in internal/auth/auth_test.go

View workflow job for this annotation

GitHub Actions / check / check

File is not properly formatted (gofumpt)
const testProdDeviceClientID = "devcli_94e247237984b85cfd58d37e"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI is currently failing gofumpt on this const section. Grouping the three const declarations should satisfy the formatter.

const testDevDeviceClientID = "devcli_dcc913b9786f9ef2825b861c"

func TestLoginWithTokenSuccess(t *testing.T) {
cfg := testAuthConfig(t)
Expand Down Expand Up @@ -60,8 +61,6 @@

func TestLoginWithBrowserDeviceFlow(t *testing.T) {
cfg := testAuthConfig(t)
// httptest binds to 127.0.0.1, so resolveDeviceClientID issues the
// deterministic local-mode client id regardless of any configured value.

var pollCount atomic.Int32
var openedURL string
Expand All @@ -75,7 +74,7 @@
case "/auth/device/authorize":
var body map[string]string
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
assert.Equal(t, localmode.DeviceClientID, body["client_id"])
assert.Equal(t, testProdDeviceClientID, body["client_id"])
writeAuthJSON(t, w, http.StatusOK, map[string]any{
"device_code": "device-code",
"user_code": "ABCD-EFGH",
Expand Down Expand Up @@ -310,58 +309,58 @@
t.ch <- time.Now()
}

func TestResolveDeviceClientIDLocalURL(t *testing.T) {
t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "env-device-client")

for _, apiURL := range []string{
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://[::1]:8000",
"http://LOCALHOST:8000",
} {
t.Run(apiURL, func(t *testing.T) {
got, err := resolveDeviceClientID(apiURL)
require.NoError(t, err)
assert.Equal(t, localmode.DeviceClientID, got)
})
}
}

func TestResolveDeviceClientIDCloudEnv(t *testing.T) {
t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "env-device-client")

got, err := resolveDeviceClientID("https://api.volcano.dev")
require.NoError(t, err)
assert.Equal(t, "env-device-client", got)
}
func TestLoginWithBrowserUsesSelectedContextDeviceClient(t *testing.T) {
cfg := testAuthConfig(t)
cfg.SetDefaultContext(config.ContextDev)

func TestResolveDeviceClientIDMissing(t *testing.T) {
t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "")
timeoutTimer := newAuthFakeTicker()
pollTicker := newAuthFakeTicker()
dotTicker := newAuthFakeTicker()

_, err := resolveDeviceClientID("https://api.volcano.dev")
require.Error(t, err)
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch strings.TrimSuffix(r.URL.Path, "/") {
case "/auth/device/authorize":
var body map[string]string
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
assert.Equal(t, testDevDeviceClientID, body["client_id"])
writeAuthJSON(t, w, http.StatusOK, map[string]any{
"device_code": "device-code",
"user_code": "ABCD-EFGH",
"verification_uri": "https://volcano.dev/device",
"expires_in": 120,
"interval": 1,
})
case "/auth/device/token":
writeAuthJSON(t, w, http.StatusOK, map[string]any{"access_token": "auth-access-token"})
case "/auth/platform/exchange":
writeAuthJSON(t, w, http.StatusOK, map[string]any{"token": "platform-token"})
default:
http.NotFound(w, r)
}
}))
defer server.Close()

func TestIsLocalAPIURL(t *testing.T) {
for _, tc := range []struct {
url string
want bool
}{
{"http://localhost:8000", true},
{"https://localhost", true},
{"http://127.0.0.1:8000", true},
{"http://127.7.7.7:8000", true},
{"http://[::1]:8000", true},
{"https://api.volcano.dev", false},
{"https://api.staging.volcano.dev", false},
{"http://192.168.1.10:8000", false},
{"", false},
{"::not a url::", false},
} {
t.Run(tc.url, func(t *testing.T) {
assert.Equal(t, tc.want, isLocalAPIURL(tc.url))
})
deps := cliruntime.Deps{
HTTPClient: server.Client(),
APIBaseURL: server.URL,
OpenBrowser: func(string) error { return nil },
NewTimer: func(time.Duration) cliruntime.Timer { return timeoutTimer },
NewTicker: func(time.Duration) cliruntime.Ticker {
if pollTicker.created.CompareAndSwap(false, true) {
return pollTicker
}
return dotTicker
},
}

var out bytes.Buffer
done := make(chan error, 1)
go func() {
_, err := NewService(deps).LoginWithBrowser(context.Background(), cfg, &out)
done <- err
}()
pollTicker.tick()
require.NoError(t, <-done)
}

func testAuthConfig(t *testing.T) *config.Config {
Expand Down
6 changes: 4 additions & 2 deletions internal/cmd/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func runLogin(ctx context.Context, opts loginOptions) error {
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if opts.deps.ContextName != nil && strings.TrimSpace(*opts.deps.ContextName) != "" {
cfg.SetContextOverride(*opts.deps.ContextName)
}

service := cliauth.NewService(opts.deps)
var credentials cliauth.Credentials
Expand All @@ -73,8 +76,7 @@ func runLogin(ctx context.Context, opts loginOptions) error {
}
}

cfg.UserToken = credentials.Token
cfg.UserID = credentials.UserID
cfg.SetCredentials(credentials.Token, credentials.UserID)
if err := cfg.Save(); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
Expand Down
166 changes: 166 additions & 0 deletions internal/cmd/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Package context wires the volcano context command tree.
package context

import (
"fmt"
"io"
"sort"
"strings"

"github.com/spf13/cobra"

"github.com/Kong/volcano-cli/internal/config"
"github.com/Kong/volcano-cli/internal/output"
)

type setOptions struct {
name string
apiURL string
deviceClientID string
out io.Writer
}

// New returns the context command tree.
func New() *cobra.Command {
cmd := &cobra.Command{
Use: "context",
Short: "Manage Volcano API contexts",
Long: "Manage named Volcano API contexts used by login and top-level commands.",
}
cmd.AddCommand(newList())
cmd.AddCommand(newUse())
cmd.AddCommand(newSet())
cmd.AddCommand(newDelete())
return cmd
}

func newList() *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List contexts",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
cfg, err := config.Load()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command loads config directly, so it never sees the root deps.ContextName pointer populated by the global --context flag. As a result, volcano --context stage context list still marks prod as current. Can the context command accept/apply the same override, or otherwise keep the advertised global context selection consistent here?

if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
printContexts(cmd.OutOrStdout(), cfg)
return nil
},
}
}

func newUse() *cobra.Command {
return &cobra.Command{
Use: "use <name>",
Short: "Set the default context",
Long: "Set the context used by default when --context and VOLCANO_CONTEXT are not set.",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
name := config.NormalizeContextName(args[0])
cfg.SetDefaultContext(name)
if err := cfg.Save(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
output.Success(cmd.OutOrStdout(), "Now using context: %s", cfg.DefaultContext)
return nil
},
}
}

func newSet() *cobra.Command {
var opts setOptions
cmd := &cobra.Command{
Use: "set <name>",
Short: "Create or update a context",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.name = args[0]
opts.out = cmd.OutOrStdout()
return runSet(opts)
},
}
cmd.Flags().StringVar(&opts.apiURL, "api-url", "", "API URL for the context")
cmd.Flags().StringVar(&opts.deviceClientID, "device-client-id", "", "OAuth device client ID for browser login")
return cmd
}

func runSet(opts setOptions) error {
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
name := config.NormalizeContextName(opts.name)
if name == "" {
return fmt.Errorf("context name cannot be empty")

Check failure on line 99 in internal/cmd/context/context.go

View workflow job for this annotation

GitHub Actions / check / check

error-format: fmt.Errorf can be replaced with errors.New (perfsprint)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI is currently failing perfsprint here because this static error can be errors.New(...) instead of fmt.Errorf(...).

}
ctx := cfg.EnsureContext(name)
if apiURL := strings.TrimSpace(opts.apiURL); apiURL != "" {
ctx.APIBaseURL = apiURL
}
if clientID := strings.TrimSpace(opts.deviceClientID); clientID != "" {
ctx.DeviceClientID = clientID
}
if err := cfg.Save(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
output.Success(opts.out, "Context saved: %s", name)
return nil
}

func newDelete() *cobra.Command {
return &cobra.Command{
Use: "delete <name>",
Short: "Delete a custom context",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
name := config.NormalizeContextName(args[0])
if config.IsBuiltInContext(name) {
return fmt.Errorf("cannot delete built-in context %q", name)
}
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
cfg.DeleteContext(name)
if err := cfg.Save(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
output.Success(cmd.OutOrStdout(), "Context deleted: %s", name)
return nil
},
}
}

func printContexts(w io.Writer, cfg *config.Config) {
names := map[string]struct{}{}
for _, name := range config.BuiltInContextNames() {
names[name] = struct{}{}
}
for name := range cfg.Contexts {
names[config.NormalizeContextName(name)] = struct{}{}
}
sorted := make([]string, 0, len(names))
for name := range names {
if name != "" {
sorted = append(sorted, name)
}
}
sort.Strings(sorted)

active := cfg.ActiveContextName()
fmt.Fprintf(w, "%-8s %-7s %s\n", "CURRENT", "NAME", "API URL")
fmt.Fprintln(w, strings.Repeat("-", 72))
for _, name := range sorted {
marker := ""
if name == active {
marker = "*"
}
resolved := cfg.ResolvedContext(name)
fmt.Fprintf(w, "%-8s %-7s %s\n", marker, name, resolved.APIBaseURL)
}
}
Loading
Loading