diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 104d6bc98a..a97afa845f 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -1,11 +1,9 @@ package ssh import ( - "fmt" "time" "github.com/databricks/cli/cmd/root" - "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/setup" "github.com/databricks/cli/libs/cmdctx" "github.com/spf13/cobra" @@ -57,17 +55,6 @@ an SSH host configuration to your SSH config file. Profile: wsClient.Config.Profile, AutoApprove: autoApprove, } - clientOpts := client.ClientOptions{ - ClusterID: setupOpts.ClusterID, - AutoStartCluster: setupOpts.AutoStartCluster, - ShutdownDelay: setupOpts.ShutdownDelay, - Profile: setupOpts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - if err != nil { - return fmt.Errorf("failed to generate ProxyCommand: %w", err) - } - setupOpts.ProxyCommand = proxyCommand return setup.Setup(ctx, wsClient, setupOpts) } diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index c2645a6379..43056c61a7 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + sshclient "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" @@ -28,8 +29,6 @@ type SetupOptions struct { SSHKeysDir string // Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand Profile string - // Proxy command to use for the SSH connection - ProxyCommand string // Skip confirmation prompts (e.g. recreate existing host config without asking) AutoApprove bool } @@ -45,17 +44,20 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func generateHostConfig(ctx context.Context, opts SetupOptions) (string, error) { +func generateHostConfig(ctx context.Context, opts SetupOptions, proxyCommand string) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(ctx, opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, proxyCommand) return hostConfig, nil } -func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { +// clusterSelectionPrompt is a package-level var so tests can replace it with a mock. +var clusterSelectionPrompt = defaultClusterSelectionPrompt + +func defaultClusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") clusters, err := client.Clusters.ClusterDetailsClusterNameToClusterIdMap(ctx, compute.ListClustersRequest{ @@ -92,6 +94,20 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } + // Build the ProxyCommand after the cluster ID is resolved. When the user + // omits --cluster, the ID is only known after the interactive picker above, + // so building it earlier would serialize an empty --cluster= flag. + clientOpts := sshclient.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + configPath, err := sshconfig.GetMainConfigPathOrDefault(ctx, opts.SSHConfigPath) if err != nil { return err @@ -102,7 +118,7 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - hostConfig, err := generateHostConfig(ctx, opts) + hostConfig, err := generateHostConfig(ctx, opts, proxyCommand) if err != nil { return err } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 4cd9970fee..f59b2e2b3a 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -1,6 +1,7 @@ package setup import ( + "context" "errors" "fmt" "os" @@ -10,6 +11,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/stretchr/testify/assert" @@ -134,10 +136,9 @@ func TestGenerateHostConfig_Valid(t *testing.T) { SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", - ProxyCommand: proxyCommand, } - result, err := generateHostConfig(t.Context(), opts) + result, err := generateHostConfig(t.Context(), opts, proxyCommand) assert.NoError(t, err) assert.Contains(t, result, "Host test-host") @@ -169,10 +170,9 @@ func TestGenerateHostConfig_WithoutProfile(t *testing.T) { SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "", - ProxyCommand: proxyCommand, } - result, err := generateHostConfig(t.Context(), opts) + result, err := generateHostConfig(t.Context(), opts, proxyCommand) assert.NoError(t, err) assert.NotContains(t, result, "--profile=") @@ -193,7 +193,7 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { ShutdownDelay: 30 * time.Second, } - result, err := generateHostConfig(t.Context(), opts) + result, err := generateHostConfig(t.Context(), opts, "") assert.NoError(t, err) // Check that quotes are properly escaped @@ -225,17 +225,7 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - AutoStartCluster: opts.AutoStartCluster, - ShutdownDelay: opts.ShutdownDelay, - Profile: opts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - require.NoError(t, err) - opts.ProxyCommand = proxyCommand - - err = Setup(ctx, m.WorkspaceClient, opts) + err := Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) // Check that main config has Include directive @@ -285,15 +275,7 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) { AutoApprove: true, } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - ShutdownDelay: opts.ShutdownDelay, - } - proxyCommand, err := clientOpts.ToProxyCommand() - require.NoError(t, err) - opts.ProxyCommand = proxyCommand - - err = Setup(ctx, m.WorkspaceClient, opts) + err := Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) // Host config should be recreated (no longer contains the stale User). @@ -304,6 +286,50 @@ func TestSetup_AutoApproveRecreatesExistingHost(t *testing.T) { assert.Contains(t, s, "--cluster=cluster-123") } +func TestSetup_PromptsForClusterWhenNotProvided(t *testing.T) { + ctx := cmdio.MockDiscard(t.Context()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, "ssh_config") + + // Replace the cluster picker with a stub returning a fixed ID. This lets the + // test exercise the empty-ClusterID path of Setup without driving promptui. + origPrompt := clusterSelectionPrompt + t.Cleanup(func() { clusterSelectionPrompt = origPrompt }) + promptCalled := false + clusterSelectionPrompt = func(_ context.Context, _ *databricks.WorkspaceClient) (string, error) { + promptCalled = true + return "picked-cluster", nil + } + + m := mocks.NewMockWorkspaceClient(t) + clustersAPI := m.GetMockClustersAPI() + clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "picked-cluster"}).Return(&compute.ClusterDetails{ + DataSecurityMode: compute.DataSecurityModeSingleUser, + }, nil) + + opts := SetupOptions{ + HostName: "test-host", + SSHConfigPath: configPath, + SSHKeysDir: tmpDir, + ShutdownDelay: 30 * time.Second, + } + + err := Setup(ctx, m.WorkspaceClient, opts) + require.NoError(t, err) + assert.True(t, promptCalled, "cluster picker should run when ClusterID is empty") + + // The picked ID must be serialized into the ProxyCommand's --cluster= flag. + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") + hostContent, err := os.ReadFile(hostConfigPath) + require.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "--cluster=picked-cluster") + assert.NotContains(t, hostConfigStr, "--cluster= ") +} + func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(t.Context()) tmpDir := t.TempDir() @@ -332,16 +358,6 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - AutoStartCluster: opts.AutoStartCluster, - ShutdownDelay: opts.ShutdownDelay, - Profile: opts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - require.NoError(t, err) - opts.ProxyCommand = proxyCommand - err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err)