Initial pass
This commit is contained in:
parent
954689ea9f
commit
f313953642
5 changed files with 131 additions and 24 deletions
|
|
@ -1062,7 +1062,7 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod
|
|||
|
||||
// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys
|
||||
// format) registered by the specified GitHub user.
|
||||
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
|
||||
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
|
||||
url := fmt.Sprintf("%s/%s.keys", a.githubServer, user)
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
|
|
@ -1082,7 +1082,19 @@ func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
|
||||
allKeys := string(b)
|
||||
|
||||
splitKeys := []string{}
|
||||
for _, key := range strings.Split(allKeys, "\n") {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
splitKeys = append(splitKeys, strings.TrimSpace(key))
|
||||
}
|
||||
|
||||
return splitKeys, nil
|
||||
}
|
||||
|
||||
// do executes the given request and returns the response. It creates an
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ type apiClient interface {
|
|||
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
|
||||
EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error)
|
||||
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
|
||||
AuthorizedKeys(ctx context.Context, user string) ([]byte, error)
|
||||
AuthorizedKeys(ctx context.Context, user string) ([]string, error)
|
||||
GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error)
|
||||
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
|
||||
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
|
||||
|
|
|
|||
|
|
@ -279,7 +279,7 @@ type apiClientMock struct {
|
|||
}
|
||||
|
||||
// AuthorizedKeys calls AuthorizedKeysFunc.
|
||||
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]byte, error) {
|
||||
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
|
||||
if mock.AuthorizedKeysFunc == nil {
|
||||
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,12 +133,12 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
sshContext := ssh.Context{}
|
||||
startSSHOptions := liveshare.StartSSHServerOptions{}
|
||||
|
||||
if shouldUseAutomaticSSHKeys(args, opts) {
|
||||
keyPair, err := setupAutomaticSSHKeys(sshContext)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate ssh keys: %w", err)
|
||||
}
|
||||
keyPair, err := setupAutomaticSSHKeys(ctx, sshContext, a.apiClient, args, opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate ssh keys: %w", err)
|
||||
}
|
||||
|
||||
if keyPair != nil {
|
||||
startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath
|
||||
// For both cp and ssh, flags need to come first in the args (before a command in ssh and files in cp), so prepend this flag
|
||||
args = append([]string{"-i", keyPair.PrivateKeyPath}, args...)
|
||||
|
|
@ -216,7 +216,37 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
}
|
||||
}
|
||||
|
||||
func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool {
|
||||
func setupAutomaticSSHKeys(
|
||||
ctx context.Context,
|
||||
sshContext ssh.Context,
|
||||
apiClient apiClient,
|
||||
args []string,
|
||||
opts sshOptions,
|
||||
) (*ssh.KeyPair, error) {
|
||||
if isUsingCustomIdentityOrProfile(args, opts) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keyPair := checkAndUpdateOldKeyPair(sshContext)
|
||||
if keyPair != nil {
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
if !sshContext.HasPrivateKey(automaticPrivateKeyName) {
|
||||
if hasAnyPrivateKeyForUploadedPublicKeys(ctx, sshContext, apiClient) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "")
|
||||
if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
func isUsingCustomIdentityOrProfile(args []string, opts sshOptions) bool {
|
||||
if opts.profile != "" {
|
||||
// The profile may specify the identity file so cautiously don't override anything with that option
|
||||
return false
|
||||
|
|
@ -237,20 +267,6 @@ func shouldUseAutomaticSSHKeys(args []string, opts sshOptions) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func setupAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) {
|
||||
keyPair := checkAndUpdateOldKeyPair(sshContext)
|
||||
if keyPair != nil {
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
keyPair, err := sshContext.GenerateSSHKey(automaticPrivateKeyName, "")
|
||||
if err != nil && !errors.Is(err, ssh.ErrKeyAlreadyExists) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
// checkAndUpdateOldKeyPair handles backward compatibility with the old keypair names.
|
||||
// If the old public and private keys both exist they are renamed to the new name.
|
||||
// The return value is non-nil only if the rename happens.
|
||||
|
|
@ -298,6 +314,25 @@ func checkAndUpdateOldKeyPair(sshContext ssh.Context) *ssh.KeyPair {
|
|||
return nil
|
||||
}
|
||||
|
||||
func hasAnyPrivateKeyForUploadedPublicKeys(
|
||||
ctx context.Context,
|
||||
sshContext ssh.Context,
|
||||
apiClient apiClient,
|
||||
) bool {
|
||||
user, err := apiClient.GetUser(ctx)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
publicKeys, err := apiClient.AuthorizedKeys(ctx, user.Login)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hasKey, _ := sshContext.HasAnyMatchingPrivateKey(publicKeys)
|
||||
return hasKey
|
||||
}
|
||||
|
||||
func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ package ssh
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/cli/cli/v2/internal/config"
|
||||
"github.com/cli/cli/v2/internal/run"
|
||||
|
|
@ -34,6 +36,64 @@ func (c *Context) LocalPublicKeys() ([]string, error) {
|
|||
return filepath.Glob(filepath.Join(sshDir, "*.pub"))
|
||||
}
|
||||
|
||||
func (c *Context) HasPrivateKey(keyName string) bool {
|
||||
sshDir, err := c.sshDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
keyFile := filepath.Join(sshDir, keyName)
|
||||
_, err = os.Stat(keyFile)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (c *Context) HasAnyMatchingPrivateKey(publicKeys []string) (bool, error) {
|
||||
keygenExe, err := c.findKeygen()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not find ssh-keygen: %w", err)
|
||||
}
|
||||
|
||||
sshDir, err := c.sshDir()
|
||||
if err != nil {
|
||||
// If .ssh doesn't exist it's not an error, there are just no keys to match
|
||||
return false, nil
|
||||
}
|
||||
|
||||
files, err := ioutil.ReadDir(sshDir)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("could not read .ssh directory: %w", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(sshDir, file.Name())
|
||||
keygenCmd := exec.Command(keygenExe, "-y", "-f", fullPath)
|
||||
|
||||
keygenRunnable := run.PrepareCmd(keygenCmd)
|
||||
outputBytes, err := keygenRunnable.Output()
|
||||
if err != nil {
|
||||
// Just move on, it probably wasn't a key file
|
||||
continue
|
||||
}
|
||||
|
||||
fullPublicKey := string(outputBytes)
|
||||
typeAndKey := strings.SplitN(fullPublicKey, " ", 3)[:2]
|
||||
publicKeyWithoutComment := strings.Join(typeAndKey, " ")
|
||||
|
||||
for _, publicKey := range publicKeys {
|
||||
if publicKeyWithoutComment == publicKey {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Context) HasKeygen() bool {
|
||||
_, err := c.findKeygen()
|
||||
return err == nil
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue