move port choice, and PortForwarder.Start call, into clients

This commit is contained in:
Alan Donovan 2021-08-31 13:52:37 -04:00
parent 509e037a5e
commit c0aae52289
4 changed files with 93 additions and 70 deletions

View file

@ -59,7 +59,12 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
return fmt.Errorf("connecting to liveshare: %v", err)
}
tunnelPort, connClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", 0)
port, err := codespaces.UnusedPort()
if err != nil {
return err
}
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", port)
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
}
@ -71,23 +76,29 @@ func logs(ctx context.Context, tail bool, codespaceName string) error {
dst := fmt.Sprintf("%s@localhost", getSSHUser(codespace))
cmd := codespaces.NewRemoteCommand(
ctx, tunnelPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
ctx, port, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
)
// Channel is buffered to avoid a goroutine leak when connClosed occurs before done.
done := make(chan error, 1)
go func() { done <- cmd.Run() }()
// Error channels are buffered so that neither sending goroutine gets stuck.
tunnelClosed := make(chan error, 1)
go func() {
tunnelClosed <- tunnel.Start(ctx) // error is non-nil
}()
cmdDone := make(chan error, 1)
go func() {
cmdDone <- cmd.Run()
}()
select {
case err := <-connClosed:
if err != nil {
return fmt.Errorf("connection closed: %v", err)
}
case err := <-done:
case err := <-tunnelClosed:
return fmt.Errorf("connection closed: %v", err)
case err := <-cmdDone:
if err != nil {
return fmt.Errorf("error retrieving logs: %v", err)
}
return nil // success
}
return nil
}

View file

