Parse SSH args before creating the shell
This commit is contained in:
parent
8c77e53c35
commit
2cb044caf5
3 changed files with 32 additions and 32 deletions
|
|
@ -19,9 +19,9 @@ type printer interface {
|
|||
// port-forwarding session. It runs until the shell is terminated
|
||||
// (including by cancellation of the context).
|
||||
func Shell(
|
||||
ctx context.Context, keepAliveOverride chan (bool), p printer, sshArgs []string, port int, destination string, printConnDetails bool,
|
||||
ctx context.Context, p printer, sshArgs []string, command []string, port int, destination string, printConnDetails bool,
|
||||
) error {
|
||||
cmd, connArgs, err := newSSHCommand(ctx, keepAliveOverride, port, destination, sshArgs)
|
||||
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs, command)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ssh command: %w", err)
|
||||
}
|
||||
|
|
@ -51,42 +51,30 @@ func Copy(ctx context.Context, scpArgs []string, port int, destination string) e
|
|||
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
|
||||
// command on the remote machine.
|
||||
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) {
|
||||
cmd, _, err := newSSHCommand(ctx, nil, tunnelPort, destination, sshArgs)
|
||||
sshArgs, command, err := ParseSSHArgs(sshArgs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs, command)
|
||||
return cmd, err
|
||||
}
|
||||
|
||||
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
|
||||
// an interactive shell) over ssh.
|
||||
func newSSHCommand(ctx context.Context, keepAliveOverride chan (bool), port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) {
|
||||
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string, command []string) (*exec.Cmd, []string, error) {
|
||||
connArgs := []string{
|
||||
"-p", strconv.Itoa(port),
|
||||
"-o", "NoHostAuthenticationForLocalhost=yes",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
}
|
||||
|
||||
// The ssh command syntax is: ssh [flags] user@host command [args...]
|
||||
// There is no way to specify the user@host destination as a flag.
|
||||
// Unfortunately, that means we need to know which user-provided words are
|
||||
// SSH flags and which are command arguments so that we can place
|
||||
// them before or after the destination, and that means we need to know all
|
||||
// the flags and their arities.
|
||||
cmdArgs, command, err := parseSSHArgs(cmdArgs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cmdArgs = append(cmdArgs, connArgs...)
|
||||
cmdArgs = append(cmdArgs, "-C") // Compression
|
||||
cmdArgs = append(cmdArgs, dst) // user@host
|
||||
|
||||
if command != nil {
|
||||
cmdArgs = append(cmdArgs, command...)
|
||||
|
||||
// If the user specified a command to run non-interactively,
|
||||
// make sure we send activity signals to keep the connection alive
|
||||
if keepAliveOverride != nil {
|
||||
keepAliveOverride <- true
|
||||
}
|
||||
}
|
||||
|
||||
exe, err := safeexec.LookPath("ssh")
|
||||
|
|
@ -102,7 +90,14 @@ func newSSHCommand(ctx context.Context, keepAliveOverride chan (bool), port int,
|
|||
return cmd, connArgs, nil
|
||||
}
|
||||
|
||||
func parseSSHArgs(args []string) (cmdArgs, command []string, err error) {
|
||||
// ParseSSHArgs parses the given array of arguments into two distinct slices of flags and command.
|
||||
// The ssh command syntax is: ssh [flags] user@host command [args...]
|
||||
// There is no way to specify the user@host destination as a flag.
|
||||
// Unfortunately, that means we need to know which user-provided words are
|
||||
// SSH flags and which are command arguments so that we can place
|
||||
// them before or after the destination, and that means we need to know all
|
||||
// the flags and their arities.
|
||||
func ParseSSHArgs(args []string) (cmdArgs, command []string, err error) {
|
||||
return parseArgs(args, "bcDeFIiLlmOopRSWw")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func TestParseSSHArgs(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, tcase := range testCases {
|
||||
args, command, err := parseSSHArgs(tcase.Args)
|
||||
args, command, err := ParseSSHArgs(tcase.Args)
|
||||
|
||||
checkParseResult(t, tcase, args, command, err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -281,17 +281,22 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
|
|||
// args is the correct variable to use here, we just use scpArgs as the check for which command to run
|
||||
err = codespaces.Copy(ctx, args, localSSHServerPort, connectDestination)
|
||||
} else {
|
||||
// Create a channel to send down to the shell to keep it alive
|
||||
keepAliveOverride := make(chan bool, 1)
|
||||
go func() {
|
||||
// If we receive true on the channel, ignore the timeout
|
||||
if <-keepAliveOverride {
|
||||
invoker.KeepAlive()
|
||||
}
|
||||
}()
|
||||
// Parse the ssh args to determine if the user specified a command
|
||||
args, command, err := codespaces.ParseSSHArgs(args)
|
||||
if err != nil {
|
||||
shellClosed <- err
|
||||
return
|
||||
}
|
||||
|
||||
// If the user specified a command, we need to keep the shell alive
|
||||
// since it will be non-interactive and the codespace might shut down
|
||||
// before the command finishes
|
||||
if command != nil {
|
||||
invoker.KeepAlive()
|
||||
}
|
||||
|
||||
err = codespaces.Shell(
|
||||
ctx, keepAliveOverride, a.errLogger, args, localSSHServerPort, connectDestination, opts.printConnDetails,
|
||||
ctx, a.errLogger, args, command, localSSHServerPort, connectDestination, opts.printConnDetails,
|
||||
)
|
||||
}
|
||||
shellClosed <- err
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue