From 8a0f8b6d1c1834186ca4fcb6da650d34df89b1eb Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 16 Sep 2021 10:32:27 -0400 Subject: [PATCH] parse ssh args and command --- internal/codespaces/ssh.go | 63 ++++++++++++++++++++++++++------ internal/codespaces/ssh_test.go | 64 +++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 10 deletions(-) create mode 100644 internal/codespaces/ssh_test.go diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index 661caecdc..0eee21286 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -12,7 +12,7 @@ import ( // port-forwarding session. It runs until the shell is terminated // (including by cancellation of the context). func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error { - cmd, connArgs := newSSHCommand(ctx, port, destination, "") + cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs) if usingCustomPort { log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " ")) @@ -23,23 +23,21 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat // NewRemoteCommand returns an exec.Cmd that will securely run a shell // command on the remote machine. -func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { - cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) +func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd { + cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs) return cmd } // newSSHCommand populates an exec.Cmd to run a command (or if blank, // an interactive shell) over ssh. -func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) { +func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) { connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"} - cmdArgs := []string{dst, "-C"} // Always use Compression - if command == "" { - // if we are in a shell send X11 and X11Trust - cmdArgs = append(cmdArgs, "-X", "-Y") - } - + cmdArgs, command := parseSSHArgs(cmdArgs) cmdArgs = append(cmdArgs, connArgs...) + cmdArgs = append(cmdArgs, "-C") // Compression + cmdArgs = append(cmdArgs, dst) // user@host + if command != "" { cmdArgs = append(cmdArgs, command) } @@ -51,3 +49,48 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm return cmd, connArgs } + +var sshArgumentFlags = map[string]bool{ + "-b": true, + "-c": true, + "-D": true, + "-e": true, + "-F": true, + "-I": true, + "-i": true, + "-L": true, + "-l": true, + "-m": true, + "-O": true, + "-o": true, + "-p": true, + "-R": true, + "-S": true, + "-W": true, + "-w": true, +} + +func parseSSHArgs(sshArgs []string) ([]string, string) { + var ( + cmdArgs []string + command []string + flagArgument bool + ) + + for _, arg := range sshArgs { + switch { + case strings.HasPrefix(arg, "-"): + cmdArgs = append(cmdArgs, arg) + if _, ok := sshArgumentFlags[arg]; ok { + flagArgument = true + } + case flagArgument: + cmdArgs = append(cmdArgs, arg) + flagArgument = false + default: + command = append(command, arg) + } + } + + return cmdArgs, strings.Join(command, " ") +} diff --git a/internal/codespaces/ssh_test.go b/internal/codespaces/ssh_test.go new file mode 100644 index 000000000..cd92b39a6 --- /dev/null +++ b/internal/codespaces/ssh_test.go @@ -0,0 +1,64 @@ +package codespaces + +import "testing" + +func TestParseSSHArgs(t *testing.T) { + type testCase struct { + Args []string + ParsedArgs []string + Command string + } + + testCases := []testCase{ + { + Args: []string{"-X", "-Y"}, + ParsedArgs: []string{"-X", "-Y"}, + Command: "", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "somecommand", + }, + { + Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"}, + ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"}, + Command: "echo test", + }, + { + Args: []string{"somecommand"}, + ParsedArgs: []string{}, + Command: "somecommand", + }, + { + Args: []string{"echo", "test"}, + ParsedArgs: []string{}, + Command: "echo test", + }, + { + Args: []string{"-v", "echo", "hello", "world"}, + ParsedArgs: []string{"-v"}, + Command: "echo hello world", + }, + } + + for _, tcase := range testCases { + args, command := parseSSHArgs(tcase.Args) + if len(args) != len(tcase.ParsedArgs) { + t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs)) + } + for i, arg := range args { + if arg != tcase.ParsedArgs[i] { + t.Fatalf("arg does not match expected parsed arg. %v, got '%s', expected: '%s'", tcase, arg, tcase.ParsedArgs[i]) + } + } + if command != tcase.Command { + t.Fatalf("command does not match expected command. %v, got: '%s', expected: '%s'", tcase, command, tcase.Command) + } + } +}