diff --git a/cmd/ghcs/logs.go b/cmd/ghcs/logs.go index b8bc462fc..a5e9235bc 100644 --- a/cmd/ghcs/logs.go +++ b/cmd/ghcs/logs.go @@ -59,7 +59,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { return fmt.Errorf("connecting to liveshare: %v", err) } - tunnelPort, connClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", 0) + port, err := codespaces.UnusedPort() + if err != nil { + return err + } + + tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", port) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -71,23 +76,29 @@ func logs(ctx context.Context, tail bool, codespaceName string) error { dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace)) cmd := codespaces.NewRemoteCommand( - ctx, tunnelPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), + ctx, port, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType), ) - // Channel is buffered to avoid a goroutine leak when connClosed occurs before done. - done := make(chan error, 1) - go func() { done <- cmd.Run() }() + // Error channels are buffered so that neither sending goroutine gets stuck. + + tunnelClosed := make(chan error, 1) + go func() { + tunnelClosed <- tunnel.Start(ctx) // error is non-nil + }() + + cmdDone := make(chan error, 1) + go func() { + cmdDone <- cmd.Run() + }() select { - case err := <-connClosed: - if err != nil { - return fmt.Errorf("connection closed: %v", err) - } - case err := <-done: + case err := <-tunnelClosed: + return fmt.Errorf("connection closed: %v", err) + + case err := <-cmdDone: if err != nil { return fmt.Errorf("error retrieving logs: %v", err) } + return nil // success } - - return nil } diff --git a/cmd/ghcs/ssh.go b/cmd/ghcs/ssh.go index 3de123dfa..2b51c4cc1 100644 --- a/cmd/ghcs/ssh.go +++ b/cmd/ghcs/ssh.go @@ -80,7 +80,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in } log.Print("\n") - tunnelPort, tunnelClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", sshServerPort) + usingCustomPort := true + if sshServerPort == 0 { + usingCustomPort = false // suppress log of command line in Shell + port, err := codespaces.UnusedPort() + if err != nil { + return err + } + sshServerPort = port + } + + tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", sshServerPort) if err != nil { return fmt.Errorf("make ssh tunnel: %v", err) } @@ -90,26 +100,27 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace)) } - usingCustomPort := tunnelPort == sshServerPort + tunnelClosed := make(chan error) + go func() { + tunnelClosed <- tunnel.Start(ctx) // error is always non-nil + }() shellClosed := make(chan error) go func() { - shellClosed <- codespaces.Shell(ctx, log, tunnelPort, connectDestination, usingCustomPort) + shellClosed <- codespaces.Shell(ctx, log, sshServerPort, connectDestination, usingCustomPort) }() log.Println("Ready...") select { case err := <-tunnelClosed: - if err != nil { - return fmt.Errorf("tunnel closed: %v", err) - } + return fmt.Errorf("tunnel closed: %v", err) + case err := <-shellClosed: if err != nil { return fmt.Errorf("shell closed: %v", err) } + return nil // success } - - return nil } func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) { diff --git a/internal/codespaces/ssh.go b/internal/codespaces/ssh.go index ac68f13ca..5118ad91c 100644 --- a/internal/codespaces/ssh.go +++ b/internal/codespaces/ssh.go @@ -3,7 +3,7 @@ package codespaces import ( "context" "fmt" - "math/rand" + "net" "os" "os/exec" "strconv" @@ -12,57 +12,47 @@ import ( "github.com/github/go-liveshare" ) -// StartPortForwarding starts LiveShare port forwarding of traffic -// between the LiveShare client and the specified local port, or, if -// zero, a port chosen at random; the effective port number is -// returned. Forwarding continues in the background until an error is -// encountered (including cancellation of the context). Therefore -// clients must cancel the context. +// UnusedPort returns the number of a local TCP port that is currently +// unbound, or an error if none was available. +// +// Use of this function carries an inherent risk of a time-of-check to +// time-of-use race against other processes. +func UnusedPort() (int, error) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + return 0, fmt.Errorf("internal error while choosing port: %v", err) + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return 0, fmt.Errorf("choosing available port: %v", err) + } + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil +} + +// NewPortForwarder returns a new port forwarder for traffic between +// the Live Share client and the specified local port (which must be +// available). // // The session name is used (along with the port) to generate // names for streams, and may appear in error messages. -// -// TODO(adonovan): simplify API concurrency from API. Either: -// 1) return a stop function so that clients don't forget to stop forwarding. -// 2) avoid creating a goroutine and returning a channel. Use approach of -// http.ListenAndServe, which runs until it encounters an error -// (incl. cancellation). But this means we can't return the port. -// Can we make the client responsible for supplying it? -// 3) return a PortForwarding object that encapsulates the port, -// and has NewRemoteCommand as a method. It will need a Stop method, -// and an Error method for querying whether the session has failed -// asynchronously. -func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, sessionName string, localPort int) (int, <-chan error, error) { - server, err := liveshare.NewServer(lsclient) - if err != nil { - return 0, nil, fmt.Errorf("new liveshare server: %v", err) +func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localPort int) (*liveshare.PortForwarder, error) { + if localPort == 0 { + return nil, fmt.Errorf("a local port must be provided") } - if localPort == 0 { - localPort = rand.Intn(9999-2000) + 2000 - // TODO(adonovan): retry if port is taken? + server, err := liveshare.NewServer(client) + if err != nil { + return nil, fmt.Errorf("new liveshare server: %v", err) } // TODO(josebalius): This port won't always be 2222 if err := server.StartSharing(ctx, sessionName, 2222); err != nil { - return 0, nil, fmt.Errorf("sharing sshd port: %v", err) + return nil, fmt.Errorf("sharing sshd port: %v", err) } - tunnelClosed := make(chan error) - go func() { - // TODO(adonovan): simplify liveshare API to combine NewPortForwarder and Start - // methods into a single ForwardPort call, like http.ListenAndServe. - // (Start is a misnomer: it runs the complete session.) - // Also document that it never returns a nil error. - portForwarder := liveshare.NewPortForwarder(lsclient, server, localPort) - if err := portForwarder.Start(ctx); err != nil { - tunnelClosed <- fmt.Errorf("forwarding port: %v", err) - return - } - tunnelClosed <- nil - }() - - return localPort, tunnelClosed, nil + return liveshare.NewPortForwarder(client, server, localPort), nil } // Shell runs an interactive secure shell over an existing @@ -78,8 +68,8 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC return cmd.Run() } -// NewRemoteCommand returns a partially populated exec.Cmd that will -// securely run a shell command on the remote machine. +// NewRemoteCommand returns an exec.Cmd that will securely run a shell +// command on the remote machine. func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd { cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command) return cmd @@ -92,7 +82,6 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm // TODO(adonovan): eliminate X11 and X11Trust flags where unneeded. cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression - // An empty command enables port forwarding but not execution. if command != "" { cmdArgs = append(cmdArgs, command) } diff --git a/internal/codespaces/states.go b/internal/codespaces/states.go index fe34f5486..870840304 100644 --- a/internal/codespaces/states.go +++ b/internal/codespaces/states.go @@ -45,22 +45,34 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u return fmt.Errorf("connect to liveshare: %v", err) } - tunnelPort, connClosed, err := StartPortForwarding(ctx, lsclient, "sshd", 0) + port, err := UnusedPort() if err != nil { - return fmt.Errorf("make ssh tunnel: %v", err) + return err } + fwd, err := NewPortForwarder(ctx, lsclient, "sshd", port) + if err != nil { + return fmt.Errorf("creating port forwarder: %v", err) + } + + tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness + go func() { + tunnelClosed <- fwd.Start(ctx) // error is non-nil + }() + t := time.NewTicker(1 * time.Second) defer t.Stop() for { select { case <-ctx.Done(): - return nil - case err := <-connClosed: - return fmt.Errorf("connection closed: %v", err) + return nil // canceled + + case err := <-tunnelClosed: + return fmt.Errorf("connection failed: %v", err) + case <-t.C: - states, err := getPostCreateOutput(ctx, tunnelPort, codespace) + states, err := getPostCreateOutput(ctx, port, codespace) if err != nil { return fmt.Errorf("get post create output: %v", err) }