Merge pull request #6304 from cli/cmbrose/no-ssh-api-call

Remove need for uploaded SSH keys in `cs ssh` and `cs cp`
This commit is contained in:
Caleb Brose 2022-09-20 17:16:52 -05:00 committed by GitHub
commit 7ae7550e83
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 209 additions and 443 deletions

View file

@ -13,7 +13,6 @@ package api
// - github.GetUser(github.Client)
// - github.GetRepository(Client)
// - github.ReadFile(Client, nwo, branch, path) // was GetCodespaceRepositoryContents
// - github.AuthorizedKeys(Client, user)
// - codespaces.Create(Client, user, repo, sku, branch, location)
// - codespaces.Delete(Client, user, token, name)
// - codespaces.Get(Client, token, owner, name)
@ -1060,44 +1059,6 @@ func (a *API) GetCodespaceRepositoryContents(ctx context.Context, codespace *Cod
return decoded, nil
}
// AuthorizedKeys returns the public keys (in ~/.ssh/authorized_keys
// format) registered by the specified GitHub user.
func (a *API) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
url := fmt.Sprintf("%s/%s.keys", a.githubServer, user)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := a.do(ctx, req, "/user.keys")
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned %s", resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
allKeys := string(b)
var splitKeys []string
for _, key := range strings.Split(allKeys, "\n") {
key = strings.TrimSpace(key)
if key == "" {
continue
}
splitKeys = append(splitKeys, key)
}
return splitKeys, nil
}
// do executes the given request and returns the response. It creates an
// opentracing span to track the length of the request.
func (a *API) do(ctx context.Context, req *http.Request, spanName string) (*http.Response, error) {

View file

@ -101,7 +101,6 @@ func testingCodeApp() *App {
}
func testCodeApiMock() *apiClientMock {
user := &api.User{Login: "monalisa"}
testingCodespace := &api.Codespace{
Name: "monalisa-cli-cli-abcdef",
WebURL: "https://monalisa-cli-cli-abcdef.github.dev",
@ -118,11 +117,5 @@ func testCodeApiMock() *apiClientMock {
}
return testingCodespace, nil
},
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
}
}

View file

