diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index a17a0fa94..b8d44d9e0 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -13,7 +13,6 @@ package api // - github.GetUser(github.Client) // - github.GetRepository(Client) // - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents -// - github.AuthorizedKeys(Client, user) // - codespaces.Create(Client, user, repo, sku, branch, location) // - codespaces.Delete(Client, user, token, name) // - codespaces.Get(Client, token, owner, name) @@ -1060,44 +1059,6 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod return decoded, nil } -// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys -// format) registered by the specified GitHub user. -func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]string, error) { - url := fmt.Sprintf("%s/%s.keys", a.githubServer, user) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } - resp, err := a.do(ctx, req, "/user.keys") - if err != nil { - return nil, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("server returned %s", resp.Status) - } - - b, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("error reading response body: %w", err) - } - - allKeys := string(b) - - var splitKeys []string - for _, key := range strings.Split(allKeys, "\n") { - key = strings.TrimSpace(key) - if key == "" { - continue - } - - splitKeys = append(splitKeys, key) - } - - return splitKeys, nil -} - // do executes the given request and returns the response. It creates an // opentracing span to track the length of the request. func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) { diff --git a/pkg/cmd/codespace/code_test.go b/pkg/cmd/codespace/code_test.go index be888a361..26aa05d4c 100644 --- a/pkg/cmd/codespace/code_test.go +++ b/pkg/cmd/codespace/code_test.go @@ -101,7 +101,6 @@ func testingCodeApp() *App { } func testCodeApiMock() *apiClientMock { - user := &api.User{Login: "monalisa"} testingCodespace := &api.Codespace{ Name: "monalisa-cli-cli-abcdef", WebURL: "https://monalisa-cli-cli-abcdef.github.dev", @@ -118,11 +117,5 @@ func testCodeApiMock() *apiClientMock { } return testingCodespace, nil }, - GetUserFunc: func(_ context.Context) (*api.User, error) { - return user, nil - }, - AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) { - return []string{}, nil - }, } } diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index c48548361..46e69d97c 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -70,13 +70,6 @@ type liveshareSession interface { // Connects to a codespace using Live Share and returns that session func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session liveshareSession, err error) { - // While connecting, ensure in the background that the user has keys installed. - // That lets us report a more useful error message if they don't. - authkeys := make(chan error, 1) - go func() { - authkeys <- checkAuthorizedKeys(ctx, a.apiClient) - }() - liveshareLogger := noopLogger() if debug { debugLogger, err := newFileLogger(debugFile) @@ -91,9 +84,6 @@ func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App session, err = codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace) if err != nil { - if authErr := <-authkeys; authErr != nil { - return nil, fmt.Errorf("failed to fetch authorization keys: %w", authErr) - } return nil, fmt.Errorf("failed to connect to Live Share: %w", err) } @@ -102,7 +92,6 @@ func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App //go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient type apiClient interface { - GetUser(ctx context.Context) (*api.User, error) GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error) GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*api.Codespace, error) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) @@ -112,7 +101,6 @@ type apiClient interface { CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) - AuthorizedKeys(ctx context.Context, user string) ([]string, error) GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error) @@ -237,25 +225,6 @@ func ask(qs []*survey.Question, response interface{}) error { return err } -// checkAuthorizedKeys reports an error if the user has not registered any SSH keys; -// see https://github.com/cli/cli/v2/issues/166#issuecomment-921769703. -// The check is not required for security but it improves the error message. -func checkAuthorizedKeys(ctx context.Context, client apiClient) error { - user, err := client.GetUser(ctx) - if err != nil { - return fmt.Errorf("error getting user: %w", err) - } - - keys, err := client.AuthorizedKeys(ctx, user.Login) - if err != nil { - return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err) - } - if len(keys) == 0 { - return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user) - } - return nil // success -} - var ErrTooManyArgs = errors.New("the command accepts no arguments") func noArgsConstraint(cmd *cobra.Command, args []string) error { diff --git a/pkg/cmd/codespace/delete_test.go b/pkg/cmd/codespace/delete_test.go index 3555a2a34..c9309f960 100644 --- a/pkg/cmd/codespace/delete_test.go +++ b/pkg/cmd/codespace/delete_test.go @@ -15,7 +15,6 @@ import ( ) func TestDelete(t *testing.T) { - user := &api.User{Login: "hubot"} now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z") daysAgo := func(n int) string { return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339) @@ -207,9 +206,6 @@ func TestDelete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { apiMock := &apiClientMock{ - GetUserFunc: func(_ context.Context) (*api.User, error) { - return user, nil - }, DeleteCodespaceFunc: func(_ context.Context, name string, orgName string, userName string) error { if tt.deleteErr != nil { return tt.deleteErr diff --git a/pkg/cmd/codespace/logs_test.go b/pkg/cmd/codespace/logs_test.go index 49a97c47d..bd4ea02f8 100644 --- a/pkg/cmd/codespace/logs_test.go +++ b/pkg/cmd/codespace/logs_test.go @@ -21,7 +21,6 @@ func TestPendingOperationDisallowsLogs(t *testing.T) { } func testingLogsApp() *App { - user := &api.User{Login: "monalisa"} disabledCodespace := &api.Codespace{ Name: "disabledCodespace", PendingOperation: true, @@ -34,12 +33,6 @@ func testingLogsApp() *App { } return nil, nil }, - GetUserFunc: func(_ context.Context) (*api.User, error) { - return user, nil - }, - AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) { - return []string{}, nil - }, } ios, _, _, _ := iostreams.Test() diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index 739e2c28b..5727591d1 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -16,9 +16,6 @@ import ( // // // make and configure a mocked apiClient // mockedapiClient := &apiClientMock{ -// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]string, error) { -// panic("mock out the AuthorizedKeys method") -// }, // CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { // panic("mock out the CreateCodespace method") // }, @@ -49,9 +46,6 @@ import ( // GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) { // panic("mock out the GetRepository method") // }, -// GetUserFunc: func(ctx context.Context) (*api.User, error) { -// panic("mock out the GetUser method") -// }, // ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { // panic("mock out the ListCodespaces method") // }, @@ -71,9 +65,6 @@ import ( // // } type apiClientMock struct { - // AuthorizedKeysFunc mocks the AuthorizedKeys method. - AuthorizedKeysFunc func(ctx context.Context, user string) ([]string, error) - // CreateCodespaceFunc mocks the CreateCodespace method. CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) @@ -104,9 +95,6 @@ type apiClientMock struct { // GetRepositoryFunc mocks the GetRepository method. GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error) - // GetUserFunc mocks the GetUser method. - GetUserFunc func(ctx context.Context) (*api.User, error) - // ListCodespacesFunc mocks the ListCodespaces method. ListCodespacesFunc func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) @@ -121,13 +109,6 @@ type apiClientMock struct { // calls tracks calls to the methods. calls struct { - // AuthorizedKeys holds details about calls to the AuthorizedKeys method. - AuthorizedKeys []struct { - // Ctx is the ctx argument value. - Ctx context.Context - // User is the user argument value. - User string - } // CreateCodespace holds details about calls to the CreateCodespace method. CreateCodespace []struct { // Ctx is the ctx argument value. @@ -218,11 +199,6 @@ type apiClientMock struct { // Nwo is the nwo argument value. Nwo string } - // GetUser holds details about calls to the GetUser method. - GetUser []struct { - // Ctx is the ctx argument value. - Ctx context.Context - } // ListCodespaces holds details about calls to the ListCodespaces method. ListCodespaces []struct { // Ctx is the ctx argument value. @@ -260,7 +236,6 @@ type apiClientMock struct { UserName string } } - lockAuthorizedKeys sync.RWMutex lockCreateCodespace sync.RWMutex lockDeleteCodespace sync.RWMutex lockEditCodespace sync.RWMutex @@ -271,48 +246,12 @@ type apiClientMock struct { lockGetCodespacesMachines sync.RWMutex lockGetOrgMemberCodespace sync.RWMutex lockGetRepository sync.RWMutex - lockGetUser sync.RWMutex lockListCodespaces sync.RWMutex lockListDevContainers sync.RWMutex lockStartCodespace sync.RWMutex lockStopCodespace sync.RWMutex } -// AuthorizedKeys calls AuthorizedKeysFunc. -func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]string, error) { - if mock.AuthorizedKeysFunc == nil { - panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") - } - callInfo := struct { - Ctx context.Context - User string - }{ - Ctx: ctx, - User: user, - } - mock.lockAuthorizedKeys.Lock() - mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo) - mock.lockAuthorizedKeys.Unlock() - return mock.AuthorizedKeysFunc(ctx, user) -} - -// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys. -// Check the length with: -// len(mockedapiClient.AuthorizedKeysCalls()) -func (mock *apiClientMock) AuthorizedKeysCalls() []struct { - Ctx context.Context - User string -} { - var calls []struct { - Ctx context.Context - User string - } - mock.lockAuthorizedKeys.RLock() - calls = mock.calls.AuthorizedKeys - mock.lockAuthorizedKeys.RUnlock() - return calls -} - // CreateCodespace calls CreateCodespaceFunc. func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) { if mock.CreateCodespaceFunc == nil { @@ -703,37 +642,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct { return calls } -// GetUser calls GetUserFunc. -func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) { - if mock.GetUserFunc == nil { - panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called") - } - callInfo := struct { - Ctx context.Context - }{ - Ctx: ctx, - } - mock.lockGetUser.Lock() - mock.calls.GetUser = append(mock.calls.GetUser, callInfo) - mock.lockGetUser.Unlock() - return mock.GetUserFunc(ctx) -} - -// GetUserCalls gets all the calls that were made to GetUser. -// Check the length with: -// len(mockedapiClient.GetUserCalls()) -func (mock *apiClientMock) GetUserCalls() []struct { - Ctx context.Context -} { - var calls []struct { - Ctx context.Context - } - mock.lockGetUser.RLock() - calls = mock.calls.GetUser - mock.lockGetUser.RUnlock() - return calls -} - // ListCodespaces calls ListCodespacesFunc. func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) { if mock.ListCodespacesFunc == nil { diff --git a/pkg/cmd/codespace/ports_test.go b/pkg/cmd/codespace/ports_test.go index f35c78363..3d3c87d95 100644 --- a/pkg/cmd/codespace/ports_test.go +++ b/pkg/cmd/codespace/ports_test.go @@ -247,7 +247,6 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) { } func testingPortsApp() *App { - user := &api.User{Login: "monalisa"} disabledCodespace := &api.Codespace{ Name: "disabledCodespace", PendingOperation: true, @@ -260,12 +259,6 @@ func testingPortsApp() *App { } return nil, nil }, - GetUserFunc: func(_ context.Context) (*api.User, error) { - return user, nil - }, - AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) { - return []string{}, nil - }, } ios, _, _, _ := iostreams.Test() diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index b79c6c13e..4b8d9bfc9 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -34,6 +34,8 @@ import ( const automaticPrivateKeyNameOld = "codespaces" const automaticPrivateKeyName = "codespaces.auto" +var errKeyFileNotFound = errors.New("SSH key file does not exist") + type sshOptions struct { codespace string profile string @@ -148,18 +150,14 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e sshContext := ssh.Context{} startSSHOptions := liveshare.StartSSHServerOptions{} - useAutoKeys, err := useAutomaticSSHKeys(ctx, sshContext, a.apiClient, args, opts) + keyPair, shouldAddArg, err := selectSSHKeys(ctx, sshContext, args, opts) if err != nil { - return fmt.Errorf("checking ssh key configuration: %w", err) + return fmt.Errorf("selecting ssh keys: %w", err) } - if useAutoKeys { - keyPair, err := setupAutomaticSSHKeys(sshContext) - if err != nil { - return fmt.Errorf("failed to generate ssh keys: %w", err) - } + startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath - startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath + if shouldAddArg { // For both cp and ssh, flags need to come first in the args (before a command in ssh and files in cp), so prepend this flag args = append([]string{"-i", keyPair.PrivateKeyPath}, args...) } @@ -236,48 +234,71 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } } -func useAutomaticSSHKeys( +// selectSSHKeys evaluates available key pairs and select which should be used to connect to the codespace +// using the precedence rules below. If there is no error, a keypair is always returned and additionally a +// bool flag is returned to specify if the private key need be appended to the ssh arguments (it doesn't need +// to be if the key was selected from a -i argument). +// +// Precedence rules: +// 1. Key which is specified by -i +// 2. Automatic key, if it already exists +// 3. First valid keypair in ssh config (according to ssh -G) +// 4. Automatic key, newly created +func selectSSHKeys( ctx context.Context, sshContext ssh.Context, - apiClient apiClient, args []string, opts sshOptions, -) (bool, error) { +) (*ssh.KeyPair, bool, error) { customConfigPath := "" for i := 0; i < len(args); i += 1 { arg := args[i] if arg == "-i" { - if i+1 < len(args) && path.Base(args[i+1]) == automaticPrivateKeyName { - return true, nil + if i+1 >= len(args) { + return nil, false, errors.New("missing value to -i argument") } - // User specified a custom identity file so just trust it is correct - return false, nil + // User manually specified an identity file so just trust it is correct + return &ssh.KeyPair{ + PrivateKeyPath: args[i+1], + PublicKeyPath: args[i+1] + ".pub", + }, false, nil } - if arg == "-F" && i < len(args)-1 { + if arg == "-F" && i+1 < len(args) { // ssh only pays attention to that last specified -F value, so it's correct to overwrite here customConfigPath = args[i+1] } } - if automaticKeySSHKeysExist(sshContext) { + if autoKeyPair := automaticSSHKeyPair(sshContext); autoKeyPair != nil { // If the automatic keys already exist, just use them - return true, nil + return autoKeyPair, true, nil } - // If there is a public key uploaded which matches a configured - // private key, there is no need for automatic key generation - hasPublicKey, err := hasUploadedPublicKeyForConfig(ctx, apiClient, customConfigPath, opts.profile) + keyPair, err := firstConfiguredKeyPair(ctx, customConfigPath, opts.profile) + if err != nil { + if !errors.Is(err, errKeyFileNotFound) { + return nil, false, fmt.Errorf("checking configured keys: %w", err) + } - return !hasPublicKey, err + // no valid key in ssh config, generate one + keyPair, err = generateAutomaticSSHKeys(sshContext) + if err != nil { + return nil, false, fmt.Errorf("generating automatic keypair: %w", err) + } + } + + return keyPair, true, nil } -func automaticKeySSHKeysExist(sshContext ssh.Context) bool { +// automaticSSHKeyPair returns the paths to the automatic key pair files, if they both exist +func automaticSSHKeyPair(sshContext ssh.Context) *ssh.KeyPair { publicKeys, err := sshContext.LocalPublicKeys() if err != nil { - return false + // The error would be that the .ssh dir doesn't exist, which just means that the keypair also doesn't exist + return nil } for _, publicKey := range publicKeys { @@ -288,13 +309,18 @@ func automaticKeySSHKeysExist(sshContext ssh.Context) bool { privateKey := strings.TrimSuffix(publicKey, ".pub") _, err := os.Stat(privateKey) - return err == nil + if err == nil { + return &ssh.KeyPair{ + PrivateKeyPath: privateKey, + PublicKeyPath: publicKey, + } + } } - return false + return nil } -func setupAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) { +func generateAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) { keyPair := checkAndUpdateOldKeyPair(sshContext) if keyPair != nil { return keyPair, nil @@ -355,76 +381,13 @@ func checkAndUpdateOldKeyPair(sshContext ssh.Context) *ssh.KeyPair { return nil } -// hasUploadedValidCustomKey checks which private keys are used in ssh configration and -// compares them to uploaded public keys to see if there is a match -func hasUploadedPublicKeyForConfig( - ctx context.Context, - apiClient apiClient, - customConfigFile string, - customHost string, -) (bool, error) { - configuredPrivateKeyPaths, err := getConfiguredPrivateKeys(ctx, customConfigFile, customHost) - if err != nil { - return false, fmt.Errorf("getting local ssh keys: %w", err) - } - - configuredPublicKeys := make(map[string]bool) - for _, privateKeyPath := range configuredPrivateKeyPaths { - publicKeyPath := privateKeyPath + ".pub" - - publicKeyContent, err := os.ReadFile(publicKeyPath) - if err != nil { - // The default configuration includes standard keys like id_rsa or id_ed25519, - // but these may not actually exist so just skip them - continue - } - - parts := strings.SplitN(string(publicKeyContent), " ", 3) - if len(parts) < 2 { - // Unexpected format, skip it - continue - } - - publicKeyWithoutComment := strings.Join(parts[:2], " ") - - configuredPublicKeys[publicKeyWithoutComment] = true - } - - if len(configuredPublicKeys) == 0 { - // There are no local private keys which ssh would use - return false, nil - } - - user, err := apiClient.GetUser(ctx) - if err != nil { - return false, fmt.Errorf("fetching user account: %w", err) - } - - uploadedPublicKeys, err := apiClient.AuthorizedKeys(ctx, user.Login) - if err != nil { - return false, fmt.Errorf("fetching known ssh keys: %w", err) - } - - if len(uploadedPublicKeys) == 0 { - return false, nil - } - - for _, uploadedPublicKey := range uploadedPublicKeys { - if configuredPublicKeys[uploadedPublicKey] { - return true, nil - } - } - - return false, nil -} - -// getConfiguredPrivateKeys reads the effective configuration for a localhost -// connection and returns all private keys which would be tried for authentication -func getConfiguredPrivateKeys( +// firstConfiguredKeyPair reads the effective configuration for a localhost +// connection and returns the first valid key pair which would be tried for authentication +func firstConfiguredKeyPair( ctx context.Context, customConfigFile string, customHost string, -) ([]string, error) { +) (*ssh.KeyPair, error) { sshExe, err := safeexec.LookPath("ssh") if err != nil { return nil, fmt.Errorf("could not find ssh executable: %w", err) @@ -449,27 +412,55 @@ func getConfiguredPrivateKeys( return nil, fmt.Errorf("could not load ssh configuration: %w", err) } - var privateKeyPaths []string - - userHomeDir, _ := os.UserHomeDir() - configLines := strings.Split(string(configBytes), "\n") for _, line := range configLines { line = strings.TrimSpace(line) if strings.HasPrefix(line, "identityfile ") { - path := strings.SplitN(line, " ", 2)[1] + privateKeyPath := strings.SplitN(line, " ", 2)[1] - if strings.HasPrefix(path, "~") { - // os.Stat can't handle ~, so convert it to the real path - path = strings.Replace(path, "~", userHomeDir, 1) + keypair, err := keypairForPrivateKey(privateKeyPath) + if errors.Is(err, errKeyFileNotFound) { + continue + } + if err != nil { + return nil, fmt.Errorf("loading ssh config: %w", err) } - privateKeyPaths = append(privateKeyPaths, path) + return keypair, nil } } - return privateKeyPaths, nil + return nil, errKeyFileNotFound +} + +// keypairForPrivateKey returns the KeyPair with the specified private key if it and the public key both exist +func keypairForPrivateKey(privateKeyPath string) (*ssh.KeyPair, error) { + if strings.HasPrefix(privateKeyPath, "~") { + userHomeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("getting home dir: %w", err) + } + + // os.Stat can't handle ~, so convert it to the real path + privateKeyPath = strings.Replace(privateKeyPath, "~", userHomeDir, 1) + } + + // The default configuration includes standard keys like id_rsa or id_ed25519, + // but these may not actually exist + if _, err := os.Stat(privateKeyPath); err != nil { + return nil, errKeyFileNotFound + } + + publicKeyPath := privateKeyPath + ".pub" + if _, err := os.Stat(publicKeyPath); err != nil { + return nil, errKeyFileNotFound + } + + return &ssh.KeyPair{ + PrivateKeyPath: privateKeyPath, + PublicKeyPath: publicKeyPath, + }, nil } func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) { @@ -535,12 +526,6 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro close(sshUsers) }() - // While the above fetches are running, ensure that the user has keys installed. - // That lets us report a more useful error message if they don't. - if err = checkAuthorizedKeys(ctx, a.apiClient); err != nil { - return err - } - t, err := template.New("ssh_config").Parse(heredoc.Doc(` Host cs.{{.Name}}.{{.EscapedRef}} User {{.SSHUser}} diff --git a/pkg/cmd/codespace/ssh_test.go b/pkg/cmd/codespace/ssh_test.go index 7e0137641..fdbca1974 100644 --- a/pkg/cmd/codespace/ssh_test.go +++ b/pkg/cmd/codespace/ssh_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "os" - "path" "path/filepath" "strings" "testing" @@ -27,11 +26,11 @@ func TestPendingOperationDisallowsSSH(t *testing.T) { } } -func TestAutomaticSSHKeyPairs(t *testing.T) { +func TestGenerateAutomaticSSHKeys(t *testing.T) { tests := []struct { - // These files exist when calling setupAutomaticSSHKeys + // These files exist when calling generateAutomaticSSHKeys existingFiles []string - // These files should exist after setupAutomaticSSHKeys finishes + // These files should exist after generateAutomaticSSHKeys finishes wantFinalFiles []string }{ // Basic case: no existing keys, they should be created @@ -82,12 +81,12 @@ func TestAutomaticSSHKeyPairs(t *testing.T) { f.Close() } - keyPair, err := setupAutomaticSSHKeys(sshContext) + keyPair, err := generateAutomaticSSHKeys(sshContext) if err != nil { - t.Errorf("Unexpected error from setupAutomaticSSHKeys: %v", err) + t.Errorf("Unexpected error from generateAutomaticSSHKeys: %v", err) } if keyPair == nil { - t.Fatal("Unexpected nil KeyPair from setupAutomaticSSHKeys") + t.Fatal("Unexpected nil KeyPair from generateAutomaticSSHKeys") } if !strings.HasSuffix(keyPair.PrivateKeyPath, automaticPrivateKeyName) { t.Errorf("Expected private key path %v, got %v", automaticPrivateKeyName, keyPair.PrivateKeyPath) @@ -99,7 +98,7 @@ func TestAutomaticSSHKeyPairs(t *testing.T) { // Check that all the expected files are present for _, file := range tt.wantFinalFiles { if _, err := os.Stat(filepath.Join(dir, file)); err != nil { - t.Errorf("Want file %q to exist after setupAutomaticSSHKeys but it doesn't", file) + t.Errorf("Want file %q to exist after generateAutomaticSSHKeys but it doesn't", file) } } @@ -119,177 +118,152 @@ func TestAutomaticSSHKeyPairs(t *testing.T) { } if !isWantedFile { - t.Errorf("Unexpected file %q exists after setupAutomaticSSHKeys", filename) + t.Errorf("Unexpected file %q exists after generateAutomaticSSHKeys", filename) } } } } -func TestUseAutomaticSSHKeysIdentityFileArg(t *testing.T) { +func TestSelectSSHKeys(t *testing.T) { tests := []struct { - identityFileArg string - wantResult bool + sshDirFiles []string + sshConfigKeys []string + sshArgs []string + profileOpt string + wantKeyPair *ssh.KeyPair + wantShouldAddArg bool }{ - {"custom-private-key", false}, - {automaticPrivateKeyName, true}, - {"", false}, // Edge case check for missing arg value - } - - for _, tt := range tests { - t.Logf("%v", tt.identityFileArg) - - args := []string{"-i"} - if tt.identityFileArg != "" { - args = append(args, tt.identityFileArg) - } - - result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{}) - - if err != nil { - t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err) - } - - if result != tt.wantResult { - t.Errorf("Want useAutomaticSSHKeys to be %v, got %v", tt.wantResult, result) - } - } -} - -func TestUseAutomaticSSHKeysWithAutoKeysExist(t *testing.T) { - dir := t.TempDir() - - f, err := os.Create(path.Join(dir, automaticPrivateKeyName)) - if err != nil { - t.Errorf("Failed to create test private key: %v", err) - } - f.Close() - - f, err = os.Create(path.Join(dir, automaticPrivateKeyName+".pub")) - if err != nil { - t.Errorf("Failed to create test public key: %v", err) - } - f.Close() - - sshContext := ssh.Context{ - ConfigDir: dir, - } - result, err := useAutomaticSSHKeys(context.Background(), sshContext, nil, nil, sshOptions{}) - - if err != nil { - t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err) - } - - if result != true { - t.Errorf("Want useAutomaticSSHKeys to be true, got false") - } -} - -func TestHasUploadedPublicKeyForConfig(t *testing.T) { - type testLocalKeyPair struct { - privateKeyFile string - publicKeyContent string - } - - tests := []struct { - apiAuthorizedPublicKeys []string - localKeyPairs []testLocalKeyPair - wantResult bool - }{ - // Failure tests + // -i tests { - // No API keys and no local keys - wantResult: false, + sshArgs: []string{"-i", "custom-private-key"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key", PublicKeyPath: "custom-private-key.pub"}, }, { - // Has API keys, but no local keys - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key"}, - wantResult: false, + sshArgs: []string{"-i", automaticPrivateKeyName}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, }, { - // No API keys, but has local keys - localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-key"}}, - wantResult: false, - }, - { - // API keys and local keys, but not matching - apiAuthorizedPublicKeys: []string{"ssh-rsa test-api-key"}, - localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-local-key"}}, - wantResult: false, + // Edge case check for missing arg value + sshArgs: []string{"-i"}, }, - // Successful tests + // Auto key exists tests { - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key"}, - localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-key"}}, - wantResult: true, + sshDirFiles: []string{automaticPrivateKeyName, automaticPrivateKeyName + ".pub"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, + wantShouldAddArg: true, }, { - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key-2"}, - localKeyPairs: []testLocalKeyPair{{"keyfile1", "ssh-rsa test-key-1"}}, - wantResult: true, - }, - { - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1"}, - localKeyPairs: []testLocalKeyPair{{"keyfile1", "ssh-rsa test-key-1"}, {"keyfile2", "ssh-rsa test-key-2"}}, - wantResult: true, - }, - { - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key-2"}, - localKeyPairs: []testLocalKeyPair{{"keyfile3", "ssh-rsa test-key-3"}, {"keyfile2", "ssh-rsa test-key-2"}}, - wantResult: true, + sshDirFiles: []string{automaticPrivateKeyName, automaticPrivateKeyName + ".pub", "custom-private-key", "custom-private-key.pub"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, + wantShouldAddArg: true, }, - // Extra case - local key contain comments + // SSH config tests { - apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key"}, - localKeyPairs: []testLocalKeyPair{{"keyfile3", "ssh-rsa test-key a comment on the key"}}, - wantResult: true, + sshDirFiles: []string{"custom-private-key", "custom-private-key.pub"}, + sshConfigKeys: []string{"custom-private-key"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key", PublicKeyPath: "custom-private-key.pub"}, + wantShouldAddArg: true, + }, + { + // 2 pairs, but only 1 is configured + sshDirFiles: []string{"custom-private-key", "custom-private-key.pub", "custom-private-key-2", "custom-private-key-2.pub"}, + sshConfigKeys: []string{"custom-private-key-2"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key-2", PublicKeyPath: "custom-private-key-2.pub"}, + wantShouldAddArg: true, + }, + { + // 2 pairs, but only 1 has both public and private + sshDirFiles: []string{"custom-private-key", "custom-private-key-2", "custom-private-key-2.pub"}, + sshConfigKeys: []string{"custom-private-key", "custom-private-key-2"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key-2", PublicKeyPath: "custom-private-key-2.pub"}, + wantShouldAddArg: true, + }, + + // Automatic key tests + { + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, + wantShouldAddArg: true, + }, + { + // Renames old key pair to new + sshDirFiles: []string{automaticPrivateKeyNameOld, automaticPrivateKeyNameOld + ".pub"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, + wantShouldAddArg: true, + }, + { + // Other key is configured, but doesn't exist + sshConfigKeys: []string{"custom-private-key"}, + wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"}, + wantShouldAddArg: true, }, } for _, tt := range tests { - t.Logf("%+v", tt) + sshDir := t.TempDir() + sshContext := ssh.Context{ConfigDir: sshDir} - mockApi := &apiClientMock{ - GetUserFunc: func(ctx context.Context) (*api.User, error) { - return &api.User{Login: "test"}, nil - }, - AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) { - return tt.apiAuthorizedPublicKeys, nil - }, - } - - dir := t.TempDir() - configPath := path.Join(dir, "test-config") - - configContent := "" - for _, pair := range tt.localKeyPairs { - configContent += fmt.Sprintf("IdentityFile %s\n", path.Join(dir, pair.privateKeyFile)) - - err := os.WriteFile(path.Join(dir, pair.privateKeyFile+".pub"), []byte(pair.publicKeyContent), 0666) + for _, file := range tt.sshDirFiles { + f, err := os.Create(filepath.Join(sshDir, file)) if err != nil { - t.Fatalf("could not write test public key file %v", err) + t.Errorf("Failed to create test ssh dir file %q: %v", file, err) } + f.Close() + } + + if tt.sshConfigKeys != nil { + configPath := filepath.Join(sshDir, "test-config") + + configContent := "" + for _, key := range tt.sshConfigKeys { + configContent += fmt.Sprintf("IdentityFile %s\n", filepath.Join(sshDir, key)) + } + + err := os.WriteFile(configPath, []byte(configContent), 0666) + if err != nil { + t.Fatalf("could not write test config %v", err) + } + + tt.sshArgs = append(tt.sshArgs, "-F", configPath) + } + + gotKeyPair, gotShouldAddArg, err := selectSSHKeys(context.Background(), sshContext, tt.sshArgs, sshOptions{profile: tt.profileOpt}) + + if tt.wantKeyPair == nil { + if err == nil { + t.Errorf("Expected error from selectSSHKeys but got nil") + } + + continue } - err := os.WriteFile(configPath, []byte(configContent), 0666) if err != nil { - t.Fatalf("could not write test config %v", err) + t.Errorf("Unexpected error from selectSSHKeys: %v", err) + continue } - result, err := hasUploadedPublicKeyForConfig(context.Background(), mockApi, configPath, "") - if err != nil { - t.Errorf("Unexpected error from hasUploadedPublicKeyForConfig: %v", err) + if gotKeyPair == nil { + t.Errorf("Expected non-nil result from selectSSHKeys but got nil") + continue } - if result != tt.wantResult { - t.Errorf("Want hasUploadedPublicKeyForConfig to be %v, got %v", tt.wantResult, result) + if gotShouldAddArg != tt.wantShouldAddArg { + t.Errorf("Got wrong shouldAddArg value from selectSSHKeys, wanted %v got %v", tt.wantShouldAddArg, gotShouldAddArg) + continue + } + + // Strip the dir (sshDir) from the gotKeyPair paths so that they match wantKeyPair (which doesn't know the directory) + gotKeyPair.PrivateKeyPath = filepath.Base(gotKeyPair.PrivateKeyPath) + gotKeyPair.PublicKeyPath = filepath.Base(gotKeyPair.PublicKeyPath) + + if fmt.Sprintf("%v", gotKeyPair) != fmt.Sprintf("%v", tt.wantKeyPair) { + t.Errorf("Want selectSSHKeys result to be %v, got %v", tt.wantKeyPair, gotKeyPair) } } } func testingSSHApp() *App { - user := &api.User{Login: "monalisa"} disabledCodespace := &api.Codespace{ Name: "disabledCodespace", PendingOperation: true, @@ -302,12 +276,6 @@ func testingSSHApp() *App { } return nil, nil }, - GetUserFunc: func(_ context.Context) (*api.User, error) { - return user, nil - }, - AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) { - return []string{}, nil - }, } ios, _, _, _ := iostreams.Test()