diff --git a/pkg/utils/vault.go b/pkg/utils/vault.go index 975cd320f..62b6a3758 100644 --- a/pkg/utils/vault.go +++ b/pkg/utils/vault.go @@ -1,10 +1,14 @@ package utils import ( + "bytes" + "context" "encoding/json" "fmt" "os" "os/exec" + "strings" + "time" ocmutils "github.com/openshift/ocm-container/pkg/utils" @@ -12,6 +16,12 @@ import ( "github.com/spf13/viper" ) +const ( + vaultCallbackPortFile = "/tmp/vault_callback_port" + defaultVaultOIDCPort = "8250" + vaultLoginTimeout = 5 * time.Minute +) + const ( VaultAddrKey string = "vault_address" ) @@ -41,9 +51,32 @@ func GetVaultRef(vaultPathKey string) (VaultRef, error) { return VaultRef{Addr: vaultAddr, Path: vaultPath}, nil } +// readFileFunc is the function used to read files. Replaced in tests +// to avoid filesystem access. +var readFileFunc = os.ReadFile + +// readCallbackPort reads the dynamically-assigned host port from the +// portmap file written by ocm-container's ports feature. +func readCallbackPort() string { + data, err := readFileFunc(vaultCallbackPortFile) + if err != nil { + log.Debugf("could not read vault callback port file: %v", err) + return "" + } + port := strings.TrimSpace(string(data)) + if port == "" { + log.Debugf("vault callback port file is empty") + return "" + } + log.Debugf("read vault callback port: %s", port) + return port +} + // setupVaultToken ensures a valid Vault token exists by checking the current // token and requesting a new one via OIDC if needed. In container environments, -// it configures authentication to work without browser auto-launch. +// it uses the dynamically-assigned callback port from ocm-container's ports +// feature and falls back to in-process token capture if the token file is +// not writable. func setupVaultToken(vaultAddr string) error { err := os.Setenv("VAULT_ADDR", vaultAddr) if err != nil { @@ -51,7 +84,6 @@ func setupVaultToken(vaultAddr string) error { } versionCheckCmd := exec.Command("vault", "version") - versionCheckCmd.Stdout = os.Stderr versionCheckCmd.Stderr = os.Stderr @@ -62,51 +94,139 @@ func setupVaultToken(vaultAddr string) error { tokenCheckCmd := exec.Command("vault", "token", "lookup") tokenCheckCmd.Stdout = nil tokenCheckCmd.Stderr = nil - // get new token since old token has expired if err = tokenCheckCmd.Run(); err != nil { log.Infoln("Vault token no longer valid, requesting new token") - // Check if we're in a container environment (OCM_CONTAINER env var is set) - // If so, skip automatic browser launch and print the URL for manual authentication - loginArgs := []string{"login", "-method=oidc", "-no-print"} if ocmutils.IsRunningInOcmContainer() { - log.Infoln("\nNOTE: Running in container mode - OIDC authentication requires port forwarding.") - log.Infoln("Ensure port 8250 is exposed in your ocm-container configuration:") - log.Infoln(" Add 'launch-opts: \"-p 8250:8250\"' to ~/.config/ocm-container/ocm-container.yaml") - log.Infoln("Then restart your ocm-container for the change to take effect.") - - // In container: skip browser launch and listen on all interfaces (0.0.0.0) - // so the callback can be reached from the host browser via localhost:8250 - loginArgs = []string{"login", "-method=oidc", "skip_browser=true", "listenaddress=0.0.0.0"} + return setupVaultTokenContainer() } - loginCmd := exec.Command("vault", loginArgs...) - // Show output when using skip_browser so user can see the authentication URL - if ocmutils.IsRunningInOcmContainer() { - loginCmd.Stdout = os.Stderr - loginCmd.Stderr = os.Stderr - } else { - loginCmd.Stdout = nil - loginCmd.Stderr = nil - } + return setupVaultTokenLocal() + } + + return nil +} + +// setupVaultTokenLocal handles vault OIDC login outside of a container. +func setupVaultTokenLocal() error { + ctx, cancel := context.WithTimeout(context.Background(), vaultLoginTimeout) + defer cancel() + + loginCmd := exec.CommandContext(ctx, "vault", "login", "-method=oidc", "-no-print") + loginCmd.Stdout = nil + loginCmd.Stderr = nil - if err = loginCmd.Run(); err != nil { - if ocmutils.IsRunningInOcmContainer() { - return fmt.Errorf("vault login failed: %v\n\n"+ - "If authentication timed out or the callback failed, this is likely because:\n"+ - " 1. Port 8250 is not exposed in your ocm-container configuration\n"+ - " 2. Your ocm-container was not restarted after adding the port\n\n"+ - "To fix:\n"+ - " - Add 'launch-opts: \"-p 8250:8250\"' to ~/.config/ocm-container/ocm-container.yaml\n"+ - " - Exit and restart your ocm-container\n"+ - " - Try the authentication again", err) - } - return fmt.Errorf("error running 'vault login': %v", err) + if err := loginCmd.Run(); err != nil { + if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("vault login timed out after %s", vaultLoginTimeout) } + return fmt.Errorf("error running 'vault login': %v", err) + } + log.Infoln("Acquired vault token") + return nil +} + +// buildOIDCArgs builds the vault login args for container mode, +// including the dynamic callback port if available. +func buildOIDCArgs(noStore bool, callbackPort string) []string { + args := []string{"login", "-method=oidc"} + if noStore { + args = append(args, "-no-store", "-field=token") + } + args = append(args, "skip_browser=true", "listenaddress=0.0.0.0") + + if callbackPort != "" { + args = append(args, + fmt.Sprintf("port=%s", defaultVaultOIDCPort), + fmt.Sprintf("callbackport=%s", callbackPort), + ) + log.Infof("Using dynamic vault OIDC callback port: %s", callbackPort) + } else { + log.Infoln("No dynamic callback port found, using default port 8250.") + log.Infoln("If running multiple containers, ensure ocm-container has the ports feature enabled.") + } + + return args +} + +// isTokenFileError checks whether a vault login error is related to +// writing the token file (e.g., bind mount rename failures). +func isTokenFileError(err error) bool { + msg := err.Error() + return strings.Contains(msg, "rename") || + strings.Contains(msg, "device or resource busy") || + strings.Contains(msg, "read-only file system") || + strings.Contains(msg, "permission denied") +} + +// setupVaultTokenContainer handles vault OIDC login inside an ocm-container. +// It first tries a normal vault login that writes ~/.vault-token. If that +// fails due to a token file write error (e.g., read-only bind mount), it +// falls back to capturing the token in-process via VAULT_TOKEN. +func setupVaultTokenContainer() error { + callbackPort := readCallbackPort() + + ctx, cancel := context.WithTimeout(context.Background(), vaultLoginTimeout) + defer cancel() + + log.Infoln("Complete the login via the URL printed below.") + + // First attempt: normal login that writes ~/.vault-token + loginArgs := buildOIDCArgs(false, callbackPort) + loginCmd := exec.CommandContext(ctx, "vault", loginArgs...) + loginCmd.Stdout = os.Stderr + loginCmd.Stderr = os.Stderr + + if err := loginCmd.Run(); err == nil { log.Infoln("Acquired vault token") + return nil + } else if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("vault login timed out after %s", vaultLoginTimeout) + } else if !isTokenFileError(err) { + return fmt.Errorf("vault login failed: %v\n\n"+ + "If authentication timed out or the callback failed:\n"+ + " 1. Ensure ocm-container has the vault port enabled in the ports feature\n"+ + " 2. Restart your ocm-container for the change to take effect\n"+ + " 3. Try the authentication again", err) + } + + // Token file write failed (bind mount rename issue). Fall back to + // capturing the token in-process. + log.Infof("Token file write failed (%v), capturing token in-process instead.", ctx.Err()) + log.Infoln("Complete the login via the URL printed below.") + + ctx2, cancel2 := context.WithTimeout(context.Background(), vaultLoginTimeout) + defer cancel2() + + loginArgs = buildOIDCArgs(true, callbackPort) + loginCmd = exec.CommandContext(ctx2, "vault", loginArgs...) + + var tokenBuf bytes.Buffer + loginCmd.Stdout = &tokenBuf + loginCmd.Stderr = os.Stderr + + if err := loginCmd.Run(); err != nil { + if ctx2.Err() == context.DeadlineExceeded { + return fmt.Errorf("vault login timed out after %s", vaultLoginTimeout) + } + return fmt.Errorf("vault login failed: %v\n\n"+ + "If authentication timed out or the callback failed:\n"+ + " 1. Ensure ocm-container has the vault port enabled in the ports feature\n"+ + " 2. Restart your ocm-container for the change to take effect\n"+ + " 3. Try the authentication again", err) + } + + token := strings.TrimSpace(tokenBuf.String()) + if token == "" { + return fmt.Errorf("vault login succeeded but returned empty token") + } + + if err := os.Setenv("VAULT_TOKEN", token); err != nil { + return fmt.Errorf("error setting VAULT_TOKEN: %v", err) } + log.Infoln("Acquired vault token (in-process)") return nil } diff --git a/pkg/utils/vault_test.go b/pkg/utils/vault_test.go index 66e0c76ec..6941500fc 100644 --- a/pkg/utils/vault_test.go +++ b/pkg/utils/vault_test.go @@ -1,156 +1,212 @@ package utils import ( - "os" - "os/exec" + "fmt" "testing" - - ocmutils "github.com/openshift/ocm-container/pkg/utils" ) -// TestSetupVaultToken_ContainerEnvironment tests that the vault login command -// uses the correct flags when running inside a container (IO_OPENSHIFT_MANAGED_NAME env var set) -func TestSetupVaultToken_ContainerEnvironment(t *testing.T) { +func TestReadCallbackPort(t *testing.T) { tests := []struct { - name string - containerEnvValue string - expectNoBrowser bool + name string + mockData []byte + mockErr error + expected string }{ { - name: "Container environment with IO_OPENSHIFT_MANAGED_NAME=ocm-container", - containerEnvValue: "ocm-container", - expectNoBrowser: true, + name: "reads port from file", + mockData: []byte("43210\n"), + expected: "43210", + }, + { + name: "trims whitespace", + mockData: []byte(" 12345 \n"), + expected: "12345", + }, + { + name: "returns empty for missing file", + mockErr: fmt.Errorf("no such file"), + expected: "", + }, + { + name: "returns empty for empty file", + mockData: []byte(""), + expected: "", }, { - name: "Non-container environment (empty)", - containerEnvValue: "", - expectNoBrowser: false, + name: "returns empty for whitespace-only file", + mockData: []byte(" \n"), + expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Set environment using t.Setenv - automatically cleaned up after test - // Always call t.Setenv to ensure clean environment, even for empty case - t.Setenv("IO_OPENSHIFT_MANAGED_NAME", tt.containerEnvValue) - - // Build the command args as the code does - loginArgs := []string{"login", "-method=oidc", "-no-print"} - if ocmutils.IsRunningInOcmContainer() { - loginArgs = []string{"login", "-method=oidc", "skip_browser=true", "listenaddress=0.0.0.0"} - } - - // Verify the correct parameter is used - cmd := exec.Command("vault", loginArgs...) - cmdArgs := cmd.Args[1:] // Skip the "vault" binary name - - if tt.expectNoBrowser { - // Should have skip_browser=true and listenaddress=0.0.0.0 parameters - hasSkipBrowser := false - hasNoPrint := false - hasListenAddress := false - for _, arg := range cmdArgs { - if arg == "skip_browser=true" { - hasSkipBrowser = true - } - if arg == "-no-print" { - hasNoPrint = true - } - if arg == "listenaddress=0.0.0.0" { - hasListenAddress = true - } - } + orig := readFileFunc + defer func() { readFileFunc = orig }() - if !hasSkipBrowser { - t.Errorf("Expected skip_browser=true parameter in container environment, got args: %v", cmdArgs) - } - if hasNoPrint { - t.Errorf("Did not expect -no-print flag in container environment, got args: %v", cmdArgs) - } - if !hasListenAddress { - t.Errorf("Expected listenaddress=0.0.0.0 parameter in container environment, got args: %v", cmdArgs) - } - } else { - // Should have -no-print flag - hasSkipBrowser := false - hasNoPrint := false - for _, arg := range cmdArgs { - if arg == "skip_browser=true" { - hasSkipBrowser = true - } - if arg == "-no-print" { - hasNoPrint = true - } + readFileFunc = func(_ string) ([]byte, error) { + if tt.mockErr != nil { + return nil, tt.mockErr } + return tt.mockData, nil + } - if hasSkipBrowser { - t.Errorf("Did not expect skip_browser=true parameter in non-container environment, got args: %v", cmdArgs) - } - if !hasNoPrint { - t.Errorf("Expected -no-print flag in non-container environment, got args: %v", cmdArgs) - } + port := readCallbackPort() + if port != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, port) } }) } } -// TestSetupVaultToken_OutputRedirection tests that stdout/stderr are properly -// redirected based on the environment -func TestSetupVaultToken_OutputRedirection(t *testing.T) { +func TestBuildOIDCArgs(t *testing.T) { tests := []struct { - name string - containerEnvValue string - expectOutput bool + name string + noStore bool + callbackPort string + expectNoStore bool + expectFieldToken bool + expectPort bool + expectCallback bool }{ { - name: "Container environment shows output", - containerEnvValue: "ocm-container", - expectOutput: true, + name: "with store, no callback port", + noStore: false, + callbackPort: "", + expectNoStore: false, + expectPort: false, + expectCallback: false, + }, + { + name: "without store, no callback port", + noStore: true, + callbackPort: "", + expectNoStore: true, + expectFieldToken: true, + expectPort: false, + expectCallback: false, }, { - name: "Non-container environment hides output", - containerEnvValue: "", - expectOutput: false, + name: "with store, with callback port", + noStore: false, + callbackPort: "43210", + expectNoStore: false, + expectPort: true, + expectCallback: true, + }, + { + name: "without store, with callback port", + noStore: true, + callbackPort: "43210", + expectNoStore: true, + expectFieldToken: true, + expectPort: true, + expectCallback: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Set environment using t.Setenv - automatically cleaned up after test - // Always call t.Setenv to ensure clean environment, even for empty case - t.Setenv("IO_OPENSHIFT_MANAGED_NAME", tt.containerEnvValue) + args := buildOIDCArgs(tt.noStore, tt.callbackPort) - // Build the command as the code does - loginArgs := []string{"login", "-method=oidc", "-no-print"} - if ocmutils.IsRunningInOcmContainer() { - loginArgs = []string{"login", "-method=oidc", "skip_browser=true", "listenaddress=0.0.0.0"} + argSet := map[string]bool{} + for _, arg := range args { + argSet[arg] = true } - loginCmd := exec.Command("vault", loginArgs...) - // Set output redirection as the code does - if ocmutils.IsRunningInOcmContainer() { - loginCmd.Stdout = os.Stdout - loginCmd.Stderr = os.Stderr - } else { - loginCmd.Stdout = nil - loginCmd.Stderr = nil + if !argSet["skip_browser=true"] { + t.Errorf("expected skip_browser=true, got args: %v", args) + } + if !argSet["listenaddress=0.0.0.0"] { + t.Errorf("expected listenaddress=0.0.0.0, got args: %v", args) + } + if !argSet["login"] { + t.Errorf("expected login, got args: %v", args) + } + if !argSet["-method=oidc"] { + t.Errorf("expected -method=oidc, got args: %v", args) } - // Verify output redirection is correct - if tt.expectOutput { - if loginCmd.Stdout != os.Stdout { - t.Error("Expected Stdout to be os.Stdout in container environment") - } - if loginCmd.Stderr != os.Stderr { - t.Error("Expected Stderr to be os.Stderr in container environment") - } - } else { - if loginCmd.Stdout != nil { - t.Error("Expected Stdout to be nil in non-container environment") - } - if loginCmd.Stderr != nil { - t.Error("Expected Stderr to be nil in non-container environment") - } + if tt.expectNoStore && !argSet["-no-store"] { + t.Errorf("expected -no-store, got args: %v", args) + } + if !tt.expectNoStore && argSet["-no-store"] { + t.Errorf("did not expect -no-store, got args: %v", args) + } + if tt.expectFieldToken && !argSet["-field=token"] { + t.Errorf("expected -field=token, got args: %v", args) + } + if !tt.expectFieldToken && argSet["-field=token"] { + t.Errorf("did not expect -field=token, got args: %v", args) + } + + expectPortArg := fmt.Sprintf("port=%s", defaultVaultOIDCPort) + expectCallbackArg := fmt.Sprintf("callbackport=%s", tt.callbackPort) + + if tt.expectPort && !argSet[expectPortArg] { + t.Errorf("expected %s, got args: %v", expectPortArg, args) + } + if !tt.expectPort && argSet[expectPortArg] { + t.Errorf("did not expect %s, got args: %v", expectPortArg, args) + } + if tt.expectCallback && !argSet[expectCallbackArg] { + t.Errorf("expected %s, got args: %v", expectCallbackArg, args) + } + if !tt.expectCallback && argSet[expectCallbackArg] { + t.Errorf("did not expect %s, got args: %v", expectCallbackArg, args) + } + }) + } +} + +func TestIsTokenFileError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "rename error", + err: fmt.Errorf("rename /root/.vault-token.tmp /root/.vault-token: device or resource busy"), + expected: true, + }, + { + name: "device or resource busy", + err: fmt.Errorf("device or resource busy"), + expected: true, + }, + { + name: "read-only file system", + err: fmt.Errorf("read-only file system"), + expected: true, + }, + { + name: "permission denied", + err: fmt.Errorf("permission denied"), + expected: true, + }, + { + name: "auth timeout", + err: fmt.Errorf("context deadline exceeded"), + expected: false, + }, + { + name: "network error", + err: fmt.Errorf("dial tcp: connection refused"), + expected: false, + }, + { + name: "generic vault error", + err: fmt.Errorf("Error making API request"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isTokenFileError(tt.err) + if result != tt.expected { + t.Errorf("isTokenFileError(%q) = %v, want %v", tt.err, result, tt.expected) } }) }