diff --git a/internal/codespaces/api/api.go b/internal/codespaces/api/api.go index 8c2bb80e2..eab4db8dd 100644 --- a/internal/codespaces/api/api.go +++ b/internal/codespaces/api/api.go @@ -1062,7 +1062,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod // AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys // format) registered by the specified GitHub user. -func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { +func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]string, error) { url := fmt.Sprintf("%s/%s.keys", a.githubServer, user) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -1082,7 +1082,19 @@ func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("error reading response body: %w", err) } - return b, nil + + allKeys := string(b) + + splitKeys := []string{} + for _, key := range strings.Split(allKeys, "\n") { + if key == "" { + continue + } + + splitKeys = append(splitKeys, strings.TrimSpace(key)) + } + + return splitKeys, nil } // do executes the given request and returns the response. It creates an diff --git a/pkg/cmd/codespace/common.go b/pkg/cmd/codespace/common.go index 759cec3ce..2ec150268 100644 --- a/pkg/cmd/codespace/common.go +++ b/pkg/cmd/codespace/common.go @@ -115,7 +115,7 @@ type apiClient interface { CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error) GetRepository(ctx context.Context, nwo string) (*api.Repository, error) - AuthorizedKeys(ctx context.Context, user string) ([]byte, error) + AuthorizedKeys(ctx context.Context, user string) ([]string, error) GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error) GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error) ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error) diff --git a/pkg/cmd/codespace/mock_api.go b/pkg/cmd/codespace/mock_api.go index fe3b61f00..86c29039b 100644 --- a/pkg/cmd/codespace/mock_api.go +++ b/pkg/cmd/codespace/mock_api.go @@ -279,7 +279,7 @@ type apiClientMock struct { } // AuthorizedKeys calls AuthorizedKeysFunc. -func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) { +func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]string, error) { if mock.AuthorizedKeysFunc == nil { panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called") } diff --git a/pkg/cmd/codespace/ssh.go b/pkg/cmd/codespace/ssh.go index e52422e51..eff81fd7d 100644 --- a/pkg/cmd/codespace/ssh.go +++ b/pkg/cmd/codespace/ssh.go @@ -133,12 +133,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e sshContext := ssh.Context{} startSSHOptions := liveshare.StartSSHServerOptions{} - if shouldUseAutomaticSSHKeys(args, opts) { - keyPair, err := setupAutomaticSSHKeys(sshContext) - if err != nil { - return fmt.Errorf("failed to generate ssh keys: %w", err) - } + keyPair, err := setupAutomaticSSHKeys(ctx, sshContext, a.apiClient, args, opts) + if err != nil { + return fmt.Errorf("failed to generate ssh keys: %w", err) + } + if keyPair != nil { startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath // For both cp and ssh, flags need to come first in the args (before a command in ssh and files in cp), so prepend this flag args = append([]string{"-i", keyPair.PrivateKeyPath}, args...) @@ -216,7 +216,37 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e } } -func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool { +func setupAutomaticSSHKeys( + ctx context.Context, + sshContext ssh.Context, + apiClient apiClient, + args []string, + opts sshOptions, +) (*ssh.KeyPair, error) { + if isUsingCustomIdentityOrProfile(args, opts) { + return nil, nil + } + + keyPair := checkAndUpdateOldKeyPair(sshContext) + if keyPair != nil { + return keyPair, nil + } + + if !sshContext.HasPrivateKey(automaticPrivateKeyName) { + if hasAnyPrivateKeyForUploadedPublicKeys(ctx, sshContext, apiClient) { + return nil, nil + } + } + + keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "") + if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) { + return nil, err + } + + return keyPair, nil +} + +func isUsingCustomIdentityOrProfile(args []string, opts sshOptions) bool { if opts.profile != "" { // The profile may specify the identity file so cautiously don't override anything with that option return false @@ -237,20 +267,6 @@ func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool { return true } -func setupAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) { - keyPair := checkAndUpdateOldKeyPair(sshContext) - if keyPair != nil { - return keyPair, nil - } - - keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "") - if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) { - return nil, err - } - - return keyPair, nil -} - // checkAndUpdateOldKeyPair handles backward compatibility with the old keypair names. // If the old public and private keys both exist they are renamed to the new name. // The return value is non-nil only if the rename happens. @@ -298,6 +314,25 @@ func checkAndUpdateOldKeyPair(sshContext ssh.Context) *ssh.KeyPair { return nil } +func hasAnyPrivateKeyForUploadedPublicKeys( + ctx context.Context, + sshContext ssh.Context, + apiClient apiClient, +) bool { + user, err := apiClient.GetUser(ctx) + if err != nil { + return false + } + + publicKeys, err := apiClient.AuthorizedKeys(ctx, user.Login) + if err != nil { + return false + } + + hasKey, _ := sshContext.HasAnyMatchingPrivateKey(publicKeys) + return hasKey +} + func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/pkg/ssh/ssh_keys.go b/pkg/ssh/ssh_keys.go index c750b608a..e903c74e9 100644 --- a/pkg/ssh/ssh_keys.go +++ b/pkg/ssh/ssh_keys.go @@ -3,10 +3,12 @@ package ssh import ( "errors" "fmt" + "io/ioutil" "os" "os/exec" "path/filepath" "runtime" + "strings" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/run" @@ -34,6 +36,64 @@ func (c *Context) LocalPublicKeys() ([]string, error) { return filepath.Glob(filepath.Join(sshDir, "*.pub")) } +func (c *Context) HasPrivateKey(keyName string) bool { + sshDir, err := c.sshDir() + if err != nil { + return false + } + + keyFile := filepath.Join(sshDir, keyName) + _, err = os.Stat(keyFile) + + return err == nil +} + +func (c *Context) HasAnyMatchingPrivateKey(publicKeys []string) (bool, error) { + keygenExe, err := c.findKeygen() + if err != nil { + return false, fmt.Errorf("could not find ssh-keygen: %w", err) + } + + sshDir, err := c.sshDir() + if err != nil { + // If .ssh doesn't exist it's not an error, there are just no keys to match + return false, nil + } + + files, err := ioutil.ReadDir(sshDir) + if err != nil { + return false, fmt.Errorf("could not read .ssh directory: %w", err) + } + + for _, file := range files { + if file.IsDir() { + continue + } + + fullPath := filepath.Join(sshDir, file.Name()) + keygenCmd := exec.Command(keygenExe, "-y", "-f", fullPath) + + keygenRunnable := run.PrepareCmd(keygenCmd) + outputBytes, err := keygenRunnable.Output() + if err != nil { + // Just move on, it probably wasn't a key file + continue + } + + fullPublicKey := string(outputBytes) + typeAndKey := strings.SplitN(fullPublicKey, " ", 3)[:2] + publicKeyWithoutComment := strings.Join(typeAndKey, " ") + + for _, publicKey := range publicKeys { + if publicKeyWithoutComment == publicKey { + return true, nil + } + } + } + + return false, nil +} + func (c *Context) HasKeygen() bool { _, err := c.findKeygen() return err == nil