@ -80,7 +80,17 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in
}
log.Print("\n")
tunnelPort, tunnelClosed, err := codespaces.StartPortForwarding(ctx, lsclient, "sshd", sshServerPort)
usingCustomPort := true
if sshServerPort == 0 {
usingCustomPort = false // suppress log of command line in Shell
port, err := codespaces.UnusedPort()
if err != nil {
return err
}
sshServerPort = port
}
tunnel, err := codespaces.NewPortForwarder(ctx, lsclient, "sshd", sshServerPort)
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
}
@ -90,26 +100,27 @@ func ssh(ctx context.Context, sshProfile, codespaceName string, sshServerPort in
connectDestination = fmt.Sprintf("%s@localhost", getSSHUser(codespace))
}
usingCustomPort := tunnelPort == sshServerPort
tunnelClosed := make(chan error)
go func() {
tunnelClosed <- tunnel.Start(ctx) // error is always non-nil
}()
shellClosed := make(chan error)
go func() {
shellClosed <- codespaces.Shell(ctx, log, tunnelPort, connectDestination, usingCustomPort)
shellClosed <- codespaces.Shell(ctx, log, sshServerPort, connectDestination, usingCustomPort)
}()
log.Println("Ready...")
select {
case err := <-tunnelClosed:
if err != nil {
return fmt.Errorf("tunnel closed: %v", err)
}
return fmt.Errorf("tunnel closed: %v", err)
case err := <-shellClosed:
if err != nil {
return fmt.Errorf("shell closed: %v", err)
}
return nil // success
}
return nil
}
func getContainerID(ctx context.Context, logger *output.Logger, terminal *liveshare.Terminal) (string, error) {

View file

@ -3,7 +3,7 @@ package codespaces
import (
"context"
"fmt"
"math/rand"
"net"
"os"
"os/exec"
"strconv"
@ -12,57 +12,47 @@ import (
"github.com/github/go-liveshare"
)
// StartPortForwarding starts LiveShare port forwarding of traffic
// between the LiveShare client and the specified local port, or, if
// zero, a port chosen at random; the effective port number is
// returned. Forwarding continues in the background until an error is
// encountered (including cancellation of the context). Therefore
// clients must cancel the context.
// UnusedPort returns the number of a local TCP port that is currently
// unbound, or an error if none was available.
//
// Use of this function carries an inherent risk of a time-of-check to
// time-of-use race against other processes.
func UnusedPort() (int, error) {
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
return 0, fmt.Errorf("internal error while choosing port: %v", err)
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return 0, fmt.Errorf("choosing available port: %v", err)
}
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
// NewPortForwarder returns a new port forwarder for traffic between
// the Live Share client and the specified local port (which must be
// available).
//
// The session name is used (along with the port) to generate
// names for streams, and may appear in error messages.
//
// TODO(adonovan): simplify API concurrency from API. Either:
// 1) return a stop function so that clients don't forget to stop forwarding.
// 2) avoid creating a goroutine and returning a channel. Use approach of
// http.ListenAndServe, which runs until it encounters an error
// (incl. cancellation). But this means we can't return the port.
// Can we make the client responsible for supplying it?
// 3) return a PortForwarding object that encapsulates the port,
// and has NewRemoteCommand as a method. It will need a Stop method,
// and an Error method for querying whether the session has failed
// asynchronously.
func StartPortForwarding(ctx context.Context, lsclient *liveshare.Client, sessionName string, localPort int) (int, <-chan error, error) {
server, err := liveshare.NewServer(lsclient)
if err != nil {
return 0, nil, fmt.Errorf("new liveshare server: %v", err)
func NewPortForwarder(ctx context.Context, client *liveshare.Client, sessionName string, localPort int) (*liveshare.PortForwarder, error) {
if localPort == 0 {
return nil, fmt.Errorf("a local port must be provided")
}
if localPort == 0 {
localPort = rand.Intn(9999-2000) + 2000
// TODO(adonovan): retry if port is taken?
server, err := liveshare.NewServer(client)
if err != nil {
return nil, fmt.Errorf("new liveshare server: %v", err)
}
// TODO(josebalius): This port won't always be 2222
if err := server.StartSharing(ctx, sessionName, 2222); err != nil {
return 0, nil, fmt.Errorf("sharing sshd port: %v", err)
return nil, fmt.Errorf("sharing sshd port: %v", err)
}
tunnelClosed := make(chan error)
go func() {
// TODO(adonovan): simplify liveshare API to combine NewPortForwarder and Start
// methods into a single ForwardPort call, like http.ListenAndServe.
// (Start is a misnomer: it runs the complete session.)
// Also document that it never returns a nil error.
portForwarder := liveshare.NewPortForwarder(lsclient, server, localPort)
if err := portForwarder.Start(ctx); err != nil {
tunnelClosed <- fmt.Errorf("forwarding port: %v", err)
return
}
tunnelClosed <- nil
}()
return localPort, tunnelClosed, nil
return liveshare.NewPortForwarder(client, server, localPort), nil
}
// Shell runs an interactive secure shell over an existing
@ -78,8 +68,8 @@ func Shell(ctx context.Context, log logger, port int, destination string, usingC
return cmd.Run()
}
// NewRemoteCommand returns a partially populated exec.Cmd that will
// securely run a shell command on the remote machine.
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
// command on the remote machine.
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
return cmd
@ -92,7 +82,6 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm
// TODO(adonovan): eliminate X11 and X11Trust flags where unneeded.
cmdArgs := append([]string{dst, "-X", "-Y", "-C"}, connArgs...) // X11, X11Trust, Compression
// An empty command enables port forwarding but not execution.
if command != "" {
cmdArgs = append(cmdArgs, command)
}

View file

@ -45,22 +45,34 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
return fmt.Errorf("connect to liveshare: %v", err)
}
tunnelPort, connClosed, err := StartPortForwarding(ctx, lsclient, "sshd", 0)
port, err := UnusedPort()
if err != nil {
return fmt.Errorf("make ssh tunnel: %v", err)
return err
}
fwd, err := NewPortForwarder(ctx, lsclient, "sshd", port)
if err != nil {
return fmt.Errorf("creating port forwarder: %v", err)
}
tunnelClosed := make(chan error, 1) // buffered to avoid sender stuckness
go func() {
tunnelClosed <- fwd.Start(ctx) // error is non-nil
}()
t := time.NewTicker(1 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return nil
case err := <-connClosed:
return fmt.Errorf("connection closed: %v", err)
return nil // canceled
case err := <-tunnelClosed:
return fmt.Errorf("connection failed: %v", err)
case <-t.C:
states, err := getPostCreateOutput(ctx, tunnelPort, codespace)
states, err := getPostCreateOutput(ctx, port, codespace)
if err != nil {
return fmt.Errorf("get post create output: %v", err)
}