From 42e47a98d7b51d0ea70ee1bcc5a09392a49080fc Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 16 Sep 2021 15:22:47 -0400 Subject: [PATCH] add docs, simplify map, error on invalid args --- cmd/ghcs/logs.go | 5 ++- internal/codespaces/ssh.go | 58 ++++++++++++++++----------------- internal/codespaces/ssh_test.go | 13 +++++++- internal/codespaces/states.go | 6 +++- 4 files changed, 50 insertions(+), 32 deletions(-) diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index ccfb46236..725a243b0 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -82,9 +82,12 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow } dst := fmt.Sprintf("%s@localhost", sshUser) - cmd := codespaces.NewRemoteCommand( + cmd, err := codespaces.NewRemoteCommand( ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) + if err != nil { + return fmt.Errorf("remote command: %w", err) + } tunnelClosed := make(chan error, 1) go func() { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 0eee21286..b58741e34 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -2,6 +2,7 @@ package codespaces import ( "context" + "fmt" "os" "os/exec" "strconv" @@ -12,7 +13,10 @@ import ( // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs) + cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs) + if err != nil { + return fmt.Errorf("failed to create ssh command: %w", err) + } if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,17 +27,27 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs) - return cmd +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) { + cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs) + return cmd, err } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs, command := parseSSHArgs(cmdArgs) + // The ssh command syntax is: ssh [flags] user@host command [args...] + // There is no way to specify the user@host destination as a flag. + // Unfortunately, that means we need to know which user-provided words are + // SSH flags and which are command arguments so that we can place + // them before or after the destination, and that means we need to know all + // the flags and their arities. + cmdArgs, command, err := parseSSHArgs(cmdArgs) + if err != nil { + return nil, []string{}, err + } + cmdArgs = append(cmdArgs, connArgs...) cmdArgs = append(cmdArgs, "-C") // Compression cmdArgs = append(cmdArgs, dst) // user@host @@ -47,30 +61,12 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd, connArgs + return cmd, connArgs, nil } -var sshArgumentFlags = map[string]bool{ - "-b": true, - "-c": true, - "-D": true, - "-e": true, - "-F": true, - "-I": true, - "-i": true, - "-L": true, - "-l": true, - "-m": true, - "-O": true, - "-o": true, - "-p": true, - "-R": true, - "-S": true, - "-W": true, - "-w": true, -} +var sshArgumentFlags = "-b-c-D-e-F-I-i-L-l-m-O-o-p-R-S-W-w" -func parseSSHArgs(sshArgs []string) ([]string, string) { +func parseSSHArgs(sshArgs []string) ([]string, string, error) { var ( cmdArgs []string command []string @@ -80,8 +76,12 @@ func parseSSHArgs(sshArgs []string) ([]string, string) { for _, arg := range sshArgs { switch { case strings.HasPrefix(arg, "-"): + if len(command) > 0 { + return []string{}, "", fmt.Errorf("invalid flag after command: %s", arg) + } + cmdArgs = append(cmdArgs, arg) - if _, ok := sshArgumentFlags[arg]; ok { + if strings.Contains(sshArgumentFlags, arg) { flagArgument = true } case flagArgument: @@ -92,5 +92,5 @@ func parseSSHArgs(sshArgs []string) ([]string, string) { } } - return cmdArgs, strings.Join(command, " ") + return cmdArgs, strings.Join(command, " "), nil } diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go index cd92b39a6..2847ffc9f 100644 --- a/internal/codespaces/ssh_test.go +++ b/internal/codespaces/ssh_test.go @@ -48,7 +48,11 @@ func TestParseSSHArgs(t *testing.T) { } for _, tcase := range testCases { - args, command := parseSSHArgs(tcase.Args) + args, command, err := parseSSHArgs(tcase.Args) + if err != nil { + t.Errorf("received unexpected error: %w", err) + } + if len(args) != len(tcase.ParsedArgs) { t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs)) } @@ -62,3 +66,10 @@ func TestParseSSHArgs(t *testing.T) { } } } + +func TestParseSSHArgsError(t *testing.T) { + _, _, err := parseSSHArgs([]string{"-X", "test", "-Y"}) + if err == nil { + t.Error("expected an error for invalid args") + } +} diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index 99683b51d..408f11941 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -89,10 +89,14 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u } func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) { - cmd := NewRemoteCommand( + cmd, err := NewRemoteCommand( ctx, tunnelPort, fmt.Sprintf("%s@localhost", user), "cat /workspaces/.codespaces/shared/postCreateOutput.json", ) + if err != nil { + return nil, fmt.Errorf("remote command: %w", err) + } + stdout := new(bytes.Buffer) cmd.Stdout = stdout if err := cmd.Run(); err != nil {