cli/pkg/cmd/codespace/ssh.go

172 lines
4.7 KiB
Go

package codespace
import (
"context"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"github.com/cli/cli/v2/internal/codespaces"
"github.com/cli/cli/v2/pkg/liveshare"
"github.com/spf13/cobra"
)
type sshOptions struct {
codespace string
profile string
serverPort int
debug bool
debugFile string
}
func newSSHCmd(app *App) *cobra.Command {
var opts sshOptions
sshCmd := &cobra.Command{
Use: "ssh [flags] [--] [ssh-flags] [command]",
Short: "SSH into a codespace",
RunE: func(cmd *cobra.Command, args []string) error {
return app.SSH(cmd.Context(), args, opts)
},
}
sshCmd.Flags().StringVarP(&opts.profile, "profile", "", "", "Name of the SSH profile to use")
sshCmd.Flags().IntVarP(&opts.serverPort, "server-port", "", 0, "SSH server port number (0 => pick unused)")
sshCmd.Flags().StringVarP(&opts.codespace, "codespace", "c", "", "Name of the codespace")
sshCmd.Flags().BoolVarP(&opts.debug, "debug", "d", false, "Log debug data to a file")
sshCmd.Flags().StringVarP(&opts.debugFile, "debug-file", "", "", "Path of the file log to")
return sshCmd
}
// SSH opens an ssh session or runs an ssh command in a codespace.
func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err error) {
// Ensure all child tasks (e.g. port forwarding) terminate before return.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
user, err := a.apiClient.GetUser(ctx)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
authkeys := make(chan error, 1)
go func() {
authkeys <- checkAuthorizedKeys(ctx, a.apiClient, user.Login)
}()
codespace, err := getOrChooseCodespace(ctx, a.apiClient, opts.codespace)
if err != nil {
return fmt.Errorf("get or choose codespace: %w", err)
}
liveshareLogger := noopLogger()
if opts.debug {
debugLogger, err := newFileLogger(opts.debugFile)
if err != nil {
return fmt.Errorf("error creating debug logger: %w", err)
}
defer safeClose(debugLogger, &err)
liveshareLogger = debugLogger.Logger
a.logger.Println("Debug file located at: " + debugLogger.Name())
}
session, err := codespaces.ConnectToLiveshare(ctx, a.logger, liveshareLogger, a.apiClient, codespace)
if err != nil {
return fmt.Errorf("error connecting to Live Share: %w", err)
}
defer safeClose(session, &err)
if err := <-authkeys; err != nil {
return err
}
a.logger.Println("Fetching SSH Details...")
remoteSSHServerPort, sshUser, err := session.StartSSHServer(ctx)
if err != nil {
return fmt.Errorf("error getting ssh server details: %w", err)
}
localSSHServerPort := opts.serverPort
usingCustomPort := localSSHServerPort != 0 // suppress log of command line in Shell
// Ensure local port is listening before client (Shell) connects.
// Unless the user specifies a server port, localSSHServerPort is 0
// and thus the client will pick a random port.
listen, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", localSSHServerPort))
if err != nil {
return err
}
defer listen.Close()
localSSHServerPort = listen.Addr().(*net.TCPAddr).Port
connectDestination := opts.profile
if connectDestination == "" {
connectDestination = fmt.Sprintf("%s@localhost", sshUser)
}
a.logger.Println("Ready...")
tunnelClosed := make(chan error, 1)
go func() {
fwd := liveshare.NewPortForwarder(session, "sshd", remoteSSHServerPort, true)
tunnelClosed <- fwd.ForwardToListener(ctx, listen) // always non-nil
}()
shellClosed := make(chan error, 1)
go func() {
shellClosed <- codespaces.Shell(ctx, a.logger, sshArgs, localSSHServerPort, connectDestination, usingCustomPort)
}()
select {
case err := <-tunnelClosed:
return fmt.Errorf("tunnel closed: %w", err)
case err := <-shellClosed:
if err != nil {
return fmt.Errorf("shell closed: %w", err)
}
return nil // success
}
}
// fileLogger is a wrapper around an log.Logger configured to write
// to a file. It exports two additional methods to get the log file name
// and close the file handle when the operation is finished.
type fileLogger struct {
*log.Logger
f *os.File
}
// newFileLogger creates a new fileLogger. It returns an error if the file
// cannot be created. The file is created on the specified path, if the path
// is empty it is created in the temporary directory.
func newFileLogger(file string) (fl *fileLogger, err error) {
var f *os.File
if file == "" {
f, err = ioutil.TempFile("", "")
if err != nil {
return nil, fmt.Errorf("failed to create tmp file: %w", err)
}
} else {
f, err = os.Create(file)
if err != nil {
return nil, err
}
}
return &fileLogger{
Logger: log.New(f, "", log.LstdFlags),
f: f,
}, nil
}
func (fl *fileLogger) Name() string {
return fl.f.Name()
}
func (fl *fileLogger) Close() error {
return fl.f.Close()
}