@ -70,13 +70,6 @@ type liveshareSession interface {
// Connects to a codespace using Live Share and returns that session
func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App, debug bool, debugFile string) (session liveshareSession, err error) {
// While connecting, ensure in the background that the user has keys installed.
// That lets us report a more useful error message if they don't.
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient)
}()
liveshareLogger := noopLogger()
if debug {
debugLogger, err := newFileLogger(debugFile)
@ -91,9 +84,6 @@ func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App
session, err = codespaces.ConnectToLiveshare(ctx, a, liveshareLogger, a.apiClient, codespace)
if err != nil {
if authErr := <-authkeys; authErr != nil {
return nil, fmt.Errorf("failed to fetch authorization keys: %w", authErr)
}
return nil, fmt.Errorf("failed to connect to Live Share: %w", err)
}
@ -102,7 +92,6 @@ func startLiveShareSession(ctx context.Context, codespace *api.Codespace, a *App
//go:generate moq -fmt goimports -rm -skip-ensure -out mock_api.go . apiClient
type apiClient interface {
GetUser(ctx context.Context) (*api.User, error)
GetCodespace(ctx context.Context, name string, includeConnection bool) (*api.Codespace, error)
GetOrgMemberCodespace(ctx context.Context, orgName string, userName string, codespaceName string) (*api.Codespace, error)
ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error)
@ -112,7 +101,6 @@ type apiClient interface {
CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
EditCodespace(ctx context.Context, codespaceName string, params *api.EditCodespaceParams) (*api.Codespace, error)
GetRepository(ctx context.Context, nwo string) (*api.Repository, error)
AuthorizedKeys(ctx context.Context, user string) ([]string, error)
GetCodespacesMachines(ctx context.Context, repoID int, branch, location string) ([]*api.Machine, error)
GetCodespaceRepositoryContents(ctx context.Context, codespace *api.Codespace, path string) ([]byte, error)
ListDevContainers(ctx context.Context, repoID int, branch string, limit int) (devcontainers []api.DevContainerEntry, err error)
@ -237,25 +225,6 @@ func ask(qs []*survey.Question, response interface{}) error {
return err
}
// checkAuthorizedKeys reports an error if the user has not registered any SSH keys;
// see https://github.com/cli/cli/v2/issues/166#issuecomment-921769703.
// The check is not required for security but it improves the error message.
func checkAuthorizedKeys(ctx context.Context, client apiClient) error {
user, err := client.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
keys, err := client.AuthorizedKeys(ctx, user.Login)
if err != nil {
return fmt.Errorf("failed to read GitHub-authorized SSH keys for %s: %w", user, err)
}
if len(keys) == 0 {
return fmt.Errorf("user %s has no GitHub-authorized SSH keys", user)
}
return nil // success
}
var ErrTooManyArgs = errors.New("the command accepts no arguments")
func noArgsConstraint(cmd *cobra.Command, args []string) error {

View file

@ -15,7 +15,6 @@ import (
)
func TestDelete(t *testing.T) {
user := &api.User{Login: "hubot"}
now, _ := time.Parse(time.RFC3339, "2021-09-22T00:00:00Z")
daysAgo := func(n int) string {
return now.Add(time.Hour * -time.Duration(24*n)).Format(time.RFC3339)
@ -207,9 +206,6 @@ func TestDelete(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
apiMock := &apiClientMock{
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
DeleteCodespaceFunc: func(_ context.Context, name string, orgName string, userName string) error {
if tt.deleteErr != nil {
return tt.deleteErr

View file

@ -21,7 +21,6 @@ func TestPendingOperationDisallowsLogs(t *testing.T) {
}
func testingLogsApp() *App {
user := &api.User{Login: "monalisa"}
disabledCodespace := &api.Codespace{
Name: "disabledCodespace",
PendingOperation: true,
@ -34,12 +33,6 @@ func testingLogsApp() *App {
}
return nil, nil
},
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
}
ios, _, _, _ := iostreams.Test()

View file

@ -16,9 +16,6 @@ import (
//
// // make and configure a mocked apiClient
// mockedapiClient := &apiClientMock{
// AuthorizedKeysFunc: func(ctx context.Context, user string) ([]string, error) {
// panic("mock out the AuthorizedKeys method")
// },
// CreateCodespaceFunc: func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
// panic("mock out the CreateCodespace method")
// },
@ -49,9 +46,6 @@ import (
// GetRepositoryFunc: func(ctx context.Context, nwo string) (*api.Repository, error) {
// panic("mock out the GetRepository method")
// },
// GetUserFunc: func(ctx context.Context) (*api.User, error) {
// panic("mock out the GetUser method")
// },
// ListCodespacesFunc: func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) {
// panic("mock out the ListCodespaces method")
// },
@ -71,9 +65,6 @@ import (
//
// }
type apiClientMock struct {
// AuthorizedKeysFunc mocks the AuthorizedKeys method.
AuthorizedKeysFunc func(ctx context.Context, user string) ([]string, error)
// CreateCodespaceFunc mocks the CreateCodespace method.
CreateCodespaceFunc func(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error)
@ -104,9 +95,6 @@ type apiClientMock struct {
// GetRepositoryFunc mocks the GetRepository method.
GetRepositoryFunc func(ctx context.Context, nwo string) (*api.Repository, error)
// GetUserFunc mocks the GetUser method.
GetUserFunc func(ctx context.Context) (*api.User, error)
// ListCodespacesFunc mocks the ListCodespaces method.
ListCodespacesFunc func(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error)
@ -121,13 +109,6 @@ type apiClientMock struct {
// calls tracks calls to the methods.
calls struct {
// AuthorizedKeys holds details about calls to the AuthorizedKeys method.
AuthorizedKeys []struct {
// Ctx is the ctx argument value.
Ctx context.Context
// User is the user argument value.
User string
}
// CreateCodespace holds details about calls to the CreateCodespace method.
CreateCodespace []struct {
// Ctx is the ctx argument value.
@ -218,11 +199,6 @@ type apiClientMock struct {
// Nwo is the nwo argument value.
Nwo string
}
// GetUser holds details about calls to the GetUser method.
GetUser []struct {
// Ctx is the ctx argument value.
Ctx context.Context
}
// ListCodespaces holds details about calls to the ListCodespaces method.
ListCodespaces []struct {
// Ctx is the ctx argument value.
@ -260,7 +236,6 @@ type apiClientMock struct {
UserName string
}
}
lockAuthorizedKeys sync.RWMutex
lockCreateCodespace sync.RWMutex
lockDeleteCodespace sync.RWMutex
lockEditCodespace sync.RWMutex
@ -271,48 +246,12 @@ type apiClientMock struct {
lockGetCodespacesMachines sync.RWMutex
lockGetOrgMemberCodespace sync.RWMutex
lockGetRepository sync.RWMutex
lockGetUser sync.RWMutex
lockListCodespaces sync.RWMutex
lockListDevContainers sync.RWMutex
lockStartCodespace sync.RWMutex
lockStopCodespace sync.RWMutex
}
// AuthorizedKeys calls AuthorizedKeysFunc.
func (mock *apiClientMock) AuthorizedKeys(ctx context.Context, user string) ([]string, error) {
if mock.AuthorizedKeysFunc == nil {
panic("apiClientMock.AuthorizedKeysFunc: method is nil but apiClient.AuthorizedKeys was just called")
}
callInfo := struct {
Ctx context.Context
User string
}{
Ctx: ctx,
User: user,
}
mock.lockAuthorizedKeys.Lock()
mock.calls.AuthorizedKeys = append(mock.calls.AuthorizedKeys, callInfo)
mock.lockAuthorizedKeys.Unlock()
return mock.AuthorizedKeysFunc(ctx, user)
}
// AuthorizedKeysCalls gets all the calls that were made to AuthorizedKeys.
// Check the length with:
// len(mockedapiClient.AuthorizedKeysCalls())
func (mock *apiClientMock) AuthorizedKeysCalls() []struct {
Ctx context.Context
User string
} {
var calls []struct {
Ctx context.Context
User string
}
mock.lockAuthorizedKeys.RLock()
calls = mock.calls.AuthorizedKeys
mock.lockAuthorizedKeys.RUnlock()
return calls
}
// CreateCodespace calls CreateCodespaceFunc.
func (mock *apiClientMock) CreateCodespace(ctx context.Context, params *api.CreateCodespaceParams) (*api.Codespace, error) {
if mock.CreateCodespaceFunc == nil {
@ -703,37 +642,6 @@ func (mock *apiClientMock) GetRepositoryCalls() []struct {
return calls
}
// GetUser calls GetUserFunc.
func (mock *apiClientMock) GetUser(ctx context.Context) (*api.User, error) {
if mock.GetUserFunc == nil {
panic("apiClientMock.GetUserFunc: method is nil but apiClient.GetUser was just called")
}
callInfo := struct {
Ctx context.Context
}{
Ctx: ctx,
}
mock.lockGetUser.Lock()
mock.calls.GetUser = append(mock.calls.GetUser, callInfo)
mock.lockGetUser.Unlock()
return mock.GetUserFunc(ctx)
}
// GetUserCalls gets all the calls that were made to GetUser.
// Check the length with:
// len(mockedapiClient.GetUserCalls())
func (mock *apiClientMock) GetUserCalls() []struct {
Ctx context.Context
} {
var calls []struct {
Ctx context.Context
}
mock.lockGetUser.RLock()
calls = mock.calls.GetUser
mock.lockGetUser.RUnlock()
return calls
}
// ListCodespaces calls ListCodespacesFunc.
func (mock *apiClientMock) ListCodespaces(ctx context.Context, opts api.ListCodespacesOptions) ([]*api.Codespace, error) {
if mock.ListCodespacesFunc == nil {

View file

@ -247,7 +247,6 @@ func TestPendingOperationDisallowsForwardPorts(t *testing.T) {
}
func testingPortsApp() *App {
user := &api.User{Login: "monalisa"}
disabledCodespace := &api.Codespace{
Name: "disabledCodespace",
PendingOperation: true,
@ -260,12 +259,6 @@ func testingPortsApp() *App {
}
return nil, nil
},
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
}
ios, _, _, _ := iostreams.Test()

View file

@ -34,6 +34,8 @@ import (
const automaticPrivateKeyNameOld = "codespaces"
const automaticPrivateKeyName = "codespaces.auto"
var errKeyFileNotFound = errors.New("SSH key file does not exist")
type sshOptions struct {
codespace string
profile string
@ -148,18 +150,14 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
sshContext := ssh.Context{}
startSSHOptions := liveshare.StartSSHServerOptions{}
useAutoKeys, err := useAutomaticSSHKeys(ctx, sshContext, a.apiClient, args, opts)
keyPair, shouldAddArg, err := selectSSHKeys(ctx, sshContext, args, opts)
if err != nil {
return fmt.Errorf("checking ssh key configuration: %w", err)
return fmt.Errorf("selecting ssh keys: %w", err)
}
if useAutoKeys {
keyPair, err := setupAutomaticSSHKeys(sshContext)
if err != nil {
return fmt.Errorf("failed to generate ssh keys: %w", err)
}
startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath
startSSHOptions.UserPublicKeyFile = keyPair.PublicKeyPath
if shouldAddArg {
// For both cp and ssh, flags need to come first in the args (before a command in ssh and files in cp), so prepend this flag
args = append([]string{"-i", keyPair.PrivateKeyPath}, args...)
}
@ -236,48 +234,71 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
}
}
func useAutomaticSSHKeys(
// selectSSHKeys evaluates available key pairs and select which should be used to connect to the codespace
// using the precedence rules below. If there is no error, a keypair is always returned and additionally a
// bool flag is returned to specify if the private key need be appended to the ssh arguments (it doesn't need
// to be if the key was selected from a -i argument).
//
// Precedence rules:
// 1. Key which is specified by -i
// 2. Automatic key, if it already exists
// 3. First valid keypair in ssh config (according to ssh -G)
// 4. Automatic key, newly created
func selectSSHKeys(
ctx context.Context,
sshContext ssh.Context,
apiClient apiClient,
args []string,
opts sshOptions,
) (bool, error) {
) (*ssh.KeyPair, bool, error) {
customConfigPath := ""
for i := 0; i < len(args); i += 1 {
arg := args[i]
if arg == "-i" {
if i+1 < len(args) && path.Base(args[i+1]) == automaticPrivateKeyName {
return true, nil
if i+1 >= len(args) {
return nil, false, errors.New("missing value to -i argument")
}
// User specified a custom identity file so just trust it is correct
return false, nil
// User manually specified an identity file so just trust it is correct
return &ssh.KeyPair{
PrivateKeyPath: args[i+1],
PublicKeyPath: args[i+1] + ".pub",
}, false, nil
}
if arg == "-F" && i < len(args)-1 {
if arg == "-F" && i+1 < len(args) {
// ssh only pays attention to that last specified -F value, so it's correct to overwrite here
customConfigPath = args[i+1]
}
}
if automaticKeySSHKeysExist(sshContext) {
if autoKeyPair := automaticSSHKeyPair(sshContext); autoKeyPair != nil {
// If the automatic keys already exist, just use them
return true, nil
return autoKeyPair, true, nil
}
// If there is a public key uploaded which matches a configured
// private key, there is no need for automatic key generation
hasPublicKey, err := hasUploadedPublicKeyForConfig(ctx, apiClient, customConfigPath, opts.profile)
keyPair, err := firstConfiguredKeyPair(ctx, customConfigPath, opts.profile)
if err != nil {
if !errors.Is(err, errKeyFileNotFound) {
return nil, false, fmt.Errorf("checking configured keys: %w", err)
}
return !hasPublicKey, err
// no valid key in ssh config, generate one
keyPair, err = generateAutomaticSSHKeys(sshContext)
if err != nil {
return nil, false, fmt.Errorf("generating automatic keypair: %w", err)
}
}
return keyPair, true, nil
}
func automaticKeySSHKeysExist(sshContext ssh.Context) bool {
// automaticSSHKeyPair returns the paths to the automatic key pair files, if they both exist
func automaticSSHKeyPair(sshContext ssh.Context) *ssh.KeyPair {
publicKeys, err := sshContext.LocalPublicKeys()
if err != nil {
return false
// The error would be that the .ssh dir doesn't exist, which just means that the keypair also doesn't exist
return nil
}
for _, publicKey := range publicKeys {
@ -288,13 +309,18 @@ func automaticKeySSHKeysExist(sshContext ssh.Context) bool {
privateKey := strings.TrimSuffix(publicKey, ".pub")
_, err := os.Stat(privateKey)
return err == nil
if err == nil {
return &ssh.KeyPair{
PrivateKeyPath: privateKey,
PublicKeyPath: publicKey,
}
}
}
return false
return nil
}
func setupAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) {
func generateAutomaticSSHKeys(sshContext ssh.Context) (*ssh.KeyPair, error) {
keyPair := checkAndUpdateOldKeyPair(sshContext)
if keyPair != nil {
return keyPair, nil
@ -355,76 +381,13 @@ func checkAndUpdateOldKeyPair(sshContext ssh.Context) *ssh.KeyPair {
return nil
}
// hasUploadedValidCustomKey checks which private keys are used in ssh configration and
// compares them to uploaded public keys to see if there is a match
func hasUploadedPublicKeyForConfig(
ctx context.Context,
apiClient apiClient,
customConfigFile string,
customHost string,
) (bool, error) {
configuredPrivateKeyPaths, err := getConfiguredPrivateKeys(ctx, customConfigFile, customHost)
if err != nil {
return false, fmt.Errorf("getting local ssh keys: %w", err)
}
configuredPublicKeys := make(map[string]bool)
for _, privateKeyPath := range configuredPrivateKeyPaths {
publicKeyPath := privateKeyPath + ".pub"
publicKeyContent, err := os.ReadFile(publicKeyPath)
if err != nil {
// The default configuration includes standard keys like id_rsa or id_ed25519,
// but these may not actually exist so just skip them
continue
}
parts := strings.SplitN(string(publicKeyContent), " ", 3)
if len(parts) < 2 {
// Unexpected format, skip it
continue
}
publicKeyWithoutComment := strings.Join(parts[:2], " ")
configuredPublicKeys[publicKeyWithoutComment] = true
}
if len(configuredPublicKeys) == 0 {
// There are no local private keys which ssh would use
return false, nil
}
user, err := apiClient.GetUser(ctx)
if err != nil {
return false, fmt.Errorf("fetching user account: %w", err)
}
uploadedPublicKeys, err := apiClient.AuthorizedKeys(ctx, user.Login)
if err != nil {
return false, fmt.Errorf("fetching known ssh keys: %w", err)
}
if len(uploadedPublicKeys) == 0 {
return false, nil
}
for _, uploadedPublicKey := range uploadedPublicKeys {
if configuredPublicKeys[uploadedPublicKey] {
return true, nil
}
}
return false, nil
}
// getConfiguredPrivateKeys reads the effective configuration for a localhost
// connection and returns all private keys which would be tried for authentication
func getConfiguredPrivateKeys(
// firstConfiguredKeyPair reads the effective configuration for a localhost
// connection and returns the first valid key pair which would be tried for authentication
func firstConfiguredKeyPair(
ctx context.Context,
customConfigFile string,
customHost string,
) ([]string, error) {
) (*ssh.KeyPair, error) {
sshExe, err := safeexec.LookPath("ssh")
if err != nil {
return nil, fmt.Errorf("could not find ssh executable: %w", err)
@ -449,27 +412,55 @@ func getConfiguredPrivateKeys(
return nil, fmt.Errorf("could not load ssh configuration: %w", err)
}
var privateKeyPaths []string
userHomeDir, _ := os.UserHomeDir()
configLines := strings.Split(string(configBytes), "\n")
for _, line := range configLines {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "identityfile ") {
path := strings.SplitN(line, " ", 2)[1]
privateKeyPath := strings.SplitN(line, " ", 2)[1]
if strings.HasPrefix(path, "~") {
// os.Stat can't handle ~, so convert it to the real path
path = strings.Replace(path, "~", userHomeDir, 1)
keypair, err := keypairForPrivateKey(privateKeyPath)
if errors.Is(err, errKeyFileNotFound) {
continue
}
if err != nil {
return nil, fmt.Errorf("loading ssh config: %w", err)
}
privateKeyPaths = append(privateKeyPaths, path)
return keypair, nil
}
}
return privateKeyPaths, nil
return nil, errKeyFileNotFound
}
// keypairForPrivateKey returns the KeyPair with the specified private key if it and the public key both exist
func keypairForPrivateKey(privateKeyPath string) (*ssh.KeyPair, error) {
if strings.HasPrefix(privateKeyPath, "~") {
userHomeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("getting home dir: %w", err)
}
// os.Stat can't handle ~, so convert it to the real path
privateKeyPath = strings.Replace(privateKeyPath, "~", userHomeDir, 1)
}
// The default configuration includes standard keys like id_rsa or id_ed25519,
// but these may not actually exist
if _, err := os.Stat(privateKeyPath); err != nil {
return nil, errKeyFileNotFound
}
publicKeyPath := privateKeyPath + ".pub"
if _, err := os.Stat(publicKeyPath); err != nil {
return nil, errKeyFileNotFound
}
return &ssh.KeyPair{
PrivateKeyPath: privateKeyPath,
PublicKeyPath: publicKeyPath,
}, nil
}
func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err error) {
@ -535,12 +526,6 @@ func (a *App) printOpenSSHConfig(ctx context.Context, opts sshOptions) (err erro
close(sshUsers)
}()
// While the above fetches are running, ensure that the user has keys installed.
// That lets us report a more useful error message if they don't.
if err = checkAuthorizedKeys(ctx, a.apiClient); err != nil {
return err
}
t, err := template.New("ssh_config").Parse(heredoc.Doc(`
Host cs.{{.Name}}.{{.EscapedRef}}
User {{.SSHUser}}

View file

@ -5,7 +5,6 @@ import (
"fmt"
"io/ioutil"
"os"
"path"
"path/filepath"
"strings"
"testing"
@ -27,11 +26,11 @@ func TestPendingOperationDisallowsSSH(t *testing.T) {
}
}
func TestAutomaticSSHKeyPairs(t *testing.T) {
func TestGenerateAutomaticSSHKeys(t *testing.T) {
tests := []struct {
// These files exist when calling setupAutomaticSSHKeys
// These files exist when calling generateAutomaticSSHKeys
existingFiles []string
// These files should exist after setupAutomaticSSHKeys finishes
// These files should exist after generateAutomaticSSHKeys finishes
wantFinalFiles []string
}{
// Basic case: no existing keys, they should be created
@ -82,12 +81,12 @@ func TestAutomaticSSHKeyPairs(t *testing.T) {
f.Close()
}
keyPair, err := setupAutomaticSSHKeys(sshContext)
keyPair, err := generateAutomaticSSHKeys(sshContext)
if err != nil {
t.Errorf("Unexpected error from setupAutomaticSSHKeys: %v", err)
t.Errorf("Unexpected error from generateAutomaticSSHKeys: %v", err)
}
if keyPair == nil {
t.Fatal("Unexpected nil KeyPair from setupAutomaticSSHKeys")
t.Fatal("Unexpected nil KeyPair from generateAutomaticSSHKeys")
}
if !strings.HasSuffix(keyPair.PrivateKeyPath, automaticPrivateKeyName) {
t.Errorf("Expected private key path %v, got %v", automaticPrivateKeyName, keyPair.PrivateKeyPath)
@ -99,7 +98,7 @@ func TestAutomaticSSHKeyPairs(t *testing.T) {
// Check that all the expected files are present
for _, file := range tt.wantFinalFiles {
if _, err := os.Stat(filepath.Join(dir, file)); err != nil {
t.Errorf("Want file %q to exist after setupAutomaticSSHKeys but it doesn't", file)
t.Errorf("Want file %q to exist after generateAutomaticSSHKeys but it doesn't", file)
}
}
@ -119,177 +118,152 @@ func TestAutomaticSSHKeyPairs(t *testing.T) {
}
if !isWantedFile {
t.Errorf("Unexpected file %q exists after setupAutomaticSSHKeys", filename)
t.Errorf("Unexpected file %q exists after generateAutomaticSSHKeys", filename)
}
}
}
}
func TestUseAutomaticSSHKeysIdentityFileArg(t *testing.T) {
func TestSelectSSHKeys(t *testing.T) {
tests := []struct {
identityFileArg string
wantResult bool
sshDirFiles []string
sshConfigKeys []string
sshArgs []string
profileOpt string
wantKeyPair *ssh.KeyPair
wantShouldAddArg bool
}{
{"custom-private-key", false},
{automaticPrivateKeyName, true},
{"", false}, // Edge case check for missing arg value
}
for _, tt := range tests {
t.Logf("%v", tt.identityFileArg)
args := []string{"-i"}
if tt.identityFileArg != "" {
args = append(args, tt.identityFileArg)
}
result, err := useAutomaticSSHKeys(context.Background(), ssh.Context{}, nil, args, sshOptions{})
if err != nil {
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
}
if result != tt.wantResult {
t.Errorf("Want useAutomaticSSHKeys to be %v, got %v", tt.wantResult, result)
}
}
}
func TestUseAutomaticSSHKeysWithAutoKeysExist(t *testing.T) {
dir := t.TempDir()
f, err := os.Create(path.Join(dir, automaticPrivateKeyName))
if err != nil {
t.Errorf("Failed to create test private key: %v", err)
}
f.Close()
f, err = os.Create(path.Join(dir, automaticPrivateKeyName+".pub"))
if err != nil {
t.Errorf("Failed to create test public key: %v", err)
}
f.Close()
sshContext := ssh.Context{
ConfigDir: dir,
}
result, err := useAutomaticSSHKeys(context.Background(), sshContext, nil, nil, sshOptions{})
if err != nil {
t.Errorf("Unexpected error from useAutomaticSSHKeys: %v", err)
}
if result != true {
t.Errorf("Want useAutomaticSSHKeys to be true, got false")
}
}
func TestHasUploadedPublicKeyForConfig(t *testing.T) {
type testLocalKeyPair struct {
privateKeyFile string
publicKeyContent string
}
tests := []struct {
apiAuthorizedPublicKeys []string
localKeyPairs []testLocalKeyPair
wantResult bool
}{
// Failure tests
// -i tests
{
// No API keys and no local keys
wantResult: false,
sshArgs: []string{"-i", "custom-private-key"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key", PublicKeyPath: "custom-private-key.pub"},
},
{
// Has API keys, but no local keys
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key"},
wantResult: false,
sshArgs: []string{"-i", automaticPrivateKeyName},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
},
{
// No API keys, but has local keys
localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-key"}},
wantResult: false,
},
{
// API keys and local keys, but not matching
apiAuthorizedPublicKeys: []string{"ssh-rsa test-api-key"},
localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-local-key"}},
wantResult: false,
// Edge case check for missing arg value
sshArgs: []string{"-i"},
},
// Successful tests
// Auto key exists tests
{
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key"},
localKeyPairs: []testLocalKeyPair{{"keyfile", "ssh-rsa test-key"}},
wantResult: true,
sshDirFiles: []string{automaticPrivateKeyName, automaticPrivateKeyName + ".pub"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
wantShouldAddArg: true,
},
{
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key-2"},
localKeyPairs: []testLocalKeyPair{{"keyfile1", "ssh-rsa test-key-1"}},
wantResult: true,
},
{
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1"},
localKeyPairs: []testLocalKeyPair{{"keyfile1", "ssh-rsa test-key-1"}, {"keyfile2", "ssh-rsa test-key-2"}},
wantResult: true,
},
{
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key-2"},
localKeyPairs: []testLocalKeyPair{{"keyfile3", "ssh-rsa test-key-3"}, {"keyfile2", "ssh-rsa test-key-2"}},
wantResult: true,
sshDirFiles: []string{automaticPrivateKeyName, automaticPrivateKeyName + ".pub", "custom-private-key", "custom-private-key.pub"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
wantShouldAddArg: true,
},
// Extra case - local key contain comments
// SSH config tests
{
apiAuthorizedPublicKeys: []string{"ssh-rsa test-key-1", "ssh-rsa test-key"},
localKeyPairs: []testLocalKeyPair{{"keyfile3", "ssh-rsa test-key a comment on the key"}},
wantResult: true,
sshDirFiles: []string{"custom-private-key", "custom-private-key.pub"},
sshConfigKeys: []string{"custom-private-key"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key", PublicKeyPath: "custom-private-key.pub"},
wantShouldAddArg: true,
},
{
// 2 pairs, but only 1 is configured
sshDirFiles: []string{"custom-private-key", "custom-private-key.pub", "custom-private-key-2", "custom-private-key-2.pub"},
sshConfigKeys: []string{"custom-private-key-2"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key-2", PublicKeyPath: "custom-private-key-2.pub"},
wantShouldAddArg: true,
},
{
// 2 pairs, but only 1 has both public and private
sshDirFiles: []string{"custom-private-key", "custom-private-key-2", "custom-private-key-2.pub"},
sshConfigKeys: []string{"custom-private-key", "custom-private-key-2"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: "custom-private-key-2", PublicKeyPath: "custom-private-key-2.pub"},
wantShouldAddArg: true,
},
// Automatic key tests
{
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
wantShouldAddArg: true,
},
{
// Renames old key pair to new
sshDirFiles: []string{automaticPrivateKeyNameOld, automaticPrivateKeyNameOld + ".pub"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
wantShouldAddArg: true,
},
{
// Other key is configured, but doesn't exist
sshConfigKeys: []string{"custom-private-key"},
wantKeyPair: &ssh.KeyPair{PrivateKeyPath: automaticPrivateKeyName, PublicKeyPath: automaticPrivateKeyName + ".pub"},
wantShouldAddArg: true,
},
}
for _, tt := range tests {
t.Logf("%+v", tt)
sshDir := t.TempDir()
sshContext := ssh.Context{ConfigDir: sshDir}
mockApi := &apiClientMock{
GetUserFunc: func(ctx context.Context) (*api.User, error) {
return &api.User{Login: "test"}, nil
},
AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) {
return tt.apiAuthorizedPublicKeys, nil
},
}
dir := t.TempDir()
configPath := path.Join(dir, "test-config")
configContent := ""
for _, pair := range tt.localKeyPairs {
configContent += fmt.Sprintf("IdentityFile %s\n", path.Join(dir, pair.privateKeyFile))
err := os.WriteFile(path.Join(dir, pair.privateKeyFile+".pub"), []byte(pair.publicKeyContent), 0666)
for _, file := range tt.sshDirFiles {
f, err := os.Create(filepath.Join(sshDir, file))
if err != nil {
t.Fatalf("could not write test public key file %v", err)
t.Errorf("Failed to create test ssh dir file %q: %v", file, err)
}
f.Close()
}
if tt.sshConfigKeys != nil {
configPath := filepath.Join(sshDir, "test-config")
configContent := ""
for _, key := range tt.sshConfigKeys {
configContent += fmt.Sprintf("IdentityFile %s\n", filepath.Join(sshDir, key))
}
err := os.WriteFile(configPath, []byte(configContent), 0666)
if err != nil {
t.Fatalf("could not write test config %v", err)
}
tt.sshArgs = append(tt.sshArgs, "-F", configPath)
}
gotKeyPair, gotShouldAddArg, err := selectSSHKeys(context.Background(), sshContext, tt.sshArgs, sshOptions{profile: tt.profileOpt})
if tt.wantKeyPair == nil {
if err == nil {
t.Errorf("Expected error from selectSSHKeys but got nil")
}
continue
}
err := os.WriteFile(configPath, []byte(configContent), 0666)
if err != nil {
t.Fatalf("could not write test config %v", err)
t.Errorf("Unexpected error from selectSSHKeys: %v", err)
continue
}
result, err := hasUploadedPublicKeyForConfig(context.Background(), mockApi, configPath, "")
if err != nil {
t.Errorf("Unexpected error from hasUploadedPublicKeyForConfig: %v", err)
if gotKeyPair == nil {
t.Errorf("Expected non-nil result from selectSSHKeys but got nil")
continue
}
if result != tt.wantResult {
t.Errorf("Want hasUploadedPublicKeyForConfig to be %v, got %v", tt.wantResult, result)
if gotShouldAddArg != tt.wantShouldAddArg {
t.Errorf("Got wrong shouldAddArg value from selectSSHKeys, wanted %v got %v", tt.wantShouldAddArg, gotShouldAddArg)
continue
}
// Strip the dir (sshDir) from the gotKeyPair paths so that they match wantKeyPair (which doesn't know the directory)
gotKeyPair.PrivateKeyPath = filepath.Base(gotKeyPair.PrivateKeyPath)
gotKeyPair.PublicKeyPath = filepath.Base(gotKeyPair.PublicKeyPath)
if fmt.Sprintf("%v", gotKeyPair) != fmt.Sprintf("%v", tt.wantKeyPair) {
t.Errorf("Want selectSSHKeys result to be %v, got %v", tt.wantKeyPair, gotKeyPair)
}
}
}
func testingSSHApp() *App {
user := &api.User{Login: "monalisa"}
disabledCodespace := &api.Codespace{
Name: "disabledCodespace",
PendingOperation: true,
@ -302,12 +276,6 @@ func testingSSHApp() *App {
}
return nil, nil
},
GetUserFunc: func(_ context.Context) (*api.User, error) {
return user, nil
},
AuthorizedKeysFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
}
ios, _, _, _ := iostreams.Test()