Initial pass

This commit is contained in:
Caleb Brose 2022-07-16 20:38:23 +00:00 committed by GitHub
parent 954689ea9f
commit f313953642
5 changed files with 131 additions and 24 deletions

View file

@ -1062,7 +1062,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod
// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys
// format) registered by the specified GitHub user.
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
url := fmt.Sprintf("%s/%s.keys", a.githubServer, user)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
@ -1082,7 +1082,19 @@ func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
return b, nil
allKeys := string(b)
splitKeys := []string{}
for _, key := range strings.Split(allKeys, "\n") {
if key == "" {
continue
}
splitKeys = append(splitKeys, strings.TrimSpace(key))
}
return splitKeys, nil
}
// do executes the given request and returns the response. It creates an

View file

@ -115,7 +115,7 @@ type apiClient interface {
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
AuthorizedKeys(ctx context.Context, user string) ([]string, error)
GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)

View file

@ -279,7 +279,7 @@ type apiClientMock struct {
}
// AuthorizedKeys calls AuthorizedKeysFunc.
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
if mock.AuthorizedKeysFunc == nil {
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
}

View file

@ -133,12 +133,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
sshContext := ssh.Context{}
startSSHOptions := liveshare.StartSSHServerOptions{}
if shouldUseAutomaticSSHKeys(args, opts) {
keyPair, err := setupAutomaticSSHKeys(sshContext)
if err != nil {
return fmt.Errorf("failed to generate ssh keys: %w", err)
}
keyPair, err := setupAutomaticSSHKeys(ctx, sshContext, a.apiClient, args, opts)
if err != nil {
return fmt.Errorf("failed to generate ssh keys: %w", err)
}
if keyPair != nil {
startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath
// For both cp and ssh, flags need to come first in the args (before a command in ssh and files in cp), so prepend this flag
args = append([]string{"-i", keyPair.PrivateKeyPath}, args...)
@ -216,7 +216,37 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
}
}
func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool {
func setupAutomaticSSHKeys(
ctx context.Context,
sshContext ssh.Context,
apiClient apiClient,
args []string,
opts sshOptions,
) (*ssh.KeyPair, error) {
if isUsingCustomIdentityOrProfile(args, opts) {
return nil, nil
}
keyPair := checkAndUpdateOldKeyPair(sshContext)
if keyPair != nil {
return keyPair, nil
}
if !sshContext.HasPrivateKey(automaticPrivateKeyName) {
if hasAnyPrivateKeyForUploadedPublicKeys(ctx, sshContext, apiClient) {
return nil, nil
}
}
keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "")
if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) {
return nil, err
}
return keyPair, nil
}
func isUsingCustomIdentityOrProfile(args []string, opts sshOptions) bool {
if opts.profile != "" {
// The profile may specify the identity file so cautiously don't override anything with that option
return false
@ -237,20 +267,6 @@ func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool {
return true
}
func setupAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) {
keyPair := checkAndUpdateOldKeyPair(sshContext)
if keyPair != nil {
return keyPair, nil
}
keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "")
if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) {
return nil, err
}
return keyPair, nil
}
// checkAndUpdateOldKeyPair handles backward compatibility with the old keypair names.
// If the old public and private keys both exist they are renamed to the new name.
// The return value is non-nil only if the rename happens.
@ -298,6 +314,25 @@ func checkAndUpdateOldKeyPair(sshContext ssh.Context) *ssh.KeyPair {
return nil
}
func hasAnyPrivateKeyForUploadedPublicKeys(
ctx context.Context,
sshContext ssh.Context,
apiClient apiClient,
) bool {
user, err := apiClient.GetUser(ctx)
if err != nil {
return false
}
publicKeys, err := apiClient.AuthorizedKeys(ctx, user.Login)
if err != nil {
return false
}
hasKey, _ := sshContext.HasAnyMatchingPrivateKey(publicKeys)
return hasKey
}
func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

View file

@ -3,10 +3,12 @@ package ssh
import (
"errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"github.com/cli/cli/v2/internal/config"
"github.com/cli/cli/v2/internal/run"
@ -34,6 +36,64 @@ func (c *Context) LocalPublicKeys() ([]string, error) {
return filepath.Glob(filepath.Join(sshDir, "*.pub"))
}
func (c *Context) HasPrivateKey(keyName string) bool {
sshDir, err := c.sshDir()
if err != nil {
return false
}
keyFile := filepath.Join(sshDir, keyName)
_, err = os.Stat(keyFile)
return err == nil
}
func (c *Context) HasAnyMatchingPrivateKey(publicKeys []string) (bool, error) {
keygenExe, err := c.findKeygen()
if err != nil {
return false, fmt.Errorf("could not find ssh-keygen: %w", err)
}
sshDir, err := c.sshDir()
if err != nil {
// If .ssh doesn't exist it's not an error, there are just no keys to match
return false, nil
}
files, err := ioutil.ReadDir(sshDir)
if err != nil {
return false, fmt.Errorf("could not read .ssh directory: %w", err)
}
for _, file := range files {
if file.IsDir() {
continue
}
fullPath := filepath.Join(sshDir, file.Name())
keygenCmd := exec.Command(keygenExe, "-y", "-f", fullPath)
keygenRunnable := run.PrepareCmd(keygenCmd)
outputBytes, err := keygenRunnable.Output()
if err != nil {
// Just move on, it probably wasn't a key file
continue
}
fullPublicKey := string(outputBytes)
typeAndKey := strings.SplitN(fullPublicKey, " ", 3)[:2]
publicKeyWithoutComment := strings.Join(typeAndKey, " ")
for _, publicKey := range publicKeys {
if publicKeyWithoutComment == publicKey {
return true, nil
}
}
}
return false, nil
}
func (c *Context) HasKeygen() bool {
_, err := c.findKeygen()
return err == nil