diff --git a/internal/auth/auth.go b/internal/auth/auth.go index b9bbe49..dd7c4f1 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -6,16 +6,13 @@ import ( "errors" "fmt" "io" - "net" "net/http" - "net/url" "strings" "time" "github.com/Kong/volcano-cli/internal/api" "github.com/Kong/volcano-cli/internal/apiclient" "github.com/Kong/volcano-cli/internal/config" - "github.com/Kong/volcano-cli/internal/localmode" cliruntime "github.com/Kong/volcano-cli/internal/runtime" clisession "github.com/Kong/volcano-cli/internal/session" ) @@ -60,7 +57,7 @@ func (s Service) LoginWithToken(ctx context.Context, cfg *config.Config, token s // LoginWithBrowser runs the OAuth device flow and returns credentials to persist. func (s Service) LoginWithBrowser(ctx context.Context, cfg *config.Config, w io.Writer) (Credentials, error) { apiURL := s.apiURL(cfg) - clientID, err := resolveDeviceClientID(apiURL) + clientID, err := cfg.DeviceClientID() if err != nil { return Credentials{}, err } @@ -78,36 +75,6 @@ func (s Service) LoginWithBrowser(ctx context.Context, cfg *config.Config, w io. return s.completeBrowserLogin(ctx, client, clientID, deviceAuth, w) } -// resolveDeviceClientID returns the device OAuth client id for the login flow. -// When the CLI is pointed at a loopback address the local server only knows the -// deterministic local device client, so issue it directly; otherwise defer to -// the configured id. -func resolveDeviceClientID(apiURL string) (string, error) { - if isLocalAPIURL(apiURL) { - return localmode.DeviceClientID, nil - } - return config.FirstPartyDeviceClientID() -} - -// isLocalAPIURL reports whether apiURL points at a loopback address. -func isLocalAPIURL(apiURL string) bool { - u, err := url.Parse(strings.TrimSpace(apiURL)) - if err != nil { - return false - } - host := u.Hostname() - if host == "" { - return false - } - if strings.EqualFold(host, "localhost") { - return true - } - if ip := net.ParseIP(host); ip != nil && ip.IsLoopback() { - return true - } - return false -} - // Logout deletes local authentication state. func (s Service) Logout() error { return config.Delete() diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 0557804..3c20ede 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -15,11 +15,12 @@ import ( "github.com/stretchr/testify/require" "github.com/Kong/volcano-cli/internal/config" - "github.com/Kong/volcano-cli/internal/localmode" cliruntime "github.com/Kong/volcano-cli/internal/runtime" ) const authAlphaProjectID = "11111111-1111-4111-8111-111111111111" +const testProdDeviceClientID = "devcli_94e247237984b85cfd58d37e" +const testDevDeviceClientID = "devcli_dcc913b9786f9ef2825b861c" func TestLoginWithTokenSuccess(t *testing.T) { cfg := testAuthConfig(t) @@ -60,8 +61,6 @@ func TestLoginWithTokenInvalid(t *testing.T) { func TestLoginWithBrowserDeviceFlow(t *testing.T) { cfg := testAuthConfig(t) - // httptest binds to 127.0.0.1, so resolveDeviceClientID issues the - // deterministic local-mode client id regardless of any configured value. var pollCount atomic.Int32 var openedURL string @@ -75,7 +74,7 @@ func TestLoginWithBrowserDeviceFlow(t *testing.T) { case "/auth/device/authorize": var body map[string]string require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) - assert.Equal(t, localmode.DeviceClientID, body["client_id"]) + assert.Equal(t, testProdDeviceClientID, body["client_id"]) writeAuthJSON(t, w, http.StatusOK, map[string]any{ "device_code": "device-code", "user_code": "ABCD-EFGH", @@ -310,58 +309,58 @@ func (t *authFakeTicker) tick() { t.ch <- time.Now() } -func TestResolveDeviceClientIDLocalURL(t *testing.T) { - t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "env-device-client") - - for _, apiURL := range []string{ - "http://localhost:8000", - "http://127.0.0.1:8000", - "http://[::1]:8000", - "http://LOCALHOST:8000", - } { - t.Run(apiURL, func(t *testing.T) { - got, err := resolveDeviceClientID(apiURL) - require.NoError(t, err) - assert.Equal(t, localmode.DeviceClientID, got) - }) - } -} - -func TestResolveDeviceClientIDCloudEnv(t *testing.T) { - t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "env-device-client") - - got, err := resolveDeviceClientID("https://api.volcano.dev") - require.NoError(t, err) - assert.Equal(t, "env-device-client", got) -} +func TestLoginWithBrowserUsesSelectedContextDeviceClient(t *testing.T) { + cfg := testAuthConfig(t) + cfg.SetDefaultContext(config.ContextDev) -func TestResolveDeviceClientIDMissing(t *testing.T) { - t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "") + timeoutTimer := newAuthFakeTicker() + pollTicker := newAuthFakeTicker() + dotTicker := newAuthFakeTicker() - _, err := resolveDeviceClientID("https://api.volcano.dev") - require.Error(t, err) -} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch strings.TrimSuffix(r.URL.Path, "/") { + case "/auth/device/authorize": + var body map[string]string + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, testDevDeviceClientID, body["client_id"]) + writeAuthJSON(t, w, http.StatusOK, map[string]any{ + "device_code": "device-code", + "user_code": "ABCD-EFGH", + "verification_uri": "https://volcano.dev/device", + "expires_in": 120, + "interval": 1, + }) + case "/auth/device/token": + writeAuthJSON(t, w, http.StatusOK, map[string]any{"access_token": "auth-access-token"}) + case "/auth/platform/exchange": + writeAuthJSON(t, w, http.StatusOK, map[string]any{"token": "platform-token"}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() -func TestIsLocalAPIURL(t *testing.T) { - for _, tc := range []struct { - url string - want bool - }{ - {"http://localhost:8000", true}, - {"https://localhost", true}, - {"http://127.0.0.1:8000", true}, - {"http://127.7.7.7:8000", true}, - {"http://[::1]:8000", true}, - {"https://api.volcano.dev", false}, - {"https://api.staging.volcano.dev", false}, - {"http://192.168.1.10:8000", false}, - {"", false}, - {"::not a url::", false}, - } { - t.Run(tc.url, func(t *testing.T) { - assert.Equal(t, tc.want, isLocalAPIURL(tc.url)) - }) + deps := cliruntime.Deps{ + HTTPClient: server.Client(), + APIBaseURL: server.URL, + OpenBrowser: func(string) error { return nil }, + NewTimer: func(time.Duration) cliruntime.Timer { return timeoutTimer }, + NewTicker: func(time.Duration) cliruntime.Ticker { + if pollTicker.created.CompareAndSwap(false, true) { + return pollTicker + } + return dotTicker + }, } + + var out bytes.Buffer + done := make(chan error, 1) + go func() { + _, err := NewService(deps).LoginWithBrowser(context.Background(), cfg, &out) + done <- err + }() + pollTicker.tick() + require.NoError(t, <-done) } func testAuthConfig(t *testing.T) *config.Config { diff --git a/internal/cmd/auth/auth.go b/internal/cmd/auth/auth.go index a2cc07f..fccb5a1 100644 --- a/internal/cmd/auth/auth.go +++ b/internal/cmd/auth/auth.go @@ -55,6 +55,9 @@ func runLogin(ctx context.Context, opts loginOptions) error { if err != nil { return fmt.Errorf("failed to load config: %w", err) } + if opts.deps.ContextName != nil && strings.TrimSpace(*opts.deps.ContextName) != "" { + cfg.SetContextOverride(*opts.deps.ContextName) + } service := cliauth.NewService(opts.deps) var credentials cliauth.Credentials @@ -73,8 +76,7 @@ func runLogin(ctx context.Context, opts loginOptions) error { } } - cfg.UserToken = credentials.Token - cfg.UserID = credentials.UserID + cfg.SetCredentials(credentials.Token, credentials.UserID) if err := cfg.Save(); err != nil { return fmt.Errorf("failed to save credentials: %w", err) } diff --git a/internal/cmd/context/context.go b/internal/cmd/context/context.go new file mode 100644 index 0000000..6d5fab3 --- /dev/null +++ b/internal/cmd/context/context.go @@ -0,0 +1,166 @@ +// Package context wires the volcano context command tree. +package context + +import ( + "fmt" + "io" + "sort" + "strings" + + "github.com/spf13/cobra" + + "github.com/Kong/volcano-cli/internal/config" + "github.com/Kong/volcano-cli/internal/output" +) + +type setOptions struct { + name string + apiURL string + deviceClientID string + out io.Writer +} + +// New returns the context command tree. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "context", + Short: "Manage Volcano API contexts", + Long: "Manage named Volcano API contexts used by login and top-level commands.", + } + cmd.AddCommand(newList()) + cmd.AddCommand(newUse()) + cmd.AddCommand(newSet()) + cmd.AddCommand(newDelete()) + return cmd +} + +func newList() *cobra.Command { + return &cobra.Command{ + Use: "list", + Short: "List contexts", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + printContexts(cmd.OutOrStdout(), cfg) + return nil + }, + } +} + +func newUse() *cobra.Command { + return &cobra.Command{ + Use: "use ", + Short: "Set the default context", + Long: "Set the context used by default when --context and VOLCANO_CONTEXT are not set.", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + name := config.NormalizeContextName(args[0]) + cfg.SetDefaultContext(name) + if err := cfg.Save(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + output.Success(cmd.OutOrStdout(), "Now using context: %s", cfg.DefaultContext) + return nil + }, + } +} + +func newSet() *cobra.Command { + var opts setOptions + cmd := &cobra.Command{ + Use: "set ", + Short: "Create or update a context", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + opts.name = args[0] + opts.out = cmd.OutOrStdout() + return runSet(opts) + }, + } + cmd.Flags().StringVar(&opts.apiURL, "api-url", "", "API URL for the context") + cmd.Flags().StringVar(&opts.deviceClientID, "device-client-id", "", "OAuth device client ID for browser login") + return cmd +} + +func runSet(opts setOptions) error { + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + name := config.NormalizeContextName(opts.name) + if name == "" { + return fmt.Errorf("context name cannot be empty") + } + ctx := cfg.EnsureContext(name) + if apiURL := strings.TrimSpace(opts.apiURL); apiURL != "" { + ctx.APIBaseURL = apiURL + } + if clientID := strings.TrimSpace(opts.deviceClientID); clientID != "" { + ctx.DeviceClientID = clientID + } + if err := cfg.Save(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + output.Success(opts.out, "Context saved: %s", name) + return nil +} + +func newDelete() *cobra.Command { + return &cobra.Command{ + Use: "delete ", + Short: "Delete a custom context", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + name := config.NormalizeContextName(args[0]) + if config.IsBuiltInContext(name) { + return fmt.Errorf("cannot delete built-in context %q", name) + } + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + cfg.DeleteContext(name) + if err := cfg.Save(); err != nil { + return fmt.Errorf("failed to save config: %w", err) + } + output.Success(cmd.OutOrStdout(), "Context deleted: %s", name) + return nil + }, + } +} + +func printContexts(w io.Writer, cfg *config.Config) { + names := map[string]struct{}{} + for _, name := range config.BuiltInContextNames() { + names[name] = struct{}{} + } + for name := range cfg.Contexts { + names[config.NormalizeContextName(name)] = struct{}{} + } + sorted := make([]string, 0, len(names)) + for name := range names { + if name != "" { + sorted = append(sorted, name) + } + } + sort.Strings(sorted) + + active := cfg.ActiveContextName() + fmt.Fprintf(w, "%-8s %-7s %s\n", "CURRENT", "NAME", "API URL") + fmt.Fprintln(w, strings.Repeat("-", 72)) + for _, name := range sorted { + marker := "" + if name == active { + marker = "*" + } + resolved := cfg.ResolvedContext(name) + fmt.Fprintf(w, "%-8s %-7s %s\n", marker, name, resolved.APIBaseURL) + } +} diff --git a/internal/cmd/context/context_test.go b/internal/cmd/context/context_test.go new file mode 100644 index 0000000..1934087 --- /dev/null +++ b/internal/cmd/context/context_test.go @@ -0,0 +1,79 @@ +package context + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/Kong/volcano-cli/internal/config" +) + +func TestContextUseCreatesPresetAndSetsDefault(t *testing.T) { + setContextTestHome(t) + + out, err := executeContextCommand(t, "use", "production") + require.NoError(t, err) + assert.Contains(t, out, "Now using context: prod") + + cfg, err := config.Load() + require.NoError(t, err) + assert.Equal(t, config.ContextProd, cfg.DefaultContext) + assert.Equal(t, "https://api.volcano.dev", cfg.ResolvedContext(config.ContextProd).APIBaseURL) +} + +func TestContextSetCustom(t *testing.T) { + setContextTestHome(t) + + _, err := executeContextCommand(t, "set", "qa", "--api-url", "https://qa.example", "--device-client-id", "devcli_custom") + require.NoError(t, err) + + cfg, err := config.Load() + require.NoError(t, err) + ctx := cfg.ResolvedContext("qa") + assert.Equal(t, "https://qa.example", ctx.APIBaseURL) + assert.Equal(t, "devcli_custom", ctx.DeviceClientID) +} + +func TestContextListShowsBuiltInsAndActive(t *testing.T) { + setContextTestHome(t) + cfg := config.Default() + cfg.SetDefaultContext(config.ContextStage) + require.NoError(t, cfg.Save()) + + out, err := executeContextCommand(t, "list") + require.NoError(t, err) + assert.Contains(t, out, "dev") + assert.Contains(t, out, "stage") + assert.Contains(t, out, "prod") + assert.Contains(t, out, "* stage") +} + +func TestContextDeleteRejectsBuiltIn(t *testing.T) { + setContextTestHome(t) + + _, err := executeContextCommand(t, "delete", "dev") + require.ErrorContains(t, err, "cannot delete built-in context") +} + +func executeContextCommand(t *testing.T, args ...string) (string, error) { + t.Helper() + cmd := New() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs(args) + err := cmd.Execute() + return out.String(), err +} + +func setContextTestHome(t *testing.T) { + t.Helper() + t.Setenv("HOME", t.TempDir()) + t.Setenv("VOLCANO_TOKEN", "") + t.Setenv("VOLCANO_PROJECT_ID", "") + t.Setenv("VOLCANO_API_URL", "") + t.Setenv("VOLCANO_CONTEXT", "") + t.Setenv("VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID", "") +} diff --git a/internal/cmd/root/root.go b/internal/cmd/root/root.go index ed3b26b..d6c15e6 100644 --- a/internal/cmd/root/root.go +++ b/internal/cmd/root/root.go @@ -9,6 +9,7 @@ import ( authcmd "github.com/Kong/volcano-cli/internal/cmd/auth" configcmd "github.com/Kong/volcano-cli/internal/cmd/config" + contextcmd "github.com/Kong/volcano-cli/internal/cmd/context" databasescmd "github.com/Kong/volcano-cli/internal/cmd/databases" frontendscmd "github.com/Kong/volcano-cli/internal/cmd/frontends" functionscmd "github.com/Kong/volcano-cli/internal/cmd/functions" @@ -26,6 +27,8 @@ import ( // New returns the root Volcano command. func New(deps cliruntime.Deps) *cobra.Command { var showVersion bool + contextName := "" + deps.ContextName = &contextName root := &cobra.Command{ Use: "volcano", Short: "Volcano CLI", @@ -43,11 +46,13 @@ func New(deps cliruntime.Deps) *cobra.Command { return cmd.Help() }, } + root.PersistentFlags().StringVar(&contextName, "context", "", "Context to use for this command") root.Flags().BoolVarP(&showVersion, "version", "v", false, "Print CLI version") root.AddCommand(newVersionCmd()) root.AddCommand(upgradecmd.New(deps)) root.AddCommand(authcmd.NewLogin(deps)) root.AddCommand(authcmd.NewLogout()) + root.AddCommand(contextcmd.New()) root.AddCommand(initcmd.New()) root.AddCommand(projectcmd.NewProjects(deps)) root.AddCommand(projectcmd.NewUse(deps)) diff --git a/internal/cmd/root/root_test.go b/internal/cmd/root/root_test.go index b88111f..a419dd5 100644 --- a/internal/cmd/root/root_test.go +++ b/internal/cmd/root/root_test.go @@ -15,6 +15,7 @@ func TestRootHelp(t *testing.T) { out, err := executeRootCommand(t, "--help") require.NoError(t, err) assert.Contains(t, out, "volcano") + assert.Contains(t, out, "context") assert.Contains(t, out, "databases") assert.Contains(t, out, "functions") assert.Contains(t, out, "init") diff --git a/internal/config/config.go b/internal/config/config.go index 66c1d7d..9c685d3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,12 +14,27 @@ const ( envToken = "VOLCANO_TOKEN" envProjectID = "VOLCANO_PROJECT_ID" envAPIURL = "VOLCANO_API_URL" + envContext = "VOLCANO_CONTEXT" envFirstPartyDeviceID = "VOLCANO_FIRST_PARTY_DEVICE_CLIENT_ID" defaultConfigDirName = ".volcano" defaultConfigFileName = "config.json" defaultConfigDirMode = 0o700 defaultConfigFileMode = 0o600 defaultCompiledAPIURL = "https://api.volcano.dev" + + // ContextDev is the built-in context for localhost development APIs. + ContextDev = "dev" + // ContextStage is the built-in context for staging APIs. + ContextStage = "stage" + // ContextProd is the built-in context for production APIs. + ContextProd = "prod" + + devAPIURL = "http://localhost:8000" + stageAPIURL = "https://api.staging.volcano.dev" + prodAPIURL = "https://api.volcano.dev" + + devDeviceClientID = "devcli_dcc913b9786f9ef2825b861c" + prodDeviceClientID = "devcli_94e247237984b85cfd58d37e" ) var ( @@ -39,12 +54,27 @@ var ( type Config struct { // APIBaseURL overrides the compiled API URL for synthetic command configs. // It is intentionally not persisted to the user's cloud config file. - APIBaseURL string `json:"-"` + APIBaseURL string `json:"-"` + // ContextOverride is set by --context and intentionally not persisted. + ContextOverride string `json:"-"` + // Legacy flat fields are still read and mirrored for compatibility with + // older callers/tests. New writes persist them under Contexts instead. + UserToken string `json:"user_token,omitempty"` + UserID string `json:"user_id,omitempty"` + CurrentProject *ProjectConfig `json:"current_project,omitempty"` + DefaultContext string `json:"default_context,omitempty"` + Contexts map[string]*ContextConfig `json:"contexts,omitempty"` + // IgnoreEnv disables environment overrides for synthetic command configs. + IgnoreEnv bool `json:"-"` +} + +// ContextConfig represents one named Volcano API target and its credentials. +type ContextConfig struct { + APIBaseURL string `json:"api_url,omitempty"` + DeviceClientID string `json:"device_client_id,omitempty"` UserToken string `json:"user_token,omitempty"` UserID string `json:"user_id,omitempty"` CurrentProject *ProjectConfig `json:"current_project,omitempty"` - // IgnoreEnv disables environment overrides for synthetic command configs. - IgnoreEnv bool `json:"-"` } // ProjectConfig represents the currently selected Volcano project. @@ -55,7 +85,7 @@ type ProjectConfig struct { // Default returns an empty config. func Default() *Config { - return &Config{} + return &Config{DefaultContext: ContextProd} } // Path returns the default on-disk config path. @@ -86,6 +116,7 @@ func Load() (*Config, error) { if err := json.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("failed to parse config file: %w", err) } + cfg.applyLoadedDefaults() return &cfg, nil } @@ -104,7 +135,8 @@ func (c *Config) Save() error { return fmt.Errorf("failed to set config directory permissions: %w", err) } - data, err := json.MarshalIndent(c, "", " ") + persisted := c.persisted() + data, err := json.MarshalIndent(persisted, "", " ") if err != nil { return fmt.Errorf("failed to marshal config: %w", err) } @@ -146,6 +178,9 @@ func (c *Config) Token() string { if token := os.Getenv(envToken); !c.IgnoreEnv && token != "" { return token } + if token := c.ResolvedActiveContext().UserToken; token != "" { + return token + } return c.UserToken } @@ -154,6 +189,9 @@ func (c *Config) ProjectID() string { if projectID := os.Getenv(envProjectID); !c.IgnoreEnv && projectID != "" { return projectID } + if project := c.ResolvedActiveContext().CurrentProject; project != nil { + return project.ID + } if c.CurrentProject != nil { return c.CurrentProject.ID } @@ -161,7 +199,7 @@ func (c *Config) ProjectID() string { } // APIURL returns the API URL with VOLCANO_API_URL taking precedence unless env overrides are disabled. -// Precedence: env > runtime override (APIBaseURL) > compiled default. +// Precedence: env > runtime override (APIBaseURL) > active context > compiled default. func (c *Config) APIURL() string { if apiURL := strings.TrimSpace(os.Getenv(envAPIURL)); !c.IgnoreEnv && apiURL != "" { return apiURL @@ -169,6 +207,9 @@ func (c *Config) APIURL() string { if c.APIBaseURL != "" { return c.APIBaseURL } + if apiURL := strings.TrimSpace(c.ResolvedActiveContext().APIBaseURL); apiURL != "" { + return apiURL + } return compiledDefaultAPIURL } @@ -198,3 +239,300 @@ func FirstPartyDeviceClientID() (string, error) { } return "", fmt.Errorf("%s is required", envFirstPartyDeviceID) } + +// DeviceClientID resolves the OAuth device client ID for the active context. +func (c *Config) DeviceClientID() (string, error) { + if clientID := strings.TrimSpace(os.Getenv(envFirstPartyDeviceID)); !c.IgnoreEnv && clientID != "" { + return clientID, nil + } + if clientID := strings.TrimSpace(c.ResolvedActiveContext().DeviceClientID); clientID != "" { + return clientID, nil + } + return "", fmt.Errorf("device_client_id is required for context %q. Run 'volcano context set %s --device-client-id '", c.ActiveContextName(), c.ActiveContextName()) +} + +// ActiveContextName returns the selected context name. +func (c *Config) ActiveContextName() string { + if name := NormalizeContextName(c.ContextOverride); name != "" { + return name + } + if name := strings.TrimSpace(os.Getenv(envContext)); !c.IgnoreEnv && name != "" { + return NormalizeContextName(name) + } + if name := NormalizeContextName(c.DefaultContext); name != "" { + return name + } + return ContextProd +} + +// SetContextOverride selects a context for this in-memory config only. +func (c *Config) SetContextOverride(name string) { + c.ContextOverride = NormalizeContextName(name) + c.refreshLegacyMirror() +} + +// ResolvedActiveContext returns active context values with built-in presets applied. +func (c *Config) ResolvedActiveContext() ContextConfig { + return c.ResolvedContext(c.ActiveContextName()) +} + +// ResolvedContext returns context values with built-in presets applied. +func (c *Config) ResolvedContext(name string) ContextConfig { + name = NormalizeContextName(name) + if name == "" { + name = ContextProd + } + resolved := PresetContext(name) + if c.Contexts == nil { + return resolved + } + stored := c.Contexts[name] + if stored == nil { + return resolved + } + if strings.TrimSpace(stored.APIBaseURL) != "" { + resolved.APIBaseURL = stored.APIBaseURL + } + if strings.TrimSpace(stored.DeviceClientID) != "" { + resolved.DeviceClientID = stored.DeviceClientID + } + if stored.UserToken != "" { + resolved.UserToken = stored.UserToken + } + if stored.UserID != "" { + resolved.UserID = stored.UserID + } + if stored.CurrentProject != nil { + resolved.CurrentProject = cloneProject(stored.CurrentProject) + } + return resolved +} + +// EnsureContext returns a mutable context, creating and presetting it when needed. +func (c *Config) EnsureContext(name string) *ContextConfig { + name = NormalizeContextName(name) + if name == "" { + name = ContextProd + } + if c.Contexts == nil { + c.Contexts = make(map[string]*ContextConfig) + } + ctx := c.Contexts[name] + if ctx == nil { + preset := PresetContext(name) + ctx = &preset + c.Contexts[name] = ctx + return ctx + } + applyPresetDefaults(name, ctx) + return ctx +} + +// SetDefaultContext selects the persisted default context. +func (c *Config) SetDefaultContext(name string) { + name = NormalizeContextName(name) + if name == "" { + name = ContextProd + } + c.EnsureContext(name) + c.DefaultContext = name + c.refreshLegacyMirror() +} + +// SetCredentials stores credentials in the active context. +func (c *Config) SetCredentials(token, userID string) { + ctx := c.EnsureContext(c.ActiveContextName()) + ctx.UserToken = token + ctx.UserID = userID + c.refreshLegacyMirror() +} + +// SetCurrentProject stores the selected project in the active context. +func (c *Config) SetCurrentProject(project *ProjectConfig) { + ctx := c.EnsureContext(c.ActiveContextName()) + ctx.CurrentProject = cloneProject(project) + c.refreshLegacyMirror() +} + +// DeleteContext removes persisted values for a context. Built-in presets still resolve. +func (c *Config) DeleteContext(name string) { + name = NormalizeContextName(name) + if c.Contexts != nil { + delete(c.Contexts, name) + } + if c.ActiveContextName() == name { + c.DefaultContext = ContextProd + } + c.refreshLegacyMirror() +} + +// NormalizeContextName normalizes context aliases. +func NormalizeContextName(name string) string { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "production" { + return ContextProd + } + return name +} + +// BuiltInContextNames returns canonical built-in context names. +func BuiltInContextNames() []string { + return []string{ContextDev, ContextStage, ContextProd} +} + +// IsBuiltInContext reports whether name is a built-in preset context. +func IsBuiltInContext(name string) bool { + switch NormalizeContextName(name) { + case ContextDev, ContextStage, ContextProd: + return true + default: + return false + } +} + +// PresetContext returns the built-in preset for name, or an empty custom context. +func PresetContext(name string) ContextConfig { + switch NormalizeContextName(name) { + case ContextDev: + return ContextConfig{APIBaseURL: devAPIURL, DeviceClientID: devDeviceClientID} + case ContextStage: + return ContextConfig{APIBaseURL: stageAPIURL, DeviceClientID: devDeviceClientID} + case ContextProd: + return ContextConfig{APIBaseURL: prodAPIURL, DeviceClientID: prodDeviceClientID} + default: + return ContextConfig{} + } +} + +func (c *Config) applyLoadedDefaults() { + if c.DefaultContext == "" { + c.DefaultContext = ContextProd + } else { + c.DefaultContext = NormalizeContextName(c.DefaultContext) + } + if len(c.Contexts) == 0 && hasLegacyConfig(c) { + legacyContext := c.legacyMigrationContextName() + ctx := PresetContext(legacyContext) + ctx.UserToken = c.UserToken + ctx.UserID = c.UserID + ctx.CurrentProject = cloneProject(c.CurrentProject) + c.Contexts = map[string]*ContextConfig{legacyContext: &ctx} + c.DefaultContext = legacyContext + } + c.refreshLegacyMirror() +} + +func (c *Config) legacyMigrationContextName() string { + if name := NormalizeContextName(os.Getenv(envContext)); !c.IgnoreEnv && name != "" { + return name + } + if name := contextNameForAPIURL(os.Getenv(envAPIURL)); !c.IgnoreEnv && name != "" { + return name + } + if name := NormalizeContextName(c.DefaultContext); name != "" { + return name + } + return ContextProd +} + +func contextNameForAPIURL(apiURL string) string { + switch strings.TrimRight(strings.TrimSpace(apiURL), "/") { + case strings.TrimRight(devAPIURL, "/"): + return ContextDev + case strings.TrimRight(stageAPIURL, "/"): + return ContextStage + case strings.TrimRight(prodAPIURL, "/"): + return ContextProd + default: + return "" + } +} + +func (c *Config) persisted() Config { + persisted := Config{ + DefaultContext: NormalizeContextName(c.DefaultContext), + Contexts: cloneContexts(c.Contexts), + } + if persisted.DefaultContext == "" { + persisted.DefaultContext = ContextProd + } + if hasLegacyConfig(c) { + ctx := ensureContextInMap(persisted.Contexts, c.ActiveContextName()) + applyPresetDefaults(c.ActiveContextName(), ctx) + if c.UserToken != "" { + ctx.UserToken = c.UserToken + } + if c.UserID != "" { + ctx.UserID = c.UserID + } + if c.CurrentProject != nil { + ctx.CurrentProject = cloneProject(c.CurrentProject) + } + } + ensureContextInMap(persisted.Contexts, persisted.DefaultContext) + applyPresetDefaults(persisted.DefaultContext, persisted.Contexts[persisted.DefaultContext]) + return persisted +} + +func (c *Config) refreshLegacyMirror() { + ctx := c.ResolvedActiveContext() + c.UserToken = ctx.UserToken + c.UserID = ctx.UserID + c.CurrentProject = cloneProject(ctx.CurrentProject) +} + +func hasLegacyConfig(c *Config) bool { + return c.UserToken != "" || c.UserID != "" || c.CurrentProject != nil +} + +func cloneContexts(contexts map[string]*ContextConfig) map[string]*ContextConfig { + cloned := make(map[string]*ContextConfig, len(contexts)) + for name, ctx := range contexts { + if ctx == nil { + continue + } + ctxCopy := *ctx + ctxCopy.CurrentProject = cloneProject(ctx.CurrentProject) + cloned[NormalizeContextName(name)] = &ctxCopy + } + return cloned +} + +func ensureContextInMap(contexts map[string]*ContextConfig, name string) *ContextConfig { + name = NormalizeContextName(name) + if name == "" { + name = ContextProd + } + if contexts == nil { + // All current callers pass a non-nil map from cloneContexts. + return nil + } + ctx := contexts[name] + if ctx == nil { + preset := PresetContext(name) + ctx = &preset + contexts[name] = ctx + } + return ctx +} + +func applyPresetDefaults(name string, ctx *ContextConfig) { + if ctx == nil { + return + } + preset := PresetContext(name) + if ctx.APIBaseURL == "" { + ctx.APIBaseURL = preset.APIBaseURL + } + if ctx.DeviceClientID == "" { + ctx.DeviceClientID = preset.DeviceClientID + } +} + +func cloneProject(project *ProjectConfig) *ProjectConfig { + if project == nil { + return nil + } + cloned := *project + return &cloned +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d224df2..b28fd47 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -35,7 +35,7 @@ func TestLoadSaveDeleteAndPermissions(t *testing.T) { loaded, err := Load() require.NoError(t, err) - assert.Equal(t, cfg.UserToken, loaded.UserToken) + assert.Equal(t, cfg.UserToken, loaded.Token()) assert.Equal(t, cfg.UserID, loaded.UserID) require.NotNil(t, loaded.CurrentProject) assert.Equal(t, cfg.CurrentProject.ID, loaded.CurrentProject.ID) @@ -52,7 +52,67 @@ func TestLoadSaveDeleteAndPermissions(t *testing.T) { assert.Nil(t, empty.CurrentProject) } -func TestSaveOmitsRuntimeOnlyAPIURL(t *testing.T) { +func TestContextPresetsAndDefaultSelection(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv(envToken, "") + t.Setenv(envProjectID, "") + t.Setenv(envAPIURL, "") + t.Setenv(envContext, "") + t.Setenv(envFirstPartyDeviceID, "") + + cfg := Default() + assert.Equal(t, ContextProd, cfg.ActiveContextName()) + assert.Equal(t, "https://api.volcano.dev", cfg.APIURL()) + clientID, err := cfg.DeviceClientID() + require.NoError(t, err) + assert.Equal(t, prodDeviceClientID, clientID) + + cfg.SetDefaultContext("production") + assert.Equal(t, ContextProd, cfg.ActiveContextName()) + + cfg.SetDefaultContext(ContextDev) + assert.Equal(t, ContextDev, cfg.ActiveContextName()) + assert.Equal(t, "http://localhost:8000", cfg.APIURL()) + clientID, err = cfg.DeviceClientID() + require.NoError(t, err) + assert.Equal(t, devDeviceClientID, clientID) + + cfg.SetDefaultContext(ContextStage) + assert.Equal(t, "https://api.staging.volcano.dev", cfg.APIURL()) + clientID, err = cfg.DeviceClientID() + require.NoError(t, err) + assert.Equal(t, devDeviceClientID, clientID) +} + +func TestContextSelectionPrecedence(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv(envAPIURL, "") + t.Setenv(envContext, ContextStage) + + cfg := Default() + cfg.SetDefaultContext(ContextDev) + assert.Equal(t, ContextStage, cfg.ActiveContextName()) + assert.Equal(t, "https://api.staging.volcano.dev", cfg.APIURL()) + + cfg.SetContextOverride(ContextProd) + assert.Equal(t, ContextProd, cfg.ActiveContextName()) + assert.Equal(t, "https://api.volcano.dev", cfg.APIURL()) +} + +func TestCustomContextRequiresDeviceClientID(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv(envFirstPartyDeviceID, "") + t.Setenv(envContext, "") + + cfg := Default() + cfg.SetDefaultContext("custom") + cfg.EnsureContext("custom").APIBaseURL = "https://custom.example" + + _, err := cfg.DeviceClientID() + require.ErrorContains(t, err, "device_client_id is required") +} + +func TestSaveOmitsRuntimeOnlyAPIURLOverride(t *testing.T) { t.Setenv("HOME", t.TempDir()) cfg := &Config{ @@ -65,7 +125,6 @@ func TestSaveOmitsRuntimeOnlyAPIURL(t *testing.T) { require.NoError(t, err) data, err := os.ReadFile(configPath) require.NoError(t, err) - assert.NotContains(t, string(data), "api_url") assert.NotContains(t, string(data), "http://localhost:8000") loaded, err := Load() @@ -98,6 +157,40 @@ func TestSaveRepairsExistingConfigPermissions(t *testing.T) { assert.Equal(t, "new-token", loaded.UserToken) } +func TestLoadMigratesLegacyConfigUsingAPIURLEnv(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv(envContext, "") + t.Setenv(envAPIURL, stageAPIURL) + + configPath, err := Path() + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(`{"user_token":"stage-token"}`), 0o600)) + + cfg, err := Load() + require.NoError(t, err) + assert.Equal(t, ContextStage, cfg.DefaultContext) + assert.Equal(t, "stage-token", cfg.ResolvedContext(ContextStage).UserToken) + assert.Empty(t, cfg.ResolvedContext(ContextProd).UserToken) +} + +func TestLoadMigratesLegacyConfigUsingContextEnvBeforeAPIURL(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv(envContext, ContextDev) + t.Setenv(envAPIURL, stageAPIURL) + + configPath, err := Path() + require.NoError(t, err) + require.NoError(t, os.MkdirAll(filepath.Dir(configPath), 0o700)) + require.NoError(t, os.WriteFile(configPath, []byte(`{"user_token":"dev-token"}`), 0o600)) + + cfg, err := Load() + require.NoError(t, err) + assert.Equal(t, ContextDev, cfg.DefaultContext) + assert.Equal(t, "dev-token", cfg.ResolvedContext(ContextDev).UserToken) + assert.Empty(t, cfg.ResolvedContext(ContextStage).UserToken) +} + func TestEnvPrecedence(t *testing.T) { t.Setenv("HOME", t.TempDir()) t.Setenv(envToken, "env-token") @@ -180,7 +273,7 @@ func TestIgnoreEnvUsesConfigValues(t *testing.T) { assert.Equal(t, "http://localhost:8000", cfg.APIURL()) } -func TestCompiledDefaults(t *testing.T) { +func TestProdPresetAndCompiledDeviceClientFallback(t *testing.T) { t.Setenv("HOME", t.TempDir()) t.Setenv(envAPIURL, "") t.Setenv(envFirstPartyDeviceID, "") @@ -195,7 +288,7 @@ func TestCompiledDefaults(t *testing.T) { }) cfg := &Config{} - assert.Equal(t, "https://compiled.example", cfg.APIURL()) + assert.Equal(t, "https://api.volcano.dev", cfg.APIURL()) got, err := FirstPartyDeviceClientID() require.NoError(t, err) diff --git a/internal/project/project.go b/internal/project/project.go index 60ec878..92e43c9 100644 --- a/internal/project/project.go +++ b/internal/project/project.go @@ -149,10 +149,10 @@ func resolveProject(ctx context.Context, client *api.Client, identifier string) } func saveCurrentProject(cfg *config.Config, project *apiclient.Project) error { - cfg.CurrentProject = &config.ProjectConfig{ + cfg.SetCurrentProject(&config.ProjectConfig{ ID: project.Id.String(), Name: project.Name, - } + }) if err := cfg.Save(); err != nil { return fmt.Errorf("failed to save config: %w", err) diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 1eb1203..7da0108 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -17,6 +17,7 @@ type Deps struct { // APIBaseURL overrides the compiled cloud API URL for tests. Synthetic // local configs supply their API URL through ConfigLoader instead. APIBaseURL string + ContextName *string OpenBrowser func(string) error NewTimer func(time.Duration) Timer NewTicker func(time.Duration) Ticker diff --git a/internal/session/session.go b/internal/session/session.go index 39d578e..ac1e38f 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -58,6 +58,9 @@ func (f Factory) Config() (*config.Config, error) { if err != nil { return nil, fmt.Errorf("failed to load config: %w", err) } + if f.deps.ContextName != nil && strings.TrimSpace(*f.deps.ContextName) != "" { + cfg.SetContextOverride(*f.deps.ContextName) + } return cfg, nil }